├── README.md ├── data └── readme.txt ├── dataloader.py ├── figures ├── performance.png ├── view_change.png └── workflow.png ├── loss.py ├── metric.py ├── network.py ├── requirements.txt ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | ## Self-Weighted Contrastive Fusion for Deep Multi-View Clustering 2 | > **Authors:** 3 | Song Wu, Yan Zheng, Yazhou Ren, Jing He, Xiaorong Pu, shudong Huang, Zhifeng Hao, Lifang He. 4 | 5 | This repository contains the code and data of our paper published in *IEEE Transactions on Multimedia (TMM)*: [Self-Weighted Contrastive Fusion for Deep Multi-View Clustering](https://ieeexplore.ieee.org/document/10499831). 6 | 7 | 8 | 9 | ## 1. Workflow of SCMVC 10 | 11 | 12 | 13 | The framework of SCMVC. We propose a hierarchical network architecture to separate the consistency objective from the reconstruction objective. Specifically, the feature learning autoencoders first project the raw data into a low-dimensional latent space $\mathbf{Z}$. Then, two feature MLPs learn view-consensus features $\mathbf{R}$ and global features $\mathbf{H}$, respectively. Particularly, a novel self-weighting method adaptively strengthens useful views in feature fusion, and weakens unreliable views, to implement multi-view contrastive fusion. 14 | 15 | ## 2.Requirements 16 | - python==3.7.13 17 | 18 | - pytorch==1.12.0 19 | 20 | - numpy==1.21.5 21 | 22 | - scikit-learn==0.22.2.post1 23 | 24 | - scipy==1.7.3 25 | 26 | ## 3.Datasets 27 | 28 | - The all datasets could be downloaded from [cloud](https://pan.baidu.com/s/18If7bx2ZOVZhyijtzycjXA). key: data 29 | 30 | - Particularly, thanks to the valuable works [MFLVC](https://github.com/SubmissionsIn/MFLVC) and [GCFAggMVC](https://github.com/Galaxy922/GCFAggMVC) for providing these datasets. 31 | 32 | ## 4.Usage 33 | 34 | ### Paper: 35 | Self-Weighted Contrastive Fusion for Deep Multi-View Clustering: https://ieeexplore.ieee.org/document/10499831. 36 | 37 | ### To test the trained model, run: 38 | ```bash 39 | python test.py 40 | ``` 41 | 42 | ### To train a new model, run: 43 | ```bash 44 | python train.py 45 | ``` 46 | 47 | The experiments are conducted on a Windows PC with Intel (R) Core (TM) i5-9300H CPU@2.40 GHz, 16.0 GB RAM, and TITAN X GPU (12 GB caches). 48 | 49 | 50 | ## 5.Experiment Results 51 | we compare our proposed SCMVC with 10 state-of-the-art multi-view clustering methods: 52 | - CGD: [multi-view clustering via cross-view graph diffusion](https://github.com/ChangTang/CGD) 53 | - LMVSC: [large-scale multi-view subspace clustering](https://github.com/sckangz/LMVSC) 54 | - EOMSV: [efficient one-pass multi-view subspace clustering](https://github.com/Tracesource/EOMSC-CA) 55 | - DEMVC: [deep embedded multi-view clustering with collaborative training](https://github.com/SubmissionsIn/DEMVC) 56 | - CoMVC: [contrastive multi-view clustering](https://github.com/DanielTrosten/mvc) 57 | - CONAN: [contrastive fusion networks for multi-view clustering](https://github.com/Guanzhou-Ke/conan) 58 | - MFLVC: [multi-level feature learning for contrastive multi-view clustering](https://github.com/SubmissionsIn/MFLVC) 59 | - DSMVC: [deep safe multi-view clustering](https://github.com/Gasteinh/DSMVC) 60 | - GCFAggMVC: [global and cross-view feature aggregation for multi-view clustering](https://github.com/Galaxy922/GCFAggMVC) 61 | - DealMVC: [dual contrastive calibration for multi-view clustering](https://github.com/xihongyang1999/DealMVC) 62 | 63 | 64 | 65 | 66 | ## 6.Acknowledgments 67 | 68 | Our proposed SCMVC are inspired by [MFLVC](https://github.com/SubmissionsIn/MFLVC), [GCFAggMVC](https://github.com/Galaxy922/GCFAggMVC), and [SEM](https://github.com/SubmissionsIn/SEM). Thanks for these valuable works. 69 | 70 | ## 7.Citation 71 | If you use our code or datasets in this repository for your research, please cite our papers. 72 | ```latex 73 | @ARTICLE{10499831, 74 | author={Wu, Song and Zheng, Yan and Ren, Yazhou and He, Jing and Pu, Xiaorong and Huang, Shudong and Hao, Zhifeng and He, Lifang}, 75 | journal={IEEE Transactions on Multimedia}, 76 | title={Self-Weighted Contrastive Fusion for Deep Multi-View Clustering}, 77 | year={2024}, 78 | volume={}, 79 | number={}, 80 | pages={1-13}, 81 | doi={10.1109/TMM.2024.3387298} 82 | } 83 | 84 | @article{xu2024self, 85 | title={Self-weighted contrastive learning among multiple views for mitigating representation degeneration}, 86 | author={Xu, Jie and Chen, Shuo and Ren, Yazhou and Shi, Xiaoshuang and Shen, Hengtao and Niu, Gang and Zhu, Xiaofeng}, 87 | journal={Advances in Neural Information Processing Systems}, 88 | volume={36}, 89 | year={2024} 90 | } 91 | 92 | @InProceedings{Xu_2022_CVPR, 93 | author = {Xu, Jie and Tang, Huayi and Ren, Yazhou and Peng, Liang and Zhu, Xiaofeng and He, Lifang}, 94 | title = {Multi-Level Feature Learning for Contrastive Multi-View Clustering}, 95 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 96 | year = {2022}, 97 | pages = {16051-16060} 98 | } 99 | ``` 100 | 101 | If you have any problems, please contact me by songwu.work@outlook.com. 102 | 103 | 104 | -------------------------------------------------------------------------------- /data/readme.txt: -------------------------------------------------------------------------------- 1 | Please read README.md to download the corresponding datasets, and adapt different seed to explore the best performance. 2 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import MinMaxScaler 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import scipy.io 5 | import torch 6 | 7 | class BDGP(Dataset): 8 | def __init__(self, path): 9 | data1 = scipy.io.loadmat(path+'BDGP.mat')['X1'].astype(np.float32) 10 | data2 = scipy.io.loadmat(path+'BDGP.mat')['X2'].astype(np.float32) 11 | labels = scipy.io.loadmat(path+'BDGP.mat')['Y'].transpose() 12 | self.x1 = data1 13 | self.x2 = data2 14 | self.y = labels 15 | 16 | def __len__(self): 17 | return self.x1.shape[0] 18 | 19 | def __getitem__(self, idx): 20 | return [torch.from_numpy(self.x1[idx]), torch.from_numpy( 21 | self.x2[idx])], torch.from_numpy(self.y[idx]), torch.from_numpy(np.array(idx)).long() 22 | 23 | 24 | class CCV(Dataset): 25 | def __init__(self, path): 26 | self.data1 = np.load(path+'STIP.npy').astype(np.float32) 27 | scaler = MinMaxScaler() 28 | self.data1 = scaler.fit_transform(self.data1) 29 | self.data2 = np.load(path+'SIFT.npy').astype(np.float32) 30 | self.data3 = np.load(path+'MFCC.npy').astype(np.float32) 31 | self.labels = np.load(path+'label.npy') 32 | 33 | def __len__(self): 34 | return 6773 35 | 36 | def __getitem__(self, idx): 37 | x1 = self.data1[idx] 38 | x2 = self.data2[idx] 39 | x3 = self.data3[idx] 40 | 41 | return [torch.from_numpy(x1), torch.from_numpy( 42 | x2), torch.from_numpy(x3)], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() 43 | 44 | 45 | class MNIST_USPS(Dataset): 46 | def __init__(self, path): 47 | self.Y = scipy.io.loadmat(path + 'MNIST_USPS.mat')['Y'].astype(np.int32).reshape(5000,) 48 | self.V1 = scipy.io.loadmat(path + 'MNIST_USPS.mat')['X1'].astype(np.float32) 49 | self.V2 = scipy.io.loadmat(path + 'MNIST_USPS.mat')['X2'].astype(np.float32) 50 | 51 | def __len__(self): 52 | return 5000 53 | 54 | def __getitem__(self, idx): 55 | 56 | x1 = self.V1[idx].reshape(784) 57 | x2 = self.V2[idx].reshape(784) 58 | return [torch.from_numpy(x1), torch.from_numpy(x2)], self.Y[idx], torch.from_numpy(np.array(idx)).long() 59 | 60 | 61 | class Fashion(Dataset): 62 | def __init__(self, path): 63 | self.Y = scipy.io.loadmat(path + 'Fashion.mat')['Y'].astype(np.int32).reshape(10000,) 64 | self.V1 = scipy.io.loadmat(path + 'Fashion.mat')['X1'].astype(np.float32) 65 | self.V2 = scipy.io.loadmat(path + 'Fashion.mat')['X2'].astype(np.float32) 66 | self.V3 = scipy.io.loadmat(path + 'Fashion.mat')['X3'].astype(np.float32) 67 | 68 | def __len__(self): 69 | return 10000 70 | 71 | def __getitem__(self, idx): 72 | 73 | x1 = self.V1[idx].reshape(784) 74 | x2 = self.V2[idx].reshape(784) 75 | x3 = self.V3[idx].reshape(784) 76 | 77 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], self.Y[idx], torch.from_numpy(np.array(idx)).long() 78 | 79 | class Caltech(Dataset): 80 | def __init__(self, path, view): 81 | data = scipy.io.loadmat(path) 82 | scaler = MinMaxScaler() 83 | self.view1 = scaler.fit_transform(data['X1'].astype(np.float32)) 84 | self.view2 = scaler.fit_transform(data['X2'].astype(np.float32)) 85 | self.view3 = scaler.fit_transform(data['X3'].astype(np.float32)) 86 | self.view4 = scaler.fit_transform(data['X4'].astype(np.float32)) 87 | self.view5 = scaler.fit_transform(data['X5'].astype(np.float32)) 88 | self.labels = scipy.io.loadmat(path)['Y'].transpose() 89 | self.view = view 90 | 91 | def __len__(self): 92 | return 1400 93 | 94 | def __getitem__(self, idx): 95 | if self.view == 2: 96 | return [torch.from_numpy( 97 | self.view1[idx]), torch.from_numpy(self.view2[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() 98 | if self.view == 3: 99 | return [torch.from_numpy(self.view1[idx]), torch.from_numpy( 100 | self.view2[idx]), torch.from_numpy(self.view5[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() 101 | if self.view == 4: 102 | return [torch.from_numpy(self.view1[idx]), torch.from_numpy(self.view2[idx]), torch.from_numpy( 103 | self.view5[idx]), torch.from_numpy(self.view4[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() 104 | if self.view == 5: 105 | return [torch.from_numpy(self.view1[idx]), torch.from_numpy( 106 | self.view2[idx]), torch.from_numpy(self.view5[idx]), torch.from_numpy( 107 | self.view4[idx]), torch.from_numpy(self.view3[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() 108 | 109 | class cifar_10(): 110 | def __init__(self, path): 111 | data = scipy.io.loadmat(path + 'cifar10.mat') 112 | self.Y = data['truelabel'][0][0].astype(np.int32).reshape(50000,) 113 | self.V1 = data['data'][0][0].T.astype(np.float32) 114 | self.V2 = data['data'][1][0].T.astype(np.float32) 115 | self.V3 = data['data'][2][0].T.astype(np.float32) 116 | def __len__(self): 117 | return 50000 118 | def __getitem__(self, idx): 119 | x1 = self.V1[idx] 120 | x2 = self.V2[idx] 121 | x3 = self.V3[idx] 122 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], self.Y[idx], torch.from_numpy(np.array(idx)).long() 123 | 124 | class cifar_100(): 125 | def __init__(self, path): 126 | data = scipy.io.loadmat(path + 'cifar100.mat') 127 | self.Y = data['truelabel'][0][0].astype(np.int32).reshape(50000,) 128 | self.V1 = data['data'][0][0].T.astype(np.float32) 129 | self.V2 = data['data'][1][0].T.astype(np.float32) 130 | self.V3 = data['data'][2][0].T.astype(np.float32) 131 | def __len__(self): 132 | return 50000 133 | def __getitem__(self, idx): 134 | x1 = self.V1[idx] 135 | x2 = self.V2[idx] 136 | x3 = self.V3[idx] 137 | 138 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)],self.Y[idx], torch.from_numpy(np.array(idx)).long() 139 | 140 | class synthetic3d(): 141 | def __init__(self, path): 142 | data = scipy.io.loadmat(path + 'synthetic3d.mat') 143 | self.Y = data['Y'].astype(np.int32).reshape(600,) 144 | self.V1 = data['X'][0][0].astype(np.float32) 145 | self.V2 = data['X'][1][0].astype(np.float32) 146 | self.V3 = data['X'][2][0].astype(np.float32) 147 | def __len__(self): 148 | return 600 149 | def __getitem__(self, idx): 150 | x1 = self.V1[idx] 151 | x2 = self.V2[idx] 152 | x3 = self.V3[idx] 153 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], \ 154 | self.Y[idx], torch.from_numpy(np.array(idx)).long() 155 | 156 | class prokaryotic(): 157 | def __init__(self, path): 158 | data = scipy.io.loadmat(path + 'prokaryotic.mat') 159 | self.Y = data['Y'].astype(np.int32).reshape(551,) 160 | self.V1 = data['X'][0][0].astype(np.float32) 161 | self.V2 = data['X'][1][0].astype(np.float32) 162 | self.V3 = data['X'][2][0].astype(np.float32) 163 | def __len__(self): 164 | return 551 165 | def __getitem__(self, idx): 166 | x1 = self.V1[idx] 167 | x2 = self.V2[idx] 168 | x3 = self.V3[idx] 169 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], \ 170 | self.Y[idx], torch.from_numpy(np.array(idx)).long() 171 | 172 | def load_data(dataset): 173 | if dataset == "BDGP": 174 | dataset = BDGP('./data/') 175 | dims = [1750, 79] 176 | view = 2 177 | data_size = 2500 178 | class_num = 5 179 | elif dataset == "MNIST-USPS": 180 | dataset = MNIST_USPS('./data/') 181 | dims = [784, 784] 182 | view = 2 183 | class_num = 10 184 | data_size = 5000 185 | elif dataset == "CCV": 186 | dataset = CCV('./data/') 187 | dims = [5000, 5000, 4000] 188 | view = 3 189 | data_size = 6773 190 | class_num = 20 191 | elif dataset == "Fashion": 192 | dataset = Fashion('./data/') 193 | dims = [784, 784, 784] 194 | view = 3 195 | data_size = 10000 196 | class_num = 10 197 | elif dataset == "Caltech-2V": 198 | dataset = Caltech('data/Caltech-5V.mat', view=2) 199 | dims = [40, 254] 200 | view = 2 201 | data_size = 1400 202 | class_num = 7 203 | elif dataset == "Caltech-3V": 204 | dataset = Caltech('data/Caltech-5V.mat', view=3) 205 | dims = [40, 254, 928] 206 | view = 3 207 | data_size = 1400 208 | class_num = 7 209 | elif dataset == "Caltech-4V": 210 | dataset = Caltech('data/Caltech-5V.mat', view=4) 211 | dims = [40, 254, 928, 512] 212 | view = 4 213 | data_size = 1400 214 | class_num = 7 215 | elif dataset == "Caltech-5V": 216 | dataset = Caltech('data/Caltech-5V.mat', view=5) 217 | dims = [40, 254, 928, 512, 1984] 218 | view = 5 219 | data_size = 1400 220 | class_num = 7 221 | elif dataset == "Synthetic3d": 222 | dataset = synthetic3d('./data/') 223 | dims = [3,3,3] 224 | view = 3 225 | data_size = 600 226 | class_num = 3 227 | elif dataset == "Prokaryotic": 228 | dataset = prokaryotic('./data/') 229 | dims = [438, 3, 393] 230 | view = 3 231 | data_size = 551 232 | class_num = 4 233 | elif dataset == "Cifar10": 234 | dataset = cifar_10('./data/') 235 | dims = [512, 2048, 1024] 236 | view = 3 237 | data_size = 50000 238 | class_num = 10 239 | elif dataset == "Cifar100": 240 | dataset = cifar_100('./data/') 241 | dims = [512, 2048, 1024] 242 | view = 3 243 | data_size = 50000 244 | class_num = 100 245 | else: 246 | raise NotImplementedError 247 | return dataset, dims, view, data_size, class_num 248 | -------------------------------------------------------------------------------- /figures/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongwuJob/SCMVC/26963adc4964fb1cc6277fea8d6cfb5152989d37/figures/performance.png -------------------------------------------------------------------------------- /figures/view_change.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongwuJob/SCMVC/26963adc4964fb1cc6277fea8d6cfb5152989d37/figures/view_change.png -------------------------------------------------------------------------------- /figures/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongwuJob/SCMVC/26963adc4964fb1cc6277fea8d6cfb5152989d37/figures/workflow.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ContrastiveLoss(nn.Module): 5 | def __init__(self, batch_size, temperature, device): 6 | super(ContrastiveLoss, self).__init__() 7 | self.batch_size = batch_size 8 | self.temperature = temperature 9 | self.device = device 10 | 11 | def forward(self, h_i, h_j, weight=None): 12 | N =self.batch_size 13 | similarity_matrix = torch.matmul(h_i, h_j.T) / self.temperature 14 | positives = torch.diag(similarity_matrix) 15 | mask = torch.ones((N, N)).to(self.device) 16 | mask = mask.fill_diagonal_(0) 17 | 18 | nominator = torch.exp(positives) 19 | denominator = (mask.bool()) * torch.exp(similarity_matrix) 20 | 21 | loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1)) 22 | loss = torch.sum(loss_partial) / N 23 | loss = weight * loss if weight is not None else loss 24 | 25 | return loss 26 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import v_measure_score, adjusted_rand_score, accuracy_score 2 | from sklearn.cluster import KMeans 3 | from scipy.optimize import linear_sum_assignment 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | import torch 7 | 8 | def cluster_acc(y_true, y_pred): 9 | y_true = y_true.astype(np.int64) 10 | assert y_pred.size == y_true.size 11 | D = max(y_pred.max(), y_true.max()) + 1 12 | w = np.zeros((D, D), dtype=np.int64) 13 | for i in range(y_pred.size): 14 | w[y_pred[i], y_true[i]] += 1 15 | u = linear_sum_assignment(w.max() - w) 16 | ind = np.concatenate([u[0].reshape(u[0].shape[0], 1), u[1].reshape([u[0].shape[0], 1])], axis=1) 17 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 18 | 19 | def purity(y_true, y_pred): 20 | y_voted_labels = np.zeros(y_true.shape) 21 | labels = np.unique(y_true) 22 | ordered_labels = np.arange(labels.shape[0]) 23 | for k in range(labels.shape[0]): 24 | y_true[y_true == labels[k]] = ordered_labels[k] 25 | labels = np.unique(y_true) 26 | bins = np.concatenate((labels, [np.max(labels)+1]), axis=0) 27 | 28 | for cluster in np.unique(y_pred): 29 | hist, _ = np.histogram(y_true[y_pred == cluster], bins=bins) 30 | winner = np.argmax(hist) 31 | y_voted_labels[y_pred == cluster] = winner 32 | 33 | return accuracy_score(y_true, y_voted_labels) 34 | 35 | def evaluate(label, pred): 36 | nmi = v_measure_score(label, pred) 37 | ari = adjusted_rand_score(label, pred) 38 | acc = cluster_acc(label, pred) 39 | pur = purity(label, pred) 40 | return nmi, ari, acc, pur 41 | 42 | def valid(model, device, dataset, view, data_size, class_num, eval_h=False, epoch=None): 43 | test_loader = DataLoader( 44 | dataset, 45 | batch_size=data_size, 46 | shuffle=False, 47 | ) 48 | for batch_idx, (xs, y, _) in enumerate(test_loader): 49 | for v in range(view): 50 | xs[v] = xs[v].to(device) 51 | labels = y.cpu().detach().data.numpy().squeeze() 52 | 53 | # inference 54 | with torch.no_grad(): 55 | xrs, zs, rs, H = model(xs) 56 | 57 | if eval_h: 58 | print("Clustering results on low-level features of each view:") 59 | for v in range(view): 60 | kmeans = KMeans(n_clusters=class_num, n_init=100) 61 | y_pred = kmeans.fit_predict(zs[v].cpu().data.numpy()) 62 | nmi, ari, acc, pur = evaluate(labels, y_pred) 63 | print('ACC{} = {:.4f} NMI{} = {:.4f} ARI{} = {:.4f} PUR{}={:.4f}'.format(v + 1, acc, 64 | v + 1, nmi, 65 | v + 1, ari, 66 | v + 1, pur)) 67 | print("Clustering results on view-consensus features of each view:") 68 | for v in range(view): 69 | y_pred = kmeans.fit_predict(rs[v].cpu().data.numpy()) 70 | nmi, ari, acc, pur = evaluate(labels, y_pred) 71 | print('ACC{} = {:.4f} NMI{} = {:.4f} ARI{} = {:.4f} PUR{}={:.4f}'.format(v + 1, acc, 72 | v + 1, nmi, 73 | v + 1, ari, 74 | v + 1, pur)) 75 | 76 | # Clustering results on global features 77 | kmeans = KMeans(n_clusters=class_num, n_init=100) 78 | y_pred = kmeans.fit_predict(H.cpu().data.numpy()) 79 | nmi, ari, acc, pur = evaluate(labels, y_pred) 80 | if epoch is not None: 81 | print('Epoch {}'.format(epoch),'The clustering performace: ACC = {:.4f} NMI = {:.4f} ARI = {:.4f} PUR={:.4f}'.format(acc, nmi, ari, pur)) 82 | else: 83 | print('The clustering performace: ACC = {:.4f} NMI = {:.4f} ARI = {:.4f} PUR={:.4f}'.format(acc, nmi, ari, pur)) 84 | return acc, nmi, pur 85 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.functional import normalize 3 | import torch 4 | 5 | # Encoder 6 | class Encoder(nn.Module): 7 | def __init__(self, input_dim, feature_dim): 8 | super(Encoder, self).__init__() 9 | self.encoder = nn.Sequential( 10 | nn.Linear(input_dim, 500), 11 | nn.ReLU(), 12 | nn.Linear(500, 500), 13 | nn.ReLU(), 14 | nn.Linear(500, 2000), 15 | nn.ReLU(), 16 | nn.Linear(2000, feature_dim), 17 | ) 18 | 19 | def forward(self, x): 20 | return self.encoder(x) 21 | 22 | # Decoder 23 | class Decoder(nn.Module): 24 | def __init__(self, input_dim, feature_dim): 25 | super(Decoder, self).__init__() 26 | self.decoder = nn.Sequential( 27 | nn.Linear(feature_dim, 2000), 28 | nn.ReLU(), 29 | nn.Linear(2000, 500), 30 | nn.ReLU(), 31 | nn.Linear(500, 500), 32 | nn.ReLU(), 33 | nn.Linear(500, input_dim) 34 | ) 35 | 36 | def forward(self, x): 37 | return self.decoder(x) 38 | 39 | # SCMVC Network 40 | class Network(nn.Module): 41 | def __init__(self, view, input_size, feature_dim, high_feature_dim, device): 42 | super(Network, self).__init__() 43 | self.view = view 44 | self.encoders = [] 45 | self.decoders = [] 46 | for v in range(view): 47 | self.encoders.append(Encoder(input_size[v], feature_dim).to(device)) 48 | self.decoders.append(Decoder(input_size[v], feature_dim).to(device)) 49 | self.encoders = nn.ModuleList(self.encoders) 50 | self.decoders = nn.ModuleList(self.decoders) 51 | 52 | # global features fusion layer 53 | self.feature_fusion_module = nn.Sequential( 54 | nn.Linear(self.view * feature_dim, 256), 55 | nn.ReLU(), 56 | nn.Linear(256, high_feature_dim) 57 | ) 58 | 59 | # view-consensus features learning layer 60 | self.common_information_module = nn.Sequential( 61 | nn.Linear(feature_dim, high_feature_dim) 62 | ) 63 | 64 | # global feature fusion 65 | def feature_fusion(self, zs, zs_gradient): 66 | input = torch.cat(zs, dim=1) if zs_gradient else torch.cat(zs, dim=1).detach() 67 | return normalize(self.feature_fusion_module(input),dim=1) 68 | 69 | def forward(self, xs, zs_gradient=True): 70 | rs = [] 71 | xrs = [] 72 | zs = [] 73 | for v in range(self.view): 74 | x = xs[v] 75 | z = self.encoders[v](x) 76 | xr = self.decoders[v](z) 77 | r = normalize(self.common_information_module(z),dim=1) 78 | 79 | rs.append(r) 80 | zs.append(z) 81 | xrs.append(xr) 82 | 83 | H = self.feature_fusion(zs,zs_gradient) 84 | return xrs,zs,rs,H 85 | 86 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.7.13 2 | pytorch==1.12.0 3 | numpy==1.21.5 4 | scikit-learn==0.22.2.post1 5 | scipy==1.7.3 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from network import Network 3 | from metric import valid 4 | import argparse 5 | from dataloader import load_data 6 | 7 | # MNIST-USPS 8 | # BDGP 9 | # CCV 10 | # Fashion 11 | # Caltech-2V 12 | # Caltech-3V 13 | # Caltech-4V 14 | # Caltech-5V 15 | # Cifar10 16 | # Cifar100 17 | # Prokaryotic 18 | # Synthetic3d 19 | Dataname = 'MNIST-USPS' 20 | parser = argparse.ArgumentParser(description='test') 21 | parser.add_argument('--dataset', default=Dataname) 22 | parser.add_argument("--feature_dim", default=64) 23 | parser.add_argument("--hide_feature_dim", default=20) 24 | args = parser.parse_args() 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | dataset, dims, view, data_size, class_num = load_data(args.dataset) 27 | model = Network(view, dims, args.feature_dim, args.hide_feature_dim, device) 28 | model = model.to(device) 29 | 30 | checkpoint = torch.load('./models/' + args.dataset + '.pth') 31 | model.load_state_dict(checkpoint) 32 | 33 | model.eval() 34 | print("Dataset:{}".format(args.dataset)) 35 | print("Datasize:" + str(data_size)) 36 | print("Loading models...") 37 | acc, nmi, pur = valid(model, device, dataset, view, data_size, class_num, eval_h=True) 38 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import torch 4 | from network import Network 5 | from metric import valid 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | import argparse 9 | import random 10 | from loss import ContrastiveLoss 11 | from dataloader import load_data 12 | 13 | # MNIST-USPS 14 | # BDGP 15 | # CCV 16 | # Fashion 17 | # Caltech-2V 18 | # Caltech-3V 19 | # Caltech-4V 20 | # Caltech-5V 21 | # Cifar10 22 | # Cifar100 23 | # Prokaryotic 24 | # Synthetic3d 25 | Dataname = 'MNIST-USPS' 26 | parser = argparse.ArgumentParser(description='train') 27 | parser.add_argument('--dataset', default=Dataname) 28 | parser.add_argument('--batch_size', default=256, type=int) 29 | parser.add_argument("--learning_rate", default=0.0003) 30 | parser.add_argument("--weight_decay", default=0.) 31 | parser.add_argument("--pre_epochs", default=200) 32 | parser.add_argument("--con_epochs", default=50) 33 | parser.add_argument("--feature_dim", default=64) 34 | parser.add_argument("--high_feature_dim", default=20) 35 | parser.add_argument("--temperature", default=1) 36 | args = parser.parse_args() 37 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 38 | 39 | if args.dataset == "MNIST-USPS": 40 | args.con_epochs = 50 41 | seed = 10 42 | if args.dataset == "BDGP": 43 | args.con_epochs = 10 # 20 44 | seed = 30 45 | if args.dataset == "CCV": 46 | args.con_epochs = 50 # 100 47 | seed = 100 48 | args.tune_epochs = 200 49 | if args.dataset == "Fashion": 50 | args.con_epochs = 50 # 100 51 | seed = 10 52 | if args.dataset == "Caltech-2V": 53 | args.con_epochs = 100 54 | seed = 200 55 | args.tune_epochs = 200 56 | if args.dataset == "Caltech-3V": 57 | args.con_epochs = 100 58 | seed = 30 59 | if args.dataset == "Caltech-4V": 60 | args.con_epochs = 100 61 | seed = 100 62 | if args.dataset == "Caltech-5V": 63 | args.con_epochs = 100 64 | seed = 1000000 65 | if args.dataset == "Cifar10": 66 | args.con_epochs = 10 67 | seed = 10 68 | if args.dataset == "Cifar100": 69 | args.con_epochs = 20 70 | seed = 10 71 | if args.dataset == "Prokaryotic": 72 | args.con_epochs = 20 73 | seed = 10000 74 | if args.dataset == "Synthetic3d": 75 | args.con_epochs = 100 76 | seed = 100 77 | 78 | def setup_seed(seed): 79 | torch.manual_seed(seed) 80 | torch.cuda.manual_seed_all(seed) 81 | np.random.seed(seed) 82 | random.seed(seed) 83 | torch.backends.cudnn.deterministic = True 84 | 85 | setup_seed(seed) 86 | 87 | dataset, dims, view, data_size, class_num = load_data(args.dataset) 88 | data_loader = torch.utils.data.DataLoader( 89 | dataset, 90 | batch_size=args.batch_size, 91 | shuffle=True, 92 | drop_last=True, 93 | ) 94 | 95 | def compute_view_value(rs, H, view): 96 | N = H.shape[0] 97 | w = [] 98 | # all features are normalized 99 | global_sim = torch.matmul(H,H.t()) 100 | for v in range(view): 101 | view_sim = torch.matmul(rs[v],rs[v].t()) 102 | related_sim = torch.matmul(rs[v],H.t()) 103 | # The implementation of MMD 104 | w_v = (torch.sum(view_sim) + torch.sum(global_sim) - 2 * torch.sum(related_sim)) / (N*N) 105 | w.append(torch.exp(-w_v)) 106 | w = torch.stack(w) 107 | w = w / torch.sum(w) 108 | return w.squeeze() 109 | 110 | 111 | def pretrain(epoch): 112 | tot_loss = 0. 113 | criterion = torch.nn.MSELoss() 114 | for batch_idx, (xs, _, _) in enumerate(data_loader): 115 | for v in range(view): 116 | xs[v] = xs[v].to(device) 117 | optimizer.zero_grad() 118 | xrs,_,_,_ = model(xs) 119 | loss_list = [] 120 | for v in range(view): 121 | loss_list.append(criterion(xs[v], xrs[v])) 122 | loss = sum(loss_list) 123 | loss.backward() 124 | optimizer.step() 125 | tot_loss += loss.item() 126 | print('Epoch {}'.format(epoch), 'Loss:{:.6f}'.format(tot_loss/len(data_loader))) 127 | 128 | def contrastive_train(epoch): 129 | tot_loss = 0. 130 | mse = torch.nn.MSELoss() 131 | for batch_idx, (xs, _, _) in enumerate(data_loader): 132 | for v in range(view): 133 | xs[v] = xs[v].to(device) 134 | optimizer.zero_grad() 135 | xrs, zs, rs, H = model(xs) 136 | loss_list = [] 137 | 138 | # compute adaptive weights for each view 139 | with torch.no_grad(): 140 | w = compute_view_value(rs, H, view) 141 | 142 | for v in range(view): 143 | # Self-weighted contrastive learning loss 144 | loss_list.append(contrastiveloss(H, rs[v], w[v])) 145 | # Reconstruction loss 146 | loss_list.append(mse(xs[v], xrs[v])) 147 | loss = sum(loss_list) 148 | loss.backward() 149 | optimizer.step() 150 | tot_loss += loss.item() 151 | print('Epoch {}'.format(epoch), 'Loss:{:.6f}'.format(tot_loss/len(data_loader))) 152 | 153 | accs = [] 154 | nmis = [] 155 | purs = [] 156 | if not os.path.exists('./models'): 157 | os.makedirs('./models') 158 | T = 1 159 | for i in range(T): 160 | print("ROUND:{}".format(i+1)) 161 | setup_seed(seed) 162 | model = Network(view, dims, args.feature_dim, args.high_feature_dim, device) 163 | print(model) 164 | model = model.to(device) 165 | state = model.state_dict() 166 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 167 | contrastiveloss = ContrastiveLoss(args.batch_size, args.temperature, device).to(device) 168 | best_acc, best_nmi, best_pur = 0, 0, 0 169 | 170 | epoch = 1 171 | while epoch <= args.pre_epochs: 172 | pretrain(epoch) 173 | epoch += 1 174 | # acc, nmi, pur = valid(model, device, dataset, view, data_size, class_num, eval_h=True, epoch=epoch) 175 | 176 | while epoch <= args.pre_epochs + args.con_epochs: 177 | contrastive_train(epoch) 178 | acc, nmi, pur = valid(model, device, dataset, view, data_size, class_num, eval_h=False, epoch=epoch) 179 | 180 | if acc > best_acc: 181 | best_acc, best_nmi, best_pur = acc, nmi, pur 182 | state = model.state_dict() 183 | torch.save(state, './models/' + args.dataset + '.pth') 184 | epoch += 1 185 | 186 | # The final result 187 | accs.append(best_acc) 188 | nmis.append(best_nmi) 189 | purs.append(best_pur) 190 | print('The best clustering performace: ACC = {:.4f} NMI = {:.4f} PUR={:.4f}'.format(best_acc, best_nmi, best_pur)) 191 | --------------------------------------------------------------------------------