├── misc ├── GIB-RGIB.png ├── digram-RGIB.png ├── digram-RGIB-REP.png ├── digram-RGIB-SSL.png └── bilateral-edge-noise.png ├── LICENSE ├── code ├── utils.py ├── models.py ├── standard-training.py ├── mixed-noise-generation.py ├── loss.py ├── RGIB-ssl-training.py └── RGIB-rep-training.py └── README.md /misc/GIB-RGIB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndrewZhou924/RGIB/HEAD/misc/GIB-RGIB.png -------------------------------------------------------------------------------- /misc/digram-RGIB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndrewZhou924/RGIB/HEAD/misc/digram-RGIB.png -------------------------------------------------------------------------------- /misc/digram-RGIB-REP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndrewZhou924/RGIB/HEAD/misc/digram-RGIB-REP.png -------------------------------------------------------------------------------- /misc/digram-RGIB-SSL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndrewZhou924/RGIB/HEAD/misc/digram-RGIB-SSL.png -------------------------------------------------------------------------------- /misc/bilateral-edge-noise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndrewZhou924/RGIB/HEAD/misc/bilateral-edge-noise.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zhanke Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | import torch_geometric.transforms as T 5 | from torch_geometric.datasets import Planetoid, WikipediaNetwork, AttributedGraphDataset 6 | import random 7 | from tqdm import tqdm 8 | import models 9 | import scipy.stats 10 | import copy 11 | import os 12 | 13 | def parser_add_main_args(parser): 14 | parser.add_argument('--dataset', type=str, default='Cora') 15 | parser.add_argument('--rel_path', type=str, default='./data') 16 | parser.add_argument('--repeat_times', type=int, default=5) 17 | parser.add_argument('--noise_type', type=str, default='mixed') 18 | parser.add_argument('--gnn_model', type=str, default='GCN') 19 | parser.add_argument('--num_gnn_layers', type=int, default=4) 20 | parser.add_argument('--noise_ratio', type=float, default=0.4) 21 | parser.add_argument('--scheduler', type=str, default='linear') 22 | parser.add_argument('--scheduler_param', type=float, default=1.0) 23 | parser.add_argument('--search_scheduler', action='store_true') 24 | parser.add_argument('--search_iteration', type=int, default=0) 25 | return 26 | 27 | def checkPath(path): 28 | if not os.path.exists(path): 29 | os.mkdir(path) 30 | return 31 | 32 | def getDataset(dataset_name, device, rel_path='./data'): 33 | assert dataset_name in ['Cora','Citeseer','Pubmed','chameleon','squirrel','facebook'] 34 | transform = T.Compose([ 35 | T.NormalizeFeatures(), 36 | T.ToDevice(device), 37 | T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, 38 | add_negative_train_samples=False),]) 39 | if dataset_name in ['Cora', 'Citeseer', 'Pubmed']: 40 | path = osp.join(rel_path, 'Planetoid') 41 | dataset = Planetoid(path, name=dataset_name, transform=transform) 42 | elif dataset_name in ['chameleon', 'squirrel']: 43 | path = osp.join(rel_path, 'WikipediaNetwork') 44 | dataset = WikipediaNetwork(path, name=dataset_name, transform=transform) 45 | elif dataset_name in ["facebook"]: 46 | path = osp.join(rel_path, 'AttributedGraphDataset') 47 | dataset = AttributedGraphDataset(path, name=dataset_name, transform=transform) 48 | else: 49 | exit() 50 | return path, dataset 51 | 52 | def getGNNArch(GNN_name): 53 | assert GNN_name in ['GCN', 'GAT', 'SAGE', 'MLP'] 54 | if GNN_name == 'GCN': 55 | return models.GCN 56 | elif GNN_name == 'GAT': 57 | return models.GAT 58 | elif GNN_name == 'SAGE': 59 | return models.SAGE 60 | elif GNN_name == 'MLP': 61 | return models.MLP 62 | 63 | def jensen_shannon_distance(p, q): 64 | """ 65 | method to compute the Jenson-Shannon Distance 66 | between two probability distributions 67 | """ 68 | # convert the vectors into numpy arrays in case that they aren't 69 | p = np.array(p) 70 | q = np.array(q) 71 | # calculate m 72 | m = (p + q) / 2 73 | # compute Jensen Shannon Divergence 74 | divergence = (scipy.stats.entropy(p, m) + scipy.stats.entropy(q, m)) / 2 75 | # compute the Jensen Shannon Distance 76 | distance = np.sqrt(divergence) 77 | return distance 78 | 79 | def calculateDistSim(res, savePath=None): 80 | r_edge, r_node, label, predict = res 81 | label = label.int().tolist() 82 | cos = torch.nn.CosineSimilarity(dim=0) 83 | pos_sim, neg_sim = [], [] 84 | for idx in range(r_node[0].shape[0]): 85 | label_idx = label[idx] 86 | sim = float(cos(r_node[0][idx], r_node[1][idx])) 87 | if label_idx == 1: 88 | pos_sim.append(sim+1) 89 | else: 90 | neg_sim.append(sim+1) 91 | js_dis = jensen_shannon_distance(pos_sim, neg_sim) 92 | ks_dis = scipy.stats.kstest(pos_sim, neg_sim).statistic 93 | kl_dis = np.mean(scipy.special.kl_div(sorted(pos_sim), sorted(neg_sim))) 94 | return [np.mean(pos_sim), np.mean(neg_sim), ks_dis] -------------------------------------------------------------------------------- /code/models.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | from torch.nn import ModuleList 5 | import torch.nn.functional as F 6 | from torch_geometric.nn import GCNConv, SAGEConv, GATConv 7 | from torch_geometric.utils import negative_sampling 8 | import random 9 | 10 | class GCN(torch.nn.Module): 11 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers): 12 | super().__init__() 13 | self.convs = ModuleList() 14 | self.convs.append(GCNConv(in_channels, hidden_channels)) 15 | for i in range(0, num_layers-2): 16 | self.convs.append(GCNConv(hidden_channels, hidden_channels)) 17 | self.convs.append(GCNConv(hidden_channels, out_channels)) 18 | 19 | def encode(self, x, edge_index): 20 | for conv in self.convs[:-1]: 21 | x = conv(x, edge_index).relu() 22 | x = self.convs[-1](x, edge_index) 23 | return x 24 | 25 | def decode(self, z, edge_label_index): 26 | hidden = z[edge_label_index[0]] * z[edge_label_index[1]] 27 | logits = (hidden).sum(dim=-1) 28 | hidden = F.normalize(hidden, dim=1) 29 | return hidden, logits 30 | 31 | def decode_all(self, z): 32 | prob_adj = z @ z.t() 33 | return (prob_adj > 0).nonzero(as_tuple=False).t() 34 | 35 | class SAGE(torch.nn.Module): 36 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers): 37 | super().__init__() 38 | self.convs = ModuleList() 39 | self.convs.append(SAGEConv(in_channels, hidden_channels)) 40 | for i in range(0, num_layers-2): 41 | self.convs.append(SAGEConv(hidden_channels, hidden_channels)) 42 | self.convs.append(SAGEConv(hidden_channels, out_channels)) 43 | 44 | def encode(self, x, edge_index): 45 | for conv in self.convs[:-1]: 46 | x = conv(x, edge_index).relu() 47 | x = self.convs[-1](x, edge_index) 48 | return x 49 | 50 | def decode(self, z, edge_label_index): 51 | hidden = z[edge_label_index[0]] * z[edge_label_index[1]] 52 | logits = (hidden).sum(dim=-1) 53 | hidden = F.normalize(hidden, dim=1) 54 | return hidden, logits 55 | 56 | def decode_all(self, z): 57 | prob_adj = z @ z.t() 58 | return (prob_adj > 0).nonzero(as_tuple=False).t() 59 | 60 | class GAT(torch.nn.Module): 61 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, heads=8, att_dropout=0): 62 | super().__init__() 63 | self.convs = ModuleList() 64 | self.convs.append(GATConv(in_channels, hidden_channels//heads, heads=heads, dropout=att_dropout)) 65 | for i in range(0, num_layers-2): 66 | self.convs.append(GATConv(hidden_channels, hidden_channels//heads, heads=heads, dropout=att_dropout)) 67 | self.convs.append(GATConv(hidden_channels, out_channels, dropout=att_dropout)) 68 | 69 | def encode(self, x, edge_index): 70 | for conv in self.convs[:-1]: 71 | x = conv(x, edge_index).relu() 72 | x = self.convs[-1](x, edge_index) 73 | return x 74 | 75 | def decode(self, z, edge_label_index): 76 | hidden = z[edge_label_index[0]] * z[edge_label_index[1]] 77 | logits = (hidden).sum(dim=-1) 78 | hidden = F.normalize(hidden, dim=1) 79 | return hidden, logits 80 | 81 | def decode_all(self, z): 82 | prob_adj = z @ z.t() 83 | return (prob_adj > 0).nonzero(as_tuple=False).t() 84 | 85 | class MLP(torch.nn.Module): 86 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers): 87 | super().__init__() 88 | self.layers = ModuleList() 89 | self.layers.append(torch.nn.Linear(in_channels, hidden_channels)) 90 | for i in range(0, num_layers-2): 91 | self.layers.append(torch.nn.Linear(hidden_channels, hidden_channels)) 92 | self.layers.append(torch.nn.Linear(hidden_channels, out_channels)) 93 | 94 | def encode(self, x, edge_index): 95 | for fc in self.layers[:-1]: 96 | x = fc(x).relu() 97 | x = self.layers[-1](x) 98 | return x 99 | 100 | def decode(self, z, edge_label_index): 101 | hidden = z[edge_label_index[0]] * z[edge_label_index[1]] 102 | logits = (hidden).sum(dim=-1) 103 | hidden = F.normalize(hidden, dim=1) 104 | return hidden, logits 105 | 106 | def decode_all(self, z): 107 | prob_adj = z @ z.t() 108 | return (prob_adj > 0).nonzero(as_tuple=False).t() -------------------------------------------------------------------------------- /code/standard-training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import roc_auc_score 4 | import torch_geometric.transforms as T 5 | from torch_geometric.datasets import Planetoid, WikipediaNetwork 6 | from torch_geometric.utils import negative_sampling 7 | import random 8 | import argparse 9 | from tqdm import tqdm 10 | from models import * 11 | from utils import * 12 | 13 | ''' 14 | Uasge of this python scipt, e.g., 15 | python3 standard-training.py --gnn_model GCN --num_gnn_layers 4 --dataset Cora --noise_ratio 0.2 16 | python3 standard-training.py --gnn_model GAT --num_gnn_layers 4 --dataset Cora --noise_ratio 0.2 17 | python3 standard-training.py --gnn_model SAGE --num_gnn_layers 4 --dataset Cora --noise_ratio 0.2 18 | ''' 19 | 20 | def standard_train_trial(rel_path, dataset_name, noise_type, noise_ratio, model_name, num_gnn_layers, device, repeat_times): 21 | path, dataset = getDataset(dataset_name, device, rel_path) 22 | Net = getGNNArch(model_name) 23 | test_auc_list, val_auc_list, best_epoch_list = [], [], [] 24 | 25 | for idx in tqdm(range(repeat_times), ncols=50, leave=False): 26 | savePath = f'{path}/{dataset_name}/processed/{noise_type}_noise_ratio_{noise_ratio}_repeat_{idx+1}.pt' 27 | data = torch.load(savePath) 28 | (train_data, val_data, test_data) = data 29 | model = Net(dataset.num_features, 128, 64, num_gnn_layers).to(device) 30 | optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) 31 | criterion = torch.nn.BCEWithLogitsLoss() 32 | 33 | def train(): 34 | model.train() 35 | optimizer.zero_grad() 36 | z = model.encode(train_data.x, train_data.edge_index) 37 | 38 | # a new round of negative sampling for every training epoch 39 | neg_edge_index = negative_sampling( 40 | edge_index=train_data.edge_index, num_nodes=train_data.num_nodes, 41 | num_neg_samples=train_data.edge_label_index.size(1), method='sparse') 42 | edge_label_index = torch.cat( 43 | [train_data.edge_label_index, neg_edge_index], 44 | dim=-1, 45 | ) 46 | edge_label = torch.cat([ 47 | train_data.edge_label, 48 | train_data.edge_label.new_zeros(neg_edge_index.size(1)) 49 | ], dim=0) 50 | 51 | hidden, out = model.decode(z, edge_label_index) 52 | out = out.view(-1) 53 | loss = criterion(out, edge_label) 54 | loss.backward() 55 | optimizer.step() 56 | return loss 57 | 58 | @torch.no_grad() 59 | def test(data): 60 | model.eval() 61 | z = model.encode(data.x, data.edge_index) 62 | hidden, out = model.decode(z, data.edge_label_index) 63 | out = out.view(-1).sigmoid() 64 | return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()) 65 | 66 | best_val_auc = best_test_auc = 0 67 | best_epoch = 0 68 | for epoch in range(1, 1001): 69 | loss = train() 70 | val_auc = test(val_data) 71 | test_auc = test(test_data) 72 | 73 | if val_auc > best_val_auc: 74 | best_epoch = epoch 75 | best_val_auc = val_auc 76 | best_test_auc = test_auc 77 | 78 | test_auc_list.append(best_test_auc) 79 | val_auc_list.append(best_val_auc) 80 | best_epoch_list.append(best_epoch) 81 | 82 | # verbose results 83 | print(f'==> data={dataset_name}, type={noise_type}, ratio={noise_ratio}, model={model_name}, num_gnn_layers={num_gnn_layers}, repeat_time={repeat_times}') 84 | print(f'==> VAL: mean={np.mean(val_auc_list)}, std={np.std(val_auc_list)}, max={np.max(val_auc_list)}, min={np.min(val_auc_list)}') 85 | print(f'==> TEST: mean={np.mean(test_auc_list)}, std={np.std(test_auc_list)}, max={np.max(test_auc_list)}, min={np.min(test_auc_list)}') 86 | print('*'*50) 87 | return 88 | 89 | 90 | ### Parse args ### 91 | parser = argparse.ArgumentParser() 92 | parser_add_main_args(parser) 93 | args = parser.parse_args() 94 | print(args) 95 | print('*'*50) 96 | 97 | seed = 1 98 | torch.manual_seed(seed) 99 | random.seed(seed) 100 | np.random.seed(seed) 101 | torch.cuda.manual_seed(seed) 102 | torch.cuda.manual_seed_all(seed) 103 | torch.backends.cudnn.deterministic =True 104 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 105 | 106 | standard_train_trial(args.rel_path, args.dataset, args.noise_type, args.noise_ratio, args.gnn_model, args.num_gnn_layers, device, args.repeat_times) 107 | -------------------------------------------------------------------------------- /code/mixed-noise-generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | import torch_geometric.transforms as T 6 | from torch_geometric.datasets import Planetoid, WikipediaNetwork, AttributedGraphDataset 7 | import random 8 | import argparse 9 | import networkx as nx 10 | from utils import * 11 | 12 | ''' 13 | Uasge of this python scipt, e.g., 14 | python3 mixed-noise-generation.py --dataset Cora --noise_ratio 0.2 15 | python3 mixed-noise-generation.py --dataset Citeseer --noise_ratio 0.2 16 | python3 mixed-noise-generation.py --dataset Pubmed --noise_ratio 0.2 17 | python3 mixed-noise-generation.py --dataset chameleon --noise_ratio 0.2 18 | python3 mixed-noise-generation.py --dataset squirrel --noise_ratio 0.2 19 | python3 mixed-noise-generation.py --dataset facebook --noise_ratio 0.2 20 | ''' 21 | 22 | def getDistanceMatrix(G): 23 | ''' return a matrix with shape of (num_nodes, num_nodes) ''' 24 | num_nodes = G.number_of_nodes() 25 | distance = np.zeros((num_nodes, num_nodes)) 26 | distance[:,:] = num_nodes + 1 27 | for h in range(num_nodes): 28 | distance[h,h] = 0 29 | for (h,t) in G.edges(): 30 | distance[h,t] = 1 31 | distance[t,h] = 1 32 | return distance 33 | 34 | def generateNoisyEdges(distance_matrix, num_noisy_edges, filter_mask=None): 35 | select_edges = [] 36 | n_node = distance_matrix.shape[0] 37 | while len(select_edges) < num_noisy_edges: 38 | head, tail = np.random.randint(n_node, size=2) 39 | if distance_matrix[head, tail] >= 2 and [head, tail] not in select_edges: 40 | select_edges.append([head, tail]) 41 | heads = [h for [h,t] in select_edges] 42 | tails = [t for [h,t] in select_edges] 43 | # double direction 44 | noisy_edges = [heads + tails, tails + heads] 45 | noisy_edges = torch.Tensor(noisy_edges).int().to(device) 46 | return noisy_edges 47 | 48 | ### Parse args ### 49 | parser = argparse.ArgumentParser() 50 | parser_add_main_args(parser) 51 | args = parser.parse_args() 52 | print(args) 53 | assert args.noise_ratio >= 0 54 | 55 | seed = 1 56 | torch.manual_seed(seed) 57 | random.seed(seed) 58 | np.random.seed(seed) 59 | 60 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 61 | 62 | transform = T.Compose([ 63 | T.NormalizeFeatures(), 64 | T.ToDevice(device), 65 | T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, 66 | add_negative_train_samples=False), 67 | ]) 68 | 69 | # load dataset 70 | if args.dataset in ['Cora', "Citeseer", 'Pubmed']: 71 | path = osp.join(args.rel_path, 'Planetoid') 72 | dataset = Planetoid(path, name=args.dataset, transform=transform) 73 | elif args.dataset in ["chameleon", "squirrel"]: 74 | path = osp.join(args.rel_path, 'WikipediaNetwork') 75 | dataset = WikipediaNetwork(path, name=args.dataset, transform=transform) 76 | elif args.dataset in ["facebook"]: 77 | path = osp.join(args.rel_path, 'AttributedGraphDataset') 78 | dataset = AttributedGraphDataset(path, name=args.dataset, transform=transform) 79 | else: 80 | exit() 81 | 82 | # check saving folder 83 | saving_folder = f'{path}/{args.dataset}/processed/' 84 | if not os.path.exists(saving_folder): os.mkdir(saving_folder) 85 | 86 | # load distance matrix 87 | distance_matrix_path = osp.join(path, args.dataset, 'distance.npy') 88 | if os.path.exists(distance_matrix_path): 89 | distance_matrix = np.load(distance_matrix_path) 90 | else: 91 | print('==> Generating the distance matrix...') 92 | G = nx.Graph() 93 | num_nodes = dataset.data.x.shape[0] 94 | for node_idx in range(num_nodes): 95 | G.add_node(node_idx) 96 | G.add_edges_from(dataset.data.edge_index.T.tolist()) 97 | distance_matrix = getDistanceMatrix(G) 98 | np.save(open(f'{path}/{args.dataset}/distance.npy', 'wb'), distance_matrix) 99 | print('==> Finished.') 100 | 101 | for idx in range(args.repeat_times): 102 | # save train_data, val_data, val_data 103 | savePath = f'{path}/{args.dataset}/processed/mixed_noise_ratio_{args.noise_ratio}_repeat_{idx+1}.pt' 104 | print('==> Save generated data to:', savePath) 105 | if os.path.exists(savePath): 106 | print(f'==> File already exists, skip path: {savePath}') 107 | continue 108 | 109 | # copy the original dataset 110 | train_data, val_data, test_data = [copy.deepcopy(d) for d in dataset[0]] 111 | if args.noise_ratio > 0: 112 | # label noise 113 | num_noisy_edges = int(args.noise_ratio * train_data.edge_label.shape[0] / 2) 114 | noisy_index = generateNoisyEdges(distance_matrix, num_noisy_edges) 115 | train_data.edge_label_index = torch.cat([train_data.edge_label_index, noisy_index], dim=1) 116 | noisy_edge_label = torch.ones(num_noisy_edges * 2).cuda() 117 | train_data.edge_label = torch.cat([train_data.edge_label, noisy_edge_label], dim=0) 118 | # input noise 119 | num_noisy_edges = int(args.noise_ratio * len(train_data.edge_index[0]) / 2) 120 | noisy_index = generateNoisyEdges(distance_matrix, num_noisy_edges) 121 | train_data.edge_index = torch.cat([train_data.edge_index, noisy_index],dim=1) 122 | val_data.edge_index = torch.cat([val_data.edge_index, noisy_index],dim=1) 123 | test_data.edge_index = torch.cat([test_data.edge_index, noisy_index],dim=1) 124 | 125 | torch.save((train_data, val_data, test_data), savePath) 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

