├── framework.png ├── scripts ├── run_cora.sh ├── run_texas.sh ├── run_actor.sh ├── run_citeseer.sh ├── run_cornell.sh ├── run_pubmed.sh ├── run_squirrel.sh ├── run_wiki_cs.sh ├── run_chameleon.sh ├── run_coauthor_cs.sh ├── run_winsconsin.sh ├── run_ama_photo.sh ├── run_coauthor_phy.sh └── run_ama_computers.sh ├── README.md ├── LICENSE ├── data_loader.py ├── utils.py ├── main.py └── model.py /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yixinliu233/GREET/HEAD/framework.png -------------------------------------------------------------------------------- /scripts/run_cora.sh: -------------------------------------------------------------------------------- 1 | # Cora 2 | python main.py -dataset cora -ntrials 10 -sparse 0 -epochs 400 -cl_batch_size 0 -nlayers_proj 1 -alpha 0.5 -k 20 -maskfeat_rate_1 0.8 -maskfeat_rate_2 0.1 -dropedge_rate_1 0.8 -dropedge_rate_2 0.8 -lr_disc 0.001 -margin_hom 0.5 -margin_het 0.5 -cl_rounds 2 -eval_freq 5 -------------------------------------------------------------------------------- /scripts/run_texas.sh: -------------------------------------------------------------------------------- 1 | # Texas 2 | python main.py -dataset texas -ntrials 10 -sparse 0 -epochs 400 -cl_batch_size 0 -nlayers_proj 2 -alpha 0.5 -k 20 -maskfeat_rate_1 0.5 -maskfeat_rate_2 0.1 -dropedge_rate_1 0.1 -dropedge_rate_2 0.1 -lr_disc 0.001 -margin_hom 0.5 -margin_het 0.5 -cl_rounds 2 -eval_freq 20 -------------------------------------------------------------------------------- /scripts/run_actor.sh: -------------------------------------------------------------------------------- 1 | # Actor 2 | python main.py -dataset actor -ntrials 10 -sparse 1 -epochs 1000 -cl_batch_size 0 -nlayers_proj 2 -alpha 0.1 -k 20 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.5 -dropedge_rate_1 0.5 -dropedge_rate_2 0.8 -lr_disc 0.001 -margin_hom 0.1 -margin_het 0.5 -cl_rounds 2 -eval_freq 50 3 | -------------------------------------------------------------------------------- /scripts/run_citeseer.sh: -------------------------------------------------------------------------------- 1 | # CiteSeer 2 | python main.py -dataset citeseer -ntrials 10 -sparse 0 -epochs 400 -cl_batch_size 0 -nlayers_proj 2 -alpha 0.1 -k 30 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.1 -dropedge_rate_1 0.8 -dropedge_rate_2 0.1 -lr_disc 0.001 -margin_hom 0.5 -margin_het 0.5 -cl_rounds 2 -eval_freq 5 -------------------------------------------------------------------------------- /scripts/run_cornell.sh: -------------------------------------------------------------------------------- 1 | # Cornell 2 | python main.py -dataset cornell -ntrials 10 -sparse 0 -epochs 400 -cl_batch_size 0 -nlayers_proj 2 -alpha 0.3 -k 25 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.5 -dropedge_rate_1 0.5 -dropedge_rate_2 0.1 -lr_disc 0.001 -margin_hom 0.3 -margin_het 0.3 -cl_rounds 2 -eval_freq 10 -------------------------------------------------------------------------------- /scripts/run_pubmed.sh: -------------------------------------------------------------------------------- 1 | # PubMed 2 | python main.py -dataset pubmed -ntrials 10 -sparse 1 -epochs 800 -cl_batch_size 5000 -nlayers_proj 2 -alpha 0.1 -k 0 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.5 -dropedge_rate_1 0.5 -dropedge_rate_2 0.1 -lr_disc 0.001 -margin_hom 0.5 -margin_het 0.5 -cl_rounds 2 -eval_freq 20 -------------------------------------------------------------------------------- /scripts/run_squirrel.sh: -------------------------------------------------------------------------------- 1 | # Squirrel 2 | python main.py -dataset squirrel -ntrials 10 -sparse 0 -epochs 1000 -cl_batch_size 0 -nlayers_proj 2 -alpha 0.1 -k 0 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.1 -dropedge_rate_1 0.1 -dropedge_rate_2 0.8 -lr_disc 0.001 -margin_hom 0.1 -margin_het 0.3 -cl_rounds 2 -eval_freq 50 -------------------------------------------------------------------------------- /scripts/run_wiki_cs.sh: -------------------------------------------------------------------------------- 1 | # Wiki-CS 2 | python main.py -dataset wikics -ntrials 10 -sparse 1 -epochs 1500 -cl_batch_size 3000 -nlayers_proj 1 -alpha 0.1 -k 30 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.1 -dropedge_rate_1 0.5 -dropedge_rate_2 0.5 -lr_disc 0.001 -margin_hom 0.5 -margin_het 0.5 -cl_rounds 2 -eval_freq 50 -------------------------------------------------------------------------------- /scripts/run_chameleon.sh: -------------------------------------------------------------------------------- 1 | # Chameleon 2 | python main.py -dataset chameleon -ntrials 10 -sparse 0 -epochs 500 -cl_batch_size 0 -nlayers_proj 1 -alpha 0.1 -k 0 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.5 -dropedge_rate_1 0.5 -dropedge_rate_2 0.1 -lr_disc 0.001 -margin_hom 0.5 -margin_het 0.5 -cl_rounds 2 -eval_freq 20 -------------------------------------------------------------------------------- /scripts/run_coauthor_cs.sh: -------------------------------------------------------------------------------- 1 | # CoAuthor CS 2 | python main.py -dataset cs -ntrials 10 -sparse 1 -epochs 1500 -cl_batch_size 5000 -nlayers_proj 1 -alpha 0.5 -k 30 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.5 -dropedge_rate_1 0.5 -dropedge_rate_2 0.1 -lr_disc 0.001 -margin_hom 0.1 -margin_het 0.5 -cl_rounds 2 -eval_freq 50 -------------------------------------------------------------------------------- /scripts/run_winsconsin.sh: -------------------------------------------------------------------------------- 1 | # Wisconsin 2 | python main.py -dataset wisconsin -ntrials 10 -sparse 0 -epochs 400 -cl_batch_size 0 -nlayers_proj 2 -alpha 0.5 -k 25 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.1 -dropedge_rate_1 0.1 -dropedge_rate_2 0.3 -lr_disc 0.001 -margin_hom 0.5 -margin_het 0.5 -cl_rounds 2 -eval_freq 20 -------------------------------------------------------------------------------- /scripts/run_ama_photo.sh: -------------------------------------------------------------------------------- 1 | # Amazon Photo 2 | python main.py -dataset photo -ntrials 10 -sparse 1 -epochs 1500 -cl_batch_size 5000 -nlayers_proj 1 -alpha 0.3 -k 30 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.1 -dropedge_rate_1 0.8 -dropedge_rate_2 0.5 -lr_disc 0.0001 -margin_hom 0.5 -margin_het 0.5 -cl_rounds 3 -eval_freq 20 -------------------------------------------------------------------------------- /scripts/run_coauthor_phy.sh: -------------------------------------------------------------------------------- 1 | # CoAuthor Physics 2 | python main.py -dataset physics -ntrials 10 -sparse 1 -epochs 800 -cl_batch_size 2000 -nlayers_proj 1 -alpha 0.1 -k 25 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.5 -dropedge_rate_1 0.5 -dropedge_rate_2 0.1 -lr_disc 0.001 -margin_hom 0.5 -margin_het 0.5 -cl_rounds 2 -eval_freq 50 -------------------------------------------------------------------------------- /scripts/run_ama_computers.sh: -------------------------------------------------------------------------------- 1 | # Amazon Computers 2 | python main.py -dataset computers -ntrials 10 -sparse 1 -epochs 1500 -cl_batch_size 5000 -nlayers_proj 1 -alpha 0.3 -k 10 -maskfeat_rate_1 0.1 -maskfeat_rate_2 0.1 -dropedge_rate_1 0.5 -dropedge_rate_2 0.1 -lr_disc 0.0001 -margin_hom 0.1 -margin_het 0.5 -cl_rounds 3 -eval_freq 20 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GREET 2 | This is the source code of AAAI'23 paper "Beyond Smoothing: Unsupervised Graph Representation Learning with Edge Heterophily Discriminating". 3 | 4 | ![The proposed framework](framework.png) 5 | 6 | ## Requirements 7 | This code requires the following: 8 | * Python==3.9 9 | * Pytorch==1.11.0 10 | * Pytorch Geometric==2.0.4 11 | * DGL==0.8.0 12 | * Numpy==1.21.2 13 | * Scikit-learn==1.0.2 14 | * Scipy==1.7.3 15 | 16 | ## Usage 17 | Just run the script corresponding to the dataset you want. For instance: 18 | 19 | ``` 20 | bash script/run_cora.sh 21 | ``` 22 | 23 | ## Cite 24 | 25 | If you compare with, build on, or use aspects of this work, please cite the following: 26 | ``` 27 | @inproceedings{ 28 | liu2023beyond, 29 | title={Beyond Smoothing: Unsupervised Graph Representation Learning with Edge Heterophily Discriminating}, 30 | author={Liu, Yixin and Zheng, Yizhen and Zhang, Daokun and Lee, Vincent and Pan, Shirui}, 31 | booktitle={AAAI}, 32 | year={2023} 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yixin Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import scipy.sparse as sp 4 | import numpy as np 5 | import os 6 | import torch_geometric.transforms as T 7 | from torch_geometric.datasets import Planetoid, WikipediaNetwork, Actor, WebKB, Amazon, Coauthor, WikiCS 8 | from torch_geometric.utils import remove_self_loops 9 | 10 | warnings.simplefilter("ignore") 11 | 12 | 13 | def get_split(num_samples: int, train_ratio: float = 0.1, test_ratio: float = 0.8, num_splits: int = 10): 14 | 15 | assert train_ratio + test_ratio < 1 16 | train_size = int(num_samples * train_ratio) 17 | test_size = int(num_samples * test_ratio) 18 | 19 | trains, vals, tests = [], [], [] 20 | 21 | for _ in range(num_splits): 22 | indices = torch.randperm(num_samples) 23 | 24 | train_mask = torch.zeros(num_samples, dtype=torch.bool) 25 | train_mask.fill_(False) 26 | train_mask[indices[:train_size]] = True 27 | 28 | test_mask = torch.zeros(num_samples, dtype=torch.bool) 29 | test_mask.fill_(False) 30 | test_mask[indices[train_size: test_size + train_size]] = True 31 | 32 | val_mask = torch.zeros(num_samples, dtype=torch.bool) 33 | val_mask.fill_(False) 34 | val_mask[indices[test_size + train_size:]] = True 35 | 36 | trains.append(train_mask.unsqueeze(1)) 37 | vals.append(val_mask.unsqueeze(1)) 38 | tests.append(test_mask.unsqueeze(1)) 39 | 40 | train_mask_all = torch.cat(trains, 1) 41 | val_mask_all = torch.cat(vals, 1) 42 | test_mask_all = torch.cat(tests, 1) 43 | 44 | return train_mask_all, val_mask_all, test_mask_all 45 | 46 | 47 | def get_structural_encoding(edges, nnodes, str_enc_dim=16): 48 | 49 | row = edges[0, :].numpy() 50 | col = edges[1, :].numpy() 51 | data = np.ones_like(row) 52 | 53 | A = sp.csr_matrix((data, (row, col)), shape=(nnodes, nnodes)) 54 | D = (np.array(A.sum(1)).squeeze()) ** -1.0 55 | 56 | Dinv = sp.diags(D) 57 | RW = A * Dinv 58 | M = RW 59 | 60 | SE = [torch.from_numpy(M.diagonal()).float()] 61 | M_power = M 62 | for _ in range(str_enc_dim - 1): 63 | M_power = M_power * M 64 | SE.append(torch.from_numpy(M_power.diagonal()).float()) 65 | SE = torch.stack(SE, dim=-1) 66 | return SE 67 | 68 | 69 | def load_data(dataset_name): 70 | 71 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '.', 'data', dataset_name) 72 | 73 | if dataset_name in ['cora', 'citeseer', 'pubmed']: 74 | dataset = Planetoid(path, dataset_name) 75 | elif dataset_name in ['chameleon']: 76 | dataset = WikipediaNetwork(path, dataset_name) 77 | elif dataset_name in ['squirrel']: 78 | dataset = WikipediaNetwork(path, dataset_name, transform=T.NormalizeFeatures()) 79 | elif dataset_name in ['actor']: 80 | dataset = Actor(path) 81 | elif dataset_name in ['cornell', 'texas', 'wisconsin']: 82 | dataset = WebKB(path, dataset_name) 83 | elif dataset_name in ['computers', 'photo']: 84 | dataset = Amazon(path, dataset_name, transform=T.NormalizeFeatures()) 85 | elif dataset_name in ['cs', 'physics']: 86 | dataset = Coauthor(path, dataset_name, transform=T.NormalizeFeatures()) 87 | elif dataset_name in ['wikics']: 88 | dataset = WikiCS(path) 89 | 90 | data = dataset[0] 91 | 92 | edges = remove_self_loops(data.edge_index)[0] 93 | 94 | features = data.x 95 | [nnodes, nfeats] = features.shape 96 | nclasses = torch.max(data.y).item() + 1 97 | 98 | if dataset_name in ['computers', 'photo', 'cs', 'physics', 'wikics']: 99 | train_mask, val_mask, test_mask = get_split(nnodes) 100 | else: 101 | train_mask, val_mask, test_mask = data.train_mask, data.val_mask, data.test_mask 102 | 103 | if len(train_mask.shape) < 2: 104 | train_mask = train_mask.unsqueeze(1) 105 | val_mask = val_mask.unsqueeze(1) 106 | test_mask = test_mask.unsqueeze(1) 107 | 108 | labels = data.y 109 | 110 | path = '../data/se/{}'.format(dataset_name) 111 | if not os.path.exists(path): 112 | os.makedirs(path) 113 | file_name = path + '/{}_{}.pt'.format(dataset_name, 16) 114 | if os.path.exists(file_name): 115 | se = torch.load(file_name) 116 | # print('Load exist structural encoding.') 117 | else: 118 | print('Computing structural encoding...') 119 | se = get_structural_encoding(edges, nnodes) 120 | torch.save(se, file_name) 121 | print('Done. The structural encoding is saved as: {}.'.format(file_name)) 122 | 123 | return features, edges, se, train_mask, val_mask, test_mask, labels, nnodes, nfeats 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | import torch.nn.functional as F 5 | import dgl 6 | import time 7 | from sklearn.neighbors import kneighbors_graph 8 | from sklearn.metrics import accuracy_score 9 | from sklearn.linear_model import LogisticRegression 10 | from sklearn.model_selection import GridSearchCV 11 | from sklearn.multiclass import OneVsRestClassifier 12 | from sklearn.preprocessing import normalize 13 | 14 | EOS = 1e-10 15 | 16 | 17 | def split_batch(init_list, batch_size): 18 | groups = zip(*(iter(init_list),) * batch_size) 19 | end_list = [list(i) for i in groups] 20 | count = len(init_list) % batch_size 21 | end_list.append(init_list[-count:]) if count != 0 else end_list 22 | return end_list 23 | 24 | 25 | def get_feat_mask(features, mask_rate): 26 | feat_node = features.shape[1] 27 | mask = torch.zeros(features.shape) 28 | samples = np.random.choice(feat_node, size=int(feat_node * mask_rate), replace=False) 29 | mask[:, samples] = 1 30 | if torch.cuda.is_available(): 31 | mask = mask.cuda() 32 | return mask, samples 33 | 34 | 35 | def normalize_adj(adj, mode, sparse=False): 36 | if not sparse: 37 | if mode == "sym": 38 | inv_sqrt_degree = 1. / (torch.sqrt(adj.sum(dim=1, keepdim=False)) + EOS) 39 | return inv_sqrt_degree[:, None] * adj * inv_sqrt_degree[None, :] 40 | elif mode == "row": 41 | inv_degree = 1. / (adj.sum(dim=1, keepdim=False) + EOS) 42 | return inv_degree[:, None] * adj 43 | else: 44 | exit("wrong norm mode") 45 | else: 46 | adj = adj.coalesce() 47 | if mode == "sym": 48 | inv_sqrt_degree = 1. / (torch.sqrt(torch.sparse.sum(adj, dim=1).values())) 49 | D_value = inv_sqrt_degree[adj.indices()[0]] * inv_sqrt_degree[adj.indices()[1]] 50 | 51 | elif mode == "row": 52 | inv_degree = 1. / (torch.sparse.sum(adj, dim=1).values() + EOS) 53 | D_value = inv_degree[adj.indices()[0]] 54 | else: 55 | exit("wrong norm mode") 56 | new_values = adj.values() * D_value 57 | 58 | return torch.sparse.FloatTensor(adj.indices(), new_values, adj.size()) 59 | 60 | 61 | def get_adj_from_edges(edges, weights, nnodes): 62 | adj = torch.zeros(nnodes, nnodes).cuda() 63 | adj[edges[0], edges[1]] = weights 64 | return adj 65 | 66 | 67 | def augmentation(features_1, adj_1, features_2, adj_2, args, training): 68 | # view 1 69 | mask_1, _ = get_feat_mask(features_1, args.maskfeat_rate_1) 70 | features_1 = features_1 * (1 - mask_1) 71 | if not args.sparse: 72 | adj_1 = F.dropout(adj_1, p=args.dropedge_rate_1, training=training) 73 | else: 74 | adj_1.edata['w'] = F.dropout(adj_1.edata['w'], p=args.dropedge_rate_1, training=training) 75 | 76 | # # view 2 77 | mask_2, _ = get_feat_mask(features_1, args.maskfeat_rate_2) 78 | features_2 = features_2 * (1 - mask_2) 79 | if not args.sparse: 80 | adj_2 = F.dropout(adj_2, p=args.dropedge_rate_2, training=training) 81 | else: 82 | adj_2.edata['w'] = F.dropout(adj_2.edata['w'], p=args.dropedge_rate_2, training=training) 83 | 84 | return features_1, adj_1, features_2, adj_2 85 | 86 | 87 | def generate_random_node_pairs(nnodes, nedges, backup=300): 88 | rand_edges = np.random.choice(nnodes, size=(nedges + backup) * 2, replace=True) 89 | rand_edges = rand_edges.reshape((2, nedges + backup)) 90 | rand_edges = torch.from_numpy(rand_edges) 91 | rand_edges = rand_edges[:, rand_edges[0,:] != rand_edges[1,:]] 92 | rand_edges = rand_edges[:, 0: nedges] 93 | return rand_edges.cuda() 94 | 95 | 96 | def eval_debug_mode(embedding, labels, train_mask, val_mask, test_mask): 97 | 98 | t1 = time.time() 99 | 100 | X = embedding.detach().cpu().numpy() 101 | Y = labels.detach().cpu().numpy() 102 | 103 | X = normalize(X, norm='l2') 104 | 105 | nb_split = train_mask.shape[1] 106 | 107 | accs = [] 108 | for split in range(nb_split): 109 | X_train = X[train_mask.cpu()[:, split]] 110 | X_test = X[test_mask.cpu()[:, split]] 111 | y_train = Y[train_mask.cpu()[:, split]] 112 | y_test = Y[test_mask.cpu()[:, split]] 113 | 114 | logreg = LogisticRegression(solver='liblinear') 115 | c = 2.0 ** np.arange(-10, 10) 116 | clf = GridSearchCV(estimator=OneVsRestClassifier(logreg), 117 | param_grid=dict(estimator__C=c), n_jobs=8, cv=5, 118 | verbose=0) 119 | clf.fit(X_train, y_train) 120 | 121 | y_pred = clf.predict(X_test) 122 | acc = accuracy_score(y_test, y_pred) 123 | accs.append(acc) 124 | 125 | print('eval time:{:.4f}s'.format(time.time() - t1)) 126 | 127 | return accs 128 | 129 | 130 | def eval_test_mode(embedding, labels, train_mask, val_mask, test_mask): 131 | 132 | X = embedding.detach().cpu().numpy() 133 | Y = labels.detach().cpu().numpy() 134 | X = normalize(X, norm='l2') 135 | 136 | X_train = X[train_mask.cpu()] 137 | X_val = X[val_mask.cpu()] 138 | X_test = X[test_mask.cpu()] 139 | y_train = Y[train_mask.cpu()] 140 | y_val = Y[val_mask.cpu()] 141 | y_test = Y[test_mask.cpu()] 142 | 143 | logreg = LogisticRegression(solver='liblinear') 144 | c = 2.0 ** np.arange(-10, 10) 145 | clf = GridSearchCV(estimator=OneVsRestClassifier(logreg), 146 | param_grid=dict(estimator__C=c), n_jobs=8, cv=5, 147 | verbose=0) 148 | clf.fit(X_train, y_train) 149 | 150 | y_pred_test = clf.predict(X_test) 151 | acc_test = accuracy_score(y_test, y_pred_test) 152 | y_pred_val = clf.predict(X_val) 153 | acc_val = accuracy_score(y_val, y_pred_val) 154 | 155 | return acc_test * 100, acc_val * 100 156 | 157 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import dgl 6 | import random 7 | 8 | from data_loader import load_data 9 | from model import * 10 | from utils import * 11 | 12 | EOS = 1e-10 13 | 14 | 15 | def setup_seed(seed): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | torch.backends.cudnn.deterministic = True 19 | np.random.seed(seed) 20 | random.seed(seed) 21 | dgl.seed(seed) 22 | dgl.random.seed(seed) 23 | 24 | 25 | def train_cl(cl_model, discriminator, optimizer_cl, features, str_encodings, edges): 26 | 27 | cl_model.train() 28 | discriminator.eval() 29 | 30 | adj_1, adj_2, weights_lp, _ = discriminator(torch.cat((features, str_encodings), 1), edges) 31 | features_1, adj_1, features_2, adj_2 = augmentation(features, adj_1, features, adj_2, args, cl_model.training) 32 | cl_loss = cl_model(features_1, adj_1, features_2, adj_2) 33 | 34 | optimizer_cl.zero_grad() 35 | cl_loss.backward() 36 | optimizer_cl.step() 37 | 38 | return cl_loss.item() 39 | 40 | 41 | def train_discriminator(cl_model, discriminator, optimizer_disc, features, str_encodings, edges, args): 42 | 43 | cl_model.eval() 44 | discriminator.train() 45 | 46 | adj_1, adj_2, weights_lp, weights_hp = discriminator(torch.cat((features, str_encodings), 1), edges) 47 | rand_np = generate_random_node_pairs(features.shape[0], edges.shape[1]) 48 | psu_label = torch.ones(edges.shape[1]).cuda() 49 | 50 | embedding = cl_model.get_embedding(features, adj_1, adj_2) 51 | edge_emb_sim = F.cosine_similarity(embedding[edges[0]], embedding[edges[1]]) 52 | 53 | rnp_emb_sim_lp = F.cosine_similarity(embedding[rand_np[0]], embedding[rand_np[1]]) 54 | loss_lp = F.margin_ranking_loss(edge_emb_sim, rnp_emb_sim_lp, psu_label, margin=args.margin_hom, reduction='none') 55 | loss_lp *= torch.relu(weights_lp - 0.5) 56 | 57 | rnp_emb_sim_hp = F.cosine_similarity(embedding[rand_np[0]], embedding[rand_np[1]]) 58 | loss_hp = F.margin_ranking_loss(rnp_emb_sim_hp, edge_emb_sim, psu_label, margin=args.margin_het, reduction='none') 59 | loss_hp *= torch.relu(weights_hp - 0.5) 60 | 61 | rank_loss = (loss_lp.mean() + loss_hp.mean()) / 2 62 | 63 | optimizer_disc.zero_grad() 64 | rank_loss.backward() 65 | optimizer_disc.step() 66 | 67 | return rank_loss.item() 68 | 69 | 70 | def main(args): 71 | 72 | setup_seed(0) 73 | features, edges, str_encodings, train_mask, val_mask, test_mask, labels, nnodes, nfeats = load_data(args.dataset) 74 | results = [] 75 | 76 | for trial in range(args.ntrials): 77 | 78 | setup_seed(trial) 79 | 80 | cl_model = GCL(nlayers=args.nlayers_enc, nlayers_proj=args.nlayers_proj, in_dim=nfeats, emb_dim=args.emb_dim, 81 | proj_dim=args.proj_dim, dropout=args.dropout, sparse=args.sparse, batch_size=args.cl_batch_size).cuda() 82 | cl_model.set_mask_knn(features.cpu(), k=args.k, dataset=args.dataset) 83 | discriminator = Edge_Discriminator(nnodes, nfeats + str_encodings.shape[1], args.alpha, args.sparse).cuda() 84 | 85 | optimizer_cl = torch.optim.Adam(cl_model.parameters(), lr=args.lr_gcl, weight_decay=args.w_decay) 86 | optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=args.lr_disc, weight_decay=args.w_decay) 87 | 88 | features = features.cuda() 89 | str_encodings = str_encodings.cuda() 90 | edges = edges.cuda() 91 | 92 | best_acc_val = 0 93 | best_acc_test = 0 94 | 95 | for epoch in range(1, args.epochs + 1): 96 | 97 | for _ in range(args.cl_rounds): 98 | cl_loss = train_cl(cl_model, discriminator, optimizer_cl, features, str_encodings, edges) 99 | rank_loss = train_discriminator(cl_model, discriminator, optimizer_discriminator, features, str_encodings, edges, args) 100 | 101 | print("[TRAIN] Epoch:{:04d} | CL Loss {:.4f} | RANK loss:{:.4f} ".format(epoch, cl_loss, rank_loss)) 102 | 103 | if epoch % args.eval_freq == 0: 104 | cl_model.eval() 105 | discriminator.eval() 106 | adj_1, adj_2, _, _ = discriminator(torch.cat((features, str_encodings), 1), edges) 107 | embedding = cl_model.get_embedding(features, adj_1, adj_2) 108 | cur_split = 0 if (train_mask.shape[1]==1) else (trial % train_mask.shape[1]) 109 | acc_test, acc_val = eval_test_mode(embedding, labels, train_mask[:, cur_split], 110 | val_mask[:, cur_split], test_mask[:, cur_split]) 111 | print( 112 | '[TEST] Epoch:{:04d} | CL loss:{:.4f} | RANK loss:{:.4f} | VAL ACC:{:.2f} | TEST ACC:{:.2f}'.format( 113 | epoch, cl_loss, rank_loss, acc_val, acc_test)) 114 | 115 | if acc_val > best_acc_val: 116 | best_acc_val = acc_val 117 | best_acc_test = acc_test 118 | 119 | results.append(best_acc_test) 120 | 121 | print('\n[FINAL RESULT] Dataset:{} | Run:{} | ACC:{:.2f}+-{:.2f}'.format(args.dataset, args.ntrials, np.mean(results), 122 | np.std(results))) 123 | 124 | if __name__ == '__main__': 125 | parser = argparse.ArgumentParser() 126 | 127 | # ESSENTIAL 128 | parser.add_argument('-dataset', type=str, default='cornell', 129 | choices=['cora', 'citeseer', 'pubmed', 'chameleon', 'squirrel', 'actor', 'cornell', 130 | 'texas', 'wisconsin', 'computers', 'photo', 'cs', 'physics', 'wikics']) 131 | parser.add_argument('-ntrials', type=int, default=10) 132 | parser.add_argument('-sparse', type=int, default=0) 133 | parser.add_argument('-eval_freq', type=int, default=20) 134 | parser.add_argument('-epochs', type=int, default=400) 135 | parser.add_argument('-lr_gcl', type=float, default=0.001) 136 | parser.add_argument('-lr_disc', type=float, default=0.001) 137 | parser.add_argument('-cl_rounds', type=int, default=2) 138 | parser.add_argument('-w_decay', type=float, default=0.0) 139 | parser.add_argument('-dropout', type=float, default=0.5) 140 | 141 | # DISC Module - Hyper-param 142 | parser.add_argument('-alpha', type=float, default=0.1) 143 | parser.add_argument('-margin_hom', type=float, default=0.5) 144 | parser.add_argument('-margin_het', type=float, default=0.5) 145 | 146 | # GRL Module - Hyper-param 147 | parser.add_argument('-nlayers_enc', type=int, default=2) 148 | parser.add_argument('-nlayers_proj', type=int, default=1, choices=[1, 2]) 149 | parser.add_argument('-emb_dim', type=int, default=128) 150 | parser.add_argument('-proj_dim', type=int, default=128) 151 | parser.add_argument('-cl_batch_size', type=int, default=0) 152 | parser.add_argument('-k', type=int, default=20) 153 | parser.add_argument('-maskfeat_rate_1', type=float, default=0.1) 154 | parser.add_argument('-maskfeat_rate_2', type=float, default=0.5) 155 | parser.add_argument('-dropedge_rate_1', type=float, default=0.5) 156 | parser.add_argument('-dropedge_rate_2', type=float, default=0.1) 157 | 158 | args = parser.parse_args() 159 | 160 | print(args) 161 | main(args) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Sequential, Linear, ReLU 2 | from sklearn.neighbors import kneighbors_graph 3 | from scipy import sparse 4 | from dgl.nn import EdgeWeightNorm 5 | import random 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import dgl.function as fn 11 | 12 | from utils import * 13 | 14 | EOS = 1e-10 15 | norm = EdgeWeightNorm(norm='both') 16 | 17 | 18 | class GCL(nn.Module): 19 | def __init__(self, nlayers, nlayers_proj, in_dim, emb_dim, proj_dim, dropout, sparse, batch_size): 20 | super(GCL, self).__init__() 21 | 22 | self.encoder1 = SGC(nlayers, in_dim, emb_dim, dropout, sparse) 23 | self.encoder2 = SGC(nlayers, in_dim, emb_dim, dropout, sparse) 24 | 25 | if nlayers_proj == 1: 26 | self.proj_head1 = Sequential(Linear(emb_dim, proj_dim)) 27 | self.proj_head2 = Sequential(Linear(emb_dim, proj_dim)) 28 | elif nlayers_proj == 2: 29 | self.proj_head1 = Sequential(Linear(emb_dim, proj_dim), ReLU(inplace=True), Linear(proj_dim, proj_dim)) 30 | self.proj_head2 = Sequential(Linear(emb_dim, proj_dim), ReLU(inplace=True), Linear(proj_dim, proj_dim)) 31 | 32 | self.batch_size = batch_size 33 | 34 | 35 | def get_embedding(self, x, a1, a2, source='all'): 36 | emb1 = self.encoder1(x, a1) 37 | emb2 = self.encoder2(x, a2) 38 | return torch.cat((emb1, emb2), dim=1) 39 | 40 | 41 | def get_projection(self, x, a1, a2): 42 | emb1 = self.encoder1(x, a1) 43 | emb2 = self.encoder2(x, a2) 44 | proj1 = self.proj_head1(emb1) 45 | proj2 = self.proj_head2(emb2) 46 | return torch.cat((proj1, proj2), dim=1) 47 | 48 | 49 | def forward(self, x1, a1, x2, a2): 50 | emb1 = self.encoder1(x1, a1) 51 | emb2 = self.encoder2(x2, a2) 52 | proj1 = self.proj_head1(emb1) 53 | proj2 = self.proj_head2(emb2) 54 | loss = self.batch_nce_loss(proj1, proj2) 55 | return loss 56 | 57 | 58 | def set_mask_knn(self, X, k, dataset, metric='cosine'): 59 | if k != 0: 60 | path = '../data/knn/{}'.format(dataset) 61 | if not os.path.exists(path): 62 | os.makedirs(path) 63 | file_name = path + '/{}_{}.npz'.format(dataset, k) 64 | if os.path.exists(file_name): 65 | knn = sparse.load_npz(file_name) 66 | # print('Load exist knn graph.') 67 | else: 68 | print('Computing knn graph...') 69 | knn = kneighbors_graph(X, k, metric=metric) 70 | sparse.save_npz(file_name, knn) 71 | print('Done. The knn graph is saved as: {}.'.format(file_name)) 72 | knn = torch.tensor(knn.toarray()) + torch.eye(X.shape[0]) 73 | else: 74 | knn = torch.eye(X.shape[0]) 75 | self.pos_mask = knn 76 | self.neg_mask = 1 - self.pos_mask 77 | 78 | 79 | def batch_nce_loss(self, z1, z2, temperature=0.2, pos_mask=None, neg_mask=None): 80 | if pos_mask is None and neg_mask is None: 81 | pos_mask = self.pos_mask 82 | neg_mask = self.neg_mask 83 | 84 | nnodes = z1.shape[0] 85 | if (self.batch_size == 0) or (self.batch_size > nnodes): 86 | loss_0 = self.infonce(z1, z2, pos_mask, neg_mask, temperature) 87 | loss_1 = self.infonce(z2, z1, pos_mask, neg_mask, temperature) 88 | loss = (loss_0 + loss_1) / 2.0 89 | else: 90 | node_idxs = list(range(nnodes)) 91 | random.shuffle(node_idxs) 92 | batches = split_batch(node_idxs, self.batch_size) 93 | loss = 0 94 | for b in batches: 95 | weight = len(b) / nnodes 96 | loss_0 = self.infonce(z1[b], z2[b], pos_mask[:,b][b,:], neg_mask[:,b][b,:], temperature) 97 | loss_1 = self.infonce(z2[b], z1[b], pos_mask[:,b][b,:], neg_mask[:,b][b,:], temperature) 98 | loss += (loss_0 + loss_1) / 2.0 * weight 99 | return loss 100 | 101 | 102 | def infonce(self, anchor, sample, pos_mask, neg_mask, tau): 103 | pos_mask = pos_mask.cuda() 104 | neg_mask = neg_mask.cuda() 105 | sim = self.similarity(anchor, sample) / tau 106 | exp_sim = torch.exp(sim) * neg_mask 107 | log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True)) 108 | loss = log_prob * pos_mask 109 | loss = loss.sum(dim=1) / pos_mask.sum(dim=1) 110 | return -loss.mean() 111 | 112 | 113 | def similarity(self, h1: torch.Tensor, h2: torch.Tensor): 114 | h1 = F.normalize(h1) 115 | h2 = F.normalize(h2) 116 | return h1 @ h2.t() 117 | 118 | 119 | class Edge_Discriminator(nn.Module): 120 | def __init__(self, nnodes, input_dim, alpha, sparse, hidden_dim=128, temperature=1.0, bias=0.0 + 0.0001): 121 | super(Edge_Discriminator, self).__init__() 122 | 123 | self.embedding_layers = nn.ModuleList() 124 | self.embedding_layers.append(nn.Linear(input_dim, hidden_dim)) 125 | self.edge_mlp = nn.Linear(hidden_dim * 2, 1) 126 | 127 | self.temperature = temperature 128 | self.bias = bias 129 | self.nnodes = nnodes 130 | self.sparse = sparse 131 | self.alpha = alpha 132 | 133 | 134 | def get_node_embedding(self, h): 135 | for layer in self.embedding_layers: 136 | h = layer(h) 137 | h = F.relu(h) 138 | return h 139 | 140 | 141 | def get_edge_weight(self, embeddings, edges): 142 | s1 = self.edge_mlp(torch.cat((embeddings[edges[0]], embeddings[edges[1]]), dim=1)).flatten() 143 | s2 = self.edge_mlp(torch.cat((embeddings[edges[1]], embeddings[edges[0]]), dim=1)).flatten() 144 | return (s1 + s2) / 2 145 | 146 | 147 | def gumbel_sampling(self, edges_weights_raw): 148 | eps = (self.bias - (1 - self.bias)) * torch.rand(edges_weights_raw.size()) + (1 - self.bias) 149 | gate_inputs = torch.log(eps) - torch.log(1 - eps) 150 | gate_inputs = gate_inputs.cuda() 151 | gate_inputs = (gate_inputs + edges_weights_raw) / self.temperature 152 | return torch.sigmoid(gate_inputs).squeeze() 153 | 154 | 155 | def weight_forward(self, features, edges): 156 | embeddings = self.get_node_embedding(features) 157 | edges_weights_raw = self.get_edge_weight(embeddings, edges) 158 | weights_lp = self.gumbel_sampling(edges_weights_raw) 159 | weights_hp = 1 - weights_lp 160 | return weights_lp, weights_hp 161 | 162 | 163 | def weight_to_adj(self, edges, weights_lp, weights_hp): 164 | if not self.sparse: 165 | adj_lp = get_adj_from_edges(edges, weights_lp, self.nnodes) 166 | adj_lp += torch.eye(self.nnodes).cuda() 167 | adj_lp = normalize_adj(adj_lp, 'sym', self.sparse) 168 | 169 | adj_hp = get_adj_from_edges(edges, weights_hp, self.nnodes) 170 | adj_hp += torch.eye(self.nnodes).cuda() 171 | adj_hp = normalize_adj(adj_hp, 'sym', self.sparse) 172 | 173 | mask = torch.zeros(adj_lp.shape).cuda() 174 | mask[edges[0], edges[1]] = 1. 175 | mask.requires_grad = False 176 | adj_hp = torch.eye(self.nnodes).cuda() - adj_hp * mask * self.alpha 177 | else: 178 | adj_lp = dgl.graph((edges[0], edges[1]), num_nodes=self.nnodes, device='cuda') 179 | adj_lp = dgl.add_self_loop(adj_lp) 180 | weights_lp = torch.cat((weights_lp, torch.ones(self.nnodes).cuda())) + EOS 181 | weights_lp = norm(adj_lp, weights_lp) 182 | adj_lp.edata['w'] = weights_lp 183 | 184 | adj_hp = dgl.graph((edges[0], edges[1]), num_nodes=self.nnodes, device='cuda') 185 | adj_hp = dgl.add_self_loop(adj_hp) 186 | weights_hp = torch.cat((weights_hp, torch.ones(self.nnodes).cuda())) + EOS 187 | weights_hp = norm(adj_hp, weights_hp) 188 | weights_hp *= - self.alpha 189 | weights_hp[edges.shape[1]:] = 1 190 | adj_hp.edata['w'] = weights_hp 191 | return adj_lp, adj_hp 192 | 193 | 194 | def forward(self, features, edges): 195 | weights_lp, weights_hp = self.weight_forward(features, edges) 196 | adj_lp, adj_hp = self.weight_to_adj(edges, weights_lp, weights_hp) 197 | return adj_lp, adj_hp, weights_lp, weights_hp 198 | 199 | 200 | class SGC(nn.Module): 201 | def __init__(self, nlayers, in_dim, emb_dim, dropout, sparse): 202 | super(SGC, self).__init__() 203 | self.dropout = dropout 204 | self.sparse = sparse 205 | 206 | self.linear = nn.Linear(in_dim, emb_dim) 207 | self.k = nlayers 208 | 209 | def forward(self, x, g): 210 | x = torch.relu(self.linear(x)) 211 | 212 | if self.sparse: 213 | with g.local_scope(): 214 | g.ndata['h'] = x 215 | for _ in range(self.k): 216 | g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum(msg='m', out='h')) 217 | return g.ndata['h'] 218 | else: 219 | for _ in range(self.k): 220 | x = torch.matmul(g, x) 221 | return x --------------------------------------------------------------------------------