├── LICENSE ├── README.md ├── data └── the_datasets.txt ├── dataloader.py ├── loss.py ├── metric.py ├── models └── the_trained_models.txt ├── network.py ├── test.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Submissions in here 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Self-Weighted Contrastive Learning among Multiple Views for Mitigating Representation Degeneration 2 | 3 | Prepare a multi-view/modal dataset, for example: 4 | 5 | The format of a multi-view dataset ($N$ samples and $V$ views) should be $\{\mathbf{X}^1, \mathbf{X}^2, \dots, \mathbf{X}^v, \dots, \mathbf{X}^V, \mathbf{Y}\}$, where the $v$-th view data is $\mathbf{X}^v\in \mathbb{R}^{N\times d_v}$ and the class label is $\mathbf{Y}\in \mathbb{R}^{N\times 1}$ (The label is leveraged to evaluate the performance of representation learning in unsupervised settings). This type of data is suitable for fully connected neural networks, otherwise the model in "network.py" needs to be modified. 6 | 7 | The public datasets and our trained models are available at **[Download](https://drive.google.com/drive/folders/1JBhb66b_z2wB4xWcuvrRvaINhgeCxiDS?usp=drive_link)** or **[国内下载源](https://pan.baidu.com/s/1m8Vi3RShRMDUTjs-TZCiAQ?pwd=0928)**. 8 | 9 | 10 | Requirements: 11 | 12 | python==3.7.11 13 | pytorch==1.9.0 14 | numpy==1.20.1 15 | scikit-learn==0.22.2.post1 16 | scipy==1.6.2 17 | 18 | To test the trained model, run: 19 | ```bash 20 | python test.py 21 | ``` 22 | 23 | 24 | To train a new model, run: 25 | ```bash 26 | python train.py 27 | ``` 28 | -------------------------------------------------------------------------------- /data/the_datasets.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import MinMaxScaler 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import scipy.io 5 | import torch 6 | 7 | 8 | class CCV(Dataset): 9 | def __init__(self, path): 10 | self.data1 = np.load(path+'STIP.npy').astype(np.float32) 11 | scaler = MinMaxScaler() 12 | self.data1 = scaler.fit_transform(self.data1) 13 | self.data2 = np.load(path+'SIFT.npy').astype(np.float32) 14 | self.data3 = np.load(path+'MFCC.npy').astype(np.float32) 15 | self.labels = np.load(path+'label.npy') 16 | print(self.data1.shape) 17 | print(self.data2.shape) 18 | print(self.data3.shape) 19 | # scipy.io.savemat('CCV.mat', {'X1': self.data1, 'X2': self.data2, 'X3': self.data3, 'Y': self.labels}) 20 | 21 | def __len__(self): 22 | return 6773 23 | 24 | def __getitem__(self, idx): 25 | x1 = self.data1[idx] 26 | x2 = self.data2[idx] 27 | x3 = self.data3[idx] 28 | 29 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() 30 | 31 | 32 | class Caltech_6V(Dataset): 33 | def __init__(self, path, view): 34 | data = scipy.io.loadmat(path) 35 | # print(data) 36 | scaler = MinMaxScaler() 37 | self.view = view 38 | self.multi_view = [] 39 | self.labels = data['Y'].T 40 | self.dims = [] 41 | self.class_num = len(np.unique(self.labels)) 42 | for i in range(view): 43 | # for i in [0, 3]: 44 | self.multi_view.append(scaler.fit_transform(data['X' + str(i + 1)].astype(np.float32))) 45 | print(data['X' + str(i + 1)].shape) 46 | self.dims.append(data['X' + str(i + 1)].shape[1]) 47 | self.data_size = self.multi_view[0].shape[0] 48 | 49 | def __len__(self): 50 | return self.data_size 51 | 52 | def __getitem__(self, idx): 53 | data_getitem = [] 54 | for i in range(self.view): 55 | data_getitem.append(torch.from_numpy(self.multi_view[i][idx])) 56 | return data_getitem, torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() 57 | 58 | 59 | class NUSWIDE(Dataset): 60 | def __init__(self, path, view): 61 | data = scipy.io.loadmat(path) 62 | # print(data) 63 | # scaler = MinMaxScaler() 64 | self.view = view 65 | self.multi_view = [] 66 | self.labels = data['Y'].T 67 | self.dims = [] 68 | self.class_num = len(np.unique(self.labels)) 69 | # print(self.class_num) 70 | # for i in range(5000): 71 | # print(data['X1'][i][-1]) 72 | # X1 = data['X1'][:, :-1] 73 | for i in range(view): 74 | self.multi_view.append(data['X' + str(i + 1)][:, :-1].astype(np.float32)) 75 | # self.multi_view.append(scaler.fit_transform(data['X' + str(i + 1)].astype(np.float32))) 76 | print(data['X' + str(i + 1)][:, :-1].shape) 77 | self.dims.append(data['X' + str(i + 1)][:, :-1].shape[1]) 78 | self.data_size = self.multi_view[0].shape[0] 79 | 80 | def __len__(self): 81 | return self.data_size 82 | 83 | def __getitem__(self, idx): 84 | data_getitem = [] 85 | for i in range(self.view): 86 | data_getitem.append(torch.from_numpy(self.multi_view[i][idx])) 87 | return data_getitem, torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() 88 | 89 | 90 | class DHA(Dataset): 91 | def __init__(self, path, view): 92 | data = scipy.io.loadmat(path) 93 | # print(data) 94 | self.view = view 95 | self.multi_view = [] 96 | self.labels = data['Y'].T 97 | self.dims = [] 98 | self.class_num = len(np.unique(self.labels)) 99 | for i in range(view): 100 | self.multi_view.append(data['X' + str(i + 1)].astype(np.float32)) 101 | print(data['X' + str(i + 1)].shape) 102 | self.dims.append(data['X' + str(i + 1)].shape[1]) 103 | self.data_size = self.multi_view[0].shape[0] 104 | 105 | def __len__(self): 106 | return self.data_size 107 | 108 | def __getitem__(self, idx): 109 | data_getitem = [] 110 | for i in range(self.view): 111 | data_getitem.append(torch.from_numpy(self.multi_view[i][idx])) 112 | return data_getitem, torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() 113 | 114 | 115 | class YoutubeVideo(Dataset): 116 | def __init__(self, path, view): 117 | data = scipy.io.loadmat(path) 118 | # print(data) 119 | # scaler = MinMaxScaler() 120 | self.view = view 121 | self.multi_view = [] 122 | self.labels = data['Y'].T 123 | self.dims = [] 124 | self.class_num = len(np.unique(self.labels)) 125 | print(self.class_num) 126 | for i in range(view): 127 | self.multi_view.append(data['X' + str(i + 1)].astype(np.float32)) 128 | # self.multi_view.append(scaler.fit_transform(data['X' + str(i + 1)].astype(np.float32))) 129 | print(data['X' + str(i + 1)].shape) 130 | self.dims.append(data['X' + str(i + 1)].shape[1]) 131 | 132 | self.data_size = self.multi_view[0].shape[0] 133 | 134 | def __len__(self): 135 | return self.data_size 136 | 137 | def __getitem__(self, idx): 138 | data_getitem = [] 139 | for i in range(self.view): 140 | data_getitem.append(torch.from_numpy(self.multi_view[i][idx])) 141 | return data_getitem, torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() 142 | 143 | 144 | def load_data(dataset): 145 | if dataset == "CCV": 146 | dataset = CCV('./data/') 147 | dims = [5000, 5000, 4000] 148 | view = 3 149 | data_size = 6773 150 | class_num = 20 151 | elif dataset == "Caltech": 152 | dataset = Caltech_6V('data/Caltech.mat', view=6) 153 | dims = dataset.dims 154 | view = dataset.view 155 | data_size = dataset.data_size 156 | class_num = dataset.class_num 157 | elif dataset == "NUSWIDE": 158 | dataset = NUSWIDE('data/NUSWIDE.mat', view=5) 159 | dims = dataset.dims 160 | view = dataset.view 161 | data_size = dataset.data_size 162 | class_num = dataset.class_num 163 | elif dataset == "DHA": 164 | dataset = DHA('data/DHA.mat', view=2) 165 | dims = dataset.dims 166 | view = dataset.view 167 | data_size = dataset.data_size 168 | class_num = dataset.class_num 169 | elif dataset == "YoutubeVideo": 170 | dataset = YoutubeVideo("./data/Video-3V.mat", view=3) 171 | dims = dataset.dims 172 | view = dataset.view 173 | data_size = dataset.data_size 174 | class_num = dataset.class_num 175 | else: 176 | raise NotImplementedError 177 | return dataset, dims, view, data_size, class_num 178 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class Loss(nn.Module): 8 | def __init__(self, batch_size, class_num, temperature_f, device): 9 | super(Loss, self).__init__() 10 | self.batch_size = batch_size 11 | self.class_num = class_num 12 | self.temperature_f = temperature_f 13 | # self.temperature_l = temperature_l 14 | self.device = device 15 | 16 | self.mask = self.mask_correlated_samples(batch_size) 17 | # self.similarity = nn.CosineSimilarity(dim=2) 18 | self.criterion = nn.CrossEntropyLoss(reduction="sum") 19 | 20 | def mask_correlated_samples(self, N): 21 | mask = torch.ones((N, N)) 22 | mask = mask.fill_diagonal_(0) 23 | for i in range(N//2): 24 | mask[i, N//2 + i] = 0 25 | mask[N//2 + i, i] = 0 26 | mask = mask.bool() 27 | return mask 28 | 29 | def forward_feature_InfoNCE(self, h_i, h_j, batch_size=256): 30 | self.batch_size = batch_size 31 | 32 | N = 2 * self.batch_size 33 | h = torch.cat((h_i, h_j), dim=0) 34 | 35 | sim = torch.matmul(h, h.T) / self.temperature_f 36 | sim_i_j = torch.diag(sim, self.batch_size) 37 | sim_j_i = torch.diag(sim, -self.batch_size) 38 | 39 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) 40 | mask = self.mask_correlated_samples(N) 41 | negative_samples = sim[mask].reshape(N, -1) 42 | 43 | labels = torch.zeros(N).to(positive_samples.device).long() 44 | logits = torch.cat((positive_samples, negative_samples), dim=1) 45 | loss = self.criterion(logits, labels) 46 | loss /= N 47 | return loss 48 | 49 | def forward_feature_PSCL(self, z1, z2, r=3.0): # r=3.0 50 | mask1 = (torch.norm(z1, p=2, dim=1) < np.sqrt(r)).float().unsqueeze(1) 51 | mask2 = (torch.norm(z2, p=2, dim=1) < np.sqrt(r)).float().unsqueeze(1) 52 | z1 = mask1 * z1 + (1 - mask1) * F.normalize(z1, dim=1) * np.sqrt(r) 53 | z2 = mask2 * z2 + (1 - mask2) * F.normalize(z2, dim=1) * np.sqrt(r) 54 | loss_part1 = -2 * torch.mean(z1 * z2) * z1.shape[1] 55 | square_term = torch.matmul(z1, z2.T) ** 2 56 | loss_part2 = torch.mean(torch.triu(square_term, diagonal=1) + torch.tril(square_term, diagonal=-1)) * \ 57 | z1.shape[0] / (z1.shape[0] - 1) 58 | 59 | return loss_part1 + loss_part2 60 | 61 | def forward_feature_RINCE(self, out_1, out_2, lam=0.001, q=0.5, temperature=0.5, batch_size=256): 62 | """ 63 | assume out_1 and out_2 are normalized 64 | out_1: [batch_size, dim] 65 | out_2: [batch_size, dim] 66 | lam, q, temperature 67 | """ 68 | # # gather representations in case of distributed training 69 | # # out_1_dist: [batch_size * world_size, dim] 70 | # # out_2_dist: [batch_size * world_size, dim] 71 | # if torch.distributed.is_available() and torch.distributed.is_initialized(): 72 | # out_1_dist = SyncFunction.apply(out_1) 73 | # out_2_dist = SyncFunction.apply(out_2) 74 | # else: 75 | self.batch_size = batch_size 76 | 77 | out_1_dist = out_1 78 | out_2_dist = out_2 79 | 80 | 81 | # out: [2 * batch_size, dim] 82 | # out_dist: [2 * batch_size * world_size, dim] 83 | out = torch.cat([out_1, out_2], dim=0) 84 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) 85 | 86 | similarity = torch.exp(torch.mm(out, out_dist.t()) / temperature) 87 | # neg_mask = self.compute_neg_mask() 88 | N = 2 * self.batch_size 89 | neg_mask = self.mask_correlated_samples(N) 90 | neg = torch.sum(similarity * neg_mask.to(self.device), 1) 91 | 92 | pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature) 93 | pos = torch.cat([pos, pos], dim=0) 94 | 95 | # InfoNCE loss 96 | # loss = -(torch.mean(torch.log(pos / (pos + neg)))) 97 | 98 | # RINCE loss 99 | neg = ((lam*(pos + neg))**q) / q 100 | pos = -(pos**q) / q 101 | loss = pos.mean() + neg.mean() 102 | 103 | return loss 104 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import v_measure_score, adjusted_rand_score, accuracy_score, normalized_mutual_info_score 2 | from sklearn.cluster import KMeans, MiniBatchKMeans 3 | from scipy.optimize import linear_sum_assignment 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | def cluster_acc(y_true, y_pred): 11 | y_true = y_true.astype(np.int64) 12 | assert y_pred.size == y_true.size 13 | D = max(y_pred.max(), y_true.max()) + 1 14 | w = np.zeros((D, D), dtype=np.int64) 15 | for i in range(y_pred.size): 16 | w[y_pred[i], y_true[i]] += 1 17 | u = linear_sum_assignment(w.max() - w) 18 | ind = np.concatenate([u[0].reshape(u[0].shape[0], 1), u[1].reshape([u[0].shape[0], 1])], axis=1) 19 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 20 | 21 | 22 | def purity(y_true, y_pred): 23 | y_voted_labels = np.zeros(y_true.shape) 24 | labels = np.unique(y_true) 25 | ordered_labels = np.arange(labels.shape[0]) 26 | for k in range(labels.shape[0]): 27 | y_true[y_true == labels[k]] = ordered_labels[k] 28 | labels = np.unique(y_true) 29 | bins = np.concatenate((labels, [np.max(labels)+1]), axis=0) 30 | 31 | for cluster in np.unique(y_pred): 32 | hist, _ = np.histogram(y_true[y_pred == cluster], bins=bins) 33 | winner = np.argmax(hist) 34 | y_voted_labels[y_pred == cluster] = winner 35 | 36 | return accuracy_score(y_true, y_voted_labels) 37 | 38 | 39 | def evaluate(label, pred): 40 | # v_measure = v_measure_score(label, pred) 41 | nmi = normalized_mutual_info_score(label, pred) 42 | ari = adjusted_rand_score(label, pred) 43 | acc = cluster_acc(label, pred) 44 | pur = purity(label, pred) 45 | return nmi, ari, acc, pur 46 | 47 | 48 | def inference(loader, model, device, view, data_size): 49 | """ 50 | Xs 51 | Zs 52 | Hs 53 | """ 54 | model.eval() 55 | pred_vectors = [] 56 | Xs = [] 57 | Zs = [] 58 | Hs = [] 59 | Qs = [] 60 | for v in range(view): 61 | pred_vectors.append([]) 62 | Xs.append([]) 63 | Zs.append([]) 64 | Hs.append([]) 65 | Qs.append([]) 66 | labels_vector = [] 67 | 68 | for step, (xs, y, _) in enumerate(loader): 69 | for v in range(view): 70 | xs[v] = xs[v].to(device) 71 | with torch.no_grad(): 72 | zs, _, _, hs, _ = model.forward(xs) 73 | for v in range(view): 74 | zs[v] = zs[v].detach() 75 | hs[v] = hs[v].detach() 76 | Xs[v].extend(xs[v].cpu().detach().numpy()) 77 | Zs[v].extend(zs[v].cpu().detach().numpy()) 78 | Hs[v].extend(hs[v].cpu().detach().numpy()) 79 | labels_vector.extend(y.numpy()) 80 | 81 | labels_vector = np.array(labels_vector).reshape(data_size) 82 | for v in range(view): 83 | Xs[v] = np.array(Xs[v]) 84 | Zs[v] = np.array(Zs[v]) 85 | Hs[v] = np.array(Hs[v]) 86 | Qs[v] = np.array(Qs[v]) 87 | pred_vectors[v] = np.array(pred_vectors[v]) 88 | return Xs, [], Zs, labels_vector, Hs 89 | 90 | 91 | def js_div(p_output, q_output, get_softmax=True): 92 | """ 93 | Function that measures JS divergence between target and output logits: 94 | """ 95 | KLDivLoss = torch.nn.KLDivLoss(reduction='batchmean') 96 | if get_softmax: 97 | p_output = F.softmax(p_output, dim=1) 98 | q_output = F.softmax(q_output, dim=1) 99 | log_mean_output = ((p_output + q_output)/2).log() 100 | return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2 101 | 102 | 103 | def guassian_kernel_mmd(source, target, kernel_mul=2, kernel_num=4, fix_sigma=None): 104 | """Gram kernel matrix 105 | source: sample_size_1 * feature_size 106 | target: sample_size_2 * feature_size 107 | kernel_mul: bandwith of kernels 108 | kernel_num: number of kernels 109 | return: (sample_size_1 + sample_size_2) * (sample_size_1 + sample_size_2) 110 | [ K_ss K_st 111 | K_ts K_tt ] 112 | """ 113 | n_samples = int(source.size()[0]) + int(target.size()[0]) 114 | total = torch.cat([source, target], dim=0) 115 | 116 | total0 = total.unsqueeze(0).expand(int(total.size(0)), \ 117 | int(total.size(0)), \ 118 | int(total.size(1))) 119 | total1 = total.unsqueeze(1).expand(int(total.size(0)), \ 120 | int(total.size(0)), \ 121 | int(total.size(1))) 122 | L2_distance = ((total0 - total1) ** 2).sum(2) 123 | 124 | if fix_sigma: 125 | bandwidth = fix_sigma 126 | else: 127 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) 128 | bandwidth /= kernel_mul ** (kernel_num // 2) 129 | bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)] 130 | 131 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for \ 132 | bandwidth_temp in bandwidth_list] 133 | 134 | return sum(kernel_val) 135 | 136 | 137 | def MMD(source, target, kernel_mul=2, kernel_num=4, fix_sigma=None): 138 | n = int(source.size()[0]) 139 | m = int(target.size()[0]) 140 | 141 | kernels = guassian_kernel_mmd(source, target, 142 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 143 | XX = kernels[:n, :n] 144 | YY = kernels[n:, n:] 145 | XY = kernels[:n, m:] 146 | YX = kernels[m:, :n] 147 | 148 | XX = torch.div(XX, n * n).sum(dim=1).view(1, -1) # K_ss,Source<->Source 149 | XY = torch.div(XY, -n * m).sum(dim=1).view(1, -1) # K_st,Source<->Target 150 | 151 | YX = torch.div(YX, -m * n).sum(dim=1).view(1, -1) # K_ts,Target<->Source 152 | YY = torch.div(YY, m * m).sum(dim=1).view(1, -1) # K_tt,Target<->Target 153 | 154 | loss = XX.sum() + XY.sum() + YX.sum() + YY.sum() 155 | return loss 156 | 157 | 158 | def valid(model, device, dataset, view, data_size, class_num, eval_h=True, eval_z=True, 159 | times_for_K=1.0, Measure='CMI', test=True, sample_num=1000): 160 | test_loader = DataLoader(dataset, batch_size=256, shuffle=False) 161 | X_vectors, pred_vectors, z_vectors, labels_vector, h_vectors = inference(test_loader, model, device, view, data_size) 162 | final_z_features = [] 163 | h_clusters = [] 164 | z_clusters = [] 165 | nmi_matrix_h = np.zeros((view, view)) 166 | nmi_matrix_z = np.zeros((view, view)) 167 | 168 | if eval_h and Measure == 'CMI': 169 | print("Clustering results on each view (H^v):") 170 | acc_avg, nmi_avg, ari_avg, pur_avg = 0, 0, 0, 0 171 | for v in range(view): 172 | kmeans = KMeans(n_clusters=int(class_num * times_for_K), n_init=100) 173 | if len(labels_vector) > 10000: 174 | kmeans = MiniBatchKMeans(n_clusters=int(class_num * times_for_K), batch_size=5000, n_init=100) 175 | y_pred = kmeans.fit_predict(h_vectors[v]) 176 | h_clusters.append(y_pred) 177 | nmi, ari, acc, pur = evaluate(labels_vector, y_pred) 178 | print('ACC{} = {:.4f} NMI{} = {:.4f} ARI{} = {:.4f} PUR{}={:.4f}'.format(v + 1, acc, 179 | v + 1, nmi, 180 | v + 1, ari, 181 | v + 1, pur)) 182 | acc_avg += acc 183 | nmi_avg += nmi 184 | ari_avg += ari 185 | pur_avg += pur 186 | 187 | print('Mean = {:.4f} Mean = {:.4f} Mean = {:.4f} Mean={:.4f}'.format(acc_avg / view, 188 | nmi_avg / view, 189 | ari_avg / view, 190 | pur_avg / view)) 191 | kmeans = KMeans(n_clusters=class_num, n_init=100) 192 | if len(labels_vector) > 10000: 193 | kmeans = MiniBatchKMeans(n_clusters=int(class_num), batch_size=5000, n_init=100) 194 | z = np.concatenate(h_vectors, axis=1) 195 | pseudo_label = kmeans.fit_predict(z) 196 | print("Clustering results on all views ([H^1...H^V]): " + str(labels_vector.shape[0])) 197 | nmi, ari, acc, pur = evaluate(labels_vector, pseudo_label) 198 | print('ACC = {:.4f} NMI = {:.4f} ARI = {:.4f} PUR={:.4f}'.format(acc, nmi, ari, pur)) 199 | for i in range(view): 200 | for j in range(view): 201 | if Measure == 'CMI': 202 | cnmi, _, _, _ = evaluate(h_clusters[i], h_clusters[j]) 203 | nmi_matrix_h[i][j] = np.exp(cnmi) - 1 204 | print(nmi_matrix_h) 205 | 206 | if eval_h and Measure is not 'CMI': 207 | for i in range(view): 208 | for j in range(view): 209 | if Measure == 'JSD': 210 | P = torch.tensor(h_vectors[i]) 211 | Q = torch.tensor(h_vectors[j]) 212 | divergence = js_div(P, Q).item() 213 | nmi_matrix_h[i][j] = np.exp(1 - divergence) - 1 214 | if Measure == 'MMD': 215 | if len(labels_vector) > sample_num: 216 | P = torch.tensor(h_vectors[i][0: sample_num]) 217 | Q = torch.tensor(h_vectors[j][0: sample_num]) 218 | else: 219 | P = torch.tensor(h_vectors[i]) 220 | Q = torch.tensor(h_vectors[j]) 221 | mmd = MMD(P, Q, kernel_mul=4, kernel_num=4) 222 | nmi_matrix_h[i][j] = np.exp(-mmd) 223 | print(nmi_matrix_h) 224 | 225 | if eval_z and Measure == 'CMI': 226 | print("Clustering results on each view (Z^v):") 227 | acc_avg, nmi_avg, ari_avg, pur_avg = 0, 0, 0, 0 228 | for v in range(view): 229 | kmeans = KMeans(n_clusters=int(class_num * times_for_K), n_init=100) 230 | if len(labels_vector) > 10000: 231 | kmeans = MiniBatchKMeans(n_clusters=int(class_num * times_for_K), batch_size=5000, n_init=100) 232 | y_pred = kmeans.fit_predict(z_vectors[v]) 233 | final_z_features.append(z_vectors[v]) 234 | z_clusters.append(y_pred) 235 | nmi, ari, acc, pur = evaluate(labels_vector, y_pred) 236 | print('ACC{} = {:.4f} NMI{} = {:.4f} ARI{} = {:.4f} PUR{}={:.4f}'.format(v + 1, acc, 237 | v + 1, nmi, 238 | v + 1, ari, 239 | v + 1, pur)) 240 | acc_avg += acc 241 | nmi_avg += nmi 242 | ari_avg += ari 243 | pur_avg += pur 244 | 245 | print('Mean = {:.4f} Mean = {:.4f} Mean = {:.4f} Mean={:.4f}'.format(acc_avg/view, 246 | nmi_avg/view, 247 | ari_avg/view, 248 | pur_avg/view)) 249 | kmeans = KMeans(n_clusters=class_num, n_init=100) 250 | if len(labels_vector) > 10000: 251 | kmeans = MiniBatchKMeans(n_clusters=int(class_num), batch_size=5000, n_init=100) 252 | h = np.concatenate(final_z_features, axis=1) 253 | pseudo_label = kmeans.fit_predict(h) 254 | print("Clustering results on all views ([Z^1...Z^V]): " + str(labels_vector.shape[0])) 255 | nmi, ari, acc, pur = evaluate(labels_vector, pseudo_label) 256 | print('ACC = {:.4f} NMI = {:.4f} ARI = {:.4f} PUR={:.4f}'.format(acc, nmi, ari, pur)) 257 | for i in range(view): 258 | for j in range(view): 259 | if Measure == 'CMI': 260 | cnmi, _, _, _ = evaluate(z_clusters[i], z_clusters[j]) 261 | nmi_matrix_z[i][j] = np.exp(cnmi) - 1 262 | print(nmi_matrix_z) 263 | 264 | if eval_z and Measure is not 'CMI': 265 | for i in range(view): 266 | for j in range(view): 267 | if Measure == 'JSD': 268 | P = torch.tensor(z_vectors[i]) 269 | Q = torch.tensor(z_vectors[j]) 270 | divergence = js_div(P, Q).item() 271 | nmi_matrix_z[i][j] = np.exp(1 - divergence) - 1 272 | if Measure == 'MMD': 273 | if len(labels_vector) > sample_num: 274 | P = torch.tensor(z_vectors[i][0: sample_num]) 275 | Q = torch.tensor(z_vectors[j][0: sample_num]) 276 | else: 277 | P = torch.tensor(z_vectors[i]) 278 | Q = torch.tensor(z_vectors[j]) 279 | mmd = MMD(P, Q, kernel_mul=4, kernel_num=4) 280 | nmi_matrix_z[i][j] = np.exp(-mmd) 281 | 282 | print(nmi_matrix_z) 283 | 284 | if test or Measure is not 'CMI': 285 | kmeans = KMeans(n_clusters=class_num, n_init=100) 286 | if len(labels_vector) > 10000: 287 | kmeans = MiniBatchKMeans(n_clusters=int(class_num), batch_size=5000, n_init=100) 288 | h = np.concatenate(z_vectors, axis=1) 289 | pseudo_label = kmeans.fit_predict(h) 290 | print("Clustering results on all views ([Z^1...Z^V]): " + str(labels_vector.shape[0])) 291 | nmi, ari, acc, pur = evaluate(labels_vector, pseudo_label) 292 | print('ACC = {:.4f} NMI = {:.4f} ARI = {:.4f} PUR={:.4f}'.format(acc, nmi, ari, pur)) 293 | 294 | return acc, nmi, ari, pur, nmi_matrix_h, nmi_matrix_z 295 | -------------------------------------------------------------------------------- /models/the_trained_models.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.functional import normalize 3 | 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, input_dim, feature_dim): 7 | super(Encoder, self).__init__() 8 | self.encoder = nn.Sequential( 9 | nn.Linear(input_dim, 500), 10 | nn.ReLU(), 11 | nn.Linear(500, 500), 12 | nn.ReLU(), 13 | nn.Linear(500, 2000), 14 | nn.ReLU(), 15 | nn.Linear(2000, feature_dim), 16 | ) 17 | 18 | def forward(self, x): 19 | return self.encoder(x) 20 | 21 | 22 | class Decoder(nn.Module): 23 | def __init__(self, input_dim, feature_dim): 24 | super(Decoder, self).__init__() 25 | self.decoder = nn.Sequential( 26 | nn.Linear(feature_dim, 2000), 27 | nn.ReLU(), 28 | nn.Linear(2000, 500), 29 | nn.ReLU(), 30 | nn.Linear(500, 500), 31 | nn.ReLU(), 32 | nn.Linear(500, input_dim) 33 | ) 34 | 35 | def forward(self, x): 36 | return self.decoder(x) 37 | 38 | 39 | class Network(nn.Module): 40 | def __init__(self, view, input_size, feature_dim, high_feature_dim, class_num, device): 41 | super(Network, self).__init__() 42 | self.view = view 43 | self.encoders = [] 44 | self.decoders = [] 45 | self.feature_contrastive_modules = [] 46 | for v in range(view): 47 | self.encoders.append(Encoder(input_size[v], feature_dim)) 48 | self.decoders.append(Decoder(input_size[v], feature_dim)) 49 | self.feature_contrastive_modules.append( 50 | nn.Sequential( 51 | nn.Linear(feature_dim, high_feature_dim), 52 | # 53 | # nn.Linear(feature_dim, feature_dim), 54 | # nn.ReLU(), 55 | # nn.Linear(feature_dim, high_feature_dim), 56 | ) 57 | ) 58 | self.encoders = nn.ModuleList(self.encoders) 59 | self.decoders = nn.ModuleList(self.decoders) 60 | self.feature_contrastive_modules = nn.ModuleList(self.feature_contrastive_modules) 61 | 62 | def forward(self, xs): 63 | hs = [] 64 | qs = [] 65 | xrs = [] 66 | zs = [] 67 | for v in range(self.view): 68 | x = xs[v] 69 | h = self.encoders[v](x) 70 | z = normalize(self.feature_contrastive_modules[v](h), dim=1) 71 | xr = self.decoders[v](h) 72 | hs.append(h) 73 | zs.append(z) 74 | xrs.append(xr) 75 | return zs, qs, xrs, hs, [] 76 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from network import Network 3 | from metric import valid 4 | import argparse 5 | from dataloader import load_data 6 | from sklearn.metrics import accuracy_score, f1_score, recall_score 7 | 8 | test_train = 0.7 9 | # Dataname = 'DHA' 10 | # Dataname = 'CCV' 11 | # Dataname = 'NUSWIDE' 12 | Dataname = 'Caltech' 13 | # Dataname = 'YoutubeVideo' 14 | 15 | CL_Loss = ['InfoNCE', 'PSCL', 'RINCE'] 16 | Measure_M_N = ['CMI', 'JSD', 'MMD'] 17 | Reconstruction = ['AE', 'DAE', 'MAE'] 18 | parser = argparse.ArgumentParser(description='test') 19 | parser.add_argument('--dataset', default=Dataname) 20 | parser.add_argument('--batch_size', default=256, type=int) 21 | parser.add_argument("--temperature_f", default=1.0) 22 | parser.add_argument("--contrastive_loss", default=CL_Loss[0]) 23 | parser.add_argument("--measurement", default=Measure_M_N[0]) 24 | parser.add_argument("--Recon", default=Reconstruction[0]) 25 | parser.add_argument("--bi_level_iteration", default=4) 26 | parser.add_argument("--times_for_K", default=1) 27 | parser.add_argument("--Lambda", default=1) 28 | parser.add_argument("--learning_rate", default=0.0003) 29 | parser.add_argument("--weight_decay", default=0.) 30 | parser.add_argument("--workers", default=8) 31 | parser.add_argument("--mse_epochs", default=100) 32 | parser.add_argument("--con_epochs", default=100) 33 | parser.add_argument("--feature_dim", default=512) 34 | parser.add_argument("--high_feature_dim", default=128) 35 | args = parser.parse_args() 36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | 38 | 39 | import numpy as np 40 | dataset, dims, view, data_size, class_num = load_data(args.dataset) 41 | model = Network(view, dims, args.feature_dim, args.high_feature_dim, class_num, device) 42 | model = model.to(device) 43 | checkpoint = torch.load('./models/' + args.dataset + '.pth') 44 | model.load_state_dict(checkpoint) 45 | print("Dataset:{}".format(args.dataset)) 46 | print("Datasize:" + str(data_size)) 47 | print("Loading model...") 48 | 49 | valid(model, device, dataset, view, data_size, class_num, eval_h=False, eval_z=False, test=True) 50 | 51 | if Dataname == 'YoutubeVideo': 52 | exit(0) 53 | 54 | from torch.utils.data import DataLoader 55 | from metric import inference 56 | from sklearn import svm 57 | from sklearn import model_selection 58 | 59 | np.random.seed(80) 60 | 61 | 62 | def SVM_Classification(x, y, seed=1, test_r=0.3): 63 | # data_train, data_test, tag_train, tag_test = model_selection.train_test_split(x, y, random_state=seed, test_size=test_r) 64 | data_train, data_test, tag_train, tag_test = model_selection.train_test_split(x, y, test_size=test_r) 65 | 66 | def classifier(): 67 | clf = svm.SVC(C=1, 68 | kernel='linear', 69 | decision_function_shape='ovr') 70 | return clf 71 | 72 | clf = classifier() 73 | 74 | def train(clf, x_train, y_train): 75 | clf.fit(x_train, y_train.ravel()) 76 | 77 | train(clf, data_train, tag_train) 78 | 79 | def print_accuracy(clf, x_train, y_train, x_test, y_test): 80 | y_pre = clf.predict(x_test) 81 | acc = accuracy_score(y_test, y_pre) 82 | # f1 = f1_score(y_test, y_pre, average='macro') 83 | recall = recall_score(y_test, y_pre, average='macro') 84 | return acc, recall 85 | acc, recall = print_accuracy(clf, data_train, tag_train, data_test, tag_test) 86 | 87 | return acc, recall 88 | 89 | 90 | test_loader = DataLoader(dataset, batch_size=256, shuffle=False) 91 | X_vectors, pred_vectors, high_level_vectors, labels_vector, low_level_vectors = inference(test_loader, model, device, view, data_size) 92 | ACC_H = [] 93 | ACC_ALLH = [] 94 | REC_ALLH = [] 95 | ALLH = np.concatenate(high_level_vectors, axis=1) 96 | for seed in range(10): 97 | y = labels_vector 98 | acc, recall = SVM_Classification(ALLH, y, seed=seed, test_r=test_train) 99 | ACC_ALLH.append(acc/0.01) 100 | REC_ALLH.append(recall/0.01) 101 | 102 | print(ACC_ALLH, np.mean(ACC_ALLH), np.std(ACC_ALLH)) 103 | print(REC_ALLH, np.mean(REC_ALLH), np.std(REC_ALLH)) 104 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from network import Network 3 | from metric import valid 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | import argparse 7 | from loss import Loss 8 | from dataloader import load_data 9 | import os 10 | import time 11 | 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 13 | 14 | # Dataname = 'DHA' 15 | # Dataname = 'CCV' 16 | # Dataname = 'NUSWIDE' 17 | Dataname = 'Caltech' 18 | # Dataname = 'YoutubeVideo' 19 | 20 | CL_Loss = ['InfoNCE', 'PSCL', 'RINCE'] # three kinds of contrastive losses 21 | Measure_M_N = ['CMI', 'JSD', 'MMD'] # Class Mutual Information (CMI), Jensen–Shannon Divergence (JSD), Maximum Mean Discrepancy (MMD) 22 | sample_mmd = 2000 # select partial samples to compute MMD as it has high complexity, otherwise might be out-of-memory 23 | Reconstruction = ['AE', 'DAE', 'MAE'] # autoencoder (AE), denoising autoencoder (DAE), masked autoencoder (MAE) 24 | per = 0.3 # the ratio of masked samples to perform masked AE, e.g., 30% 25 | 26 | parser = argparse.ArgumentParser(description='train') 27 | parser.add_argument('--dataset', default=Dataname) 28 | parser.add_argument('--batch_size', default=256, type=int) # 256 29 | parser.add_argument("--temperature_f", default=1.0) # 1.0 30 | parser.add_argument("--contrastive_loss", default=CL_Loss[0]) # 0, 1, 2 31 | parser.add_argument("--measurement", default=Measure_M_N[0]) # 0, 1, 2 32 | parser.add_argument("--Recon", default=Reconstruction[0]) # 0, 1, 2 33 | parser.add_argument("--bi_level_iteration", default=4) # 4 34 | parser.add_argument("--times_for_K", default=1) # 0.5 1 2 4 35 | parser.add_argument("--Lambda", default=1) # 0.001 0.01 0.1 1 10 100 1000 36 | parser.add_argument("--learning_rate", default=0.0003) # 0.0003 37 | parser.add_argument("--weight_decay", default=0.) # 0. 38 | parser.add_argument("--workers", default=8) # 8 39 | parser.add_argument("--mse_epochs", default=100) # 100 40 | parser.add_argument("--con_epochs", default=100) # 100 41 | parser.add_argument("--feature_dim", default=512) # 512 42 | parser.add_argument("--high_feature_dim", default=128) # 128 43 | args = parser.parse_args() 44 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 45 | 46 | print('SEM + ' + args.contrastive_loss + ' + ' + args.measurement + ' + ' + args.Recon) 47 | 48 | if args.dataset == "DHA": 49 | args.con_epochs = 300 50 | args.bi_level_iteration = 1 51 | 52 | if args.dataset == "CCV": 53 | args.con_epochs = 50 54 | args.bi_level_iteration = 4 55 | 56 | if args.dataset == "YoutubeVideo": 57 | args.con_epochs = 25 58 | args.bi_level_iteration = 1 59 | 60 | if args.dataset == "NUSWIDE": 61 | args.con_epochs = 25 62 | args.bi_level_iteration = 4 63 | 64 | if args.dataset == "Caltech": 65 | args.con_epochs = 100 66 | args.bi_level_iteration = 4 67 | # or 68 | args.bi_level_iteration = 3 69 | 70 | Total_con_epochs = args.con_epochs * args.bi_level_iteration 71 | 72 | 73 | def setup_seed(seed): 74 | torch.manual_seed(seed) 75 | torch.cuda.manual_seed_all(seed) 76 | # np.random.seed(seed) 77 | # random.seed(seed) 78 | torch.backends.cudnn.deterministic = True 79 | 80 | 81 | accs = [] 82 | nmis = [] 83 | aris = [] 84 | purs = [] 85 | ACC_tmp = 0 86 | 87 | for Runs in range(1): # 10 88 | print("ROUND:{}".format(Runs+1)) 89 | 90 | t1 = time.time() 91 | # setup_seed(5) # if we find that the initialization of networks is sensitive, we can set a seed for stable performance. 92 | dataset, dims, view, data_size, class_num = load_data(args.dataset) 93 | 94 | data_loader = torch.utils.data.DataLoader( 95 | dataset, 96 | batch_size=args.batch_size, 97 | shuffle=True, 98 | # drop_last=True, 99 | drop_last=False, 100 | ) 101 | 102 | 103 | def Low_level_rec_train(epoch, rec='AE', p=0.3, mask_ones_full=[], mask_ones_not_full=[]): 104 | tot_loss = 0. 105 | criterion = torch.nn.MSELoss() 106 | Vones_full = [] 107 | Vones_not_full = [] 108 | flag_full = 0 109 | flag_not_full = 0 110 | for batch_idx, (xs, _, _) in enumerate(data_loader): 111 | for v in range(view): 112 | xs[v] = xs[v].to(device) 113 | 114 | xnum = xs[0].shape[0] 115 | 116 | if rec == 'AE': 117 | optimizer.zero_grad() 118 | _, _, xrs, _, _ = model(xs) 119 | if rec == 'DAE': 120 | noise_x = [] 121 | for v in range(view): 122 | # print(xs[v]) 123 | noise = torch.randn(xs[v].shape).to(device) 124 | # print(noise) 125 | noise = noise + xs[v] 126 | # print(noise) 127 | noise_x.append(noise) 128 | optimizer.zero_grad() 129 | _, _, xrs, _, _ = model(noise_x) 130 | if rec == 'MAE': 131 | noise_x = [] 132 | for v in range(view): 133 | 134 | if xnum == args.batch_size and flag_full == 0 and epoch == 1: 135 | # print(1) 136 | num = xs[v].shape[0] * xs[v].shape[1] 137 | ones = torch.ones([1, num]).to(device) 138 | zeros_num = int(num * p) 139 | for i in range(zeros_num): 140 | ones[0, i] = 0 141 | Vones_full.append(ones) 142 | if xnum is not args.batch_size and flag_not_full == 0 and epoch == 1: 143 | # print(1) 144 | num = xs[v].shape[0] * xs[v].shape[1] 145 | ones = torch.ones([1, num]).to(device) 146 | zeros_num = int(num * p) 147 | for i in range(zeros_num): 148 | ones[0, i] = 0 149 | Vones_not_full.append(ones) 150 | 151 | if xnum == args.batch_size and epoch == 1: 152 | noise = Vones_full[v][:, torch.randperm(Vones_full[v].size(1))] 153 | if xnum is not args.batch_size and epoch == 1: 154 | noise = Vones_not_full[v][:, torch.randperm(Vones_not_full[v].size(1))] 155 | 156 | if xnum == args.batch_size and epoch is not 1: 157 | noise = mask_ones_full[v][:, torch.randperm(mask_ones_full[v].size(1))] 158 | if xnum is not args.batch_size and epoch is not 1: 159 | noise = mask_ones_not_full[v][:, torch.randperm(mask_ones_not_full[v].size(1))] 160 | noise = torch.reshape(noise, xs[v].shape) 161 | noise = noise * xs[v] 162 | noise_x.append(noise) 163 | 164 | if xnum == args.batch_size: 165 | flag_full = 1 166 | else: 167 | flag_not_full = 1 168 | 169 | optimizer.zero_grad() 170 | _, _, xrs, _, _ = model(noise_x) 171 | 172 | loss_list = [] 173 | for v in range(view): 174 | loss_list.append(criterion(xs[v], xrs[v])) 175 | loss = sum(loss_list) 176 | loss.backward() 177 | optimizer.step() 178 | tot_loss += loss.item() 179 | # print('Epoch {}'.format(epoch), 'Loss:{:.6f}'.format(tot_loss / len(data_loader))) 180 | return Vones_full, Vones_not_full 181 | 182 | def High_level_contrastive_train(epoch, nmi_matrix, Lambda=1.0, rec='AE', p=0.3, mask_ones_full=[], mask_ones_not_full=[]): 183 | tot_loss = 0. 184 | mes = torch.nn.MSELoss() 185 | record_loss_con = [] 186 | Vones_full = [] 187 | Vones_not_full = [] 188 | flag_full = 0 189 | flag_not_full = 0 190 | 191 | for v in range(view): 192 | record_loss_con.append([]) 193 | for w in range(view): 194 | record_loss_con[v].append([]) 195 | 196 | # Sim = 0 197 | # cos = torch.nn.CosineSimilarity(dim=0) 198 | 199 | for batch_idx, (xs, _, _) in enumerate(data_loader): 200 | for v in range(view): 201 | xs[v] = xs[v].to(device) 202 | 203 | optimizer.zero_grad() 204 | zs, qs, xrs, hs, re_h = model(xs) 205 | loss_list = [] 206 | 207 | xnum = xs[0].shape[0] 208 | #------------------------ 209 | # P = zs[0] 210 | # Q = zs[1] 211 | # for i in range(xnum): 212 | # # print(cos(P[i], Q[i])) 213 | # Sim += cos(P[i], Q[i]).item() 214 | #------------------------- 215 | if rec == 'DAE': 216 | noise_x = [] 217 | for v in range(view): 218 | # print(xs[v]) 219 | noise = torch.randn(xs[v].shape).to(device) 220 | # print(noise) 221 | noise = noise + xs[v] 222 | # print(noise) 223 | noise_x.append(noise) 224 | optimizer.zero_grad() 225 | _, _, xrs, _, _ = model(noise_x) 226 | if rec == 'MAE': 227 | noise_x = [] 228 | for v in range(view): 229 | 230 | if xnum == args.batch_size and flag_full == 0 and epoch == 1: 231 | # print(1) 232 | num = xs[v].shape[0] * xs[v].shape[1] 233 | ones = torch.ones([1, num]).to(device) 234 | zeros_num = int(num * p) 235 | for i in range(zeros_num): 236 | ones[0, i] = 0 237 | Vones_full.append(ones) 238 | if xnum is not args.batch_size and flag_not_full == 0 and epoch == 1: 239 | # print(1) 240 | num = xs[v].shape[0] * xs[v].shape[1] 241 | ones = torch.ones([1, num]).to(device) 242 | zeros_num = int(num * p) 243 | for i in range(zeros_num): 244 | ones[0, i] = 0 245 | Vones_not_full.append(ones) 246 | 247 | if xnum == args.batch_size and epoch == 1: 248 | noise = Vones_full[v][:, torch.randperm(Vones_full[v].size(1))] 249 | if xnum is not args.batch_size and epoch == 1: 250 | noise = Vones_not_full[v][:, torch.randperm(Vones_not_full[v].size(1))] 251 | 252 | if xnum == args.batch_size and epoch is not 1: 253 | noise = mask_ones_full[v][:, torch.randperm(mask_ones_full[v].size(1))] 254 | if xnum is not args.batch_size and epoch is not 1: 255 | noise = mask_ones_not_full[v][:, torch.randperm(mask_ones_not_full[v].size(1))] 256 | 257 | noise = torch.reshape(noise, xs[v].shape) 258 | noise = noise * xs[v] 259 | noise_x.append(noise) 260 | 261 | if xnum == args.batch_size: 262 | flag_full = 1 263 | else: 264 | flag_not_full = 1 265 | 266 | optimizer.zero_grad() 267 | _, _, xrs, _, _ = model(noise_x) 268 | 269 | for v in range(view): 270 | # for w in range(v + 1, view): 271 | for w in range(view): 272 | # if v == w: 273 | # continue 274 | if args.contrastive_loss == 'InfoNCE': 275 | tmp = criterion.forward_feature_InfoNCE(zs[v], zs[w], batch_size=xnum) 276 | if args.contrastive_loss == 'PSCL': 277 | tmp = criterion.forward_feature_PSCL(zs[v], zs[w]) 278 | if args.contrastive_loss == 'RINCE': 279 | tmp = criterion.forward_feature_RINCE(zs[v], zs[w], batch_size=xnum) 280 | 281 | # loss_list.append(tmp) 282 | loss_list.append(tmp * nmi_matrix[v][w]) 283 | record_loss_con[v][w].append(tmp) 284 | 285 | loss_list.append(Lambda * mes(xs[v], xrs[v])) 286 | loss = sum(loss_list) 287 | loss.backward() 288 | optimizer.step() 289 | tot_loss += loss.item() 290 | 291 | # print(Sim / 1400) # 1400 is the data size of Caltech 292 | 293 | for v in range(view): 294 | for w in range(view): 295 | record_loss_con[v][w] = sum(record_loss_con[v][w]) 296 | record_loss_con[v][w] = record_loss_con[v][w].item() / len(data_loader) 297 | 298 | return Vones_full, Vones_not_full, record_loss_con, _ 299 | 300 | if not os.path.exists('./models'): 301 | os.makedirs('./models') 302 | 303 | model = Network(view, dims, args.feature_dim, args.high_feature_dim, class_num, device) 304 | # print(model) 305 | model = model.to(device) 306 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 307 | criterion = Loss(args.batch_size, class_num, args.temperature_f, device).to(device) 308 | 309 | print("Initialization......") 310 | epoch = 0 311 | while epoch < args.mse_epochs: 312 | epoch += 1 313 | if epoch == 1: 314 | mask_ones_full, mask_ones_not_full = Low_level_rec_train(epoch, 315 | rec=args.Recon, 316 | p=per, 317 | ) 318 | else: 319 | Low_level_rec_train(epoch, 320 | rec=args.Recon, 321 | p=per, 322 | mask_ones_full=mask_ones_full, 323 | mask_ones_not_full=mask_ones_not_full, 324 | ) 325 | 326 | acc, nmi, ari, pur, nmi_matrix_1, _ = valid(model, device, dataset, view, data_size, class_num, 327 | eval_h=True, eval_z=False, times_for_K=args.times_for_K, 328 | Measure=args.measurement, test=False, sample_num=sample_mmd) 329 | 330 | print("Self-Weighted Multi-view Contrastive Learning with Reconstruction Regularization...") 331 | Iteration = 1 332 | print("Iteration " + str(Iteration) + ":") 333 | epoch = 0 334 | record_loss_con = [] 335 | record_cos = [] 336 | while epoch < Total_con_epochs: 337 | epoch += 1 338 | if epoch == 1: 339 | mask_ones_full, mask_ones_not_full, record_loss_con_, record_cos_ = High_level_contrastive_train(epoch, 340 | nmi_matrix_1, 341 | args.Lambda, 342 | rec=args.Recon, 343 | p=per) 344 | else: 345 | _, _, record_loss_con_, record_cos_ = High_level_contrastive_train(epoch, 346 | nmi_matrix_1, 347 | args.Lambda, 348 | rec=args.Recon, 349 | p=per, 350 | mask_ones_full=mask_ones_full, 351 | mask_ones_not_full=mask_ones_not_full, 352 | ) 353 | 354 | record_loss_con.append(record_loss_con_) 355 | record_cos.append(record_cos_) 356 | if epoch % args.con_epochs == 0: 357 | if epoch == args.mse_epochs + Total_con_epochs: 358 | break 359 | 360 | # print(nmi_matrix_1) 361 | 362 | acc, nmi, ari, pur, _, nmi_matrix_2 = valid(model, device, dataset, view, data_size, class_num, 363 | eval_h=False, eval_z=True, times_for_K=args.times_for_K, 364 | Measure=args.measurement, test=False, sample_num=sample_mmd) 365 | nmi_matrix_1 = nmi_matrix_2 366 | if epoch < Total_con_epochs: 367 | Iteration += 1 368 | print("Iteration " + str(Iteration) + ":") 369 | 370 | pg = [p for p in model.parameters() if p.requires_grad] 371 | # this code matters, to re-initialize the optimizers 372 | optimizer = torch.optim.Adam(pg, lr=args.learning_rate, weight_decay=args.weight_decay) 373 | 374 | accs.append(acc) 375 | nmis.append(nmi) 376 | aris.append(ari) 377 | purs.append(pur) 378 | 379 | # if acc > ACC_tmp: 380 | # ACC_tmp = acc 381 | # state = model.state_dict() 382 | # torch.save(state, './models/' + args.dataset + '.pth') 383 | 384 | t2 = time.time() 385 | print("Time cost: " + str(t2 - t1)) 386 | print('End......') 387 | 388 | 389 | print(accs, np.mean(accs)/0.01, np.std(accs)/0.01) 390 | print(nmis, np.mean(nmis)/0.01, np.std(nmis)/0.01) 391 | # print(aris, np.mean(aris)/0.01, np.std(aris)/0.01) 392 | # print(purs, np.mean(purs)/0.01, np.std(purs)/0.01) 393 | 394 | 395 | def PLOT_LOSS(record_loss_con=[]): 396 | import matplotlib.pyplot as plt 397 | from matplotlib import pyplot 398 | plt.style.use('seaborn-whitegrid') 399 | palette = pyplot.get_cmap('Set1') 400 | font1 = {'family': 'Times New Roman', 401 | 'weight': 'normal', 402 | 'size': 50, 403 | } 404 | fontsize = 60 405 | fig = plt.figure() 406 | ax = fig.add_subplot() 407 | length = 133 # 400 / 3 408 | # print(len(record_loss_con)) 409 | iters = np.linspace(0, length - 1, length, dtype=int) 410 | 411 | loss = [] 412 | v1 = 0 413 | v2 = 3 414 | for i in range(length): 415 | loss.append(record_loss_con[3 * i + 1][v1][v2]) 416 | ax.plot(iters, loss, color=palette(1), linestyle='-', label='InfoNCE loss (View 1; View 4)', linewidth=4.0) 417 | 418 | loss = [] 419 | v1 = 3 420 | v2 = 4 421 | for i in range(length): 422 | loss.append(record_loss_con[3 * i + 1][v1][v2]) 423 | ax.plot(iters, loss, color=palette(2), linestyle='-', label='InfoNCE loss (View 4; View 5)', linewidth=4.0) 424 | 425 | ax.legend(prop=font1, frameon=1, fancybox=0, framealpha=1) 426 | ax.set_xlabel('Training epochs', fontsize=fontsize) 427 | ax.set_ylabel('Loss values', fontsize=fontsize) 428 | plt.xticks([0, 33, 66, 99, 132], [0, 100, 200, 300, 400], rotation=0, fontsize=fontsize) 429 | plt.yticks(rotation=0, fontsize=fontsize) 430 | plt.show() 431 | 432 | 433 | if Dataname == 'Caltech' and args.contrastive_loss == 'InfoNCE' and args.bi_level_iteration == 4: 434 | PLOT_LOSS(record_loss_con) 435 | --------------------------------------------------------------------------------