├── pGRACE ├── __init__.py ├── __pycache__ │ ├── eval.cpython-38.pyc │ ├── eval.cpython-39.pyc │ ├── model.cpython-38.pyc │ ├── model.cpython-39.pyc │ ├── utils.cpython-38.pyc │ ├── utils.cpython-39.pyc │ ├── dataset.cpython-38.pyc │ ├── dataset.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── functional.cpython-38.pyc │ └── functional.cpython-39.pyc ├── dataset.py ├── functional.py ├── utils.py ├── eval.py ├── model_mb.py └── model.py ├── simple_param ├── __init__.py ├── __pycache__ │ ├── sp.cpython-38.pyc │ ├── sp.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ └── __init__.cpython-39.pyc └── sp.py ├── .DS_Store ├── appendix.pdf ├── __pycache__ ├── eval.cpython-38.pyc ├── ssgc.cpython-38.pyc ├── ssgc.cpython-39.pyc ├── attack.cpython-38.pyc ├── utils.cpython-38.pyc ├── utils.cpython-39.pyc ├── layers_ar.cpython-38.pyc ├── model_ar.cpython-38.pyc └── utils_ar.cpython-38.pyc ├── README.md ├── param ├── wikics.json ├── amazon_photo.json ├── coauthor_phy.json ├── amazon_computers.json └── coauthor_cs.json ├── layers_ar.py ├── utils_ar.py ├── eval.py ├── config.yaml ├── MIE.py ├── utils.py ├── param.yaml ├── model_ar.py ├── info.py ├── train.py ├── main.py └── requirements.txt /pGRACE/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simple_param/__init__.py: -------------------------------------------------------------------------------- 1 | import simple_param.sp as sp -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/.DS_Store -------------------------------------------------------------------------------- /appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/appendix.pdf -------------------------------------------------------------------------------- /__pycache__/eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/__pycache__/eval.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/ssgc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/__pycache__/ssgc.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/ssgc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/__pycache__/ssgc.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/attack.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/__pycache__/attack.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/layers_ar.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/__pycache__/layers_ar.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/model_ar.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/__pycache__/model_ar.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/utils_ar.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/__pycache__/utils_ar.cpython-38.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/eval.cpython-38.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/eval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/eval.cpython-39.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/functional.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/functional.cpython-38.pyc -------------------------------------------------------------------------------- /pGRACE/__pycache__/functional.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/pGRACE/__pycache__/functional.cpython-39.pyc -------------------------------------------------------------------------------- /simple_param/__pycache__/sp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/simple_param/__pycache__/sp.cpython-38.pyc -------------------------------------------------------------------------------- /simple_param/__pycache__/sp.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/simple_param/__pycache__/sp.cpython-39.pyc -------------------------------------------------------------------------------- /simple_param/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/simple_param/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /simple_param/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GXM1141/MA-GCL/HEAD/simple_param/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MA-GCL: Model Augmention Tricks for Graph Contrastive Learning (AAAI2023) 2 | 3 | paper: https://arxiv.org/pdf/2212.07035.pdf 4 | 5 | For example, to run MA-GCL under Cora, execute: 6 | 7 | python main.py --device cuda:0 --dataset Cora 8 | 9 | -------------------------------------------------------------------------------- /param/wikics.json: -------------------------------------------------------------------------------- 1 | { 2 | "learning_rate": 0.01, 3 | "num_hidden": 256, 4 | "num_proj_hidden": 256, 5 | "activation": "prelu", 6 | "drop_edge_rate_1": 0.2, 7 | "drop_edge_rate_2": 0.3, 8 | "drop_feature_rate_1": 0.1, 9 | "drop_feature_rate_2": 0.1, 10 | "tau": 0.4, 11 | "num_epochs": 3000 12 | } -------------------------------------------------------------------------------- /param/amazon_photo.json: -------------------------------------------------------------------------------- 1 | { 2 | "learning_rate": 0.1, 3 | "num_hidden": 256, 4 | "num_proj_hidden": 64, 5 | "activation": "relu", 6 | "drop_edge_rate_1": 0.3, 7 | "drop_edge_rate_2": 0.5, 8 | "drop_feature_rate_1": 0.1, 9 | "drop_feature_rate_2": 0.1, 10 | "tau": 0.3, 11 | "num_epochs": 2000 12 | } -------------------------------------------------------------------------------- /param/coauthor_phy.json: -------------------------------------------------------------------------------- 1 | { 2 | "learning_rate": 0.01, 3 | "num_hidden": 128, 4 | "num_proj_hidden": 64, 5 | "activation": "rrelu", 6 | "drop_edge_rate_1": 0.4, 7 | "drop_edge_rate_2": 0.1, 8 | "drop_feature_rate_1": 0.1, 9 | "drop_feature_rate_2": 0.4, 10 | "tau": 0.5, 11 | "num_epochs": 1500 12 | } -------------------------------------------------------------------------------- /param/amazon_computers.json: -------------------------------------------------------------------------------- 1 | { 2 | "learning_rate": 0.01, 3 | "num_hidden": 128, 4 | "num_proj_hidden": 128, 5 | "activation": "rrelu", 6 | "drop_edge_rate_1": 0.6, 7 | "drop_edge_rate_2": 0.3, 8 | "drop_feature_rate_1": 0.2, 9 | "drop_feature_rate_2": 0.3, 10 | "tau": 0.2, 11 | "num_epochs": 2000 12 | } -------------------------------------------------------------------------------- /param/coauthor_cs.json: -------------------------------------------------------------------------------- 1 | { 2 | "learning_rate": 0.0005, 3 | "num_hidden": 256, 4 | "num_proj_hidden": 256, 5 | "activation": "rrelu", 6 | "drop_edge_rate_1": 0.3, 7 | "drop_edge_rate_2": 0.2, 8 | "drop_feature_rate_1": 0.3, 9 | "drop_feature_rate_2": 0.4, 10 | "tau": 0.4, 11 | "num_epochs": 1000 12 | } -------------------------------------------------------------------------------- /layers_ar.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from torch.nn.parameter import Parameter 6 | from torch.nn.modules.module import Module 7 | 8 | 9 | class GCNConv(Module): 10 | """ 11 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 12 | """ 13 | 14 | def __init__(self, in_features, out_features, bias=True): 15 | super(GCNConv, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 19 | if bias: 20 | self.bias = Parameter(torch.FloatTensor(out_features)) 21 | else: 22 | self.register_parameter('bias', None) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | stdv = 1. / math.sqrt(self.weight.size(1)) 27 | self.weight.data.uniform_(-stdv, stdv) 28 | if self.bias is not None: 29 | self.bias.data.uniform_(-stdv, stdv) 30 | 31 | def forward(self, input, adj): 32 | support = torch.mm(input, self.weight) 33 | output = torch.spmm(adj, support) 34 | if self.bias is not None: 35 | return output + self.bias 36 | else: 37 | return output 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + ' (' \ 41 | + str(self.in_features) + ' -> ' \ 42 | + str(self.out_features) + ')' -------------------------------------------------------------------------------- /utils_ar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter_add 3 | from torch_geometric.utils import get_laplacian, add_self_loops 4 | 5 | 6 | def normalize_adj_tensor(adj): 7 | """Symmetrically normalize adjacency tensor.""" 8 | rowsum = torch.sum(adj,1) 9 | d_inv_sqrt = torch.pow(rowsum, -0.5) 10 | d_inv_sqrt[d_inv_sqrt == float("Inf")] = 0. 11 | d_mat_inv_sqrt = torch.diag(d_inv_sqrt) 12 | return torch.mm(torch.mm(adj,d_mat_inv_sqrt).transpose(0,1),d_mat_inv_sqrt) 13 | 14 | def normalize_adj_tensor_sp(adj): 15 | """Symmetrically normalize sparse adjacency tensor.""" 16 | device = adj.device 17 | adj = adj.to("cpu") 18 | rowsum = torch.spmm(adj, torch.ones((adj.size(0),1))).reshape(-1) 19 | d_inv_sqrt = torch.pow(rowsum, -0.5) 20 | d_inv_sqrt[d_inv_sqrt == float("Inf")] = 0. 21 | d_mat_inv_sqrt = torch.diag(d_inv_sqrt) 22 | adj = torch.mm(torch.smm(adj.transpose(0,1),d_mat_inv_sqrt.transpose(0,1)),d_mat_inv_sqrt) 23 | return adj.to(device) 24 | 25 | def edge2adj(x, edge_index): 26 | """Convert edge index to adjacency matrix""" 27 | num_nodes = x.shape[0] 28 | tmp, _ = add_self_loops(edge_index, num_nodes=num_nodes) 29 | edge_weight = torch.ones(tmp.size(1), dtype=None, 30 | device=edge_index.device) 31 | 32 | row, col = tmp[0], tmp[1] 33 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 34 | deg_inv_sqrt = deg.pow_(-0.5) 35 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) 36 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 37 | return torch.sparse.FloatTensor(tmp, edge_weight,torch.Size((num_nodes, num_nodes))) -------------------------------------------------------------------------------- /pGRACE/dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from torch_geometric.datasets import Planetoid, CitationFull, WikiCS, Coauthor, Amazon 4 | import torch_geometric.transforms as T 5 | 6 | from ogb.nodeproppred import PygNodePropPredDataset 7 | 8 | def get_dataset(path, name): 9 | assert name in ['Cora', 'CiteSeer', 'PubMed', 'DBLP', 'Karate', 'WikiCS', 'Coauthor-CS', 'Coauthor-Phy', 10 | 'Amazon-Computers', 'Amazon-Photo', 'ogbn-arxiv', 'ogbg-code'] 11 | name = 'dblp' if name == 'DBLP' else name 12 | root_path = osp.expanduser('~/datasets') 13 | print (root_path) 14 | if name == 'Coauthor-CS': 15 | return Coauthor(root=path, name='cs', transform=T.NormalizeFeatures()) 16 | # return Coauthor(root=path, name='cs') 17 | if name == 'Coauthor-Phy': 18 | return Coauthor(root=path, name='physics', transform=T.NormalizeFeatures()) 19 | 20 | if name == 'WikiCS': 21 | return WikiCS(root=path, transform=T.NormalizeFeatures()) 22 | 23 | if name == 'Amazon-Computers': 24 | return Amazon(root=path, name='computers', transform=T.NormalizeFeatures()) 25 | # return Amazon(root=path, name='computers') 26 | 27 | if name == 'Amazon-Photo': 28 | return Amazon(root=path, name='photo', transform=T.NormalizeFeatures()) 29 | # return Amazon(root=path, name='photo') 30 | 31 | if name.startswith('ogbn'): 32 | return PygNodePropPredDataset(root=osp.join(root_path, 'OGB'), name=name, transform=T.NormalizeFeatures()) 33 | 34 | return (CitationFull if name == 'dblp' else Planetoid)(osp.join(root_path, 'Citation'), name, transform=T.NormalizeFeatures()) 35 | 36 | 37 | def get_path(base_path, name): 38 | if name in ['Cora', 'CiteSeer', 'PubMed']: 39 | return base_path 40 | else: 41 | return osp.join(base_path, name) 42 | -------------------------------------------------------------------------------- /simple_param/sp.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import os.path as osp 3 | 4 | import json 5 | import yaml 6 | 7 | import nni 8 | 9 | 10 | class SimpleParam: 11 | def __init__(self, local_dir: str = 'param', default: Optional[dict] = None): 12 | if default is None: 13 | default = dict() 14 | 15 | self.local_dir = local_dir 16 | self.default = default 17 | 18 | def __call__(self, source: str, preprocess: str = 'none'): 19 | if source == 'nni': 20 | return {**self.default, **nni.get_next_parameter()} 21 | if source.startswith('local'): 22 | ts = source.split(':') 23 | assert len(ts) == 2, 'local parameter file should be specified in a form of `local:FILE_NAME`' 24 | path = ts[-1] 25 | path = osp.join(self.local_dir, path) 26 | if path.endswith('.json'): 27 | loaded = parse_json(path) 28 | elif path.endswith('.yaml') or path.endswith('.yml'): 29 | loaded = parse_yaml(path) 30 | else: 31 | raise Exception('Invalid file name. Should end with .yaml or .json.') 32 | 33 | if preprocess == 'nni': 34 | loaded = preprocess_nni(loaded) 35 | 36 | return {**self.default, **loaded} 37 | if source == 'default': 38 | return self.default 39 | 40 | raise Exception('invalid source') 41 | 42 | 43 | def preprocess_nni(params: dict): 44 | def process_key(key: str): 45 | xs = key.split('/') 46 | if len(xs) == 3: 47 | return xs[1] 48 | elif len(xs) == 1: 49 | return key 50 | else: 51 | raise Exception('Unexpected param name ' + key) 52 | 53 | return { 54 | process_key(k): v for k, v in params.items() 55 | } 56 | 57 | 58 | def parse_yaml(path: str): 59 | content = open(path).read() 60 | return yaml.load(content, Loader=yaml.Loader) 61 | 62 | 63 | def parse_json(path: str): 64 | content = open(path).read() 65 | return json.loads(content) 66 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import functools 3 | 4 | from sklearn.metrics import f1_score, roc_auc_score 5 | from sklearn.linear_model import LogisticRegression 6 | from sklearn.svm import SVC 7 | from sklearn.model_selection import train_test_split, GridSearchCV 8 | from sklearn.multiclass import OneVsRestClassifier 9 | from sklearn.preprocessing import normalize, OneHotEncoder 10 | 11 | 12 | def repeat(n_times): 13 | def decorator(f): 14 | @functools.wraps(f) 15 | def wrapper(*args, **kwargs): 16 | results = [f(*args, **kwargs) for _ in range(n_times)] 17 | statistics = {} 18 | for key in results[0].keys(): 19 | values = [r[key] for r in results] 20 | statistics[key] = { 21 | 'mean': np.mean(values), 22 | 'std': np.std(values)} 23 | print_statistics(statistics, f.__name__) 24 | return statistics 25 | return wrapper 26 | return decorator 27 | 28 | 29 | def prob_to_one_hot(y_pred): 30 | ret = np.zeros(y_pred.shape, np.bool) 31 | indices = np.argmax(y_pred, axis=1) 32 | for i in range(y_pred.shape[0]): 33 | ret[i][indices[i]] = True 34 | return ret 35 | 36 | 37 | def print_statistics(statistics, function_name): 38 | print(f'(E) | {function_name}:', end=' ') 39 | for i, key in enumerate(statistics.keys()): 40 | mean = statistics[key]['mean'] 41 | std = statistics[key]['std'] 42 | print(f'{key}={mean:.4f}+-{std:.4f}', end='') 43 | if i != len(statistics.keys()) - 1: 44 | print(',', end=' ') 45 | else: 46 | print() 47 | 48 | 49 | @repeat(20) 50 | def label_classification(embeddings, y, ratio): 51 | X = embeddings.detach().cpu().numpy() 52 | Y = y.detach().cpu().numpy() 53 | Y = Y.reshape(-1, 1) 54 | onehot_encoder = OneHotEncoder(categories='auto').fit(Y) 55 | Y = onehot_encoder.transform(Y).toarray().astype(np.bool) 56 | 57 | X = normalize(X, norm='l2') 58 | 59 | X_train, X_test, y_train, y_test = train_test_split(X, Y, 60 | test_size=1 - ratio) 61 | 62 | logreg = LogisticRegression(solver='liblinear') 63 | c = 2.0 ** np.arange(-10, 10) 64 | 65 | clf = GridSearchCV(estimator=OneVsRestClassifier(logreg), 66 | param_grid=dict(estimator__C=c), n_jobs=8, cv=5, 67 | verbose=0) 68 | clf.fit(X_train, y_train) 69 | 70 | y_pred = clf.predict_proba(X_test) 71 | y_pred = prob_to_one_hot(y_pred) 72 | 73 | acc = np.sum(np.where(y_test)[1]==np.where(y_pred)[1])/len(y_pred) 74 | 75 | return {"ACC": acc} 76 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | Cora: 2 | seed: 1998 3 | learning_rate: 0.0005 4 | num_hidden: 128 5 | num_proj_hidden: 128 6 | activation: 'relu' 7 | base_model: 'GCNConv' 8 | num_layers: 2 9 | drop_edge_rate_1: 0.2 10 | drop_edge_rate_2: 0.4 11 | drop_feature_rate_1: 0.3 12 | drop_feature_rate_2: 0.4 13 | tau: 0.4 14 | num_epochs: 300 15 | weight_decay: 0.00001 16 | eps: 0.5 17 | alpha: 0.1 18 | beta: 0.01 19 | lamb: 0 20 | CiteSeer: 21 | seed: 4 22 | learning_rate: 0.001 23 | num_hidden: 256 24 | num_proj_hidden: 256 25 | activation: 'prelu' 26 | base_model: 'GCNConv' 27 | num_layers: 2 28 | drop_edge_rate_1: 0.2 29 | drop_edge_rate_2: 0.0 30 | drop_feature_rate_1: 0.3 31 | drop_feature_rate_2: 0.2 32 | tau: 0.9 33 | num_epochs: 300 34 | weight_decay: 0.00001 35 | eps: 2 36 | alpha: 0.01 37 | beta: 0.001 38 | lamb: 1 39 | Amazon-Computers: 40 | seed: 3 41 | learning_rate: 0.01 42 | num_hidden: 128 43 | num_proj_hidden: 128 44 | activation: "rrelu" 45 | base_model: 'GCNConv' 46 | num_layers: 2 47 | drop_edge_rate_1: 0.6 48 | drop_edge_rate_2: 0.3 49 | drop_feature_rate_1: 0.2 50 | drop_feature_rate_2: 0.3 51 | tau: 0.2 52 | num_epochs: 2000 53 | weight_decay: 0.00001 54 | eps: 0.5 55 | alpha: 0.001 56 | beta: 0.01 57 | lamb: 0 58 | Amazon-Photo: 59 | seed: 4 60 | learning_rate: 0.1 61 | num_hidden: 256 62 | num_proj_hidden: 64 63 | activation: "relu" 64 | base_model: 'GCNConv' 65 | num_layers: 2 66 | drop_edge_rate_1: 0.3 67 | drop_edge_rate_2: 0.5 68 | drop_feature_rate_1: 0.1 69 | drop_feature_rate_2: 0.1 70 | tau: 0.3 71 | num_epochs: 2000 72 | weight_decay: 0.00001 73 | eps: 2 74 | alpha: 0.1 75 | beta: 0.01 76 | lamb: 1 77 | Coauthor-CS: 78 | seed: 1 79 | learning_rate: 0.0005 80 | num_hidden: 256 81 | num_proj_hidden: 256 82 | activation: "rrelu" 83 | base_model: 'GCNConv' 84 | num_layers: 2 85 | drop_edge_rate_1: 0.3 86 | drop_edge_rate_2: 0.2 87 | drop_feature_rate_1: 0.3 88 | drop_feature_rate_2: 0.4 89 | tau: 0.4 90 | num_epochs: 1000 91 | weight_decay: 0.00001 92 | eps: 1.5 93 | alpha: 0.1 94 | beta: 0.001 95 | lamb: 0 96 | Coauthor-Physics: 97 | seed: 1 98 | learning_rate: 0.01 99 | num_hidden: 128 100 | num_proj_hidden: 64 101 | activation: "rrelu" 102 | base_model: 'GCNConv' 103 | num_layers: 2 104 | drop_edge_rate_1: 0.4 105 | drop_edge_rate_2: 0.1 106 | drop_feature_rate_1: 0.1 107 | drop_feature_rate_2: 0.4 108 | tau: 0.5 109 | num_epochs: 1500 110 | weight_decay: 0.00001 111 | eps: 1 112 | alpha: 0.1 113 | beta: 0.001 114 | lamb: 0.5 115 | -------------------------------------------------------------------------------- /MIE.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import os.path as osp 3 | import random 4 | import torch 5 | from torch.autograd import Variable 6 | import tqdm 7 | from torch.optim import Adam 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numpy as np 11 | from pGRACE.dataset import get_dataset 12 | from pGRACE.eval import MulticlassEvaluator 13 | from pGRACE.model import LogReg 14 | from pGRACE.utils import get_base_model, get_activation, \ 15 | generate_split, compute_pr, eigenvector_centrality 16 | 17 | 18 | 19 | hidden_size = 64 20 | n_epoch = 1500 21 | x_size = 128 22 | 23 | class MINE(nn.Module): 24 | def __init__(self, in_size, hidden_size=10): 25 | super(MINE, self).__init__() 26 | self.layers = nn.Sequential(nn.Linear(2 * in_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1)) 27 | self.xT = nn.Linear(x_size, num_classes) 28 | self.init_emb() 29 | 30 | def init_emb(self): 31 | for m in self.modules(): 32 | if isinstance(m, nn.Linear): 33 | torch.nn.init.xavier_uniform_(m.weight.data) 34 | if m.bias is not None: 35 | m.bias.data.fill_(0.0) 36 | def forward(self, x, y): 37 | #x = self.xT(x) 38 | #x = F.softmax(x, dim=1) 39 | batch_size = x.size(0) 40 | print (x.shape) 41 | tiled_x = torch.cat([x, x, ], dim=0) 42 | print (tiled_x.shape) 43 | idx = torch.randperm(batch_size) 44 | 45 | shuffled_y = y[idx] 46 | concat_y = torch.cat([y, shuffled_y], dim=0) 47 | print (concat_y.shape) 48 | inputs = torch.cat([tiled_x, concat_y], dim=1) 49 | print (inputs.shape) 50 | logits = self.layers(inputs) 51 | 52 | pred_xy = logits[:batch_size] 53 | pred_x_y = logits[batch_size:] 54 | loss = - np.log2(np.exp(1)) * (torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y)))) 55 | # compute loss, you'd better scale exp to bit 56 | return loss 57 | device = 'cuda:1' 58 | z1 = np.load('embedding/Amazon-Computersview1_embeddingfull.npy') 59 | z2 = np.load('embedding/Amazon-Computersview2_embeddingfull.npy') 60 | dataset = 'Cora' 61 | path = osp.expanduser('~/datasets') 62 | path = osp.join(path, dataset) 63 | dataset = get_dataset(path, dataset) 64 | label = dataset[0].y.view(-1) 65 | label = label.unsqueeze(1) 66 | num_classes = dataset[0].y.max().item() + 1 67 | onehot = torch.zeros(label.shape[0], num_classes) 68 | onehot.scatter_(1, label, 1) 69 | x_sample = z1 70 | y_sample = z2 71 | 72 | x_sample = torch.from_numpy(x_sample).float().to(device) 73 | #y_sample = onehot.float().to(device) 74 | y_sample = torch.from_numpy(y_sample).float().to(device) 75 | 76 | in_dim = y_sample.shape[1] 77 | 78 | model = MINE(in_dim, hidden_size).to(device) 79 | optimizer = torch.optim.Adam(model.parameters(), lr=0.005) 80 | plot_loss = [] 81 | all_mi = [] 82 | 83 | dataset = 'PubMed' 84 | for epoch in range(n_epoch): 85 | 86 | loss = model(x_sample, y_sample) 87 | 88 | model.zero_grad() 89 | loss.backward() 90 | optimizer.step() 91 | if (epoch > 900): 92 | all_mi.append(-loss.cpu().item()) 93 | 94 | print (np.mean(all_mi)) 95 | -------------------------------------------------------------------------------- /pGRACE/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import degree, to_undirected 3 | 4 | from pGRACE.utils import compute_pr, eigenvector_centrality 5 | 6 | 7 | def drop_feature(x, drop_prob): 8 | drop_mask = torch.empty((x.size(1),), dtype=torch.float32, device=x.device).uniform_(0, 1) < drop_prob 9 | x = x.clone() 10 | x[:, drop_mask] = 0 11 | 12 | return x 13 | 14 | 15 | def drop_feature_weighted(x, w, p: float, threshold: float = 0.7): 16 | w = w / w.mean() * p 17 | w = w.where(w < threshold, torch.ones_like(w) * threshold) 18 | drop_prob = w.repeat(x.size(0)).view(x.size(0), -1) 19 | 20 | drop_mask = torch.bernoulli(drop_prob).to(torch.bool) 21 | 22 | x = x.clone() 23 | x[drop_mask] = 0. 24 | 25 | return x 26 | 27 | 28 | def drop_feature_weighted_2(x, w, p: float, threshold: float = 0.7): 29 | w = w / w.mean() * p 30 | w = w.where(w < threshold, torch.ones_like(w) * threshold) 31 | drop_prob = w 32 | 33 | drop_mask = torch.bernoulli(drop_prob).to(torch.bool) 34 | 35 | x = x.clone() 36 | x[:, drop_mask] = 0. 37 | 38 | return x 39 | 40 | 41 | def feature_drop_weights(x, node_c): 42 | x = x.to(torch.bool).to(torch.float32) 43 | w = x.t() @ node_c 44 | w = w.log() 45 | s = (w.max() - w) / (w.max() - w.mean()) 46 | 47 | return s 48 | 49 | 50 | def feature_drop_weights_dense(x, node_c): 51 | x = x.abs() 52 | w = x.t() @ node_c 53 | w = w.log() 54 | s = (w.max() - w) / (w.max() - w.mean()) 55 | 56 | return s 57 | 58 | 59 | def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 1.): 60 | edge_weights = edge_weights / edge_weights.mean() * p 61 | edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold)#小于门限的保存,大于门限的用门限替代 62 | sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool)#越小选择1的概率越大 63 | 64 | return edge_index[:, sel_mask] 65 | 66 | 67 | def degree_drop_weights(edge_index): 68 | edge_index_ = to_undirected(edge_index) 69 | deg = degree(edge_index_[1]) 70 | deg_col = deg[edge_index[1]].to(torch.float32) 71 | s_col = torch.log(deg_col) 72 | weights = (s_col.max() - s_col) / (s_col.max() - s_col.mean()) 73 | 74 | return weights 75 | 76 | 77 | def pr_drop_weights(edge_index, aggr: str = 'sink', k: int = 10): 78 | pv = compute_pr(edge_index, k=k) 79 | pv_row = pv[edge_index[0]].to(torch.float32) 80 | pv_col = pv[edge_index[1]].to(torch.float32) 81 | s_row = torch.log(pv_row) 82 | s_col = torch.log(pv_col) 83 | if aggr == 'sink': 84 | s = s_col 85 | elif aggr == 'source': 86 | s = s_row 87 | elif aggr == 'mean': 88 | s = (s_col + s_row) * 0.5 89 | else: 90 | s = s_col 91 | weights = (s.max() - s) / (s.max() - s.mean()) 92 | 93 | return weights 94 | 95 | 96 | def evc_drop_weights(data): 97 | evc = eigenvector_centrality(data) 98 | evc = evc.where(evc > 0, torch.zeros_like(evc)) 99 | evc = evc + 1e-8 100 | s = evc.log() 101 | 102 | edge_index = data.edge_index 103 | s_row, s_col = s[edge_index[0]], s[edge_index[1]] 104 | s = s_col 105 | 106 | return (s.max() - s) / (s.max() - s.mean()) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import networkx as nx 5 | from scipy import sparse 6 | from torch_geometric.nn import GCNConv 7 | from sklearn.metrics import accuracy_score 8 | import random 9 | import os 10 | 11 | def create_adjacency_matrix(edge_index): 12 | """ 13 | Creating a sparse adjacency matrix. 14 | :param graph: NetworkX object. 15 | :return A: Adjacency matrix. 16 | """ 17 | edges = edge_index.t().cpu().numpy() 18 | index_1 = [edge[0] for edge in edges] + [edge[1] for edge in edges] 19 | index_2 = [edge[1] for edge in edges] + [edge[0] for edge in edges] 20 | values = [1 for edge in index_1] 21 | node_count = max(max(index_1)+1, max(index_2)+1) 22 | A = sparse.coo_matrix((np.array(values)/2, (np.array(index_1), np.array(index_2))), shape=(node_count, node_count), dtype=np.float32) 23 | return A 24 | 25 | def Rank(A, len = 1000, a = 0.8): 26 | I = sparse.eye(A.shape[0]) 27 | A_s = normalize_adjacency_matrix(A, I)[0].tocoo() 28 | A_s = normalize_adj(A).tocoo() 29 | Pi = (torch.ones(A.shape[0], 1)/A.shape[0]) 30 | A = torch.sparse.LongTensor(torch.LongTensor([A_s.row.tolist(), A_s.col.tolist()]), 31 | torch.FloatTensor(A_s.data.astype(np.float))) 32 | 33 | A = A.to_dense().float() 34 | A_s = A 35 | e = (torch.ones_like(A) / A.shape[0]) 36 | A = a * A + (1-a) * e 37 | for i in range (100): 38 | Pi = torch.mm(A, Pi) 39 | val, idx = torch.topk(Pi, len, dim = 0) 40 | idx = idx.t() 41 | val = val.t() 42 | idx, _ = torch.sort(idx) 43 | return idx 44 | 45 | def normalize_adj(adj): 46 | adj = sparse.coo_matrix(adj) 47 | row_sum = np.array(adj.sum(1)) 48 | d_inv_sqrt = np.power(row_sum, -0.5).flatten() 49 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 50 | d_mat_inv_sqrt = sparse.diags(d_inv_sqrt) 51 | return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt) 52 | 53 | def normalize_adjacency_matrix(A, I): 54 | 55 | A_tilde = A + I 56 | degrees = A_tilde.sum(axis=0)[0].tolist() 57 | D = sparse.diags(degrees, [0]) 58 | D = D.power(-0.5) 59 | A_tilde_hat = D.dot(A_tilde).dot(D) 60 | A_lap = D.dot(A).dot(D) 61 | return A_tilde_hat, A_lap 62 | 63 | def load_adj_neg(num_nodes, sample): 64 | col = np.random.randint(0, num_nodes, size=num_nodes * sample) 65 | row = np.repeat(range(num_nodes), sample) 66 | index = np.not_equal(col,row) 67 | col = col[index] 68 | row = row[index] 69 | new_col = np.concatenate((col,row),axis=0) 70 | new_row = np.concatenate((row,col),axis=0) 71 | #data = np.ones(num_nodes * sample*2) 72 | data = np.ones(new_col.shape[0]) 73 | adj_neg = sparse.coo_matrix((data, (new_row, new_col)), shape=(num_nodes, num_nodes)) 74 | #adj_neg = (sp.eye(adj_neg.shape[0]) * sample - adj_neg).toarray() 75 | #adj_neg = (sp.eye(adj_neg.shape[0]) - adj_neg/sample).toarray() 76 | #adj_neg = (adj_neg / sample).toarray() 77 | adj_neg = normalize_adj(adj_neg) 78 | 79 | return adj_neg.toarray() 80 | 81 | def seed_torch(seed=1029): 82 | random.seed(seed) 83 | os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现 84 | np.random.seed(seed) 85 | torch.manual_seed(seed) 86 | torch.cuda.manual_seed(seed) 87 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 88 | torch.backends.cudnn.benchmark = False 89 | torch.backends.cudnn.deterministic = True 90 | -------------------------------------------------------------------------------- /pGRACE/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import scipy.sparse as sp 4 | from scipy.special import iv 5 | from scipy.sparse.linalg import eigsh 6 | import os.path as osp 7 | from sklearn.cluster import KMeans, SpectralClustering 8 | from sklearn.manifold import SpectralEmbedding 9 | # from libKMCUDA import kmeans_cuda 10 | from tqdm import tqdm 11 | from matplotlib import cm 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.distributions.categorical import Categorical 17 | from torch.optim import Adam 18 | from torch.utils.data import random_split 19 | from torch_geometric.nn import GCNConv, SGConv, SAGEConv, GATConv, GraphConv, GINConv 20 | from torch_geometric.utils import sort_edge_index, degree, add_remaining_self_loops, remove_self_loops, get_laplacian, \ 21 | to_undirected, to_dense_adj, to_networkx 22 | from torch_geometric.datasets import KarateClub 23 | from torch_scatter import scatter 24 | import torch_sparse 25 | 26 | import networkx as nx 27 | import matplotlib.pyplot as plt 28 | 29 | 30 | def get_base_model(name: str): 31 | def gat_wrapper(in_channels, out_channels): 32 | return GATConv( 33 | in_channels=in_channels, 34 | out_channels=out_channels // 4, 35 | heads=4 36 | ) 37 | 38 | def gin_wrapper(in_channels, out_channels): 39 | mlp = nn.Sequential( 40 | nn.Linear(in_channels, 2 * out_channels), 41 | nn.ELU(), 42 | nn.Linear(2 * out_channels, out_channels) 43 | ) 44 | return GINConv(mlp) 45 | 46 | base_models = { 47 | 'GCNConv': GCNConv, 48 | 'SGConv': SGConv, 49 | 'SAGEConv': SAGEConv, 50 | 'GATConv': gat_wrapper, 51 | 'GraphConv': GraphConv, 52 | 'GINConv': gin_wrapper 53 | } 54 | 55 | return base_models[name] 56 | 57 | 58 | def get_activation(name: str): 59 | activations = { 60 | 'relu': F.relu, 61 | 'hardtanh': F.hardtanh, 62 | 'elu': F.elu, 63 | 'leakyrelu': F.leaky_relu, 64 | 'prelu': torch.nn.PReLU(), 65 | 'rrelu': F.rrelu 66 | } 67 | 68 | return activations[name] 69 | 70 | 71 | def compute_pr(edge_index, damp: float = 0.85, k: int = 10): 72 | num_nodes = edge_index.max().item() + 1 73 | deg_out = degree(edge_index[0]) 74 | x = torch.ones((num_nodes, )).to(edge_index.device).to(torch.float32) 75 | 76 | for i in range(k): 77 | edge_msg = x[edge_index[0]] / deg_out[edge_index[0]] 78 | agg_msg = scatter(edge_msg, edge_index[1], reduce='sum') 79 | 80 | x = (1 - damp) * x + damp * agg_msg 81 | 82 | return x 83 | 84 | 85 | def eigenvector_centrality(data): 86 | graph = to_networkx(data) 87 | x = nx.eigenvector_centrality_numpy(graph) 88 | x = [x[i] for i in range(data.num_nodes)] 89 | return torch.tensor(x, dtype=torch.float32).to(data.edge_index.device) 90 | 91 | 92 | def generate_split(num_samples: int, train_ratio: float, val_ratio: float): 93 | train_len = int(num_samples * train_ratio) 94 | val_len = int(num_samples * val_ratio) 95 | test_len = num_samples - train_len - val_len 96 | 97 | train_set, test_set, val_set = random_split(torch.arange(0, num_samples), (train_len, test_len, val_len)) 98 | 99 | idx_train, idx_test, idx_val = train_set.indices, test_set.indices, val_set.indices 100 | train_mask = torch.zeros((num_samples,)).to(torch.bool) 101 | test_mask = torch.zeros((num_samples,)).to(torch.bool) 102 | val_mask = torch.zeros((num_samples,)).to(torch.bool) 103 | 104 | train_mask[idx_train] = True 105 | test_mask[idx_test] = True 106 | val_mask[idx_val] = True 107 | 108 | return train_mask, test_mask, val_mask 109 | 110 | -------------------------------------------------------------------------------- /pGRACE/eval.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch.optim import Adam 5 | import torch.nn as nn 6 | 7 | from pGRACE.model import LogReg 8 | 9 | 10 | def get_idx_split(dataset, split, preload_split): 11 | if split[:4] == 'rand': 12 | train_ratio = float(split.split(':')[1]) 13 | num_nodes = dataset[0].x.size(0) 14 | train_size = int(num_nodes * train_ratio) 15 | indices = torch.randperm(num_nodes) 16 | return { 17 | 'train': indices[:train_size], 18 | 'val': indices[train_size:2 * train_size], 19 | 'test': indices[2 * train_size:] 20 | } 21 | elif split == 'ogb': 22 | return dataset.get_idx_split() 23 | elif split.startswith('wikics'): 24 | split_idx = int(split.split(':')[1]) 25 | return { 26 | 'train': dataset[0].train_mask[:, split_idx], 27 | 'test': dataset[0].test_mask, 28 | 'val': dataset[0].val_mask[:, split_idx] 29 | } 30 | elif split == 'preloaded': 31 | assert preload_split is not None, 'use preloaded split, but preloaded_split is None' 32 | train_mask, test_mask, val_mask = preload_split 33 | return { 34 | 'train': train_mask, 35 | 'test': test_mask, 36 | 'val': val_mask 37 | } 38 | else: 39 | raise RuntimeError(f'Unknown split type {split}') 40 | 41 | 42 | def log_regression(z, 43 | dataset, 44 | evaluator, 45 | num_epochs: int = 5000, 46 | test_device: Optional[str] = None, 47 | split: str = 'rand:0.1', 48 | verbose: bool = False, 49 | preload_split=None): 50 | test_device = z.device if test_device is None else test_device 51 | z = z.detach().to(test_device) 52 | num_hidden = z.size(1) 53 | y = dataset[0].y.view(-1).to(test_device) 54 | num_classes = dataset[0].y.max().item() + 1 55 | classifier = LogReg(num_hidden, num_classes).to(test_device) 56 | optimizer = Adam(classifier.parameters(), lr=0.01, weight_decay=0.00000) 57 | 58 | split = get_idx_split(dataset, split, preload_split) 59 | split = {k: v.to(test_device) for k, v in split.items()} 60 | f = nn.LogSoftmax(dim=-1) 61 | nll_loss = nn.NLLLoss() 62 | 63 | best_test_acc = 0 64 | best_val_acc = 0 65 | best_epoch = 0 66 | 67 | for epoch in range(3000): 68 | classifier.train() 69 | optimizer.zero_grad() 70 | 71 | output = classifier(z[split['train']]) 72 | loss = nll_loss(f(output), y[split['train']]) 73 | 74 | loss.backward() 75 | optimizer.step() 76 | 77 | if (epoch + 1) % 20 == 0: 78 | if 'val' in split: 79 | # val split is available 80 | test_acc = evaluator.eval({ 81 | 'y_true': y[split['test']].view(-1, 1), 82 | 'y_pred': classifier(z[split['test']]).argmax(-1).view(-1, 1) 83 | })['acc'] 84 | val_acc = evaluator.eval({ 85 | 'y_true': y[split['val']].view(-1, 1), 86 | 'y_pred': classifier(z[split['val']]).argmax(-1).view(-1, 1) 87 | })['acc'] 88 | if val_acc > best_val_acc: 89 | best_val_acc = val_acc 90 | best_test_acc = test_acc 91 | best_epoch = epoch 92 | else: 93 | acc = evaluator.eval({ 94 | 'y_true': y[split['test']].view(-1, 1), 95 | 'y_pred': classifier(z[split['test']]).argmax(-1).view(-1, 1) 96 | })['acc'] 97 | if best_test_acc < acc: 98 | best_test_acc = acc 99 | best_epoch = epoch 100 | if verbose: 101 | print(f'logreg epoch {epoch}: best test acc {best_test_acc}') 102 | 103 | return {'acc': best_test_acc} 104 | 105 | 106 | class MulticlassEvaluator: 107 | def __init__(self, *args, **kwargs): 108 | pass 109 | 110 | @staticmethod 111 | def _eval(y_true, y_pred): 112 | y_true = y_true.view(-1) 113 | y_pred = y_pred.view(-1) 114 | total = y_true.size(0) 115 | correct = (y_true == y_pred).to(torch.float32).sum() 116 | return (correct / total).item() 117 | 118 | def eval(self, res): 119 | return {'acc': self._eval(**res)} 120 | -------------------------------------------------------------------------------- /param.yaml: -------------------------------------------------------------------------------- 1 | WikiCS: 2 | learning_rate: 0.01 3 | num_hidden: 256 4 | num_proj_hidden: 256 5 | activation: 'prelu' 6 | base_model: 'GCNConv' 7 | num_layers: 2 8 | drop_edge_rate_1: 0.2 9 | drop_edge_rate_2: 0.3 10 | drop_feature_rate_1: 0.1 11 | drop_feature_rate_2: 0.1 12 | tau: 0.4 13 | num_epochs: 3000 14 | weight_decay: 0.00001 15 | drop_scheme: 'degree' 16 | rand_layers: 4 17 | 18 | Cora: 19 | learning_rate: 0.0005 20 | num_hidden: 512 21 | num_proj_hidden: 512 22 | activation: 'relu' 23 | base_model: 'GCNConv' 24 | num_layers: 2 25 | drop_embedding: 0.4 26 | drop_edge_rate_1: 0.4 27 | drop_edge_rate_2: 0.4 28 | drop_feature_rate_1: 0.5 29 | drop_feature_rate_2: 0.5 30 | tau: 0.4 31 | num_epochs: 500 32 | weight_decay: 0.00001 33 | drop_scheme: 'degree' 34 | k_neighbors: 2 35 | rand_layers: 4 36 | K_layers: 8 37 | 38 | Cora20: 39 | learning_rate: 0.0005 40 | num_hidden: 128 41 | num_proj_hidden: 128 42 | activation: 'relu' 43 | base_model: 'GCNConv' 44 | num_layers: 2 45 | drop_embedding: 0.4 46 | drop_edge_rate_1: 0.5 47 | drop_edge_rate_2: 0.4 48 | drop_feature_rate_1: 0.3 49 | drop_feature_rate_2: 0.5 50 | tau: 0.5 51 | num_epochs: 500 52 | weight_decay: 0.00005 53 | drop_scheme: 'degree' 54 | k_neighbors: 2 55 | rand_layers: 4 56 | K_layers: 8 57 | PubMed: 58 | learning_rate: 0.002 59 | num_hidden: 512 60 | num_proj_hidden: 128 61 | activation: 'relu' 62 | base_model: 'GCNConv' 63 | num_layers: 2 64 | drop_embedding: 0.4 65 | drop_edge_rate_1: 0.5 66 | drop_edge_rate_2: 0.4 67 | drop_feature_rate_1: 0.3 68 | drop_feature_rate_2: 0.5 69 | tau: 0.4 70 | num_epochs: 900 71 | weight_decay: 0.00001 72 | drop_scheme: 'degree' 73 | k_neighbors: 2 74 | rand_layers: 4 75 | K_layers: 8 76 | 77 | CiteSeer: 78 | learning_rate: 0.0001 79 | num_hidden: 512 80 | num_proj_hidden: 512 81 | activation: 'relu' 82 | base_model: 'GCNConv' 83 | num_layers: 2 84 | drop_embedding: 0.4 85 | drop_edge_rate_1: 0.4 86 | drop_edge_rate_2: 0.4 87 | drop_feature_rate_1: 0.3 88 | drop_feature_rate_2: 0.3 89 | rand_layers: 2 90 | tau: 0.8 91 | num_epochs: 600 92 | weight_decay: 0.00001 93 | drop_scheme: 'degree' 94 | K_layers: 4 95 | k_neighbors: 1 96 | 97 | CiteSeer20: 98 | learning_rate: 0.00001 99 | num_hidden: 512 100 | num_proj_hidden: 512 101 | activation: 'relu' 102 | base_model: 'GCNConv' 103 | num_layers: 2 104 | drop_embedding: 0.4 105 | drop_edge_rate_1: 0.3 106 | drop_edge_rate_2: 0.2 107 | drop_feature_rate_1: 0.3 108 | drop_feature_rate_2: 0.2 109 | rand_layers: 2 110 | tau: 0.9 111 | num_epochs: 400 112 | weight_decay: 0.000001 113 | drop_scheme: 'degree' 114 | K_layers: 4 115 | k_neighbors: 1 116 | ogbn-arxiv: 117 | learning_rate: 0.0001 118 | num_hidden: 128 119 | num_proj_hidden: 128 120 | activation: 'relu' 121 | base_model: 'GCNConv' 122 | num_layers: 2 123 | drop_embedding: 0.4 124 | drop_edge_rate_1: 0.2 125 | drop_edge_rate_2: 0.2 126 | drop_feature_rate_1: 0.2 127 | drop_feature_rate_2: 0.2 128 | tau: 0.8 129 | num_epochs: 500 130 | weight_decay: 0.0000 131 | drop_scheme: 'degree' 132 | K_layers: 4 133 | rand_layers: 2 134 | k_neighbors: 1 135 | Amazon-Computers: 136 | learning_rate: 0.0008 137 | num_hidden: 256 138 | num_proj_hidden: 256 139 | activation: "rrelu" 140 | base_model: 'GCNConv' 141 | num_layers: 2 142 | drop_edge_rate_1: 0.7 143 | drop_edge_rate_2: 0.2 144 | drop_feature_rate_1: 0.2 145 | drop_feature_rate_2: 0.2 146 | tau: 0.3 147 | num_epochs: 2000 148 | weight_decay: 0.00001 149 | drop_scheme: 'degree' 150 | rand_layers: 2 151 | Amazon-Photo: 152 | learning_rate: 0.001 153 | num_hidden: 256 154 | num_proj_hidden: 256 155 | activation: "prelu" 156 | base_model: 'GCNConv' 157 | num_layers: 2 158 | drop_embedding: 0.10 159 | drop_edge_rate_1: 0.3 160 | drop_edge_rate_2: 0.4 161 | drop_feature_rate_1: 0.2 162 | drop_feature_rate_2: 0.3 163 | tau: 0.3 164 | k_layers: 4 165 | k_neighbors: 2 166 | num_epochs: 1500 167 | weight_decay: 0.000000 168 | drop_scheme: 'degree' 169 | rand_layers: 4 170 | Coauthor-CS1: 171 | learning_rate: 0.0005 172 | num_hidden: 256 173 | num_proj_hidden: 256 174 | activation: "rrelu" 175 | base_model: 'GCNConv' 176 | num_layers: 2 177 | drop_embedding: 0.3 178 | drop_edge_rate_1: 0.3 179 | drop_edge_rate_2: 0.2 180 | drop_feature_rate_1: 0.3 181 | drop_feature_rate_2: 0.4 182 | tau: 0.4 183 | num_epochs: 1000 184 | weight_decay: 0.00001 185 | drop_scheme: 'degree' 186 | rand_layers: 4 187 | Coauthor-CS: 188 | learning_rate: 0.0005 189 | num_hidden: 256 190 | num_proj_hidden: 256 191 | activation: "rrelu" 192 | base_model: 'GCNConv' 193 | num_layers: 2 194 | drop_embedding: 0.3 195 | drop_edge_rate_1: 0.3 196 | drop_edge_rate_2: 0.2 197 | drop_feature_rate_1: 0.3 198 | drop_feature_rate_2: 0.4 199 | tau: 0.4 200 | num_epochs: 1000 201 | weight_decay: 0.00001 202 | drop_scheme: 'degree' 203 | rand_layers: 4 204 | Coauthor-Phy: 205 | learning_rate: 0.01 206 | num_hidden: 128 207 | num_proj_hidden: 64 208 | activation: "rrelu" 209 | base_model: 'GCNConv' 210 | num_layers: 2 211 | drop_edge_rate_1: 0.4 212 | drop_edge_rate_2: 0.1 213 | drop_feature_rate_1: 0.1 214 | drop_feature_rate_2: 0.4 215 | tau: 0.5 216 | num_epochs: 1500 217 | weight_decay: 0.00001 218 | drop_scheme: 'degree' 219 | rand_layers: 4 220 | -------------------------------------------------------------------------------- /model_ar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv, GATConv 5 | from torch_geometric.utils import to_dense_adj 6 | 7 | class Encoder(torch.nn.Module): 8 | def __init__(self, in_channels: int, out_channels: int, activation, 9 | base_model=GCNConv, k: int = 2): 10 | super(Encoder, self).__init__() 11 | self.base_model = base_model 12 | 13 | assert k >= 2 14 | self.k = k 15 | self.conv = [base_model(in_channels, 2 * out_channels)] 16 | for _ in range(1, k-1): 17 | self.conv.append(base_model(2 * out_channels, 2 * out_channels)) 18 | self.conv.append(base_model(2 * out_channels, out_channels)) 19 | self.conv = nn.ModuleList(self.conv) 20 | self.activation = activation 21 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor): 22 | for i in range(self.k): 23 | x = self.activation(self.conv[i](x, edge_index)) 24 | return x 25 | 26 | class GCN(torch.nn.Module): 27 | def __init__(self, in_channels: int, out_channels: int, n_class: int, activation, 28 | base_model=GCNConv, dropout: float=0.5): 29 | super(GCN, self).__init__() 30 | self.base_model = base_model 31 | 32 | self.conv1 = base_model(in_channels, out_channels) 33 | self.head = base_model(out_channels, n_class) 34 | self.dropout = dropout 35 | self.activation = activation 36 | 37 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor): 38 | x = F.dropout(x, self.dropout, training=self.training) 39 | x = self.activation(self.conv1(x, edge_index)) 40 | x = F.dropout(x, self.dropout, training=self.training) 41 | return F.log_softmax(self.head(x, edge_index), dim=1) 42 | 43 | class GAT(torch.nn.Module): 44 | def __init__(self, in_channels: int, out_channels: int, n_class: int, activation, 45 | base_model=GATConv, input_dropout: float=0.5, coef_dropout: float=0.5): 46 | super(GAT, self).__init__() 47 | self.base_model = base_model 48 | self.conv1 = base_model(in_channels, out_channels, 8, dropout=coef_dropout) 49 | self.head = base_model(out_channels*8, n_class, 1, dropout=coef_dropout) 50 | self.dropout = input_dropout 51 | self.activation = activation 52 | 53 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor): 54 | x = F.dropout(x, self.dropout, training=self.training) 55 | x = self.activation(self.conv1(x, edge_index)) 56 | x = F.dropout(x, self.dropout, training=self.training) 57 | return F.log_softmax(self.head(x, edge_index), dim=1) 58 | 59 | 60 | class Model(torch.nn.Module): 61 | def __init__(self, encoder: Encoder, num_hidden: int, num_proj_hidden: int, 62 | tau: float = 0.5): 63 | super(Model, self).__init__() 64 | self.encoder: Encoder = encoder 65 | self.tau: float = tau 66 | 67 | self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden) 68 | self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden) 69 | self.cos = nn.CosineSimilarity() 70 | 71 | 72 | def forward(self, x: torch.Tensor, 73 | adj: torch.Tensor) -> torch.Tensor: 74 | 75 | return self.encoder(x, adj) 76 | 77 | def projection(self, z: torch.Tensor) -> torch.Tensor: 78 | z = F.elu(self.fc1(z)) 79 | return self.fc2(z) 80 | 81 | def sim(self, z1: torch.Tensor, z2: torch.Tensor): 82 | z1 = F.normalize(z1) 83 | z2 = F.normalize(z2) 84 | return torch.mm(z1, z2.t()) 85 | 86 | def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor): 87 | f = lambda x: torch.exp(x / self.tau) 88 | refl_sim = f(self.sim(z1, z1)) 89 | between_sim = f(self.sim(z1, z2)) 90 | 91 | return -torch.log( 92 | between_sim.diag() 93 | / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag())) 94 | 95 | def batched_semi_loss(self, z1: torch.Tensor, z2: torch.Tensor, 96 | batch_size: int): 97 | # Space complexity: O(BN) (semi_loss: O(N^2)) 98 | device = z1.device 99 | num_nodes = z1.size(0) 100 | num_batches = (num_nodes - 1) // batch_size + 1 101 | f = lambda x: torch.exp(x / self.tau) 102 | indices = torch.arange(0, num_nodes).to(device) 103 | losses = [] 104 | 105 | for i in range(num_batches): 106 | mask = indices[i * batch_size:(i + 1) * batch_size] 107 | refl_sim = f(self.sim(z1[mask], z1)) # [B, N] 108 | between_sim = f(self.sim(z1[mask], z2)) # [B, N] 109 | 110 | losses.append(-torch.log( 111 | between_sim[:, i * batch_size:(i + 1) * batch_size].diag() 112 | / (refl_sim.sum(1) + between_sim.sum(1) 113 | - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag()))) 114 | 115 | return torch.cat(losses) 116 | 117 | def loss(self, z1: torch.Tensor, z2: torch.Tensor, 118 | mean: bool = True, batch_size: int = 0): 119 | h1 = self.projection(z1) 120 | h2 = self.projection(z2) 121 | simi = torch.exp(self.cos(h1,h2)/self.tau) 122 | 123 | if batch_size == 0: 124 | l1 = self.semi_loss(h1, h2) 125 | l2 = self.semi_loss(h2, h1) 126 | else: 127 | l1 = self.batched_semi_loss(h1, h2, batch_size) 128 | l2 = self.batched_semi_loss(h2, h1, batch_size) 129 | 130 | ret = (l1 + l2) * 0.5 131 | #ret = ret.mean() if mean else ret.sum() 132 | 133 | return ret, simi 134 | 135 | 136 | def drop_feature(x, drop_prob): 137 | drop_mask = torch.empty( 138 | (x.size(1), ), 139 | dtype=torch.float32, 140 | device=x.device).uniform_(0, 1) < drop_prob 141 | x = x.clone() 142 | x[:, drop_mask] = 0 143 | 144 | return x 145 | 146 | 147 | class LogReg(nn.Module): 148 | def __init__(self, ft_in, nb_classes): 149 | super(LogReg, self).__init__() 150 | self.fc = nn.Linear(ft_in, nb_classes) 151 | 152 | for m in self.modules(): 153 | self.weights_init(m) 154 | 155 | def weights_init(self, m): 156 | if isinstance(m, nn.Linear): 157 | torch.nn.init.xavier_uniform_(m.weight.data) 158 | if m.bias is not None: 159 | m.bias.data.fill_(0.0) 160 | 161 | def forward(self, seq): 162 | ret = self.fc(seq) 163 | return ret 164 | -------------------------------------------------------------------------------- /info.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Optional 4 | import os.path as osp 5 | import random 6 | import torch 7 | from torch.optim import Adam 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numpy as np 11 | from pGRACE.dataset import get_dataset 12 | from pGRACE.eval import MulticlassEvaluator 13 | from pGRACE.model import LogReg 14 | from pGRACE.utils import get_base_model, get_activation, \ 15 | generate_split, compute_pr, eigenvector_centrality 16 | def get_idx_split(dataset, split, preload_split): 17 | if split[:4] == 'rand': 18 | train_ratio = float(split.split(':')[1]) 19 | num_nodes = dataset[0].x.size(0) 20 | train_size = int(num_nodes * train_ratio) 21 | indices = torch.randperm(num_nodes) 22 | return { 23 | 'train': indices[:train_size], 24 | 'val': indices[train_size:2 * train_size], 25 | 'test': indices[2 * train_size:] 26 | } 27 | elif split == 'ogb': 28 | return dataset.get_idx_split() 29 | elif split.startswith('wikics'): 30 | split_idx = int(split.split(':')[1]) 31 | return { 32 | 'train': dataset[0].train_mask[:, split_idx], 33 | 'test': dataset[0].test_mask, 34 | 'val': dataset[0].val_mask[:, split_idx] 35 | } 36 | elif split == 'preloaded': 37 | assert preload_split is not None, 'use preloaded split, but preloaded_split is None' 38 | train_mask, test_mask, val_mask = preload_split 39 | return { 40 | 'train': train_mask, 41 | 'test': test_mask, 42 | 'val': val_mask 43 | } 44 | else: 45 | raise RuntimeError(f'Unknown split type {split}') 46 | 47 | class LogReg(nn.Module): 48 | def __init__(self, in_feat, num_class): 49 | super(LogReg, self).__init__() 50 | self.fc = nn.Linear(in_feat, num_class) 51 | for m in self.modules(): 52 | self.weight_init(m) 53 | 54 | def weight_init(self, m): 55 | if isinstance(m, nn.Linear): 56 | torch.nn.init.xavier_uniform_(m.weight.data) 57 | if m.bias != None: 58 | m.bias.data.fill_(0.0) 59 | 60 | def forward(self, feats): 61 | return self.fc(feats) 62 | 63 | def log_regression(z1, 64 | z2, 65 | dataset, 66 | evaluator, 67 | device, 68 | num_epochs: int = 5000, 69 | split: str = 'rand:0.1', 70 | verbose: bool = False, 71 | preload_split=None): 72 | z1 = torch.from_numpy(z1).to(device) 73 | z2 = torch.from_numpy(z2).to(device) 74 | num_hidden = z1.size(1) 75 | y = dataset[0].y.view(-1).to(device) 76 | num_classes = dataset[0].y.max().item() + 1 77 | classifier = LogReg(num_hidden, num_classes).to(device) 78 | classifier2 = LogReg(num_hidden, num_classes).to(device) 79 | optimizer = Adam(classifier.parameters(), lr=0.01, weight_decay=0.0) 80 | optimizer2 = Adam(classifier2.parameters(), lr=0.01, weight_decay=0.0) 81 | torch_seed = 0 82 | torch.manual_seed(torch_seed) 83 | random.seed(12345) 84 | split = get_idx_split(dataset, split, preload_split) 85 | split = {k: v.to(device) for k, v in split.items()} 86 | f = nn.LogSoftmax(dim=-1) 87 | 88 | nll_loss = nn.NLLLoss() 89 | 90 | best_test_acc = 0 91 | best_val_acc = 0 92 | best_epoch = 0 93 | 94 | for epoch in range(num_epochs): 95 | classifier.train() 96 | classifier2.train() 97 | optimizer.zero_grad() 98 | optimizer2.zero_grad() 99 | out1 = classifier(z1) 100 | out2 = classifier(z2) 101 | out3 = classifier2(z1) 102 | out4 = classifier2(z2) 103 | output = classifier(z1[split['train']]) 104 | loss = nll_loss(f(out1)[split['train']], y[split['train']]) + nll_loss(f(out2)[split['train']], y[split['train']])\ 105 | - F.kl_div(f(out3)[split['train']], F.softmax(out4,dim=-1)[split['train']]) - F.kl_div(f(out4)[split['train']], F.softmax(out3,dim=-1)[split['train']]) 106 | 107 | loss.backward() 108 | optimizer.step() 109 | optimizer2.step() 110 | 111 | if (epoch + 1) % 20 == 0: 112 | if 'val' in split: 113 | # val split is available 114 | test_acc = evaluator.eval({ 115 | 'y_true': y[split['test']].view(-1, 1), 116 | 'y_pred': classifier(z2[split['test']]).argmax(-1).view(-1, 1) 117 | })['acc'] 118 | val_acc = evaluator.eval({ 119 | 'y_true': y[split['val']].view(-1, 1), 120 | 'y_pred': classifier(z2[split['val']]).argmax(-1).view(-1, 1) 121 | })['acc'] 122 | if val_acc > best_val_acc: 123 | best_val_acc = val_acc 124 | best_test_acc = test_acc 125 | best_epoch = epoch 126 | info1 = nll_loss(f(out1), y).detach() 127 | info2 = nll_loss(f(out2), y).detach() 128 | info3 = F.kl_div(f(out3), F.softmax(out4,dim=-1)).detach() 129 | info4 = F.kl_div(f(out4), F.softmax(out3,dim=-1)).detach() 130 | else: 131 | acc = evaluator.eval({ 132 | 'y_true': y[split['test']].view(-1, 1), 133 | 'y_pred': classifier(z2[split['test']]).argmax(-1).view(-1, 1) 134 | })['acc'] 135 | if best_test_acc < acc: 136 | best_test_acc = acc 137 | best_epoch = epoch 138 | best_loss = loss 139 | info1 = nll_loss(f(out1), y).detach() 140 | info2 = nll_loss(f(out2), y).detach() 141 | info3 = F.kl_div(F.softmax(out3,dim=-1), F.softmax(out4,dim=-1)).detach() 142 | info4 = F.kl_div(F.softmax(out4,dim=-1), F.softmax(out3,dim=-1)).detach() 143 | if verbose: 144 | print(f'logreg epoch {epoch}: best test acc {best_test_acc}') 145 | 146 | return {'acc': best_test_acc}, info1, info2, info3, info4 147 | 148 | 149 | z1 = np.load('embedding/CiteSeerview1_embeddingnew2.npy') 150 | z2 = np.load('embedding/CiteSeerview2_embeddingnew2.npy') 151 | dataset = 'CiteSeer' 152 | 153 | path = osp.expanduser('~/datasets') 154 | path = osp.join(path, dataset) 155 | dataset = get_dataset(path, dataset) 156 | 157 | device = 'cuda:1' 158 | evaluator = MulticlassEvaluator() 159 | data = dataset[0] 160 | data = data.to(device) 161 | split = generate_split(data.num_nodes, train_ratio=0.1, val_ratio=0.1) 162 | res, info1, info2, info3, info4 = log_regression(z1, z2, dataset, evaluator, device, split='rand:0.1', num_epochs=3000, preload_split=split) 163 | 164 | print (res['acc']) 165 | print (info1) 166 | print (info2) 167 | print (info3) 168 | print (info4) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import argparse 3 | import os.path as osp 4 | import random 5 | import nni 6 | 7 | import torch 8 | from torch_geometric.utils import dropout_adj, degree, to_undirected 9 | 10 | from simple_param.sp import SimpleParam 11 | from pGRACE.model import Encoder, GRACE 12 | from pGRACE.functional import drop_feature, drop_edge_weighted, \ 13 | degree_drop_weights, \ 14 | evc_drop_weights, pr_drop_weights, \ 15 | feature_drop_weights, drop_feature_weighted_2, feature_drop_weights_dense 16 | from pGRACE.eval import log_regression, MulticlassEvaluator 17 | from pGRACE.utils import get_base_model, get_activation, \ 18 | generate_split, compute_pr, eigenvector_centrality 19 | from pGRACE.dataset import get_dataset 20 | 21 | 22 | def train(): 23 | model.train() 24 | optimizer.zero_grad() 25 | 26 | def drop_edge(idx: int): 27 | global drop_weights 28 | 29 | if param['drop_scheme'] == 'uniform': 30 | return dropout_adj(data.edge_index, p=param[f'drop_edge_rate_{idx}'])[0] 31 | elif param['drop_scheme'] in ['degree', 'evc', 'pr']: 32 | return drop_edge_weighted(data.edge_index, drop_weights, p=param[f'drop_edge_rate_{idx}'], threshold=0.7) 33 | else: 34 | raise Exception(f'undefined drop scheme: {param["drop_scheme"]}') 35 | 36 | edge_index_1 = drop_edge(1) 37 | edge_index_2 = drop_edge(2) 38 | x_1 = drop_feature(data.x, param['drop_feature_rate_1']) 39 | x_2 = drop_feature(data.x, param['drop_feature_rate_2']) 40 | 41 | if param['drop_scheme'] in ['pr', 'degree', 'evc']: 42 | x_1 = drop_feature_weighted_2(data.x, feature_weights, param['drop_feature_rate_1']) 43 | x_2 = drop_feature_weighted_2(data.x, feature_weights, param['drop_feature_rate_2']) 44 | 45 | z1 = model(x_1, edge_index_1) 46 | z2 = model(x_2, edge_index_2) 47 | 48 | loss = model.loss(z1, z2, batch_size=1024 if args.dataset == 'Coauthor-Phy' else None) 49 | loss.backward() 50 | optimizer.step() 51 | 52 | return loss.item() 53 | 54 | 55 | def test(final=False): 56 | model.eval() 57 | z = model(data.x, data.edge_index) 58 | 59 | evaluator = MulticlassEvaluator() 60 | if args.dataset == 'WikiCS': 61 | accs = [] 62 | for i in range(20): 63 | acc = log_regression(z, dataset, evaluator, split=f'wikics:{i}', num_epochs=800)['acc'] 64 | accs.append(acc) 65 | acc = sum(accs) / len(accs) 66 | else: 67 | acc = log_regression(z, dataset, evaluator, split='rand:0.1', num_epochs=3000, preload_split=split)['acc'] 68 | 69 | if final and use_nni: 70 | nni.report_final_result(acc) 71 | elif use_nni: 72 | nni.report_intermediate_result(acc) 73 | 74 | return acc 75 | 76 | 77 | if __name__ == '__main__': 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('--device', type=str, default='cuda:2') 80 | parser.add_argument('--dataset', type=str, default='WikiCS') 81 | parser.add_argument('--param', type=str, default='local:wikics.json') 82 | parser.add_argument('--seed', type=int, default=39788) 83 | parser.add_argument('--verbose', type=str, default='train,eval,final') 84 | parser.add_argument('--save_split', type=str, nargs='?') 85 | parser.add_argument('--load_split', type=str, nargs='?') 86 | default_param = { 87 | 'learning_rate': 0.01, 88 | 'num_hidden': 256, 89 | 'num_proj_hidden': 32, 90 | 'activation': 'prelu', 91 | 'base_model': 'GCNConv', 92 | 'num_layers': 2, 93 | 'drop_edge_rate_1': 0.3, 94 | 'drop_edge_rate_2': 0.4, 95 | 'drop_feature_rate_1': 0.1, 96 | 'drop_feature_rate_2': 0.0, 97 | 'tau': 0.4, 98 | 'num_epochs': 3000, 99 | 'weight_decay': 1e-5, 100 | 'drop_scheme': 'degree', 101 | } 102 | 103 | # add hyper-parameters into parser 104 | param_keys = default_param.keys() 105 | for key in param_keys: 106 | parser.add_argument(f'--{key}', type=type(default_param[key]), nargs='?') 107 | args = parser.parse_args() 108 | 109 | # parse param 110 | sp = SimpleParam(default=default_param) 111 | param = sp(source=args.param, preprocess='nni') 112 | 113 | # merge cli arguments and parsed param 114 | for key in param_keys: 115 | if getattr(args, key) is not None: 116 | param[key] = getattr(args, key) 117 | 118 | use_nni = args.param == 'nni' 119 | if use_nni and args.device != 'cpu': 120 | args.device = 'cuda' 121 | 122 | torch_seed = args.seed 123 | torch.manual_seed(torch_seed) 124 | random.seed(12345) 125 | 126 | device = torch.device(args.device) 127 | 128 | path = osp.expanduser('~/datasets') 129 | path = osp.join(path, args.dataset) 130 | dataset = get_dataset(path, args.dataset) 131 | 132 | data = dataset[0] 133 | data = data.to(device) 134 | 135 | # generate split 136 | split = generate_split(data.num_nodes, train_ratio=0.1, val_ratio=0.1) 137 | 138 | if args.save_split: 139 | torch.save(split, args.save_split) 140 | elif args.load_split: 141 | split = torch.load(args.load_split) 142 | 143 | encoder = Encoder(dataset.num_features, param['num_hidden'], get_activation(param['activation']), 144 | base_model=get_base_model(param['base_model']), k=param['num_layers']).to(device) 145 | model = GRACE(encoder, param['num_hidden'], param['num_proj_hidden'], param['tau']).to(device) 146 | optimizer = torch.optim.Adam( 147 | model.parameters(), 148 | lr=param['learning_rate'], 149 | weight_decay=param['weight_decay'] 150 | ) 151 | 152 | if param['drop_scheme'] == 'degree': 153 | drop_weights = degree_drop_weights(data.edge_index).to(device) 154 | elif param['drop_scheme'] == 'pr': 155 | drop_weights = pr_drop_weights(data.edge_index, aggr='sink', k=200).to(device) 156 | elif param['drop_scheme'] == 'evc': 157 | drop_weights = evc_drop_weights(data).to(device) 158 | else: 159 | drop_weights = None 160 | 161 | if param['drop_scheme'] == 'degree': 162 | edge_index_ = to_undirected(data.edge_index) 163 | node_deg = degree(edge_index_[1]) 164 | if args.dataset == 'WikiCS': 165 | feature_weights = feature_drop_weights_dense(data.x, node_c=node_deg).to(device) 166 | else: 167 | feature_weights = feature_drop_weights(data.x, node_c=node_deg).to(device) 168 | elif param['drop_scheme'] == 'pr': 169 | node_pr = compute_pr(data.edge_index) 170 | if args.dataset == 'WikiCS': 171 | feature_weights = feature_drop_weights_dense(data.x, node_c=node_pr).to(device) 172 | else: 173 | feature_weights = feature_drop_weights(data.x, node_c=node_pr).to(device) 174 | elif param['drop_scheme'] == 'evc': 175 | node_evc = eigenvector_centrality(data) 176 | if args.dataset == 'WikiCS': 177 | feature_weights = feature_drop_weights_dense(data.x, node_c=node_evc).to(device) 178 | else: 179 | feature_weights = feature_drop_weights(data.x, node_c=node_evc).to(device) 180 | else: 181 | feature_weights = torch.ones((data.x.size(1),)).to(device) 182 | 183 | log = args.verbose.split(',') 184 | 185 | for epoch in range(1, param['num_epochs'] + 1): 186 | loss = train() 187 | if 'train' in log: 188 | print(f'(T) | Epoch={epoch:03d}, loss={loss:.4f}') 189 | 190 | if epoch % 100 == 0: 191 | acc = test() 192 | 193 | if 'eval' in log: 194 | print(f'(E) | Epoch={epoch:04d}, avg_acc = {acc}') 195 | 196 | acc = test(final=True) 197 | 198 | if 'final' in log: 199 | print(f'{acc}') 200 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import random 4 | import nni 5 | import yaml 6 | from yaml import SafeLoader 7 | import numpy as np 8 | import scipy 9 | import torch 10 | from torch_scatter import scatter_add 11 | import torch.nn as nn 12 | from torch_geometric.utils import dropout_adj, degree, to_undirected, get_laplacian 13 | import torch.nn.functional as F 14 | import networkx as nx 15 | from scipy.sparse.linalg import eigs, eigsh 16 | 17 | from torch_geometric.utils import get_laplacian, to_scipy_sparse_matrix 18 | from simple_param.sp import SimpleParam 19 | from pGRACE.model import Encoder, GRACE, NewGConv, NewEncoder, NewGRACE 20 | from pGRACE.functional import drop_feature, drop_edge_weighted, \ 21 | degree_drop_weights, \ 22 | evc_drop_weights, pr_drop_weights, \ 23 | feature_drop_weights, drop_feature_weighted_2, feature_drop_weights_dense 24 | from pGRACE.eval import log_regression, MulticlassEvaluator 25 | from pGRACE.utils import get_base_model, get_activation, \ 26 | generate_split, compute_pr, eigenvector_centrality 27 | from pGRACE.dataset import get_dataset 28 | from utils import normalize_adjacency_matrix, create_adjacency_matrix, load_adj_neg, Rank 29 | 30 | def train(): 31 | model.train() 32 | #view_learner.eval() 33 | optimizer.zero_grad() 34 | edge_index_1 = dropout_adj(data.edge_index, p=drop_edge_rate_1)[0] 35 | edge_index_2 = dropout_adj(data.edge_index, p=drop_edge_rate_2)[0] #adjacency with edge droprate 2 36 | 37 | x_1 = drop_feature(data.x, drop_feature_rate_1)#3 38 | x_2 = drop_feature(data.x, drop_feature_rate_2)#4 39 | #cora:3,3,6,3 40 | #CS:(1,2)(1,2)(2,3)(2,3) 41 | #AP:(3,4)(4,5)(1,2)(2,3) 42 | #Citseer(2,3)(3,4)(1,2)(1,2)(2,2) 43 | #CiteSeer(4,2)(3,2) 44 | #AC:(3,4)(1,4)(0,2)(1,3) 45 | #PubMed:(0,3)(1,3)(0,3)(0,2) 46 | #k2 = np.random.randint(0, 4) 47 | z1 = model(x_1, edge_index_1, [2, 2]) 48 | z2 = model(x_2, edge_index_2, [8, 8]) 49 | 50 | loss = model.loss(z1, z2, batch_size=64 if args.dataset == 'Coauthor-Phy' or args.dataset == 'ogbn-arxiv' else None) 51 | loss.backward() 52 | optimizer.step() 53 | 54 | return loss.item() 55 | 56 | 57 | def test(final=False): 58 | 59 | model.eval() 60 | z = model(data.x, data.edge_index, [1, 1], final=True) 61 | 62 | evaluator = MulticlassEvaluator() 63 | if args.dataset == 'WikiCS': 64 | accs = [] 65 | accs_1 = [] 66 | accs_2 = [] 67 | for i in range(20): 68 | acc = log_regression(z, dataset, evaluator, split=f'wikics:{i}', num_epochs=800)['acc'] 69 | accs.append(acc) 70 | acc = sum(accs) / len(accs) 71 | else: 72 | if args.dataset == 'Cora' or args.dataset == 'CiteSeer' or args.dataset == 'PubMed': 73 | #acc = log_regression(z, dataset, evaluator, split='preloaded', num_epochs=3000, preload_split=0)['acc'] 74 | acc = log_regression(z, dataset, evaluator, split='rand:0.1', num_epochs=3000, preload_split=0)['acc'] 75 | else : acc = log_regression(z, dataset, evaluator, split='rand:0.1', num_epochs=3000, preload_split=0)['acc'] 76 | #acc_2 = log_regression(z2, dataset, evaluator2, split='rand:0.1', num_epochs=3000, preload_split=split)['acc'] 77 | 78 | if final and use_nni: 79 | nni.report_final_result(acc) 80 | elif use_nni: 81 | nni.report_intermediate_result(acc) 82 | 83 | return acc#, acc_1, acc_2 84 | 85 | 86 | if __name__ == '__main__': 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--device', type=str, default='cuda:0') 89 | parser.add_argument('--dataset', type=str, default='Amazon-Computers') 90 | parser.add_argument('--config', type=str, default='param.yaml') 91 | parser.add_argument('--seed', type=int, default=0) 92 | parser.add_argument('--verbose', type=str, default='train,eval,final') 93 | parser.add_argument('--save_split', type=str, nargs='?') 94 | parser.add_argument('--load_split', type=str, nargs='?') 95 | args = parser.parse_args() 96 | 97 | config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset] 98 | 99 | torch.manual_seed(args.seed) 100 | random.seed(0) 101 | np.random.seed(args.seed) 102 | use_nni = args.config == 'nni' 103 | learning_rate = config['learning_rate'] 104 | num_hidden = config['num_hidden'] 105 | num_proj_hidden = config['num_proj_hidden'] 106 | activation = config['activation'] 107 | base_model = config['base_model'] 108 | num_layers = config['num_layers'] 109 | dataset = args.dataset 110 | drop_edge_rate_1 = config['drop_edge_rate_1'] 111 | drop_edge_rate_2 = config['drop_edge_rate_2'] 112 | drop_feature_rate_1 = config['drop_feature_rate_1'] 113 | drop_feature_rate_2 = config['drop_feature_rate_2'] 114 | drop_scheme = config['drop_scheme'] 115 | tau = config['tau'] 116 | num_epochs = config['num_epochs'] 117 | weight_decay = config['weight_decay'] 118 | rand_layers = config['rand_layers'] 119 | 120 | device = torch.device(args.device) 121 | 122 | path = osp.expanduser('~/datasets') 123 | path = osp.join(path, args.dataset) 124 | dataset = get_dataset(path, args.dataset) 125 | 126 | data = dataset[0] 127 | data = data.to(device) 128 | """ 129 | adj = create_adjacency_matrix(data.edge_index) 130 | I = scipy.sparse.eye(adj.shape[0]) 131 | adj, lap = normalize_adjacency_matrix(adj, I) 132 | adj = adj.tocoo() 133 | adj = torch.sparse.LongTensor(torch.LongTensor([adj.row.tolist(), adj.col.tolist()]), 134 | torch.FloatTensor(adj.data.astype(np.float))).to(device) 135 | 136 | from torch_geometric.utils import add_remaining_self_loops, add_self_loops 137 | edge_index = data.edge_index 138 | edge_weight = torch.ones((edge_index.size(1), ), 139 | device=edge_index.device) 140 | fill_value = 1. 141 | num_nodes = data.num_nodes 142 | edge_index, edge_weight = add_remaining_self_loops( 143 | edge_index, edge_weight, fill_value, num_nodes) 144 | 145 | edge_weight_x = edge_weight 146 | row, col = edge_index 147 | edge_index, edge_weight = get_laplacian(edge_index, edge_weight, num_nodes=num_nodes) 148 | deg = scatter_add(edge_weight_x, row, dim=0, dim_size=num_nodes) 149 | deg_inv_sqrt = deg.pow(-0.5) 150 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 151 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 152 | edge_index, edge_weight = add_self_loops( 153 | edge_index, -0.5*edge_weight, 1, num_nodes) 154 | L = to_scipy_sparse_matrix(edge_index, edge_weight, data.num_nodes) 155 | #L_3 = L.dot(L).dot(L) 156 | #L_6 = L.dot(L).dot(L).dot(L).dot(L).dot(L) 157 | eig_fn = eigs 158 | eig_3 = eig_fn(L, k=num_nodes-2, which='LM', return_eigenvectors=False) 159 | #eig_6 = eig_fn(L_6, k=num_nodes-2, which='LM', return_eigenvectors=False) 160 | np.savetxt('eig_4.txt', eig_3) 161 | #np.savetxt('eig_6.txt', eig_6) 162 | 163 | adj = create_adjacency_matrix(data.edge_index) 164 | #idx = Rank(adj, 1000, 0.9) 165 | #idx = idx.squeeze(0).numpy() 166 | I = scipy.sparse.eye(adj.shape[0]) 167 | TA = adj + I 168 | 169 | adj, lap = normalize_adjacency_matrix(adj, I) 170 | adj = adj.tocoo() 171 | TA = TA.tocoo() 172 | lap = lap.tocoo() 173 | TA = torch.sparse.LongTensor(torch.LongTensor([TA.row.tolist(), TA.col.tolist()]), 174 | torch.FloatTensor(TA.data.astype(np.float))) 175 | #TA = TA.to_dense().float() 176 | adj = torch.sparse.LongTensor(torch.LongTensor([adj.row.tolist(), adj.col.tolist()]), 177 | torch.FloatTensor(adj.data.astype(np.float))) 178 | #adj = adj.to_dense().float().to(device) 179 | adj = adj.float().to(device) 180 | #lap = torch.sparse.LongTensor(torch.LongTensor([lap.row.tolist(), lap.col.tolist()]), 181 | # torch.FloatTensor(lap.data.astype(np.float))) 182 | #lap = lap.to_dense().float() 183 | 184 | K = 8 185 | feat = data.x 186 | emb = feat 187 | for i in range(K): 188 | feat = torch.spmm(adj, feat) 189 | emb = emb + feat 190 | emb/=K 191 | """ 192 | adj = 0 193 | 194 | #if args.dataset == 'Cora' or args.dataset == 'CiteSeer' or args.dataset == 'PubMed': split = (data.train_mask, data.val_mask, data.test_mask) 195 | 196 | encoder = NewEncoder(dataset.num_features, num_hidden, get_activation(activation), 197 | base_model=NewGConv, k=num_layers).to(device) 198 | 199 | model = NewGRACE(encoder, adj, num_hidden, num_proj_hidden, tau).to(device) 200 | 201 | optimizer = torch.optim.Adam( 202 | model.parameters(), 203 | lr=learning_rate, 204 | weight_decay=weight_decay 205 | ) 206 | 207 | log = args.verbose.split(',') 208 | 209 | for epoch in range(1, num_epochs + 1): 210 | 211 | loss = train() 212 | if 'train' in log: 213 | print(f'(T) | Epoch={epoch:03d}, loss={loss:.4f}') 214 | if epoch % 100 == 0: 215 | acc = test() 216 | x_1 = drop_feature(data.x, drop_feature_rate_1)#3 217 | x_2 = drop_feature(data.x, drop_feature_rate_2)#4 218 | #x_3 = drop_feature(sub_x, drop_feature_rate_1) 219 | 220 | edge_index_2 = dropout_adj(data.edge_index, p=drop_edge_rate_2)[0] #adjacency with edge droprate 2 221 | edge_index_1 = dropout_adj(data.edge_index, p=drop_edge_rate_1)[0] #adjacency with edge droprate 2 222 | z = model(data.x, data.edge_index, [2, 2], final=True).detach().cpu().numpy() 223 | z1 = model(x_1, edge_index_1, [2, 2], final=True).detach().cpu().numpy() 224 | z2 = model(x_2, edge_index_2, [2, 2], final=True).detach().cpu().numpy() 225 | np.save('embedding/'+args.dataset + 'view1_embeddingfull.npy', z1) 226 | np.save('embedding/'+args.dataset + 'view2_embeddingfull.npy', z2) 227 | np.save('embedding/'+args.dataset + 'Graph_embeddingfull.npy', z) 228 | if 'eval' in log: 229 | print(f'(E) | Epoch={epoch:04d}, avg_acc = {acc}') 230 | 231 | 232 | acc = test(final=True) 233 | 234 | if 'final' in log: 235 | print(f'{acc}') 236 | 237 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster 0.7.12 2 | alembic 1.5.7 3 | anaconda-client 1.7.2 4 | anaconda-navigator 1.10.0 5 | anaconda-project 0.8.3 6 | appdirs 1.4.4 7 | argh 0.26.2 8 | argon2-cffi 20.1.0 9 | asn1crypto 1.4.0 10 | astor 0.8.1 11 | astroid 2.4.2 12 | astropy 4.0.2 13 | async-generator 1.10 14 | atomicwrites 1.4.0 15 | attrs 20.3.0 16 | autopep8 1.6.0 17 | Babel 2.8.1 18 | backcall 0.2.0 19 | backports.functools-lru-cache 1.6.4 20 | backports.shutil-get-terminal-size 1.0.0 21 | backports.tempfile 1.0 22 | backports.weakref 1.0.post1 23 | beautifulsoup4 4.9.3 24 | bitarray 1.6.1 25 | bkcharts 0.2 26 | bleach 3.2.1 27 | bokeh 2.2.3 28 | boto 2.49.0 29 | Bottleneck 1.3.2 30 | brotlipy 0.7.0 31 | certifi 2020.6.20 32 | cffi 1.14.3 33 | chardet 3.0.4 34 | click 7.1.2 35 | cliff 3.7.0 36 | cloudpickle 1.6.0 37 | clyent 1.2.2 38 | cmaes 0.8.2 39 | cmd2 1.5.0 40 | colorama 0.4.4 41 | colorlog 4.7.2 42 | conda 4.12.0 43 | conda-build 3.20.5 44 | conda-package-handling 1.7.3 45 | conda-verify 3.4.2 46 | contextlib2 0.6.0.post1 47 | cryptography 3.1.1 48 | cycler 0.10.0 49 | Cython 0.29.21 50 | cytoolz 0.11.0 51 | dask 2.30.0 52 | decorator 4.4.2 53 | defusedxml 0.6.0 54 | dgl 0.8.0.post1 55 | dgl-cu100 0.5.2 56 | dgl-cu101 0.8a211204 57 | dgl-cu110 0.6.1 58 | dgl-cu111 0.6.0.post1 59 | diff-match-patch 20200713 60 | distlib 0.3.1 61 | distributed 2.30.1 62 | docutils 0.16 63 | entrypoints 0.3 64 | et-xmlfile 1.0.1 65 | fastcache 1.1.0 66 | filelock 3.0.12 67 | flake8 3.8.4 68 | Flask 1.1.2 69 | fsspec 0.8.3 70 | future 0.18.2 71 | gensim 3.8.3 72 | gevent 20.9.0 73 | glob2 0.7 74 | gmpy2 2.0.8 75 | googledrivedownloader 0.4 76 | greenlet 0.4.17 77 | h5py 2.10.0 78 | HeapDict 1.0.1 79 | html5lib 1.1 80 | hyperopt 0.1.2 81 | idna 2.10 82 | imageio 2.9.0 83 | imagesize 1.2.0 84 | importlib-metadata 2.0.0 85 | iniconfig 1.1.1 86 | intervaltree 3.1.0 87 | ipykernel 5.3.4 88 | ipython 7.19.0 89 | ipython-genutils 0.2.0 90 | ipywidgets 7.5.1 91 | isodate 0.6.0 92 | isort 5.10.1 93 | itsdangerous 1.1.0 94 | jdcal 1.4.1 95 | jedi 0.17.1 96 | jeepney 0.5.0 97 | Jinja2 2.11.2 98 | joblib 0.17.0 99 | json-tricks 3.15.5 100 | json5 0.9.5 101 | jsonschema 3.2.0 102 | jupyter 1.0.0 103 | jupyter-client 6.1.7 104 | jupyter-console 6.2.0 105 | jupyter-core 4.6.3 106 | jupyterlab 2.2.6 107 | jupyterlab-pygments 0.1.2 108 | jupyterlab-server 1.2.0 109 | keras 2.8.0 110 | keyring 21.4.0 111 | kiwisolver 1.3.0 112 | lazy-object-proxy 1.4.3 113 | libarchive-c 2.9 114 | littleutils 0.2.2 115 | llvmlite 0.34.0 116 | locket 0.2.0 117 | lxml 4.6.1 118 | Mako 1.1.4 119 | MarkupSafe 1.1.1 120 | matplotlib 3.3.2 121 | mccabe 0.6.1 122 | memory-profiler 0.58.0 123 | mistune 0.8.4 124 | mkl-fft 1.2.0 125 | mkl-random 1.1.1 126 | mkl-service 2.3.0 127 | mock 4.0.2 128 | more-itertools 8.6.0 129 | mpmath 1.1.0 130 | msgpack 1.0.0 131 | multipledispatch 0.6.0 132 | munkres 1.1.4 133 | navigator-updater 0.2.1 134 | nbclient 0.5.1 135 | nbconvert 6.0.7 136 | nbformat 5.0.8 137 | nest-asyncio 1.4.2 138 | networkx 2.5 139 | nltk 3.5 140 | nose 1.3.7 141 | notebook 6.1.4 142 | numba 0.51.2 143 | numexpr 2.7.1 144 | numpy 1.19.2 145 | numpydoc 1.1.0 146 | ogb 1.3.1 147 | olefile 0.46 148 | openpyxl 3.0.5 149 | optuna 2.6.0 150 | outdated 0.2.1 151 | packaging 20.4 152 | pandas 1.1.3 153 | pandocfilters 1.4.3 154 | parso 0.7.0 155 | partd 1.1.0 156 | path 15.0.0 157 | pathlib2 2.3.5 158 | pathtools 0.1.2 159 | patsy 0.5.1 160 | pbr 5.5.1 161 | pep8 1.7.1 162 | pexpect 4.8.0 163 | pickleshare 0.7.5 164 | Pillow 8.0.1 165 | pip 20.2.4 166 | pkginfo 1.6.1 167 | pluggy 0.13.1 168 | ply 3.11 169 | prettytable 2.1.0 170 | prometheus-client 0.8.0 171 | prompt-toolkit 3.0.8 172 | protobuf 3.19.4 173 | psutil 5.7.2 174 | ptyprocess 0.6.0 175 | py 1.9.0 176 | pycodestyle 2.8.0 177 | pycosat 0.6.3 178 | pycparser 2.20 179 | pycurl 7.43.0.6 180 | pydocstyle 5.1.1 181 | pyflakes 2.2.0 182 | PyGCL 0.1.1 183 | Pygments 2.7.2 184 | pylint 2.6.0 185 | pymongo 4.0.1 186 | pyodbc 4.0.0-unsupported 187 | pyOpenSSL 19.1.0 188 | pyparsing 2.4.7 189 | pyperclip 1.8.2 190 | pyrsistent 0.17.3 191 | PySocks 1.7.1 192 | pytest 0.0.0 193 | python-dateutil 2.8.1 194 | python-editor 1.0.4 195 | python-jsonrpc-server 0.4.0 196 | python-language-server 0.35.1 197 | python-louvain 0.15 198 | PythonWebHDFS 0.2.3 199 | pytz 2020.1 200 | PyWavelets 1.1.1 201 | pyxdg 0.27 202 | PyYAML 5.3.1 203 | pyzmq 19.0.2 204 | QDarkStyle 2.8.1 205 | QtAwesome 1.0.1 206 | qtconsole 4.7.7 207 | QtPy 1.9.0 208 | rdflib 6.0.0 209 | regex 2020.10.15 210 | requests 2.24.0 211 | responses 0.16.0 212 | rope 0.18.0 213 | Rtree 0.9.4 214 | ruamel-yaml 0.15.87 215 | ruamel.yaml.clib 0.2.6 216 | schema 0.7.5 217 | scikit-image 0.17.2 218 | scikit-learn 1.0.1 219 | scipy 1.5.2 220 | seaborn 0.11.0 221 | SecretStorage 3.1.2 222 | Send2Trash 1.5.0 223 | setuptools 50.3.1.post20201107 224 | simplegeneric 0.8.1 225 | simplejson 3.17.6 226 | singledispatch 3.4.0.3 227 | sip 4.19.13 228 | six 1.15.0 229 | smart-open 4.1.0 230 | snowballstemmer 2.0.0 231 | sortedcollections 1.2.1 232 | sortedcontainers 2.2.2 233 | soupsieve 2.0.1 234 | Sphinx 3.2.1 235 | sphinxcontrib-applehelp 1.0.2 236 | sphinxcontrib-devhelp 1.0.2 237 | sphinxcontrib-htmlhelp 1.0.3 238 | sphinxcontrib-jsmath 1.0.1 239 | sphinxcontrib-qthelp 1.0.3 240 | sphinxcontrib-serializinghtml 1.1.4 241 | sphinxcontrib-websupport 1.2.4 242 | spyder 4.1.5 243 | spyder-kernels 1.9.4 244 | SQLAlchemy 1.3.20 245 | statsmodels 0.12.0 246 | stevedore 3.3.0 247 | sympy 1.6.2 248 | tables 3.6.1 249 | tabulate 0.8.7 250 | tblib 1.7.0 251 | tensorboardX 2.5 252 | terminado 0.9.1 253 | testpath 0.4.4 254 | threadpoolctl 2.1.0 255 | tifffile 2020.10.1 256 | toml 0.10.1 257 | toolz 0.11.1 258 | torch 1.9.0+cu111 259 | torch-cluster 1.5.9 260 | torch-geometric 2.0.3 261 | torch-scatter 2.0.9 262 | torch-sparse 0.6.12 263 | torch-spline-conv 1.2.1 264 | torchaudio 0.9.0 265 | torchvision 0.10.0+cu111 266 | tornado 6.0.4 267 | tqdm 4.50.2 268 | traitlets 5.0.5 269 | tsnecuda 3.0.1 270 | typing-extensions 3.7.4.3 271 | ujson 4.0.1 272 | unicodecsv 0.14.1 273 | urllib3 1.25.11 274 | virtualenv 20.4.0 275 | watchdog 0.10.3 276 | wcwidth 0.2.5 277 | webencodings 0.5.1 278 | websockets 10.1 279 | Werkzeug 1.0.1 280 | wheel 0.35.1 281 | widgetsnbextension 3.5.1 282 | wrapt 1.11.2 283 | wurlitzer 2.0.1 284 | xlrd 1.2.0 285 | XlsxWriter 1.3.7 286 | xlwt 1.3.0 287 | xmltodict 0.12.0 288 | yacs 0.1.8 289 | yapf 0.30.0 290 | zict 2.0.0 291 | zipp 3.4.0 292 | zope.event 4.5.0 293 | zope.interface 5.1.2 -------------------------------------------------------------------------------- /pGRACE/model_mb.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import random 7 | import numpy as np 8 | from torch_geometric.nn import GCNConv, GATConv 9 | from ssgc import Net 10 | 11 | from torch.nn import Sequential, Linear, ReLU 12 | from torch_geometric.nn import MessagePassing 13 | from torch_geometric.utils import add_self_loops, degree 14 | from torch_geometric.datasets import TUDataset 15 | 16 | import torch 17 | from torch.nn import Parameter 18 | from torch_scatter import scatter_add 19 | from torch_geometric.nn.conv import MessagePassing 20 | from torch_geometric.utils import add_remaining_self_loops 21 | import math 22 | from typing import Optional 23 | 24 | import torch 25 | from torch import nn 26 | import torch.nn.functional as F 27 | 28 | from torch_geometric.nn import GCNConv 29 | 30 | def glorot(tensor):#inits.py中 31 | if tensor is not None: 32 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 33 | tensor.data.uniform_(-stdv, stdv)#将tensor的值设置为-stdv, stdv之间 34 | def zeros(tensor): 35 | if tensor is not None: 36 | tensor.data.fill_(0) 37 | 38 | 39 | 40 | class NewGConv(MessagePassing): 41 | r"""The graph convolutional operator from the `"Semi-supervised 42 | Classification with Graph Convolutional Networks" 43 | `_ paper 44 | 45 | .. math:: 46 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 47 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 48 | 49 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 50 | adjacency matrix with inserted self-loops and 51 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 52 | 53 | Args: 54 | in_channels (int): Size of each input sample. 55 | out_channels (int): Size of each output sample. 56 | improved (bool, optional): If set to :obj:`True`, the layer computes 57 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 58 | (default: :obj:`False`) 59 | cached (bool, optional): If set to :obj:`True`, the layer will cache 60 | the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 61 | \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the 62 | cached version for further executions. 63 | This parameter should only be set to :obj:`True` in transductive 64 | learning scenarios. (default: :obj:`False`) 65 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 66 | an additive bias. (default: :obj:`True`) 67 | normalize (bool, optional): Whether to add self-loops and apply 68 | symmetric normalization. (default: :obj:`True`) 69 | **kwargs (optional): Additional arguments of 70 | :class:`torch_geometric.nn.conv.MessagePassing`. 71 | """ 72 | 73 | def __init__(self, in_channels, out_channels, improved=False, cached=False, 74 | bias=True, normalize=True, **kwargs): 75 | super(NewGConv, self).__init__(aggr='add', **kwargs) 76 | 77 | self.in_channels = in_channels#输入通道数,也就是X的shape[1] 78 | self.out_channels = out_channels#输出通道数 79 | self.improved = improved#$设置为true时A尖等于A+2I 80 | 81 | self.cached = cached#If set to True, the layer will cache the computation of D^−1/2A^D^−1/2 on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False) 82 | self.normalize = normalize#是否添加自环并应用对称归一化。 83 | 84 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 85 | 86 | if bias:#如果设置为False,则该层将不会学习加法偏差 87 | self.bias = Parameter(torch.Tensor(out_channels)) 88 | else: 89 | self.register_parameter('bias', None) 90 | 91 | self.reset_parameters() 92 | 93 | def reset_parameters(self): 94 | glorot(self.weight)#glorot函数下面有写,初始化weight矩阵 95 | zeros(self.bias)#zeros函数下面有写,初始化偏置矩阵 96 | self.cached_result = None 97 | self.cached_num_edges = None 98 | 99 | 100 | @staticmethod 101 | def norm(edge_index, num_nodes, edge_weight=None, improved=False, 102 | dtype=None): 103 | if edge_weight is None: 104 | edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, 105 | device=edge_index.device) 106 | 107 | fill_value = 1 if not improved else 2 108 | edge_index, edge_weight = add_remaining_self_loops( 109 | edge_index, edge_weight, fill_value, num_nodes) 110 | 111 | row, col = edge_index 112 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 113 | deg_inv_sqrt = deg.pow(-0.5) 114 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 115 | 116 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 117 | 118 | 119 | def forward(self, x, edge_index, c, edge_weight=None): 120 | """""" 121 | x = torch.matmul(x, self.weight)#将x与权重矩阵相乘 122 | 123 | if self.cached and self.cached_result is not None: 124 | if edge_index.size(1) != self.cached_num_edges: 125 | raise RuntimeError( 126 | 'Cached {} number of edges, but found {}. Please ' 127 | 'disable the caching behavior of this layer by removing ' 128 | 'the `cached=True` argument in its constructor.'.format( 129 | self.cached_num_edges, edge_index.size(1))) 130 | 131 | if not self.cached or self.cached_result is None: 132 | self.cached_num_edges = edge_index.size(1) 133 | if self.normalize: 134 | edge_index, norm = self.norm(edge_index, x.size( 135 | self.node_dim), edge_weight, self.improved, x.dtype) 136 | else: 137 | norm = edge_weight 138 | self.cached_result = edge_index, norm 139 | 140 | edge_index, norm = self.cached_result 141 | for _ in range(c): 142 | x = x + self.propagate(edge_index, x=x, norm=norm) 143 | return x 144 | 145 | 146 | def message(self, x_j, norm): 147 | return norm.view(-1, 1) * x_j if norm is not None else x_j 148 | 149 | def update(self, aggr_out): 150 | if self.bias is not None: 151 | aggr_out = aggr_out + self.bias 152 | return aggr_out 153 | 154 | def __repr__(self): 155 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 156 | self.out_channels) 157 | 158 | class NewEncoder(nn.Module): 159 | def __init__(self, in_channels: int, out_channels: int, activation, base_model=GATConv, k: int = 2, skip=False): 160 | super(NewEncoder, self).__init__() 161 | self.base_model = base_model 162 | assert k >= 1 163 | self.k = k 164 | self.skip = skip 165 | self.out = out_channels 166 | hi = 2 167 | if k == 1: 168 | self.conv = [base_model(in_channels, out_channels).jittable()] 169 | self.conv = nn.ModuleList(self.conv) 170 | self.activation = activation 171 | elif not self.skip: 172 | self.conv = [base_model(in_channels, hi * out_channels)] 173 | for _ in range(1, k - 1): 174 | self.conv.append(base_model(hi * out_channels, hi * out_channels)) 175 | self.conv.append(base_model(hi * out_channels, out_channels)) 176 | self.conv = nn.ModuleList(self.conv) 177 | 178 | self.activation = activation 179 | else: 180 | self.fc_skip = nn.Linear(in_channels, out_channels) 181 | self.conv = [base_model(in_channels, out_channels)] 182 | for _ in range(1, k): 183 | self.conv.append(base_model(out_channels, out_channels)) 184 | self.conv = nn.ModuleList(self.conv) 185 | 186 | self.activation = activation 187 | 188 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor, l = [1, 1]): 189 | 190 | for i in range(self.k): 191 | K = np.random.randint(0, 4) 192 | feat = x 193 | #emb = feat 194 | #print (K) 195 | x = self.activation(self.conv[i](feat, edge_index, K)) 196 | 197 | return x 198 | 199 | class NewGRACE(torch.nn.Module): 200 | def __init__(self, encoder: NewEncoder, A: torch.Tensor, adj:torch.Tensor, num_hidden: int, num_proj_hidden: int, tau: float = 0.5): 201 | super(NewGRACE, self).__init__() 202 | self.encoder: NewEncoder = encoder 203 | #self.encoder2: Encoder = encoder2 204 | self.BCE = torch.nn.BCELoss() 205 | self.tau: float = tau 206 | self.adj = adj 207 | self.A = A 208 | self.norm = (A.shape[0] * A.shape[0]) / (float((A.shape[0] * A.shape[0] - torch.sum(A))) * 2) 209 | self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden) 210 | self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden) 211 | 212 | self.num_hidden = num_hidden 213 | 214 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor, l = [0, 0]) -> torch.Tensor: 215 | return self.encoder(x, edge_index, l)#, self.encoder2(x, edge_index) 216 | 217 | def projection(self, z: torch.Tensor) -> torch.Tensor: 218 | z = F.elu(self.fc1(z)) 219 | #z = self.fc1(z) 220 | return self.fc2(z) 221 | 222 | def sim(self, z1: torch.Tensor, z2: torch.Tensor): 223 | z1 = F.normalize(z1) 224 | z2 = F.normalize(z2) 225 | return torch.mm(z1, z2.t()) 226 | 227 | def recLoss(self, z1: torch.Tensor, z2: torch.Tensor): 228 | f = lambda x: torch.exp(x / self.tau) 229 | between_sim = f(self.sim(z1, z2)) 230 | ret = -torch.log(between_sim.diag() / between_sim.sum(1)) 231 | ret = ret.mean() 232 | return ret 233 | 234 | def GAELoss(self, z: torch.Tensor): 235 | act = nn.Sigmoid() 236 | return self.norm * self.BCE(act(torch.mm(z, z.t())), self.A) 237 | 238 | 239 | def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor, k = 0, r=0.3): 240 | f = lambda x: torch.exp(x / self.tau) 241 | refl_sim = f(self.sim(z1, z1)) 242 | between_sim = f(self.sim(z1, z2)) 243 | sim_gate = torch.sigmoid(between_sim) 244 | 245 | sim_f = between_sim 246 | pos_sim = between_sim 247 | for i in range(k): 248 | between_sim = torch.mm(self.adj, between_sim) 249 | pos_sim = (1 - r**(i+1)) * pos_sim + (r**(i+1)) * between_sim 250 | between_sim = pos_sim 251 | 252 | return -torch.log(between_sim.diag() / (refl_sim.sum(1) + sim_f.sum(1) - refl_sim.diag())) 253 | 254 | def bn_loss(self, z1: torch.Tensor, z2: torch.Tensor, z3: torch.Tensor): 255 | f = lambda x: torch.exp(x / self.tau) 256 | refl_sim = f(self.sim(z1, z3)) 257 | refl2_sim = f(self.sim(z2, z3)) 258 | between_sim = f(self.sim(z1, z2)) 259 | 260 | #return -torch.log(refl_sim.diag() / refl_sim.sum(1)) - torch.log(refl2_sim.diag() / (refl2_sim.sum(1))) 261 | return -torch.log(refl_sim.diag() / (between_sim.sum(1) + refl_sim.sum(1))) - torch.log(refl2_sim.diag() / (between_sim.sum(1) + refl2_sim.sum(1))) 262 | 263 | def batched_semi_loss(self, z1: torch.Tensor, z2: torch.Tensor, batch_size: int): 264 | # Space complexity: O(BN) (semi_loss: O(N^2)) 265 | device = z1.device 266 | num_nodes = z1.size(0) 267 | num_batches = (num_nodes - 1) // batch_size + 1 268 | f = lambda x: torch.exp(x / self.tau) 269 | indices = torch.arange(0, num_nodes).to(device) 270 | losses = [] 271 | 272 | for i in range(num_batches): 273 | mask = indices[i * batch_size:(i + 1) * batch_size] 274 | refl_sim = f(self.sim(z1[mask], z1)) # [B, N] 275 | between_sim = f(self.sim(z1[mask], z2)) # [B, N] 276 | 277 | losses.append(-torch.log(between_sim[:, i * batch_size:(i + 1) * batch_size].diag() 278 | / (refl_sim.sum(1) + between_sim.sum(1) 279 | - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag()))) 280 | 281 | return torch.cat(losses) 282 | 283 | def loss(self, z1: torch.Tensor, z2: torch.Tensor, z3: torch.Tensor, mean: bool = True, batch_size: Optional[int] = None): 284 | h1 = self.projection(z1) 285 | h2 = self.projection(z2) 286 | h3 = self.projection(z3) 287 | if batch_size is None: 288 | l1 = self.semi_loss(h1, h2) 289 | l2 = self.semi_loss(h2, h1) 290 | else: 291 | l1 = self.batched_semi_loss(h1, h2, batch_size) 292 | l2 = self.batched_semi_loss(h2, h1, batch_size) 293 | 294 | ret = (l1 + l2) * 0.5 295 | #ret = l1 296 | ret = ret.mean() if mean else ret.sum() 297 | #ret = ret# + 0.1 * self.GAELoss(z1) + 0.1 * self.GAELoss(z2) 298 | return ret 299 | 300 | 301 | class LogReg(nn.Module): 302 | def __init__(self, ft_in, nb_classes): 303 | super(LogReg, self).__init__() 304 | self.fc = nn.Linear(ft_in, nb_classes) 305 | 306 | for m in self.modules(): 307 | self.weights_init(m) 308 | 309 | def weights_init(self, m): 310 | if isinstance(m, nn.Linear): 311 | torch.nn.init.xavier_uniform_(m.weight.data) 312 | if m.bias is not None: 313 | m.bias.data.fill_(0.0) 314 | 315 | def forward(self, seq): 316 | ret = self.fc(seq) 317 | return ret 318 | 319 | class ViewLearner(torch.nn.Module): 320 | def __init__(self, encoder: Encoder, A: torch.Tensor, mlp_edge_model_dim: int = 64): 321 | super(ViewLearner, self).__init__() 322 | self.encoder: Encoder = encoder 323 | self.input_dim = self.encoder.out 324 | self.hidden = 128 325 | self.tau = 0.4 326 | self.BCE = torch.nn.BCELoss() 327 | self.A = A 328 | self.norm = (A.shape[0] * A.shape[0]) / (float((A.shape[0] * A.shape[0] - torch.sum(A))) * 2) 329 | self.mlp_edge_model = Sequential( 330 | Linear(self.input_dim * 2, mlp_edge_model_dim), 331 | ReLU(), 332 | Linear(mlp_edge_model_dim, 1) 333 | ) 334 | self.predict = nn.Sequential(nn.Linear(self.input_dim, self.hidden), nn.PReLU(), nn.Linear(self.hidden, self.input_dim)) 335 | self.init_emb() 336 | 337 | def sim(self, z1: torch.Tensor, z2: torch.Tensor): 338 | z1 = F.normalize(z1) 339 | z2 = F.normalize(z2) 340 | return torch.mm(z1, z2.t()) 341 | def predLoss(self, z1: torch.Tensor, z2: torch.Tensor): 342 | x = F.normalize(z1, dim=-1, p=2) 343 | y = F.normalize(z2, dim=-1, p=2) 344 | ret = 2 - 2 * (x * y).sum(dim=-1) 345 | ret = ret.mean() 346 | return ret 347 | def recLoss(self, z1: torch.Tensor, z2: torch.Tensor): 348 | f = lambda x: torch.exp(x / self.tau) 349 | between_sim = f(self.sim(z1, z2)) 350 | ret = -torch.log(between_sim.diag() / between_sim.sum(1)) 351 | ret = ret.mean() 352 | return ret 353 | 354 | def GAELoss(self, z: torch.Tensor): 355 | act = nn.Sigmoid() 356 | return self.norm * self.BCE(act(torch.mm(z, z.t())), self.A) 357 | 358 | def init_emb(self): 359 | for m in self.modules(): 360 | if isinstance(m, Linear): 361 | torch.nn.init.xavier_uniform_(m.weight.data) 362 | if m.bias is not None: 363 | m.bias.data.fill_(0.0) 364 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor): 365 | node_emb = self.encoder(x, edge_index) 366 | src, dst = edge_index[0], edge_index[1] 367 | emb_src = node_emb[src] 368 | emb_dst = node_emb[dst] 369 | 370 | edge_emb = torch.cat([emb_src, emb_dst], 1) 371 | edge_logits = self.mlp_edge_model(edge_emb) 372 | 373 | return edge_logits 374 | -------------------------------------------------------------------------------- /pGRACE/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import random 7 | import numpy as np 8 | from torch_geometric.nn import GCNConv, GATConv 9 | from ssgc import Net 10 | 11 | from torch.nn import Sequential, Linear, ReLU 12 | from torch_geometric.nn import MessagePassing 13 | from torch_geometric.utils import add_self_loops, degree 14 | from torch_geometric.datasets import TUDataset 15 | 16 | import torch 17 | from torch.nn import Parameter 18 | from torch_scatter import scatter_add 19 | from torch_geometric.nn.conv import MessagePassing 20 | from torch_geometric.utils import add_remaining_self_loops 21 | import math 22 | from typing import Optional 23 | from typing import Optional, Tuple, Union 24 | 25 | import torch 26 | import torch.nn.functional as F 27 | from torch import Tensor 28 | from torch.nn import Parameter 29 | from torch_sparse import SparseTensor, set_diag 30 | 31 | from torch_geometric.nn.dense.linear import Linear 32 | from torch_geometric.typing import ( 33 | Adj, 34 | NoneType, 35 | OptPairTensor, 36 | OptTensor, 37 | Size, 38 | ) 39 | from torch_geometric.utils import add_self_loops, remove_self_loops, softmax 40 | 41 | #from ..inits import glorot, zeros 42 | import torch 43 | from torch import nn 44 | import torch.nn.functional as F 45 | 46 | from torch_geometric.nn import GCNConv 47 | 48 | 49 | class Encoder(nn.Module): 50 | def __init__(self, in_channels: int, out_channels: int, activation, base_model=GCNConv, k: int = 2, skip=False): 51 | super(Encoder, self).__init__() 52 | self.base_model = base_model 53 | 54 | assert k >= 2 55 | self.k = k 56 | self.skip = skip 57 | if not self.skip: 58 | self.conv = [base_model(in_channels, 2 * out_channels).jittable()] 59 | for _ in range(1, k - 1): 60 | self.conv.append(base_model(2 * out_channels, 2 * out_channels)) 61 | self.conv.append(base_model(2 * out_channels, out_channels)) 62 | self.conv = nn.ModuleList(self.conv) 63 | 64 | self.activation = activation 65 | else: 66 | self.fc_skip = nn.Linear(in_channels, out_channels) 67 | self.conv = [base_model(in_channels, out_channels)] 68 | for _ in range(1, k): 69 | self.conv.append(base_model(out_channels, out_channels)) 70 | self.conv = nn.ModuleList(self.conv) 71 | 72 | self.activation = activation 73 | 74 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor): 75 | if not self.skip: 76 | for i in range(self.k): 77 | x = self.activation(self.conv[i](x, edge_index)) 78 | return x 79 | else: 80 | h = self.activation(self.conv[0](x, edge_index)) 81 | hs = [self.fc_skip(x), h] 82 | for i in range(1, self.k): 83 | u = sum(hs) 84 | hs.append(self.activation(self.conv[i](u, edge_index))) 85 | return hs[-1] 86 | 87 | 88 | class GRACE(torch.nn.Module): 89 | def __init__(self, encoder: Encoder, num_hidden: int, num_proj_hidden: int, tau: float = 0.5): 90 | super(GRACE, self).__init__() 91 | self.encoder: Encoder = encoder 92 | self.tau: float = tau 93 | 94 | self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden) 95 | self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden) 96 | 97 | self.num_hidden = num_hidden 98 | 99 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: 100 | return self.encoder(x, edge_index) 101 | 102 | def projection(self, z: torch.Tensor) -> torch.Tensor: 103 | z = F.elu(self.fc1(z)) 104 | return self.fc2(z) 105 | 106 | def sim(self, z1: torch.Tensor, z2: torch.Tensor): 107 | z1 = F.normalize(z1) 108 | z2 = F.normalize(z2) 109 | return torch.mm(z1, z2.t()) 110 | 111 | def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor): 112 | f = lambda x: torch.exp(x / self.tau) 113 | refl_sim = f(self.sim(z1, z1)) 114 | between_sim = f(self.sim(z1, z2)) 115 | 116 | return -torch.log(between_sim.diag() / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag())) 117 | 118 | def batched_semi_loss(self, z1: torch.Tensor, z2: torch.Tensor, batch_size: int): 119 | # Space complexity: O(BN) (semi_loss: O(N^2)) 120 | device = z1.device 121 | num_nodes = z1.size(0) 122 | num_batches = (num_nodes - 1) // batch_size + 1 123 | f = lambda x: torch.exp(x / self.tau) 124 | indices = torch.arange(0, num_nodes).to(device) 125 | losses = [] 126 | 127 | for i in range(num_batches): 128 | mask = indices[i * batch_size:(i + 1) * batch_size] 129 | refl_sim = f(self.sim(z1[mask], z1)) # [B, N] 130 | between_sim = f(self.sim(z1[mask], z2)) # [B, N] 131 | 132 | losses.append(-torch.log(between_sim[:, i * batch_size:(i + 1) * batch_size].diag() 133 | / (refl_sim.sum(1) + between_sim.sum(1) 134 | - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag()))) 135 | 136 | return torch.cat(losses) 137 | 138 | def loss(self, z1: torch.Tensor, z2: torch.Tensor, mean: bool = True, batch_size: Optional[int] = None): 139 | h1 = self.projection(z1) 140 | h2 = self.projection(z2) 141 | 142 | if batch_size is None: 143 | l1 = self.semi_loss(h1, h2) 144 | l2 = self.semi_loss(h2, h1) 145 | else: 146 | l1 = self.batched_semi_loss(h1, h2, batch_size) 147 | l2 = self.batched_semi_loss(h2, h1, batch_size) 148 | 149 | ret = (l1 + l2) * 0.5 150 | ret = ret.mean() if mean else ret.sum() 151 | 152 | return ret 153 | 154 | def glorot(tensor):#inits.py中 155 | if tensor is not None: 156 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 157 | tensor.data.uniform_(-stdv, stdv)#将tensor的值设置为-stdv, stdv之间 158 | def zeros(tensor): 159 | if tensor is not None: 160 | tensor.data.fill_(0) 161 | 162 | 163 | 164 | class NewGConv(MessagePassing): 165 | r"""The graph convolutional operator from the `"Semi-supervised 166 | Classification with Graph Convolutional Networks" 167 | `_ paper 168 | 169 | .. math:: 170 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 171 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 172 | 173 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 174 | adjacency matrix with inserted self-loops and 175 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 176 | 177 | Args: 178 | in_channels (int): Size of each input sample. 179 | out_channels (int): Size of each output sample. 180 | improved (bool, optional): If set to :obj:`True`, the layer computes 181 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 182 | (default: :obj:`False`) 183 | cached (bool, optional): If set to :obj:`True`, the layer will cache 184 | the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 185 | \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the 186 | cached version for further executions. 187 | This parameter should only be set to :obj:`True` in transductive 188 | learning scenarios. (default: :obj:`False`) 189 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 190 | an additive bias. (default: :obj:`True`) 191 | normalize (bool, optional): Whether to add self-loops and apply 192 | symmetric normalization. (default: :obj:`True`) 193 | **kwargs (optional): Additional arguments of 194 | :class:`torch_geometric.nn.conv.MessagePassing`. 195 | """ 196 | 197 | def __init__(self, in_channels, out_channels, improved=False, cached=False, 198 | bias=True, normalize=True, **kwargs): 199 | super(NewGConv, self).__init__(aggr='add', **kwargs) 200 | 201 | self.in_channels = in_channels#输入通道数,也就是X的shape[1] 202 | self.out_channels = out_channels#输出通道数 203 | self.improved = improved#$设置为true时A尖等于A+2I 204 | 205 | self.cached = cached#If set to True, the layer will cache the computation of D^−1/2A^D^−1/2 on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False) 206 | self.normalize = normalize#是否添加自环并应用对称归一化。 207 | 208 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 209 | 210 | if bias:#如果设置为False,则该层将不会学习加法偏差 211 | self.bias = Parameter(torch.Tensor(out_channels)) 212 | else: 213 | self.register_parameter('bias', None) 214 | 215 | self.reset_parameters() 216 | 217 | def reset_parameters(self): 218 | glorot(self.weight)#glorot函数下面有写,初始化weight矩阵 219 | zeros(self.bias)#zeros函数下面有写,初始化偏置矩阵 220 | self.cached_result = None 221 | self.cached_num_edges = None 222 | 223 | 224 | @staticmethod 225 | def norm(edge_index, num_nodes, edge_weight=None, improved=False, 226 | dtype=None): 227 | if edge_weight is None: 228 | edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, 229 | device=edge_index.device) 230 | 231 | fill_value = 1 if not improved else 2 232 | edge_index, edge_weight = add_remaining_self_loops( 233 | edge_index, edge_weight, fill_value, num_nodes) 234 | 235 | row, col = edge_index 236 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 237 | deg_inv_sqrt = deg.pow(-0.5) 238 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 239 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 240 | 241 | return edge_index, edge_weight 242 | 243 | 244 | def forward(self, x, edge_index, c, edge_weight=None): 245 | """""" 246 | x = torch.matmul(x, self.weight)#将x与权重矩阵相乘 247 | 248 | if self.cached and self.cached_result is not None: 249 | if edge_index.size(1) != self.cached_num_edges: 250 | raise RuntimeError( 251 | 'Cached {} number of edges, but found {}. Please ' 252 | 'disable the caching behavior of this layer by removing ' 253 | 'the `cached=True` argument in its constructor.'.format( 254 | self.cached_num_edges, edge_index.size(1))) 255 | 256 | if not self.cached or self.cached_result is None: 257 | self.cached_num_edges = edge_index.size(1) 258 | if self.normalize: 259 | edge_index, norm = self.norm(edge_index, x.size( 260 | self.node_dim), edge_weight, self.improved, x.dtype) 261 | else: 262 | norm = edge_weight 263 | self.cached_result = edge_index, norm 264 | 265 | edge_index, norm = self.cached_result 266 | if c == 0: 267 | x = self.propagate(edge_index, x=x, norm=norm) 268 | for _ in range(c): 269 | x = 1 * x + 1 * self.propagate(edge_index, x=x, norm=norm) 270 | x = 0.5 * x 271 | return x 272 | 273 | 274 | def message(self, x_j, norm): 275 | return norm.view(-1, 1) * x_j if norm is not None else x_j 276 | 277 | def update(self, aggr_out): 278 | if self.bias is not None: 279 | aggr_out = aggr_out + self.bias 280 | return aggr_out 281 | 282 | def __repr__(self): 283 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 284 | self.out_channels) 285 | 286 | class NewEncoder(nn.Module): 287 | def __init__(self, in_channels: int, out_channels: int, activation, base_model=GATConv, k: int = 2, skip=False): 288 | super(NewEncoder, self).__init__() 289 | self.base_model = base_model 290 | #self.adj = adj 291 | assert k >= 1 292 | self.k = k 293 | self.skip = skip 294 | self.out = out_channels 295 | hi = 1 296 | if k == 1: 297 | self.conv = [base_model(in_channels, out_channels).jittable()] 298 | self.conv = nn.ModuleList(self.conv) 299 | self.activation = activation 300 | elif not self.skip: 301 | self.conv = [base_model(in_channels, hi * out_channels)] 302 | for _ in range(1, k - 1): 303 | self.conv.append(base_model(hi * out_channels, hi * out_channels)) 304 | self.conv.append(base_model(hi * out_channels, out_channels)) 305 | self.conv = nn.ModuleList(self.conv) 306 | 307 | self.activation = activation 308 | else: 309 | self.fc_skip = nn.Linear(in_channels, out_channels) 310 | self.conv = [base_model(in_channels, out_channels)] 311 | for _ in range(1, k): 312 | self.conv.append(base_model(out_channels, out_channels)) 313 | self.conv = nn.ModuleList(self.conv) 314 | 315 | self.activation = activation 316 | 317 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor, R = [1,2], final = False): 318 | if final == False: 319 | K2 = np.random.randint(0, R[1]) 320 | K1 = np.random.randint(0, R[0]) 321 | if final: 322 | K1 = 2 323 | K2 = 2 324 | feat = x 325 | x = self.activation(self.conv[0](feat, edge_index, K1)) 326 | x = self.conv[1](x, edge_index, K2) 327 | 328 | return x 329 | 330 | class NewGRACE(torch.nn.Module): 331 | def __init__(self, encoder: NewEncoder, Adj, num_hidden: int, num_proj_hidden: int, tau: float = 0.5): 332 | super(NewGRACE, self).__init__() 333 | self.encoder: NewEncoder = encoder 334 | #self.encoder2: Encoder = encoder2 335 | #self.BCE = torch.nn.BCELoss() 336 | self.tau: float = tau 337 | self.adj = Adj 338 | #self.adj = adj 339 | #self.A = A 340 | #self.norm = (A.shape[0] * A.shape[0]) / (float((A.shape[0] * A.shape[0] - torch.sum(A))) * 2) 341 | self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden) 342 | self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden) 343 | 344 | self.num_hidden = num_hidden 345 | 346 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor, R = [1,2], final = False) -> torch.Tensor: 347 | return self.encoder(x, edge_index, R, final)#, self.encoder2(x, edge_index) 348 | 349 | def projection(self, z: torch.Tensor) -> torch.Tensor: 350 | z = F.elu(self.fc1(z)) 351 | #z = self.fc1(z) 352 | return self.fc2(z) 353 | 354 | def sim(self, z1: torch.Tensor, z2: torch.Tensor): 355 | z1 = F.normalize(z1) 356 | z2 = F.normalize(z2) 357 | return torch.mm(z1, z2.t()) 358 | 359 | def recLoss(self, z1: torch.Tensor, z2: torch.Tensor): 360 | f = lambda x: torch.exp(x / self.tau) 361 | between_sim = f(self.sim(z1, z2)) 362 | iden = torch.tensor(np.eye(between_sim.shape[0])) 363 | ref_sim = f(self.sim(z1, z1) - iden) 364 | ret = -torch.log(between_sim.diag() / ref_sim.sum(1)) 365 | ret = ret.mean() 366 | return ret 367 | """ 368 | def GAELoss(self, z: torch.Tensor): 369 | act = nn.Sigmoid() 370 | return self.norm * self.BCE(act(torch.mm(z, z.t())), self.A) 371 | """ 372 | 373 | def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor, k = 0, r=0.9): 374 | 375 | f = lambda x: torch.exp(x / self.tau) 376 | 377 | 378 | z1 = F.normalize(z1) 379 | z2 = F.normalize(z2) 380 | 381 | between_sim = f(torch.mm(z1, z2.t())) 382 | 383 | return -torch.log(between_sim.diag() / between_sim.sum(1)) 384 | 385 | 386 | def loss(self, z1: torch.Tensor, z2: torch.Tensor, mean: bool = True, batch_size: Optional[int] = None): 387 | h1 = z1 388 | h2 = z2 389 | if batch_size is None: 390 | l = self.semi_loss(h1, h2) 391 | else: 392 | l = self.semi_loss(h1, h2) 393 | 394 | ret = l 395 | #ret = l1 396 | ret = ret.mean() if mean else ret.sum() 397 | #ret = ret# + 0.1 * self.GAELoss(z1) + 0.1 * self.GAELoss(z2) 398 | return ret 399 | 400 | 401 | class LogReg(nn.Module): 402 | def __init__(self, ft_in, nb_classes): 403 | super(LogReg, self).__init__() 404 | self.fc = nn.Linear(ft_in, nb_classes) 405 | 406 | for m in self.modules(): 407 | self.weights_init(m) 408 | 409 | def weights_init(self, m): 410 | if isinstance(m, nn.Linear): 411 | torch.nn.init.xavier_uniform_(m.weight.data) 412 | if m.bias is not None: 413 | m.bias.data.fill_(0.0) 414 | 415 | def forward(self, seq): 416 | ret = self.fc(seq) 417 | return ret 418 | --------------------------------------------------------------------------------