├── LICENSE ├── README.md ├── dataset ├── jf17k │ ├── n-ary_test.json │ ├── n-ary_train.json │ └── n-ary_valid.json ├── wd50k │ ├── n-ary_test.json │ ├── n-ary_train.json │ └── n-ary_valid.json └── wikipeople │ ├── n-ary_test.json │ ├── n-ary_train.json │ └── n-ary_valid.json ├── model ├── NYLON.py ├── NYLONModel.py └── graph_encoder.py ├── reader ├── data_loader.py └── data_reader.py ├── run.py └── utils ├── args.py └── evaluation.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 UM-Data-Intelligence-Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NYLON (Robust HKG Model) 2 | 3 | NYLON is a robust link prediction model over noisy hyper-relational knowledge graphs. Trained with an active learning strategy, NYLON evaluates the confidence of facts and rebalances the loss for each fact with its confidence to alleviate the negative impact of less confident facts. Please see the details in our paper below: 4 | - Weijian Yu, Jie Yang and Dingqi Yang, Robust Link Prediction over Noisy Hyper-Relational Knowledge Graphs via Active Learning. In Proceedings of the ACM Web Conference 2024 (WWW'24), May 13-17, 2024, Singapore. 5 | 6 | ## How to run the code 7 | ###### Train and evaluate model (suggested parameters for JF17k, WikiPeople and WD50K dataset) 8 | ``` 9 | python3 -u run.py --input "dataset/jf17k" 10 | 11 | python3 -u run.py --input "dataset/wikipeople" 12 | 13 | python3 -u run.py --input "dataset/wd50k" 14 | ``` 15 | 16 | ###### Parameter setting: 17 | In `run.py`, you can set: 18 | 19 | `--input`: input dataset. 20 | 21 | `--epochs`: number of training epochs. 22 | 23 | `--batch_size`: batch size of training set. 24 | 25 | `--learning_rate`: learning rate. 26 | 27 | `--noise_level`: noise level in float, i.e., 1.0 refers to generate 100% noisy facts of positive facts. 28 | 29 | `--active_sample_per_epoch`: active labeling budget in float, i.e., 0.0025 refers to labeling 0.25% elements per epoch. 30 | 31 | `--aug_amount`: number of pseudo-labeled positive facts for every positive fact in training of confidence evaluator. For more detail please refer to our paper. 32 | 33 | # Python lib versions 34 | Python: 3.10.12 35 | 36 | torch: 2.1.0 37 | 38 | # Reference 39 | If you use our code or datasets, please cite: 40 | ``` 41 | @inproceedings{yu2024robust, 42 | title={Robust Link Prediction over Noisy Hyper-Relational Knowledge Graphs via Active Learning}, 43 | author={Yu, Weijian and Yang, Jie and Yang, Dingqi}, 44 | booktitle={Proceedings of the ACM Web Conference 2024}, 45 | pages={}, 46 | year={2024} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /model/NYLON.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import logging 6 | # from model.HGNN_encoder import HGNNLayer 7 | import time 8 | from utils.evaluation import eval_type_hyperbolic 9 | import torch 10 | import torch.nn 11 | from model.NYLONModel import NYLONModel 12 | from model.graph_encoder import truncated_normal 13 | torch.set_printoptions(precision=16) 14 | 15 | logging.basicConfig( 16 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 17 | datefmt='%m/%d/%Y %H:%M:%S') 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(logging.DEBUG) 20 | logger.info(logger.getEffectiveLevel()) 21 | 22 | class NYLON(torch.nn.Module): 23 | def __init__(self, ins_info, config): 24 | super(NYLON,self).__init__() 25 | #CONFIG SETTING 26 | self.config = config 27 | self.ins_node_num = ins_info["node_num"] 28 | 29 | #INIT EMBEDDING 30 | self.ins_node_embeddings = torch.nn.Embedding(self.ins_node_num, self.config['dim']) 31 | self.ins_node_embeddings.weight.data=truncated_normal(self.ins_node_embeddings.weight.data,std=0.02) 32 | 33 | 34 | ##GRAN_LAYER 35 | self.ins_config=dict() 36 | self.ins_config['num_hidden_layers']=self.config['num_hidden_layers'] 37 | self.ins_config['num_attention_heads']=self.config['num_attention_heads'] 38 | self.ins_config['hidden_size']=self.config['dim'] 39 | self.ins_config['intermediate_size']=self.config['ins_intermediate_size'] 40 | self.ins_config['hidden_dropout_prob']=self.config['hidden_dropout_prob'] 41 | self.ins_config['attention_dropout_prob']=self.config['attention_dropout_prob'] 42 | self.ins_config['vocab_size']=self.ins_node_num 43 | self.ins_config['num_relations']=ins_info["rel_num"] 44 | self.ins_config['num_edges']=self.config['num_edges'] 45 | self.ins_config['max_arity']=ins_info['max_n'] 46 | self.ins_config['device']=self.config['device'] 47 | self.ins_granlayer=NYLONModel(self.ins_config,self.ins_node_embeddings).to(self.config['device']) 48 | 49 | 50 | 51 | def forward_E(self,ins_pos,ins_edge_labels, tag, confidence, correct_rate): 52 | # print(len(ins_pos)) 53 | ins_input_ids, ins_input_mask, ins_mask_pos, ins_mask_label, ins_mask_type = ins_pos 54 | if tag == "normal": 55 | self.ins_triple_loss, self.ins_fc_out = self.ins_granlayer(ins_input_ids, ins_input_mask, ins_edge_labels, 56 | ins_mask_pos, ins_mask_label, ins_mask_type, tag, 57 | None,None, confidence, correct_rate) 58 | return self.ins_triple_loss , self.ins_fc_out 59 | else: 60 | self.ins_triple_loss, self.ins_fc_out, fc_out_vector, embeddings = self.ins_granlayer(ins_input_ids, ins_mask_label, ins_edge_labels, 61 | None, None, None, tag, 62 | ins_input_mask, ins_mask_pos, confidence, correct_rate) 63 | return self.ins_triple_loss, self.ins_fc_out, fc_out_vector, embeddings 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /model/NYLONModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import copy 6 | import logging 7 | import torch 8 | import torch.nn 9 | from model.graph_encoder import encoder,truncated_normal 10 | 11 | logging.basicConfig( 12 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 13 | datefmt='%m/%d/%Y %H:%M:%S') 14 | logger = logging.getLogger(__name__) 15 | logger.setLevel(logging.DEBUG) 16 | logger.info(logger.getEffectiveLevel()) 17 | 18 | class NYLONModel(torch.nn.Module): 19 | def __init__(self,config,node_embeddings): 20 | super(NYLONModel,self).__init__() 21 | 22 | self._n_layer = config['num_hidden_layers'] 23 | self._n_head = config['num_attention_heads'] 24 | self._emb_size = config['hidden_size'] 25 | self._intermediate_size = config['intermediate_size'] 26 | self._prepostprocess_dropout = config['hidden_dropout_prob'] 27 | self._attention_dropout = config['attention_dropout_prob'] 28 | 29 | self._voc_size = config['vocab_size'] 30 | self._n_relation = config['num_relations'] 31 | self._n_edge = config['num_edges'] 32 | self._max_arity = config['max_arity'] 33 | self._max_seq_len = self._max_arity*2-1 34 | 35 | self._device=config["device"] 36 | 37 | self.node_embedding=node_embeddings 38 | self.layer_norm1=torch.nn.LayerNorm(normalized_shape=self._emb_size,eps=1e-12,elementwise_affine=True) 39 | self.edge_embedding_k=torch.nn.Embedding(self._n_edge, self._emb_size // self._n_head) 40 | self.edge_embedding_k.weight.data=truncated_normal(self.edge_embedding_k.weight.data,std=0.02) 41 | self.edge_embedding_v=torch.nn.Embedding(self._n_edge, self._emb_size // self._n_head) 42 | self.edge_embedding_v.weight.data=truncated_normal(self.edge_embedding_v.weight.data,std=0.02) 43 | self.encoder_model=encoder( 44 | n_layer=self._n_layer, 45 | n_head=self._n_head, 46 | d_key=self._emb_size // self._n_head, 47 | d_value=self._emb_size // self._n_head, 48 | d_model=self._emb_size, 49 | d_inner_hid=self._intermediate_size, 50 | prepostprocess_dropout=self._prepostprocess_dropout, 51 | attention_dropout=self._attention_dropout) 52 | self.fc1=torch.nn.Linear(self._emb_size, self._emb_size) 53 | self.fc1.weight.data=truncated_normal(self.fc1.weight.data,std=0.02) 54 | torch.nn.init.constant_(self.fc1.bias, 0.0) 55 | self.layer_norm2=torch.nn.LayerNorm(normalized_shape=self._emb_size,eps=1e-7,elementwise_affine=True) 56 | self.fc2_bias = torch.nn.init.constant_(torch.nn.parameter.Parameter(torch.Tensor(self._voc_size)), 0.0) 57 | self.fc3 = torch.nn.Linear(self._emb_size, self._emb_size) 58 | torch.nn.init.xavier_uniform_(self.fc3.weight) 59 | torch.nn.init.constant_(self.fc3.bias, 0.0) 60 | self.fc4 = torch.nn.Linear(self._emb_size, self._emb_size) 61 | torch.nn.init.xavier_uniform_(self.fc4.weight) 62 | torch.nn.init.constant_(self.fc4.bias, 0.0) 63 | self.fc5 = torch.nn.Linear(self._emb_size, self._max_seq_len) 64 | torch.nn.init.xavier_uniform_(self.fc5.weight) 65 | torch.nn.init.constant_(self.fc5.bias, 0.0) 66 | self.fc_tuple = torch.nn.Linear(self._emb_size*self._max_seq_len, self._emb_size) 67 | torch.nn.init.xavier_uniform_(self.fc_tuple.weight) 68 | torch.nn.init.constant_(self.fc_tuple.bias, 0.0) 69 | self.scaler = torch.nn.Parameter(torch.tensor(0.3)) 70 | self.myloss = softmax_with_cross_entropy() 71 | self.loss_conf = torch.nn.BCELoss() 72 | 73 | def forward(self,input_ids, input_mask, edge_labels, mask_pos,mask_label, mask_type, mode, is_true, is_shown, confidence, correct_rate): 74 | emb_out = self.node_embedding(input_ids) 75 | batch_size = emb_out.shape[0] 76 | emb_out = torch.nn.Dropout(self._prepostprocess_dropout)(self.layer_norm1(emb_out)) 77 | edges_key = self.edge_embedding_k(edge_labels) 78 | edges_value = self.edge_embedding_v(edge_labels) 79 | edge_mask = torch.sign(edge_labels).unsqueeze(2) 80 | edges_key = torch.mul(edges_key, edge_mask) 81 | edges_value = torch.mul(edges_value, edge_mask) 82 | input_mask=input_mask.unsqueeze(2) 83 | self_attn_mask = torch.matmul(input_mask,input_mask.transpose(1,2)) 84 | self_attn_mask=1000000.0*(self_attn_mask-1.0) 85 | n_head_self_attn_mask = torch.stack([self_attn_mask] * self._n_head, dim=1) 86 | _enc_out = self.encoder_model( 87 | enc_input=emb_out, 88 | edges_key=edges_key, 89 | edges_value=edges_value, 90 | attn_bias=n_head_self_attn_mask) 91 | if mode == "normal": 92 | mask_pos = mask_pos.unsqueeze(1) 93 | mask_pos = mask_pos[:, :, None].expand(-1, -1, self._emb_size) 94 | h_masked = torch.gather(input=_enc_out, dim=1, index=mask_pos).reshape([-1, _enc_out.size(-1)]) 95 | fc_out = self.fc1(h_masked) 96 | h_masked = torch.nn.GELU()(h_masked) 97 | h_masked = self.layer_norm2(h_masked) 98 | fc_out = torch.nn.functional.linear(h_masked, self.node_embedding.weight, self.fc2_bias) 99 | special_indicator = torch.empty(input_ids.size(0), 2).to(self._device) 100 | torch.nn.init.constant_(special_indicator, -1) 101 | relation_indicator = torch.empty(input_ids.size(0), self._n_relation).to(self._device) 102 | torch.nn.init.constant_(relation_indicator, -1) 103 | entity_indicator = torch.empty(input_ids.size(0), (self._voc_size - self._n_relation - 2)).to(self._device) 104 | torch.nn.init.constant_(entity_indicator, 1) 105 | type_indicator = torch.cat((relation_indicator, entity_indicator), dim=1).to(self._device) 106 | mask_type = mask_type.unsqueeze(1) 107 | type_indicator = torch.mul(type_indicator, mask_type) 108 | type_indicator = torch.cat([special_indicator, type_indicator], dim=1) 109 | type_indicator = torch.nn.functional.relu(type_indicator) 110 | fc_out_mask = 1000000.0 * (type_indicator - 1.0) 111 | fc_out = torch.add(fc_out, fc_out_mask) 112 | one_hot_labels = torch.nn.functional.one_hot(mask_label, self._voc_size) # _voc_size = node_num 113 | type_indicator = torch.sub(type_indicator, one_hot_labels) 114 | num_candidates = torch.sum(type_indicator, dim=1) 115 | soft_labels = ((1 + mask_type) * 0.9 + 116 | (1 - mask_type) * 0.9) / 2.0 117 | soft_labels = soft_labels.expand(-1, self._voc_size) 118 | soft_labels = soft_labels * one_hot_labels + (1.0 - soft_labels) * \ 119 | torch.mul(type_indicator, 1.0 / torch.unsqueeze(num_candidates, 1)) 120 | mean_mask_lm_loss = self.myloss(logits=fc_out, label=soft_labels, weight=confidence) 121 | return mean_mask_lm_loss, fc_out 122 | 123 | else: 124 | _enc_out = _enc_out.view(batch_size, -1) 125 | out_embeddings = _enc_out.clone().detach() 126 | _enc_out = torch.nn.GELU()(self.fc_tuple(_enc_out)) 127 | h_masked = self.fc3(_enc_out) 128 | h_masked = torch.nn.GELU()(h_masked) 129 | h_masked = self.layer_norm2(h_masked) 130 | h_masked = torch.nn.GELU()(self.fc4(h_masked)) 131 | fc_out = self.fc5(h_masked) 132 | fc_out_vactor = torch.nn.Sigmoid()(fc_out) 133 | fc_out_vactor = 1 + fc_out_vactor * (input_mask.squeeze()) - input_mask.squeeze() 134 | fc_out = torch.min(fc_out_vactor, dim=1).values 135 | if len(input_mask.shape) == 3: 136 | is_true_sum = torch.sum(is_true * input_mask.squeeze(), dim=1) 137 | input_mask_sum = torch.sum(input_mask.squeeze(), dim=1) 138 | else: 139 | is_true_sum = torch.sum(is_true * input_mask, dim=1) 140 | input_mask_sum = torch.sum(input_mask, dim=1) 141 | tags = is_true_sum == input_mask_sum 142 | loss_conf = self.loss_conf(fc_out, tags.float()) 143 | return loss_conf, fc_out, fc_out_vactor, out_embeddings 144 | 145 | class softmax_with_cross_entropy(torch.nn.Module): 146 | def __init__(self): 147 | super(softmax_with_cross_entropy,self).__init__() 148 | 149 | def forward(self,logits, label, weight): 150 | logprobs=torch.nn.functional.log_softmax(logits,dim=1) 151 | loss=-1.0*torch.sum(torch.mul(label,logprobs),dim=1).squeeze() 152 | loss=torch.mean(loss*weight) 153 | return loss -------------------------------------------------------------------------------- /model/graph_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn 7 | import numpy as np 8 | 9 | def truncated_normal(t, mean=0.0, std=0.01): 10 | torch.nn.init.normal_(t, mean=mean, std=std) 11 | while True: 12 | cond = torch.logical_or(t < mean - 2*std, t > mean + 2*std) 13 | if not torch.sum(cond): 14 | break 15 | t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t) 16 | return t 17 | 18 | 19 | 20 | class multi_head_attention(torch.nn.Module): 21 | def __init__(self,d_key,d_value,d_model,n_head,attention_dropout): 22 | super(multi_head_attention,self).__init__() 23 | self.d_key=d_key 24 | self.d_value=d_value 25 | self.d_model=d_model 26 | self.n_head=n_head 27 | self.attention_dropout=attention_dropout 28 | 29 | self.layer_q=torch.nn.Linear(self.d_model,self.d_key * self.n_head) 30 | self.layer_q.weight.data=truncated_normal(self.layer_q.weight.data,std=0.02) 31 | #torch.nn.init.xavier_uniform_(self.layer_q.weight) 32 | torch.nn.init.constant_(self.layer_q.bias, 0.0) 33 | self.layer_k=torch.nn.Linear(self.d_model,self.d_key * self.n_head) 34 | self.layer_k.weight.data=truncated_normal(self.layer_k.weight.data,std=0.02) 35 | #torch.nn.init.xavier_uniform_(self.layer_k.weight) 36 | torch.nn.init.constant_(self.layer_k.bias, 0.0) 37 | self.layer_v=torch.nn.Linear(self.d_model,self.d_value * self.n_head) 38 | self.layer_v.weight.data=truncated_normal(self.layer_v.weight.data,std=0.02) 39 | #torch.nn.init.xavier_uniform_(self.layer_v.weight) 40 | torch.nn.init.constant_(self.layer_v.bias, 0.0) 41 | self.project_layer=torch.nn.Linear(d_value * n_head,self.d_model) 42 | self.project_layer.weight.data=truncated_normal(self.project_layer.weight.data,std=0.02) 43 | #torch.nn.init.xavier_uniform_(self.project_layer.weight) 44 | torch.nn.init.constant_(self.project_layer.bias, 0.0) 45 | 46 | def forward(self, 47 | queries, 48 | edges_key, 49 | edges_value, 50 | attn_bias): 51 | #B is batch_size, M is max_seq_len, N is n_head, H is d_key 52 | batch_size=queries.size(0) 53 | max_seq_len=queries.size(1) 54 | #query,key,value is [B,M,N*H], edges_key,edges_value is [M,M,H], attn_bias is [B,N,M,M] 55 | keys = queries 56 | values = keys 57 | #q,k,v is [B,N,M,H] 58 | q=self.layer_q(queries).view(batch_size,-1,self.n_head,self.d_key).transpose(1,2) 59 | k=self.layer_k(keys).view(batch_size,-1,self.n_head,self.d_key).transpose(1,2) 60 | v=self.layer_v(values).view(batch_size,-1,self.n_head,self.d_value).transpose(1,2) 61 | #scores1,scores2,scores is [B,N,M,M] 62 | scores1 = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(self.d_key) 63 | scores2 = torch.matmul(q.permute(2,0,1,3).contiguous().view(max_seq_len,-1,self.d_key),edges_key.transpose(-1,-2)).view(max_seq_len,-1,self.n_head,max_seq_len).permute(1,2,0,3)/ np.sqrt(self.d_key) 64 | scores=torch.add(scores1,scores2) 65 | scores=torch.add(scores,attn_bias) 66 | #weights is [B,N,M,M] 67 | weights=torch.nn.Dropout(self.attention_dropout)(torch.nn.Softmax(dim=-1)(scores)) 68 | #context1,context2,context is [B,N,M,H] 69 | context1= torch.matmul(weights,v) 70 | context2= torch.matmul(weights.permute(2,0,1,3).contiguous().view(max_seq_len,-1,max_seq_len),edges_value).view(max_seq_len,-1,self.n_head,self.d_value).permute(1,2,0,3) 71 | context=torch.add(context1,context2) 72 | #output is [B,M,N*H] 73 | output=context.transpose(1,2).contiguous().view(batch_size,-1,self.n_head*self.d_value) 74 | output=self.project_layer(output) 75 | return output 76 | 77 | 78 | class positionwise_feed_forward(torch.nn.Module): 79 | def __init__(self,d_inner_hid,d_model): 80 | super(positionwise_feed_forward,self).__init__() 81 | self.d_inner_hid=d_inner_hid 82 | self.d_hid=d_model 83 | 84 | self.fc1=torch.nn.Linear(self.d_hid,self.d_inner_hid) 85 | self.fc1.weight.data=truncated_normal(self.fc1.weight.data,std=0.02) 86 | #torch.nn.init.xavier_uniform_(self.fc1.weight) 87 | torch.nn.init.constant_(self.fc1.bias, 0.0) 88 | self.fc2=torch.nn.Linear(self.d_inner_hid,self.d_hid) 89 | self.fc2.weight.data=truncated_normal(self.fc2.weight.data,std=0.02) 90 | #torch.nn.init.xavier_uniform_(self.fc2.weight) 91 | torch.nn.init.constant_(self.fc2.bias, 0.0) 92 | 93 | def forward(self,x): 94 | return self.fc2(torch.nn.GELU()(self.fc1(x))) 95 | 96 | class encoder_layer(torch.nn.Module): 97 | def __init__(self, 98 | n_head, 99 | d_key, 100 | d_value, 101 | d_model, 102 | d_inner_hid, 103 | prepostprocess_dropout, 104 | attention_dropout): 105 | super(encoder_layer,self).__init__() 106 | self.n_head=n_head 107 | self.d_key=d_key 108 | self.d_value=d_value 109 | self.d_model=d_model 110 | self.d_inner_hid=d_inner_hid 111 | self.prepostprocess_dropout=prepostprocess_dropout 112 | self.attention_dropout=attention_dropout 113 | 114 | self.multi_head_attention=multi_head_attention( 115 | self.d_key, 116 | self.d_value, 117 | self.d_model, 118 | self.n_head, 119 | self.attention_dropout) 120 | self.layer_norm1=torch.nn.LayerNorm(normalized_shape=self.d_model,eps=1e-7,elementwise_affine=True) 121 | 122 | self.positionwise_feed_forward=positionwise_feed_forward( 123 | self.d_inner_hid, 124 | self.d_model) 125 | self.layer_norm2=torch.nn.LayerNorm(normalized_shape=self.d_model,eps=1e-7,elementwise_affine=True) 126 | 127 | def forward(self,enc_input, 128 | edges_key, 129 | edges_value, 130 | attn_bias): 131 | attn_output = self.multi_head_attention( 132 | enc_input, 133 | edges_key, 134 | edges_value, 135 | attn_bias) 136 | attn_output=self.layer_norm1(torch.add(enc_input,torch.nn.Dropout(self.prepostprocess_dropout)(attn_output))) 137 | 138 | ffd_output = self.positionwise_feed_forward(attn_output) 139 | ffd_output=self.layer_norm2(torch.add(attn_output,torch.nn.Dropout(self.prepostprocess_dropout)(ffd_output))) 140 | return ffd_output 141 | 142 | 143 | class encoder(torch.nn.Module): 144 | def __init__(self,n_layer,n_head,d_key,d_value,d_model, 145 | d_inner_hid,prepostprocess_dropout,attention_dropout): 146 | super(encoder,self).__init__() 147 | self.n_layer=n_layer 148 | self.n_head=n_head 149 | self.d_key=d_key 150 | self.d_value=d_value 151 | self.d_model=d_model 152 | self.d_inner_hid=d_inner_hid 153 | self.prepostprocess_dropout=prepostprocess_dropout 154 | self.attention_dropout=attention_dropout 155 | 156 | for nl in range(self.n_layer): 157 | setattr(self,"encoder_layer{}".format(nl),encoder_layer( 158 | self.n_head, 159 | self.d_key, 160 | self.d_value, 161 | self.d_model, 162 | self.d_inner_hid, 163 | self.prepostprocess_dropout, 164 | self.attention_dropout)) 165 | 166 | def forward(self,enc_input,edges_key,edges_value,attn_bias): 167 | for nl in range(self.n_layer): 168 | enc_output = getattr(self,"encoder_layer{}".format(nl))( 169 | enc_input, 170 | edges_key, 171 | edges_value, 172 | attn_bias) 173 | enc_input = enc_output 174 | return enc_output 175 | -------------------------------------------------------------------------------- /reader/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import logging 6 | import torch 7 | import numpy as np 8 | import time 9 | import numpy as np 10 | import scipy.sparse as sp 11 | 12 | logging.basicConfig( 13 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 14 | datefmt='%m/%d/%Y %H:%M:%S') 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.DEBUG) 17 | logger.info(logger.getEffectiveLevel()) 18 | 19 | def get_edge_labels(max_n): 20 | edge_labels = [] 21 | max_seq_length=2*max_n-1 22 | max_aux = max_n - 2 23 | 24 | edge_labels.append([0, 1, 0] + [5,0] * max_aux) 25 | edge_labels.append([1, 0, 2] + [3,0] * max_aux) 26 | edge_labels.append([0, 2, 0] + [0,0] * max_aux) 27 | for idx in range(max_aux): 28 | edge_labels.append( 29 | [5, 3, 0] + [0,0] * idx + [0,4] + [0,0] * (max_aux - idx - 1)) 30 | edge_labels.append( 31 | [0, 0, 0] + [0,0] * idx + [4,0] + [0,0] * (max_aux - idx - 1)) 32 | edge_labels = np.asarray(edge_labels).astype("int64").reshape( 33 | [max_seq_length, max_seq_length]) 34 | edge_labels=torch.from_numpy(edge_labels) 35 | ''' 36 | for i in range(edge_labels.shape[0]): 37 | if edge_labels[1, i] == 0: 38 | edge_labels[1, i] = 5 39 | if edge_labels[i, 1] == 0: 40 | edge_labels[i, 1] = 5 41 | print(edge_labels) 42 | ''' 43 | return edge_labels.long() 44 | 45 | def prepare_EC_info(ins_info, device): 46 | instance_info=dict() 47 | instance_info["node_num"]=ins_info['node_num'] 48 | instance_info["rel_num"]=ins_info['rel_num'] 49 | instance_info["max_n"]=ins_info['max_n'] 50 | 51 | return instance_info 52 | ''' 53 | def normalize_adj(adj): 54 | """Symmetrically normalize adjacency matrix.""" 55 | adj = sp.coo_matrix(adj) 56 | rowsum = np.array(adj.sum(1)) 57 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 58 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 59 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 60 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 61 | ''' 62 | def sparse_to_tuple(sparse_mx): 63 | def to_tuple(mx): 64 | if not sp.isspmatrix_coo(mx): 65 | mx = mx.tocoo() 66 | coords = np.vstack((mx.row, mx.col)).transpose() 67 | values = mx.data 68 | shape = mx.shape 69 | return coords, values, shape 70 | 71 | if isinstance(sparse_mx, list): 72 | for i in range(len(sparse_mx)): 73 | sparse_mx[i] = to_tuple(sparse_mx[i]) 74 | else: 75 | sparse_mx = to_tuple(sparse_mx) 76 | return sparse_mx 77 | ''' 78 | def preprocess_adj(adj): 79 | """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation.""" 80 | adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) 81 | return sparse_to_tuple(adj_normalized) 82 | 83 | def no_weighted_adj(total_ent_num, triple_list): 84 | start = time.time() 85 | edge = dict() 86 | for item in triple_list: 87 | if item[0] not in edge.keys(): 88 | edge[item[0]] = set() 89 | if item[2] not in edge.keys(): 90 | edge[item[2]] = set() 91 | edge[item[0]].add(item[2]) 92 | edge[item[2]].add(item[0]) 93 | row = list() 94 | col = list() 95 | for i in range(total_ent_num): 96 | if i not in edge.keys(): 97 | continue 98 | key = i 99 | value = edge[key] 100 | add_key_len = len(value) 101 | add_key = (key * np.ones(add_key_len)).tolist() 102 | row.extend(add_key) 103 | col.extend(list(value)) 104 | data_len = len(row) 105 | data = np.ones(data_len) 106 | one_adj = sp.coo_matrix((data, (row, col)), shape=(total_ent_num, total_ent_num)) 107 | one_adj = preprocess_adj(one_adj) 108 | print('generating one-adj costs time: {:.4f}s'.format(time.time() - start)) 109 | return one_adj 110 | 111 | def gen_adj(total_e_num, triples): 112 | adj_triples=list() 113 | for i,item in enumerate(triples[0]): 114 | if triples[2][i]==0: 115 | item[0]=triples[3][i] 116 | adj_triples.append(item) 117 | one_adj = no_weighted_adj(total_e_num, adj_triples) 118 | adj = one_adj 119 | return adj 120 | ''' 121 | def gen_hadj(total_ent_num, statements): 122 | start = time.time() 123 | total_state_num=len(statements) 124 | row = list() 125 | col = list() 126 | for j,item in enumerate(statements): 127 | for p,i in enumerate(item): 128 | if p%2==0: 129 | row.append(i) 130 | col.append(j) 131 | 132 | data_len = len(row) 133 | data = np.ones(data_len) 134 | H = sp.coo_matrix((data, (row, col)), shape=(total_ent_num, total_state_num)) 135 | 136 | n_edge = H.shape[1]# 超边矩阵 137 | # the weight of the hyperedge 138 | W = np.ones(n_edge) # 超边权重矩阵 139 | # the degree of the node 140 | DV = np.array(H.sum(1)) # 节点度; (12311,) 141 | # the degree of the hyperedge 142 | DE = np.array(H.sum(0)) # 超边的度; (24622,) 143 | 144 | invDE = sp.diags(np.power(DE, -1).flatten()) # DE^-1; 建立对角阵 145 | DV2 = sp.diags(np.power(DV, -0.5).flatten()) # DV^-1/2 146 | W = sp.diags(W) # 超边权重矩阵 147 | HT = H.transpose() 148 | 149 | G = DV2 * H * W * invDE * HT * DV2 150 | 151 | logger.info('generating G costs time: {:.4f}s'.format(time.time() - start)) 152 | return sparse_to_tuple(G) 153 | 154 | def prepare_adj_info(ins_info, onto_info, device): 155 | ins_adj = gen_hadj(ins_info['node_num'], ins_info['all_fact_ids']) 156 | ins_adj_info=dict() 157 | ins_adj_info['indices']=torch.tensor(ins_adj[0]).t().to(device) 158 | ins_adj_info['values']=torch.tensor(ins_adj[1]).to(device) 159 | ins_adj_info['size']=torch.tensor(ins_adj[2]).to(device) 160 | onto_adj = gen_hadj(onto_info['node_num'], onto_info['all_fact_ids']) 161 | onto_adj_info=dict() 162 | onto_adj_info['indices']=torch.tensor(onto_adj[0]).t().to(device) 163 | onto_adj_info['values']=torch.tensor(onto_adj[1]).to(device) 164 | onto_adj_info['size']=torch.tensor(onto_adj[2]).to(device) 165 | return ins_adj_info, onto_adj_info -------------------------------------------------------------------------------- /reader/data_reader.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import random 6 | from abc import abstractclassmethod 7 | 8 | import logging 9 | import collections 10 | 11 | 12 | import json 13 | import copy 14 | import numpy as np 15 | import torch 16 | 17 | 18 | logging.basicConfig( 19 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 20 | datefmt='%m/%d/%Y %H:%M:%S') 21 | logger = logging.getLogger(__name__) 22 | logger.setLevel(logging.DEBUG) 23 | logger.info(logger.getEffectiveLevel()) 24 | 25 | #读取三元组 26 | 27 | def read_facts_new(file): 28 | facts_list = list() 29 | max_n=0 30 | entity_list = list() 31 | relation_list = list() 32 | with open(file, 'r', encoding='utf8') as f: 33 | for line in f: 34 | fact=list() 35 | # print(line) 36 | obj = json.loads(line) 37 | if obj['N']>max_n: 38 | max_n=obj['N'] 39 | flag = 0 40 | for key in obj: 41 | if flag == 0: 42 | fact.append(obj[key][0]) 43 | fact.append(key) 44 | fact.append(obj[key][1]) 45 | relation_list.append(key) 46 | entity_list.append(obj[key][0]) 47 | entity_list.append(obj[key][1]) 48 | break 49 | if obj['N']>2: 50 | for kv in obj.keys(): 51 | if kv!='N' and kv!=key: 52 | if isinstance(obj[kv], list): 53 | for item in obj[kv]: 54 | fact.append(kv) 55 | fact.append(item) 56 | relation_list.append(kv) 57 | entity_list.append(item) 58 | else: 59 | fact.append(kv) 60 | fact.append(obj[kv]) 61 | relation_list.append(kv) 62 | entity_list.append(obj[kv]) 63 | facts_list.append(fact) 64 | return facts_list,max_n,relation_list,entity_list 65 | 66 | def read_dict(ent_file,rel_file): 67 | dict_id=dict() 68 | dict_id['PAD']=0 69 | dict_id['MASK']=1 70 | dict_num=2 71 | rel_num=0 72 | with open(rel_file, 'r', encoding='utf8') as f: 73 | for line in f: 74 | line=line.strip('\n') 75 | dict_id[line]=dict_num 76 | dict_num+=1 77 | rel_num+=1 78 | 79 | with open(ent_file, 'r', encoding='utf8') as f: 80 | for line in f: 81 | line=line.strip('\n') 82 | dict_id[line]=dict_num 83 | dict_num+=1 84 | 85 | return dict_id,dict_num,rel_num 86 | 87 | 88 | def read_dict_new(e_ls,r_ls): 89 | dict_id=dict() 90 | dict_id['PAD']=0 91 | dict_id['MASK']=1 92 | dict_num=2 93 | rel_num=0 94 | 95 | for item in r_ls: 96 | dict_id[item] = dict_num 97 | dict_num += 1 98 | rel_num += 1 99 | 100 | for item in e_ls: 101 | dict_id[item] = dict_num 102 | dict_num += 1 103 | 104 | return dict_id,dict_num,rel_num 105 | 106 | 107 | def facts_to_id(facts,max_n,node_dict, is_true_out): 108 | id_facts=list() 109 | id_masks=list() 110 | mask_labels=list() 111 | mask_pos=list() 112 | mask_types=list() 113 | real_triples = list() 114 | is_true = list() 115 | is_shown = list() 116 | real_masks = list() 117 | replace_masks = list() 118 | for fact in facts: 119 | id_fact=list() 120 | id_mask=list() 121 | for i,item in enumerate(fact): 122 | id_fact.append(node_dict[item]) 123 | id_mask.append(1.0) 124 | 125 | for j,mask_label in enumerate(id_fact): 126 | x=copy.copy(id_fact) 127 | x[j]=1 128 | y=copy.copy(id_mask) 129 | z = copy.copy(id_mask) 130 | if j%2==0: 131 | mask_type=1 132 | else: 133 | mask_type=-1 134 | while len(x)<(2*max_n-1): 135 | x.append(0) 136 | y.append(0.0) 137 | z.append(1.0) 138 | id_facts.append(x) 139 | id_masks.append(y) 140 | replace_masks.append(z) 141 | mask_pos.append(j) 142 | if j == 0: 143 | is_true.append(0) 144 | is_shown.append(0) 145 | x_copy = copy.copy(x) 146 | x_copy[0] = mask_label 147 | real_triples.append(x_copy) 148 | real_masks.append(y) 149 | mask_labels.append(mask_label) 150 | mask_types.append(mask_type) 151 | return [id_facts,id_masks,mask_pos,mask_labels,mask_types], [real_triples, is_true_out.numpy().tolist(), torch.zeros(is_true_out.shape).numpy().tolist(), real_masks, is_true_out.numpy().tolist()] 152 | 153 | def get_truth(all_facts,max_n,node_dict): 154 | max_aux=max_n-2 155 | max_seq_length = 2 * max_aux + 3 156 | gt_dict = collections.defaultdict(lambda: collections.defaultdict(list)) 157 | all_fact_ids=list() 158 | for fact in all_facts: 159 | id_fact=list() 160 | for i,item in enumerate(fact): 161 | id_fact.append(node_dict[item]) 162 | all_fact_id=copy.copy(id_fact) 163 | all_fact_ids.append(all_fact_id) 164 | while len(id_fact)<(2*max_n-1): 165 | id_fact.append(0) 166 | for pos in range(max_seq_length): 167 | if id_fact[pos]==0: 168 | continue 169 | key = " ".join([ 170 | str(id_fact[x]) for x in range(max_seq_length) if x != pos 171 | ]) 172 | gt_dict[pos][key].append(id_fact[pos]) 173 | 174 | return gt_dict,all_fact_ids 175 | 176 | def get_input(train_file, valid_file, test_file, noise_file, initiatial_amount): 177 | 178 | train_facts,max_train,train_r,train_e = read_facts_new(train_file) 179 | valid_facts,max_valid,valid_r,valid_e = read_facts_new(valid_file) 180 | test_facts,max_test,test_r,test_e = read_facts_new(test_file) 181 | 182 | max_n = max(max_train, max_valid, max_test) 183 | e_list = list(set(train_e + valid_e + test_e)) 184 | r_list = list(set(train_r + valid_r + test_r)) 185 | noise_facts = [] 186 | is_true = torch.ones([0, max_n*2-1]) 187 | train_fact_num = len(train_facts) 188 | noise_amount = int(train_fact_num * initiatial_amount) 189 | for _ in range(noise_amount): 190 | temp_fact = copy.copy(random.choice(train_facts)) 191 | is_true_temp = torch.ones([1, max_n*2-1]) 192 | replace_num = int(random.randint(1, len(temp_fact) - 1) / 2) 193 | if replace_num == 0: 194 | replace_num = 1 195 | for j in range(replace_num): 196 | replace_index = random.randint(0, len(temp_fact) - 1) 197 | is_true_temp[0, replace_index] = 0 198 | if replace_index % 2 == 0: 199 | random_replace = random.choice(e_list) 200 | while True: 201 | if random_replace != temp_fact[replace_index]: 202 | break 203 | random_replace = random.choice(e_list) 204 | temp_fact[replace_index] = random_replace 205 | else: 206 | random_replace = random.choice(r_list) 207 | while True: 208 | if random_replace != temp_fact[replace_index]: 209 | break 210 | random_replace = random.choice(r_list) 211 | temp_fact[replace_index] = random_replace 212 | is_true = torch.cat([is_true, is_true_temp], dim=0) 213 | noise_facts.append(temp_fact) 214 | 215 | noise_test = [] 216 | is_true_test = torch.ones([0, max_n * 2 - 1]) 217 | test_fact_num = len(test_facts) 218 | noise_amount_test = int(test_fact_num * initiatial_amount) 219 | for _ in range(noise_amount_test): 220 | temp_fact = copy.copy(random.choice(test_facts)) 221 | is_true_temp = torch.ones([1, max_n * 2 - 1]) 222 | replace_num = int(random.randint(1, len(temp_fact) - 1) / 2) 223 | if replace_num == 0: 224 | replace_num = 1 225 | for j in range(replace_num): 226 | replace_index = random.randint(0, len(temp_fact) - 1) 227 | is_true_temp[0, replace_index] = 0 228 | if replace_index % 2 == 0: 229 | random_replace = random.choice(e_list) 230 | while True: 231 | if random_replace != temp_fact[replace_index]: 232 | break 233 | random_replace = random.choice(e_list) 234 | temp_fact[replace_index] = random_replace 235 | else: 236 | random_replace = random.choice(r_list) 237 | while True: 238 | if random_replace != temp_fact[replace_index]: 239 | break 240 | random_replace = random.choice(r_list) 241 | temp_fact[replace_index] = random_replace 242 | is_true_test = torch.cat([is_true_test, is_true_temp], dim=0) 243 | noise_test.append(temp_fact) 244 | all_facts = train_facts + valid_facts + test_facts + noise_facts 245 | node_dict, node_num, rel_num=read_dict_new(e_list,r_list) 246 | all_facts,all_fact_ids= get_truth(all_facts,max_n,node_dict) 247 | train_facts, train_real = facts_to_id(train_facts,max_n,node_dict, torch.ones([len(train_facts), 2*max_n-1])) 248 | valid_facts, _= facts_to_id(valid_facts,max_n,node_dict, torch.ones(is_true.shape)) 249 | test_facts, test_real= facts_to_id(test_facts,max_n,node_dict, torch.ones([len(test_facts), 2*max_n-1])) 250 | noise_facts, noise_real = facts_to_id(noise_facts,max_n,node_dict, is_true) 251 | _, test_noise_real = facts_to_id(noise_test, max_n, node_dict, is_true_test) 252 | for i in range(len(train_facts)): 253 | train_facts[i] += noise_facts[i] 254 | for i in range(len(train_real)): 255 | train_real[i] += noise_real[i] 256 | for i in range(len(test_real)): 257 | test_real[i] = test_noise_real[i] + test_real[i] 258 | input_info=dict() 259 | input_info['all_facts']=all_facts 260 | input_info['all_fact_ids']=all_fact_ids 261 | input_info['train_facts']=train_facts 262 | input_info['train_real'] = train_real 263 | input_info['test_real'] = test_real 264 | input_info['valid_facts']=valid_facts 265 | input_info['test_facts']=test_facts 266 | input_info['node_dict']=node_dict 267 | input_info['node_num']=node_num 268 | input_info['rel_num']=rel_num 269 | input_info['max_n']=max_n 270 | return input_info 271 | 272 | def truth_to_id(all_facts, ins_ent_ids, onto_ent_ids): 273 | typing=dict() 274 | for fact in all_facts: 275 | if fact[0] not in typing.keys(): 276 | typing[fact[0]]=list() 277 | if onto_ent_ids[fact[2]] not in typing[fact[0]]: 278 | typing[fact[0]].append(onto_ent_ids[fact[2]]) 279 | typing_id=dict() 280 | for key in typing.keys(): 281 | typing_id[str(ins_ent_ids[key])]=typing[key] 282 | return typing_id 283 | 284 | def read_input(folder, initiatial_amount): 285 | ins_info = get_input(folder +"/n-ary_train.json", folder + "/n-ary_valid.json", folder + "/n-ary_test.json", folder + "/n-ary_noise.json", initiatial_amount) 286 | 287 | logger.info("Number of ins_all fact_ids: "+str(len(ins_info['all_fact_ids']))) 288 | logger.info("Number of ins_train facts: "+str(len(ins_info['train_facts'][0]))) 289 | logger.info("Number of ins_valid facts: "+str(len(ins_info['valid_facts'][0]))) 290 | logger.info("Number of ins_test facts: "+str(len(ins_info['test_facts'][0]))) 291 | logger.info("Number of ins nodes: "+str(ins_info['node_num'])) 292 | logger.info("Number of ins relations: "+str(ins_info['rel_num'])) 293 | logger.info("Number of ins max_n: "+str(ins_info['max_n'])) 294 | logger.info("Number of ins max_seq_length: "+str(2*ins_info['max_n']-1)) 295 | 296 | 297 | return ins_info 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import os 4 | 5 | # os.environ['CUDA_VISIBLE_DEVICES'] = '3' 6 | from utils.args import ArgumentGroup, print_arguments 7 | import logging 8 | from reader.data_reader import read_input 9 | from reader.data_loader import prepare_EC_info, get_edge_labels 10 | from model.NYLON import NYLON 11 | import time 12 | import math 13 | import random 14 | import torch 15 | import torch.utils.data.dataset as Dataset 16 | import torch.utils.data.dataloader as DataLoader 17 | import copy 18 | import numpy as np 19 | from itertools import cycle 20 | from utils.evaluation import batch_evaluation, compute_metrics 21 | 22 | torch.set_printoptions(precision=8) 23 | 24 | logging.basicConfig( 25 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 26 | datefmt='%m/%d/%Y %H:%M:%S') 27 | logger = logging.getLogger(__name__) 28 | logger.setLevel(logging.INFO) 29 | logger.info(logger.getEffectiveLevel()) 30 | 31 | parser = argparse.ArgumentParser(description='HyperKE4TI') 32 | NYLON_g = ArgumentGroup(parser, "model", "model and checkpoint configuration.") 33 | NYLON_g.add_arg('input', type=str, default='dataset/jf17k', help="") # db 34 | NYLON_g.add_arg('output', type=str, default='./', help="") 35 | 36 | NYLON_g.add_arg('dim', type=int, default=256, help="") 37 | NYLON_g.add_arg('onto_dim', type=int, default=256, help="") 38 | NYLON_g.add_arg('ins_layer_num', type=int, default=3, help="") 39 | NYLON_g.add_arg('onto_layer_num', type=int, default=3, help="") 40 | NYLON_g.add_arg('neg_typing_margin', type=float, default=0.1, help="") 41 | NYLON_g.add_arg('neg_triple_margin', type=float, default=0.2, help="") 42 | 43 | NYLON_g.add_arg('nums_neg', type=int, default=30, help="") 44 | NYLON_g.add_arg('mapping_neg_nums', type=int, default=30, help="") 45 | 46 | NYLON_g.add_arg('learning_rate', type=float, default=1e-4, help="") 47 | NYLON_g.add_arg('batch_size', type=int, default=1024, help="") 48 | NYLON_g.add_arg('epochs', type=int, default=100, help="") 49 | 50 | NYLON_g.add_arg('combine', type=ast.literal_eval, default=True, help="") 51 | NYLON_g.add_arg('ent_top_k', type=list, default=[1, 3, 5, 10], help="") 52 | NYLON_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.") 53 | 54 | NYLON_g.add_arg('ins_intermediate_size', type=int, default=512, help="") 55 | NYLON_g.add_arg('onto_intermediate_size', type=int, default=512, help="") 56 | NYLON_g.add_arg('num_hidden_layers', type=int, default=12, help="") 57 | NYLON_g.add_arg('num_attention_heads', type=int, default=4, help="") 58 | NYLON_g.add_arg('hidden_dropout_prob', type=float, default=0.1, help="") 59 | NYLON_g.add_arg('attention_dropout_prob', type=float, default=0.1, help="") 60 | NYLON_g.add_arg('num_edges', type=int, default=6, help="") 61 | 62 | NYLON_g.add_arg('noise_level', type=float, default=1.0, help="") 63 | NYLON_g.add_arg('active_sample_per_epoch', type=float, default=0.0025, help="") 64 | NYLON_g.add_arg('meta_lr', type=float, default=0.1, help="") 65 | NYLON_g.add_arg('error_detection_every_x_epochs', type=int, default=1, help="") 66 | NYLON_g.add_arg('aug_amount', type=int, default=0, help="") 67 | 68 | args = parser.parse_args() 69 | 70 | 71 | class EDataset(Dataset.Dataset): 72 | def __init__(self, triples1): 73 | self.triples1 = triples1 74 | 75 | def __len__(self): 76 | return len(self.triples1[0]) 77 | 78 | def __getitem__(self, index): 79 | return self.triples1[0][index], self.triples1[1][index], self.triples1[2][index], self.triples1[3][index], \ 80 | self.triples1[4][index] 81 | 82 | class EDataset6(Dataset.Dataset): 83 | def __init__(self, triples1): 84 | self.triples1 = triples1 85 | 86 | def __len__(self): 87 | return len(self.triples1[0]) 88 | 89 | def __getitem__(self, index): 90 | return self.triples1[0][index], self.triples1[1][index], self.triples1[2][index], self.triples1[3][index], \ 91 | self.triples1[4][index], self.triples1[5][index] 92 | 93 | def is_same(tensor1, tensor2, n_arity): 94 | list1 = tensor1.cpu().tolist() 95 | list2 = tensor2.cpu().tolist() 96 | for i in range(n_arity-1): 97 | if i == 0: 98 | tag1 = list1[0] + list1[1] + list1[2] 99 | tag2 = list2[0] + list2[1] + list2[2] 100 | tag1 = (tag1 == 3) 101 | tag2 = (tag2 == 3) 102 | if tag1 != tag2: 103 | return 0 104 | else: 105 | tag1 = list1[2 * i + 1] + list1[2 * i + 2] 106 | tag2 = list2[2 * i + 1] + list2[2 * i + 2] 107 | tag1 = (tag1 == 2) 108 | tag2 = (tag2 == 2) 109 | if tag1 != tag2: 110 | return 0 111 | return 1 112 | 113 | 114 | 115 | def main(args): 116 | config = vars(args) 117 | if args.use_cuda: 118 | device = torch.device("cuda") 119 | config["device"] = "cuda" 120 | else: 121 | device = torch.device("cpu") 122 | config["device"] = "cpu" 123 | 124 | ins_info = read_input(args.input, args.noise_level) 125 | 126 | instance_info = prepare_EC_info(ins_info, device) 127 | ins_edge_labels = get_edge_labels(ins_info['max_n']).to(device) 128 | 129 | model_normal = NYLON(instance_info, config).to(device) 130 | 131 | # E_train_dataloader 132 | ins_train_facts = list() 133 | for ins_train_fact in ins_info['train_facts']: 134 | ins_train_fact = torch.tensor(ins_train_fact).to(device) 135 | ins_train_facts.append(ins_train_fact) 136 | ins_train_facts.append(torch.ones(ins_train_facts[0].shape[0]).to(device)) 137 | train_data_E_reader = EDataset6(ins_train_facts) 138 | train_E_pyreader = DataLoader.DataLoader(train_data_E_reader, batch_size=args.batch_size, shuffle=True, 139 | drop_last=False) 140 | 141 | # train_information 142 | logging.info("train_ins_batch_size: " + str(args.batch_size)) 143 | logging.info("train_onto_batch_size: " + str(args.batch_size)) 144 | steps = math.ceil(len(ins_info['train_facts']) / args.batch_size) 145 | logging.info("train_steps_per_epoch: " + str(steps)) 146 | 147 | # E_valid_dataloader 148 | ins_valid_facts = list() 149 | for ins_valid_fact in ins_info['valid_facts']: 150 | ins_valid_fact = torch.tensor(ins_valid_fact).to(device) 151 | ins_valid_facts.append(ins_valid_fact) 152 | valid_data_E_reader = EDataset(ins_valid_facts) 153 | valid_E_pyreader = DataLoader.DataLoader( 154 | valid_data_E_reader, 155 | batch_size=args.batch_size, 156 | shuffle=True, 157 | drop_last=False) 158 | 159 | # E_valid_dataloader 160 | ins_test_reals = list() 161 | for ins_test_fact in ins_info['test_real']: 162 | ins_test_fact = torch.tensor(ins_test_fact).to(device) 163 | ins_test_reals.append(ins_test_fact) 164 | test_real_data_E_reader = EDataset(ins_test_reals) 165 | test_real_E_pyreader = DataLoader.DataLoader( 166 | test_real_data_E_reader, 167 | batch_size=args.batch_size, 168 | shuffle=True, 169 | drop_last=False) 170 | 171 | ins_test_facts = list() 172 | for ins_test_fact in ins_info['test_facts']: 173 | ins_test_fact = torch.tensor(ins_test_fact).to(device) 174 | ins_test_facts.append(ins_test_fact) 175 | test_data_E_reader = EDataset(ins_test_facts) 176 | test_E_pyreader = DataLoader.DataLoader( 177 | test_data_E_reader, 178 | batch_size=args.batch_size, 179 | shuffle=True, 180 | drop_last=False) 181 | 182 | ins_real_facts = list() 183 | for ins_real_fact in ins_info['train_real']: 184 | ins_real_fact = torch.tensor(ins_real_fact).to(device) 185 | ins_real_facts.append(ins_real_fact) 186 | real_data_E_reader = EDataset(ins_real_facts) 187 | real_E_pyreader = DataLoader.DataLoader( 188 | real_data_E_reader, 189 | batch_size=args.batch_size, 190 | shuffle=True, 191 | drop_last=False) 192 | no_shuffle_reader = DataLoader.DataLoader( 193 | real_data_E_reader, 194 | batch_size=args.batch_size, 195 | shuffle=False, 196 | drop_last=False) 197 | samples = int(torch.sum(ins_real_facts[3]).int().item() * args.active_sample_per_epoch) 198 | print("sample amount : ", samples) 199 | 200 | real_indexing = dict() 201 | for index in range(ins_train_facts[0].shape[0]): 202 | key_real_indexing = ins_train_facts[0][index].cpu().numpy().tolist() 203 | key_real_indexing[ins_train_facts[2][index].item()] = ins_train_facts[3][index].item() 204 | key_real_indexing = [str(x) for x in key_real_indexing] 205 | key_real_indexing = "_".join(key_real_indexing) 206 | if key_real_indexing in real_indexing: 207 | real_indexing[key_real_indexing].append(index) 208 | else: 209 | real_indexing[key_real_indexing] = [index] 210 | 211 | 212 | # ECS_optimizers 213 | ins_optimizer_normal = torch.optim.Adam([{"params": model_normal.parameters()}], lr=config['learning_rate']) 214 | al_list = list() 215 | rel_num = ins_info["rel_num"] 216 | node_num = ins_info["node_num"] 217 | global_true_list = list() 218 | 219 | confidences = torch.empty([0, ]).to(device) 220 | positional_confidences = torch.empty([0, ]).to(device) 221 | model_normal.eval() 222 | all_reviewed_true = (torch.sum(ins_real_facts[2], dim=1) >= 1).int() + (torch.sum(ins_real_facts[1] * ins_real_facts[3], dim=1) == torch.sum(ins_real_facts[3], dim=1)).int() == 2 223 | all_reviewed_true = torch.nonzero(all_reviewed_true).cpu().numpy().tolist() 224 | for j, data in enumerate(no_shuffle_reader): 225 | [real_triples, is_true, is_shown, real_masks, _] = data 226 | ins_optimizer_normal.zero_grad() 227 | _, confidence, positional_confidence, _ = model_normal.forward_E(data, ins_edge_labels, "conf", 228 | None, 0) 229 | positional_confidence_temp = positional_confidence.clone().detach() 230 | positional_confidences = torch.cat([positional_confidences, positional_confidence_temp], dim=0) 231 | confidence_temp = confidence.clone().detach() 232 | confidences = torch.cat([confidences, confidence_temp]) 233 | confidences = (confidences - 0.5) ** 2 234 | confidences = confidences.view(-1) 235 | ranks = torch.argsort(confidences) 236 | true_list = list() 237 | false_list = list() 238 | global_counter = 0 239 | for index in ranks: 240 | is_right = True 241 | if torch.sum(ins_real_facts[2][index]) > 0: 242 | continue 243 | positional_confidence_selected = positional_confidences[index, :] 244 | local_ranks = torch.argsort(positional_confidence_selected) 245 | local_counter = 0 246 | for position in local_ranks: 247 | if ins_real_facts[3][index, position] == 0: 248 | continue 249 | ins_real_facts[2][index, position] = 1 250 | local_counter += 1 251 | if ins_real_facts[1][index, position] == 0: 252 | is_right = False 253 | break 254 | global_counter += local_counter 255 | if torch.sum(ins_real_facts[1][index] * ins_real_facts[3][index]) == torch.sum(ins_real_facts[3][index]): 256 | true_list.append(index.item()) 257 | else: 258 | false_list.append(index.item()) 259 | if global_counter >= samples: 260 | break 261 | 262 | true_facts = ins_real_facts[0][true_list].clone() 263 | true_true = ins_real_facts[1][true_list].clone() 264 | true_shown = ins_real_facts[2][true_list].clone() 265 | true_mask = ins_real_facts[3][true_list].clone() 266 | true_place = ins_real_facts[4][true_list].clone() 267 | false_facts = ins_real_facts[0][false_list].clone() 268 | false_true = ins_real_facts[1][false_list].clone() 269 | false_shown = ins_real_facts[2][false_list].clone() 270 | false_mask = ins_real_facts[3][false_list].clone() 271 | false_place = ins_real_facts[4][false_list].clone() 272 | true_num = len(true_list) 273 | false_num = len(false_list) 274 | global_true_list += true_list 275 | al_facts = torch.cat([true_facts, false_facts], dim=0) 276 | al_true = torch.cat([true_true, false_true], dim=0) 277 | al_shown = torch.cat([true_shown, false_shown], dim=0) 278 | al_mask = torch.cat([true_mask, false_mask], dim=0) 279 | al_place = torch.cat([true_place, false_place], dim=0) 280 | al_list.append([al_facts.long(), al_true, al_shown, al_mask, al_true, true_list]) 281 | if len(al_list) > 10: 282 | del al_list[0] 283 | torch.cuda.empty_cache() 284 | 285 | # Start Training 286 | iterations = 1 287 | 288 | for iteration in range(1, args.epochs // iterations + 1): 289 | logger.info("iteration " + str(iteration)) 290 | model_normal.train() 291 | correct_rate = torch.sum(ins_real_facts[2] * ins_real_facts[1]) / torch.sum(ins_real_facts[2]) 292 | for i in range(iterations): 293 | ins_epoch_loss = 0 294 | start = time.time() 295 | model_normal.train() 296 | if iteration % args.error_detection_every_x_epochs == 0: 297 | print("Start training the cross-grained confidence evaluator") 298 | tuple_embeddings = torch.empty( 299 | [0, model_normal.ins_config["hidden_size"] * (2 * ins_info['max_n'] - 1)]).to(device) 300 | for j, data in enumerate(no_shuffle_reader): 301 | _, _, _, temp_embeddings = model_normal.forward_E(data, ins_edge_labels, "conf", None, 0) 302 | tuple_embeddings = torch.cat([tuple_embeddings, temp_embeddings.clone().detach()], dim=0) 303 | 304 | for j in range(len(al_list)): 305 | id_facts_aug = torch.empty([0, al_facts.shape[1]]).to(device) 306 | is_true_aug = torch.empty([0, al_true.shape[1]]).to(device) 307 | is_shown_aug = torch.empty([0, al_shown.shape[1]]).to(device) 308 | id_masks_aug = torch.empty([0, al_mask.shape[1]]).to(device) 309 | id_facts_aug_false = torch.empty([0, al_facts.shape[1]]).to(device) 310 | is_true_aug_false = torch.empty([0, al_true.shape[1]]).to(device) 311 | is_shown_aug_false = torch.empty([0, al_shown.shape[1]]).to(device) 312 | id_masks_aug_false = torch.empty([0, al_mask.shape[1]]).to(device) 313 | last_loss = 0 314 | true_list = al_list[j][5] 315 | root_embeddings = tuple_embeddings[true_list] 316 | belong_list = dict() 317 | for embedding_index in range(tuple_embeddings.shape[0]): 318 | closest_center = true_list[torch.argmin( 319 | torch.norm(root_embeddings - tuple_embeddings[embedding_index].view(1, -1), dim=1))] 320 | if closest_center in belong_list: 321 | belong_list[closest_center].append(embedding_index) 322 | else: 323 | belong_list[closest_center] = [embedding_index] 324 | correct_2 = 0 325 | total_2 = 0 326 | correct_5 = 0 327 | total_5 = 0 328 | correct_10 = 0 329 | total_10 = 0 330 | correct_20 = 0 331 | total_20 = 0 332 | correct_50 = 0 333 | total_50 = 0 334 | aug_amount = 0 335 | for center in belong_list: 336 | if torch.sum(ins_real_facts[1][center] * ins_real_facts[3][center]) != torch.sum(ins_real_facts[3][center]): 337 | print("error occur!") 338 | print(ins_real_facts[1][center]) 339 | continue 340 | belong_embeddings = tuple_embeddings[belong_list[center]] 341 | aug_distance = torch.norm(belong_embeddings - tuple_embeddings[center], dim=1) 342 | aug_distance[torch.sum(ins_real_facts[2][belong_list[center]], dim=1) > 0] = 999999999 343 | temp_ranks = torch.tensor(belong_list[center], device=device)[ 344 | torch.argsort(aug_distance, descending=False)] 345 | correct_2 += torch.sum((torch.sum(ins_real_facts[1][temp_ranks[:2]] * ins_real_facts[3][temp_ranks[:2]], dim=1) == torch.sum(ins_real_facts[3][temp_ranks[:2]], dim=1)).int()) 346 | correct_5 += torch.sum((torch.sum(ins_real_facts[1][temp_ranks[:5]] * ins_real_facts[3][temp_ranks[:5]], dim=1) == torch.sum(ins_real_facts[3][temp_ranks[:5]], dim=1)).int()) 347 | correct_10 += torch.sum((torch.sum(ins_real_facts[1][temp_ranks[:10]] * ins_real_facts[3][temp_ranks[:10]], dim=1) == torch.sum(ins_real_facts[3][temp_ranks[:10]], dim=1)).int()) 348 | correct_20 += torch.sum((torch.sum(ins_real_facts[1][temp_ranks[:20]] * ins_real_facts[3][temp_ranks[:20]], dim=1) == torch.sum(ins_real_facts[3][temp_ranks[:20]], dim=1)).int()) 349 | correct_50 += torch.sum((torch.sum(ins_real_facts[1][temp_ranks[:50]] * ins_real_facts[3][temp_ranks[:50]], dim=1) == torch.sum(ins_real_facts[3][temp_ranks[:50]], dim=1)).int()) 350 | total_2 += 2 351 | total_5 += 5 352 | total_10 += 10 353 | total_20 += 20 354 | total_50 += 50 355 | temp_ranks = temp_ranks.cpu().numpy().tolist() 356 | # print(temp_ranks) 357 | while True: 358 | if len(temp_ranks) < args.aug_amount: 359 | temp_ranks.append(random.choice(temp_ranks)) 360 | else: 361 | break 362 | temp_ranks = torch.tensor(temp_ranks, dtype=torch.long, device=device) 363 | temp_ranks = temp_ranks[:args.aug_amount] 364 | id_facts_aug = torch.cat([id_facts_aug, ins_real_facts[0][temp_ranks].clone().detach()], dim=0) 365 | is_true_aug = torch.cat( 366 | [is_true_aug, ins_real_facts[3][temp_ranks].clone().detach()], dim=0) 367 | is_shown_aug = torch.cat( 368 | [is_shown_aug, ins_real_facts[3][temp_ranks].clone().detach()], dim=0) 369 | id_masks_aug = torch.cat([id_masks_aug, ins_real_facts[3][temp_ranks].clone().detach()], dim=0) 370 | aug_amount += temp_ranks.shape[0] * 2 371 | for temp_rank_index in temp_ranks: 372 | temp_fact = ins_real_facts[0][temp_rank_index].clone().detach() 373 | # print("temp_fact_before: ", temp_fact) 374 | temp_true = ins_real_facts[3][temp_rank_index].clone().detach() 375 | temp_shown = ins_real_facts[3][temp_rank_index].clone().detach() 376 | temp_fact_mask = ins_real_facts[3][temp_rank_index].clone().detach() 377 | replace_num = int(random.randint(1, torch.sum(temp_fact_mask).item() - 1) / 2) 378 | if replace_num == 0: 379 | replace_num = 1 380 | for replace_index in range(replace_num): 381 | replace_pos = random.randint(0, torch.sum(temp_fact_mask).item() - 1) 382 | temp_true[replace_pos] = 0 383 | if replace_pos % 2 == 0: 384 | random_replace = random.randint(rel_num + 2, node_num - 1) 385 | while True: 386 | if random_replace != temp_fact[replace_pos]: 387 | break 388 | random_replace = random.randint(rel_num + 2, node_num - 1) 389 | temp_fact[replace_pos] = random_replace 390 | else: 391 | random_replace = random.randint(2, rel_num + 1) 392 | while True: 393 | if random_replace != temp_fact[replace_pos]: 394 | break 395 | random_replace = random.randint(2, rel_num + 1) 396 | temp_fact[replace_pos] = random_replace 397 | id_facts_aug_false = torch.cat([id_facts_aug_false, temp_fact.view(1, -1)], dim=0) 398 | is_true_aug_false = torch.cat([is_true_aug_false, temp_true.view(1, -1)], dim=0) 399 | is_shown_aug_false = torch.cat([is_shown_aug_false, temp_shown.view(1, -1)], dim=0) 400 | id_masks_aug_false = torch.cat([id_masks_aug_false, temp_fact_mask.view(1, -1)], dim=0) 401 | correct_rate_j = torch.sum((torch.sum(al_list[j][1]*al_list[j][3], dim=1) == torch.sum(al_list[j][3], dim=1)).int()).item() 402 | total_j = al_list[j][0].shape[0] 403 | # print("id_facts_aug_false: ", id_facts_aug_false.shape[0]) 404 | # print("id_facts_aug: ", id_facts_aug.shape[0]) 405 | # print("correct_rate_j: ", correct_rate_j) 406 | # print("total_j: ", total_j) 407 | if correct_rate_j > 0.5 * total_j: 408 | retain_count = int((total_j - correct_rate_j) * id_facts_aug_false.shape[0] / correct_rate_j) 409 | retain_index = random.sample(range(id_facts_aug_false.shape[0]), retain_count) 410 | # print(max(retain_index)) 411 | # print(id_facts_aug_false.shape[0]) 412 | id_facts_aug_false = id_facts_aug_false[retain_index] 413 | is_true_aug_false = is_true_aug_false[retain_index] 414 | is_shown_aug_false = is_shown_aug_false[retain_index] 415 | id_masks_aug_false = id_masks_aug_false[retain_index] 416 | else: 417 | retain_count = int(correct_rate_j * id_facts_aug.shape[0] / (total_j - correct_rate_j)) 418 | retain_index = random.sample(range(id_facts_aug.shape[0]), retain_count) 419 | # print(max(retain_index)) 420 | # print(id_facts_aug.shape[0]) 421 | id_facts_aug = id_facts_aug[retain_index] 422 | is_true_aug = is_true_aug[retain_index] 423 | is_shown_aug = is_shown_aug[retain_index] 424 | id_masks_aug = id_masks_aug[retain_index] 425 | id_facts_aug = torch.cat([id_facts_aug, id_facts_aug_false], dim=0) 426 | is_true_aug = torch.cat([is_true_aug, is_true_aug_false], dim=0) 427 | is_shown_aug = torch.cat([is_shown_aug, is_shown_aug_false], dim=0) 428 | id_masks_aug = torch.cat([id_masks_aug, id_masks_aug_false], dim=0) 429 | # print("closest 2: ", correct_2 / total_2) 430 | # print("closest 5: ", correct_5 / total_5) 431 | # print("closest 10: ", correct_10 / total_10) 432 | # print("closest 20: ", correct_20 / total_20) 433 | # print("closest 50: ", correct_50 / total_50) 434 | # print("aug_amount: ", aug_amount) 435 | # print("correct_rate: ", correct_rate_j/total_j) 436 | 437 | # temp_states = list() 438 | raw_state = model_normal.state_dict() 439 | while True: 440 | [id_facts, is_true, is_shown, id_masks, place_mask, true_list] = al_list[j] 441 | loss_item = 0 442 | # print(id_facts_aug.shape[0]) 443 | if id_facts.shape[0] == 0: 444 | pass 445 | else: 446 | ins_pos_final = [id_facts_aug.long(), is_true_aug, is_shown_aug, id_masks_aug, is_true_aug] 447 | for concat_index in range(len(ins_pos_final)): 448 | ins_pos_final[concat_index] = torch.cat([ins_pos_final[concat_index], al_list[j][concat_index]], dim=0) 449 | aug_dataset = EDataset(ins_pos_final) 450 | aug_reader = DataLoader.DataLoader( 451 | aug_dataset, 452 | batch_size=args.batch_size, 453 | shuffle=True, 454 | drop_last=False) 455 | for aug_index, data in enumerate(aug_reader): 456 | ins_optimizer_normal.zero_grad() 457 | ins_loss_conf, _, fc_out_vector, _ = model_normal.forward_E(data, 458 | ins_edge_labels, 459 | "conf", None, 460 | correct_rate) 461 | ins_loss_pos_nong = torch.nn.BCELoss(reduction="none")(fc_out_vector * data[2], 462 | data[1] * data[ 463 | 2]) 464 | ins_loss_pos_nong = torch.sum(ins_loss_pos_nong, dim=1) / torch.sum(data[2], 465 | dim=1) 466 | ins_loss_pos = torch.mean(ins_loss_pos_nong) 467 | ins_loss = ins_loss_pos + ins_loss_conf 468 | ins_loss.backward() 469 | ins_optimizer_normal.step() 470 | loss_item += ins_loss.item() 471 | if j % 1 == 0: 472 | logger.info( 473 | str(j) + ' , ins_loss_conf: ' + str(ins_loss.item()) + " with memory " + str( 474 | torch.cuda.memory_allocated(device=device) / 1024 / 1024)) 475 | if abs(loss_item - last_loss) / loss_item <= 0.05: 476 | break 477 | else: 478 | last_loss = loss_item 479 | temp_state = model_normal.state_dict() 480 | for key in temp_state: 481 | temp_state[key] = raw_state[key] + args.meta_lr * (temp_state[key] - raw_state[key]) 482 | model_normal.load_state_dict(temp_state) 483 | del temp_state 484 | del raw_state 485 | torch.cuda.empty_cache() 486 | 487 | print("Start active learning with effort-efficient active labeler") 488 | confidences = torch.empty([0, ]).to(device) 489 | positional_confidences = torch.empty([0, ]).to(device) 490 | model_normal.eval() 491 | all_reviewed_true = (torch.sum(ins_real_facts[2], dim=1) >= 1).int() + ( 492 | torch.sum(ins_real_facts[1] * ins_real_facts[3], dim=1) == torch.sum(ins_real_facts[3], 493 | dim=1)).int() == 2 494 | for j, data in enumerate(no_shuffle_reader): 495 | ins_optimizer_normal.zero_grad() 496 | _, confidence, positional_confidence, _ = model_normal.forward_E(data, ins_edge_labels, "conf", 497 | None, 0) 498 | positional_confidence_temp = positional_confidence.clone().detach() 499 | positional_confidences = torch.cat([positional_confidences, positional_confidence_temp], dim=0) 500 | confidence_temp = confidence.clone().detach() 501 | for update_index in range(data[0].shape[0]): 502 | update_key = data[0][update_index].cpu().numpy().tolist() 503 | update_key = [str(x) for x in update_key] 504 | update_key = "_".join(update_key) 505 | ins_train_facts[5][real_indexing[update_key]] = confidence_temp[update_index].clone() 506 | confidences = torch.cat([confidences, confidence_temp]) 507 | confidences = (confidences - 0.5) ** 2 508 | confidences = confidences.view(-1) 509 | ranks = torch.argsort(confidences) 510 | true_list = list() 511 | false_list = list() 512 | global_counter = 0 513 | for index in ranks: 514 | is_right = True 515 | if torch.sum(ins_real_facts[2][index]) > 0: 516 | continue 517 | positional_confidence_selected = positional_confidences[index, :] 518 | local_ranks = torch.argsort(positional_confidence_selected) 519 | local_counter = 0 520 | for position in local_ranks: 521 | if ins_real_facts[3][index, position] == 0: 522 | continue 523 | ins_real_facts[2][index, position] = 1 524 | local_counter += 1 525 | if ins_real_facts[1][index, position] == 0: 526 | is_right = False 527 | break 528 | global_counter += local_counter 529 | if torch.sum(ins_real_facts[1][index] * ins_real_facts[3][index]) == torch.sum( 530 | ins_real_facts[3][index]): 531 | true_list.append(index.item()) 532 | else: 533 | false_list.append(index.item()) 534 | if global_counter >= samples: 535 | break 536 | 537 | true_facts = ins_real_facts[0][true_list].clone() 538 | true_true = ins_real_facts[1][true_list].clone() 539 | true_shown = ins_real_facts[2][true_list].clone() 540 | true_mask = ins_real_facts[3][true_list].clone() 541 | true_place = ins_real_facts[4][true_list].clone() 542 | false_facts = ins_real_facts[0][false_list].clone() 543 | false_true = ins_real_facts[1][false_list].clone() 544 | false_shown = ins_real_facts[2][false_list].clone() 545 | false_mask = ins_real_facts[3][false_list].clone() 546 | false_place = ins_real_facts[4][false_list].clone() 547 | global_true_list += true_list 548 | 549 | al_facts = torch.cat([true_facts, false_facts], dim=0) 550 | al_true = torch.cat([true_true, false_true], dim=0) 551 | al_shown = torch.cat([true_shown, false_shown], dim=0) 552 | al_mask = torch.cat([true_mask, false_mask], dim=0) 553 | al_list.append([al_facts.long(), al_true, al_shown, al_mask, al_true, true_list]) 554 | if len(al_list) > 10: 555 | del al_list[0] 556 | torch.cuda.empty_cache() 557 | 558 | batch_num = 0 559 | sum_cos = 0 560 | average_margin = 0 561 | right_sum = 0 562 | wrong_sum = 0 563 | right_cos = 0 564 | wrong_cos = 0 565 | 566 | right_right_fact = 0 567 | right_false_fact = 0 568 | false_false_fact = 0 569 | false_right_fact = 0 570 | 571 | right_right_element = 0 572 | right_false_element = 0 573 | false_false_element = 0 574 | false_right_element = 0 575 | 576 | model_normal.eval() 577 | for _, data in enumerate(test_real_E_pyreader): 578 | _, confidences_tag, fc_out_vector, _ = model_normal.forward_E(data, ins_edge_labels, "conf", None, 579 | correct_rate) 580 | tags = data[4] 581 | fc_out_vector[fc_out_vector >= 0.5] = 1 582 | fc_out_vector[fc_out_vector < 0.5] = 0 583 | confidences_tag[confidences_tag >= 0.5] = 1 584 | confidences_tag[confidences_tag < 0.5] = 0 585 | input_masks = data[3].squeeze() 586 | fc_out_vector = fc_out_vector * input_masks 587 | tags = tags * input_masks 588 | for i in range(fc_out_vector.shape[0]): 589 | output = fc_out_vector[i, :] 590 | tag = tags[i, :] 591 | mask = input_masks[i, :] 592 | conf_tag = confidences_tag[i] 593 | if torch.sum(tag) == torch.sum(mask): 594 | total_tag = 1 595 | else: 596 | total_tag = 0 597 | batch_num += 1 598 | is_same_num = is_same(output, tag, ins_info['max_n']) 599 | sum_cos += is_same_num 600 | if mask.equal(tag): 601 | right_sum += 1 602 | right_cos += is_same_num 603 | else: 604 | wrong_sum += 1 605 | wrong_cos += is_same_num 606 | 607 | if total_tag == 1 and conf_tag == 1: 608 | right_right_fact += 1 609 | if total_tag == 1 and conf_tag == 0: 610 | right_false_fact += 1 611 | if total_tag == 0 and conf_tag == 1: 612 | false_right_fact += 1 613 | if total_tag == 0 and conf_tag == 0: 614 | false_false_fact += 1 615 | 616 | for index_conf in range(torch.sum(mask).int().item()): 617 | if tag[index_conf] == 1 and output[index_conf] == 1: 618 | right_right_element += 1 619 | if tag[index_conf] == 1 and output[index_conf] == 0: 620 | right_false_element += 1 621 | if tag[index_conf] == 0 and output[index_conf] == 1: 622 | false_right_element += 1 623 | if tag[index_conf] == 0 and output[index_conf] == 0: 624 | false_false_element += 1 625 | 626 | model_normal.train() 627 | # conf_sum = 0 628 | print("Start training hyper-relational link predictor") 629 | for j, data in enumerate(train_E_pyreader): 630 | [id_facts, id_masks, mask_pos, mask_labels, mask_types, confidence] = data 631 | id_facts_temp = copy.deepcopy(id_facts) 632 | id_facts_len = id_facts_temp.shape[0] 633 | id_facts_temp[list(range(id_facts_len)), mask_pos] = mask_labels 634 | ins_optimizer_normal.zero_grad() 635 | bs = id_facts_temp.shape[0] 636 | ins_pos_normal = [id_facts, id_masks, mask_pos, mask_labels, mask_types] 637 | ins_loss, _ = model_normal.forward_E(ins_pos_normal, ins_edge_labels, "normal", 2 * confidence, correct_rate) 638 | ins_loss.backward() 639 | ins_optimizer_normal.step() 640 | ins_epoch_loss += ins_loss 641 | 642 | # print_ECS_loss_per_step 643 | if j % 100 == 0: 644 | logger.info(str(j) + ' , ins_loss: ' + str(ins_loss.item()) + " with memory " + str(torch.cuda.memory_allocated(device=device) / 1024 / 1024)) 645 | 646 | ins_epoch_loss /= steps 647 | end = time.time() 648 | t2 = round(end - start, 2) 649 | logger.info("ins_epoch_loss = {:.3f}, time = {:.3f} s".format(ins_epoch_loss, t2)) 650 | 651 | 652 | # Start validation and testing 653 | with torch.no_grad(): 654 | h2E = predict( 655 | model=model_normal, 656 | ins_test_pyreader=test_E_pyreader, 657 | ins_all_facts=ins_info['all_facts'], 658 | ins_edge_labels=ins_edge_labels, 659 | device=device) 660 | 661 | if iteration % args.error_detection_every_x_epochs == 0: 662 | print("accuracy_fact: ", (right_right_fact + false_false_fact) / (right_right_fact + false_false_fact + false_right_fact + right_false_fact)) 663 | precision_fact = right_right_fact / (right_right_fact + false_right_fact) 664 | recall_fact = right_right_fact / (right_right_fact + right_false_fact) 665 | print("precision_fact: ", precision_fact) 666 | print("recall_fact: ", recall_fact) 667 | print("F1_fact: ", (2 * precision_fact * recall_fact) / (precision_fact + recall_fact)) 668 | print("accuracy_element: ", (right_right_element + false_false_element) / (right_right_element + false_false_element + false_right_element + right_false_element)) 669 | precision_element = right_right_element / (right_right_element + false_right_element) 670 | recall_element = right_right_element / (right_right_element + right_false_element) 671 | print("precision_element: ", precision_element) 672 | print("recall_element: ", recall_element) 673 | print("F1_element: ", (2 * precision_element * recall_element) / (precision_element + recall_element)) 674 | 675 | logger.info("stop") 676 | 677 | 678 | def predict(model, ins_test_pyreader, 679 | ins_all_facts, 680 | ins_edge_labels, device): 681 | start = time.time() 682 | 683 | step = 0 684 | ins_ret_ranks = dict() 685 | ins_ret_ranks['entity'] = torch.empty(0).to(device) 686 | ins_ret_ranks['relation'] = torch.empty(0).to(device) 687 | ins_ret_ranks['2-r'] = torch.empty(0).to(device) 688 | ins_ret_ranks['2-ht'] = torch.empty(0).to(device) 689 | ins_ret_ranks['n-r'] = torch.empty(0).to(device) 690 | ins_ret_ranks['n-ht'] = torch.empty(0).to(device) 691 | ins_ret_ranks['n-a'] = torch.empty(0).to(device) 692 | ins_ret_ranks['n-v'] = torch.empty(0).to(device) 693 | 694 | # while steps < max_train_steps: 695 | for i, data in enumerate(ins_test_pyreader): 696 | ins_pos = data 697 | length = data[0].shape[0] 698 | _, ins_np_fc_out = model.forward_E(ins_pos, ins_edge_labels, "normal", torch.ones(length).to("cuda:0"), 0) 699 | 700 | ins_ret_ranks = batch_evaluation(ins_np_fc_out, ins_pos, ins_all_facts, ins_ret_ranks, device) 701 | 702 | step += 1 703 | 704 | ins_eval_performance = compute_metrics(ins_ret_ranks) 705 | 706 | ins_all_entity = "ENTITY\t\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % ( 707 | ins_eval_performance['entity']['mrr'], 708 | ins_eval_performance['entity']['hits1'], 709 | ins_eval_performance['entity']['hits3'], 710 | ins_eval_performance['entity']['hits5'], 711 | ins_eval_performance['entity']['hits10']) 712 | 713 | ins_all_relation = "RELATION\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % ( 714 | ins_eval_performance['relation']['mrr'], 715 | ins_eval_performance['relation']['hits1'], 716 | ins_eval_performance['relation']['hits3'], 717 | ins_eval_performance['relation']['hits5'], 718 | ins_eval_performance['relation']['hits10']) 719 | 720 | ins_all_ht = "HEAD/TAIL\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % ( 721 | ins_eval_performance['ht']['mrr'], 722 | ins_eval_performance['ht']['hits1'], 723 | ins_eval_performance['ht']['hits3'], 724 | ins_eval_performance['ht']['hits5'], 725 | ins_eval_performance['ht']['hits10']) 726 | 727 | ins_all_r = "PRIMARY_R\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % ( 728 | ins_eval_performance['r']['mrr'], 729 | ins_eval_performance['r']['hits1'], 730 | ins_eval_performance['r']['hits3'], 731 | ins_eval_performance['r']['hits5'], 732 | ins_eval_performance['r']['hits10']) 733 | 734 | logger.info("\n-------- E Evaluation Performance --------\n%s\n%s\n%s\n%s\n%s" % ( 735 | "\t".join(["TASK\t", "MRR", "Hits@1", "Hits@3", "Hits@5", "Hits@10"]), 736 | ins_all_ht, ins_all_r, ins_all_entity, ins_all_relation)) 737 | 738 | end = time.time() 739 | logger.info("INS time: " + str(round(end - start, 3)) + 's') 740 | 741 | return ins_eval_performance['entity']['hits1'] 742 | 743 | 744 | if __name__ == '__main__': 745 | print_arguments(args) 746 | main(args) 747 | -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Arguments for configuration.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import six 21 | import logging 22 | 23 | logging.basicConfig( 24 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 25 | datefmt='%m/%d/%Y %H:%M:%S', 26 | level=logging.INFO) 27 | logging.getLogger().setLevel(logging.INFO) 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def str2bool(v): 32 | # because argparse does not support to parse "true, False" as python 33 | # boolean directly 34 | return v.lower() in ("true", "t", "1") 35 | 36 | 37 | class ArgumentGroup(object): 38 | def __init__(self, parser, title, des): 39 | self._group = parser.add_argument_group(title=title, description=des) 40 | 41 | def add_arg(self, name, type, default, help, **kwargs): 42 | type = str2bool if type == bool else type 43 | self._group.add_argument( 44 | "--" + name, 45 | default=default, 46 | type=type, 47 | help=help + ' Default: %(default)s.', 48 | **kwargs) 49 | 50 | 51 | def print_arguments(args): 52 | logger.info('----------- Configuration Arguments -----------') 53 | for arg, value in sorted(six.iteritems(vars(args))): 54 | logger.info('%s: %s' % (arg, value)) 55 | logger.info('------------------------------------------------') 56 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | torch.set_printoptions(precision=8) 5 | 6 | 7 | def compute_hyperbolic_distances(vectors_u, vectors_v): 8 | """ 9 | Compute distances between input vectors. 10 | Modified based on gensim code. 11 | vectors_u: (batch_size, dim) 12 | vectors_v: (batch_size, dim) 13 | """ 14 | euclidean_dists = np.linalg.norm(vectors_u - vectors_v, axis=1) # (batch_size, ) 15 | return euclidean_dists # (batch_size, ) 16 | 17 | 18 | def normalization(data): 19 | _range = np.max(data) - np.min(data) 20 | return (data - np.min(data)) / _range 21 | 22 | def compute_hyperbolic_similarity(embeds1, embeds2): 23 | x1, y1 = embeds1.shape # 24 | x2, y2 = embeds2.shape 25 | assert y1 == y2 26 | dist_vec_list = list() 27 | for i in range(x1): 28 | embed1 = embeds1[i, ] # (y1,) 29 | embed1 = np.reshape(embed1, (1, y1)) # (1, y1) 30 | embed1 = np.repeat(embed1, x2, axis=0) # (x2, y1) 31 | dist_vec = compute_hyperbolic_distances(embed1, embeds2) 32 | dist_vec_list.append(dist_vec) 33 | dis_mat = np.row_stack(dist_vec_list) # (x1, x2) 34 | return normalization(-dis_mat) 35 | 36 | def cal_rank_hyperbolic(frags, sub_embed, embed, multi_types_list, top_k, greedy): 37 | onto_number = embed.shape[0] 38 | mr = 0 39 | mrr = 0 40 | hits = np.array([0 for _ in top_k]) 41 | sim_mat = compute_hyperbolic_similarity(sub_embed, embed) 42 | results = set() 43 | test_num = sub_embed.shape[0] 44 | for i in range(len(frags)): 45 | ref = frags[i] 46 | rank = (-sim_mat[i, :]).argsort() 47 | aligned_e = rank[0] 48 | results.add((ref, aligned_e)) 49 | multi_types = multi_types_list[ref] 50 | if greedy: 51 | rank_index = onto_number 52 | for item in multi_types: 53 | temp_rank_index = np.where(rank == item)[0][0] 54 | rank_index = min(temp_rank_index, rank_index) 55 | mr += (rank_index + 1) 56 | mrr += 1 / (rank_index + 1) 57 | for j in range(len(top_k)): 58 | if rank_index < top_k[j]: 59 | hits[j] += 1 60 | else: 61 | for item in multi_types: 62 | rank_index = np.where(rank == item)[0][0] 63 | mr += (rank_index + 1) 64 | mrr += 1 / (rank_index + 1) 65 | for j in range(len(top_k)): 66 | if rank_index < top_k[j]: 67 | hits[j] += 1 68 | test_num += (len(multi_types) - 1) 69 | return mr, mrr, hits, results, test_num 70 | 71 | def eval_type_hyperbolic(embed1, embed2, ent_type, top_k, greedy=True): 72 | 73 | ref_num = len(embed1) 74 | hits = np.array([0 for _ in top_k]) 75 | mr = 0 76 | mrr = 0 77 | total_test_num = 0 78 | total_alignment = set() 79 | 80 | frags = np.array(range(ref_num)) 81 | results=cal_rank_hyperbolic(frags, embed1, embed2, ent_type, top_k, greedy) 82 | 83 | mr, mrr, hits, total_alignment, total_test_num = results 84 | 85 | 86 | if greedy: 87 | assert total_test_num == ref_num 88 | else: 89 | print("multi types:", total_test_num - ref_num) 90 | 91 | hits = hits / total_test_num 92 | for i in range(len(hits)): 93 | hits[i] = round(hits[i], 4) 94 | mr /= total_test_num 95 | mrr /= total_test_num 96 | 97 | eval_performance=dict() 98 | eval_performance['mrr']=mrr 99 | eval_performance['hits1']=hits[0] 100 | eval_performance['hits3']=hits[1] 101 | eval_performance['hits5']=hits[2] 102 | eval_performance['hits10']=hits[3] 103 | 104 | return eval_performance 105 | 106 | def batch_evaluation( batch_results, all_facts, gt_dict,ret_ranks,device): 107 | """ 108 | Perform batch evaluation. 109 | """ 110 | for i, result in enumerate(batch_results): 111 | target = all_facts[3][i] 112 | pos = all_facts[2][i] 113 | key = " ".join([ 114 | str(all_facts[0][i][x].item()) for x in range(len(all_facts[0][i])) 115 | if x != pos 116 | ]) 117 | 118 | # filtered setting 119 | rm_idx = torch.tensor(gt_dict[pos.item()][key]).to(device) 120 | rm_idx=torch.where(rm_idx!=target,rm_idx, torch.tensor(0).to(device)) 121 | result.index_fill_(0,rm_idx,-np.Inf) 122 | 123 | sortidx = torch.argsort(result,dim=-1,descending=True) 124 | 125 | if all_facts[4][i] == 1: 126 | ret_ranks['entity']=torch.cat([ret_ranks['entity'],(torch.where(sortidx == target)[0] + 1)],dim=0) 127 | elif all_facts[4][i] == -1: 128 | ret_ranks['relation']=torch.cat([ret_ranks['relation'],(torch.where(sortidx == target)[0]+ 1)],dim=0) 129 | else: 130 | raise ValueError("Invalid `feature.mask_type`.") 131 | 132 | if torch.sum(all_facts[1][i]) == 3: 133 | if pos == 1: 134 | ret_ranks['2-r']=torch.cat([ret_ranks['2-r'],(torch.where(sortidx == target)[0] + 1)],dim=0) 135 | elif pos == 0 or pos == 2: 136 | ret_ranks['2-ht']=torch.cat([ret_ranks['2-ht'],(torch.where(sortidx == target)[0] + 1)],dim=0) 137 | else: 138 | raise ValueError("Invalid `feature.mask_position`.") 139 | elif torch.sum(all_facts[1][i]) > 3: 140 | if pos == 1: 141 | ret_ranks['n-r']=torch.cat([ret_ranks['n-r'],(torch.where(sortidx == target)[0]+ 1)],dim=0) 142 | elif pos == 0 or pos == 2: 143 | ret_ranks['n-ht']=torch.cat([ret_ranks['n-ht'],(torch.where(sortidx == target)[0]+ 1)],dim=0) 144 | elif pos > 2 and all_facts[4][i] == -1: 145 | ret_ranks['n-a']=torch.cat([ret_ranks['n-a'],(torch.where(sortidx == target)[0]+ 1)],dim=0) 146 | elif pos > 2 and all_facts[4][i] == 1: 147 | ret_ranks['n-v']=torch.cat([ret_ranks['n-v'],(torch.where(sortidx == target)[0]+ 1)],dim=0) 148 | else: 149 | raise ValueError("Invalid `feature.mask_position`.") 150 | else: 151 | raise ValueError("Invalid `feature.arity`.") 152 | return ret_ranks 153 | 154 | def compute_metrics(ret_ranks): 155 | """ 156 | Combine the ranks from batches into final metrics. 157 | """ 158 | 159 | all_ent_ranks = ret_ranks['entity'] 160 | all_rel_ranks = ret_ranks['relation'] 161 | _2_r_ranks = ret_ranks['2-r'] 162 | _2_ht_ranks = ret_ranks['2-ht'] 163 | _n_r_ranks = ret_ranks['n-r'] 164 | _n_ht_ranks = ret_ranks['n-ht'] 165 | _n_a_ranks = ret_ranks['n-a'] 166 | _n_v_ranks = ret_ranks['n-v'] 167 | all_r_ranks = torch.cat([ret_ranks['2-r'],ret_ranks['n-r']],dim=0) 168 | all_ht_ranks = torch.cat([ret_ranks['2-ht'],ret_ranks['n-ht']],dim=0) 169 | 170 | mrr_ent = torch.mean(1.0 / all_ent_ranks).item() 171 | hits1_ent = torch.mean(torch.where(all_ent_ranks <= 1.0,1.0,0.0)).item() 172 | hits3_ent = torch.mean(torch.where(all_ent_ranks <= 3.0,1.0,0.0)).item() 173 | hits5_ent = torch.mean(torch.where(all_ent_ranks <= 5.0,1.0,0.0)).item() 174 | hits10_ent = torch.mean(torch.where(all_ent_ranks <= 10.0,1.0,0.0)).item() 175 | 176 | mrr_rel = torch.mean(1.0 / all_rel_ranks).item() 177 | hits1_rel = torch.mean(torch.where(all_rel_ranks <= 1.0,1.0,0.0)).item() 178 | hits3_rel = torch.mean(torch.where(all_rel_ranks <= 3.0,1.0,0.0)).item() 179 | hits5_rel = torch.mean(torch.where(all_rel_ranks <= 5.0,1.0,0.0)).item() 180 | hits10_rel = torch.mean(torch.where(all_rel_ranks <= 10.0,1.0,0.0)).item() 181 | 182 | mrr_2r = torch.mean(1.0 / _2_r_ranks).item() 183 | hits1_2r = torch.mean(torch.where(_2_r_ranks <= 1.0,1.0,0.0)).item() 184 | hits3_2r = torch.mean(torch.where(_2_r_ranks <= 3.0,1.0,0.0)).item() 185 | hits5_2r = torch.mean(torch.where(_2_r_ranks <= 5.0,1.0,0.0)).item() 186 | hits10_2r = torch.mean(torch.where(_2_r_ranks <= 10.0,1.0,0.0)).item() 187 | 188 | mrr_2ht = torch.mean(1.0 / _2_ht_ranks).item() 189 | hits1_2ht = torch.mean(torch.where(_2_ht_ranks <= 1.0,1.0,0.0)).item() 190 | hits3_2ht = torch.mean(torch.where(_2_ht_ranks <= 3.0,1.0,0.0)).item() 191 | hits5_2ht = torch.mean(torch.where(_2_ht_ranks <= 5.0,1.0,0.0)).item() 192 | hits10_2ht = torch.mean(torch.where(_2_ht_ranks <= 10.0,1.0,0.0)).item() 193 | 194 | mrr_nr = torch.mean(1.0 / _n_r_ranks).item() 195 | hits1_nr = torch.mean(torch.where(_n_r_ranks <= 1.0,1.0,0.0)).item() 196 | hits3_nr = torch.mean(torch.where(_n_r_ranks <= 3.0,1.0,0.0)).item() 197 | hits5_nr = torch.mean(torch.where(_n_r_ranks <= 5.0,1.0,0.0)).item() 198 | hits10_nr = torch.mean(torch.where(_n_r_ranks <= 10.0,1.0,0.0)).item() 199 | 200 | mrr_nht = torch.mean(1.0 / _n_ht_ranks).item() 201 | hits1_nht = torch.mean(torch.where(_n_ht_ranks <= 1.0,1.0,0.0)).item() 202 | hits3_nht = torch.mean(torch.where(_n_ht_ranks <= 3.0,1.0,0.0)).item() 203 | hits5_nht = torch.mean(torch.where(_n_ht_ranks <= 5.0,1.0,0.0)).item() 204 | hits10_nht = torch.mean(torch.where(_n_ht_ranks <= 10.0,1.0,0.0)).item() 205 | 206 | mrr_na = torch.mean(1.0 / _n_a_ranks).item() 207 | hits1_na = torch.mean(torch.where(_n_a_ranks <= 1.0,1.0,0.0)).item() 208 | hits3_na = torch.mean(torch.where(_n_a_ranks <= 3.0,1.0,0.0)).item() 209 | hits5_na = torch.mean(torch.where(_n_a_ranks <= 5.0,1.0,0.0)).item() 210 | hits10_na = torch.mean(torch.where(_n_a_ranks <= 10.0,1.0,0.0)).item() 211 | 212 | mrr_nv = torch.mean(1.0 / _n_v_ranks).item() 213 | hits1_nv = torch.mean(torch.where(_n_v_ranks <= 1.0,1.0,0.0)).item() 214 | hits3_nv = torch.mean(torch.where(_n_v_ranks <= 3.0,1.0,0.0)).item() 215 | hits5_nv = torch.mean(torch.where(_n_v_ranks <= 5.0,1.0,0.0)).item() 216 | hits10_nv = torch.mean(torch.where(_n_v_ranks <= 10.0,1.0,0.0)).item() 217 | 218 | mrr_r = torch.mean(1.0 / all_r_ranks).item() 219 | hits1_r = torch.mean(torch.where(all_r_ranks <= 1.0,1.0,0.0)).item() 220 | hits3_r = torch.mean(torch.where(all_r_ranks <= 3.0,1.0,0.0)).item() 221 | hits5_r = torch.mean(torch.where(all_r_ranks <= 5.0,1.0,0.0)).item() 222 | hits10_r = torch.mean(torch.where(all_r_ranks <= 10.0,1.0,0.0)).item() 223 | 224 | mrr_ht = torch.mean(1.0 / all_ht_ranks).item() 225 | hits1_ht = torch.mean(torch.where(all_ht_ranks <= 1.0,1.0,0.0)).item() 226 | hits3_ht = torch.mean(torch.where(all_ht_ranks <= 3.0,1.0,0.0)).item() 227 | hits5_ht = torch.mean(torch.where(all_ht_ranks <= 5.0,1.0,0.0)).item() 228 | hits10_ht = torch.mean(torch.where(all_ht_ranks <= 10.0,1.0,0.0)).item() 229 | 230 | eval_result = { 231 | 'entity': { 232 | 'mrr': mrr_ent, 233 | 'hits1': hits1_ent, 234 | 'hits3': hits3_ent, 235 | 'hits5': hits5_ent, 236 | 'hits10': hits10_ent 237 | }, 238 | 'relation': { 239 | 'mrr': mrr_rel, 240 | 'hits1': hits1_rel, 241 | 'hits3': hits3_rel, 242 | 'hits5': hits5_rel, 243 | 'hits10': hits10_rel 244 | }, 245 | 'ht': { 246 | 'mrr': mrr_ht, 247 | 'hits1': hits1_ht, 248 | 'hits3': hits3_ht, 249 | 'hits5': hits5_ht, 250 | 'hits10': hits10_ht 251 | }, 252 | '2-ht': { 253 | 'mrr': mrr_2ht, 254 | 'hits1': hits1_2ht, 255 | 'hits3': hits3_2ht, 256 | 'hits5': hits5_2ht, 257 | 'hits10': hits10_2ht 258 | }, 259 | 'n-ht': { 260 | 'mrr': mrr_nht, 261 | 'hits1': hits1_nht, 262 | 'hits3': hits3_nht, 263 | 'hits5': hits5_nht, 264 | 'hits10': hits10_nht 265 | }, 266 | 'r': { 267 | 'mrr': mrr_r, 268 | 'hits1': hits1_r, 269 | 'hits3': hits3_r, 270 | 'hits5': hits5_r, 271 | 'hits10': hits10_r 272 | }, 273 | '2-r': { 274 | 'mrr': mrr_2r, 275 | 'hits1': hits1_2r, 276 | 'hits3': hits3_2r, 277 | 'hits5': hits5_2r, 278 | 'hits10': hits10_2r 279 | }, 280 | 'n-r': { 281 | 'mrr': mrr_nr, 282 | 'hits1': hits1_nr, 283 | 'hits3': hits3_nr, 284 | 'hits5': hits5_nr, 285 | 'hits10': hits10_nr 286 | }, 287 | 'n-a': { 288 | 'mrr': mrr_na, 289 | 'hits1': hits1_na, 290 | 'hits3': hits3_na, 291 | 'hits5': hits5_na, 292 | 'hits10': hits10_na 293 | }, 294 | 'n-v': { 295 | 'mrr': mrr_nv, 296 | 'hits1': hits1_nv, 297 | 'hits3': hits3_nv, 298 | 'hits5': hits5_nv, 299 | 'hits10': hits10_nv 300 | }, 301 | } 302 | 303 | return eval_result --------------------------------------------------------------------------------