├── data ├── NGs.mat └── README.md ├── readme.txt ├── loss.py ├── network.py ├── TSNE.py ├── metric.py ├── dataloader.py └── train.py /data/NGs.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SubmissionsIn/SCM/HEAD/data/NGs.mat -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | The complete documents for reproduction are provided in **[Baiduyunpan](https://pan.baidu.com/s/1bz8THF1FhE8cIp-4itebRQ?pwd=1s1h)**. 2 | -------------------------------------------------------------------------------- /readme.txt: -------------------------------------------------------------------------------- 1 | This is the PyTorch implementation of SCM: 2 | 3 | @InProceedings{Luo_2024_IJCAI, 4 | title = {Simple Contrastive Multi-View Clustering with Data-Level Fusion}, 5 | author = {Luo, Caixuan and Xu, Jie and Ren, Yazhou and Ma, Junbo and Zhu, Xiaofeng}, 6 | booktitle = {Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence}, 7 | pages = {4697--4705}, 8 | year = {2024} 9 | } -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class Loss(nn.Module): 7 | def __init__(self, batch_size, class_num, temperature_f, device): 8 | super(Loss, self).__init__() 9 | self.batch_size = batch_size 10 | self.class_num = class_num 11 | self.temperature_f = temperature_f 12 | self.device = device 13 | 14 | self.mask = self.mask_correlated_samples(batch_size) 15 | self.similarity = nn.CosineSimilarity(dim=2) 16 | self.criterion = nn.CrossEntropyLoss(reduction="sum") 17 | 18 | def mask_correlated_samples(self, N): 19 | mask = torch.ones((N, N)) 20 | mask = mask.fill_diagonal_(0) 21 | for i in range(N//2): 22 | mask[i, N//2 + i] = 0 23 | mask[N//2 + i, i] = 0 24 | mask = mask.bool() 25 | return mask 26 | 27 | def forward_feature(self, h_i, h_j): 28 | num_rows = h_i.shape[0] 29 | N = 2 * num_rows 30 | h = torch.cat((h_i, h_j), dim=0) 31 | 32 | sim = torch.matmul(h, h.T) / self.temperature_f 33 | sim_i_j = torch.diag(sim, num_rows) 34 | sim_j_i = torch.diag(sim, -num_rows) 35 | 36 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) 37 | mask = self.mask_correlated_samples(N) 38 | negative_samples = sim[mask].reshape(N, -1) 39 | 40 | labels = torch.zeros(N).to(positive_samples.device).long() 41 | logits = torch.cat((positive_samples, negative_samples), dim=1) 42 | loss = self.criterion(logits, labels) 43 | loss /= N 44 | return loss 45 | 46 | 47 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.functional import normalize 3 | class Encoder(nn.Module): 4 | def __init__(self, input_dim, feature_dim): 5 | super(Encoder, self).__init__() 6 | self.encoder = nn.Sequential( 7 | nn.Linear(input_dim, 500), 8 | nn.ReLU(), 9 | nn.Linear(500, 500), 10 | nn.ReLU(), 11 | nn.Linear(500, 2000), 12 | nn.ReLU(), 13 | nn.Linear(2000, feature_dim), 14 | ) 15 | 16 | def forward(self, x): 17 | return self.encoder(x) 18 | class Decoder(nn.Module): 19 | def __init__(self, input_dim, feature_dim): 20 | super(Decoder, self).__init__() 21 | self.decoder = nn.Sequential( 22 | nn.Linear(feature_dim, 2000), 23 | nn.ReLU(), 24 | nn.Linear(2000, 500), 25 | nn.ReLU(), 26 | nn.Linear(500, 500), 27 | nn.ReLU(), 28 | nn.Linear(500, input_dim) 29 | ) 30 | 31 | def forward(self, x): 32 | return self.decoder(x) 33 | class Network(nn.Module): 34 | def __init__(self, input_size, feature_dim, high_feature_dim,device): 35 | super(Network, self).__init__() 36 | self.encoders = Encoder(input_size, feature_dim).to(device) 37 | self.decoders = Decoder(input_size, feature_dim).to(device) 38 | 39 | self.feature_contrastive_module = nn.Sequential( 40 | nn.Linear(feature_dim, high_feature_dim), 41 | ) 42 | self.label_contrastive_module = nn.Sequential( 43 | nn.Linear(high_feature_dim, 64), 44 | nn.Softmax(dim=1) 45 | ) 46 | 47 | def forward(self, x): 48 | h = self.encoders(x) 49 | z = normalize(self.feature_contrastive_module(h), dim=1) 50 | q = self.label_contrastive_module(z) 51 | xr = self.decoders(h) 52 | return xr, h, z, q 53 | -------------------------------------------------------------------------------- /TSNE.py: -------------------------------------------------------------------------------- 1 | from sklearn.manifold import TSNE 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | 6 | def plot_embedding(data, label, title): 7 | x_min, x_max = np.min(data, 0), np.max(data, 0) 8 | data = (data - x_min) / (x_max - x_min) 9 | fig = plt.figure() 10 | ax = plt.subplot(111) 11 | # RGB 12 | color0 = (0.8941176470588236, 0.10196078431372549, 0.10980392156862745, 1.0) 13 | color1 = (0.21568627450980393, 0.49411764705882355, 0.7215686274509804, 1.0) 14 | color2 = (0.30196078431372547, 0.6862745098039216, 0.2901960784313726, 1.0) 15 | color3 = (0.596078431372549, 0.3058823529411765, 0.6392156862745098, 1.0) 16 | color4 = (1.0, 0.4980392156862745, 0.0, 1.0) 17 | color5 = (1.0, 1.0, 0.2, 1.0) 18 | color6 = (0.6509803921568628, 0.33725490196078434, 0.1568627450980392, 1.0) 19 | color7 = (0.9686274509803922, 0.5058823529411764, 0.7490196078431373, 1.0) 20 | color8 = (0.6, 0.6, 0.6, 1.0) 21 | color9 = (0.3, 0.5, 0.4, 0.7) 22 | colorcenters = (0.1, 0.1, 0.1, 1.0) 23 | #c = [color0, color1, color2, color3, color4, color5, color6, color7, color8, color9] # fixed color of 10 classes 24 | c = [color0, color1, color2, color3, color4, color5] # un-fixed color 25 | print(np.unique(label)) 26 | for i in range(len(np.unique(label))): 27 | c.append((np.random.random(), np.random.random(), np.random.random(), 1.0)) 28 | # print(len(c)) 29 | for i in range(data.shape[0]): 30 | color = c[int(label[i])] 31 | plt.text(data[i, 0], data[i, 1], str(label[i]), color=color, # plt.cm.Set123 32 | fontdict={'weight': 'bold', 'size': 9}) 33 | # plt.legend() 34 | plt.xlim(-0.005, 1.02) 35 | plt.ylim(-0.005, 1.02) 36 | plt.xticks([]) 37 | plt.yticks([]) 38 | plt.title(title) 39 | return fig 40 | 41 | 42 | def TSNE_PLOT(Z, Y, name="xxx"): 43 | tsne = TSNE(n_components=2, init='pca', random_state=0) 44 | F = tsne.fit_transform(Z) 45 | fig1 = plot_embedding(F, Y, name) 46 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 47 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 48 | plt.subplots_adjust(top=0.99, bottom=0.01, right=0.99, left=0.01, hspace=0, wspace=0) 49 | plt.margins(0, 0) 50 | plt.show() 51 | #fig1.savefig("images/" + name + ".png", format='png', transparent=True, dpi=1000, pad_inches=0) 52 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, accuracy_score 2 | from sklearn.cluster import KMeans 3 | from scipy.optimize import linear_sum_assignment 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | import torch 7 | 8 | def scale_normalize_matrix(input_matrix, min_value=0, max_value=1): 9 | min_val = input_matrix.min() 10 | max_val = input_matrix.max() 11 | input_range = max_val - min_val 12 | scaled_matrix = (input_matrix - min_val) / input_range * (max_value - min_value) + min_value 13 | return scaled_matrix 14 | 15 | def cluster_acc(y_true, y_pred): 16 | y_true = y_true.astype(np.int64) 17 | assert y_pred.size == y_true.size 18 | D = max(y_pred.max(), y_true.max()) + 1 19 | w = np.zeros((D, D), dtype=np.int64) 20 | for i in range(y_pred.size): 21 | w[y_pred[i], y_true[i]] += 1 22 | u = linear_sum_assignment(w.max() - w) 23 | ind = np.concatenate([u[0].reshape(u[0].shape[0], 1), u[1].reshape([u[0].shape[0], 1])], axis=1) 24 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 25 | 26 | def purity(y_true, y_pred): 27 | y_voted_labels = np.zeros(y_true.shape) 28 | labels = np.unique(y_true) 29 | ordered_labels = np.arange(labels.shape[0]) 30 | for k in range(labels.shape[0]): 31 | y_true[y_true == labels[k]] = ordered_labels[k] 32 | labels = np.unique(y_true) 33 | bins = np.concatenate((labels, [np.max(labels)+1]), axis=0) 34 | 35 | for cluster in np.unique(y_pred): 36 | hist, _ = np.histogram(y_true[y_pred == cluster], bins=bins) 37 | winner = np.argmax(hist) 38 | y_voted_labels[y_pred == cluster] = winner 39 | 40 | return accuracy_score(y_true, y_voted_labels) 41 | 42 | def evaluate(label, pred): 43 | nmi = normalized_mutual_info_score(label, pred) 44 | ari = adjusted_rand_score(label, pred) 45 | acc = cluster_acc(label, pred) 46 | pur = purity(label, pred) 47 | return nmi, ari, acc, pur 48 | 49 | def inference(loader, model, device, view): 50 | model.eval() 51 | soft_vector = [] 52 | 53 | for step, (xs, y, _) in enumerate(loader): 54 | for v in range(view): 55 | xs[v] = xs[v].to(device) 56 | xs_all = torch.cat(xs, dim=1) 57 | with torch.no_grad(): 58 | _, h, z, q = model.forward(xs_all) 59 | z = z.cpu().detach().numpy() 60 | h = h.cpu().detach().numpy() 61 | q = q.detach() 62 | soft_vector.extend(q.cpu().detach().numpy()) 63 | total_pred = np.argmax(np.array(soft_vector), axis=1) 64 | 65 | y = y.numpy() 66 | y = y.flatten() 67 | return y, h, z, total_pred 68 | def valid(model, device, dataset, view, data_size, class_num, eval_q = False,eval_z = False): 69 | test_loader = DataLoader( 70 | dataset, 71 | batch_size=data_size, 72 | shuffle=False, 73 | ) 74 | labels_vector, h, z, q = inference(test_loader, model, device, view) 75 | kmeans = KMeans(n_clusters=class_num) 76 | print(str(len(labels_vector)) + " samples") 77 | if eval_q == True: 78 | nmi_q, ari_q, acc_q, pur_q = evaluate(labels_vector, q) 79 | print('ACC_q = {:.4f} NMI_q = {:.4f} ARI_q = {:.4f} PUR_q = {:.4f}'.format(acc_q, nmi_q, ari_q, pur_q)) 80 | return acc_q, nmi_q, ari_q, pur_q 81 | if eval_z == True: 82 | z_pred = kmeans.fit_predict(z) 83 | nmi_z, ari_z, acc_z, pur_z = evaluate(labels_vector, z_pred) 84 | print('ACC_z = {:.4f} NMI_z = {:.4f} ARI_z = {:.4f} PUR_z = {:.4f}'.format(acc_z, nmi_z, ari_z, pur_z)) 85 | return acc_z, nmi_z, ari_z, pur_z 86 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import MinMaxScaler 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import scipy.io 5 | import torch 6 | 7 | def scale_normalize_matrix(input_matrix, min_value=0, max_value=1): 8 | min_val = input_matrix.min() 9 | max_val = input_matrix.max() 10 | input_range = max_val - min_val 11 | scaled_matrix = (input_matrix - min_val) / input_range * (max_value - min_value) + min_value 12 | return scaled_matrix 13 | class BDGP(Dataset): 14 | def __init__(self, path): 15 | data1 = scipy.io.loadmat(path+'BDGP.mat')['X1'].astype(np.float32) 16 | data2 = scipy.io.loadmat(path+'BDGP.mat')['X2'].astype(np.float32) 17 | labels = scipy.io.loadmat(path+'BDGP.mat')['Y'].transpose() 18 | self.x1 = scale_normalize_matrix(data1) 19 | self.x2 = scale_normalize_matrix(data2) 20 | self.y = labels 21 | 22 | def __len__(self): 23 | return self.x1.shape[0] 24 | 25 | def __getitem__(self, idx): 26 | return [torch.from_numpy(self.x1[idx]), torch.from_numpy( 27 | self.x2[idx])], torch.from_numpy(self.y[idx]), torch.from_numpy(np.array(idx)).long() 28 | class MNIST_USPS(Dataset): 29 | def __init__(self, path): 30 | self.Y = scipy.io.loadmat(path + 'MNIST_USPS.mat')['Y'].astype(np.int32).reshape(5000,) 31 | self.V1 = scipy.io.loadmat(path + 'MNIST_USPS.mat')['X1'].astype(np.float32) 32 | self.V2 = scipy.io.loadmat(path + 'MNIST_USPS.mat')['X2'].astype(np.float32) 33 | 34 | def __len__(self): 35 | return 5000 36 | 37 | def __getitem__(self, idx): 38 | 39 | x1 = self.V1[idx].reshape(784) 40 | x2 = self.V2[idx].reshape(784) 41 | return [torch.from_numpy(x1), torch.from_numpy(x2)], self.Y[idx], torch.from_numpy(np.array(idx)).long() 42 | class Fashion(Dataset): 43 | def __init__(self, path): 44 | self.Y = scipy.io.loadmat(path + 'Fashion.mat')['Y'].astype(np.int32).reshape(10000,) 45 | self.V1 = scipy.io.loadmat(path + 'Fashion.mat')['X1'].astype(np.float32) 46 | self.V2 = scipy.io.loadmat(path + 'Fashion.mat')['X2'].astype(np.float32) 47 | self.V3 = scipy.io.loadmat(path + 'Fashion.mat')['X3'].astype(np.float32) 48 | 49 | def __len__(self): 50 | return 10000 51 | 52 | def __getitem__(self, idx): 53 | 54 | x1 = self.V1[idx].reshape(784) 55 | x2 = self.V2[idx].reshape(784) 56 | x3 = self.V3[idx].reshape(784) 57 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], self.Y[idx], torch.from_numpy(np.array(idx)).long() 58 | class DHA(Dataset): 59 | def __init__(self, path): 60 | self.Y = scipy.io.loadmat(path + 'DHA.mat')['Y'].astype(np.int32).reshape(483,) 61 | self.V1 = scipy.io.loadmat(path + 'DHA.mat')['X1'].astype(np.float32) 62 | self.V2 = scipy.io.loadmat(path + 'DHA.mat')['X2'].astype(np.float32) 63 | def __len__(self): 64 | return 483 65 | def __getitem__(self, idx): 66 | x1 = self.V1[idx] 67 | x2 = self.V2[idx] 68 | 69 | x1 = scale_normalize_matrix(x1) 70 | x2 = scale_normalize_matrix(x2) 71 | 72 | return [torch.from_numpy(x1), torch.from_numpy(x2)], self.Y[idx], torch.from_numpy(np.array(idx)).long() 73 | class WebKB(Dataset): 74 | def __init__(self,path): 75 | self.Y = scipy.io.loadmat(path + 'WebKB')['gnd'].astype(np.int32).reshape(1051,) 76 | self.V1 = scipy.io.loadmat(path + 'WebKB')['X'][0][0].astype(np.float32) 77 | self.V2 = scipy.io.loadmat(path + 'WebKB')['X'][0][1].astype(np.float32) 78 | def __len__(self): 79 | return 1051 80 | def __getitem__(self, idx): 81 | x1 = self.V1[idx] 82 | x2 = self.V2[idx] 83 | 84 | return[torch.from_numpy(x1),torch.from_numpy(x2)],self.Y[idx],torch.from_numpy(np.array(idx)).long() 85 | class NGs(Dataset): 86 | def __init__(self,path): 87 | self.Y = scipy.io.loadmat(path + 'NGs')['truelabel'][0][0].astype(np.int32).reshape(500,) 88 | self.V1 = scipy.io.loadmat(path + 'NGs')['data'][0][0].astype(np.float32) 89 | self.V2 = scipy.io.loadmat(path + 'NGs')['data'][0][1].astype(np.float32) 90 | self.V3 = scipy.io.loadmat(path + 'NGs')['data'][0][2].astype(np.float32) 91 | 92 | self.V1 = np.transpose(self.V1) 93 | self.V2 = np.transpose(self.V2) 94 | self.V3 = np.transpose(self.V3) 95 | 96 | self.v1 = scale_normalize_matrix(self.V1) 97 | self.v2 = scale_normalize_matrix(self.V2) 98 | self.v3 = scale_normalize_matrix(self.V3) 99 | 100 | def __len__(self): 101 | return 500 102 | def __getitem__(self, idx): 103 | x1 = self.V1[idx] 104 | x2 = self.V2[idx] 105 | x3 = self.V3[idx] 106 | 107 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], \ 108 | self.Y[idx], torch.from_numpy(np.array(idx)).long() 109 | class VOC(Dataset): 110 | def __init__(self,path): 111 | self.Y = scipy.io.loadmat(path + 'VOC')['Y'].astype(np.int32).reshape(5649,) 112 | self.V1 = scipy.io.loadmat(path + 'VOC')['X1'].astype(np.float32) 113 | self.V2 = scipy.io.loadmat(path + 'VOC')['X2'].astype(np.float32) 114 | def __len__(self): 115 | return 5649 116 | def __getitem__(self, idx): 117 | x1 = self.V1[idx] 118 | x2 = self.V2[idx] 119 | 120 | return [torch.from_numpy(x1), torch.from_numpy(x2)], \ 121 | self.Y[idx], torch.from_numpy(np.array(idx)).long() 122 | class Fc_COIL_20(Dataset): 123 | def __init__(self,path): 124 | self.Y = scipy.io.loadmat(path + 'Fc_COIL_20')['Y'].astype(np.int32).reshape(1440, ) 125 | self.V1 = scipy.io.loadmat(path + 'Fc_COIL_20')['X1'].astype(np.float32) 126 | self.V2 = scipy.io.loadmat(path + 'Fc_COIL_20')['X2'].astype(np.float32) 127 | self.V3 = scipy.io.loadmat(path + 'Fc_COIL_20')['X3'].astype(np.float32) 128 | def __len__(self): 129 | return 1440 130 | def __getitem__(self, idx): 131 | x1 = self.V1[idx] 132 | x2 = self.V2[idx] 133 | x3 = self.V3[idx] 134 | return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], \ 135 | self.Y[idx], torch.from_numpy(np.array(idx)).long() 136 | 137 | def load_data(dataset): 138 | if dataset == "BDGP": 139 | dataset = BDGP('./data/') 140 | dims = [1750, 79] 141 | dimss = 1829 142 | view = 2 143 | data_size = 2500 144 | class_num = 5 145 | elif dataset == "MNIST-USPS": 146 | dataset = MNIST_USPS('./data/') 147 | dims = [784, 784] 148 | dimss = 1568 149 | view = 2 150 | class_num = 10 151 | data_size = 5000 152 | elif dataset == "Fashion": 153 | dataset = Fashion('./data/') 154 | dims = [784, 784, 784] 155 | dimss = 2352 156 | view = 3 157 | data_size = 10000 158 | class_num = 10 159 | elif dataset == "DHA": 160 | dataset = DHA('./data/') 161 | dims = [110, 6144] 162 | dimss = 6254 163 | view = 2 164 | data_size = 483 165 | class_num = 23 166 | elif dataset == "WebKB": 167 | dataset = WebKB('./data/') 168 | dims = [2949, 334] 169 | dimss = 3283 170 | view = 2 171 | data_size = 1051 172 | class_num = 2 173 | elif dataset == "NGs": 174 | dataset = NGs('./data/') 175 | dims = [2000, 2000 , 2000] 176 | dimss = 6000 177 | view = 3 178 | data_size = 500 179 | class_num = 5 180 | elif dataset == "VOC": 181 | dataset = VOC('./data/') 182 | dims = [512, 399] 183 | dimss = 911 184 | view = 2 185 | data_size = 5649 186 | class_num = 20 187 | elif dataset == "Fc_COIL_20": 188 | dataset = Fc_COIL_20('./data/') 189 | dims = [1024, 1024, 1024] 190 | dimss = 3072 191 | view = 3 192 | data_size = 1440 193 | class_num = 20 194 | else: 195 | raise NotImplementedError 196 | return dataset, dims, view, data_size, class_num, dimss 197 | 198 | 199 | 200 | 201 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from network import Network 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | import argparse 6 | import random 7 | from loss import Loss 8 | from dataloader import load_data 9 | from scipy.spatial import distance 10 | from torch.utils.data import DataLoader 11 | from metric import valid 12 | from sklearn.metrics import confusion_matrix 13 | import time 14 | from TSNE import TSNE_PLOT as ttsne 15 | # MNIST-USPS (aka. DIGIT) 16 | # BDGP 17 | # Fashion 18 | # NGs 19 | # VOC 20 | # WebKB 21 | # DHA 22 | # Fc_COIL_20 (aka. COIL-20) 23 | 24 | # SCM_w/o_DA 25 | # SCM_w/o_NoiseDA 26 | # SCM_w/o_MaskDA 27 | # SCM 28 | # SCM_REC 29 | # SCM_REC_ETC 30 | # SCM_ETC 31 | 32 | Dataname = 'MNIST-USPS' 33 | MODE = 'SCM' 34 | miss_rate = 0.25 35 | noise_rate = 0.25 36 | Gaussian_noise = 0.4 37 | tsne = True # True / False 38 | T = 1 39 | 40 | parser = argparse.ArgumentParser(description='train') 41 | parser.add_argument('--dataset', default=Dataname) 42 | parser.add_argument('--batch_size', default=256, type=int) 43 | parser.add_argument("--temperature_f", default=0.5) 44 | parser.add_argument("--learning_rate", default=0.0003) 45 | parser.add_argument("--weight_decay", default=0.) 46 | parser.add_argument("--workers", default=8) 47 | parser.add_argument("--mse_iterations", default=200) 48 | parser.add_argument("--con_iterations", default=50) 49 | parser.add_argument("--tune_iterations", default=50) 50 | parser.add_argument("--feature_dim", default=256) 51 | parser.add_argument("--high_feature_dim", default=128) 52 | parser.add_argument('--mode', type=str, default=MODE) 53 | parser.add_argument('--miss_rate', type=str, default=miss_rate) 54 | parser.add_argument('--noise_rate', type=str, default=noise_rate) 55 | parser.add_argument('--Gaussian_noise', type=str, default=Gaussian_noise) 56 | args = parser.parse_args() 57 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 58 | 59 | 60 | if args.dataset == "MNIST-USPS": 61 | args.con_iterations = 1000 62 | args.mse_iterations = 1000 63 | args.gamma = 0.02 64 | args.alpha = 0.5 65 | args.beta = 0.5 66 | seed = 1 67 | if args.dataset == "BDGP": 68 | args.con_iterations = 400 69 | args.mse_iterations = 3000 70 | args.gamma = 0.1 71 | args.alpha = 0.5 72 | args.beta = 0.5 73 | seed = 4 74 | if args.dataset == "Fashion": 75 | args.con_iterations = 20000 76 | args.mse_iterations = 2500 77 | args.gamma = 0.003 78 | args.alpha = 0.2 79 | args.beta = 0.81 80 | seed = 1 81 | if args.dataset == "DHA": 82 | args.con_iterations = 500 83 | args.mse_iterations = 700 84 | args.gamma = 0.02 85 | args.alpha = 0.2 86 | args.beta = 0.5 87 | seed = 4 88 | if args.dataset == "WebKB": 89 | args.con_iterations = 200 90 | args.mse_iterations = 200 91 | args.gamma = 0.001 92 | args.alpha = 0.6 93 | args.beta = 0.6 94 | seed = 2 95 | if args.dataset == "NGs": 96 | args.con_iterations = 200 97 | args.mse_iterations = 800 98 | args.gamma = 0.00005 99 | args.alpha = 0.5 100 | args.beta = 0.5 101 | seed = 5 102 | if args.dataset == "VOC": 103 | args.con_iterations = 200 104 | args.mse_iterations = 900 105 | args.gamma = 0.002 106 | args.alpha = 0.01 107 | args.beta = 0.37 108 | seed = 9 109 | if args.dataset == "Fc_COIL_20": 110 | args.con_iterations = 2000 111 | args.mse_iterations = 400 112 | args.gamma = 0.031 113 | args.alpha = 0.2 114 | args.beta = 0.5 115 | seed = 1 116 | 117 | def setup_seed(seed): 118 | torch.manual_seed(seed) 119 | torch.cuda.manual_seed_all(seed) 120 | #np.random.seed(seed) 121 | #random.seed(seed) 122 | torch.backends.cudnn.deterministic = True 123 | def mask(rows, cols, p): 124 | tensor = np.zeros((rows, cols), dtype=int) 125 | for i in range(rows): 126 | if i < int(rows * p): 127 | while True: 128 | row = np.random.randint(0, 2, size=cols) 129 | if np.count_nonzero(row) < cols and np.count_nonzero(row) > 0: 130 | tensor[i, :] = row 131 | break 132 | else: 133 | tensor[i, :] = 1 134 | np.random.shuffle(tensor) 135 | tensor = torch.tensor(tensor) 136 | return tensor 137 | def add_noise(matrix, std, p): 138 | rows, cols = matrix.shape 139 | noisy_matrix = matrix.clone() 140 | for i in range(rows): 141 | if random.random() < p: 142 | noise = torch.randn(cols, device=device) * std 143 | noisy_matrix[i] += noise 144 | return noisy_matrix 145 | def scale_normalize_matrix(input_matrix, min_value=0, max_value=1): 146 | min_val = input_matrix.min() 147 | max_val = input_matrix.max() 148 | input_range = max_val - min_val 149 | scaled_matrix = (input_matrix - min_val) / input_range * (max_value - min_value) + min_value 150 | 151 | return scaled_matrix 152 | dataset, _, view, data_size, class_num, dimss = load_data(args.dataset) 153 | data_loader = torch.utils.data.DataLoader( 154 | dataset, 155 | batch_size=args.batch_size, 156 | shuffle=True, 157 | drop_last=True 158 | ) 159 | def SCM(iteration, mode,miss_rate,noise_rate,Gaussian_noise): 160 | mse = torch.nn.MSELoss() 161 | for batch_idx, (xs, y, _) in enumerate(data_loader): 162 | # print(y) # different batches 163 | for v in range(view): 164 | xs[v] = xs[v].to(device) 165 | break 166 | masked_xs = [] 167 | noised_xs = [] 168 | num_rows = xs[0].shape[0] 169 | mask_tensor = mask(num_rows,view,miss_rate).to(device) 170 | for v in range(view): 171 | masked_x = mask_tensor[:,v].unsqueeze(1)*xs[v] 172 | masked_xs.append(masked_x) 173 | for v in range(view): 174 | noised_x = add_noise(xs[v],Gaussian_noise,noise_rate) 175 | noised_xs.append(noised_x) 176 | xs_all = torch.cat(xs,dim=1) 177 | mask_all = torch.cat(masked_xs, dim=1) 178 | noise_all = torch.cat(noised_xs, dim=1) 179 | optimizer.zero_grad() 180 | xrs,_,xs_z,q = model(xs_all) 181 | mask_xrs,_,mask_z,_ = model(mask_all) 182 | noise_xrs,_,noise_z,_ = model(noise_all) 183 | loss_xrs = mse(xs_all,xrs) 184 | loss_mask = mse(xs_all,mask_xrs) 185 | loss_noise = mse(xs_all,noise_xrs) 186 | if mode =='SCM' or mode == 'SCM_REC'or mode =='SCM_REC_ETC'or mode =='SCM_ETC': 187 | loss_con_1 = criterion.forward_feature(noise_z, mask_z) 188 | loss_con_2 = criterion.forward_feature(mask_z, noise_z) 189 | if mode =='SCM_w/o_MaskDA': 190 | loss_con_1 = criterion.forward_feature(noise_z, xs_z) 191 | loss_con_2 = criterion.forward_feature(xs_z, noise_z) 192 | if mode == 'SCM_w/o_NoiseDA': 193 | loss_con_1 = criterion.forward_feature(mask_z, xs_z) 194 | loss_con_2 = criterion.forward_feature(xs_z, mask_z) 195 | if mode == 'SCM_w/o_DA': 196 | loss_con_1 = criterion.forward_feature(xs_z, xs_z) 197 | loss_con_2 = criterion.forward_feature(xs_z, xs_z) 198 | if mode =='SCM_REC_ETC' or mode == 'SCM_REC': 199 | loss = loss_xrs + loss_mask + loss_noise + loss_con_1 + loss_con_2 200 | if mode == 'SCM' or mode == 'SCM_ETC' or mode =='SCM_w/o_NoiseDA' or mode =='SCM_w/o_DA' or mode =='SCM_w/o_MaskDA': 201 | loss = loss_con_1+loss_con_2 202 | loss.backward() 203 | optimizer.step() 204 | print('Iteration {}'.format(iteration), 'Loss:{:.6f}'.format(loss)) 205 | def destiny_peak(model, device, gamma=args.gamma, alpha=args.alpha, beta=args.beta, metric='euclidean'): 206 | ALL_loader = DataLoader( 207 | dataset, 208 | batch_size=data_size, 209 | shuffle=False, 210 | ) 211 | for step, (xs, ys, _) in enumerate(ALL_loader): 212 | for v in range(view): 213 | xs[v] = xs[v].to(device) 214 | xs_all = torch.cat(xs, dim=1) 215 | with torch.no_grad(): 216 | _, _, z, _ = model.forward(xs_all) 217 | z = z.cpu().detach().numpy() 218 | 219 | condensed_distance = distance.pdist(z, metric=metric) 220 | d_c = np.sort(condensed_distance)[int(len(condensed_distance) * gamma)] 221 | redundant_distance = distance.squareform(condensed_distance) 222 | rho = np.sum(np.exp(-(redundant_distance / d_c) ** 2), axis=1) 223 | order_distance = np.argsort(redundant_distance, axis=1) 224 | delta = np.zeros_like(rho) 225 | nn = np.zeros_like(rho).astype(int) 226 | for i in range(len(delta)): 227 | mask = rho[order_distance[i]] > rho[i] 228 | if mask.sum() > 0: 229 | nn[i] = order_distance[i][mask][0] 230 | delta[i] = redundant_distance[i, nn[i]] 231 | else: 232 | nn[i] = order_distance[i, -1] 233 | delta[i] = redundant_distance[i, nn[i]] 234 | rho_c = min(rho) + (max(rho) - min(rho)) * alpha 235 | delta_c = min(delta) + (max(delta) - min(delta)) * beta 236 | centers = np.where(np.logical_and(rho > rho_c, delta > delta_c))[0] 237 | num_clusters = len(centers) 238 | cluster_points = z[centers] 239 | probabilities = np.zeros((z.shape[0], num_clusters)) 240 | for i in range(z.shape[0]): 241 | for j in range(num_clusters): 242 | probabilities[i, j] = np.exp(-np.linalg.norm(z[i] - cluster_points[j])) 243 | probabilities /= probabilities.sum(axis=1, keepdims=True) 244 | yyy = torch.from_numpy(probabilities) 245 | yyy = torch.argmax(yyy, dim=1) 246 | confusion = confusion_matrix(yyy, ys) 247 | per = np.sum(np.max(confusion, axis=0)) / np.sum(confusion) 248 | additional_columns = 64 - probabilities.shape[1] 249 | zero_columns = np.zeros((probabilities.shape[0], additional_columns)) 250 | probabilities = np.hstack((probabilities, zero_columns)) 251 | probabilities = torch.from_numpy(probabilities) 252 | print('num:{}'.format(num_clusters), 'accuracy:{:.6f}'.format(per)) 253 | return probabilities 254 | def end2end(iteration,probability_matrix,mode,miss_rate,noise_rate,Gaussian_noise): 255 | 256 | if iteration > args.mse_iterations: 257 | mse = torch.nn.MSELoss() 258 | masked_xs = [] 259 | noised_xs = [] 260 | for batch_idx, (xs, _, idx) in enumerate(data_loader): 261 | for v in range(view): 262 | xs[v] = xs[v].to(device) 263 | idx[v] = idx[v].to(device) 264 | num_rows = xs[0].shape[0] 265 | mask_tensor = mask(num_rows , view, miss_rate).to(device) 266 | for v in range(view): 267 | masked_x = mask_tensor[:, v].unsqueeze(1) * xs[v] 268 | masked_xs.append(masked_x) 269 | for v in range(view): 270 | noised_x = add_noise(xs[v], Gaussian_noise, noise_rate) 271 | noised_xs.append(noised_x) 272 | xs_all = torch.cat(xs, dim=1) 273 | mask_all = torch.cat(masked_xs, dim=1) 274 | noise_all = torch.cat(noised_xs, dim=1) 275 | optimizer.zero_grad() 276 | xrs, _, z_all, q = model(xs_all) 277 | mask_xrs, _, mask_z, mask_q = model(mask_all) 278 | noise_xrs, _, noise_z, noise_q = model(noise_all) 279 | select_rows = probability_matrix[idx] 280 | qs = np.vstack(select_rows) 281 | qs = torch.from_numpy(qs).float() 282 | qs = qs.to(device) 283 | qs = scale_normalize_matrix(qs) 284 | loss_xrs = mse(xs_all, xrs) 285 | loss_mask = mse(xs_all, mask_xrs) 286 | loss_noise = mse(xs_all, noise_xrs) 287 | loss_con_1 = criterion.forward_feature(noise_z, mask_z) 288 | loss_con_2 = criterion.forward_feature(mask_z, noise_z) 289 | loss_con = mse(qs, q) 290 | if mode == 'SCM_REC_ETC': 291 | loss = loss_xrs + loss_mask + loss_noise + loss_con_1 + loss_con_2 + loss_con 292 | if mode == 'SCM_ETC': 293 | loss = loss_con_1 + loss_con_2 + loss_con 294 | loss.backward() 295 | optimizer.step() 296 | print('Epoch {}'.format(iteration), 'Loss:{:.6f}'.format(loss)) 297 | accs = [] 298 | nmis = [] 299 | purs = [] 300 | aris = [] 301 | 302 | for i in range(T): 303 | print("ROUND:{}".format(i + 1)) 304 | setup_seed(seed) 305 | model = Network(dimss, args.feature_dim, args.high_feature_dim,device) 306 | model = model.to(device) 307 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 308 | criterion = Loss(args.batch_size, class_num, args.temperature_f, device).to(device) 309 | mode = args.mode 310 | miss_rate = args.miss_rate 311 | noise_rate = args.noise_rate 312 | Gaussian_noise = args.Gaussian_noise 313 | 314 | time0 = time.time() 315 | iteration = 1 316 | while iteration <= args.con_iterations: 317 | SCM(iteration,mode,miss_rate,noise_rate,Gaussian_noise) 318 | if iteration == args.con_iterations: 319 | if mode == 'SCM' or mode =='SCM_REC' or mode =='SCM_w/o_NoiseDA' or mode =='SCM_w/o_DA' or mode =='SCM_w/o_MaskDA': 320 | acc, nmi, ari, pur = valid(model, device, dataset, view, data_size, class_num, eval_z=True) 321 | accs.append(acc) 322 | nmis.append(nmi) 323 | purs.append(pur) 324 | aris.append(ari) 325 | iteration += 1 326 | 327 | if mode == 'SCM_ETC' or mode =='SCM_REC_ETC': 328 | probability_matrix = destiny_peak(model, device) 329 | while iteration <= args.mse_iterations + args.con_iterations: 330 | end2end(iteration, probability_matrix,mode,miss_rate,noise_rate,Gaussian_noise) 331 | if iteration == args.mse_iterations + args.con_iterations: 332 | acc, nmi, ari, pur = valid(model, device, dataset, view, data_size, class_num, eval_q =True) 333 | accs.append(acc) 334 | nmis.append(nmi) 335 | purs.append(pur) 336 | aris.append(ari) 337 | iteration += 1 338 | 339 | 340 | print('%.4f'% np.mean(accs), '%.4f'% np.std(accs), accs) 341 | print('%.4f'% np.mean(nmis), '%.4f'% np.std(nmis), nmis) 342 | print('%.4f'% np.mean(aris), '%.4f'% np.std(aris), aris) 343 | 344 | 345 | if tsne == True: 346 | miss_x = [] 347 | noise_x = [] 348 | model.eval() 349 | ALL_loader = DataLoader( 350 | dataset, 351 | batch_size=data_size, 352 | shuffle=False, 353 | ) 354 | for step, (xs, ys, _) in enumerate(ALL_loader): 355 | ys = ys.numpy() 356 | for v in range(view): 357 | xs[v] = xs[v].to(device) 358 | num_rows = xs[0].shape[0] 359 | miss = mask(num_rows,view,miss_rate).to(device) 360 | for v in range(view): 361 | miss = miss[:,v].unsqueeze(1)*xs[v] 362 | miss_x.append(miss) 363 | for v in range(view): 364 | noisedx = add_noise(xs[v],Gaussian_noise,noise_rate) 365 | noise_x.append(noisedx) 366 | 367 | xs_all = torch.cat(xs, dim=1) 368 | mask_xx = torch.cat(miss_x, dim=1) 369 | noise_xx = torch.cat(noise_x, dim=1) 370 | with torch.no_grad(): 371 | _, _, z, _ = model.forward(xs_all) 372 | _, _, zm, _ = model.forward(mask_xx) 373 | _, _, zn, _ = model.forward(noise_xx) 374 | z = z.cpu().detach().numpy() 375 | zm = zm.cpu().detach().numpy() 376 | zn = zn.cpu().detach().numpy() 377 | 378 | ttsne(z, ys, "z") 379 | ttsne(zm, ys, "zm") 380 | ttsne(zn, ys, "zn") 381 | --------------------------------------------------------------------------------