├── .gitignore ├── .ipynb_checkpoints └── graph_visualization-checkpoint.ipynb ├── LICENSE ├── README.md ├── asset └── pipeline.jpg ├── feeder ├── __init__.py ├── __init__.pyc ├── feeder.py ├── feeder.pyc ├── feeder_visualization.py ├── feeder_visualization.pyc └── test_feeder.py ├── logs ├── epoch_0.ckpt ├── log.txt └── logs │ └── best.ckpt ├── model ├── __init__.py ├── __init__.pyc ├── gcn.py └── gcn.pyc ├── test.py ├── train.py ├── utils ├── __init__.py ├── __init__.pyc ├── graph.py ├── graph.pyc ├── logging.py ├── logging.pyc ├── meters.py ├── meters.pyc ├── osutils.py ├── osutils.pyc ├── serialization.py ├── serialization.pyc ├── utils.py └── utils.pyc └── visulization.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | logs 2 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/graph_visualization-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 ZhongdaoWang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Linkage-based Face Clustering via Graph Convolution Network 2 | This repository contains the code for our CVPR'19 paper [Linkage-based Face Clustering via GCN](https://arxiv.org/abs/1903.11306), by Zhongdao Wang, Liang Zheng, Yali Li and Shengjin Wang, Tsinghua University and Australian National University. 3 | 4 | ![](https://github.com/Zhongdao/gcn_clustering/blob/master/asset/pipeline.jpg) 5 | 6 | ## Introduction 7 | We present an accurate and scalable approach to the face clustering task. We aim at grouping a set of faces by their potential identities. We formulate this task as a link prediction problem: a link exists between two faces if they are of the same identity. The key idea is that 8 | we find the local context in the feature space around an instance(face) contains rich information about the linkage relationship between this instance and its neighbors. By constructing 9 | sub-graphs around each instance as input data, 10 | which depict the local context, we utilize the graph convolution 11 | network (GCN) to perform reasoning and infer the 12 | likelihood of linkage between pairs in the sub-graphs. 13 | 14 | ## Requirements 15 | - PyTorch 0.4.0 16 | - Python 2.7 17 | - sklearn >= 0.19.1 18 | 19 | ## Data Format 20 | Firstly, extract features for IJB-B data, and save the features as an NxD dimensional `.npy` file, in which each row is a D-dimensional feature for a sample. Then, save the labels as an Nx1 dimensional `.npy` file, each row is an integer indicating the identity. Lastly, generate the KNN graph (either by brute force or ANN). The KNN graph should be saved as an Nx(K+1) dimensional `.npy` file, and in each row, the first element is the node index, and the following K elements are the indices of its KNN nodes. 21 | 22 | For training, featrues+labels+knn_graphs are needed. For testing, only features+knn_graphs are needed, but if you need to compute accuracy the labels are also needed. 23 | We also provide the ArcFace features / labels / knn_graphs of IJB-B/CASIA dataset at [OneDrive](https://1drv.ms/u/s!Ai0390AjdQNVhUbCRARo8PVc1m3j) and [Baidu NetDisk](https://pan.baidu.com/s/1wmMct86Izubw7d2hgBga7A), extract code: 8wj1 24 | 25 | ## Testing 26 | ``` 27 | python test.py --val_feat_path path/to/features --val_knn_graph_path path/to/knn/graph --val_labels_path path/to/labels --checkpoint path/to/gcn_weights 28 | ``` 29 | During inference, the test script will dynamically output the pairwise precision/recall/accuracy. After each subgraph is processed, the test script will output the final B-Cubed precision/recall/F-score (Note that it is not the same as the pairwise p/r/acc) and NMI score. 30 | 31 | ## Training 32 | ``` 33 | python train.py --feat_path path/to/features --knn_graph_path path/to/knn/graph --labels_path path/to/labels 34 | ``` 35 | We employ the CASIA dataset to train the GCN. Usually, 4 epoch is sufficient. We provide a pre-trained model weights in `logs/logs/best.ckpt` 36 | 37 | ## Citation 38 | If you find GCN-Clustering helps your research, please cite our paper: 39 | ``` 40 | @inproceedings{wang2019gncclust, 41 | title={Linkage-based Face Clustering via Graph Convolution Network }, 42 | author={Zhongdao Wang, Liang Zheng, Yali Li and Shengjin Wang}, 43 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 44 | year={2019} 45 | } 46 | ``` 47 | ## Ackownledgement 48 | I borrowed some code on pseudo label propagation from [CDP](https://github.com/XiaohangZhan/cdp), many thanks to [Xiaohang Zhan](https://github.com/XiaohangZhan)! 49 | -------------------------------------------------------------------------------- /asset/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/asset/pipeline.jpg -------------------------------------------------------------------------------- /feeder/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################################### 2 | # File Name: __init__.py 3 | # Author: Zhongdao Wang 4 | # mail: wcd17@mails.tsinghua.edu.cn 5 | # Created Time: Fri 07 Sep 2018 12:57:24 PM CST 6 | ################################################################### 7 | 8 | from __future__ import print_function 9 | from __future__ import division 10 | from __future__ import absolute_import 11 | -------------------------------------------------------------------------------- /feeder/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/feeder/__init__.pyc -------------------------------------------------------------------------------- /feeder/feeder.py: -------------------------------------------------------------------------------- 1 | ################################################################### 2 | # File Name: feeder.py 3 | # Author: Zhongdao Wang 4 | # mail: wcd17@mails.tsinghua.edu.cn 5 | # Created Time: Thu 06 Sep 2018 01:06:16 PM CST 6 | ################################################################### 7 | 8 | from __future__ import print_function 9 | from __future__ import division 10 | from __future__ import absolute_import 11 | 12 | import numpy as np 13 | import random 14 | import torch 15 | import torch.utils.data as data 16 | class Feeder(data.Dataset): 17 | ''' 18 | Generate a sub-graph from the feature graph centered at some node, 19 | and now the sub-graph has a fixed depth, i.e. 2 20 | ''' 21 | def __init__(self, feat_path, knn_graph_path, label_path, seed=1, 22 | k_at_hop=[200,5], active_connection=5, train=True): 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | self.features = np.load(feat_path) 26 | self.knn_graph = np.load(knn_graph_path)[:,:k_at_hop[0]+1] 27 | self.labels = np.load(label_path) 28 | self.num_samples = len(self.features) 29 | self.depth = len(k_at_hop) 30 | self.k_at_hop = k_at_hop 31 | self.active_connection = active_connection 32 | self.train = train 33 | assert np.mean(k_at_hop)>=active_connection 34 | 35 | def __len__(self): 36 | return self.num_samples 37 | 38 | def __getitem__(self, index): 39 | ''' 40 | return the vertex feature and the adjacent matrix A, together 41 | with the indices of the center node and its 1-hop nodes 42 | ''' 43 | # hops[0] for 1-hop neighbors, hops[1] for 2-hop neighbors 44 | hops = list() 45 | center_node = index 46 | hops.append(set(self.knn_graph[center_node][1:])) 47 | 48 | # Actually we dont need the loop since the depth is fixed here, 49 | # But we still remain the code for further revision 50 | for d in range(1,self.depth): 51 | hops.append(set()) 52 | for h in hops[-2]: 53 | hops[-1].update(set(self.knn_graph[h][1:self.k_at_hop[d]+1])) 54 | 55 | 56 | hops_set = set([h for hop in hops for h in hop]) 57 | hops_set.update([center_node,]) 58 | unique_nodes_list = list(hops_set) 59 | unique_nodes_map = {j:i for i,j in enumerate(unique_nodes_list)} 60 | 61 | center_idx = torch.Tensor([unique_nodes_map[center_node],]).type(torch.long) 62 | one_hop_idcs = torch.Tensor([unique_nodes_map[i] for i in hops[0]]).type(torch.long) 63 | center_feat = torch.Tensor(self.features[center_node]).type(torch.float) 64 | feat = torch.Tensor(self.features[unique_nodes_list]).type(torch.float) 65 | feat = feat - center_feat 66 | 67 | max_num_nodes = self.k_at_hop[0] * (self.k_at_hop[1] + 1) + 1 68 | num_nodes = len(unique_nodes_list) 69 | A = torch.zeros(num_nodes, num_nodes) 70 | 71 | _, fdim = feat.shape 72 | feat = torch.cat([feat, torch.zeros(max_num_nodes - num_nodes, fdim)], dim=0) 73 | 74 | for node in unique_nodes_list: 75 | neighbors = self.knn_graph[node, 1:self.active_connection+1] 76 | for n in neighbors: 77 | if n in unique_nodes_list: 78 | A[unique_nodes_map[node], unique_nodes_map[n]] = 1 79 | A[unique_nodes_map[n], unique_nodes_map[node]] = 1 80 | 81 | D = A.sum(1, keepdim=True) 82 | A = A.div(D) 83 | A_ = torch.zeros(max_num_nodes,max_num_nodes) 84 | A_[:num_nodes,:num_nodes] = A 85 | 86 | 87 | labels = self.labels[np.asarray(unique_nodes_list)] 88 | labels = torch.from_numpy(labels).type(torch.long) 89 | #edge_labels = labels.expand(num_nodes,num_nodes).eq( 90 | # labels.expand(num_nodes,num_nodes).t()) 91 | one_hop_labels = labels[one_hop_idcs] 92 | center_label = labels[center_idx] 93 | edge_labels = (center_label == one_hop_labels).long() 94 | 95 | if self.train: 96 | return (feat, A_, center_idx, one_hop_idcs), edge_labels 97 | 98 | # Testing 99 | unique_nodes_list = torch.Tensor(unique_nodes_list) 100 | unique_nodes_list = torch.cat( 101 | [unique_nodes_list, torch.zeros(max_num_nodes-num_nodes)], dim=0) 102 | return(feat, A_, center_idx, one_hop_idcs, unique_nodes_list), edge_labels 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /feeder/feeder.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/feeder/feeder.pyc -------------------------------------------------------------------------------- /feeder/feeder_visualization.py: -------------------------------------------------------------------------------- 1 | ################################################################### 2 | # File Name: feeder.py 3 | # Author: Zhongdao Wang 4 | # mail: wcd17@mails.tsinghua.edu.cn 5 | # Created Time: Thu 06 Sep 2018 01:06:16 PM CST 6 | ################################################################### 7 | 8 | from __future__ import print_function 9 | from __future__ import division 10 | from __future__ import absolute_import 11 | 12 | import numpy as np 13 | import random 14 | import torch 15 | import torch.utils.data as data 16 | class Feeder(data.Dataset): 17 | ''' 18 | Generate a sub-graph from the feature graph centered at some node, 19 | and now the sub-graph has a fixed depth, i.e. 2 20 | ''' 21 | def __init__(self, feat_path, knn_graph_path, label_path, seed=1, 22 | k_at_hop=[200,5], active_connection=5, train=True): 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | self.features = np.load(feat_path) 26 | self.knn_graph = np.load(knn_graph_path)[:,:k_at_hop[0]+1] 27 | self.labels = np.load(label_path) 28 | self.num_samples = len(self.features) 29 | self.depth = len(k_at_hop) 30 | self.k_at_hop = k_at_hop 31 | self.active_connection = active_connection 32 | self.train = train 33 | assert np.mean(k_at_hop)>=active_connection 34 | 35 | def __len__(self): 36 | return self.num_samples 37 | 38 | def __getitem__(self, index): 39 | ''' 40 | return the vertex feature and the adjacent matrix A, together 41 | with the indices of the center node and its 1-hop nodes 42 | ''' 43 | # hops[0] for 1-hop neighbors, hops[1] for 2-hop neighbors 44 | hops = list() 45 | center_node = index 46 | hops.append(set(self.knn_graph[center_node][1:])) 47 | 48 | # Actually we dont need the loop since the depth is fixed here, 49 | # But we still remain the code for further revision 50 | for d in range(1,self.depth): 51 | hops.append(set()) 52 | for h in hops[-2]: 53 | hops[-1].update(set(self.knn_graph[h][1:self.k_at_hop[d]+1])) 54 | 55 | 56 | hops_set = set([h for hop in hops for h in hop]) 57 | hops_set.update([center_node,]) 58 | unique_nodes_list = list(hops_set) 59 | unique_nodes_map = {j:i for i,j in enumerate(unique_nodes_list)} 60 | 61 | center_idx = torch.Tensor([unique_nodes_map[center_node],]).type(torch.long) 62 | one_hop_idcs = torch.Tensor([unique_nodes_map[i] for i in hops[0]]).type(torch.long) 63 | center_feat = torch.Tensor(self.features[center_node]).type(torch.float) 64 | feat = torch.Tensor(self.features[unique_nodes_list]).type(torch.float) 65 | feat = feat - center_feat 66 | 67 | max_num_nodes = self.k_at_hop[0] * (self.k_at_hop[1] + 1) + 1 68 | num_nodes = len(unique_nodes_list) 69 | A = torch.zeros(num_nodes, num_nodes) 70 | 71 | _, fdim = feat.shape 72 | feat = torch.cat([feat, torch.zeros(max_num_nodes - num_nodes, fdim)], dim=0) 73 | 74 | for node in unique_nodes_list: 75 | neighbors = self.knn_graph[node, 1:self.active_connection+1] 76 | for n in neighbors: 77 | if n in unique_nodes_list: 78 | A[unique_nodes_map[node], unique_nodes_map[n]] = 1 79 | A[unique_nodes_map[n], unique_nodes_map[node]] = 1 80 | 81 | D = A.sum(1, keepdim=True) 82 | A = A.div(D) 83 | A_ = torch.zeros(max_num_nodes,max_num_nodes) 84 | A_[:num_nodes,:num_nodes] = A 85 | 86 | 87 | labels = self.labels[np.asarray(unique_nodes_list)] 88 | labels = torch.from_numpy(labels).type(torch.long) 89 | #edge_labels = labels.expand(num_nodes,num_nodes).eq( 90 | # labels.expand(num_nodes,num_nodes).t()) 91 | one_hop_labels = labels[one_hop_idcs] 92 | center_label = labels[center_idx] 93 | edge_labels = (center_label == one_hop_labels).long() 94 | 95 | if self.train: 96 | return (feat, A_, center_idx, one_hop_idcs), edge_labels, labels 97 | 98 | # Testing 99 | unique_nodes_list = torch.Tensor(unique_nodes_list) 100 | unique_nodes_list = torch.cat( 101 | [unique_nodes_list, torch.zeros(max_num_nodes-num_nodes)], dim=0) 102 | return(feat, A_, center_idx, one_hop_idcs, unique_nodes_list), edge_labels, labels 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /feeder/feeder_visualization.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/feeder/feeder_visualization.pyc -------------------------------------------------------------------------------- /feeder/test_feeder.py: -------------------------------------------------------------------------------- 1 | ################################################################### 2 | # File Name: test_feeder.py 3 | # Author: Zhongdao Wang 4 | # mail: wcd17@mails.tsinghua.edu.cn 5 | # Created Time: Thu 06 Sep 2018 04:09:46 PM CST 6 | ################################################################### 7 | 8 | from __future__ import print_function 9 | from __future__ import division 10 | from __future__ import absolute_import 11 | 12 | from feeder import Feeder 13 | import torch 14 | torch.set_printoptions(threshold=10000000, linewidth=500) 15 | import time 16 | if __name__ == '__main__': 17 | feeder = Feeder('../../facedata/1024.fea.npy', 18 | '../../facedata/knn.graph.1024.kdtree.npy', 19 | '../../facedata/1024.labels.npy', 20 | seed=2111112, 21 | k_at_hop=[5,5], 22 | active_connection=3) 23 | (feat, A, cn, oh), edge_labels = feeder[0] 24 | 25 | print(oh) 26 | print(edge_labels) 27 | length = feat.norm(2,dim=1) 28 | print(length) 29 | #print(torch.sum(A,dim=1)) 30 | #print(torch.sum(edge_labels,dim=0)) 31 | #print(A) 32 | #print(edge_labels) 33 | #print(feat.shape) 34 | #feat =feat.div(feat.norm(2,dim=1,keepdim=True).expand_as(feat)) 35 | #print(torch.mm(feat,feat.t())) 36 | -------------------------------------------------------------------------------- /logs/epoch_0.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/logs/epoch_0.ckpt -------------------------------------------------------------------------------- /logs/log.txt: -------------------------------------------------------------------------------- 1 | Current lr 0.01 2 | -------------------------------------------------------------------------------- /logs/logs/best.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/logs/logs/best.ckpt -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################################### 2 | # File Name: __init__.py 3 | # Author: Zhongdao Wang 4 | # mail: wcd17@mails.tsinghua.edu.cn 5 | # Created Time: Fri 07 Sep 2018 12:57:08 PM CST 6 | ################################################################### 7 | 8 | from __future__ import print_function 9 | from __future__ import division 10 | from __future__ import absolute_import 11 | 12 | from .gcn import * 13 | -------------------------------------------------------------------------------- /model/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/model/__init__.pyc -------------------------------------------------------------------------------- /model/gcn.py: -------------------------------------------------------------------------------- 1 | ################################################################### 2 | # File Name: gcn.py 3 | # Author: Zhongdao Wang 4 | # mail: wcd17@mails.tsinghua.edu.cn 5 | # Created Time: Fri 07 Sep 2018 01:16:31 PM CST 6 | ################################################################### 7 | 8 | from __future__ import print_function 9 | from __future__ import division 10 | from __future__ import absolute_import 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.nn import init 16 | 17 | 18 | class MeanAggregator(nn.Module): 19 | def __init__(self): 20 | super(MeanAggregator, self).__init__() 21 | def forward(self, features, A ): 22 | x = torch.bmm(A, features) 23 | return x 24 | 25 | class GraphConv(nn.Module): 26 | def __init__(self, in_dim, out_dim, agg): 27 | super(GraphConv, self).__init__() 28 | self.in_dim = in_dim 29 | self.out_dim = out_dim 30 | self.weight = nn.Parameter( 31 | torch.FloatTensor(in_dim *2, out_dim)) 32 | self.bias = nn.Parameter(torch.FloatTensor(out_dim)) 33 | init.xavier_uniform_(self.weight) 34 | init.constant_(self.bias, 0) 35 | self.agg = agg() 36 | 37 | def forward(self, features, A): 38 | b, n, d = features.shape 39 | assert(d==self.in_dim) 40 | agg_feats = self.agg(features,A) 41 | cat_feats = torch.cat([features, agg_feats], dim=2) 42 | out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight)) 43 | out = F.relu(out + self.bias) 44 | return out 45 | 46 | 47 | class gcn(nn.Module): 48 | def __init__(self): 49 | super(gcn, self).__init__() 50 | self.bn0 = nn.BatchNorm1d(512, affine=False) 51 | self.conv1 = GraphConv(512, 512, MeanAggregator) 52 | self.conv2 = GraphConv(512, 512, MeanAggregator) 53 | self.conv3 = GraphConv(512, 256, MeanAggregator) 54 | self.conv4 = GraphConv(256, 256,MeanAggregator) 55 | 56 | self.classifier = nn.Sequential( 57 | nn.Linear(256, 256), 58 | nn.PReLU(256), 59 | nn.Linear(256, 2)) 60 | 61 | def forward(self, x, A, one_hop_idcs, train=True): 62 | # data normalization l2 -> bn 63 | B,N,D = x.shape 64 | #xnorm = x.norm(2,2,keepdim=True) + 1e-8 65 | #xnorm = xnorm.expand_as(x) 66 | #x = x.div(xnorm) 67 | 68 | x = x.view(-1, D) 69 | x = self.bn0(x) 70 | x = x.view(B,N,D) 71 | 72 | 73 | x = self.conv1(x,A) 74 | x = self.conv2(x,A) 75 | x = self.conv3(x,A) 76 | x = self.conv4(x,A) 77 | k1 = one_hop_idcs.size(-1) 78 | dout = x.size(-1) 79 | edge_feat = torch.zeros(B,k1,dout).cuda() 80 | for b in range(B): 81 | edge_feat[b,:,:] = x[b, one_hop_idcs[b]] 82 | edge_feat = edge_feat.view(-1,dout) 83 | pred = self.classifier(edge_feat) 84 | 85 | # shape: (B*k1)x2 86 | return pred 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /model/gcn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/model/gcn.pyc -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | ################################################################### 2 | # File Name: train.py 3 | # Author: Zhongdao Wang 4 | # mail: wcd17@mails.tsinghua.edu.cn 5 | # Created Time: Thu 06 Sep 2018 10:08:49 PM CST 6 | ################################################################### 7 | 8 | from __future__ import print_function 9 | from __future__ import division 10 | from __future__ import absolute_import 11 | 12 | import os 13 | import os.path as osp 14 | import sys 15 | import time 16 | import argparse 17 | import numpy as np 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.backends import cudnn 23 | from torch.utils.data import DataLoader 24 | 25 | import model 26 | from feeder.feeder import Feeder 27 | from utils import to_numpy 28 | from utils.meters import AverageMeter 29 | from utils.serialization import load_checkpoint 30 | from utils.utils import bcubed 31 | from utils.graph import graph_propagation, graph_propagation_soft, graph_propagation_naive 32 | 33 | from sklearn.metrics import normalized_mutual_info_score, precision_score, recall_score 34 | 35 | def single_remove(Y, pred): 36 | single_idcs = np.zeros_like(pred) 37 | pred_unique = np.unique(pred) 38 | for u in pred_unique: 39 | idcs = pred == u 40 | if np.sum(idcs) == 1: 41 | single_idcs[np.where(idcs)[0][0]] = 1 42 | remain_idcs = [i for i in range(len(pred)) if not single_idcs[i]] 43 | remain_idcs = np.asarray(remain_idcs) 44 | return Y[remain_idcs], pred[remain_idcs] 45 | 46 | def main(args): 47 | np.random.seed(args.seed) 48 | torch.manual_seed(args.seed) 49 | cudnn.benchmark = True 50 | 51 | valset = Feeder(args.val_feat_path, 52 | args.val_knn_graph_path, 53 | args.val_label_path, 54 | args.seed, 55 | args.k_at_hop, 56 | args.active_connection, 57 | train=False) 58 | valloader = DataLoader( 59 | valset, batch_size=args.batch_size, 60 | num_workers=args.workers, shuffle=False, pin_memory=True) 61 | 62 | ckpt = load_checkpoint(args.checkpoint) 63 | net = model.gcn() 64 | net.load_state_dict(ckpt['state_dict']) 65 | net = net.cuda() 66 | 67 | knn_graph = valset.knn_graph 68 | knn_graph_dict = list() 69 | for neighbors in knn_graph: 70 | knn_graph_dict.append(dict()) 71 | for n in neighbors[1:]: 72 | knn_graph_dict[-1][n] = [] 73 | 74 | criterion = nn.CrossEntropyLoss().cuda() 75 | edges, scores = validate(valloader, net, criterion) 76 | 77 | np.save('edges', edges) 78 | np.save('scores', scores) 79 | #edges=np.load('edges.npy') 80 | #scores = np.load('scores.npy') 81 | 82 | clusters = graph_propagation(edges, scores, max_sz=900, step=0.6, pool='avg' ) 83 | final_pred = clusters2labels(clusters, len(valset)) 84 | labels = valset.labels 85 | 86 | print('------------------------------------') 87 | print('Number of nodes: ', len(labels)) 88 | print('Precision Recall F-Sore NMI') 89 | p,r,f = bcubed(labels, final_pred) 90 | nmi = normalized_mutual_info_score(final_pred, labels) 91 | print(('{:.4f} '*4).format(p,r,f, nmi)) 92 | 93 | labels, final_pred = single_remove(labels, final_pred) 94 | print('------------------------------------') 95 | print('After removing singleton culsters, number of nodes: ', len(labels)) 96 | print('Precision Recall F-Sore NMI') 97 | p,r,f = bcubed(labels, final_pred) 98 | nmi = normalized_mutual_info_score(final_pred, labels) 99 | print(('{:.4f} '*4).format(p,r,f, nmi)) 100 | 101 | 102 | def clusters2labels(clusters, n_nodes): 103 | labels = (-1)* np.ones((n_nodes,)) 104 | for ci, c in enumerate(clusters): 105 | for xid in c: 106 | labels[xid.name] = ci 107 | assert np.sum(labels<0) < 1 108 | return labels 109 | 110 | def make_labels(gtmat): 111 | return gtmat.view(-1) 112 | 113 | def validate(loader, net, crit): 114 | batch_time = AverageMeter() 115 | data_time = AverageMeter() 116 | losses = AverageMeter() 117 | accs = AverageMeter() 118 | precisions = AverageMeter() 119 | recalls = AverageMeter() 120 | 121 | net.eval() 122 | end = time.time() 123 | edges = list() 124 | scores = list() 125 | for i, ((feat, adj, cid, h1id, node_list), gtmat) in enumerate(loader): 126 | data_time.update(time.time() - end) 127 | feat, adj, cid, h1id, gtmat = map(lambda x: x.cuda(), 128 | (feat, adj, cid, h1id, gtmat)) 129 | pred = net(feat, adj, h1id) 130 | labels = make_labels(gtmat).long() 131 | loss = crit(pred, labels) 132 | pred = F.softmax(pred, dim=1) 133 | p,r, acc = accuracy(pred, labels) 134 | 135 | losses.update(loss.item(),feat.size(0)) 136 | accs.update(acc.item(),feat.size(0)) 137 | precisions.update(p, feat.size(0)) 138 | recalls.update(r,feat.size(0)) 139 | 140 | batch_time.update(time.time()- end) 141 | end = time.time() 142 | if i % args.print_freq == 0: 143 | print('[{0}/{1}]\t' 144 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 145 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 146 | 'Loss {losses.val:.3f} ({losses.avg:.3f})\n' 147 | 'Accuracy {accs.val:.3f} ({accs.avg:.3f})\t' 148 | 'Precison {precisions.val:.3f} ({precisions.avg:.3f})\t' 149 | 'Recall {recalls.val:.3f} ({recalls.avg:.3f})'.format( 150 | i, len(loader), batch_time=batch_time, 151 | data_time=data_time, losses=losses, accs=accs, 152 | precisions=precisions, recalls=recalls)) 153 | 154 | node_list = node_list.long().squeeze().numpy() 155 | bs = feat.size(0) 156 | for b in range(bs): 157 | cidb = cid[b].int().item() 158 | nl = node_list[b] 159 | 160 | for j,n in enumerate(h1id[b]): 161 | n = n.item() 162 | edges.append([nl[cidb], nl[n]]) 163 | scores.append(pred[b*args.k_at_hop[0]+j,1].item()) 164 | edges = np.asarray(edges) 165 | scores = np.asarray(scores) 166 | return edges, scores 167 | 168 | def accuracy(pred, label): 169 | pred = torch.argmax(pred, dim=1).long() 170 | acc = torch.mean((pred == label).float()) 171 | pred = to_numpy(pred) 172 | label = to_numpy(label) 173 | p = precision_score(label, pred) 174 | r = recall_score(label, pred) 175 | return p,r,acc 176 | 177 | if __name__ == '__main__': 178 | parser = argparse.ArgumentParser() 179 | # misc 180 | working_dir = osp.dirname(osp.abspath(__file__)) 181 | parser.add_argument('--seed', default=1, type=int) 182 | parser.add_argument('--workers', default=16, type=int) 183 | parser.add_argument('--print_freq', default=40, type=int) 184 | 185 | # Optimization args 186 | parser.add_argument('--lr', type=float, default=1e-5) 187 | parser.add_argument('--momentum', type=float, default=0.9) 188 | parser.add_argument('--weight_decay', type=float, default=1e-4) 189 | parser.add_argument('--epochs', type=int, default=20) 190 | 191 | parser.add_argument('--batch_size', type=int, default=32) 192 | parser.add_argument('--k-at-hop', type=int, nargs='+', default=[20,5]) 193 | parser.add_argument('--active_connection', type=int, default=5) 194 | 195 | # Validation args 196 | parser.add_argument('--val_feat_path', type=str, metavar='PATH', 197 | default=osp.join(working_dir, '../facedata/1024.fea.npy')) 198 | parser.add_argument('--val_knn_graph_path', type=str, metavar='PATH', 199 | default=osp.join(working_dir, '../facedata/knn.graph.1024.bf.npy')) 200 | parser.add_argument('--val_label_path', type=str, metavar='PATH', 201 | default=osp.join(working_dir, '../facedata/1024.labels.npy')) 202 | 203 | # Test args 204 | parser.add_argument('--checkpoint', type=str, metavar='PATH', default='./logs/logs/best.ckpt') 205 | args = parser.parse_args() 206 | main(args) 207 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ################################################################### 2 | # File Name: train.py 3 | # Author: Zhongdao Wang 4 | # mail: wcd17@mails.tsinghua.edu.cn 5 | # Created Time: Thu 06 Sep 2018 10:08:49 PM CST 6 | ################################################################### 7 | 8 | from __future__ import print_function 9 | from __future__ import division 10 | from __future__ import absolute_import 11 | 12 | import os 13 | import os.path as osp 14 | import sys 15 | import time 16 | import argparse 17 | import numpy as np 18 | 19 | import torch 20 | import torch.nn as nn 21 | from torch.backends import cudnn 22 | from torch.utils.data import DataLoader 23 | 24 | import model 25 | from feeder.feeder import Feeder 26 | from utils import to_numpy 27 | from utils.logging import Logger 28 | from utils.meters import AverageMeter 29 | from utils.serialization import save_checkpoint 30 | 31 | from sklearn.metrics import precision_score, recall_score 32 | 33 | 34 | def main(args): 35 | np.random.seed(args.seed) 36 | torch.manual_seed(args.seed) 37 | cudnn.benchmark = True 38 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 39 | 40 | trainset = Feeder(args.feat_path, 41 | args.knn_graph_path, 42 | args.label_path, 43 | args.seed, 44 | args.k_at_hop, 45 | args.active_connection) 46 | trainloader = DataLoader( 47 | trainset, batch_size=args.batch_size, 48 | num_workers=args.workers, shuffle=True, pin_memory=True) 49 | 50 | net = model.gcn().cuda() 51 | opt = torch.optim.SGD(net.parameters(), args.lr, 52 | momentum=args.momentum, 53 | weight_decay=args.weight_decay) 54 | 55 | criterion = nn.CrossEntropyLoss().cuda() 56 | 57 | save_checkpoint({ 58 | 'state_dict':net.state_dict(), 59 | 'epoch': 0,}, False, 60 | fpath=osp.join(args.logs_dir, 'epoch_{}.ckpt'.format(0))) 61 | for epoch in range(args.epochs): 62 | adjust_lr(opt, epoch) 63 | 64 | train(trainloader, net, criterion, opt, epoch) 65 | save_checkpoint({ 66 | 'state_dict':net.state_dict(), 67 | 'epoch': epoch+1,}, False, 68 | fpath=osp.join(args.logs_dir, 'epoch_{}.ckpt'.format(epoch+1))) 69 | 70 | 71 | def train(loader, net, crit, opt, epoch): 72 | batch_time = AverageMeter() 73 | data_time = AverageMeter() 74 | losses = AverageMeter() 75 | accs = AverageMeter() 76 | precisions = AverageMeter() 77 | recalls = AverageMeter() 78 | 79 | net.train() 80 | end = time.time() 81 | for i, ((feat, adj, cid, h1id), gtmat) in enumerate(loader): 82 | data_time.update(time.time() - end) 83 | feat, adj, cid, h1id, gtmat = map(lambda x: x.cuda(), 84 | (feat, adj, cid, h1id, gtmat)) 85 | pred = net(feat, adj, h1id) 86 | labels = make_labels(gtmat).long() 87 | loss = crit(pred, labels) 88 | p,r, acc = accuracy(pred, labels) 89 | 90 | opt.zero_grad() 91 | loss.backward() 92 | opt.step() 93 | 94 | losses.update(loss.item(),feat.size(0)) 95 | accs.update(acc.item(),feat.size(0)) 96 | precisions.update(p, feat.size(0)) 97 | recalls.update(r,feat.size(0)) 98 | 99 | batch_time.update(time.time()- end) 100 | end = time.time() 101 | if i % args.print_freq == 0: 102 | print('Epoch:[{0}][{1}/{2}]\t' 103 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 104 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 105 | 'Loss {losses.val:.3f} ({losses.avg:.3f})\t' 106 | 'Accuracy {accs.val:.3f} ({accs.avg:.3f})\t' 107 | 'Precison {precisions.val:.3f} ({precisions.avg:.3f})\t' 108 | 'Recall {recalls.val:.3f} ({recalls.avg:.3f})'.format( 109 | epoch, i, len(loader), batch_time=batch_time, 110 | data_time=data_time, losses=losses, accs=accs, 111 | precisions=precisions, recalls=recalls)) 112 | 113 | 114 | def make_labels(gtmat): 115 | return gtmat.view(-1) 116 | 117 | def adjust_lr(opt, epoch): 118 | scale = 0.1 119 | print('Current lr {}'.format(args.lr)) 120 | if epoch in [1,2,3,4]: 121 | args.lr *=0.1 122 | print('Change lr to {}'.format(args.lr)) 123 | for param_group in opt.param_groups: 124 | param_group['lr'] = param_group['lr'] * scale 125 | 126 | 127 | 128 | def accuracy(pred, label): 129 | pred = torch.argmax(pred, dim=1).long() 130 | acc = torch.mean((pred == label).float()) 131 | pred = to_numpy(pred) 132 | label = to_numpy(label) 133 | p = precision_score(label, pred) 134 | r = recall_score(label, pred) 135 | return p,r,acc 136 | 137 | if __name__ == '__main__': 138 | parser = argparse.ArgumentParser() 139 | # misc 140 | working_dir = osp.dirname(osp.abspath(__file__)) 141 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 142 | default=osp.join(working_dir, 'logs')) 143 | parser.add_argument('--seed', default=1, type=int) 144 | parser.add_argument('--workers', default=16, type=int) 145 | parser.add_argument('--print_freq', default=200, type=int) 146 | 147 | # Optimization args 148 | parser.add_argument('--lr', type=float, default=1e-2) 149 | parser.add_argument('--momentum', type=float, default=0.9) 150 | parser.add_argument('--weight_decay', type=float, default=1e-4) 151 | parser.add_argument('--epochs', type=int, default=4) 152 | 153 | # Training args 154 | parser.add_argument('--batch_size', type=int, default=16) 155 | parser.add_argument('--feat_path', type=str, metavar='PATH', 156 | default=osp.join(working_dir, '../facedata/CASIA.feas.npy')) 157 | parser.add_argument('--knn_graph_path', type=str, metavar='PATH', 158 | default=osp.join(working_dir, '../facedata/knn.graph.CASIA.kdtree.npy')) 159 | parser.add_argument('--label_path', type=str, metavar='PATH', 160 | default=osp.join(working_dir, '../facedata/CASIA.labels.npy')) 161 | parser.add_argument('--k-at-hop', type=int, nargs='+', default=[200,10]) 162 | parser.add_argument('--active_connection', type=int, default=10) 163 | 164 | args = parser.parse_args() 165 | main(args) 166 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /utils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/utils/__init__.pyc -------------------------------------------------------------------------------- /utils/graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import time 7 | 8 | class Data(object): 9 | def __init__(self, name): 10 | self.__name = name 11 | self.__links = set() 12 | 13 | @property 14 | def name(self): 15 | return self.__name 16 | 17 | @property 18 | def links(self): 19 | return set(self.__links) 20 | 21 | def add_link(self, other, score): 22 | self.__links.add(other) 23 | other.__links.add(self) 24 | 25 | def connected_components(nodes, score_dict, th): 26 | ''' 27 | conventional connected components searching 28 | ''' 29 | result = [] 30 | nodes = set(nodes) 31 | while nodes: 32 | n = nodes.pop() 33 | group = {n} 34 | queue = [n] 35 | while queue: 36 | n = queue.pop(0) 37 | if th is not None: 38 | neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th} 39 | else: 40 | neighbors = n.links 41 | neighbors.difference_update(group) 42 | nodes.difference_update(neighbors) 43 | group.update(neighbors) 44 | queue.extend(neighbors) 45 | result.append(group) 46 | return result 47 | 48 | def connected_components_constraint(nodes, max_sz, score_dict=None, th=None): 49 | ''' 50 | only use edges whose scores are above `th` 51 | if a component is larger than `max_sz`, all the nodes in this component are added into `remain` and returned for next iteration. 52 | ''' 53 | result = [] 54 | remain = set() 55 | nodes = set(nodes) 56 | while nodes: 57 | n = nodes.pop() 58 | group = {n} 59 | queue = [n] 60 | valid = True 61 | while queue: 62 | n = queue.pop(0) 63 | if th is not None: 64 | neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th} 65 | else: 66 | neighbors = n.links 67 | neighbors.difference_update(group) 68 | nodes.difference_update(neighbors) 69 | group.update(neighbors) 70 | queue.extend(neighbors) 71 | if len(group) > max_sz or len(remain.intersection(neighbors)) > 0: 72 | # if this group is larger than `max_sz`, add the nodes into `remain` 73 | valid = False 74 | remain.update(group) 75 | break 76 | if valid: # if this group is smaller than or equal to `max_sz`, finalize it. 77 | result.append(group) 78 | return result, remain 79 | 80 | 81 | def graph_propagation_naive(edges, score, th): 82 | 83 | edges = np.sort(edges, axis=1) 84 | 85 | # construct graph 86 | score_dict = {} # score lookup table 87 | for i,e in enumerate(edges): 88 | score_dict[e[0], e[1]] = score[i] 89 | 90 | nodes = np.sort(np.unique(edges.flatten())) 91 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int) 92 | mapping[nodes] = np.arange(nodes.shape[0]) 93 | link_idx = mapping[edges] 94 | vertex = [Data(n) for n in nodes] 95 | for l, s in zip(link_idx, score): 96 | vertex[l[0]].add_link(vertex[l[1]], s) 97 | 98 | # first iteration 99 | comps = connected_components(vertex, score_dict,th) 100 | 101 | return comps 102 | 103 | def graph_propagation(edges, score, max_sz, step=0.1, beg_th=0.9, pool=None): 104 | 105 | edges = np.sort(edges, axis=1) 106 | th = score.min() 107 | #th = beg_th 108 | # construct graph 109 | score_dict = {} # score lookup table 110 | if pool is None: 111 | for i,e in enumerate(edges): 112 | score_dict[e[0], e[1]] = score[i] 113 | elif pool == 'avg': 114 | for i,e in enumerate(edges): 115 | if score_dict.has_key((e[0],e[1])): 116 | score_dict[e[0], e[1]] = 0.5*(score_dict[e[0], e[1]] + score[i]) 117 | else: 118 | score_dict[e[0], e[1]] = score[i] 119 | 120 | elif pool == 'max': 121 | for i,e in enumerate(edges): 122 | if score_dict.has_key((e[0],e[1])): 123 | score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]] , score[i]) 124 | else: 125 | score_dict[e[0], e[1]] = score[i] 126 | else: 127 | raise ValueError('Pooling operation not supported') 128 | 129 | nodes = np.sort(np.unique(edges.flatten())) 130 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int) 131 | mapping[nodes] = np.arange(nodes.shape[0]) 132 | link_idx = mapping[edges] 133 | vertex = [Data(n) for n in nodes] 134 | for l, s in zip(link_idx, score): 135 | vertex[l[0]].add_link(vertex[l[1]], s) 136 | 137 | # first iteration 138 | comps, remain = connected_components_constraint(vertex, max_sz) 139 | 140 | # iteration 141 | components = comps[:] 142 | while remain: 143 | th = th + (1 - th) * step 144 | comps, remain = connected_components_constraint(remain, max_sz, score_dict, th) 145 | components.extend(comps) 146 | return components 147 | 148 | def graph_propagation_soft(edges, score, max_sz, step=0.1, **kwargs): 149 | 150 | edges = np.sort(edges, axis=1) 151 | th = score.min() 152 | 153 | # construct graph 154 | score_dict = {} # score lookup table 155 | for i,e in enumerate(edges): 156 | score_dict[e[0], e[1]] = score[i] 157 | 158 | nodes = np.sort(np.unique(edges.flatten())) 159 | mapping = -1 * np.ones((nodes.max()+1), dtype=np.int) 160 | mapping[nodes] = np.arange(nodes.shape[0]) 161 | link_idx = mapping[edges] 162 | vertex = [Data(n) for n in nodes] 163 | for l, s in zip(link_idx, score): 164 | vertex[l[0]].add_link(vertex[l[1]], s) 165 | 166 | # first iteration 167 | comps, remain = connected_components_constraint(vertex, max_sz) 168 | first_vertex_idx = np.array([mapping[n.name] for c in comps for n in c]) 169 | fusion_vertex_idx = np.setdiff1d(np.arange(nodes.shape[0]), first_vertex_idx, assume_unique=True) 170 | # iteration 171 | components = comps[:] 172 | while remain: 173 | th = th + (1 - th) * step 174 | comps, remain = connected_components_constraint(remain, max_sz, score_dict, th) 175 | components.extend(comps) 176 | label_dict = {} 177 | for i,c in enumerate(components): 178 | for n in c: 179 | label_dict[n.name] = i 180 | print('Propagation ...') 181 | prop_vertex = [vertex[idx] for idx in fusion_vertex_idx] 182 | label, label_fusion = diffusion(prop_vertex, label_dict, score_dict, **kwargs) 183 | return label, label_fusion 184 | 185 | def diffusion(vertex, label, score_dict, max_depth=5, weight_decay=0.6, normalize=True): 186 | class BFSNode(): 187 | def __init__(self, node, depth, value): 188 | self.node = node 189 | self.depth = depth 190 | self.value = value 191 | 192 | label_fusion = {} 193 | for name in label.keys(): 194 | label_fusion[name] = {label[name]: 1.0} 195 | prog = 0 196 | prog_step = len(vertex) // 20 197 | start = time.time() 198 | for root in vertex: 199 | if prog % prog_step == 0: 200 | print("progress: {} / {}, elapsed time: {}".format(prog, len(vertex), time.time() - start)) 201 | prog += 1 202 | #queue = {[root, 0, 1.0]} 203 | queue = {BFSNode(root, 0, 1.0)} 204 | visited = [root.name] 205 | root_label = label[root.name] 206 | while queue: 207 | curr = queue.pop() 208 | if curr.depth >= max_depth: # pruning 209 | continue 210 | neighbors = curr.node.links 211 | tmp_value = [] 212 | tmp_neighbor = [] 213 | for n in neighbors: 214 | if n.name not in visited: 215 | sub_value = score_dict[tuple(sorted([curr.node.name, n.name]))] * weight_decay * curr.value 216 | tmp_value.append(sub_value) 217 | tmp_neighbor.append(n) 218 | if root_label not in label_fusion[n.name].keys(): 219 | label_fusion[n.name][root_label] = sub_value 220 | else: 221 | label_fusion[n.name][root_label] += sub_value 222 | visited.append(n.name) 223 | #queue.add([n, curr.depth+1, sub_value]) 224 | sortidx = np.argsort(tmp_value)[::-1] 225 | for si in sortidx: 226 | queue.add(BFSNode(tmp_neighbor[si], curr.depth+1, tmp_value[si])) 227 | if normalize: 228 | for name in label_fusion.keys(): 229 | summ = sum(label_fusion[name].values()) 230 | for k in label_fusion[name].keys(): 231 | label_fusion[name][k] /= summ 232 | return label, label_fusion 233 | -------------------------------------------------------------------------------- /utils/graph.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/utils/graph.pyc -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /utils/logging.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/utils/logging.pyc -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /utils/meters.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/utils/meters.pyc -------------------------------------------------------------------------------- /utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /utils/osutils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/utils/osutils.pyc -------------------------------------------------------------------------------- /utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /utils/serialization.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/utils/serialization.pyc -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | ################################################################### 2 | # File Name: utils.py 3 | # Author: Zhongdao Wang 4 | # mail: wcd17@mails.tsinghua.edu.cn 5 | # Created Time: Tue 28 Aug 2018 04:57:29 PM CST 6 | ################################################################### 7 | 8 | from __future__ import print_function 9 | from __future__ import division 10 | from __future__ import absolute_import 11 | import numpy as np 12 | from scipy.sparse import coo_matrix 13 | 14 | import matplotlib 15 | matplotlib.use('Agg') 16 | import matplotlib.pyplot as plt 17 | 18 | def norm(X): 19 | for ix,x in enumerate(X): 20 | X[ix]/=np.linalg.norm(x) 21 | return X 22 | 23 | def plot_embedding(X,Y): 24 | x_min, x_max = np.min(X,0), np.max(X,0) 25 | X = (X - x_min) / (x_max - x_min) 26 | plt.figure(figsize=(10,10)) 27 | for i in xrange(X.shape[0]): 28 | plt.text(X[i,0],X[i,1], str(Y[i]), 29 | color=plt.cm.Set1(Y[i]/10.), 30 | fontdict={'weight':'bold','size':12}) 31 | plt.savefig('a.jpg') 32 | 33 | 34 | EPS = np.finfo(float).eps 35 | 36 | 37 | def contingency_matrix(ref_labels, sys_labels): 38 | """Return contingency matrix between ``ref_labels`` and ``sys_labels``.""" 39 | ref_classes, ref_class_inds = np.unique(ref_labels, return_inverse=True) 40 | sys_classes, sys_class_inds = np.unique(sys_labels, return_inverse=True) 41 | n_frames = ref_labels.size 42 | # Following works because coo_matrix sums duplicate entries. Is roughly 43 | # twice as fast as np.histogram2d. 44 | cmatrix = coo_matrix( 45 | (np.ones(n_frames), (ref_class_inds, sys_class_inds)), 46 | shape=(ref_classes.size, sys_classes.size), 47 | dtype=np.int) 48 | cmatrix = cmatrix.toarray() 49 | return cmatrix, ref_classes, sys_classes 50 | 51 | 52 | def bcubed(ref_labels, sys_labels, cm=None): 53 | """Return B-cubed precision, recall, and F1. 54 | 55 | The B-cubed precision of an item is the proportion of items with its 56 | system label that share its reference label (Bagga and Baldwin, 1998). 57 | Similarly, the B-cubed recall of an item is the proportion of items 58 | with its reference label that share its system label. The overall B-cubed 59 | precision and recall, then, are the means of the precision and recall for 60 | each item. 61 | 62 | Parameters 63 | ---------- 64 | ref_labels : ndarray, (n_frames,) 65 | Reference labels. 66 | 67 | sys_labels : ndarray, (n_frames,) 68 | System labels. 69 | 70 | cm : ndarray, (n_ref_classes, n_sys_classes) 71 | Contingency matrix between reference and system labelings. If None, 72 | will be computed automatically from ``ref_labels`` and ``sys_labels``. 73 | Otherwise, the given value will be used and ``ref_labels`` and 74 | ``sys_labels`` ignored. 75 | (Default: None) 76 | 77 | Returns 78 | ------- 79 | precision : float 80 | B-cubed precision. 81 | 82 | recall : float 83 | B-cubed recall. 84 | 85 | f1 : float 86 | B-cubed F1. 87 | 88 | References 89 | ---------- 90 | Bagga, A. and Baldwin, B. (1998). "Algorithms for scoring coreference 91 | chains." Proceedings of LREC 1998. 92 | """ 93 | if cm is None: 94 | cm, _, _ = contingency_matrix(ref_labels, sys_labels) 95 | cm = cm.astype('float64') 96 | cm_norm = cm / cm.sum() 97 | precision = np.sum(cm_norm * (cm / cm.sum(axis=0))) 98 | recall = np.sum(cm_norm * (cm / np.expand_dims(cm.sum(axis=1), 1))) 99 | f1 = 2*(precision*recall)/(precision + recall) 100 | return precision, recall, f1 101 | -------------------------------------------------------------------------------- /utils/utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhongdao/gcn_clustering/3068dba9b12caf3f3cb4968c8045a2f1602171a5/utils/utils.pyc --------------------------------------------------------------------------------