├── data
├── ind.cora.tx
├── ind.cora.ty
├── ind.cora.x
├── ind.cora.y
├── ind.cora.allx
├── ind.cora.ally
├── ind.cora.graph
└── ind.cora.test.index
├── assets
├── CCGC_model.png
├── CCGC_tsne.png
├── HSAN_model.png
├── HSAN_tsne.png
├── CCGC_result.png
├── HSAN_result.png
└── HSAN_parameter.png
├── requirements.txt
├── model.py
├── layers.py
├── README.md
├── kmeans_gpu.py
├── train.py
└── utils.py
/data/ind.cora.tx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.tx
--------------------------------------------------------------------------------
/data/ind.cora.ty:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.ty
--------------------------------------------------------------------------------
/data/ind.cora.x:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.x
--------------------------------------------------------------------------------
/data/ind.cora.y:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.y
--------------------------------------------------------------------------------
/data/ind.cora.allx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.allx
--------------------------------------------------------------------------------
/data/ind.cora.ally:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.ally
--------------------------------------------------------------------------------
/data/ind.cora.graph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/data/ind.cora.graph
--------------------------------------------------------------------------------
/assets/CCGC_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/CCGC_model.png
--------------------------------------------------------------------------------
/assets/CCGC_tsne.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/CCGC_tsne.png
--------------------------------------------------------------------------------
/assets/HSAN_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/HSAN_model.png
--------------------------------------------------------------------------------
/assets/HSAN_tsne.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/HSAN_tsne.png
--------------------------------------------------------------------------------
/assets/CCGC_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/CCGC_result.png
--------------------------------------------------------------------------------
/assets/HSAN_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/HSAN_result.png
--------------------------------------------------------------------------------
/assets/HSAN_parameter.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xihongyang1999/CCGC/HEAD/assets/HSAN_parameter.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.8.0
2 | tqdm==4.61.2
3 | numpy==1.21.0
4 | munkres==1.1.4
5 | scikit_learn==1.0
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers import *
5 | from torch.nn import Linear
6 | from torch_geometric.nn import GATConv, GCNConv, ChebConv
7 | from torch_geometric.nn import JumpingKnowledge
8 | from torch_geometric.nn import MessagePassing, APPNP
9 |
10 |
11 | class Encoder_Net(nn.Module):
12 | def __init__(self, layers, dims):
13 | super(Encoder_Net, self).__init__()
14 | self.layers1 = nn.Linear(dims[0], dims[1])
15 | self.layers2 = nn.Linear(dims[0], dims[1])
16 |
17 | def forward(self, x):
18 | out1 = self.layers1(x)
19 | out2 = self.layers2(x)
20 |
21 | out1 = F.normalize(out1, dim=1, p=2)
22 | out2 = F.normalize(out2, dim=1, p=2)
23 |
24 | return out1, out2
25 |
26 |
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.modules.module import Module
5 | from torch.nn.parameter import Parameter
6 | import numpy as np
7 |
8 | class GraphConvolution(Module):
9 | """
10 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
11 | """
12 |
13 | def __init__(self, in_features, out_features, dropout=0., act=F.relu):
14 | super(GraphConvolution, self).__init__()
15 | self.in_features = in_features
16 | self.out_features = out_features
17 | self.dropout = dropout
18 | self.act = act
19 | self.weight = Parameter(torch.FloatTensor(in_features, out_features))
20 | self.reset_parameters()
21 |
22 | def reset_parameters(self):
23 | torch.nn.init.xavier_uniform_(self.weight)
24 |
25 | def forward(self, input, adj):
26 | input = F.dropout(input, self.dropout, self.training)
27 | support = torch.mm(input, self.weight)
28 | output = torch.spmm(adj, support)
29 | output = self.act(output)
30 | return output
31 |
32 | def __repr__(self):
33 | return self.__class__.__name__ + ' (' \
34 | + str(self.in_features) + ' -> ' \
35 | + str(self.out_features) + ')'
36 |
37 | class SampleDecoder(Module):
38 | def __init__(self, act=torch.sigmoid):
39 | super(SampleDecoder, self).__init__()
40 | self.act = act
41 |
42 | def forward(self, zx, zy):
43 |
44 | sim = (zx * zy).sum(1)
45 | sim = self.act(sim)
46 |
47 | return sim
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [stars-img]: https://img.shields.io/github/stars/xihongyang1999/CCGC?color=yellow
2 | [stars-url]: https://github.com/xihongyang1999/CCGC/stargazers
3 | [fork-img]: https://img.shields.io/github/forks/xihongyang1999/CCGC?color=lightblue&label=fork
4 | [fork-url]: https://github.com/xihongyang1999/CCGC/network/members
5 | [visitors-img]: https://visitor-badge.glitch.me/badge?page_id=xihongyang1999.CCGC
6 | [adgc-url]: https://github.com/xihongyang1999/CCGC
7 |
8 | # Cluster-guided Contrastive Graph Clustering Network
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | [![GitHub stars][stars-img]][stars-url]
19 | [![GitHub forks][fork-img]][fork-url]
20 | [![visitors][visitors-img]][adgc-url]
21 |
22 |
23 | An official source code for paper Cluster-guided Contrastive Graph Clustering Network , accepted by AAAI 2023. Any communications or issues are welcomed. Please contact xihong_edu@163.com. If you find this repository useful to your research or work, it is really appreciate to star this repository. :heart:
24 |
25 | -------------
26 |
27 | ### Overview
28 |
29 |
30 | We propose a Cluster-guided Contrastive deep Graph Clustering network (CCGC) by mining the intrinsic supervision information in the high-confidence clustering results. Specifically, instead of conducting complex node or edge perturbation, we construct two views of the graph by designing special Siamese encoders whose weights are not shared between the sibling sub-networks. Then, guided by the high-confidence clustering information, we carefully select and construct the positive samples from the same high-confidence cluster in two views. Moreover, to construct semantic meaningful negative sample pairs, we regard the centers of different high-confidence clusters as negative samples, thus improving the discriminative capability and reliability of the constructed sample pairs. Lastly, we design an objective function to pull close the samples from the same cluster while pushing away those from other clusters by maximizing and minimizing the cross-view cosine similarity between positive and negative samples. Extensive experimental results on six datasets demonstrate the effectiveness of CCGC compared with the existing state-of-the-art algorithms.
31 |
32 |
33 |
34 |
35 |
36 |

