├── AND
├── adjacency.py
├── config.yml
├── net
│ └── gcn_v.py
├── train.py
└── util
│ ├── confidence.py
│ ├── deduce.py
│ ├── evaluate.py
│ └── metrics.py
├── GCN
├── adjacency.py
├── cluster.py
├── config.yml
├── net
│ ├── gat.py
│ └── optim_modules.py
├── train.py
└── util
│ ├── confidence.py
│ ├── deduce.py
│ ├── evaluate.py
│ ├── graph.py
│ └── metrics.py
├── LICENSE
├── README.md
├── image
├── fig.png
├── results.png
└── results2.png
├── requirements.txt
├── script
├── cluster.sh
├── faiss_search.sh
├── gene_adj.sh
├── max_Q_ind.sh
├── structure_space.sh
├── train_AND.sh
└── train_GCN.sh
└── tool
├── adjacency.py
├── faiss_search.py
├── gene_adj.py
├── gene_adj_adanets.py
├── knn.py
├── max_Q_ind.py
└── struct_space.py
/AND/adjacency.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import numpy as np
5 | import scipy.sparse as sp
6 |
7 |
8 | def row_normalize(mx):
9 | """Row-normalize sparse matrix"""
10 | rowsum = np.array(mx.sum(1))
11 | # if rowsum <= 0, keep its previous value
12 | rowsum[rowsum <= 0] = 1
13 | r_inv = np.power(rowsum, -1).flatten()
14 | r_inv[np.isinf(r_inv)] = 0.
15 | r_mat_inv = sp.diags(r_inv)
16 | mx = r_mat_inv.dot(mx)
17 | return mx
18 |
19 |
20 | def build_symmetric_adj(adj, self_loop=True):
21 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
22 | if self_loop:
23 | adj = adj + sp.eye(adj.shape[0])
24 | return adj
25 |
26 |
27 | def sparse_mx_to_indices_values(sparse_mx):
28 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
29 | indices = np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
30 | values = sparse_mx.data
31 | shape = np.array(sparse_mx.shape)
32 | return indices, values, shape
33 |
34 |
35 | def indices_values_to_sparse_tensor(indices, values, shape):
36 | import torch
37 | indices = torch.from_numpy(indices)
38 | values = torch.from_numpy(values)
39 | shape = torch.Size(shape)
40 | return torch.sparse.FloatTensor(indices, values, shape)
41 |
42 |
43 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
44 | """Convert a scipy sparse matrix to a torch sparse tensor."""
45 | indices, values, shape = sparse_mx_to_indices_values(sparse_mx)
46 | return indices_values_to_sparse_tensor(indices, values, shape)
47 |
--------------------------------------------------------------------------------
/AND/config.yml:
--------------------------------------------------------------------------------
1 | # model
2 | feat_dim: 256
3 | nhid: 512
4 | nclass: 1
5 |
6 | # optimizer
7 | lr: 0.01
8 | sgd_momentum: 0.9
9 | sgd_weight_decay: 0.00001
10 | lr_step : [0.5, 0.8, 0.9]
11 | factor: 0.1
12 | total_step: 40000
13 | cuda: True
14 | warmup_step: 128
15 | batchsize: 1024
16 |
17 | # output
18 | save_freq: 20000
19 | log_freq: 1
20 | val_freq: 1000
21 | # resume
22 | resume_path:
23 |
--------------------------------------------------------------------------------
/AND/net/gcn_v.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | #from .utils import GraphConv, MeanAggregator
7 | from dgl.nn.pytorch import GraphConv
8 | import torch.nn.functional as F
9 | import math
10 |
11 |
12 | class huberloss(nn.Module):
13 | def __init__(self, delta):
14 | super(huberloss, self).__init__()
15 | self.delta = delta
16 |
17 | def forward(self, input_arr, target_arr):
18 | rate = input_arr / target_arr - 1
19 | loss = torch.where(torch.abs(rate) <= self.delta, 0.5*rate*rate, (torch.abs(rate) - 0.5*self.delta) * self.delta)
20 | return loss.mean()
21 |
22 | class MREloss(nn.Module):
23 | def __init__(self):
24 | super(MREloss, self).__init__()
25 |
26 | def forward(self, input_arr, target_arr):
27 | loss = torch.abs(input_arr / target_arr - 1)
28 | return loss.mean()
29 |
30 |
31 | class GCN_V(nn.Module):
32 | def __init__(self, feature_dim, nhid, nclass, dropout=0):
33 | super(GCN_V, self).__init__()
34 | self.lstm = nn.LSTM(input_size=feature_dim, hidden_size=feature_dim, num_layers=1, batch_first=True, dropout=dropout, bidirectional=True)
35 | self.out_proj = nn.Linear(2*feature_dim, feature_dim, bias=True)
36 |
37 | self.nclass = nclass
38 | self.mlp = nn.Sequential(
39 | nn.Linear(feature_dim, nhid), nn.PReLU(nhid), nn.Dropout(p=dropout),
40 | nn.Linear(nhid, feature_dim), nn.PReLU(feature_dim), nn.Dropout(p=dropout),
41 | )
42 | self.regressor = nn.Linear(feature_dim, 1)
43 | #self.loss = torch.nn.MSELoss()
44 | #self.loss = MREloss()
45 | self.loss = huberloss(delta=1.0)
46 |
47 | def forward(self, data, output_feat=False, return_loss=False):
48 | assert not output_feat or not return_loss
49 | batch_feat, batch_label = data[0], data[1]
50 |
51 | # lstm block
52 | out, (hn, cn) = self.lstm(batch_feat)
53 | out = self.out_proj(out)
54 | out = (out + batch_feat) / math.sqrt(2.)
55 |
56 | # normalize before mean
57 | out = F.normalize(out, 2, dim=-1)
58 | out = out.mean(dim=1)
59 | out = F.normalize(out, 2, dim=-1)
60 |
61 | # mlp block
62 | residual = out
63 | out = self.mlp(out)
64 | out = (residual + out ) / math.sqrt(2.)
65 |
66 | # regressor block
67 | pred = self.regressor(out).view(-1)
68 |
69 | if output_feat:
70 | return pred, residual
71 |
72 | if return_loss:
73 | loss = self.loss(pred, batch_label)
74 | return pred, loss
75 |
76 | return pred
77 |
78 |
79 | def gcn_v(feature_dim, nhid, nclass=1, dropout=0., **kwargs):
80 | model = GCN_V(feature_dim=feature_dim,
81 | nhid=nhid,
82 | nclass=nclass,
83 | dropout=dropout)
84 | return model
85 |
--------------------------------------------------------------------------------
/AND/train.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | from __future__ import division
3 | import torch
4 | import torch.optim as optim
5 | from adjacency import sparse_mx_to_torch_sparse_tensor
6 | from net.gcn_v import GCN_V
7 | import yaml
8 | from easydict import EasyDict
9 | from tensorboardX import SummaryWriter
10 | import numpy as np
11 | import scipy.sparse as sp
12 | import time
13 | import pprint
14 | import sys
15 | import os
16 | import argparse
17 | import math
18 | import pandas as pd
19 | import dgl
20 | import warnings
21 | from tqdm import tqdm
22 |
23 |
24 | class node_dataset(torch.utils.data.Dataset):
25 | def __init__(self, node_list, **kwargs):
26 | self.node_list = node_list
27 |
28 | def __getitem__(self, index):
29 | return self.node_list[index]
30 |
31 | def __len__(self):
32 | return len(self.node_list)
33 |
34 | def row_normalize(mx):
35 | """Row-normalize sparse matrix"""
36 | rowsum = np.array(mx.sum(1))
37 | # if rowsum <= 0, keep its previous value
38 | rowsum[rowsum <= 0] = 1
39 | r_inv = np.power(rowsum, -1).flatten()
40 | r_inv[np.isinf(r_inv)] = 0.
41 | r_mat_inv = sp.diags(r_inv)
42 | mx = r_mat_inv.dot(mx)
43 | return mx
44 |
45 | class AverageMeter(object):
46 | def __init__(self):
47 | self.val = 0
48 | self.avg = 0
49 | self.sum = 0
50 | self.count = 0
51 | def reset(self):
52 | self.val = 0
53 | self.avg = 0
54 | self.sum = 0
55 | self.count = 0
56 | def update(self, val, n=1):
57 | self.val = val
58 | self.sum += val * n
59 | self.count += n
60 | self.avg = float(self.sum) / (self.count + 1e-10)
61 |
62 | class Timer():
63 | def __init__(self, name='task', verbose=True):
64 | self.name = name
65 | self.verbose = verbose
66 |
67 | def __enter__(self):
68 | print('[begin {}]'.format(self.name))
69 | self.start = time.time()
70 | return self
71 |
72 | def __exit__(self, exc_type, exc_val, exc_tb):
73 | if self.verbose:
74 | print('[done {}] use {:.3f} s'.format(self.name, time.time() - self.start))
75 | return exc_type is None
76 |
77 | def adjust_lr(cur_epoch, param, cfg):
78 | if cur_epoch not in cfg.step_number:
79 | return
80 | ind = cfg.step_number.index(cur_epoch)
81 | for each in optimizer.param_groups:
82 | each['lr'] = lr
83 |
84 | def cos_lr(current_step, optimizer, cfg):
85 | if current_step < cfg.warmup_step:
86 | rate = 1.0 * current_step / cfg.warmup_step
87 | lr = cfg.lr * rate
88 | else:
89 | n1 = cfg.total_step - cfg.warmup_step
90 | n2 = current_step - cfg.warmup_step
91 | rate = (1 + math.cos(math.pi * n2 / n1)) / 2
92 | lr = cfg.lr * rate
93 | for each in optimizer.param_groups:
94 | each['lr'] = lr
95 |
96 | if __name__ == "__main__":
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument('--config_file', type=str)
99 | parser.add_argument('--outpath', type=str)
100 | parser.add_argument('--phase', type=str)
101 | parser.add_argument('--train_featfile', type=str)
102 | parser.add_argument('--train_Ifile', type=str)
103 | parser.add_argument('--train_labelfile', type=str)
104 | parser.add_argument('--test_featfile', type=str)
105 | parser.add_argument('--test_Ifile', type=str)
106 | parser.add_argument('--test_labelfile', type=str)
107 | parser.add_argument('--resume_path', type=str)
108 | args = parser.parse_args()
109 |
110 | beg_time = time.time()
111 | config = yaml.load(open(args.config_file, "r"), Loader=yaml.FullLoader)
112 | cfg = EasyDict(config)
113 | cfg.step_number = [int(r * cfg.total_step) for r in cfg.lr_step]
114 |
115 | # force assignment
116 | for key, value in args._get_kwargs():
117 | cfg[key] = value
118 | #cfg[list(dict(train_adjfile=train_adjfile).keys())[0]] = train_adjfile
119 | #cfg[list(dict(train_labelfile=train_labelfile).keys())[0]] = train_labelfile
120 | #cfg[list(dict(test_adjfile=test_adjfile).keys())[0]] = test_adjfile
121 | #cfg[list(dict(test_labelfile=test_labelfile).keys())[0]] = test_labelfile
122 | print("train hyper parameter list")
123 | pprint.pprint(cfg)
124 |
125 | # get model
126 | model = GCN_V(feature_dim=cfg.feat_dim, nhid=cfg.nhid, nclass=cfg.nclass, dropout=0.5)
127 | model.cuda()
128 |
129 | # get dataset
130 | scale_max = 80.
131 | with Timer('load data'):
132 | train_feature = np.load(cfg.train_featfile)
133 | train_feature = train_feature / np.linalg.norm(train_feature, axis=1, keepdims=True)
134 | train_adj = np.load(cfg.train_Ifile)[:, :int(scale_max)]
135 | train_label_k = np.load(cfg.train_labelfile).astype(np.float32)
136 | train_label_s = train_label_k / scale_max
137 | train_feature = torch.FloatTensor(train_feature).cuda()
138 | train_label_s = torch.FloatTensor(train_label_s).cuda()
139 | train_data = (train_feature, train_adj, train_label_s)
140 |
141 | test_feature = np.load(cfg.test_featfile)
142 | test_feature = test_feature / np.linalg.norm(test_feature, axis=1, keepdims=True)
143 | test_adj = np.load(cfg.test_Ifile)[:, :int(scale_max)]
144 | test_label_k = np.load(cfg.test_labelfile).astype(np.float32)
145 | test_label_s = test_label_k / scale_max
146 | test_feature = torch.FloatTensor(test_feature).cuda()
147 | test_label_s = torch.FloatTensor(test_label_s).cuda()
148 |
149 | train_dataset = node_dataset(range(len(train_feature)))
150 | test_dataset = node_dataset(range(len(test_feature)))
151 | train_dataloader = torch.utils.data.DataLoader(
152 | dataset=train_dataset,
153 | batch_size=cfg.batchsize,
154 | shuffle=True,
155 | num_workers=16,
156 | pin_memory=True,
157 | drop_last=False)
158 |
159 | test_dataloader = torch.utils.data.DataLoader(
160 | dataset=test_dataset,
161 | batch_size=cfg.batchsize,
162 | shuffle=False,
163 | num_workers=16,
164 | pin_memory=True,
165 | drop_last=False)
166 |
167 | if cfg.phase == 'train':
168 | optimizer = optim.SGD(model.parameters(), cfg.lr, momentum=cfg.sgd_momentum, weight_decay=cfg.sgd_weight_decay)
169 | beg_step = 0
170 | if cfg.resume_path != None:
171 | beg_step = int(os.path.splitext(os.path.basename(cfg.resume_path))[0].split('_')[1])
172 | with Timer('resume model from %s'%cfg.resume_path):
173 | ckpt = torch.load(cfg.resume_path, map_location='cpu')
174 | model.load_state_dict(ckpt['state_dict'])
175 |
176 | train_loss_meter = AverageMeter()
177 | train_kdiff_meter = AverageMeter()
178 | train_mre_meter = AverageMeter()
179 | test_loss_meter = AverageMeter()
180 | test_kdiff_meter = AverageMeter()
181 | test_mre_meter = AverageMeter()
182 | writer = SummaryWriter(os.path.join(cfg.outpath), filename_suffix='')
183 |
184 | current_step = beg_step
185 | break_flag = False
186 | while 1:
187 | if break_flag:
188 | break
189 | iter_begtime = time.time()
190 | for _, index in enumerate(train_dataloader):
191 | if current_step > cfg.total_step:
192 | break_flag = True
193 | break
194 | current_step += 1
195 | cos_lr(current_step, optimizer, cfg)
196 |
197 | batch_feature = train_feature[train_adj[index]]
198 | batch_label = train_label_s[index]
199 | batch_k = train_label_k[index]
200 | batch_data = (batch_feature, batch_label)
201 |
202 | model.train()
203 | pred_arr, train_loss = model(batch_data, return_loss=True)
204 | optimizer.zero_grad()
205 | train_loss.backward()
206 | optimizer.step()
207 |
208 | train_loss_meter.update(train_loss.item())
209 | pred_arr = pred_arr.data.cpu().numpy()
210 |
211 | # add this clip
212 | k_hat = np.round(pred_arr * scale_max)
213 | k_hat[np.where(k_hat < 1)[0]] = 1
214 | k_hat[np.where(k_hat > scale_max)[0]] = scale_max
215 |
216 | train_kdiff = np.abs(k_hat - batch_k)
217 | train_kdiff_meter.update(train_kdiff.mean())
218 | train_mre = train_kdiff / batch_k
219 | train_mre_meter.update(train_mre.mean())
220 | writer.add_scalar('lr', optimizer.param_groups[0]['lr'], global_step=current_step)
221 | writer.add_scalar('loss/train', train_loss.item(), global_step=current_step)
222 | writer.add_scalar('kdiff/train', train_kdiff_meter.val, global_step=current_step)
223 | writer.add_scalar('mre/train', train_mre_meter.val, global_step=current_step)
224 | if current_step % cfg.log_freq == 0:
225 | log = "step:{}, step_time:{:.3f}, lr:{:.8f}, trainloss:{:.4f}({:.4f}), trainkdiff:{:.2f}({:.2f}), trainmre:{:.2f}({:.2f}), testloss:{:.4f}({:.4f}), testkdiff:{:.2f}({:.2f}), testmre:{:.2f}({:.2f})".format(current_step, time.time()-iter_begtime, optimizer.param_groups[0]['lr'], train_loss_meter.val, train_loss_meter.avg, train_kdiff_meter.val, train_kdiff_meter.avg, train_mre_meter.val, train_mre_meter.avg, test_loss_meter.val, test_loss_meter.avg, test_kdiff_meter.val, test_kdiff_meter.avg, test_mre_meter.val, test_mre_meter.avg)
226 | print(log)
227 | iter_begtime = time.time()
228 | if (current_step) % cfg.save_freq == 0 and current_step > 0:
229 | torch.save({'state_dict' : model.state_dict(), 'step': current_step},
230 | os.path.join(cfg.outpath, "ckpt_%s.pth"%(current_step)))
231 |
232 | if (current_step) % cfg.val_freq == 0 and current_step > 0:
233 | pred_list = []
234 | model.eval()
235 | testloss_list = []
236 | for step, index in enumerate(tqdm(test_dataloader, desc='test phase', disable=False)):
237 |
238 | batch_feature = test_feature[test_adj[index]]
239 | batch_label = test_label_s[index]
240 | batch_data = (batch_feature, batch_label)
241 |
242 | pred, test_loss = model(batch_data, return_loss=True)
243 | pred_list.append(pred.data.cpu().numpy())
244 | testloss_list.append(test_loss.item())
245 |
246 | pred_list = np.concatenate(pred_list)
247 | k_hat, k_arr = pred_list * scale_max, test_label_k
248 |
249 | # add this clip before eval
250 | k_hat = np.round(k_hat)
251 | k_hat[np.where(k_hat < 1)[0]] = 1
252 | k_hat[np.where(k_hat > scale_max)[0]] = scale_max
253 |
254 | test_kdiff = np.abs(np.round(k_hat) - k_arr.reshape(-1))
255 | test_mre = test_kdiff / k_arr.reshape(-1)
256 | test_kdiff_meter.update(test_kdiff.mean())
257 | test_mre_meter.update(test_mre.mean())
258 | test_loss_meter.update(np.mean(testloss_list))
259 | writer.add_scalar('loss/test', test_loss_meter.val, global_step=current_step)
260 | writer.add_scalar('kdiff/test', test_kdiff_meter.val, global_step=current_step)
261 | writer.add_scalar('mre/test', test_mre_meter.val, global_step=current_step)
262 |
263 | writer.close()
264 | else:
265 | ckpt = torch.load(cfg.resume_path, map_location='cpu')
266 | model.load_state_dict(ckpt['state_dict'])
267 |
268 | pred_list, gcnfeat_list = [], []
269 | model.eval()
270 | beg_time = time.time()
271 | for step, index, in enumerate(test_dataloader):
272 | batch_feature = test_feature[test_adj[index]]
273 | batch_label = test_label_s[index]
274 | batch_data = (batch_feature, batch_label)
275 |
276 | pred, gcnfeat = model(batch_data, output_feat=True)
277 | pred_list.append(pred.data.cpu().numpy())
278 | gcnfeat_list.append(gcnfeat.data.cpu().numpy())
279 | print("time use %.4f"%(time.time()-beg_time))
280 |
281 | pred_list = np.concatenate(pred_list)
282 | gcnfeat_arr = np.vstack(gcnfeat_list)
283 | gcnfeat_arr = gcnfeat_arr / np.linalg.norm(gcnfeat_arr, axis=1, keepdims=True)
284 | tag = os.path.splitext(os.path.basename(cfg.resume_path))[0]
285 |
286 | print("stat")
287 | k_hat, k_arr = pred_list * scale_max, test_label_k
288 |
289 | # add this clip before eval
290 | k_hat = np.round(k_hat)
291 | k_hat[np.where(k_hat < 1)[0]] = 1
292 | k_hat[np.where(k_hat > scale_max)[0]] = scale_max
293 | np.save(os.path.join(cfg.outpath, 'k_infer_pred'), np.round(k_hat))
294 |
295 | print("time use", time.time() - beg_time)
296 |
--------------------------------------------------------------------------------
/AND/util/confidence.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import numpy as np
5 | from tqdm import tqdm
6 | from itertools import groupby
7 |
8 | __all__ = ['density', 'confidence', 'confidence_to_peaks']
9 |
10 |
11 | def density(dists, radius=0.3, use_weight=True):
12 | row, col = (dists < radius).nonzero()
13 |
14 | num, _ = dists.shape
15 | if use_weight:
16 | density = np.zeros((num, ), dtype=np.float32)
17 | for r, c in zip(row, col):
18 | density[r] += 1 - dists[r, c]
19 | else:
20 | density = np.zeros((num, ), dtype=np.int32)
21 | for k, g in groupby(row):
22 | density[k] = len(list(g))
23 | return density
24 |
25 |
26 | def s_nbr(dists, nbrs, idx2lb, **kwargs):
27 | ''' use supervised confidence defined on neigborhood
28 | '''
29 | num, _ = dists.shape
30 | conf = np.zeros((num, ), dtype=np.float32)
31 | contain_neg = 0
32 | for i, (nbr, dist) in enumerate(zip(nbrs, dists)):
33 | lb = idx2lb[i]
34 | pos, neg = 0, 0
35 | for j, n in enumerate(nbr):
36 | if idx2lb[n] == lb:
37 | pos += 1 - dist[j]
38 | else:
39 | neg += 1 - dist[j]
40 | conf[i] = pos - neg
41 | if neg > 0:
42 | contain_neg += 1
43 | print('#contain_neg:', contain_neg)
44 | conf /= np.abs(conf).max()
45 | return conf
46 |
47 |
48 | def s_nbr_size_norm(dists, nbrs, idx2lb, **kwargs):
49 | ''' use supervised confidence defined on neigborhood (norm by size)
50 | '''
51 | num, _ = dists.shape
52 | conf = np.zeros((num, ), dtype=np.float32)
53 | contain_neg = 0
54 | max_size = 0
55 | for i, (nbr, dist) in enumerate(zip(nbrs, dists)):
56 | size = 0
57 | pos, neg = 0, 0
58 | lb = idx2lb[i]
59 | for j, n in enumerate(nbr):
60 | if idx2lb[n] == lb:
61 | pos += 1 - dist[j]
62 | else:
63 | neg += 1 - dist[j]
64 | size += 1
65 | conf[i] = pos - neg
66 | max_size = max(max_size, size)
67 | if neg > 0:
68 | contain_neg += 1
69 | print('#contain_neg:', contain_neg)
70 | print('max_size: {}'.format(max_size))
71 | conf /= max_size
72 | return conf
73 |
74 |
75 | def s_avg(feats, idx2lb, lb2idxs, **kwargs):
76 | ''' use average similarity of intra-nodes
77 | '''
78 | num = len(idx2lb)
79 | conf = np.zeros((num, ), dtype=np.float32)
80 | for i in range(num):
81 | lb = idx2lb[i]
82 | idxs = lb2idxs[lb]
83 | idxs.remove(i)
84 | if len(idxs) == 0:
85 | continue
86 | feat = feats[i, :]
87 | conf[i] = feat.dot(feats[idxs, :].T).mean()
88 | eps = 1e-6
89 | assert -1 - eps <= conf.min() <= conf.max(
90 | ) <= 1 + eps, "min: {}, max: {}".format(conf.min(), conf.max())
91 | return conf
92 |
93 |
94 | def s_center(feats, idx2lb, lb2idxs, **kwargs):
95 | ''' use average similarity of intra-nodes
96 | '''
97 | num = len(idx2lb)
98 | conf = np.zeros((num, ), dtype=np.float32)
99 | for i in range(num):
100 | lb = idx2lb[i]
101 | idxs = lb2idxs[lb]
102 | if len(idxs) == 0:
103 | continue
104 | feat = feats[i, :]
105 | feat_center = feats[idxs, :].mean(axis=0)
106 | conf[i] = feat.dot(feat_center.T)
107 | eps = 1e-6
108 | assert -1 - eps <= conf.min() <= conf.max(
109 | ) <= 1 + eps, "min: {}, max: {}".format(conf.min(), conf.max())
110 | return conf
111 |
112 |
113 | def confidence(metric='s_nbr', **kwargs):
114 | metric2func = {
115 | 's_nbr': s_nbr,
116 | 's_nbr_size_norm': s_nbr_size_norm,
117 | 's_avg': s_avg,
118 | 's_center': s_center,
119 | }
120 | if metric in metric2func:
121 | func = metric2func[metric]
122 | else:
123 | raise KeyError('Only support confidence metircs: {}'.format(
124 | metric2func.keys()))
125 |
126 | conf = func(**kwargs)
127 | return conf
128 |
129 |
130 | def confidence_to_peaks(dists, nbrs, confidence, max_conn=1):
131 | # Note that dists has been sorted in ascending order
132 | assert dists.shape[0] == confidence.shape[0]
133 | assert dists.shape == nbrs.shape
134 |
135 | num, _ = dists.shape
136 | dist2peak = {i: [] for i in range(num)}
137 | peaks = {i: [] for i in range(num)}
138 |
139 | for i, nbr in tqdm(enumerate(nbrs)):
140 | nbr_conf = confidence[nbr]
141 | for j, c in enumerate(nbr_conf):
142 | nbr_idx = nbr[j]
143 | if i == nbr_idx or c <= confidence[i]:
144 | continue
145 | dist2peak[i].append(dists[i, j])
146 | peaks[i].append(nbr_idx)
147 | if len(dist2peak[i]) >= max_conn:
148 | break
149 | return dist2peak, peaks
150 |
--------------------------------------------------------------------------------
/AND/util/deduce.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | __all__ = ['peaks_to_labels']
4 |
5 |
6 | def _find_parent(parent, u):
7 | idx = []
8 | # parent is a fixed point
9 | while (u != parent[u]):
10 | idx.append(u)
11 | u = parent[u]
12 | for i in idx:
13 | parent[i] = u
14 | return u
15 |
16 |
17 | def edge_to_connected_graph(edges, num):
18 | parent = list(range(num))
19 | for u, v in edges:
20 | p_u = _find_parent(parent, u)
21 | p_v = _find_parent(parent, v)
22 | parent[p_u] = p_v
23 |
24 | for i in range(num):
25 | parent[i] = _find_parent(parent, i)
26 | remap = {}
27 | uf = np.unique(np.array(parent))
28 | for i, f in enumerate(uf):
29 | remap[f] = i
30 | cluster_id = np.array([remap[f] for f in parent])
31 | return cluster_id
32 |
33 |
34 | def peaks_to_edges(peaks, dist2peak, tau):
35 | edges = []
36 | for src in peaks:
37 | dsts = peaks[src]
38 | dists = dist2peak[src]
39 | for dst, dist in zip(dsts, dists):
40 | if src == dst or dist >= 1 - tau:
41 | continue
42 | edges.append([src, dst])
43 | return edges
44 |
45 |
46 | def peaks_to_labels(peaks, dist2peak, tau, inst_num):
47 | edges = peaks_to_edges(peaks, dist2peak, tau)
48 | pred_labels = edge_to_connected_graph(edges, inst_num)
49 | return pred_labels
50 |
--------------------------------------------------------------------------------
/AND/util/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import inspect
5 | import argparse
6 | import numpy as np
7 | import util.metrics as metrics
8 | import time
9 |
10 | class TextColors:
11 | #HEADER = '\033[35m'
12 | #OKBLUE = '\033[34m'
13 | #OKGREEN = '\033[32m'
14 | #WARNING = '\033[33m'
15 | #FATAL = '\033[31m'
16 | #ENDC = '\033[0m'
17 | #BOLD = '\033[1m'
18 | #UNDERLINE = '\033[4m'
19 | HEADER = ''
20 | OKBLUE = ''
21 | OKGREEN = ''
22 | WARNING = ''
23 | FATAL = ''
24 | ENDC = ''
25 | BOLD = ''
26 | UNDERLINE = ''
27 |
28 | class Timer():
29 | def __init__(self, name='task', verbose=True):
30 | self.name = name
31 | self.verbose = verbose
32 |
33 | def __enter__(self):
34 | self.start = time.time()
35 | return self
36 |
37 | def __exit__(self, exc_type, exc_val, exc_tb):
38 | if self.verbose:
39 | print('[Time] {} consumes {:.4f} s'.format(
40 | self.name,
41 | time.time() - self.start))
42 | return exc_type is None
43 |
44 |
45 | def _read_meta(fn):
46 | labels = list()
47 | lb_set = set()
48 | with open(fn) as f:
49 | for lb in f.readlines():
50 | lb = int(lb.strip())
51 | labels.append(lb)
52 | lb_set.add(lb)
53 | return np.array(labels), lb_set
54 |
55 |
56 | def evaluate(gt_labels, pred_labels, metric='pairwise'):
57 | if isinstance(gt_labels, str) and isinstance(pred_labels, str):
58 | print('[gt_labels] {}'.format(gt_labels))
59 | print('[pred_labels] {}'.format(pred_labels))
60 | gt_labels, gt_lb_set = _read_meta(gt_labels)
61 | pred_labels, pred_lb_set = _read_meta(pred_labels)
62 |
63 | print('#inst: gt({}) vs pred({})'.format(len(gt_labels),
64 | len(pred_labels)))
65 | print('#cls: gt({}) vs pred({})'.format(len(gt_lb_set),
66 | len(pred_lb_set)))
67 |
68 | metric_func = metrics.__dict__[metric]
69 |
70 | with Timer('evaluate with {}{}{}'.format(TextColors.FATAL, metric,
71 | TextColors.ENDC)):
72 | result = metric_func(gt_labels, pred_labels)
73 | if isinstance(result, np.float):
74 | print('{}{}: {:.4f}{}'.format(TextColors.OKGREEN, metric, result,
75 | TextColors.ENDC))
76 | else:
77 | ave_pre, ave_rec, fscore = result
78 | print('{}ave_pre: {:.4f}, ave_rec: {:.4f}, fscore: {:.4f}{}'.format(
79 | TextColors.OKGREEN, ave_pre, ave_rec, fscore, TextColors.ENDC))
80 |
81 |
82 | if __name__ == '__main__':
83 | metric_funcs = inspect.getmembers(metrics, inspect.isfunction)
84 | metric_names = [n for n, _ in metric_funcs]
85 |
86 | parser = argparse.ArgumentParser(description='Evaluate Cluster')
87 | parser.add_argument('--gt_labels', type=str, required=True)
88 | parser.add_argument('--pred_labels', type=str, required=True)
89 | parser.add_argument('--metric', default='pairwise', choices=metric_names)
90 | args = parser.parse_args()
91 |
92 | evaluate(args.gt_labels, args.pred_labels, args.metric)
93 |
--------------------------------------------------------------------------------
/AND/util/metrics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import division
5 |
6 | import numpy as np
7 | from sklearn.metrics.cluster import (contingency_matrix,
8 | normalized_mutual_info_score)
9 | from sklearn.metrics import (precision_score, recall_score)
10 |
11 | __all__ = ['pairwise', 'bcubed', 'nmi', 'precision', 'recall', 'accuracy']
12 |
13 |
14 | def _check(gt_labels, pred_labels):
15 | if gt_labels.ndim != 1:
16 | raise ValueError("gt_labels must be 1D: shape is %r" %
17 | (gt_labels.shape, ))
18 | if pred_labels.ndim != 1:
19 | raise ValueError("pred_labels must be 1D: shape is %r" %
20 | (pred_labels.shape, ))
21 | if gt_labels.shape != pred_labels.shape:
22 | raise ValueError(
23 | "gt_labels and pred_labels must have same size, got %d and %d" %
24 | (gt_labels.shape[0], pred_labels.shape[0]))
25 | return gt_labels, pred_labels
26 |
27 |
28 | def _get_lb2idxs(labels):
29 | lb2idxs = {}
30 | for idx, lb in enumerate(labels):
31 | if lb not in lb2idxs:
32 | lb2idxs[lb] = []
33 | lb2idxs[lb].append(idx)
34 | return lb2idxs
35 |
36 |
37 | def _compute_fscore(pre, rec):
38 | return 2. * pre * rec / (pre + rec)
39 |
40 |
41 | def fowlkes_mallows_score(gt_labels, pred_labels, sparse=True):
42 | ''' The original function is from `sklearn.metrics.fowlkes_mallows_score`.
43 | We output the pairwise precision, pairwise recall and F-measure,
44 | instead of calculating the geometry mean of precision and recall.
45 | '''
46 | n_samples, = gt_labels.shape
47 |
48 | c = contingency_matrix(gt_labels, pred_labels, sparse=sparse)
49 | tk = np.dot(c.data, c.data) - n_samples
50 | pk = np.sum(np.asarray(c.sum(axis=0)).ravel()**2) - n_samples
51 | qk = np.sum(np.asarray(c.sum(axis=1)).ravel()**2) - n_samples
52 |
53 | avg_pre = tk / pk
54 | avg_rec = tk / qk
55 | fscore = _compute_fscore(avg_pre, avg_rec)
56 |
57 | return avg_pre, avg_rec, fscore
58 |
59 |
60 | def pairwise(gt_labels, pred_labels, sparse=True):
61 | _check(gt_labels, pred_labels)
62 | return fowlkes_mallows_score(gt_labels, pred_labels, sparse)
63 |
64 |
65 | def bcubed(gt_labels, pred_labels):
66 | _check(gt_labels, pred_labels)
67 |
68 | gt_lb2idxs = _get_lb2idxs(gt_labels)
69 | pred_lb2idxs = _get_lb2idxs(pred_labels)
70 |
71 | num_lbs = len(gt_lb2idxs)
72 | pre = np.zeros(num_lbs)
73 | rec = np.zeros(num_lbs)
74 | gt_num = np.zeros(num_lbs)
75 |
76 | for i, gt_idxs in enumerate(gt_lb2idxs.values()):
77 | all_pred_lbs = np.unique(pred_labels[gt_idxs])
78 | gt_num[i] = len(gt_idxs)
79 | for pred_lb in all_pred_lbs:
80 | pred_idxs = pred_lb2idxs[pred_lb]
81 | n = 1. * np.intersect1d(gt_idxs, pred_idxs).size
82 | pre[i] += n**2 / len(pred_idxs)
83 | rec[i] += n**2 / gt_num[i]
84 |
85 | gt_num = gt_num.sum()
86 | avg_pre = pre.sum() / gt_num
87 | avg_rec = rec.sum() / gt_num
88 | fscore = _compute_fscore(avg_pre, avg_rec)
89 |
90 | return avg_pre, avg_rec, fscore
91 |
92 |
93 | def nmi(gt_labels, pred_labels):
94 | return normalized_mutual_info_score(pred_labels, gt_labels)
95 |
96 |
97 | def precision(gt_labels, pred_labels):
98 | return precision_score(gt_labels, pred_labels)
99 |
100 |
101 | def recall(gt_labels, pred_labels):
102 | return recall_score(gt_labels, pred_labels)
103 |
104 |
105 | def accuracy(gt_labels, pred_labels):
106 | return np.mean(gt_labels == pred_labels)
107 |
--------------------------------------------------------------------------------
/GCN/adjacency.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import numpy as np
5 | import scipy.sparse as sp
6 |
7 |
8 | def row_normalize(mx):
9 | """Row-normalize sparse matrix"""
10 | rowsum = np.array(mx.sum(1))
11 | # if rowsum <= 0, keep its previous value
12 | rowsum[rowsum <= 0] = 1
13 | r_inv = np.power(rowsum, -1).flatten()
14 | r_inv[np.isinf(r_inv)] = 0.
15 | r_mat_inv = sp.diags(r_inv)
16 | mx = r_mat_inv.dot(mx)
17 | return mx
18 |
19 |
20 | def build_symmetric_adj(adj, self_loop=True):
21 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
22 | if self_loop:
23 | adj = adj + sp.eye(adj.shape[0])
24 | return adj
25 |
26 |
27 | def sparse_mx_to_indices_values(sparse_mx):
28 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
29 | indices = np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
30 | values = sparse_mx.data
31 | shape = np.array(sparse_mx.shape)
32 | return indices, values, shape
33 |
34 |
35 | def indices_values_to_sparse_tensor(indices, values, shape):
36 | import torch
37 | indices = torch.from_numpy(indices)
38 | values = torch.from_numpy(values)
39 | shape = torch.Size(shape)
40 | return torch.sparse.FloatTensor(indices, values, shape)
41 |
42 |
43 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
44 | """Convert a scipy sparse matrix to a torch sparse tensor."""
45 | indices, values, shape = sparse_mx_to_indices_values(sparse_mx)
46 | return indices_values_to_sparse_tensor(indices, values, shape)
47 |
--------------------------------------------------------------------------------
/GCN/cluster.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import sys
3 | import time
4 | from util.confidence import confidence_to_peaks
5 | from util.deduce import peaks_to_labels
6 | from util.evaluate import evaluate
7 | from multiprocessing import Process, Manager
8 | import numpy as np
9 | import os
10 | import torch
11 | import torch.nn.functional as F
12 | from util.graph import graph_propagation_onecut
13 | from multiprocessing import Pool
14 | from util.deduce import edge_to_connected_graph
15 |
16 | metric_list = ['bcubed', 'pairwise', 'nmi']
17 | topN = 121
18 | def worker(param):
19 | i, pdict = param
20 | query_nodeid = ngbr_arr[i, 0]
21 | for j in range(1, dist_arr.shape[1]):
22 | doc_nodeid = ngbr_arr[i, j]
23 | tpl = (query_nodeid, doc_nodeid)
24 | dist = dist_arr[query_nodeid, j]
25 | if dist > cos_dist_thres:
26 | continue
27 | pdict[tpl] = dist
28 |
29 | def format(dist_arr, ngbr_arr):
30 | edge_list, score_list = [], []
31 | for i in range(dist_arr.shape[0]):
32 | query_nodeid = ngbr_arr[i, 0]
33 | for j in range(1, dist_arr.shape[1]):
34 | doc_nodeid = ngbr_arr[i, j]
35 | tpl = (query_nodeid, doc_nodeid)
36 | score = 1 - dist_arr[query_nodeid, j]
37 | if score < cos_sim_thres:
38 | continue
39 | edge_list.append(tpl)
40 | score_list.append(score)
41 | edge_arr, score_arr = np.array(edge_list), np.array(score_list)
42 | return edge_arr, score_arr
43 |
44 | def clusters2labels(clusters, n_nodes):
45 | labels = (-1)* np.ones((n_nodes,))
46 | for ci, c in enumerate(clusters):
47 | for xid in c:
48 | labels[xid.name] = ci
49 |
50 | cnt = len(clusters)
51 | idx_list = np.where(labels < 0)[0]
52 | for idx in idx_list:
53 | labels[idx] = cnt
54 | cnt += 1
55 | assert np.sum(labels<0) < 1
56 | return labels
57 |
58 | def disjoint_set_onecut(sim_dict, thres, num):
59 | edge_arr = []
60 | for edge, score in sim_dict.items():
61 | if score < thres:
62 | continue
63 | edge_arr.append(edge)
64 | pred_arr = edge_to_connected_graph(edge_arr, num)
65 | return pred_arr
66 |
67 | def get_eval(cos_sim_thres):
68 | pred_arr = disjoint_set_onecut(sim_dict, cos_sim_thres, len(gt_arr))
69 | print("now is %s done"%cos_sim_thres)
70 | res_str = ""
71 | for metric in metric_list:
72 | res_str += evaluate(gt_arr, pred_arr, metric)
73 | res_str += "\n"
74 | return res_str
75 |
76 | if __name__ == "__main__":
77 | Ifile, Dfile, gtfile = sys.argv[1], sys.argv[2], sys.argv[3]
78 |
79 | gt_arr = np.load(gtfile)
80 | nbr_arr = np.load(Ifile).astype(np.int32)[:, :topN]
81 | dist_arr = np.load(Dfile)[:, :topN]
82 | sim_dict = {}
83 | for query_nodeid, (nbr, dist) in enumerate(zip(nbr_arr, dist_arr)):
84 | for j, doc_nodeid in enumerate(nbr): # 从0开始,包括自己
85 | if query_nodeid < doc_nodeid:
86 | tpl = (query_nodeid, doc_nodeid)
87 | else:
88 | tpl = (doc_nodeid, query_nodeid)
89 | sim_dict[tpl] = 1 - dist[j]
90 |
91 | thres = 0.96
92 | print('now sim thres %.2f'%sim_thres)
93 | pred_arr = disjoint_set_onecut(sim_dict, sim_thres, len(gt_arr))
94 | for metric in metric_list:
95 | print(evaluate(gt_arr, pred_arr, metric))
96 |
--------------------------------------------------------------------------------
/GCN/config.yml:
--------------------------------------------------------------------------------
1 | # model
2 | feat_dim: 256
3 | nhid: 512
4 | nclass: 8573
5 |
6 | # optimizer
7 | lr: 0.1 #0.01
8 | sgd_momentum: 0.9
9 | sgd_weight_decay: 0.00001
10 | lr_step : [0.5, 0.8, 0.9]
11 | factor: 0.1
12 | total_step: 35000 #
13 | cuda: True
14 | fp16: False
15 | batchsize: 1 #
16 | warmup_step: 1024 #
17 |
18 | # output
19 | save_freq: 5000 #
20 | log_freq: 1
21 | # resume
22 | resume_path:
23 |
--------------------------------------------------------------------------------
/GCN/net/gat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | import torch.nn as nn
5 | from dgl.nn.pytorch import SAGEConv
6 | from .optim_modules import BallClusterLearningLoss, ClusterLoss
7 | import torch.nn.functional as F
8 |
9 | class GCN_V(nn.Module):
10 | def __init__(self, feature_dim, nhid, nclass, dropout=0, losstype='allall', margin=1., pweight=4., pmargin=1.0):
11 | super(GCN_V, self).__init__()
12 |
13 | self.sage1 = SAGEConv(feature_dim, nhid, aggregator_type='gcn', activation=F.relu)
14 | self.sage2 = SAGEConv(nhid, nhid, aggregator_type='gcn', activation=F.relu)
15 |
16 | self.nclass = nclass
17 | self.fc = nn.Sequential(nn.Linear(nhid, nhid), nn.PReLU(nhid))
18 | self.loss = torch.nn.MSELoss()
19 | self.bclloss = ClusterLoss(losstype=losstype, margin=margin, alpha_pos=pweight, pmargin=pmargin)
20 |
21 | def forward(self, data, output_feat=False, return_loss=False):
22 | assert not output_feat or not return_loss
23 | x, block_list, label, idlabel = data[0], data[1], data[2], data[3]
24 |
25 | # layer1
26 | gcnfeat = self.sage1(block_list[0], x)
27 | gcnfeat = F.normalize(gcnfeat, p=2, dim=1)
28 |
29 | # layer2
30 | gcnfeat = self.sage2(block_list[1], gcnfeat)
31 |
32 | # layer3
33 | fcfeat = self.fc(gcnfeat)
34 | fcfeat = F.normalize(fcfeat, dim=1)
35 |
36 | if output_feat:
37 | return fcfeat, gcnfeat
38 |
39 | if return_loss:
40 | bclloss_dict = self.bclloss(fcfeat, label)
41 | return bclloss_dict
42 |
43 | return fcfeat
44 |
--------------------------------------------------------------------------------
/GCN/net/optim_modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Classes for all models and loss functions for clustering.
3 |
4 | FC --> ReLU --> BN --> Dropout --> FC
5 | """
6 |
7 | import warnings
8 | import numpy as np
9 |
10 | # Torch
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 | # Local imports
16 | #import utils
17 | #import config
18 | #import lorentz
19 |
20 |
21 | def sqeuclidean_pdist(x, y=None):
22 | """Fast and efficient implementation of ||X - Y||^2 = ||X||^2 + ||Y||^2 - 2 X^T Y
23 | Input: x is a Nxd matrix
24 | y is an optional Mxd matirx
25 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
26 | if y is not given then use 'y=x'.
27 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
28 | """
29 |
30 | x_norm = (x**2).sum(1).unsqueeze(1)
31 | if y is not None:
32 | y_t = torch.transpose(y, 0, 1)
33 | y_norm = (y**2).sum(1).unsqueeze(0)
34 | else:
35 | y_t = torch.transpose(x, 0, 1)
36 | y_norm = x_norm.squeeze().unsqueeze(0)
37 |
38 | dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
39 | # get rid of NaNs
40 | dist[torch.isnan(dist)] = 0.
41 | # clamp negative stuff to 0
42 | dist = torch.clamp(dist, 0., np.inf)
43 | # ensure diagonal is 0
44 | if y is None:
45 | dist[dist == torch.diag(dist)] = 0.
46 |
47 | return dist
48 |
49 | # ============================================================================ #
50 | # LOSS FUNCTIONS #
51 | # ============================================================================ #
52 |
53 | def get_pos_loss(decision_mat, label_mat, beta, k=1):
54 | if label_mat.sum() == 0:
55 | print('cut pos 0')
56 | return torch.tensor(0.)
57 | # decision is not confidence, which is always to positive in 0~1, and not contain the class information
58 | # decision is a real value and the class infomation in the sign
59 | # this is decision mat, for pos sample, the smaller of the val the harder of the case
60 | decision_arr = decision_mat[label_mat].topk(k=k, largest=False)[0]
61 | loss = F.relu(beta - decision_arr)
62 | loss = loss.mean()
63 | return loss
64 |
65 | def get_neg_loss(decision_mat, label_mat, beta, k=1):
66 | if label_mat.sum() == 0:
67 | print('cut neg 0')
68 | return torch.tensor(0.)
69 | # this is decision mat, for neg sample, the larger of the val, the harder of the case
70 | decision_arr = -1 * decision_mat[label_mat].topk(k=k, largest=True)[0]
71 | loss = F.relu(beta - decision_arr)
72 | loss = loss.mean()
73 | return loss
74 |
75 | def cosine_sim(x, y):
76 | return torch.mm(x, y.T)
77 |
78 | def euclidean_dist(x, y):
79 | m, n = x.size(0), y.size(0)
80 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
81 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
82 | dist = xx + yy
83 | dist.addmm_(1, -2, x, y.t())
84 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
85 | return dist
86 |
87 | class ClusterLoss(nn.Module):
88 | def __init__(self, beta_pos=0.5, beta_neg=0.5, alpha_pos=4., alpha_neg=1., gamma_eps=0.05, losstype='allall', margin=1., pmargin=1.):
89 | super(ClusterLoss, self).__init__()
90 | #self.gamma_eps = gamma_eps
91 | self.alpha_pos = alpha_pos
92 | self.alpha_neg = alpha_neg
93 |
94 | self.beta_pos = nn.Parameter(torch.tensor(beta_pos))
95 | #self.beta_neg = nn.Parameter(torch.tensor(beta_neg))
96 | self.losstype = losstype
97 | self.margin = margin
98 | self.pmargin = pmargin
99 |
100 | def forward(self, X, labels):
101 | #beta_pos = F.softplus(self.beta_pos)
102 | #beta_neg = F.softplus(self.beta_neg)
103 | #beta_neg = beta_pos
104 | beta_pos = self.pmargin
105 | beta_neg = self.margin
106 |
107 | #X_copy = X.clone().detach()
108 | #decision_mat = (X.unsqueeze(1) - X_copy.unsqueeze(0)).pow(2).sum(2).sqrt() # euclidean distance
109 | #decision_mat = euclidean_dist(X, X)
110 | decision_mat = cosine_sim(X, X)
111 | label_mat = (labels.unsqueeze(0) == labels.unsqueeze(1))
112 | #print("beta pos", beta_pos.item(), 'beta neg', beta_neg.item())
113 | print("losstype", self.losstype, "margin", self.margin, 'pweight', self.alpha_pos, 'pmargin', self.pmargin)
114 |
115 | neg_label_mat = (1-label_mat.float()).bool()
116 | if self.losstype == 'maxmax':
117 | pos_loss = get_pos_loss(decision_mat, label_mat, beta_pos, k=1)
118 | neg_loss = get_neg_loss(decision_mat, neg_label_mat, beta_neg, k=1)
119 | elif self.losstype == 'allmax':
120 | pos_loss = get_pos_loss(decision_mat, label_mat, beta_pos, k=label_mat.sum().item())
121 | neg_loss = get_neg_loss(decision_mat, neg_label_mat, beta_neg, k=1)
122 | elif self.losstype == 'allall':
123 | pos_loss = get_pos_loss(decision_mat, label_mat, beta_pos, k=label_mat.sum().item())
124 | neg_loss = get_neg_loss(decision_mat, neg_label_mat, beta_neg, k=neg_label_mat.sum().item())
125 | elif self.losstype == 'alltopk':
126 | pos_loss = get_pos_loss(decision_mat, label_mat, beta_pos, k=label_mat.sum().item())
127 | neg_loss = get_neg_loss(decision_mat, neg_label_mat, beta_neg, k=min(len(labels), neg_label_mat.sum().item()) )
128 | else:
129 | raise ValueError('loss type %s not implement'%self.lossstype)
130 |
131 | losses = {'ctrd_pos': pos_loss * self.alpha_pos, 'ctrd_neg': neg_loss * self.alpha_neg}
132 | return losses
133 |
134 | class BallClusterLearningLoss(nn.Module):
135 | """Final BCL method
136 | space: 'sqeuclidean' or 'lorentz'
137 | init_bias: initialize bias to this value
138 | temperature: sampling temperature (decayed in main training loop)
139 | beta: Lorentz beta for comparison in Lorentz space
140 | """
141 |
142 | def __init__(self, gpuid=0, space='sqeuclidean', l2norm=True, gamma_eps=0.05,
143 | init_bias=0.1, learn_bias=True, beta=0.01, alpha_pos=4., alpha_neg=1., mult_bias=0.):
144 | """Initialize
145 | """
146 | super(BallClusterLearningLoss, self).__init__()
147 | self.space = space
148 | self.learn_bias = learn_bias
149 | self.l2norm = l2norm
150 | self.beta = beta
151 | self.gamma_eps = gamma_eps
152 | self.alpha_pos = alpha_pos
153 | self.alpha_neg = alpha_neg
154 | self.mult_bias = mult_bias
155 | self.gpuid = gpuid
156 |
157 | self.h_bias = nn.Parameter(torch.tensor(init_bias))
158 | self.bias = F.softplus(self.h_bias)
159 |
160 | def forward(self, Xemb, labels):
161 | """
162 | Xemb: N x D, N features, D embedding dimension
163 | labels: ground-truth cluster indices
164 | NOTE: labels are not necessarily ordered indices, just unique ones, don't use for indexing!
165 | """
166 |
167 | self.bias = F.softplus(self.h_bias)
168 |
169 | # get unique labels to loop over clusters
170 | unique_labels = labels.unique() # torch vector on cuda
171 | K = unique_labels.numel()
172 | N = Xemb.size(0)
173 |
174 | # collect centroids, cluster-assignment matrix, and positive cluster index
175 | centroids = []
176 | pos_idx = -1 * torch.ones_like(labels) # N vector, each in [0 .. K-1]
177 | clst_assignments = torch.zeros(N, K).to(self.gpuid) # NxK {0, 1} matrix
178 | for k, clid in enumerate(unique_labels):
179 | idx = labels == clid
180 | # assign all samples with cluster clid as k
181 | pos_idx[idx] = k
182 | clst_assignments[idx, k] = 1
183 | # collect all features
184 | Xclst = Xemb[idx, :]
185 | centroid = Xclst.mean(0)
186 | centroid = centroid / centroid.norm()
187 | # collect centroids
188 | centroids.append(centroid)
189 | centroids = torch.stack(centroids, dim=0)
190 |
191 | # pairwise distances between all embeddings of the batch and the centroids
192 | XC_dist = (Xemb.unsqueeze(1) - centroids.unsqueeze(0)).pow(2).sum(2)
193 |
194 | # add bias to the distances indexed appropriately
195 | pos_bias = self.bias
196 | neg_bias = 9 * self.bias + self.gamma_eps
197 |
198 | # add bias and use "cross-entropy" loss on pos_idx
199 | bias_adds = clst_assignments * pos_bias + (1 - clst_assignments) * neg_bias
200 | final_distance = (-XC_dist + bias_adds) * 0.1
201 | # when not using bias, just ignore
202 | if self.bias == 0.:
203 | final_distance = -XC_dist * 0.1
204 |
205 | # make sure positive distances are below the pos-bias
206 | pos_distances = XC_dist.gather(1, pos_idx.unsqueeze(1))
207 | pos_sample_loss = F.relu(pos_distances - pos_bias)
208 |
209 | # make sure negative distances are more than neg-bias
210 | #avg_neg_distances = XC_dist[1 - clst_assignments.byte()].view(N, K-1).mean(1)
211 | #min_neg_distances = XC_dist[1 - clst_assignments.byte()].view(N, K-1).min(1)[0] # [0] returns values not indices
212 | if len(XC_dist[(1 - clst_assignments).bool()]) == 0:
213 | neg_sample_loss = torch.Tensor([0])
214 | print("===== neg sample is 0")
215 | else:
216 | min_neg_distances = XC_dist[(1 - clst_assignments).bool()].view(N, K-1).min(1)[0] # [0] returns values not indices
217 | neg_sample_loss = F.relu(neg_bias - min_neg_distances)
218 |
219 | pos_loss = pos_sample_loss.mean()
220 | neg_loss = neg_sample_loss.mean()
221 |
222 | losses = {'ctrd_pos': pos_loss * self.alpha_pos, 'ctrd_neg': neg_loss * self.alpha_neg}
223 | #losses = pos_loss * self.alpha_pos + neg_loss * self.alpha_neg
224 |
225 | return losses
226 |
227 |
228 | class PrototypicalLoss(nn.Module):
229 | """Prototypical network like loss with bias
230 | p_ik = exp(- d(x^k_i, c^k) + b) / (exp(- d(x^k_i, c^k) + b) + sum_j exp(- d(x^k_i, c^j) + 2b))
231 | Loss = -mean_k( mean_i ( -log p_ik ))
232 | space: 'sqeuclidean' or 'lorentz'
233 | init_bias: initialize bias to this value
234 | temperature: sampling temperature (decayed in main training loop)
235 | beta: Lorentz beta for comparison in Lorentz space
236 | """
237 |
238 | def __init__(self, device, space='sqeuclidean', l2norm=False, gamma_eps=0.05,
239 | init_bias=0., learn_bias=False, beta=0.01, alpha_pos=1., alpha_neg=1., mult_bias=0.):
240 | """Initialize
241 | """
242 | super(PrototypicalLoss, self).__init__()
243 | self.device = device
244 | self.space = space
245 | self.learn_bias = learn_bias
246 | self.l2norm = l2norm
247 | self.beta = beta
248 | self.gamma_eps = gamma_eps
249 | self.alpha_pos = alpha_pos
250 | self.alpha_neg = alpha_neg
251 | self.mult_bias = mult_bias
252 |
253 | self.bias = torch.tensor(init_bias).to(self.device)
254 |
255 | def forward(self, Xemb, scores, labels):
256 | """
257 | Xemb: N x D, N features, D embedding dimension
258 | labels: ground-truth cluster indices
259 | NOTE: labels are not necessarily ordered indices, just unique ones, don't use for indexing!
260 | """
261 |
262 | unique_labels = labels.unique() # torch vector on cuda
263 | K = unique_labels.numel()
264 | N = Xemb.size(0)
265 |
266 | # collect centroids, cluster-assignment matrix, and positive cluster index
267 | centroids = []
268 | pos_idx = -1 * torch.ones_like(labels) # N vector, each in [0 .. K-1]
269 | clst_assignments = torch.zeros(N, K).to(self.device) # NxK {0, 1} matrix
270 | for k, clid in enumerate(unique_labels):
271 | idx = labels == clid
272 | # assign all samples with cluster clid as k
273 | pos_idx[idx] = k
274 | clst_assignments[idx, k] = 1
275 | # collect all features
276 | Xclst = Xemb[idx, :]
277 | centroid = Xclst.mean(0)
278 | # collect centroids
279 | centroids.append(centroid)
280 | centroids = torch.stack(centroids, dim=0)
281 |
282 | # pairwise distances between all embeddings of the batch and the centroids
283 | XC_dist = (Xemb.unsqueeze(1) - centroids.unsqueeze(0)).pow(2).sum(2)
284 |
285 | # add bias to the distances indexed appropriately
286 | pos_bias = self.bias
287 | neg_bias = 9 * self.bias + self.gamma_eps
288 | final_distance = -XC_dist * 0.1
289 |
290 | # compute cross-entropy
291 | pro_sample_loss = F.cross_entropy(final_distance, pos_idx, reduction='none')
292 |
293 | # do mean of means to get final loss value
294 | pro_loss = torch.tensor(0.).to(self.device)
295 | for clid in unique_labels:
296 | pro_loss += pro_sample_loss[labels == clid].mean()
297 | pro_loss /= K
298 |
299 | losses = {'ctrd_pro': pro_loss}
300 |
301 | return losses
302 |
303 |
304 | class ContrastiveLoss(nn.Module):
305 | """
306 | In the original paper http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
307 | Y = 0 for similar pairs ("positive")
308 | Y = 1 for dissimilar pairs ("negatives")
309 | L(Y, X1, X2) = (1 - Y) * 0.5 * D^2 + Y * 0.5 * (max(0, m - D))^2
310 | NOTE: distance is in Euclidean space, not sqeuclidean!
311 | """
312 |
313 | def __init__(self, device, l2norm=True,
314 | init_bias=1., learn_bias=True):
315 | """Initialize
316 | """
317 | super(ContrastiveLoss, self).__init__()
318 | self.device = device
319 | self.learn_bias = learn_bias
320 | self.l2norm = l2norm
321 |
322 | self.h_bias = nn.Parameter(torch.tensor(init_bias))
323 | self.bias = F.softplus(self.h_bias)
324 |
325 | def forward(self, Xemb, scores, labels):
326 | """
327 | Xemb: N x D, N features, D embedding dimension
328 | labels: ground-truth cluster indices
329 | """
330 |
331 | self.bias = F.softplus(self.h_bias)
332 |
333 | N = Xemb.size(0)
334 | match = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() # a NxN {0,1} matrix
335 |
336 | ### generate positive pairs, and pull corresponding features
337 | diag_mask = 1 - torch.eye(N).to(self.device)
338 | pos_idx = (diag_mask * match).nonzero()
339 | X1_pos = Xemb.index_select(0, pos_idx[:, 0])
340 | X2_pos = Xemb.index_select(0, pos_idx[:, 1])
341 |
342 | ### generate random negatives
343 | neg_idx = []
344 | while len(neg_idx) < X1_pos.size(0): # match pairs for negatives
345 | idx = torch.randint(N, (2,)).long()
346 | if match[idx[0], idx[1]] == 0:
347 | neg_idx.append(idx)
348 | neg_idx = torch.stack(neg_idx).to(self.device)
349 |
350 | X1_neg = Xemb.index_select(0, neg_idx[:, 0])
351 | X2_neg = Xemb.index_select(0, neg_idx[:, 1])
352 |
353 | # compute distances (Euclidean!)
354 | pos_distances_sq = ((X1_pos - X2_pos) ** 2).sum(1)
355 | neg_distances = ((X1_neg - X2_neg) ** 2).sum(1).sqrt()
356 |
357 | # Loss = 0.5 * pos_distances_sq + 0.5 * (max(0, m - neg_distances))^2
358 | pos_loss = 0.5 * pos_distances_sq.mean()
359 | neg_loss = 0.5 * (F.relu(self.bias - neg_distances) ** 2).mean()
360 |
361 | return {'cont_pos': pos_loss, 'cont_neg': neg_loss}
362 |
363 |
364 | class TripletLoss(nn.Module):
365 | """
366 | In the FaceNet paper https://arxiv.org/pdf/1503.03832.pdf
367 | L = max(0, d+ - d- + alpha)
368 | NOTE: distance is in sqeuclidean space!
369 | """
370 |
371 | def __init__(self, device, space='sqeuclidean', l2norm=True,
372 | init_bias=0.5, learn_bias=False):
373 | """Initialize
374 | """
375 | super(TripletLoss, self).__init__()
376 | self.device = device
377 | self.space = space
378 | self.learn_bias = learn_bias
379 | self.l2norm = l2norm
380 |
381 | self.bias = torch.tensor(init_bias).to(self.device)
382 |
383 | def forward(self, Xemb, scores, labels):
384 | """
385 | Xemb: N x D, N features, D embedding dimension
386 | labels: ground-truth cluster indices
387 | """
388 |
389 | N = Xemb.size(0)
390 | match = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() # a NxN {0,1} matrix
391 |
392 | ### generate positive pairs, and pull corresponding features
393 | diag_mask = 1 - torch.eye(N).to(self.device)
394 | pos_idx = (diag_mask * match).nonzero()
395 | anc_idx = pos_idx[:, 0]
396 | pos_idx = pos_idx[:, 1]
397 |
398 | ### generate negatives for the same anchors as positive
399 | neg_idx = torch.zeros_like(pos_idx).long()
400 | for k in range(pos_idx.size(0)):
401 | this_negs = torch.nonzero(1 - match[pos_idx[k]]).squeeze()
402 | neg_idx[k] = this_negs[torch.randperm(this_negs.size(0))][0]
403 |
404 | X_anc = Xemb.index_select(0, anc_idx)
405 | X_pos = Xemb.index_select(0, pos_idx)
406 | X_neg = Xemb.index_select(0, neg_idx)
407 |
408 | # compute distances
409 | pos_distances = ((X_anc - X_pos) ** 2).sum(1)
410 | neg_distances = ((X_anc - X_neg) ** 2).sum(1)
411 |
412 | # loss
413 | loss = F.relu(self.bias + pos_distances - neg_distances).mean()
414 |
415 | return {'trip': loss}
416 |
417 |
418 | class LogisticDiscriminantLoss(nn.Module):
419 | """Pairwise distance between samples, using logistic regression
420 | https://hal.inria.fr/file/index/docid/439290/filename/GVS09.pdf
421 | space: 'sqeuclidean' or 'lorentz'
422 | init_bias: initialize bias to this value (or as set by radius)
423 | temperature: sampling temperature (decayed in main training loop)
424 | with_ball: loss being used along with ball loss?
425 | beta: Lorentz beta for comparison in Lorentz space
426 | """
427 |
428 | def __init__(self, device, space='sqeuclidean',
429 | init_bias=0.5, learn_bias=True, temperature=1., beta=0.01,
430 | with_ball=False):
431 | """Initialize
432 | """
433 | super(LogisticDiscriminantLoss, self).__init__()
434 | self.device = device
435 | self.space = space
436 | self.temperature = temperature
437 |
438 | self.bias = nn.Parameter(torch.tensor(init_bias))
439 |
440 | def forward(self, Xemb, scores, labels):
441 | """
442 | Xemb: N x D, N features, D embedding dimension
443 | labels: ground-truth cluster indices
444 | """
445 |
446 | N = Xemb.size(0)
447 | match = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() # a NxN {0,1} matrix
448 |
449 | ### generate positive pairs, and pull corresponding features
450 | diag_mask = 1 - torch.eye(N).to(self.device)
451 | pos_idx = (diag_mask * match).nonzero()
452 | X1_pos = Xemb.index_select(0, pos_idx[:, 0])
453 | X2_pos = Xemb.index_select(0, pos_idx[:, 1])
454 |
455 | ### generate random negatives
456 | neg_idx = []
457 | while len(neg_idx) < X1_pos.size(0): # match pairs for negatives
458 | idx = torch.randint(N, (2,)).long()
459 | if match[idx[0], idx[1]] == 0:
460 | neg_idx.append(idx)
461 | neg_idx = torch.stack(neg_idx).to(self.device)
462 |
463 | X1_neg = Xemb.index_select(0, neg_idx[:, 0])
464 | X2_neg = Xemb.index_select(0, neg_idx[:, 1])
465 |
466 | # compute distances
467 | pos_distances = ((X1_pos - X2_pos) ** 2).sum(1)
468 | neg_distances = ((X1_neg - X2_neg) ** 2).sum(1)
469 |
470 | # Loss = -y log(p) - (1-y) log(1-p)
471 | pos_logprobs = torch.sigmoid((self.bias - pos_distances)/self.temperature)
472 | neg_logprobs = torch.sigmoid((self.bias - neg_distances)/self.temperature)
473 | pos_loss = -(pos_logprobs).log().mean()
474 | neg_loss = -(1 - neg_logprobs).log().mean()
475 |
476 | return {'ldml_pos': pos_loss, 'ldml_neg': neg_loss}
477 |
478 |
479 |
--------------------------------------------------------------------------------
/GCN/train.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | from __future__ import division
3 | import torch
4 | import torch.optim as optim
5 | from adjacency import sparse_mx_to_torch_sparse_tensor
6 | from net.gat import GCN_V
7 | #from net.gcn_v import GCN_V
8 | #from net.softmaxloss import GCN_V
9 | import yaml
10 | from easydict import EasyDict
11 | from tensorboardX import SummaryWriter
12 | import numpy as np
13 | import scipy.sparse as sp
14 | import time
15 | import sys
16 | import os
17 | import apex
18 | from apex import amp
19 | import dgl
20 | import math
21 | import argparse
22 | import pprint
23 | from abc import ABC, abstractproperty, abstractmethod
24 | from collections.abc import Mapping
25 |
26 | class Collator(ABC):
27 | @abstractproperty
28 | def dataset(self):
29 | raise NotImplementedError
30 |
31 | @abstractmethod
32 | def collate(self, items):
33 | raise NotImplementedError
34 |
35 | class multigraph_NodeCollator(Collator):
36 | def __init__(self, order_graph, ngbr_graph, nids, block_sampler):
37 | self.order_graph = order_graph
38 | self.ngbr_graph = ngbr_graph
39 | self.nids = nids
40 | self.block_sampler = block_sampler
41 | self._dataset = nids
42 |
43 | @property
44 | def dataset(self):
45 | return self._dataset
46 |
47 | def collate(self, items):
48 | # use collate to fasten
49 | #blocks = self.block_sampler.sample_blocks(self.g, items)
50 |
51 | seed_node0 = items
52 | frontier0 = dgl.sampling.sample_neighbors(self.order_graph, seed_node0, 128, replace=False) # sample 128 from 256
53 | block0 = dgl.to_block(frontier0, seed_node0)
54 |
55 | seed_node1 = {ntype: block0.srcnodes[ntype].data[dgl.NID] for ntype in block0.srctypes}
56 | frontier1 = dgl.sampling.sample_neighbors(self.ngbr_graph, seed_node1, 80, replace=False)
57 | block1 = dgl.to_block(frontier1, seed_node1)
58 | block1.create_format_()
59 |
60 | seed_node2 = {ntype: block1.srcnodes[ntype].data[dgl.NID] for ntype in block1.srctypes}
61 | frontier2 = dgl.sampling.sample_neighbors(self.ngbr_graph, seed_node2, 80, replace=False)
62 | block2 = dgl.to_block(frontier2, seed_node2)
63 | block2.create_format_()
64 |
65 | blocks = [block2, block1]
66 | input_nodes = blocks[0].srcdata[dgl.NID]
67 | output_nodes = blocks[-1].dstdata[dgl.NID]
68 | return input_nodes, output_nodes, blocks
69 |
70 | class AverageMeter(object):
71 | def __init__(self):
72 | self.val = 0
73 | self.avg = 0
74 | self.sum = 0
75 | self.count = 0
76 | def reset(self):
77 | self.val = 0
78 | self.avg = 0
79 | self.sum = 0
80 | self.count = 0
81 | def update(self, val, n=1):
82 | self.val = val
83 | self.sum += val * n
84 | self.count += n
85 | self.avg = float(self.sum) / (self.count + 1e-10)
86 |
87 | def row_normalize(mx):
88 | """Row-normalize sparse matrix"""
89 | rowsum = np.array(mx.sum(1))
90 | # if rowsum <= 0, keep its previous value
91 | rowsum[rowsum <= 0] = 1
92 | r_inv = np.power(rowsum, -1).flatten()
93 | r_inv[np.isinf(r_inv)] = 0.
94 | r_mat_inv = sp.diags(r_inv)
95 | mx = r_mat_inv.dot(mx)
96 | return mx
97 |
98 | class Timer():
99 | def __init__(self, name='task', verbose=True):
100 | self.name = name
101 | self.verbose = verbose
102 |
103 | def __enter__(self):
104 | print('[begin {}]'.format(self.name))
105 | self.start = time.time()
106 | return self
107 |
108 | def __exit__(self, exc_type, exc_val, exc_tb):
109 | if self.verbose:
110 | print('[done {}] use {:.3f} s'.format(self.name, time.time() - self.start))
111 | return exc_type is None
112 |
113 | def adjust_lr(cur_epoch, optimizer, cfg):
114 | if cur_epoch not in cfg.step_number:
115 | return
116 | ind = cfg.step_number.index(cur_epoch)
117 | for each in optimizer.param_groups:
118 | each['lr'] = cfg.lr * cfg.factor ** (ind+1)
119 |
120 | def cos_lr(current_step, optimizer, cfg):
121 | if current_step < cfg.warmup_step:
122 | rate = 1.0 * current_step / cfg.warmup_step
123 | lr = cfg.lr * rate
124 | else:
125 | n1 = cfg.total_step - cfg.warmup_step
126 | n2 = current_step - cfg.warmup_step
127 | rate = (1 + math.cos(math.pi * n2 / n1)) / 2
128 | lr = cfg.lr * rate
129 | for each in optimizer.param_groups:
130 | each['lr'] = lr
131 |
132 | if __name__ == "__main__":
133 | parser = argparse.ArgumentParser()
134 | parser.add_argument('--config_file', type=str)
135 | parser.add_argument('--outpath', type=str)
136 | parser.add_argument('--phase', type=str)
137 | parser.add_argument('--train_featfile', type=str)
138 | parser.add_argument('--train_adjfile', type=str)
139 | parser.add_argument('--train_orderadjfile', type=str)
140 | parser.add_argument('--train_labelfile', type=str)
141 | parser.add_argument('--test_featfile', type=str)
142 | parser.add_argument('--test_adjfile', type=str)
143 | parser.add_argument('--test_labelfile', type=str)
144 | parser.add_argument('--resume_path', type=str)
145 | parser.add_argument('--losstype', type=str)
146 | parser.add_argument('--margin', type=float)
147 | parser.add_argument('--pweight', type=float)
148 | parser.add_argument('--pmargin', type=float)
149 | parser.add_argument('--topk', type=int)
150 | args = parser.parse_args()
151 |
152 | beg_time = time.time()
153 | config = yaml.load(open(args.config_file, "r"), Loader=yaml.FullLoader)
154 | cfg = EasyDict(config)
155 | cfg.step_number = [int(r * cfg.total_step) for r in cfg.lr_step]
156 |
157 | # force assignment
158 | for key, value in args._get_kwargs():
159 | cfg[key] = value
160 | #cfg[list(dict(train_adjfile=train_adjfile).keys())[0]] = train_adjfile
161 | #cfg[list(dict(train_labelfile=train_labelfile).keys())[0]] = train_labelfile
162 | #cfg[list(dict(test_adjfile=test_adjfile).keys())[0]] = test_adjfile
163 | #cfg[list(dict(test_labelfile=test_labelfile).keys())[0]] = test_labelfile
164 | cfg.var = EasyDict()
165 | print("train hyper parameter list")
166 | pprint.pprint(cfg)
167 |
168 |
169 | # get model
170 | model = GCN_V(feature_dim=cfg.feat_dim, nhid=cfg.nhid, nclass=cfg.nclass, dropout=0., losstype=cfg.losstype, margin=cfg.margin,
171 | pweight=cfg.pweight, pmargin=cfg.pmargin)
172 |
173 | # get dataset
174 | with Timer('load data'):
175 | if cfg.phase == 'train':
176 | featfile, adjfile, labelfile = cfg.train_featfile, cfg.train_adjfile, cfg.train_labelfile
177 | order_adj = sp.load_npz(cfg.train_orderadjfile).astype(np.float32)
178 | order_graph = dgl.from_scipy(order_adj)
179 | else:
180 | featfile, adjfile, labelfile = cfg.test_featfile, cfg.test_adjfile, cfg.test_labelfile
181 | features = np.load(featfile)
182 | features = features / np.linalg.norm(features, axis=1, keepdims=True)
183 | adj = sp.load_npz(adjfile).astype(np.float32)
184 | graph = dgl.from_scipy(adj)
185 | label_arr = np.load(labelfile)
186 | features = torch.FloatTensor(features)
187 | #adj = sparse_mx_to_torch_sparse_tensor(adj)
188 | label_cpu = torch.LongTensor(label_arr)
189 | if cfg.cuda:
190 | model.cuda()
191 | features = features.cuda()
192 | #adj = adj.cuda()
193 | labels = label_cpu.cuda()
194 | #data = (features, adj, labels)
195 |
196 | # get train
197 | if cfg.phase == 'train':
198 | # get optimizer
199 | pretrain_pool = True
200 | pretrain_pool = False
201 | if pretrain_pool:
202 | pool_weight, net_weight = [], []
203 | for k, v in model.named_parameters():
204 | if 'pool.' in k:
205 | pool_weight += [v]
206 | else:
207 | net_weight += [v]
208 | param_list = [{'params': pool_weight}, {'params': net_weight, 'lr': 0.}]
209 | optimizer = optim.SGD(param_list, cfg.lr, momentum=cfg.sgd_momentum, weight_decay=cfg.sgd_weight_decay)
210 | else:
211 | optimizer = optim.SGD(model.parameters(), cfg.lr, momentum=cfg.sgd_momentum, weight_decay=cfg.sgd_weight_decay)
212 |
213 | if cfg.fp16:
214 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1", keep_batchnorm_fp32=None, loss_scale='dynamic')
215 |
216 | beg_step = 0
217 | if cfg.resume_path != None:
218 | beg_step = int(os.path.splitext(os.path.basename(cfg.resume_path))[0].split('_')[1])
219 | with Timer('resume model from %s'%cfg.resume_path):
220 | ckpt = torch.load(cfg.resume_path, map_location='cpu')
221 | model.load_state_dict(ckpt['state_dict'])
222 |
223 | totalloss_meter = AverageMeter()
224 | bclloss_pos_meter = AverageMeter()
225 | bclloss_neg_meter = AverageMeter()
226 | keeploss_meter = AverageMeter()
227 | before_edge_num_meter = AverageMeter()
228 | after_edge_num_meter = AverageMeter()
229 | acc_meter = AverageMeter()
230 | prec_meter = AverageMeter()
231 | recall_meter = AverageMeter()
232 | leftprec_meter = AverageMeter()
233 | writer = SummaryWriter(os.path.join(cfg.outpath), filename_suffix='')
234 | #sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
235 | #sampler = dgl.dataloading.MultiLayerNeighborSampler([cfg.topk, cfg.topk, 128])
236 | #sampler = dgl.dataloading.MultiLayerNeighborSampler([None, None, 128])
237 | #dataloader = dgl.dataloading.NodeDataLoader(
238 | # order_graph,
239 | # np.arange(order_graph.number_of_nodes()),
240 | # sampler,
241 | # batch_size=cfg.batchsize,
242 | # shuffle=True,
243 | # drop_last=False,
244 | # num_workers=4)
245 | #sampler = dgl.dataloading.MultiLayerNeighborSampler([128])
246 | sampler = None
247 | collator = multigraph_NodeCollator(order_graph, graph, np.arange(len(features)), sampler) # v4
248 | dataloader = torch.utils.data.DataLoader(
249 | dataset=collator.dataset,
250 | batch_size=cfg.batchsize,
251 | shuffle=True,
252 | num_workers=4,
253 | pin_memory=True,
254 | drop_last=False,
255 | collate_fn=collator.collate,
256 | )
257 |
258 | current_step = beg_step
259 | break_flag = False
260 | while 1:
261 | #adjust_lr(current_step, optimizer.param_groups, cfg)
262 | if break_flag:
263 | break
264 | for _, (src_idx, dst_idx, blocks) in enumerate(dataloader):
265 | if current_step > cfg.total_step:
266 | break_flag = True
267 | break
268 | iter_begtime = time.time()
269 | current_step += 1
270 | cos_lr(current_step, optimizer, cfg)
271 | #src_idx = blocks[0].srcdata[dgl.NID].numpy()
272 | #dst_idx = blocks[-1].dstdata[dgl.NID].numpy()
273 |
274 | batch_feature = features[src_idx, :]
275 | #batch_adj = sparse_mx_to_torch_sparse_tensor(adj[src_idx, :][:, src_idx]).cuda()
276 | #batch_adj = torch.from_numpy( row_normalize(adj[src_idx, :][:, src_idx]).todense() ).cuda()
277 | batch_block = [block.to(0) for block in blocks] # need not row normalize, because the attention weight edge
278 | batch_label = labels[dst_idx]
279 | batch_idlabel = labels[src_idx]
280 | batch_data = (batch_feature, batch_block, batch_label, batch_idlabel)
281 | bclloss_dict = model(batch_data, return_loss=True)
282 | loss = bclloss_dict['ctrd_pos'] + bclloss_dict['ctrd_neg']
283 |
284 | optimizer.zero_grad()
285 | if cfg.fp16:
286 | with amp.scale_loss(loss, optimizer) as scaled_loss:
287 | scaled_loss.backward()
288 | else:
289 | loss.backward()
290 | optimizer.step()
291 |
292 | totalloss_meter.update(loss.item())
293 | bclloss_pos_meter.update(bclloss_dict['ctrd_pos'].item())
294 | bclloss_neg_meter.update(bclloss_dict['ctrd_neg'].item())
295 |
296 | writer.add_scalar('loss/total', loss.item(), global_step=current_step)
297 | writer.add_scalar('loss/bcl_pos', bclloss_dict['ctrd_pos'].item(), global_step=current_step)
298 | writer.add_scalar('loss/bcl_neg', bclloss_dict['ctrd_neg'].item(), global_step=current_step)
299 | if current_step % cfg.log_freq == 0:
300 | log = "step{}/{}, iter_time:{:.3f}, lr:{:.4f}, loss:{:.4f}({:.4f}), bclloss_pos:{:.8f}({:.8f}), bclloss_neg:{:.4f}({:.4f}) ".format(current_step, cfg.total_step, time.time()-iter_begtime, optimizer.param_groups[0]['lr'], totalloss_meter.val, totalloss_meter.avg, bclloss_pos_meter.val, bclloss_pos_meter.avg, bclloss_neg_meter.val, bclloss_neg_meter.avg)
301 | print(log)
302 | if (current_step+1) % cfg.save_freq == 0 and current_step > 0:
303 | torch.save({'state_dict' : model.state_dict(), 'step': current_step+1},
304 | os.path.join(cfg.outpath, "ckpt_%s.pth"%(current_step+1)))
305 | writer.close()
306 | else:
307 | with Timer('resume model from %s'%cfg.resume_path):
308 | ckpt = torch.load(cfg.resume_path, map_location='cpu')
309 | model.load_state_dict(ckpt['state_dict'])
310 | model.eval()
311 |
312 | sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
313 | dataloader = dgl.dataloading.NodeDataLoader(
314 | graph,
315 | np.arange(graph.number_of_nodes()),
316 | sampler,
317 | batch_size=1024,
318 | shuffle=False,
319 | drop_last=False,
320 | num_workers=16)
321 |
322 | gcnfeat_list, fcfeat_list = [], []
323 | leftprec_meter = AverageMeter()
324 | beg_time = time.time()
325 | for step, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
326 | src_idx = blocks[0].srcdata[dgl.NID].numpy()
327 | dst_idx = blocks[-1].dstdata[dgl.NID].numpy()
328 | #zip(block.srcnodes(), block.srcdata[dgl.NID])
329 | #zip(block.dstnodes(), block.dstdata[dgl.NID])
330 | batch_feature = features[src_idx, :]
331 | #batch_adj = sparse_mx_to_torch_sparse_tensor(adj[src_idx, :][:, src_idx]).cuda()
332 |
333 | #batch_adj = torch.from_numpy(adj[src_idx, :][:, src_idx].todense()).cuda() # no sample and no row normalize again
334 | batch_block = [block.to(0) for block in blocks]
335 | batch_label = labels[dst_idx]
336 | batch_idlabel = labels[src_idx]
337 | batch_data = (batch_feature, batch_block, batch_label, batch_idlabel)
338 | #fcfeat, gcnfeat, before_edge_num, after_edge_num, acc_rate, prec, recall, left_prec = model(batch_data, output_feat=True)
339 | fcfeat, gcnfeat = model(batch_data, output_feat=True)
340 |
341 | fcfeat_list.append(fcfeat.data.cpu().numpy())
342 | gcnfeat_list.append(gcnfeat.data.cpu().numpy())
343 | #leftprec_meter.update(left_prec)
344 | #if step % 1 == 0:
345 | # log = "step %s/%s"%(step, len(dataloader))
346 | # print(log)
347 | print("time use %.4f"%(time.time()-beg_time))
348 |
349 | fcfeat_arr = np.vstack(fcfeat_list)
350 | gcnfeat_arr = np.vstack(gcnfeat_list)
351 | fcfeat_arr = fcfeat_arr / np.linalg.norm(fcfeat_arr, axis=1, keepdims=True)
352 | gcnfeat_arr = gcnfeat_arr / np.linalg.norm(gcnfeat_arr, axis=1, keepdims=True)
353 | tag = os.path.splitext(os.path.basename(cfg.resume_path))[0]
354 | np.save(os.path.join(cfg.outpath, 'fcfeat_%s'%tag), fcfeat_arr)
355 | np.save(os.path.join(cfg.outpath, 'gcnfeat_%s'%tag), gcnfeat_arr)
356 |
357 | print("time use", time.time() - beg_time)
358 |
--------------------------------------------------------------------------------
/GCN/util/confidence.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import numpy as np
5 | from tqdm import tqdm
6 | from itertools import groupby
7 |
8 | __all__ = ['density', 'confidence', 'confidence_to_peaks']
9 |
10 |
11 | def density(dists, radius=0.3, use_weight=True):
12 | row, col = (dists < radius).nonzero()
13 |
14 | num, _ = dists.shape
15 | if use_weight:
16 | density = np.zeros((num, ), dtype=np.float32)
17 | for r, c in zip(row, col):
18 | density[r] += 1 - dists[r, c]
19 | else:
20 | density = np.zeros((num, ), dtype=np.int32)
21 | for k, g in groupby(row):
22 | density[k] = len(list(g))
23 | return density
24 |
25 |
26 | def s_nbr(dists, nbrs, idx2lb, **kwargs):
27 | ''' use supervised confidence defined on neigborhood
28 | '''
29 | num, _ = dists.shape
30 | conf = np.zeros((num, ), dtype=np.float32)
31 | contain_neg = 0
32 | for i, (nbr, dist) in enumerate(zip(nbrs, dists)):
33 | lb = idx2lb[i]
34 | pos, neg = 0, 0
35 | for j, n in enumerate(nbr):
36 | if idx2lb[n] == lb:
37 | pos += 1 - dist[j]
38 | else:
39 | neg += 1 - dist[j]
40 | conf[i] = pos - neg
41 | if neg > 0:
42 | contain_neg += 1
43 | print('#contain_neg:', contain_neg)
44 | conf /= np.abs(conf).max()
45 | return conf
46 |
47 |
48 | def s_nbr_size_norm(dists, nbrs, idx2lb, **kwargs):
49 | ''' use supervised confidence defined on neigborhood (norm by size)
50 | '''
51 | num, _ = dists.shape
52 | conf = np.zeros((num, ), dtype=np.float32)
53 | contain_neg = 0
54 | max_size = 0
55 | for i, (nbr, dist) in enumerate(zip(nbrs, dists)):
56 | size = 0
57 | pos, neg = 0, 0
58 | lb = idx2lb[i]
59 | for j, n in enumerate(nbr):
60 | if idx2lb[n] == lb:
61 | pos += 1 - dist[j]
62 | else:
63 | neg += 1 - dist[j]
64 | size += 1
65 | conf[i] = pos - neg
66 | max_size = max(max_size, size)
67 | if neg > 0:
68 | contain_neg += 1
69 | print('#contain_neg:', contain_neg)
70 | print('max_size: {}'.format(max_size))
71 | conf /= max_size
72 | return conf
73 |
74 |
75 | def s_avg(feats, idx2lb, lb2idxs, **kwargs):
76 | ''' use average similarity of intra-nodes
77 | '''
78 | num = len(idx2lb)
79 | conf = np.zeros((num, ), dtype=np.float32)
80 | for i in range(num):
81 | lb = idx2lb[i]
82 | idxs = lb2idxs[lb]
83 | idxs.remove(i)
84 | if len(idxs) == 0:
85 | continue
86 | feat = feats[i, :]
87 | conf[i] = feat.dot(feats[idxs, :].T).mean()
88 | eps = 1e-6
89 | assert -1 - eps <= conf.min() <= conf.max(
90 | ) <= 1 + eps, "min: {}, max: {}".format(conf.min(), conf.max())
91 | return conf
92 |
93 |
94 | def s_center(feats, idx2lb, lb2idxs, **kwargs):
95 | ''' use average similarity of intra-nodes
96 | '''
97 | num = len(idx2lb)
98 | conf = np.zeros((num, ), dtype=np.float32)
99 | for i in range(num):
100 | lb = idx2lb[i]
101 | idxs = lb2idxs[lb]
102 | if len(idxs) == 0:
103 | continue
104 | feat = feats[i, :]
105 | feat_center = feats[idxs, :].mean(axis=0)
106 | conf[i] = feat.dot(feat_center.T)
107 | eps = 1e-6
108 | assert -1 - eps <= conf.min() <= conf.max(
109 | ) <= 1 + eps, "min: {}, max: {}".format(conf.min(), conf.max())
110 | return conf
111 |
112 |
113 | def confidence(metric='s_nbr', **kwargs):
114 | metric2func = {
115 | 's_nbr': s_nbr,
116 | 's_nbr_size_norm': s_nbr_size_norm,
117 | 's_avg': s_avg,
118 | 's_center': s_center,
119 | }
120 | if metric in metric2func:
121 | func = metric2func[metric]
122 | else:
123 | raise KeyError('Only support confidence metircs: {}'.format(
124 | metric2func.keys()))
125 |
126 | conf = func(**kwargs)
127 | return conf
128 |
129 |
130 | def confidence_to_peaks(dists, nbrs, confidence, max_conn=1):
131 | # Note that dists has been sorted in ascending order
132 | assert dists.shape[0] == confidence.shape[0]
133 | assert dists.shape == nbrs.shape
134 |
135 | num, _ = dists.shape
136 | dist2peak = {i: [] for i in range(num)}
137 | peaks = {i: [] for i in range(num)}
138 |
139 | for i, nbr in tqdm(enumerate(nbrs)):
140 | nbr_conf = confidence[nbr]
141 | for j, c in enumerate(nbr_conf):
142 | nbr_idx = nbr[j]
143 | if i == nbr_idx or c <= confidence[i]:
144 | continue
145 | dist2peak[i].append(dists[i, j])
146 | peaks[i].append(nbr_idx)
147 | if len(dist2peak[i]) >= max_conn:
148 | break
149 | return dist2peak, peaks
150 |
--------------------------------------------------------------------------------
/GCN/util/deduce.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | __all__ = ['peaks_to_labels']
4 |
5 |
6 | def _find_parent(parent, u):
7 | idx = []
8 | # parent is a fixed point
9 | while (u != parent[u]):
10 | idx.append(u)
11 | u = parent[u]
12 | for i in idx:
13 | parent[i] = u
14 | return u
15 |
16 |
17 | def edge_to_connected_graph(edges, num):
18 | parent = list(range(num))
19 | for u, v in edges:
20 | p_u = _find_parent(parent, u)
21 | p_v = _find_parent(parent, v)
22 | parent[p_u] = p_v
23 |
24 | for i in range(num):
25 | parent[i] = _find_parent(parent, i)
26 | remap = {}
27 | uf = np.unique(np.array(parent))
28 | for i, f in enumerate(uf):
29 | remap[f] = i
30 | cluster_id = np.array([remap[f] for f in parent])
31 | return cluster_id
32 |
33 |
34 | def peaks_to_edges(peaks, dist2peak, tau):
35 | edges = []
36 | for src in peaks:
37 | dsts = peaks[src]
38 | dists = dist2peak[src]
39 | for dst, dist in zip(dsts, dists):
40 | if src == dst or dist >= 1 - tau:
41 | continue
42 | edges.append([src, dst])
43 | return edges
44 |
45 |
46 | def peaks_to_labels(peaks, dist2peak, tau, inst_num):
47 | edges = peaks_to_edges(peaks, dist2peak, tau)
48 | pred_labels = edge_to_connected_graph(edges, inst_num)
49 | return pred_labels
50 |
--------------------------------------------------------------------------------
/GCN/util/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import inspect
5 | import argparse
6 | import numpy as np
7 | import util.metrics as metrics
8 | import time
9 |
10 | class TextColors:
11 | #HEADER = '\033[35m'
12 | #OKBLUE = '\033[34m'
13 | #OKGREEN = '\033[32m'
14 | #WARNING = '\033[33m'
15 | #FATAL = '\033[31m'
16 | #ENDC = '\033[0m'
17 | #BOLD = '\033[1m'
18 | #UNDERLINE = '\033[4m'
19 | HEADER = ''
20 | OKBLUE = ''
21 | OKGREEN = ''
22 | WARNING = ''
23 | FATAL = ''
24 | ENDC = ''
25 | BOLD = ''
26 | UNDERLINE = ''
27 |
28 | class Timer():
29 | def __init__(self, name='task', verbose=True):
30 | self.name = name
31 | self.verbose = verbose
32 |
33 | def __enter__(self):
34 | self.start = time.time()
35 | return self
36 |
37 | def __exit__(self, exc_type, exc_val, exc_tb):
38 | if self.verbose:
39 | print('[Time] {} consumes {:.4f} s'.format(
40 | self.name,
41 | time.time() - self.start))
42 | return exc_type is None
43 |
44 |
45 | def _read_meta(fn):
46 | labels = list()
47 | lb_set = set()
48 | with open(fn) as f:
49 | for lb in f.readlines():
50 | lb = int(lb.strip())
51 | labels.append(lb)
52 | lb_set.add(lb)
53 | return np.array(labels), lb_set
54 |
55 |
56 | def evaluate(gt_labels, pred_labels, metric='pairwise'):
57 | if isinstance(gt_labels, str) and isinstance(pred_labels, str):
58 | print('[gt_labels] {}'.format(gt_labels))
59 | print('[pred_labels] {}'.format(pred_labels))
60 | gt_labels, gt_lb_set = _read_meta(gt_labels)
61 | pred_labels, pred_lb_set = _read_meta(pred_labels)
62 |
63 | print('#inst: gt({}) vs pred({})'.format(len(gt_labels),
64 | len(pred_labels)))
65 | print('#cls: gt({}) vs pred({})'.format(len(gt_lb_set),
66 | len(pred_lb_set)))
67 |
68 | metric_func = metrics.__dict__[metric]
69 |
70 | with Timer('evaluate with {}{}{}'.format(TextColors.FATAL, metric,
71 | TextColors.ENDC), verbose=False):
72 | result = metric_func(gt_labels, pred_labels)
73 | if isinstance(result, np.float):
74 | #print('{}{}: {:.4f}{}'.format(TextColors.OKGREEN, metric, result, TextColors.ENDC))
75 | res_str = '{}{}: {:.4f}{}'.format(TextColors.OKGREEN, metric, result, TextColors.ENDC)
76 | else:
77 | from collections import Counter
78 | singleton_num = len( list( filter(lambda x: x==1, Counter(pred_labels).values()) ) )
79 | ave_pre, ave_rec, fscore = result
80 | #print('{}ave_pre: {:.4f}, ave_rec: {:.4f}, fscore: {:.4f}{}, cluster_num: {}, singleton_num: {}'.format(
81 | # TextColors.OKGREEN, ave_pre, ave_rec, fscore, TextColors.ENDC, len(np.unique(pred_labels)), singleton_num))
82 | res_str = '{}{}: ave_pre: {:.4f}, ave_rec: {:.4f}, fscore: {:.4f}{}, cluster_num: {}, singleton_num: {}'.format(
83 | TextColors.OKGREEN, metric, ave_pre, ave_rec, fscore, TextColors.ENDC, len(np.unique(pred_labels)), singleton_num)
84 | #return ave_pre, ave_rec, fscore
85 | return res_str
86 |
87 |
88 | if __name__ == '__main__':
89 | metric_funcs = inspect.getmembers(metrics, inspect.isfunction)
90 | metric_names = [n for n, _ in metric_funcs]
91 |
92 | parser = argparse.ArgumentParser(description='Evaluate Cluster')
93 | parser.add_argument('--gt_labels', type=str, required=True)
94 | parser.add_argument('--pred_labels', type=str, required=True)
95 | parser.add_argument('--metric', default='pairwise', choices=metric_names)
96 | args = parser.parse_args()
97 |
98 | evaluate(args.gt_labels, args.pred_labels, args.metric)
99 |
--------------------------------------------------------------------------------
/GCN/util/graph.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import division
3 | from __future__ import absolute_import
4 |
5 | import numpy as np
6 | import time
7 |
8 | class Data(object):
9 | def __init__(self, name):
10 | self.__name = name
11 | self.__links = set()
12 |
13 | @property
14 | def name(self):
15 | return self.__name
16 |
17 | @property
18 | def links(self):
19 | return set(self.__links)
20 |
21 | def add_link(self, other, score):
22 | self.__links.add(other)
23 | other.__links.add(self)
24 |
25 | def connected_components(nodes, score_dict, th):
26 | '''
27 | conventional connected components searching
28 | '''
29 | result = []
30 | nodes = set(nodes)
31 | while nodes:
32 | n = nodes.pop()
33 | group = {n}
34 | queue = [n]
35 | while queue:
36 | n = queue.pop(0)
37 | if th is not None:
38 | neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th}
39 | else:
40 | neighbors = n.links
41 | neighbors.difference_update(group)
42 | nodes.difference_update(neighbors)
43 | group.update(neighbors)
44 | queue.extend(neighbors)
45 | result.append(group)
46 | return result
47 |
48 | def connected_components_constraint(nodes, max_sz, score_dict=None, th=None):
49 | '''
50 | only use edges whose scores are above `th`
51 | if a component is larger than `max_sz`, all the nodes in this component are added into `remain` and returned for next iteration.
52 | '''
53 | result = []
54 | remain = set()
55 | nodes = set(nodes)
56 | while nodes:
57 | n = nodes.pop()
58 | group = {n}
59 | queue = [n]
60 | valid = True
61 | while queue:
62 | n = queue.pop(0)
63 | if th is not None:
64 | neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th}
65 | else:
66 | neighbors = n.links
67 | neighbors.difference_update(group)
68 | nodes.difference_update(neighbors)
69 | group.update(neighbors)
70 | queue.extend(neighbors)
71 | if len(group) > max_sz or len(remain.intersection(neighbors)) > 0:
72 | # if this group is larger than `max_sz`, add the nodes into `remain`
73 | valid = False
74 | remain.update(group)
75 | break
76 | if valid: # if this group is smaller than or equal to `max_sz`, finalize it.
77 | result.append(group)
78 | return result, remain
79 |
80 |
81 | def graph_propagation_naive(edges, score, th):
82 |
83 | edges = np.sort(edges, axis=1)
84 |
85 | # construct graph
86 | score_dict = {} # score lookup table
87 | for i,e in enumerate(edges):
88 | score_dict[e[0], e[1]] = score[i]
89 |
90 | nodes = np.sort(np.unique(edges.flatten()))
91 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
92 | mapping[nodes] = np.arange(nodes.shape[0])
93 | link_idx = mapping[edges]
94 | vertex = [Data(n) for n in nodes]
95 | for l, s in zip(link_idx, score):
96 | vertex[l[0]].add_link(vertex[l[1]], s)
97 |
98 | # first iteration
99 | comps = connected_components(vertex, score_dict,th)
100 |
101 | return comps
102 |
103 | def graph_propagation(edges, score, max_sz, step=0.1, beg_th=0.9, pool=None):
104 |
105 | edges = np.sort(edges, axis=1)
106 | th = score.min()
107 | #th = beg_th
108 | # construct graph
109 | score_dict = {} # score lookup table
110 | if pool is None:
111 | for i,e in enumerate(edges):
112 | score_dict[e[0], e[1]] = score[i]
113 | elif pool == 'avg':
114 | for i,e in enumerate(edges):
115 | #if score_dict.has_key((e[0],e[1])):
116 | if (e[0],e[1]) in score_dict:
117 | score_dict[e[0], e[1]] = 0.5*(score_dict[e[0], e[1]] + score[i])
118 | else:
119 | score_dict[e[0], e[1]] = score[i]
120 |
121 | elif pool == 'max':
122 | for i,e in enumerate(edges):
123 | #if score_dict.has_key((e[0],e[1])):
124 | if (e[0],e[1]) in score_dict:
125 | score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]] , score[i])
126 | else:
127 | score_dict[e[0], e[1]] = score[i]
128 | else:
129 | raise ValueError('Pooling operation not supported')
130 |
131 | nodes = np.sort(np.unique(edges.flatten()))
132 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
133 | mapping[nodes] = np.arange(nodes.shape[0])
134 | link_idx = mapping[edges]
135 | vertex = [Data(n) for n in nodes]
136 | for l, s in zip(link_idx, score):
137 | vertex[l[0]].add_link(vertex[l[1]], s)
138 |
139 | # first iteration
140 | comps, remain = connected_components_constraint(vertex, max_sz)
141 |
142 | # iteration
143 | components = comps[:]
144 | while remain:
145 | th = th + (1 - th) * step
146 | comps, remain = connected_components_constraint(remain, max_sz, score_dict, th)
147 | components.extend(comps)
148 | return components
149 |
150 | def graph_propagation_begin(edges, score, max_sz, step=0.1, beg_th=0.9, pool=None):
151 |
152 | th = beg_th
153 | # construct graph
154 | score_dict = {} # score lookup table
155 | if pool is None:
156 | for i,e in enumerate(edges):
157 | score_dict[e[0], e[1]] = score[i]
158 | elif pool == 'avg':
159 | for i,e in enumerate(edges):
160 | #if score_dict.has_key((e[0],e[1])):
161 | if (e[0],e[1]) in score_dict:
162 | score_dict[e[0], e[1]] = 0.5*(score_dict[e[0], e[1]] + score[i])
163 | else:
164 | score_dict[e[0], e[1]] = score[i]
165 |
166 | elif pool == 'max':
167 | for i,e in enumerate(edges):
168 | #if score_dict.has_key((e[0],e[1])):
169 | if (e[0],e[1]) in score_dict:
170 | score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]] , score[i])
171 | else:
172 | score_dict[e[0], e[1]] = score[i]
173 | else:
174 | raise ValueError('Pooling operation not supported')
175 |
176 | nodes = np.sort(np.unique(edges.flatten()))
177 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
178 | mapping[nodes] = np.arange(nodes.shape[0])
179 | link_idx = mapping[edges]
180 | vertex = [Data(n) for n in nodes]
181 | for l, s in zip(link_idx, score):
182 | vertex[l[0]].add_link(vertex[l[1]], s)
183 |
184 | # first iteration
185 | comps, remain = connected_components_constraint(vertex, max_sz)
186 |
187 | # iteration
188 | components = comps[:]
189 | while remain:
190 | th = th + (1 - th) * step
191 | comps, remain = connected_components_constraint(remain, max_sz, score_dict, th)
192 | components.extend(comps)
193 | return components
194 |
195 | def graph_propagation_onecut(edges, score, max_sz, th=0.4, pool=None):
196 | edges = np.sort(edges, axis=1)
197 |
198 | # construct graph
199 | score_dict = {} # score lookup table
200 | if pool is None:
201 | for i,e in enumerate(edges):
202 | score_dict[e[0], e[1]] = score[i]
203 | elif pool == 'avg':
204 | for i,e in enumerate(edges):
205 | #if score_dict.has_key((e[0],e[1])):
206 | if (e[0],e[1]) in score_dict:
207 | score_dict[e[0], e[1]] = 0.5*(score_dict[e[0], e[1]] + score[i])
208 | else:
209 | score_dict[e[0], e[1]] = score[i]
210 | elif pool == 'max':
211 | for i,e in enumerate(edges):
212 | #if score_dict.has_key((e[0],e[1])):
213 | if (e[0],e[1]) in score_dict:
214 | score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]] , score[i])
215 | else:
216 | score_dict[e[0], e[1]] = score[i]
217 | else:
218 | raise ValueError('Pooling operation not supported')
219 |
220 | nodes = np.sort(np.unique(edges.flatten()))
221 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
222 | mapping[nodes] = np.arange(nodes.shape[0])
223 | link_idx = mapping[edges]
224 | vertex = [Data(n) for n in nodes]
225 | for l, s in zip(link_idx, score):
226 | vertex[l[0]].add_link(vertex[l[1]], s)
227 |
228 | comps, remain = connected_components_constraint(vertex, max_sz, score_dict, th)
229 | assert len(remain) == 0
230 | return comps
231 |
232 | def graph_propagation_soft(edges, score, max_sz, step=0.1, **kwargs):
233 |
234 | edges = np.sort(edges, axis=1)
235 | th = score.min()
236 |
237 | # construct graph
238 | score_dict = {} # score lookup table
239 | for i,e in enumerate(edges):
240 | score_dict[e[0], e[1]] = score[i]
241 |
242 | nodes = np.sort(np.unique(edges.flatten()))
243 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
244 | mapping[nodes] = np.arange(nodes.shape[0])
245 | link_idx = mapping[edges]
246 | vertex = [Data(n) for n in nodes]
247 | for l, s in zip(link_idx, score):
248 | vertex[l[0]].add_link(vertex[l[1]], s)
249 |
250 | # first iteration
251 | comps, remain = connected_components_constraint(vertex, max_sz)
252 | first_vertex_idx = np.array([mapping[n.name] for c in comps for n in c])
253 | fusion_vertex_idx = np.setdiff1d(np.arange(nodes.shape[0]), first_vertex_idx, assume_unique=True)
254 | # iteration
255 | components = comps[:]
256 | while remain:
257 | th = th + (1 - th) * step
258 | comps, remain = connected_components_constraint(remain, max_sz, score_dict, th)
259 | components.extend(comps)
260 | label_dict = {}
261 | for i,c in enumerate(components):
262 | for n in c:
263 | label_dict[n.name] = i
264 | print('Propagation ...')
265 | prop_vertex = [vertex[idx] for idx in fusion_vertex_idx]
266 | label, label_fusion = diffusion(prop_vertex, label_dict, score_dict, **kwargs)
267 | return label, label_fusion
268 |
269 | def diffusion(vertex, label, score_dict, max_depth=5, weight_decay=0.6, normalize=True):
270 | class BFSNode():
271 | def __init__(self, node, depth, value):
272 | self.node = node
273 | self.depth = depth
274 | self.value = value
275 |
276 | label_fusion = {}
277 | for name in label.keys():
278 | label_fusion[name] = {label[name]: 1.0}
279 | prog = 0
280 | prog_step = len(vertex) // 20
281 | start = time.time()
282 | for root in vertex:
283 | if prog % prog_step == 0:
284 | print("progress: {} / {}, elapsed time: {}".format(prog, len(vertex), time.time() - start))
285 | prog += 1
286 | #queue = {[root, 0, 1.0]}
287 | queue = {BFSNode(root, 0, 1.0)}
288 | visited = [root.name]
289 | root_label = label[root.name]
290 | while queue:
291 | curr = queue.pop()
292 | if curr.depth >= max_depth: # pruning
293 | continue
294 | neighbors = curr.node.links
295 | tmp_value = []
296 | tmp_neighbor = []
297 | for n in neighbors:
298 | if n.name not in visited:
299 | sub_value = score_dict[tuple(sorted([curr.node.name, n.name]))] * weight_decay * curr.value
300 | tmp_value.append(sub_value)
301 | tmp_neighbor.append(n)
302 | if root_label not in label_fusion[n.name].keys():
303 | label_fusion[n.name][root_label] = sub_value
304 | else:
305 | label_fusion[n.name][root_label] += sub_value
306 | visited.append(n.name)
307 | #queue.add([n, curr.depth+1, sub_value])
308 | sortidx = np.argsort(tmp_value)[::-1]
309 | for si in sortidx:
310 | queue.add(BFSNode(tmp_neighbor[si], curr.depth+1, tmp_value[si]))
311 | if normalize:
312 | for name in label_fusion.keys():
313 | summ = sum(label_fusion[name].values())
314 | for k in label_fusion[name].keys():
315 | label_fusion[name][k] /= summ
316 | return label, label_fusion
317 |
--------------------------------------------------------------------------------
/GCN/util/metrics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import division
5 |
6 | import numpy as np
7 | from sklearn.metrics.cluster import (contingency_matrix,
8 | normalized_mutual_info_score)
9 | from sklearn.metrics import (precision_score, recall_score)
10 |
11 | __all__ = ['pairwise', 'bcubed', 'nmi', 'precision', 'recall', 'accuracy']
12 |
13 |
14 | def _check(gt_labels, pred_labels):
15 | if gt_labels.ndim != 1:
16 | raise ValueError("gt_labels must be 1D: shape is %r" %
17 | (gt_labels.shape, ))
18 | if pred_labels.ndim != 1:
19 | raise ValueError("pred_labels must be 1D: shape is %r" %
20 | (pred_labels.shape, ))
21 | if gt_labels.shape != pred_labels.shape:
22 | raise ValueError(
23 | "gt_labels and pred_labels must have same size, got %d and %d" %
24 | (gt_labels.shape[0], pred_labels.shape[0]))
25 | return gt_labels, pred_labels
26 |
27 |
28 | def _get_lb2idxs(labels):
29 | lb2idxs = {}
30 | for idx, lb in enumerate(labels):
31 | if lb not in lb2idxs:
32 | lb2idxs[lb] = []
33 | lb2idxs[lb].append(idx)
34 | return lb2idxs
35 |
36 |
37 | def _compute_fscore(pre, rec):
38 | return 2. * pre * rec / (pre + rec)
39 |
40 |
41 | def fowlkes_mallows_score(gt_labels, pred_labels, sparse=True):
42 | ''' The original function is from `sklearn.metrics.fowlkes_mallows_score`.
43 | We output the pairwise precision, pairwise recall and F-measure,
44 | instead of calculating the geometry mean of precision and recall.
45 | '''
46 | n_samples, = gt_labels.shape
47 |
48 | c = contingency_matrix(gt_labels, pred_labels, sparse=sparse)
49 | tk = np.dot(c.data, c.data) - n_samples
50 | pk = np.sum(np.asarray(c.sum(axis=0)).ravel()**2) - n_samples
51 | qk = np.sum(np.asarray(c.sum(axis=1)).ravel()**2) - n_samples
52 |
53 | avg_pre = tk / pk
54 | avg_rec = tk / qk
55 | fscore = _compute_fscore(avg_pre, avg_rec)
56 |
57 | return avg_pre, avg_rec, fscore
58 |
59 |
60 | def pairwise(gt_labels, pred_labels, sparse=True):
61 | _check(gt_labels, pred_labels)
62 | return fowlkes_mallows_score(gt_labels, pred_labels, sparse)
63 |
64 |
65 | def bcubed(gt_labels, pred_labels):
66 | _check(gt_labels, pred_labels)
67 |
68 | gt_lb2idxs = _get_lb2idxs(gt_labels)
69 | pred_lb2idxs = _get_lb2idxs(pred_labels)
70 |
71 | num_lbs = len(gt_lb2idxs)
72 | pre = np.zeros(num_lbs)
73 | rec = np.zeros(num_lbs)
74 | gt_num = np.zeros(num_lbs)
75 |
76 | for i, gt_idxs in enumerate(gt_lb2idxs.values()):
77 | all_pred_lbs = np.unique(pred_labels[gt_idxs])
78 | gt_num[i] = len(gt_idxs)
79 | for pred_lb in all_pred_lbs:
80 | pred_idxs = pred_lb2idxs[pred_lb]
81 | n = 1. * np.intersect1d(gt_idxs, pred_idxs).size
82 | pre[i] += n**2 / len(pred_idxs)
83 | rec[i] += n**2 / gt_num[i]
84 |
85 | gt_num = gt_num.sum()
86 | avg_pre = pre.sum() / gt_num
87 | avg_rec = rec.sum() / gt_num
88 | fscore = _compute_fscore(avg_pre, avg_rec)
89 |
90 | return avg_pre, avg_rec, fscore
91 |
92 |
93 | def nmi(gt_labels, pred_labels):
94 | return normalized_mutual_info_score(pred_labels, gt_labels)
95 |
96 |
97 | def precision(gt_labels, pred_labels):
98 | return precision_score(gt_labels, pred_labels)
99 |
100 |
101 | def recall(gt_labels, pred_labels):
102 | return recall_score(gt_labels, pred_labels)
103 |
104 |
105 | def accuracy(gt_labels, pred_labels):
106 | return np.mean(gt_labels == pred_labels)
107 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Thomas-wyh
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Ada-NETS
2 |
3 | This is an official implementation for "Ada-NETS: Face Clustering via Adaptive Neighbour Discovery in the Structure Space" accepted at ICLR 2022.
4 |
5 | ## News
6 | - 🔥 An improved method on face clustering ([**B-Attenion**](https://github.com/Thomas-wyh/B-Attention/)) is accepted by NeurIPS 2022!
7 | - 🔥 Ada-NETS is accepted by ICLR 2022!
8 |
9 |
10 | ## Introduction
11 |
12 | This paper presents a novel Ada-NETS algorithm to deal with the noise edges problem when building the graph in GCN-based face clustering. In Ada-NETS, the features are first transformed to the structure space to enhance the accuracy of the similarity metrics. Then an adaptive neighbour discovery method is used to find neighbours for all samples adaptively with the guidance of a heuristic quality criterion. Based on the discovered neighbour relations, a graph with clean and rich edges is built as the input of GCNs to obtain state-of-the-art on the face, clothes, and person clustering tasks.
13 |
14 |
15 |
16 |
17 |
18 | ## Main Results
19 |
20 |
21 |
22 |
23 |
24 |
25 | ## Getting Started
26 |
27 | ### Install
28 |
29 | + Clone this repo
30 |
31 | ```
32 | git clone https://github.com/Thomas-wyh/Ada-NETS
33 | cd Ada-NETS
34 | ```
35 |
36 | + Create a conda virtual environment and activate it
37 |
38 | ```
39 | conda create -n adanets python=3.6 -y
40 | conda activate adanets
41 | ```
42 |
43 | + Install `Pytorch` , `cudatoolkit` and other requirements.
44 | ```
45 | conda install pytorch==1.2 torchvision==0.4.0a0 cudatoolkit=10.2 -c pytorch
46 | pip install -r requirements.txt
47 | ```
48 |
49 | - Install `Apex`:
50 |
51 | ```
52 | git clone https://github.com/NVIDIA/apex
53 | cd apex
54 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
55 | ```
56 |
57 | ### Data preparation
58 |
59 | The process of clustering on the MS-Celeb part1 is as follows:
60 |
61 | The original data files are from [here](https://github.com/yl-1993/learn-to-cluster/blob/master/DATASET.md#supported-datasets)(The feature and label files of MSMT17 used in Ada-NETS are [here](http://idstcv.oss-cn-zhangjiakou.aliyuncs.com/Ada-NETS/MSMT17/msmt17_feature_label.zip)). For convenience, we convert them to `.npy` format after L2 normalized. The original features' dimension is 256. The file structure should look like:
62 |
63 | ```
64 | data
65 | ├── feature
66 | │ ├── part0_train.npy
67 | │ └── part1_test.npy
68 | └── label
69 | ├── part0_train.npy
70 | └── part1_test.npy
71 | ```
72 |
73 | Build the $k$NN by faiss:
74 |
75 | ```
76 | sh script/faiss_search.sh
77 | ```
78 |
79 | Obtain the top$K$ neighbours and distances of each vertex in the structure space:
80 |
81 | ```
82 | sh script/struct_space.sh
83 | ```
84 |
85 | Obtain the best neigbours by the candidate neighbours quality criterion:
86 |
87 | ```
88 | sh script/max_Q_ind.sh
89 | ```
90 |
91 | ### Train the Adaptive Filter
92 |
93 | Train the adaptive filter based on the data prepared above:
94 |
95 | ```
96 | sh script/train_AND.sh
97 | ```
98 |
99 | ### Train the GCN and cluster faces
100 |
101 | Generate the clean yet rich Graph:
102 |
103 | ```
104 | sh script/gene_adj.sh
105 | ```
106 |
107 | Train the GCN to obtain enhanced vertex features:
108 |
109 | ```
110 | sh script/train_GCN.sh
111 | ```
112 |
113 | Perform cluster faces:
114 |
115 | ```
116 | sh script/cluster.sh
117 | ```
118 |
119 | It will print the evaluation results of clustering. The Bcubed F-socre is about 91.4 and the Pairwise F-score is about 92.7.
120 |
121 |
122 |
123 | ## Acknowledgement
124 |
125 | This code is based on the publicly available face clustering [codebase](https://github.com/yl-1993/learn-to-cluster), [codebase](https://github.com/makarandtapaswi/BallClustering_ICCV2019) and the [dmlc/dgl](https://github.com/dmlc/dgl).
126 |
127 | The k-nearest neighbor search tool uses [faiss](https://github.com/facebookresearch/faiss).
128 |
129 |
130 |
131 |
132 | ## Citing Ada-NETS
133 |
134 | ```
135 | @inproceedings{wang2022adanets,
136 | title={Ada-NETS: Face Clustering via Adaptive Neighbour Discovery in the Structure Space},
137 | author={Yaohua Wang and Yaobin Zhang and Fangyi Zhang and Senzhang Wang and Ming Lin and YuQi Zhang and Xiuyu Sun},
138 | booktitle={International conference on learning representations (ICLR)},
139 | year={2022}
140 | }
141 |
142 | @misc{wang2022adanets,
143 | title={Ada-NETS: Face Clustering via Adaptive Neighbour Discovery in the Structure Space},
144 | author={Yaohua Wang and Yaobin Zhang and Fangyi Zhang and Senzhang Wang and Ming Lin and YuQi Zhang and Xiuyu Sun},
145 | year={2022},
146 | eprint={2202.03800},
147 | archivePrefix={arXiv},
148 | primaryClass={cs.CV}
149 | }
150 | ```
151 |
--------------------------------------------------------------------------------
/image/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/Ada-NETS/42e445fbb3903059136b5dcac992ba85df4a1cf5/image/fig.png
--------------------------------------------------------------------------------
/image/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/Ada-NETS/42e445fbb3903059136b5dcac992ba85df4a1cf5/image/results.png
--------------------------------------------------------------------------------
/image/results2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/Ada-NETS/42e445fbb3903059136b5dcac992ba85df4a1cf5/image/results2.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | easydict==1.9
2 | dgl-cu102==0.5.0
3 | apex==0.1
4 | faiss==1.6.3
5 | numpy==1.18.2
6 | yapf==0.30.0
7 | PyYAML==5.3.1
8 | tqdm
9 | scipy==1.2.1
10 | tensorboardX==2.0
11 |
--------------------------------------------------------------------------------
/script/cluster.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -euxo pipefail
3 | set +x
4 |
5 | cd GCN
6 |
7 | featfile=outpath/fcfeat_ckpt_35000.npy
8 | tag=fc
9 | python -W ignore ../tool/faiss_search.py $featfile $featfile $outpath $tag
10 |
11 | Ifile=outpath/fcI.npy
12 | Dfile=outpath/fcD.npy
13 | python cluster.py $Ifile $Dfile $test_labelfile
14 |
--------------------------------------------------------------------------------
/script/faiss_search.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -euxo pipefail
3 | set +x
4 |
5 | mkdir -p data/knn/train data/knn/test
6 | featfile=data/feature/part0_train.npy
7 | outpath=data/knn/train
8 | python -W ignore tool/faiss_search.py $featfile $featfile $outpath
9 |
10 | featfile=data/feature/part1_test.npy
11 | outpath=data/knn/test
12 | python -W ignore tool/faiss_search.py $featfile $featfile $outpath
13 |
--------------------------------------------------------------------------------
/script/gene_adj.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -euxo pipefail
3 | set +x
4 |
5 | mkdir -p data/adj/train data/adj/test
6 |
7 | knnfile=data/knn/train/data.npz
8 | topk=256
9 | outfile=data/adj/train/adj
10 | #python tool/gene_adj.py $knnfile $topk $outfile
11 |
12 | knnfile=data/ss/test/data.npz
13 | kfile=AND/outpath/k_infer_pred.npy
14 | outfile=data/adj/test/adj_adanets
15 | python tool/gene_adj_adanets.py $knnfile $kfile $outfile
16 |
--------------------------------------------------------------------------------
/script/max_Q_ind.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -euxo pipefail
3 | set +x
4 |
5 | mkdir -p data/max_Q/train data/max_Q/test
6 |
7 | beta=0.50
8 | Ifile=data/ss/train/I.npy
9 | labelfile=data/label/part0_train.npy
10 | outfile=data/max_Q/train/ind
11 | python tool/max_Q_ind.py $Ifile $labelfile $beta $outfile
12 |
13 | Ifile=data/ss/test/I.npy
14 | labelfile=data/label/part1_test.npy
15 | outfile=data/max_Q/test/ind
16 | python tool/max_Q_ind.py $Ifile $labelfile $beta $outfile
17 |
--------------------------------------------------------------------------------
/script/structure_space.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -euxo pipefail
3 | set +x
4 |
5 | mkdir -p data/ss/train data/ss/test
6 | python tool/struct_space.py data/knn/test/I.npy data/knn/test/D.npy 80 data/ss/test/I data/ss/test/D data/ss/test/data
7 | python tool/struct_space.py data/knn/train/I.npy data/knn/train/D.npy 80 data/ss/train/I data/ss/train/D data/ss/train/data
8 |
--------------------------------------------------------------------------------
/script/train_AND.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -euxo pipefail
3 | set +x
4 |
5 | cd AND
6 | outpath=outpath
7 | mkdir -p $outpath
8 |
9 | train_featfile=../data/feature/part0_train.npy
10 | train_Ifile=../data/ss/train/I.npy
11 | train_labelfile=../data/max_Q/train/ind.npy
12 |
13 | test_featfile=../data/feature/part1_test.npy
14 | test_Ifile=../data/ss/test/I.npy
15 | test_labelfile=../data/max_Q/test/ind.npy
16 |
17 | phase=train
18 | param=" --config_file config.yml --outpath $outpath --phase $phase
19 | --train_featfile $train_featfile --train_Ifile $train_Ifile --train_labelfile $train_labelfile
20 | --test_featfile $test_featfile --test_Ifile $test_Ifile --test_labelfile $test_labelfile"
21 | #python -u train.py $param
22 |
23 | phase=test
24 | ckpt=ckpt_40000.pth
25 | param=" --config_file config.yml --outpath $outpath --phase $phase
26 | --train_featfile $train_featfile --train_Ifile $train_Ifile --train_labelfile $train_labelfile
27 | --test_featfile $test_featfile --test_Ifile $test_Ifile --test_labelfile $test_labelfile
28 | --resume_path ${outpath}/ckpt_40000.pth"
29 | python -u train.py ${param}
30 |
--------------------------------------------------------------------------------
/script/train_GCN.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -euxo pipefail
3 | set +x
4 |
5 | cd GCN
6 |
7 | losstype=allmax
8 | margin=1.0
9 | pweight=1.0
10 | pmargin=0.9
11 | beta=0.50
12 |
13 | outpath=outpath
14 | train_featfile=../data/feature/part0_train.npy
15 | train_orderadjfile=../data/adj/train/adj.npz
16 | train_adjfile=../data/adj/train/adj.npz
17 | train_labelfile=../data/label/part0_train.npy
18 | test_featfile=../data/feature/part1_test.npy
19 | test_adjfile=../data/adj/test/adj_adanets.npz
20 | test_labelfile=../data/label/part1_test.npy
21 |
22 | phase=train
23 | param="--config_file config.yml --outpath $outpath --phase $phase
24 | --train_featfile $train_featfile --train_adjfile $train_adjfile --train_labelfile $train_labelfile --train_orderadjfile $train_orderadjfile
25 | --test_featfile $test_featfile --test_adjfile $test_adjfile --test_labelfile $test_labelfile
26 | --losstype $losstype --margin $margin --pweight $pweight --pmargin ${pmargin}"
27 | python -u train.py $param
28 |
29 | phase=test
30 | param="--config_file config.yml --outpath $outpath --phase $phase
31 | --train_featfile $train_featfile --train_adjfile $train_adjfile --train_labelfile $train_labelfile --train_orderadjfile $train_orderadjfile
32 | --test_featfile $test_featfile --test_adjfile $test_adjfile --test_labelfile $test_labelfile
33 | --losstype $losstype --margin $margin --pweight $pweight --pmargin ${pmargin} --resume_path $outpath/ckpt_35000.pth"
34 | python -u train.py $param
35 |
--------------------------------------------------------------------------------
/tool/adjacency.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import numpy as np
5 | import scipy.sparse as sp
6 |
7 |
8 | def row_normalize(mx):
9 | """Row-normalize sparse matrix"""
10 | rowsum = np.array(mx.sum(1))
11 | # if rowsum <= 0, keep its previous value
12 | rowsum[rowsum <= 0] = 1
13 | r_inv = np.power(rowsum, -1).flatten()
14 | r_inv[np.isinf(r_inv)] = 0.
15 | r_mat_inv = sp.diags(r_inv)
16 | mx = r_mat_inv.dot(mx)
17 | return mx
18 |
19 |
20 | def build_symmetric_adj(adj, self_loop=True):
21 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
22 | if self_loop:
23 | adj = adj + sp.eye(adj.shape[0])
24 | return adj
25 |
26 |
27 | def sparse_mx_to_indices_values(sparse_mx):
28 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
29 | indices = np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
30 | values = sparse_mx.data
31 | shape = np.array(sparse_mx.shape)
32 | return indices, values, shape
33 |
34 |
35 | def indices_values_to_sparse_tensor(indices, values, shape):
36 | import torch
37 | indices = torch.from_numpy(indices)
38 | values = torch.from_numpy(values)
39 | shape = torch.Size(shape)
40 | return torch.sparse.FloatTensor(indices, values, shape)
41 |
42 |
43 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
44 | """Convert a scipy sparse matrix to a torch sparse tensor."""
45 | indices, values, shape = sparse_mx_to_indices_values(sparse_mx)
46 | return indices_values_to_sparse_tensor(indices, values, shape)
47 |
--------------------------------------------------------------------------------
/tool/faiss_search.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import numpy as np
3 | import faiss
4 | from tqdm import tqdm
5 | import sys
6 | import time
7 | import os
8 |
9 | def batch_search(index, query, topk, bs, verbose=False):
10 | n = len(query)
11 | dists = np.zeros((n, topk), dtype=np.float32)
12 | nbrs = np.zeros((n, topk), dtype=np.int32)
13 |
14 | for sid in tqdm(range(0, n, bs), desc="faiss searching...", disable=not verbose):
15 | eid = min(n, sid + bs)
16 | dists[sid:eid], nbrs[sid:eid] = index.search(query[sid:eid], topk)
17 | cos_dist = dists / 2
18 | return cos_dist, nbrs
19 |
20 |
21 | def search(query_arr, doc_arr, outpath, tag, save_file=True):
22 | ### parameter
23 | nlist = 100 # 1000 cluster for 100w
24 | nprobe = 100 # test 10 cluster
25 | topk = 1024
26 | bs = 100
27 | ### end parameter
28 |
29 |
30 | #print("configure faiss")
31 | beg_time = time.time()
32 | num_gpu = faiss.get_num_gpus()
33 | dim = query_arr.shape[1]
34 | #cpu_index = faiss.index_factory(dim, 'IVF100', faiss.METRIC_INNER_PRODUCT)
35 | quantizer = faiss.IndexFlatL2(dim)
36 | cpu_index = faiss.IndexIVFFlat(quantizer, dim, nlist)
37 | cpu_index.nprobe = nprobe
38 |
39 | co = faiss.GpuMultipleClonerOptions()
40 | co.useFloat16 = True
41 | co.usePrecomputed = False
42 | co.shard = True
43 | gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co, ngpu=num_gpu)
44 |
45 | # start IVF
46 | #print("build index")
47 | gpu_index.train(doc_arr)
48 | gpu_index.add(doc_arr)
49 | #print(gpu_index.ntotal)
50 |
51 | # start query
52 | #print("start query")
53 | gpu_index.nprobe = nprobe # default nprobe is 1, try a few more
54 | print("beg search")
55 | D, I = batch_search(gpu_index, query_arr, topk, bs, verbose=True)
56 | print("time use %.4f"%(time.time()-beg_time))
57 |
58 | if save_file:
59 | np.save(os.path.join(outpath, tag+'D'), D)
60 | np.save(os.path.join(outpath, tag+'I'), I)
61 | data = np.concatenate((I[:,None,:], D[:,None,:]), axis=1)
62 | np.savez(os.path.join(outpath,'data'), data=data)
63 | print("time use", time.time()-beg_time)
64 |
65 | if __name__ == "__main__":
66 | queryfile, docfile, outpath = sys.argv[1], sys.argv[2], sys.argv[3]
67 | if len(sys.argv) == 5:
68 | tag = sys.argv[4]
69 | else:
70 | tag = ""
71 |
72 | query_arr = np.load(queryfile)
73 | doc_arr = np.load(docfile)
74 | query_arr = query_arr / np.linalg.norm(query_arr, axis=1, keepdims=True)
75 | doc_arr = doc_arr / np.linalg.norm(doc_arr, axis=1, keepdims=True)
76 |
77 | search(query_arr, doc_arr, outpath, tag)
78 |
--------------------------------------------------------------------------------
/tool/gene_adj.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import numpy as np
3 | from knn import fast_knns2spmat, knns2ordered_nbrs
4 | from adjacency import build_symmetric_adj, row_normalize
5 | from scipy.sparse import coo_matrix, save_npz
6 | import sys
7 |
8 | th_sim = 0.0
9 | if __name__ == "__main__":
10 | knnfile, topk, outfile = sys.argv[1], int(sys.argv[2]), sys.argv[3]
11 | knn_arr = np.load(knnfile)['data'][:, :, :topk]
12 |
13 | adj = fast_knns2spmat(knn_arr, topk, th_sim, use_sim=True)
14 |
15 | # build symmetric adjacency matrix
16 | adj = build_symmetric_adj(adj, self_loop=True)
17 | adj = row_normalize(adj)
18 |
19 | adj_coo = adj.tocoo()
20 | print("edge num", adj_coo.row.shape)
21 | print("mat shape", adj_coo.shape)
22 |
23 | save_npz(outfile, adj)
24 |
--------------------------------------------------------------------------------
/tool/gene_adj_adanets.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import numpy as np
3 | from knn import fast_knns2spmat, knns2ordered_nbrs, fast_knns2spmat_adaptivek
4 | from adjacency import build_symmetric_adj, row_normalize
5 | from scipy.sparse import coo_matrix, save_npz
6 | import sys
7 |
8 | th_sim = 0.0
9 | if __name__ == "__main__":
10 | knnfile, kfile, outfile = sys.argv[1], sys.argv[2], sys.argv[3]
11 | knn_arr = np.load(knnfile)['data']
12 | k_arr = np.load(kfile)
13 |
14 | adj = fast_knns2spmat_adaptivek(knn_arr, k_arr, th_sim)
15 |
16 | # build symmetric adjacency matrix
17 | adj = build_symmetric_adj(adj, self_loop=True)
18 | adj = row_normalize(adj)
19 |
20 | adj_coo = adj.tocoo()
21 | print("edge num", adj_coo.row.shape)
22 | print("mat shape", adj_coo.shape)
23 |
24 | save_npz(outfile, adj)
25 |
--------------------------------------------------------------------------------
/tool/knn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import math
6 | import numpy as np
7 | import multiprocessing as mp
8 | from tqdm import tqdm
9 |
10 | __all__ = [
11 | 'knn_brute_force', 'knn_hnsw', 'knn_faiss', 'knn_faiss_gpu', 'knns2spmat',
12 | 'fast_knns2spmat', 'knns2sub_spmat', 'build_knns', 'filter_knns',
13 | 'knns2ordered_nbrs', 'fast_knns2spmat_adaptivek'
14 | ]
15 |
16 |
17 | def knns_recall(nbrs, idx2lb, lb2idxs):
18 | with Timer('compute recall'):
19 | recs = []
20 | cnt = 0
21 | for idx, (n, _) in enumerate(nbrs):
22 | lb = idx2lb[idx]
23 | idxs = lb2idxs[lb]
24 | n = list(n)
25 | if len(n) == 1:
26 | cnt += 1
27 | s = set(idxs) & set(n)
28 | recs += [1. * len(s) / len(idxs)]
29 | print('there are {} / {} = {:.3f} isolated anchors.'.format(
30 | cnt, len(nbrs), 1. * cnt / len(nbrs)))
31 | recall = np.mean(recs)
32 | return recall
33 |
34 |
35 | def filter_knns(knns, k, th):
36 | pairs = []
37 | scores = []
38 | n = len(knns)
39 | nbrs = np.zeros([n, k], dtype=np.int32) - 1
40 | simi = np.zeros([n, k]) - 1
41 | for i, (nbr, dist) in enumerate(knns):
42 | assert len(nbr) == len(dist)
43 | nbrs[i, :len(nbr)] = nbr
44 | simi[i, :len(nbr)] = 1. - dist
45 | anchor = np.tile(np.arange(n).reshape(n, 1), (1, k))
46 |
47 | # filter
48 | selidx = np.where((simi >= th) & (nbrs != -1) & (nbrs != anchor))
49 | pairs = np.hstack((anchor[selidx].reshape(-1,
50 | 1), nbrs[selidx].reshape(-1, 1)))
51 | scores = simi[selidx]
52 |
53 | if len(pairs) > 0:
54 | # keep uniq pairs
55 | pairs = np.sort(pairs, axis=1)
56 | pairs, unique_idx = np.unique(pairs, return_index=True, axis=0)
57 | scores = scores[unique_idx]
58 | return pairs, scores
59 |
60 |
61 | def knns2ordered_nbrs(knns, sort=True):
62 | if isinstance(knns, list):
63 | knns = np.array(knns)
64 | nbrs = knns[:, 0, :].astype(np.int32)
65 | dists = knns[:, 1, :]
66 | if sort:
67 | # sort dists from low to high
68 | nb_idx = np.argsort(dists, axis=1)
69 | idxs = np.arange(nb_idx.shape[0]).reshape(-1, 1)
70 | dists = dists[idxs, nb_idx]
71 | nbrs = nbrs[idxs, nb_idx]
72 | return dists, nbrs
73 |
74 |
75 | def knns2spmat(knns, k, th_sim=0.7, use_sim=False):
76 | # convert knns to symmetric sparse matrix
77 | from scipy.sparse import csr_matrix
78 | eps = 1e-5
79 | n = len(knns)
80 | row, col, data = [], [], []
81 | for row_i, knn in enumerate(knns):
82 | nbrs, dists = knn
83 | for nbr, dist in zip(nbrs, dists):
84 | assert -eps <= dist <= 1 + eps, "{}: {}".format(row_i, dist)
85 | w = dist
86 | if 1 - w < th_sim or nbr == -1:
87 | continue
88 | if row_i == nbr:
89 | assert abs(dist) < eps
90 | continue
91 | row.append(row_i)
92 | col.append(nbr)
93 | if use_sim:
94 | w = 1 - w
95 | data.append(w)
96 | assert len(row) == len(col) == len(data)
97 | spmat = csr_matrix((data, (row, col)), shape=(n, n))
98 | return spmat
99 |
100 | def fast_knns2spmat_adaptivek(knns, k_arr, th_sim=0.7):
101 | # convert knns to symmetric sparse matrix
102 | from scipy.sparse import csr_matrix
103 | eps = 1e-5
104 | n = len(knns)
105 |
106 | nbrs = knns[:, 0, :]
107 | dists = knns[:, 1, :]
108 | assert -eps <= dists.min() <= dists.max() <= 1 + eps, "min: {}, max: {}".format(dists.min(), dists.max())
109 | sims = 1. - dists
110 |
111 | row, col = np.where(sims >= th_sim) # 这里划分阈值了
112 | new_row, new_col = [], []
113 | for row_idx, col_idx in zip(row, col):
114 | thresK = k_arr[row_idx]
115 | if col_idx >= thresK:
116 | continue
117 | new_row.append(row_idx)
118 | new_col.append(col_idx)
119 | row, col = np.array(new_row), np.array(new_col)
120 |
121 | # remove the self-loop
122 | idxs = np.where(row != nbrs[row, col])
123 | row = row[idxs]
124 | col = col[idxs]
125 | data = sims[row, col]
126 | col = nbrs[row, col] # convert to absolute column of the FULL N*N adj matrix
127 | assert len(row) == len(col) == len(data)
128 | spmat = csr_matrix((data, (row, col)), shape=(n, n))
129 | return spmat
130 |
131 | def fast_knns2spmat(knns, k, th_sim=0.7, use_sim=False, fill_value=None):
132 | # convert knns to symmetric sparse matrix
133 | from scipy.sparse import csr_matrix
134 | eps = 1e-5
135 | n = len(knns)
136 | if isinstance(knns, list):
137 | knns = np.array(knns)
138 | if len(knns.shape) == 2:
139 | # knns saved by hnsw has different shape
140 | n = len(knns)
141 | ndarr = np.ones([n, 2, k])
142 | ndarr[:, 0, :] = -1 # assign unknown dist to 1 and nbr to -1
143 | for i, (nbr, dist) in enumerate(knns):
144 | size = len(nbr)
145 | assert size == len(dist)
146 | ndarr[i, 0, :size] = nbr[:size]
147 | ndarr[i, 1, :size] = dist[:size]
148 | knns = ndarr
149 | nbrs = knns[:, 0, :]
150 | dists = knns[:, 1, :]
151 | assert -eps <= dists.min() <= dists.max(
152 | ) <= 1 + eps, "min: {}, max: {}".format(dists.min(), dists.max())
153 | if use_sim:
154 | sims = 1. - dists
155 | else:
156 | sims = dists
157 | if fill_value is not None:
158 | print('[fast_knns2spmat] edge fill value:', fill_value)
159 | sims.fill(fill_value)
160 | row, col = np.where(sims >= th_sim) # 这里划分阈值了
161 | # remove the self-loop
162 | idxs = np.where(row != nbrs[row, col])
163 | row = row[idxs]
164 | col = col[idxs]
165 | data = sims[row, col]
166 | col = nbrs[row, col] # convert to absolute column of the FULL N*N adj matrix
167 | assert len(row) == len(col) == len(data)
168 | spmat = csr_matrix((data, (row, col)), shape=(n, n))
169 | return spmat
170 |
171 |
172 | def knns2sub_spmat(idxs, knns, th_sim=0.7, use_sim=False):
173 | # convert knns to symmetric sparse sub-matrix
174 | from scipy.sparse import csr_matrix
175 | n = len(idxs)
176 | row, col, data = [], [], []
177 | abs2rel = {}
178 | for rel_i, abs_i in enumerate(idxs):
179 | assert abs_i not in abs2rel
180 | abs2rel[abs_i] = rel_i
181 |
182 | for row_i, idx in enumerate(idxs):
183 | nbrs, dists = knns[idx]
184 | for nbr, dist in zip(nbrs, dists):
185 | if idx == nbr:
186 | assert abs(dist) < 1e-6, "{}: {}".format(idx, dist)
187 | continue
188 | if nbr not in abs2rel:
189 | continue
190 | col_i = abs2rel[nbr]
191 | assert -1e-6 <= dist <= 1
192 | w = dist
193 | if 1 - w < th_sim or nbr == -1:
194 | continue
195 | row.append(row_i)
196 | col.append(col_i)
197 | if use_sim:
198 | w = 1 - w
199 | data.append(w)
200 | assert len(row) == len(col) == len(data)
201 | spmat = csr_matrix((data, (row, col)), shape=(n, n))
202 | return spmat
203 |
204 |
205 | def build_knns(knn_prefix,
206 | feats,
207 | knn_method,
208 | k,
209 | num_process=None,
210 | is_rebuild=False,
211 | feat_create_time=None):
212 | knn_prefix = os.path.join(knn_prefix, '{}_k_{}'.format(knn_method, k))
213 | mkdir_if_no_exists(knn_prefix)
214 | knn_path = knn_prefix + '.npz'
215 | if os.path.isfile(
216 | knn_path) and not is_rebuild and feat_create_time is not None:
217 | knn_create_time = os.path.getmtime(knn_path)
218 | if knn_create_time <= feat_create_time:
219 | print('[warn] knn is created before feats ({} vs {})'.format(
220 | format_time(knn_create_time), format_time(feat_create_time)))
221 | is_rebuild = True
222 | if not os.path.isfile(knn_path) or is_rebuild:
223 | index_path = knn_prefix + '.index'
224 | with Timer('build index'):
225 | if knn_method == 'hnsw':
226 | index = knn_hnsw(feats, k, index_path)
227 | elif knn_method == 'faiss':
228 | index = knn_faiss(feats,
229 | k,
230 | index_path,
231 | omp_num_threads=num_process,
232 | rebuild_index=True)
233 | elif knn_method == 'faiss_gpu':
234 | index = knn_faiss_gpu(feats,
235 | k,
236 | index_path,
237 | num_process=num_process)
238 | else:
239 | raise KeyError(
240 | 'Only support hnsw and faiss currently ({}).'.format(
241 | knn_method))
242 | knns = index.get_knns()
243 | with Timer('dump knns to {}'.format(knn_path)):
244 | dump_data(knn_path, knns, force=True)
245 | else:
246 | print('read knn from {}'.format(knn_path))
247 | knns = load_data(knn_path)
248 | return knns
249 |
250 |
251 | #class knn():
252 | # def __init__(self, feats, k, index_path='', verbose=True):
253 | # pass
254 | #
255 | # def filter_by_th(self, i):
256 | # th_nbrs = []
257 | # th_dists = []
258 | # nbrs, dists = self.knns[i]
259 | # for n, dist in zip(nbrs, dists):
260 | # if 1 - dist < self.th:
261 | # continue
262 | # th_nbrs.append(n)
263 | # th_dists.append(dist)
264 | # th_nbrs = np.array(th_nbrs)
265 | # th_dists = np.array(th_dists)
266 | # return (th_nbrs, th_dists)
267 | #
268 | # def get_knns(self, th=None):
269 | # if th is None or th <= 0.:
270 | # return self.knns
271 | # # TODO: optimize the filtering process by numpy
272 | # # nproc = mp.cpu_count()
273 | # nproc = 1
274 | # with Timer('filter edges by th {} (CPU={})'.format(th, nproc),
275 | # self.verbose):
276 | # self.th = th
277 | # self.th_knns = []
278 | # tot = len(self.knns)
279 | # if nproc > 1:
280 | # pool = mp.Pool(nproc)
281 | # th_knns = list(
282 | # tqdm(pool.imap(self.filter_by_th, range(tot)), total=tot))
283 | # pool.close()
284 | # else:
285 | # th_knns = [self.filter_by_th(i) for i in range(tot)]
286 | # return th_knns
287 |
288 |
289 | class knn_brute_force(knn):
290 | def __init__(self, feats, k, index_path='', verbose=True):
291 | self.verbose = verbose
292 | with Timer('[brute force] build index', verbose):
293 | feats = feats.astype('float32')
294 | sim = feats.dot(feats.T)
295 | with Timer('[brute force] query topk {}'.format(k), verbose):
296 | nbrs = np.argpartition(-sim, kth=k)[:, :k]
297 | idxs = np.array([i for i in range(nbrs.shape[0])])
298 | dists = 1 - sim[idxs.reshape(-1, 1), nbrs]
299 | self.knns = [(np.array(nbr, dtype=np.int32),
300 | np.array(dist, dtype=np.float32))
301 | for nbr, dist in zip(nbrs, dists)]
302 |
303 |
304 | class knn_hnsw(knn):
305 | def __init__(self, feats, k, index_path='', verbose=True, **kwargs):
306 | import nmslib
307 | self.verbose = verbose
308 | with Timer('[hnsw] build index', verbose):
309 | ''' higher ef leads to better accuracy, but slower search
310 | higher M leads to higher accuracy/run_time at fixed ef,
311 | but consumes more memory
312 | '''
313 | # space_params = {
314 | # 'ef': 100,
315 | # 'M': 16,
316 | # }
317 | # index = nmslib.init(method='hnsw',
318 | # space='cosinesimil',
319 | # space_params=space_params)
320 | index = nmslib.init(method='hnsw', space='cosinesimil')
321 | if index_path != '' and os.path.isfile(index_path):
322 | index.loadIndex(index_path)
323 | else:
324 | index.addDataPointBatch(feats)
325 | index.createIndex({
326 | 'post': 2,
327 | 'indexThreadQty': 1
328 | },
329 | print_progress=verbose)
330 | if index_path:
331 | print('[hnsw] save index to {}'.format(index_path))
332 | mkdir_if_no_exists(index_path)
333 | index.saveIndex(index_path)
334 | with Timer('[hnsw] query topk {}'.format(k), verbose):
335 | knn_ofn = index_path + '.npz'
336 | if os.path.exists(knn_ofn):
337 | print('[hnsw] read knns from {}'.format(knn_ofn))
338 | self.knns = np.load(knn_ofn)['data']
339 | else:
340 | self.knns = index.knnQueryBatch(feats, k=k)
341 |
342 |
343 | class knn_faiss(knn):
344 | def __init__(self,
345 | feats,
346 | k,
347 | index_path='',
348 | index_key='',
349 | nprobe=128,
350 | omp_num_threads=None,
351 | rebuild_index=True,
352 | verbose=True,
353 | **kwargs):
354 | import faiss
355 | if omp_num_threads is not None:
356 | faiss.omp_set_num_threads(omp_num_threads)
357 | self.verbose = verbose
358 | with Timer('[faiss] build index', verbose):
359 | if index_path != '' and not rebuild_index and os.path.exists(
360 | index_path):
361 | print('[faiss] read index from {}'.format(index_path))
362 | index = faiss.read_index(index_path)
363 | else:
364 | feats = feats.astype('float32')
365 | size, dim = feats.shape
366 | index = faiss.IndexFlatIP(dim)
367 | if index_key != '':
368 | assert index_key.find(
369 | 'HNSW') < 0, 'HNSW returns distances insted of sims'
370 | metric = faiss.METRIC_INNER_PRODUCT
371 | nlist = min(4096, 8 * round(math.sqrt(size)))
372 | if index_key == 'IVF':
373 | quantizer = index
374 | index = faiss.IndexIVFFlat(quantizer, dim, nlist,
375 | metric)
376 | else:
377 | index = faiss.index_factory(dim, index_key, metric)
378 | if index_key.find('Flat') < 0:
379 | assert not index.is_trained
380 | index.train(feats)
381 | index.nprobe = min(nprobe, nlist)
382 | assert index.is_trained
383 | print('nlist: {}, nprobe: {}'.format(nlist, nprobe))
384 | index.add(feats)
385 | if index_path != '':
386 | print('[faiss] save index to {}'.format(index_path))
387 | mkdir_if_no_exists(index_path)
388 | faiss.write_index(index, index_path)
389 | with Timer('[faiss] query topk {}'.format(k), verbose):
390 | knn_ofn = index_path + '.npz'
391 | if os.path.exists(knn_ofn):
392 | print('[faiss] read knns from {}'.format(knn_ofn))
393 | self.knns = np.load(knn_ofn)['data']
394 | else:
395 | sims, nbrs = index.search(feats, k=k)
396 | self.knns = [(np.array(nbr, dtype=np.int32),
397 | 1 - np.array(sim, dtype=np.float32))
398 | for nbr, sim in zip(nbrs, sims)]
399 |
400 |
401 | class knn_faiss_gpu(knn):
402 | def __init__(self,
403 | feats,
404 | k,
405 | index_path='',
406 | index_key='',
407 | nprobe=128,
408 | num_process=4,
409 | is_precise=True,
410 | sort=True,
411 | verbose=True,
412 | **kwargs):
413 | with Timer('[faiss_gpu] query topk {}'.format(k), verbose):
414 | knn_ofn = index_path + '.npz'
415 | if os.path.exists(knn_ofn):
416 | print('[faiss_gpu] read knns from {}'.format(knn_ofn))
417 | self.knns = np.load(knn_ofn)['data']
418 | else:
419 | dists, nbrs = faiss_search_knn(feats,
420 | k=k,
421 | nprobe=nprobe,
422 | num_process=num_process,
423 | is_precise=is_precise,
424 | sort=sort,
425 | verbose=False)
426 |
427 | self.knns = [(np.array(nbr, dtype=np.int32),
428 | np.array(dist, dtype=np.float32))
429 | for nbr, dist in zip(nbrs, dists)]
430 |
431 |
432 | if __name__ == '__main__':
433 | from utils import l2norm
434 |
435 | k = 30
436 | d = 256
437 | nfeat = 10000
438 | np.random.seed(42)
439 |
440 | feats = np.random.random((nfeat, d)).astype('float32')
441 | feats = l2norm(feats)
442 |
443 | index1 = knn_hnsw(feats, k)
444 | index2 = knn_faiss(feats, k)
445 | index3 = knn_faiss(feats, k, index_key='Flat')
446 | index4 = knn_faiss(feats, k, index_key='IVF')
447 | index5 = knn_faiss(feats, k, index_key='IVF100,PQ32')
448 |
449 | print(index1.knns[0])
450 | print(index2.knns[0])
451 | print(index3.knns[0])
452 | print(index4.knns[0])
453 | print(index5.knns[0])
454 |
--------------------------------------------------------------------------------
/tool/max_Q_ind.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import sys
3 | import numpy as np
4 | from multiprocessing import Pool
5 | import pandas as pd
6 |
7 | total_k = 80
8 | def get_topK(query_nodeid):
9 | query_label = label_arr[query_nodeid]
10 |
11 | total_num = len(np.where(label_arr == query_label)[0])
12 | prec_list, recall_list, fscore_list = [], [], []
13 | for topK in range(1, total_k + 1):
14 | result_list = []
15 | for i in range(0, topK):
16 | doc_nodeid = I[query_nodeid][i]
17 | doc_label = label_arr[doc_nodeid]
18 | result = 1 if doc_label == query_label else 0
19 | if i == 0:
20 | result = 1
21 | result_list.append(result)
22 | prec = np.mean(result_list)
23 | recall = np.sum(result_list) / total_num
24 | fscore = (1 + beta*beta) * prec * recall / (beta*beta*prec + recall)
25 | prec_list.append(prec)
26 | recall_list.append(recall)
27 | fscore_list.append(fscore)
28 | fscore_arr = np.array(fscore_list)
29 | idx = fscore_arr.argmax()
30 | thres_topK = idx + 1
31 | return thres_topK, prec_list[idx], recall_list[idx], fscore_list[idx]
32 |
33 | if __name__ == "__main__":
34 | Ifile, labelfile, beta, outfile = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]
35 |
36 | I = np.load(Ifile)
37 | label_arr = np.load(labelfile)
38 | beta = float(beta)
39 |
40 | debug = True
41 | debug = False
42 | if debug:
43 | res = []
44 | for query_nodeid in range(len(I)):
45 | item = get_topK(query_nodeid)
46 | res.append(item)
47 | else:
48 | pool = Pool(48)
49 | res = pool.map(get_topK, range(len(I)))
50 | pool.close()
51 | pool.join()
52 |
53 | topK_list, prec_list, recall_list, fscore_list = list(zip(*res))
54 | np.save(outfile, topK_list)
55 |
56 |
--------------------------------------------------------------------------------
/tool/struct_space.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import sys
3 | import numpy as np
4 | from multiprocessing import Pool, Manager
5 | import copy
6 | from functools import partial
7 |
8 | k = 80
9 | lamb = 0.3
10 |
11 | def worker(k, queue, query_nodeid):
12 | docnodeid_list = I[query_nodeid, :k]
13 | query_Rstarset = Rstarset_list[query_nodeid]
14 | outlist = []
15 | for idx, doc_nodeid in enumerate(docnodeid_list):
16 | doc_Rstarset = Rstarset_list[doc_nodeid]
17 | sim = 1.0 * len(query_Rstarset & doc_Rstarset) / len(query_Rstarset | doc_Rstarset)
18 | jd = 1 - sim
19 | cd = D[query_nodeid, idx]
20 | nd = (1-lamb) * jd + lamb * cd
21 | tpl = (doc_nodeid, nd)
22 | outlist.append(tpl)
23 | outlist = sorted(outlist, key=lambda x:x[1])
24 | queue.put(query_nodeid)
25 | fn_name = sys._getframe().f_code.co_name
26 | #if queue.qsize() % 1000 == 0:
27 | # print("==>", fn_name, queue.qsize())
28 | return list(zip(*outlist))
29 |
30 | def get_Kngbr(query_nodeid, k):
31 | Kngbr = I[query_nodeid, :k]
32 | return set(Kngbr)
33 |
34 | def get_Rset(k, queue, query_nodeid):
35 | docnodeid_set = get_Kngbr(query_nodeid, k)
36 | Rset = set()
37 | for doc_nodeid in docnodeid_set:
38 | if query_nodeid not in get_Kngbr(doc_nodeid, k):
39 | continue
40 | Rset.add(doc_nodeid)
41 | queue.put(query_nodeid)
42 | fn_name = sys._getframe().f_code.co_name
43 | #if queue.qsize() % 1000 == 0:
44 | # print("==>", fn_name, queue.qsize())
45 | return Rset
46 |
47 | def get_Rstarset(queue, query_nodeid):
48 | Rset = Rset_list[query_nodeid]
49 | Rstarset = copy.deepcopy(Rset)
50 | for doc_nodeid in Rset:
51 | doc_Rset = half_Rset_list[doc_nodeid]
52 | if len(doc_Rset & Rset) >= len(doc_Rset) * 2 / 3:
53 | Rstarset |= doc_Rset
54 | queue.put(query_nodeid)
55 | fn_name = sys._getframe().f_code.co_name
56 | #if queue.qsize() % 1000 == 0:
57 | # print("==>", fn_name, queue.qsize())
58 | return Rstarset
59 |
60 | if __name__ == "__main__":
61 | Ifile, Dfile, topk, outIfile, outDfile, outDatafile = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6]
62 | k = int(topk)
63 | print("use topk", k)
64 | I = np.load(Ifile)
65 | D = np.load(Dfile)
66 |
67 | queue1 = Manager().Queue()
68 | queue2 = Manager().Queue()
69 | queue3 = Manager().Queue()
70 | queue4 = Manager().Queue()
71 | debug = True
72 | debug = False
73 | if debug:
74 | for query_nodeid in range(len(I)):
75 | res = worker(k, query_nodeid)
76 | else:
77 | pool = Pool(52)
78 | get_Rset_partial = partial(get_Rset, k, queue1)
79 | Rset_list = pool.map(get_Rset_partial, range(len(I)))
80 | pool.close()
81 | pool.join()
82 |
83 | pool = Pool(52)
84 | k2 = k // 2
85 | get_Rset_partial = partial(get_Rset, k2, queue2)
86 | half_Rset_list = pool.map(get_Rset_partial, range(len(I)))
87 | pool.close()
88 | pool.join()
89 |
90 | pool = Pool(52)
91 | get_Rstarset_partial = partial(get_Rstarset, queue3)
92 | Rstarset_list = pool.map(get_Rstarset_partial, range(len(I)))
93 | pool.close()
94 | pool.join()
95 |
96 | pool = Pool(52)
97 | worker_partial = partial(worker, k, queue4)
98 | res = pool.map(worker_partial, range(len(I)))
99 | pool.close()
100 | pool.join()
101 |
102 | newI, newD = list(zip(*res))
103 | newI = np.array(newI)
104 | newD = np.array(newD)
105 | newdata = np.concatenate((newI[:,None,:], newD[:,None,:]), axis=1)
106 | np.save(outIfile, newI)
107 | np.save(outDfile, newD)
108 | np.savez(outDatafile, data=newdata)
109 |
--------------------------------------------------------------------------------