├── data ├── ind.cora.tx ├── ind.cora.ty ├── ind.cora.x ├── ind.cora.y ├── ind.cora.allx ├── ind.cora.ally ├── ind.cora.graph └── ind.cora.test.index ├── assets ├── CCGC_model.png ├── CCGC_tsne.png ├── HSAN_model.png ├── HSAN_tsne.png ├── CCGC_result.png ├── HSAN_result.png └── HSAN_parameter.png ├── requirements.txt ├── model.py ├── layers.py ├── README.md ├── kmeans_gpu.py ├── train.py └── utils.py /data/ind.cora.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.tx -------------------------------------------------------------------------------- /data/ind.cora.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.ty -------------------------------------------------------------------------------- /data/ind.cora.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.x -------------------------------------------------------------------------------- /data/ind.cora.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.y -------------------------------------------------------------------------------- /data/ind.cora.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.allx -------------------------------------------------------------------------------- /data/ind.cora.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.ally -------------------------------------------------------------------------------- /data/ind.cora.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.graph -------------------------------------------------------------------------------- /assets/CCGC_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/CCGC_model.png -------------------------------------------------------------------------------- /assets/CCGC_tsne.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/CCGC_tsne.png -------------------------------------------------------------------------------- /assets/HSAN_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/HSAN_model.png -------------------------------------------------------------------------------- /assets/HSAN_tsne.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/HSAN_tsne.png -------------------------------------------------------------------------------- /assets/CCGC_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/CCGC_result.png -------------------------------------------------------------------------------- /assets/HSAN_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/HSAN_result.png -------------------------------------------------------------------------------- /assets/HSAN_parameter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/HSAN_parameter.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.0 2 | tqdm==4.61.2 3 | numpy==1.21.0 4 | munkres==1.1.4 5 | scikit_learn==1.0 -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers import * 5 | from torch.nn import Linear 6 | from torch_geometric.nn import GATConv, GCNConv, ChebConv 7 | from torch_geometric.nn import JumpingKnowledge 8 | from torch_geometric.nn import MessagePassing, APPNP 9 | 10 | 11 | class Encoder_Net(nn.Module): 12 | def __init__(self, layers, dims): 13 | super(Encoder_Net, self).__init__() 14 | self.layers1 = nn.Linear(dims[0], dims[1]) 15 | self.layers2 = nn.Linear(dims[0], dims[1]) 16 | 17 | def forward(self, x): 18 | out1 = self.layers1(x) 19 | out2 = self.layers2(x) 20 | 21 | out1 = F.normalize(out1, dim=1, p=2) 22 | out2 = F.normalize(out2, dim=1, p=2) 23 | 24 | return out1, out2 25 | 26 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.module import Module 5 | from torch.nn.parameter import Parameter 6 | import numpy as np 7 | 8 | class GraphConvolution(Module): 9 | """ 10 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 11 | """ 12 | 13 | def __init__(self, in_features, out_features, dropout=0., act=F.relu): 14 | super(GraphConvolution, self).__init__() 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | self.dropout = dropout 18 | self.act = act 19 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 20 | self.reset_parameters() 21 | 22 | def reset_parameters(self): 23 | torch.nn.init.xavier_uniform_(self.weight) 24 | 25 | def forward(self, input, adj): 26 | input = F.dropout(input, self.dropout, self.training) 27 | support = torch.mm(input, self.weight) 28 | output = torch.spmm(adj, support) 29 | output = self.act(output) 30 | return output 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + ' (' \ 34 | + str(self.in_features) + ' -> ' \ 35 | + str(self.out_features) + ')' 36 | 37 | class SampleDecoder(Module): 38 | def __init__(self, act=torch.sigmoid): 39 | super(SampleDecoder, self).__init__() 40 | self.act = act 41 | 42 | def forward(self, zx, zy): 43 | 44 | sim = (zx * zy).sum(1) 45 | sim = self.act(sim) 46 | 47 | return sim -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [stars-img]: https://img.shields.io/github/stars/xihongyang1999/CCGC?color=yellow 2 | [stars-url]: https://github.com/xihongyang1999/CCGC/stargazers 3 | [fork-img]: https://img.shields.io/github/forks/xihongyang1999/CCGC?color=lightblue&label=fork 4 | [fork-url]: https://github.com/xihongyang1999/CCGC/network/members 5 | [visitors-img]: https://visitor-badge.glitch.me/badge?page_id=xihongyang1999.CCGC 6 | [adgc-url]: https://github.com/xihongyang1999/CCGC 7 | 8 | # Cluster-guided Contrastive Graph Clustering Network 9 | 10 |

11 | 12 | 13 | 14 | 15 |

16 | 17 | 18 | [![GitHub stars][stars-img]][stars-url] 19 | [![GitHub forks][fork-img]][fork-url] 20 | [![visitors][visitors-img]][adgc-url] 21 | 22 | 23 | An official source code for paper Cluster-guided Contrastive Graph Clustering Network , accepted by AAAI 2023. Any communications or issues are welcomed. Please contact xihong_edu@163.com. If you find this repository useful to your research or work, it is really appreciate to star this repository. :heart: 24 | 25 | ------------- 26 | 27 | ### Overview 28 | 29 |

30 | We propose a Cluster-guided Contrastive deep Graph Clustering network (CCGC) by mining the intrinsic supervision information in the high-confidence clustering results. Specifically, instead of conducting complex node or edge perturbation, we construct two views of the graph by designing special Siamese encoders whose weights are not shared between the sibling sub-networks. Then, guided by the high-confidence clustering information, we carefully select and construct the positive samples from the same high-confidence cluster in two views. Moreover, to construct semantic meaningful negative sample pairs, we regard the centers of different high-confidence clusters as negative samples, thus improving the discriminative capability and reliability of the constructed sample pairs. Lastly, we design an objective function to pull close the samples from the same cluster while pushing away those from other clusters by maximizing and minimizing the cross-view cosine similarity between positive and negative samples. Extensive experimental results on six datasets demonstrate the effectiveness of CCGC compared with the existing state-of-the-art algorithms. 31 | 32 | 33 | 34 | 35 |