37 |
38 |
39 |
40 | Figure 1: Overall framework of CCGC.
41 |
42 |
43 |
44 |
45 | ### Requirements
46 |
47 | The proposed CCGC is implemented with python 3.7 on a NVIDIA 2080Ti GPU.
48 |
49 | Python package information is summarized in **requirements.txt**:
50 |
51 | - torch==1.7.1
52 | - tqdm==4.59.0
53 | - numpy==1.19.2
54 | - munkres==1.1.4
55 | - scikit_learn==1.2.0
56 |
57 |
58 |
59 |
60 |
61 | ### Quick Start
62 |
63 | python train.py
64 |
65 |
66 |
67 | ### Clustering Results
68 |
69 |
70 |

71 |
72 |
73 | Table 1: Clustering results of our proposed CCGC and twelve baselines on six datasets.
74 |
75 |
76 |
77 |
78 |

79 |
80 |
81 |
82 |
83 | Figure 2: 2D t-SNE visualization of six methods on two datasets.
84 |
85 |
86 |
87 |
88 | ### Citation
89 |
90 | If you find this project useful for your research, please cite your paper with the following BibTeX entry.
91 |
92 | ```
93 | @inproceedings{CCGC,
94 | title={Cluster-guided Contrastive Graph Clustering Network},
95 | author={Yang, Xihong and Liu, Yue and Zhou, Sihang and Wang, Siwei and Tu, Wenxuan and Zheng, Qun and Liu, Xinwang and Fang, Liming and Zhu, En},
96 | booktitle={Proceedings of the AAAI conference on artificial intelligence},
97 | volume={37},
98 | number={9},
99 | pages={10834--10842},
100 | year={2023}
101 | }
102 | ```
103 |
--------------------------------------------------------------------------------
/kmeans_gpu.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 |
5 |
6 | def setup_seed(seed):
7 | """
8 | setup random seed to fix the result
9 | Args:
10 | seed: random seed
11 | Returns: None
12 | """
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 | np.random.seed(seed)
17 | torch.manual_seed(seed)
18 | torch.backends.cudnn.benchmark = False
19 | torch.backends.cudnn.deterministic = True
20 |
21 |
22 | def initialize(X, num_clusters):
23 | """
24 | initialize cluster centers
25 | :param X: (torch.tensor) matrix
26 | :param num_clusters: (int) number of clusters
27 | :return: (np.array) initial state
28 | """
29 | num_samples = len(X)
30 | indices = np.random.choice(num_samples, num_clusters, replace=False)
31 | initial_state = X[indices]
32 | return initial_state
33 |
34 |
35 | def kmeans(
36 | X,
37 | num_clusters,
38 | distance='euclidean',
39 | tol=1e-4,
40 | device=torch.device('cuda')
41 | ):
42 | """
43 | perform kmeans
44 | :param X: (torch.tensor) matrix
45 | :param num_clusters: (int) number of clusters
46 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
47 | :param tol: (float) threshold [default: 0.0001]
48 | :param device: (torch.device) device [default: cpu]
49 | :return: (torch.tensor, torch.tensor) cluster ids, cluster centers
50 | """
51 | # print(f'running k-means on {device}..')
52 | if distance == 'euclidean':
53 | pairwise_distance_function = pairwise_distance
54 | elif distance == 'cosine':
55 | pairwise_distance_function = pairwise_cosine
56 | else:
57 | raise NotImplementedError
58 |
59 | # convert to float
60 | X = X.float()
61 |
62 | # transfer to device
63 | X = X.to(device)
64 |
65 | # initialize
66 | dis_min = float('inf')
67 | initial_state_best = None
68 | for i in range(20):
69 | initial_state = initialize(X, num_clusters)
70 | dis = pairwise_distance_function(X, initial_state).sum()
71 | if dis < dis_min:
72 | dis_min = dis
73 | initial_state_best = initial_state
74 |
75 | initial_state = initial_state_best
76 | iteration = 0
77 | while True:
78 | dis = pairwise_distance_function(X, initial_state)
79 |
80 | choice_cluster = torch.argmin(dis, dim=1)
81 |
82 | initial_state_pre = initial_state.clone()
83 |
84 | for index in range(num_clusters):
85 | selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
86 |
87 | selected = torch.index_select(X, 0, selected)
88 | initial_state[index] = selected.mean(dim=0)
89 |
90 | center_shift = torch.sum(
91 | torch.sqrt(
92 | torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
93 | ))
94 |
95 | # increment iteration
96 | iteration = iteration + 1
97 |
98 | if iteration > 500:
99 | break
100 | if center_shift ** 2 < tol:
101 | break
102 |
103 | return choice_cluster.cpu(), dis.cpu(), initial_state.cpu()
104 |
105 |
106 | def kmeans_predict(
107 | X,
108 | cluster_centers,
109 | distance='euclidean',
110 | device=torch.device('cuda')
111 | ):
112 | """
113 | predict using cluster centers
114 | :param X: (torch.tensor) matrix
115 | :param cluster_centers: (torch.tensor) cluster centers
116 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
117 | :param device: (torch.device) device [default: 'cpu']
118 | :return: (torch.tensor) cluster ids
119 | """
120 | # print(f'predicting on {device}..')
121 |
122 | if distance == 'euclidean':
123 | pairwise_distance_function = pairwise_distance
124 | elif distance == 'cosine':
125 | pairwise_distance_function = pairwise_cosine
126 | else:
127 | raise NotImplementedError
128 |
129 | # convert to float
130 | X = X.float()
131 |
132 | # transfer to device
133 | X = X.to(device)
134 |
135 | dis = pairwise_distance_function(X, cluster_centers)
136 | choice_cluster = torch.argmin(dis, dim=1)
137 |
138 | return choice_cluster.cpu()
139 |
140 |
141 | def pairwise_distance(data1, data2, device=torch.device('cuda')):
142 | # transfer to device
143 | data1, data2 = data1.to(device), data2.to(device)
144 |
145 | # N*1*M
146 | A = data1.unsqueeze(dim=1)
147 |
148 | # 1*N*M
149 | B = data2.unsqueeze(dim=0)
150 |
151 | dis = (A - B) ** 2.0
152 | # return N*N matrix for pairwise distance
153 | dis = dis.sum(dim=-1).squeeze()
154 | return dis
155 |
156 |
157 | def pairwise_cosine(data1, data2, device=torch.device('cuda')):
158 | # transfer to device
159 | data1, data2 = data1.to(device), data2.to(device)
160 |
161 | # N*1*M
162 | A = data1.unsqueeze(dim=1)
163 |
164 | # 1*N*M
165 | B = data2.unsqueeze(dim=0)
166 |
167 | # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
168 | A_normalized = A / A.norm(dim=-1, keepdim=True)
169 | B_normalized = B / B.norm(dim=-1, keepdim=True)
170 |
171 | cosine = A_normalized * B_normalized
172 |
173 | # return N*N matrix for pairwise distance
174 | cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
175 | return cosine_dis
176 |
177 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from utils import *
3 | from tqdm import tqdm
4 | from torch import optim
5 | from model import Encoder_Net
6 | import torch.nn.functional as F
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--t', type=int, default=4, help="Number of gnn layers")
10 | parser.add_argument('--linlayers', type=int, default=1, help="Number of hidden layers")
11 | parser.add_argument('--epochs', type=int, default=400, help='Number of epochs to train.')
12 | parser.add_argument('--dims', type=int, default=500, help='feature dim')
13 | parser.add_argument('--lr', type=float, default=1e-4, help='Initial learning rate.')
14 | parser.add_argument('--dataset', type=str, default='cora', help='name of dataset.')
15 | parser.add_argument('--cluster_num', type=int, default=7, help='number of cluster.')
16 | parser.add_argument('--device', type=str, default='cuda', help='the training device')
17 | parser.add_argument('--threshold', type=float, default=0.5, help='the threshold of high-confidence')
18 | parser.add_argument('--alpha', type=float, default=0.5, help='trade-off of loss')
19 | args = parser.parse_args()
20 |
21 | #load data
22 | adj, features, true_labels, idx_train, idx_val, idx_test = load_data(args.dataset)
23 | adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
24 | adj.eliminate_zeros()
25 |
26 | # Laplacian Smoothing
27 | adj_norm_s = preprocess_graph(adj, args.t, norm='sym', renorm=True)
28 | smooth_fea = sp.csr_matrix(features).toarray()
29 | for a in adj_norm_s:
30 | smooth_fea = a.dot(smooth_fea)
31 | smooth_fea = torch.FloatTensor(smooth_fea)
32 |
33 | acc_list = []
34 | nmi_list = []
35 | ari_list = []
36 | f1_list = []
37 |
38 | for seed in range(10):
39 |
40 | setup_seed(seed)
41 |
42 | # init
43 | best_acc, best_nmi, best_ari, best_f1, predict_labels, dis= clustering(smooth_fea, true_labels, args.cluster_num)
44 |
45 | # MLP
46 | model = Encoder_Net(args.linlayers, [features.shape[1]] + [args.dims])
47 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
48 |
49 | # GPU
50 | model.to(args.device)
51 | smooth_fea = smooth_fea.to(args.device)
52 | sample_size = features.shape[0]
53 | target = torch.eye(smooth_fea.shape[0]).to(args.device)
54 |
55 | for epoch in tqdm(range(args.epochs)):
56 | model.train()
57 | z1, z2 = model(smooth_fea)
58 | if epoch > 50:
59 |
60 | high_confidence = torch.min(dis, dim=1).values
61 | threshold = torch.sort(high_confidence).values[int(len(high_confidence) * args.threshold)]
62 | high_confidence_idx = np.argwhere(high_confidence < threshold)[0]
63 |
64 | # pos samples
65 | index = torch.tensor(range(smooth_fea.shape[0]), device=args.device)[high_confidence_idx]
66 | y_sam = torch.tensor(predict_labels, device=args.device)[high_confidence_idx]
67 | index = index[torch.argsort(y_sam)]
68 | class_num = {}
69 |
70 | for label in torch.sort(y_sam).values:
71 | label = label.item()
72 | if label in class_num.keys():
73 | class_num[label] += 1
74 | else:
75 | class_num[label] = 1
76 | key = sorted(class_num.keys())
77 | if len(class_num) < 2:
78 | continue
79 | pos_contrastive = 0
80 | centers_1 = torch.tensor([], device=args.device)
81 | centers_2 = torch.tensor([], device=args.device)
82 |
83 |
84 | for i in range(len(key[:-1])):
85 | class_num[key[i + 1]] = class_num[key[i]] + class_num[key[i + 1]]
86 | now = index[class_num[key[i]]:class_num[key[i + 1]]]
87 | pos_embed_1 = z1[np.random.choice(now.cpu(), size=int((now.shape[0] * 0.8)), replace=False)]
88 | pos_embed_2 = z2[np.random.choice(now.cpu(), size=int((now.shape[0] * 0.8)), replace=False)]
89 | pos_contrastive += (2 - 2 * torch.sum(pos_embed_1 * pos_embed_2, dim=1)).sum()
90 | centers_1 = torch.cat([centers_1, torch.mean(z1[now], dim=0).unsqueeze(0)], dim=0)
91 | centers_2 = torch.cat([centers_2, torch.mean(z2[now], dim=0).unsqueeze(0)], dim=0)
92 |
93 | pos_contrastive = pos_contrastive / args.cluster_num
94 | if pos_contrastive == 0:
95 | continue
96 | if len(class_num) < 2:
97 | loss = pos_contrastive
98 | else:
99 | centers_1 = F.normalize(centers_1, dim=1, p=2)
100 | centers_2 = F.normalize(centers_2, dim=1, p=2)
101 | S = centers_1 @ centers_2.T
102 | S_diag = torch.diag_embed(torch.diag(S))
103 | S = S - S_diag
104 | neg_contrastive = F.mse_loss(S, torch.zeros_like(S))
105 | loss = pos_contrastive + args.alpha * neg_contrastive
106 |
107 | else:
108 | S = z1 @ z2.T
109 | loss = F.mse_loss(S, target)
110 |
111 | loss.backward(retain_graph=True)
112 | optimizer.step()
113 |
114 | if epoch % 10 == 0:
115 | model.eval()
116 | z1, z2 = model(smooth_fea)
117 |
118 | hidden_emb = (z1 + z2) / 2
119 |
120 | acc, nmi, ari, f1, predict_labels, dis = clustering(hidden_emb, true_labels, args.cluster_num)
121 | if acc >= best_acc:
122 | best_acc = acc
123 | best_nmi = nmi
124 | best_ari = ari
125 | best_f1 = f1
126 |
127 | acc_list.append(best_acc)
128 | nmi_list.append(best_nmi)
129 | ari_list.append(best_ari)
130 | f1_list.append(best_f1)
131 |
132 | acc_list = np.array(acc_list)
133 | nmi_list = np.array(nmi_list)
134 | ari_list = np.array(ari_list)
135 | f1_list = np.array(f1_list)
136 | print(acc_list.mean(), "±", acc_list.std())
137 | print(nmi_list.mean(), "±", nmi_list.std())
138 | print(ari_list.mean(), "±", ari_list.std())
139 | print(f1_list.mean(), "±", f1_list.std())
140 |
--------------------------------------------------------------------------------
/data/ind.cora.test.index:
--------------------------------------------------------------------------------
1 | 2692
2 | 2532
3 | 2050
4 | 1715
5 | 2362
6 | 2609
7 | 2622
8 | 1975
9 | 2081
10 | 1767
11 | 2263
12 | 1725
13 | 2588
14 | 2259
15 | 2357
16 | 1998
17 | 2574
18 | 2179
19 | 2291
20 | 2382
21 | 1812
22 | 1751
23 | 2422
24 | 1937
25 | 2631
26 | 2510
27 | 2378
28 | 2589
29 | 2345
30 | 1943
31 | 1850
32 | 2298
33 | 1825
34 | 2035
35 | 2507
36 | 2313
37 | 1906
38 | 1797
39 | 2023
40 | 2159
41 | 2495
42 | 1886
43 | 2122
44 | 2369
45 | 2461
46 | 1925
47 | 2565
48 | 1858
49 | 2234
50 | 2000
51 | 1846
52 | 2318
53 | 1723
54 | 2559
55 | 2258
56 | 1763
57 | 1991
58 | 1922
59 | 2003
60 | 2662
61 | 2250
62 | 2064
63 | 2529
64 | 1888
65 | 2499
66 | 2454
67 | 2320
68 | 2287
69 | 2203
70 | 2018
71 | 2002
72 | 2632
73 | 2554
74 | 2314
75 | 2537
76 | 1760
77 | 2088
78 | 2086
79 | 2218
80 | 2605
81 | 1953
82 | 2403
83 | 1920
84 | 2015
85 | 2335
86 | 2535
87 | 1837
88 | 2009
89 | 1905
90 | 2636
91 | 1942
92 | 2193
93 | 2576
94 | 2373
95 | 1873
96 | 2463
97 | 2509
98 | 1954
99 | 2656
100 | 2455
101 | 2494
102 | 2295
103 | 2114
104 | 2561
105 | 2176
106 | 2275
107 | 2635
108 | 2442
109 | 2704
110 | 2127
111 | 2085
112 | 2214
113 | 2487
114 | 1739
115 | 2543
116 | 1783
117 | 2485
118 | 2262
119 | 2472
120 | 2326
121 | 1738
122 | 2170
123 | 2100
124 | 2384
125 | 2152
126 | 2647
127 | 2693
128 | 2376
129 | 1775
130 | 1726
131 | 2476
132 | 2195
133 | 1773
134 | 1793
135 | 2194
136 | 2581
137 | 1854
138 | 2524
139 | 1945
140 | 1781
141 | 1987
142 | 2599
143 | 1744
144 | 2225
145 | 2300
146 | 1928
147 | 2042
148 | 2202
149 | 1958
150 | 1816
151 | 1916
152 | 2679
153 | 2190
154 | 1733
155 | 2034
156 | 2643
157 | 2177
158 | 1883
159 | 1917
160 | 1996
161 | 2491
162 | 2268
163 | 2231
164 | 2471
165 | 1919
166 | 1909
167 | 2012
168 | 2522
169 | 1865
170 | 2466
171 | 2469
172 | 2087
173 | 2584
174 | 2563
175 | 1924
176 | 2143
177 | 1736
178 | 1966
179 | 2533
180 | 2490
181 | 2630
182 | 1973
183 | 2568
184 | 1978
185 | 2664
186 | 2633
187 | 2312
188 | 2178
189 | 1754
190 | 2307
191 | 2480
192 | 1960
193 | 1742
194 | 1962
195 | 2160
196 | 2070
197 | 2553
198 | 2433
199 | 1768
200 | 2659
201 | 2379
202 | 2271
203 | 1776
204 | 2153
205 | 1877
206 | 2027
207 | 2028
208 | 2155
209 | 2196
210 | 2483
211 | 2026
212 | 2158
213 | 2407
214 | 1821
215 | 2131
216 | 2676
217 | 2277
218 | 2489
219 | 2424
220 | 1963
221 | 1808
222 | 1859
223 | 2597
224 | 2548
225 | 2368
226 | 1817
227 | 2405
228 | 2413
229 | 2603
230 | 2350
231 | 2118
232 | 2329
233 | 1969
234 | 2577
235 | 2475
236 | 2467
237 | 2425
238 | 1769
239 | 2092
240 | 2044
241 | 2586
242 | 2608
243 | 1983
244 | 2109
245 | 2649
246 | 1964
247 | 2144
248 | 1902
249 | 2411
250 | 2508
251 | 2360
252 | 1721
253 | 2005
254 | 2014
255 | 2308
256 | 2646
257 | 1949
258 | 1830
259 | 2212
260 | 2596
261 | 1832
262 | 1735
263 | 1866
264 | 2695
265 | 1941
266 | 2546
267 | 2498
268 | 2686
269 | 2665
270 | 1784
271 | 2613
272 | 1970
273 | 2021
274 | 2211
275 | 2516
276 | 2185
277 | 2479
278 | 2699
279 | 2150
280 | 1990
281 | 2063
282 | 2075
283 | 1979
284 | 2094
285 | 1787
286 | 2571
287 | 2690
288 | 1926
289 | 2341
290 | 2566
291 | 1957
292 | 1709
293 | 1955
294 | 2570
295 | 2387
296 | 1811
297 | 2025
298 | 2447
299 | 2696
300 | 2052
301 | 2366
302 | 1857
303 | 2273
304 | 2245
305 | 2672
306 | 2133
307 | 2421
308 | 1929
309 | 2125
310 | 2319
311 | 2641
312 | 2167
313 | 2418
314 | 1765
315 | 1761
316 | 1828
317 | 2188
318 | 1972
319 | 1997
320 | 2419
321 | 2289
322 | 2296
323 | 2587
324 | 2051
325 | 2440
326 | 2053
327 | 2191
328 | 1923
329 | 2164
330 | 1861
331 | 2339
332 | 2333
333 | 2523
334 | 2670
335 | 2121
336 | 1921
337 | 1724
338 | 2253
339 | 2374
340 | 1940
341 | 2545
342 | 2301
343 | 2244
344 | 2156
345 | 1849
346 | 2551
347 | 2011
348 | 2279
349 | 2572
350 | 1757
351 | 2400
352 | 2569
353 | 2072
354 | 2526
355 | 2173
356 | 2069
357 | 2036
358 | 1819
359 | 1734
360 | 1880
361 | 2137
362 | 2408
363 | 2226
364 | 2604
365 | 1771
366 | 2698
367 | 2187
368 | 2060
369 | 1756
370 | 2201
371 | 2066
372 | 2439
373 | 1844
374 | 1772
375 | 2383
376 | 2398
377 | 1708
378 | 1992
379 | 1959
380 | 1794
381 | 2426
382 | 2702
383 | 2444
384 | 1944
385 | 1829
386 | 2660
387 | 2497
388 | 2607
389 | 2343
390 | 1730
391 | 2624
392 | 1790
393 | 1935
394 | 1967
395 | 2401
396 | 2255
397 | 2355
398 | 2348
399 | 1931
400 | 2183
401 | 2161
402 | 2701
403 | 1948
404 | 2501
405 | 2192
406 | 2404
407 | 2209
408 | 2331
409 | 1810
410 | 2363
411 | 2334
412 | 1887
413 | 2393
414 | 2557
415 | 1719
416 | 1732
417 | 1986
418 | 2037
419 | 2056
420 | 1867
421 | 2126
422 | 1932
423 | 2117
424 | 1807
425 | 1801
426 | 1743
427 | 2041
428 | 1843
429 | 2388
430 | 2221
431 | 1833
432 | 2677
433 | 1778
434 | 2661
435 | 2306
436 | 2394
437 | 2106
438 | 2430
439 | 2371
440 | 2606
441 | 2353
442 | 2269
443 | 2317
444 | 2645
445 | 2372
446 | 2550
447 | 2043
448 | 1968
449 | 2165
450 | 2310
451 | 1985
452 | 2446
453 | 1982
454 | 2377
455 | 2207
456 | 1818
457 | 1913
458 | 1766
459 | 1722
460 | 1894
461 | 2020
462 | 1881
463 | 2621
464 | 2409
465 | 2261
466 | 2458
467 | 2096
468 | 1712
469 | 2594
470 | 2293
471 | 2048
472 | 2359
473 | 1839
474 | 2392
475 | 2254
476 | 1911
477 | 2101
478 | 2367
479 | 1889
480 | 1753
481 | 2555
482 | 2246
483 | 2264
484 | 2010
485 | 2336
486 | 2651
487 | 2017
488 | 2140
489 | 1842
490 | 2019
491 | 1890
492 | 2525
493 | 2134
494 | 2492
495 | 2652
496 | 2040
497 | 2145
498 | 2575
499 | 2166
500 | 1999
501 | 2434
502 | 1711
503 | 2276
504 | 2450
505 | 2389
506 | 2669
507 | 2595
508 | 1814
509 | 2039
510 | 2502
511 | 1896
512 | 2168
513 | 2344
514 | 2637
515 | 2031
516 | 1977
517 | 2380
518 | 1936
519 | 2047
520 | 2460
521 | 2102
522 | 1745
523 | 2650
524 | 2046
525 | 2514
526 | 1980
527 | 2352
528 | 2113
529 | 1713
530 | 2058
531 | 2558
532 | 1718
533 | 1864
534 | 1876
535 | 2338
536 | 1879
537 | 1891
538 | 2186
539 | 2451
540 | 2181
541 | 2638
542 | 2644
543 | 2103
544 | 2591
545 | 2266
546 | 2468
547 | 1869
548 | 2582
549 | 2674
550 | 2361
551 | 2462
552 | 1748
553 | 2215
554 | 2615
555 | 2236
556 | 2248
557 | 2493
558 | 2342
559 | 2449
560 | 2274
561 | 1824
562 | 1852
563 | 1870
564 | 2441
565 | 2356
566 | 1835
567 | 2694
568 | 2602
569 | 2685
570 | 1893
571 | 2544
572 | 2536
573 | 1994
574 | 1853
575 | 1838
576 | 1786
577 | 1930
578 | 2539
579 | 1892
580 | 2265
581 | 2618
582 | 2486
583 | 2583
584 | 2061
585 | 1796
586 | 1806
587 | 2084
588 | 1933
589 | 2095
590 | 2136
591 | 2078
592 | 1884
593 | 2438
594 | 2286
595 | 2138
596 | 1750
597 | 2184
598 | 1799
599 | 2278
600 | 2410
601 | 2642
602 | 2435
603 | 1956
604 | 2399
605 | 1774
606 | 2129
607 | 1898
608 | 1823
609 | 1938
610 | 2299
611 | 1862
612 | 2420
613 | 2673
614 | 1984
615 | 2204
616 | 1717
617 | 2074
618 | 2213
619 | 2436
620 | 2297
621 | 2592
622 | 2667
623 | 2703
624 | 2511
625 | 1779
626 | 1782
627 | 2625
628 | 2365
629 | 2315
630 | 2381
631 | 1788
632 | 1714
633 | 2302
634 | 1927
635 | 2325
636 | 2506
637 | 2169
638 | 2328
639 | 2629
640 | 2128
641 | 2655
642 | 2282
643 | 2073
644 | 2395
645 | 2247
646 | 2521
647 | 2260
648 | 1868
649 | 1988
650 | 2324
651 | 2705
652 | 2541
653 | 1731
654 | 2681
655 | 2707
656 | 2465
657 | 1785
658 | 2149
659 | 2045
660 | 2505
661 | 2611
662 | 2217
663 | 2180
664 | 1904
665 | 2453
666 | 2484
667 | 1871
668 | 2309
669 | 2349
670 | 2482
671 | 2004
672 | 1965
673 | 2406
674 | 2162
675 | 1805
676 | 2654
677 | 2007
678 | 1947
679 | 1981
680 | 2112
681 | 2141
682 | 1720
683 | 1758
684 | 2080
685 | 2330
686 | 2030
687 | 2432
688 | 2089
689 | 2547
690 | 1820
691 | 1815
692 | 2675
693 | 1840
694 | 2658
695 | 2370
696 | 2251
697 | 1908
698 | 2029
699 | 2068
700 | 2513
701 | 2549
702 | 2267
703 | 2580
704 | 2327
705 | 2351
706 | 2111
707 | 2022
708 | 2321
709 | 2614
710 | 2252
711 | 2104
712 | 1822
713 | 2552
714 | 2243
715 | 1798
716 | 2396
717 | 2663
718 | 2564
719 | 2148
720 | 2562
721 | 2684
722 | 2001
723 | 2151
724 | 2706
725 | 2240
726 | 2474
727 | 2303
728 | 2634
729 | 2680
730 | 2055
731 | 2090
732 | 2503
733 | 2347
734 | 2402
735 | 2238
736 | 1950
737 | 2054
738 | 2016
739 | 1872
740 | 2233
741 | 1710
742 | 2032
743 | 2540
744 | 2628
745 | 1795
746 | 2616
747 | 1903
748 | 2531
749 | 2567
750 | 1946
751 | 1897
752 | 2222
753 | 2227
754 | 2627
755 | 1856
756 | 2464
757 | 2241
758 | 2481
759 | 2130
760 | 2311
761 | 2083
762 | 2223
763 | 2284
764 | 2235
765 | 2097
766 | 1752
767 | 2515
768 | 2527
769 | 2385
770 | 2189
771 | 2283
772 | 2182
773 | 2079
774 | 2375
775 | 2174
776 | 2437
777 | 1993
778 | 2517
779 | 2443
780 | 2224
781 | 2648
782 | 2171
783 | 2290
784 | 2542
785 | 2038
786 | 1855
787 | 1831
788 | 1759
789 | 1848
790 | 2445
791 | 1827
792 | 2429
793 | 2205
794 | 2598
795 | 2657
796 | 1728
797 | 2065
798 | 1918
799 | 2427
800 | 2573
801 | 2620
802 | 2292
803 | 1777
804 | 2008
805 | 1875
806 | 2288
807 | 2256
808 | 2033
809 | 2470
810 | 2585
811 | 2610
812 | 2082
813 | 2230
814 | 1915
815 | 1847
816 | 2337
817 | 2512
818 | 2386
819 | 2006
820 | 2653
821 | 2346
822 | 1951
823 | 2110
824 | 2639
825 | 2520
826 | 1939
827 | 2683
828 | 2139
829 | 2220
830 | 1910
831 | 2237
832 | 1900
833 | 1836
834 | 2197
835 | 1716
836 | 1860
837 | 2077
838 | 2519
839 | 2538
840 | 2323
841 | 1914
842 | 1971
843 | 1845
844 | 2132
845 | 1802
846 | 1907
847 | 2640
848 | 2496
849 | 2281
850 | 2198
851 | 2416
852 | 2285
853 | 1755
854 | 2431
855 | 2071
856 | 2249
857 | 2123
858 | 1727
859 | 2459
860 | 2304
861 | 2199
862 | 1791
863 | 1809
864 | 1780
865 | 2210
866 | 2417
867 | 1874
868 | 1878
869 | 2116
870 | 1961
871 | 1863
872 | 2579
873 | 2477
874 | 2228
875 | 2332
876 | 2578
877 | 2457
878 | 2024
879 | 1934
880 | 2316
881 | 1841
882 | 1764
883 | 1737
884 | 2322
885 | 2239
886 | 2294
887 | 1729
888 | 2488
889 | 1974
890 | 2473
891 | 2098
892 | 2612
893 | 1834
894 | 2340
895 | 2423
896 | 2175
897 | 2280
898 | 2617
899 | 2208
900 | 2560
901 | 1741
902 | 2600
903 | 2059
904 | 1747
905 | 2242
906 | 2700
907 | 2232
908 | 2057
909 | 2147
910 | 2682
911 | 1792
912 | 1826
913 | 2120
914 | 1895
915 | 2364
916 | 2163
917 | 1851
918 | 2391
919 | 2414
920 | 2452
921 | 1803
922 | 1989
923 | 2623
924 | 2200
925 | 2528
926 | 2415
927 | 1804
928 | 2146
929 | 2619
930 | 2687
931 | 1762
932 | 2172
933 | 2270
934 | 2678
935 | 2593
936 | 2448
937 | 1882
938 | 2257
939 | 2500
940 | 1899
941 | 2478
942 | 2412
943 | 2107
944 | 1746
945 | 2428
946 | 2115
947 | 1800
948 | 1901
949 | 2397
950 | 2530
951 | 1912
952 | 2108
953 | 2206
954 | 2091
955 | 1740
956 | 2219
957 | 1976
958 | 2099
959 | 2142
960 | 2671
961 | 2668
962 | 2216
963 | 2272
964 | 2229
965 | 2666
966 | 2456
967 | 2534
968 | 2697
969 | 2688
970 | 2062
971 | 2691
972 | 2689
973 | 2154
974 | 2590
975 | 2626
976 | 2390
977 | 1813
978 | 2067
979 | 1952
980 | 2518
981 | 2358
982 | 1789
983 | 2076
984 | 2049
985 | 2119
986 | 2013
987 | 2124
988 | 2556
989 | 2105
990 | 2093
991 | 1885
992 | 2305
993 | 2354
994 | 2135
995 | 2601
996 | 1770
997 | 1995
998 | 2504
999 | 1749
1000 | 2157
1001 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | import pickle as pkl
5 | import networkx as nx
6 | import scipy.sparse as sp
7 | from sklearn import metrics
8 | from munkres import Munkres
9 | import matplotlib.pyplot as plt
10 | from kmeans_gpu import kmeans
11 | import sklearn.preprocessing as preprocess
12 | from sklearn.metrics import adjusted_rand_score as ari_score
13 | from sklearn.metrics import roc_auc_score, average_precision_score
14 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
15 |
16 |
17 | def sample_mask(idx, l):
18 | """Create mask."""
19 | mask = np.zeros(l)
20 | mask[idx] = 1
21 | return np.array(mask, dtype=np.bool)
22 |
23 |
24 | def load_data(dataset):
25 | # load the data: x, tx, allx, graph
26 | names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
27 | objects = []
28 | if dataset == 'wiki':
29 | adj, features, label = load_wiki()
30 | return adj, features, label, 0, 0, 0
31 |
32 | for i in range(len(names)):
33 | '''
34 | fix Pickle incompatibility of numpy arrays between Python 2 and 3
35 | https://stackoverflow.com/questions/11305790/pickle-incompatibility-of-numpy-arrays-between-python-2-and-3
36 | '''
37 | with open("data/ind.{}.{}".format(dataset, names[i]), 'rb') as rf:
38 | u = pkl._Unpickler(rf)
39 | u.encoding = 'latin1'
40 | cur_data = u.load()
41 | objects.append(cur_data)
42 | # objects.append(
43 | # pkl.load(open("data/ind.{}.{}".format(dataset, names[i]), 'rb')))
44 | x, y, tx, ty, allx, ally, graph = tuple(objects)
45 | test_idx_reorder = parse_index_file(
46 | "data/ind.{}.test.index".format(dataset))
47 | test_idx_range = np.sort(test_idx_reorder)
48 |
49 | if dataset == 'citeseer':
50 | # Fix citeseer dataset (there are some isolated nodes in the graph)
51 | # Find isolated nodes, add them as zero-vecs into the right position
52 | test_idx_range_full = range(
53 | min(test_idx_reorder), max(test_idx_reorder) + 1)
54 | tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
55 | tx_extended[test_idx_range - min(test_idx_range), :] = tx
56 | tx = tx_extended
57 | ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
58 | ty_extended[test_idx_range - min(test_idx_range), :] = ty
59 | ty = ty_extended
60 |
61 | features = sp.vstack((allx, tx)).tolil()
62 | features[test_idx_reorder, :] = features[test_idx_range, :]
63 | features = torch.FloatTensor(np.array(features.todense()))
64 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
65 |
66 | labels = np.vstack((ally, ty))
67 | labels[test_idx_reorder, :] = labels[test_idx_range, :]
68 |
69 | idx_test = test_idx_range.tolist()
70 | idx_train = range(len(y))
71 | idx_val = range(len(y), len(y) + 500)
72 |
73 | train_mask = sample_mask(idx_train, labels.shape[0])
74 | val_mask = sample_mask(idx_val, labels.shape[0])
75 | test_mask = sample_mask(idx_test, labels.shape[0])
76 |
77 | y_train = np.zeros(labels.shape)
78 | y_val = np.zeros(labels.shape)
79 | y_test = np.zeros(labels.shape)
80 | y_train[train_mask, :] = labels[train_mask, :]
81 | y_val[val_mask, :] = labels[val_mask, :]
82 | y_test[test_mask, :] = labels[test_mask, :]
83 |
84 | return adj, features, np.argmax(labels, 1), idx_train, idx_val, idx_test
85 |
86 |
87 | def load_wiki():
88 | f = open('data/graph.txt', 'r')
89 | adj, xind, yind = [], [], []
90 | for line in f.readlines():
91 | line = line.split()
92 |
93 | xind.append(int(line[0]))
94 | yind.append(int(line[1]))
95 | adj.append([int(line[0]), int(line[1])])
96 | f.close()
97 | ##print(len(adj))
98 |
99 | f = open('data/group.txt', 'r')
100 | label = []
101 | for line in f.readlines():
102 | line = line.split()
103 | label.append(int(line[1]))
104 | f.close()
105 |
106 | f = open('data/tfidf.txt', 'r')
107 | fea_idx = []
108 | fea = []
109 | adj = np.array(adj)
110 | adj = np.vstack((adj, adj[:, [1, 0]]))
111 | adj = np.unique(adj, axis=0)
112 |
113 | labelset = np.unique(label)
114 | labeldict = dict(zip(labelset, range(len(labelset))))
115 | label = np.array([labeldict[x] for x in label])
116 | adj = sp.csr_matrix((np.ones(len(adj)), (adj[:, 0], adj[:, 1])), shape=(len(label), len(label)))
117 |
118 | for line in f.readlines():
119 | line = line.split()
120 | fea_idx.append([int(line[0]), int(line[1])])
121 | fea.append(float(line[2]))
122 | f.close()
123 |
124 | fea_idx = np.array(fea_idx)
125 | features = sp.csr_matrix((fea, (fea_idx[:, 0], fea_idx[:, 1])), shape=(len(label), 4973)).toarray()
126 | scaler = preprocess.MinMaxScaler()
127 | # features = preprocess.normalize(features, norm='l2')
128 | features = scaler.fit_transform(features)
129 | features = torch.FloatTensor(features)
130 |
131 | return adj, features, label
132 |
133 |
134 | def parse_index_file(filename):
135 | index = []
136 | for line in open(filename):
137 | index.append(int(line.strip()))
138 | return index
139 |
140 |
141 | def sparse_to_tuple(sparse_mx):
142 | if not sp.isspmatrix_coo(sparse_mx):
143 | sparse_mx = sparse_mx.tocoo()
144 | coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
145 | values = sparse_mx.data
146 | shape = sparse_mx.shape
147 | return coords, values, shape
148 |
149 |
150 | def decompose(adj, dataset, norm='sym', renorm=True):
151 | adj = sp.coo_matrix(adj)
152 | ident = sp.eye(adj.shape[0])
153 | if renorm:
154 | adj_ = adj + ident
155 | else:
156 | adj_ = adj
157 |
158 | rowsum = np.array(adj_.sum(1))
159 |
160 | if norm == 'sym':
161 | degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
162 | adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
163 | laplacian = ident - adj_normalized
164 | evalue, evector = np.linalg.eig(laplacian.toarray())
165 | np.save(dataset + ".npy", evalue)
166 | print(max(evalue))
167 | exit(1)
168 | fig = plt.figure()
169 | ax = fig.add_subplot(1, 1, 1)
170 | n, bins, patches = ax.hist(evalue, 50, facecolor='g')
171 | plt.xlabel('Eigenvalues')
172 | plt.ylabel('Frequncy')
173 | fig.savefig("eig_renorm_" + dataset + ".png")
174 |
175 |
176 | def preprocess_graph(adj, layer, norm='sym', renorm=True):
177 | adj = sp.coo_matrix(adj)
178 | ident = sp.eye(adj.shape[0])
179 | if renorm:
180 | adj_ = adj + ident
181 | else:
182 | adj_ = adj
183 |
184 | rowsum = np.array(adj_.sum(1))
185 |
186 | if norm == 'sym':
187 | degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
188 | adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
189 | laplacian = ident - adj_normalized
190 | elif norm == 'left':
191 | degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -1.).flatten())
192 | adj_normalized = degree_mat_inv_sqrt.dot(adj_).tocoo()
193 | laplacian = ident - adj_normalized
194 |
195 | # reg = [2 / 3] * (layer)
196 | reg = [1] * (layer)
197 |
198 | adjs = []
199 | for i in range(len(reg)):
200 | adjs.append(ident - (reg[i] * laplacian))
201 |
202 | return adjs
203 |
204 |
205 | def laplacian(adj):
206 | rowsum = np.array(adj.sum(1))
207 | degree_mat = sp.diags(rowsum.flatten())
208 | lap = degree_mat - adj
209 | return torch.FloatTensor(lap.toarray())
210 |
211 |
212 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
213 | """Convert a scipy sparse matrix to a torch sparse tensor."""
214 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
215 | indices = torch.from_numpy(
216 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
217 | values = torch.from_numpy(sparse_mx.data)
218 | shape = torch.Size(sparse_mx.shape)
219 | return torch.sparse.FloatTensor(indices, values, shape)
220 |
221 |
222 | def get_roc_score(emb, adj_orig, edges_pos, edges_neg):
223 | def sigmoid(x):
224 | return 1 / (1 + np.exp(-x))
225 |
226 | # Predict on test set of edges
227 | adj_rec = np.dot(emb, emb.T)
228 | preds = []
229 | pos = []
230 | for e in edges_pos:
231 | preds.append(sigmoid(adj_rec[e[0], e[1]]))
232 | pos.append(adj_orig[e[0], e[1]])
233 |
234 | preds_neg = []
235 | neg = []
236 | for e in edges_neg:
237 | preds_neg.append(sigmoid(adj_rec[e[0], e[1]]))
238 | neg.append(adj_orig[e[0], e[1]])
239 |
240 | preds_all = np.hstack([preds, preds_neg])
241 | labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds))])
242 | roc_score = roc_auc_score(labels_all, preds_all)
243 | ap_score = average_precision_score(labels_all, preds_all)
244 |
245 | return roc_score, ap_score
246 |
247 |
248 | def cluster_acc(y_true, y_pred):
249 | """
250 | calculate clustering acc and f1-score
251 | Args:
252 | y_true: the ground truth
253 | y_pred: the clustering id
254 |
255 | Returns: acc and f1-score
256 | """
257 | y_true = y_true - np.min(y_true)
258 | l1 = list(set(y_true))
259 | num_class1 = len(l1)
260 | l2 = list(set(y_pred))
261 | num_class2 = len(l2)
262 | ind = 0
263 | if num_class1 != num_class2:
264 | for i in l1:
265 | if i in l2:
266 | pass
267 | else:
268 | y_pred[ind] = i
269 | ind += 1
270 | l2 = list(set(y_pred))
271 | numclass2 = len(l2)
272 | if num_class1 != numclass2:
273 | print('error')
274 | return
275 | cost = np.zeros((num_class1, numclass2), dtype=int)
276 | for i, c1 in enumerate(l1):
277 | mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1]
278 | for j, c2 in enumerate(l2):
279 | mps_d = [i1 for i1 in mps if y_pred[i1] == c2]
280 | cost[i][j] = len(mps_d)
281 | m = Munkres()
282 | cost = cost.__neg__().tolist()
283 | indexes = m.compute(cost)
284 | new_predict = np.zeros(len(y_pred))
285 | for i, c in enumerate(l1):
286 | c2 = l2[indexes[i][1]]
287 | ai = [ind for ind, elm in enumerate(y_pred) if elm == c2]
288 | new_predict[ai] = c
289 | acc = metrics.accuracy_score(y_true, new_predict)
290 | f1_macro = metrics.f1_score(y_true, new_predict, average='macro')
291 | return acc, f1_macro
292 |
293 |
294 | def eva(y_true, y_pred, show_details=True):
295 | """
296 | evaluate the clustering performance
297 | Args:
298 | y_true: the ground truth
299 | y_pred: the predicted label
300 | show_details: if print the details
301 | Returns: None
302 | """
303 | acc, f1 = cluster_acc(y_true, y_pred)
304 | nmi = nmi_score(y_true, y_pred, average_method='arithmetic')
305 | ari = ari_score(y_true, y_pred)
306 | if show_details:
307 | print(':acc {:.4f}'.format(acc), ', nmi {:.4f}'.format(nmi), ', ari {:.4f}'.format(ari),
308 | ', f1 {:.4f}'.format(f1))
309 | return acc, nmi, ari, f1
310 |
311 |
312 | def load_graph_data(dataset_name, show_details=False):
313 | """
314 | load graph data
315 | :param dataset_name: the name of the dataset
316 | :param show_details: if show the details of dataset
317 | - dataset name
318 | - features' shape
319 | - labels' shape
320 | - adj shape
321 | - edge num
322 | - category num
323 | - category distribution
324 | :return: the features, labels and adj
325 | """
326 | load_path = "dataset/" + dataset_name + "/" + dataset_name
327 | feat = np.load(load_path+"_feat.npy", allow_pickle=True)
328 | label = np.load(load_path+"_label.npy", allow_pickle=True)
329 | adj = np.load(load_path+"_adj.npy", allow_pickle=True)
330 | if show_details:
331 | print("++++++++++++++++++++++++++++++")
332 | print("---details of graph dataset---")
333 | print("++++++++++++++++++++++++++++++")
334 | print("dataset name: ", dataset_name)
335 | print("feature shape: ", feat.shape)
336 | print("label shape: ", label.shape)
337 | print("adj shape: ", adj.shape)
338 | print("undirected edge num: ", int(np.nonzero(adj)[0].shape[0]/2))
339 | print("category num: ", max(label)-min(label)+1)
340 | print("category distribution: ")
341 | for i in range(max(label)+1):
342 | print("label", i, end=":")
343 | print(len(label[np.where(label == i)]))
344 | print("++++++++++++++++++++++++++++++")
345 |
346 | return feat, label, adj
347 |
348 |
349 | def normalize_adj(adj, self_loop=True, symmetry=False):
350 | """
351 | normalize the adj matrix
352 | :param adj: input adj matrix
353 | :param self_loop: if add the self loop or not
354 | :param symmetry: symmetry normalize or not
355 | :return: the normalized adj matrix
356 | """
357 | # add the self_loop
358 | if self_loop:
359 | adj_tmp = adj + np.eye(adj.shape[0])
360 | else:
361 | adj_tmp = adj
362 |
363 | # calculate degree matrix and it's inverse matrix
364 | d = np.diag(adj_tmp.sum(0))
365 | d_inv = np.linalg.inv(d)
366 |
367 | # symmetry normalize: D^{-0.5} A D^{-0.5}
368 | if symmetry:
369 | sqrt_d_inv = np.sqrt(d_inv)
370 | norm_adj = np.matmul(np.matmul(sqrt_d_inv, adj_tmp), adj_tmp)
371 |
372 | # non-symmetry normalize: D^{-1} A
373 | else:
374 | norm_adj = np.matmul(d_inv, adj_tmp)
375 |
376 | return norm_adj
377 |
378 |
379 | def setup_seed(seed):
380 | """
381 | setup random seed to fix the result
382 | Args:
383 | seed: random seed
384 | Returns: None
385 | """
386 | torch.manual_seed(seed)
387 | torch.cuda.manual_seed(seed)
388 | torch.cuda.manual_seed_all(seed)
389 | np.random.seed(seed)
390 | random.seed(seed)
391 | torch.manual_seed(seed)
392 | torch.backends.cudnn.benchmark = False
393 | torch.backends.cudnn.deterministic = True
394 |
395 |
396 | def clustering(feature, true_labels, cluster_num):
397 | predict_labels, dis, initial = kmeans(X=feature, num_clusters=cluster_num, distance="euclidean", device="cuda")
398 | acc, nmi, ari, f1 = eva(true_labels, predict_labels.numpy(), show_details=False)
399 | return 100 * acc, 100 * nmi, 100 * ari, 100 * f1, predict_labels.numpy(),dis
--------------------------------------------------------------------------------