RGIB: Robust Graph Information Bottleneck

2 |

3 | Paper 4 | Github 5 | 6 | License 7 | 8 | 9 | 10 | 11 |

12 | 13 | Official code for the paper "Combating Bilateral Edge Noise for Robust Link Prediction" (NeurIPS 2023). 14 | 15 | ## Introduction 16 | 17 | Although link prediction on graphs has achieved great success with the development of graph neural networks (GNNs), the potential robustness under the edge noise is still less investigated. To close this gap, we first conduct an empirical study to disclose that the edge noise bilaterally perturbs both input topology and target label, yielding severe performance degradation and representation collapse. 18 | 19 | 20 | 21 |

Figure 1. Link prediction with bilateral edge noise. The GNN takes the graph as inputs, generates the edge representation, and then predicts the existence of unseen edges with labels.

22 | 23 | To address this dilemma, we propose an information-theory-guided principle, Robust Graph Information Bottleneck (RGIB), to extract reliable supervision signals and avoid representation collapse. Different from the basic information bottleneck, RGIB further decouples and balances the mutual dependence among graph topology ($\tilde{A}$), target labels ($\tilde{Y}$), and representation ($\bf{H}$), building new learning objectives for robust representation against the bilateral noise. 24 | 25 |

26 | 27 |

