├── Accuracy.txt ├── README.md ├── pGRACE ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── dataset.cpython-37.pyc │ ├── eval.cpython-37.pyc │ ├── functional.cpython-37.pyc │ ├── model.cpython-37.pyc │ └── utils.cpython-37.pyc ├── dataset.py ├── eval.py ├── functional.py ├── model.py └── utils.py ├── param ├── amazon_computers.json ├── amazon_photo.json ├── coauthor_cs.json └── wikics.json ├── simple_param ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── sp.cpython-37.pyc └── sp.py └── train.py /Accuracy.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxia97/ProGCL/3580f640b3e0c73dde6757b9a4196d5dba008ebd/Accuracy.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProGCL: Rethinking Hard Negative Mining in Graph Contrastive Learning (ICML 2022) 2 | PyTorch implementation for [ProGCL: Rethinking Hard Negative Mining in Graph Contrastive Learning](https://arxiv.org/abs/2110.02027) accepted by ICML 2022. 3 | ## Requirements 4 | * Python 3.7.4 5 | * PyTorch 1.7.0 6 | * torch_geometric 1.5.0 7 | * tqdm 8 | ## Training & Evaluation 9 | ProGCL-weight: 10 | ``` 11 | python train.py --device cuda:0 --dataset Amazon-Computers --param local:amazon-computers.json --mode weight 12 | ``` 13 | ProGCL-mix: 14 | ``` 15 | python train.py --device cuda:0 --dataset Amazon-Computers --param local:amazon-computers.json --mode mix 16 | ``` 17 | ## Useful resources for Pretrained Graphs Neural Networks 18 | * The first comprehensive survey on this topic: [A Survey of Pretraining on Graphs: Taxonomy, Methods, and Applications](https://arxiv.org/abs/2202.07893v1) 19 | * [A curated list of must-read papers, open-source pretrained models and pretraining datasets.](https://github.com/junxia97/awesome-pretrain-on-graphs) 20 | 21 | ## Citation 22 | ``` 23 | @inproceedings{xia2022progcl, 24 | title={ProGCL: Rethinking Hard Negative Mining in Graph Contrastive Learning}, 25 | author={Xia, Jun and Wu, Lirong and Wang, Ge and Li, Stan Z.}, 26 | booktitle={International conference on machine learning}, 27 | year={2022}, 28 | organization={PMLR} 29 | } 30 | ``` 31 | -------------------------------------------------------------------------------- /pGRACE/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxia97/ProGCL/3580f640b3e0c73dde6757b9a4196d5dba008ebd/pGRACE/__init__.py -------------------------------------------------------------------------------- /pGRACE/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxia97/ProGCL/3580f640b3e0c73dde6757b9a4196d5dba008ebd/pGRACE/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxia97/ProGCL/3580f640b3e0c73dde6757b9a4196d5dba008ebd/pGRACE/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxia97/ProGCL/3580f640b3e0c73dde6757b9a4196d5dba008ebd/pGRACE/__pycache__/eval.cpython-37.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxia97/ProGCL/3580f640b3e0c73dde6757b9a4196d5dba008ebd/pGRACE/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxia97/ProGCL/3580f640b3e0c73dde6757b9a4196d5dba008ebd/pGRACE/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxia97/ProGCL/3580f640b3e0c73dde6757b9a4196d5dba008ebd/pGRACE/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /pGRACE/dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from torch_geometric.datasets import Planetoid, CitationFull, WikiCS, Coauthor, Amazon 3 | import torch_geometric.transforms as T 4 | from ogb.nodeproppred import PygNodePropPredDataset 5 | 6 | def get_dataset(path, name): 7 | assert name in ['Cora', 'CiteSeer', 'PubMed', 'DBLP', 'Karate', 'WikiCS', 'Coauthor-CS', 'Coauthor-Phy', 8 | 'Amazon-Computers', 'Amazon-Photo', 'ogbn-arxiv', 'ogbg-code'] 9 | name = 'dblp' if name == 'DBLP' else name 10 | root_path = osp.expanduser('~/datasets') 11 | 12 | if name == 'Coauthor-CS': 13 | return Coauthor(root=path, name='cs', transform=T.NormalizeFeatures()) 14 | 15 | if name == 'WikiCS': 16 | return WikiCS(root=path, transform=T.NormalizeFeatures()) 17 | 18 | if name == 'Amazon-Computers': 19 | return Amazon(root=path, name='computers', transform=T.NormalizeFeatures()) 20 | 21 | if name == 'Amazon-Photo': 22 | return Amazon(root=path, name='photo', transform=T.NormalizeFeatures()) 23 | 24 | if name.startswith('ogbn'): 25 | return PygNodePropPredDataset(root=osp.join(root_path, 'OGB'), name=name, transform=T.NormalizeFeatures()) 26 | 27 | return (CitationFull if name == 'dblp' else Planetoid)(osp.join(root_path, 'Citation'), name, transform=T.NormalizeFeatures()) 28 | 29 | 30 | def get_path(base_path, name): 31 | if name in ['Cora', 'CiteSeer', 'PubMed']: 32 | return base_path 33 | else: 34 | return osp.join(base_path, name) 35 | -------------------------------------------------------------------------------- /pGRACE/eval.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch.optim import Adam 5 | import torch.nn as nn 6 | 7 | from pGRACE.model import LogReg 8 | 9 | 10 | def get_idx_split(dataset, split, preload_split): 11 | if split[:4] == 'rand': 12 | train_ratio = float(split.split(':')[1]) 13 | num_nodes = dataset[0].x.size(0) 14 | train_size = int(num_nodes * train_ratio) 15 | indices = torch.randperm(num_nodes) 16 | return { 17 | 'train': indices[:train_size], 18 | 'val': indices[train_size:2 * train_size], 19 | 'test': indices[2 * train_size:] 20 | } 21 | elif split == 'ogb': 22 | return dataset.get_idx_split() 23 | elif split.startswith('wikics'): 24 | split_idx = int(split.split(':')[1]) 25 | return { 26 | 'train': dataset[0].train_mask[:, split_idx], 27 | 'test': dataset[0].test_mask, 28 | 'val': dataset[0].val_mask[:, split_idx] 29 | } 30 | elif split == 'preloaded': 31 | assert preload_split is not None, 'use preloaded split, but preloaded_split is None' 32 | train_mask, test_mask, val_mask = preload_split 33 | return { 34 | 'train': train_mask, 35 | 'test': test_mask, 36 | 'val': val_mask 37 | } 38 | else: 39 | raise RuntimeError(f'Unknown split type {split}') 40 | 41 | 42 | def log_regression(z, 43 | dataset, 44 | evaluator, 45 | num_epochs: int = 5000, 46 | test_device: Optional[str] = None, 47 | split: str = 'rand:0.1', 48 | verbose: bool = False, 49 | preload_split=None): 50 | test_device = z.device if test_device is None else test_device 51 | z = z.detach().to(test_device) 52 | num_hidden = z.size(1) 53 | y = dataset[0].y.view(-1).to(test_device) 54 | num_classes = dataset[0].y.max().item() + 1 55 | classifier = LogReg(num_hidden, num_classes).to(test_device) 56 | optimizer = Adam(classifier.parameters(), lr=0.01, weight_decay=0.0) 57 | 58 | split = get_idx_split(dataset, split, preload_split) 59 | split = {k: v.to(test_device) for k, v in split.items()} 60 | print(split['train'].sum()) 61 | print(split['test'].sum()) 62 | print(split['val'].sum()) 63 | f = nn.LogSoftmax(dim=-1) 64 | nll_loss = nn.NLLLoss() 65 | 66 | best_test_acc = 0 67 | best_val_acc = 0 68 | best_epoch = 0 69 | 70 | for epoch in range(num_epochs): 71 | classifier.train() 72 | optimizer.zero_grad() 73 | 74 | output = classifier(z[split['train']]) 75 | loss = nll_loss(f(output), y[split['train']]) 76 | 77 | loss.backward() 78 | optimizer.step() 79 | 80 | if (epoch + 1) % 20 == 0: 81 | if 'val' in split: 82 | # val split is available 83 | test_acc = evaluator.eval({ 84 | 'y_true': y[split['test']].view(-1, 1), 85 | 'y_pred': classifier(z[split['test']]).argmax(-1).view(-1, 1) 86 | })['acc'] 87 | val_acc = evaluator.eval({ 88 | 'y_true': y[split['val']].view(-1, 1), 89 | 'y_pred': classifier(z[split['val']]).argmax(-1).view(-1, 1) 90 | })['acc'] 91 | if val_acc > best_val_acc: 92 | best_val_acc = val_acc 93 | best_test_acc = test_acc 94 | best_epoch = epoch 95 | else: 96 | acc = evaluator.eval({ 97 | 'y_true': y[split['test']].view(-1, 1), 98 | 'y_pred': classifier(z[split['test']]).argmax(-1).view(-1, 1) 99 | })['acc'] 100 | if best_test_acc < acc: 101 | best_test_acc = acc 102 | best_epoch = epoch 103 | if verbose: 104 | print(f'logreg epoch {epoch}: best test acc {best_test_acc}') 105 | 106 | return {'acc': best_test_acc} 107 | 108 | 109 | class MulticlassEvaluator: 110 | def __init__(self, *args, **kwargs): 111 | pass 112 | 113 | @staticmethod 114 | def _eval(y_true, y_pred): 115 | y_true = y_true.view(-1) 116 | y_pred = y_pred.view(-1) 117 | total = y_true.size(0) 118 | correct = (y_true == y_pred).to(torch.float32).sum() 119 | return (correct / total).item() 120 | 121 | def eval(self, res): 122 | return {'acc': self._eval(**res)} 123 | -------------------------------------------------------------------------------- /pGRACE/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import degree, to_undirected 3 | from pGRACE.utils import compute_pr, eigenvector_centrality 4 | 5 | 6 | def drop_feature(x, drop_prob): 7 | drop_mask = torch.empty((x.size(1),), dtype=torch.float32, device=x.device).uniform_(0, 1) < drop_prob 8 | x = x.clone() 9 | x[:, drop_mask] = 0 10 | return x 11 | 12 | 13 | def drop_feature_weighted(x, w, p: float, threshold: float = 0.7): 14 | w = w / w.mean() * p 15 | w = w.where(w < threshold, torch.ones_like(w) * threshold) 16 | drop_prob = w.repeat(x.size(0)).view(x.size(0), -1) 17 | 18 | drop_mask = torch.bernoulli(drop_prob).to(torch.bool) 19 | 20 | x = x.clone() 21 | x[drop_mask] = 0. 22 | 23 | return x 24 | 25 | 26 | def drop_feature_weighted_2(x, w, p: float, threshold: float = 0.7): 27 | w = w / w.mean() * p 28 | w = w.where(w < threshold, torch.ones_like(w) * threshold) 29 | drop_prob = w 30 | 31 | drop_mask = torch.bernoulli(drop_prob).to(torch.bool) 32 | 33 | x = x.clone() 34 | x[:, drop_mask] = 0. 35 | 36 | return x 37 | 38 | 39 | def feature_drop_weights(x, node_c): 40 | x = x.to(torch.bool).to(torch.float32) 41 | w = x.t() @ node_c 42 | w = w.log() 43 | s = (w.max() - w) / (w.max() - w.mean()) 44 | 45 | return s 46 | 47 | 48 | def feature_drop_weights_dense(x, node_c): 49 | x = x.abs() 50 | w = x.t() @ node_c 51 | w = w.log() 52 | s = (w.max() - w) / (w.max() - w.mean()) 53 | 54 | return s 55 | 56 | 57 | def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 1.): 58 | edge_weights = edge_weights / edge_weights.mean() * p 59 | edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold) 60 | sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool) 61 | 62 | return edge_index[:, sel_mask] 63 | 64 | 65 | def degree_drop_weights(edge_index): 66 | edge_index_ = to_undirected(edge_index) 67 | deg = degree(edge_index_[1]) 68 | deg_col = deg[edge_index[1]].to(torch.float32) 69 | s_col = torch.log(deg_col) 70 | weights = (s_col.max() - s_col) / (s_col.max() - s_col.mean()) 71 | 72 | return weights 73 | 74 | 75 | def pr_drop_weights(edge_index, aggr: str = 'sink', k: int = 10): 76 | pv = compute_pr(edge_index, k=k) 77 | pv_row = pv[edge_index[0]].to(torch.float32) 78 | pv_col = pv[edge_index[1]].to(torch.float32) 79 | s_row = torch.log(pv_row) 80 | s_col = torch.log(pv_col) 81 | if aggr == 'sink': 82 | s = s_col 83 | elif aggr == 'source': 84 | s = s_row 85 | elif aggr == 'mean': 86 | s = (s_col + s_row) * 0.5 87 | else: 88 | s = s_col 89 | weights = (s.max() - s) / (s.max() - s.mean()) 90 | 91 | return weights 92 | 93 | 94 | def evc_drop_weights(data): 95 | evc = eigenvector_centrality(data) 96 | evc = evc.where(evc > 0, torch.zeros_like(evc)) 97 | evc = evc + 1e-8 98 | s = evc.log() 99 | 100 | edge_index = data.edge_index 101 | s_row, s_col = s[edge_index[0]], s[edge_index[1]] 102 | s = s_col 103 | 104 | return (s.max() - s) / (s.max() - s.mean()) -------------------------------------------------------------------------------- /pGRACE/model.py: -------------------------------------------------------------------------------- 1 | from random import random 2 | from typing import Optional 3 | import matplotlib 4 | from torch._C import device 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from torch_geometric.nn import GCNConv 10 | import scipy.stats as stats 11 | from torch.distributions.beta import Beta 12 | 13 | class Encoder(nn.Module): 14 | def __init__(self, in_channels: int, out_channels: int, activation, base_model=GCNConv, k: int = 2, skip=False): 15 | super(Encoder, self).__init__() 16 | self.base_model = base_model 17 | assert k >= 2 18 | self.k = k 19 | self.skip = skip 20 | if not self.skip: 21 | # self.conv = [base_model(in_channels, 2 * out_channels).jittable()] 22 | self.conv = [base_model(in_channels, 2 * out_channels)] 23 | for _ in range(1, k - 1): 24 | self.conv.append(base_model(2 * out_channels, 2 * out_channels)) 25 | self.conv.append(base_model(2 * out_channels, out_channels)) 26 | self.conv = nn.ModuleList(self.conv) 27 | 28 | self.activation = activation 29 | else: 30 | self.fc_skip = nn.Linear(in_channels, out_channels) 31 | self.conv = [base_model(in_channels, out_channels)] 32 | for _ in range(1, k): 33 | self.conv.append(base_model(out_channels, out_channels)) 34 | self.conv = nn.ModuleList(self.conv) 35 | 36 | self.activation = activation 37 | 38 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor): 39 | if not self.skip: 40 | for i in range(self.k): 41 | x = self.activation(self.conv[i](x, edge_index)) 42 | return x 43 | else: 44 | h = self.activation(self.conv[0](x, edge_index)) 45 | hs = [self.fc_skip(x), h] 46 | for i in range(1, self.k): 47 | u = sum(hs) 48 | hs.append(self.activation(self.conv[i](u, edge_index))) 49 | return hs[-1] 50 | 51 | class GRACE(torch.nn.Module): 52 | def __init__(self, encoder: Encoder, num_hidden: int, num_proj_hidden: int, tau: float = 0.5): 53 | super(GRACE, self).__init__() 54 | self.encoder: Encoder = encoder 55 | self.tau: float = tau 56 | self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden) 57 | self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden) 58 | self.num_hidden = num_hidden 59 | 60 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: 61 | return self.encoder(x, edge_index) 62 | 63 | def projection(self, z: torch.Tensor) -> torch.Tensor: 64 | z = F.elu(self.fc1(z)) 65 | return self.fc2(z) 66 | 67 | def sim(self, z1: torch.Tensor, z2: torch.Tensor): 68 | z1 = F.normalize(z1) 69 | z2 = F.normalize(z2) 70 | return torch.mm(z1, z2.t()) 71 | 72 | def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor, epoch): 73 | f = lambda x: torch.exp(x / self.tau) 74 | refl_sim = self.sim(z1, z1) 75 | between_sim = self.sim(z1, z2) 76 | refl_sim = f(refl_sim) 77 | between_sim = f(between_sim) 78 | return -torch.log(between_sim.diag() / (between_sim.sum(1) + refl_sim.sum(1) - refl_sim.diag())) 79 | 80 | def semi_loss_bmm(self, z1: torch.Tensor, z2: torch.Tensor, epoch, args, bmm_model, fit = False): 81 | f = lambda x: torch.exp(x / self.tau) 82 | refl_sim = self.sim(z1, z1) 83 | between_sim = self.sim(z1, z2) 84 | N = between_sim.size(0) 85 | mask = torch.ones((N,N),dtype=bool).to(z1.device) 86 | mask[np.eye(N,dtype=bool)] = False 87 | if epoch == args.epoch_start and fit: 88 | global B 89 | N_sel = 100 90 | index_fit = np.random.randint(0, N, N_sel) 91 | sim_fit = between_sim[:,index_fit] 92 | sim_fit = (sim_fit + 1) / 2 # Min-Max Normalization 93 | bmm_model.fit(sim_fit.flatten()) 94 | between_sim_norm = between_sim.masked_select(mask).view(N, -1) 95 | between_sim_norm = (between_sim_norm + 1) / 2 96 | print('Computing positive probility,wait...') 97 | B = bmm_model.posterior(between_sim_norm,0) * between_sim_norm.detach() 98 | print('Over!') 99 | if args.mode == 'weight': 100 | refl_sim = f(refl_sim) 101 | between_sim = f(between_sim) 102 | ng_bet = (between_sim.masked_select(mask).view(N,-1) * B).sum(1) / B.mean(1) 103 | ng_refl = (refl_sim.masked_select(mask).view(N,-1) * B).sum(1) / B.mean(1) 104 | return -torch.log(between_sim.diag()/(between_sim.diag() + ng_bet + ng_refl)) 105 | elif args.mode == 'mix': 106 | eps = 1e-12 107 | sorted, indices = torch.sort(B, descending=True) 108 | N_sel = torch.gather(between_sim[mask].view(N,-1), -1, indices)[:,:args.sel_num] 109 | random_index = np.random.permutation(np.arange(args.sel_num)) 110 | N_random = N_sel[:,random_index] 111 | M = sorted[:,:args.sel_num] 112 | M_random = M[:,random_index] 113 | M = (N_sel * M + N_random * M_random) / (M + M_random + eps) 114 | refl_sim = f(refl_sim) 115 | between_sim = f(between_sim) 116 | M = f(M) 117 | return -torch.log(between_sim.diag()/(M.sum(1) + between_sim.sum(1) + refl_sim.sum(1) - refl_sim.diag())) 118 | else: 119 | print('Mode Error!') 120 | 121 | def batched_semi_loss(self, z1: torch.Tensor, z2: torch.Tensor, batch_size: int, epoch): 122 | # Space complexity: O(BN) (semi_loss: O(N^2)) 123 | device = z1.device 124 | num_nodes = z1.size(0) 125 | num_batches = (num_nodes - 1) // batch_size + 1 126 | f = lambda x: torch.exp(x / self.tau) 127 | indices = np.arange(0, num_nodes) 128 | losses = [] 129 | 130 | for i in range(num_batches): 131 | mask = indices[i * batch_size:(i + 1) * batch_size] 132 | refl_sim = self.sim(z1[mask], z1) # [B, N] 133 | between_sim = self.sim(z1[mask], z2) # [B, N] 134 | refl_sim = f(refl_sim) 135 | between_sim = f(refl_sim) 136 | losses.append(-torch.log(between_sim[:, i * batch_size:(i + 1) * batch_size].diag() 137 | / (refl_sim.sum(1) + between_sim.sum(1) 138 | - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag()))) 139 | 140 | return torch.cat(losses) 141 | 142 | def batched_semi_loss_bmm(self, z1: torch.Tensor, z2: torch.Tensor, batch_size: int, epoch, args, bmm_model, fit): 143 | device = z1.device 144 | num_nodes = z1.size(0) 145 | num_batches = (num_nodes - 1) // batch_size + 1 146 | f = lambda x: torch.exp(x / self.tau) 147 | indices = torch.arange(0, num_nodes).to(device) 148 | losses = [] 149 | global B 150 | B = [] 151 | for i in range(num_batches): 152 | index = indices[i * batch_size:(i + 1) * batch_size] 153 | neg_mask = torch.ones((batch_size, num_nodes),dtype=bool).to(device) 154 | pos_index = np.transpose(np.column_stack((np.arange(0,batch_size,1),np.arange(i*batch_size, (i + 1) * batch_size,1)))) 155 | neg_mask[pos_index] = False 156 | refl_sim = self.sim(z1[index], z1) 157 | between_sim = self.sim(z1[index], z2) 158 | if epoch == args.epoch_start and fit: 159 | N_sel = 100 160 | index_fit = np.random.randint(0, num_nodes, N_sel) 161 | sim_fit = between_sim[:,index_fit] 162 | sim_fit = (sim_fit - sim_fit.min()) / (sim_fit.max() - sim_fit.min()) 163 | bmm_model.fit(sim_fit.flatten()) 164 | between_sim_norm = between_sim.masked_select(neg_mask).view(batch_size,-1) 165 | between_sim_norm = (between_sim_norm - between_sim_norm.min()) / (between_sim_norm.max() - between_sim_norm.min()) 166 | print('Computing positive probility,wait...') 167 | B.append(bmm_model.posterior(between_sim_norm,0) * between_sim_norm.detach()) 168 | print('Over!') 169 | if args.mode == 'weight': 170 | refl_sim = f(refl_sim) 171 | between_sim = f(between_sim) 172 | ng_bet = (between_sim.masked_select(neg_mask).view(neg_mask.size(0),-1) * B[i]).sum(1) / B[i].mean(1) 173 | ng_refl = (refl_sim.masked_select(neg_mask).view(neg_mask.size(0),-1) * B[i]).sum(1) / B[i].mean(1) 174 | losses.append(-torch.log(between_sim.diag()/(between_sim.diag() + ng_bet + ng_refl))) 175 | return torch.cat(losses) 176 | elif args.mode == 'mix': 177 | eps = 1e-12 178 | B_sel, indices = torch.sort(B[i],descending=True) 179 | N_sel = torch.gather(between_sim, -1, indices) 180 | random_index = np.random.permutation(np.arange(N_sel.size(1))) 181 | N_sel_random = N_sel[:,random_index] 182 | B_sel_random = B_sel[:,random_index] 183 | M = (B_sel * N_sel + B_sel_random * N_sel_random) / (B_sel + B_sel_random + eps) 184 | refl_sim = f(refl_sim) 185 | between_sim = f(between_sim) 186 | M = f(M) 187 | losses.append(-torch.log(between_sim.diag()/(M.sum(1) + between_sim.sum(1) + refl_sim.sum(1) - refl_sim.diag()))) 188 | return torch.cat(losses) 189 | else: 190 | print('Mode Error!') 191 | return torch.cat(losses) 192 | 193 | def loss(self, z1: torch.Tensor, z2: torch.Tensor, epoch, args, bmm_model,mean: bool = True, batch_size: Optional[int] = None): 194 | 195 | h1 = self.projection(z1) 196 | h2 = self.projection(z2) 197 | if epoch < args.epoch_start: 198 | if batch_size is None: 199 | l1 = self.semi_loss(h1, h2, epoch) 200 | l2 = self.semi_loss(h2, h1, epoch) 201 | else: 202 | l1 = self.batched_semi_loss(h1, h2, batch_size, epoch) 203 | l2 = self.batched_semi_loss(h2, h1, batch_size, epoch) 204 | ret = (l1 + l2) * 0.5 205 | ret = ret.mean() if mean else ret.sum() 206 | else: 207 | if batch_size is None: 208 | l1 = self.semi_loss_bmm(h1, h2, epoch, args, bmm_model, fit = True) 209 | l2 = self.semi_loss_bmm(h2, h1, epoch, args, bmm_model) 210 | else: 211 | l1 = self.batched_semi_loss_bmm(h1, h2, batch_size, epoch, args, bmm_model, fit = True) 212 | l2 = self.batched_semi_loss_bmm(h2, h1, batch_size, epoch, args, bmm_model) 213 | ret = (l1 + l2) * 0.5 214 | ret = ret.mean() if mean else ret.sum() 215 | return ret 216 | 217 | class LogReg(nn.Module): 218 | def __init__(self, ft_in, nb_classes): 219 | super(LogReg, self).__init__() 220 | self.fc = nn.Linear(ft_in, nb_classes) 221 | 222 | for m in self.modules(): 223 | self.weights_init(m) 224 | 225 | def weights_init(self, m): 226 | if isinstance(m, nn.Linear): 227 | torch.nn.init.xavier_uniform_(m.weight.data) 228 | if m.bias is not None: 229 | m.bias.data.fill_(0.0) 230 | 231 | def forward(self, seq): 232 | ret = self.fc(seq) 233 | return ret 234 | 235 | def weighted_mean(x, w): 236 | return torch.sum(w * x) / torch.sum(w) 237 | 238 | def fit_beta_weighted(x, w): 239 | x_bar = weighted_mean(x, w) 240 | s2 = weighted_mean((x - x_bar)**2, w) 241 | alpha = x_bar * ((x_bar * (1 - x_bar)) / s2 - 1) 242 | beta = alpha * (1 - x_bar) /x_bar 243 | return alpha, beta 244 | 245 | class BetaMixture1D(object): 246 | def __init__(self, max_iters, 247 | alphas_init, 248 | betas_init, 249 | weights_init): 250 | self.alphas = alphas_init 251 | self.betas = betas_init 252 | self.weight = weights_init 253 | self.max_iters = max_iters 254 | self.eps_nan = 1e-12 255 | 256 | def likelihood(self, x, y): 257 | x_cpu = x.cpu().detach().numpy() 258 | alpha_cpu = self.alphas.cpu().detach().numpy() 259 | beta_cpu = self.betas.cpu().detach().numpy() 260 | return torch.from_numpy(stats.beta.pdf(x_cpu, alpha_cpu[y], beta_cpu[y])).to(x.device) 261 | 262 | def weighted_likelihood(self, x, y): 263 | return self.weight[y] * self.likelihood(x, y) 264 | 265 | def probability(self, x): 266 | return self.weighted_likelihood(x, 0) + self.weighted_likelihood(x, 1) 267 | 268 | def posterior(self, x, y): 269 | return self.weighted_likelihood(x, y) / (self.probability(x) + self.eps_nan) 270 | 271 | def responsibilities(self, x): 272 | r = torch.cat((self.weighted_likelihood(x, 0).view(1,-1),self.weighted_likelihood(x, 1).view(1,-1)),0) 273 | r[r <= self.eps_nan] = self.eps_nan 274 | r /= r.sum(0) 275 | return r 276 | 277 | def fit(self, x): 278 | eps = 1e-12 279 | x[x >= 1 - eps] = 1 - eps 280 | x[x <= eps] = eps 281 | 282 | for i in range(self.max_iters): 283 | # E-step 284 | r = self.responsibilities(x) 285 | # M-step 286 | self.alphas[0], self.betas[0] = fit_beta_weighted(x, r[0]) 287 | self.alphas[1], self.betas[1] = fit_beta_weighted(x, r[1]) 288 | if self.betas[1] < 1: 289 | self.betas[1] = 1.01 290 | self.weight = r.sum(1) 291 | self.weight /= self.weight.sum() 292 | return self 293 | 294 | def predict(self, x): 295 | return self.posterior(x, 1) > 0.5 296 | 297 | def __str__(self): 298 | return 'BetaMixture1D(w={}, a={}, b={})'.format(self.weight, self.alphas, self.betas) 299 | -------------------------------------------------------------------------------- /pGRACE/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import scipy.sparse as sp 4 | from scipy.special import iv 5 | from scipy.sparse.linalg import eigsh 6 | import os.path as osp 7 | from sklearn.cluster import KMeans, SpectralClustering 8 | from sklearn.manifold import SpectralEmbedding 9 | from tqdm import tqdm 10 | from matplotlib import cm 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.distributions.categorical import Categorical 15 | from torch.optim import Adam 16 | from torch.utils.data import random_split 17 | from torch_geometric.nn import GCNConv, SGConv, SAGEConv, GATConv, GraphConv, GINConv 18 | from torch_geometric.utils import sort_edge_index, degree, add_remaining_self_loops, remove_self_loops, get_laplacian, \ 19 | to_undirected, to_dense_adj, to_networkx 20 | from torch_geometric.datasets import KarateClub 21 | from torch_scatter import scatter 22 | import torch_sparse 23 | 24 | import networkx as nx 25 | import matplotlib.pyplot as plt 26 | 27 | 28 | def get_base_model(name: str): 29 | def gat_wrapper(in_channels, out_channels): 30 | return GATConv( 31 | in_channels=in_channels, 32 | out_channels=out_channels // 4, 33 | heads=4 34 | ) 35 | 36 | def gin_wrapper(in_channels, out_channels): 37 | mlp = nn.Sequential( 38 | nn.Linear(in_channels, 2 * out_channels), 39 | nn.ELU(), 40 | nn.Linear(2 * out_channels, out_channels) 41 | ) 42 | return GINConv(mlp) 43 | 44 | base_models = { 45 | 'GCNConv': GCNConv, 46 | 'SGConv': SGConv, 47 | 'SAGEConv': SAGEConv, 48 | 'GATConv': gat_wrapper, 49 | 'GraphConv': GraphConv, 50 | 'GINConv': gin_wrapper 51 | } 52 | 53 | return base_models[name] 54 | 55 | 56 | def get_activation(name: str): 57 | activations = { 58 | 'relu': F.relu, 59 | 'hardtanh': F.hardtanh, 60 | 'elu': F.elu, 61 | 'leakyrelu': F.leaky_relu, 62 | 'prelu': torch.nn.PReLU(), 63 | 'rrelu': F.rrelu 64 | } 65 | 66 | return activations[name] 67 | 68 | 69 | def compute_pr(edge_index, damp: float = 0.85, k: int = 10): 70 | num_nodes = edge_index.max().item() + 1 71 | deg_out = degree(edge_index[0]) 72 | x = torch.ones((num_nodes, )).to(edge_index.device).to(torch.float32) 73 | 74 | for i in range(k): 75 | edge_msg = x[edge_index[0]] / deg_out[edge_index[0]] 76 | agg_msg = scatter(edge_msg, edge_index[1], reduce='sum') 77 | 78 | x = (1 - damp) * x + damp * agg_msg 79 | 80 | return x 81 | 82 | 83 | def eigenvector_centrality(data): 84 | graph = to_networkx(data) 85 | x = nx.eigenvector_centrality_numpy(graph) 86 | x = [x[i] for i in range(data.num_nodes)] 87 | return torch.tensor(x, dtype=torch.float32).to(data.edge_index.device) 88 | 89 | 90 | def generate_split(num_samples: int, train_ratio: float, val_ratio: float): 91 | train_len = int(num_samples * train_ratio) 92 | val_len = int(num_samples * val_ratio) 93 | test_len = num_samples - train_len - val_len 94 | 95 | train_set, test_set, val_set = random_split(torch.arange(0, num_samples), (train_len, test_len, val_len)) 96 | 97 | idx_train, idx_test, idx_val = train_set.indices, test_set.indices, val_set.indices 98 | train_mask = torch.zeros((num_samples,)).to(torch.bool) 99 | test_mask = torch.zeros((num_samples,)).to(torch.bool) 100 | val_mask = torch.zeros((num_samples,)).to(torch.bool) 101 | 102 | train_mask[idx_train] = True 103 | test_mask[idx_test] = True 104 | val_mask[idx_val] = True 105 | 106 | return train_mask, test_mask, val_mask 107 | 108 | -------------------------------------------------------------------------------- /param/amazon_computers.json: -------------------------------------------------------------------------------- 1 | { 2 | "learning_rate": 0.01, 3 | "num_hidden": 128, 4 | "num_proj_hidden": 128, 5 | "activation": "rrelu", 6 | "drop_edge_rate_1": 0.6, 7 | "drop_edge_rate_2": 0.3, 8 | "drop_feature_rate_1": 0.2, 9 | "drop_feature_rate_2": 0.3, 10 | "tau": 0.2, 11 | "num_epochs": 2000 12 | } -------------------------------------------------------------------------------- /param/amazon_photo.json: -------------------------------------------------------------------------------- 1 | { 2 | "learning_rate": 0.01, 3 | "num_hidden": 128, 4 | "num_proj_hidden": 128, 5 | "activation": "rrelu", 6 | "drop_edge_rate_1": 0.6, 7 | "drop_edge_rate_2": 0.3, 8 | "drop_feature_rate_1": 0.2, 9 | "drop_feature_rate_2": 0.3, 10 | "tau": 0.2, 11 | "num_epochs": 2500, 12 | "epoch_start": 400, 13 | "weight_init": 0.15 14 | } -------------------------------------------------------------------------------- /param/coauthor_cs.json: -------------------------------------------------------------------------------- 1 | { 2 | "learning_rate": 0.0001, 3 | "num_hidden": 256, 4 | "num_proj_hidden": 256, 5 | "activation": "rrelu", 6 | "drop_edge_rate_1": 0.3, 7 | "drop_edge_rate_2": 0.2, 8 | "drop_feature_rate_1": 0.3, 9 | "drop_feature_rate_2": 0.4, 10 | "tau": 0.2, 11 | "num_epochs": 1000, 12 | "epoch_start": 400, 13 | "weight_init": 0.05 14 | } -------------------------------------------------------------------------------- /param/wikics.json: -------------------------------------------------------------------------------- 1 | { 2 | "learning_rate": 0.01, 3 | "num_hidden": 256, 4 | "num_proj_hidden": 256, 5 | "activation": "prelu", 6 | "drop_edge_rate_1": 0.2, 7 | "drop_edge_rate_2": 0.3, 8 | "drop_feature_rate_1": 0.1, 9 | "drop_feature_rate_2": 0.1, 10 | "tau": 0.4, 11 | "num_epochs": 3000 12 | } -------------------------------------------------------------------------------- /simple_param/__init__.py: -------------------------------------------------------------------------------- 1 | import simple_param.sp as sp -------------------------------------------------------------------------------- /simple_param/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxia97/ProGCL/3580f640b3e0c73dde6757b9a4196d5dba008ebd/simple_param/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /simple_param/__pycache__/sp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junxia97/ProGCL/3580f640b3e0c73dde6757b9a4196d5dba008ebd/simple_param/__pycache__/sp.cpython-37.pyc -------------------------------------------------------------------------------- /simple_param/sp.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import os.path as osp 3 | import json 4 | import yaml 5 | 6 | class SimpleParam: 7 | def __init__(self, local_dir: str = 'param', default: Optional[dict] = None): 8 | if default is None: 9 | default = dict() 10 | 11 | self.local_dir = local_dir 12 | self.default = default 13 | 14 | def __call__(self, source: str, preprocess: str = 'none'): 15 | if source.startswith('local'): 16 | ts = source.split(':') 17 | assert len(ts) == 2, 'local parameter file should be specified in a form of `local:FILE_NAME`' 18 | path = ts[-1] 19 | path = osp.join(self.local_dir, path) 20 | if path.endswith('.json'): 21 | loaded = parse_json(path) 22 | elif path.endswith('.yaml') or path.endswith('.yml'): 23 | loaded = parse_yaml(path) 24 | else: 25 | raise Exception('Invalid file name. Should end with .yaml or .json.') 26 | 27 | if preprocess == 'nni': 28 | loaded = preprocess_nni(loaded) 29 | 30 | return {**self.default, **loaded} 31 | if source == 'default': 32 | return self.default 33 | 34 | raise Exception('invalid source') 35 | 36 | 37 | def preprocess_nni(params: dict): 38 | def process_key(key: str): 39 | xs = key.split('/') 40 | if len(xs) == 3: 41 | return xs[1] 42 | elif len(xs) == 1: 43 | return key 44 | else: 45 | raise Exception('Unexpected param name ' + key) 46 | 47 | return { 48 | process_key(k): v for k, v in params.items() 49 | } 50 | 51 | 52 | def parse_yaml(path: str): 53 | content = open(path).read() 54 | return yaml.load(content, Loader=yaml.Loader) 55 | 56 | 57 | def parse_json(path: str): 58 | content = open(path).read() 59 | return json.loads(content) 60 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import os.path as osp 4 | import random 5 | import torch 6 | from torch_geometric.utils import dropout_adj, degree, to_undirected 7 | from simple_param.sp import SimpleParam 8 | from pGRACE.model import Encoder, GRACE, BetaMixture1D 9 | from pGRACE.functional import drop_feature, drop_edge_weighted, \ 10 | degree_drop_weights, \ 11 | evc_drop_weights, pr_drop_weights, \ 12 | feature_drop_weights, drop_feature_weighted_2, feature_drop_weights_dense 13 | from pGRACE.eval import log_regression, MulticlassEvaluator 14 | from pGRACE.utils import get_base_model, get_activation, \ 15 | generate_split, compute_pr, eigenvector_centrality 16 | from pGRACE.dataset import get_dataset 17 | import logging 18 | 19 | LOG_FORMAT = "%(levelname)s - %(message)s" 20 | DATE_FORMAT = "%m/%d/%Y %H:%M:%S %p" 21 | logging.basicConfig(filename='Accuracy.txt',level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT) 22 | 23 | def train(epoch, bmm_model): 24 | model.train() 25 | optimizer.zero_grad() 26 | 27 | def drop_edge(idx: int): 28 | global drop_weights 29 | if param['drop_scheme'] == 'uniform': 30 | return dropout_adj(data.edge_index, p=param[f'drop_edge_rate_{idx}'])[0] 31 | elif param['drop_scheme'] in ['degree', 'evc', 'pr']: 32 | return drop_edge_weighted(data.edge_index, drop_weights, p=param[f'drop_edge_rate_{idx}'], threshold=0.7) 33 | else: 34 | raise Exception(f'undefined drop scheme: {param["drop_scheme"]}') 35 | edge_index_1 = drop_edge(1) 36 | edge_index_2 = drop_edge(2) 37 | 38 | x_1 = drop_feature(data.x, param['drop_feature_rate_1']) 39 | x_2 = drop_feature(data.x, param['drop_feature_rate_2']) 40 | 41 | if param['drop_scheme'] in ['pr', 'degree', 'evc']: 42 | x_1 = drop_feature_weighted_2(data.x, feature_weights, param['drop_feature_rate_1']) 43 | x_2 = drop_feature_weighted_2(data.x, feature_weights, param['drop_feature_rate_2']) 44 | 45 | z1 = model(x_1, edge_index_1) 46 | z2 = model(x_2, edge_index_2) 47 | loss = model.loss(z1, z2, epoch, args, bmm_model, batch_size=512 if args.dataset == 'Coauthor-Phy' else None) 48 | loss.backward() 49 | optimizer.step() 50 | return loss.item() 51 | 52 | def test(final=False): 53 | model.eval() 54 | z = model(data.x, data.edge_index) 55 | nclass = dataset[0].y.max().item() + 1 56 | evaluator = MulticlassEvaluator(n_clusters=nclass, random_state=0, n_jobs=8) 57 | if args.dataset == 'WikiCS': 58 | accs = [] 59 | for i in range(20): 60 | acc = log_regression(z, dataset, evaluator, split=f'wikics:{i}', num_epochs=800)['acc'] 61 | accs.append(acc) 62 | acc = sum(accs) / len(accs) 63 | else: 64 | acc = log_regression(z, dataset, evaluator, split='rand:0.1', num_epochs=3000, preload_split=split)['acc'] 65 | return acc 66 | 67 | 68 | if __name__ == '__main__': 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument('--device', type=str, default='cuda:0') 71 | parser.add_argument('--dataset', type=str, default='Amazon-Computers') 72 | parser.add_argument('--param', type=str, default='local:amazon_computers.json') 73 | parser.add_argument('--seed', type=int, default=39788) 74 | parser.add_argument('--verbose', type=str, default='train,eval,final') 75 | parser.add_argument('--save_split', type=str, nargs='?') 76 | parser.add_argument('--load_split', type=str, nargs='?') 77 | parser.add_argument('--epoch_start', type=int, default=400) 78 | parser.add_argument('--mode', type=str, default='weight') 79 | parser.add_argument('--sel_num', type=int, default=1000) 80 | parser.add_argument('--weight_init', type=float, default=0.05) 81 | parser.add_argument('--iters', type=int, default=10) 82 | default_param = { 83 | 'learning_rate': 0.01, 84 | 'num_hidden': 256, 85 | 'num_proj_hidden': 32, 86 | 'activation': 'prelu', 87 | 'base_model': 'GCNConv', 88 | 'num_layers': 2, 89 | 'drop_edge_rate_1': 0.3, 90 | 'drop_edge_rate_2': 0.4, 91 | 'drop_feature_rate_1': 0.1, 92 | 'drop_feature_rate_2': 0.0, 93 | 'tau': 0.4, 94 | 'num_epochs': 3000, 95 | 'weight_decay': 1e-5, 96 | 'drop_scheme': 'evc', 97 | } 98 | 99 | # add hyper-parameters into parser 100 | param_keys = default_param.keys() 101 | for key in param_keys: 102 | parser.add_argument(f'--{key}', type=type(default_param[key]), nargs='?') 103 | args = parser.parse_args() 104 | 105 | # parse param 106 | sp = SimpleParam(default=default_param) 107 | param = sp(source=args.param, preprocess='nni') 108 | 109 | # merge cli arguments and parsed param 110 | for key in param_keys: 111 | if getattr(args, key) is not None: 112 | param[key] = getattr(args, key) 113 | 114 | use_nni = args.param == 'nni' 115 | if use_nni and args.device != 'cpu': 116 | args.device = 'cuda' 117 | 118 | torch_seed = args.seed 119 | torch.manual_seed(torch_seed) 120 | random.seed(12345) 121 | device = torch.device(args.device) 122 | path = osp.expanduser('~/datasets') 123 | path = osp.join(path, args.dataset) 124 | dataset = get_dataset(path, args.dataset) 125 | y = dataset[0].y.view(-1).numpy() 126 | data = dataset[0] 127 | data = data.to(device) 128 | # generate split 129 | split = generate_split(data.num_nodes, train_ratio=0.1, val_ratio=0.1) 130 | 131 | if args.save_split: 132 | torch.save(split, args.save_split) 133 | elif args.load_split: 134 | split = torch.load(args.load_split) 135 | 136 | encoder = Encoder(dataset.num_features, param['num_hidden'], get_activation(param['activation']), 137 | base_model=get_base_model(param['base_model']), k=param['num_layers']).to(device) 138 | model = GRACE(encoder, param['num_hidden'], param['num_proj_hidden'], param['tau']).to(device) 139 | optimizer = torch.optim.Adam( 140 | model.parameters(), 141 | lr=param['learning_rate'], 142 | weight_decay=param['weight_decay'] 143 | ) 144 | 145 | if param['drop_scheme'] == 'degree': 146 | drop_weights = degree_drop_weights(data.edge_index).to(device) 147 | elif param['drop_scheme'] == 'pr': 148 | drop_weights = pr_drop_weights(data.edge_index, aggr='sink', k=200).to(device) 149 | elif param['drop_scheme'] == 'evc': 150 | drop_weights = evc_drop_weights(data).to(device) 151 | else: 152 | drop_weights = None 153 | 154 | if param['drop_scheme'] == 'degree': 155 | edge_index_ = to_undirected(data.edge_index) 156 | node_deg = degree(edge_index_[1]) 157 | if args.dataset == 'WikiCS': 158 | feature_weights = feature_drop_weights_dense(data.x, node_c=node_deg).to(device) 159 | else: 160 | feature_weights = feature_drop_weights(data.x, node_c=node_deg).to(device) 161 | elif param['drop_scheme'] == 'pr': 162 | node_pr = compute_pr(data.edge_index) 163 | if args.dataset == 'WikiCS': 164 | feature_weights = feature_drop_weights_dense(data.x, node_c=node_pr).to(device) 165 | else: 166 | feature_weights = feature_drop_weights(data.x, node_c=node_pr).to(device) 167 | elif param['drop_scheme'] == 'evc': 168 | node_evc = eigenvector_centrality(data) 169 | if args.dataset == 'WikiCS': 170 | feature_weights = feature_drop_weights_dense(data.x, node_c=node_evc).to(device) 171 | else: 172 | feature_weights = feature_drop_weights(data.x, node_c=node_evc).to(device) 173 | else: 174 | feature_weights = torch.ones((data.x.size(1),)).to(device) 175 | 176 | log = args.verbose.split(',') 177 | alphas_init = torch.tensor([1, 2],dtype=torch.float64).to(device) 178 | betas_init = torch.tensor([2, 1],dtype=torch.float64).to(device) 179 | weights_init = torch.tensor([1-args.weight_init, args.weight_init], dtype=torch.float64).to(device) 180 | bmm_model = BetaMixture1D(args.iters, alphas_init, betas_init, weights_init) 181 | 182 | for epoch in range(1, param['num_epochs'] + 1): 183 | loss = train(epoch, bmm_model) 184 | if 'train' in log: 185 | print(f'(T) | Epoch={epoch:03d}, loss={loss:.4f}') 186 | 187 | if epoch % 100 == 0: 188 | acc = test() 189 | logging.info('\t%.4f'%(acc)) 190 | if 'eval' in log: 191 | print(f'(E) | Epoch={epoch:04d}, avg_acc = {acc}') 192 | 193 | acc = test(final=True) 194 | logging.info('Final:') 195 | logging.info('\t%.4f'%(acc)) 196 | 197 | if 'final' in log: 198 | print(f'{acc}') 199 | --------------------------------------------------------------------------------