├── README.md
├── data
└── readme.txt
├── dataloader.py
├── figures
├── performance.png
├── view_change.png
└── workflow.png
├── loss.py
├── metric.py
├── network.py
├── requirements.txt
├── test.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | ## Self-Weighted Contrastive Fusion for Deep Multi-View Clustering
2 | > **Authors:**
3 | Song Wu, Yan Zheng, Yazhou Ren, Jing He, Xiaorong Pu, shudong Huang, Zhifeng Hao, Lifang He.
4 |
5 | This repository contains the code and data of our paper published in *IEEE Transactions on Multimedia (TMM)*: [Self-Weighted Contrastive Fusion for Deep Multi-View Clustering](https://ieeexplore.ieee.org/document/10499831).
6 |
7 |
8 |
9 | ## 1. Workflow of SCMVC
10 |
11 |
12 |
13 | The framework of SCMVC. We propose a hierarchical network architecture to separate the consistency objective from the reconstruction objective. Specifically, the feature learning autoencoders first project the raw data into a low-dimensional latent space $\mathbf{Z}$. Then, two feature MLPs learn view-consensus features $\mathbf{R}$ and global features $\mathbf{H}$, respectively. Particularly, a novel self-weighting method adaptively strengthens useful views in feature fusion, and weakens unreliable views, to implement multi-view contrastive fusion.
14 |
15 | ## 2.Requirements
16 | - python==3.7.13
17 |
18 | - pytorch==1.12.0
19 |
20 | - numpy==1.21.5
21 |
22 | - scikit-learn==0.22.2.post1
23 |
24 | - scipy==1.7.3
25 |
26 | ## 3.Datasets
27 |
28 | - The all datasets could be downloaded from [cloud](https://pan.baidu.com/s/18If7bx2ZOVZhyijtzycjXA). key: data
29 |
30 | - Particularly, thanks to the valuable works [MFLVC](https://github.com/SubmissionsIn/MFLVC) and [GCFAggMVC](https://github.com/Galaxy922/GCFAggMVC) for providing these datasets.
31 |
32 | ## 4.Usage
33 |
34 | ### Paper:
35 | Self-Weighted Contrastive Fusion for Deep Multi-View Clustering: https://ieeexplore.ieee.org/document/10499831.
36 |
37 | ### To test the trained model, run:
38 | ```bash
39 | python test.py
40 | ```
41 |
42 | ### To train a new model, run:
43 | ```bash
44 | python train.py
45 | ```
46 |
47 | The experiments are conducted on a Windows PC with Intel (R) Core (TM) i5-9300H CPU@2.40 GHz, 16.0 GB RAM, and TITAN X GPU (12 GB caches).
48 |
49 |
50 | ## 5.Experiment Results
51 | we compare our proposed SCMVC with 10 state-of-the-art multi-view clustering methods:
52 | - CGD: [multi-view clustering via cross-view graph diffusion](https://github.com/ChangTang/CGD)
53 | - LMVSC: [large-scale multi-view subspace clustering](https://github.com/sckangz/LMVSC)
54 | - EOMSV: [efficient one-pass multi-view subspace clustering](https://github.com/Tracesource/EOMSC-CA)
55 | - DEMVC: [deep embedded multi-view clustering with collaborative training](https://github.com/SubmissionsIn/DEMVC)
56 | - CoMVC: [contrastive multi-view clustering](https://github.com/DanielTrosten/mvc)
57 | - CONAN: [contrastive fusion networks for multi-view clustering](https://github.com/Guanzhou-Ke/conan)
58 | - MFLVC: [multi-level feature learning for contrastive multi-view clustering](https://github.com/SubmissionsIn/MFLVC)
59 | - DSMVC: [deep safe multi-view clustering](https://github.com/Gasteinh/DSMVC)
60 | - GCFAggMVC: [global and cross-view feature aggregation for multi-view clustering](https://github.com/Galaxy922/GCFAggMVC)
61 | - DealMVC: [dual contrastive calibration for multi-view clustering](https://github.com/xihongyang1999/DealMVC)
62 |
63 |
64 |
65 |
66 | ## 6.Acknowledgments
67 |
68 | Our proposed SCMVC are inspired by [MFLVC](https://github.com/SubmissionsIn/MFLVC), [GCFAggMVC](https://github.com/Galaxy922/GCFAggMVC), and [SEM](https://github.com/SubmissionsIn/SEM). Thanks for these valuable works.
69 |
70 | ## 7.Citation
71 | If you use our code or datasets in this repository for your research, please cite our papers.
72 | ```latex
73 | @ARTICLE{10499831,
74 | author={Wu, Song and Zheng, Yan and Ren, Yazhou and He, Jing and Pu, Xiaorong and Huang, Shudong and Hao, Zhifeng and He, Lifang},
75 | journal={IEEE Transactions on Multimedia},
76 | title={Self-Weighted Contrastive Fusion for Deep Multi-View Clustering},
77 | year={2024},
78 | volume={},
79 | number={},
80 | pages={1-13},
81 | doi={10.1109/TMM.2024.3387298}
82 | }
83 |
84 | @article{xu2024self,
85 | title={Self-weighted contrastive learning among multiple views for mitigating representation degeneration},
86 | author={Xu, Jie and Chen, Shuo and Ren, Yazhou and Shi, Xiaoshuang and Shen, Hengtao and Niu, Gang and Zhu, Xiaofeng},
87 | journal={Advances in Neural Information Processing Systems},
88 | volume={36},
89 | year={2024}
90 | }
91 |
92 | @InProceedings{Xu_2022_CVPR,
93 | author = {Xu, Jie and Tang, Huayi and Ren, Yazhou and Peng, Liang and Zhu, Xiaofeng and He, Lifang},
94 | title = {Multi-Level Feature Learning for Contrastive Multi-View Clustering},
95 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
96 | year = {2022},
97 | pages = {16051-16060}
98 | }
99 | ```
100 |
101 | If you have any problems, please contact me by songwu.work@outlook.com.
102 |
103 |
104 |
--------------------------------------------------------------------------------
/data/readme.txt:
--------------------------------------------------------------------------------
1 | Please read README.md to download the corresponding datasets, and adapt different seed to explore the best performance.
2 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | from sklearn.preprocessing import MinMaxScaler
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import scipy.io
5 | import torch
6 |
7 | class BDGP(Dataset):
8 | def __init__(self, path):
9 | data1 = scipy.io.loadmat(path+'BDGP.mat')['X1'].astype(np.float32)
10 | data2 = scipy.io.loadmat(path+'BDGP.mat')['X2'].astype(np.float32)
11 | labels = scipy.io.loadmat(path+'BDGP.mat')['Y'].transpose()
12 | self.x1 = data1
13 | self.x2 = data2
14 | self.y = labels
15 |
16 | def __len__(self):
17 | return self.x1.shape[0]
18 |
19 | def __getitem__(self, idx):
20 | return [torch.from_numpy(self.x1[idx]), torch.from_numpy(
21 | self.x2[idx])], torch.from_numpy(self.y[idx]), torch.from_numpy(np.array(idx)).long()
22 |
23 |
24 | class CCV(Dataset):
25 | def __init__(self, path):
26 | self.data1 = np.load(path+'STIP.npy').astype(np.float32)
27 | scaler = MinMaxScaler()
28 | self.data1 = scaler.fit_transform(self.data1)
29 | self.data2 = np.load(path+'SIFT.npy').astype(np.float32)
30 | self.data3 = np.load(path+'MFCC.npy').astype(np.float32)
31 | self.labels = np.load(path+'label.npy')
32 |
33 | def __len__(self):
34 | return 6773
35 |
36 | def __getitem__(self, idx):
37 | x1 = self.data1[idx]
38 | x2 = self.data2[idx]
39 | x3 = self.data3[idx]
40 |
41 | return [torch.from_numpy(x1), torch.from_numpy(
42 | x2), torch.from_numpy(x3)], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long()
43 |
44 |
45 | class MNIST_USPS(Dataset):
46 | def __init__(self, path):
47 | self.Y = scipy.io.loadmat(path + 'MNIST_USPS.mat')['Y'].astype(np.int32).reshape(5000,)
48 | self.V1 = scipy.io.loadmat(path + 'MNIST_USPS.mat')['X1'].astype(np.float32)
49 | self.V2 = scipy.io.loadmat(path + 'MNIST_USPS.mat')['X2'].astype(np.float32)
50 |
51 | def __len__(self):
52 | return 5000
53 |
54 | def __getitem__(self, idx):
55 |
56 | x1 = self.V1[idx].reshape(784)
57 | x2 = self.V2[idx].reshape(784)
58 | return [torch.from_numpy(x1), torch.from_numpy(x2)], self.Y[idx], torch.from_numpy(np.array(idx)).long()
59 |
60 |
61 | class Fashion(Dataset):
62 | def __init__(self, path):
63 | self.Y = scipy.io.loadmat(path + 'Fashion.mat')['Y'].astype(np.int32).reshape(10000,)
64 | self.V1 = scipy.io.loadmat(path + 'Fashion.mat')['X1'].astype(np.float32)
65 | self.V2 = scipy.io.loadmat(path + 'Fashion.mat')['X2'].astype(np.float32)
66 | self.V3 = scipy.io.loadmat(path + 'Fashion.mat')['X3'].astype(np.float32)
67 |
68 | def __len__(self):
69 | return 10000
70 |
71 | def __getitem__(self, idx):
72 |
73 | x1 = self.V1[idx].reshape(784)
74 | x2 = self.V2[idx].reshape(784)
75 | x3 = self.V3[idx].reshape(784)
76 |
77 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], self.Y[idx], torch.from_numpy(np.array(idx)).long()
78 |
79 | class Caltech(Dataset):
80 | def __init__(self, path, view):
81 | data = scipy.io.loadmat(path)
82 | scaler = MinMaxScaler()
83 | self.view1 = scaler.fit_transform(data['X1'].astype(np.float32))
84 | self.view2 = scaler.fit_transform(data['X2'].astype(np.float32))
85 | self.view3 = scaler.fit_transform(data['X3'].astype(np.float32))
86 | self.view4 = scaler.fit_transform(data['X4'].astype(np.float32))
87 | self.view5 = scaler.fit_transform(data['X5'].astype(np.float32))
88 | self.labels = scipy.io.loadmat(path)['Y'].transpose()
89 | self.view = view
90 |
91 | def __len__(self):
92 | return 1400
93 |
94 | def __getitem__(self, idx):
95 | if self.view == 2:
96 | return [torch.from_numpy(
97 | self.view1[idx]), torch.from_numpy(self.view2[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long()
98 | if self.view == 3:
99 | return [torch.from_numpy(self.view1[idx]), torch.from_numpy(
100 | self.view2[idx]), torch.from_numpy(self.view5[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long()
101 | if self.view == 4:
102 | return [torch.from_numpy(self.view1[idx]), torch.from_numpy(self.view2[idx]), torch.from_numpy(
103 | self.view5[idx]), torch.from_numpy(self.view4[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long()
104 | if self.view == 5:
105 | return [torch.from_numpy(self.view1[idx]), torch.from_numpy(
106 | self.view2[idx]), torch.from_numpy(self.view5[idx]), torch.from_numpy(
107 | self.view4[idx]), torch.from_numpy(self.view3[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long()
108 |
109 | class cifar_10():
110 | def __init__(self, path):
111 | data = scipy.io.loadmat(path + 'cifar10.mat')
112 | self.Y = data['truelabel'][0][0].astype(np.int32).reshape(50000,)
113 | self.V1 = data['data'][0][0].T.astype(np.float32)
114 | self.V2 = data['data'][1][0].T.astype(np.float32)
115 | self.V3 = data['data'][2][0].T.astype(np.float32)
116 | def __len__(self):
117 | return 50000
118 | def __getitem__(self, idx):
119 | x1 = self.V1[idx]
120 | x2 = self.V2[idx]
121 | x3 = self.V3[idx]
122 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], self.Y[idx], torch.from_numpy(np.array(idx)).long()
123 |
124 | class cifar_100():
125 | def __init__(self, path):
126 | data = scipy.io.loadmat(path + 'cifar100.mat')
127 | self.Y = data['truelabel'][0][0].astype(np.int32).reshape(50000,)
128 | self.V1 = data['data'][0][0].T.astype(np.float32)
129 | self.V2 = data['data'][1][0].T.astype(np.float32)
130 | self.V3 = data['data'][2][0].T.astype(np.float32)
131 | def __len__(self):
132 | return 50000
133 | def __getitem__(self, idx):
134 | x1 = self.V1[idx]
135 | x2 = self.V2[idx]
136 | x3 = self.V3[idx]
137 |
138 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)],self.Y[idx], torch.from_numpy(np.array(idx)).long()
139 |
140 | class synthetic3d():
141 | def __init__(self, path):
142 | data = scipy.io.loadmat(path + 'synthetic3d.mat')
143 | self.Y = data['Y'].astype(np.int32).reshape(600,)
144 | self.V1 = data['X'][0][0].astype(np.float32)
145 | self.V2 = data['X'][1][0].astype(np.float32)
146 | self.V3 = data['X'][2][0].astype(np.float32)
147 | def __len__(self):
148 | return 600
149 | def __getitem__(self, idx):
150 | x1 = self.V1[idx]
151 | x2 = self.V2[idx]
152 | x3 = self.V3[idx]
153 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], \
154 | self.Y[idx], torch.from_numpy(np.array(idx)).long()
155 |
156 | class prokaryotic():
157 | def __init__(self, path):
158 | data = scipy.io.loadmat(path + 'prokaryotic.mat')
159 | self.Y = data['Y'].astype(np.int32).reshape(551,)
160 | self.V1 = data['X'][0][0].astype(np.float32)
161 | self.V2 = data['X'][1][0].astype(np.float32)
162 | self.V3 = data['X'][2][0].astype(np.float32)
163 | def __len__(self):
164 | return 551
165 | def __getitem__(self, idx):
166 | x1 = self.V1[idx]
167 | x2 = self.V2[idx]
168 | x3 = self.V3[idx]
169 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], \
170 | self.Y[idx], torch.from_numpy(np.array(idx)).long()
171 |
172 | def load_data(dataset):
173 | if dataset == "BDGP":
174 | dataset = BDGP('./data/')
175 | dims = [1750, 79]
176 | view = 2
177 | data_size = 2500
178 | class_num = 5
179 | elif dataset == "MNIST-USPS":
180 | dataset = MNIST_USPS('./data/')
181 | dims = [784, 784]
182 | view = 2
183 | class_num = 10
184 | data_size = 5000
185 | elif dataset == "CCV":
186 | dataset = CCV('./data/')
187 | dims = [5000, 5000, 4000]
188 | view = 3
189 | data_size = 6773
190 | class_num = 20
191 | elif dataset == "Fashion":
192 | dataset = Fashion('./data/')
193 | dims = [784, 784, 784]
194 | view = 3
195 | data_size = 10000
196 | class_num = 10
197 | elif dataset == "Caltech-2V":
198 | dataset = Caltech('data/Caltech-5V.mat', view=2)
199 | dims = [40, 254]
200 | view = 2
201 | data_size = 1400
202 | class_num = 7
203 | elif dataset == "Caltech-3V":
204 | dataset = Caltech('data/Caltech-5V.mat', view=3)
205 | dims = [40, 254, 928]
206 | view = 3
207 | data_size = 1400
208 | class_num = 7
209 | elif dataset == "Caltech-4V":
210 | dataset = Caltech('data/Caltech-5V.mat', view=4)
211 | dims = [40, 254, 928, 512]
212 | view = 4
213 | data_size = 1400
214 | class_num = 7
215 | elif dataset == "Caltech-5V":
216 | dataset = Caltech('data/Caltech-5V.mat', view=5)
217 | dims = [40, 254, 928, 512, 1984]
218 | view = 5
219 | data_size = 1400
220 | class_num = 7
221 | elif dataset == "Synthetic3d":
222 | dataset = synthetic3d('./data/')
223 | dims = [3,3,3]
224 | view = 3
225 | data_size = 600
226 | class_num = 3
227 | elif dataset == "Prokaryotic":
228 | dataset = prokaryotic('./data/')
229 | dims = [438, 3, 393]
230 | view = 3
231 | data_size = 551
232 | class_num = 4
233 | elif dataset == "Cifar10":
234 | dataset = cifar_10('./data/')
235 | dims = [512, 2048, 1024]
236 | view = 3
237 | data_size = 50000
238 | class_num = 10
239 | elif dataset == "Cifar100":
240 | dataset = cifar_100('./data/')
241 | dims = [512, 2048, 1024]
242 | view = 3
243 | data_size = 50000
244 | class_num = 100
245 | else:
246 | raise NotImplementedError
247 | return dataset, dims, view, data_size, class_num
248 |
--------------------------------------------------------------------------------
/figures/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongwuJob/SCMVC/26963adc4964fb1cc6277fea8d6cfb5152989d37/figures/performance.png
--------------------------------------------------------------------------------
/figures/view_change.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongwuJob/SCMVC/26963adc4964fb1cc6277fea8d6cfb5152989d37/figures/view_change.png
--------------------------------------------------------------------------------
/figures/workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongwuJob/SCMVC/26963adc4964fb1cc6277fea8d6cfb5152989d37/figures/workflow.png
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class ContrastiveLoss(nn.Module):
5 | def __init__(self, batch_size, temperature, device):
6 | super(ContrastiveLoss, self).__init__()
7 | self.batch_size = batch_size
8 | self.temperature = temperature
9 | self.device = device
10 |
11 | def forward(self, h_i, h_j, weight=None):
12 | N =self.batch_size
13 | similarity_matrix = torch.matmul(h_i, h_j.T) / self.temperature
14 | positives = torch.diag(similarity_matrix)
15 | mask = torch.ones((N, N)).to(self.device)
16 | mask = mask.fill_diagonal_(0)
17 |
18 | nominator = torch.exp(positives)
19 | denominator = (mask.bool()) * torch.exp(similarity_matrix)
20 |
21 | loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
22 | loss = torch.sum(loss_partial) / N
23 | loss = weight * loss if weight is not None else loss
24 |
25 | return loss
26 |
--------------------------------------------------------------------------------
/metric.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import v_measure_score, adjusted_rand_score, accuracy_score
2 | from sklearn.cluster import KMeans
3 | from scipy.optimize import linear_sum_assignment
4 | from torch.utils.data import DataLoader
5 | import numpy as np
6 | import torch
7 |
8 | def cluster_acc(y_true, y_pred):
9 | y_true = y_true.astype(np.int64)
10 | assert y_pred.size == y_true.size
11 | D = max(y_pred.max(), y_true.max()) + 1
12 | w = np.zeros((D, D), dtype=np.int64)
13 | for i in range(y_pred.size):
14 | w[y_pred[i], y_true[i]] += 1
15 | u = linear_sum_assignment(w.max() - w)
16 | ind = np.concatenate([u[0].reshape(u[0].shape[0], 1), u[1].reshape([u[0].shape[0], 1])], axis=1)
17 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
18 |
19 | def purity(y_true, y_pred):
20 | y_voted_labels = np.zeros(y_true.shape)
21 | labels = np.unique(y_true)
22 | ordered_labels = np.arange(labels.shape[0])
23 | for k in range(labels.shape[0]):
24 | y_true[y_true == labels[k]] = ordered_labels[k]
25 | labels = np.unique(y_true)
26 | bins = np.concatenate((labels, [np.max(labels)+1]), axis=0)
27 |
28 | for cluster in np.unique(y_pred):
29 | hist, _ = np.histogram(y_true[y_pred == cluster], bins=bins)
30 | winner = np.argmax(hist)
31 | y_voted_labels[y_pred == cluster] = winner
32 |
33 | return accuracy_score(y_true, y_voted_labels)
34 |
35 | def evaluate(label, pred):
36 | nmi = v_measure_score(label, pred)
37 | ari = adjusted_rand_score(label, pred)
38 | acc = cluster_acc(label, pred)
39 | pur = purity(label, pred)
40 | return nmi, ari, acc, pur
41 |
42 | def valid(model, device, dataset, view, data_size, class_num, eval_h=False, epoch=None):
43 | test_loader = DataLoader(
44 | dataset,
45 | batch_size=data_size,
46 | shuffle=False,
47 | )
48 | for batch_idx, (xs, y, _) in enumerate(test_loader):
49 | for v in range(view):
50 | xs[v] = xs[v].to(device)
51 | labels = y.cpu().detach().data.numpy().squeeze()
52 |
53 | # inference
54 | with torch.no_grad():
55 | xrs, zs, rs, H = model(xs)
56 |
57 | if eval_h:
58 | print("Clustering results on low-level features of each view:")
59 | for v in range(view):
60 | kmeans = KMeans(n_clusters=class_num, n_init=100)
61 | y_pred = kmeans.fit_predict(zs[v].cpu().data.numpy())
62 | nmi, ari, acc, pur = evaluate(labels, y_pred)
63 | print('ACC{} = {:.4f} NMI{} = {:.4f} ARI{} = {:.4f} PUR{}={:.4f}'.format(v + 1, acc,
64 | v + 1, nmi,
65 | v + 1, ari,
66 | v + 1, pur))
67 | print("Clustering results on view-consensus features of each view:")
68 | for v in range(view):
69 | y_pred = kmeans.fit_predict(rs[v].cpu().data.numpy())
70 | nmi, ari, acc, pur = evaluate(labels, y_pred)
71 | print('ACC{} = {:.4f} NMI{} = {:.4f} ARI{} = {:.4f} PUR{}={:.4f}'.format(v + 1, acc,
72 | v + 1, nmi,
73 | v + 1, ari,
74 | v + 1, pur))
75 |
76 | # Clustering results on global features
77 | kmeans = KMeans(n_clusters=class_num, n_init=100)
78 | y_pred = kmeans.fit_predict(H.cpu().data.numpy())
79 | nmi, ari, acc, pur = evaluate(labels, y_pred)
80 | if epoch is not None:
81 | print('Epoch {}'.format(epoch),'The clustering performace: ACC = {:.4f} NMI = {:.4f} ARI = {:.4f} PUR={:.4f}'.format(acc, nmi, ari, pur))
82 | else:
83 | print('The clustering performace: ACC = {:.4f} NMI = {:.4f} ARI = {:.4f} PUR={:.4f}'.format(acc, nmi, ari, pur))
84 | return acc, nmi, pur
85 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn.functional import normalize
3 | import torch
4 |
5 | # Encoder
6 | class Encoder(nn.Module):
7 | def __init__(self, input_dim, feature_dim):
8 | super(Encoder, self).__init__()
9 | self.encoder = nn.Sequential(
10 | nn.Linear(input_dim, 500),
11 | nn.ReLU(),
12 | nn.Linear(500, 500),
13 | nn.ReLU(),
14 | nn.Linear(500, 2000),
15 | nn.ReLU(),
16 | nn.Linear(2000, feature_dim),
17 | )
18 |
19 | def forward(self, x):
20 | return self.encoder(x)
21 |
22 | # Decoder
23 | class Decoder(nn.Module):
24 | def __init__(self, input_dim, feature_dim):
25 | super(Decoder, self).__init__()
26 | self.decoder = nn.Sequential(
27 | nn.Linear(feature_dim, 2000),
28 | nn.ReLU(),
29 | nn.Linear(2000, 500),
30 | nn.ReLU(),
31 | nn.Linear(500, 500),
32 | nn.ReLU(),
33 | nn.Linear(500, input_dim)
34 | )
35 |
36 | def forward(self, x):
37 | return self.decoder(x)
38 |
39 | # SCMVC Network
40 | class Network(nn.Module):
41 | def __init__(self, view, input_size, feature_dim, high_feature_dim, device):
42 | super(Network, self).__init__()
43 | self.view = view
44 | self.encoders = []
45 | self.decoders = []
46 | for v in range(view):
47 | self.encoders.append(Encoder(input_size[v], feature_dim).to(device))
48 | self.decoders.append(Decoder(input_size[v], feature_dim).to(device))
49 | self.encoders = nn.ModuleList(self.encoders)
50 | self.decoders = nn.ModuleList(self.decoders)
51 |
52 | # global features fusion layer
53 | self.feature_fusion_module = nn.Sequential(
54 | nn.Linear(self.view * feature_dim, 256),
55 | nn.ReLU(),
56 | nn.Linear(256, high_feature_dim)
57 | )
58 |
59 | # view-consensus features learning layer
60 | self.common_information_module = nn.Sequential(
61 | nn.Linear(feature_dim, high_feature_dim)
62 | )
63 |
64 | # global feature fusion
65 | def feature_fusion(self, zs, zs_gradient):
66 | input = torch.cat(zs, dim=1) if zs_gradient else torch.cat(zs, dim=1).detach()
67 | return normalize(self.feature_fusion_module(input),dim=1)
68 |
69 | def forward(self, xs, zs_gradient=True):
70 | rs = []
71 | xrs = []
72 | zs = []
73 | for v in range(self.view):
74 | x = xs[v]
75 | z = self.encoders[v](x)
76 | xr = self.decoders[v](z)
77 | r = normalize(self.common_information_module(z),dim=1)
78 |
79 | rs.append(r)
80 | zs.append(z)
81 | xrs.append(xr)
82 |
83 | H = self.feature_fusion(zs,zs_gradient)
84 | return xrs,zs,rs,H
85 |
86 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python==3.7.13
2 | pytorch==1.12.0
3 | numpy==1.21.5
4 | scikit-learn==0.22.2.post1
5 | scipy==1.7.3
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from network import Network
3 | from metric import valid
4 | import argparse
5 | from dataloader import load_data
6 |
7 | # MNIST-USPS
8 | # BDGP
9 | # CCV
10 | # Fashion
11 | # Caltech-2V
12 | # Caltech-3V
13 | # Caltech-4V
14 | # Caltech-5V
15 | # Cifar10
16 | # Cifar100
17 | # Prokaryotic
18 | # Synthetic3d
19 | Dataname = 'MNIST-USPS'
20 | parser = argparse.ArgumentParser(description='test')
21 | parser.add_argument('--dataset', default=Dataname)
22 | parser.add_argument("--feature_dim", default=64)
23 | parser.add_argument("--hide_feature_dim", default=20)
24 | args = parser.parse_args()
25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26 | dataset, dims, view, data_size, class_num = load_data(args.dataset)
27 | model = Network(view, dims, args.feature_dim, args.hide_feature_dim, device)
28 | model = model.to(device)
29 |
30 | checkpoint = torch.load('./models/' + args.dataset + '.pth')
31 | model.load_state_dict(checkpoint)
32 |
33 | model.eval()
34 | print("Dataset:{}".format(args.dataset))
35 | print("Datasize:" + str(data_size))
36 | print("Loading models...")
37 | acc, nmi, pur = valid(model, device, dataset, view, data_size, class_num, eval_h=True)
38 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3 | import torch
4 | from network import Network
5 | from metric import valid
6 | from torch.utils.data import Dataset
7 | import numpy as np
8 | import argparse
9 | import random
10 | from loss import ContrastiveLoss
11 | from dataloader import load_data
12 |
13 | # MNIST-USPS
14 | # BDGP
15 | # CCV
16 | # Fashion
17 | # Caltech-2V
18 | # Caltech-3V
19 | # Caltech-4V
20 | # Caltech-5V
21 | # Cifar10
22 | # Cifar100
23 | # Prokaryotic
24 | # Synthetic3d
25 | Dataname = 'MNIST-USPS'
26 | parser = argparse.ArgumentParser(description='train')
27 | parser.add_argument('--dataset', default=Dataname)
28 | parser.add_argument('--batch_size', default=256, type=int)
29 | parser.add_argument("--learning_rate", default=0.0003)
30 | parser.add_argument("--weight_decay", default=0.)
31 | parser.add_argument("--pre_epochs", default=200)
32 | parser.add_argument("--con_epochs", default=50)
33 | parser.add_argument("--feature_dim", default=64)
34 | parser.add_argument("--high_feature_dim", default=20)
35 | parser.add_argument("--temperature", default=1)
36 | args = parser.parse_args()
37 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38 |
39 | if args.dataset == "MNIST-USPS":
40 | args.con_epochs = 50
41 | seed = 10
42 | if args.dataset == "BDGP":
43 | args.con_epochs = 10 # 20
44 | seed = 30
45 | if args.dataset == "CCV":
46 | args.con_epochs = 50 # 100
47 | seed = 100
48 | args.tune_epochs = 200
49 | if args.dataset == "Fashion":
50 | args.con_epochs = 50 # 100
51 | seed = 10
52 | if args.dataset == "Caltech-2V":
53 | args.con_epochs = 100
54 | seed = 200
55 | args.tune_epochs = 200
56 | if args.dataset == "Caltech-3V":
57 | args.con_epochs = 100
58 | seed = 30
59 | if args.dataset == "Caltech-4V":
60 | args.con_epochs = 100
61 | seed = 100
62 | if args.dataset == "Caltech-5V":
63 | args.con_epochs = 100
64 | seed = 1000000
65 | if args.dataset == "Cifar10":
66 | args.con_epochs = 10
67 | seed = 10
68 | if args.dataset == "Cifar100":
69 | args.con_epochs = 20
70 | seed = 10
71 | if args.dataset == "Prokaryotic":
72 | args.con_epochs = 20
73 | seed = 10000
74 | if args.dataset == "Synthetic3d":
75 | args.con_epochs = 100
76 | seed = 100
77 |
78 | def setup_seed(seed):
79 | torch.manual_seed(seed)
80 | torch.cuda.manual_seed_all(seed)
81 | np.random.seed(seed)
82 | random.seed(seed)
83 | torch.backends.cudnn.deterministic = True
84 |
85 | setup_seed(seed)
86 |
87 | dataset, dims, view, data_size, class_num = load_data(args.dataset)
88 | data_loader = torch.utils.data.DataLoader(
89 | dataset,
90 | batch_size=args.batch_size,
91 | shuffle=True,
92 | drop_last=True,
93 | )
94 |
95 | def compute_view_value(rs, H, view):
96 | N = H.shape[0]
97 | w = []
98 | # all features are normalized
99 | global_sim = torch.matmul(H,H.t())
100 | for v in range(view):
101 | view_sim = torch.matmul(rs[v],rs[v].t())
102 | related_sim = torch.matmul(rs[v],H.t())
103 | # The implementation of MMD
104 | w_v = (torch.sum(view_sim) + torch.sum(global_sim) - 2 * torch.sum(related_sim)) / (N*N)
105 | w.append(torch.exp(-w_v))
106 | w = torch.stack(w)
107 | w = w / torch.sum(w)
108 | return w.squeeze()
109 |
110 |
111 | def pretrain(epoch):
112 | tot_loss = 0.
113 | criterion = torch.nn.MSELoss()
114 | for batch_idx, (xs, _, _) in enumerate(data_loader):
115 | for v in range(view):
116 | xs[v] = xs[v].to(device)
117 | optimizer.zero_grad()
118 | xrs,_,_,_ = model(xs)
119 | loss_list = []
120 | for v in range(view):
121 | loss_list.append(criterion(xs[v], xrs[v]))
122 | loss = sum(loss_list)
123 | loss.backward()
124 | optimizer.step()
125 | tot_loss += loss.item()
126 | print('Epoch {}'.format(epoch), 'Loss:{:.6f}'.format(tot_loss/len(data_loader)))
127 |
128 | def contrastive_train(epoch):
129 | tot_loss = 0.
130 | mse = torch.nn.MSELoss()
131 | for batch_idx, (xs, _, _) in enumerate(data_loader):
132 | for v in range(view):
133 | xs[v] = xs[v].to(device)
134 | optimizer.zero_grad()
135 | xrs, zs, rs, H = model(xs)
136 | loss_list = []
137 |
138 | # compute adaptive weights for each view
139 | with torch.no_grad():
140 | w = compute_view_value(rs, H, view)
141 |
142 | for v in range(view):
143 | # Self-weighted contrastive learning loss
144 | loss_list.append(contrastiveloss(H, rs[v], w[v]))
145 | # Reconstruction loss
146 | loss_list.append(mse(xs[v], xrs[v]))
147 | loss = sum(loss_list)
148 | loss.backward()
149 | optimizer.step()
150 | tot_loss += loss.item()
151 | print('Epoch {}'.format(epoch), 'Loss:{:.6f}'.format(tot_loss/len(data_loader)))
152 |
153 | accs = []
154 | nmis = []
155 | purs = []
156 | if not os.path.exists('./models'):
157 | os.makedirs('./models')
158 | T = 1
159 | for i in range(T):
160 | print("ROUND:{}".format(i+1))
161 | setup_seed(seed)
162 | model = Network(view, dims, args.feature_dim, args.high_feature_dim, device)
163 | print(model)
164 | model = model.to(device)
165 | state = model.state_dict()
166 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
167 | contrastiveloss = ContrastiveLoss(args.batch_size, args.temperature, device).to(device)
168 | best_acc, best_nmi, best_pur = 0, 0, 0
169 |
170 | epoch = 1
171 | while epoch <= args.pre_epochs:
172 | pretrain(epoch)
173 | epoch += 1
174 | # acc, nmi, pur = valid(model, device, dataset, view, data_size, class_num, eval_h=True, epoch=epoch)
175 |
176 | while epoch <= args.pre_epochs + args.con_epochs:
177 | contrastive_train(epoch)
178 | acc, nmi, pur = valid(model, device, dataset, view, data_size, class_num, eval_h=False, epoch=epoch)
179 |
180 | if acc > best_acc:
181 | best_acc, best_nmi, best_pur = acc, nmi, pur
182 | state = model.state_dict()
183 | torch.save(state, './models/' + args.dataset + '.pth')
184 | epoch += 1
185 |
186 | # The final result
187 | accs.append(best_acc)
188 | nmis.append(best_nmi)
189 | purs.append(best_pur)
190 | print('The best clustering performace: ACC = {:.4f} NMI = {:.4f} PUR={:.4f}'.format(best_acc, best_nmi, best_pur))
191 |
--------------------------------------------------------------------------------