├── AND ├── adjacency.py ├── config.yml ├── net │ └── gcn_v.py ├── train.py └── util │ ├── confidence.py │ ├── deduce.py │ ├── evaluate.py │ └── metrics.py ├── GCN ├── adjacency.py ├── cluster.py ├── config.yml ├── net │ ├── gat.py │ └── optim_modules.py ├── train.py └── util │ ├── confidence.py │ ├── deduce.py │ ├── evaluate.py │ ├── graph.py │ └── metrics.py ├── LICENSE ├── README.md ├── image ├── fig.png ├── results.png └── results2.png ├── requirements.txt ├── script ├── cluster.sh ├── faiss_search.sh ├── gene_adj.sh ├── max_Q_ind.sh ├── structure_space.sh ├── train_AND.sh └── train_GCN.sh └── tool ├── adjacency.py ├── faiss_search.py ├── gene_adj.py ├── gene_adj_adanets.py ├── knn.py ├── max_Q_ind.py └── struct_space.py /AND/adjacency.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import scipy.sparse as sp 6 | 7 | 8 | def row_normalize(mx): 9 | """Row-normalize sparse matrix""" 10 | rowsum = np.array(mx.sum(1)) 11 | # if rowsum <= 0, keep its previous value 12 | rowsum[rowsum <= 0] = 1 13 | r_inv = np.power(rowsum, -1).flatten() 14 | r_inv[np.isinf(r_inv)] = 0. 15 | r_mat_inv = sp.diags(r_inv) 16 | mx = r_mat_inv.dot(mx) 17 | return mx 18 | 19 | 20 | def build_symmetric_adj(adj, self_loop=True): 21 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 22 | if self_loop: 23 | adj = adj + sp.eye(adj.shape[0]) 24 | return adj 25 | 26 | 27 | def sparse_mx_to_indices_values(sparse_mx): 28 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 29 | indices = np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64) 30 | values = sparse_mx.data 31 | shape = np.array(sparse_mx.shape) 32 | return indices, values, shape 33 | 34 | 35 | def indices_values_to_sparse_tensor(indices, values, shape): 36 | import torch 37 | indices = torch.from_numpy(indices) 38 | values = torch.from_numpy(values) 39 | shape = torch.Size(shape) 40 | return torch.sparse.FloatTensor(indices, values, shape) 41 | 42 | 43 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 44 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 45 | indices, values, shape = sparse_mx_to_indices_values(sparse_mx) 46 | return indices_values_to_sparse_tensor(indices, values, shape) 47 | -------------------------------------------------------------------------------- /AND/config.yml: -------------------------------------------------------------------------------- 1 | # model 2 | feat_dim: 256 3 | nhid: 512 4 | nclass: 1 5 | 6 | # optimizer 7 | lr: 0.01 8 | sgd_momentum: 0.9 9 | sgd_weight_decay: 0.00001 10 | lr_step : [0.5, 0.8, 0.9] 11 | factor: 0.1 12 | total_step: 40000 13 | cuda: True 14 | warmup_step: 128 15 | batchsize: 1024 16 | 17 | # output 18 | save_freq: 20000 19 | log_freq: 1 20 | val_freq: 1000 21 | # resume 22 | resume_path: 23 | -------------------------------------------------------------------------------- /AND/net/gcn_v.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | #from .utils import GraphConv, MeanAggregator 7 | from dgl.nn.pytorch import GraphConv 8 | import torch.nn.functional as F 9 | import math 10 | 11 | 12 | class huberloss(nn.Module): 13 | def __init__(self, delta): 14 | super(huberloss, self).__init__() 15 | self.delta = delta 16 | 17 | def forward(self, input_arr, target_arr): 18 | rate = input_arr / target_arr - 1 19 | loss = torch.where(torch.abs(rate) <= self.delta, 0.5*rate*rate, (torch.abs(rate) - 0.5*self.delta) * self.delta) 20 | return loss.mean() 21 | 22 | class MREloss(nn.Module): 23 | def __init__(self): 24 | super(MREloss, self).__init__() 25 | 26 | def forward(self, input_arr, target_arr): 27 | loss = torch.abs(input_arr / target_arr - 1) 28 | return loss.mean() 29 | 30 | 31 | class GCN_V(nn.Module): 32 | def __init__(self, feature_dim, nhid, nclass, dropout=0): 33 | super(GCN_V, self).__init__() 34 | self.lstm = nn.LSTM(input_size=feature_dim, hidden_size=feature_dim, num_layers=1, batch_first=True, dropout=dropout, bidirectional=True) 35 | self.out_proj = nn.Linear(2*feature_dim, feature_dim, bias=True) 36 | 37 | self.nclass = nclass 38 | self.mlp = nn.Sequential( 39 | nn.Linear(feature_dim, nhid), nn.PReLU(nhid), nn.Dropout(p=dropout), 40 | nn.Linear(nhid, feature_dim), nn.PReLU(feature_dim), nn.Dropout(p=dropout), 41 | ) 42 | self.regressor = nn.Linear(feature_dim, 1) 43 | #self.loss = torch.nn.MSELoss() 44 | #self.loss = MREloss() 45 | self.loss = huberloss(delta=1.0) 46 | 47 | def forward(self, data, output_feat=False, return_loss=False): 48 | assert not output_feat or not return_loss 49 | batch_feat, batch_label = data[0], data[1] 50 | 51 | # lstm block 52 | out, (hn, cn) = self.lstm(batch_feat) 53 | out = self.out_proj(out) 54 | out = (out + batch_feat) / math.sqrt(2.) 55 | 56 | # normalize before mean 57 | out = F.normalize(out, 2, dim=-1) 58 | out = out.mean(dim=1) 59 | out = F.normalize(out, 2, dim=-1) 60 | 61 | # mlp block 62 | residual = out 63 | out = self.mlp(out) 64 | out = (residual + out ) / math.sqrt(2.) 65 | 66 | # regressor block 67 | pred = self.regressor(out).view(-1) 68 | 69 | if output_feat: 70 | return pred, residual 71 | 72 | if return_loss: 73 | loss = self.loss(pred, batch_label) 74 | return pred, loss 75 | 76 | return pred 77 | 78 | 79 | def gcn_v(feature_dim, nhid, nclass=1, dropout=0., **kwargs): 80 | model = GCN_V(feature_dim=feature_dim, 81 | nhid=nhid, 82 | nclass=nclass, 83 | dropout=dropout) 84 | return model 85 | -------------------------------------------------------------------------------- /AND/train.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | from __future__ import division 3 | import torch 4 | import torch.optim as optim 5 | from adjacency import sparse_mx_to_torch_sparse_tensor 6 | from net.gcn_v import GCN_V 7 | import yaml 8 | from easydict import EasyDict 9 | from tensorboardX import SummaryWriter 10 | import numpy as np 11 | import scipy.sparse as sp 12 | import time 13 | import pprint 14 | import sys 15 | import os 16 | import argparse 17 | import math 18 | import pandas as pd 19 | import dgl 20 | import warnings 21 | from tqdm import tqdm 22 | 23 | 24 | class node_dataset(torch.utils.data.Dataset): 25 | def __init__(self, node_list, **kwargs): 26 | self.node_list = node_list 27 | 28 | def __getitem__(self, index): 29 | return self.node_list[index] 30 | 31 | def __len__(self): 32 | return len(self.node_list) 33 | 34 | def row_normalize(mx): 35 | """Row-normalize sparse matrix""" 36 | rowsum = np.array(mx.sum(1)) 37 | # if rowsum <= 0, keep its previous value 38 | rowsum[rowsum <= 0] = 1 39 | r_inv = np.power(rowsum, -1).flatten() 40 | r_inv[np.isinf(r_inv)] = 0. 41 | r_mat_inv = sp.diags(r_inv) 42 | mx = r_mat_inv.dot(mx) 43 | return mx 44 | 45 | class AverageMeter(object): 46 | def __init__(self): 47 | self.val = 0 48 | self.avg = 0 49 | self.sum = 0 50 | self.count = 0 51 | def reset(self): 52 | self.val = 0 53 | self.avg = 0 54 | self.sum = 0 55 | self.count = 0 56 | def update(self, val, n=1): 57 | self.val = val 58 | self.sum += val * n 59 | self.count += n 60 | self.avg = float(self.sum) / (self.count + 1e-10) 61 | 62 | class Timer(): 63 | def __init__(self, name='task', verbose=True): 64 | self.name = name 65 | self.verbose = verbose 66 | 67 | def __enter__(self): 68 | print('[begin {}]'.format(self.name)) 69 | self.start = time.time() 70 | return self 71 | 72 | def __exit__(self, exc_type, exc_val, exc_tb): 73 | if self.verbose: 74 | print('[done {}] use {:.3f} s'.format(self.name, time.time() - self.start)) 75 | return exc_type is None 76 | 77 | def adjust_lr(cur_epoch, param, cfg): 78 | if cur_epoch not in cfg.step_number: 79 | return 80 | ind = cfg.step_number.index(cur_epoch) 81 | for each in optimizer.param_groups: 82 | each['lr'] = lr 83 | 84 | def cos_lr(current_step, optimizer, cfg): 85 | if current_step < cfg.warmup_step: 86 | rate = 1.0 * current_step / cfg.warmup_step 87 | lr = cfg.lr * rate 88 | else: 89 | n1 = cfg.total_step - cfg.warmup_step 90 | n2 = current_step - cfg.warmup_step 91 | rate = (1 + math.cos(math.pi * n2 / n1)) / 2 92 | lr = cfg.lr * rate 93 | for each in optimizer.param_groups: 94 | each['lr'] = lr 95 | 96 | if __name__ == "__main__": 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument('--config_file', type=str) 99 | parser.add_argument('--outpath', type=str) 100 | parser.add_argument('--phase', type=str) 101 | parser.add_argument('--train_featfile', type=str) 102 | parser.add_argument('--train_Ifile', type=str) 103 | parser.add_argument('--train_labelfile', type=str) 104 | parser.add_argument('--test_featfile', type=str) 105 | parser.add_argument('--test_Ifile', type=str) 106 | parser.add_argument('--test_labelfile', type=str) 107 | parser.add_argument('--resume_path', type=str) 108 | args = parser.parse_args() 109 | 110 | beg_time = time.time() 111 | config = yaml.load(open(args.config_file, "r"), Loader=yaml.FullLoader) 112 | cfg = EasyDict(config) 113 | cfg.step_number = [int(r * cfg.total_step) for r in cfg.lr_step] 114 | 115 | # force assignment 116 | for key, value in args._get_kwargs(): 117 | cfg[key] = value 118 | #cfg[list(dict(train_adjfile=train_adjfile).keys())[0]] = train_adjfile 119 | #cfg[list(dict(train_labelfile=train_labelfile).keys())[0]] = train_labelfile 120 | #cfg[list(dict(test_adjfile=test_adjfile).keys())[0]] = test_adjfile 121 | #cfg[list(dict(test_labelfile=test_labelfile).keys())[0]] = test_labelfile 122 | print("train hyper parameter list") 123 | pprint.pprint(cfg) 124 | 125 | # get model 126 | model = GCN_V(feature_dim=cfg.feat_dim, nhid=cfg.nhid, nclass=cfg.nclass, dropout=0.5) 127 | model.cuda() 128 | 129 | # get dataset 130 | scale_max = 80. 131 | with Timer('load data'): 132 | train_feature = np.load(cfg.train_featfile) 133 | train_feature = train_feature / np.linalg.norm(train_feature, axis=1, keepdims=True) 134 | train_adj = np.load(cfg.train_Ifile)[:, :int(scale_max)] 135 | train_label_k = np.load(cfg.train_labelfile).astype(np.float32) 136 | train_label_s = train_label_k / scale_max 137 | train_feature = torch.FloatTensor(train_feature).cuda() 138 | train_label_s = torch.FloatTensor(train_label_s).cuda() 139 | train_data = (train_feature, train_adj, train_label_s) 140 | 141 | test_feature = np.load(cfg.test_featfile) 142 | test_feature = test_feature / np.linalg.norm(test_feature, axis=1, keepdims=True) 143 | test_adj = np.load(cfg.test_Ifile)[:, :int(scale_max)] 144 | test_label_k = np.load(cfg.test_labelfile).astype(np.float32) 145 | test_label_s = test_label_k / scale_max 146 | test_feature = torch.FloatTensor(test_feature).cuda() 147 | test_label_s = torch.FloatTensor(test_label_s).cuda() 148 | 149 | train_dataset = node_dataset(range(len(train_feature))) 150 | test_dataset = node_dataset(range(len(test_feature))) 151 | train_dataloader = torch.utils.data.DataLoader( 152 | dataset=train_dataset, 153 | batch_size=cfg.batchsize, 154 | shuffle=True, 155 | num_workers=16, 156 | pin_memory=True, 157 | drop_last=False) 158 | 159 | test_dataloader = torch.utils.data.DataLoader( 160 | dataset=test_dataset, 161 | batch_size=cfg.batchsize, 162 | shuffle=False, 163 | num_workers=16, 164 | pin_memory=True, 165 | drop_last=False) 166 | 167 | if cfg.phase == 'train': 168 | optimizer = optim.SGD(model.parameters(), cfg.lr, momentum=cfg.sgd_momentum, weight_decay=cfg.sgd_weight_decay) 169 | beg_step = 0 170 | if cfg.resume_path != None: 171 | beg_step = int(os.path.splitext(os.path.basename(cfg.resume_path))[0].split('_')[1]) 172 | with Timer('resume model from %s'%cfg.resume_path): 173 | ckpt = torch.load(cfg.resume_path, map_location='cpu') 174 | model.load_state_dict(ckpt['state_dict']) 175 | 176 | train_loss_meter = AverageMeter() 177 | train_kdiff_meter = AverageMeter() 178 | train_mre_meter = AverageMeter() 179 | test_loss_meter = AverageMeter() 180 | test_kdiff_meter = AverageMeter() 181 | test_mre_meter = AverageMeter() 182 | writer = SummaryWriter(os.path.join(cfg.outpath), filename_suffix='') 183 | 184 | current_step = beg_step 185 | break_flag = False 186 | while 1: 187 | if break_flag: 188 | break 189 | iter_begtime = time.time() 190 | for _, index in enumerate(train_dataloader): 191 | if current_step > cfg.total_step: 192 | break_flag = True 193 | break 194 | current_step += 1 195 | cos_lr(current_step, optimizer, cfg) 196 | 197 | batch_feature = train_feature[train_adj[index]] 198 | batch_label = train_label_s[index] 199 | batch_k = train_label_k[index] 200 | batch_data = (batch_feature, batch_label) 201 | 202 | model.train() 203 | pred_arr, train_loss = model(batch_data, return_loss=True) 204 | optimizer.zero_grad() 205 | train_loss.backward() 206 | optimizer.step() 207 | 208 | train_loss_meter.update(train_loss.item()) 209 | pred_arr = pred_arr.data.cpu().numpy() 210 | 211 | # add this clip 212 | k_hat = np.round(pred_arr * scale_max) 213 | k_hat[np.where(k_hat < 1)[0]] = 1 214 | k_hat[np.where(k_hat > scale_max)[0]] = scale_max 215 | 216 | train_kdiff = np.abs(k_hat - batch_k) 217 | train_kdiff_meter.update(train_kdiff.mean()) 218 | train_mre = train_kdiff / batch_k 219 | train_mre_meter.update(train_mre.mean()) 220 | writer.add_scalar('lr', optimizer.param_groups[0]['lr'], global_step=current_step) 221 | writer.add_scalar('loss/train', train_loss.item(), global_step=current_step) 222 | writer.add_scalar('kdiff/train', train_kdiff_meter.val, global_step=current_step) 223 | writer.add_scalar('mre/train', train_mre_meter.val, global_step=current_step) 224 | if current_step % cfg.log_freq == 0: 225 | log = "step:{}, step_time:{:.3f}, lr:{:.8f}, trainloss:{:.4f}({:.4f}), trainkdiff:{:.2f}({:.2f}), trainmre:{:.2f}({:.2f}), testloss:{:.4f}({:.4f}), testkdiff:{:.2f}({:.2f}), testmre:{:.2f}({:.2f})".format(current_step, time.time()-iter_begtime, optimizer.param_groups[0]['lr'], train_loss_meter.val, train_loss_meter.avg, train_kdiff_meter.val, train_kdiff_meter.avg, train_mre_meter.val, train_mre_meter.avg, test_loss_meter.val, test_loss_meter.avg, test_kdiff_meter.val, test_kdiff_meter.avg, test_mre_meter.val, test_mre_meter.avg) 226 | print(log) 227 | iter_begtime = time.time() 228 | if (current_step) % cfg.save_freq == 0 and current_step > 0: 229 | torch.save({'state_dict' : model.state_dict(), 'step': current_step}, 230 | os.path.join(cfg.outpath, "ckpt_%s.pth"%(current_step))) 231 | 232 | if (current_step) % cfg.val_freq == 0 and current_step > 0: 233 | pred_list = [] 234 | model.eval() 235 | testloss_list = [] 236 | for step, index in enumerate(tqdm(test_dataloader, desc='test phase', disable=False)): 237 | 238 | batch_feature = test_feature[test_adj[index]] 239 | batch_label = test_label_s[index] 240 | batch_data = (batch_feature, batch_label) 241 | 242 | pred, test_loss = model(batch_data, return_loss=True) 243 | pred_list.append(pred.data.cpu().numpy()) 244 | testloss_list.append(test_loss.item()) 245 | 246 | pred_list = np.concatenate(pred_list) 247 | k_hat, k_arr = pred_list * scale_max, test_label_k 248 | 249 | # add this clip before eval 250 | k_hat = np.round(k_hat) 251 | k_hat[np.where(k_hat < 1)[0]] = 1 252 | k_hat[np.where(k_hat > scale_max)[0]] = scale_max 253 | 254 | test_kdiff = np.abs(np.round(k_hat) - k_arr.reshape(-1)) 255 | test_mre = test_kdiff / k_arr.reshape(-1) 256 | test_kdiff_meter.update(test_kdiff.mean()) 257 | test_mre_meter.update(test_mre.mean()) 258 | test_loss_meter.update(np.mean(testloss_list)) 259 | writer.add_scalar('loss/test', test_loss_meter.val, global_step=current_step) 260 | writer.add_scalar('kdiff/test', test_kdiff_meter.val, global_step=current_step) 261 | writer.add_scalar('mre/test', test_mre_meter.val, global_step=current_step) 262 | 263 | writer.close() 264 | else: 265 | ckpt = torch.load(cfg.resume_path, map_location='cpu') 266 | model.load_state_dict(ckpt['state_dict']) 267 | 268 | pred_list, gcnfeat_list = [], [] 269 | model.eval() 270 | beg_time = time.time() 271 | for step, index, in enumerate(test_dataloader): 272 | batch_feature = test_feature[test_adj[index]] 273 | batch_label = test_label_s[index] 274 | batch_data = (batch_feature, batch_label) 275 | 276 | pred, gcnfeat = model(batch_data, output_feat=True) 277 | pred_list.append(pred.data.cpu().numpy()) 278 | gcnfeat_list.append(gcnfeat.data.cpu().numpy()) 279 | print("time use %.4f"%(time.time()-beg_time)) 280 | 281 | pred_list = np.concatenate(pred_list) 282 | gcnfeat_arr = np.vstack(gcnfeat_list) 283 | gcnfeat_arr = gcnfeat_arr / np.linalg.norm(gcnfeat_arr, axis=1, keepdims=True) 284 | tag = os.path.splitext(os.path.basename(cfg.resume_path))[0] 285 | 286 | print("stat") 287 | k_hat, k_arr = pred_list * scale_max, test_label_k 288 | 289 | # add this clip before eval 290 | k_hat = np.round(k_hat) 291 | k_hat[np.where(k_hat < 1)[0]] = 1 292 | k_hat[np.where(k_hat > scale_max)[0]] = scale_max 293 | np.save(os.path.join(cfg.outpath, 'k_infer_pred'), np.round(k_hat)) 294 | 295 | print("time use", time.time() - beg_time) 296 | -------------------------------------------------------------------------------- /AND/util/confidence.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | from itertools import groupby 7 | 8 | __all__ = ['density', 'confidence', 'confidence_to_peaks'] 9 | 10 | 11 | def density(dists, radius=0.3, use_weight=True): 12 | row, col = (dists < radius).nonzero() 13 | 14 | num, _ = dists.shape 15 | if use_weight: 16 | density = np.zeros((num, ), dtype=np.float32) 17 | for r, c in zip(row, col): 18 | density[r] += 1 - dists[r, c] 19 | else: 20 | density = np.zeros((num, ), dtype=np.int32) 21 | for k, g in groupby(row): 22 | density[k] = len(list(g)) 23 | return density 24 | 25 | 26 | def s_nbr(dists, nbrs, idx2lb, **kwargs): 27 | ''' use supervised confidence defined on neigborhood 28 | ''' 29 | num, _ = dists.shape 30 | conf = np.zeros((num, ), dtype=np.float32) 31 | contain_neg = 0 32 | for i, (nbr, dist) in enumerate(zip(nbrs, dists)): 33 | lb = idx2lb[i] 34 | pos, neg = 0, 0 35 | for j, n in enumerate(nbr): 36 | if idx2lb[n] == lb: 37 | pos += 1 - dist[j] 38 | else: 39 | neg += 1 - dist[j] 40 | conf[i] = pos - neg 41 | if neg > 0: 42 | contain_neg += 1 43 | print('#contain_neg:', contain_neg) 44 | conf /= np.abs(conf).max() 45 | return conf 46 | 47 | 48 | def s_nbr_size_norm(dists, nbrs, idx2lb, **kwargs): 49 | ''' use supervised confidence defined on neigborhood (norm by size) 50 | ''' 51 | num, _ = dists.shape 52 | conf = np.zeros((num, ), dtype=np.float32) 53 | contain_neg = 0 54 | max_size = 0 55 | for i, (nbr, dist) in enumerate(zip(nbrs, dists)): 56 | size = 0 57 | pos, neg = 0, 0 58 | lb = idx2lb[i] 59 | for j, n in enumerate(nbr): 60 | if idx2lb[n] == lb: 61 | pos += 1 - dist[j] 62 | else: 63 | neg += 1 - dist[j] 64 | size += 1 65 | conf[i] = pos - neg 66 | max_size = max(max_size, size) 67 | if neg > 0: 68 | contain_neg += 1 69 | print('#contain_neg:', contain_neg) 70 | print('max_size: {}'.format(max_size)) 71 | conf /= max_size 72 | return conf 73 | 74 | 75 | def s_avg(feats, idx2lb, lb2idxs, **kwargs): 76 | ''' use average similarity of intra-nodes 77 | ''' 78 | num = len(idx2lb) 79 | conf = np.zeros((num, ), dtype=np.float32) 80 | for i in range(num): 81 | lb = idx2lb[i] 82 | idxs = lb2idxs[lb] 83 | idxs.remove(i) 84 | if len(idxs) == 0: 85 | continue 86 | feat = feats[i, :] 87 | conf[i] = feat.dot(feats[idxs, :].T).mean() 88 | eps = 1e-6 89 | assert -1 - eps <= conf.min() <= conf.max( 90 | ) <= 1 + eps, "min: {}, max: {}".format(conf.min(), conf.max()) 91 | return conf 92 | 93 | 94 | def s_center(feats, idx2lb, lb2idxs, **kwargs): 95 | ''' use average similarity of intra-nodes 96 | ''' 97 | num = len(idx2lb) 98 | conf = np.zeros((num, ), dtype=np.float32) 99 | for i in range(num): 100 | lb = idx2lb[i] 101 | idxs = lb2idxs[lb] 102 | if len(idxs) == 0: 103 | continue 104 | feat = feats[i, :] 105 | feat_center = feats[idxs, :].mean(axis=0) 106 | conf[i] = feat.dot(feat_center.T) 107 | eps = 1e-6 108 | assert -1 - eps <= conf.min() <= conf.max( 109 | ) <= 1 + eps, "min: {}, max: {}".format(conf.min(), conf.max()) 110 | return conf 111 | 112 | 113 | def confidence(metric='s_nbr', **kwargs): 114 | metric2func = { 115 | 's_nbr': s_nbr, 116 | 's_nbr_size_norm': s_nbr_size_norm, 117 | 's_avg': s_avg, 118 | 's_center': s_center, 119 | } 120 | if metric in metric2func: 121 | func = metric2func[metric] 122 | else: 123 | raise KeyError('Only support confidence metircs: {}'.format( 124 | metric2func.keys())) 125 | 126 | conf = func(**kwargs) 127 | return conf 128 | 129 | 130 | def confidence_to_peaks(dists, nbrs, confidence, max_conn=1): 131 | # Note that dists has been sorted in ascending order 132 | assert dists.shape[0] == confidence.shape[0] 133 | assert dists.shape == nbrs.shape 134 | 135 | num, _ = dists.shape 136 | dist2peak = {i: [] for i in range(num)} 137 | peaks = {i: [] for i in range(num)} 138 | 139 | for i, nbr in tqdm(enumerate(nbrs)): 140 | nbr_conf = confidence[nbr] 141 | for j, c in enumerate(nbr_conf): 142 | nbr_idx = nbr[j] 143 | if i == nbr_idx or c <= confidence[i]: 144 | continue 145 | dist2peak[i].append(dists[i, j]) 146 | peaks[i].append(nbr_idx) 147 | if len(dist2peak[i]) >= max_conn: 148 | break 149 | return dist2peak, peaks 150 | -------------------------------------------------------------------------------- /AND/util/deduce.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | __all__ = ['peaks_to_labels'] 4 | 5 | 6 | def _find_parent(parent, u): 7 | idx = [] 8 | # parent is a fixed point 9 | while (u != parent[u]): 10 | idx.append(u) 11 | u = parent[u] 12 | for i in idx: 13 | parent[i] = u 14 | return u 15 | 16 | 17 | def edge_to_connected_graph(edges, num): 18 | parent = list(range(num)) 19 | for u, v in edges: 20 | p_u = _find_parent(parent, u) 21 | p_v = _find_parent(parent, v) 22 | parent[p_u] = p_v 23 | 24 | for i in range(num): 25 | parent[i] = _find_parent(parent, i) 26 | remap = {} 27 | uf = np.unique(np.array(parent)) 28 | for i, f in enumerate(uf): 29 | remap[f] = i 30 | cluster_id = np.array([remap[f] for f in parent]) 31 | return cluster_id 32 | 33 | 34 | def peaks_to_edges(peaks, dist2peak, tau): 35 | edges = [] 36 | for src in peaks: 37 | dsts = peaks[src] 38 | dists = dist2peak[src] 39 | for dst, dist in zip(dsts, dists): 40 | if src == dst or dist >= 1 - tau: 41 | continue 42 | edges.append([src, dst]) 43 | return edges 44 | 45 | 46 | def peaks_to_labels(peaks, dist2peak, tau, inst_num): 47 | edges = peaks_to_edges(peaks, dist2peak, tau) 48 | pred_labels = edge_to_connected_graph(edges, inst_num) 49 | return pred_labels 50 | -------------------------------------------------------------------------------- /AND/util/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import inspect 5 | import argparse 6 | import numpy as np 7 | import util.metrics as metrics 8 | import time 9 | 10 | class TextColors: 11 | #HEADER = '\033[35m' 12 | #OKBLUE = '\033[34m' 13 | #OKGREEN = '\033[32m' 14 | #WARNING = '\033[33m' 15 | #FATAL = '\033[31m' 16 | #ENDC = '\033[0m' 17 | #BOLD = '\033[1m' 18 | #UNDERLINE = '\033[4m' 19 | HEADER = '' 20 | OKBLUE = '' 21 | OKGREEN = '' 22 | WARNING = '' 23 | FATAL = '' 24 | ENDC = '' 25 | BOLD = '' 26 | UNDERLINE = '' 27 | 28 | class Timer(): 29 | def __init__(self, name='task', verbose=True): 30 | self.name = name 31 | self.verbose = verbose 32 | 33 | def __enter__(self): 34 | self.start = time.time() 35 | return self 36 | 37 | def __exit__(self, exc_type, exc_val, exc_tb): 38 | if self.verbose: 39 | print('[Time] {} consumes {:.4f} s'.format( 40 | self.name, 41 | time.time() - self.start)) 42 | return exc_type is None 43 | 44 | 45 | def _read_meta(fn): 46 | labels = list() 47 | lb_set = set() 48 | with open(fn) as f: 49 | for lb in f.readlines(): 50 | lb = int(lb.strip()) 51 | labels.append(lb) 52 | lb_set.add(lb) 53 | return np.array(labels), lb_set 54 | 55 | 56 | def evaluate(gt_labels, pred_labels, metric='pairwise'): 57 | if isinstance(gt_labels, str) and isinstance(pred_labels, str): 58 | print('[gt_labels] {}'.format(gt_labels)) 59 | print('[pred_labels] {}'.format(pred_labels)) 60 | gt_labels, gt_lb_set = _read_meta(gt_labels) 61 | pred_labels, pred_lb_set = _read_meta(pred_labels) 62 | 63 | print('#inst: gt({}) vs pred({})'.format(len(gt_labels), 64 | len(pred_labels))) 65 | print('#cls: gt({}) vs pred({})'.format(len(gt_lb_set), 66 | len(pred_lb_set))) 67 | 68 | metric_func = metrics.__dict__[metric] 69 | 70 | with Timer('evaluate with {}{}{}'.format(TextColors.FATAL, metric, 71 | TextColors.ENDC)): 72 | result = metric_func(gt_labels, pred_labels) 73 | if isinstance(result, np.float): 74 | print('{}{}: {:.4f}{}'.format(TextColors.OKGREEN, metric, result, 75 | TextColors.ENDC)) 76 | else: 77 | ave_pre, ave_rec, fscore = result 78 | print('{}ave_pre: {:.4f}, ave_rec: {:.4f}, fscore: {:.4f}{}'.format( 79 | TextColors.OKGREEN, ave_pre, ave_rec, fscore, TextColors.ENDC)) 80 | 81 | 82 | if __name__ == '__main__': 83 | metric_funcs = inspect.getmembers(metrics, inspect.isfunction) 84 | metric_names = [n for n, _ in metric_funcs] 85 | 86 | parser = argparse.ArgumentParser(description='Evaluate Cluster') 87 | parser.add_argument('--gt_labels', type=str, required=True) 88 | parser.add_argument('--pred_labels', type=str, required=True) 89 | parser.add_argument('--metric', default='pairwise', choices=metric_names) 90 | args = parser.parse_args() 91 | 92 | evaluate(args.gt_labels, args.pred_labels, args.metric) 93 | -------------------------------------------------------------------------------- /AND/util/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import division 5 | 6 | import numpy as np 7 | from sklearn.metrics.cluster import (contingency_matrix, 8 | normalized_mutual_info_score) 9 | from sklearn.metrics import (precision_score, recall_score) 10 | 11 | __all__ = ['pairwise', 'bcubed', 'nmi', 'precision', 'recall', 'accuracy'] 12 | 13 | 14 | def _check(gt_labels, pred_labels): 15 | if gt_labels.ndim != 1: 16 | raise ValueError("gt_labels must be 1D: shape is %r" % 17 | (gt_labels.shape, )) 18 | if pred_labels.ndim != 1: 19 | raise ValueError("pred_labels must be 1D: shape is %r" % 20 | (pred_labels.shape, )) 21 | if gt_labels.shape != pred_labels.shape: 22 | raise ValueError( 23 | "gt_labels and pred_labels must have same size, got %d and %d" % 24 | (gt_labels.shape[0], pred_labels.shape[0])) 25 | return gt_labels, pred_labels 26 | 27 | 28 | def _get_lb2idxs(labels): 29 | lb2idxs = {} 30 | for idx, lb in enumerate(labels): 31 | if lb not in lb2idxs: 32 | lb2idxs[lb] = [] 33 | lb2idxs[lb].append(idx) 34 | return lb2idxs 35 | 36 | 37 | def _compute_fscore(pre, rec): 38 | return 2. * pre * rec / (pre + rec) 39 | 40 | 41 | def fowlkes_mallows_score(gt_labels, pred_labels, sparse=True): 42 | ''' The original function is from `sklearn.metrics.fowlkes_mallows_score`. 43 | We output the pairwise precision, pairwise recall and F-measure, 44 | instead of calculating the geometry mean of precision and recall. 45 | ''' 46 | n_samples, = gt_labels.shape 47 | 48 | c = contingency_matrix(gt_labels, pred_labels, sparse=sparse) 49 | tk = np.dot(c.data, c.data) - n_samples 50 | pk = np.sum(np.asarray(c.sum(axis=0)).ravel()**2) - n_samples 51 | qk = np.sum(np.asarray(c.sum(axis=1)).ravel()**2) - n_samples 52 | 53 | avg_pre = tk / pk 54 | avg_rec = tk / qk 55 | fscore = _compute_fscore(avg_pre, avg_rec) 56 | 57 | return avg_pre, avg_rec, fscore 58 | 59 | 60 | def pairwise(gt_labels, pred_labels, sparse=True): 61 | _check(gt_labels, pred_labels) 62 | return fowlkes_mallows_score(gt_labels, pred_labels, sparse) 63 | 64 | 65 | def bcubed(gt_labels, pred_labels): 66 | _check(gt_labels, pred_labels) 67 | 68 | gt_lb2idxs = _get_lb2idxs(gt_labels) 69 | pred_lb2idxs = _get_lb2idxs(pred_labels) 70 | 71 | num_lbs = len(gt_lb2idxs) 72 | pre = np.zeros(num_lbs) 73 | rec = np.zeros(num_lbs) 74 | gt_num = np.zeros(num_lbs) 75 | 76 | for i, gt_idxs in enumerate(gt_lb2idxs.values()): 77 | all_pred_lbs = np.unique(pred_labels[gt_idxs]) 78 | gt_num[i] = len(gt_idxs) 79 | for pred_lb in all_pred_lbs: 80 | pred_idxs = pred_lb2idxs[pred_lb] 81 | n = 1. * np.intersect1d(gt_idxs, pred_idxs).size 82 | pre[i] += n**2 / len(pred_idxs) 83 | rec[i] += n**2 / gt_num[i] 84 | 85 | gt_num = gt_num.sum() 86 | avg_pre = pre.sum() / gt_num 87 | avg_rec = rec.sum() / gt_num 88 | fscore = _compute_fscore(avg_pre, avg_rec) 89 | 90 | return avg_pre, avg_rec, fscore 91 | 92 | 93 | def nmi(gt_labels, pred_labels): 94 | return normalized_mutual_info_score(pred_labels, gt_labels) 95 | 96 | 97 | def precision(gt_labels, pred_labels): 98 | return precision_score(gt_labels, pred_labels) 99 | 100 | 101 | def recall(gt_labels, pred_labels): 102 | return recall_score(gt_labels, pred_labels) 103 | 104 | 105 | def accuracy(gt_labels, pred_labels): 106 | return np.mean(gt_labels == pred_labels) 107 | -------------------------------------------------------------------------------- /GCN/adjacency.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import scipy.sparse as sp 6 | 7 | 8 | def row_normalize(mx): 9 | """Row-normalize sparse matrix""" 10 | rowsum = np.array(mx.sum(1)) 11 | # if rowsum <= 0, keep its previous value 12 | rowsum[rowsum <= 0] = 1 13 | r_inv = np.power(rowsum, -1).flatten() 14 | r_inv[np.isinf(r_inv)] = 0. 15 | r_mat_inv = sp.diags(r_inv) 16 | mx = r_mat_inv.dot(mx) 17 | return mx 18 | 19 | 20 | def build_symmetric_adj(adj, self_loop=True): 21 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 22 | if self_loop: 23 | adj = adj + sp.eye(adj.shape[0]) 24 | return adj 25 | 26 | 27 | def sparse_mx_to_indices_values(sparse_mx): 28 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 29 | indices = np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64) 30 | values = sparse_mx.data 31 | shape = np.array(sparse_mx.shape) 32 | return indices, values, shape 33 | 34 | 35 | def indices_values_to_sparse_tensor(indices, values, shape): 36 | import torch 37 | indices = torch.from_numpy(indices) 38 | values = torch.from_numpy(values) 39 | shape = torch.Size(shape) 40 | return torch.sparse.FloatTensor(indices, values, shape) 41 | 42 | 43 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 44 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 45 | indices, values, shape = sparse_mx_to_indices_values(sparse_mx) 46 | return indices_values_to_sparse_tensor(indices, values, shape) 47 | -------------------------------------------------------------------------------- /GCN/cluster.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import sys 3 | import time 4 | from util.confidence import confidence_to_peaks 5 | from util.deduce import peaks_to_labels 6 | from util.evaluate import evaluate 7 | from multiprocessing import Process, Manager 8 | import numpy as np 9 | import os 10 | import torch 11 | import torch.nn.functional as F 12 | from util.graph import graph_propagation_onecut 13 | from multiprocessing import Pool 14 | from util.deduce import edge_to_connected_graph 15 | 16 | metric_list = ['bcubed', 'pairwise', 'nmi'] 17 | topN = 121 18 | def worker(param): 19 | i, pdict = param 20 | query_nodeid = ngbr_arr[i, 0] 21 | for j in range(1, dist_arr.shape[1]): 22 | doc_nodeid = ngbr_arr[i, j] 23 | tpl = (query_nodeid, doc_nodeid) 24 | dist = dist_arr[query_nodeid, j] 25 | if dist > cos_dist_thres: 26 | continue 27 | pdict[tpl] = dist 28 | 29 | def format(dist_arr, ngbr_arr): 30 | edge_list, score_list = [], [] 31 | for i in range(dist_arr.shape[0]): 32 | query_nodeid = ngbr_arr[i, 0] 33 | for j in range(1, dist_arr.shape[1]): 34 | doc_nodeid = ngbr_arr[i, j] 35 | tpl = (query_nodeid, doc_nodeid) 36 | score = 1 - dist_arr[query_nodeid, j] 37 | if score < cos_sim_thres: 38 | continue 39 | edge_list.append(tpl) 40 | score_list.append(score) 41 | edge_arr, score_arr = np.array(edge_list), np.array(score_list) 42 | return edge_arr, score_arr 43 | 44 | def clusters2labels(clusters, n_nodes): 45 | labels = (-1)* np.ones((n_nodes,)) 46 | for ci, c in enumerate(clusters): 47 | for xid in c: 48 | labels[xid.name] = ci 49 | 50 | cnt = len(clusters) 51 | idx_list = np.where(labels < 0)[0] 52 | for idx in idx_list: 53 | labels[idx] = cnt 54 | cnt += 1 55 | assert np.sum(labels<0) < 1 56 | return labels 57 | 58 | def disjoint_set_onecut(sim_dict, thres, num): 59 | edge_arr = [] 60 | for edge, score in sim_dict.items(): 61 | if score < thres: 62 | continue 63 | edge_arr.append(edge) 64 | pred_arr = edge_to_connected_graph(edge_arr, num) 65 | return pred_arr 66 | 67 | def get_eval(cos_sim_thres): 68 | pred_arr = disjoint_set_onecut(sim_dict, cos_sim_thres, len(gt_arr)) 69 | print("now is %s done"%cos_sim_thres) 70 | res_str = "" 71 | for metric in metric_list: 72 | res_str += evaluate(gt_arr, pred_arr, metric) 73 | res_str += "\n" 74 | return res_str 75 | 76 | if __name__ == "__main__": 77 | Ifile, Dfile, gtfile = sys.argv[1], sys.argv[2], sys.argv[3] 78 | 79 | gt_arr = np.load(gtfile) 80 | nbr_arr = np.load(Ifile).astype(np.int32)[:, :topN] 81 | dist_arr = np.load(Dfile)[:, :topN] 82 | sim_dict = {} 83 | for query_nodeid, (nbr, dist) in enumerate(zip(nbr_arr, dist_arr)): 84 | for j, doc_nodeid in enumerate(nbr): # 从0开始,包括自己 85 | if query_nodeid < doc_nodeid: 86 | tpl = (query_nodeid, doc_nodeid) 87 | else: 88 | tpl = (doc_nodeid, query_nodeid) 89 | sim_dict[tpl] = 1 - dist[j] 90 | 91 | thres = 0.96 92 | print('now sim thres %.2f'%sim_thres) 93 | pred_arr = disjoint_set_onecut(sim_dict, sim_thres, len(gt_arr)) 94 | for metric in metric_list: 95 | print(evaluate(gt_arr, pred_arr, metric)) 96 | -------------------------------------------------------------------------------- /GCN/config.yml: -------------------------------------------------------------------------------- 1 | # model 2 | feat_dim: 256 3 | nhid: 512 4 | nclass: 8573 5 | 6 | # optimizer 7 | lr: 0.1 #0.01 8 | sgd_momentum: 0.9 9 | sgd_weight_decay: 0.00001 10 | lr_step : [0.5, 0.8, 0.9] 11 | factor: 0.1 12 | total_step: 35000 # 13 | cuda: True 14 | fp16: False 15 | batchsize: 1 # 16 | warmup_step: 1024 # 17 | 18 | # output 19 | save_freq: 5000 # 20 | log_freq: 1 21 | # resume 22 | resume_path: 23 | -------------------------------------------------------------------------------- /GCN/net/gat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | from dgl.nn.pytorch import SAGEConv 6 | from .optim_modules import BallClusterLearningLoss, ClusterLoss 7 | import torch.nn.functional as F 8 | 9 | class GCN_V(nn.Module): 10 | def __init__(self, feature_dim, nhid, nclass, dropout=0, losstype='allall', margin=1., pweight=4., pmargin=1.0): 11 | super(GCN_V, self).__init__() 12 | 13 | self.sage1 = SAGEConv(feature_dim, nhid, aggregator_type='gcn', activation=F.relu) 14 | self.sage2 = SAGEConv(nhid, nhid, aggregator_type='gcn', activation=F.relu) 15 | 16 | self.nclass = nclass 17 | self.fc = nn.Sequential(nn.Linear(nhid, nhid), nn.PReLU(nhid)) 18 | self.loss = torch.nn.MSELoss() 19 | self.bclloss = ClusterLoss(losstype=losstype, margin=margin, alpha_pos=pweight, pmargin=pmargin) 20 | 21 | def forward(self, data, output_feat=False, return_loss=False): 22 | assert not output_feat or not return_loss 23 | x, block_list, label, idlabel = data[0], data[1], data[2], data[3] 24 | 25 | # layer1 26 | gcnfeat = self.sage1(block_list[0], x) 27 | gcnfeat = F.normalize(gcnfeat, p=2, dim=1) 28 | 29 | # layer2 30 | gcnfeat = self.sage2(block_list[1], gcnfeat) 31 | 32 | # layer3 33 | fcfeat = self.fc(gcnfeat) 34 | fcfeat = F.normalize(fcfeat, dim=1) 35 | 36 | if output_feat: 37 | return fcfeat, gcnfeat 38 | 39 | if return_loss: 40 | bclloss_dict = self.bclloss(fcfeat, label) 41 | return bclloss_dict 42 | 43 | return fcfeat 44 | -------------------------------------------------------------------------------- /GCN/net/optim_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes for all models and loss functions for clustering. 3 | 4 | FC --> ReLU --> BN --> Dropout --> FC 5 | """ 6 | 7 | import warnings 8 | import numpy as np 9 | 10 | # Torch 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | # Local imports 16 | #import utils 17 | #import config 18 | #import lorentz 19 | 20 | 21 | def sqeuclidean_pdist(x, y=None): 22 | """Fast and efficient implementation of ||X - Y||^2 = ||X||^2 + ||Y||^2 - 2 X^T Y 23 | Input: x is a Nxd matrix 24 | y is an optional Mxd matirx 25 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] 26 | if y is not given then use 'y=x'. 27 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 28 | """ 29 | 30 | x_norm = (x**2).sum(1).unsqueeze(1) 31 | if y is not None: 32 | y_t = torch.transpose(y, 0, 1) 33 | y_norm = (y**2).sum(1).unsqueeze(0) 34 | else: 35 | y_t = torch.transpose(x, 0, 1) 36 | y_norm = x_norm.squeeze().unsqueeze(0) 37 | 38 | dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) 39 | # get rid of NaNs 40 | dist[torch.isnan(dist)] = 0. 41 | # clamp negative stuff to 0 42 | dist = torch.clamp(dist, 0., np.inf) 43 | # ensure diagonal is 0 44 | if y is None: 45 | dist[dist == torch.diag(dist)] = 0. 46 | 47 | return dist 48 | 49 | # ============================================================================ # 50 | # LOSS FUNCTIONS # 51 | # ============================================================================ # 52 | 53 | def get_pos_loss(decision_mat, label_mat, beta, k=1): 54 | if label_mat.sum() == 0: 55 | print('cut pos 0') 56 | return torch.tensor(0.) 57 | # decision is not confidence, which is always to positive in 0~1, and not contain the class information 58 | # decision is a real value and the class infomation in the sign 59 | # this is decision mat, for pos sample, the smaller of the val the harder of the case 60 | decision_arr = decision_mat[label_mat].topk(k=k, largest=False)[0] 61 | loss = F.relu(beta - decision_arr) 62 | loss = loss.mean() 63 | return loss 64 | 65 | def get_neg_loss(decision_mat, label_mat, beta, k=1): 66 | if label_mat.sum() == 0: 67 | print('cut neg 0') 68 | return torch.tensor(0.) 69 | # this is decision mat, for neg sample, the larger of the val, the harder of the case 70 | decision_arr = -1 * decision_mat[label_mat].topk(k=k, largest=True)[0] 71 | loss = F.relu(beta - decision_arr) 72 | loss = loss.mean() 73 | return loss 74 | 75 | def cosine_sim(x, y): 76 | return torch.mm(x, y.T) 77 | 78 | def euclidean_dist(x, y): 79 | m, n = x.size(0), y.size(0) 80 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 81 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 82 | dist = xx + yy 83 | dist.addmm_(1, -2, x, y.t()) 84 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 85 | return dist 86 | 87 | class ClusterLoss(nn.Module): 88 | def __init__(self, beta_pos=0.5, beta_neg=0.5, alpha_pos=4., alpha_neg=1., gamma_eps=0.05, losstype='allall', margin=1., pmargin=1.): 89 | super(ClusterLoss, self).__init__() 90 | #self.gamma_eps = gamma_eps 91 | self.alpha_pos = alpha_pos 92 | self.alpha_neg = alpha_neg 93 | 94 | self.beta_pos = nn.Parameter(torch.tensor(beta_pos)) 95 | #self.beta_neg = nn.Parameter(torch.tensor(beta_neg)) 96 | self.losstype = losstype 97 | self.margin = margin 98 | self.pmargin = pmargin 99 | 100 | def forward(self, X, labels): 101 | #beta_pos = F.softplus(self.beta_pos) 102 | #beta_neg = F.softplus(self.beta_neg) 103 | #beta_neg = beta_pos 104 | beta_pos = self.pmargin 105 | beta_neg = self.margin 106 | 107 | #X_copy = X.clone().detach() 108 | #decision_mat = (X.unsqueeze(1) - X_copy.unsqueeze(0)).pow(2).sum(2).sqrt() # euclidean distance 109 | #decision_mat = euclidean_dist(X, X) 110 | decision_mat = cosine_sim(X, X) 111 | label_mat = (labels.unsqueeze(0) == labels.unsqueeze(1)) 112 | #print("beta pos", beta_pos.item(), 'beta neg', beta_neg.item()) 113 | print("losstype", self.losstype, "margin", self.margin, 'pweight', self.alpha_pos, 'pmargin', self.pmargin) 114 | 115 | neg_label_mat = (1-label_mat.float()).bool() 116 | if self.losstype == 'maxmax': 117 | pos_loss = get_pos_loss(decision_mat, label_mat, beta_pos, k=1) 118 | neg_loss = get_neg_loss(decision_mat, neg_label_mat, beta_neg, k=1) 119 | elif self.losstype == 'allmax': 120 | pos_loss = get_pos_loss(decision_mat, label_mat, beta_pos, k=label_mat.sum().item()) 121 | neg_loss = get_neg_loss(decision_mat, neg_label_mat, beta_neg, k=1) 122 | elif self.losstype == 'allall': 123 | pos_loss = get_pos_loss(decision_mat, label_mat, beta_pos, k=label_mat.sum().item()) 124 | neg_loss = get_neg_loss(decision_mat, neg_label_mat, beta_neg, k=neg_label_mat.sum().item()) 125 | elif self.losstype == 'alltopk': 126 | pos_loss = get_pos_loss(decision_mat, label_mat, beta_pos, k=label_mat.sum().item()) 127 | neg_loss = get_neg_loss(decision_mat, neg_label_mat, beta_neg, k=min(len(labels), neg_label_mat.sum().item()) ) 128 | else: 129 | raise ValueError('loss type %s not implement'%self.lossstype) 130 | 131 | losses = {'ctrd_pos': pos_loss * self.alpha_pos, 'ctrd_neg': neg_loss * self.alpha_neg} 132 | return losses 133 | 134 | class BallClusterLearningLoss(nn.Module): 135 | """Final BCL method 136 | space: 'sqeuclidean' or 'lorentz' 137 | init_bias: initialize bias to this value 138 | temperature: sampling temperature (decayed in main training loop) 139 | beta: Lorentz beta for comparison in Lorentz space 140 | """ 141 | 142 | def __init__(self, gpuid=0, space='sqeuclidean', l2norm=True, gamma_eps=0.05, 143 | init_bias=0.1, learn_bias=True, beta=0.01, alpha_pos=4., alpha_neg=1., mult_bias=0.): 144 | """Initialize 145 | """ 146 | super(BallClusterLearningLoss, self).__init__() 147 | self.space = space 148 | self.learn_bias = learn_bias 149 | self.l2norm = l2norm 150 | self.beta = beta 151 | self.gamma_eps = gamma_eps 152 | self.alpha_pos = alpha_pos 153 | self.alpha_neg = alpha_neg 154 | self.mult_bias = mult_bias 155 | self.gpuid = gpuid 156 | 157 | self.h_bias = nn.Parameter(torch.tensor(init_bias)) 158 | self.bias = F.softplus(self.h_bias) 159 | 160 | def forward(self, Xemb, labels): 161 | """ 162 | Xemb: N x D, N features, D embedding dimension 163 | labels: ground-truth cluster indices 164 | NOTE: labels are not necessarily ordered indices, just unique ones, don't use for indexing! 165 | """ 166 | 167 | self.bias = F.softplus(self.h_bias) 168 | 169 | # get unique labels to loop over clusters 170 | unique_labels = labels.unique() # torch vector on cuda 171 | K = unique_labels.numel() 172 | N = Xemb.size(0) 173 | 174 | # collect centroids, cluster-assignment matrix, and positive cluster index 175 | centroids = [] 176 | pos_idx = -1 * torch.ones_like(labels) # N vector, each in [0 .. K-1] 177 | clst_assignments = torch.zeros(N, K).to(self.gpuid) # NxK {0, 1} matrix 178 | for k, clid in enumerate(unique_labels): 179 | idx = labels == clid 180 | # assign all samples with cluster clid as k 181 | pos_idx[idx] = k 182 | clst_assignments[idx, k] = 1 183 | # collect all features 184 | Xclst = Xemb[idx, :] 185 | centroid = Xclst.mean(0) 186 | centroid = centroid / centroid.norm() 187 | # collect centroids 188 | centroids.append(centroid) 189 | centroids = torch.stack(centroids, dim=0) 190 | 191 | # pairwise distances between all embeddings of the batch and the centroids 192 | XC_dist = (Xemb.unsqueeze(1) - centroids.unsqueeze(0)).pow(2).sum(2) 193 | 194 | # add bias to the distances indexed appropriately 195 | pos_bias = self.bias 196 | neg_bias = 9 * self.bias + self.gamma_eps 197 | 198 | # add bias and use "cross-entropy" loss on pos_idx 199 | bias_adds = clst_assignments * pos_bias + (1 - clst_assignments) * neg_bias 200 | final_distance = (-XC_dist + bias_adds) * 0.1 201 | # when not using bias, just ignore 202 | if self.bias == 0.: 203 | final_distance = -XC_dist * 0.1 204 | 205 | # make sure positive distances are below the pos-bias 206 | pos_distances = XC_dist.gather(1, pos_idx.unsqueeze(1)) 207 | pos_sample_loss = F.relu(pos_distances - pos_bias) 208 | 209 | # make sure negative distances are more than neg-bias 210 | #avg_neg_distances = XC_dist[1 - clst_assignments.byte()].view(N, K-1).mean(1) 211 | #min_neg_distances = XC_dist[1 - clst_assignments.byte()].view(N, K-1).min(1)[0] # [0] returns values not indices 212 | if len(XC_dist[(1 - clst_assignments).bool()]) == 0: 213 | neg_sample_loss = torch.Tensor([0]) 214 | print("===== neg sample is 0") 215 | else: 216 | min_neg_distances = XC_dist[(1 - clst_assignments).bool()].view(N, K-1).min(1)[0] # [0] returns values not indices 217 | neg_sample_loss = F.relu(neg_bias - min_neg_distances) 218 | 219 | pos_loss = pos_sample_loss.mean() 220 | neg_loss = neg_sample_loss.mean() 221 | 222 | losses = {'ctrd_pos': pos_loss * self.alpha_pos, 'ctrd_neg': neg_loss * self.alpha_neg} 223 | #losses = pos_loss * self.alpha_pos + neg_loss * self.alpha_neg 224 | 225 | return losses 226 | 227 | 228 | class PrototypicalLoss(nn.Module): 229 | """Prototypical network like loss with bias 230 | p_ik = exp(- d(x^k_i, c^k) + b) / (exp(- d(x^k_i, c^k) + b) + sum_j exp(- d(x^k_i, c^j) + 2b)) 231 | Loss = -mean_k( mean_i ( -log p_ik )) 232 | space: 'sqeuclidean' or 'lorentz' 233 | init_bias: initialize bias to this value 234 | temperature: sampling temperature (decayed in main training loop) 235 | beta: Lorentz beta for comparison in Lorentz space 236 | """ 237 | 238 | def __init__(self, device, space='sqeuclidean', l2norm=False, gamma_eps=0.05, 239 | init_bias=0., learn_bias=False, beta=0.01, alpha_pos=1., alpha_neg=1., mult_bias=0.): 240 | """Initialize 241 | """ 242 | super(PrototypicalLoss, self).__init__() 243 | self.device = device 244 | self.space = space 245 | self.learn_bias = learn_bias 246 | self.l2norm = l2norm 247 | self.beta = beta 248 | self.gamma_eps = gamma_eps 249 | self.alpha_pos = alpha_pos 250 | self.alpha_neg = alpha_neg 251 | self.mult_bias = mult_bias 252 | 253 | self.bias = torch.tensor(init_bias).to(self.device) 254 | 255 | def forward(self, Xemb, scores, labels): 256 | """ 257 | Xemb: N x D, N features, D embedding dimension 258 | labels: ground-truth cluster indices 259 | NOTE: labels are not necessarily ordered indices, just unique ones, don't use for indexing! 260 | """ 261 | 262 | unique_labels = labels.unique() # torch vector on cuda 263 | K = unique_labels.numel() 264 | N = Xemb.size(0) 265 | 266 | # collect centroids, cluster-assignment matrix, and positive cluster index 267 | centroids = [] 268 | pos_idx = -1 * torch.ones_like(labels) # N vector, each in [0 .. K-1] 269 | clst_assignments = torch.zeros(N, K).to(self.device) # NxK {0, 1} matrix 270 | for k, clid in enumerate(unique_labels): 271 | idx = labels == clid 272 | # assign all samples with cluster clid as k 273 | pos_idx[idx] = k 274 | clst_assignments[idx, k] = 1 275 | # collect all features 276 | Xclst = Xemb[idx, :] 277 | centroid = Xclst.mean(0) 278 | # collect centroids 279 | centroids.append(centroid) 280 | centroids = torch.stack(centroids, dim=0) 281 | 282 | # pairwise distances between all embeddings of the batch and the centroids 283 | XC_dist = (Xemb.unsqueeze(1) - centroids.unsqueeze(0)).pow(2).sum(2) 284 | 285 | # add bias to the distances indexed appropriately 286 | pos_bias = self.bias 287 | neg_bias = 9 * self.bias + self.gamma_eps 288 | final_distance = -XC_dist * 0.1 289 | 290 | # compute cross-entropy 291 | pro_sample_loss = F.cross_entropy(final_distance, pos_idx, reduction='none') 292 | 293 | # do mean of means to get final loss value 294 | pro_loss = torch.tensor(0.).to(self.device) 295 | for clid in unique_labels: 296 | pro_loss += pro_sample_loss[labels == clid].mean() 297 | pro_loss /= K 298 | 299 | losses = {'ctrd_pro': pro_loss} 300 | 301 | return losses 302 | 303 | 304 | class ContrastiveLoss(nn.Module): 305 | """ 306 | In the original paper http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf 307 | Y = 0 for similar pairs ("positive") 308 | Y = 1 for dissimilar pairs ("negatives") 309 | L(Y, X1, X2) = (1 - Y) * 0.5 * D^2 + Y * 0.5 * (max(0, m - D))^2 310 | NOTE: distance is in Euclidean space, not sqeuclidean! 311 | """ 312 | 313 | def __init__(self, device, l2norm=True, 314 | init_bias=1., learn_bias=True): 315 | """Initialize 316 | """ 317 | super(ContrastiveLoss, self).__init__() 318 | self.device = device 319 | self.learn_bias = learn_bias 320 | self.l2norm = l2norm 321 | 322 | self.h_bias = nn.Parameter(torch.tensor(init_bias)) 323 | self.bias = F.softplus(self.h_bias) 324 | 325 | def forward(self, Xemb, scores, labels): 326 | """ 327 | Xemb: N x D, N features, D embedding dimension 328 | labels: ground-truth cluster indices 329 | """ 330 | 331 | self.bias = F.softplus(self.h_bias) 332 | 333 | N = Xemb.size(0) 334 | match = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() # a NxN {0,1} matrix 335 | 336 | ### generate positive pairs, and pull corresponding features 337 | diag_mask = 1 - torch.eye(N).to(self.device) 338 | pos_idx = (diag_mask * match).nonzero() 339 | X1_pos = Xemb.index_select(0, pos_idx[:, 0]) 340 | X2_pos = Xemb.index_select(0, pos_idx[:, 1]) 341 | 342 | ### generate random negatives 343 | neg_idx = [] 344 | while len(neg_idx) < X1_pos.size(0): # match pairs for negatives 345 | idx = torch.randint(N, (2,)).long() 346 | if match[idx[0], idx[1]] == 0: 347 | neg_idx.append(idx) 348 | neg_idx = torch.stack(neg_idx).to(self.device) 349 | 350 | X1_neg = Xemb.index_select(0, neg_idx[:, 0]) 351 | X2_neg = Xemb.index_select(0, neg_idx[:, 1]) 352 | 353 | # compute distances (Euclidean!) 354 | pos_distances_sq = ((X1_pos - X2_pos) ** 2).sum(1) 355 | neg_distances = ((X1_neg - X2_neg) ** 2).sum(1).sqrt() 356 | 357 | # Loss = 0.5 * pos_distances_sq + 0.5 * (max(0, m - neg_distances))^2 358 | pos_loss = 0.5 * pos_distances_sq.mean() 359 | neg_loss = 0.5 * (F.relu(self.bias - neg_distances) ** 2).mean() 360 | 361 | return {'cont_pos': pos_loss, 'cont_neg': neg_loss} 362 | 363 | 364 | class TripletLoss(nn.Module): 365 | """ 366 | In the FaceNet paper https://arxiv.org/pdf/1503.03832.pdf 367 | L = max(0, d+ - d- + alpha) 368 | NOTE: distance is in sqeuclidean space! 369 | """ 370 | 371 | def __init__(self, device, space='sqeuclidean', l2norm=True, 372 | init_bias=0.5, learn_bias=False): 373 | """Initialize 374 | """ 375 | super(TripletLoss, self).__init__() 376 | self.device = device 377 | self.space = space 378 | self.learn_bias = learn_bias 379 | self.l2norm = l2norm 380 | 381 | self.bias = torch.tensor(init_bias).to(self.device) 382 | 383 | def forward(self, Xemb, scores, labels): 384 | """ 385 | Xemb: N x D, N features, D embedding dimension 386 | labels: ground-truth cluster indices 387 | """ 388 | 389 | N = Xemb.size(0) 390 | match = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() # a NxN {0,1} matrix 391 | 392 | ### generate positive pairs, and pull corresponding features 393 | diag_mask = 1 - torch.eye(N).to(self.device) 394 | pos_idx = (diag_mask * match).nonzero() 395 | anc_idx = pos_idx[:, 0] 396 | pos_idx = pos_idx[:, 1] 397 | 398 | ### generate negatives for the same anchors as positive 399 | neg_idx = torch.zeros_like(pos_idx).long() 400 | for k in range(pos_idx.size(0)): 401 | this_negs = torch.nonzero(1 - match[pos_idx[k]]).squeeze() 402 | neg_idx[k] = this_negs[torch.randperm(this_negs.size(0))][0] 403 | 404 | X_anc = Xemb.index_select(0, anc_idx) 405 | X_pos = Xemb.index_select(0, pos_idx) 406 | X_neg = Xemb.index_select(0, neg_idx) 407 | 408 | # compute distances 409 | pos_distances = ((X_anc - X_pos) ** 2).sum(1) 410 | neg_distances = ((X_anc - X_neg) ** 2).sum(1) 411 | 412 | # loss 413 | loss = F.relu(self.bias + pos_distances - neg_distances).mean() 414 | 415 | return {'trip': loss} 416 | 417 | 418 | class LogisticDiscriminantLoss(nn.Module): 419 | """Pairwise distance between samples, using logistic regression 420 | https://hal.inria.fr/file/index/docid/439290/filename/GVS09.pdf 421 | space: 'sqeuclidean' or 'lorentz' 422 | init_bias: initialize bias to this value (or as set by radius) 423 | temperature: sampling temperature (decayed in main training loop) 424 | with_ball: loss being used along with ball loss? 425 | beta: Lorentz beta for comparison in Lorentz space 426 | """ 427 | 428 | def __init__(self, device, space='sqeuclidean', 429 | init_bias=0.5, learn_bias=True, temperature=1., beta=0.01, 430 | with_ball=False): 431 | """Initialize 432 | """ 433 | super(LogisticDiscriminantLoss, self).__init__() 434 | self.device = device 435 | self.space = space 436 | self.temperature = temperature 437 | 438 | self.bias = nn.Parameter(torch.tensor(init_bias)) 439 | 440 | def forward(self, Xemb, scores, labels): 441 | """ 442 | Xemb: N x D, N features, D embedding dimension 443 | labels: ground-truth cluster indices 444 | """ 445 | 446 | N = Xemb.size(0) 447 | match = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() # a NxN {0,1} matrix 448 | 449 | ### generate positive pairs, and pull corresponding features 450 | diag_mask = 1 - torch.eye(N).to(self.device) 451 | pos_idx = (diag_mask * match).nonzero() 452 | X1_pos = Xemb.index_select(0, pos_idx[:, 0]) 453 | X2_pos = Xemb.index_select(0, pos_idx[:, 1]) 454 | 455 | ### generate random negatives 456 | neg_idx = [] 457 | while len(neg_idx) < X1_pos.size(0): # match pairs for negatives 458 | idx = torch.randint(N, (2,)).long() 459 | if match[idx[0], idx[1]] == 0: 460 | neg_idx.append(idx) 461 | neg_idx = torch.stack(neg_idx).to(self.device) 462 | 463 | X1_neg = Xemb.index_select(0, neg_idx[:, 0]) 464 | X2_neg = Xemb.index_select(0, neg_idx[:, 1]) 465 | 466 | # compute distances 467 | pos_distances = ((X1_pos - X2_pos) ** 2).sum(1) 468 | neg_distances = ((X1_neg - X2_neg) ** 2).sum(1) 469 | 470 | # Loss = -y log(p) - (1-y) log(1-p) 471 | pos_logprobs = torch.sigmoid((self.bias - pos_distances)/self.temperature) 472 | neg_logprobs = torch.sigmoid((self.bias - neg_distances)/self.temperature) 473 | pos_loss = -(pos_logprobs).log().mean() 474 | neg_loss = -(1 - neg_logprobs).log().mean() 475 | 476 | return {'ldml_pos': pos_loss, 'ldml_neg': neg_loss} 477 | 478 | 479 | -------------------------------------------------------------------------------- /GCN/train.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | from __future__ import division 3 | import torch 4 | import torch.optim as optim 5 | from adjacency import sparse_mx_to_torch_sparse_tensor 6 | from net.gat import GCN_V 7 | #from net.gcn_v import GCN_V 8 | #from net.softmaxloss import GCN_V 9 | import yaml 10 | from easydict import EasyDict 11 | from tensorboardX import SummaryWriter 12 | import numpy as np 13 | import scipy.sparse as sp 14 | import time 15 | import sys 16 | import os 17 | import apex 18 | from apex import amp 19 | import dgl 20 | import math 21 | import argparse 22 | import pprint 23 | from abc import ABC, abstractproperty, abstractmethod 24 | from collections.abc import Mapping 25 | 26 | class Collator(ABC): 27 | @abstractproperty 28 | def dataset(self): 29 | raise NotImplementedError 30 | 31 | @abstractmethod 32 | def collate(self, items): 33 | raise NotImplementedError 34 | 35 | class multigraph_NodeCollator(Collator): 36 | def __init__(self, order_graph, ngbr_graph, nids, block_sampler): 37 | self.order_graph = order_graph 38 | self.ngbr_graph = ngbr_graph 39 | self.nids = nids 40 | self.block_sampler = block_sampler 41 | self._dataset = nids 42 | 43 | @property 44 | def dataset(self): 45 | return self._dataset 46 | 47 | def collate(self, items): 48 | # use collate to fasten 49 | #blocks = self.block_sampler.sample_blocks(self.g, items) 50 | 51 | seed_node0 = items 52 | frontier0 = dgl.sampling.sample_neighbors(self.order_graph, seed_node0, 128, replace=False) # sample 128 from 256 53 | block0 = dgl.to_block(frontier0, seed_node0) 54 | 55 | seed_node1 = {ntype: block0.srcnodes[ntype].data[dgl.NID] for ntype in block0.srctypes} 56 | frontier1 = dgl.sampling.sample_neighbors(self.ngbr_graph, seed_node1, 80, replace=False) 57 | block1 = dgl.to_block(frontier1, seed_node1) 58 | block1.create_format_() 59 | 60 | seed_node2 = {ntype: block1.srcnodes[ntype].data[dgl.NID] for ntype in block1.srctypes} 61 | frontier2 = dgl.sampling.sample_neighbors(self.ngbr_graph, seed_node2, 80, replace=False) 62 | block2 = dgl.to_block(frontier2, seed_node2) 63 | block2.create_format_() 64 | 65 | blocks = [block2, block1] 66 | input_nodes = blocks[0].srcdata[dgl.NID] 67 | output_nodes = blocks[-1].dstdata[dgl.NID] 68 | return input_nodes, output_nodes, blocks 69 | 70 | class AverageMeter(object): 71 | def __init__(self): 72 | self.val = 0 73 | self.avg = 0 74 | self.sum = 0 75 | self.count = 0 76 | def reset(self): 77 | self.val = 0 78 | self.avg = 0 79 | self.sum = 0 80 | self.count = 0 81 | def update(self, val, n=1): 82 | self.val = val 83 | self.sum += val * n 84 | self.count += n 85 | self.avg = float(self.sum) / (self.count + 1e-10) 86 | 87 | def row_normalize(mx): 88 | """Row-normalize sparse matrix""" 89 | rowsum = np.array(mx.sum(1)) 90 | # if rowsum <= 0, keep its previous value 91 | rowsum[rowsum <= 0] = 1 92 | r_inv = np.power(rowsum, -1).flatten() 93 | r_inv[np.isinf(r_inv)] = 0. 94 | r_mat_inv = sp.diags(r_inv) 95 | mx = r_mat_inv.dot(mx) 96 | return mx 97 | 98 | class Timer(): 99 | def __init__(self, name='task', verbose=True): 100 | self.name = name 101 | self.verbose = verbose 102 | 103 | def __enter__(self): 104 | print('[begin {}]'.format(self.name)) 105 | self.start = time.time() 106 | return self 107 | 108 | def __exit__(self, exc_type, exc_val, exc_tb): 109 | if self.verbose: 110 | print('[done {}] use {:.3f} s'.format(self.name, time.time() - self.start)) 111 | return exc_type is None 112 | 113 | def adjust_lr(cur_epoch, optimizer, cfg): 114 | if cur_epoch not in cfg.step_number: 115 | return 116 | ind = cfg.step_number.index(cur_epoch) 117 | for each in optimizer.param_groups: 118 | each['lr'] = cfg.lr * cfg.factor ** (ind+1) 119 | 120 | def cos_lr(current_step, optimizer, cfg): 121 | if current_step < cfg.warmup_step: 122 | rate = 1.0 * current_step / cfg.warmup_step 123 | lr = cfg.lr * rate 124 | else: 125 | n1 = cfg.total_step - cfg.warmup_step 126 | n2 = current_step - cfg.warmup_step 127 | rate = (1 + math.cos(math.pi * n2 / n1)) / 2 128 | lr = cfg.lr * rate 129 | for each in optimizer.param_groups: 130 | each['lr'] = lr 131 | 132 | if __name__ == "__main__": 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument('--config_file', type=str) 135 | parser.add_argument('--outpath', type=str) 136 | parser.add_argument('--phase', type=str) 137 | parser.add_argument('--train_featfile', type=str) 138 | parser.add_argument('--train_adjfile', type=str) 139 | parser.add_argument('--train_orderadjfile', type=str) 140 | parser.add_argument('--train_labelfile', type=str) 141 | parser.add_argument('--test_featfile', type=str) 142 | parser.add_argument('--test_adjfile', type=str) 143 | parser.add_argument('--test_labelfile', type=str) 144 | parser.add_argument('--resume_path', type=str) 145 | parser.add_argument('--losstype', type=str) 146 | parser.add_argument('--margin', type=float) 147 | parser.add_argument('--pweight', type=float) 148 | parser.add_argument('--pmargin', type=float) 149 | parser.add_argument('--topk', type=int) 150 | args = parser.parse_args() 151 | 152 | beg_time = time.time() 153 | config = yaml.load(open(args.config_file, "r"), Loader=yaml.FullLoader) 154 | cfg = EasyDict(config) 155 | cfg.step_number = [int(r * cfg.total_step) for r in cfg.lr_step] 156 | 157 | # force assignment 158 | for key, value in args._get_kwargs(): 159 | cfg[key] = value 160 | #cfg[list(dict(train_adjfile=train_adjfile).keys())[0]] = train_adjfile 161 | #cfg[list(dict(train_labelfile=train_labelfile).keys())[0]] = train_labelfile 162 | #cfg[list(dict(test_adjfile=test_adjfile).keys())[0]] = test_adjfile 163 | #cfg[list(dict(test_labelfile=test_labelfile).keys())[0]] = test_labelfile 164 | cfg.var = EasyDict() 165 | print("train hyper parameter list") 166 | pprint.pprint(cfg) 167 | 168 | 169 | # get model 170 | model = GCN_V(feature_dim=cfg.feat_dim, nhid=cfg.nhid, nclass=cfg.nclass, dropout=0., losstype=cfg.losstype, margin=cfg.margin, 171 | pweight=cfg.pweight, pmargin=cfg.pmargin) 172 | 173 | # get dataset 174 | with Timer('load data'): 175 | if cfg.phase == 'train': 176 | featfile, adjfile, labelfile = cfg.train_featfile, cfg.train_adjfile, cfg.train_labelfile 177 | order_adj = sp.load_npz(cfg.train_orderadjfile).astype(np.float32) 178 | order_graph = dgl.from_scipy(order_adj) 179 | else: 180 | featfile, adjfile, labelfile = cfg.test_featfile, cfg.test_adjfile, cfg.test_labelfile 181 | features = np.load(featfile) 182 | features = features / np.linalg.norm(features, axis=1, keepdims=True) 183 | adj = sp.load_npz(adjfile).astype(np.float32) 184 | graph = dgl.from_scipy(adj) 185 | label_arr = np.load(labelfile) 186 | features = torch.FloatTensor(features) 187 | #adj = sparse_mx_to_torch_sparse_tensor(adj) 188 | label_cpu = torch.LongTensor(label_arr) 189 | if cfg.cuda: 190 | model.cuda() 191 | features = features.cuda() 192 | #adj = adj.cuda() 193 | labels = label_cpu.cuda() 194 | #data = (features, adj, labels) 195 | 196 | # get train 197 | if cfg.phase == 'train': 198 | # get optimizer 199 | pretrain_pool = True 200 | pretrain_pool = False 201 | if pretrain_pool: 202 | pool_weight, net_weight = [], [] 203 | for k, v in model.named_parameters(): 204 | if 'pool.' in k: 205 | pool_weight += [v] 206 | else: 207 | net_weight += [v] 208 | param_list = [{'params': pool_weight}, {'params': net_weight, 'lr': 0.}] 209 | optimizer = optim.SGD(param_list, cfg.lr, momentum=cfg.sgd_momentum, weight_decay=cfg.sgd_weight_decay) 210 | else: 211 | optimizer = optim.SGD(model.parameters(), cfg.lr, momentum=cfg.sgd_momentum, weight_decay=cfg.sgd_weight_decay) 212 | 213 | if cfg.fp16: 214 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1", keep_batchnorm_fp32=None, loss_scale='dynamic') 215 | 216 | beg_step = 0 217 | if cfg.resume_path != None: 218 | beg_step = int(os.path.splitext(os.path.basename(cfg.resume_path))[0].split('_')[1]) 219 | with Timer('resume model from %s'%cfg.resume_path): 220 | ckpt = torch.load(cfg.resume_path, map_location='cpu') 221 | model.load_state_dict(ckpt['state_dict']) 222 | 223 | totalloss_meter = AverageMeter() 224 | bclloss_pos_meter = AverageMeter() 225 | bclloss_neg_meter = AverageMeter() 226 | keeploss_meter = AverageMeter() 227 | before_edge_num_meter = AverageMeter() 228 | after_edge_num_meter = AverageMeter() 229 | acc_meter = AverageMeter() 230 | prec_meter = AverageMeter() 231 | recall_meter = AverageMeter() 232 | leftprec_meter = AverageMeter() 233 | writer = SummaryWriter(os.path.join(cfg.outpath), filename_suffix='') 234 | #sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) 235 | #sampler = dgl.dataloading.MultiLayerNeighborSampler([cfg.topk, cfg.topk, 128]) 236 | #sampler = dgl.dataloading.MultiLayerNeighborSampler([None, None, 128]) 237 | #dataloader = dgl.dataloading.NodeDataLoader( 238 | # order_graph, 239 | # np.arange(order_graph.number_of_nodes()), 240 | # sampler, 241 | # batch_size=cfg.batchsize, 242 | # shuffle=True, 243 | # drop_last=False, 244 | # num_workers=4) 245 | #sampler = dgl.dataloading.MultiLayerNeighborSampler([128]) 246 | sampler = None 247 | collator = multigraph_NodeCollator(order_graph, graph, np.arange(len(features)), sampler) # v4 248 | dataloader = torch.utils.data.DataLoader( 249 | dataset=collator.dataset, 250 | batch_size=cfg.batchsize, 251 | shuffle=True, 252 | num_workers=4, 253 | pin_memory=True, 254 | drop_last=False, 255 | collate_fn=collator.collate, 256 | ) 257 | 258 | current_step = beg_step 259 | break_flag = False 260 | while 1: 261 | #adjust_lr(current_step, optimizer.param_groups, cfg) 262 | if break_flag: 263 | break 264 | for _, (src_idx, dst_idx, blocks) in enumerate(dataloader): 265 | if current_step > cfg.total_step: 266 | break_flag = True 267 | break 268 | iter_begtime = time.time() 269 | current_step += 1 270 | cos_lr(current_step, optimizer, cfg) 271 | #src_idx = blocks[0].srcdata[dgl.NID].numpy() 272 | #dst_idx = blocks[-1].dstdata[dgl.NID].numpy() 273 | 274 | batch_feature = features[src_idx, :] 275 | #batch_adj = sparse_mx_to_torch_sparse_tensor(adj[src_idx, :][:, src_idx]).cuda() 276 | #batch_adj = torch.from_numpy( row_normalize(adj[src_idx, :][:, src_idx]).todense() ).cuda() 277 | batch_block = [block.to(0) for block in blocks] # need not row normalize, because the attention weight edge 278 | batch_label = labels[dst_idx] 279 | batch_idlabel = labels[src_idx] 280 | batch_data = (batch_feature, batch_block, batch_label, batch_idlabel) 281 | bclloss_dict = model(batch_data, return_loss=True) 282 | loss = bclloss_dict['ctrd_pos'] + bclloss_dict['ctrd_neg'] 283 | 284 | optimizer.zero_grad() 285 | if cfg.fp16: 286 | with amp.scale_loss(loss, optimizer) as scaled_loss: 287 | scaled_loss.backward() 288 | else: 289 | loss.backward() 290 | optimizer.step() 291 | 292 | totalloss_meter.update(loss.item()) 293 | bclloss_pos_meter.update(bclloss_dict['ctrd_pos'].item()) 294 | bclloss_neg_meter.update(bclloss_dict['ctrd_neg'].item()) 295 | 296 | writer.add_scalar('loss/total', loss.item(), global_step=current_step) 297 | writer.add_scalar('loss/bcl_pos', bclloss_dict['ctrd_pos'].item(), global_step=current_step) 298 | writer.add_scalar('loss/bcl_neg', bclloss_dict['ctrd_neg'].item(), global_step=current_step) 299 | if current_step % cfg.log_freq == 0: 300 | log = "step{}/{}, iter_time:{:.3f}, lr:{:.4f}, loss:{:.4f}({:.4f}), bclloss_pos:{:.8f}({:.8f}), bclloss_neg:{:.4f}({:.4f}) ".format(current_step, cfg.total_step, time.time()-iter_begtime, optimizer.param_groups[0]['lr'], totalloss_meter.val, totalloss_meter.avg, bclloss_pos_meter.val, bclloss_pos_meter.avg, bclloss_neg_meter.val, bclloss_neg_meter.avg) 301 | print(log) 302 | if (current_step+1) % cfg.save_freq == 0 and current_step > 0: 303 | torch.save({'state_dict' : model.state_dict(), 'step': current_step+1}, 304 | os.path.join(cfg.outpath, "ckpt_%s.pth"%(current_step+1))) 305 | writer.close() 306 | else: 307 | with Timer('resume model from %s'%cfg.resume_path): 308 | ckpt = torch.load(cfg.resume_path, map_location='cpu') 309 | model.load_state_dict(ckpt['state_dict']) 310 | model.eval() 311 | 312 | sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2) 313 | dataloader = dgl.dataloading.NodeDataLoader( 314 | graph, 315 | np.arange(graph.number_of_nodes()), 316 | sampler, 317 | batch_size=1024, 318 | shuffle=False, 319 | drop_last=False, 320 | num_workers=16) 321 | 322 | gcnfeat_list, fcfeat_list = [], [] 323 | leftprec_meter = AverageMeter() 324 | beg_time = time.time() 325 | for step, (input_nodes, output_nodes, blocks) in enumerate(dataloader): 326 | src_idx = blocks[0].srcdata[dgl.NID].numpy() 327 | dst_idx = blocks[-1].dstdata[dgl.NID].numpy() 328 | #zip(block.srcnodes(), block.srcdata[dgl.NID]) 329 | #zip(block.dstnodes(), block.dstdata[dgl.NID]) 330 | batch_feature = features[src_idx, :] 331 | #batch_adj = sparse_mx_to_torch_sparse_tensor(adj[src_idx, :][:, src_idx]).cuda() 332 | 333 | #batch_adj = torch.from_numpy(adj[src_idx, :][:, src_idx].todense()).cuda() # no sample and no row normalize again 334 | batch_block = [block.to(0) for block in blocks] 335 | batch_label = labels[dst_idx] 336 | batch_idlabel = labels[src_idx] 337 | batch_data = (batch_feature, batch_block, batch_label, batch_idlabel) 338 | #fcfeat, gcnfeat, before_edge_num, after_edge_num, acc_rate, prec, recall, left_prec = model(batch_data, output_feat=True) 339 | fcfeat, gcnfeat = model(batch_data, output_feat=True) 340 | 341 | fcfeat_list.append(fcfeat.data.cpu().numpy()) 342 | gcnfeat_list.append(gcnfeat.data.cpu().numpy()) 343 | #leftprec_meter.update(left_prec) 344 | #if step % 1 == 0: 345 | # log = "step %s/%s"%(step, len(dataloader)) 346 | # print(log) 347 | print("time use %.4f"%(time.time()-beg_time)) 348 | 349 | fcfeat_arr = np.vstack(fcfeat_list) 350 | gcnfeat_arr = np.vstack(gcnfeat_list) 351 | fcfeat_arr = fcfeat_arr / np.linalg.norm(fcfeat_arr, axis=1, keepdims=True) 352 | gcnfeat_arr = gcnfeat_arr / np.linalg.norm(gcnfeat_arr, axis=1, keepdims=True) 353 | tag = os.path.splitext(os.path.basename(cfg.resume_path))[0] 354 | np.save(os.path.join(cfg.outpath, 'fcfeat_%s'%tag), fcfeat_arr) 355 | np.save(os.path.join(cfg.outpath, 'gcnfeat_%s'%tag), gcnfeat_arr) 356 | 357 | print("time use", time.time() - beg_time) 358 | -------------------------------------------------------------------------------- /GCN/util/confidence.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | from itertools import groupby 7 | 8 | __all__ = ['density', 'confidence', 'confidence_to_peaks'] 9 | 10 | 11 | def density(dists, radius=0.3, use_weight=True): 12 | row, col = (dists < radius).nonzero() 13 | 14 | num, _ = dists.shape 15 | if use_weight: 16 | density = np.zeros((num, ), dtype=np.float32) 17 | for r, c in zip(row, col): 18 | density[r] += 1 - dists[r, c] 19 | else: 20 | density = np.zeros((num, ), dtype=np.int32) 21 | for k, g in groupby(row): 22 | density[k] = len(list(g)) 23 | return density 24 | 25 | 26 | def s_nbr(dists, nbrs, idx2lb, **kwargs): 27 | ''' use supervised confidence defined on neigborhood 28 | ''' 29 | num, _ = dists.shape 30 | conf = np.zeros((num, ), dtype=np.float32) 31 | contain_neg = 0 32 | for i, (nbr, dist) in enumerate(zip(nbrs, dists)): 33 | lb = idx2lb[i] 34 | pos, neg = 0, 0 35 | for j, n in enumerate(nbr): 36 | if idx2lb[n] == lb: 37 | pos += 1 - dist[j] 38 | else: 39 | neg += 1 - dist[j] 40 | conf[i] = pos - neg 41 | if neg > 0: 42 | contain_neg += 1 43 | print('#contain_neg:', contain_neg) 44 | conf /= np.abs(conf).max() 45 | return conf 46 | 47 | 48 | def s_nbr_size_norm(dists, nbrs, idx2lb, **kwargs): 49 | ''' use supervised confidence defined on neigborhood (norm by size) 50 | ''' 51 | num, _ = dists.shape 52 | conf = np.zeros((num, ), dtype=np.float32) 53 | contain_neg = 0 54 | max_size = 0 55 | for i, (nbr, dist) in enumerate(zip(nbrs, dists)): 56 | size = 0 57 | pos, neg = 0, 0 58 | lb = idx2lb[i] 59 | for j, n in enumerate(nbr): 60 | if idx2lb[n] == lb: 61 | pos += 1 - dist[j] 62 | else: 63 | neg += 1 - dist[j] 64 | size += 1 65 | conf[i] = pos - neg 66 | max_size = max(max_size, size) 67 | if neg > 0: 68 | contain_neg += 1 69 | print('#contain_neg:', contain_neg) 70 | print('max_size: {}'.format(max_size)) 71 | conf /= max_size 72 | return conf 73 | 74 | 75 | def s_avg(feats, idx2lb, lb2idxs, **kwargs): 76 | ''' use average similarity of intra-nodes 77 | ''' 78 | num = len(idx2lb) 79 | conf = np.zeros((num, ), dtype=np.float32) 80 | for i in range(num): 81 | lb = idx2lb[i] 82 | idxs = lb2idxs[lb] 83 | idxs.remove(i) 84 | if len(idxs) == 0: 85 | continue 86 | feat = feats[i, :] 87 | conf[i] = feat.dot(feats[idxs, :].T).mean() 88 | eps = 1e-6 89 | assert -1 - eps <= conf.min() <= conf.max( 90 | ) <= 1 + eps, "min: {}, max: {}".format(conf.min(), conf.max()) 91 | return conf 92 | 93 | 94 | def s_center(feats, idx2lb, lb2idxs, **kwargs): 95 | ''' use average similarity of intra-nodes 96 | ''' 97 | num = len(idx2lb) 98 | conf = np.zeros((num, ), dtype=np.float32) 99 | for i in range(num): 100 | lb = idx2lb[i] 101 | idxs = lb2idxs[lb] 102 | if len(idxs) == 0: 103 | continue 104 | feat = feats[i, :] 105 | feat_center = feats[idxs, :].mean(axis=0) 106 | conf[i] = feat.dot(feat_center.T) 107 | eps = 1e-6 108 | assert -1 - eps <= conf.min() <= conf.max( 109 | ) <= 1 + eps, "min: {}, max: {}".format(conf.min(), conf.max()) 110 | return conf 111 | 112 | 113 | def confidence(metric='s_nbr', **kwargs): 114 | metric2func = { 115 | 's_nbr': s_nbr, 116 | 's_nbr_size_norm': s_nbr_size_norm, 117 | 's_avg': s_avg, 118 | 's_center': s_center, 119 | } 120 | if metric in metric2func: 121 | func = metric2func[metric] 122 | else: 123 | raise KeyError('Only support confidence metircs: {}'.format( 124 | metric2func.keys())) 125 | 126 | conf = func(**kwargs) 127 | return conf 128 | 129 | 130 | def confidence_to_peaks(dists, nbrs, confidence, max_conn=1): 131 | # Note that dists has been sorted in ascending order 132 | assert dists.shape[0] == confidence.shape[0] 133 | assert dists.shape == nbrs.shape 134 | 135 | num, _ = dists.shape 136 | dist2peak = {i: [] for i in range(num)} 137 | peaks = {i: [] for i in range(num)} 138 | 139 | for i, nbr in tqdm(enumerate(nbrs)): 140 | nbr_conf = confidence[nbr] 141 | for j, c in enumerate(nbr_conf): 142 | nbr_idx = nbr[j] 143 | if i == nbr_idx or c <= confidence[i]: 144 | continue 145 | dist2peak[i].append(dists[i, j]) 146 | peaks[i].append(nbr_idx) 147 | if len(dist2peak[i]) >= max_conn: 148 | break 149 | return dist2peak, peaks 150 | -------------------------------------------------------------------------------- /GCN/util/deduce.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | __all__ = ['peaks_to_labels'] 4 | 5 | 6 | def _find_parent(parent, u): 7 | idx = [] 8 | # parent is a fixed point 9 | while (u != parent[u]): 10 | idx.append(u) 11 | u = parent[u] 12 | for i in idx: 13 | parent[i] = u 14 | return u 15 | 16 | 17 | def edge_to_connected_graph(edges, num): 18 | parent = list(range(num)) 19 | for u, v in edges: 20 | p_u = _find_parent(parent, u) 21 | p_v = _find_parent(parent, v) 22 | parent[p_u] = p_v 23 | 24 | for i in range(num): 25 | parent[i] = _find_parent(parent, i) 26 | remap = {} 27 | uf = np.unique(np.array(parent)) 28 | for i, f in enumerate(uf): 29 | remap[f] = i 30 | cluster_id = np.array([remap[f] for f in parent]) 31 | return cluster_id 32 | 33 | 34 | def peaks_to_edges(peaks, dist2peak, tau): 35 | edges = [] 36 | for src in peaks: 37 | dsts = peaks[src] 38 | dists = dist2peak[src] 39 | for dst, dist in zip(dsts, dists): 40 | if src == dst or dist >= 1 - tau: 41 | continue 42 | edges.append([src, dst]) 43 | return edges 44 | 45 | 46 | def peaks_to_labels(peaks, dist2peak, tau, inst_num): 47 | edges = peaks_to_edges(peaks, dist2peak, tau) 48 | pred_labels = edge_to_connected_graph(edges, inst_num) 49 | return pred_labels 50 | -------------------------------------------------------------------------------- /GCN/util/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import inspect 5 | import argparse 6 | import numpy as np 7 | import util.metrics as metrics 8 | import time 9 | 10 | class TextColors: 11 | #HEADER = '\033[35m' 12 | #OKBLUE = '\033[34m' 13 | #OKGREEN = '\033[32m' 14 | #WARNING = '\033[33m' 15 | #FATAL = '\033[31m' 16 | #ENDC = '\033[0m' 17 | #BOLD = '\033[1m' 18 | #UNDERLINE = '\033[4m' 19 | HEADER = '' 20 | OKBLUE = '' 21 | OKGREEN = '' 22 | WARNING = '' 23 | FATAL = '' 24 | ENDC = '' 25 | BOLD = '' 26 | UNDERLINE = '' 27 | 28 | class Timer(): 29 | def __init__(self, name='task', verbose=True): 30 | self.name = name 31 | self.verbose = verbose 32 | 33 | def __enter__(self): 34 | self.start = time.time() 35 | return self 36 | 37 | def __exit__(self, exc_type, exc_val, exc_tb): 38 | if self.verbose: 39 | print('[Time] {} consumes {:.4f} s'.format( 40 | self.name, 41 | time.time() - self.start)) 42 | return exc_type is None 43 | 44 | 45 | def _read_meta(fn): 46 | labels = list() 47 | lb_set = set() 48 | with open(fn) as f: 49 | for lb in f.readlines(): 50 | lb = int(lb.strip()) 51 | labels.append(lb) 52 | lb_set.add(lb) 53 | return np.array(labels), lb_set 54 | 55 | 56 | def evaluate(gt_labels, pred_labels, metric='pairwise'): 57 | if isinstance(gt_labels, str) and isinstance(pred_labels, str): 58 | print('[gt_labels] {}'.format(gt_labels)) 59 | print('[pred_labels] {}'.format(pred_labels)) 60 | gt_labels, gt_lb_set = _read_meta(gt_labels) 61 | pred_labels, pred_lb_set = _read_meta(pred_labels) 62 | 63 | print('#inst: gt({}) vs pred({})'.format(len(gt_labels), 64 | len(pred_labels))) 65 | print('#cls: gt({}) vs pred({})'.format(len(gt_lb_set), 66 | len(pred_lb_set))) 67 | 68 | metric_func = metrics.__dict__[metric] 69 | 70 | with Timer('evaluate with {}{}{}'.format(TextColors.FATAL, metric, 71 | TextColors.ENDC), verbose=False): 72 | result = metric_func(gt_labels, pred_labels) 73 | if isinstance(result, np.float): 74 | #print('{}{}: {:.4f}{}'.format(TextColors.OKGREEN, metric, result, TextColors.ENDC)) 75 | res_str = '{}{}: {:.4f}{}'.format(TextColors.OKGREEN, metric, result, TextColors.ENDC) 76 | else: 77 | from collections import Counter 78 | singleton_num = len( list( filter(lambda x: x==1, Counter(pred_labels).values()) ) ) 79 | ave_pre, ave_rec, fscore = result 80 | #print('{}ave_pre: {:.4f}, ave_rec: {:.4f}, fscore: {:.4f}{}, cluster_num: {}, singleton_num: {}'.format( 81 | # TextColors.OKGREEN, ave_pre, ave_rec, fscore, TextColors.ENDC, len(np.unique(pred_labels)), singleton_num)) 82 | res_str = '{}{}: ave_pre: {:.4f}, ave_rec: {:.4f}, fscore: {:.4f}{}, cluster_num: {}, singleton_num: {}'.format( 83 | TextColors.OKGREEN, metric, ave_pre, ave_rec, fscore, TextColors.ENDC, len(np.unique(pred_labels)), singleton_num) 84 | #return ave_pre, ave_rec, fscore 85 | return res_str 86 | 87 | 88 | if __name__ == '__main__': 89 | metric_funcs = inspect.getmembers(metrics, inspect.isfunction) 90 | metric_names = [n for n, _ in metric_funcs] 91 | 92 | parser = argparse.ArgumentParser(description='Evaluate Cluster') 93 | parser.add_argument('--gt_labels', type=str, required=True) 94 | parser.add_argument('--pred_labels', type=str, required=True) 95 | parser.add_argument('--metric', default='pairwise', choices=metric_names) 96 | args = parser.parse_args() 97 | 98 | evaluate(args.gt_labels, args.pred_labels, args.metric) 99 | -------------------------------------------------------------------------------- /GCN/util/graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import time 7 | 8 | class Data(object): 9 | def __init__(self, name): 10 | self.__name = name 11 | self.__links = set() 12 | 13 | @property 14 | def name(self): 15 | return self.__name 16 | 17 | @property 18 | def links(self): 19 | return set(self.__links) 20 | 21 | def add_link(self, other, score): 22 | self.__links.add(other) 23 | other.__links.add(self) 24 | 25 | def connected_components(nodes, score_dict, th): 26 | ''' 27 | conventional connected components searching 28 | ''' 29 | result = [] 30 | nodes = set(nodes) 31 | while nodes: 32 | n = nodes.pop() 33 | group = {n} 34 | queue = [n] 35 | while queue: 36 | n = queue.pop(0) 37 | if th is not None: 38 | neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th} 39 | else: 40 | neighbors = n.links 41 | neighbors.difference_update(group) 42 | nodes.difference_update(neighbors) 43 | group.update(neighbors) 44 | queue.extend(neighbors) 45 | result.append(group) 46 | return result 47 | 48 | def connected_components_constraint(nodes, max_sz, score_dict=None, th=None): 49 | ''' 50 | only use edges whose scores are above `th` 51 | if a component is larger than `max_sz`, all the nodes in this component are added into `remain` and returned for next iteration. 52 | ''' 53 | result = [] 54 | remain = set() 55 | nodes = set(nodes) 56 | while nodes: 57 | n = nodes.pop() 58 | group = {n} 59 | queue = [n] 60 | valid = True 61 | while queue: 62 | n = queue.pop(0) 63 | if th is not None: 64 | neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th} 65 | else: 66 | neighbors = n.links 67 | neighbors.difference_update(group) 68 | nodes.difference_update(neighbors) 69 | group.update(neighbors) 70 | queue.extend(neighbors) 71 | if len(group) > max_sz or len(remain.intersection(neighbors)) > 0: 72 | # if this group is larger than `max_sz`, add the nodes into `remain` 73 | valid = False 74 | remain.update(group) 75 | break 76 | if valid: # if this group is smaller than or equal to `max_sz`, finalize it. 77 | result.append(group) 78 | return result, remain 79 | 80 | 81 | def graph_propagation_naive(edges, score, th): 82 | 83 | edges = np.sort(edges, axis=1) 84 | 85 | # construct graph 86 | score_dict = {} # score lookup table 87 | for i,e in enumerate(edges): 88 | score_dict[e[0], e[1]] = score[i] 89 | 90 | nodes = np.sort(np.unique(edges.flatten())) 91 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int) 92 | mapping[nodes] = np.arange(nodes.shape[0]) 93 | link_idx = mapping[edges] 94 | vertex = [Data(n) for n in nodes] 95 | for l, s in zip(link_idx, score): 96 | vertex[l[0]].add_link(vertex[l[1]], s) 97 | 98 | # first iteration 99 | comps = connected_components(vertex, score_dict,th) 100 | 101 | return comps 102 | 103 | def graph_propagation(edges, score, max_sz, step=0.1, beg_th=0.9, pool=None): 104 | 105 | edges = np.sort(edges, axis=1) 106 | th = score.min() 107 | #th = beg_th 108 | # construct graph 109 | score_dict = {} # score lookup table 110 | if pool is None: 111 | for i,e in enumerate(edges): 112 | score_dict[e[0], e[1]] = score[i] 113 | elif pool == 'avg': 114 | for i,e in enumerate(edges): 115 | #if score_dict.has_key((e[0],e[1])): 116 | if (e[0],e[1]) in score_dict: 117 | score_dict[e[0], e[1]] = 0.5*(score_dict[e[0], e[1]] + score[i]) 118 | else: 119 | score_dict[e[0], e[1]] = score[i] 120 | 121 | elif pool == 'max': 122 | for i,e in enumerate(edges): 123 | #if score_dict.has_key((e[0],e[1])): 124 | if (e[0],e[1]) in score_dict: 125 | score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]] , score[i]) 126 | else: 127 | score_dict[e[0], e[1]] = score[i] 128 | else: 129 | raise ValueError('Pooling operation not supported') 130 | 131 | nodes = np.sort(np.unique(edges.flatten())) 132 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int) 133 | mapping[nodes] = np.arange(nodes.shape[0]) 134 | link_idx = mapping[edges] 135 | vertex = [Data(n) for n in nodes] 136 | for l, s in zip(link_idx, score): 137 | vertex[l[0]].add_link(vertex[l[1]], s) 138 | 139 | # first iteration 140 | comps, remain = connected_components_constraint(vertex, max_sz) 141 | 142 | # iteration 143 | components = comps[:] 144 | while remain: 145 | th = th + (1 - th) * step 146 | comps, remain = connected_components_constraint(remain, max_sz, score_dict, th) 147 | components.extend(comps) 148 | return components 149 | 150 | def graph_propagation_begin(edges, score, max_sz, step=0.1, beg_th=0.9, pool=None): 151 | 152 | th = beg_th 153 | # construct graph 154 | score_dict = {} # score lookup table 155 | if pool is None: 156 | for i,e in enumerate(edges): 157 | score_dict[e[0], e[1]] = score[i] 158 | elif pool == 'avg': 159 | for i,e in enumerate(edges): 160 | #if score_dict.has_key((e[0],e[1])): 161 | if (e[0],e[1]) in score_dict: 162 | score_dict[e[0], e[1]] = 0.5*(score_dict[e[0], e[1]] + score[i]) 163 | else: 164 | score_dict[e[0], e[1]] = score[i] 165 | 166 | elif pool == 'max': 167 | for i,e in enumerate(edges): 168 | #if score_dict.has_key((e[0],e[1])): 169 | if (e[0],e[1]) in score_dict: 170 | score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]] , score[i]) 171 | else: 172 | score_dict[e[0], e[1]] = score[i] 173 | else: 174 | raise ValueError('Pooling operation not supported') 175 | 176 | nodes = np.sort(np.unique(edges.flatten())) 177 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int) 178 | mapping[nodes] = np.arange(nodes.shape[0]) 179 | link_idx = mapping[edges] 180 | vertex = [Data(n) for n in nodes] 181 | for l, s in zip(link_idx, score): 182 | vertex[l[0]].add_link(vertex[l[1]], s) 183 | 184 | # first iteration 185 | comps, remain = connected_components_constraint(vertex, max_sz) 186 | 187 | # iteration 188 | components = comps[:] 189 | while remain: 190 | th = th + (1 - th) * step 191 | comps, remain = connected_components_constraint(remain, max_sz, score_dict, th) 192 | components.extend(comps) 193 | return components 194 | 195 | def graph_propagation_onecut(edges, score, max_sz, th=0.4, pool=None): 196 | edges = np.sort(edges, axis=1) 197 | 198 | # construct graph 199 | score_dict = {} # score lookup table 200 | if pool is None: 201 | for i,e in enumerate(edges): 202 | score_dict[e[0], e[1]] = score[i] 203 | elif pool == 'avg': 204 | for i,e in enumerate(edges): 205 | #if score_dict.has_key((e[0],e[1])): 206 | if (e[0],e[1]) in score_dict: 207 | score_dict[e[0], e[1]] = 0.5*(score_dict[e[0], e[1]] + score[i]) 208 | else: 209 | score_dict[e[0], e[1]] = score[i] 210 | elif pool == 'max': 211 | for i,e in enumerate(edges): 212 | #if score_dict.has_key((e[0],e[1])): 213 | if (e[0],e[1]) in score_dict: 214 | score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]] , score[i]) 215 | else: 216 | score_dict[e[0], e[1]] = score[i] 217 | else: 218 | raise ValueError('Pooling operation not supported') 219 | 220 | nodes = np.sort(np.unique(edges.flatten())) 221 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int) 222 | mapping[nodes] = np.arange(nodes.shape[0]) 223 | link_idx = mapping[edges] 224 | vertex = [Data(n) for n in nodes] 225 | for l, s in zip(link_idx, score): 226 | vertex[l[0]].add_link(vertex[l[1]], s) 227 | 228 | comps, remain = connected_components_constraint(vertex, max_sz, score_dict, th) 229 | assert len(remain) == 0 230 | return comps 231 | 232 | def graph_propagation_soft(edges, score, max_sz, step=0.1, **kwargs): 233 | 234 | edges = np.sort(edges, axis=1) 235 | th = score.min() 236 | 237 | # construct graph 238 | score_dict = {} # score lookup table 239 | for i,e in enumerate(edges): 240 | score_dict[e[0], e[1]] = score[i] 241 | 242 | nodes = np.sort(np.unique(edges.flatten())) 243 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int) 244 | mapping[nodes] = np.arange(nodes.shape[0]) 245 | link_idx = mapping[edges] 246 | vertex = [Data(n) for n in nodes] 247 | for l, s in zip(link_idx, score): 248 | vertex[l[0]].add_link(vertex[l[1]], s) 249 | 250 | # first iteration 251 | comps, remain = connected_components_constraint(vertex, max_sz) 252 | first_vertex_idx = np.array([mapping[n.name] for c in comps for n in c]) 253 | fusion_vertex_idx = np.setdiff1d(np.arange(nodes.shape[0]), first_vertex_idx, assume_unique=True) 254 | # iteration 255 | components = comps[:] 256 | while remain: 257 | th = th + (1 - th) * step 258 | comps, remain = connected_components_constraint(remain, max_sz, score_dict, th) 259 | components.extend(comps) 260 | label_dict = {} 261 | for i,c in enumerate(components): 262 | for n in c: 263 | label_dict[n.name] = i 264 | print('Propagation ...') 265 | prop_vertex = [vertex[idx] for idx in fusion_vertex_idx] 266 | label, label_fusion = diffusion(prop_vertex, label_dict, score_dict, **kwargs) 267 | return label, label_fusion 268 | 269 | def diffusion(vertex, label, score_dict, max_depth=5, weight_decay=0.6, normalize=True): 270 | class BFSNode(): 271 | def __init__(self, node, depth, value): 272 | self.node = node 273 | self.depth = depth 274 | self.value = value 275 | 276 | label_fusion = {} 277 | for name in label.keys(): 278 | label_fusion[name] = {label[name]: 1.0} 279 | prog = 0 280 | prog_step = len(vertex) // 20 281 | start = time.time() 282 | for root in vertex: 283 | if prog % prog_step == 0: 284 | print("progress: {} / {}, elapsed time: {}".format(prog, len(vertex), time.time() - start)) 285 | prog += 1 286 | #queue = {[root, 0, 1.0]} 287 | queue = {BFSNode(root, 0, 1.0)} 288 | visited = [root.name] 289 | root_label = label[root.name] 290 | while queue: 291 | curr = queue.pop() 292 | if curr.depth >= max_depth: # pruning 293 | continue 294 | neighbors = curr.node.links 295 | tmp_value = [] 296 | tmp_neighbor = [] 297 | for n in neighbors: 298 | if n.name not in visited: 299 | sub_value = score_dict[tuple(sorted([curr.node.name, n.name]))] * weight_decay * curr.value 300 | tmp_value.append(sub_value) 301 | tmp_neighbor.append(n) 302 | if root_label not in label_fusion[n.name].keys(): 303 | label_fusion[n.name][root_label] = sub_value 304 | else: 305 | label_fusion[n.name][root_label] += sub_value 306 | visited.append(n.name) 307 | #queue.add([n, curr.depth+1, sub_value]) 308 | sortidx = np.argsort(tmp_value)[::-1] 309 | for si in sortidx: 310 | queue.add(BFSNode(tmp_neighbor[si], curr.depth+1, tmp_value[si])) 311 | if normalize: 312 | for name in label_fusion.keys(): 313 | summ = sum(label_fusion[name].values()) 314 | for k in label_fusion[name].keys(): 315 | label_fusion[name][k] /= summ 316 | return label, label_fusion 317 | -------------------------------------------------------------------------------- /GCN/util/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import division 5 | 6 | import numpy as np 7 | from sklearn.metrics.cluster import (contingency_matrix, 8 | normalized_mutual_info_score) 9 | from sklearn.metrics import (precision_score, recall_score) 10 | 11 | __all__ = ['pairwise', 'bcubed', 'nmi', 'precision', 'recall', 'accuracy'] 12 | 13 | 14 | def _check(gt_labels, pred_labels): 15 | if gt_labels.ndim != 1: 16 | raise ValueError("gt_labels must be 1D: shape is %r" % 17 | (gt_labels.shape, )) 18 | if pred_labels.ndim != 1: 19 | raise ValueError("pred_labels must be 1D: shape is %r" % 20 | (pred_labels.shape, )) 21 | if gt_labels.shape != pred_labels.shape: 22 | raise ValueError( 23 | "gt_labels and pred_labels must have same size, got %d and %d" % 24 | (gt_labels.shape[0], pred_labels.shape[0])) 25 | return gt_labels, pred_labels 26 | 27 | 28 | def _get_lb2idxs(labels): 29 | lb2idxs = {} 30 | for idx, lb in enumerate(labels): 31 | if lb not in lb2idxs: 32 | lb2idxs[lb] = [] 33 | lb2idxs[lb].append(idx) 34 | return lb2idxs 35 | 36 | 37 | def _compute_fscore(pre, rec): 38 | return 2. * pre * rec / (pre + rec) 39 | 40 | 41 | def fowlkes_mallows_score(gt_labels, pred_labels, sparse=True): 42 | ''' The original function is from `sklearn.metrics.fowlkes_mallows_score`. 43 | We output the pairwise precision, pairwise recall and F-measure, 44 | instead of calculating the geometry mean of precision and recall. 45 | ''' 46 | n_samples, = gt_labels.shape 47 | 48 | c = contingency_matrix(gt_labels, pred_labels, sparse=sparse) 49 | tk = np.dot(c.data, c.data) - n_samples 50 | pk = np.sum(np.asarray(c.sum(axis=0)).ravel()**2) - n_samples 51 | qk = np.sum(np.asarray(c.sum(axis=1)).ravel()**2) - n_samples 52 | 53 | avg_pre = tk / pk 54 | avg_rec = tk / qk 55 | fscore = _compute_fscore(avg_pre, avg_rec) 56 | 57 | return avg_pre, avg_rec, fscore 58 | 59 | 60 | def pairwise(gt_labels, pred_labels, sparse=True): 61 | _check(gt_labels, pred_labels) 62 | return fowlkes_mallows_score(gt_labels, pred_labels, sparse) 63 | 64 | 65 | def bcubed(gt_labels, pred_labels): 66 | _check(gt_labels, pred_labels) 67 | 68 | gt_lb2idxs = _get_lb2idxs(gt_labels) 69 | pred_lb2idxs = _get_lb2idxs(pred_labels) 70 | 71 | num_lbs = len(gt_lb2idxs) 72 | pre = np.zeros(num_lbs) 73 | rec = np.zeros(num_lbs) 74 | gt_num = np.zeros(num_lbs) 75 | 76 | for i, gt_idxs in enumerate(gt_lb2idxs.values()): 77 | all_pred_lbs = np.unique(pred_labels[gt_idxs]) 78 | gt_num[i] = len(gt_idxs) 79 | for pred_lb in all_pred_lbs: 80 | pred_idxs = pred_lb2idxs[pred_lb] 81 | n = 1. * np.intersect1d(gt_idxs, pred_idxs).size 82 | pre[i] += n**2 / len(pred_idxs) 83 | rec[i] += n**2 / gt_num[i] 84 | 85 | gt_num = gt_num.sum() 86 | avg_pre = pre.sum() / gt_num 87 | avg_rec = rec.sum() / gt_num 88 | fscore = _compute_fscore(avg_pre, avg_rec) 89 | 90 | return avg_pre, avg_rec, fscore 91 | 92 | 93 | def nmi(gt_labels, pred_labels): 94 | return normalized_mutual_info_score(pred_labels, gt_labels) 95 | 96 | 97 | def precision(gt_labels, pred_labels): 98 | return precision_score(gt_labels, pred_labels) 99 | 100 | 101 | def recall(gt_labels, pred_labels): 102 | return recall_score(gt_labels, pred_labels) 103 | 104 | 105 | def accuracy(gt_labels, pred_labels): 106 | return np.mean(gt_labels == pred_labels) 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Thomas-wyh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ada-NETS 2 | 3 | This is an official implementation for "Ada-NETS: Face Clustering via Adaptive Neighbour Discovery in the Structure Space" accepted at ICLR 2022. 4 | 5 | ## News 6 | - 🔥 An improved method on face clustering ([**B-Attenion**](https://github.com/Thomas-wyh/B-Attention/)) is accepted by NeurIPS 2022! 7 | - 🔥 Ada-NETS is accepted by ICLR 2022! 8 | 9 | 10 | ## Introduction 11 | 12 | This paper presents a novel Ada-NETS algorithm to deal with the noise edges problem when building the graph in GCN-based face clustering. In Ada-NETS, the features are first transformed to the structure space to enhance the accuracy of the similarity metrics. Then an adaptive neighbour discovery method is used to find neighbours for all samples adaptively with the guidance of a heuristic quality criterion. Based on the discovered neighbour relations, a graph with clean and rich edges is built as the input of GCNs to obtain state-of-the-art on the face, clothes, and person clustering tasks. 13 | 14 | 15 | 16 | 17 | 18 | ## Main Results 19 | 20 | 21 | 22 | 23 | 24 | 25 | ## Getting Started 26 | 27 | ### Install 28 | 29 | + Clone this repo 30 | 31 | ``` 32 | git clone https://github.com/Thomas-wyh/Ada-NETS 33 | cd Ada-NETS 34 | ``` 35 | 36 | + Create a conda virtual environment and activate it 37 | 38 | ``` 39 | conda create -n adanets python=3.6 -y 40 | conda activate adanets 41 | ``` 42 | 43 | + Install `Pytorch` , `cudatoolkit` and other requirements. 44 | ``` 45 | conda install pytorch==1.2 torchvision==0.4.0a0 cudatoolkit=10.2 -c pytorch 46 | pip install -r requirements.txt 47 | ``` 48 | 49 | - Install `Apex`: 50 | 51 | ``` 52 | git clone https://github.com/NVIDIA/apex 53 | cd apex 54 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 55 | ``` 56 | 57 | ### Data preparation 58 | 59 | The process of clustering on the MS-Celeb part1 is as follows: 60 | 61 | The original data files are from [here](https://github.com/yl-1993/learn-to-cluster/blob/master/DATASET.md#supported-datasets)(The feature and label files of MSMT17 used in Ada-NETS are [here](http://idstcv.oss-cn-zhangjiakou.aliyuncs.com/Ada-NETS/MSMT17/msmt17_feature_label.zip)). For convenience, we convert them to `.npy` format after L2 normalized. The original features' dimension is 256. The file structure should look like: 62 | 63 | ``` 64 | data 65 | ├── feature 66 | │   ├── part0_train.npy 67 | │   └── part1_test.npy 68 | └── label 69 | ├── part0_train.npy 70 | └── part1_test.npy 71 | ``` 72 | 73 | Build the $k$NN by faiss: 74 | 75 | ``` 76 | sh script/faiss_search.sh 77 | ``` 78 | 79 | Obtain the top$K$ neighbours and distances of each vertex in the structure space: 80 | 81 | ``` 82 | sh script/struct_space.sh 83 | ``` 84 | 85 | Obtain the best neigbours by the candidate neighbours quality criterion: 86 | 87 | ``` 88 | sh script/max_Q_ind.sh 89 | ``` 90 | 91 | ### Train the Adaptive Filter 92 | 93 | Train the adaptive filter based on the data prepared above: 94 | 95 | ``` 96 | sh script/train_AND.sh 97 | ``` 98 | 99 | ### Train the GCN and cluster faces 100 | 101 | Generate the clean yet rich Graph: 102 | 103 | ``` 104 | sh script/gene_adj.sh 105 | ``` 106 | 107 | Train the GCN to obtain enhanced vertex features: 108 | 109 | ``` 110 | sh script/train_GCN.sh 111 | ``` 112 | 113 | Perform cluster faces: 114 | 115 | ``` 116 | sh script/cluster.sh 117 | ``` 118 | 119 | It will print the evaluation results of clustering. The Bcubed F-socre is about 91.4 and the Pairwise F-score is about 92.7. 120 | 121 | 122 | 123 | ## Acknowledgement 124 | 125 | This code is based on the publicly available face clustering [codebase](https://github.com/yl-1993/learn-to-cluster), [codebase](https://github.com/makarandtapaswi/BallClustering_ICCV2019) and the [dmlc/dgl](https://github.com/dmlc/dgl). 126 | 127 | The k-nearest neighbor search tool uses [faiss](https://github.com/facebookresearch/faiss). 128 | 129 | 130 | 131 | 132 | ## Citing Ada-NETS 133 | 134 | ``` 135 | @inproceedings{wang2022adanets, 136 | title={Ada-NETS: Face Clustering via Adaptive Neighbour Discovery in the Structure Space}, 137 | author={Yaohua Wang and Yaobin Zhang and Fangyi Zhang and Senzhang Wang and Ming Lin and YuQi Zhang and Xiuyu Sun}, 138 | booktitle={International conference on learning representations (ICLR)}, 139 | year={2022} 140 | } 141 | 142 | @misc{wang2022adanets, 143 | title={Ada-NETS: Face Clustering via Adaptive Neighbour Discovery in the Structure Space}, 144 | author={Yaohua Wang and Yaobin Zhang and Fangyi Zhang and Senzhang Wang and Ming Lin and YuQi Zhang and Xiuyu Sun}, 145 | year={2022}, 146 | eprint={2202.03800}, 147 | archivePrefix={arXiv}, 148 | primaryClass={cs.CV} 149 | } 150 | ``` 151 | -------------------------------------------------------------------------------- /image/fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/Ada-NETS/42e445fbb3903059136b5dcac992ba85df4a1cf5/image/fig.png -------------------------------------------------------------------------------- /image/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/Ada-NETS/42e445fbb3903059136b5dcac992ba85df4a1cf5/image/results.png -------------------------------------------------------------------------------- /image/results2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/Ada-NETS/42e445fbb3903059136b5dcac992ba85df4a1cf5/image/results2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.9 2 | dgl-cu102==0.5.0 3 | apex==0.1 4 | faiss==1.6.3 5 | numpy==1.18.2 6 | yapf==0.30.0 7 | PyYAML==5.3.1 8 | tqdm 9 | scipy==1.2.1 10 | tensorboardX==2.0 11 | -------------------------------------------------------------------------------- /script/cluster.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | set +x 4 | 5 | cd GCN 6 | 7 | featfile=outpath/fcfeat_ckpt_35000.npy 8 | tag=fc 9 | python -W ignore ../tool/faiss_search.py $featfile $featfile $outpath $tag 10 | 11 | Ifile=outpath/fcI.npy 12 | Dfile=outpath/fcD.npy 13 | python cluster.py $Ifile $Dfile $test_labelfile 14 | -------------------------------------------------------------------------------- /script/faiss_search.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | set +x 4 | 5 | mkdir -p data/knn/train data/knn/test 6 | featfile=data/feature/part0_train.npy 7 | outpath=data/knn/train 8 | python -W ignore tool/faiss_search.py $featfile $featfile $outpath 9 | 10 | featfile=data/feature/part1_test.npy 11 | outpath=data/knn/test 12 | python -W ignore tool/faiss_search.py $featfile $featfile $outpath 13 | -------------------------------------------------------------------------------- /script/gene_adj.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | set +x 4 | 5 | mkdir -p data/adj/train data/adj/test 6 | 7 | knnfile=data/knn/train/data.npz 8 | topk=256 9 | outfile=data/adj/train/adj 10 | #python tool/gene_adj.py $knnfile $topk $outfile 11 | 12 | knnfile=data/ss/test/data.npz 13 | kfile=AND/outpath/k_infer_pred.npy 14 | outfile=data/adj/test/adj_adanets 15 | python tool/gene_adj_adanets.py $knnfile $kfile $outfile 16 | -------------------------------------------------------------------------------- /script/max_Q_ind.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | set +x 4 | 5 | mkdir -p data/max_Q/train data/max_Q/test 6 | 7 | beta=0.50 8 | Ifile=data/ss/train/I.npy 9 | labelfile=data/label/part0_train.npy 10 | outfile=data/max_Q/train/ind 11 | python tool/max_Q_ind.py $Ifile $labelfile $beta $outfile 12 | 13 | Ifile=data/ss/test/I.npy 14 | labelfile=data/label/part1_test.npy 15 | outfile=data/max_Q/test/ind 16 | python tool/max_Q_ind.py $Ifile $labelfile $beta $outfile 17 | -------------------------------------------------------------------------------- /script/structure_space.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | set +x 4 | 5 | mkdir -p data/ss/train data/ss/test 6 | python tool/struct_space.py data/knn/test/I.npy data/knn/test/D.npy 80 data/ss/test/I data/ss/test/D data/ss/test/data 7 | python tool/struct_space.py data/knn/train/I.npy data/knn/train/D.npy 80 data/ss/train/I data/ss/train/D data/ss/train/data 8 | -------------------------------------------------------------------------------- /script/train_AND.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | set +x 4 | 5 | cd AND 6 | outpath=outpath 7 | mkdir -p $outpath 8 | 9 | train_featfile=../data/feature/part0_train.npy 10 | train_Ifile=../data/ss/train/I.npy 11 | train_labelfile=../data/max_Q/train/ind.npy 12 | 13 | test_featfile=../data/feature/part1_test.npy 14 | test_Ifile=../data/ss/test/I.npy 15 | test_labelfile=../data/max_Q/test/ind.npy 16 | 17 | phase=train 18 | param=" --config_file config.yml --outpath $outpath --phase $phase 19 | --train_featfile $train_featfile --train_Ifile $train_Ifile --train_labelfile $train_labelfile 20 | --test_featfile $test_featfile --test_Ifile $test_Ifile --test_labelfile $test_labelfile" 21 | #python -u train.py $param 22 | 23 | phase=test 24 | ckpt=ckpt_40000.pth 25 | param=" --config_file config.yml --outpath $outpath --phase $phase 26 | --train_featfile $train_featfile --train_Ifile $train_Ifile --train_labelfile $train_labelfile 27 | --test_featfile $test_featfile --test_Ifile $test_Ifile --test_labelfile $test_labelfile 28 | --resume_path ${outpath}/ckpt_40000.pth" 29 | python -u train.py ${param} 30 | -------------------------------------------------------------------------------- /script/train_GCN.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | set +x 4 | 5 | cd GCN 6 | 7 | losstype=allmax 8 | margin=1.0 9 | pweight=1.0 10 | pmargin=0.9 11 | beta=0.50 12 | 13 | outpath=outpath 14 | train_featfile=../data/feature/part0_train.npy 15 | train_orderadjfile=../data/adj/train/adj.npz 16 | train_adjfile=../data/adj/train/adj.npz 17 | train_labelfile=../data/label/part0_train.npy 18 | test_featfile=../data/feature/part1_test.npy 19 | test_adjfile=../data/adj/test/adj_adanets.npz 20 | test_labelfile=../data/label/part1_test.npy 21 | 22 | phase=train 23 | param="--config_file config.yml --outpath $outpath --phase $phase 24 | --train_featfile $train_featfile --train_adjfile $train_adjfile --train_labelfile $train_labelfile --train_orderadjfile $train_orderadjfile 25 | --test_featfile $test_featfile --test_adjfile $test_adjfile --test_labelfile $test_labelfile 26 | --losstype $losstype --margin $margin --pweight $pweight --pmargin ${pmargin}" 27 | python -u train.py $param 28 | 29 | phase=test 30 | param="--config_file config.yml --outpath $outpath --phase $phase 31 | --train_featfile $train_featfile --train_adjfile $train_adjfile --train_labelfile $train_labelfile --train_orderadjfile $train_orderadjfile 32 | --test_featfile $test_featfile --test_adjfile $test_adjfile --test_labelfile $test_labelfile 33 | --losstype $losstype --margin $margin --pweight $pweight --pmargin ${pmargin} --resume_path $outpath/ckpt_35000.pth" 34 | python -u train.py $param 35 | -------------------------------------------------------------------------------- /tool/adjacency.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import scipy.sparse as sp 6 | 7 | 8 | def row_normalize(mx): 9 | """Row-normalize sparse matrix""" 10 | rowsum = np.array(mx.sum(1)) 11 | # if rowsum <= 0, keep its previous value 12 | rowsum[rowsum <= 0] = 1 13 | r_inv = np.power(rowsum, -1).flatten() 14 | r_inv[np.isinf(r_inv)] = 0. 15 | r_mat_inv = sp.diags(r_inv) 16 | mx = r_mat_inv.dot(mx) 17 | return mx 18 | 19 | 20 | def build_symmetric_adj(adj, self_loop=True): 21 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 22 | if self_loop: 23 | adj = adj + sp.eye(adj.shape[0]) 24 | return adj 25 | 26 | 27 | def sparse_mx_to_indices_values(sparse_mx): 28 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 29 | indices = np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64) 30 | values = sparse_mx.data 31 | shape = np.array(sparse_mx.shape) 32 | return indices, values, shape 33 | 34 | 35 | def indices_values_to_sparse_tensor(indices, values, shape): 36 | import torch 37 | indices = torch.from_numpy(indices) 38 | values = torch.from_numpy(values) 39 | shape = torch.Size(shape) 40 | return torch.sparse.FloatTensor(indices, values, shape) 41 | 42 | 43 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 44 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 45 | indices, values, shape = sparse_mx_to_indices_values(sparse_mx) 46 | return indices_values_to_sparse_tensor(indices, values, shape) 47 | -------------------------------------------------------------------------------- /tool/faiss_search.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import numpy as np 3 | import faiss 4 | from tqdm import tqdm 5 | import sys 6 | import time 7 | import os 8 | 9 | def batch_search(index, query, topk, bs, verbose=False): 10 | n = len(query) 11 | dists = np.zeros((n, topk), dtype=np.float32) 12 | nbrs = np.zeros((n, topk), dtype=np.int32) 13 | 14 | for sid in tqdm(range(0, n, bs), desc="faiss searching...", disable=not verbose): 15 | eid = min(n, sid + bs) 16 | dists[sid:eid], nbrs[sid:eid] = index.search(query[sid:eid], topk) 17 | cos_dist = dists / 2 18 | return cos_dist, nbrs 19 | 20 | 21 | def search(query_arr, doc_arr, outpath, tag, save_file=True): 22 | ### parameter 23 | nlist = 100 # 1000 cluster for 100w 24 | nprobe = 100 # test 10 cluster 25 | topk = 1024 26 | bs = 100 27 | ### end parameter 28 | 29 | 30 | #print("configure faiss") 31 | beg_time = time.time() 32 | num_gpu = faiss.get_num_gpus() 33 | dim = query_arr.shape[1] 34 | #cpu_index = faiss.index_factory(dim, 'IVF100', faiss.METRIC_INNER_PRODUCT) 35 | quantizer = faiss.IndexFlatL2(dim) 36 | cpu_index = faiss.IndexIVFFlat(quantizer, dim, nlist) 37 | cpu_index.nprobe = nprobe 38 | 39 | co = faiss.GpuMultipleClonerOptions() 40 | co.useFloat16 = True 41 | co.usePrecomputed = False 42 | co.shard = True 43 | gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co, ngpu=num_gpu) 44 | 45 | # start IVF 46 | #print("build index") 47 | gpu_index.train(doc_arr) 48 | gpu_index.add(doc_arr) 49 | #print(gpu_index.ntotal) 50 | 51 | # start query 52 | #print("start query") 53 | gpu_index.nprobe = nprobe # default nprobe is 1, try a few more 54 | print("beg search") 55 | D, I = batch_search(gpu_index, query_arr, topk, bs, verbose=True) 56 | print("time use %.4f"%(time.time()-beg_time)) 57 | 58 | if save_file: 59 | np.save(os.path.join(outpath, tag+'D'), D) 60 | np.save(os.path.join(outpath, tag+'I'), I) 61 | data = np.concatenate((I[:,None,:], D[:,None,:]), axis=1) 62 | np.savez(os.path.join(outpath,'data'), data=data) 63 | print("time use", time.time()-beg_time) 64 | 65 | if __name__ == "__main__": 66 | queryfile, docfile, outpath = sys.argv[1], sys.argv[2], sys.argv[3] 67 | if len(sys.argv) == 5: 68 | tag = sys.argv[4] 69 | else: 70 | tag = "" 71 | 72 | query_arr = np.load(queryfile) 73 | doc_arr = np.load(docfile) 74 | query_arr = query_arr / np.linalg.norm(query_arr, axis=1, keepdims=True) 75 | doc_arr = doc_arr / np.linalg.norm(doc_arr, axis=1, keepdims=True) 76 | 77 | search(query_arr, doc_arr, outpath, tag) 78 | -------------------------------------------------------------------------------- /tool/gene_adj.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import numpy as np 3 | from knn import fast_knns2spmat, knns2ordered_nbrs 4 | from adjacency import build_symmetric_adj, row_normalize 5 | from scipy.sparse import coo_matrix, save_npz 6 | import sys 7 | 8 | th_sim = 0.0 9 | if __name__ == "__main__": 10 | knnfile, topk, outfile = sys.argv[1], int(sys.argv[2]), sys.argv[3] 11 | knn_arr = np.load(knnfile)['data'][:, :, :topk] 12 | 13 | adj = fast_knns2spmat(knn_arr, topk, th_sim, use_sim=True) 14 | 15 | # build symmetric adjacency matrix 16 | adj = build_symmetric_adj(adj, self_loop=True) 17 | adj = row_normalize(adj) 18 | 19 | adj_coo = adj.tocoo() 20 | print("edge num", adj_coo.row.shape) 21 | print("mat shape", adj_coo.shape) 22 | 23 | save_npz(outfile, adj) 24 | -------------------------------------------------------------------------------- /tool/gene_adj_adanets.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import numpy as np 3 | from knn import fast_knns2spmat, knns2ordered_nbrs, fast_knns2spmat_adaptivek 4 | from adjacency import build_symmetric_adj, row_normalize 5 | from scipy.sparse import coo_matrix, save_npz 6 | import sys 7 | 8 | th_sim = 0.0 9 | if __name__ == "__main__": 10 | knnfile, kfile, outfile = sys.argv[1], sys.argv[2], sys.argv[3] 11 | knn_arr = np.load(knnfile)['data'] 12 | k_arr = np.load(kfile) 13 | 14 | adj = fast_knns2spmat_adaptivek(knn_arr, k_arr, th_sim) 15 | 16 | # build symmetric adjacency matrix 17 | adj = build_symmetric_adj(adj, self_loop=True) 18 | adj = row_normalize(adj) 19 | 20 | adj_coo = adj.tocoo() 21 | print("edge num", adj_coo.row.shape) 22 | print("mat shape", adj_coo.shape) 23 | 24 | save_npz(outfile, adj) 25 | -------------------------------------------------------------------------------- /tool/knn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import math 6 | import numpy as np 7 | import multiprocessing as mp 8 | from tqdm import tqdm 9 | 10 | __all__ = [ 11 | 'knn_brute_force', 'knn_hnsw', 'knn_faiss', 'knn_faiss_gpu', 'knns2spmat', 12 | 'fast_knns2spmat', 'knns2sub_spmat', 'build_knns', 'filter_knns', 13 | 'knns2ordered_nbrs', 'fast_knns2spmat_adaptivek' 14 | ] 15 | 16 | 17 | def knns_recall(nbrs, idx2lb, lb2idxs): 18 | with Timer('compute recall'): 19 | recs = [] 20 | cnt = 0 21 | for idx, (n, _) in enumerate(nbrs): 22 | lb = idx2lb[idx] 23 | idxs = lb2idxs[lb] 24 | n = list(n) 25 | if len(n) == 1: 26 | cnt += 1 27 | s = set(idxs) & set(n) 28 | recs += [1. * len(s) / len(idxs)] 29 | print('there are {} / {} = {:.3f} isolated anchors.'.format( 30 | cnt, len(nbrs), 1. * cnt / len(nbrs))) 31 | recall = np.mean(recs) 32 | return recall 33 | 34 | 35 | def filter_knns(knns, k, th): 36 | pairs = [] 37 | scores = [] 38 | n = len(knns) 39 | nbrs = np.zeros([n, k], dtype=np.int32) - 1 40 | simi = np.zeros([n, k]) - 1 41 | for i, (nbr, dist) in enumerate(knns): 42 | assert len(nbr) == len(dist) 43 | nbrs[i, :len(nbr)] = nbr 44 | simi[i, :len(nbr)] = 1. - dist 45 | anchor = np.tile(np.arange(n).reshape(n, 1), (1, k)) 46 | 47 | # filter 48 | selidx = np.where((simi >= th) & (nbrs != -1) & (nbrs != anchor)) 49 | pairs = np.hstack((anchor[selidx].reshape(-1, 50 | 1), nbrs[selidx].reshape(-1, 1))) 51 | scores = simi[selidx] 52 | 53 | if len(pairs) > 0: 54 | # keep uniq pairs 55 | pairs = np.sort(pairs, axis=1) 56 | pairs, unique_idx = np.unique(pairs, return_index=True, axis=0) 57 | scores = scores[unique_idx] 58 | return pairs, scores 59 | 60 | 61 | def knns2ordered_nbrs(knns, sort=True): 62 | if isinstance(knns, list): 63 | knns = np.array(knns) 64 | nbrs = knns[:, 0, :].astype(np.int32) 65 | dists = knns[:, 1, :] 66 | if sort: 67 | # sort dists from low to high 68 | nb_idx = np.argsort(dists, axis=1) 69 | idxs = np.arange(nb_idx.shape[0]).reshape(-1, 1) 70 | dists = dists[idxs, nb_idx] 71 | nbrs = nbrs[idxs, nb_idx] 72 | return dists, nbrs 73 | 74 | 75 | def knns2spmat(knns, k, th_sim=0.7, use_sim=False): 76 | # convert knns to symmetric sparse matrix 77 | from scipy.sparse import csr_matrix 78 | eps = 1e-5 79 | n = len(knns) 80 | row, col, data = [], [], [] 81 | for row_i, knn in enumerate(knns): 82 | nbrs, dists = knn 83 | for nbr, dist in zip(nbrs, dists): 84 | assert -eps <= dist <= 1 + eps, "{}: {}".format(row_i, dist) 85 | w = dist 86 | if 1 - w < th_sim or nbr == -1: 87 | continue 88 | if row_i == nbr: 89 | assert abs(dist) < eps 90 | continue 91 | row.append(row_i) 92 | col.append(nbr) 93 | if use_sim: 94 | w = 1 - w 95 | data.append(w) 96 | assert len(row) == len(col) == len(data) 97 | spmat = csr_matrix((data, (row, col)), shape=(n, n)) 98 | return spmat 99 | 100 | def fast_knns2spmat_adaptivek(knns, k_arr, th_sim=0.7): 101 | # convert knns to symmetric sparse matrix 102 | from scipy.sparse import csr_matrix 103 | eps = 1e-5 104 | n = len(knns) 105 | 106 | nbrs = knns[:, 0, :] 107 | dists = knns[:, 1, :] 108 | assert -eps <= dists.min() <= dists.max() <= 1 + eps, "min: {}, max: {}".format(dists.min(), dists.max()) 109 | sims = 1. - dists 110 | 111 | row, col = np.where(sims >= th_sim) # 这里划分阈值了 112 | new_row, new_col = [], [] 113 | for row_idx, col_idx in zip(row, col): 114 | thresK = k_arr[row_idx] 115 | if col_idx >= thresK: 116 | continue 117 | new_row.append(row_idx) 118 | new_col.append(col_idx) 119 | row, col = np.array(new_row), np.array(new_col) 120 | 121 | # remove the self-loop 122 | idxs = np.where(row != nbrs[row, col]) 123 | row = row[idxs] 124 | col = col[idxs] 125 | data = sims[row, col] 126 | col = nbrs[row, col] # convert to absolute column of the FULL N*N adj matrix 127 | assert len(row) == len(col) == len(data) 128 | spmat = csr_matrix((data, (row, col)), shape=(n, n)) 129 | return spmat 130 | 131 | def fast_knns2spmat(knns, k, th_sim=0.7, use_sim=False, fill_value=None): 132 | # convert knns to symmetric sparse matrix 133 | from scipy.sparse import csr_matrix 134 | eps = 1e-5 135 | n = len(knns) 136 | if isinstance(knns, list): 137 | knns = np.array(knns) 138 | if len(knns.shape) == 2: 139 | # knns saved by hnsw has different shape 140 | n = len(knns) 141 | ndarr = np.ones([n, 2, k]) 142 | ndarr[:, 0, :] = -1 # assign unknown dist to 1 and nbr to -1 143 | for i, (nbr, dist) in enumerate(knns): 144 | size = len(nbr) 145 | assert size == len(dist) 146 | ndarr[i, 0, :size] = nbr[:size] 147 | ndarr[i, 1, :size] = dist[:size] 148 | knns = ndarr 149 | nbrs = knns[:, 0, :] 150 | dists = knns[:, 1, :] 151 | assert -eps <= dists.min() <= dists.max( 152 | ) <= 1 + eps, "min: {}, max: {}".format(dists.min(), dists.max()) 153 | if use_sim: 154 | sims = 1. - dists 155 | else: 156 | sims = dists 157 | if fill_value is not None: 158 | print('[fast_knns2spmat] edge fill value:', fill_value) 159 | sims.fill(fill_value) 160 | row, col = np.where(sims >= th_sim) # 这里划分阈值了 161 | # remove the self-loop 162 | idxs = np.where(row != nbrs[row, col]) 163 | row = row[idxs] 164 | col = col[idxs] 165 | data = sims[row, col] 166 | col = nbrs[row, col] # convert to absolute column of the FULL N*N adj matrix 167 | assert len(row) == len(col) == len(data) 168 | spmat = csr_matrix((data, (row, col)), shape=(n, n)) 169 | return spmat 170 | 171 | 172 | def knns2sub_spmat(idxs, knns, th_sim=0.7, use_sim=False): 173 | # convert knns to symmetric sparse sub-matrix 174 | from scipy.sparse import csr_matrix 175 | n = len(idxs) 176 | row, col, data = [], [], [] 177 | abs2rel = {} 178 | for rel_i, abs_i in enumerate(idxs): 179 | assert abs_i not in abs2rel 180 | abs2rel[abs_i] = rel_i 181 | 182 | for row_i, idx in enumerate(idxs): 183 | nbrs, dists = knns[idx] 184 | for nbr, dist in zip(nbrs, dists): 185 | if idx == nbr: 186 | assert abs(dist) < 1e-6, "{}: {}".format(idx, dist) 187 | continue 188 | if nbr not in abs2rel: 189 | continue 190 | col_i = abs2rel[nbr] 191 | assert -1e-6 <= dist <= 1 192 | w = dist 193 | if 1 - w < th_sim or nbr == -1: 194 | continue 195 | row.append(row_i) 196 | col.append(col_i) 197 | if use_sim: 198 | w = 1 - w 199 | data.append(w) 200 | assert len(row) == len(col) == len(data) 201 | spmat = csr_matrix((data, (row, col)), shape=(n, n)) 202 | return spmat 203 | 204 | 205 | def build_knns(knn_prefix, 206 | feats, 207 | knn_method, 208 | k, 209 | num_process=None, 210 | is_rebuild=False, 211 | feat_create_time=None): 212 | knn_prefix = os.path.join(knn_prefix, '{}_k_{}'.format(knn_method, k)) 213 | mkdir_if_no_exists(knn_prefix) 214 | knn_path = knn_prefix + '.npz' 215 | if os.path.isfile( 216 | knn_path) and not is_rebuild and feat_create_time is not None: 217 | knn_create_time = os.path.getmtime(knn_path) 218 | if knn_create_time <= feat_create_time: 219 | print('[warn] knn is created before feats ({} vs {})'.format( 220 | format_time(knn_create_time), format_time(feat_create_time))) 221 | is_rebuild = True 222 | if not os.path.isfile(knn_path) or is_rebuild: 223 | index_path = knn_prefix + '.index' 224 | with Timer('build index'): 225 | if knn_method == 'hnsw': 226 | index = knn_hnsw(feats, k, index_path) 227 | elif knn_method == 'faiss': 228 | index = knn_faiss(feats, 229 | k, 230 | index_path, 231 | omp_num_threads=num_process, 232 | rebuild_index=True) 233 | elif knn_method == 'faiss_gpu': 234 | index = knn_faiss_gpu(feats, 235 | k, 236 | index_path, 237 | num_process=num_process) 238 | else: 239 | raise KeyError( 240 | 'Only support hnsw and faiss currently ({}).'.format( 241 | knn_method)) 242 | knns = index.get_knns() 243 | with Timer('dump knns to {}'.format(knn_path)): 244 | dump_data(knn_path, knns, force=True) 245 | else: 246 | print('read knn from {}'.format(knn_path)) 247 | knns = load_data(knn_path) 248 | return knns 249 | 250 | 251 | #class knn(): 252 | # def __init__(self, feats, k, index_path='', verbose=True): 253 | # pass 254 | # 255 | # def filter_by_th(self, i): 256 | # th_nbrs = [] 257 | # th_dists = [] 258 | # nbrs, dists = self.knns[i] 259 | # for n, dist in zip(nbrs, dists): 260 | # if 1 - dist < self.th: 261 | # continue 262 | # th_nbrs.append(n) 263 | # th_dists.append(dist) 264 | # th_nbrs = np.array(th_nbrs) 265 | # th_dists = np.array(th_dists) 266 | # return (th_nbrs, th_dists) 267 | # 268 | # def get_knns(self, th=None): 269 | # if th is None or th <= 0.: 270 | # return self.knns 271 | # # TODO: optimize the filtering process by numpy 272 | # # nproc = mp.cpu_count() 273 | # nproc = 1 274 | # with Timer('filter edges by th {} (CPU={})'.format(th, nproc), 275 | # self.verbose): 276 | # self.th = th 277 | # self.th_knns = [] 278 | # tot = len(self.knns) 279 | # if nproc > 1: 280 | # pool = mp.Pool(nproc) 281 | # th_knns = list( 282 | # tqdm(pool.imap(self.filter_by_th, range(tot)), total=tot)) 283 | # pool.close() 284 | # else: 285 | # th_knns = [self.filter_by_th(i) for i in range(tot)] 286 | # return th_knns 287 | 288 | 289 | class knn_brute_force(knn): 290 | def __init__(self, feats, k, index_path='', verbose=True): 291 | self.verbose = verbose 292 | with Timer('[brute force] build index', verbose): 293 | feats = feats.astype('float32') 294 | sim = feats.dot(feats.T) 295 | with Timer('[brute force] query topk {}'.format(k), verbose): 296 | nbrs = np.argpartition(-sim, kth=k)[:, :k] 297 | idxs = np.array([i for i in range(nbrs.shape[0])]) 298 | dists = 1 - sim[idxs.reshape(-1, 1), nbrs] 299 | self.knns = [(np.array(nbr, dtype=np.int32), 300 | np.array(dist, dtype=np.float32)) 301 | for nbr, dist in zip(nbrs, dists)] 302 | 303 | 304 | class knn_hnsw(knn): 305 | def __init__(self, feats, k, index_path='', verbose=True, **kwargs): 306 | import nmslib 307 | self.verbose = verbose 308 | with Timer('[hnsw] build index', verbose): 309 | ''' higher ef leads to better accuracy, but slower search 310 | higher M leads to higher accuracy/run_time at fixed ef, 311 | but consumes more memory 312 | ''' 313 | # space_params = { 314 | # 'ef': 100, 315 | # 'M': 16, 316 | # } 317 | # index = nmslib.init(method='hnsw', 318 | # space='cosinesimil', 319 | # space_params=space_params) 320 | index = nmslib.init(method='hnsw', space='cosinesimil') 321 | if index_path != '' and os.path.isfile(index_path): 322 | index.loadIndex(index_path) 323 | else: 324 | index.addDataPointBatch(feats) 325 | index.createIndex({ 326 | 'post': 2, 327 | 'indexThreadQty': 1 328 | }, 329 | print_progress=verbose) 330 | if index_path: 331 | print('[hnsw] save index to {}'.format(index_path)) 332 | mkdir_if_no_exists(index_path) 333 | index.saveIndex(index_path) 334 | with Timer('[hnsw] query topk {}'.format(k), verbose): 335 | knn_ofn = index_path + '.npz' 336 | if os.path.exists(knn_ofn): 337 | print('[hnsw] read knns from {}'.format(knn_ofn)) 338 | self.knns = np.load(knn_ofn)['data'] 339 | else: 340 | self.knns = index.knnQueryBatch(feats, k=k) 341 | 342 | 343 | class knn_faiss(knn): 344 | def __init__(self, 345 | feats, 346 | k, 347 | index_path='', 348 | index_key='', 349 | nprobe=128, 350 | omp_num_threads=None, 351 | rebuild_index=True, 352 | verbose=True, 353 | **kwargs): 354 | import faiss 355 | if omp_num_threads is not None: 356 | faiss.omp_set_num_threads(omp_num_threads) 357 | self.verbose = verbose 358 | with Timer('[faiss] build index', verbose): 359 | if index_path != '' and not rebuild_index and os.path.exists( 360 | index_path): 361 | print('[faiss] read index from {}'.format(index_path)) 362 | index = faiss.read_index(index_path) 363 | else: 364 | feats = feats.astype('float32') 365 | size, dim = feats.shape 366 | index = faiss.IndexFlatIP(dim) 367 | if index_key != '': 368 | assert index_key.find( 369 | 'HNSW') < 0, 'HNSW returns distances insted of sims' 370 | metric = faiss.METRIC_INNER_PRODUCT 371 | nlist = min(4096, 8 * round(math.sqrt(size))) 372 | if index_key == 'IVF': 373 | quantizer = index 374 | index = faiss.IndexIVFFlat(quantizer, dim, nlist, 375 | metric) 376 | else: 377 | index = faiss.index_factory(dim, index_key, metric) 378 | if index_key.find('Flat') < 0: 379 | assert not index.is_trained 380 | index.train(feats) 381 | index.nprobe = min(nprobe, nlist) 382 | assert index.is_trained 383 | print('nlist: {}, nprobe: {}'.format(nlist, nprobe)) 384 | index.add(feats) 385 | if index_path != '': 386 | print('[faiss] save index to {}'.format(index_path)) 387 | mkdir_if_no_exists(index_path) 388 | faiss.write_index(index, index_path) 389 | with Timer('[faiss] query topk {}'.format(k), verbose): 390 | knn_ofn = index_path + '.npz' 391 | if os.path.exists(knn_ofn): 392 | print('[faiss] read knns from {}'.format(knn_ofn)) 393 | self.knns = np.load(knn_ofn)['data'] 394 | else: 395 | sims, nbrs = index.search(feats, k=k) 396 | self.knns = [(np.array(nbr, dtype=np.int32), 397 | 1 - np.array(sim, dtype=np.float32)) 398 | for nbr, sim in zip(nbrs, sims)] 399 | 400 | 401 | class knn_faiss_gpu(knn): 402 | def __init__(self, 403 | feats, 404 | k, 405 | index_path='', 406 | index_key='', 407 | nprobe=128, 408 | num_process=4, 409 | is_precise=True, 410 | sort=True, 411 | verbose=True, 412 | **kwargs): 413 | with Timer('[faiss_gpu] query topk {}'.format(k), verbose): 414 | knn_ofn = index_path + '.npz' 415 | if os.path.exists(knn_ofn): 416 | print('[faiss_gpu] read knns from {}'.format(knn_ofn)) 417 | self.knns = np.load(knn_ofn)['data'] 418 | else: 419 | dists, nbrs = faiss_search_knn(feats, 420 | k=k, 421 | nprobe=nprobe, 422 | num_process=num_process, 423 | is_precise=is_precise, 424 | sort=sort, 425 | verbose=False) 426 | 427 | self.knns = [(np.array(nbr, dtype=np.int32), 428 | np.array(dist, dtype=np.float32)) 429 | for nbr, dist in zip(nbrs, dists)] 430 | 431 | 432 | if __name__ == '__main__': 433 | from utils import l2norm 434 | 435 | k = 30 436 | d = 256 437 | nfeat = 10000 438 | np.random.seed(42) 439 | 440 | feats = np.random.random((nfeat, d)).astype('float32') 441 | feats = l2norm(feats) 442 | 443 | index1 = knn_hnsw(feats, k) 444 | index2 = knn_faiss(feats, k) 445 | index3 = knn_faiss(feats, k, index_key='Flat') 446 | index4 = knn_faiss(feats, k, index_key='IVF') 447 | index5 = knn_faiss(feats, k, index_key='IVF100,PQ32') 448 | 449 | print(index1.knns[0]) 450 | print(index2.knns[0]) 451 | print(index3.knns[0]) 452 | print(index4.knns[0]) 453 | print(index5.knns[0]) 454 | -------------------------------------------------------------------------------- /tool/max_Q_ind.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import sys 3 | import numpy as np 4 | from multiprocessing import Pool 5 | import pandas as pd 6 | 7 | total_k = 80 8 | def get_topK(query_nodeid): 9 | query_label = label_arr[query_nodeid] 10 | 11 | total_num = len(np.where(label_arr == query_label)[0]) 12 | prec_list, recall_list, fscore_list = [], [], [] 13 | for topK in range(1, total_k + 1): 14 | result_list = [] 15 | for i in range(0, topK): 16 | doc_nodeid = I[query_nodeid][i] 17 | doc_label = label_arr[doc_nodeid] 18 | result = 1 if doc_label == query_label else 0 19 | if i == 0: 20 | result = 1 21 | result_list.append(result) 22 | prec = np.mean(result_list) 23 | recall = np.sum(result_list) / total_num 24 | fscore = (1 + beta*beta) * prec * recall / (beta*beta*prec + recall) 25 | prec_list.append(prec) 26 | recall_list.append(recall) 27 | fscore_list.append(fscore) 28 | fscore_arr = np.array(fscore_list) 29 | idx = fscore_arr.argmax() 30 | thres_topK = idx + 1 31 | return thres_topK, prec_list[idx], recall_list[idx], fscore_list[idx] 32 | 33 | if __name__ == "__main__": 34 | Ifile, labelfile, beta, outfile = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4] 35 | 36 | I = np.load(Ifile) 37 | label_arr = np.load(labelfile) 38 | beta = float(beta) 39 | 40 | debug = True 41 | debug = False 42 | if debug: 43 | res = [] 44 | for query_nodeid in range(len(I)): 45 | item = get_topK(query_nodeid) 46 | res.append(item) 47 | else: 48 | pool = Pool(48) 49 | res = pool.map(get_topK, range(len(I))) 50 | pool.close() 51 | pool.join() 52 | 53 | topK_list, prec_list, recall_list, fscore_list = list(zip(*res)) 54 | np.save(outfile, topK_list) 55 | 56 | -------------------------------------------------------------------------------- /tool/struct_space.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import sys 3 | import numpy as np 4 | from multiprocessing import Pool, Manager 5 | import copy 6 | from functools import partial 7 | 8 | k = 80 9 | lamb = 0.3 10 | 11 | def worker(k, queue, query_nodeid): 12 | docnodeid_list = I[query_nodeid, :k] 13 | query_Rstarset = Rstarset_list[query_nodeid] 14 | outlist = [] 15 | for idx, doc_nodeid in enumerate(docnodeid_list): 16 | doc_Rstarset = Rstarset_list[doc_nodeid] 17 | sim = 1.0 * len(query_Rstarset & doc_Rstarset) / len(query_Rstarset | doc_Rstarset) 18 | jd = 1 - sim 19 | cd = D[query_nodeid, idx] 20 | nd = (1-lamb) * jd + lamb * cd 21 | tpl = (doc_nodeid, nd) 22 | outlist.append(tpl) 23 | outlist = sorted(outlist, key=lambda x:x[1]) 24 | queue.put(query_nodeid) 25 | fn_name = sys._getframe().f_code.co_name 26 | #if queue.qsize() % 1000 == 0: 27 | # print("==>", fn_name, queue.qsize()) 28 | return list(zip(*outlist)) 29 | 30 | def get_Kngbr(query_nodeid, k): 31 | Kngbr = I[query_nodeid, :k] 32 | return set(Kngbr) 33 | 34 | def get_Rset(k, queue, query_nodeid): 35 | docnodeid_set = get_Kngbr(query_nodeid, k) 36 | Rset = set() 37 | for doc_nodeid in docnodeid_set: 38 | if query_nodeid not in get_Kngbr(doc_nodeid, k): 39 | continue 40 | Rset.add(doc_nodeid) 41 | queue.put(query_nodeid) 42 | fn_name = sys._getframe().f_code.co_name 43 | #if queue.qsize() % 1000 == 0: 44 | # print("==>", fn_name, queue.qsize()) 45 | return Rset 46 | 47 | def get_Rstarset(queue, query_nodeid): 48 | Rset = Rset_list[query_nodeid] 49 | Rstarset = copy.deepcopy(Rset) 50 | for doc_nodeid in Rset: 51 | doc_Rset = half_Rset_list[doc_nodeid] 52 | if len(doc_Rset & Rset) >= len(doc_Rset) * 2 / 3: 53 | Rstarset |= doc_Rset 54 | queue.put(query_nodeid) 55 | fn_name = sys._getframe().f_code.co_name 56 | #if queue.qsize() % 1000 == 0: 57 | # print("==>", fn_name, queue.qsize()) 58 | return Rstarset 59 | 60 | if __name__ == "__main__": 61 | Ifile, Dfile, topk, outIfile, outDfile, outDatafile = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6] 62 | k = int(topk) 63 | print("use topk", k) 64 | I = np.load(Ifile) 65 | D = np.load(Dfile) 66 | 67 | queue1 = Manager().Queue() 68 | queue2 = Manager().Queue() 69 | queue3 = Manager().Queue() 70 | queue4 = Manager().Queue() 71 | debug = True 72 | debug = False 73 | if debug: 74 | for query_nodeid in range(len(I)): 75 | res = worker(k, query_nodeid) 76 | else: 77 | pool = Pool(52) 78 | get_Rset_partial = partial(get_Rset, k, queue1) 79 | Rset_list = pool.map(get_Rset_partial, range(len(I))) 80 | pool.close() 81 | pool.join() 82 | 83 | pool = Pool(52) 84 | k2 = k // 2 85 | get_Rset_partial = partial(get_Rset, k2, queue2) 86 | half_Rset_list = pool.map(get_Rset_partial, range(len(I))) 87 | pool.close() 88 | pool.join() 89 | 90 | pool = Pool(52) 91 | get_Rstarset_partial = partial(get_Rstarset, queue3) 92 | Rstarset_list = pool.map(get_Rstarset_partial, range(len(I))) 93 | pool.close() 94 | pool.join() 95 | 96 | pool = Pool(52) 97 | worker_partial = partial(worker, k, queue4) 98 | res = pool.map(worker_partial, range(len(I))) 99 | pool.close() 100 | pool.join() 101 | 102 | newI, newD = list(zip(*res)) 103 | newI = np.array(newI) 104 | newD = np.array(newD) 105 | newdata = np.concatenate((newI[:,None,:], newD[:,None,:]), axis=1) 106 | np.save(outIfile, newI) 107 | np.save(outDfile, newD) 108 | np.savez(outDatafile, data=newdata) 109 | --------------------------------------------------------------------------------