├── README.md ├── data ├── bat │ ├── bat_adj.npy │ ├── bat_feat.npy │ └── bat_label.npy ├── citeseer.mat ├── cora.mat └── eat │ ├── eat_adj.npy │ ├── eat_feat.npy │ └── eat_label.npy ├── dataset.py ├── demo_bash.sh ├── env.txt ├── envPy37.txt ├── function_laplacian_diffusion.py ├── log └── 20230730_221042.log ├── main.py ├── model.py ├── run.sh ├── task.py └── utilis.py /README.md: -------------------------------------------------------------------------------- 1 | # Self-Contrastive Graph Diffusion Network 2 | Contrastive learning has been proven to be a successful approach in graph self-supervised learning. Augmentation techniques and sampling strategies are crucial in contrastive learning, but in most existing works, augmentation techniques require careful design, and their sampling strategies can only capture a small amount of intrinsic supervision information. Additionally, the existing methods require complex designs to obtain two different representations of the data. To overcome these limitations, we propose a novel framework called the Self-Contrastive Graph Diffusion Network (SCGDN). Our framework consists of two main components: the Attentional Module (AttM) and the Diffusion Module (DiFM). AttM aggregates higher-order structure and feature information to get an excellent embedding, while DiFM balances the state of each node in the graph through Laplacian diffusion learning and allows the cooperative evolution of adjacency and feature information in the graph. Unlike existing methodologies, SCGDN is an augmentation-free approach that avoids "sampling bias" and semantic drift, without the need for pre-training. We conduct a high-quality sampling of samples based on structure and feature information. If two nodes are neighbors, they are considered positive samples of each other. If two disconnected nodes are also unrelated on $k$NN graph, they are considered negative samples for each other. The contrastive objective reasonably uses our proposed sampling strategies, and the redundancy reduction term minimizes redundant information in the embedding and can well retain more discriminative information. In this novel framework, the graph self-contrastive learning paradigm gives expression to a powerful force. SCGDN effectively balances between preserving high-order structure information and avoiding overfitting. The results manifest that SCGDN can consistently generate outperformance over both the contrastive methods and the classical methods. 3 | 4 | ## Requirements 5 | > Dependencies (with python >= 3.7): Main dependencies are 6 | ``` 7 | torch==1.8.1 8 | torch-cluster==1.5.9 9 | torch-geometric==1.7.0 10 | torch-scatter==2.0.6 11 | torch-sparse==0.6.9 12 | torch-spline-conv==1.2.1 13 | torchdiffeq==0.2.1 14 | ``` 15 | > Commands to install all the dependencies in a new conda environment 16 | ```bash 17 | conda create --name grand python=3.7 18 | conda activate grand 19 | 20 | pip install ogb pykeops 21 | pip install torch==1.8.1 22 | pip install torchdiffeq -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html 23 | 24 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html 25 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html 26 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html 27 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.1+cu102.html 28 | 29 | # or 30 | conda create --name grand --file env.txt 31 | ``` 32 | 33 | ```bash 34 | mkdir best log 35 | ``` 36 | 37 | ```py 38 | python main.py --dataname cora --beta=3 --epochs=50 --hid_dim=512 --knn=21 --time=100 39 | # ACC:74.79$\pm$0.38 NMI:56.86$\pm$0.42 ARI:52.61$\pm$0.33 F1:70.42$\pm$0.48 40 | 41 | python main.py --dataname citeseer --beta=7 --epochs=50 --hid_dim=512 --knn=150 --time=160 42 | # ACC:69.62$\pm$0.02 NMI:44.35$\pm$0.03 ARI:45.43$\pm$0.02 F1:65.50$\pm$0.06 43 | 44 | python main.py --dataname amap --epochs=20 --knn=19 --hid_dim=512 --time=10.0 --beta=5.0 45 | # 78.91$\pm$0.00 NMI:72.53$\pm0.02 ARI:63.41$\pm$0.01 F1:75.27$\pm$0.01 46 | 47 | python main.py --dataname bat --beta 0.7 --epochs=25 --hid_dim=64 --knn=21 --time=200.0 48 | # ACC:74.73 $\pm$ 0.23 NMI:52.63 $\pm$ 0.11 ARI:47.65 $\pm$ 0.18 F1:74.49 $\pm$0.26 49 | 50 | python main.py --dataname eat --beta=6 --gamma=1.5 --epochs=30 --hid_dim=512 --knn=155 --time=15 51 | # ACC:51.88 $\pm$ 0.00 NMI:32.49 $\pm$ 0.21 ARI:23.86 $\pm$ 0.08 F1:47.62 $\pm$ 0.02 52 | 53 | python main.py --dataname corafull --beta=2 --epochs=12 --hid_dim=1024 --knn=73 --time=5 --gpu=-1 54 | ``` 55 | 56 | ## Citation 57 | 58 | If you find this project useful, please consider citing: 59 | 60 | ```bibtex 61 | @InProceedings{MaMM2023, 62 | author = {Yixuan Ma and Kun Zhan}, 63 | booktitle = {ACM MM}, 64 | title = {Self-contrastive graph diffusion network}, 65 | year = {2023}, 66 | volume = {31}, 67 | } 68 | ``` 69 | # Contact 70 | https://kunzhan.github.io/ 71 | 72 | If you have any questions, feel free to contact me. (Email: `ice.echo#gmail.com`) 73 | -------------------------------------------------------------------------------- /data/bat/bat_adj.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/bat/bat_adj.npy -------------------------------------------------------------------------------- /data/bat/bat_feat.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/bat/bat_feat.npy -------------------------------------------------------------------------------- /data/bat/bat_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/bat/bat_label.npy -------------------------------------------------------------------------------- /data/citeseer.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/citeseer.mat -------------------------------------------------------------------------------- /data/cora.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/cora.mat -------------------------------------------------------------------------------- /data/eat/eat_adj.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/eat/eat_adj.npy -------------------------------------------------------------------------------- /data/eat/eat_feat.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/eat/eat_feat.npy -------------------------------------------------------------------------------- /data/eat/eat_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/SCDGN/988e352c96a3a5bc97ccb1c7344913c78ac840e3/data/eat/eat_label.npy -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import pairwise 4 | import scipy 5 | import scipy.sparse as sp 6 | from torch_scatter import scatter_add 7 | from torch_geometric.utils import to_undirected, to_scipy_sparse_matrix, degree, add_remaining_self_loops 8 | from torch_geometric.utils.num_nodes import maybe_num_nodes 9 | 10 | from ipdb import set_trace 11 | 12 | def load_data(dataset_name, show_details=False): 13 | load_path = "./data/" + dataset_name + "/" + dataset_name 14 | feat = np.load(load_path+"_feat.npy", allow_pickle=True) 15 | label = np.load(load_path+"_label.npy", allow_pickle=True) 16 | adj = np.load(load_path+"_adj.npy", allow_pickle=True) 17 | if show_details: 18 | print("++++++++++++++++++++++++++++++") 19 | print("---details of graph dataset---") 20 | print("++++++++++++++++++++++++++++++") 21 | print("dataset name: ", dataset_name) 22 | print("feature shape: ", feat.shape) 23 | print("label shape: ", label.shape) 24 | print("adj shape: ", adj.shape) 25 | print("undirected edge num: ", int(np.nonzero(adj)[0].shape[0]/2)) 26 | print("category num: ", max(label)-min(label)+1) 27 | print("category distribution: ") 28 | for i in range(max(label)+1): 29 | print("label", i, end=":") 30 | print(len(label[np.where(label == i)])) 31 | print("++++++++++++++++++++++++++++++") 32 | 33 | return feat, label, adj 34 | 35 | 36 | def get_rw_adj(edge_index, norm_dim=1, fill_value=0., num_nodes=None, type='sys'): 37 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 38 | edge_weight = torch.ones((edge_index.size(1),), dtype=torch.float32, device=edge_index.device) 39 | 40 | if not fill_value == 0: 41 | edge_index, tmp_edge_weight = add_remaining_self_loops(edge_index, edge_weight, fill_value, num_nodes) 42 | assert tmp_edge_weight is not None 43 | edge_weight = tmp_edge_weight 44 | 45 | row, col = edge_index[0], edge_index[1] 46 | indices = row if norm_dim == 0 else col 47 | deg = scatter_add(edge_weight, indices, dim=0, dim_size=num_nodes) 48 | # deg_inv_sqrt = deg.pow_(-1) 49 | # edge_weight = deg_inv_sqrt[indices] * edge_weight if norm_dim == 0 else edge_weight * deg_inv_sqrt[indices] 50 | 51 | if type=='sys': 52 | deg_inv_sqrt = deg.pow_(-0.5) 53 | edge_weight = deg_inv_sqrt[indices] * edge_weight * deg_inv_sqrt[indices] 54 | else: 55 | deg_inv_sqrt = deg.pow_(-1) 56 | edge_weight = deg_inv_sqrt[indices] * edge_weight if norm_dim == 0 else edge_weight * deg_inv_sqrt[indices] 57 | return edge_index, edge_weight 58 | 59 | def adj_normalized(adj, type='sys'): 60 | row_sum = torch.sum(adj, dim=1) 61 | row_sum = (row_sum==0)*1+row_sum 62 | if type=='sys': 63 | d_inv_sqrt = torch.pow(row_sum, -0.5).flatten() 64 | d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0. 65 | d_mat_inv_sqrt = torch.diag(d_inv_sqrt) 66 | return d_mat_inv_sqrt.mm(adj).mm(d_mat_inv_sqrt) 67 | else: 68 | d_inv = torch.pow(row_sum, -1).flatten() 69 | d_inv[torch.isinf(d_inv)] = 0. 70 | d_mat_inv = torch.diag(d_inv) 71 | return d_mat_inv.mm(adj) 72 | 73 | def FeatureNormalize(mx): 74 | """Row-normalize sparse matrix""" 75 | rowsum = np.array(mx.sum(1)) 76 | rowsum = (rowsum == 0) * 1 + rowsum #!!!!! 77 | r_inv = np.power(rowsum, -1).flatten() 78 | r_inv[np.isinf(r_inv)] = 0.0 79 | r_mat_inv = sp.diags(r_inv) 80 | mx = r_mat_inv.dot(mx) 81 | return mx 82 | 83 | def compute_knn(args, features, distribution='t-distribution'): 84 | features = FeatureNormalize(features) 85 | # Dis = pairwise.euclidean_distances(self.data, self.data) 86 | # Dis = pairwise.manhattan_distances(self.data, self.data) 87 | # Dis = pairwise.haversine_distances(self.data, self.data) 88 | Dis = pairwise.cosine_distances(features, features) 89 | Dis = Dis/np.max(np.max(Dis, 1)) 90 | if distribution=='t-distribution': 91 | gamma = CalGamma(args.v_input) 92 | sim = gamma * np.sqrt(2 * np.pi) * np.power((1 + args.sigma*np.power(Dis,2) / args.v_input), -1 * (args.v_input + 1) / 2) 93 | else: 94 | sim = np.exp(-Dis/(args.sigma**2)) 95 | 96 | K = args.knn 97 | if K>0: 98 | idx = sim.argsort()[:,::-1] 99 | sim_new = np.zeros_like(sim) 100 | for ii in range(0, len(sim_new)): 101 | sim_new[ii, idx[ii,0:K]] = sim[ii, idx[ii,0:K]] 102 | Disknn = (sim_new + sim_new.T)/2 103 | else: 104 | Disknn = (sim + sim.T)/2 105 | 106 | Disknn = torch.from_numpy(Disknn).type(torch.FloatTensor) 107 | Disknn = torch.add(torch.eye(Disknn.shape[0]), Disknn) 108 | Disknn = adj_normalized(Disknn) 109 | 110 | return Disknn 111 | 112 | def CalGamma(v): 113 | a = scipy.special.gamma((v + 1) / 2) 114 | b = np.sqrt(v * np.pi) * scipy.special.gamma(v / 2) 115 | out = a / b 116 | return out 117 | 118 | 119 | # def cal_norm(edge_index0, args, feat=None, cut=False, num_nodes=None): 120 | # # calculate normalization factors: (2*D)^{-1/2} or (D)^{-1/2} 121 | # edge_index0 = sp.coo_matrix(edge_index0) 122 | # values = edge_index0.data 123 | # indices = np.vstack((edge_index0.row, edge_index0.col)) 124 | # edge_index0 = torch.LongTensor(indices).to(args.device) 125 | 126 | # edge_weight = torch.ones((edge_index0.size(1),), dtype=torch.float32, device=args.device) 127 | # edge_index, _ = add_remaining_self_loops(edge_index0, edge_weight, 0, args.N) 128 | 129 | # if num_nodes is None: 130 | # num_nodes = edge_index.max()+1 131 | # D = degree(edge_index[0], num_nodes) # 传入edge_index[0]计算节点出度, 该处为无向图,所以即计算节点度 132 | 133 | # if cut: 134 | # D = torch.sqrt(1/D) 135 | # D[D == float("inf")] = 0. 136 | # edge_index = to_undirected(edge_index, num_nodes=num_nodes) 137 | # row, col = edge_index 138 | # mask = row&1 | tee ./log/$TRAINING_LOG 7 | -------------------------------------------------------------------------------- /env.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | asttokens=2.2.1=pypi_0 7 | backcall=0.2.0=pypi_0 8 | blas=1.0=mkl 9 | brotlipy=0.7.0=py38h27cfd23_1003 10 | ca-certificates=2022.10.11=h06a4308_0 11 | certifi=2022.12.7=py38h06a4308_0 12 | cffi=1.15.1=py38h5eee18b_3 13 | charset-normalizer=2.0.4=pyhd3eb1b0_0 14 | contourpy=1.0.7=pypi_0 15 | cryptography=38.0.4=py38h9ce1e76_0 16 | cudatoolkit=11.3.1=h2bc3f7f_2 17 | cycler=0.11.0=pypi_0 18 | decorator=5.1.1=pypi_0 19 | dgl-cuda10.2=0.9.1post1=py38_0 20 | executing=1.2.0=pypi_0 21 | fftw=3.3.9=h27cfd23_1 22 | flit-core=3.6.0=pyhd3eb1b0_0 23 | fonttools=4.38.0=pypi_0 24 | gcl=0.6.11=pypi_0 25 | idna=3.4=py38h06a4308_0 26 | imageio=2.24.0=pypi_0 27 | intel-openmp=2021.4.0=h06a4308_3561 28 | ipdb=0.13.11=pypi_0 29 | ipython=8.8.0=pypi_0 30 | jedi=0.18.2=pypi_0 31 | jinja2=3.1.2=py38h06a4308_0 32 | joblib=1.1.1=py38h06a4308_0 33 | kiwisolver=1.4.4=pypi_0 34 | ld_impl_linux-64=2.38=h1181459_1 35 | libffi=3.4.2=h6a678d5_6 36 | libgcc-ng=11.2.0=h1234567_1 37 | libgfortran-ng=11.2.0=h00389a5_1 38 | libgfortran5=11.2.0=h1234567_1 39 | libgomp=11.2.0=h1234567_1 40 | libstdcxx-ng=11.2.0=h1234567_1 41 | littleutils=0.2.2=pypi_0 42 | markupsafe=2.1.1=py38h7f8727e_0 43 | matplotlib=3.6.3=pypi_0 44 | matplotlib-inline=0.1.6=pypi_0 45 | mkl=2021.4.0=h06a4308_640 46 | mkl-service=2.4.0=py38h7f8727e_0 47 | mkl_fft=1.3.1=py38hd3c417c_0 48 | mkl_random=1.2.2=py38h51133e4_0 49 | munkres=1.1.4=pypi_0 50 | ncurses=6.3=h5eee18b_3 51 | networkx=3.0=pypi_0 52 | numpy=1.23.5=py38h14f4228_0 53 | numpy-base=1.23.5=py38h31eccc5_0 54 | ogb=1.3.5=pypi_0 55 | openssl=1.1.1s=h7f8727e_0 56 | outdated=0.2.2=pypi_0 57 | packaging=23.0=pypi_0 58 | pandas=1.5.3=pypi_0 59 | parso=0.8.3=pypi_0 60 | pexpect=4.8.0=pypi_0 61 | pickleshare=0.7.5=pypi_0 62 | pillow=9.4.0=pypi_0 63 | pip=22.3.1=py38h06a4308_0 64 | prompt-toolkit=3.0.36=pypi_0 65 | psutil=5.9.0=py38h5eee18b_0 66 | ptyprocess=0.7.0=pypi_0 67 | pure-eval=0.2.2=pypi_0 68 | pycparser=2.21=pyhd3eb1b0_0 69 | pyg=2.2.0=py38_torch_1.12.0_cu113 70 | pygcl=0.1.2=pypi_0 71 | pygments=2.14.0=pypi_0 72 | pyopenssl=22.0.0=pyhd3eb1b0_0 73 | pyparsing=3.0.9=py38h06a4308_0 74 | pysocks=1.7.1=py38h06a4308_0 75 | python=3.8.16=h7a1cb2a_2 76 | python-dateutil=2.8.2=pypi_0 77 | pytorch=1.12.0=py3.8_cuda11.3_cudnn8.3.2_0 78 | pytorch-cluster=1.6.0=py38_torch_1.12.0_cu113 79 | pytorch-mutex=1.0=cuda 80 | pytorch-scatter=2.1.0=py38_torch_1.12.0_cu113 81 | pytorch-sparse=0.6.16=py38_torch_1.12.0_cu113 82 | pytz=2022.7.1=pypi_0 83 | readline=8.2=h5eee18b_0 84 | requests=2.28.1=py38h06a4308_0 85 | scikit-learn=1.2.0=py38h6a678d5_0 86 | scipy=1.9.3=py38h14f4228_0 87 | seaborn=0.12.2=pypi_0 88 | setuptools=65.6.3=py38h06a4308_0 89 | six=1.16.0=pyhd3eb1b0_1 90 | sqlite=3.40.1=h5082296_0 91 | stack-data=0.6.2=pypi_0 92 | threadpoolctl=2.2.0=pyh0d69192_0 93 | tk=8.6.12=h1ccaba5_0 94 | tomli=2.0.1=pypi_0 95 | torchdiffeq=0.2.3=pypi_0 96 | tqdm=4.64.1=py38h06a4308_0 97 | traitlets=5.8.1=pypi_0 98 | typing_extensions=4.4.0=py38h06a4308_0 99 | urllib3=1.26.14=py38h06a4308_0 100 | wcwidth=0.2.6=pypi_0 101 | wheel=0.37.1=pyhd3eb1b0_0 102 | xz=5.2.10=h5eee18b_1 103 | zlib=1.2.13=h5eee18b_0 104 | -------------------------------------------------------------------------------- /envPy37.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | ase=3.22.1=pypi_0 7 | backcall=0.2.0=pypi_0 8 | ca-certificates=2023.05.30=h06a4308_0 9 | certifi=2022.12.7=py37h06a4308_0 10 | charset-normalizer=3.2.0=pypi_0 11 | cycler=0.11.0=pypi_0 12 | decorator=5.1.1=pypi_0 13 | fonttools=4.38.0=pypi_0 14 | googledrivedownloader=0.4=pypi_0 15 | h5py=3.8.0=pypi_0 16 | idna=3.4=pypi_0 17 | imageio=2.31.1=pypi_0 18 | importlib-metadata=4.13.0=pypi_0 19 | ipdb=0.13.13=pypi_0 20 | ipython=7.34.0=pypi_0 21 | isodate=0.6.1=pypi_0 22 | jedi=0.19.0=pypi_0 23 | jinja2=3.1.2=pypi_0 24 | joblib=1.3.1=pypi_0 25 | keopscore=2.1.2=pypi_0 26 | kiwisolver=1.4.4=pypi_0 27 | ld_impl_linux-64=2.38=h1181459_1 28 | libffi=3.4.4=h6a678d5_0 29 | libgcc-ng=11.2.0=h1234567_1 30 | libgomp=11.2.0=h1234567_1 31 | libstdcxx-ng=11.2.0=h1234567_1 32 | littleutils=0.2.2=pypi_0 33 | llvmlite=0.39.1=pypi_0 34 | markupsafe=2.1.3=pypi_0 35 | matplotlib=3.5.3=pypi_0 36 | matplotlib-inline=0.1.6=pypi_0 37 | munkres=1.1.4=pypi_0 38 | ncurses=6.4=h6a678d5_0 39 | networkx=2.6.3=pypi_0 40 | numba=0.56.4=pypi_0 41 | numpy=1.21.6=pypi_0 42 | nvidia-cublas-cu11=11.10.3.66=pypi_0 43 | nvidia-cuda-nvrtc-cu11=11.7.99=pypi_0 44 | nvidia-cuda-runtime-cu11=11.7.99=pypi_0 45 | nvidia-cudnn-cu11=8.5.0.96=pypi_0 46 | ogb=1.3.6=pypi_0 47 | openssl=1.1.1u=h7f8727e_0 48 | outdated=0.2.2=pypi_0 49 | packaging=23.1=pypi_0 50 | pandas=1.3.5=pypi_0 51 | parso=0.8.3=pypi_0 52 | pexpect=4.8.0=pypi_0 53 | pickleshare=0.7.5=pypi_0 54 | pillow=9.5.0=pypi_0 55 | pip=22.3.1=py37h06a4308_0 56 | prompt-toolkit=3.0.39=pypi_0 57 | psutil=5.9.5=pypi_0 58 | ptyprocess=0.7.0=pypi_0 59 | pybind11=2.11.1=pypi_0 60 | pygments=2.15.1=pypi_0 61 | pykeops=2.1.2=pypi_0 62 | pyparsing=3.1.0=pypi_0 63 | python=3.7.16=h7a1cb2a_0 64 | python-dateutil=2.8.2=pypi_0 65 | python-louvain=0.16=pypi_0 66 | pytz=2023.3=pypi_0 67 | rdflib=6.3.2=pypi_0 68 | readline=8.2=h5eee18b_0 69 | requests=2.31.0=pypi_0 70 | scikit-learn=1.0.2=pypi_0 71 | scipy=1.7.3=pypi_0 72 | setuptools=65.6.3=py37h06a4308_0 73 | six=1.16.0=pypi_0 74 | sqlite=3.41.2=h5eee18b_0 75 | threadpoolctl=3.1.0=pypi_0 76 | tk=8.6.12=h1ccaba5_0 77 | tomli=2.0.1=pypi_0 78 | torch=1.8.1=pypi_0 79 | torch-cluster=1.5.9=pypi_0 80 | torch-geometric=1.7.0=pypi_0 81 | torch-scatter=2.1.1=pypi_0 82 | torch-sparse=0.6.9=pypi_0 83 | torch-spline-conv=1.2.2=pypi_0 84 | torchdiffeq=0.2.3=pypi_0 85 | tqdm=4.65.0=pypi_0 86 | traitlets=5.9.0=pypi_0 87 | typing-extensions=4.7.1=pypi_0 88 | urllib3=2.0.4=pypi_0 89 | wcwidth=0.2.6=pypi_0 90 | wheel=0.38.4=py37h06a4308_0 91 | xz=5.4.2=h5eee18b_0 92 | zipp=3.15.0=pypi_0 93 | zlib=1.2.13=h5eee18b_0 94 | -------------------------------------------------------------------------------- /function_laplacian_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch_sparse 4 | from torch_geometric.nn.conv import MessagePassing 5 | from ipdb import set_trace 6 | 7 | class ODEFunc(MessagePassing): 8 | # currently requires in_features = out_features 9 | def __init__(self, opt, data, device): 10 | super(ODEFunc, self).__init__() 11 | self.opt = opt 12 | self.device = device 13 | self.edge_index = None 14 | self.edge_weight = None 15 | self.attention_weights = None 16 | self.alpha_train = nn.Parameter(torch.tensor(0.0)) 17 | self.beta_train = nn.Parameter(torch.tensor(0.0)) 18 | self.x0 = None 19 | self.nfe = 0 20 | self.alpha_sc = nn.Parameter(torch.ones(1)) 21 | self.beta_sc = nn.Parameter(torch.ones(1)) 22 | 23 | def __repr__(self): 24 | return self.__class__.__name__ 25 | 26 | # Define the ODE function. 27 | # Input: 28 | # --- t: A tensor with shape [], meaning the current time. 29 | # --- x: A tensor with shape [#batches, dims], meaning the value of x at t. 30 | # Output: 31 | # --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t. 32 | class LaplacianODEFunc(ODEFunc): 33 | # currently requires in_features = out_features 34 | def __init__(self, args, data, device): 35 | super(LaplacianODEFunc, self).__init__(args, data, device) 36 | self.args = args 37 | 38 | def forward(self, t, x): # the t param is needed by the ODE solver. 39 | self.nfe += 1 40 | ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x) 41 | 42 | alpha = torch.sigmoid(self.alpha_train) 43 | 44 | f = alpha * (ax - x) 45 | if self.args.add_source: 46 | f = f + self.beta_train * self.x0 47 | 48 | return f 49 | -------------------------------------------------------------------------------- /log/20230730_221042.log: -------------------------------------------------------------------------------- 1 | Namespace(add_source=True, beta=3.0, cut=False, dataname='cora', device='cuda:0', dropout=0.0, epochs=50, exp_lr=1e-05, exp_wd=1e-05, gamma=1, gpu=0, hid_dim=512, imp_lr=0.001, imp_wd=1e-05, knn=21, method='dopri5', n_layers=2, patience=20, seed=42, sigma=0.5, time=100.0, tol_scale=200, train=1, type='sys', v_input=1) 2 | ./best/cora_best.pt 3 | [75.14771049 74.29837518 74.77843427 74.29837518 75.18463811 74.29837518 4 | 75.11078287 75.11078287 75.07385524 75.14771049] 5 | 6 | Clustering result: ACC:74.84 $\pm$ 0.37 NMI:56.89 $\pm$ 0.39 ARI:52.64 $\pm$ 0.34 F1:70.50 $\pm$ 0.46 7 | 8 | Namespace(add_source=True, beta=7.0, cut=False, dataname='citeseer', device='cuda:0', dropout=0.0, epochs=50, exp_lr=1e-05, exp_wd=1e-05, gamma=1, gpu=0, hid_dim=512, imp_lr=0.001, imp_wd=1e-05, knn=150, method='dopri5', n_layers=2, patience=20, seed=42, sigma=0.5, time=160.0, tol_scale=200, train=1, type='sys', v_input=1) 9 | ./best/citeseer_best.pt 10 | [69.58220619 69.55214908 69.55214908 69.55214908 69.55214908 69.55214908 11 | 69.58220619 69.55214908 69.55214908 69.55214908] 12 | 13 | Clustering result: ACC:69.56 $\pm$ 0.01 NMI:44.29 $\pm$ 0.02 ARI:45.34 $\pm$ 0.00 F1:65.43 $\pm$ 0.04 14 | 15 | Namespace(add_source=True, beta=0.7, cut=False, dataname='bat', device='cuda:0', dropout=0.0, epochs=25, exp_lr=1e-05, exp_wd=1e-05, gamma=1, gpu=0, hid_dim=64, imp_lr=0.001, imp_wd=1e-05, knn=21, method='dopri5', n_layers=2, patience=20, seed=42, sigma=0.5, time=200.0, tol_scale=200, train=1, type='sys', v_input=1) 16 | ./best/bat_best.pt 17 | [74.04580153 74.80916031 74.80916031 74.80916031 74.80916031 74.80916031 18 | 74.80916031 74.80916031 74.80916031 74.80916031] 19 | 20 | Clustering result: ACC:74.73 $\pm$ 0.23 NMI:52.63 $\pm$ 0.11 ARI:47.65 $\pm$ 0.18 F1:74.49 $\pm$ 0.26 21 | 22 | Namespace(add_source=True, beta=6.0, cut=False, dataname='eat', device='cuda:0', dropout=0.0, epochs=30, exp_lr=1e-05, exp_wd=1e-05, gamma=1.5, gpu=0, hid_dim=512, imp_lr=0.001, imp_wd=1e-05, knn=155, method='dopri5', n_layers=2, patience=20, seed=42, sigma=0.5, time=15.0, tol_scale=200, train=1, type='sys', v_input=1) 23 | ./best/eat_best.pt 24 | [51.87969925 51.87969925 51.87969925 51.87969925 51.87969925 51.87969925 25 | 51.87969925 51.87969925 51.87969925 51.87969925] 26 | 27 | Clustering result: ACC:51.88 $\pm$ 0.00 NMI:32.44 $\pm$ 0.19 ARI:23.87 $\pm$ 0.08 F1:47.63 $\pm$ 0.02 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from scipy import io 2 | import math 3 | import argparse 4 | from time import * 5 | import torch 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | from torch.optim.lr_scheduler import MultiStepLR 9 | import warnings 10 | import numpy as np 11 | 12 | from model import SCDGN as Net 13 | from dataset import * 14 | from task import * 15 | from utilis import * 16 | 17 | from ipdb import set_trace 18 | 19 | warnings.filterwarnings('ignore') 20 | 21 | 22 | if __name__ == '__main__': 23 | # import faulthandler 24 | # faulthandler.enable() 25 | parser = argparse.ArgumentParser(description='ICML') 26 | 27 | parser.add_argument('--dataname', type=str, default='cora', help='Name of dataset.') 28 | parser.add_argument('--gpu', type=int, default=0, help='GPU index.') 29 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 30 | parser.add_argument('--epochs', type=int, default=40, help='Training epochs.') 31 | parser.add_argument('--patience', type=int, default=20, help='Patience for early stop.') 32 | parser.add_argument('--train', type=int, default=1, help='Train or not.') 33 | 34 | # Dataset args 35 | parser.add_argument('--cut', type=str, default=False, help='The type of degree.') ## 36 | parser.add_argument('--type', type=str,default='sys',help='sys or rw') ## 37 | parser.add_argument('--knn', type=int, default=25, help='The K of KNN graph.') ## 38 | parser.add_argument('--v_input', type=int, default=1, help='Degree of freedom of T distribution') ## 39 | parser.add_argument('--sigma', type=float, default=0.5, help='Weight parameters for knn.') ## 40 | 41 | # Optimizer args 42 | parser.add_argument('--imp_lr', type=float, default=1e-3, help='Learning rate of ICML.') ## 43 | parser.add_argument('--exp_lr', type=float, default=1e-5, help='Learning rate of ICML.') ## 44 | parser.add_argument('--imp_wd', type=float, default=1e-5, help='Weight decay of ICML.') 45 | parser.add_argument('--exp_wd', type=float, default=1e-5, help='Weight decay of ICML.') 46 | 47 | # GNN args 48 | parser.add_argument("--hid_dim", type=int, default=512, help='Hidden layer dim.') ## 49 | parser.add_argument('--time', type=float, default=18, help='End time of ODE integrator.') ## 50 | parser.add_argument('--method', type=str, default='dopri5', help="set the numerical solver: dopri5, euler, rk4, midpoint") 51 | parser.add_argument('--tol_scale', type=float, default=200, help='tol_scale .') ## 52 | parser.add_argument('--add_source', type=str, default=True, help='Add source.') 53 | parser.add_argument('--dropout', type=float, default=0., help='drop rate.') ## 54 | parser.add_argument('--n_layers', type=int, default=2, help='number of Linear.') ## 55 | 56 | # Loss args 57 | parser.add_argument('--beta', type=float, default=1, help='Weight parameters for loss.') 58 | parser.add_argument('--gamma', type=float, default=1, help='Weight parameters for ICML.') 59 | 60 | args = parser.parse_args() 61 | 62 | # check cuda 63 | if args.gpu != -1 and torch.cuda.is_available(): 64 | args.device = 'cuda:{}'.format(args.gpu) 65 | else: 66 | args.device = 'cpu' 67 | set_seed(args.seed) 68 | print(args) 69 | begin_time = time() 70 | 71 | # Load data 72 | if args.dataname in ['amap', 'bat', 'eat', 'uat','corafull']: 73 | feat, label, A = load_data(args.dataname) 74 | labels = torch.from_numpy(label) 75 | else: 76 | data = io.loadmat('./data/{}.mat'.format(args.dataname)) 77 | if args.dataname == 'wiki': 78 | feat = data['fea'].todense() 79 | A = data['W'].todense() 80 | elif args.dataname == 'pubmed': 81 | feat = data['fea'] 82 | A = data['W'].todense() 83 | else: 84 | feat = data['fea'] 85 | A = np.mat(data['W']) 86 | gnd = data['gnd'].T - 1 87 | labels = torch.from_numpy(gnd[0, :]) 88 | 89 | feat = torch.from_numpy(feat).type(torch.FloatTensor) 90 | in_dim = feat.shape[1] 91 | args.N = N = feat.shape[0] 92 | norm_factor, edge_index, edge_weight, adj_norm, knn, Lap = cal_norm(A, args, feat) 93 | Lap_Neg = cal_Neg(adj_norm, knn, args) 94 | feat = feat.to(args.device) 95 | 96 | 97 | # Initial 98 | model = Net(N, edge_index, edge_weight, args).to(args.device) 99 | optimizer = optim.Adam([{'params':model.params_imp,'weight_decay':args.imp_wd, 'lr': args.imp_lr}, 100 | {'params':model.params_exp,'weight_decay':args.exp_wd, 'lr': args.exp_lr}]) 101 | 102 | 103 | checkpt_file = './best/'+args.dataname+'_best.pt' 104 | print(checkpt_file) 105 | if args.train: 106 | cnt_wait = 0 107 | best_loss = 1e9 108 | best_epoch = 0 109 | best_acc = 0 110 | EYE = torch.eye(args.N).to(args.device) 111 | for epoch in range(1,args.epochs+1): 112 | model.train() 113 | optimizer.zero_grad() 114 | 115 | emb = model(knn, adj_norm, norm_factor) 116 | loss =( torch.trace(torch.mm(torch.mm(emb.t(), Lap), emb)) \ 117 | - args.beta*(torch.trace(torch.mm(torch.mm(emb.t(), Lap_Neg), emb))) \ 118 | + args.gamma*nn.MSELoss()(torch.mm(emb,emb.t()), EYE))/args.N 119 | 120 | loss.backward() 121 | optimizer.step() 122 | 123 | if loss <= best_loss: 124 | best_loss = loss 125 | best_epoch = epoch 126 | cnt_wait = 0 127 | # acc, nmi, ari, f1 = clustering(emb.cpu().detach(), labels) 128 | # print(style.YELLOW + '\nClustering result: ACC:%1.2f || NMI:%1.2f || RI:%1.2f || F-score:%1.2f '%(acc, nmi, ari, f1)) 129 | torch.save(model.state_dict(), checkpt_file) 130 | else: 131 | cnt_wait += 1 132 | if cnt_wait == args.patience or math.isnan(loss): 133 | print('\nEarly stopping!', end='') 134 | break 135 | 136 | # print(style.MAGENTA + '\r\rEpoch={:03d}, loss={:.4f}'.format(epoch, loss.item()), end=' ') 137 | # print('') 138 | 139 | model.load_state_dict(torch.load(checkpt_file)) 140 | model.eval() 141 | with torch.no_grad(): 142 | emb = model(knn, adj_norm, norm_factor) 143 | 144 | # Clustering 145 | emb = emb.cpu().detach().numpy() 146 | # TSNE_plot(emb, labels, args.dataname) 147 | clustering(emb, labels,args) 148 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchdiffeq import odeint 5 | from torch.nn import Parameter 6 | 7 | 8 | from ipdb import set_trace 9 | from utilis import glorot_init 10 | from function_laplacian_diffusion import * 11 | 12 | class AGCN(nn.Module): 13 | def __init__(self, num_nodes): 14 | super(AGCN, self).__init__() 15 | self.n = num_nodes 16 | self.w1 = Parameter(torch.FloatTensor(self.n,self.n)) 17 | self.w1.data = torch.eye(self.n) 18 | self.w2 = Parameter(torch.FloatTensor(self.n,self.n)) 19 | self.w2.data = torch.eye(self.n) 20 | 21 | def forward(self, X, A): 22 | H = torch.mm(torch.mm(A, self.w1), A.T) # 图结构自注意 23 | H = torch.mm(torch.mm(H, self.w2), X) 24 | embed = torch.mm(H, H.T) 25 | embed = F.normalize(embed, dim=1) 26 | 27 | return embed 28 | 29 | 30 | class GCN(nn.Module): 31 | def __init__(self, input_dim, activation = F.relu, **kwargs): 32 | super(GCN, self).__init__(**kwargs) 33 | self.weight = glorot_init(input_dim, input_dim) 34 | self.activation = activation 35 | 36 | def forward(self, x, adj): 37 | x = torch.mm(x, self.weight) 38 | x = torch.mm(adj, x) 39 | outputs = self.activation(x) 40 | return outputs 41 | 42 | 43 | class ConstantODEblock(nn.Module): 44 | def __init__(self, args, edge_index, edge_weight): 45 | super(ConstantODEblock, self).__init__() 46 | self.args = args 47 | self.t = torch.tensor([0, args.time]).to(args.device) 48 | 49 | self.odefunc = LaplacianODEFunc(args, edge_index, args.device) 50 | 51 | self.odefunc.edge_index = edge_index.to(args.device) 52 | self.odefunc.edge_weight = edge_weight.to(args.device) 53 | 54 | self.train_integrator = odeint 55 | self.test_integrator = odeint 56 | self.atol = args.tol_scale * 1e-7 # Absolute tolerance. 57 | self.rtol = args.tol_scale * 1e-9 # Relative tolerance 58 | 59 | def set_x0(self, x0): ## 设置初始条件 60 | self.odefunc.x0 = x0.clone().detach() 61 | 62 | def forward(self, x): 63 | t = self.t.type_as(x) # 设置迭代总时间 t 64 | integrator = self.train_integrator if self.training else self.test_integrator # odeint 积分求解器 65 | 66 | func = self.odefunc 67 | state = x 68 | 69 | state_dt = integrator( 70 | func, state, t, 71 | method= self.args.method, 72 | options = dict(step_size=1, max_iters=100), 73 | atol=self.atol, # Absolute tolerance. 74 | rtol=self.rtol) # Relative tolerance. 75 | 76 | z = state_dt[1] 77 | return z 78 | 79 | 80 | class SCDGN(nn.Module): 81 | def __init__(self, N, edge_index, edge_weight, args): 82 | super().__init__() 83 | self.edge_weight = edge_weight 84 | self.edge_index = edge_index 85 | self.n_layers = args.n_layers 86 | 87 | self.AttenGCN = AGCN(N) 88 | 89 | self.extractor = nn.ModuleList() 90 | self.extractor.append(nn.Linear(N, args.hid_dim)) 91 | for i in range(self.n_layers - 1): 92 | self.extractor.append(nn.Linear(args.hid_dim, args.hid_dim)) 93 | self.dropout = nn.Dropout(p=args.dropout) 94 | 95 | self.diffusion = ConstantODEblock(args, edge_index,edge_weight ) 96 | 97 | self.init_weights() 98 | 99 | self.params_imp = list(self.diffusion.parameters()) 100 | self.params_exp = list(self.AttenGCN.parameters()) \ 101 | + list(self.extractor.parameters()) 102 | 103 | def init_weights(self): 104 | for m in self.modules(): 105 | if isinstance(m, nn.Linear): 106 | nn.init.kaiming_normal_(m.weight, nonlinearity='relu') 107 | if m.bias is not None: 108 | nn.init.zeros_(m.bias) 109 | 110 | def forward(self, knn, adj, norm_factor): 111 | # 联合学习结构和特征 112 | h = self.AttenGCN(knn,adj) 113 | 114 | for i, layer in enumerate(self.extractor): 115 | if i != 0: 116 | h = self.dropout(h) 117 | h = layer(h) 118 | 119 | # 隐式 diffusion 120 | self.diffusion.set_x0(h) # 设置初始边界为 h 121 | new_z = self.diffusion(h) 122 | # z = norm_factor * new_z + h 123 | # z = F.tanh(norm_factor * new_z + h) # 输入特征(初值)与平衡态之和(通量:描述两个节点之间信息流的大小) 124 | z = F.relu(norm_factor * new_z + h) 125 | z = (z - z.mean(0)) / z.std() 126 | 127 | return z 128 | 129 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | python main.py --dataname cora --beta=3 --epochs=50 --hid_dim=512 --knn=21 --time=100 4 | # ACC:74.79$\pm$0.38 NMI:56.86$\pm$0.42 ARI:52.61$\pm$0.33 F1:70.42$\pm$0.48 5 | 6 | python main.py --dataname citeseer --beta=7 --epochs=50 --hid_dim=512 --knn=150 --time=160 7 | # ACC:69.62$\pm$0.02 NMI:44.35$\pm$0.03 ARI:45.43$\pm$0.02 F1:65.50$\pm$0.06 8 | 9 | # python main.py --dataname amap --epochs=20 --knn=19 --hid_dim=512 --time=10.0 --beta=5.0 10 | # 78.91$\pm$0.00 NMI:72.53$\pm0.02 ARI:63.41$\pm$0.01 F1:75.27$\pm$0.01 11 | 12 | python main.py --dataname bat --beta 0.7 --epochs=25 --hid_dim=64 --knn=21 --time=200.0 13 | # ACC:74.73 $\pm$ 0.23 NMI:52.63 $\pm$ 0.11 ARI:47.65 $\pm$ 0.18 F1:74.49 $\pm$0.26 14 | 15 | python main.py --dataname eat --beta=6 --gamma=1.5 --epochs=30 --hid_dim=512 --knn=155 --time=15 16 | # ACC:51.88 $\pm$ 0.00 NMI:32.49 $\pm$ 0.21 ARI:23.86 $\pm$ 0.08 F1:47.62 $\pm$ 0.02 17 | -------------------------------------------------------------------------------- /task.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import matplotlib.pyplot as plt 7 | from time import * 8 | import os 9 | import imageio 10 | from sklearn.decomposition import PCA 11 | from sklearn.manifold import TSNE 12 | 13 | import numpy as np 14 | from sklearn import metrics 15 | from munkres import Munkres 16 | from sklearn.cluster import KMeans 17 | from scipy.optimize import linear_sum_assignment 18 | from utilis import style 19 | 20 | from ipdb import set_trace 21 | 22 | 23 | class GIFPloter(): 24 | def __init__(self, ): 25 | self.path_list = [] 26 | 27 | def PlotOtherLayer(self,fig,data,label,title='',fig_position0=1,fig_position1=1,fig_position2=1,s=0.1,graph=None,link=None,): 28 | color_list = [] 29 | for i in range(label.shape[0]): 30 | color_list.append(int(label[i])) 31 | 32 | if data.shape[1] > 3: 33 | pca = PCA(n_components=2) 34 | data_em = pca.fit_transform(data) 35 | else: 36 | data_em = data 37 | 38 | # data_em = data_em-data_em.mean(axis=0) 39 | 40 | if data_em.shape[1] == 3: 41 | ax = fig.add_subplot(fig_position0, fig_position1, fig_position2, projection='3d') 42 | 43 | ax.scatter(data_em[:, 0], data_em[:, 1], data_em[:, 2], c=color_list, s=s, cmap='rainbow') 44 | 45 | if data_em.shape[1] == 2: 46 | ax = fig.add_subplot(fig_position0, fig_position1, fig_position2) 47 | 48 | if graph is not None: 49 | self.PlotGraph(data, graph, link) 50 | 51 | s = ax.scatter(data_em[:, 0], data_em[:, 1], c=label, s=s, cmap='rainbow') 52 | plt.axis('equal') 53 | if None: 54 | list_i_n = len(set(label.tolist())) 55 | # print(list_i_n) 56 | legend1 = ax.legend(*s.legend_elements(num=list_i_n), 57 | loc="upper left", 58 | title="Ranking") 59 | ax.add_artist(legend1) 60 | # ax.spines['top'].set_visible(False) 61 | # ax.spines['right'].set_visible(False) 62 | # ax.spines['bottom'].set_visible(False) 63 | # ax.spines['left'].set_visible(False) 64 | # plt.xticks([]) 65 | # plt.yticks([]) 66 | # plt.title(title) 67 | 68 | def AddNewFig(self,latent,label,link=None,graph=None,his_loss=None,title_='',path='./',dataset=None): 69 | fig = plt.figure(figsize=(5, 5)) 70 | 71 | if latent.shape[0] <= 1000: s=3 72 | elif latent.shape[0] <= 10000: s = 1 73 | else: s = 0.1 74 | 75 | # if latent.shape[1] <= 3: 76 | self.PlotOtherLayer(fig, latent, label, title=title_, fig_position0=1, fig_position1=1, fig_position2=1, graph=graph, link=link, s=s) 77 | plt.tight_layout() 78 | path_c = path + title_ 79 | 80 | self.path_list.append(path_c) 81 | 82 | plt.savefig(path_c, dpi=100) 83 | plt.close() 84 | 85 | def PlotGraph(self, latent, graph, link): 86 | for i in range(graph.shape[0]): 87 | for j in range(graph.shape[0]): 88 | if graph[i, j] == True: 89 | p1 = latent[i] 90 | p2 = latent[j] 91 | lik = link[i, j] 92 | plt.plot([p1[0], p2[0]], [p1[1], p2[1]], 93 | 'gray', 94 | lw=1 / lik) 95 | if lik > link.min() * 1.01: 96 | plt.text((p1[0] + p2[0]) / 2, (p1[1] + p2[1]) / 2, 97 | str(lik)[:4], 98 | fontsize=5) 99 | 100 | def SaveGIF(self): 101 | gif_images = [] 102 | for i, path_ in enumerate(self.path_list): 103 | gif_images.append(imageio.imread(path_)) 104 | if i > 0 and i < len(self.path_list)-2: 105 | os.remove(path_) 106 | imageio.mimsave(path_[:-4] + ".gif", gif_images, fps=3) 107 | 108 | def TSNE_plot(X, label, str): 109 | em = TSNE(n_components=2,random_state=6).fit_transform(X) 110 | ploter = GIFPloter() 111 | ploter.AddNewFig(em, label, title_= str+".png", path='./figure/',) 112 | 113 | 114 | # Clustering metrics 115 | def spectral(W, k): 116 | """ 117 | SPECTRUAL spectral clustering 118 | :param W: Adjacency matrix, N-by-N matrix 119 | :param k: number of clusters 120 | :return: data point cluster labels, n-by-1 vector. 121 | """ 122 | w_sum = np.array(W.sum(axis=1)).reshape(-1) 123 | D = np.diag(w_sum) 124 | _D = np.diag((w_sum + np.finfo(float).eps)** (-1 / 2)) 125 | L = D - W 126 | L = _D @ L @ _D 127 | eigval, eigvec = np.linalg.eig(L) 128 | eigval_argsort = eigval.real.astype(np.float32).argsort() 129 | F = np.take(eigvec.real.astype(np.float32), eigval_argsort[:k], axis=-1) 130 | idx = KMeans(n_clusters=k).fit(F).labels_ 131 | return idx 132 | 133 | def bestMap(L1,L2): 134 | ''' 135 | bestmap: permute labels of L2 to match L1 as good as possible 136 | INPUT: 137 | L1: labels of L1, shape of (N,) vector 138 | L2: labels of L2, shape of (N,) vector 139 | OUTPUT: 140 | new_L2: best matched permuted L2, shape of (N,) vector 141 | version 1.0 --December/2018 142 | Modified from bestMap.m (written by Deng Cai) 143 | ''' 144 | if L1.shape[0] != L2.shape[0] or len(L1.shape) > 1 or len(L2.shape) > 1: 145 | raise Exception('L1 shape must equal L2 shape') 146 | return 147 | Label1 = np.unique(L1) 148 | nClass1 = Label1.shape[0] 149 | Label2 = np.unique(L2) 150 | nClass2 = Label2.shape[0] 151 | nClass = max(nClass1,nClass2) 152 | G = np.zeros((nClass, nClass)) 153 | for i in range(nClass1): 154 | for j in range(nClass2): 155 | G[j,i] = np.sum((np.logical_and(L1 == Label1[i], L2 == Label2[j])).astype(np.int64)) 156 | c,t = linear_sum_assignment(-G) 157 | newL2 = np.zeros(L2.shape) 158 | for i in range(nClass2): 159 | newL2[L2 == Label2[i]] = Label1[t[i]] 160 | return newL2 161 | 162 | def clustering_metrics(true_label, pred_label): 163 | l1 = list(set(true_label)) 164 | numclass1 = len(l1) 165 | 166 | l2 = list(set(pred_label)) 167 | numclass2 = len(l2) 168 | if numclass1 != numclass2: 169 | print('Class Not equal, Error!!!!') 170 | return 0, 0, 0, 0, 0 171 | 172 | cost = np.zeros((numclass1, numclass2), dtype=int) 173 | for i, c1 in enumerate(l1): 174 | mps = [i1 for i1, e1 in enumerate(true_label) if e1 == c1] 175 | for j, c2 in enumerate(l2): 176 | mps_d = [i1 for i1 in mps if pred_label[i1] == c2] 177 | cost[i][j] = len(mps_d) 178 | 179 | # match two clustering results by Munkres algorithm 180 | m = Munkres() 181 | cost = cost.__neg__().tolist() 182 | 183 | indexes = m.compute(cost) 184 | idx = indexes[2][1] 185 | # get the match results 186 | new_predict = np.zeros(len(pred_label)) 187 | for i, c in enumerate(l1): 188 | # correponding label in l2: 189 | c2 = l2[indexes[i][1]] 190 | # ai is the index with label==c2 in the predict list 191 | ai = [ind for ind, elm in enumerate(pred_label) if elm == c2] 192 | new_predict[ai] = c 193 | 194 | acc = metrics.accuracy_score(true_label, new_predict) 195 | f1_macro = metrics.f1_score(true_label, new_predict, average='macro') 196 | nmi = metrics.normalized_mutual_info_score(true_label, pred_label) 197 | ari = metrics.adjusted_rand_score(true_label, pred_label) 198 | 199 | return acc* 100, f1_macro* 100, nmi* 100, ari* 100, idx 200 | 201 | def clustering(embeds,labels,args): 202 | # labels = torch.from_numpy(labels).type(torch.LongTensor) 203 | num_classes = torch.max(labels).item()+1 204 | 205 | # print('=================================== KMeans Clustering. ========================================') 206 | # u, s, v = sp.linalg.svds(embeds, k=num_classes, which='LM') 207 | # predY = KMeans(n_clusters=num_classes).fit(u).labels_ 208 | 209 | accs = [] 210 | nmis = [] 211 | aris = [] 212 | f1s = [] 213 | for i in range(10): 214 | # best_loss = 1e9 215 | best_acc = 0 216 | best_f1 = 0 217 | best_nmi = 0 218 | best_ari = 0 219 | for j in range(10): 220 | # set_trace() 221 | predY = KMeans(n_clusters=num_classes).fit(embeds).labels_ 222 | gnd_Y = bestMap(predY, labels.cpu().detach().numpy()) 223 | # predY_temp = torch.tensor(predY, dtype=torch.float) 224 | # gnd_Y_temp = torch.tensor(gnd_Y) 225 | # loss = nn.MSELoss()(predY_temp, gnd_Y_temp) 226 | # if loss <= best_loss: 227 | # best_loss = loss 228 | # acc_temp, f1_temp, nmi_temp, ari_temp, _ = clustering_metrics(gnd_Y, predY) 229 | 230 | acc_temp, f1_temp, nmi_temp, ari_temp, _ = clustering_metrics(gnd_Y, predY) 231 | if acc_temp > best_acc: 232 | best_acc = acc_temp 233 | best_f1 = f1_temp 234 | best_nmi = nmi_temp 235 | best_ari = ari_temp 236 | 237 | accs.append(best_acc) 238 | nmis.append(best_nmi) 239 | aris.append(best_ari) 240 | f1s.append(best_f1) 241 | 242 | # accs.append(acc_temp) 243 | # nmis.append(nmi_temp) 244 | # aris.append(ari_temp) 245 | # f1s.append(f1_temp) 246 | 247 | accs = np.stack(accs) 248 | nmis = np.stack(nmis) 249 | aris = np.stack(aris) 250 | f1s = np.stack(f1s) 251 | print(accs) 252 | print(style.YELLOW + '\nClustering result: ACC:{:.2f}'.format(accs.mean().item()),'$\pm$','{:.2f}'.format(accs.std().item()),\ 253 | 'NMI:{:.2f}'.format(nmis.mean().item()),'$\pm$','{:.2f}'.format(nmis.std().item()),\ 254 | 'ARI:{:.2f}'.format(aris.mean().item()),'$\pm$','{:.2f}'.format(aris.std().item()),\ 255 | 'F1:{:.2f}'.format(f1s.mean().item()),'$\pm$','{:.2f}'.format(f1s.std().item()),) 256 | 257 | -------------------------------------------------------------------------------- /utilis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import torch.nn as nn 5 | 6 | def set_seed(seed): 7 | """ 8 | setup random seed to fix the result 9 | Args: 10 | seed: random seed 11 | Returns: None 12 | """ 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | torch.backends.cudnn.benchmark = False 19 | torch.backends.cudnn.deterministic = True 20 | 21 | class style(): 22 | BLACK = '\033[30m' 23 | RED = '\033[31m' 24 | GREEN = '\033[32m' 25 | YELLOW = '\033[33m' 26 | BLUE = '\033[34m' 27 | MAGENTA = '\033[35m' 28 | CYAN = '\033[36m' 29 | WHITE = '\033[37m' 30 | UNDERLINE = '\033[4m' 31 | RESET = '\033[0m' 32 | 33 | 34 | def glorot_init(input_dim, output_dim): 35 | init_range = np.sqrt(6.0/(input_dim + output_dim)) 36 | initial = torch.rand(input_dim, output_dim)*2*init_range - init_range 37 | return nn.Parameter(initial) 38 | --------------------------------------------------------------------------------