Figure 2. The principles of basic GIB and the proposed RGIB.

28 | 29 | Two instantiations, RGIB-SSL and RGIB-REP, are explored to leverage the merits of different methodologies, i.e., self-supervised learning and data reparameterization, for implicit and explicit data denoising, respectively. 30 | 31 | 32 | 33 | 34 | 35 |
36 | 37 |

Figure 3. Digrams of RGIB (left) and its two instantiations RGIB-SSL (middle) and RGIB-REP (right).

38 | 39 | 40 | 41 | ## Usage 42 | 43 | ### Step1. Installation (with pip) 44 | 45 | Create a new virtual environment. 46 | ``` 47 | python -m venv ~/RGIB 48 | source ~/RGIB/bin/activate 49 | ``` 50 | 51 | Install the essential dependencies. 52 | ``` 53 | pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 54 | 55 | pip install torch-scatter==2.0.9 torch-sparse==0.6.15 torch-cluster==1.6.0 torch-geometric==2.1.0 -f https://data.pyg.org/whl/torch-1.12.1+cu113.html 56 | 57 | pip install networkx pygcl dgl packaging pandas hyperopt 58 | ``` 59 | 60 | ### Step2. Generate noisy data (in `code/`) 61 | 62 | Template: 63 | ``` 64 | python3 mixed-noise-generation.py --dataset "dataset name" --noise_ratio "a float" 65 | ``` 66 | Note that 67 | - `dataset` is in range of `[Cora, Citeseer, Pubmed, chameleon, squirrel, facebook]`. 68 | - `noise_ratio` is in range of `[0,1]`. 69 | 70 | Examples: 71 | ``` 72 | python3 mixed-noise-generation.py --dataset Cora --noise_ratio 0.2 73 | 74 | python3 mixed-noise-generation.py --dataset Citeseer --noise_ratio 0.2 75 | ``` 76 | 77 | ### Step3. Training with noisy data (in `code/`) 78 | 79 | #### Standard training with cross-entropy loss. 80 | Template: 81 | ``` 82 | python3 standard-training.py --gnn_model "GNN name" --num_gnn_layers "an integer" --dataset "dataset name" --noise_ratio "a float" 83 | ``` 84 | 85 | Examples: 86 | ``` 87 | python3 standard-training.py --gnn_model GCN --num_gnn_layers 4 --dataset Cora --noise_ratio 0.2 88 | 89 | python3 standard-training.py --gnn_model GAT --num_gnn_layers 4 --dataset Citeseer --noise_ratio 0.2 90 | ``` 91 | 92 | Note that 93 | - `gnn_model` can be `GCN, GAT`, or `SAGE`. 94 | - make sure the `dataset` with a certain `noise_ratio` has been generated in Step2. 95 | 96 | #### Training with RGIB-SSL / RGIB-REP. 97 | 98 | Template: 99 | ``` 100 | python3 RGIB-ssl-training.py --gnn_model "GNN name" --num_gnn_layers "an integer" --dataset "dataset name" --noise_ratio "a float" --scheduler "scheduler_name" --scheduler_param "a float" 101 | 102 | python3 RGIB-rep-training.py --gnn_model "GNN name" --num_gnn_layers "an integer" --dataset "dataset name" --noise_ratio "a float" --scheduler "scheduler_name" --scheduler_param "a float" 103 | ``` 104 | 105 | Examples: 106 | ``` 107 | python3 RGIB-ssl-training.py --gnn_model GCN --dataset Cora --noise_ratio 0.2 --scheduler linear --scheduler_param 1.0 108 | 109 | python3 RGIB-rep-training.py --gnn_model GCN --dataset Cora --noise_ratio 0.2 --scheduler constant --scheduler_param 1.0 110 | ``` 111 | 112 | Note that 113 | - `scheduler` can be `linear`, `exp`, `sin`, `cos`, or `constant`. 114 | - `scheduler_param` is used to set the certain `scheduler`. 115 | 116 | In addition, we implement the automated searching of `scheduler` and `scheduler_param`, e.g., searching for 50 trials as follows. 117 | ``` 118 | python3 RGIB-ssl-training.py --gnn_model GCN --dataset Cora --noise_ratio 0.2 --search_scheduler --search_iteration 50 119 | 120 | python3 RGIB-rep-training.py --gnn_model GCN --dataset Cora --noise_ratio 0.2 --search_scheduler --search_iteration 50 121 | ``` 122 | 123 | 124 | 125 | ### Citation 126 | 127 | If you find our work useful, please kindly cite our paper. 128 | ```bibtex 129 | @inproceedings{zhou2023combating, 130 | title={Combating Bilateral Edge Noise for Robust Link Prediction}, 131 | author={Zhanke Zhou and Jiangchao Yao and Jiaxu Liu and Xiawei Guo and Quanming Yao and Li He and Liang Wang and Bo Zheng and Bo Han}, 132 | booktitle={Advances in Neural Information Processing Systems}, 133 | year={2023}, 134 | } 135 | ``` 136 | -------------------------------------------------------------------------------- /code/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class ContrastiveLoss(nn.Module): 6 | def __init__(self, m=2.0): 7 | super(ContrastiveLoss, self).__init__() # pre 3.3 syntax 8 | self.m = m # margin or radius 9 | 10 | def forward(self, y1, y2, d=0): 11 | # d = 0 means y1 and y2 are supposed to be same 12 | # d = 1 means y1 and y2 are supposed to be different 13 | euc_dist = nn.functional.pairwise_distance(y1, y2) 14 | if d == 0: 15 | return torch.mean(T.pow(euc_dist, 2)) # distance squared 16 | else: # d == 1 17 | delta = self.m - euc_dist # sort of reverse distance 18 | delta = torch.clamp(delta, min=0.0, max=None) 19 | return torch.mean(torch.pow(delta, 2)) # mean over all rows 20 | 21 | class NoiseRobustLoss(nn.Module): 22 | def __init__(self): 23 | super(NoiseRobustLoss, self).__init__() 24 | 25 | def forward(self, pair_dist, P, margin, use_robust_loss, start_fine): 26 | dist_sq = pair_dist * pair_dist 27 | N = len(P) 28 | if use_robust_loss == 1: 29 | if start_fine: 30 | loss = P * dist_sq + (1 - P) * (1 / margin) * torch.pow( 31 | torch.clamp(torch.pow(pair_dist, 0.5) * (margin - pair_dist), min=0.0), 2) 32 | else: 33 | loss = P * dist_sq + (1 - P) * torch.pow(torch.clamp(margin - pair_dist, min=0.0), 2) 34 | else: 35 | loss = P * dist_sq + (1 - P) * torch.pow(torch.clamp(margin - pair_dist, min=0.0), 2) 36 | loss = torch.sum(loss) / (2.0 * N) 37 | return loss 38 | 39 | class SelfAdversarialClLoss(nn.Module): 40 | def __init__(self): 41 | super(SelfAdversarialClLoss, self).__init__() 42 | 43 | def forward(self, pair_dist, P, margin, use_robust_loss, start_fine, alpha=1.0): 44 | dist_sq = pair_dist * pair_dist 45 | N = len(P) 46 | loss = P * dist_sq + (1 - P) * (margin - pair_dist) 47 | pos_index = torch.where(P == 1) 48 | neg_index = torch.where(P == 0) 49 | adv_prob = torch.zeros(N).cuda() 50 | adv_prob[pos_index] = F.softmax(pair_dist[pos_index] * alpha, dim=0).detach() 51 | adv_prob[neg_index] = F.softmax((1-pair_dist[neg_index]) * alpha, dim=0).detach() 52 | loss = torch.sum(loss * adv_prob) 53 | return loss 54 | 55 | class SupConLoss(nn.Module): 56 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 57 | It also supports the unsupervised contrastive loss in SimCLR""" 58 | def __init__(self, temperature=0.07, contrast_mode='all', 59 | base_temperature=0.07): 60 | super(SupConLoss, self).__init__() 61 | self.temperature = temperature 62 | self.contrast_mode = contrast_mode 63 | self.base_temperature = base_temperature 64 | 65 | def forward(self, features, labels=None, mask=None, reverse=False): 66 | """Compute loss for model. If both `labels` and `mask` are None, 67 | it degenerates to SimCLR unsupervised loss: 68 | https://arxiv.org/pdf/2002.05709.pdf 69 | 70 | Args: 71 | features: hidden vector of shape [bsz, n_views, ...]. 72 | labels: ground truth of shape [bsz]. 73 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 74 | has the same class as sample i. Can be asymmetric. 75 | Returns: 76 | A loss scalar. 77 | """ 78 | device = (torch.device('cuda') 79 | if features.is_cuda 80 | else torch.device('cpu')) 81 | 82 | if len(features.shape) < 3: 83 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 84 | 'at least 3 dimensions are required') 85 | if len(features.shape) > 3: 86 | features = features.view(features.shape[0], features.shape[1], -1) 87 | 88 | batch_size = features.shape[0] 89 | if labels is not None and mask is not None: 90 | raise ValueError('Cannot define both `labels` and `mask`') 91 | elif labels is None and mask is None: 92 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 93 | elif labels is not None: 94 | labels = labels.contiguous().view(-1, 1) 95 | if labels.shape[0] != batch_size: 96 | raise ValueError('Num of labels does not match num of features') 97 | mask = torch.eq(labels, labels.T).float().to(device) 98 | else: 99 | mask = mask.float().to(device) 100 | 101 | if reverse: 102 | mask = (~mask.bool()).int() 103 | 104 | contrast_count = features.shape[1] 105 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 106 | if self.contrast_mode == 'one': 107 | anchor_feature = features[:, 0] 108 | anchor_count = 1 109 | elif self.contrast_mode == 'all': 110 | anchor_feature = contrast_feature 111 | anchor_count = contrast_count 112 | else: 113 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 114 | 115 | # compute logits 116 | anchor_dot_contrast = torch.div( 117 | torch.matmul(anchor_feature, contrast_feature.T), 118 | self.temperature) 119 | # for numerical stability 120 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 121 | logits = anchor_dot_contrast - logits_max.detach() 122 | # tile mask 123 | mask = mask.repeat(anchor_count, contrast_count) 124 | # mask-out self-contrast cases 125 | logits_mask = torch.scatter( 126 | torch.ones_like(mask), 127 | 1, 128 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 129 | 0 130 | ) 131 | mask = mask * logits_mask 132 | # compute log_prob 133 | exp_logits = torch.exp(logits) * logits_mask 134 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 135 | # compute mean of log-likelihood over positive 136 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 137 | # loss 138 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 139 | loss = loss.view(anchor_count, batch_size).mean() 140 | return loss 141 | 142 | class LabelSmoothingLoss(nn.Module): 143 | def __init__(self, classes, smoothing=0, dim=-1): 144 | super(LabelSmoothingLoss, self).__init__() 145 | self.confidence = 1.0 - smoothing 146 | self.smoothing = smoothing 147 | self.cls = classes 148 | self.dim = dim 149 | 150 | def forward(self, pred, target): 151 | pred = pred.log_softmax(dim=self.dim) 152 | with torch.no_grad(): 153 | # true_dist = pred.data.clone() 154 | true_dist = torch.zeros_like(pred) 155 | true_dist.fill_(self.smoothing / (self.cls - 1)) 156 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 157 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) 158 | 159 | 160 | LOSSES = {'SupCon': SupConLoss, 161 | 'LabelSmoothing': LabelSmoothingLoss, 162 | 'CrossEntropy': nn.CrossEntropyLoss} 163 | 164 | def lunif(x, t=2): 165 | sq_pdist = torch.pdist(x, p=2).pow(2) 166 | return sq_pdist.mul(-t).exp().mean() 167 | 168 | -------------------------------------------------------------------------------- /code/RGIB-ssl-training.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import argparse 4 | import os.path as osp 5 | import numpy as np 6 | from tqdm import tqdm 7 | import torch 8 | import torch.nn.functional as F 9 | from sklearn.metrics import roc_auc_score 10 | import torch_geometric.transforms as T 11 | from torch_geometric.utils import negative_sampling 12 | import GCL.augmentors as A 13 | from GCL.augmentors.functional import add_edge 14 | import models 15 | from loss import * 16 | from utils import * 17 | from hyperopt import hp, fmin, tpe 18 | 19 | ''' 20 | Uasge of this python scipt, e.g., 21 | python3 RGIB-ssl-training.py --gnn_model GCN --num_gnn_layers 4 --dataset Cora --noise_ratio 0.2 --scheduler linear --scheduler_param 1.0 22 | python3 RGIB-ssl-training.py --gnn_model GCN --num_gnn_layers 4 --dataset Cora --noise_ratio 0.2 --search_scheduler --search_iteration 50 23 | ''' 24 | 25 | def generate_augmentation_operator(n=2): 26 | search_space = [ 27 | (A.Identity, ()), 28 | (A.FeatureMasking, (0.0, 0.3)), 29 | (A.FeatureDropout, (0.0, 0.3)), 30 | (A.EdgeRemoving, (0.0, 0.5)) 31 | ] 32 | 33 | operator_list = [] 34 | index = list(range(len(search_space))) 35 | random.shuffle(index) 36 | sampled_index = index[:n] 37 | for idx in sampled_index: 38 | opt, hp_range = search_space[idx] 39 | if hp_range == (): 40 | operator_list.append(opt()) 41 | else: 42 | sampled_hp = random.uniform(hp_range[0], hp_range[1]) 43 | operator_list.append(opt(sampled_hp)) 44 | 45 | aug = A.Compose(operator_list) 46 | return aug 47 | 48 | def standard_train_trial(rel_path, dataset_name, noise_type, noise_ratio, model_name, num_gnn_layers, device, repeat_times, verbose=True): 49 | path, dataset = getDataset(dataset_name, device, rel_path) 50 | Net = getGNNArch(model_name) 51 | test_auc_list, val_auc_list, best_epoch_list = [], [], [] 52 | 53 | MAX_EPOCH = 1000 54 | if verbose: print(f'==> schedule={args.scheduler}, param={args.scheduler_param}') 55 | assert args.scheduler in ['linear', 'exp', 'sin', 'cos', 'constant'] 56 | if args.scheduler == 'linear': 57 | lamb_scheduler = np.linspace(0, 1, MAX_EPOCH) * args.scheduler_param 58 | elif args.scheduler == 'exp': 59 | lamb_scheduler = np.array([math.exp(-t/MAX_EPOCH) for t in range(MAX_EPOCH)]) * args.scheduler_param 60 | elif args.scheduler == 'sin': 61 | lamb_scheduler = np.array([math.sin(t/MAX_EPOCH * math.pi * 0.5) for t in range(MAX_EPOCH)]) * args.scheduler_param 62 | elif args.scheduler == 'cos': 63 | lamb_scheduler = np.array([math.cos(t/MAX_EPOCH * math.pi * 0.5) for t in range(MAX_EPOCH)]) * args.scheduler_param 64 | elif args.scheduler == 'constant': 65 | lamb_scheduler = np.array([args.scheduler_param] * MAX_EPOCH) 66 | 67 | for idx in tqdm(range(repeat_times), ncols=50, leave=False): 68 | savePath = f'{path}/{dataset_name}/processed/{noise_type}_noise_ratio_{noise_ratio}_repeat_{idx+1}.pt' 69 | data = torch.load(savePath) 70 | (train_data, val_data, test_data) = data 71 | model = Net(dataset.num_features, 128, 64, num_gnn_layers).to(device) 72 | optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) 73 | criterion = torch.nn.BCEWithLogitsLoss(reduction='none') 74 | cl_criterion = SelfAdversarialClLoss() 75 | margin, use_robust_loss, start_fine = 5, 0, True 76 | 77 | def train(epoch_idx): 78 | aug1 = generate_augmentation_operator() 79 | aug2 = generate_augmentation_operator() 80 | 81 | lamb1 = lamb_scheduler[epoch_idx] 82 | lamb2, lamb3 = 1-lamb1, 1-lamb1 83 | 84 | # a new round of negative sampling for every training epoch 85 | neg_edge_index = negative_sampling( 86 | edge_index=train_data.edge_index, num_nodes=train_data.num_nodes, 87 | num_neg_samples=train_data.edge_label_index.size(1), method='sparse') 88 | edge_label_index = torch.cat( 89 | [train_data.edge_label_index, neg_edge_index], 90 | dim=-1, 91 | ) 92 | edge_label = torch.cat([ 93 | train_data.edge_label, 94 | train_data.edge_label.new_zeros(neg_edge_index.size(1)) 95 | ], dim=0) 96 | 97 | model.train() 98 | optimizer.zero_grad() 99 | 100 | # forward with original graph 101 | z = model.encode(train_data.x, train_data.edge_index) 102 | hidden, out = model.decode(z, edge_label_index) 103 | out = out.view(-1) 104 | 105 | # forward with original augmented graph1 106 | x1, edge_index1, _ = aug1(train_data.x, train_data.edge_index) 107 | z1 = model.encode(x1, edge_index1) 108 | hidden1, out1 = model.decode(z1, edge_label_index) 109 | out1 = out1.view(-1) 110 | 111 | # forward with original augmented graph2 112 | x2, edge_index2, _ = aug2(train_data.x, train_data.edge_index) 113 | z2 = model.encode(x2, edge_index2) 114 | hidden2, out2 = model.decode(z2, edge_label_index) 115 | out2 = out2.view(-1) 116 | 117 | # loss1: supervised loss with original graph 118 | sup_loss_ori = criterion(out, edge_label).mean() 119 | 120 | # loss2: supervised loss with augmentated graphs 121 | sup_loss_aug = (criterion(out1, edge_label) + criterion(out2, edge_label)).mean() if lamb1 > 0 else 0 122 | 123 | # loss3: self-supervised loss with original graphs 124 | h1 = torch.cat([hidden, hidden], dim=0) 125 | h2 = torch.cat([hidden, hidden.flip([0])], dim=0) 126 | cl_label = torch.cat([torch.ones(hidden.size(0)), torch.zeros(hidden.size(0))], dim=0).to(device) 127 | pair_dist = F.pairwise_distance(h1, h2) 128 | ssl_loss_ori = cl_criterion(pair_dist, cl_label, margin, use_robust_loss, start_fine) if lamb2 > 0 else 0 129 | 130 | # loss4: self-supervised loss with augmentated graphs 131 | h1 = torch.cat([hidden1, hidden1], dim=0) 132 | h2 = torch.cat([hidden2, hidden2.flip([0])], dim=0) 133 | cl_label = torch.cat([torch.ones(hidden1.size(0)), torch.zeros(hidden1.size(0))], dim=0).to(device) 134 | pair_dist = F.pairwise_distance(h1, h2) 135 | ssl_loss_aug = cl_criterion(pair_dist, cl_label, margin, use_robust_loss, start_fine) if lamb2 > 0 else 0 136 | 137 | # loss5: uniformity I(H, H') 138 | batchsize = 1024 139 | tmp_cand = torch.randperm(len(edge_label)) 140 | sampled_index = tmp_cand[:batchsize] # boost uniformity with data sampling 141 | uniformity_loss = lunif(hidden[sampled_index, :]) if lamb3 > 0 else 0 142 | 143 | # final objective of RGIB-SSL 144 | sup_loss = lamb1 * (sup_loss_ori + sup_loss_aug) 145 | align_loss = lamb2 * (ssl_loss_ori + ssl_loss_aug) 146 | uni_loss = lamb3 * uniformity_loss 147 | loss = sup_loss + align_loss + uni_loss 148 | 149 | loss.backward() 150 | optimizer.step() 151 | return 152 | 153 | @torch.no_grad() 154 | def test(data): 155 | model.eval() 156 | z = model.encode(data.x, data.edge_index) 157 | hidden, out = model.decode(z, data.edge_label_index) 158 | out = out.view(-1).sigmoid() 159 | return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()) 160 | 161 | best_val_auc = best_test_auc = 0 162 | best_epoch = 0 163 | time_list = [] 164 | for epoch in range(MAX_EPOCH): 165 | train(epoch_idx=epoch) 166 | val_auc = test(val_data) 167 | test_auc = test(test_data) 168 | 169 | if val_auc > best_val_auc: 170 | best_epoch = epoch 171 | best_val_auc = val_auc 172 | best_test_auc = test_auc 173 | 174 | test_auc_list.append(best_test_auc) 175 | val_auc_list.append(best_val_auc) 176 | best_epoch_list.append(best_epoch) 177 | 178 | # verbose the training results 179 | if verbose: 180 | print(f'==> data={dataset_name}, type={noise_type}, ratio={noise_ratio}, model={model_name}, num_gnn_layers={num_gnn_layers}, repeat_time={repeat_times}') 181 | print(f'==> VAL: mean={np.mean(val_auc_list)}, std={np.std(val_auc_list)}, max={np.max(val_auc_list)}, min={np.min(val_auc_list)}') 182 | print(f'==> TEST: mean={np.mean(test_auc_list)}, std={np.std(test_auc_list)}, max={np.max(test_auc_list)}, min={np.min(test_auc_list)}') 183 | print('*'*50) 184 | 185 | return np.mean(val_auc_list) 186 | 187 | ### Parse args ### 188 | parser = argparse.ArgumentParser() 189 | parser_add_main_args(parser) 190 | args = parser.parse_args() 191 | print(args) 192 | print('*'*50) 193 | 194 | seed = 1 195 | torch.manual_seed(seed) 196 | random.seed(seed) 197 | np.random.seed(seed) 198 | torch.cuda.manual_seed(seed) 199 | torch.cuda.manual_seed_all(seed) 200 | torch.backends.cudnn.deterministic =True 201 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 202 | 203 | if args.search_scheduler: 204 | # mutiple trials to search the optimal scheduler 205 | assert args.search_iteration > 0 206 | candidate_schduler = ['linear', 'exp', 'sin', 'cos', 'constant'] 207 | schedule_search_space = { 208 | 'scheduler': hp.choice('scheduler', candidate_schduler), 209 | 'scheduler_param': hp.uniform('scheduler_param', 0, 1), 210 | } 211 | 212 | def run_model(params): 213 | args.scheduler = params['scheduler'] 214 | args.scheduler_param = params['scheduler_param'] 215 | mean_val_auc = standard_train_trial(args.rel_path, args.dataset, args.noise_type, args.noise_ratio, args.gnn_model, args.num_gnn_layers, device, args.repeat_times) 216 | return (1-mean_val_auc) 217 | 218 | # searching 219 | best = fmin(run_model, schedule_search_space, algo=tpe.suggest, max_evals=int(args.search_iteration), verbose=False) 220 | print(f'==> optimal scheduler:{best}') 221 | 222 | # final trial 223 | args.scheduler = candidate_schduler[best['scheduler']] 224 | args.scheduler_param = best['scheduler_param'] 225 | standard_train_trial(args.rel_path, args.dataset, args.noise_type, args.noise_ratio, args.gnn_model, args.num_gnn_layers, device, args.repeat_times) 226 | 227 | else: 228 | # single trial 229 | standard_train_trial(args.rel_path, args.dataset, args.noise_type, args.noise_ratio, args.gnn_model, args.num_gnn_layers, device, args.repeat_times) 230 | -------------------------------------------------------------------------------- /code/RGIB-rep-training.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import argparse 4 | import os.path as osp 5 | import numpy as np 6 | from tqdm import tqdm 7 | import torch 8 | from torch.nn import ModuleList 9 | import torch.nn.functional as F 10 | from sklearn.metrics import roc_auc_score 11 | import torch_geometric.transforms as T 12 | from torch_geometric.utils import negative_sampling 13 | from torch_geometric.nn import GCNConv, SAGEConv, GATConv, MLP 14 | from utils import * 15 | from loss import * 16 | from hyperopt import hp, fmin, tpe 17 | 18 | ''' 19 | Uasge of this python scipt, e.g., 20 | python3 RGIB-rep-training.py --gnn_model GCN --num_gnn_layers 4 --dataset Cora --noise_ratio 0.2 --scheduler constant --scheduler_param 1.0 21 | python3 RGIB-rep-training.py --gnn_model GCN --num_gnn_layers 4 --dataset Cora --noise_ratio 0.2 --search_scheduler --search_iteration 50 22 | ''' 23 | 24 | SAMPLING_RATIO = 1.0 25 | def sampling_MI(prob, tau=0.8, reduction='none'): 26 | prob = prob.clamp(1e-4, 1-1e-4) 27 | entropy1 = prob * torch.log(prob / tau) 28 | entropy2 = (1-prob) * torch.log((1-prob) / (1-tau)) 29 | res = entropy1 + entropy2 30 | if reduction == 'none': 31 | return res 32 | elif reduction == 'mean': 33 | return torch.mean(res) 34 | elif reduction == 'sum': 35 | return torch.sum(res) 36 | 37 | class GCN(torch.nn.Module): 38 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers): 39 | super().__init__() 40 | self.convs = ModuleList() 41 | self.convs.append(GCNConv(in_channels, hidden_channels)) 42 | for i in range(0, num_layers-2): 43 | self.convs.append(GCNConv(hidden_channels, hidden_channels)) 44 | self.convs.append(GCNConv(hidden_channels, out_channels)) 45 | 46 | def encode(self, x, edge_index): 47 | # get edge logits from original X 48 | z = x.clone() 49 | for conv in self.convs[:-1]: 50 | z = conv(z, edge_index).relu() 51 | self.tmp_z = self.convs[-1](z, edge_index) 52 | edge_logit = (self.tmp_z[edge_index[0]] * self.tmp_z[edge_index[1]]).sum(dim=-1) 53 | edge_weight = torch.nn.Sigmoid()(edge_logit) 54 | if self.training: self.encode_edge_weight = edge_weight 55 | # edge sampling 56 | sampled_index = (edge_weight > SAMPLING_RATIO * torch.rand_like(edge_weight)).detach() 57 | new_edge_index = edge_index[:, sampled_index] 58 | # forward with sampled edges 59 | for conv in self.convs[:-1]: 60 | x = conv(x, new_edge_index).relu() 61 | x = self.convs[-1](x, edge_index) 62 | return x 63 | 64 | def decode(self, x, z, edge_label_index): 65 | edge_logit = (self.tmp_z[edge_label_index[0]] * self.tmp_z[edge_label_index[1]]).sum(dim=-1) 66 | pos_weight = torch.nn.Sigmoid()(edge_logit) 67 | neg_weight = torch.ones_like(pos_weight) 68 | edge_weight = torch.cat([neg_weight.unsqueeze(1), pos_weight.unsqueeze(1)], dim=1) 69 | if self.training: self.decode_edge_weight = pos_weight 70 | 71 | hidden = z[edge_label_index[0]] * z[edge_label_index[1]] 72 | logits = (hidden).sum(dim=-1) 73 | hidden = F.normalize(hidden, dim=1) 74 | return hidden, logits, edge_weight 75 | 76 | class SAGE(torch.nn.Module): 77 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers): 78 | super().__init__() 79 | self.convs = ModuleList() 80 | self.convs.append(SAGEConv(in_channels, hidden_channels)) 81 | for i in range(0, num_layers-2): 82 | self.convs.append(SAGEConv(hidden_channels, hidden_channels)) 83 | self.convs.append(SAGEConv(hidden_channels, out_channels)) 84 | 85 | def encode(self, x, edge_index): 86 | # get edge logits from original X 87 | z = x.clone() 88 | for conv in self.convs[:-1]: 89 | z = conv(z, edge_index).relu() 90 | self.tmp_z = self.convs[-1](z, edge_index) 91 | edge_logit = (self.tmp_z[edge_index[0]] * self.tmp_z[edge_index[1]]).sum(dim=-1) 92 | edge_weight = torch.nn.Sigmoid()(edge_logit) 93 | if self.training: self.encode_edge_weight = edge_weight 94 | # edge sampling 95 | sampled_index = (edge_weight > SAMPLING_RATIO * torch.rand_like(edge_weight)).detach() 96 | new_edge_index = edge_index[:, sampled_index] 97 | # forward with sampled edges 98 | for conv in self.convs[:-1]: 99 | x = conv(x, new_edge_index).relu() 100 | x = self.convs[-1](x, edge_index) 101 | return x 102 | 103 | def decode(self, x, z, edge_label_index): 104 | edge_logit = (self.tmp_z[edge_label_index[0]] * self.tmp_z[edge_label_index[1]]).sum(dim=-1) 105 | pos_weight = torch.nn.Sigmoid()(edge_logit) 106 | # neg_weight = 1 - pos_weight 107 | neg_weight = torch.ones_like(pos_weight) 108 | edge_weight = torch.cat([neg_weight.unsqueeze(1), pos_weight.unsqueeze(1)], dim=1) 109 | if self.training: self.decode_edge_weight = pos_weight 110 | hidden = z[edge_label_index[0]] * z[edge_label_index[1]] 111 | logits = (hidden).sum(dim=-1) 112 | hidden = F.normalize(hidden, dim=1) 113 | return hidden, logits, edge_weight 114 | 115 | class GAT(torch.nn.Module): 116 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, heads=8, att_dropout=0): 117 | super().__init__() 118 | self.convs = ModuleList() 119 | self.convs.append(GATConv(in_channels, hidden_channels//heads, heads=heads, dropout=att_dropout)) 120 | for i in range(0, num_layers-2): 121 | self.convs.append(GATConv(hidden_channels, hidden_channels//heads, heads=heads, dropout=att_dropout)) 122 | self.convs.append(GATConv(hidden_channels, out_channels, dropout=att_dropout)) 123 | 124 | def encode(self, x, edge_index): 125 | # get edge logits from original X 126 | z = x.clone() 127 | for conv in self.convs[:-1]: 128 | z = conv(z, edge_index).relu() 129 | self.tmp_z = self.convs[-1](z, edge_index) 130 | edge_logit = (self.tmp_z[edge_index[0]] * self.tmp_z[edge_index[1]]).sum(dim=-1) 131 | edge_weight = torch.nn.Sigmoid()(edge_logit) 132 | if self.training: self.encode_edge_weight = edge_weight 133 | # edge sampling 134 | sampled_index = (edge_weight > SAMPLING_RATIO * torch.rand_like(edge_weight)).detach() 135 | new_edge_index = edge_index[:, sampled_index] 136 | # forward with sampled edges 137 | for conv in self.convs[:-1]: 138 | x = conv(x, new_edge_index).relu() 139 | x = self.convs[-1](x, edge_index) 140 | return x 141 | 142 | def decode(self, x, z, edge_label_index): 143 | edge_logit = (self.tmp_z[edge_label_index[0]] * self.tmp_z[edge_label_index[1]]).sum(dim=-1) 144 | pos_weight = torch.nn.Sigmoid()(edge_logit) 145 | neg_weight = torch.ones_like(pos_weight) 146 | edge_weight = torch.cat([neg_weight.unsqueeze(1), pos_weight.unsqueeze(1)], dim=1) 147 | if self.training: self.decode_edge_weight = pos_weight 148 | hidden = z[edge_label_index[0]] * z[edge_label_index[1]] 149 | logits = (hidden).sum(dim=-1) 150 | hidden = F.normalize(hidden, dim=1) 151 | return hidden, logits, edge_weight 152 | 153 | def getGNNArch_tmp(GNN_name): 154 | assert GNN_name in ['GCN', 'GAT', 'SAGE'] 155 | if GNN_name == 'GCN': 156 | return GCN 157 | elif GNN_name == 'GAT': 158 | return GAT 159 | elif GNN_name == 'SAGE': 160 | return SAGE 161 | 162 | def standard_train_trial(rel_path, dataset_name, noise_type, noise_ratio, model_name, num_gnn_layers, device, repeat_times, verbose=True): 163 | path, dataset = getDataset(dataset_name, device, rel_path) 164 | Net = getGNNArch_tmp(model_name) 165 | test_auc_list, val_auc_list, best_epoch_list = [], [], [] 166 | 167 | MAX_EPOCH = 1000 168 | if verbose: print(f'==> schedule={args.scheduler}, param={args.scheduler_param}') 169 | assert args.scheduler in ['linear', 'exp', 'sin', 'cos', 'constant'] 170 | if args.scheduler == 'linear': 171 | lamb_scheduler = np.linspace(0, 1, MAX_EPOCH) * args.scheduler_param 172 | elif args.scheduler == 'exp': 173 | lamb_scheduler = np.array([math.exp(-t/MAX_EPOCH) for t in range(MAX_EPOCH)]) * args.scheduler_param 174 | elif args.scheduler == 'sin': 175 | lamb_scheduler = np.array([math.sin(t/MAX_EPOCH * math.pi * 0.5) for t in range(MAX_EPOCH)]) * args.scheduler_param 176 | elif args.scheduler == 'cos': 177 | lamb_scheduler = np.array([math.cos(t/MAX_EPOCH * math.pi * 0.5) for t in range(MAX_EPOCH)]) * args.scheduler_param 178 | elif args.scheduler == 'constant': 179 | lamb_scheduler = np.array([args.scheduler_param] * MAX_EPOCH) 180 | 181 | for idx in tqdm(range(repeat_times), ncols=50, leave=False): 182 | savePath = f'{path}/{dataset_name}/processed/{noise_type}_noise_ratio_{noise_ratio}_repeat_{idx+1}.pt' 183 | data = torch.load(savePath) 184 | (train_data, val_data, test_data) = data 185 | model = Net(dataset.num_features, 128, 64, num_gnn_layers).to(device) 186 | optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) 187 | criterion = torch.nn.BCEWithLogitsLoss(reduction='none') 188 | 189 | def train(epoch): 190 | model.train() 191 | optimizer.zero_grad() 192 | z = model.encode(train_data.x, train_data.edge_index) 193 | 194 | # a new round of negative sampling for every training epoch 195 | neg_edge_index = negative_sampling( 196 | edge_index=train_data.edge_index, num_nodes=train_data.num_nodes, 197 | num_neg_samples=train_data.edge_label_index.size(1), method='sparse') 198 | edge_label_index = torch.cat( 199 | [train_data.edge_label_index, neg_edge_index], 200 | dim=-1, 201 | ) 202 | edge_label = torch.cat([ 203 | train_data.edge_label, 204 | train_data.edge_label.new_zeros(neg_edge_index.size(1)) 205 | ], dim=0) 206 | 207 | sample_idx = torch.arange(len(edge_label)) 208 | hidden, out, weight = model.decode(train_data.x, z, edge_label_index) 209 | out = out.view(-1) 210 | 211 | # loss1: supervised loss 212 | tmp_loss = criterion(out, edge_label) 213 | sample_weigtht = weight[sample_idx, edge_label.long()] 214 | sample_weigtht = sample_weigtht.detach() 215 | if args.gnn_model in ['GCN', 'SAGE']: 216 | sup_loss = torch.mean(tmp_loss * sample_weigtht) 217 | elif args.gnn_model == 'GAT': 218 | sup_loss = torch.mean(tmp_loss) 219 | 220 | # loss2: information regularizer 221 | regu_A = sampling_MI(model.encode_edge_weight, reduction='mean') 222 | regu_Y = sampling_MI(model.decode_edge_weight, reduction='mean') 223 | 224 | # final objective of RGIB-REP 225 | lamb = lamb_scheduler[epoch] 226 | loss = lamb * sup_loss + (1-lamb) * regu_A + (1-lamb) * regu_Y 227 | 228 | loss.backward() 229 | optimizer.step() 230 | return 231 | 232 | @torch.no_grad() 233 | def test(data): 234 | model.eval() 235 | z = model.encode(data.x, data.edge_index) 236 | hidden, out, weight = model.decode(data.x, z, data.edge_label_index) 237 | out = out.view(-1).sigmoid() 238 | return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()) 239 | 240 | best_val_auc = best_test_auc = 0 241 | best_epoch = 0 242 | time_list = [] 243 | for epoch in range(MAX_EPOCH): 244 | train(epoch) 245 | val_auc = test(val_data) 246 | test_auc = test(test_data) 247 | 248 | if val_auc > best_val_auc: 249 | best_epoch = epoch 250 | best_val_auc = val_auc 251 | best_test_auc = test_auc 252 | 253 | # print(best_test_auc, best_val_auc, best_epoch) 254 | test_auc_list.append(best_test_auc) 255 | val_auc_list.append(best_val_auc) 256 | best_epoch_list.append(best_epoch) 257 | 258 | # verbose the training results 259 | if verbose: 260 | print(f'==> data={dataset_name}, type={noise_type}, ratio={noise_ratio}, model={model_name}, num_gnn_layers={num_gnn_layers}, repeat_time={repeat_times}') 261 | print(f'==> VAL: mean={np.mean(val_auc_list)}, std={np.std(val_auc_list)}, max={np.max(val_auc_list)}, min={np.min(val_auc_list)}') 262 | print(f'==> TEST: mean={np.mean(test_auc_list)}, std={np.std(test_auc_list)}, max={np.max(test_auc_list)}, min={np.min(test_auc_list)}') 263 | print('*'*50) 264 | 265 | return np.mean(val_auc_list) 266 | 267 | ### Parse args ### 268 | parser = argparse.ArgumentParser() 269 | parser_add_main_args(parser) 270 | args = parser.parse_args() 271 | print(args) 272 | print('*'*50) 273 | 274 | seed = 1 275 | torch.manual_seed(seed) 276 | random.seed(seed) 277 | np.random.seed(seed) 278 | torch.cuda.manual_seed(seed) 279 | torch.cuda.manual_seed_all(seed) 280 | torch.backends.cudnn.deterministic =True 281 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 282 | 283 | if args.search_scheduler: 284 | # mutiple trials to search the optimal scheduler 285 | assert args.search_iteration > 0 286 | candidate_schduler = ['linear', 'exp', 'sin', 'cos', 'constant'] 287 | schedule_search_space = { 288 | 'scheduler': hp.choice('scheduler', candidate_schduler), 289 | 'scheduler_param': hp.uniform('scheduler_param', 0, 1), 290 | } 291 | 292 | def run_model(params): 293 | args.scheduler = params['scheduler'] 294 | args.scheduler_param = params['scheduler_param'] 295 | mean_val_auc = standard_train_trial(args.rel_path, args.dataset, args.noise_type, args.noise_ratio, args.gnn_model, args.num_gnn_layers, device, args.repeat_times) 296 | return (1-mean_val_auc) 297 | 298 | # searching 299 | best = fmin(run_model, schedule_search_space, algo=tpe.suggest, max_evals=int(args.search_iteration), verbose=False) 300 | print(f'==> optimal scheduler:{best}') 301 | 302 | # final trial 303 | args.scheduler = candidate_schduler[best['scheduler']] 304 | args.scheduler_param = best['scheduler_param'] 305 | standard_train_trial(args.rel_path, args.dataset, args.noise_type, args.noise_ratio, args.gnn_model, args.num_gnn_layers, device, args.repeat_times) 306 | 307 | else: 308 | # single trial 309 | standard_train_trial(args.rel_path, args.dataset, args.noise_type, args.noise_ratio, args.gnn_model, args.num_gnn_layers, device, args.repeat_times) 310 | --------------------------------------------------------------------------------