36 | 37 |
38 | 39 |
40 | Figure 1: Overall framework of CCGC. 41 |
42 | 43 | 44 | 45 | ### Requirements 46 | 47 | The proposed CCGC is implemented with python 3.7 on a NVIDIA 2080Ti GPU. 48 | 49 | Python package information is summarized in **requirements.txt**: 50 | 51 | - torch==1.7.1 52 | - tqdm==4.59.0 53 | - numpy==1.19.2 54 | - munkres==1.1.4 55 | - scikit_learn==1.2.0 56 | 57 | 58 | 59 | 60 | 61 | ### Quick Start 62 | 63 | python train.py 64 | 65 | 66 | 67 | ### Clustering Results 68 | 69 |
70 | 71 |
72 |
73 | Table 1: Clustering results of our proposed CCGC and twelve baselines on six datasets. 74 |
75 | 76 | 77 |
78 | 79 |
80 | 81 | 82 |
83 | Figure 2: 2D t-SNE visualization of six methods on two datasets. 84 |
85 | 86 | 87 | 88 | ### Citation 89 | 90 | If you find this project useful for your research, please cite your paper with the following BibTeX entry. 91 | 92 | ``` 93 | @inproceedings{CCGC, 94 | title={Cluster-guided Contrastive Graph Clustering Network}, 95 | author={Yang, Xihong and Liu, Yue and Zhou, Sihang and Wang, Siwei and Tu, Wenxuan and Zheng, Qun and Liu, Xinwang and Fang, Liming and Zhu, En}, 96 | booktitle={Proceedings of the AAAI conference on artificial intelligence}, 97 | volume={37}, 98 | number={9}, 99 | pages={10834--10842}, 100 | year={2023} 101 | } 102 | ``` 103 | -------------------------------------------------------------------------------- /kmeans_gpu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | 5 | 6 | def setup_seed(seed): 7 | """ 8 | setup random seed to fix the result 9 | Args: 10 | seed: random seed 11 | Returns: None 12 | """ 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.backends.cudnn.benchmark = False 19 | torch.backends.cudnn.deterministic = True 20 | 21 | 22 | def initialize(X, num_clusters): 23 | """ 24 | initialize cluster centers 25 | :param X: (torch.tensor) matrix 26 | :param num_clusters: (int) number of clusters 27 | :return: (np.array) initial state 28 | """ 29 | num_samples = len(X) 30 | indices = np.random.choice(num_samples, num_clusters, replace=False) 31 | initial_state = X[indices] 32 | return initial_state 33 | 34 | 35 | def kmeans( 36 | X, 37 | num_clusters, 38 | distance='euclidean', 39 | tol=1e-4, 40 | device=torch.device('cuda') 41 | ): 42 | """ 43 | perform kmeans 44 | :param X: (torch.tensor) matrix 45 | :param num_clusters: (int) number of clusters 46 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] 47 | :param tol: (float) threshold [default: 0.0001] 48 | :param device: (torch.device) device [default: cpu] 49 | :return: (torch.tensor, torch.tensor) cluster ids, cluster centers 50 | """ 51 | # print(f'running k-means on {device}..') 52 | if distance == 'euclidean': 53 | pairwise_distance_function = pairwise_distance 54 | elif distance == 'cosine': 55 | pairwise_distance_function = pairwise_cosine 56 | else: 57 | raise NotImplementedError 58 | 59 | # convert to float 60 | X = X.float() 61 | 62 | # transfer to device 63 | X = X.to(device) 64 | 65 | # initialize 66 | dis_min = float('inf') 67 | initial_state_best = None 68 | for i in range(20): 69 | initial_state = initialize(X, num_clusters) 70 | dis = pairwise_distance_function(X, initial_state).sum() 71 | if dis < dis_min: 72 | dis_min = dis 73 | initial_state_best = initial_state 74 | 75 | initial_state = initial_state_best 76 | iteration = 0 77 | while True: 78 | dis = pairwise_distance_function(X, initial_state) 79 | 80 | choice_cluster = torch.argmin(dis, dim=1) 81 | 82 | initial_state_pre = initial_state.clone() 83 | 84 | for index in range(num_clusters): 85 | selected = torch.nonzero(choice_cluster == index).squeeze().to(device) 86 | 87 | selected = torch.index_select(X, 0, selected) 88 | initial_state[index] = selected.mean(dim=0) 89 | 90 | center_shift = torch.sum( 91 | torch.sqrt( 92 | torch.sum((initial_state - initial_state_pre) ** 2, dim=1) 93 | )) 94 | 95 | # increment iteration 96 | iteration = iteration + 1 97 | 98 | if iteration > 500: 99 | break 100 | if center_shift ** 2 < tol: 101 | break 102 | 103 | return choice_cluster.cpu(), dis.cpu(), initial_state.cpu() 104 | 105 | 106 | def kmeans_predict( 107 | X, 108 | cluster_centers, 109 | distance='euclidean', 110 | device=torch.device('cuda') 111 | ): 112 | """ 113 | predict using cluster centers 114 | :param X: (torch.tensor) matrix 115 | :param cluster_centers: (torch.tensor) cluster centers 116 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] 117 | :param device: (torch.device) device [default: 'cpu'] 118 | :return: (torch.tensor) cluster ids 119 | """ 120 | # print(f'predicting on {device}..') 121 | 122 | if distance == 'euclidean': 123 | pairwise_distance_function = pairwise_distance 124 | elif distance == 'cosine': 125 | pairwise_distance_function = pairwise_cosine 126 | else: 127 | raise NotImplementedError 128 | 129 | # convert to float 130 | X = X.float() 131 | 132 | # transfer to device 133 | X = X.to(device) 134 | 135 | dis = pairwise_distance_function(X, cluster_centers) 136 | choice_cluster = torch.argmin(dis, dim=1) 137 | 138 | return choice_cluster.cpu() 139 | 140 | 141 | def pairwise_distance(data1, data2, device=torch.device('cuda')): 142 | # transfer to device 143 | data1, data2 = data1.to(device), data2.to(device) 144 | 145 | # N*1*M 146 | A = data1.unsqueeze(dim=1) 147 | 148 | # 1*N*M 149 | B = data2.unsqueeze(dim=0) 150 | 151 | dis = (A - B) ** 2.0 152 | # return N*N matrix for pairwise distance 153 | dis = dis.sum(dim=-1).squeeze() 154 | return dis 155 | 156 | 157 | def pairwise_cosine(data1, data2, device=torch.device('cuda')): 158 | # transfer to device 159 | data1, data2 = data1.to(device), data2.to(device) 160 | 161 | # N*1*M 162 | A = data1.unsqueeze(dim=1) 163 | 164 | # 1*N*M 165 | B = data2.unsqueeze(dim=0) 166 | 167 | # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] 168 | A_normalized = A / A.norm(dim=-1, keepdim=True) 169 | B_normalized = B / B.norm(dim=-1, keepdim=True) 170 | 171 | cosine = A_normalized * B_normalized 172 | 173 | # return N*N matrix for pairwise distance 174 | cosine_dis = 1 - cosine.sum(dim=-1).squeeze() 175 | return cosine_dis 176 | 177 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import * 3 | from tqdm import tqdm 4 | from torch import optim 5 | from model import Encoder_Net 6 | import torch.nn.functional as F 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--t', type=int, default=4, help="Number of gnn layers") 10 | parser.add_argument('--linlayers', type=int, default=1, help="Number of hidden layers") 11 | parser.add_argument('--epochs', type=int, default=400, help='Number of epochs to train.') 12 | parser.add_argument('--dims', type=int, default=500, help='feature dim') 13 | parser.add_argument('--lr', type=float, default=1e-4, help='Initial learning rate.') 14 | parser.add_argument('--dataset', type=str, default='cora', help='name of dataset.') 15 | parser.add_argument('--cluster_num', type=int, default=7, help='number of cluster.') 16 | parser.add_argument('--device', type=str, default='cuda', help='the training device') 17 | parser.add_argument('--threshold', type=float, default=0.5, help='the threshold of high-confidence') 18 | parser.add_argument('--alpha', type=float, default=0.5, help='trade-off of loss') 19 | args = parser.parse_args() 20 | 21 | #load data 22 | adj, features, true_labels, idx_train, idx_val, idx_test = load_data(args.dataset) 23 | adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape) 24 | adj.eliminate_zeros() 25 | 26 | # Laplacian Smoothing 27 | adj_norm_s = preprocess_graph(adj, args.t, norm='sym', renorm=True) 28 | smooth_fea = sp.csr_matrix(features).toarray() 29 | for a in adj_norm_s: 30 | smooth_fea = a.dot(smooth_fea) 31 | smooth_fea = torch.FloatTensor(smooth_fea) 32 | 33 | acc_list = [] 34 | nmi_list = [] 35 | ari_list = [] 36 | f1_list = [] 37 | 38 | for seed in range(10): 39 | 40 | setup_seed(seed) 41 | 42 | # init 43 | best_acc, best_nmi, best_ari, best_f1, predict_labels, dis= clustering(smooth_fea, true_labels, args.cluster_num) 44 | 45 | # MLP 46 | model = Encoder_Net(args.linlayers, [features.shape[1]] + [args.dims]) 47 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 48 | 49 | # GPU 50 | model.to(args.device) 51 | smooth_fea = smooth_fea.to(args.device) 52 | sample_size = features.shape[0] 53 | target = torch.eye(smooth_fea.shape[0]).to(args.device) 54 | 55 | for epoch in tqdm(range(args.epochs)): 56 | model.train() 57 | z1, z2 = model(smooth_fea) 58 | if epoch > 50: 59 | 60 | high_confidence = torch.min(dis, dim=1).values 61 | threshold = torch.sort(high_confidence).values[int(len(high_confidence) * args.threshold)] 62 | high_confidence_idx = np.argwhere(high_confidence < threshold)[0] 63 | 64 | # pos samples 65 | index = torch.tensor(range(smooth_fea.shape[0]), device=args.device)[high_confidence_idx] 66 | y_sam = torch.tensor(predict_labels, device=args.device)[high_confidence_idx] 67 | index = index[torch.argsort(y_sam)] 68 | class_num = {} 69 | 70 | for label in torch.sort(y_sam).values: 71 | label = label.item() 72 | if label in class_num.keys(): 73 | class_num[label] += 1 74 | else: 75 | class_num[label] = 1 76 | key = sorted(class_num.keys()) 77 | if len(class_num) < 2: 78 | continue 79 | pos_contrastive = 0 80 | centers_1 = torch.tensor([], device=args.device) 81 | centers_2 = torch.tensor([], device=args.device) 82 | 83 | 84 | for i in range(len(key[:-1])): 85 | class_num[key[i + 1]] = class_num[key[i]] + class_num[key[i + 1]] 86 | now = index[class_num[key[i]]:class_num[key[i + 1]]] 87 | pos_embed_1 = z1[np.random.choice(now.cpu(), size=int((now.shape[0] * 0.8)), replace=False)] 88 | pos_embed_2 = z2[np.random.choice(now.cpu(), size=int((now.shape[0] * 0.8)), replace=False)] 89 | pos_contrastive += (2 - 2 * torch.sum(pos_embed_1 * pos_embed_2, dim=1)).sum() 90 | centers_1 = torch.cat([centers_1, torch.mean(z1[now], dim=0).unsqueeze(0)], dim=0) 91 | centers_2 = torch.cat([centers_2, torch.mean(z2[now], dim=0).unsqueeze(0)], dim=0) 92 | 93 | pos_contrastive = pos_contrastive / args.cluster_num 94 | if pos_contrastive == 0: 95 | continue 96 | if len(class_num) < 2: 97 | loss = pos_contrastive 98 | else: 99 | centers_1 = F.normalize(centers_1, dim=1, p=2) 100 | centers_2 = F.normalize(centers_2, dim=1, p=2) 101 | S = centers_1 @ centers_2.T 102 | S_diag = torch.diag_embed(torch.diag(S)) 103 | S = S - S_diag 104 | neg_contrastive = F.mse_loss(S, torch.zeros_like(S)) 105 | loss = pos_contrastive + args.alpha * neg_contrastive 106 | 107 | else: 108 | S = z1 @ z2.T 109 | loss = F.mse_loss(S, target) 110 | 111 | loss.backward(retain_graph=True) 112 | optimizer.step() 113 | 114 | if epoch % 10 == 0: 115 | model.eval() 116 | z1, z2 = model(smooth_fea) 117 | 118 | hidden_emb = (z1 + z2) / 2 119 | 120 | acc, nmi, ari, f1, predict_labels, dis = clustering(hidden_emb, true_labels, args.cluster_num) 121 | if acc >= best_acc: 122 | best_acc = acc 123 | best_nmi = nmi 124 | best_ari = ari 125 | best_f1 = f1 126 | 127 | acc_list.append(best_acc) 128 | nmi_list.append(best_nmi) 129 | ari_list.append(best_ari) 130 | f1_list.append(best_f1) 131 | 132 | acc_list = np.array(acc_list) 133 | nmi_list = np.array(nmi_list) 134 | ari_list = np.array(ari_list) 135 | f1_list = np.array(f1_list) 136 | print(acc_list.mean(), "±", acc_list.std()) 137 | print(nmi_list.mean(), "±", nmi_list.std()) 138 | print(ari_list.mean(), "±", ari_list.std()) 139 | print(f1_list.mean(), "±", f1_list.std()) 140 | -------------------------------------------------------------------------------- /data/ind.cora.test.index: -------------------------------------------------------------------------------- 1 | 2692 2 | 2532 3 | 2050 4 | 1715 5 | 2362 6 | 2609 7 | 2622 8 | 1975 9 | 2081 10 | 1767 11 | 2263 12 | 1725 13 | 2588 14 | 2259 15 | 2357 16 | 1998 17 | 2574 18 | 2179 19 | 2291 20 | 2382 21 | 1812 22 | 1751 23 | 2422 24 | 1937 25 | 2631 26 | 2510 27 | 2378 28 | 2589 29 | 2345 30 | 1943 31 | 1850 32 | 2298 33 | 1825 34 | 2035 35 | 2507 36 | 2313 37 | 1906 38 | 1797 39 | 2023 40 | 2159 41 | 2495 42 | 1886 43 | 2122 44 | 2369 45 | 2461 46 | 1925 47 | 2565 48 | 1858 49 | 2234 50 | 2000 51 | 1846 52 | 2318 53 | 1723 54 | 2559 55 | 2258 56 | 1763 57 | 1991 58 | 1922 59 | 2003 60 | 2662 61 | 2250 62 | 2064 63 | 2529 64 | 1888 65 | 2499 66 | 2454 67 | 2320 68 | 2287 69 | 2203 70 | 2018 71 | 2002 72 | 2632 73 | 2554 74 | 2314 75 | 2537 76 | 1760 77 | 2088 78 | 2086 79 | 2218 80 | 2605 81 | 1953 82 | 2403 83 | 1920 84 | 2015 85 | 2335 86 | 2535 87 | 1837 88 | 2009 89 | 1905 90 | 2636 91 | 1942 92 | 2193 93 | 2576 94 | 2373 95 | 1873 96 | 2463 97 | 2509 98 | 1954 99 | 2656 100 | 2455 101 | 2494 102 | 2295 103 | 2114 104 | 2561 105 | 2176 106 | 2275 107 | 2635 108 | 2442 109 | 2704 110 | 2127 111 | 2085 112 | 2214 113 | 2487 114 | 1739 115 | 2543 116 | 1783 117 | 2485 118 | 2262 119 | 2472 120 | 2326 121 | 1738 122 | 2170 123 | 2100 124 | 2384 125 | 2152 126 | 2647 127 | 2693 128 | 2376 129 | 1775 130 | 1726 131 | 2476 132 | 2195 133 | 1773 134 | 1793 135 | 2194 136 | 2581 137 | 1854 138 | 2524 139 | 1945 140 | 1781 141 | 1987 142 | 2599 143 | 1744 144 | 2225 145 | 2300 146 | 1928 147 | 2042 148 | 2202 149 | 1958 150 | 1816 151 | 1916 152 | 2679 153 | 2190 154 | 1733 155 | 2034 156 | 2643 157 | 2177 158 | 1883 159 | 1917 160 | 1996 161 | 2491 162 | 2268 163 | 2231 164 | 2471 165 | 1919 166 | 1909 167 | 2012 168 | 2522 169 | 1865 170 | 2466 171 | 2469 172 | 2087 173 | 2584 174 | 2563 175 | 1924 176 | 2143 177 | 1736 178 | 1966 179 | 2533 180 | 2490 181 | 2630 182 | 1973 183 | 2568 184 | 1978 185 | 2664 186 | 2633 187 | 2312 188 | 2178 189 | 1754 190 | 2307 191 | 2480 192 | 1960 193 | 1742 194 | 1962 195 | 2160 196 | 2070 197 | 2553 198 | 2433 199 | 1768 200 | 2659 201 | 2379 202 | 2271 203 | 1776 204 | 2153 205 | 1877 206 | 2027 207 | 2028 208 | 2155 209 | 2196 210 | 2483 211 | 2026 212 | 2158 213 | 2407 214 | 1821 215 | 2131 216 | 2676 217 | 2277 218 | 2489 219 | 2424 220 | 1963 221 | 1808 222 | 1859 223 | 2597 224 | 2548 225 | 2368 226 | 1817 227 | 2405 228 | 2413 229 | 2603 230 | 2350 231 | 2118 232 | 2329 233 | 1969 234 | 2577 235 | 2475 236 | 2467 237 | 2425 238 | 1769 239 | 2092 240 | 2044 241 | 2586 242 | 2608 243 | 1983 244 | 2109 245 | 2649 246 | 1964 247 | 2144 248 | 1902 249 | 2411 250 | 2508 251 | 2360 252 | 1721 253 | 2005 254 | 2014 255 | 2308 256 | 2646 257 | 1949 258 | 1830 259 | 2212 260 | 2596 261 | 1832 262 | 1735 263 | 1866 264 | 2695 265 | 1941 266 | 2546 267 | 2498 268 | 2686 269 | 2665 270 | 1784 271 | 2613 272 | 1970 273 | 2021 274 | 2211 275 | 2516 276 | 2185 277 | 2479 278 | 2699 279 | 2150 280 | 1990 281 | 2063 282 | 2075 283 | 1979 284 | 2094 285 | 1787 286 | 2571 287 | 2690 288 | 1926 289 | 2341 290 | 2566 291 | 1957 292 | 1709 293 | 1955 294 | 2570 295 | 2387 296 | 1811 297 | 2025 298 | 2447 299 | 2696 300 | 2052 301 | 2366 302 | 1857 303 | 2273 304 | 2245 305 | 2672 306 | 2133 307 | 2421 308 | 1929 309 | 2125 310 | 2319 311 | 2641 312 | 2167 313 | 2418 314 | 1765 315 | 1761 316 | 1828 317 | 2188 318 | 1972 319 | 1997 320 | 2419 321 | 2289 322 | 2296 323 | 2587 324 | 2051 325 | 2440 326 | 2053 327 | 2191 328 | 1923 329 | 2164 330 | 1861 331 | 2339 332 | 2333 333 | 2523 334 | 2670 335 | 2121 336 | 1921 337 | 1724 338 | 2253 339 | 2374 340 | 1940 341 | 2545 342 | 2301 343 | 2244 344 | 2156 345 | 1849 346 | 2551 347 | 2011 348 | 2279 349 | 2572 350 | 1757 351 | 2400 352 | 2569 353 | 2072 354 | 2526 355 | 2173 356 | 2069 357 | 2036 358 | 1819 359 | 1734 360 | 1880 361 | 2137 362 | 2408 363 | 2226 364 | 2604 365 | 1771 366 | 2698 367 | 2187 368 | 2060 369 | 1756 370 | 2201 371 | 2066 372 | 2439 373 | 1844 374 | 1772 375 | 2383 376 | 2398 377 | 1708 378 | 1992 379 | 1959 380 | 1794 381 | 2426 382 | 2702 383 | 2444 384 | 1944 385 | 1829 386 | 2660 387 | 2497 388 | 2607 389 | 2343 390 | 1730 391 | 2624 392 | 1790 393 | 1935 394 | 1967 395 | 2401 396 | 2255 397 | 2355 398 | 2348 399 | 1931 400 | 2183 401 | 2161 402 | 2701 403 | 1948 404 | 2501 405 | 2192 406 | 2404 407 | 2209 408 | 2331 409 | 1810 410 | 2363 411 | 2334 412 | 1887 413 | 2393 414 | 2557 415 | 1719 416 | 1732 417 | 1986 418 | 2037 419 | 2056 420 | 1867 421 | 2126 422 | 1932 423 | 2117 424 | 1807 425 | 1801 426 | 1743 427 | 2041 428 | 1843 429 | 2388 430 | 2221 431 | 1833 432 | 2677 433 | 1778 434 | 2661 435 | 2306 436 | 2394 437 | 2106 438 | 2430 439 | 2371 440 | 2606 441 | 2353 442 | 2269 443 | 2317 444 | 2645 445 | 2372 446 | 2550 447 | 2043 448 | 1968 449 | 2165 450 | 2310 451 | 1985 452 | 2446 453 | 1982 454 | 2377 455 | 2207 456 | 1818 457 | 1913 458 | 1766 459 | 1722 460 | 1894 461 | 2020 462 | 1881 463 | 2621 464 | 2409 465 | 2261 466 | 2458 467 | 2096 468 | 1712 469 | 2594 470 | 2293 471 | 2048 472 | 2359 473 | 1839 474 | 2392 475 | 2254 476 | 1911 477 | 2101 478 | 2367 479 | 1889 480 | 1753 481 | 2555 482 | 2246 483 | 2264 484 | 2010 485 | 2336 486 | 2651 487 | 2017 488 | 2140 489 | 1842 490 | 2019 491 | 1890 492 | 2525 493 | 2134 494 | 2492 495 | 2652 496 | 2040 497 | 2145 498 | 2575 499 | 2166 500 | 1999 501 | 2434 502 | 1711 503 | 2276 504 | 2450 505 | 2389 506 | 2669 507 | 2595 508 | 1814 509 | 2039 510 | 2502 511 | 1896 512 | 2168 513 | 2344 514 | 2637 515 | 2031 516 | 1977 517 | 2380 518 | 1936 519 | 2047 520 | 2460 521 | 2102 522 | 1745 523 | 2650 524 | 2046 525 | 2514 526 | 1980 527 | 2352 528 | 2113 529 | 1713 530 | 2058 531 | 2558 532 | 1718 533 | 1864 534 | 1876 535 | 2338 536 | 1879 537 | 1891 538 | 2186 539 | 2451 540 | 2181 541 | 2638 542 | 2644 543 | 2103 544 | 2591 545 | 2266 546 | 2468 547 | 1869 548 | 2582 549 | 2674 550 | 2361 551 | 2462 552 | 1748 553 | 2215 554 | 2615 555 | 2236 556 | 2248 557 | 2493 558 | 2342 559 | 2449 560 | 2274 561 | 1824 562 | 1852 563 | 1870 564 | 2441 565 | 2356 566 | 1835 567 | 2694 568 | 2602 569 | 2685 570 | 1893 571 | 2544 572 | 2536 573 | 1994 574 | 1853 575 | 1838 576 | 1786 577 | 1930 578 | 2539 579 | 1892 580 | 2265 581 | 2618 582 | 2486 583 | 2583 584 | 2061 585 | 1796 586 | 1806 587 | 2084 588 | 1933 589 | 2095 590 | 2136 591 | 2078 592 | 1884 593 | 2438 594 | 2286 595 | 2138 596 | 1750 597 | 2184 598 | 1799 599 | 2278 600 | 2410 601 | 2642 602 | 2435 603 | 1956 604 | 2399 605 | 1774 606 | 2129 607 | 1898 608 | 1823 609 | 1938 610 | 2299 611 | 1862 612 | 2420 613 | 2673 614 | 1984 615 | 2204 616 | 1717 617 | 2074 618 | 2213 619 | 2436 620 | 2297 621 | 2592 622 | 2667 623 | 2703 624 | 2511 625 | 1779 626 | 1782 627 | 2625 628 | 2365 629 | 2315 630 | 2381 631 | 1788 632 | 1714 633 | 2302 634 | 1927 635 | 2325 636 | 2506 637 | 2169 638 | 2328 639 | 2629 640 | 2128 641 | 2655 642 | 2282 643 | 2073 644 | 2395 645 | 2247 646 | 2521 647 | 2260 648 | 1868 649 | 1988 650 | 2324 651 | 2705 652 | 2541 653 | 1731 654 | 2681 655 | 2707 656 | 2465 657 | 1785 658 | 2149 659 | 2045 660 | 2505 661 | 2611 662 | 2217 663 | 2180 664 | 1904 665 | 2453 666 | 2484 667 | 1871 668 | 2309 669 | 2349 670 | 2482 671 | 2004 672 | 1965 673 | 2406 674 | 2162 675 | 1805 676 | 2654 677 | 2007 678 | 1947 679 | 1981 680 | 2112 681 | 2141 682 | 1720 683 | 1758 684 | 2080 685 | 2330 686 | 2030 687 | 2432 688 | 2089 689 | 2547 690 | 1820 691 | 1815 692 | 2675 693 | 1840 694 | 2658 695 | 2370 696 | 2251 697 | 1908 698 | 2029 699 | 2068 700 | 2513 701 | 2549 702 | 2267 703 | 2580 704 | 2327 705 | 2351 706 | 2111 707 | 2022 708 | 2321 709 | 2614 710 | 2252 711 | 2104 712 | 1822 713 | 2552 714 | 2243 715 | 1798 716 | 2396 717 | 2663 718 | 2564 719 | 2148 720 | 2562 721 | 2684 722 | 2001 723 | 2151 724 | 2706 725 | 2240 726 | 2474 727 | 2303 728 | 2634 729 | 2680 730 | 2055 731 | 2090 732 | 2503 733 | 2347 734 | 2402 735 | 2238 736 | 1950 737 | 2054 738 | 2016 739 | 1872 740 | 2233 741 | 1710 742 | 2032 743 | 2540 744 | 2628 745 | 1795 746 | 2616 747 | 1903 748 | 2531 749 | 2567 750 | 1946 751 | 1897 752 | 2222 753 | 2227 754 | 2627 755 | 1856 756 | 2464 757 | 2241 758 | 2481 759 | 2130 760 | 2311 761 | 2083 762 | 2223 763 | 2284 764 | 2235 765 | 2097 766 | 1752 767 | 2515 768 | 2527 769 | 2385 770 | 2189 771 | 2283 772 | 2182 773 | 2079 774 | 2375 775 | 2174 776 | 2437 777 | 1993 778 | 2517 779 | 2443 780 | 2224 781 | 2648 782 | 2171 783 | 2290 784 | 2542 785 | 2038 786 | 1855 787 | 1831 788 | 1759 789 | 1848 790 | 2445 791 | 1827 792 | 2429 793 | 2205 794 | 2598 795 | 2657 796 | 1728 797 | 2065 798 | 1918 799 | 2427 800 | 2573 801 | 2620 802 | 2292 803 | 1777 804 | 2008 805 | 1875 806 | 2288 807 | 2256 808 | 2033 809 | 2470 810 | 2585 811 | 2610 812 | 2082 813 | 2230 814 | 1915 815 | 1847 816 | 2337 817 | 2512 818 | 2386 819 | 2006 820 | 2653 821 | 2346 822 | 1951 823 | 2110 824 | 2639 825 | 2520 826 | 1939 827 | 2683 828 | 2139 829 | 2220 830 | 1910 831 | 2237 832 | 1900 833 | 1836 834 | 2197 835 | 1716 836 | 1860 837 | 2077 838 | 2519 839 | 2538 840 | 2323 841 | 1914 842 | 1971 843 | 1845 844 | 2132 845 | 1802 846 | 1907 847 | 2640 848 | 2496 849 | 2281 850 | 2198 851 | 2416 852 | 2285 853 | 1755 854 | 2431 855 | 2071 856 | 2249 857 | 2123 858 | 1727 859 | 2459 860 | 2304 861 | 2199 862 | 1791 863 | 1809 864 | 1780 865 | 2210 866 | 2417 867 | 1874 868 | 1878 869 | 2116 870 | 1961 871 | 1863 872 | 2579 873 | 2477 874 | 2228 875 | 2332 876 | 2578 877 | 2457 878 | 2024 879 | 1934 880 | 2316 881 | 1841 882 | 1764 883 | 1737 884 | 2322 885 | 2239 886 | 2294 887 | 1729 888 | 2488 889 | 1974 890 | 2473 891 | 2098 892 | 2612 893 | 1834 894 | 2340 895 | 2423 896 | 2175 897 | 2280 898 | 2617 899 | 2208 900 | 2560 901 | 1741 902 | 2600 903 | 2059 904 | 1747 905 | 2242 906 | 2700 907 | 2232 908 | 2057 909 | 2147 910 | 2682 911 | 1792 912 | 1826 913 | 2120 914 | 1895 915 | 2364 916 | 2163 917 | 1851 918 | 2391 919 | 2414 920 | 2452 921 | 1803 922 | 1989 923 | 2623 924 | 2200 925 | 2528 926 | 2415 927 | 1804 928 | 2146 929 | 2619 930 | 2687 931 | 1762 932 | 2172 933 | 2270 934 | 2678 935 | 2593 936 | 2448 937 | 1882 938 | 2257 939 | 2500 940 | 1899 941 | 2478 942 | 2412 943 | 2107 944 | 1746 945 | 2428 946 | 2115 947 | 1800 948 | 1901 949 | 2397 950 | 2530 951 | 1912 952 | 2108 953 | 2206 954 | 2091 955 | 1740 956 | 2219 957 | 1976 958 | 2099 959 | 2142 960 | 2671 961 | 2668 962 | 2216 963 | 2272 964 | 2229 965 | 2666 966 | 2456 967 | 2534 968 | 2697 969 | 2688 970 | 2062 971 | 2691 972 | 2689 973 | 2154 974 | 2590 975 | 2626 976 | 2390 977 | 1813 978 | 2067 979 | 1952 980 | 2518 981 | 2358 982 | 1789 983 | 2076 984 | 2049 985 | 2119 986 | 2013 987 | 2124 988 | 2556 989 | 2105 990 | 2093 991 | 1885 992 | 2305 993 | 2354 994 | 2135 995 | 2601 996 | 1770 997 | 1995 998 | 2504 999 | 1749 1000 | 2157 1001 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import pickle as pkl 5 | import networkx as nx 6 | import scipy.sparse as sp 7 | from sklearn import metrics 8 | from munkres import Munkres 9 | import matplotlib.pyplot as plt 10 | from kmeans_gpu import kmeans 11 | import sklearn.preprocessing as preprocess 12 | from sklearn.metrics import adjusted_rand_score as ari_score 13 | from sklearn.metrics import roc_auc_score, average_precision_score 14 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 15 | 16 | 17 | def sample_mask(idx, l): 18 | """Create mask.""" 19 | mask = np.zeros(l) 20 | mask[idx] = 1 21 | return np.array(mask, dtype=np.bool) 22 | 23 | 24 | def load_data(dataset): 25 | # load the data: x, tx, allx, graph 26 | names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph'] 27 | objects = [] 28 | if dataset == 'wiki': 29 | adj, features, label = load_wiki() 30 | return adj, features, label, 0, 0, 0 31 | 32 | for i in range(len(names)): 33 | ''' 34 | fix Pickle incompatibility of numpy arrays between Python 2 and 3 35 | https://stackoverflow.com/questions/11305790/pickle-incompatibility-of-numpy-arrays-between-python-2-and-3 36 | ''' 37 | with open("data/ind.{}.{}".format(dataset, names[i]), 'rb') as rf: 38 | u = pkl._Unpickler(rf) 39 | u.encoding = 'latin1' 40 | cur_data = u.load() 41 | objects.append(cur_data) 42 | # objects.append( 43 | # pkl.load(open("data/ind.{}.{}".format(dataset, names[i]), 'rb'))) 44 | x, y, tx, ty, allx, ally, graph = tuple(objects) 45 | test_idx_reorder = parse_index_file( 46 | "data/ind.{}.test.index".format(dataset)) 47 | test_idx_range = np.sort(test_idx_reorder) 48 | 49 | if dataset == 'citeseer': 50 | # Fix citeseer dataset (there are some isolated nodes in the graph) 51 | # Find isolated nodes, add them as zero-vecs into the right position 52 | test_idx_range_full = range( 53 | min(test_idx_reorder), max(test_idx_reorder) + 1) 54 | tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) 55 | tx_extended[test_idx_range - min(test_idx_range), :] = tx 56 | tx = tx_extended 57 | ty_extended = np.zeros((len(test_idx_range_full), y.shape[1])) 58 | ty_extended[test_idx_range - min(test_idx_range), :] = ty 59 | ty = ty_extended 60 | 61 | features = sp.vstack((allx, tx)).tolil() 62 | features[test_idx_reorder, :] = features[test_idx_range, :] 63 | features = torch.FloatTensor(np.array(features.todense())) 64 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph)) 65 | 66 | labels = np.vstack((ally, ty)) 67 | labels[test_idx_reorder, :] = labels[test_idx_range, :] 68 | 69 | idx_test = test_idx_range.tolist() 70 | idx_train = range(len(y)) 71 | idx_val = range(len(y), len(y) + 500) 72 | 73 | train_mask = sample_mask(idx_train, labels.shape[0]) 74 | val_mask = sample_mask(idx_val, labels.shape[0]) 75 | test_mask = sample_mask(idx_test, labels.shape[0]) 76 | 77 | y_train = np.zeros(labels.shape) 78 | y_val = np.zeros(labels.shape) 79 | y_test = np.zeros(labels.shape) 80 | y_train[train_mask, :] = labels[train_mask, :] 81 | y_val[val_mask, :] = labels[val_mask, :] 82 | y_test[test_mask, :] = labels[test_mask, :] 83 | 84 | return adj, features, np.argmax(labels, 1), idx_train, idx_val, idx_test 85 | 86 | 87 | def load_wiki(): 88 | f = open('data/graph.txt', 'r') 89 | adj, xind, yind = [], [], [] 90 | for line in f.readlines(): 91 | line = line.split() 92 | 93 | xind.append(int(line[0])) 94 | yind.append(int(line[1])) 95 | adj.append([int(line[0]), int(line[1])]) 96 | f.close() 97 | ##print(len(adj)) 98 | 99 | f = open('data/group.txt', 'r') 100 | label = [] 101 | for line in f.readlines(): 102 | line = line.split() 103 | label.append(int(line[1])) 104 | f.close() 105 | 106 | f = open('data/tfidf.txt', 'r') 107 | fea_idx = [] 108 | fea = [] 109 | adj = np.array(adj) 110 | adj = np.vstack((adj, adj[:, [1, 0]])) 111 | adj = np.unique(adj, axis=0) 112 | 113 | labelset = np.unique(label) 114 | labeldict = dict(zip(labelset, range(len(labelset)))) 115 | label = np.array([labeldict[x] for x in label]) 116 | adj = sp.csr_matrix((np.ones(len(adj)), (adj[:, 0], adj[:, 1])), shape=(len(label), len(label))) 117 | 118 | for line in f.readlines(): 119 | line = line.split() 120 | fea_idx.append([int(line[0]), int(line[1])]) 121 | fea.append(float(line[2])) 122 | f.close() 123 | 124 | fea_idx = np.array(fea_idx) 125 | features = sp.csr_matrix((fea, (fea_idx[:, 0], fea_idx[:, 1])), shape=(len(label), 4973)).toarray() 126 | scaler = preprocess.MinMaxScaler() 127 | # features = preprocess.normalize(features, norm='l2') 128 | features = scaler.fit_transform(features) 129 | features = torch.FloatTensor(features) 130 | 131 | return adj, features, label 132 | 133 | 134 | def parse_index_file(filename): 135 | index = [] 136 | for line in open(filename): 137 | index.append(int(line.strip())) 138 | return index 139 | 140 | 141 | def sparse_to_tuple(sparse_mx): 142 | if not sp.isspmatrix_coo(sparse_mx): 143 | sparse_mx = sparse_mx.tocoo() 144 | coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose() 145 | values = sparse_mx.data 146 | shape = sparse_mx.shape 147 | return coords, values, shape 148 | 149 | 150 | def decompose(adj, dataset, norm='sym', renorm=True): 151 | adj = sp.coo_matrix(adj) 152 | ident = sp.eye(adj.shape[0]) 153 | if renorm: 154 | adj_ = adj + ident 155 | else: 156 | adj_ = adj 157 | 158 | rowsum = np.array(adj_.sum(1)) 159 | 160 | if norm == 'sym': 161 | degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten()) 162 | adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo() 163 | laplacian = ident - adj_normalized 164 | evalue, evector = np.linalg.eig(laplacian.toarray()) 165 | np.save(dataset + ".npy", evalue) 166 | print(max(evalue)) 167 | exit(1) 168 | fig = plt.figure() 169 | ax = fig.add_subplot(1, 1, 1) 170 | n, bins, patches = ax.hist(evalue, 50, facecolor='g') 171 | plt.xlabel('Eigenvalues') 172 | plt.ylabel('Frequncy') 173 | fig.savefig("eig_renorm_" + dataset + ".png") 174 | 175 | 176 | def preprocess_graph(adj, layer, norm='sym', renorm=True): 177 | adj = sp.coo_matrix(adj) 178 | ident = sp.eye(adj.shape[0]) 179 | if renorm: 180 | adj_ = adj + ident 181 | else: 182 | adj_ = adj 183 | 184 | rowsum = np.array(adj_.sum(1)) 185 | 186 | if norm == 'sym': 187 | degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten()) 188 | adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo() 189 | laplacian = ident - adj_normalized 190 | elif norm == 'left': 191 | degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -1.).flatten()) 192 | adj_normalized = degree_mat_inv_sqrt.dot(adj_).tocoo() 193 | laplacian = ident - adj_normalized 194 | 195 | # reg = [2 / 3] * (layer) 196 | reg = [1] * (layer) 197 | 198 | adjs = [] 199 | for i in range(len(reg)): 200 | adjs.append(ident - (reg[i] * laplacian)) 201 | 202 | return adjs 203 | 204 | 205 | def laplacian(adj): 206 | rowsum = np.array(adj.sum(1)) 207 | degree_mat = sp.diags(rowsum.flatten()) 208 | lap = degree_mat - adj 209 | return torch.FloatTensor(lap.toarray()) 210 | 211 | 212 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 213 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 214 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 215 | indices = torch.from_numpy( 216 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 217 | values = torch.from_numpy(sparse_mx.data) 218 | shape = torch.Size(sparse_mx.shape) 219 | return torch.sparse.FloatTensor(indices, values, shape) 220 | 221 | 222 | def get_roc_score(emb, adj_orig, edges_pos, edges_neg): 223 | def sigmoid(x): 224 | return 1 / (1 + np.exp(-x)) 225 | 226 | # Predict on test set of edges 227 | adj_rec = np.dot(emb, emb.T) 228 | preds = [] 229 | pos = [] 230 | for e in edges_pos: 231 | preds.append(sigmoid(adj_rec[e[0], e[1]])) 232 | pos.append(adj_orig[e[0], e[1]]) 233 | 234 | preds_neg = [] 235 | neg = [] 236 | for e in edges_neg: 237 | preds_neg.append(sigmoid(adj_rec[e[0], e[1]])) 238 | neg.append(adj_orig[e[0], e[1]]) 239 | 240 | preds_all = np.hstack([preds, preds_neg]) 241 | labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds))]) 242 | roc_score = roc_auc_score(labels_all, preds_all) 243 | ap_score = average_precision_score(labels_all, preds_all) 244 | 245 | return roc_score, ap_score 246 | 247 | 248 | def cluster_acc(y_true, y_pred): 249 | """ 250 | calculate clustering acc and f1-score 251 | Args: 252 | y_true: the ground truth 253 | y_pred: the clustering id 254 | 255 | Returns: acc and f1-score 256 | """ 257 | y_true = y_true - np.min(y_true) 258 | l1 = list(set(y_true)) 259 | num_class1 = len(l1) 260 | l2 = list(set(y_pred)) 261 | num_class2 = len(l2) 262 | ind = 0 263 | if num_class1 != num_class2: 264 | for i in l1: 265 | if i in l2: 266 | pass 267 | else: 268 | y_pred[ind] = i 269 | ind += 1 270 | l2 = list(set(y_pred)) 271 | numclass2 = len(l2) 272 | if num_class1 != numclass2: 273 | print('error') 274 | return 275 | cost = np.zeros((num_class1, numclass2), dtype=int) 276 | for i, c1 in enumerate(l1): 277 | mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1] 278 | for j, c2 in enumerate(l2): 279 | mps_d = [i1 for i1 in mps if y_pred[i1] == c2] 280 | cost[i][j] = len(mps_d) 281 | m = Munkres() 282 | cost = cost.__neg__().tolist() 283 | indexes = m.compute(cost) 284 | new_predict = np.zeros(len(y_pred)) 285 | for i, c in enumerate(l1): 286 | c2 = l2[indexes[i][1]] 287 | ai = [ind for ind, elm in enumerate(y_pred) if elm == c2] 288 | new_predict[ai] = c 289 | acc = metrics.accuracy_score(y_true, new_predict) 290 | f1_macro = metrics.f1_score(y_true, new_predict, average='macro') 291 | return acc, f1_macro 292 | 293 | 294 | def eva(y_true, y_pred, show_details=True): 295 | """ 296 | evaluate the clustering performance 297 | Args: 298 | y_true: the ground truth 299 | y_pred: the predicted label 300 | show_details: if print the details 301 | Returns: None 302 | """ 303 | acc, f1 = cluster_acc(y_true, y_pred) 304 | nmi = nmi_score(y_true, y_pred, average_method='arithmetic') 305 | ari = ari_score(y_true, y_pred) 306 | if show_details: 307 | print(':acc {:.4f}'.format(acc), ', nmi {:.4f}'.format(nmi), ', ari {:.4f}'.format(ari), 308 | ', f1 {:.4f}'.format(f1)) 309 | return acc, nmi, ari, f1 310 | 311 | 312 | def load_graph_data(dataset_name, show_details=False): 313 | """ 314 | load graph data 315 | :param dataset_name: the name of the dataset 316 | :param show_details: if show the details of dataset 317 | - dataset name 318 | - features' shape 319 | - labels' shape 320 | - adj shape 321 | - edge num 322 | - category num 323 | - category distribution 324 | :return: the features, labels and adj 325 | """ 326 | load_path = "dataset/" + dataset_name + "/" + dataset_name 327 | feat = np.load(load_path+"_feat.npy", allow_pickle=True) 328 | label = np.load(load_path+"_label.npy", allow_pickle=True) 329 | adj = np.load(load_path+"_adj.npy", allow_pickle=True) 330 | if show_details: 331 | print("++++++++++++++++++++++++++++++") 332 | print("---details of graph dataset---") 333 | print("++++++++++++++++++++++++++++++") 334 | print("dataset name: ", dataset_name) 335 | print("feature shape: ", feat.shape) 336 | print("label shape: ", label.shape) 337 | print("adj shape: ", adj.shape) 338 | print("undirected edge num: ", int(np.nonzero(adj)[0].shape[0]/2)) 339 | print("category num: ", max(label)-min(label)+1) 340 | print("category distribution: ") 341 | for i in range(max(label)+1): 342 | print("label", i, end=":") 343 | print(len(label[np.where(label == i)])) 344 | print("++++++++++++++++++++++++++++++") 345 | 346 | return feat, label, adj 347 | 348 | 349 | def normalize_adj(adj, self_loop=True, symmetry=False): 350 | """ 351 | normalize the adj matrix 352 | :param adj: input adj matrix 353 | :param self_loop: if add the self loop or not 354 | :param symmetry: symmetry normalize or not 355 | :return: the normalized adj matrix 356 | """ 357 | # add the self_loop 358 | if self_loop: 359 | adj_tmp = adj + np.eye(adj.shape[0]) 360 | else: 361 | adj_tmp = adj 362 | 363 | # calculate degree matrix and it's inverse matrix 364 | d = np.diag(adj_tmp.sum(0)) 365 | d_inv = np.linalg.inv(d) 366 | 367 | # symmetry normalize: D^{-0.5} A D^{-0.5} 368 | if symmetry: 369 | sqrt_d_inv = np.sqrt(d_inv) 370 | norm_adj = np.matmul(np.matmul(sqrt_d_inv, adj_tmp), adj_tmp) 371 | 372 | # non-symmetry normalize: D^{-1} A 373 | else: 374 | norm_adj = np.matmul(d_inv, adj_tmp) 375 | 376 | return norm_adj 377 | 378 | 379 | def setup_seed(seed): 380 | """ 381 | setup random seed to fix the result 382 | Args: 383 | seed: random seed 384 | Returns: None 385 | """ 386 | torch.manual_seed(seed) 387 | torch.cuda.manual_seed(seed) 388 | torch.cuda.manual_seed_all(seed) 389 | np.random.seed(seed) 390 | random.seed(seed) 391 | torch.manual_seed(seed) 392 | torch.backends.cudnn.benchmark = False 393 | torch.backends.cudnn.deterministic = True 394 | 395 | 396 | def clustering(feature, true_labels, cluster_num): 397 | predict_labels, dis, initial = kmeans(X=feature, num_clusters=cluster_num, distance="euclidean", device="cuda") 398 | acc, nmi, ari, f1 = eva(true_labels, predict_labels.numpy(), show_details=False) 399 | return 100 * acc, 100 * nmi, 100 * ari, 100 * f1, predict_labels.numpy(),dis --------------------------------------------------------------------------------