├── README.txt ├── data └── bdgp.npz ├── dataloader.py ├── demo.py ├── evaluation.py ├── loss.py ├── make_mask.py ├── meta_network.py ├── network.py └── requirements.txt /README.txt: -------------------------------------------------------------------------------- 1 | # To start training a model on provided datasets, e.g., BDGP, run: 2 | python demo.py --dataset bdgp --miss_rate 0.1 3 | 4 | # Acknowledgements 5 | We thank the PyTorch implementation on DS3L (https://www.lamda.nju.edu.cn/code_DS3L.ashx?AspxAutoDetectCookieSupport=1), Meta-Net (https://github.com/xjtushujun/meta-weight-net) and learning-to-reweight-examples(https://github.com/danieltan07/learning-to-reweight-examples). 6 | -------------------------------------------------------------------------------- /data/bdgp.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gasteinh/DSIMVC/9448423221d43f2a0f1951840515c326868a16ca/data/bdgp.npz -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, Sampler 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class MultiviewDataset(Dataset): 7 | def __init__(self, num_views, data_list, labels): 8 | self.num_views = num_views 9 | self.data_list = data_list 10 | self.labels = labels 11 | 12 | def __len__(self): 13 | return self.data_list[0].shape[0] 14 | 15 | def __getitem__(self, idx): 16 | data = [] 17 | for i in range(self.num_views): 18 | data.append(torch.tensor(self.data_list[i][idx].astype('float32'))) 19 | return data, torch.tensor(self.labels[idx]), torch.tensor(np.array(idx)).long() 20 | 21 | 22 | def load_data(name): 23 | """ 24 | :param name: name of dataset 25 | :return: 26 | data_list: python list containing all views, where each view is represented as numpy array 27 | labels: ground_truth labels represented as numpy array 28 | dims: python list containing dimension of each view 29 | num_views: number of views 30 | data_size: size of data 31 | class_num: number of category 32 | """ 33 | data_path = "./data/" 34 | path = data_path + name + '.npz' 35 | data = np.load(path) 36 | num_views = int(data['n_views']) 37 | data_list = [] 38 | for i in range(num_views): 39 | x = data[f"view_{i}"] 40 | if len(x.shape) > 2: 41 | x = x.reshape([x.shape[0], -1]) 42 | data_list.append(x.astype(np.float32)) 43 | labels = data['labels'] 44 | dims = [] 45 | for i in range(num_views): 46 | dims.append(data_list[i].shape[1]) 47 | class_num = labels.max() + 1 48 | data_size = data_list[0].shape[0] 49 | 50 | return data_list, labels, dims, num_views, data_size, class_num 51 | 52 | 53 | class RandomSampler(Sampler): 54 | """ sampling without replacement """ 55 | def __init__(self, num_data, num_sample): 56 | iterations = num_sample // num_data + 1 57 | self.indices = torch.cat([torch.randperm(num_data) for _ in range(iterations)]).tolist()[:num_sample] 58 | 59 | def __iter__(self): 60 | return iter(self.indices) 61 | 62 | def __len__(self): 63 | return len(self.indices) 64 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import numpy as np 4 | from sklearn.preprocessing import MinMaxScaler 5 | import faiss 6 | import argparse 7 | from meta_network import WNet, SafeNetwork, Online 8 | from network import Network 9 | from dataloader import load_data, MultiviewDataset, RandomSampler 10 | from loss import Loss 11 | from make_mask import get_mask 12 | from evaluation import evaluate 13 | import copy 14 | 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | parser = argparse.ArgumentParser(description='train') 18 | parser.add_argument('--batch_size', default=256, type=int) 19 | parser.add_argument('--dataset', default='bdgp') 20 | parser.add_argument("--view", type=int, default=2) 21 | parser.add_argument("--feature_dim", default=512) 22 | parser.add_argument("--high_feature_dim", type=int, default=128) 23 | parser.add_argument('--lr_wnet', type=float, default=0.0004) 24 | parser.add_argument('--meta_lr', type=float, default=0.001) 25 | parser.add_argument("--epochs", default=120) 26 | parser.add_argument('--lr_decay_factor', type=float, default=0.2) 27 | parser.add_argument('--lr_decay_iter', type=int, default=20) 28 | parser.add_argument('--K', type=int, default=3) 29 | parser.add_argument('--interval', type=int, default=1) 30 | parser.add_argument('--initial_epochs', type=int, default=100) 31 | parser.add_argument('--pretrain_epochs', type=int, default=100) 32 | parser.add_argument('--alpha', type=float, default=0.5) 33 | parser.add_argument('--miss_rate', default=0.1, type=float) 34 | parser.add_argument('--T', default=10, type=int) 35 | parser.add_argument('--iterations', default=200, type=int) 36 | args = parser.parse_args() 37 | 38 | 39 | data_list, Y, dims, total_view, data_size, class_num = load_data(args.dataset) 40 | view = total_view 41 | miss_rate = args.miss_rate 42 | incomplete_loader = None 43 | 44 | if args.dataset not in ['ccv']: 45 | for v in range(total_view): 46 | min_max_scaler = MinMaxScaler() 47 | data_list[v] = min_max_scaler.fit_transform(data_list[v]) 48 | record_data_list = copy.deepcopy(data_list) 49 | 50 | 51 | if args.dataset == 'bdgp': 52 | args.initial_epochs = 30 53 | args.pretrain_epochs = 100 54 | args.iterations = 100 55 | if args.dataset == 'mnist_usps': 56 | args.initial_epochs = 80 57 | args.pretrain_epochs = 100 58 | args.iterations = 200 59 | if args.dataset == 'ccv': 60 | args.initial_epochs = 30 61 | args.pretrain_epochs = 100 62 | args.iterations = 300 63 | if args.dataset == 'multi-fashion': 64 | args.initial_epochs = 100 65 | args.pretrain_epochs = 200 66 | args.iterations = 300 67 | 68 | 69 | def get_model(): 70 | return SafeNetwork(view, dims, args.feature_dim, args.high_feature_dim, class_num).to(device) 71 | 72 | 73 | def pretrain(com_dataset): 74 | """ 75 | pretraining on complete data 76 | :return: parameters of the pretraining model 77 | """ 78 | print("Initializing network parameters...") 79 | pretrain_model = Online(view, dims, args.feature_dim).to(device) 80 | loader = DataLoader(com_dataset, batch_size=args.batch_size, shuffle=True) 81 | opti = torch.optim.Adam(pretrain_model.params(), lr=0.0003) 82 | criterion = torch.nn.MSELoss() 83 | for epoch in range(args.pretrain_epochs): 84 | for batch_idx, (xs, _, _) in enumerate(loader): 85 | for v in range(view): 86 | xs[v] = xs[v].to(device) 87 | xrs = pretrain_model(xs) 88 | loss_list = [] 89 | for v in range(view): 90 | loss_list.append(criterion(xs[v], xrs[v])) 91 | loss = sum(loss_list) 92 | 93 | opti.zero_grad() 94 | loss.backward() 95 | opti.step() 96 | return pretrain_model.state_dict() 97 | 98 | 99 | def bi_level_train(model, criterion, optimizer, class_num, view, 100 | com_loader, full_loader, mask, incomplete_ind): 101 | wnet_label = WNet(class_num, 100, 1).to(device) 102 | memory = Memory() 103 | memory.bi = True 104 | wnet_label.train() 105 | iteration = 0 106 | 107 | optimizer_wnet_label = torch.optim.Adam(wnet_label.params(), lr=args.lr_wnet) 108 | 109 | for com_batch, incomplete_batch in zip(com_loader, incomplete_loader): 110 | xs, _, _ = com_batch 111 | incomplete_xs, _, _ = incomplete_batch 112 | iteration += 1 113 | for v in range(view): 114 | xs[v] = xs[v].to(device) 115 | incomplete_xs[v] = incomplete_xs[v].to(device) 116 | 117 | model.train() 118 | meta_net = get_model() 119 | meta_net.load_state_dict(model.state_dict()) 120 | 121 | com_hs, com_qs, incomplete_hs, incomplete_qs = meta_net(xs, incomplete_xs) 122 | 123 | loss_list = [] 124 | for v in range(view): 125 | for w in range(v+1, view): 126 | loss_list.append(criterion.forward_feature(com_hs[v], com_hs[w])) 127 | loss_list.append(criterion.forward_label(com_qs[v], com_qs[w])) 128 | loss_hat = sum(loss_list) 129 | 130 | cost_w_labels = [] 131 | cost_w_features = [] 132 | for v in range(view): 133 | for w in range(v+1, view): 134 | l_f, l_l = criterion.forward_feature2(incomplete_hs[v], incomplete_hs[w]), criterion.forward_label(incomplete_qs[v], incomplete_qs[w]) 135 | cost_w_labels.append(l_l) 136 | cost_w_features.append(l_f) 137 | 138 | weight_label = wnet_label(sum(incomplete_qs)/view) 139 | norm_label = torch.sum(weight_label) 140 | 141 | for v in range(len(cost_w_labels)): 142 | if norm_label != 0: 143 | loss_hat += (torch.sum(cost_w_features[v] * weight_label)/norm_label 144 | + torch.sum(cost_w_labels[v]*weight_label) / norm_label) 145 | else: 146 | loss_hat += torch.sum(cost_w_labels[v] * weight_label + cost_w_features[v]*weight_label) 147 | 148 | meta_net.zero_grad() 149 | grads = torch.autograd.grad(loss_hat, (meta_net.params()), create_graph=True) 150 | meta_net.update_params(lr_inner=args.meta_lr, source_params=grads) 151 | del grads 152 | 153 | com_hs, com_qs, _, _ = meta_net(xs, incomplete_xs) 154 | 155 | loss_list = [] 156 | for v in range(view): 157 | for w in range(v + 1, view): 158 | loss_list.append(criterion.forward_feature(com_hs[v], com_hs[w])) 159 | loss_list.append(criterion.forward_label(com_qs[v], com_qs[w])) 160 | 161 | l_g_meta = sum(loss_list) 162 | 163 | optimizer_wnet_label.zero_grad() 164 | l_g_meta.backward() 165 | optimizer_wnet_label.step() 166 | 167 | com_hs, com_qs, incomplete_hs, incomplete_qs = model(xs, incomplete_xs) 168 | 169 | loss_list = [] 170 | for v in range(view): 171 | for w in range(v + 1, view): 172 | loss_list.append(criterion.forward_feature(com_hs[v], com_hs[w])) 173 | loss_list.append(criterion.forward_label(com_qs[v], com_qs[w])) 174 | 175 | loss = sum(loss_list) 176 | 177 | cost_w_labels = [] 178 | cost_w_features = [] 179 | for v in range(view): 180 | for w in range(v+1, view): 181 | l_f, l_l = criterion.forward_feature2(incomplete_hs[v], incomplete_hs[w]), criterion.forward_label(incomplete_qs[v], incomplete_qs[w]) 182 | cost_w_labels.append(l_l) 183 | cost_w_features.append(l_f) 184 | 185 | with torch.no_grad(): 186 | weight_label = wnet_label(sum(incomplete_qs)/view) 187 | norm_label = torch.sum(weight_label) 188 | 189 | for v in range(len(cost_w_labels)): 190 | if norm_label != 0: 191 | loss += (torch.sum(cost_w_labels[v] * weight_label)/norm_label 192 | + torch.sum(cost_w_features[v]*weight_label) / norm_label) 193 | else: 194 | loss += torch.sum(cost_w_labels[v] * weight_label + cost_w_features[v]*weight_label) 195 | 196 | optimizer.zero_grad() 197 | loss.backward() 198 | optimizer.step() 199 | 200 | memory.update_feature(model, full_loader, mask, incomplete_ind, iteration) 201 | 202 | acc, nmi, pur = valid(model, mask) 203 | 204 | return acc, nmi, pur 205 | 206 | 207 | def valid(model, mask): 208 | pred_vec = [] 209 | with torch.no_grad(): 210 | input_data = [] 211 | for v in range(view): 212 | data_v = torch.from_numpy(record_data_list[v]).to(device) 213 | input_data.append(data_v) 214 | output, _ = model.forward_cluster(input_data) 215 | for v in range(view): 216 | miss_ind = mask[:, v] == 0 217 | output[v][miss_ind] = 0 218 | sum_ind = np.sum(mask, axis=1, keepdims=True) 219 | output = sum(output)/torch.from_numpy(sum_ind).to(device) 220 | pred_vec.extend(output.detach().cpu().numpy()) 221 | 222 | pred_vec = np.argmax(np.array(pred_vec), axis=1) 223 | acc, nmi, pur = evaluate(Y, pred_vec) 224 | print('ACC = {:.4f} NMI = {:.4f} PUR = {:.4f}'.format(acc, nmi, pur)) 225 | return acc, nmi, pur 226 | 227 | 228 | class Memory: 229 | def __init__(self): 230 | self.features = None 231 | self.alpha = args.alpha 232 | self.interval = args.interval 233 | self.bi = False 234 | 235 | def cal_cur_feature(self, model, loader): 236 | features = [] 237 | for v in range(view): 238 | features.append([]) 239 | 240 | for _, (xs, y, _) in enumerate(loader): 241 | for v in range(view): 242 | xs[v] = xs[v].to(device) 243 | with torch.no_grad(): 244 | if self.bi: 245 | hs, _, _ = model.forward_xs(xs) 246 | else: 247 | hs, _, _ = model(xs) 248 | for v in range(view): 249 | fea = hs[v].detach().cpu().numpy() 250 | features[v].extend(fea) 251 | 252 | for v in range(view): 253 | features[v] = np.array(features[v]) 254 | 255 | return features 256 | 257 | def update_feature(self, model, loader, mask, incomplete_ind, epoch): 258 | topK = 600 259 | model.eval() 260 | cur_features = self.cal_cur_feature(model, loader) 261 | indices = [] 262 | if epoch == 1: 263 | self.features = cur_features 264 | for v in range(view): 265 | fea = np.array(self.features[v]) 266 | n, dim = fea.shape[0], fea.shape[1] 267 | index = faiss.IndexFlatIP(dim) 268 | index.add(fea) 269 | _, ind = index.search(fea, topK + 1) # Sample itself is included 270 | indices.append(ind[:, 1:]) 271 | return indices 272 | elif epoch % self.interval == 0: 273 | for v in range(view): 274 | f_v = (1-self.alpha)*self.features[v] + self.alpha*cur_features[v] 275 | self.features[v] = f_v/np.linalg.norm(f_v, axis=1, keepdims=True) 276 | 277 | n, dim = self.features[v].shape[0], self.features[v].shape[1] 278 | index = faiss.IndexFlatIP(dim) 279 | index.add(self.features[v]) 280 | _, ind = index.search(self.features[v], topK + 1) # Sample itself is included 281 | indices.append(ind[:, 1:]) 282 | if self.bi: 283 | make_imputation(mask, indices, incomplete_ind) 284 | return indices 285 | 286 | 287 | def make_imputation(mask, indices, incomplete_ind): 288 | global data_list 289 | 290 | for v in range(view): 291 | for i in range(data_size): 292 | if mask[i, v] == 0: 293 | predicts = [] 294 | for w in range(view): 295 | # only the available views are selected as neighbors 296 | if w != v and mask[i, w] != 0: 297 | neigh_w = indices[w][i] 298 | for n_w in range(neigh_w.shape[0]): 299 | if mask[neigh_w[n_w], v] != 0 and mask[neigh_w[n_w], w] != 0: 300 | predicts.append(data_list[v][neigh_w[n_w]]) 301 | if len(predicts) >= args.K: 302 | break 303 | 304 | assert len(predicts) >= args.K 305 | fill_sample = np.mean(predicts, axis=0) 306 | data_list[v][i] = fill_sample 307 | 308 | global incomplete_loader 309 | incomplete_data = [] 310 | for v in range(view): 311 | incomplete_data.append(data_list[v][incomplete_ind]) 312 | incomplete_label = Y[incomplete_ind] 313 | incomplete_dataset = MultiviewDataset(view, incomplete_data, incomplete_label) 314 | incomplete_loader = DataLoader( 315 | incomplete_dataset, args.batch_size, drop_last=True, 316 | sampler=RandomSampler(len(incomplete_dataset), args.iterations * args.batch_size) 317 | ) 318 | 319 | 320 | def initial(com_dataset, full_loader, criterion, mask, incomplete_ind): 321 | print("Initializing neighbors...") 322 | online_net = Network(view, dims, args.feature_dim, args.high_feature_dim, class_num).to(device) 323 | loader = DataLoader(com_dataset, batch_size=256, shuffle=True, drop_last=True) 324 | mse_loader = DataLoader(com_dataset, batch_size=256, shuffle=True) 325 | opti = torch.optim.Adam(online_net.parameters(), lr=0.0003, weight_decay=0.) 326 | mse = torch.nn.MSELoss() 327 | 328 | memory = Memory() 329 | memory.interval = 1 330 | epochs = args.initial_epochs 331 | 332 | # pretraining on complete data 333 | 334 | for e in range(1, 201): 335 | for xs, _, _ in mse_loader: 336 | for v in range(view): 337 | xs[v] = xs[v].to(device) 338 | 339 | xrs = online_net.forward_mse(xs) 340 | 341 | loss_list = [] 342 | for v in range(view): 343 | loss_list.append(mse(xrs[v], xs[v])) 344 | loss = sum(loss_list) 345 | 346 | opti.zero_grad() 347 | loss.backward() 348 | opti.step() 349 | 350 | for e in range(1, epochs+1): 351 | for xs, _, _ in loader: 352 | for v in range(view): 353 | xs[v] = xs[v].to(device) 354 | 355 | hs, qs, _ = online_net(xs) 356 | 357 | loss_list = [] 358 | for v in range(view): 359 | for w in range(v+1, view): 360 | loss_list.append(criterion.forward_feature(hs[v], hs[w])) 361 | loss_list.append(criterion.forward_label(qs[v], qs[w])) 362 | loss = sum(loss_list) 363 | 364 | opti.zero_grad() 365 | loss.backward() 366 | opti.step() 367 | 368 | # initial neighbors by the pretrain model 369 | indices = memory.update_feature(online_net, full_loader, mask, incomplete_ind, epoch=1) 370 | make_imputation(mask, indices, incomplete_ind) 371 | 372 | 373 | def main(): 374 | result_record = {"ACC": [], "NMI": [], "PUR": []} 375 | for t in range(1, args.T+1): 376 | print("--------Iter:{}--------".format(t)) 377 | 378 | data_list = copy.deepcopy(record_data_list) 379 | mask = get_mask(view, data_size, miss_rate) 380 | sum_vec = np.sum(mask, axis=1, keepdims=True) 381 | complete_index = (sum_vec[:, 0]) == view 382 | mv_data = [] 383 | for v in range(view): 384 | mv_data.append(data_list[v][complete_index]) 385 | mv_label = Y[complete_index] 386 | com_dataset = MultiviewDataset(view, mv_data, mv_label) 387 | com_loader = DataLoader( 388 | com_dataset, args.batch_size, drop_last=True, 389 | sampler=RandomSampler(len(com_dataset), args.iterations * args.batch_size) 390 | ) 391 | full_dataset = MultiviewDataset(view, data_list, Y) 392 | full_loader = DataLoader(full_dataset, batch_size=args.batch_size, shuffle=False) 393 | incomplete_ind = (sum_vec[:, 0]) != view 394 | 395 | model = get_model() 396 | state_dict = pretrain(com_dataset) 397 | model.load_state_dict(state_dict, strict=False) 398 | optimizer = torch.optim.Adam(model.params(), lr=0.0003, weight_decay=0.) 399 | criterion = Loss(args.batch_size, class_num, view, device) 400 | initial(com_dataset, full_loader, criterion, mask, incomplete_ind) 401 | acc, nmi, pur = bi_level_train(model, criterion, optimizer, class_num, view, com_loader, 402 | full_loader, mask, incomplete_ind) 403 | result_record["ACC"].append(acc) 404 | result_record["NMI"].append(nmi) 405 | result_record["PUR"].append(pur) 406 | 407 | print("----------------Training Finish----------------") 408 | print("----------------Final Results----------------") 409 | print("ACC (mean) = {:.4f} ACC (std) = {:.4f}".format(np.mean(result_record["ACC"]), np.std(result_record["ACC"]))) 410 | print("NMI (mean) = {:.4f} NMI (std) = {:.4f}".format(np.mean(result_record["NMI"]), np.std(result_record["NMI"]))) 411 | print("PUR (mean) = {:.4f} PUR (std) = {:.4f}".format(np.mean(result_record["PUR"]), np.std(result_record["PUR"]))) 412 | 413 | 414 | if __name__ == '__main__': 415 | main() 416 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import v_measure_score, accuracy_score 2 | from scipy.optimize import linear_sum_assignment 3 | import numpy as np 4 | 5 | 6 | def cluster_acc(y_true, y_pred): 7 | y_true = y_true.astype(np.int64) 8 | assert y_pred.size == y_true.size 9 | D = max(y_pred.max(), y_true.max()) + 1 10 | w = np.zeros((D, D), dtype=np.int64) 11 | for i in range(y_pred.size): 12 | w[y_pred[i], y_true[i]] += 1 13 | u = linear_sum_assignment(w.max() - w) 14 | ind = np.concatenate([u[0].reshape(u[0].shape[0], 1), u[1].reshape([u[0].shape[0], 1])], axis=1) 15 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 16 | 17 | 18 | def purity(y_true, y_pred): 19 | y_voted_labels = np.zeros(y_true.shape) 20 | labels = np.unique(y_true) 21 | ordered_labels = np.arange(labels.shape[0]) 22 | for k in range(labels.shape[0]): 23 | y_true[y_true == labels[k]] = ordered_labels[k] 24 | labels = np.unique(y_true) 25 | bins = np.concatenate((labels, [np.max(labels)+1]), axis=0) 26 | 27 | for cluster in np.unique(y_pred): 28 | hist, _ = np.histogram(y_true[y_pred == cluster], bins=bins) 29 | winner = np.argmax(hist) 30 | y_voted_labels[y_pred == cluster] = winner 31 | 32 | return accuracy_score(y_true, y_voted_labels) 33 | 34 | 35 | def evaluate(label, pred): 36 | nmi = v_measure_score(label, pred) 37 | acc = cluster_acc(label, pred) 38 | pur = purity(label, pred) 39 | 40 | return acc, nmi, pur 41 | 42 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | 8 | class Loss(nn.Module): 9 | def __init__(self, batch_size, class_num, view, device): 10 | super(Loss, self).__init__() 11 | self.batch_size = batch_size 12 | self.class_num = class_num 13 | self.device = device 14 | self.view = view 15 | 16 | self.mask = self.mask_correlated_samples(batch_size) 17 | self.similarity = nn.CosineSimilarity(dim=2) 18 | self.criterion = nn.CrossEntropyLoss(reduction="sum") 19 | 20 | def mask_correlated_samples(self, N): 21 | mask = torch.ones((N, N)) 22 | mask = mask.fill_diagonal_(0) 23 | for i in range(N//2): 24 | mask[i, N//2 + i] = 0 25 | mask[N//2 + i, i] = 0 26 | mask = mask.bool() 27 | return mask 28 | 29 | def mask_correlated_samples2(self, N): 30 | m1 = torch.ones((N//2, N//2)) 31 | m1 = m1.fill_diagonal_(0) 32 | m2 = torch.zeros((N//2, N//2)) 33 | mask1 = torch.cat([m1, m2], dim=1) 34 | mask2 = torch.cat([m2, m1], dim=1) 35 | mask = torch.cat([mask1, mask2], dim=0) 36 | mask = mask.bool() 37 | return mask 38 | 39 | def mask_correlated_samples3(self, N): 40 | m1 = torch.ones((N//2, N//2)) 41 | m1 = m1.fill_diagonal_(0) 42 | m2 = torch.zeros((N//2, N//2)) 43 | mask1 = torch.cat([m2, m1], dim=1) 44 | mask2 = torch.cat([m1, m2], dim=1) 45 | mask = torch.cat([mask1, mask2], dim=0) 46 | mask = mask.bool() 47 | return mask 48 | 49 | def forward_feature(self, z1, z2, r=3.0): 50 | mask1 = (torch.norm(z1, p=2, dim=1) < np.sqrt(r)).float().unsqueeze(1) 51 | mask2 = (torch.norm(z2, p=2, dim=1) < np.sqrt(r)).float().unsqueeze(1) 52 | z1 = mask1 * z1 + (1 - mask1) * F.normalize(z1, dim=1) * np.sqrt(r) 53 | z2 = mask2 * z2 + (1 - mask2) * F.normalize(z2, dim=1) * np.sqrt(r) 54 | loss_part1 = -2 * torch.mean(z1 * z2) * z1.shape[1] 55 | square_term = torch.matmul(z1, z2.T) ** 2 56 | loss_part2 = torch.mean(torch.triu(square_term, diagonal=1) + torch.tril(square_term, diagonal=-1)) * \ 57 | z1.shape[0] / (z1.shape[0] - 1) 58 | 59 | return loss_part1 + loss_part2 60 | 61 | def forward_feature2(self, z1, z2, r=3.0): 62 | mask1 = (torch.norm(z1, p=2, dim=1) < np.sqrt(r)).float().unsqueeze(1) 63 | mask2 = (torch.norm(z2, p=2, dim=1) < np.sqrt(r)).float().unsqueeze(1) 64 | z1 = mask1 * z1 + (1 - mask1) * F.normalize(z1, dim=1) * np.sqrt(r) 65 | z2 = mask2 * z2 + (1 - mask2) * F.normalize(z2, dim=1) * np.sqrt(r) 66 | loss_part1 = -2 * torch.sum(z1*z2, dim=1, keepdim=True)/z1.shape[0] 67 | square_term = torch.matmul(z1, z2.T) ** 2 68 | loss_part2 = torch.sum(torch.triu(square_term, diagonal=1) + torch.tril(square_term, diagonal=-1), dim=1, 69 | keepdim=True) \ 70 | / (z1.shape[0] * (z1.shape[0] - 1)) 71 | 72 | return loss_part1 + loss_part2 73 | 74 | def forward_label(self, q_i, q_j): 75 | p_i = q_i.sum(0).view(-1) 76 | p_i /= p_i.sum() 77 | ne_i = math.log(p_i.size(0)) + (p_i * torch.log(p_i)).sum() 78 | p_j = q_j.sum(0).view(-1) 79 | p_j /= p_j.sum() 80 | ne_j = math.log(p_j.size(0)) + (p_j * torch.log(p_j)).sum() 81 | entropy = ne_i + ne_j 82 | 83 | q_i = q_i.t() 84 | q_j = q_j.t() 85 | N = 2 * self.class_num 86 | 87 | q = torch.cat((q_i, q_j), dim=0) 88 | 89 | sim = self.similarity(q.unsqueeze(1), q.unsqueeze(0)) 90 | sim_i_j = torch.diag(sim, self.class_num) 91 | sim_j_i = torch.diag(sim, -self.class_num) 92 | 93 | positive_clusters = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) 94 | mask = self.mask_correlated_samples2(N) 95 | negative_clusters = sim[mask].reshape(N, -1) 96 | 97 | labels = torch.zeros(N).to(positive_clusters.device).long() 98 | logits = torch.cat((positive_clusters, negative_clusters), dim=1) 99 | loss = self.criterion(logits, labels) 100 | loss /= N 101 | 102 | return loss + entropy 103 | 104 | -------------------------------------------------------------------------------- /make_mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.random import randint 3 | import random 4 | import math 5 | 6 | 7 | def get_mask(view_num, data_size, missing_ratio): 8 | """ 9 | :param view_num: number of views 10 | :param data_size: size of data 11 | :param missing_ratio: missing ratio 12 | :return: mask matrix 13 | """ 14 | assert view_num >= 2 15 | miss_sample_num = math.floor(data_size*missing_ratio) 16 | data_ind = [i for i in range(data_size)] 17 | random.shuffle(data_ind) 18 | miss_ind = data_ind[:miss_sample_num] 19 | mask = np.ones([data_size, view_num]) 20 | for j in range(miss_sample_num): 21 | while True: 22 | rand_v = np.random.rand(view_num) 23 | v_threshold = np.random.rand(1) 24 | observed_ind = (rand_v >= v_threshold) 25 | ind_ = ~observed_ind 26 | rand_v[observed_ind] = 1 27 | rand_v[ind_] = 0 28 | if np.sum(rand_v) > 0 and np.sum(rand_v) < view_num: 29 | break 30 | mask[miss_ind[j]] = rand_v 31 | 32 | return mask 33 | 34 | -------------------------------------------------------------------------------- /meta_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | EPS = 1e-10 6 | 7 | 8 | def to_var(x, requires_grad=True): 9 | if torch.cuda.is_available(): 10 | x = x.cuda() 11 | return Variable(x, requires_grad=requires_grad) 12 | 13 | 14 | class MetaModule(nn.Module): 15 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE 16 | 17 | def params(self): 18 | for name, param in self.named_params(self): 19 | yield param 20 | 21 | def named_leaves(self): 22 | return [] 23 | 24 | def named_submodules(self): 25 | return [] 26 | 27 | def named_params(self, curr_module=None, memo=None, prefix=''): 28 | if memo is None: 29 | memo = set() 30 | 31 | if hasattr(curr_module, 'named_leaves'): 32 | for name, p in curr_module.named_leaves(): 33 | if p is not None and p not in memo: 34 | memo.add(p) 35 | yield prefix + ('.' if prefix else '') + name, p 36 | else: 37 | for name, p in curr_module._parameters.items(): 38 | if p is not None and p not in memo: 39 | memo.add(p) 40 | yield prefix + ('.' if prefix else '') + name, p 41 | 42 | for mname, module in curr_module.named_children(): 43 | submodule_prefix = prefix + ('.' if prefix else '') + mname 44 | for name, p in self.named_params(module, memo, submodule_prefix): 45 | yield name, p 46 | 47 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 48 | if source_params is not None: 49 | for tgt, src in zip(self.named_params(self), source_params): 50 | name_t, param_t = tgt 51 | # name_s, param_s = src 52 | # grad = param_s.grad 53 | # name_s, param_s = src 54 | grad = src 55 | if first_order: 56 | grad = to_var(grad.detach().data) 57 | tmp = param_t - lr_inner * grad 58 | self.set_param(self, name_t, tmp) 59 | else: 60 | 61 | for name, param in self.named_params(self): 62 | if not detach: 63 | grad = param.grad 64 | if first_order: 65 | grad = to_var(grad.detach().data) 66 | tmp = param - lr_inner * grad 67 | self.set_param(self, name, tmp) 68 | else: 69 | param = param.detach_() # https://blog.csdn.net/qq_39709535/article/details/81866686 70 | self.set_param(self, name, param) 71 | 72 | def set_param(self, curr_mod, name, param): 73 | if '.' in name: 74 | n = name.split('.') 75 | module_name = n[0] 76 | rest = '.'.join(n[1:]) 77 | for name, mod in curr_mod.named_children(): 78 | if module_name == name: 79 | self.set_param(mod, rest, param) 80 | break 81 | else: 82 | setattr(curr_mod, name, param) 83 | 84 | def detach_params(self): 85 | for name, param in self.named_params(self): 86 | self.set_param(self, name, param.detach()) 87 | 88 | def copy(self, other, same_var=False): 89 | for name, param in other.named_params(): 90 | if not same_var: 91 | param = to_var(param.data.clone(), requires_grad=True) 92 | self.set_param(name, param) 93 | 94 | 95 | class MetaLinear(MetaModule): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__() 98 | ignore = nn.Linear(*args, **kwargs) 99 | 100 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 101 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 102 | 103 | def forward(self, x): 104 | return F.linear(x, self.weight, self.bias) 105 | 106 | def named_leaves(self): 107 | return [('weight', self.weight), ('bias', self.bias)] 108 | 109 | 110 | class WNet(MetaModule): 111 | def __init__(self, input, hidden, output): 112 | super(WNet, self).__init__() 113 | self.linear1 = MetaLinear(input, hidden) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.linear2 = MetaLinear(hidden, output) 116 | 117 | def forward(self, x): 118 | x = self.linear1(x) 119 | x = self.relu(x) 120 | out = self.linear2(x) 121 | return torch.sigmoid(out) 122 | 123 | 124 | class Encoder(MetaModule): 125 | def __init__(self, input_dim, feature_dim): 126 | super(Encoder, self).__init__() 127 | self.encoder = nn.Sequential( 128 | MetaLinear(input_dim, 500), 129 | nn.ReLU(), 130 | MetaLinear(500, 500), 131 | nn.ReLU(), 132 | MetaLinear(500, 2000), 133 | nn.ReLU(), 134 | MetaLinear(2000, feature_dim), 135 | ) 136 | 137 | def forward(self, x): 138 | return self.encoder(x) 139 | 140 | 141 | class Decoder(MetaModule): 142 | def __init__(self, input_dim, feature_dim): 143 | super(Decoder, self).__init__() 144 | self.decoder = nn.Sequential( 145 | MetaLinear(feature_dim, 2000), 146 | nn.ReLU(), 147 | MetaLinear(2000, 500), 148 | nn.ReLU(), 149 | MetaLinear(500, 500), 150 | nn.ReLU(), 151 | MetaLinear(500, input_dim) 152 | ) 153 | 154 | def forward(self, x): 155 | return self.decoder(x) 156 | 157 | 158 | class SafeNetwork(MetaModule): 159 | def __init__(self, view, input_size, feature_dim, high_feature_dim, class_num): 160 | super(SafeNetwork, self).__init__() 161 | self.encoders = [] 162 | for v in range(view): 163 | self.encoders.append(Encoder(input_size[v], feature_dim)) 164 | self.encoders = nn.ModuleList(self.encoders) 165 | self.feature_submodule = nn.Sequential( 166 | MetaLinear(feature_dim, feature_dim), 167 | nn.ReLU(), 168 | MetaLinear(feature_dim, high_feature_dim) 169 | ) 170 | self.label_submodule = nn.Sequential( 171 | MetaLinear(feature_dim, feature_dim), 172 | nn.ReLU(), 173 | MetaLinear(feature_dim, class_num), 174 | nn.Softmax(dim=1)) 175 | self.view = view 176 | 177 | def forward(self, xs, xs_incomplete): 178 | qs = [] 179 | qs_incomplete = [] 180 | zs = [] 181 | zs_incomplete = [] 182 | for v in range(self.view): 183 | x = xs[v] 184 | z = self.encoders[v](x) 185 | h = self.feature_submodule(z) 186 | q = self.label_submodule(z) 187 | zs.append(h) 188 | qs.append(q) 189 | 190 | x_ = xs_incomplete[v] 191 | z_ = self.encoders[v](x_) 192 | h_ = self.feature_submodule(z_) 193 | q_ = self.label_submodule(z_) 194 | zs_incomplete.append(h_) 195 | qs_incomplete.append(q_) 196 | return zs, qs, zs_incomplete, qs_incomplete 197 | 198 | def forward_xs(self, xs): 199 | hs = [] 200 | for v in range(self.view): 201 | z = self.encoders[v](xs[v]) 202 | hs.append(self.feature_submodule(z)) 203 | 204 | return hs, None, None 205 | 206 | def forward_s(self, xs): 207 | qs = [] 208 | zs = [] 209 | for v in range(self.view): 210 | x = xs[v] 211 | z = self.encoders[v](x) 212 | h = self.feature_submodule(z) 213 | q = self.label_submodule(z) 214 | zs.append(h) 215 | qs.append(q) 216 | 217 | return zs, qs 218 | 219 | def forward_cluster(self, xs): 220 | qs = [] 221 | preds = [] 222 | for v in range(self.view): 223 | x = xs[v] 224 | z = self.encoders[v](x) 225 | q = self.label_submodule(z) 226 | pred = torch.argmax(q, dim=1) 227 | qs.append(q) 228 | preds.append(pred) 229 | return qs, preds 230 | 231 | 232 | class Online(MetaModule): 233 | def __init__(self, view, input_size, feature_dim): 234 | super(Online, self).__init__() 235 | self.encoders = [] 236 | self.decoders = [] 237 | for v in range(view): 238 | self.encoders.append(Encoder(input_size[v], feature_dim)) 239 | self.decoders.append(Decoder(input_size[v], feature_dim)) 240 | self.encoders = nn.ModuleList(self.encoders) 241 | self.decoders = nn.ModuleList(self.decoders) 242 | self.view = view 243 | 244 | def forward(self, xs): 245 | xrs = [] 246 | for v in range(self.view): 247 | z = self.encoders[v](xs[v]) 248 | xrs.append(self.decoders[v](z)) 249 | 250 | return xrs 251 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.functional import normalize 3 | import torch 4 | 5 | 6 | class Encoder(nn.Module): 7 | def __init__(self, input_dim, feature_dim): 8 | super(Encoder, self).__init__() 9 | self.encoder = nn.Sequential( 10 | nn.Linear(input_dim, 500), 11 | nn.ReLU(), 12 | nn.Linear(500, 500), 13 | nn.ReLU(), 14 | nn.Linear(500, 2000), 15 | nn.ReLU(), 16 | nn.Linear(2000, feature_dim), 17 | ) 18 | 19 | def forward(self, x): 20 | return self.encoder(x) 21 | 22 | 23 | class Decoder(nn.Module): 24 | def __init__(self, input_dim, feature_dim): 25 | super(Decoder, self).__init__() 26 | self.decoder = nn.Sequential( 27 | nn.Linear(feature_dim, 2000), 28 | nn.ReLU(), 29 | nn.Linear(2000, 500), 30 | nn.ReLU(), 31 | nn.Linear(500, 500), 32 | nn.ReLU(), 33 | nn.Linear(500, input_dim) 34 | ) 35 | 36 | def forward(self, x): 37 | return self.decoder(x) 38 | 39 | 40 | class Network(nn.Module): 41 | def __init__(self, view, input_size, feature_dim, high_feature_dim, class_num): 42 | super(Network, self).__init__() 43 | self.encoders = [] 44 | self.decoders = [] 45 | for v in range(view): 46 | self.encoders.append(Encoder(input_size[v], feature_dim)) 47 | self.decoders.append(Decoder(input_size[v], feature_dim)) 48 | self.encoders = nn.ModuleList(self.encoders) 49 | self.decoders = nn.ModuleList(self.decoders) 50 | self.feature_submodule = nn.Sequential( 51 | nn.Linear(feature_dim, feature_dim), 52 | nn.ReLU(), 53 | nn.Linear(feature_dim, high_feature_dim), 54 | ) 55 | self.label_submodule = nn.Sequential( 56 | nn.Linear(feature_dim, feature_dim), 57 | nn.ReLU(), 58 | nn.Linear(feature_dim, class_num), 59 | nn.Softmax(dim=1) 60 | ) 61 | self.view = view 62 | 63 | def forward(self, xs): 64 | hs = [] 65 | qs = [] 66 | xrs = [] 67 | for v in range(self.view): 68 | x = xs[v] 69 | z = self.encoders[v](x) 70 | h = normalize(self.feature_submodule(z), dim=1) 71 | q = self.label_submodule(z) 72 | xr = self.decoders[v](z) 73 | hs.append(h) 74 | qs.append(q) 75 | xrs.append(xr) 76 | return hs, qs, xrs 77 | 78 | def forward_mse(self, xs): 79 | xrs = [] 80 | for v in range(self.view): 81 | z = self.encoders[v](xs[v]) 82 | xrs.append(self.decoders[v](z)) 83 | 84 | return xrs 85 | 86 | def forward_cluster(self, xs): 87 | qs = [] 88 | preds = [] 89 | for v in range(self.view): 90 | x = xs[v] 91 | z = self.encoders[v](x) 92 | q = self.label_submodule(z) 93 | pred = torch.argmax(q, dim=1) 94 | qs.append(q) 95 | preds.append(pred) 96 | return qs, preds 97 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.8.10 2 | pytorch==1.11.0 3 | numpy==1.22.3 4 | scikit-learn==1.0.2 5 | scipy==1.8.0 6 | faiss-gpu==1.7.2 --------------------------------------------------------------------------------