├── README.md ├── data └── link ├── model ├── graph_encoder.py ├── init_helios.py └── model.py ├── reader ├── data_loader.py └── data_reader.py ├── run.py └── utils ├── args.py └── evaluation.py /README.md: -------------------------------------------------------------------------------- 1 | # HELIOS (Hyper-relational Schema Model) 2 | 3 | HELIOS is a hyper-relational schema model, which directly learns from hyper-relational schema tuples in a KG. HELIOS captures not only the correlation between multiple types of a single entity, but also the correlation between types of different entities and relations in a hyper-relational schema tuple. Please see the details in our paper below: 4 | - Yuhuan Lu, Bangchao Deng, Weijian Yu, and Dingqi Yang. 2023. HELIOS: Hyper-Relational Schema Modeling from Knowledge Graphs. In Proceedings of the 31st ACM International Conference on Multimedia (MM ’23), October 29–November 3, 2023, Ottawa, ON, Canada. 5 | 6 | ## How to run the code 7 | ###### Train and evaluate model (suggested parameters for JF17k, WikiPeople and WD50K dataset) 8 | ``` 9 | python run.py --dataset jf17k --gpu 0 10 | 11 | python run.py --dataset wikipeople --gpu 0 12 | 13 | python run.py --dataset wd50k --gpu 0 14 | ``` 15 | The datasets are available here: https://www.dropbox.com/s/iz5wxp0uldx5i05/data.zip?dl=0 , and put them into the data folder. 16 | 17 | ###### Parameter setting: 18 | In `run.py`, you can set: 19 | 20 | `--dataset`: input dataset 21 | 22 | `--epochs`: number of training epochs 23 | 24 | `--batch_size`: batch size of training set 25 | 26 | `--dim`: embedding size 27 | 28 | `--learning_rate`: learning rate 29 | 30 | `--self_attention_layers`: number of self-attention layers 31 | 32 | `--gat_layers`: number of GAT layers 33 | 34 | `--gpu`: gpu to be used for train and test the model 35 | 36 | `--num_attention_heads`: number of attention heads 37 | 38 | # Python lib versions 39 | Python: 3.7.13 40 | 41 | torch: 1.11.0 42 | 43 | # Reference 44 | If you use our code or datasets, please cite: 45 | ``` 46 | @inproceedings{lu2023helios, 47 | title={HELIOS: Hyper-Relational Schema Modeling from Knowledge Graphs}, 48 | author={Lu, Yuhuan and Deng, Bangchao and Yu, Weijian and Yang, Dingqi}, 49 | booktitle={Proceedings of the 31st ACM International Conference on Multimedia}, 50 | pages={4053--4064}, 51 | year={2023} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /data/link: -------------------------------------------------------------------------------- 1 | The datasets are available here: https://www.dropbox.com/s/iz5wxp0uldx5i05/data.zip?dl=0 , and put them into the data folder. 2 | -------------------------------------------------------------------------------- /model/graph_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import numpy as np 4 | 5 | def truncated_normal(t, mean=0.0, std=0.01): 6 | torch.nn.init.normal_(t, mean=mean, std=std) 7 | while True: 8 | cond = torch.logical_or(t < mean - 2*std, t > mean + 2*std) 9 | if not torch.sum(cond): 10 | break 11 | t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t) 12 | return t 13 | class multi_head_attention(torch.nn.Module): 14 | def __init__(self,d_key,d_value,d_model,n_head,attention_dropout): 15 | super(multi_head_attention,self).__init__() 16 | self.d_key=d_key 17 | self.d_value=d_value 18 | self.d_model=d_model 19 | self.n_head=n_head 20 | self.attention_dropout=attention_dropout 21 | 22 | self.layer_q=torch.nn.Linear(self.d_model,self.d_key * self.n_head) 23 | self.layer_q.weight.data=truncated_normal(self.layer_q.weight.data,std=0.02) 24 | 25 | torch.nn.init.constant_(self.layer_q.bias, 0.0) 26 | self.layer_k=torch.nn.Linear(self.d_model,self.d_key * self.n_head) 27 | self.layer_k.weight.data=truncated_normal(self.layer_k.weight.data,std=0.02) 28 | 29 | torch.nn.init.constant_(self.layer_k.bias, 0.0) 30 | self.layer_v=torch.nn.Linear(self.d_model,self.d_value * self.n_head) 31 | self.layer_v.weight.data=truncated_normal(self.layer_v.weight.data,std=0.02) 32 | 33 | torch.nn.init.constant_(self.layer_v.bias, 0.0) 34 | self.project_layer=torch.nn.Linear(d_value * n_head,self.d_model) 35 | self.project_layer.weight.data=truncated_normal(self.project_layer.weight.data,std=0.02) 36 | 37 | torch.nn.init.constant_(self.project_layer.bias, 0.0) 38 | 39 | def forward(self,queries,edges_key,edges_value,pos_matrix): 40 | 41 | batch_size=queries.size(0) 42 | max_seq_len=queries.size(1) 43 | 44 | keys = queries 45 | values = keys 46 | 47 | q=self.layer_q(queries).view(batch_size,-1,self.n_head,self.d_key).transpose(1,2) 48 | k=self.layer_k(keys).view(batch_size,-1,self.n_head,self.d_key).transpose(1,2) 49 | v=self.layer_v(values).view(batch_size,-1,self.n_head,self.d_value).transpose(1,2) 50 | 51 | scores1 = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(self.d_key) 52 | scores2 = torch.matmul(q.permute(2,0,1,3).contiguous().view(max_seq_len,-1,self.d_key),edges_key.transpose(-1,-2)).view(max_seq_len,-1,self.n_head,max_seq_len).permute(1,2,0,3)/ np.sqrt(self.d_key) 53 | scores=torch.add(scores1,scores2) 54 | scores=torch.add(scores,pos_matrix) 55 | 56 | weights=torch.nn.Dropout(self.attention_dropout)(torch.nn.Softmax(dim=-1)(scores)) 57 | 58 | context1= torch.matmul(weights,v) 59 | context2= torch.matmul(weights.permute(2,0,1,3).contiguous().view(max_seq_len,-1,max_seq_len),edges_value).view(max_seq_len,-1,self.n_head,self.d_value).permute(1,2,0,3) 60 | context=torch.add(context1,context2) 61 | 62 | output=context.transpose(1,2).contiguous().view(batch_size,-1,self.n_head*self.d_value) 63 | output=self.project_layer(output) 64 | return output 65 | 66 | 67 | class positionwise_feed_forward(torch.nn.Module): 68 | def __init__(self,d_inner_hid,d_model): 69 | super(positionwise_feed_forward,self).__init__() 70 | self.d_inner_hid=d_inner_hid 71 | self.d_hid=d_model 72 | 73 | self.fc1=torch.nn.Linear(self.d_hid,self.d_inner_hid) 74 | self.fc1.weight.data=truncated_normal(self.fc1.weight.data,std=0.02) 75 | 76 | torch.nn.init.constant_(self.fc1.bias, 0.0) 77 | self.fc2=torch.nn.Linear(self.d_inner_hid,self.d_hid) 78 | self.fc2.weight.data=truncated_normal(self.fc2.weight.data,std=0.02) 79 | 80 | torch.nn.init.constant_(self.fc2.bias, 0.0) 81 | 82 | def forward(self,x): 83 | return self.fc2(torch.nn.GELU()(self.fc1(x))) 84 | 85 | class encoder_layer(torch.nn.Module): 86 | def __init__(self,n_head,d_key,d_value,d_model,d_inner_hid,prepostprocess_dropout,attention_dropout): 87 | super(encoder_layer,self).__init__() 88 | self.n_head=n_head 89 | self.d_key=d_key 90 | self.d_value=d_value 91 | self.d_model=d_model 92 | self.d_inner_hid=d_inner_hid 93 | self.prepostprocess_dropout=prepostprocess_dropout 94 | self.attention_dropout=attention_dropout 95 | 96 | self.multi_head_attention=multi_head_attention(self.d_key,self.d_value,self.d_model,self.n_head,self.attention_dropout) 97 | self.layer_norm1=torch.nn.LayerNorm(normalized_shape=self.d_model,eps=1e-7,elementwise_affine=True) 98 | 99 | self.positionwise_feed_forward=positionwise_feed_forward(self.d_inner_hid,self.d_model) 100 | self.layer_norm2=torch.nn.LayerNorm(normalized_shape=self.d_model,eps=1e-7,elementwise_affine=True) 101 | 102 | def forward(self,enc_input,edges_key,edges_value,pos_matrix): 103 | attn_output = self.multi_head_attention( 104 | enc_input, 105 | edges_key, 106 | edges_value, 107 | pos_matrix) 108 | attn_output=self.layer_norm1(torch.add(enc_input,torch.nn.Dropout(self.prepostprocess_dropout)(attn_output))) 109 | 110 | ffd_output = self.positionwise_feed_forward(attn_output) 111 | ffd_output=self.layer_norm2(torch.add(attn_output,torch.nn.Dropout(self.prepostprocess_dropout)(ffd_output))) 112 | return ffd_output 113 | 114 | 115 | class encoder(torch.nn.Module): 116 | def __init__(self,n_layer,n_head,d_key,d_value,d_model,d_inner_hid,prepostprocess_dropout,attention_dropout): 117 | super(encoder,self).__init__() 118 | self.n_layer=n_layer 119 | self.n_head=n_head 120 | self.d_key=d_key 121 | self.d_value=d_value 122 | self.d_model=d_model 123 | self.d_inner_hid=d_inner_hid 124 | self.prepostprocess_dropout=prepostprocess_dropout 125 | self.attention_dropout=attention_dropout 126 | 127 | for nl in range(self.n_layer): 128 | setattr(self,"encoder_layer{}".format(nl),encoder_layer( 129 | self.n_head, 130 | self.d_key, 131 | self.d_value, 132 | self.d_model, 133 | self.d_inner_hid, 134 | self.prepostprocess_dropout, 135 | self.attention_dropout)) 136 | 137 | def forward(self,enc_input,edges_key,edges_value,pos_matrix): 138 | for nl in range(self.n_layer): 139 | enc_output = getattr(self,"encoder_layer{}".format(nl))( 140 | enc_input, 141 | edges_key, 142 | edges_value, 143 | pos_matrix) 144 | enc_input = enc_output 145 | return enc_output 146 | 147 | class GATlayer(torch.nn.Module): 148 | def __init__(self,input_size,sparsifier): 149 | super(GATlayer,self).__init__() 150 | self.input_size=input_size 151 | self.sparsifier=sparsifier 152 | self.fc1=torch.nn.Linear(self.input_size,self.input_size) 153 | self.attn_fc=torch.nn.Linear(self.input_size*2,1) 154 | self.leaky = torch.nn.LeakyReLU(0.1) 155 | self.softmax=torch.nn.Softmax(dim=2) 156 | self.elu=torch.nn.ELU() 157 | def forward(self,input,pos_matrix): 158 | input_tmp=input 159 | input_tmp=self.fc1(input_tmp) 160 | Wh=input_tmp.view(-1,self.sparsifier,self.input_size) 161 | head_type=Wh.repeat(1,1,self.sparsifier).view(input.size(0),-1,self.input_size) 162 | end_type=Wh.repeat(1,self.sparsifier,1) 163 | type_matrix=torch.cat((head_type,end_type),dim=-1) 164 | type_matrix=type_matrix.view(input.size(0),self.sparsifier,self.sparsifier,-1) 165 | 166 | base=end_type.view(input.size(0),self.sparsifier,self.sparsifier,self.input_size) 167 | 168 | scores=self.attn_fc(type_matrix) 169 | scores=self.leaky(scores) 170 | scores=scores.view(-1,self.sparsifier,self.sparsifier,1) 171 | 172 | scores=torch.add(scores,pos_matrix) 173 | weight=self.softmax(scores) 174 | weight_matrix=torch.mul(base,weight) 175 | output=self.elu(torch.sum(weight_matrix,dim=2)) 176 | return output 177 | 178 | 179 | 180 | class GAT_attention(torch.nn.Module): 181 | def __init__(self,input_size,sparsifier,gat_layers): 182 | super(GAT_attention,self).__init__() 183 | self.input_size=input_size 184 | self.sparsifier=sparsifier 185 | self.gat_layers=gat_layers 186 | for nl in range(self.gat_layers): 187 | setattr(self,"GATlayer{}".format(nl),GATlayer( 188 | self.input_size, 189 | self.sparsifier)) 190 | 191 | def forward(self, 192 | input, 193 | pos_matrix, 194 | pos): 195 | attn_input = input 196 | pos_matrix=pos_matrix 197 | for nl in range(self.gat_layers): 198 | attn_out = getattr(self,"GATlayer{}".format(nl))( 199 | attn_input, 200 | pos_matrix) 201 | attn_input = attn_out 202 | 203 | return torch.sum(torch.mul(attn_out,pos),dim=1) -------------------------------------------------------------------------------- /model/init_helios.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | from model.model import HeliosModel 4 | from model.graph_encoder import truncated_normal 5 | torch.set_printoptions(precision=16) 6 | 7 | class HELIOS(torch.nn.Module): 8 | def __init__(self, info, config, ent_types): 9 | super(HELIOS, self).__init__() 10 | self.config = config 11 | self.node_num = info["node_num"] 12 | 13 | self.node_embeddings = torch.nn.Embedding(self.node_num, self.config['dim']) 14 | self.node_embeddings.weight.data = truncated_normal(self.node_embeddings.weight.data, std=0.02) 15 | 16 | self.ent_types = ent_types 17 | self.edge_embedding_k = torch.nn.Embedding(self.config['num_edges'],self.config['dim'] // self.config['num_attention_heads']) 18 | self.edge_embedding_k.weight.data = truncated_normal(self.edge_embedding_k.weight.data, std=0.02) 19 | 20 | self.edge_embedding_v = torch.nn.Embedding(self.config['num_edges'],self.config['dim'] // self.config['num_attention_heads']) 21 | self.edge_embedding_v.weight.data = truncated_normal(self.edge_embedding_v.weight.data, std=0.02) 22 | 23 | 24 | self.heliosconfig = dict() 25 | self.heliosconfig['self_attention_layers'] = self.config['self_attention_layers'] 26 | self.heliosconfig['gat_layers'] = self.config['gat_layers'] 27 | self.heliosconfig['num_attention_heads'] = self.config['num_attention_heads'] 28 | self.heliosconfig['hidden_size'] = self.config['dim'] 29 | self.heliosconfig['intermediate_size'] = self.config['intermediate_size'] 30 | self.heliosconfig['hidden_dropout_prob'] = self.config['hidden_dropout_prob'] 31 | self.heliosconfig['attention_dropout_prob'] = self.config['attention_dropout_prob'] 32 | self.heliosconfig['vocab_size'] = self.node_num 33 | self.heliosconfig['num_relations'] = info['rel_num'] 34 | self.heliosconfig['num_types'] = info['type_num'] 35 | self.heliosconfig['num_edges'] = self.config['num_edges'] 36 | self.heliosconfig['max_arity'] = info['max_n'] 37 | self.heliosconfig['device'] = self.config['device'] 38 | self.heliosconfig['sparsifier'] = self.config['sparsifier'] 39 | self.model = HeliosModel(self.heliosconfig, self.node_embeddings, self.edge_embedding_k,self.edge_embedding_v).to(self.heliosconfig['device']) 40 | 41 | 42 | 43 | def forward(self,data,edge_labels,type_attn_l2_matrix): 44 | 45 | input_ids, input_mask, mask_pos, mask_label, mask_type, mask_t_label, groud_truth = data 46 | 47 | loss, sortidx = self.model(input_ids, input_mask, edge_labels, mask_pos, mask_label, mask_type, self.ent_types, groud_truth, type_attn_l2_matrix) 48 | 49 | 50 | return loss, sortidx -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn 4 | import copy 5 | from model.graph_encoder import encoder,truncated_normal,GAT_attention 6 | 7 | class HeliosModel(torch.nn.Module): 8 | def __init__(self, config, node_embeddings,edge_embedding_k,edge_embedding_v): 9 | super(HeliosModel, self).__init__() 10 | 11 | self._n_layer = config['self_attention_layers'] 12 | self._gat_layer=config['gat_layers'] 13 | self._n_head = config['num_attention_heads'] 14 | self._emb_size = config['hidden_size'] 15 | self._intermediate_size = config['intermediate_size'] 16 | self._prepostprocess_dropout = config['hidden_dropout_prob'] 17 | self._attention_dropout = config['attention_dropout_prob'] 18 | 19 | self._voc_size = config['vocab_size'] - config['num_types'] 20 | self._node_num = config['vocab_size'] 21 | self._n_relation = config['num_relations'] 22 | self._n_type = config['num_types'] 23 | self._n_edge = config['num_edges'] 24 | self._max_arity = config['max_arity'] 25 | self._max_seq_len = self._max_arity*2-1 26 | self._sparsifier = config['sparsifier'] 27 | 28 | self._device=config["device"] 29 | self.node_embedding=node_embeddings 30 | self.layer_norm1=torch.nn.LayerNorm(normalized_shape=self._emb_size,eps=1e-12,elementwise_affine=True) 31 | 32 | self.edge_embedding_k=edge_embedding_k 33 | self.edge_embedding_v=edge_embedding_v 34 | 35 | self.encoder_model = encoder( 36 | n_layer=self._n_layer, 37 | n_head=self._n_head, 38 | d_key=self._emb_size // self._n_head, 39 | d_value=self._emb_size // self._n_head, 40 | d_model=self._emb_size, 41 | d_inner_hid=self._intermediate_size, 42 | prepostprocess_dropout=self._prepostprocess_dropout, 43 | attention_dropout=self._attention_dropout) 44 | 45 | self.GAT_attention=GAT_attention( 46 | input_size=self._emb_size, 47 | sparsifier=self._sparsifier, 48 | gat_layers=self._gat_layer 49 | ) 50 | 51 | self.GAT_attention_l2=GAT_attention( 52 | input_size=self._emb_size, 53 | sparsifier=self._max_arity + self._sparsifier - 1, 54 | gat_layers=self._gat_layer 55 | ) 56 | 57 | 58 | 59 | self.layer_norm2=torch.nn.LayerNorm(normalized_shape=self._emb_size,eps=1e-7,elementwise_affine=True) 60 | 61 | self.fc2_bias = torch.nn.init.constant_(torch.nn.parameter.Parameter(torch.Tensor(self._n_relation)), 0.0) 62 | self.fc3_bias = torch.nn.init.constant_(torch.nn.parameter.Parameter(torch.Tensor(self._n_type)), 0.0) 63 | 64 | self.loss = torch.nn.CrossEntropyLoss() 65 | 66 | 67 | 68 | def forward(self, input_ids, input_mask, edge_labels, mask_pos, mask_label, mask_type, ent_types, groud_truth, type_attn_l2_matrix): 69 | 70 | mask_type_tmp=copy.copy(mask_type) 71 | emb_out = self.node_embedding(input_ids) 72 | emb_out = torch.nn.Dropout(self._prepostprocess_dropout)(self.layer_norm1(emb_out)) 73 | 74 | type_matrix = ent_types[:,1:] 75 | type_matrix = torch.cat((torch.tensor([2]*self._sparsifier).unsqueeze(0).to(self._device), type_matrix), 0) 76 | type_matrix = torch.cat((torch.tensor([1]*self._sparsifier).unsqueeze(0).to(self._device), type_matrix), 0) 77 | type_matrix = torch.cat((torch.tensor([0]*self._sparsifier).unsqueeze(0).to(self._device), type_matrix), 0) 78 | 79 | input_type_ids = input_ids[:, 0::2] - self._n_relation 80 | 81 | input_type_tmp = type_matrix.index_select(0,input_type_ids.view(-1)).view(-1,self._sparsifier) 82 | 83 | 84 | h_tmp = torch.sign(input_type_tmp).unsqueeze(2).float() 85 | pos = torch.sign(input_type_tmp).unsqueeze(2).float() 86 | h_tmp = torch.matmul(h_tmp,h_tmp.transpose(1,2)).unsqueeze(3) 87 | h_tmp = 1000000.0*(h_tmp-1.0) 88 | input_type = self.node_embedding(input_type_tmp) 89 | 90 | type_attn_output = self.GAT_attention( 91 | input=input_type, 92 | pos_matrix=h_tmp, 93 | pos=pos) 94 | 95 | type_attn_output = type_attn_output.view(input_ids.size(0),self._max_arity,self._emb_size) 96 | type_attn_output = type_attn_output.unsqueeze(1).repeat(1,self._max_arity,1,1) 97 | type_attn_output = type_attn_output.view(-1,self._max_arity,self._emb_size) 98 | 99 | 100 | type_attn_l2_matrix0 = type_attn_l2_matrix.unsqueeze(0).repeat(input_ids.size(0),1,1) 101 | type_attn_l2_matrix = type_attn_l2_matrix0.unsqueeze(3).repeat(1,1,1,self._emb_size) 102 | type_attn_l2_matrix = type_attn_l2_matrix.view(-1,self._max_arity-1,self._emb_size) 103 | 104 | type_attn_output = torch.gather(type_attn_output, 1, type_attn_l2_matrix) 105 | 106 | 107 | type_attn_output_tmp = torch.abs(torch.sum(type_attn_output, 2)) 108 | 109 | type_attn_output = torch.cat((input_type, type_attn_output), 1) 110 | 111 | input_type_tmp = torch.cat((input_type_tmp, type_attn_output_tmp), 1) 112 | h_tmp = torch.sign(input_type_tmp).unsqueeze(2).float() 113 | pos = torch.sign(input_type_tmp).unsqueeze(2).float() 114 | h_tmp = torch.matmul(h_tmp,h_tmp.transpose(1,2)).unsqueeze(3) 115 | h_tmp = 1000000.0*(h_tmp-1.0) 116 | 117 | type_attn_output = self.GAT_attention_l2( 118 | input=type_attn_output, 119 | pos_matrix=h_tmp, 120 | pos=pos) 121 | 122 | type_attn_output = type_attn_output.view(input_ids.size(0),self._max_arity,self._emb_size) 123 | 124 | emb_out[:,0::2,:] = type_attn_output 125 | 126 | mask_1=torch.tensor(1).to(self._device) 127 | mask_2=torch.tensor(2).to(self._device) 128 | 129 | 130 | mask_matrix=torch.where(mask_type==1,mask_2,mask_1) 131 | 132 | emb_out[range(input_ids.size(0)),mask_pos,:]=self.node_embedding(mask_matrix) 133 | 134 | edges_key = self.edge_embedding_k(edge_labels) 135 | edges_value = self.edge_embedding_v(edge_labels) 136 | edge_mask = torch.sign(edge_labels).unsqueeze(2) 137 | edges_key = torch.mul(edges_key, edge_mask) 138 | edges_value = torch.mul(edges_value, edge_mask) 139 | 140 | input_mask=input_mask.unsqueeze(2) 141 | self_attn_mask = torch.matmul(input_mask,input_mask.transpose(1,2)) 142 | self_attn_mask=1000000.0*(self_attn_mask-1.0) 143 | n_head_self_attn_mask = torch.stack([self_attn_mask] * self._n_head, dim=1) 144 | 145 | _enc_out = self.encoder_model( 146 | enc_input=emb_out, 147 | edges_key=edges_key, 148 | edges_value=edges_value, 149 | pos_matrix=n_head_self_attn_mask) 150 | 151 | mask_pos=mask_pos.unsqueeze(1) 152 | mask_pos=mask_pos[:,:,None].expand(-1,-1,self._emb_size) 153 | h_masked = torch.gather(input=_enc_out, dim=1, index=mask_pos).reshape([-1, _enc_out.size(-1)]) 154 | 155 | h_masked = torch.nn.GELU()(h_masked) 156 | h_masked = self.layer_norm2(h_masked) 157 | 158 | fc_out1 = torch.nn.functional.linear(h_masked, self.node_embedding.weight[self._voc_size:self._node_num,:], self.fc3_bias) 159 | 160 | fc_out2 = torch.nn.functional.linear(h_masked, self.node_embedding.weight[3:(3+self._n_relation),:], self.fc2_bias) 161 | 162 | fc_out = torch.cat((fc_out2,fc_out1),1) 163 | 164 | relation_indicator = torch.empty(input_ids.size(0), self._n_relation).to(self._device) 165 | torch.nn.init.constant_(relation_indicator,-1) 166 | 167 | entity_indicator = torch.empty(input_ids.size(0), (self._n_type)).to(self._device) 168 | torch.nn.init.constant_(entity_indicator,1) 169 | type_indicator = torch.cat((relation_indicator, entity_indicator), dim=1).to(self._device) 170 | mask_type=mask_type.unsqueeze(1) 171 | type_indicator = torch.mul(type_indicator, mask_type) 172 | type_indicator=torch.nn.functional.relu(type_indicator) 173 | 174 | fc_out_mask=1000000.0*(type_indicator-1.0) 175 | fc_out = torch.add(fc_out, fc_out_mask) 176 | 177 | rel_index=torch.where(mask_type_tmp==-1)[0] 178 | type_index=torch.where(mask_type_tmp==1)[0] 179 | 180 | one_hot_labels=groud_truth.float().to(self._device) 181 | t_one_hot_labels=one_hot_labels[type_index,self._n_relation:] 182 | r_one_hot_labels=one_hot_labels[rel_index,:self._n_relation] 183 | 184 | t_one_hot_labels=t_one_hot_labels /torch.sum(t_one_hot_labels,dim=1).unsqueeze(1) 185 | r_one_hot_labels=r_one_hot_labels /torch.sum(r_one_hot_labels,dim=1).unsqueeze(1) 186 | 187 | fc_out2=fc_out2[rel_index,:] 188 | fc_out1=fc_out1[type_index,:] 189 | 190 | loss_r=self.loss(fc_out2,r_one_hot_labels) 191 | loss_t = self.loss(fc_out1,t_one_hot_labels) 192 | loss=(loss_r+loss_t) 193 | 194 | return loss, fc_out 195 | 196 | 197 | -------------------------------------------------------------------------------- /reader/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import time 4 | import numpy as np 5 | 6 | def get_edge_labels(max_n): 7 | 8 | edge_labels = [] 9 | max_seq_length=2*max_n-1 10 | max_aux = max_n - 2 11 | 12 | edge_labels.append([0, 1, 0] + [0,0] * max_aux) 13 | edge_labels.append([1, 0, 2] + [3,0] * max_aux) 14 | edge_labels.append([0, 2, 0] + [0,0] * max_aux) 15 | for idx in range(max_aux): 16 | edge_labels.append( 17 | [0, 3, 0] + [0,0] * idx + [0,4] + [0,0] * (max_aux - idx - 1)) 18 | edge_labels.append( 19 | [0, 0, 0] + [0,0] * idx + [4,0] + [0,0] * (max_aux - idx - 1)) 20 | edge_labels = np.asarray(edge_labels).astype("int64").reshape( 21 | [max_seq_length, max_seq_length]) 22 | edge_labels=torch.from_numpy(edge_labels) 23 | 24 | return edge_labels 25 | 26 | 27 | 28 | 29 | 30 | def prepare_EC_info(ins_info, device): 31 | instance_info=dict() 32 | instance_info["node_num"]=ins_info['node_num'] 33 | instance_info["rel_num"]=ins_info['rel_num'] 34 | instance_info["type_num"] = ins_info['type_num'] 35 | instance_info["max_n"]=ins_info['max_n'] 36 | return instance_info -------------------------------------------------------------------------------- /reader/data_reader.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import copy 4 | from collections import Counter 5 | import logging 6 | import numpy as np 7 | import time 8 | 9 | logging.basicConfig( 10 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 11 | datefmt='%m/%d/%Y %H:%M:%S') 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.INFO) 14 | logger.info(logger.getEffectiveLevel()) 15 | 16 | def build_entity2types_dictionaries(dataset_name, entities_values2id): 17 | entityName2entityTypes = {} 18 | entityId2entityTypes = {} 19 | entityType2entityNames = {} 20 | entityType2entityIds = {} 21 | 22 | entity2type_file = open(dataset_name, "r") 23 | 24 | for line in entity2type_file: 25 | splitted_line = line.strip().split("\t") 26 | entity_name = splitted_line[0][8:] 27 | entity_type = splitted_line[1][6:] 28 | 29 | if entity_name not in entityName2entityTypes: 30 | entityName2entityTypes[entity_name] = [] 31 | if entity_type not in entityName2entityTypes[entity_name]: 32 | entityName2entityTypes[entity_name].append(entity_type) 33 | 34 | if entity_type not in entityType2entityNames: 35 | entityType2entityNames[entity_type] = [] 36 | if entity_name not in entityType2entityNames[entity_type]: 37 | entityType2entityNames[entity_type].append(entity_name) 38 | 39 | entity_id = entities_values2id[entity_name] 40 | if entity_id not in entityId2entityTypes: 41 | entityId2entityTypes[entity_id] = [] 42 | if entity_type not in entityId2entityTypes[entity_id]: 43 | entityId2entityTypes[entity_id].append(entity_type) 44 | 45 | if entity_type not in entityType2entityIds: 46 | entityType2entityIds[entity_type] = [] 47 | if entity_id not in entityType2entityIds[entity_type]: 48 | entityType2entityIds[entity_type].append(entity_id) 49 | 50 | entity2type_file.close() 51 | 52 | return entityName2entityTypes, entityId2entityTypes, entityType2entityNames, entityType2entityIds 53 | 54 | 55 | def build_type2id_v2(inputData): 56 | type2id = {} 57 | id2type = {} 58 | type_counter = 0 59 | with open(inputData) as entity2type_file: 60 | for line in entity2type_file: 61 | splitted_line = line.strip().split("\t") 62 | entity_type = splitted_line[1][6:] 63 | 64 | if entity_type not in type2id: 65 | type2id[entity_type] = type_counter 66 | id2type[type_counter] = entity_type 67 | type_counter += 1 68 | 69 | entity2type_file.close() 70 | return type2id, id2type 71 | 72 | 73 | def build_typeId2frequency(dataset_name, type2id): 74 | if "type2relation2type_ttv" in dataset_name: 75 | typeId2frequency = {} 76 | type_relation_type_file = open(dataset_name, "r") 77 | 78 | for line in type_relation_type_file: 79 | splitted_line = line.strip().split("\t") 80 | head_type = splitted_line[0][6:] 81 | tail_type = splitted_line[2][6:] 82 | head_type_id = type2id[head_type] 83 | tail_type_id = type2id[tail_type] 84 | 85 | if head_type_id not in typeId2frequency: 86 | typeId2frequency[head_type_id] = 0 87 | if tail_type_id not in typeId2frequency: 88 | typeId2frequency[tail_type_id] = 0 89 | 90 | typeId2frequency[head_type_id] += 1 91 | typeId2frequency[tail_type_id] += 1 92 | 93 | type_relation_type_file.close() 94 | elif "type2relation2type2key2type_ttv" in dataset_name: 95 | typeId2frequency = {} 96 | type_relation_type_file = open(dataset_name, "r") 97 | for line in type_relation_type_file: 98 | splitted_line = line.strip().split("\t") 99 | for i in range(0, len(splitted_line), 2): 100 | value_type = splitted_line[i][6:] 101 | value_type_id = type2id[value_type] 102 | if value_type_id not in typeId2frequency: 103 | typeId2frequency[value_type_id] = 0 104 | typeId2frequency[value_type_id] += 1 105 | type_relation_type_file.close() 106 | return typeId2frequency 107 | 108 | 109 | def build_entityId2SparsifierType(entities_values2id, type2id, entityId2entityTypes, sparsifier, typeId2frequency, 110 | unk_entity_type_id): 111 | entityId2SparsifierType = {} 112 | if sparsifier > 0: 113 | for i in entities_values2id: 114 | entityId = entities_values2id[i] 115 | if entityId in entityId2entityTypes: 116 | entityTypes = entityId2entityTypes[entityId] 117 | entityTypeIds = [] 118 | for j in entityTypes: 119 | entityTypeIds.append(type2id[j]) 120 | 121 | current_freq = {} 122 | for typeId in entityTypeIds: 123 | current_freq[typeId] = typeId2frequency[typeId] 124 | 125 | sorted_current_freq = sorted(current_freq.items(), key=lambda kv: kv[1], 126 | reverse=True)[:sparsifier] 127 | topNvalueTypes = [item[0] for item in sorted_current_freq] 128 | entityId2SparsifierType[entityId] = topNvalueTypes 129 | 130 | else: 131 | entityId2SparsifierType[entityId] = [unk_entity_type_id] 132 | return entityId2SparsifierType 133 | else: 134 | logger.info("SPARSIFIER ERROR!") 135 | 136 | 137 | def read_facts_new(file, entityName2SparsifierType): 138 | facts_list = list() 139 | max_n = 0 140 | entity_list = list() 141 | relation_list = list() 142 | type_list = list() 143 | with open(file, 'r', encoding='utf8') as f: 144 | for line in f: 145 | fact = list() 146 | obj = json.loads(line) 147 | if obj['N'] > 7: 148 | continue 149 | if obj['N'] > max_n: 150 | max_n = obj['N'] 151 | flag = 0 152 | for key in obj: 153 | if flag == 0: 154 | fact.append(obj[key][0]) 155 | fact.append(key) 156 | fact.append(obj[key][1]) 157 | relation_list.append(key) 158 | entity_list.append(obj[key][0]) 159 | entity_list.append(obj[key][1]) 160 | break 161 | 162 | if obj['N'] > 2: 163 | for kv in obj.keys(): 164 | if kv != 'N' and kv != key: 165 | if isinstance(obj[kv], list): 166 | for item in obj[kv]: 167 | fact.append(kv) 168 | fact.append(item) 169 | relation_list.append(kv) 170 | entity_list.append(item) 171 | else: 172 | fact.append(kv) 173 | fact.append(obj[kv]) 174 | relation_list.append(kv) 175 | entity_list.append(obj[kv]) 176 | values = fact[0::2] 177 | for i in values: 178 | types = entityName2SparsifierType[i] 179 | type_list.extend(types) 180 | facts_list.append(fact) 181 | return facts_list, max_n, relation_list, entity_list, type_list 182 | 183 | def read_dict_new(e_ls, r_ls, t_ls): 184 | dict_id = dict() 185 | dict_id['PAD'] = 0 186 | dict_id['E_MASK'] = 1 187 | dict_id['T_MASK'] = 2 188 | dict_num = 3 189 | 190 | rel_num = 0 191 | ent_num = 0 192 | type_num = 0 193 | 194 | for item in r_ls: 195 | dict_id[item] = dict_num 196 | dict_num += 1 197 | rel_num += 1 198 | 199 | for item in e_ls: 200 | dict_id[item] = dict_num 201 | dict_num += 1 202 | ent_num += 1 203 | 204 | for item in t_ls: 205 | dict_id[item] = dict_num 206 | dict_num += 1 207 | type_num += 1 208 | 209 | return dict_id, dict_num, rel_num, ent_num, type_num 210 | 211 | 212 | def facts_to_id(facts, max_n, node_dict, rel_num, ent_num, ent_types, type_num): 213 | mask_labels = list() 214 | mask_pos = list() 215 | mask_types = list() 216 | id_t_facts = list() 217 | id_t_masks = list() 218 | mask_t_labels = list() 219 | 220 | for fact in facts: 221 | id_fact = list() 222 | id_mask = list() 223 | 224 | for i, item in enumerate(fact): 225 | id_fact.append(node_dict[item]) 226 | id_mask.append(1.0) 227 | 228 | max_fact_length = 2 * max_n - 1 229 | arity = (len(id_fact) + 1) // 2 230 | for j, mask_label in enumerate(id_fact): 231 | x = copy.copy(id_fact) 232 | y = copy.copy(id_mask) 233 | if j % 2 == 0: 234 | x[j] = 2 + rel_num 235 | mask_type = 1 236 | mask_t_label = mask_label 237 | else: 238 | x[j] = 1 239 | mask_type = -1 240 | mask_t_label = rel_num 241 | 242 | x = x + [0, rel_num] * (max_n - arity) 243 | y = y + [0] * (max_n - arity) * 2 244 | 245 | id_t_facts.append(x) 246 | id_t_masks.append(y) 247 | mask_pos.append(j) 248 | mask_labels.append(mask_label) 249 | mask_types.append(mask_type) 250 | mask_t_labels.append(mask_t_label) 251 | 252 | return [id_t_facts, id_t_masks, mask_pos, mask_labels, mask_types, mask_t_labels] 253 | 254 | 255 | def update(train_facts,train_ground_truth,train_ground_truth_keys,train_max_type_num): 256 | [id_t_facts, id_t_masks, mask_pos, mask_labels, mask_types, mask_t_labels]=train_facts 257 | groud_truth=list() 258 | for i in range(len(mask_pos)): 259 | tmp=list() 260 | tmp=train_ground_truth[mask_pos[i]][train_ground_truth_keys[i]] 261 | assert len(tmp) >0 262 | groud_truth.append(tmp) 263 | return [id_t_facts, id_t_masks, mask_pos, mask_labels, mask_types, mask_t_labels, groud_truth] 264 | def get_truth_eval_new(all_facts,max_n,node_dict,ent_types,sparsifier,num_rel): 265 | max_aux=max_n-2 266 | max_seq_length = 2 * max_aux + 3 267 | gt_dict = collections.defaultdict(lambda: collections.defaultdict(list)) 268 | all_fact_ids=list() 269 | max_len=0 270 | for fact in all_facts: 271 | id_fact=list() 272 | for i,item in enumerate(fact): 273 | id_fact.append(node_dict[item]) 274 | all_fact_id=copy.copy(id_fact) 275 | all_fact_ids.append(all_fact_id) 276 | while len(id_fact)<(2*max_n-1): 277 | id_fact.append(0) 278 | for pos in range(max_seq_length): 279 | if id_fact[pos]==0: 280 | continue 281 | keys=list() 282 | for j in range(max_seq_length): 283 | if j ==pos : 284 | continue 285 | if j % 2 ==0 : 286 | if id_fact[j] ==0: 287 | keys.append(str(id_fact[j])) 288 | else: 289 | keys.append(' '.join (str(x) for x in ent_types[id_fact[j]-3-num_rel][1:])) 290 | else: 291 | keys.append(str(id_fact[j])) 292 | key=" ".join(keys[x] for x in range(len(keys))) 293 | if pos %2 ==0: 294 | value=set(ent_types[id_fact[pos]-3-num_rel][1:]) 295 | value.discard(0) 296 | gt_dict[pos][key]=list(set(gt_dict[pos][key])|value) 297 | if(len(gt_dict[pos][key])>max_len): 298 | max_len=len(gt_dict[pos][key]) 299 | else : 300 | gt_dict[pos][key].append(id_fact[pos]) 301 | return gt_dict,all_fact_ids 302 | 303 | 304 | def get_ground_truth(train_facts,max_n,node_dict,ent_types,sparsifier, rel_num, ent_num, type_num): 305 | max_aux=max_n-2 306 | max_seq_length = 2 * max_aux + 3 307 | gt_dict = collections.defaultdict(lambda: collections.defaultdict(list)) 308 | all_fact_ids=list() 309 | ground_truth_keys=list() 310 | max_len=0 311 | for fact in train_facts: 312 | id_fact=list() 313 | for i,item in enumerate(fact): 314 | id_fact.append(node_dict[item]) 315 | all_fact_id=copy.copy(id_fact) 316 | all_fact_ids.append(all_fact_id) 317 | while len(id_fact)<(2*max_n-1): 318 | id_fact.append(0) 319 | for pos in range(max_seq_length): 320 | if id_fact[pos]==0: 321 | continue 322 | keys=list() 323 | for j in range(max_seq_length): 324 | if j ==pos : 325 | continue 326 | if j % 2 ==0 : 327 | if id_fact[j] ==0: 328 | keys.append(str(id_fact[j])) 329 | else: 330 | keys.append(' '.join (str(x) for x in ent_types[id_fact[j]-3-rel_num][1:])) 331 | else: 332 | keys.append(str(id_fact[j])) 333 | key=" ".join(keys[x] for x in range(len(keys))) 334 | ground_truth_keys.append(key) 335 | if pos %2 ==0: 336 | value=set(ent_types[id_fact[pos]-3-rel_num][1:]) 337 | value.discard(0) 338 | value=set(np.array(list(value))-ent_num-3) 339 | gt_dict[pos][key]=list(set(gt_dict[pos][key])|value) 340 | if(len(gt_dict[pos][key])>max_len): 341 | max_len=len(gt_dict[pos][key]) 342 | else : 343 | gt_dict[pos][key].append(id_fact[pos]-3) 344 | return gt_dict,ground_truth_keys,max_len 345 | 346 | def helper(train_ground_truth): 347 | max_len=0 348 | for key,values in train_ground_truth.items(): 349 | for k,v in values.items(): 350 | train_ground_truth[key][k]=list(set(v)) 351 | if(len(train_ground_truth[key][k])>max_len): 352 | max_len=len(train_ground_truth[key][k]) 353 | return train_ground_truth 354 | def update_facts(train_facts,valid_facts,test_facts,ent_types,node_dict,rel_num,max_n): 355 | start=time.time() 356 | type_valid_test_facts=list() 357 | trian_facts_new=list() 358 | removed_train_facts=list() 359 | max_seq_length = 2 * max_n -1 360 | for fact in test_facts: 361 | id_fact=list() 362 | for i,item in enumerate(fact): 363 | id_fact.append(node_dict[item]) 364 | while len(id_fact)<(2*max_n-1): 365 | id_fact.append(0) 366 | list_tmp=list() 367 | for i in range(max_seq_length): 368 | if id_fact[i]==0 or i%2!=0: 369 | list_tmp.append(id_fact[i]) 370 | else: 371 | list_tmp.append(ent_types[id_fact[i]-3-rel_num][1:]) 372 | type_valid_test_facts.append(list_tmp) 373 | 374 | for fact in valid_facts: 375 | id_fact=list() 376 | for i,item in enumerate(fact): 377 | id_fact.append(node_dict[item]) 378 | while len(id_fact)<(2*max_n-1): 379 | id_fact.append(0) 380 | list_tmp=list() 381 | for i in range(max_seq_length): 382 | if id_fact[i]==0 or i%2!=0: 383 | list_tmp.append(id_fact[i]) 384 | else: 385 | list_tmp.append(ent_types[id_fact[i]-3-rel_num][1:]) 386 | type_valid_test_facts.append(list_tmp) 387 | 388 | for fact in train_facts: 389 | id_fact=list() 390 | for i,item in enumerate(fact): 391 | id_fact.append(node_dict[item]) 392 | while len(id_fact)<(2*max_n-1): 393 | id_fact.append(0) 394 | list_tmp=list() 395 | for i in range(max_seq_length): 396 | if id_fact[i]==0 or i%2!=0: 397 | list_tmp.append(id_fact[i]) 398 | else: 399 | list_tmp.append(ent_types[id_fact[i]-3-rel_num][1:]) 400 | if list_tmp not in type_valid_test_facts: 401 | trian_facts_new.append(fact) 402 | else: 403 | removed_train_facts.append(fact) 404 | return trian_facts_new,valid_facts,test_facts 405 | 406 | 407 | 408 | def ent_to_type(e_list, entityName2SparsifierType, node_dict, sparsifier): 409 | type_ls = list() 410 | type_pos_ls = list() 411 | for e in e_list: 412 | e_types = list() 413 | type_pos = list() 414 | e_types.append(node_dict[e]) 415 | e_type = entityName2SparsifierType[e] 416 | e_type = [node_dict[i] for i in e_type] 417 | e_types.extend(e_type) 418 | type_pos.extend([1] * len(e_types)) 419 | while len(e_types) < (sparsifier+1): 420 | e_types.append(0) 421 | type_pos.append(0) 422 | type_ls.append(e_types) 423 | type_pos_ls.append(type_pos) 424 | 425 | return type_ls, type_pos_ls 426 | 427 | 428 | def get_input(train_file, valid_file, test_file, file1, file2, file3, sparsifier, entities_values2id): 429 | 430 | entityName2entityTypes, entityId2entityTypes, entityType2entityNames, entityType2entityIds = \ 431 | build_entity2types_dictionaries(file1, entities_values2id) 432 | 433 | type2id, id2type = build_type2id_v2(file1) 434 | n_types = len(type2id) 435 | 436 | entity_typeId2frequency = build_typeId2frequency(file2, type2id) 437 | value_typeId2frequency = build_typeId2frequency(file3, type2id) 438 | 439 | entity_typeId2frequency_tmp, value_typeId2frequency_tmp = Counter(entity_typeId2frequency), Counter( 440 | value_typeId2frequency) 441 | typeId2frequency = dict(entity_typeId2frequency_tmp + value_typeId2frequency_tmp) 442 | 443 | unk_entity_type_id = len(type2id) 444 | 445 | entityId2SparsifierType = build_entityId2SparsifierType(entities_values2id, 446 | type2id, 447 | entityId2entityTypes, 448 | sparsifier, 449 | typeId2frequency, 450 | unk_entity_type_id) 451 | 452 | id2type[n_types] = "unknown" 453 | type2id["unknown"] = n_types 454 | e_ls = list(entities_values2id.keys()) 455 | entityName2SparsifierType = dict() 456 | for key in entityId2SparsifierType.keys(): 457 | entityName2SparsifierType[e_ls[key]] = list() 458 | for item in entityId2SparsifierType[key]: 459 | typeName = id2type[item] + '_type' 460 | entityName2SparsifierType[e_ls[key]].append(typeName) 461 | 462 | train_facts, max_train, train_r, train_e, train_t = read_facts_new(train_file, entityName2SparsifierType) 463 | valid_facts, max_valid, valid_r, valid_e, valid_t = read_facts_new(valid_file, entityName2SparsifierType) 464 | test_facts, max_test, test_r, test_e, test_t = read_facts_new(test_file, entityName2SparsifierType) 465 | 466 | max_n = max(max_train, max_valid, max_test) 467 | e_list = list(set(train_e + valid_e + test_e)) 468 | r_list = list(set(train_r + valid_r + test_r)) 469 | t_list = list(set(train_t + valid_t + test_t)) 470 | node_dict, node_num, rel_num, ent_num, type_num = read_dict_new(e_list,r_list,t_list) 471 | 472 | ent_types, _ = ent_to_type(e_list, entityName2SparsifierType, node_dict, sparsifier) 473 | train_facts,valid_facts,test_facts = update_facts(train_facts,valid_facts,test_facts,ent_types,node_dict,rel_num,max_n) 474 | train_facts_tmp=train_facts 475 | valid_facts_tmp=valid_facts 476 | test_facts_tmp=test_facts 477 | all_facts = train_facts + valid_facts + test_facts 478 | train_facts = facts_to_id(train_facts, max_n, node_dict, rel_num, ent_num, ent_types, type_num) 479 | valid_facts = facts_to_id(valid_facts, max_n, node_dict, rel_num, ent_num, ent_types, type_num) 480 | test_facts = facts_to_id(test_facts, max_n, node_dict, rel_num, ent_num, ent_types, type_num) 481 | 482 | train_ground_truth,train_ground_truth_keys,train_max_type_num=get_ground_truth(train_facts_tmp, max_n, node_dict, ent_types, sparsifier, rel_num, ent_num, type_num) 483 | train_ground_truth=helper(train_ground_truth) 484 | train_facts=update(train_facts,train_ground_truth,train_ground_truth_keys,train_max_type_num) 485 | 486 | valid_ground_truth,valid_ground_truth_keys,valid_max_type_num=get_ground_truth(valid_facts_tmp, max_n, node_dict, ent_types, sparsifier, rel_num, ent_num, type_num) 487 | valid_ground_truth=helper(valid_ground_truth) 488 | valid_facts=update(valid_facts,valid_ground_truth,valid_ground_truth_keys,valid_max_type_num) 489 | 490 | test_ground_truth,test_ground_truth_keys,test_max_type_num=get_ground_truth(test_facts_tmp, max_n, node_dict, ent_types, sparsifier, rel_num, ent_num, type_num) 491 | test_ground_truth=helper(test_ground_truth) 492 | test_facts=update(test_facts,test_ground_truth,test_ground_truth_keys,test_max_type_num) 493 | 494 | 495 | all_facts, all_fact_ids = get_truth_eval_new(all_facts, max_n, node_dict, ent_types, sparsifier, rel_num) 496 | 497 | input_info = dict() 498 | input_info['train_facts'] = train_facts 499 | input_info['valid_facts'] = valid_facts 500 | input_info['test_facts'] = test_facts 501 | input_info['ent_types'] = ent_types 502 | input_info['node_dict'] = node_dict 503 | input_info['node_num'] = node_num 504 | input_info['rel_num'] = rel_num 505 | input_info['ent_num'] = ent_num 506 | input_info['type_num'] = type_num 507 | input_info['max_n'] = max_n 508 | input_info['all_facts_eval'] = all_facts 509 | 510 | return input_info 511 | 512 | 513 | 514 | def read_input(folder, sparsifier, entities_values2id): 515 | ins_info = get_input(folder + "/n-ary_train.json", folder + "/n-ary_valid.json", folder + "/n-ary_test.json", 516 | folder + "/entity2types_ttv.txt", folder + "/type2relation2type_ttv.txt", 517 | folder + "/type2relation2type2key2type_ttv.txt", sparsifier, entities_values2id) 518 | 519 | logger.info("Number of train facts: " + str(len(ins_info['train_facts'][0]))) 520 | logger.info("Number of valid facts: " + str(len(ins_info['valid_facts'][0]))) 521 | logger.info("Number of test facts: " + str(len(ins_info['test_facts'][0]))) 522 | logger.info("Number of relations: " + str(ins_info['rel_num'])) 523 | logger.info("Number of types: " + str(ins_info['type_num'])) 524 | logger.info("Number of max_n: " + str(ins_info['max_n'])) 525 | logger.info("Number of max_seq_length: " + str(2 * ins_info['max_n'] - 1)) 526 | 527 | return ins_info 528 | 529 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | import random 5 | import torch 6 | import pickle 7 | import numpy as np 8 | from utils.args import print_arguments 9 | from reader.data_reader import read_input 10 | from reader.data_loader import prepare_EC_info, get_edge_labels 11 | from model.init_helios import HELIOS 12 | import torch.utils.data.dataset as Dataset 13 | import torch.utils.data.dataloader as DataLoader 14 | from utils.evaluation import batch_evaluation 15 | 16 | logging.basicConfig( 17 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 18 | datefmt='%m/%d/%Y %H:%M:%S') 19 | logger = logging.getLogger(__name__) 20 | logger.setLevel(logging.INFO) 21 | logger.info(logger.getEffectiveLevel()) 22 | 23 | 24 | parser = argparse.ArgumentParser(description='Helios model') 25 | parser.add_argument('--dataset', type=str, default='wd50k',help="") 26 | parser.add_argument('--dim', type=int, default=256,help="") 27 | parser.add_argument('--learning_rate', type=float, default=0.0001,help="") 28 | parser.add_argument('--batch_size', type=int, default=1024,help="") 29 | parser.add_argument('--epochs', type=int, default=101,help="") 30 | parser.add_argument("--use_cuda", type=bool,default=True, help="") 31 | parser.add_argument("--gpu", type=int,default=0, help="") 32 | parser.add_argument('--intermediate_size', type=int, default=512,help="") 33 | parser.add_argument('--self_attention_layers', type=int, default=6,help="") 34 | parser.add_argument('--gat_layers', type=int, default=2,help="") 35 | parser.add_argument('--num_attention_heads', type=int, default=4,help="") 36 | parser.add_argument('--hidden_dropout_prob', type=float, default=0.1,help="") 37 | parser.add_argument('--attention_dropout_prob', type=float, default=0.1,help="") 38 | parser.add_argument('--num_edges', type=int, default=5,help="") 39 | parser.add_argument('--sparsifier', type=int, default=10,help="") 40 | args = parser.parse_args() 41 | args.dataset='./data/'+args.dataset 42 | 43 | class EDataset(Dataset.Dataset): 44 | def __init__(self, data,type_num,device): 45 | self.data=data 46 | self.type_num=type_num 47 | self.device=device 48 | def __len__(self): 49 | return len(self.data[0]) 50 | def __getitem__(self,index): 51 | label=np.zeros(self.type_num, dtype=np.float32) 52 | label[self.data[6][index]]=1 53 | return self.data[0][index],self.data[1][index],self.data[2][index],self.data[3][index],\ 54 | self.data[4][index],self.data[5][index],label 55 | 56 | 57 | 58 | def main(args): 59 | config = vars(args) 60 | if args.use_cuda: 61 | device = torch.device("cuda",args.gpu) 62 | config["device"]=device 63 | else: 64 | device = torch.device("cpu") 65 | config["device"]="cpu" 66 | 67 | with open(args.dataset + "/dictionaries_and_facts.bin", 'rb') as fin: 68 | data_info = pickle.load(fin) 69 | rel_keys2id = data_info['roles_indexes'] 70 | entities_values2id = data_info['values_indexes'] 71 | n_rel_keys = len(rel_keys2id) 72 | n_entities_values = len(entities_values2id) 73 | 74 | helios_info = read_input(args.dataset, args.sparsifier, entities_values2id) 75 | 76 | instance_info = prepare_EC_info(helios_info, device) 77 | edge_labels = get_edge_labels(helios_info['max_n']).to(device) 78 | 79 | type_attn_l2_matrix = torch.ones(helios_info['max_n']-1) 80 | for i in range(helios_info['max_n']): 81 | row = np.arange(helios_info['max_n']) 82 | row = np.delete(row, i) 83 | type_attn_l2_matrix = np.vstack((type_attn_l2_matrix, row)) 84 | 85 | type_attn_l2_matrix = type_attn_l2_matrix[1:] 86 | type_attn_l2_matrix = torch.tensor(type_attn_l2_matrix, dtype=torch.int64).to(device) 87 | 88 | model = HELIOS(instance_info, config, torch.tensor(helios_info['ent_types']).to(device)).to(device) 89 | 90 | helios_train_facts=list() 91 | for i,helios_train_fact in enumerate(helios_info['train_facts']): 92 | if i <6 : 93 | helios_train_fact=torch.tensor(helios_train_fact).to(device) 94 | helios_train_facts.append(helios_train_fact) 95 | train_data_E_reader=EDataset(helios_train_facts,helios_info['type_num']+helios_info['rel_num'],device) 96 | train_E_pyreader=DataLoader.DataLoader(train_data_E_reader,batch_size=args.batch_size,shuffle=True,drop_last=False) 97 | 98 | helios_valid_facts=list() 99 | for i,helios_valid_fact in enumerate(helios_info['valid_facts']): 100 | if i <6 : 101 | helios_valid_fact=torch.tensor(helios_valid_fact).to(device) 102 | helios_valid_facts.append(helios_valid_fact) 103 | valid_data_E_reader=EDataset(helios_valid_facts,helios_info['type_num']+helios_info['rel_num'],device) 104 | valid_E_pyreader=DataLoader.DataLoader( 105 | valid_data_E_reader, 106 | batch_size=args.batch_size, 107 | shuffle=False, 108 | drop_last=False) 109 | 110 | helios_test_facts=list() 111 | for i,helios_test_fact in enumerate(helios_info['test_facts']): 112 | if i <6 : 113 | helios_test_fact=torch.tensor(helios_test_fact).to(device) 114 | helios_test_facts.append(helios_test_fact) 115 | test_data_E_reader=EDataset(helios_test_facts,helios_info['type_num']+helios_info['rel_num'],device) 116 | test_E_pyreader=DataLoader.DataLoader( 117 | test_data_E_reader, 118 | batch_size=args.batch_size, 119 | shuffle=False, 120 | drop_last=False) 121 | 122 | helios_optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate']) 123 | scheduler = torch.optim.lr_scheduler.StepLR(helios_optimizer, step_size=20, gamma=0.9) 124 | 125 | for iteration in range(1, args.epochs): 126 | logger.info("iteration "+str(iteration)) 127 | model.train() 128 | helios_epoch_loss = 0 129 | start = time.time() 130 | for j,data in enumerate(train_E_pyreader): 131 | helios_pos = data 132 | 133 | helios_optimizer.zero_grad() 134 | helios_loss,_= model.forward(helios_pos,edge_labels,type_attn_l2_matrix) 135 | helios_loss.backward() 136 | helios_optimizer.step() 137 | helios_epoch_loss += helios_loss 138 | 139 | 140 | if j%100==0: 141 | logger.info(str(j)+' , loss: '+str(helios_loss.item())) 142 | scheduler.step() 143 | end = time.time() 144 | t2=round(end - start, 2) 145 | logger.info("epoch_loss = {:.3f}, time = {:.3f} s , lr= {}".format(helios_epoch_loss, t2,scheduler.get_last_lr()[0]) ) 146 | 147 | if iteration % 5 ==0 : 148 | model.eval() 149 | 150 | 151 | with torch.no_grad(): 152 | h1E = predict( 153 | model=model, 154 | helios_test_pyreader=valid_E_pyreader, 155 | helios_all_facts=helios_info['all_facts_eval'], 156 | edge_labels=edge_labels, 157 | type_attn_l2_matrix=type_attn_l2_matrix, 158 | max_n = instance_info['max_n'], 159 | ent_types = helios_info['ent_types'], 160 | rel_num = instance_info['rel_num'], 161 | ent_num = helios_info['ent_num'], 162 | is_test=False, 163 | device=device) 164 | 165 | h2E = predict( 166 | model=model, 167 | helios_test_pyreader=test_E_pyreader, 168 | helios_all_facts=helios_info['all_facts_eval'], 169 | edge_labels=edge_labels, 170 | type_attn_l2_matrix=type_attn_l2_matrix, 171 | max_n = instance_info['max_n'], 172 | ent_types = helios_info['ent_types'], 173 | rel_num = instance_info['rel_num'], 174 | ent_num = helios_info['ent_num'], 175 | is_test=True, 176 | device=device) 177 | 178 | 179 | logger.info("stop") 180 | 181 | def predict(model, helios_test_pyreader, helios_all_facts, edge_labels, type_attn_l2_matrix, max_n, ent_types, rel_num, ent_num, is_test, device): 182 | start=time.time() 183 | 184 | step = 0 185 | helios_ret_ranks=dict() 186 | 187 | helios_ret_ranks['entity_ap']=torch.empty(0).to(device) 188 | helios_ret_ranks['entity_ndcg'] = torch.empty(0).to(device) 189 | helios_ret_ranks['entity_p_1'] = torch.empty(0).to(device) 190 | helios_ret_ranks['entity_p_5'] = torch.empty(0).to(device) 191 | helios_ret_ranks['entity_p_10'] = torch.empty(0).to(device) 192 | helios_ret_ranks['entity_r_1'] = torch.empty(0).to(device) 193 | helios_ret_ranks['entity_r_5'] = torch.empty(0).to(device) 194 | helios_ret_ranks['entity_r_10'] = torch.empty(0).to(device) 195 | 196 | helios_ret_ranks['ht_ap'] = torch.empty(0).to(device) 197 | helios_ret_ranks['ht_ndcg'] = torch.empty(0).to(device) 198 | helios_ret_ranks['ht_p_1'] = torch.empty(0).to(device) 199 | helios_ret_ranks['ht_p_5'] = torch.empty(0).to(device) 200 | helios_ret_ranks['ht_p_10'] = torch.empty(0).to(device) 201 | helios_ret_ranks['ht_r_1'] = torch.empty(0).to(device) 202 | helios_ret_ranks['ht_r_5'] = torch.empty(0).to(device) 203 | helios_ret_ranks['ht_r_10'] = torch.empty(0).to(device) 204 | 205 | helios_ret_ranks['v_ap'] = torch.empty(0).to(device) 206 | helios_ret_ranks['v_ndcg'] = torch.empty(0).to(device) 207 | helios_ret_ranks['v_p_1'] = torch.empty(0).to(device) 208 | helios_ret_ranks['v_p_5'] = torch.empty(0).to(device) 209 | helios_ret_ranks['v_p_10'] = torch.empty(0).to(device) 210 | helios_ret_ranks['v_r_1'] = torch.empty(0).to(device) 211 | helios_ret_ranks['v_r_5'] = torch.empty(0).to(device) 212 | helios_ret_ranks['v_r_10'] = torch.empty(0).to(device) 213 | 214 | helios_ret_ranks['relation_ap']=torch.empty(0).to(device) 215 | helios_ret_ranks['relation_ndcg']=torch.empty(0).to(device) 216 | helios_ret_ranks['relation_p_1']=torch.empty(0).to(device) 217 | helios_ret_ranks['relation_p_5'] = torch.empty(0).to(device) 218 | helios_ret_ranks['relation_p_10'] = torch.empty(0).to(device) 219 | helios_ret_ranks['relation_r_1'] = torch.empty(0).to(device) 220 | helios_ret_ranks['relation_r_5'] = torch.empty(0).to(device) 221 | helios_ret_ranks['relation_r_10'] = torch.empty(0).to(device) 222 | 223 | helios_ret_ranks['r_ap']=torch.empty(0).to(device) 224 | helios_ret_ranks['r_ndcg']=torch.empty(0).to(device) 225 | helios_ret_ranks['r_p_1']=torch.empty(0).to(device) 226 | helios_ret_ranks['r_p_5'] = torch.empty(0).to(device) 227 | helios_ret_ranks['r_p_10'] = torch.empty(0).to(device) 228 | helios_ret_ranks['r_r_1'] = torch.empty(0).to(device) 229 | helios_ret_ranks['r_r_5'] = torch.empty(0).to(device) 230 | helios_ret_ranks['r_r_10'] = torch.empty(0).to(device) 231 | 232 | helios_ret_ranks['k_ap'] = torch.empty(0).to(device) 233 | helios_ret_ranks['k_ndcg'] = torch.empty(0).to(device) 234 | helios_ret_ranks['k_p_1'] = torch.empty(0).to(device) 235 | helios_ret_ranks['k_p_5'] = torch.empty(0).to(device) 236 | helios_ret_ranks['k_p_10'] = torch.empty(0).to(device) 237 | helios_ret_ranks['k_r_1'] = torch.empty(0).to(device) 238 | helios_ret_ranks['k_r_5'] = torch.empty(0).to(device) 239 | helios_ret_ranks['k_r_10'] = torch.empty(0).to(device) 240 | 241 | 242 | 243 | for i, data in enumerate(helios_test_pyreader): 244 | 245 | helios_pos = data 246 | _,helios_np_fc_out = model.forward(helios_pos,edge_labels,type_attn_l2_matrix) 247 | 248 | 249 | helios_pos[0][:,0::2] = torch.where(helios_pos[0][:,0::2] > rel_num + 2, helios_pos[0][:,0::2], helios_pos[0][:,0::2]-rel_num) 250 | 251 | 252 | helios_ret_ranks=batch_evaluation(helios_np_fc_out, helios_pos, helios_all_facts, helios_ret_ranks, ent_types, rel_num, ent_num, device) 253 | 254 | step += 1 255 | 256 | 257 | entity_ap = helios_ret_ranks['entity_ap'].mean().item() 258 | entity_ndcg = helios_ret_ranks['entity_ndcg'].mean().item() 259 | entity_p_1 = helios_ret_ranks['entity_p_1'].mean().item() 260 | entity_p_5 = helios_ret_ranks['entity_p_5'].mean().item() 261 | entity_p_10 = helios_ret_ranks['entity_p_10'].mean().item() 262 | entity_r_1 = helios_ret_ranks['entity_r_1'].mean().item() 263 | entity_r_5 = helios_ret_ranks['entity_r_5'].mean().item() 264 | entity_r_10 = helios_ret_ranks['entity_r_10'].mean().item() 265 | 266 | ht_ap = helios_ret_ranks['ht_ap'].mean().item() 267 | ht_ndcg = helios_ret_ranks['ht_ndcg'].mean().item() 268 | ht_p_1 = helios_ret_ranks['ht_p_1'].mean().item() 269 | ht_p_5 = helios_ret_ranks['ht_p_5'].mean().item() 270 | ht_p_10 = helios_ret_ranks['ht_p_10'].mean().item() 271 | ht_r_1 = helios_ret_ranks['ht_r_1'].mean().item() 272 | ht_r_5 = helios_ret_ranks['ht_r_5'].mean().item() 273 | ht_r_10 = helios_ret_ranks['ht_r_10'].mean().item() 274 | 275 | v_ap = helios_ret_ranks['v_ap'].mean().item() 276 | v_ndcg = helios_ret_ranks['v_ndcg'].mean().item() 277 | v_p_1 = helios_ret_ranks['v_p_1'].mean().item() 278 | v_p_5 = helios_ret_ranks['v_p_5'].mean().item() 279 | v_p_10 = helios_ret_ranks['v_p_10'].mean().item() 280 | v_r_1 = helios_ret_ranks['v_r_1'].mean().item() 281 | v_r_5 = helios_ret_ranks['v_r_5'].mean().item() 282 | v_r_10 = helios_ret_ranks['v_r_10'].mean().item() 283 | 284 | relation_ap = helios_ret_ranks['relation_ap'].mean().item() 285 | relation_ndcg = helios_ret_ranks['relation_ndcg'].mean().item() 286 | relation_p_1 = helios_ret_ranks['relation_p_1'].mean().item() 287 | relation_p_5 = helios_ret_ranks['relation_p_5'].mean().item() 288 | relation_p_10 = helios_ret_ranks['relation_p_10'].mean().item() 289 | relation_r_1 = helios_ret_ranks['relation_r_1'].mean().item() 290 | relation_r_5 = helios_ret_ranks['relation_r_5'].mean().item() 291 | relation_r_10 = helios_ret_ranks['relation_r_10'].mean().item() 292 | 293 | r_ap = helios_ret_ranks['r_ap'].mean().item() 294 | r_ndcg = helios_ret_ranks['r_ndcg'].mean().item() 295 | r_p_1 = helios_ret_ranks['r_p_1'].mean().item() 296 | r_p_5 = helios_ret_ranks['r_p_5'].mean().item() 297 | r_p_10 = helios_ret_ranks['r_p_10'].mean().item() 298 | r_r_1 = helios_ret_ranks['r_r_1'].mean().item() 299 | r_r_5 = helios_ret_ranks['r_r_5'].mean().item() 300 | r_r_10 = helios_ret_ranks['r_r_10'].mean().item() 301 | 302 | k_ap = helios_ret_ranks['k_ap'].mean().item() 303 | k_ndcg = helios_ret_ranks['k_ndcg'].mean().item() 304 | k_p_1 = helios_ret_ranks['k_p_1'].mean().item() 305 | k_p_5 = helios_ret_ranks['k_p_5'].mean().item() 306 | k_p_10 = helios_ret_ranks['k_p_10'].mean().item() 307 | k_r_1 = helios_ret_ranks['k_r_1'].mean().item() 308 | k_r_5 = helios_ret_ranks['k_r_5'].mean().item() 309 | k_r_10 = helios_ret_ranks['k_r_10'].mean().item() 310 | 311 | 312 | helios_all_entity = "ENT_TYPE\t\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % ( 313 | entity_ap, 314 | entity_ndcg, 315 | entity_p_1, 316 | entity_p_5, 317 | entity_p_10, 318 | entity_r_1, 319 | entity_r_5, 320 | entity_r_10) 321 | 322 | helios_all_relation = "RELATION\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % ( 323 | relation_ap, 324 | relation_ndcg, 325 | relation_p_1, 326 | relation_p_5, 327 | relation_p_10, 328 | relation_r_1, 329 | relation_r_5, 330 | relation_r_10) 331 | 332 | helios_all_ht = "H/T_TYPE\t\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % ( 333 | ht_ap, 334 | ht_ndcg, 335 | ht_p_1, 336 | ht_p_5, 337 | ht_p_10, 338 | ht_r_1, 339 | ht_r_5, 340 | ht_r_10) 341 | 342 | helios_all_r = "PRIMARY_R\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % ( 343 | r_ap, 344 | r_ndcg, 345 | r_p_1, 346 | r_p_5, 347 | r_p_10, 348 | r_r_1, 349 | r_r_5, 350 | r_r_10) 351 | 352 | helios_all_v = "VALUE_TYPE\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % ( 353 | v_ap, 354 | v_ndcg, 355 | v_p_1, 356 | v_p_5, 357 | v_p_10, 358 | v_r_1, 359 | v_r_5, 360 | v_r_10) 361 | 362 | helios_all_k = "KEY\t\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % ( 363 | k_ap, 364 | k_ndcg, 365 | k_p_1, 366 | k_p_5, 367 | k_p_10, 368 | k_r_1, 369 | k_r_5, 370 | k_r_10) 371 | 372 | 373 | 374 | if is_test: 375 | option='Evaluation' 376 | else: 377 | option='Validation' 378 | logger.info("\n-------- "+option+" Performance --------\n%s\n%s\n%s\n%s\n%s\n%s\n%s" % ( 379 | "\t".join(["TASK\t", "mAP", "NDCG", "Prec@1", "Prec@5", "Prec@10", "Recall@1", "Recall@5", "Recall@10"]), 380 | helios_all_ht, helios_all_r, helios_all_v, helios_all_k, helios_all_entity, helios_all_relation)) 381 | 382 | 383 | end=time.time() 384 | logger.info("time: "+str(round(end - start, 3))+'s') 385 | 386 | 387 | return None 388 | 389 | if __name__ == '__main__': 390 | print_arguments(args) 391 | main(args) 392 | -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | import six 2 | import logging 3 | 4 | logging.basicConfig( 5 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 6 | datefmt='%m/%d/%Y %H:%M:%S', 7 | level=logging.INFO) 8 | logging.getLogger().setLevel(logging.INFO) 9 | logger = logging.getLogger(__name__) 10 | 11 | def print_arguments(args): 12 | logger.info('----------- Configuration Arguments -----------') 13 | for arg, value in six.iteritems(vars(args)): 14 | logger.info('%s: %s' % (arg, value)) 15 | logger.info('------------------------------------------------') 16 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | torch.set_printoptions(precision=8) 5 | 6 | def batch_evaluation(batch_results, all_facts, gt_dict,ret_ranks,ent_types,rel_num,ent_num,device): 7 | for i, result in enumerate(batch_results): 8 | mask_type = all_facts[4][i] 9 | if mask_type == 1: 10 | target = ent_types[all_facts[3][i]-rel_num-3][1:] 11 | target = torch.LongTensor(target).to(device) 12 | target = target[torch.nonzero(target).squeeze(1)] 13 | pos = all_facts[2][i] 14 | id_fact=all_facts[0][i] 15 | keys=list() 16 | for j in range(id_fact.size(0)): 17 | if j ==pos : 18 | continue 19 | if j % 2 ==0 : 20 | if id_fact[j] ==0: 21 | keys.append(str(id_fact[j].item())) 22 | else: 23 | keys.append(' '.join (str(x) for x in ent_types[id_fact[j]-3-rel_num][1:])) 24 | else: 25 | keys.append(str(id_fact[j].item())) 26 | key=" ".join(keys[x] for x in range(len(keys))) 27 | 28 | 29 | rm_idx = torch.LongTensor(list(gt_dict[pos.item()][key])).to(device) 30 | rm_idx_tmp = torch.empty(0).to(device) 31 | for j in rm_idx: 32 | if j not in target: 33 | rm_idx_tmp = torch.cat([rm_idx_tmp,j.unsqueeze(0)],0) 34 | rm_idx = rm_idx_tmp.long() 35 | 36 | if rm_idx.shape != torch.Size([0]): 37 | result.index_fill_(0,rm_idx-ent_num-3,-np.Inf) 38 | 39 | sortidx = torch.argsort(result,dim=-1,descending=True) 40 | target = target - ent_num - 3 41 | ap, ndcg, p_1, p_5, p_10, r_1, r_5, r_10 = compute_metrics_new(sortidx, target) 42 | ap = torch.tensor(ap).float().unsqueeze(0).to(device) 43 | ret_ranks['entity_ap']=torch.cat([ret_ranks['entity_ap'],ap],dim=0) 44 | ndcg = torch.tensor(ndcg).float().unsqueeze(0).to(device) 45 | ret_ranks['entity_ndcg']=torch.cat([ret_ranks['entity_ndcg'],ndcg],dim=0) 46 | p_1 = torch.tensor(p_1).float().unsqueeze(0).to(device) 47 | ret_ranks['entity_p_1']=torch.cat([ret_ranks['entity_p_1'],p_1],dim=0) 48 | p_5 = torch.tensor(p_5).float().unsqueeze(0).to(device) 49 | ret_ranks['entity_p_5']=torch.cat([ret_ranks['entity_p_5'],p_5],dim=0) 50 | p_10 = torch.tensor(p_10).float().unsqueeze(0).to(device) 51 | ret_ranks['entity_p_10']=torch.cat([ret_ranks['entity_p_10'],p_10],dim=0) 52 | r_1 = torch.tensor(r_1).float().unsqueeze(0).to(device) 53 | ret_ranks['entity_r_1'] = torch.cat([ret_ranks['entity_r_1'], r_1], dim=0) 54 | r_5 = torch.tensor(r_5).float().unsqueeze(0).to(device) 55 | ret_ranks['entity_r_5'] = torch.cat([ret_ranks['entity_r_5'], r_5], dim=0) 56 | r_10 = torch.tensor(r_10).float().unsqueeze(0).to(device) 57 | ret_ranks['entity_r_10'] = torch.cat([ret_ranks['entity_r_10'], r_10], dim=0) 58 | if (pos == 0) or (pos == 2): 59 | ret_ranks['ht_ap'] = torch.cat([ret_ranks['ht_ap'], ap], dim=0) 60 | ret_ranks['ht_ndcg'] = torch.cat([ret_ranks['ht_ndcg'], ndcg], dim=0) 61 | ret_ranks['ht_p_1'] = torch.cat([ret_ranks['ht_p_1'], p_1], dim=0) 62 | ret_ranks['ht_p_5'] = torch.cat([ret_ranks['ht_p_5'], p_5], dim=0) 63 | ret_ranks['ht_p_10'] = torch.cat([ret_ranks['ht_p_10'], p_10], dim=0) 64 | ret_ranks['ht_r_1'] = torch.cat([ret_ranks['ht_r_1'], r_1], dim=0) 65 | ret_ranks['ht_r_5'] = torch.cat([ret_ranks['ht_r_5'], r_5], dim=0) 66 | ret_ranks['ht_r_10'] = torch.cat([ret_ranks['ht_r_10'], r_10], dim=0) 67 | else: 68 | ret_ranks['v_ap'] = torch.cat([ret_ranks['v_ap'], ap], dim=0) 69 | ret_ranks['v_ndcg'] = torch.cat([ret_ranks['v_ndcg'], ndcg], dim=0) 70 | ret_ranks['v_p_1'] = torch.cat([ret_ranks['v_p_1'], p_1], dim=0) 71 | ret_ranks['v_p_5'] = torch.cat([ret_ranks['v_p_5'], p_5], dim=0) 72 | ret_ranks['v_p_10'] = torch.cat([ret_ranks['v_p_10'], p_10], dim=0) 73 | ret_ranks['v_r_1'] = torch.cat([ret_ranks['v_r_1'], r_1], dim=0) 74 | ret_ranks['v_r_5'] = torch.cat([ret_ranks['v_r_5'], r_5], dim=0) 75 | ret_ranks['v_r_10'] = torch.cat([ret_ranks['v_r_10'], r_10], dim=0) 76 | 77 | 78 | 79 | else: 80 | target = all_facts[3][i] 81 | pos = all_facts[2][i] 82 | id_fact=all_facts[0][i] 83 | keys=list() 84 | for j in range(id_fact.size(0)): 85 | if j ==pos : 86 | continue 87 | if j % 2 ==0 : 88 | if id_fact[j] ==0: 89 | keys.append(str(id_fact[j].item())) 90 | else: 91 | keys.append(' '.join (str(x) for x in ent_types[id_fact[j]-3-rel_num][1:])) 92 | else: 93 | keys.append(str(id_fact[j].item())) 94 | key=" ".join(keys[x] for x in range(len(keys))) 95 | 96 | rm_idx = torch.LongTensor(gt_dict[pos.item()][key]).to(device) 97 | rm_idx=torch.where(rm_idx!=target,rm_idx,1) 98 | result.index_fill_(0,rm_idx-3,-np.Inf) 99 | sortidx = torch.argsort(result,dim=-1,descending=True) 100 | target = target - 3 101 | target = torch.LongTensor([target]).to(device) 102 | ap, ndcg, p_1, p_5, p_10, r_1, r_5, r_10 = compute_metrics_new(sortidx, target) 103 | ap = torch.tensor(ap).float().unsqueeze(0).to(device) 104 | ret_ranks['relation_ap'] = torch.cat([ret_ranks['relation_ap'], ap], dim=0) 105 | ndcg = torch.tensor(ndcg).float().unsqueeze(0).to(device) 106 | ret_ranks['relation_ndcg'] = torch.cat([ret_ranks['relation_ndcg'], ndcg], dim=0) 107 | p_1 = torch.tensor(p_1).float().unsqueeze(0).to(device) 108 | ret_ranks['relation_p_1'] = torch.cat([ret_ranks['relation_p_1'], p_1], dim=0) 109 | p_5 = torch.tensor(p_5).float().unsqueeze(0).to(device) 110 | ret_ranks['relation_p_5'] = torch.cat([ret_ranks['relation_p_5'], p_5], dim=0) 111 | p_10 = torch.tensor(p_10).float().unsqueeze(0).to(device) 112 | ret_ranks['relation_p_10'] = torch.cat([ret_ranks['relation_p_10'], p_10], dim=0) 113 | r_1 = torch.tensor(r_1).float().unsqueeze(0).to(device) 114 | ret_ranks['relation_r_1'] = torch.cat([ret_ranks['relation_r_1'], r_1], dim=0) 115 | r_5 = torch.tensor(r_5).float().unsqueeze(0).to(device) 116 | ret_ranks['relation_r_5'] = torch.cat([ret_ranks['relation_r_5'], r_5], dim=0) 117 | r_10 = torch.tensor(r_10).float().unsqueeze(0).to(device) 118 | ret_ranks['relation_r_10'] = torch.cat([ret_ranks['relation_r_10'], r_10], dim=0) 119 | if pos == 1: 120 | ret_ranks['r_ap'] = torch.cat([ret_ranks['r_ap'], ap], dim=0) 121 | ret_ranks['r_ndcg'] = torch.cat([ret_ranks['r_ndcg'], ndcg], dim=0) 122 | ret_ranks['r_p_1'] = torch.cat([ret_ranks['r_p_1'], p_1], dim=0) 123 | ret_ranks['r_p_5'] = torch.cat([ret_ranks['r_p_5'], p_5], dim=0) 124 | ret_ranks['r_p_10'] = torch.cat([ret_ranks['r_p_10'], p_10], dim=0) 125 | ret_ranks['r_r_1'] = torch.cat([ret_ranks['r_r_1'], r_1], dim=0) 126 | ret_ranks['r_r_5'] = torch.cat([ret_ranks['r_r_5'], r_5], dim=0) 127 | ret_ranks['r_r_10'] = torch.cat([ret_ranks['r_r_10'], r_10], dim=0) 128 | else: 129 | ret_ranks['k_ap'] = torch.cat([ret_ranks['k_ap'], ap], dim=0) 130 | ret_ranks['k_ndcg'] = torch.cat([ret_ranks['k_ndcg'], ndcg], dim=0) 131 | ret_ranks['k_p_1'] = torch.cat([ret_ranks['k_p_1'], p_1], dim=0) 132 | ret_ranks['k_p_5'] = torch.cat([ret_ranks['k_p_5'], p_5], dim=0) 133 | ret_ranks['k_p_10'] = torch.cat([ret_ranks['k_p_10'], p_10], dim=0) 134 | ret_ranks['k_r_1'] = torch.cat([ret_ranks['k_r_1'], r_1], dim=0) 135 | ret_ranks['k_r_5'] = torch.cat([ret_ranks['k_r_5'], r_5], dim=0) 136 | ret_ranks['k_r_10'] = torch.cat([ret_ranks['k_r_10'], r_10], dim=0) 137 | 138 | 139 | return ret_ranks 140 | 141 | def compute_metrics_new(ranked_list, ground_truth): 142 | 143 | # AP 144 | hits = 0 145 | sum_precs = 0 146 | ranked_list = ranked_list.cpu().detach().tolist() 147 | ground_truth = ground_truth.cpu().detach().tolist() 148 | for n in range(len(ranked_list)): 149 | if ranked_list[n] in ground_truth: 150 | hits += 1 151 | sum_precs += hits / (n+1.0) 152 | if hits > 0: 153 | AP = sum_precs / len(ground_truth) 154 | else: 155 | AP = 0 156 | 157 | # precision & recall @ k 158 | predict_1 = [ranked_list[0]] 159 | predict_5 = ranked_list[:5] 160 | predict_10 = ranked_list[:10] 161 | 162 | intersection_1 = len(set(predict_1)&set(ground_truth)) 163 | intersection_5 = len(set(predict_5) & set(ground_truth)) 164 | intersection_10 = len(set(predict_10) & set(ground_truth)) 165 | 166 | precision_1 = intersection_1 / len(predict_1) 167 | precision_5 = intersection_5 / len(predict_5) 168 | precision_10 = intersection_10 / len(predict_10) 169 | 170 | recall_1 = intersection_1 / len(ground_truth) 171 | recall_5 = intersection_5 / len(ground_truth) 172 | recall_10 = intersection_10 / len(ground_truth) 173 | 174 | # NDCG 175 | score = 0.0 176 | for rank, item in enumerate(ranked_list): 177 | if item in ground_truth: 178 | grade = 1.0 179 | score += grade / np.log2(rank + 2) 180 | 181 | norm = 0.0 182 | for rank in range(len(ground_truth)): 183 | grade = 1.0 184 | norm += grade / np.log2(rank + 2) 185 | 186 | ndcg = score / norm 187 | 188 | return AP, ndcg, precision_1, precision_5, precision_10, recall_1, recall_5, recall_10 --------------------------------------------------------------------------------