├── README.md
├── data
├── bat
│ ├── bat_adj.npy
│ ├── bat_feat.npy
│ └── bat_label.npy
├── citeseer.mat
├── cora.mat
└── eat
│ ├── eat_adj.npy
│ ├── eat_feat.npy
│ └── eat_label.npy
├── dataset.py
├── demo_bash.sh
├── env.txt
├── envPy37.txt
├── function_laplacian_diffusion.py
├── log
└── 20230730_221042.log
├── main.py
├── model.py
├── run.sh
├── task.py
└── utilis.py
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Contrastive Graph Diffusion Network
2 | Contrastive learning has been proven to be a successful approach in graph self-supervised learning. Augmentation techniques and sampling strategies are crucial in contrastive learning, but in most existing works, augmentation techniques require careful design, and their sampling strategies can only capture a small amount of intrinsic supervision information. Additionally, the existing methods require complex designs to obtain two different representations of the data. To overcome these limitations, we propose a novel framework called the Self-Contrastive Graph Diffusion Network (SCGDN). Our framework consists of two main components: the Attentional Module (AttM) and the Diffusion Module (DiFM). AttM aggregates higher-order structure and feature information to get an excellent embedding, while DiFM balances the state of each node in the graph through Laplacian diffusion learning and allows the cooperative evolution of adjacency and feature information in the graph. Unlike existing methodologies, SCGDN is an augmentation-free approach that avoids "sampling bias" and semantic drift, without the need for pre-training. We conduct a high-quality sampling of samples based on structure and feature information. If two nodes are neighbors, they are considered positive samples of each other. If two disconnected nodes are also unrelated on $k$NN graph, they are considered negative samples for each other. The contrastive objective reasonably uses our proposed sampling strategies, and the redundancy reduction term minimizes redundant information in the embedding and can well retain more discriminative information. In this novel framework, the graph self-contrastive learning paradigm gives expression to a powerful force. SCGDN effectively balances between preserving high-order structure information and avoiding overfitting. The results manifest that SCGDN can consistently generate outperformance over both the contrastive methods and the classical methods.
3 |
4 | ## Requirements
5 | > Dependencies (with python >= 3.7): Main dependencies are
6 | ```
7 | torch==1.8.1
8 | torch-cluster==1.5.9
9 | torch-geometric==1.7.0
10 | torch-scatter==2.0.6
11 | torch-sparse==0.6.9
12 | torch-spline-conv==1.2.1
13 | torchdiffeq==0.2.1
14 | ```
15 | > Commands to install all the dependencies in a new conda environment
16 | ```bash
17 | conda create --name grand python=3.7
18 | conda activate grand
19 |
20 | pip install ogb pykeops
21 | pip install torch==1.8.1
22 | pip install torchdiffeq -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html
23 |
24 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html
25 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html
26 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html
27 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html
28 |
29 | # or
30 | conda create --name grand --file env.txt
31 | ```
32 |
33 | ```bash
34 | mkdir best log
35 | ```
36 |
37 | ```py
38 | python main.py --dataname cora --beta=3 --epochs=50 --hid_dim=512 --knn=21 --time=100
39 | # ACC:74.79$\pm$0.38 NMI:56.86$\pm$0.42 ARI:52.61$\pm$0.33 F1:70.42$\pm$0.48
40 |
41 | python main.py --dataname citeseer --beta=7 --epochs=50 --hid_dim=512 --knn=150 --time=160
42 | # ACC:69.62$\pm$0.02 NMI:44.35$\pm$0.03 ARI:45.43$\pm$0.02 F1:65.50$\pm$0.06
43 |
44 | python main.py --dataname amap --epochs=20 --knn=19 --hid_dim=512 --time=10.0 --beta=5.0
45 | # 78.91$\pm$0.00 NMI:72.53$\pm0.02 ARI:63.41$\pm$0.01 F1:75.27$\pm$0.01
46 |
47 | python main.py --dataname bat --beta 0.7 --epochs=25 --hid_dim=64 --knn=21 --time=200.0
48 | # ACC:74.73 $\pm$ 0.23 NMI:52.63 $\pm$ 0.11 ARI:47.65 $\pm$ 0.18 F1:74.49 $\pm$0.26
49 |
50 | python main.py --dataname eat --beta=6 --gamma=1.5 --epochs=30 --hid_dim=512 --knn=155 --time=15
51 | # ACC:51.88 $\pm$ 0.00 NMI:32.49 $\pm$ 0.21 ARI:23.86 $\pm$ 0.08 F1:47.62 $\pm$ 0.02
52 |
53 | python main.py --dataname corafull --beta=2 --epochs=12 --hid_dim=1024 --knn=73 --time=5 --gpu=-1
54 | ```
55 |
56 | ## Citation
57 |
58 | If you find this project useful, please consider citing:
59 |
60 | ```bibtex
61 | @InProceedings{MaMM2023,
62 | author = {Yixuan Ma and Kun Zhan},
63 | booktitle = {ACM MM},
64 | title = {Self-contrastive graph diffusion network},
65 | year = {2023},
66 | volume = {31},
67 | }
68 | ```
69 | # Contact
70 | https://kunzhan.github.io/
71 |
72 | If you have any questions, feel free to contact me. (Email: `ice.echo#gmail.com`)
73 |
--------------------------------------------------------------------------------
/data/bat/bat_adj.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/bat/bat_adj.npy
--------------------------------------------------------------------------------
/data/bat/bat_feat.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/bat/bat_feat.npy
--------------------------------------------------------------------------------
/data/bat/bat_label.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/bat/bat_label.npy
--------------------------------------------------------------------------------
/data/citeseer.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/citeseer.mat
--------------------------------------------------------------------------------
/data/cora.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/cora.mat
--------------------------------------------------------------------------------
/data/eat/eat_adj.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/eat/eat_adj.npy
--------------------------------------------------------------------------------
/data/eat/eat_feat.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/eat/eat_feat.npy
--------------------------------------------------------------------------------
/data/eat/eat_label.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/eat/eat_label.npy
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sklearn.metrics import pairwise
4 | import scipy
5 | import scipy.sparse as sp
6 | from torch_scatter import scatter_add
7 | from torch_geometric.utils import to_undirected, to_scipy_sparse_matrix, degree, add_remaining_self_loops
8 | from torch_geometric.utils.num_nodes import maybe_num_nodes
9 |
10 | from ipdb import set_trace
11 |
12 | def load_data(dataset_name, show_details=False):
13 | load_path = "./data/" + dataset_name + "/" + dataset_name
14 | feat = np.load(load_path+"_feat.npy", allow_pickle=True)
15 | label = np.load(load_path+"_label.npy", allow_pickle=True)
16 | adj = np.load(load_path+"_adj.npy", allow_pickle=True)
17 | if show_details:
18 | print("++++++++++++++++++++++++++++++")
19 | print("---details of graph dataset---")
20 | print("++++++++++++++++++++++++++++++")
21 | print("dataset name: ", dataset_name)
22 | print("feature shape: ", feat.shape)
23 | print("label shape: ", label.shape)
24 | print("adj shape: ", adj.shape)
25 | print("undirected edge num: ", int(np.nonzero(adj)[0].shape[0]/2))
26 | print("category num: ", max(label)-min(label)+1)
27 | print("category distribution: ")
28 | for i in range(max(label)+1):
29 | print("label", i, end=":")
30 | print(len(label[np.where(label == i)]))
31 | print("++++++++++++++++++++++++++++++")
32 |
33 | return feat, label, adj
34 |
35 |
36 | def get_rw_adj(edge_index, norm_dim=1, fill_value=0., num_nodes=None, type='sys'):
37 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
38 | edge_weight = torch.ones((edge_index.size(1),), dtype=torch.float32, device=edge_index.device)
39 |
40 | if not fill_value == 0:
41 | edge_index, tmp_edge_weight = add_remaining_self_loops(edge_index, edge_weight, fill_value, num_nodes)
42 | assert tmp_edge_weight is not None
43 | edge_weight = tmp_edge_weight
44 |
45 | row, col = edge_index[0], edge_index[1]
46 | indices = row if norm_dim == 0 else col
47 | deg = scatter_add(edge_weight, indices, dim=0, dim_size=num_nodes)
48 | # deg_inv_sqrt = deg.pow_(-1)
49 | # edge_weight = deg_inv_sqrt[indices] * edge_weight if norm_dim == 0 else edge_weight * deg_inv_sqrt[indices]
50 |
51 | if type=='sys':
52 | deg_inv_sqrt = deg.pow_(-0.5)
53 | edge_weight = deg_inv_sqrt[indices] * edge_weight * deg_inv_sqrt[indices]
54 | else:
55 | deg_inv_sqrt = deg.pow_(-1)
56 | edge_weight = deg_inv_sqrt[indices] * edge_weight if norm_dim == 0 else edge_weight * deg_inv_sqrt[indices]
57 | return edge_index, edge_weight
58 |
59 | def adj_normalized(adj, type='sys'):
60 | row_sum = torch.sum(adj, dim=1)
61 | row_sum = (row_sum==0)*1+row_sum
62 | if type=='sys':
63 | d_inv_sqrt = torch.pow(row_sum, -0.5).flatten()
64 | d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
65 | d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
66 | return d_mat_inv_sqrt.mm(adj).mm(d_mat_inv_sqrt)
67 | else:
68 | d_inv = torch.pow(row_sum, -1).flatten()
69 | d_inv[torch.isinf(d_inv)] = 0.
70 | d_mat_inv = torch.diag(d_inv)
71 | return d_mat_inv.mm(adj)
72 |
73 | def FeatureNormalize(mx):
74 | """Row-normalize sparse matrix"""
75 | rowsum = np.array(mx.sum(1))
76 | rowsum = (rowsum == 0) * 1 + rowsum #!!!!!
77 | r_inv = np.power(rowsum, -1).flatten()
78 | r_inv[np.isinf(r_inv)] = 0.0
79 | r_mat_inv = sp.diags(r_inv)
80 | mx = r_mat_inv.dot(mx)
81 | return mx
82 |
83 | def compute_knn(args, features, distribution='t-distribution'):
84 | features = FeatureNormalize(features)
85 | # Dis = pairwise.euclidean_distances(self.data, self.data)
86 | # Dis = pairwise.manhattan_distances(self.data, self.data)
87 | # Dis = pairwise.haversine_distances(self.data, self.data)
88 | Dis = pairwise.cosine_distances(features, features)
89 | Dis = Dis/np.max(np.max(Dis, 1))
90 | if distribution=='t-distribution':
91 | gamma = CalGamma(args.v_input)
92 | sim = gamma * np.sqrt(2 * np.pi) * np.power((1 + args.sigma*np.power(Dis,2) / args.v_input), -1 * (args.v_input + 1) / 2)
93 | else:
94 | sim = np.exp(-Dis/(args.sigma**2))
95 |
96 | K = args.knn
97 | if K>0:
98 | idx = sim.argsort()[:,::-1]
99 | sim_new = np.zeros_like(sim)
100 | for ii in range(0, len(sim_new)):
101 | sim_new[ii, idx[ii,0:K]] = sim[ii, idx[ii,0:K]]
102 | Disknn = (sim_new + sim_new.T)/2
103 | else:
104 | Disknn = (sim + sim.T)/2
105 |
106 | Disknn = torch.from_numpy(Disknn).type(torch.FloatTensor)
107 | Disknn = torch.add(torch.eye(Disknn.shape[0]), Disknn)
108 | Disknn = adj_normalized(Disknn)
109 |
110 | return Disknn
111 |
112 | def CalGamma(v):
113 | a = scipy.special.gamma((v + 1) / 2)
114 | b = np.sqrt(v * np.pi) * scipy.special.gamma(v / 2)
115 | out = a / b
116 | return out
117 |
118 |
119 | # def cal_norm(edge_index0, args, feat=None, cut=False, num_nodes=None):
120 | # # calculate normalization factors: (2*D)^{-1/2} or (D)^{-1/2}
121 | # edge_index0 = sp.coo_matrix(edge_index0)
122 | # values = edge_index0.data
123 | # indices = np.vstack((edge_index0.row, edge_index0.col))
124 | # edge_index0 = torch.LongTensor(indices).to(args.device)
125 |
126 | # edge_weight = torch.ones((edge_index0.size(1),), dtype=torch.float32, device=args.device)
127 | # edge_index, _ = add_remaining_self_loops(edge_index0, edge_weight, 0, args.N)
128 |
129 | # if num_nodes is None:
130 | # num_nodes = edge_index.max()+1
131 | # D = degree(edge_index[0], num_nodes) # 传入edge_index[0]计算节点出度, 该处为无向图,所以即计算节点度
132 |
133 | # if cut:
134 | # D = torch.sqrt(1/D)
135 | # D[D == float("inf")] = 0.
136 | # edge_index = to_undirected(edge_index, num_nodes=num_nodes)
137 | # row, col = edge_index
138 | # mask = row
&1 | tee ./log/$TRAINING_LOG
7 |
--------------------------------------------------------------------------------
/env.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | _openmp_mutex=5.1=1_gnu
6 | asttokens=2.2.1=pypi_0
7 | backcall=0.2.0=pypi_0
8 | blas=1.0=mkl
9 | brotlipy=0.7.0=py38h27cfd23_1003
10 | ca-certificates=2022.10.11=h06a4308_0
11 | certifi=2022.12.7=py38h06a4308_0
12 | cffi=1.15.1=py38h5eee18b_3
13 | charset-normalizer=2.0.4=pyhd3eb1b0_0
14 | contourpy=1.0.7=pypi_0
15 | cryptography=38.0.4=py38h9ce1e76_0
16 | cudatoolkit=11.3.1=h2bc3f7f_2
17 | cycler=0.11.0=pypi_0
18 | decorator=5.1.1=pypi_0
19 | dgl-cuda10.2=0.9.1post1=py38_0
20 | executing=1.2.0=pypi_0
21 | fftw=3.3.9=h27cfd23_1
22 | flit-core=3.6.0=pyhd3eb1b0_0
23 | fonttools=4.38.0=pypi_0
24 | gcl=0.6.11=pypi_0
25 | idna=3.4=py38h06a4308_0
26 | imageio=2.24.0=pypi_0
27 | intel-openmp=2021.4.0=h06a4308_3561
28 | ipdb=0.13.11=pypi_0
29 | ipython=8.8.0=pypi_0
30 | jedi=0.18.2=pypi_0
31 | jinja2=3.1.2=py38h06a4308_0
32 | joblib=1.1.1=py38h06a4308_0
33 | kiwisolver=1.4.4=pypi_0
34 | ld_impl_linux-64=2.38=h1181459_1
35 | libffi=3.4.2=h6a678d5_6
36 | libgcc-ng=11.2.0=h1234567_1
37 | libgfortran-ng=11.2.0=h00389a5_1
38 | libgfortran5=11.2.0=h1234567_1
39 | libgomp=11.2.0=h1234567_1
40 | libstdcxx-ng=11.2.0=h1234567_1
41 | littleutils=0.2.2=pypi_0
42 | markupsafe=2.1.1=py38h7f8727e_0
43 | matplotlib=3.6.3=pypi_0
44 | matplotlib-inline=0.1.6=pypi_0
45 | mkl=2021.4.0=h06a4308_640
46 | mkl-service=2.4.0=py38h7f8727e_0
47 | mkl_fft=1.3.1=py38hd3c417c_0
48 | mkl_random=1.2.2=py38h51133e4_0
49 | munkres=1.1.4=pypi_0
50 | ncurses=6.3=h5eee18b_3
51 | networkx=3.0=pypi_0
52 | numpy=1.23.5=py38h14f4228_0
53 | numpy-base=1.23.5=py38h31eccc5_0
54 | ogb=1.3.5=pypi_0
55 | openssl=1.1.1s=h7f8727e_0
56 | outdated=0.2.2=pypi_0
57 | packaging=23.0=pypi_0
58 | pandas=1.5.3=pypi_0
59 | parso=0.8.3=pypi_0
60 | pexpect=4.8.0=pypi_0
61 | pickleshare=0.7.5=pypi_0
62 | pillow=9.4.0=pypi_0
63 | pip=22.3.1=py38h06a4308_0
64 | prompt-toolkit=3.0.36=pypi_0
65 | psutil=5.9.0=py38h5eee18b_0
66 | ptyprocess=0.7.0=pypi_0
67 | pure-eval=0.2.2=pypi_0
68 | pycparser=2.21=pyhd3eb1b0_0
69 | pyg=2.2.0=py38_torch_1.12.0_cu113
70 | pygcl=0.1.2=pypi_0
71 | pygments=2.14.0=pypi_0
72 | pyopenssl=22.0.0=pyhd3eb1b0_0
73 | pyparsing=3.0.9=py38h06a4308_0
74 | pysocks=1.7.1=py38h06a4308_0
75 | python=3.8.16=h7a1cb2a_2
76 | python-dateutil=2.8.2=pypi_0
77 | pytorch=1.12.0=py3.8_cuda11.3_cudnn8.3.2_0
78 | pytorch-cluster=1.6.0=py38_torch_1.12.0_cu113
79 | pytorch-mutex=1.0=cuda
80 | pytorch-scatter=2.1.0=py38_torch_1.12.0_cu113
81 | pytorch-sparse=0.6.16=py38_torch_1.12.0_cu113
82 | pytz=2022.7.1=pypi_0
83 | readline=8.2=h5eee18b_0
84 | requests=2.28.1=py38h06a4308_0
85 | scikit-learn=1.2.0=py38h6a678d5_0
86 | scipy=1.9.3=py38h14f4228_0
87 | seaborn=0.12.2=pypi_0
88 | setuptools=65.6.3=py38h06a4308_0
89 | six=1.16.0=pyhd3eb1b0_1
90 | sqlite=3.40.1=h5082296_0
91 | stack-data=0.6.2=pypi_0
92 | threadpoolctl=2.2.0=pyh0d69192_0
93 | tk=8.6.12=h1ccaba5_0
94 | tomli=2.0.1=pypi_0
95 | torchdiffeq=0.2.3=pypi_0
96 | tqdm=4.64.1=py38h06a4308_0
97 | traitlets=5.8.1=pypi_0
98 | typing_extensions=4.4.0=py38h06a4308_0
99 | urllib3=1.26.14=py38h06a4308_0
100 | wcwidth=0.2.6=pypi_0
101 | wheel=0.37.1=pyhd3eb1b0_0
102 | xz=5.2.10=h5eee18b_1
103 | zlib=1.2.13=h5eee18b_0
104 |
--------------------------------------------------------------------------------
/envPy37.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | _openmp_mutex=5.1=1_gnu
6 | ase=3.22.1=pypi_0
7 | backcall=0.2.0=pypi_0
8 | ca-certificates=2023.05.30=h06a4308_0
9 | certifi=2022.12.7=py37h06a4308_0
10 | charset-normalizer=3.2.0=pypi_0
11 | cycler=0.11.0=pypi_0
12 | decorator=5.1.1=pypi_0
13 | fonttools=4.38.0=pypi_0
14 | googledrivedownloader=0.4=pypi_0
15 | h5py=3.8.0=pypi_0
16 | idna=3.4=pypi_0
17 | imageio=2.31.1=pypi_0
18 | importlib-metadata=4.13.0=pypi_0
19 | ipdb=0.13.13=pypi_0
20 | ipython=7.34.0=pypi_0
21 | isodate=0.6.1=pypi_0
22 | jedi=0.19.0=pypi_0
23 | jinja2=3.1.2=pypi_0
24 | joblib=1.3.1=pypi_0
25 | keopscore=2.1.2=pypi_0
26 | kiwisolver=1.4.4=pypi_0
27 | ld_impl_linux-64=2.38=h1181459_1
28 | libffi=3.4.4=h6a678d5_0
29 | libgcc-ng=11.2.0=h1234567_1
30 | libgomp=11.2.0=h1234567_1
31 | libstdcxx-ng=11.2.0=h1234567_1
32 | littleutils=0.2.2=pypi_0
33 | llvmlite=0.39.1=pypi_0
34 | markupsafe=2.1.3=pypi_0
35 | matplotlib=3.5.3=pypi_0
36 | matplotlib-inline=0.1.6=pypi_0
37 | munkres=1.1.4=pypi_0
38 | ncurses=6.4=h6a678d5_0
39 | networkx=2.6.3=pypi_0
40 | numba=0.56.4=pypi_0
41 | numpy=1.21.6=pypi_0
42 | nvidia-cublas-cu11=11.10.3.66=pypi_0
43 | nvidia-cuda-nvrtc-cu11=11.7.99=pypi_0
44 | nvidia-cuda-runtime-cu11=11.7.99=pypi_0
45 | nvidia-cudnn-cu11=8.5.0.96=pypi_0
46 | ogb=1.3.6=pypi_0
47 | openssl=1.1.1u=h7f8727e_0
48 | outdated=0.2.2=pypi_0
49 | packaging=23.1=pypi_0
50 | pandas=1.3.5=pypi_0
51 | parso=0.8.3=pypi_0
52 | pexpect=4.8.0=pypi_0
53 | pickleshare=0.7.5=pypi_0
54 | pillow=9.5.0=pypi_0
55 | pip=22.3.1=py37h06a4308_0
56 | prompt-toolkit=3.0.39=pypi_0
57 | psutil=5.9.5=pypi_0
58 | ptyprocess=0.7.0=pypi_0
59 | pybind11=2.11.1=pypi_0
60 | pygments=2.15.1=pypi_0
61 | pykeops=2.1.2=pypi_0
62 | pyparsing=3.1.0=pypi_0
63 | python=3.7.16=h7a1cb2a_0
64 | python-dateutil=2.8.2=pypi_0
65 | python-louvain=0.16=pypi_0
66 | pytz=2023.3=pypi_0
67 | rdflib=6.3.2=pypi_0
68 | readline=8.2=h5eee18b_0
69 | requests=2.31.0=pypi_0
70 | scikit-learn=1.0.2=pypi_0
71 | scipy=1.7.3=pypi_0
72 | setuptools=65.6.3=py37h06a4308_0
73 | six=1.16.0=pypi_0
74 | sqlite=3.41.2=h5eee18b_0
75 | threadpoolctl=3.1.0=pypi_0
76 | tk=8.6.12=h1ccaba5_0
77 | tomli=2.0.1=pypi_0
78 | torch=1.8.1=pypi_0
79 | torch-cluster=1.5.9=pypi_0
80 | torch-geometric=1.7.0=pypi_0
81 | torch-scatter=2.1.1=pypi_0
82 | torch-sparse=0.6.9=pypi_0
83 | torch-spline-conv=1.2.2=pypi_0
84 | torchdiffeq=0.2.3=pypi_0
85 | tqdm=4.65.0=pypi_0
86 | traitlets=5.9.0=pypi_0
87 | typing-extensions=4.7.1=pypi_0
88 | urllib3=2.0.4=pypi_0
89 | wcwidth=0.2.6=pypi_0
90 | wheel=0.38.4=py37h06a4308_0
91 | xz=5.4.2=h5eee18b_0
92 | zipp=3.15.0=pypi_0
93 | zlib=1.2.13=h5eee18b_0
94 |
--------------------------------------------------------------------------------
/function_laplacian_diffusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch_sparse
4 | from torch_geometric.nn.conv import MessagePassing
5 | from ipdb import set_trace
6 |
7 | class ODEFunc(MessagePassing):
8 | # currently requires in_features = out_features
9 | def __init__(self, opt, data, device):
10 | super(ODEFunc, self).__init__()
11 | self.opt = opt
12 | self.device = device
13 | self.edge_index = None
14 | self.edge_weight = None
15 | self.attention_weights = None
16 | self.alpha_train = nn.Parameter(torch.tensor(0.0))
17 | self.beta_train = nn.Parameter(torch.tensor(0.0))
18 | self.x0 = None
19 | self.nfe = 0
20 | self.alpha_sc = nn.Parameter(torch.ones(1))
21 | self.beta_sc = nn.Parameter(torch.ones(1))
22 |
23 | def __repr__(self):
24 | return self.__class__.__name__
25 |
26 | # Define the ODE function.
27 | # Input:
28 | # --- t: A tensor with shape [], meaning the current time.
29 | # --- x: A tensor with shape [#batches, dims], meaning the value of x at t.
30 | # Output:
31 | # --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t.
32 | class LaplacianODEFunc(ODEFunc):
33 | # currently requires in_features = out_features
34 | def __init__(self, args, data, device):
35 | super(LaplacianODEFunc, self).__init__(args, data, device)
36 | self.args = args
37 |
38 | def forward(self, t, x): # the t param is needed by the ODE solver.
39 | self.nfe += 1
40 | ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x)
41 |
42 | alpha = torch.sigmoid(self.alpha_train)
43 |
44 | f = alpha * (ax - x)
45 | if self.args.add_source:
46 | f = f + self.beta_train * self.x0
47 |
48 | return f
49 |
--------------------------------------------------------------------------------
/log/20230730_221042.log:
--------------------------------------------------------------------------------
1 | Namespace(add_source=True, beta=3.0, cut=False, dataname='cora', device='cuda:0', dropout=0.0, epochs=50, exp_lr=1e-05, exp_wd=1e-05, gamma=1, gpu=0, hid_dim=512, imp_lr=0.001, imp_wd=1e-05, knn=21, method='dopri5', n_layers=2, patience=20, seed=42, sigma=0.5, time=100.0, tol_scale=200, train=1, type='sys', v_input=1)
2 | ./best/cora_best.pt
3 | [75.14771049 74.29837518 74.77843427 74.29837518 75.18463811 74.29837518
4 | 75.11078287 75.11078287 75.07385524 75.14771049]
5 |
6 | Clustering result: ACC:74.84 $\pm$ 0.37 NMI:56.89 $\pm$ 0.39 ARI:52.64 $\pm$ 0.34 F1:70.50 $\pm$ 0.46
7 |
8 | Namespace(add_source=True, beta=7.0, cut=False, dataname='citeseer', device='cuda:0', dropout=0.0, epochs=50, exp_lr=1e-05, exp_wd=1e-05, gamma=1, gpu=0, hid_dim=512, imp_lr=0.001, imp_wd=1e-05, knn=150, method='dopri5', n_layers=2, patience=20, seed=42, sigma=0.5, time=160.0, tol_scale=200, train=1, type='sys', v_input=1)
9 | ./best/citeseer_best.pt
10 | [69.58220619 69.55214908 69.55214908 69.55214908 69.55214908 69.55214908
11 | 69.58220619 69.55214908 69.55214908 69.55214908]
12 |
13 | Clustering result: ACC:69.56 $\pm$ 0.01 NMI:44.29 $\pm$ 0.02 ARI:45.34 $\pm$ 0.00 F1:65.43 $\pm$ 0.04
14 |
15 | Namespace(add_source=True, beta=0.7, cut=False, dataname='bat', device='cuda:0', dropout=0.0, epochs=25, exp_lr=1e-05, exp_wd=1e-05, gamma=1, gpu=0, hid_dim=64, imp_lr=0.001, imp_wd=1e-05, knn=21, method='dopri5', n_layers=2, patience=20, seed=42, sigma=0.5, time=200.0, tol_scale=200, train=1, type='sys', v_input=1)
16 | ./best/bat_best.pt
17 | [74.04580153 74.80916031 74.80916031 74.80916031 74.80916031 74.80916031
18 | 74.80916031 74.80916031 74.80916031 74.80916031]
19 |
20 | Clustering result: ACC:74.73 $\pm$ 0.23 NMI:52.63 $\pm$ 0.11 ARI:47.65 $\pm$ 0.18 F1:74.49 $\pm$ 0.26
21 |
22 | Namespace(add_source=True, beta=6.0, cut=False, dataname='eat', device='cuda:0', dropout=0.0, epochs=30, exp_lr=1e-05, exp_wd=1e-05, gamma=1.5, gpu=0, hid_dim=512, imp_lr=0.001, imp_wd=1e-05, knn=155, method='dopri5', n_layers=2, patience=20, seed=42, sigma=0.5, time=15.0, tol_scale=200, train=1, type='sys', v_input=1)
23 | ./best/eat_best.pt
24 | [51.87969925 51.87969925 51.87969925 51.87969925 51.87969925 51.87969925
25 | 51.87969925 51.87969925 51.87969925 51.87969925]
26 |
27 | Clustering result: ACC:51.88 $\pm$ 0.00 NMI:32.44 $\pm$ 0.19 ARI:23.87 $\pm$ 0.08 F1:47.63 $\pm$ 0.02
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from scipy import io
2 | import math
3 | import argparse
4 | from time import *
5 | import torch
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | from torch.optim.lr_scheduler import MultiStepLR
9 | import warnings
10 | import numpy as np
11 |
12 | from model import SCDGN as Net
13 | from dataset import *
14 | from task import *
15 | from utilis import *
16 |
17 | from ipdb import set_trace
18 |
19 | warnings.filterwarnings('ignore')
20 |
21 |
22 | if __name__ == '__main__':
23 | # import faulthandler
24 | # faulthandler.enable()
25 | parser = argparse.ArgumentParser(description='ICML')
26 |
27 | parser.add_argument('--dataname', type=str, default='cora', help='Name of dataset.')
28 | parser.add_argument('--gpu', type=int, default=0, help='GPU index.')
29 | parser.add_argument('--seed', type=int, default=42, help='Random seed.')
30 | parser.add_argument('--epochs', type=int, default=40, help='Training epochs.')
31 | parser.add_argument('--patience', type=int, default=20, help='Patience for early stop.')
32 | parser.add_argument('--train', type=int, default=1, help='Train or not.')
33 |
34 | # Dataset args
35 | parser.add_argument('--cut', type=str, default=False, help='The type of degree.') ##
36 | parser.add_argument('--type', type=str,default='sys',help='sys or rw') ##
37 | parser.add_argument('--knn', type=int, default=25, help='The K of KNN graph.') ##
38 | parser.add_argument('--v_input', type=int, default=1, help='Degree of freedom of T distribution') ##
39 | parser.add_argument('--sigma', type=float, default=0.5, help='Weight parameters for knn.') ##
40 |
41 | # Optimizer args
42 | parser.add_argument('--imp_lr', type=float, default=1e-3, help='Learning rate of ICML.') ##
43 | parser.add_argument('--exp_lr', type=float, default=1e-5, help='Learning rate of ICML.') ##
44 | parser.add_argument('--imp_wd', type=float, default=1e-5, help='Weight decay of ICML.')
45 | parser.add_argument('--exp_wd', type=float, default=1e-5, help='Weight decay of ICML.')
46 |
47 | # GNN args
48 | parser.add_argument("--hid_dim", type=int, default=512, help='Hidden layer dim.') ##
49 | parser.add_argument('--time', type=float, default=18, help='End time of ODE integrator.') ##
50 | parser.add_argument('--method', type=str, default='dopri5', help="set the numerical solver: dopri5, euler, rk4, midpoint")
51 | parser.add_argument('--tol_scale', type=float, default=200, help='tol_scale .') ##
52 | parser.add_argument('--add_source', type=str, default=True, help='Add source.')
53 | parser.add_argument('--dropout', type=float, default=0., help='drop rate.') ##
54 | parser.add_argument('--n_layers', type=int, default=2, help='number of Linear.') ##
55 |
56 | # Loss args
57 | parser.add_argument('--beta', type=float, default=1, help='Weight parameters for loss.')
58 | parser.add_argument('--gamma', type=float, default=1, help='Weight parameters for ICML.')
59 |
60 | args = parser.parse_args()
61 |
62 | # check cuda
63 | if args.gpu != -1 and torch.cuda.is_available():
64 | args.device = 'cuda:{}'.format(args.gpu)
65 | else:
66 | args.device = 'cpu'
67 | set_seed(args.seed)
68 | print(args)
69 | begin_time = time()
70 |
71 | # Load data
72 | if args.dataname in ['amap', 'bat', 'eat', 'uat','corafull']:
73 | feat, label, A = load_data(args.dataname)
74 | labels = torch.from_numpy(label)
75 | else:
76 | data = io.loadmat('./data/{}.mat'.format(args.dataname))
77 | if args.dataname == 'wiki':
78 | feat = data['fea'].todense()
79 | A = data['W'].todense()
80 | elif args.dataname == 'pubmed':
81 | feat = data['fea']
82 | A = data['W'].todense()
83 | else:
84 | feat = data['fea']
85 | A = np.mat(data['W'])
86 | gnd = data['gnd'].T - 1
87 | labels = torch.from_numpy(gnd[0, :])
88 |
89 | feat = torch.from_numpy(feat).type(torch.FloatTensor)
90 | in_dim = feat.shape[1]
91 | args.N = N = feat.shape[0]
92 | norm_factor, edge_index, edge_weight, adj_norm, knn, Lap = cal_norm(A, args, feat)
93 | Lap_Neg = cal_Neg(adj_norm, knn, args)
94 | feat = feat.to(args.device)
95 |
96 |
97 | # Initial
98 | model = Net(N, edge_index, edge_weight, args).to(args.device)
99 | optimizer = optim.Adam([{'params':model.params_imp,'weight_decay':args.imp_wd, 'lr': args.imp_lr},
100 | {'params':model.params_exp,'weight_decay':args.exp_wd, 'lr': args.exp_lr}])
101 |
102 |
103 | checkpt_file = './best/'+args.dataname+'_best.pt'
104 | print(checkpt_file)
105 | if args.train:
106 | cnt_wait = 0
107 | best_loss = 1e9
108 | best_epoch = 0
109 | best_acc = 0
110 | EYE = torch.eye(args.N).to(args.device)
111 | for epoch in range(1,args.epochs+1):
112 | model.train()
113 | optimizer.zero_grad()
114 |
115 | emb = model(knn, adj_norm, norm_factor)
116 | loss =( torch.trace(torch.mm(torch.mm(emb.t(), Lap), emb)) \
117 | - args.beta*(torch.trace(torch.mm(torch.mm(emb.t(), Lap_Neg), emb))) \
118 | + args.gamma*nn.MSELoss()(torch.mm(emb,emb.t()), EYE))/args.N
119 |
120 | loss.backward()
121 | optimizer.step()
122 |
123 | if loss <= best_loss:
124 | best_loss = loss
125 | best_epoch = epoch
126 | cnt_wait = 0
127 | # acc, nmi, ari, f1 = clustering(emb.cpu().detach(), labels)
128 | # print(style.YELLOW + '\nClustering result: ACC:%1.2f || NMI:%1.2f || RI:%1.2f || F-score:%1.2f '%(acc, nmi, ari, f1))
129 | torch.save(model.state_dict(), checkpt_file)
130 | else:
131 | cnt_wait += 1
132 | if cnt_wait == args.patience or math.isnan(loss):
133 | print('\nEarly stopping!', end='')
134 | break
135 |
136 | # print(style.MAGENTA + '\r\rEpoch={:03d}, loss={:.4f}'.format(epoch, loss.item()), end=' ')
137 | # print('')
138 |
139 | model.load_state_dict(torch.load(checkpt_file))
140 | model.eval()
141 | with torch.no_grad():
142 | emb = model(knn, adj_norm, norm_factor)
143 |
144 | # Clustering
145 | emb = emb.cpu().detach().numpy()
146 | # TSNE_plot(emb, labels, args.dataname)
147 | clustering(emb, labels,args)
148 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchdiffeq import odeint
5 | from torch.nn import Parameter
6 |
7 |
8 | from ipdb import set_trace
9 | from utilis import glorot_init
10 | from function_laplacian_diffusion import *
11 |
12 | class AGCN(nn.Module):
13 | def __init__(self, num_nodes):
14 | super(AGCN, self).__init__()
15 | self.n = num_nodes
16 | self.w1 = Parameter(torch.FloatTensor(self.n,self.n))
17 | self.w1.data = torch.eye(self.n)
18 | self.w2 = Parameter(torch.FloatTensor(self.n,self.n))
19 | self.w2.data = torch.eye(self.n)
20 |
21 | def forward(self, X, A):
22 | H = torch.mm(torch.mm(A, self.w1), A.T) # 图结构自注意
23 | H = torch.mm(torch.mm(H, self.w2), X)
24 | embed = torch.mm(H, H.T)
25 | embed = F.normalize(embed, dim=1)
26 |
27 | return embed
28 |
29 |
30 | class GCN(nn.Module):
31 | def __init__(self, input_dim, activation = F.relu, **kwargs):
32 | super(GCN, self).__init__(**kwargs)
33 | self.weight = glorot_init(input_dim, input_dim)
34 | self.activation = activation
35 |
36 | def forward(self, x, adj):
37 | x = torch.mm(x, self.weight)
38 | x = torch.mm(adj, x)
39 | outputs = self.activation(x)
40 | return outputs
41 |
42 |
43 | class ConstantODEblock(nn.Module):
44 | def __init__(self, args, edge_index, edge_weight):
45 | super(ConstantODEblock, self).__init__()
46 | self.args = args
47 | self.t = torch.tensor([0, args.time]).to(args.device)
48 |
49 | self.odefunc = LaplacianODEFunc(args, edge_index, args.device)
50 |
51 | self.odefunc.edge_index = edge_index.to(args.device)
52 | self.odefunc.edge_weight = edge_weight.to(args.device)
53 |
54 | self.train_integrator = odeint
55 | self.test_integrator = odeint
56 | self.atol = args.tol_scale * 1e-7 # Absolute tolerance.
57 | self.rtol = args.tol_scale * 1e-9 # Relative tolerance
58 |
59 | def set_x0(self, x0): ## 设置初始条件
60 | self.odefunc.x0 = x0.clone().detach()
61 |
62 | def forward(self, x):
63 | t = self.t.type_as(x) # 设置迭代总时间 t
64 | integrator = self.train_integrator if self.training else self.test_integrator # odeint 积分求解器
65 |
66 | func = self.odefunc
67 | state = x
68 |
69 | state_dt = integrator(
70 | func, state, t,
71 | method= self.args.method,
72 | options = dict(step_size=1, max_iters=100),
73 | atol=self.atol, # Absolute tolerance.
74 | rtol=self.rtol) # Relative tolerance.
75 |
76 | z = state_dt[1]
77 | return z
78 |
79 |
80 | class SCDGN(nn.Module):
81 | def __init__(self, N, edge_index, edge_weight, args):
82 | super().__init__()
83 | self.edge_weight = edge_weight
84 | self.edge_index = edge_index
85 | self.n_layers = args.n_layers
86 |
87 | self.AttenGCN = AGCN(N)
88 |
89 | self.extractor = nn.ModuleList()
90 | self.extractor.append(nn.Linear(N, args.hid_dim))
91 | for i in range(self.n_layers - 1):
92 | self.extractor.append(nn.Linear(args.hid_dim, args.hid_dim))
93 | self.dropout = nn.Dropout(p=args.dropout)
94 |
95 | self.diffusion = ConstantODEblock(args, edge_index,edge_weight )
96 |
97 | self.init_weights()
98 |
99 | self.params_imp = list(self.diffusion.parameters())
100 | self.params_exp = list(self.AttenGCN.parameters()) \
101 | + list(self.extractor.parameters())
102 |
103 | def init_weights(self):
104 | for m in self.modules():
105 | if isinstance(m, nn.Linear):
106 | nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
107 | if m.bias is not None:
108 | nn.init.zeros_(m.bias)
109 |
110 | def forward(self, knn, adj, norm_factor):
111 | # 联合学习结构和特征
112 | h = self.AttenGCN(knn,adj)
113 |
114 | for i, layer in enumerate(self.extractor):
115 | if i != 0:
116 | h = self.dropout(h)
117 | h = layer(h)
118 |
119 | # 隐式 diffusion
120 | self.diffusion.set_x0(h) # 设置初始边界为 h
121 | new_z = self.diffusion(h)
122 | # z = norm_factor * new_z + h
123 | # z = F.tanh(norm_factor * new_z + h) # 输入特征(初值)与平衡态之和(通量:描述两个节点之间信息流的大小)
124 | z = F.relu(norm_factor * new_z + h)
125 | z = (z - z.mean(0)) / z.std()
126 |
127 | return z
128 |
129 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 |
3 | python main.py --dataname cora --beta=3 --epochs=50 --hid_dim=512 --knn=21 --time=100
4 | # ACC:74.79$\pm$0.38 NMI:56.86$\pm$0.42 ARI:52.61$\pm$0.33 F1:70.42$\pm$0.48
5 |
6 | python main.py --dataname citeseer --beta=7 --epochs=50 --hid_dim=512 --knn=150 --time=160
7 | # ACC:69.62$\pm$0.02 NMI:44.35$\pm$0.03 ARI:45.43$\pm$0.02 F1:65.50$\pm$0.06
8 |
9 | # python main.py --dataname amap --epochs=20 --knn=19 --hid_dim=512 --time=10.0 --beta=5.0
10 | # 78.91$\pm$0.00 NMI:72.53$\pm0.02 ARI:63.41$\pm$0.01 F1:75.27$\pm$0.01
11 |
12 | python main.py --dataname bat --beta 0.7 --epochs=25 --hid_dim=64 --knn=21 --time=200.0
13 | # ACC:74.73 $\pm$ 0.23 NMI:52.63 $\pm$ 0.11 ARI:47.65 $\pm$ 0.18 F1:74.49 $\pm$0.26
14 |
15 | python main.py --dataname eat --beta=6 --gamma=1.5 --epochs=30 --hid_dim=512 --knn=155 --time=15
16 | # ACC:51.88 $\pm$ 0.00 NMI:32.49 $\pm$ 0.21 ARI:23.86 $\pm$ 0.08 F1:47.62 $\pm$ 0.02
17 |
--------------------------------------------------------------------------------
/task.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | import matplotlib.pyplot as plt
7 | from time import *
8 | import os
9 | import imageio
10 | from sklearn.decomposition import PCA
11 | from sklearn.manifold import TSNE
12 |
13 | import numpy as np
14 | from sklearn import metrics
15 | from munkres import Munkres
16 | from sklearn.cluster import KMeans
17 | from scipy.optimize import linear_sum_assignment
18 | from utilis import style
19 |
20 | from ipdb import set_trace
21 |
22 |
23 | class GIFPloter():
24 | def __init__(self, ):
25 | self.path_list = []
26 |
27 | def PlotOtherLayer(self,fig,data,label,title='',fig_position0=1,fig_position1=1,fig_position2=1,s=0.1,graph=None,link=None,):
28 | color_list = []
29 | for i in range(label.shape[0]):
30 | color_list.append(int(label[i]))
31 |
32 | if data.shape[1] > 3:
33 | pca = PCA(n_components=2)
34 | data_em = pca.fit_transform(data)
35 | else:
36 | data_em = data
37 |
38 | # data_em = data_em-data_em.mean(axis=0)
39 |
40 | if data_em.shape[1] == 3:
41 | ax = fig.add_subplot(fig_position0, fig_position1, fig_position2, projection='3d')
42 |
43 | ax.scatter(data_em[:, 0], data_em[:, 1], data_em[:, 2], c=color_list, s=s, cmap='rainbow')
44 |
45 | if data_em.shape[1] == 2:
46 | ax = fig.add_subplot(fig_position0, fig_position1, fig_position2)
47 |
48 | if graph is not None:
49 | self.PlotGraph(data, graph, link)
50 |
51 | s = ax.scatter(data_em[:, 0], data_em[:, 1], c=label, s=s, cmap='rainbow')
52 | plt.axis('equal')
53 | if None:
54 | list_i_n = len(set(label.tolist()))
55 | # print(list_i_n)
56 | legend1 = ax.legend(*s.legend_elements(num=list_i_n),
57 | loc="upper left",
58 | title="Ranking")
59 | ax.add_artist(legend1)
60 | # ax.spines['top'].set_visible(False)
61 | # ax.spines['right'].set_visible(False)
62 | # ax.spines['bottom'].set_visible(False)
63 | # ax.spines['left'].set_visible(False)
64 | # plt.xticks([])
65 | # plt.yticks([])
66 | # plt.title(title)
67 |
68 | def AddNewFig(self,latent,label,link=None,graph=None,his_loss=None,title_='',path='./',dataset=None):
69 | fig = plt.figure(figsize=(5, 5))
70 |
71 | if latent.shape[0] <= 1000: s=3
72 | elif latent.shape[0] <= 10000: s = 1
73 | else: s = 0.1
74 |
75 | # if latent.shape[1] <= 3:
76 | self.PlotOtherLayer(fig, latent, label, title=title_, fig_position0=1, fig_position1=1, fig_position2=1, graph=graph, link=link, s=s)
77 | plt.tight_layout()
78 | path_c = path + title_
79 |
80 | self.path_list.append(path_c)
81 |
82 | plt.savefig(path_c, dpi=100)
83 | plt.close()
84 |
85 | def PlotGraph(self, latent, graph, link):
86 | for i in range(graph.shape[0]):
87 | for j in range(graph.shape[0]):
88 | if graph[i, j] == True:
89 | p1 = latent[i]
90 | p2 = latent[j]
91 | lik = link[i, j]
92 | plt.plot([p1[0], p2[0]], [p1[1], p2[1]],
93 | 'gray',
94 | lw=1 / lik)
95 | if lik > link.min() * 1.01:
96 | plt.text((p1[0] + p2[0]) / 2, (p1[1] + p2[1]) / 2,
97 | str(lik)[:4],
98 | fontsize=5)
99 |
100 | def SaveGIF(self):
101 | gif_images = []
102 | for i, path_ in enumerate(self.path_list):
103 | gif_images.append(imageio.imread(path_))
104 | if i > 0 and i < len(self.path_list)-2:
105 | os.remove(path_)
106 | imageio.mimsave(path_[:-4] + ".gif", gif_images, fps=3)
107 |
108 | def TSNE_plot(X, label, str):
109 | em = TSNE(n_components=2,random_state=6).fit_transform(X)
110 | ploter = GIFPloter()
111 | ploter.AddNewFig(em, label, title_= str+".png", path='./figure/',)
112 |
113 |
114 | # Clustering metrics
115 | def spectral(W, k):
116 | """
117 | SPECTRUAL spectral clustering
118 | :param W: Adjacency matrix, N-by-N matrix
119 | :param k: number of clusters
120 | :return: data point cluster labels, n-by-1 vector.
121 | """
122 | w_sum = np.array(W.sum(axis=1)).reshape(-1)
123 | D = np.diag(w_sum)
124 | _D = np.diag((w_sum + np.finfo(float).eps)** (-1 / 2))
125 | L = D - W
126 | L = _D @ L @ _D
127 | eigval, eigvec = np.linalg.eig(L)
128 | eigval_argsort = eigval.real.astype(np.float32).argsort()
129 | F = np.take(eigvec.real.astype(np.float32), eigval_argsort[:k], axis=-1)
130 | idx = KMeans(n_clusters=k).fit(F).labels_
131 | return idx
132 |
133 | def bestMap(L1,L2):
134 | '''
135 | bestmap: permute labels of L2 to match L1 as good as possible
136 | INPUT:
137 | L1: labels of L1, shape of (N,) vector
138 | L2: labels of L2, shape of (N,) vector
139 | OUTPUT:
140 | new_L2: best matched permuted L2, shape of (N,) vector
141 | version 1.0 --December/2018
142 | Modified from bestMap.m (written by Deng Cai)
143 | '''
144 | if L1.shape[0] != L2.shape[0] or len(L1.shape) > 1 or len(L2.shape) > 1:
145 | raise Exception('L1 shape must equal L2 shape')
146 | return
147 | Label1 = np.unique(L1)
148 | nClass1 = Label1.shape[0]
149 | Label2 = np.unique(L2)
150 | nClass2 = Label2.shape[0]
151 | nClass = max(nClass1,nClass2)
152 | G = np.zeros((nClass, nClass))
153 | for i in range(nClass1):
154 | for j in range(nClass2):
155 | G[j,i] = np.sum((np.logical_and(L1 == Label1[i], L2 == Label2[j])).astype(np.int64))
156 | c,t = linear_sum_assignment(-G)
157 | newL2 = np.zeros(L2.shape)
158 | for i in range(nClass2):
159 | newL2[L2 == Label2[i]] = Label1[t[i]]
160 | return newL2
161 |
162 | def clustering_metrics(true_label, pred_label):
163 | l1 = list(set(true_label))
164 | numclass1 = len(l1)
165 |
166 | l2 = list(set(pred_label))
167 | numclass2 = len(l2)
168 | if numclass1 != numclass2:
169 | print('Class Not equal, Error!!!!')
170 | return 0, 0, 0, 0, 0
171 |
172 | cost = np.zeros((numclass1, numclass2), dtype=int)
173 | for i, c1 in enumerate(l1):
174 | mps = [i1 for i1, e1 in enumerate(true_label) if e1 == c1]
175 | for j, c2 in enumerate(l2):
176 | mps_d = [i1 for i1 in mps if pred_label[i1] == c2]
177 | cost[i][j] = len(mps_d)
178 |
179 | # match two clustering results by Munkres algorithm
180 | m = Munkres()
181 | cost = cost.__neg__().tolist()
182 |
183 | indexes = m.compute(cost)
184 | idx = indexes[2][1]
185 | # get the match results
186 | new_predict = np.zeros(len(pred_label))
187 | for i, c in enumerate(l1):
188 | # correponding label in l2:
189 | c2 = l2[indexes[i][1]]
190 | # ai is the index with label==c2 in the predict list
191 | ai = [ind for ind, elm in enumerate(pred_label) if elm == c2]
192 | new_predict[ai] = c
193 |
194 | acc = metrics.accuracy_score(true_label, new_predict)
195 | f1_macro = metrics.f1_score(true_label, new_predict, average='macro')
196 | nmi = metrics.normalized_mutual_info_score(true_label, pred_label)
197 | ari = metrics.adjusted_rand_score(true_label, pred_label)
198 |
199 | return acc* 100, f1_macro* 100, nmi* 100, ari* 100, idx
200 |
201 | def clustering(embeds,labels,args):
202 | # labels = torch.from_numpy(labels).type(torch.LongTensor)
203 | num_classes = torch.max(labels).item()+1
204 |
205 | # print('=================================== KMeans Clustering. ========================================')
206 | # u, s, v = sp.linalg.svds(embeds, k=num_classes, which='LM')
207 | # predY = KMeans(n_clusters=num_classes).fit(u).labels_
208 |
209 | accs = []
210 | nmis = []
211 | aris = []
212 | f1s = []
213 | for i in range(10):
214 | # best_loss = 1e9
215 | best_acc = 0
216 | best_f1 = 0
217 | best_nmi = 0
218 | best_ari = 0
219 | for j in range(10):
220 | # set_trace()
221 | predY = KMeans(n_clusters=num_classes).fit(embeds).labels_
222 | gnd_Y = bestMap(predY, labels.cpu().detach().numpy())
223 | # predY_temp = torch.tensor(predY, dtype=torch.float)
224 | # gnd_Y_temp = torch.tensor(gnd_Y)
225 | # loss = nn.MSELoss()(predY_temp, gnd_Y_temp)
226 | # if loss <= best_loss:
227 | # best_loss = loss
228 | # acc_temp, f1_temp, nmi_temp, ari_temp, _ = clustering_metrics(gnd_Y, predY)
229 |
230 | acc_temp, f1_temp, nmi_temp, ari_temp, _ = clustering_metrics(gnd_Y, predY)
231 | if acc_temp > best_acc:
232 | best_acc = acc_temp
233 | best_f1 = f1_temp
234 | best_nmi = nmi_temp
235 | best_ari = ari_temp
236 |
237 | accs.append(best_acc)
238 | nmis.append(best_nmi)
239 | aris.append(best_ari)
240 | f1s.append(best_f1)
241 |
242 | # accs.append(acc_temp)
243 | # nmis.append(nmi_temp)
244 | # aris.append(ari_temp)
245 | # f1s.append(f1_temp)
246 |
247 | accs = np.stack(accs)
248 | nmis = np.stack(nmis)
249 | aris = np.stack(aris)
250 | f1s = np.stack(f1s)
251 | print(accs)
252 | print(style.YELLOW + '\nClustering result: ACC:{:.2f}'.format(accs.mean().item()),'$\pm$','{:.2f}'.format(accs.std().item()),\
253 | 'NMI:{:.2f}'.format(nmis.mean().item()),'$\pm$','{:.2f}'.format(nmis.std().item()),\
254 | 'ARI:{:.2f}'.format(aris.mean().item()),'$\pm$','{:.2f}'.format(aris.std().item()),\
255 | 'F1:{:.2f}'.format(f1s.mean().item()),'$\pm$','{:.2f}'.format(f1s.std().item()),)
256 |
257 |
--------------------------------------------------------------------------------
/utilis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 | import torch.nn as nn
5 |
6 | def set_seed(seed):
7 | """
8 | setup random seed to fix the result
9 | Args:
10 | seed: random seed
11 | Returns: None
12 | """
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 | torch.backends.cudnn.benchmark = False
19 | torch.backends.cudnn.deterministic = True
20 |
21 | class style():
22 | BLACK = '\033[30m'
23 | RED = '\033[31m'
24 | GREEN = '\033[32m'
25 | YELLOW = '\033[33m'
26 | BLUE = '\033[34m'
27 | MAGENTA = '\033[35m'
28 | CYAN = '\033[36m'
29 | WHITE = '\033[37m'
30 | UNDERLINE = '\033[4m'
31 | RESET = '\033[0m'
32 |
33 |
34 | def glorot_init(input_dim, output_dim):
35 | init_range = np.sqrt(6.0/(input_dim + output_dim))
36 | initial = torch.rand(input_dim, output_dim)*2*init_range - init_range
37 | return nn.Parameter(initial)
38 |
--------------------------------------------------------------------------------