├── Code ├── attack_HAN.png ├── attack_HAN_RoHe.png ├── HAN │ ├── __pycache__ │ │ ├── HAN.cpython-36.pyc │ │ └── utils.cpython-36.pyc │ ├── model.py │ └── utils.py ├── HAN_RoHe │ ├── __pycache__ │ │ ├── model.cpython-36.pyc │ │ ├── utils.cpython-36.pyc │ │ └── gatconv_rohe.cpython-36.pyc │ ├── model.py │ ├── gatconv_rohe.py │ └── utils.py ├── data │ ├── generated_attacks │ │ ├── adv_acm_pap_pa_1.pkl │ │ ├── adv_acm_pap_pa_3.pkl │ │ ├── adv_acm_pap_pa_5.pkl │ │ ├── adv_aminer_prp_pr_1.pkl │ │ ├── adv_aminer_prp_pr_3.pkl │ │ ├── adv_aminer_prp_pr_5.pkl │ │ ├── adv_dblp_apa_pa_1.pkl │ │ ├── adv_dblp_apa_pa_3.pkl │ │ └── adv_dblp_apa_pa_5.pkl │ └── preprocess │ │ ├── split_data │ │ ├── data_split_acm.pkl │ │ ├── data_split_dblp.pkl │ │ └── data_split_aminer.pkl │ │ ├── target_nodes │ │ ├── acm_r_target0.pkl │ │ ├── acm_r_target1.pkl │ │ ├── acm_r_target2.pkl │ │ ├── acm_r_target3.pkl │ │ ├── acm_r_target4.pkl │ │ ├── dblp_r_target0.pkl │ │ ├── dblp_r_target1.pkl │ │ ├── dblp_r_target2.pkl │ │ ├── dblp_r_target3.pkl │ │ ├── dblp_r_target4.pkl │ │ ├── aminer_r_target0.pkl │ │ ├── aminer_r_target1.pkl │ │ ├── aminer_r_target2.pkl │ │ ├── aminer_r_target3.pkl │ │ └── aminer_r_target4.pkl │ │ └── pseudo_labels │ │ ├── acm_pseudo_labels.pkl │ │ ├── dblp_pseudo_labels.pkl │ │ └── aminer_pseudo_labels.pkl ├── attack_HAN.ipynb └── attack_HAN-RoHe.ipynb └── README.md /Code/attack_HAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/attack_HAN.png -------------------------------------------------------------------------------- /Code/attack_HAN_RoHe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/attack_HAN_RoHe.png -------------------------------------------------------------------------------- /Code/HAN/__pycache__/HAN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/HAN/__pycache__/HAN.cpython-36.pyc -------------------------------------------------------------------------------- /Code/HAN/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/HAN/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /Code/HAN_RoHe/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/HAN_RoHe/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /Code/HAN_RoHe/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/HAN_RoHe/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /Code/data/generated_attacks/adv_acm_pap_pa_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/generated_attacks/adv_acm_pap_pa_1.pkl -------------------------------------------------------------------------------- /Code/data/generated_attacks/adv_acm_pap_pa_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/generated_attacks/adv_acm_pap_pa_3.pkl -------------------------------------------------------------------------------- /Code/data/generated_attacks/adv_acm_pap_pa_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/generated_attacks/adv_acm_pap_pa_5.pkl -------------------------------------------------------------------------------- /Code/data/generated_attacks/adv_aminer_prp_pr_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/generated_attacks/adv_aminer_prp_pr_1.pkl -------------------------------------------------------------------------------- /Code/data/generated_attacks/adv_aminer_prp_pr_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/generated_attacks/adv_aminer_prp_pr_3.pkl -------------------------------------------------------------------------------- /Code/data/generated_attacks/adv_aminer_prp_pr_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/generated_attacks/adv_aminer_prp_pr_5.pkl -------------------------------------------------------------------------------- /Code/data/generated_attacks/adv_dblp_apa_pa_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/generated_attacks/adv_dblp_apa_pa_1.pkl -------------------------------------------------------------------------------- /Code/data/generated_attacks/adv_dblp_apa_pa_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/generated_attacks/adv_dblp_apa_pa_3.pkl -------------------------------------------------------------------------------- /Code/data/generated_attacks/adv_dblp_apa_pa_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/generated_attacks/adv_dblp_apa_pa_5.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/split_data/data_split_acm.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/split_data/data_split_acm.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/split_data/data_split_dblp.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/split_data/data_split_dblp.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/acm_r_target0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/acm_r_target0.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/acm_r_target1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/acm_r_target1.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/acm_r_target2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/acm_r_target2.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/acm_r_target3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/acm_r_target3.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/acm_r_target4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/acm_r_target4.pkl -------------------------------------------------------------------------------- /Code/HAN_RoHe/__pycache__/gatconv_rohe.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/HAN_RoHe/__pycache__/gatconv_rohe.cpython-36.pyc -------------------------------------------------------------------------------- /Code/data/preprocess/split_data/data_split_aminer.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/split_data/data_split_aminer.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/dblp_r_target0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/dblp_r_target0.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/dblp_r_target1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/dblp_r_target1.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/dblp_r_target2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/dblp_r_target2.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/dblp_r_target3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/dblp_r_target3.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/dblp_r_target4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/dblp_r_target4.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/pseudo_labels/acm_pseudo_labels.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/pseudo_labels/acm_pseudo_labels.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/aminer_r_target0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/aminer_r_target0.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/aminer_r_target1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/aminer_r_target1.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/aminer_r_target2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/aminer_r_target2.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/aminer_r_target3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/aminer_r_target3.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/target_nodes/aminer_r_target4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/target_nodes/aminer_r_target4.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/pseudo_labels/dblp_pseudo_labels.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/pseudo_labels/dblp_pseudo_labels.pkl -------------------------------------------------------------------------------- /Code/data/preprocess/pseudo_labels/aminer_pseudo_labels.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/HEAD/Code/data/preprocess/pseudo_labels/aminer_pseudo_labels.pkl -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RoHe 2 | 3 | The code of paper ”Robust Heterogeneous Graph Neural Networks against Adversarial Attacks“ (AAAI2022). 4 | 5 | Here we show the attack process in "attack_HAN.ipynb and attack_HAN_RoHe.ipynb". 6 | 7 | To clearly compare the robustness of HAN and our HAN-RoHe, we also provide their visualizations in Figure attack_HAN.png and attack_HAN_RoHe.png. 8 | *For simplicity, in the code we take 100 target nodes as example, which can be change to 500 in codes. 9 | 10 | 11 | # Usage 12 | Run attack_HAN.ipynb and attack_HAN_RoHe.ipynb 13 | 14 | 15 | # Requirements 16 | -torch 1.3.1 17 | 18 | -dgl 19 | 20 | -deeprobust 21 | 22 | -sklearn 23 | 24 | 25 | 26 | 27 | More details will be published soon (^▽^) 28 | -------------------------------------------------------------------------------- /Code/HAN/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl 5 | from dgl.nn.pytorch import GATConv 6 | 7 | class SemanticAttention(nn.Module): 8 | def __init__(self, in_size, hidden_size=128): 9 | super(SemanticAttention, self).__init__() 10 | 11 | self.project = nn.Sequential( 12 | nn.Linear(in_size, hidden_size), 13 | nn.Tanh(), 14 | nn.Linear(hidden_size, 1, bias=False) 15 | ) 16 | 17 | def forward(self, z): 18 | w = self.project(z).mean(0) # (M, 1) 19 | beta = torch.softmax(w, dim=0) # (M, 1) 20 | beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1) 21 | 22 | return (beta * z).sum(1) # (N, D * K) 23 | 24 | class HANLayer(nn.Module): 25 | def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout): 26 | super(HANLayer, self).__init__() 27 | 28 | # One GAT layer for each meta path based adjacency matrix 29 | self.gat_layers = nn.ModuleList() 30 | for i in range(len(meta_paths)): 31 | self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads, 32 | dropout, dropout, activation=F.elu)) 33 | self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads) 34 | self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths) 35 | 36 | self._cached_graph = None 37 | self._cached_coalesced_graph = {} 38 | 39 | def forward(self, g, h): 40 | semantic_embeddings = [] 41 | 42 | if self._cached_graph is None or self._cached_graph is not g: 43 | self._cached_graph = g 44 | self._cached_coalesced_graph.clear() 45 | for meta_path in self.meta_paths: 46 | self._cached_coalesced_graph[meta_path] = dgl.metapath_reachable_graph( 47 | g, meta_path) 48 | 49 | for i, meta_path in enumerate(self.meta_paths): 50 | new_g = self._cached_coalesced_graph[meta_path] 51 | semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1)) 52 | semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K) 53 | 54 | return self.semantic_attention(semantic_embeddings) # (N, D * K) 55 | 56 | class HAN(nn.Module): 57 | def __init__(self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout): 58 | super(HAN, self).__init__() 59 | self.layers = nn.ModuleList() 60 | self.layers.append(HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout)) 61 | for l in range(1, len(num_heads)): 62 | self.layers.append(HANLayer(meta_paths, hidden_size * num_heads[l-1], 63 | hidden_size, num_heads[l], dropout)) 64 | self.predict = nn.Linear(hidden_size * num_heads[-1], out_size) 65 | 66 | def forward(self, g, h): 67 | for gnn in self.layers: 68 | h = gnn(g, h) 69 | 70 | return self.predict(h) -------------------------------------------------------------------------------- /Code/HAN_RoHe/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import pickle as pkl 4 | import scipy.sparse as sp 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import dgl 9 | from HAN_RoHe.gatconv_rohe import GATConv 10 | import random 11 | import re 12 | import copy 13 | 14 | 15 | class SemanticAttention(nn.Module): 16 | def __init__(self, in_size, hidden_size=128): 17 | super(SemanticAttention, self).__init__() 18 | 19 | self.project = nn.Sequential( 20 | nn.Linear(in_size, hidden_size), 21 | nn.Tanh(), 22 | nn.Linear(hidden_size, 1, bias=False) 23 | ) 24 | 25 | def forward(self, z): 26 | w = self.project(z).mean(0) # (M, 1) 27 | beta = torch.softmax(w, dim=0) # (M, 1) 28 | beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1) 29 | 30 | return (beta * z).sum(1) # (N, D * K) 31 | 32 | class HANLayer(nn.Module): 33 | def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout, settings): 34 | super(HANLayer, self).__init__() 35 | 36 | self.gat_layers = nn.ModuleList() 37 | for i in range(len(meta_paths)): 38 | self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads, 39 | dropout, dropout, activation=F.elu, settings=settings[i])) 40 | self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads) 41 | self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths) 42 | 43 | self._cached_graph = None 44 | self._cached_coalesced_graph = {} 45 | 46 | def forward(self, g, h): 47 | semantic_embeddings = [] 48 | 49 | if self._cached_graph is None or self._cached_graph is not g: 50 | self._cached_graph = g 51 | self._cached_coalesced_graph.clear() 52 | for meta_path in self.meta_paths: 53 | self._cached_coalesced_graph[meta_path] = dgl.metapath_reachable_graph( 54 | g, meta_path) 55 | for i, meta_path in enumerate(self.meta_paths): 56 | new_g = self._cached_coalesced_graph[meta_path] 57 | semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1)) 58 | semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K) 59 | 60 | return self.semantic_attention(semantic_embeddings) # (N, D * K) 61 | 62 | class HAN(nn.Module): 63 | def __init__(self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout, settings): 64 | super(HAN, self).__init__() 65 | self.layers = nn.ModuleList() 66 | self.layers.append(HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout, settings)) 67 | for l in range(1, len(num_heads)): 68 | self.layers.append(HANLayer(meta_paths, hidden_size * num_heads[l-1], 69 | hidden_size, num_heads[l], dropout)) 70 | self.predict = nn.Linear(hidden_size * num_heads[-1], out_size) 71 | def forward(self, g, h): 72 | for gnn in self.layers: 73 | h = gnn(g, h) 74 | return self.predict(h) 75 | 76 | -------------------------------------------------------------------------------- /Code/HAN_RoHe/gatconv_rohe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from dgl import function as fn 5 | from dgl.utils import expand_as_pair 6 | from dgl.nn.pytorch.softmax import edge_softmax 7 | from dgl.nn.pytorch.utils import Identity 8 | 9 | 10 | class GATConv(nn.Module): 11 | def __init__(self, 12 | in_feats, 13 | out_feats, 14 | num_heads, 15 | feat_drop=0., 16 | attn_drop=0., 17 | negative_slope=0.2, 18 | residual=False, 19 | activation=None, 20 | settings={'K':10,'P':0.6,'tau':0.1,'Flag':"None"}): 21 | 22 | super(GATConv, self).__init__() 23 | self._num_heads = num_heads 24 | self.settings = settings 25 | self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) 26 | self._out_feats = out_feats 27 | if isinstance(in_feats, tuple): 28 | self.fc_src = nn.Linear( 29 | self._in_src_feats, out_feats * num_heads, bias=False) 30 | self.fc_dst = nn.Linear( 31 | self._in_dst_feats, out_feats * num_heads, bias=False) 32 | else: 33 | self.fc = nn.Linear( 34 | self._in_src_feats, out_feats * num_heads, bias=False) 35 | self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) 36 | self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) 37 | self.feat_drop = nn.Dropout(feat_drop) 38 | self.attn_drop = nn.Dropout(0.0) 39 | self.leaky_relu = nn.LeakyReLU(negative_slope) 40 | if residual: 41 | if self._in_dst_feats != out_feats: 42 | self.res_fc = nn.Linear( 43 | self._in_dst_feats, num_heads * out_feats, bias=False) 44 | else: 45 | self.res_fc = Identity() 46 | else: 47 | self.register_buffer('res_fc', None) 48 | self.reset_parameters() 49 | self.activation = activation 50 | 51 | def reset_parameters(self): 52 | gain = nn.init.calculate_gain('relu') 53 | if hasattr(self, 'fc'): 54 | nn.init.xavier_normal_(self.fc.weight, gain=gain) 55 | else: 56 | nn.init.xavier_normal_(self.fc_src.weight, gain=gain) 57 | nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) 58 | nn.init.xavier_normal_(self.attn_l, gain=gain) 59 | nn.init.xavier_normal_(self.attn_r, gain=gain) 60 | if isinstance(self.res_fc, nn.Linear): 61 | nn.init.xavier_normal_(self.res_fc.weight, gain=gain) 62 | 63 | def mask(self, attM): 64 | T = self.settings['T'] 65 | indices_to_remove = attM < torch.clamp(torch.topk(attM, T)[0][..., -1, None],min=0) 66 | attM[indices_to_remove] = -9e15 67 | return attM 68 | 69 | 70 | def forward(self, graph, feat): 71 | graph = graph.local_var() 72 | if isinstance(feat, tuple): 73 | h_src = self.feat_drop(feat[0]) 74 | h_dst = self.feat_drop(feat[1]) 75 | feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) 76 | feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) 77 | else: 78 | h_src = h_dst = self.feat_drop(feat) 79 | feat_src = feat_dst = self.fc(h_src).view( 80 | -1, self._num_heads, self._out_feats) 81 | N = graph.nodes().shape[0] 82 | N_e = graph.edges()[0].shape[0] 83 | graph.srcdata.update({'ft': feat_src}) 84 | 85 | # introduce transiting prior 86 | e_trans = torch.FloatTensor(self.settings['TransM'].data).view(N_e,1) 87 | e_trans = e_trans.repeat(1,8).resize_(N_e,8,1) 88 | 89 | # feature-based similarity 90 | e = torch.cat([torch.matmul(feat_src[:,i,:].view(N,self._out_feats),\ 91 | feat_src[:,i,:].t().view(self._out_feats,N))[graph.edges()[0], graph.edges()[1]].view(N_e,1)\ 92 | for i in range(self._num_heads)],dim=1).view(N_e,8,1) 93 | 94 | total_edge = torch.cat((graph.edges()[0].view(1,N_e),graph.edges()[1].view(1,N_e)),0) 95 | # confidence score in Eq(7) 96 | attn = torch.sparse.FloatTensor(total_edge,\ 97 | torch.squeeze((e.to('cpu') * e_trans).sum(-2)), torch.Size([N,N])).to(self.settings['device']) 98 | 99 | # purification mask in Eq(8) 100 | attn = self.mask(attn.to_dense()).t() 101 | e[attn[graph.edges()[0],graph.edges()[1]].view(N_e,1).repeat(1,8).view(N_e,8,1)<-100] = -9e15 102 | 103 | # obtain purified final attention in Eq(9) 104 | graph.edata['a'] = edge_softmax(graph, e) 105 | 106 | # message passing 107 | graph.update_all(fn.u_mul_e('ft', 'a', 'm'), 108 | fn.sum('m', 'ft')) 109 | rst = graph.dstdata['ft'] 110 | 111 | # residual 112 | if self.res_fc is not None: 113 | resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) 114 | rst = rst + resval 115 | 116 | # activation 117 | if self.activation: 118 | rst = self.activation(rst) 119 | return rst 120 | -------------------------------------------------------------------------------- /Code/attack_HAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import argparse\n", 11 | "import pickle as pkl\n", 12 | "import scipy.sparse as sp\n", 13 | "import torch.nn as nn\n", 14 | "import torch.nn.functional as F\n", 15 | "import numpy as np\n", 16 | "import dgl\n", 17 | "import random\n", 18 | "import re\n", 19 | "import copy \n", 20 | "from HAN.model import *\n", 21 | "from HAN.utils import *" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "dataname = 'acm'\n", 31 | "device = 1\n", 32 | "meta_paths_dict = {'acm':[['pa','ap'],['pf','fp']], \\\n", 33 | " 'dblp':[['ap','pa'],['ap','pc','cp','pa'],['ap','pt','tp','pa']],\\\n", 34 | " 'aminer':[['pa','ap'],['pr','rp']]}\n", 35 | "#1.init\n", 36 | "args = {}\n", 37 | "args['seed'] = 2\n", 38 | "args['hetero'] = True\n", 39 | "args['log_dir'] = 'results'\n", 40 | "args = setup(args)# following original setup in dgl\n", 41 | "g, hete_adjs, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, val_mask, test_mask = load_acm_raw(False)\n", 42 | "g = g.to(device)\n", 43 | "if hasattr(torch, 'BoolTensor'):\n", 44 | " train_mask = train_mask.bool()\n", 45 | " val_mask = val_mask.bool()\n", 46 | " test_mask = test_mask.bool()\n", 47 | "features = features.to(device)\n", 48 | "labels = labels.to(device)\n", 49 | "train_mask = train_mask.to(device)\n", 50 | "val_mask = val_mask.to(device)\n", 51 | "test_mask = test_mask.to(device)\n", 52 | "\n", 53 | "#2.train model\n", 54 | "model = HAN(meta_paths=meta_paths_dict[dataname],\n", 55 | " in_size=features.shape[1],\n", 56 | " hidden_size=args['hidden_units'],\n", 57 | " out_size=num_classes,\n", 58 | " num_heads=args['num_heads'],\n", 59 | " dropout=args['dropout']).to(device)\n", 60 | "\n", 61 | "stopper = EarlyStopping(patience=args['patience'])\n", 62 | "stopper.filename = 'save_model/mid_dglHan_'+dataname+'.pth'\n", 63 | "loss_fcn = torch.nn.CrossEntropyLoss()\n", 64 | "optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],\n", 65 | " weight_decay=args['weight_decay'])\n", 66 | "for epoch in range(args['num_epochs']):\n", 67 | " model.train()\n", 68 | " logits = model(g, features)\n", 69 | " loss = loss_fcn(logits[train_mask], labels[train_mask])\n", 70 | " optimizer.zero_grad()\n", 71 | " loss.backward()\n", 72 | " optimizer.step()\n", 73 | " train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask])\n", 74 | " val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, features, labels, val_mask, loss_fcn)\n", 75 | " early_stop = stopper.step(val_loss.data.item(), val_acc, model)\n", 76 | " print(epoch,\"|VAL Micro-F1:\",val_micro_f1,\", Macro-F1:\",val_macro_f1)\n", 77 | " if early_stop:\n", 78 | " break\n", 79 | "\n", 80 | "#3.test model\n", 81 | "stopper.load_checkpoint(model)\n", 82 | "test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, features, labels, test_mask, loss_fcn)\n", 83 | "print(\"@@@@test:\",test_acc, test_micro_f1, test_macro_f1)\n", 84 | "\n", 85 | "# return model, g, hete_adjs, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, val_mask, test_mask,test_acc, test_micro_f1, test_macro_f1\n" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "# load target node ID:\n", 95 | "tar_idx = []\n", 96 | "for i in range(1): # can attack 500 target nodes by seting range(5)\n", 97 | " with open('data/preprocess/target_nodes/'+dataname+'_r_target' + str(i) + '.pkl', 'rb') as f:\n", 98 | " tar_tmp = np.sort(pkl.load(f))\n", 99 | " tar_idx.extend(tar_tmp) \n", 100 | "\n", 101 | "# evaluate result\n", 102 | "with torch.no_grad():\n", 103 | " logits = model(g, features)\n", 104 | "logits_clean = logits[tar_idx]\n", 105 | "labels_clean = labels[tar_idx]\n", 106 | "_, tar_micro_f1_clean, tar_macro_f1_clean = score(logits_clean, labels_clean)\n", 107 | "print(\"Clean data: Micro-F1:\", tar_micro_f1_clean, \" Macro-F1:\",tar_macro_f1_clean)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "n_perturbation = 1\n", 117 | "adv_filename = 'data/generated_attacks/adv_acm_pap_pa_'+str(n_perturbation)+'.pkl'\n", 118 | "tar_mask = get_binary_mask(train_mask.shape[0], tar_idx)\n", 119 | "micro_f1_list_adv = []\n", 120 | "macro_f1_list_adv = []\n", 121 | "# load adversarial attacks for each target node\n", 122 | "with open(adv_filename,'rb') as f:\n", 123 | " modified_opt = pkl.load(f)\n", 124 | "#2.attack\n", 125 | "logits_adv = []\n", 126 | "labels_adv = []\n", 127 | "for items in modified_opt:\n", 128 | " #2.1 init\n", 129 | " target_node = items[0]\n", 130 | " del_list = items[2]\n", 131 | " add_list = items[3]\n", 132 | " if target_node not in tar_idx:\n", 133 | " continue\n", 134 | " #2.2 attack adjs\n", 135 | " mod_hete_adj_dict = copy.deepcopy(hete_adjs)\n", 136 | " for edge in del_list:\n", 137 | " mod_hete_adj_dict['pa'][edge[0],edge[1]] = 0\n", 138 | " mod_hete_adj_dict['ap'][edge[1],edge[0]] = 0\n", 139 | " for edge in add_list:\n", 140 | " mod_hete_adj_dict['pa'][edge[0],edge[1]] = 1\n", 141 | " mod_hete_adj_dict['ap'][edge[1],edge[0]] = 1\n", 142 | " hg_atk = get_hg(dataname, mod_hete_adj_dict).to(device)\n", 143 | " #2.3 run model\n", 144 | " with torch.no_grad():\n", 145 | " logits = model(hg_atk, features)\n", 146 | " #2.4 evaluate\n", 147 | " logits_adv.append(logits[np.array([[target_node]])])\n", 148 | " labels_adv.append(labels[np.array([[target_node]])])\n", 149 | "logits_adv = torch.cat(logits_adv,0)\n", 150 | "labels_adv = torch.cat(labels_adv)\n", 151 | "_, tar_micro_f1_atk, tar_macro_f1_atk = score(logits_adv, labels_adv)\n", 152 | "print(\"Attacked data: Micro-F1:\", tar_micro_f1_atk, \" Macro-F1:\",tar_macro_f1_atk)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "import matplotlib.pyplot as plt\n", 162 | "y_testAcc_name = 'Results of HAN (%)'\n", 163 | "plt.figure(figsize=(9, 10))#dpi=xx\n", 164 | "tick_label = ['Mi-F1','Ma-F1']\n", 165 | "Y_clean = [tar_micro_f1_clean*100, tar_macro_f1_clean*100]\n", 166 | "Y_attack = [tar_micro_f1_atk*100, tar_macro_f1_atk*100]\n", 167 | "font_size = 35\n", 168 | "X = np.arange(len(Y_attack))\n", 169 | "plt.ylim(0,120) \n", 170 | "bar_width = 0.2\n", 171 | "for x,y in zip(X,Y_clean):\n", 172 | " plt.text(x+0.05,y+0.005,'%d' %y, ha='center',va='bottom',fontsize=font_size)\n", 173 | "for x,y in zip(X,Y_attack):\n", 174 | " plt.text(x+0.25,y+0.005,'%d' %y, ha='center',va='bottom',fontsize=font_size)\n", 175 | "\n", 176 | "\n", 177 | "clean = plt.bar(X, Y_clean, width=bar_width, color = 'g',edgecolor='black')\n", 178 | "attack = plt.bar([x+0.2 for x in X], Y_attack, width=bar_width, color = 'red',edgecolor='black')\n", 179 | "\n", 180 | "\n", 181 | "#添加xy轴\n", 182 | "# plt.xlabel('',font)\n", 183 | "plt.ylabel(y_testAcc_name,{'family' : 'Times New Roman','weight':'bold', 'size':font_size})\n", 184 | "#x轴刻度\n", 185 | "print(X,tick_label)\n", 186 | "plt.xticks(X, tick_label, size=font_size+3)\n", 187 | "plt.yticks([0,20,40,60,80,100],fontsize = font_size)\n", 188 | "font_legend={'family' : 'Times New Roman', 'size':font_size}\n", 189 | "plt.legend([clean, attack],[r'clean', r'attack'],ncol=2,loc='upper center',bbox_to_anchor=(0.5,1.025), prop=font_legend)\n", 190 | "plt.savefig(\"attack_HAN.png\", bbox_inches='tight',dpi=400,pad_inches=0.0)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [] 199 | } 200 | ], 201 | "metadata": { 202 | "kernelspec": { 203 | "display_name": "Python 3", 204 | "language": "python", 205 | "name": "python3" 206 | }, 207 | "language_info": { 208 | "codemirror_mode": { 209 | "name": "ipython", 210 | "version": 3 211 | }, 212 | "file_extension": ".py", 213 | "mimetype": "text/x-python", 214 | "name": "python", 215 | "nbconvert_exporter": "python", 216 | "pygments_lexer": "ipython3", 217 | "version": "3.6.8" 218 | } 219 | }, 220 | "nbformat": 4, 221 | "nbformat_minor": 2 222 | } 223 | -------------------------------------------------------------------------------- /Code/attack_HAN-RoHe.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import argparse\n", 11 | "import pickle as pkl\n", 12 | "import scipy.sparse as sp\n", 13 | "import torch.nn as nn\n", 14 | "import torch.nn.functional as F\n", 15 | "import numpy as np\n", 16 | "import dgl\n", 17 | "import random\n", 18 | "import re\n", 19 | "import copy \n", 20 | "from HAN_RoHe.model import *\n", 21 | "from HAN_RoHe.utils import *" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "dataname = 'acm'\n", 31 | "settings_pap = {'T':2, 'device':2}# acm\n", 32 | "settings_psp = {'T':5, 'device':2}# acm\n", 33 | "settings = [settings_pap, settings_psp]\n", 34 | "device = 2\n", 35 | "meta_paths_dict = {'acm':[['pa','ap'],['pf','fp']], \\\n", 36 | " 'dblp':[['ap','pa'],['ap','pc','cp','pa'],['ap','pt','tp','pa']],\\\n", 37 | " 'aminer':[['pa','ap'],['pr','rp']]}\n", 38 | "#1.init\n", 39 | "args = {}\n", 40 | "args['seed'] = 2\n", 41 | "args['hetero'] = True\n", 42 | "args['log_dir'] = 'results'\n", 43 | "args = setup(args)\n", 44 | "g, hete_adjs, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, val_mask, test_mask = load_acm_raw(False)\n", 45 | "g = g.to(device)\n", 46 | "if hasattr(torch, 'BoolTensor'):\n", 47 | " train_mask = train_mask.bool()\n", 48 | " val_mask = val_mask.bool()\n", 49 | " test_mask = test_mask.bool()\n", 50 | "features = features.to(device)\n", 51 | "labels = labels.to(device)\n", 52 | "train_mask = train_mask.to(device)\n", 53 | "val_mask = val_mask.to(device)\n", 54 | "test_mask = test_mask.to(device)\n", 55 | "\n", 56 | "#2.generate transition matrix\n", 57 | "def get_transition(given_hete_adjs, metapath_info):\n", 58 | " # transition\n", 59 | " hete_adj_dict_tmp = {}\n", 60 | " for key in given_hete_adjs.keys():\n", 61 | " deg = given_hete_adjs[key].sum(1)\n", 62 | " hete_adj_dict_tmp[key] = given_hete_adjs[key]/(np.where(deg > 0, deg, 1))#make sure deg>0\n", 63 | " homo_adj_list = []\n", 64 | " for i in range(len(metapath_info)):\n", 65 | " adj = hete_adj_dict_tmp[metapath_info[i][0]]\n", 66 | " for etype in metapath_info[i][1:]:\n", 67 | " adj = adj.dot(hete_adj_dict_tmp[etype])\n", 68 | " homo_adj_list.append(sp.csc_matrix(adj))\n", 69 | " return homo_adj_list\n", 70 | "trans_adj_list = get_transition(hete_adjs, meta_paths_dict[dataname]) \n", 71 | "for i in range(len(trans_adj_list)):\n", 72 | " settings[i]['device'] = device\n", 73 | " settings[i]['TransM'] = trans_adj_list[i]\n", 74 | "\n", 75 | "\n", 76 | "#3.train model\n", 77 | "model = HAN(meta_paths=meta_paths_dict[dataname],\n", 78 | " in_size=features.shape[1],\n", 79 | " hidden_size=args['hidden_units'],\n", 80 | " out_size=num_classes,\n", 81 | " num_heads=args['num_heads'],\n", 82 | " dropout=args['dropout'],\n", 83 | " settings = settings).to(device)\n", 84 | "\n", 85 | "stopper = EarlyStopping(patience=args['patience'])\n", 86 | "# stopper.filename = 'atk_result/mid_routdglHan_hyper_'+dataname+'.pth'\n", 87 | "loss_fcn = torch.nn.CrossEntropyLoss()\n", 88 | "optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],\n", 89 | " weight_decay=args['weight_decay'])\n", 90 | "for epoch in range(args['num_epochs']):\n", 91 | " model.train()\n", 92 | " logits = model(g, features)\n", 93 | " loss = loss_fcn(logits[train_mask], labels[train_mask])\n", 94 | " optimizer.zero_grad()\n", 95 | " loss.backward()\n", 96 | " optimizer.step()\n", 97 | " train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask])\n", 98 | " val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, features, labels, val_mask, loss_fcn)\n", 99 | " early_stop = stopper.step(val_loss.data.item(), val_acc, model)\n", 100 | " print(epoch,\"|\",val_micro_f1, val_macro_f1)\n", 101 | " if early_stop:\n", 102 | " break\n", 103 | " \n", 104 | "#3.test model\n", 105 | "stopper.load_checkpoint(model)\n", 106 | "test_loss, _, test_micro_f1, test_macro_f1 = evaluate(model, g, features, labels, test_mask, loss_fcn)\n", 107 | "print(\"@@@@test:\", test_micro_f1, test_macro_f1) " 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# load target node ID:\n", 117 | "tar_idx = []\n", 118 | "for i in range(1): # can attack 500 target nodes by seting range(5)\n", 119 | " with open('data/preprocess/target_nodes/'+dataname+'_r_target' + str(i) + '.pkl', 'rb') as f:\n", 120 | " tar_tmp = np.sort(pkl.load(f))\n", 121 | " tar_idx.extend(tar_tmp) \n", 122 | "\n", 123 | "# evaluate result\n", 124 | "with torch.no_grad():\n", 125 | " logits = model(g, features)\n", 126 | "logits_clean = logits[tar_idx]\n", 127 | "labels_clean = labels[tar_idx]\n", 128 | "_, tar_micro_f1_clean, tar_macro_f1_clean = score(logits_clean, labels_clean)\n", 129 | "print(\"Clean data: Micro-F1:\", tar_micro_f1_clean, \" Macro-F1:\",tar_macro_f1_clean)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "n_perturbation = 1\n", 139 | "adv_filename = 'data/generated_attacks/adv_acm_pap_pa_'+str(n_perturbation)+'.pkl'\n", 140 | "tar_mask = get_binary_mask(train_mask.shape[0], tar_idx)\n", 141 | "micro_f1_list_adv = []\n", 142 | "macro_f1_list_adv = []\n", 143 | "# load adversarial attacks for each target node\n", 144 | "with open(adv_filename,'rb') as f:\n", 145 | " modified_opt = pkl.load(f)\n", 146 | "#2.attack\n", 147 | "logits_adv = []\n", 148 | "labels_adv = []\n", 149 | "for items in modified_opt:\n", 150 | " #2.1 init\n", 151 | " target_node = items[0]\n", 152 | " del_list = items[2]\n", 153 | " add_list = items[3]\n", 154 | " if target_node not in tar_idx:\n", 155 | " continue\n", 156 | " #2.2 attack adjs\n", 157 | " mod_hete_adj_dict = copy.deepcopy(hete_adjs)\n", 158 | " for edge in del_list:\n", 159 | " mod_hete_adj_dict['pa'][edge[0],edge[1]] = 0\n", 160 | " mod_hete_adj_dict['ap'][edge[1],edge[0]] = 0\n", 161 | " for edge in add_list:\n", 162 | " mod_hete_adj_dict['pa'][edge[0],edge[1]] = 1\n", 163 | " mod_hete_adj_dict['ap'][edge[1],edge[0]] = 1\n", 164 | " trans_adj_list = get_transition(mod_hete_adj_dict, meta_paths_dict[dataname]) \n", 165 | " for i in range(len(trans_adj_list)):\n", 166 | " model.layers[0].gat_layers[i].settings['device'] = device\n", 167 | " model.layers[0].gat_layers[i].settings['TransM'] = trans_adj_list[i]\n", 168 | " hg_atk = get_hg(dataname, mod_hete_adj_dict).to(device)\n", 169 | " #2.3 run model\n", 170 | " with torch.no_grad():\n", 171 | " logits = model(hg_atk, features)\n", 172 | " #2.4 evaluate\n", 173 | " logits_adv.append(logits[np.array([[target_node]])])\n", 174 | " labels_adv.append(labels[np.array([[target_node]])])\n", 175 | "logits_adv = torch.cat(logits_adv,0)\n", 176 | "labels_adv = torch.cat(labels_adv)\n", 177 | "_, tar_micro_f1_atk, tar_macro_f1_atk = score(logits_adv, labels_adv)\n", 178 | "print(\"Attacked data: Micro-F1:\", tar_micro_f1_atk, \" Macro-F1:\",tar_macro_f1_atk)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "import matplotlib.pyplot as plt\n", 188 | "y_testAcc_name = 'Results of HAN-RoHe (%)'\n", 189 | "plt.figure(figsize=(9, 10))#dpi=xx\n", 190 | "tick_label = ['Mi-F1','Ma-F1']\n", 191 | "Y_clean = [tar_micro_f1_clean*100, tar_macro_f1_clean*100]\n", 192 | "Y_attack = [tar_micro_f1_atk*100, tar_macro_f1_atk*100]\n", 193 | "font_size = 35\n", 194 | "X = np.arange(len(Y_attack))\n", 195 | "plt.ylim(0,120) \n", 196 | "bar_width = 0.2\n", 197 | "for x,y in zip(X,Y_clean):\n", 198 | " plt.text(x+0.05,y+0.005,'%d' %y, ha='center',va='bottom',fontsize=font_size)\n", 199 | "for x,y in zip(X,Y_attack):\n", 200 | " plt.text(x+0.25,y+0.005,'%d' %y, ha='center',va='bottom',fontsize=font_size)\n", 201 | "\n", 202 | "\n", 203 | "clean = plt.bar(X, Y_clean, width=bar_width, color = 'g',edgecolor='black')\n", 204 | "attack = plt.bar([x+0.2 for x in X], Y_attack, width=bar_width, color = 'red',edgecolor='black')\n", 205 | "plt.ylabel(y_testAcc_name,{'family' : 'Times New Roman','weight':'bold', 'size':font_size})\n", 206 | "plt.xticks(X, tick_label, size=font_size+3)\n", 207 | "plt.yticks([0,20,40,60,80,100],fontsize = font_size)\n", 208 | "font_legend={'family' : 'Times New Roman', 'size':font_size}\n", 209 | "plt.legend([clean, attack],[r'clean', r'attack'],ncol=2,loc='upper center',bbox_to_anchor=(0.5,1.025), prop=font_legend)\n", 210 | "plt.savefig(\"attack_HAN_RoHe.png\", bbox_inches='tight',dpi=400,pad_inches=0.0)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [] 219 | } 220 | ], 221 | "metadata": { 222 | "kernelspec": { 223 | "display_name": "Python 3", 224 | "language": "python", 225 | "name": "python3" 226 | }, 227 | "language_info": { 228 | "codemirror_mode": { 229 | "name": "ipython", 230 | "version": 3 231 | }, 232 | "file_extension": ".py", 233 | "mimetype": "text/x-python", 234 | "name": "python", 235 | "nbconvert_exporter": "python", 236 | "pygments_lexer": "ipython3", 237 | "version": "3.6.8" 238 | } 239 | }, 240 | "nbformat": 4, 241 | "nbformat_minor": 2 242 | } 243 | -------------------------------------------------------------------------------- /Code/HAN_RoHe/utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import f1_score 2 | import datetime 3 | import dgl 4 | import errno 5 | import numpy as np 6 | import os 7 | import pickle 8 | import random 9 | import torch 10 | 11 | from dgl.data.utils import download, get_download_dir, _get_dgl_url 12 | from pprint import pprint 13 | from scipy import sparse 14 | from scipy import io as sio 15 | 16 | 17 | def score(logits, labels): 18 | _, indices = torch.max(logits, dim=1) 19 | prediction = indices.long().cpu().numpy() 20 | labels = labels.cpu().numpy() 21 | accuracy = (prediction == labels).sum() / len(prediction) 22 | micro_f1 = f1_score(labels, prediction, average='micro') 23 | macro_f1 = f1_score(labels, prediction, average='macro') 24 | return accuracy, micro_f1, macro_f1 25 | 26 | def score_detail(logits, labels): 27 | _, indices = torch.max(logits, dim=1) 28 | prediction = indices.long().cpu().numpy() 29 | labels = labels.cpu().numpy() 30 | acc_detail = np.array(prediction == labels,dtype='int') 31 | accuracy = (prediction == labels).sum() / len(prediction) 32 | micro_f1 = f1_score(labels, prediction, average='micro') 33 | macro_f1 = f1_score(labels, prediction, average='macro') 34 | return accuracy, micro_f1, macro_f1,acc_detail 35 | 36 | def evaluate(model, g, features, labels, mask, loss_func, detail=False): 37 | model.eval() 38 | with torch.no_grad(): 39 | logits = model(g, features) 40 | loss = loss_func(logits[mask], labels[mask]) 41 | if detail: 42 | accuracy, micro_f1, macro_f1, acc_detail = score_detail(logits[mask], labels[mask]) 43 | return acc_detail, accuracy, micro_f1, macro_f1 44 | else: 45 | accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask]) 46 | return loss, accuracy, micro_f1, macro_f1 47 | 48 | 49 | def get_hg(dataname,given_adj_dict): 50 | if dataname == 'acm': 51 | hg_new = dgl.heterograph({ 52 | ('paper', 'pa', 'author'): given_adj_dict['pa'].nonzero(), 53 | ('author', 'ap', 'paper'): given_adj_dict['ap'].nonzero(), 54 | ('paper', 'pf', 'field'): given_adj_dict['pf'].nonzero(), 55 | ('field', 'fp', 'paper'): given_adj_dict['fp'].nonzero(), 56 | }) 57 | if dataname == 'aminer': 58 | hg_new = dgl.heterograph({ 59 | ('paper', 'pa', 'author'): given_adj_dict['pa'].nonzero(), 60 | ('author', 'ap', 'paper'): given_adj_dict['ap'].nonzero(), 61 | ('paper', 'pr', 'ref'): given_adj_dict['pr'].nonzero(), 62 | ('ref', 'rp', 'paper'): given_adj_dict['rp'].nonzero(), 63 | }) 64 | if dataname == 'dblp': 65 | hg_new = dgl.heterograph({ 66 | ('paper', 'pa', 'author'): given_adj_dict['pa'].nonzero(), 67 | ('author', 'ap', 'paper'): given_adj_dict['ap'].nonzero(), 68 | ('paper', 'pc', 'conf'): given_adj_dict['pc'].nonzero(), 69 | ('conf', 'cp', 'paper'): given_adj_dict['cp'].nonzero(), 70 | ('paper', 'pt', 'term'): given_adj_dict['pt'].nonzero(), 71 | ('term', 'tp', 'paper'): given_adj_dict['tp'].nonzero() 72 | }) 73 | if dataname == 'yelp': 74 | hg_new = dgl.heterograph({ 75 | ('b', 'bu', 'u'): given_adj_dict['bu'].nonzero(), 76 | ('u', 'ub', 'b'): given_adj_dict['ub'].nonzero(), 77 | ('b', 'bs', 's'): given_adj_dict['bs'].nonzero(), 78 | ('s', 'sb', 'b'): given_adj_dict['sb'].nonzero(), 79 | ('b', 'bl', 'l'): given_adj_dict['bl'].nonzero(), 80 | ('l', 'lb', 'b'): given_adj_dict['lb'].nonzero(), 81 | }) 82 | return hg_new 83 | 84 | 85 | def set_random_seed(seed=0): 86 | """Set random seed. 87 | Parameters 88 | ---------- 89 | seed : int 90 | Random seed to use 91 | """ 92 | random.seed(seed) 93 | np.random.seed(seed) 94 | torch.manual_seed(seed) 95 | if torch.cuda.is_available(): 96 | torch.cuda.manual_seed(seed) 97 | 98 | def mkdir_p(path, log=True): 99 | """Create a directory for the specified path. 100 | Parameters 101 | ---------- 102 | path : str 103 | Path name 104 | log : bool 105 | Whether to print result for directory creation 106 | """ 107 | try: 108 | os.makedirs(path) 109 | if log: 110 | print('Created directory {}'.format(path)) 111 | except OSError as exc: 112 | if exc.errno == errno.EEXIST and os.path.isdir(path) and log: 113 | print('Directory {} already exists.'.format(path)) 114 | else: 115 | raise 116 | 117 | def get_date_postfix(): 118 | """Get a date based postfix for directory name. 119 | Returns 120 | ------- 121 | post_fix : str 122 | """ 123 | dt = datetime.datetime.now() 124 | post_fix = '{}_{:02d}-{:02d}-{:02d}'.format( 125 | dt.date(), dt.hour, dt.minute, dt.second) 126 | 127 | return post_fix 128 | 129 | def setup_log_dir(args, sampling=False): 130 | """Name and create directory for logging. 131 | Parameters 132 | ---------- 133 | args : dict 134 | Configuration 135 | Returns 136 | ------- 137 | log_dir : str 138 | Path for logging directory 139 | sampling : bool 140 | Whether we are using sampling based training 141 | """ 142 | date_postfix = get_date_postfix() 143 | log_dir = os.path.join( 144 | args['log_dir'], 145 | '{}_{}'.format(args['dataset'], date_postfix)) 146 | 147 | if sampling: 148 | log_dir = log_dir + '_sampling' 149 | 150 | mkdir_p(log_dir) 151 | return log_dir 152 | 153 | # The configuration below is from the paper. 154 | default_configure = { 155 | 'lr': 0.005, # Learning rate 156 | 'num_heads': [8], # Number of attention heads for node-level attention 157 | 'hidden_units': 8, 158 | 'dropout': 0.6, 159 | 'weight_decay': 0.001, 160 | 'num_epochs': 200, 161 | 'patience': 100 162 | } 163 | 164 | sampling_configure = { 165 | 'batch_size': 20 166 | } 167 | 168 | def setup(args): 169 | args.update(default_configure) 170 | set_random_seed(args['seed']) 171 | args['dataset'] = 'ACMRaw' if args['hetero'] else 'ACM' 172 | args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu' 173 | args['log_dir'] = setup_log_dir(args) 174 | return args 175 | 176 | def setup_for_sampling(args): 177 | args.update(default_configure) 178 | args.update(sampling_configure) 179 | set_random_seed() 180 | args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu' 181 | args['log_dir'] = setup_log_dir(args, sampling=True) 182 | return args 183 | 184 | def get_binary_mask(total_size, indices): 185 | mask = torch.zeros(total_size) 186 | mask[indices] = 1 187 | return mask.byte() 188 | 189 | def load_acm(remove_self_loop): 190 | url = 'dataset/ACM3025.pkl' 191 | data_path = get_download_dir() + '/ACM3025.pkl' 192 | download(_get_dgl_url(url), path=data_path) 193 | 194 | with open(data_path, 'rb') as f: 195 | data = pickle.load(f) 196 | 197 | labels, features = torch.from_numpy(data['label'].todense()).long(), \ 198 | torch.from_numpy(data['feature'].todense()).float() 199 | num_classes = labels.shape[1] 200 | labels = labels.nonzero()[:, 1] 201 | 202 | if remove_self_loop: 203 | num_nodes = data['label'].shape[0] 204 | data['PAP'] = sparse.csr_matrix(data['PAP'] - np.eye(num_nodes)) 205 | data['PLP'] = sparse.csr_matrix(data['PLP'] - np.eye(num_nodes)) 206 | 207 | # Adjacency matrices for meta path based neighbors 208 | # (Mufei): I verified both of them are binary adjacency matrices with self loops 209 | author_g = dgl.from_scipy(data['PAP']) 210 | subject_g = dgl.from_scipy(data['PLP']) 211 | gs = [author_g, subject_g] 212 | 213 | train_idx = torch.from_numpy(data['train_idx']).long().squeeze(0) 214 | val_idx = torch.from_numpy(data['val_idx']).long().squeeze(0) 215 | test_idx = torch.from_numpy(data['test_idx']).long().squeeze(0) 216 | 217 | num_nodes = author_g.number_of_nodes() 218 | train_mask = get_binary_mask(num_nodes, train_idx) 219 | val_mask = get_binary_mask(num_nodes, val_idx) 220 | test_mask = get_binary_mask(num_nodes, test_idx) 221 | 222 | print('dataset loaded') 223 | pprint({ 224 | 'dataset': 'ACM', 225 | 'train': train_mask.sum().item() / num_nodes, 226 | 'val': val_mask.sum().item() / num_nodes, 227 | 'test': test_mask.sum().item() / num_nodes 228 | }) 229 | 230 | return gs, features, labels, num_classes, train_idx, val_idx, test_idx, \ 231 | train_mask, val_mask, test_mask 232 | 233 | # def load_acm_raw(remove_self_loop): 234 | # assert not remove_self_loop 235 | # url = 'dataset/ACM.mat' 236 | # data_path = get_download_dir() + '/ACM.mat' 237 | # download(_get_dgl_url(url), path=data_path) 238 | 239 | # data = sio.loadmat(data_path) 240 | # p_vs_l = data['PvsL'] # paper-field? 241 | # p_vs_a = data['PvsA'] # paper-author 242 | # p_vs_t = data['PvsT'] # paper-term, bag of words 243 | # p_vs_c = data['PvsC'] # paper-conference, labels come from that 244 | 245 | # # We assign 246 | # # (1) KDD papers as class 0 (data mining), 247 | # # (2) SIGMOD and VLDB papers as class 1 (database), 248 | # # (3) SIGCOMM and MOBICOMM papers as class 2 (communication) 249 | # conf_ids = [0, 1, 9, 10, 13] 250 | # label_ids = [0, 1, 2, 2, 1] 251 | 252 | # p_vs_c_filter = p_vs_c[:, conf_ids] 253 | # p_selected = (p_vs_c_filter.sum(1) != 0).A1.nonzero()[0] 254 | # p_vs_l = p_vs_l[p_selected] 255 | # p_vs_a = p_vs_a[p_selected] 256 | # p_vs_t = p_vs_t[p_selected] 257 | # p_vs_c = p_vs_c[p_selected] 258 | 259 | # hg = dgl.heterograph({ 260 | # ('paper', 'pa', 'author'): p_vs_a.nonzero(), 261 | # ('author', 'ap', 'paper'): p_vs_a.transpose().nonzero(), 262 | # ('paper', 'pf', 'field'): p_vs_l.nonzero(), 263 | # ('field', 'fp', 'paper'): p_vs_l.transpose().nonzero() 264 | # }) 265 | 266 | # features = torch.FloatTensor(p_vs_t.toarray()) 267 | 268 | # pc_p, pc_c = p_vs_c.nonzero() 269 | # labels = np.zeros(len(p_selected), dtype=np.int64) 270 | # for conf_id, label_id in zip(conf_ids, label_ids): 271 | # labels[pc_p[pc_c == conf_id]] = label_id 272 | # labels = torch.LongTensor(labels) 273 | 274 | # num_classes = 3 275 | 276 | # float_mask = np.zeros(len(pc_p)) 277 | # for conf_id in conf_ids: 278 | # pc_c_mask = (pc_c == conf_id) 279 | # float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum())) 280 | # train_idx = np.where(float_mask <= 0.2)[0] 281 | # val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0] 282 | # test_idx = np.where(float_mask > 0.3)[0] 283 | 284 | # num_nodes = hg.number_of_nodes('paper') 285 | # train_mask = get_binary_mask(num_nodes, train_idx) 286 | # val_mask = get_binary_mask(num_nodes, val_idx) 287 | # test_mask = get_binary_mask(num_nodes, test_idx) 288 | 289 | # return hg, features, labels, num_classes, train_idx, val_idx, test_idx, \ 290 | # train_mask, val_mask, test_mask 291 | 292 | 293 | def load_acm_raw(remove_self_loop): 294 | assert not remove_self_loop 295 | set_random_seed(1) 296 | url = 'dataset/ACM.mat' 297 | data_path = get_download_dir() + '/ACM.mat' 298 | 299 | data = sio.loadmat(data_path) 300 | p_vs_f = data['PvsL'] 301 | p_vs_a = data['PvsA'] 302 | p_vs_t = data['PvsT'] 303 | p_vs_c = data['PvsC'] 304 | 305 | # We assign 306 | # (1) KDD papers as class 0 (data mining), 307 | # (2) SIGMOD and VLDB papers as class 1 (database), 308 | # (3) SIGCOMM and MOBICOMM papers as class 2 (communication) 309 | conf_ids = [0, 1, 9, 10, 13] 310 | label_ids = [0, 1, 2, 2, 1] 311 | 312 | p_vs_c_filter = p_vs_c[:, conf_ids] 313 | p_selected = (p_vs_c_filter.sum(1) != 0).A1.nonzero()[0] 314 | p_vs_f = p_vs_f[p_selected] 315 | p_vs_a = p_vs_a[p_selected] 316 | p_vs_t = p_vs_t[p_selected] 317 | p_vs_c = p_vs_c[p_selected]#CSC 318 | hete_adjs = {'pa':p_vs_a, 'ap':p_vs_a.T,\ 319 | 'pf':p_vs_f, 'fp':p_vs_f.T} 320 | 321 | hg = dgl.heterograph({ 322 | ('paper', 'pa', 'author'): p_vs_a.nonzero(), 323 | ('author', 'ap', 'paper'): p_vs_a.transpose().nonzero(), 324 | ('paper', 'pf', 'field'): p_vs_f.nonzero(), 325 | ('field', 'fp', 'paper'): p_vs_f.transpose().nonzero(), 326 | }) 327 | 328 | features = torch.FloatTensor(p_vs_t.toarray()) 329 | 330 | pc_p, pc_c = p_vs_c.nonzero() 331 | labels = np.zeros(len(p_selected), dtype=np.int64) 332 | for conf_id, label_id in zip(conf_ids, label_ids): 333 | labels[pc_p[pc_c == conf_id]] = label_id 334 | labels = torch.LongTensor(labels) 335 | 336 | num_classes = 3 337 | 338 | float_mask = np.zeros(len(pc_p)) 339 | for conf_id in conf_ids: 340 | pc_c_mask = (pc_c == conf_id) 341 | float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum())) 342 | train_idx = np.where(float_mask <= 0.2)[0] 343 | val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0] 344 | test_idx = np.where(float_mask > 0.3)[0] 345 | 346 | num_nodes = hg.number_of_nodes('paper') 347 | train_mask = get_binary_mask(num_nodes, train_idx) 348 | val_mask = get_binary_mask(num_nodes, val_idx) 349 | test_mask = get_binary_mask(num_nodes, test_idx) 350 | 351 | return hg, hete_adjs, features, labels, num_classes, train_idx, val_idx, test_idx, \ 352 | train_mask, val_mask, test_mask 353 | 354 | def load_data(dataset, remove_self_loop=False): 355 | if dataset == 'ACM': 356 | return load_acm(remove_self_loop) 357 | elif dataset == 'ACMRaw': 358 | return load_acm_raw(remove_self_loop) 359 | else: 360 | return NotImplementedError('Unsupported dataset {}'.format(dataset)) 361 | 362 | class EarlyStopping(object): 363 | def __init__(self, patience=10): 364 | dt = datetime.datetime.now() 365 | self.filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format( 366 | dt.date(), dt.hour, dt.minute, dt.second) 367 | self.patience = patience 368 | self.counter = 0 369 | self.best_acc = None 370 | self.best_loss = None 371 | self.early_stop = False 372 | 373 | def step(self, loss, acc, model): 374 | if self.best_loss is None: 375 | self.best_acc = acc 376 | self.best_loss = loss 377 | self.save_checkpoint(model) 378 | elif (loss > self.best_loss) and (acc < self.best_acc): 379 | self.counter += 1 380 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 381 | if self.counter >= self.patience: 382 | self.early_stop = True 383 | else: 384 | if (loss <= self.best_loss) and (acc >= self.best_acc): 385 | self.save_checkpoint(model) 386 | self.best_loss = np.min((loss, self.best_loss)) 387 | self.best_acc = np.max((acc, self.best_acc)) 388 | self.counter = 0 389 | return self.early_stop 390 | 391 | def save_checkpoint(self, model): 392 | """Saves model when validation loss decreases.""" 393 | torch.save(model.state_dict(), self.filename) 394 | 395 | def load_checkpoint(self, model): 396 | """Load the latest checkpoint.""" 397 | model.load_state_dict(torch.load(self.filename)) -------------------------------------------------------------------------------- /Code/HAN/utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import f1_score 2 | import datetime 3 | import dgl 4 | import errno 5 | import numpy as np 6 | import os 7 | import pickle 8 | import random 9 | import torch 10 | 11 | from dgl.data.utils import download, get_download_dir, _get_dgl_url 12 | from pprint import pprint 13 | from scipy import sparse 14 | from scipy import io as sio 15 | 16 | 17 | 18 | def score(logits, labels): 19 | _, indices = torch.max(logits, dim=1) 20 | prediction = indices.long().cpu().numpy() 21 | labels = labels.cpu().numpy() 22 | accuracy = (prediction == labels).sum() / len(prediction) 23 | micro_f1 = f1_score(labels, prediction, average='micro') 24 | macro_f1 = f1_score(labels, prediction, average='macro') 25 | return accuracy, micro_f1, macro_f1 26 | 27 | def score_detail(logits, labels): 28 | _, indices = torch.max(logits, dim=1) 29 | prediction = indices.long().cpu().numpy() 30 | labels = labels.cpu().numpy() 31 | acc_detail = np.array(prediction == labels,dtype='int') 32 | accuracy = (prediction == labels).sum() / len(prediction) 33 | micro_f1 = f1_score(labels, prediction, average='micro') 34 | macro_f1 = f1_score(labels, prediction, average='macro') 35 | return accuracy, micro_f1, macro_f1,acc_detail 36 | 37 | def evaluate(model, g, features, labels, mask, loss_func, detail=False): 38 | model.eval() 39 | with torch.no_grad(): 40 | logits = model(g, features) 41 | loss = loss_func(logits[mask], labels[mask]) 42 | if detail: 43 | accuracy, micro_f1, macro_f1, acc_detail = score_detail(logits[mask], labels[mask]) 44 | return acc_detail, accuracy, micro_f1, macro_f1 45 | else: 46 | accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask]) 47 | return loss, accuracy, micro_f1, macro_f1 48 | 49 | 50 | def get_hg(dataname,given_adj_dict): 51 | if dataname == 'acm': 52 | hg_new = dgl.heterograph({ 53 | ('paper', 'pa', 'author'): given_adj_dict['pa'].nonzero(), 54 | ('author', 'ap', 'paper'): given_adj_dict['ap'].nonzero(), 55 | ('paper', 'pf', 'field'): given_adj_dict['pf'].nonzero(), 56 | ('field', 'fp', 'paper'): given_adj_dict['fp'].nonzero(), 57 | }) 58 | if dataname == 'aminer': 59 | hg_new = dgl.heterograph({ 60 | ('paper', 'pa', 'author'): given_adj_dict['pa'].nonzero(), 61 | ('author', 'ap', 'paper'): given_adj_dict['ap'].nonzero(), 62 | ('paper', 'pr', 'ref'): given_adj_dict['pr'].nonzero(), 63 | ('ref', 'rp', 'paper'): given_adj_dict['rp'].nonzero(), 64 | }) 65 | if dataname == 'dblp': 66 | hg_new = dgl.heterograph({ 67 | ('paper', 'pa', 'author'): given_adj_dict['pa'].nonzero(), 68 | ('author', 'ap', 'paper'): given_adj_dict['ap'].nonzero(), 69 | ('paper', 'pc', 'conf'): given_adj_dict['pc'].nonzero(), 70 | ('conf', 'cp', 'paper'): given_adj_dict['cp'].nonzero(), 71 | ('paper', 'pt', 'term'): given_adj_dict['pt'].nonzero(), 72 | ('term', 'tp', 'paper'): given_adj_dict['tp'].nonzero() 73 | }) 74 | if dataname == 'yelp': 75 | hg_new = dgl.heterograph({ 76 | ('b', 'bu', 'u'): given_adj_dict['bu'].nonzero(), 77 | ('u', 'ub', 'b'): given_adj_dict['ub'].nonzero(), 78 | ('b', 'bs', 's'): given_adj_dict['bs'].nonzero(), 79 | ('s', 'sb', 'b'): given_adj_dict['sb'].nonzero(), 80 | ('b', 'bl', 'l'): given_adj_dict['bl'].nonzero(), 81 | ('l', 'lb', 'b'): given_adj_dict['lb'].nonzero(), 82 | }) 83 | return hg_new 84 | 85 | 86 | def set_random_seed(seed=0): 87 | """Set random seed. 88 | Parameters 89 | ---------- 90 | seed : int 91 | Random seed to use 92 | """ 93 | random.seed(seed) 94 | np.random.seed(seed) 95 | torch.manual_seed(seed) 96 | if torch.cuda.is_available(): 97 | torch.cuda.manual_seed(seed) 98 | 99 | def mkdir_p(path, log=True): 100 | """Create a directory for the specified path. 101 | Parameters 102 | ---------- 103 | path : str 104 | Path name 105 | log : bool 106 | Whether to print result for directory creation 107 | """ 108 | try: 109 | os.makedirs(path) 110 | if log: 111 | print('Created directory {}'.format(path)) 112 | except OSError as exc: 113 | if exc.errno == errno.EEXIST and os.path.isdir(path) and log: 114 | print('Directory {} already exists.'.format(path)) 115 | else: 116 | raise 117 | 118 | def get_date_postfix(): 119 | """Get a date based postfix for directory name. 120 | Returns 121 | ------- 122 | post_fix : str 123 | """ 124 | dt = datetime.datetime.now() 125 | post_fix = '{}_{:02d}-{:02d}-{:02d}'.format( 126 | dt.date(), dt.hour, dt.minute, dt.second) 127 | 128 | return post_fix 129 | 130 | def setup_log_dir(args, sampling=False): 131 | """Name and create directory for logging. 132 | Parameters 133 | ---------- 134 | args : dict 135 | Configuration 136 | Returns 137 | ------- 138 | log_dir : str 139 | Path for logging directory 140 | sampling : bool 141 | Whether we are using sampling based training 142 | """ 143 | date_postfix = get_date_postfix() 144 | log_dir = os.path.join( 145 | args['log_dir'], 146 | '{}_{}'.format(args['dataset'], date_postfix)) 147 | 148 | if sampling: 149 | log_dir = log_dir + '_sampling' 150 | 151 | mkdir_p(log_dir) 152 | return log_dir 153 | 154 | # The configuration below is from the paper. 155 | default_configure = { 156 | 'lr': 0.005, # Learning rate 157 | 'num_heads': [8], # Number of attention heads for node-level attention 158 | 'hidden_units': 8, 159 | 'dropout': 0.6, 160 | 'weight_decay': 0.001, 161 | 'num_epochs': 200, 162 | 'patience': 100 163 | } 164 | 165 | sampling_configure = { 166 | 'batch_size': 20 167 | } 168 | 169 | def setup(args): 170 | args.update(default_configure) 171 | set_random_seed(args['seed']) 172 | args['dataset'] = 'ACMRaw' if args['hetero'] else 'ACM' 173 | args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu' 174 | args['log_dir'] = setup_log_dir(args) 175 | return args 176 | 177 | def setup_for_sampling(args): 178 | args.update(default_configure) 179 | args.update(sampling_configure) 180 | set_random_seed() 181 | args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu' 182 | args['log_dir'] = setup_log_dir(args, sampling=True) 183 | return args 184 | 185 | def get_binary_mask(total_size, indices): 186 | mask = torch.zeros(total_size) 187 | mask[indices] = 1 188 | return mask.byte() 189 | 190 | def load_acm(remove_self_loop): 191 | url = 'dataset/ACM3025.pkl' 192 | data_path = get_download_dir() + '/ACM3025.pkl' 193 | download(_get_dgl_url(url), path=data_path) 194 | 195 | with open(data_path, 'rb') as f: 196 | data = pickle.load(f) 197 | 198 | labels, features = torch.from_numpy(data['label'].todense()).long(), \ 199 | torch.from_numpy(data['feature'].todense()).float() 200 | num_classes = labels.shape[1] 201 | labels = labels.nonzero()[:, 1] 202 | 203 | if remove_self_loop: 204 | num_nodes = data['label'].shape[0] 205 | data['PAP'] = sparse.csr_matrix(data['PAP'] - np.eye(num_nodes)) 206 | data['PLP'] = sparse.csr_matrix(data['PLP'] - np.eye(num_nodes)) 207 | 208 | # Adjacency matrices for meta path based neighbors 209 | # (Mufei): I verified both of them are binary adjacency matrices with self loops 210 | author_g = dgl.from_scipy(data['PAP']) 211 | subject_g = dgl.from_scipy(data['PLP']) 212 | gs = [author_g, subject_g] 213 | 214 | train_idx = torch.from_numpy(data['train_idx']).long().squeeze(0) 215 | val_idx = torch.from_numpy(data['val_idx']).long().squeeze(0) 216 | test_idx = torch.from_numpy(data['test_idx']).long().squeeze(0) 217 | 218 | num_nodes = author_g.number_of_nodes() 219 | train_mask = get_binary_mask(num_nodes, train_idx) 220 | val_mask = get_binary_mask(num_nodes, val_idx) 221 | test_mask = get_binary_mask(num_nodes, test_idx) 222 | 223 | print('dataset loaded') 224 | pprint({ 225 | 'dataset': 'ACM', 226 | 'train': train_mask.sum().item() / num_nodes, 227 | 'val': val_mask.sum().item() / num_nodes, 228 | 'test': test_mask.sum().item() / num_nodes 229 | }) 230 | 231 | return gs, features, labels, num_classes, train_idx, val_idx, test_idx, \ 232 | train_mask, val_mask, test_mask 233 | 234 | # def load_acm_raw(remove_self_loop): 235 | # assert not remove_self_loop 236 | # url = 'dataset/ACM.mat' 237 | # data_path = get_download_dir() + '/ACM.mat' 238 | # download(_get_dgl_url(url), path=data_path) 239 | 240 | # data = sio.loadmat(data_path) 241 | # p_vs_l = data['PvsL'] # paper-field? 242 | # p_vs_a = data['PvsA'] # paper-author 243 | # p_vs_t = data['PvsT'] # paper-term, bag of words 244 | # p_vs_c = data['PvsC'] # paper-conference, labels come from that 245 | 246 | # # We assign 247 | # # (1) KDD papers as class 0 (data mining), 248 | # # (2) SIGMOD and VLDB papers as class 1 (database), 249 | # # (3) SIGCOMM and MOBICOMM papers as class 2 (communication) 250 | # conf_ids = [0, 1, 9, 10, 13] 251 | # label_ids = [0, 1, 2, 2, 1] 252 | 253 | # p_vs_c_filter = p_vs_c[:, conf_ids] 254 | # p_selected = (p_vs_c_filter.sum(1) != 0).A1.nonzero()[0] 255 | # p_vs_l = p_vs_l[p_selected] 256 | # p_vs_a = p_vs_a[p_selected] 257 | # p_vs_t = p_vs_t[p_selected] 258 | # p_vs_c = p_vs_c[p_selected] 259 | 260 | # hg = dgl.heterograph({ 261 | # ('paper', 'pa', 'author'): p_vs_a.nonzero(), 262 | # ('author', 'ap', 'paper'): p_vs_a.transpose().nonzero(), 263 | # ('paper', 'pf', 'field'): p_vs_l.nonzero(), 264 | # ('field', 'fp', 'paper'): p_vs_l.transpose().nonzero() 265 | # }) 266 | 267 | # features = torch.FloatTensor(p_vs_t.toarray()) 268 | 269 | # pc_p, pc_c = p_vs_c.nonzero() 270 | # labels = np.zeros(len(p_selected), dtype=np.int64) 271 | # for conf_id, label_id in zip(conf_ids, label_ids): 272 | # labels[pc_p[pc_c == conf_id]] = label_id 273 | # labels = torch.LongTensor(labels) 274 | 275 | # num_classes = 3 276 | 277 | # float_mask = np.zeros(len(pc_p)) 278 | # for conf_id in conf_ids: 279 | # pc_c_mask = (pc_c == conf_id) 280 | # float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum())) 281 | # train_idx = np.where(float_mask <= 0.2)[0] 282 | # val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0] 283 | # test_idx = np.where(float_mask > 0.3)[0] 284 | 285 | # num_nodes = hg.number_of_nodes('paper') 286 | # train_mask = get_binary_mask(num_nodes, train_idx) 287 | # val_mask = get_binary_mask(num_nodes, val_idx) 288 | # test_mask = get_binary_mask(num_nodes, test_idx) 289 | 290 | # return hg, features, labels, num_classes, train_idx, val_idx, test_idx, \ 291 | # train_mask, val_mask, test_mask 292 | 293 | 294 | def load_acm_raw(remove_self_loop): 295 | assert not remove_self_loop 296 | set_random_seed(1) 297 | url = 'dataset/ACM.mat' 298 | data_path = get_download_dir() + '/ACM.mat' 299 | 300 | data = sio.loadmat(data_path) 301 | p_vs_f = data['PvsL'] 302 | p_vs_a = data['PvsA'] 303 | p_vs_t = data['PvsT'] 304 | p_vs_c = data['PvsC'] 305 | 306 | # We assign 307 | # (1) KDD papers as class 0 (data mining), 308 | # (2) SIGMOD and VLDB papers as class 1 (database), 309 | # (3) SIGCOMM and MOBICOMM papers as class 2 (communication) 310 | conf_ids = [0, 1, 9, 10, 13] 311 | label_ids = [0, 1, 2, 2, 1] 312 | 313 | p_vs_c_filter = p_vs_c[:, conf_ids] 314 | p_selected = (p_vs_c_filter.sum(1) != 0).A1.nonzero()[0] 315 | p_vs_f = p_vs_f[p_selected] 316 | p_vs_a = p_vs_a[p_selected] 317 | p_vs_t = p_vs_t[p_selected] 318 | p_vs_c = p_vs_c[p_selected]#CSC 319 | hete_adjs = {'pa':p_vs_a, 'ap':p_vs_a.T,\ 320 | 'pf':p_vs_f, 'fp':p_vs_f.T} 321 | 322 | hg = dgl.heterograph({ 323 | ('paper', 'pa', 'author'): p_vs_a.nonzero(), 324 | ('author', 'ap', 'paper'): p_vs_a.transpose().nonzero(), 325 | ('paper', 'pf', 'field'): p_vs_f.nonzero(), 326 | ('field', 'fp', 'paper'): p_vs_f.transpose().nonzero(), 327 | }) 328 | 329 | features = torch.FloatTensor(p_vs_t.toarray()) 330 | 331 | pc_p, pc_c = p_vs_c.nonzero() 332 | labels = np.zeros(len(p_selected), dtype=np.int64) 333 | for conf_id, label_id in zip(conf_ids, label_ids): 334 | labels[pc_p[pc_c == conf_id]] = label_id 335 | labels = torch.LongTensor(labels) 336 | 337 | num_classes = 3 338 | 339 | float_mask = np.zeros(len(pc_p)) 340 | for conf_id in conf_ids: 341 | pc_c_mask = (pc_c == conf_id) 342 | float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum())) 343 | train_idx = np.where(float_mask <= 0.2)[0] 344 | val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0] 345 | test_idx = np.where(float_mask > 0.3)[0] 346 | 347 | num_nodes = hg.number_of_nodes('paper') 348 | train_mask = get_binary_mask(num_nodes, train_idx) 349 | val_mask = get_binary_mask(num_nodes, val_idx) 350 | test_mask = get_binary_mask(num_nodes, test_idx) 351 | 352 | return hg, hete_adjs, features, labels, num_classes, train_idx, val_idx, test_idx, \ 353 | train_mask, val_mask, test_mask 354 | 355 | def load_data(dataset, remove_self_loop=False): 356 | if dataset == 'ACM': 357 | return load_acm(remove_self_loop) 358 | elif dataset == 'ACMRaw': 359 | return load_acm_raw(remove_self_loop) 360 | else: 361 | return NotImplementedError('Unsupported dataset {}'.format(dataset)) 362 | 363 | class EarlyStopping(object): 364 | def __init__(self, patience=10): 365 | dt = datetime.datetime.now() 366 | self.filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format( 367 | dt.date(), dt.hour, dt.minute, dt.second) 368 | self.patience = patience 369 | self.counter = 0 370 | self.best_acc = None 371 | self.best_loss = None 372 | self.early_stop = False 373 | 374 | def step(self, loss, acc, model): 375 | if self.best_loss is None: 376 | self.best_acc = acc 377 | self.best_loss = loss 378 | self.save_checkpoint(model) 379 | elif (loss > self.best_loss) and (acc < self.best_acc): 380 | self.counter += 1 381 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 382 | if self.counter >= self.patience: 383 | self.early_stop = True 384 | else: 385 | if (loss <= self.best_loss) and (acc >= self.best_acc): 386 | self.save_checkpoint(model) 387 | self.best_loss = np.min((loss, self.best_loss)) 388 | self.best_acc = np.max((acc, self.best_acc)) 389 | self.counter = 0 390 | return self.early_stop 391 | 392 | def save_checkpoint(self, model): 393 | """Saves model when validation loss decreases.""" 394 | torch.save(model.state_dict(), self.filename) 395 | 396 | def load_checkpoint(self, model): 397 | """Load the latest checkpoint.""" 398 | model.load_state_dict(torch.load(self.filename)) --------------------------------------------------------------------------------