├── figures ├── ACT.jpg └── MSMT.jpg ├── reid ├── utils │ ├── data │ │ ├── __init__.py │ │ ├── preprocessor.py │ │ ├── sampler.py │ │ ├── transforms.py │ │ ├── video_loader.py │ │ └── dataset.py │ ├── osutils.py │ ├── meters.py │ ├── __init__.py │ ├── logging.py │ └── serialization.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── feature_extraction │ ├── __init__.py │ ├── cnn.py │ └── database.py ├── __init__.py ├── loss │ ├── __init__.py │ ├── gan_loss.py │ ├── var_loss.py │ ├── neighbour_loss.py │ ├── center_loss.py │ ├── matchLoss.py │ ├── oim.py │ ├── triplet.py │ └── virtual_ce.py ├── models │ ├── module.py │ ├── functional.py │ ├── classifier.py │ ├── __init__.py │ ├── resnet.py │ └── inception.py ├── metric_learning │ ├── euclidean.py │ ├── __init__.py │ └── kissme.py ├── dataloader.py ├── dist_metric.py ├── datasets │ ├── __init__.py │ ├── utils.py │ ├── cuhk03.py │ ├── msmt17.py │ ├── market1501.py │ └── dukemtmc.py ├── evaluators.py └── rerank.py ├── requirements.txt ├── data └── readme.md ├── readme.md ├── selftrainingKmeans.py ├── selfNoise.py ├── selftrainingKmeansAsy.py ├── selftrainingCT.py ├── selftrainingACT.py └── selftrainingRCT.py /figures/ACT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyingRoastDuck/ACT_AAAI20/HEAD/figures/ACT.jpg -------------------------------------------------------------------------------- /figures/MSMT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyingRoastDuck/ACT_AAAI20/HEAD/figures/MSMT.jpg -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .dataset import Dataset 4 | from .preprocessor import Preprocessor 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16.4 2 | matplotlib==3.1.1 3 | torch==1.3.1 4 | metric_learn==0.4.0 5 | tqdm==4.32.2 6 | torchvision==0.2.0 7 | scipy==1.1.0 8 | h5py==2.9.0 9 | Pillow==6.2.1 10 | six==1.13.0 11 | scikit_learn==0.21.3 12 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cnn import extract_cnn_feature 4 | from .database import FeatureDatabase 5 | 6 | __all__ = [ 7 | 'extract_cnn_feature', 8 | 'FeatureDatabase', 9 | ] 10 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | # put all datasets under this folder like this: 2 | 3 | ``` 4 | data 5 | ----cuhk03 6 | --------raw 7 | ------------cuhk03_release.zip 8 | ----market1501 9 | --------raw 10 | ------------Market-1501-v15.09.15.zip 11 | ----dukemtmc 12 | --------raw 13 | ------------DukeMTMC-reID.zip 14 | ``` -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import feature_extraction 6 | from . import loss 7 | from . import metric_learning 8 | from . import models 9 | from . import utils 10 | from . import dist_metric 11 | from . import evaluators 12 | from . import trainers 13 | from . import rerank 14 | from . import dataloader 15 | 16 | __version__ = '0.2.0' 17 | -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .oim import oim, OIM, OIMLoss 4 | from .triplet import TripletLoss 5 | from .center_loss import CenterLoss 6 | from .matchLoss import lossMMD 7 | from .neighbour_loss import NeiLoss 8 | from .virtual_ce import VirtualCE, VirtualKCE 9 | 10 | __all__ = [ 11 | 'oim', 'OIM', 'OIMLoss','NeiLoss', 12 | 'TripletLoss', 'CenterLoss', 'lossMMD', 'VirtualCE', 'VirtualKCE' 13 | ] 14 | -------------------------------------------------------------------------------- /reid/models/module.py: -------------------------------------------------------------------------------- 1 | from .functional import revgrad 2 | from torch.nn import Module 3 | 4 | 5 | class RevGrad(Module): 6 | def __init__(self, *args, **kwargs): 7 | """ 8 | A gradient reversal layer. 9 | 10 | This layer has no parameters, and simply reverses the gradient 11 | in the backward pass. 12 | """ 13 | 14 | super().__init__(*args, **kwargs) 15 | 16 | def forward(self, input_): 17 | return revgrad(input_) 18 | -------------------------------------------------------------------------------- /reid/models/functional.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | 3 | 4 | class RevGrad(Function): 5 | @staticmethod 6 | def forward(ctx, input_): 7 | ctx.save_for_backward(input_) 8 | output = input_ 9 | return output 10 | 11 | @staticmethod 12 | def backward(ctx, grad_output): # pragma: no cover 13 | grad_input = None 14 | if ctx.needs_input_grad[0]: 15 | grad_input = -grad_output 16 | return grad_input 17 | 18 | 19 | revgrad = RevGrad.apply 20 | -------------------------------------------------------------------------------- /reid/metric_learning/euclidean.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from metric_learn.base_metric import BaseMetricLearner 5 | 6 | 7 | class Euclidean(BaseMetricLearner): 8 | def __init__(self): 9 | self.M_ = None 10 | 11 | def metric(self): 12 | return self.M_ 13 | 14 | def fit(self, X): 15 | self.M_ = np.eye(X.shape[1]) 16 | self.X_ = X 17 | 18 | def transform(self, X=None): 19 | if X is None: 20 | return self.X_ 21 | return X 22 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray).cuda() 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray.cuda() 22 | -------------------------------------------------------------------------------- /reid/loss/gan_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class GanLoss(nn.Module): 6 | def __init__(self, modelD, modelG, inDim=2048, outDim=2, lossD=nn.CrossEntropyLoss(), lossG=nn.CrossEntropyLoss()): 7 | super(GanLoss, self).__init__() 8 | self.inDim = inDim 9 | self.outDim = outDim 10 | self.modelD = modelD 11 | self.modelG = modelG 12 | self.lossG = lossG 13 | self.lossD = lossD 14 | 15 | def forward(self, x, labels, domainLab): 16 | dScore, gScore = self.modelD(x), self.modelG(x) 17 | lossDomain = self.lossD(dScore, domainLab) 18 | lossCls = self.lossG(gScore, labels) 19 | return lossDomain, lossCls 20 | -------------------------------------------------------------------------------- /reid/metric_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from metric_learn import (ITML_Supervised, LMNN, LSML_Supervised, 4 | SDML_Supervised, NCA, LFDA, RCA_Supervised) 5 | 6 | from .euclidean import Euclidean 7 | from .kissme import KISSME 8 | 9 | __factory = { 10 | 'euclidean': Euclidean, 11 | 'kissme': KISSME, 12 | 'itml': ITML_Supervised, 13 | 'lmnn': LMNN, 14 | 'lsml': LSML_Supervised, 15 | 'sdml': SDML_Supervised, 16 | 'nca': NCA, 17 | 'lfda': LFDA, 18 | 'rca': RCA_Supervised, 19 | } 20 | 21 | 22 | def get_metric(algorithm, *args, **kwargs): 23 | if algorithm not in __factory: 24 | raise KeyError("Unknown metric:", algorithm) 25 | return __factory[algorithm](*args, **kwargs) 26 | -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import OrderedDict 3 | 4 | from ..utils import to_torch 5 | from torch.autograd import Variable 6 | 7 | 8 | def extract_cnn_feature(model, inputs): 9 | model.eval() 10 | inputs = Variable(to_torch(inputs)) 11 | 12 | outputs = model(inputs)[0] 13 | outputs = outputs.data.cpu() 14 | return outputs # pool5 15 | 16 | # Register forward hook for each module 17 | outputs = OrderedDict() 18 | handles = [] 19 | for m in modules: 20 | outputs[id(m)] = None 21 | 22 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 23 | 24 | handles.append(m.register_forward_hook(func)) 25 | model(inputs) 26 | for h in handles: 27 | h.remove() 28 | return list(outputs.values()) 29 | -------------------------------------------------------------------------------- /reid/models/classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | from torch.autograd import Variable 9 | import torchvision 10 | # from torch_deform_conv.layers import ConvOffset2D 11 | from reid.utils.serialization import load_checkpoint, save_checkpoint 12 | 13 | 14 | class classifier(nn.Module): 15 | def __init__(self, inDim, outDim): 16 | super(classifier,self).__init__() 17 | self.cls = nn.Linear(inDim, outDim) 18 | # self.cls = nn.Sequential( 19 | # nn.Linear(inDim, outDim), 20 | # # nn.ReLU() 21 | # ) 22 | init.normal(self.cls.weight, std=0.001) 23 | init.constant(self.cls.bias, 0) 24 | 25 | def forward(self, x): 26 | return self.cls(x) -------------------------------------------------------------------------------- /reid/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class GraphLoader(object): 7 | def __init__(self, trainList, labels, model, loss=None): 8 | # self.hw = [384,128] 9 | self.graphs = trainList 10 | self.ID = labels 11 | self.model = model 12 | self.loss = loss 13 | 14 | def __getitem__(self, idx): 15 | curGraph = self.graphs[idx] 16 | # featSize = curGraph.size(0) 17 | # useFeat = curGraph[np.random.choice(featSize, size=(int(0.8*featSize),), replace=False), :] 18 | gEmb, scores = self.model(curGraph) 19 | loss = self.loss(scores, torch.LongTensor([self.ID[idx]]).cuda()) if self.loss is not None else 0 20 | return gEmb.squeeze(), loss # return embedding and loss 21 | 22 | def __len__(self): 23 | return len(self.graphs) 24 | -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/dist_metric.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | from .evaluators import extract_features 6 | from .metric_learning import get_metric 7 | 8 | 9 | class DistanceMetric(object): 10 | def __init__(self, algorithm='euclidean', *args, **kwargs): 11 | super(DistanceMetric, self).__init__() 12 | self.algorithm = algorithm 13 | self.metric = get_metric(algorithm, *args, **kwargs) 14 | 15 | def train(self, model, data_loader): 16 | if self.algorithm == 'euclidean': return 17 | features, labels = extract_features(model, data_loader) 18 | features = torch.stack(features.values()).numpy() 19 | labels = torch.Tensor(list(labels.values())).numpy() 20 | self.metric.fit(features, labels) 21 | 22 | def transform(self, X): 23 | if torch.is_tensor(X): 24 | X = X.numpy() 25 | X = self.metric.transform(X) 26 | X = torch.from_numpy(X) 27 | else: 28 | X = self.metric.transform(X) 29 | return X 30 | 31 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | 6 | 7 | class Preprocessor(object): 8 | def __init__(self, dataset, root=None, transform=None): 9 | super(Preprocessor, self).__init__() 10 | self.dataset = dataset 11 | self.root = root 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | def __getitem__(self, indices): 18 | if isinstance(indices, (tuple, list)): 19 | return [self._get_single_item(index) for index in indices] 20 | return self._get_single_item(indices) 21 | 22 | def _get_single_item(self, index): 23 | fname, pid, camid = self.dataset[index] 24 | fpath = fname 25 | if self.root is not None: 26 | fpath = osp.join(self.root, fname) 27 | if not osp.exists(fpath): 28 | fpath = fname 29 | 30 | img = Image.open(fpath).convert('RGB') 31 | if self.transform is not None: 32 | img = self.transform(img) 33 | return img, fname, pid, camid 34 | -------------------------------------------------------------------------------- /reid/loss/var_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import defaultdict 4 | 5 | 6 | class VarLoss(nn.Module): 7 | def __init__(self, feat_dim=768): 8 | super(VarLoss, self).__init__() 9 | self.feat_dim = feat_dim 10 | self.simiFunc = nn.Softmax(dim=0) 11 | 12 | def __calDis(self, x, y): # 246s 13 | # x, y = F.normalize(qFeature), F.normalize(gFeature) 14 | # x, y = qFeature, gFeature 15 | m, n = x.shape[0], y.shape[0] 16 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 17 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 18 | disMat.addmm_(1, -2, x, y.t()) 19 | return disMat 20 | 21 | def forward(self, x, labels): 22 | labelMap = defaultdict(list) 23 | # per-ID features 24 | labVal = [int(val) for val in labels.cpu()] 25 | for pid in set(labVal): 26 | labelMap[pid].append(x[labels == pid, :]) 27 | # cal loss 28 | loss = 0 29 | for keyNum in labelMap.keys(): 30 | meanVec = labelMap[keyNum][0].mean(dim=0, keepdim=True) 31 | dist = self.__calDis(meanVec, labelMap[keyNum][0]) 32 | import ipdb; 33 | ipdb.set_trace() 34 | loss += dist.mean() 35 | return loss 36 | -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | def __init__(self, data_source, num_instances=1): 13 | self.data_source = data_source 14 | self.num_instances = num_instances 15 | self.index_dic = defaultdict(list) 16 | for index, (_, pid, _) in enumerate(data_source): 17 | self.index_dic[pid].append(index) 18 | self.pids = list(self.index_dic.keys()) 19 | self.num_samples = len(self.pids) 20 | 21 | def __len__(self): 22 | return self.num_samples * self.num_instances 23 | 24 | def __iter__(self): 25 | indices = torch.randperm(self.num_samples) 26 | ret = [] 27 | for i in indices: 28 | pid = self.pids[i] 29 | t = self.index_dic[pid] 30 | if len(t) >= self.num_instances: 31 | t = np.random.choice(t, size=self.num_instances, replace=False) 32 | else: 33 | t = np.random.choice(t, size=self.num_instances, replace=True) 34 | ret.extend(t) 35 | return iter(ret) 36 | -------------------------------------------------------------------------------- /reid/loss/neighbour_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import defaultdict 4 | 5 | class NeiLoss(nn.Module): 6 | def __init__(self, feat_dim=768): 7 | super(NeiLoss, self).__init__() 8 | self.feat_dim = feat_dim 9 | self.simiFunc = nn.Softmax(dim=0) 10 | 11 | def __calDis(self, x, y):#246s 12 | # x, y = F.normalize(qFeature), F.normalize(gFeature) 13 | # x, y = qFeature, gFeature 14 | m, n = x.shape[0], y.shape[0] 15 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 16 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 17 | disMat.addmm_(1, -2, x, y.t()) 18 | return disMat 19 | 20 | def forward(self, x, labels): 21 | bSize = x.shape[0] 22 | labelMap = defaultdict(list) 23 | distmat = self.__calDis(x,x) 24 | # per-ID features 25 | labVal = [int(val) for val in labels.cpu()] 26 | for pid in set(labVal): 27 | labelMap[pid].append(labels==pid) 28 | # cal loss 29 | loss = 0 30 | for keyNum in labelMap.keys(): 31 | mask = labelMap[keyNum] 32 | curProb = distmat[labels==keyNum][0] 33 | loss += -torch.log(self.simiFunc(curProb)[mask].sum()).sum() 34 | return loss/len(labelMap.keys()) 35 | -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .cuhk03 import CUHK03 5 | from .dukemtmc import DukeMTMC 6 | from .market1501 import Market1501 7 | from .msmt17 import msmt17 8 | 9 | __factory = { 10 | 'cuhk03': CUHK03, 11 | 'market1501': Market1501, 12 | 'dukemtmc': DukeMTMC, 13 | 'msmt17': msmt17, 14 | } 15 | 16 | 17 | def names(): 18 | return sorted(__factory.keys()) 19 | 20 | 21 | def create(name, root, *args, **kwargs): 22 | """ 23 | Create a dataset instance. 24 | 25 | Parameters 26 | ---------- 27 | name : str 28 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03', 29 | 'market1501', and 'dukemtmc'. 30 | root : str 31 | The path to the dataset directory. 32 | split_id : int, optional 33 | The index of data split. Default: 0 34 | num_val : int or float, optional 35 | When int, it means the number of validation identities. When float, 36 | it means the proportion of validation to all the trainval. Default: 100 37 | download : bool, optional 38 | If True, will download the dataset. Default: False 39 | """ 40 | if name not in __factory: 41 | raise KeyError("Unknown dataset:", name) 42 | return __factory[name](root, *args, **kwargs) 43 | 44 | 45 | def get_dataset(name, root, *args, **kwargs): 46 | warnings.warn("get_dataset is deprecated. Use create instead.") 47 | return create(name, root, *args, **kwargs) 48 | -------------------------------------------------------------------------------- /reid/feature_extraction/database.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import h5py 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FeatureDatabase(Dataset): 9 | def __init__(self, *args, **kwargs): 10 | super(FeatureDatabase, self).__init__() 11 | self.fid = h5py.File(*args, **kwargs) 12 | 13 | def __enter__(self): 14 | return self 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | self.close() 18 | 19 | def __getitem__(self, keys): 20 | if isinstance(keys, (tuple, list)): 21 | return [self._get_single_item(k) for k in keys] 22 | return self._get_single_item(keys) 23 | 24 | def _get_single_item(self, key): 25 | return np.asarray(self.fid[key]) 26 | 27 | def __setitem__(self, key, value): 28 | if key in self.fid: 29 | if self.fid[key].shape == value.shape and \ 30 | self.fid[key].dtype == value.dtype: 31 | self.fid[key][...] = value 32 | else: 33 | del self.fid[key] 34 | self.fid.create_dataset(key, data=value) 35 | else: 36 | self.fid.create_dataset(key, data=value) 37 | 38 | def __delitem__(self, key): 39 | del self.fid[key] 40 | 41 | def __len__(self): 42 | return len(self.fid) 43 | 44 | def __iter__(self): 45 | return iter(self.fid) 46 | 47 | def flush(self): 48 | self.fid.flush() 49 | 50 | def close(self): 51 | self.fid.close() 52 | -------------------------------------------------------------------------------- /reid/loss/center_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class CenterLoss(nn.Module): 5 | """Center loss. 6 | 7 | Reference: 8 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 9 | 10 | Args: 11 | num_classes (int): number of classes. 12 | feat_dim (int): feature dimension. 13 | """ 14 | def __init__(self, num_classes=10, feat_dim=2, use_gpu=True): 15 | super(CenterLoss, self).__init__() 16 | self.num_classes = num_classes 17 | self.feat_dim = feat_dim 18 | self.use_gpu = use_gpu 19 | 20 | if self.use_gpu: 21 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 22 | else: 23 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 24 | 25 | def forward(self, x, labels): 26 | """ 27 | Args: 28 | x: feature matrix with shape (batch_size, feat_dim). 29 | labels: ground truth labels with shape (batch_size). 30 | """ 31 | batch_size = x.size(0) 32 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 33 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 34 | distmat.addmm_(1, -2, x, self.centers.t()) 35 | 36 | classes = torch.arange(self.num_classes).long() 37 | if self.use_gpu: classes = classes.cuda() 38 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 39 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 40 | 41 | dist = distmat * mask.float() 42 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 43 | 44 | return loss 45 | -------------------------------------------------------------------------------- /reid/loss/matchLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def pairwiseDis(x, y):#246s 5 | # x, y = F.normalize(qFeature), F.normalize(gFeature) 6 | # x, y = qFeature, gFeature 7 | m, n = x.shape[0], y.shape[0] 8 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 9 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 10 | disMat.addmm_(1, -2, x, y.t()) 11 | return disMat 12 | 13 | 14 | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 15 | n_samples = source.shape[0]+target.shape[0] 16 | L2_distance = pairwiseDis(source, target) 17 | if fix_sigma: 18 | bandwidth = fix_sigma 19 | else: 20 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) 21 | bandwidth /= kernel_mul ** (kernel_num // 2) 22 | bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] 23 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 24 | return sum(kernel_val)/len(kernel_val) 25 | 26 | def lossMMD(srcFeat, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 27 | srcBatch = srcFeat.shape[0] 28 | newFeat = torch.cat([srcFeat,target]) 29 | kernels = guassian_kernel(newFeat, newFeat, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 30 | XX = kernels[:srcBatch, :srcBatch] 31 | YY = kernels[srcBatch:, srcBatch:] 32 | XY = kernels[:srcBatch, srcBatch:] 33 | YX = kernels[srcBatch:, :srcBatch] 34 | return XX.mean() + YY.mean() - XY.mean() -YX.mean() 35 | 36 | 37 | # if __name__ == "__main__": 38 | # import numpy as np 39 | # srcFeat, tarFeat = np.load('E://gcn//srcGFeat.npy'), np.load('E://gcn//tarGFeat.npy') 40 | # disMat = lossMMD(torch.from_numpy(srcFeat), torch.from_numpy(tarFeat)) -------------------------------------------------------------------------------- /reid/metric_learning/kissme.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from metric_learn.base_metric import BaseMetricLearner 5 | 6 | 7 | def validate_cov_matrix(M): 8 | M = (M + M.T) * 0.5 9 | k = 0 10 | I = np.eye(M.shape[0]) 11 | while True: 12 | try: 13 | _ = np.linalg.cholesky(M) 14 | break 15 | except np.linalg.LinAlgError: 16 | # Find the nearest positive definite matrix for M. Modified from 17 | # http://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd 18 | # Might take several minutes 19 | k += 1 20 | w, v = np.linalg.eig(M) 21 | min_eig = v.min() 22 | M += (-min_eig * k * k + np.spacing(min_eig)) * I 23 | return M 24 | 25 | 26 | class KISSME(BaseMetricLearner): 27 | def __init__(self): 28 | self.M_ = None 29 | 30 | def metric(self): 31 | return self.M_ 32 | 33 | def fit(self, X, y=None): 34 | n = X.shape[0] 35 | if y is None: 36 | y = np.arange(n) 37 | X1, X2 = np.meshgrid(np.arange(n), np.arange(n)) 38 | X1, X2 = X1[X1 < X2], X2[X1 < X2] 39 | matches = (y[X1] == y[X2]) 40 | num_matches = matches.sum() 41 | num_non_matches = len(matches) - num_matches 42 | idxa = X1[matches] 43 | idxb = X2[matches] 44 | S = X[idxa] - X[idxb] 45 | C1 = S.transpose().dot(S) / num_matches 46 | p = np.random.choice(num_non_matches, num_matches, replace=False) 47 | idxa = X1[~matches] 48 | idxb = X2[~matches] 49 | idxa = idxa[p] 50 | idxb = idxb[p] 51 | S = X[idxa] - X[idxb] 52 | C0 = S.transpose().dot(S) / num_matches 53 | self.M_ = np.linalg.inv(C1) - np.linalg.inv(C0) 54 | self.M_ = validate_cov_matrix(self.M_) 55 | self.X_ = X 56 | -------------------------------------------------------------------------------- /reid/loss/oim.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, autograd 6 | 7 | 8 | class OIM(autograd.Function): 9 | def __init__(self, lut, momentum=0.5): 10 | super(OIM, self).__init__() 11 | self.lut = lut 12 | self.momentum = momentum 13 | 14 | def forward(self, inputs, targets): 15 | self.save_for_backward(inputs, targets) 16 | outputs = inputs.mm(self.lut.t()) 17 | return outputs 18 | 19 | def backward(self, grad_outputs): 20 | inputs, targets = self.saved_tensors 21 | grad_inputs = None 22 | if self.needs_input_grad[0]: 23 | grad_inputs = grad_outputs.mm(self.lut) 24 | for x, y in zip(inputs, targets): 25 | self.lut[y] = self.momentum * self.lut[y] + (1. - self.momentum) * x 26 | self.lut[y] /= self.lut[y].norm() 27 | return grad_inputs, None 28 | 29 | 30 | def oim(inputs, targets, lut, momentum=0.5): 31 | return OIM(lut, momentum=momentum)(inputs, targets) 32 | 33 | 34 | class OIMLoss(nn.Module): 35 | def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5, 36 | weight=None, size_average=True): 37 | super(OIMLoss, self).__init__() 38 | self.num_features = num_features 39 | self.num_classes = num_classes 40 | self.momentum = momentum 41 | self.scalar = scalar 42 | self.weight = weight 43 | self.size_average = size_average 44 | 45 | self.register_buffer('lut', torch.zeros(num_classes, num_features)) 46 | 47 | def forward(self, inputs, targets): 48 | inputs = oim(inputs, targets, self.lut, momentum=self.momentum) 49 | inputs *= self.scalar 50 | loss = F.cross_entropy(inputs, targets, weight=self.weight, 51 | size_average=self.size_average) 52 | return loss, inputs 53 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .inception import * 4 | from .resnet import * 5 | from .classifier import classifier 6 | 7 | __factory = { 8 | 'inception': inception, 9 | 'resnet18': resnet18, 10 | 'resnet34': resnet34, 11 | 'resnet50': resnet50, 12 | 'resnet101': resnet101, 13 | 'resnet152': resnet152, 14 | 'classifier': classifier 15 | } 16 | 17 | 18 | def names(): 19 | return sorted(__factory.keys()) 20 | 21 | 22 | def create(name, *args, **kwargs): 23 | """ 24 | Create a model instance. 25 | 26 | Parameters 27 | ---------- 28 | name : str 29 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 30 | 'resnet50', 'resnet101', and 'resnet152'. 31 | pretrained : bool, optional 32 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 33 | model. Default: True 34 | cut_at_pooling : bool, optional 35 | If True, will cut the model before the last global pooling layer and 36 | ignore the remaining kwargs. Default: False 37 | num_features : int, optional 38 | If positive, will append a Linear layer after the global pooling layer, 39 | with this number of output units, followed by a BatchNorm layer. 40 | Otherwise these layers will not be appended. Default: 256 for 41 | 'inception', 0 for 'resnet*' 42 | norm : bool, optional 43 | If True, will normalize the feature to be unit L2-norm for each sample. 44 | Otherwise will append a ReLU layer after the above Linear layer if 45 | num_features > 0. Default: False 46 | dropout : float, optional 47 | If positive, will append a Dropout layer with this dropout rate. 48 | Default: 0 49 | num_classes : int, optional 50 | If positive, will append a Linear layer at the end as the classifier 51 | with this number of output units. Default: 0 52 | """ 53 | if name not in __factory: 54 | raise KeyError("Unknown model:", name) 55 | return __factory[name](*args, **kwargs) 56 | -------------------------------------------------------------------------------- /reid/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | 7 | 8 | class TripletLoss(nn.Module): 9 | def __init__(self, margin=0, num_instances=0, use_semi=True, isAvg=True): 10 | super(TripletLoss, self).__init__() 11 | self.margin = margin 12 | self.ranking_loss = nn.MarginRankingLoss(margin=self.margin, reduce=isAvg) 13 | self.K = num_instances 14 | self.use_semi = use_semi 15 | 16 | def forward(self, inputs, targets, epoch): 17 | n = inputs.size(0) 18 | P = n // self.K 19 | 20 | # Compute pairwise distance, replace by the official when merged 21 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 22 | dist = dist + dist.t() 23 | dist.addmm_(1, -2, inputs, inputs.t()) 24 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 25 | # For each anchor, find the hardest positive and negative 26 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 27 | dist_ap, dist_an = [], [] 28 | if self.use_semi: 29 | for i in range(P): 30 | for j in range(self.K): 31 | neg_examples = dist[i * self.K + j][mask[i * self.K + j] == 0] 32 | for pair in range(j + 1, self.K): 33 | ap = dist[i * self.K + j][i * self.K + pair] 34 | dist_ap.append(ap.unsqueeze(0)) 35 | dist_an.append(neg_examples.min().unsqueeze(0)) 36 | else: 37 | for i in range(n): 38 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 39 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 40 | dist_ap = torch.cat(dist_ap) 41 | dist_an = torch.cat(dist_an) 42 | # Compute ranking hinge loss 43 | y = dist_an.data.new() 44 | y.resize_as_(dist_an.data) 45 | y.fill_(1) 46 | y = Variable(y) 47 | loss = self.ranking_loss(dist_an, dist_ap, y) 48 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 49 | return loss, prec 50 | -------------------------------------------------------------------------------- /reid/loss/virtual_ce.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | from scipy.stats import norm 8 | 9 | import numpy as np 10 | 11 | 12 | class VirtualCE(nn.Module): 13 | def __init__(self, beta=0.1): 14 | super(VirtualCE, self).__init__() 15 | self.beta = beta 16 | 17 | def forward(self, inputs, targets): 18 | # norm first 19 | n = inputs.shape[0] 20 | inputs = F.normalize(inputs, p=2) 21 | allPids = targets.cpu().numpy().tolist() 22 | # All Centers 23 | centerHash = { 24 | pid: F.normalize(inputs[targets == pid, :].mean(dim=0, keepdim=True), p=2).detach() for pid in set(allPids) 25 | } 26 | allCenters = torch.autograd.Variable(torch.cat(list(centerHash.values()))).cuda() 27 | centerPID = torch.from_numpy(np.asarray(list(centerHash.keys()))) 28 | # sampler vs center 29 | samplerCenter = torch.autograd.Variable(torch.cat([allCenters[centerPID == pid, :] for pid in allPids])).cuda() 30 | # inputs--(128*1024), allCenters--(32*1024) 31 | vce = torch.diag(torch.exp(samplerCenter.mm(inputs.t()) / self.beta)) # 1*128 32 | centerScore = torch.exp(allCenters.mm(inputs.t()) / self.beta).sum(dim=0) # 32(center number)*128->1*128 33 | return -torch.log(vce.div(centerScore)).mean() 34 | 35 | 36 | class VirtualKCE(nn.Module): 37 | def __init__(self, beta=0.1): 38 | super(VirtualKCE, self).__init__() 39 | self.beta = beta 40 | 41 | def forward(self, inputs, targets): 42 | # norm first 43 | n = inputs.shape[0] 44 | inputs = F.normalize(inputs, p=2) 45 | allPids = targets.cpu().numpy().tolist() 46 | # All Centers 47 | centerHash = { 48 | pid: F.normalize(inputs[targets == pid, :].mean(dim=0, keepdim=True), p=2).detach() for pid in set(allPids) 49 | } 50 | allCenters = torch.autograd.Variable(torch.cat(list(centerHash.values()))).cuda() 51 | centerPID = torch.from_numpy(np.asarray(list(centerHash.keys()))) 52 | samplerCenter = torch.autograd.Variable(torch.cat([allCenters[centerPID == pid, :] for pid in allPids])).cuda() 53 | # inputs--(128*1024), allCenters--(32*1024) 54 | vce = torch.diag(torch.exp(samplerCenter.mm(inputs.t()) / self.beta)) # 1*128 55 | centerScore = torch.exp(allCenters.mm(inputs.t()) / self.beta).sum(dim=0) # 32*128->1*128 56 | kNegScore = torch.diag(inputs.mm(inputs.t())) 57 | return -torch.log(vce.div(kNegScore + centerScore)).mean() 58 | -------------------------------------------------------------------------------- /reid/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | import errno 5 | import shutil 6 | import json 7 | import os.path as osp 8 | 9 | import torch 10 | 11 | 12 | def mkdir_if_missing(directory): 13 | if not osp.exists(directory): 14 | try: 15 | os.makedirs(directory) 16 | except OSError as e: 17 | if e.errno != errno.EEXIST: 18 | raise 19 | 20 | 21 | class AverageMeter(object): 22 | """Computes and stores the average and current value. 23 | 24 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 25 | """ 26 | 27 | def __init__(self): 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = 0 32 | self.avg = 0 33 | self.sum = 0 34 | self.count = 0 35 | 36 | def update(self, val, n=1): 37 | self.val = val 38 | self.sum += val * n 39 | self.count += n 40 | self.avg = self.sum / self.count 41 | 42 | 43 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 44 | mkdir_if_missing(osp.dirname(fpath)) 45 | torch.save(state, fpath) 46 | if is_best: 47 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) 48 | 49 | 50 | class Logger(object): 51 | """ 52 | Write console output to external text file. 53 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 54 | """ 55 | 56 | def __init__(self, fpath=None): 57 | self.console = sys.stdout 58 | self.file = None 59 | if fpath is not None: 60 | mkdir_if_missing(os.path.dirname(fpath)) 61 | self.file = open(fpath, 'w') 62 | 63 | def __del__(self): 64 | self.close() 65 | 66 | def __enter__(self): 67 | pass 68 | 69 | def __exit__(self, *args): 70 | self.close() 71 | 72 | def write(self, msg): 73 | self.console.write(msg) 74 | if self.file is not None: 75 | self.file.write(msg) 76 | 77 | def flush(self): 78 | self.console.flush() 79 | if self.file is not None: 80 | self.file.flush() 81 | os.fsync(self.file.fileno()) 82 | 83 | def close(self): 84 | self.console.close() 85 | if self.file is not None: 86 | self.file.close() 87 | 88 | 89 | def read_json(fpath): 90 | with open(fpath, 'r') as f: 91 | obj = json.load(f) 92 | return obj 93 | 94 | 95 | def write_json(obj, fpath): 96 | mkdir_if_missing(osp.dirname(fpath)) 97 | with open(fpath, 'w') as f: 98 | json.dump(obj, f, indent=4, separators=(',', ': ')) 99 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Asymmetric Co-Teaching for Unsupervised Cross Domain Person Re-Identification (AAAI 2020) 2 | 3 | Code for AAAI 2020 paper [Asymmetric Co-Teaching for Unsupervised Cross Domain Person Re-Identification](https://arxiv.org/abs/1912.01349). 4 | 5 | ![Results](figures/ACT.jpg) 6 | 7 | ## Requirements 8 | * python 3.7 9 | * Server with 4 GPUs 10 | * Market1501, DukeMTMC-reID and other datasets. 11 | * Other necessary packages listed in [requirements.txt](requirements.txt) 12 | 13 | ## Adaptation with ACT 14 | 1. Download all necessary datasets and move them to 'data' by following instructions in 'data/readme.md' 15 | 16 | 2. If you want to train from the pre-adapted model for fast reproduction, 17 | please download all models in Resources and run the following command: 18 | 19 | ``` 20 | python selftrainingACT.py --src_dataset {src_dataset_name} --tgt_dataset {tgt_dataset_name} --resume {model's path} --data_dir ./data --logs_dir {path to save model} 21 | ``` 22 | 23 | avaliable choices to fill "src_dataset_name" and "tgt_dataset_name" are: 24 | market1501 (for Market1501), dukemtmc (for DukeMTMC-reID), cuhk03 (for CUHK03). 25 | 26 | 27 | 3. If you want to train from scratch, please train source model and adapted model by using code in 28 | [Adaptive-ReID](https://github.com/LcDog/DomainAdaptiveReID) and follow #2. 29 | 30 | ## Adaptation with other co-teaching-like structures. 31 | To reproduce the results in Tab. 2 of our paper, please run selftrainingRCT.py and selftrainingCT.py in similar way. 32 | 33 | ## Adaptation with other clustering methods. 34 | To reproduce Tab.3 of our paper, run selftrainingKmeans.py (co-teaching version) and selftrainingKmeansAsy.py (ACT version). 35 | 36 | If you find this code useful in your research, please consider citing: 37 | ``` 38 | @inproceedings{yang2020asymmetric, 39 | title={Asymmetric Co-Teaching for Unsupervised Cross-Domain Person Re-Identification.}, 40 | author={Yang, Fengxiang and Li, Ke and Zhong, Zhun and Luo, Zhiming and Sun, Xing and Cheng, Hao and Guo, Xiaowei and Huang, Feiyue and Ji, Rongrong and Li, Shaozi}, 41 | booktitle={AAAI}, 42 | pages={12597--12604}, 43 | year={2020} 44 | } 45 | ``` 46 | 47 | ## Acknowledgments 48 | Our code is based on [open-reid](https://github.com/Cysu/open-reid) and [Adaptive-ReID](https://arxiv.org/abs/1807.11334), 49 | if you use our code, please also cite their paper. 50 | ``` 51 | @article{song2018unsupervised, 52 | title={Unsupervised domain adaptive re-identification: Theory and practice}, 53 | author={Song, Liangchen and Wang, Cheng and Zhang, Lefei and Du, Bo and Zhang, Qian and Huang, Chang and Wang, Xinggang}, 54 | journal={arXiv preprint arXiv:1807.11334}, 55 | year={2018} 56 | } 57 | ``` 58 | 59 | 60 | 61 |

Resouces:

62 | 63 | 1. Pretrained Models: 64 | 65 | all pre-adapted models are named by the following formula: 66 | ``` 67 | ada{src}2{tgt}.pth 68 | ``` 69 | where "src" and "tgt" are the initial letter of source and target dataset's name, i.e., M for Market1501, D for Duke and C for CUHK03. 70 | 71 | [Baidu NetDisk](https://pan.baidu.com/s/1uPjKpkdZjqSJdk3XxR1-Yg), Password: 9aba 72 | 73 | [Google Drive](https://drive.google.com/file/d/1W1BcmHjmzxR3TVj2rFpnV703Huat3AeA/view?usp=sharing) 74 | 75 | 2. Results on MSMT17(MS), Duke(D) and Market(M). 76 | 77 | ![Results](figures/MSMT.jpg) 78 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from PIL import Image, ImageOps, ImageEnhance 4 | import random 5 | import math 6 | from torchvision.transforms import * 7 | 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | class RandomErasing(object): 52 | ''' 53 | Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al. 54 | ------------------------------------------------------------------------------------- 55 | probability: The probability that the operation will be performed. 56 | sl: min erasing area 57 | sh: max erasing area 58 | r1: min aspect ratio 59 | mean: erasing value 60 | ------------------------------------------------------------------------------------- 61 | ''' 62 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.2, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 63 | self.probability = probability 64 | self.mean = mean 65 | self.sl = sl 66 | self.sh = sh 67 | self.r1 = r1 68 | 69 | def __call__(self, img): 70 | 71 | if random.uniform(0, 1) > self.probability: 72 | return img 73 | 74 | for attempt in range(100): 75 | area = img.size()[1] * img.size()[2] 76 | 77 | target_area = random.uniform(self.sl, self.sh) * area 78 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 79 | 80 | h = int(round(math.sqrt(target_area * aspect_ratio))) 81 | w = int(round(math.sqrt(target_area / aspect_ratio))) 82 | 83 | if w < img.size()[2] and h < img.size()[1]: 84 | x1 = random.randint(0, img.size()[1] - h) 85 | y1 = random.randint(0, img.size()[2] - w) 86 | if img.size()[0] == 3: 87 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 88 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 89 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 90 | else: 91 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 92 | return img 93 | 94 | return img 95 | -------------------------------------------------------------------------------- /reid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | from torch.autograd import Variable 9 | import torchvision 10 | # from torch_deform_conv.layers import ConvOffset2D 11 | from reid.utils.serialization import load_checkpoint, save_checkpoint 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 14 | 'resnet152'] 15 | 16 | 17 | class ResNet(nn.Module): 18 | __factory = { 19 | 18: torchvision.models.resnet18, 20 | 34: torchvision.models.resnet34, 21 | 50: torchvision.models.resnet50, 22 | 101: torchvision.models.resnet101, 23 | 152: torchvision.models.resnet152, 24 | } 25 | 26 | def __init__(self, depth, checkpoint=None, pretrained=True, num_features=2048, 27 | dropout=0.1, num_classes=0, numCams=0): 28 | super(ResNet, self).__init__() 29 | 30 | self.depth = depth 31 | self.checkpoint = checkpoint 32 | self.pretrained = pretrained 33 | self.num_features = num_features 34 | self.dropout = dropout 35 | self.num_classes = num_classes 36 | 37 | if self.dropout > 0: 38 | self.drop = nn.Dropout(self.dropout) 39 | # Construct base (pretrained) resnet 40 | if depth not in ResNet.__factory: 41 | raise KeyError("Unsupported depth:", depth) 42 | self.base = ResNet.__factory[depth](pretrained=pretrained) 43 | out_planes = self.base.fc.in_features 44 | 45 | # resume from pre-iteration training 46 | if self.checkpoint: 47 | state_dict = load_checkpoint(checkpoint) 48 | self.load_state_dict(state_dict['state_dict'], strict=False) 49 | 50 | self.feat = nn.Linear(out_planes, self.num_features, bias=False) 51 | self.feat_bn = nn.BatchNorm1d(self.num_features) 52 | self.relu = nn.ReLU(inplace=True) 53 | init.normal(self.feat.weight, std=0.001) 54 | init.constant(self.feat_bn.weight, 1) 55 | init.constant(self.feat_bn.bias, 0) 56 | 57 | # x2 classifier 58 | self.classifier_x2 = nn.Linear(self.num_features, self.num_classes) 59 | init.normal(self.classifier_x2.weight, std=0.001) 60 | init.constant(self.classifier_x2.bias, 0) 61 | 62 | if not self.pretrained: 63 | self.reset_params() 64 | 65 | def forward(self, x): 66 | for name, module in self.base._modules.items(): 67 | if name == 'avgpool': break 68 | x = module(x) 69 | # x with (bSize, 1024, 8, 4) 70 | x1 = F.avg_pool2d(x, x.size()[2:]) 71 | x1 = x1.view(x1.size(0), -1) 72 | # get classification 73 | x2 = F.avg_pool2d(x, x.size()[2:]) 74 | x2 = x2.view(x2.size(0), -1) 75 | x2 = self.feat(x2) 76 | x2 = self.feat_bn(x2) 77 | x2 = self.relu(x2) 78 | if self.num_classes != 0: 79 | return x1, x2, self.classifier_x2(x2) # pool5, fc2048, classifier 80 | return x1, x2 81 | 82 | def reset_params(self): 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | init.normal(m.weight, std=0.001) 86 | if m.bias is not None: 87 | init.constant(m.bias, 0) 88 | elif isinstance(m, nn.BatchNorm2d): 89 | init.constant(m.weight, 1) 90 | init.constant(m.bias, 0) 91 | elif isinstance(m, nn.Linear): 92 | init.normal(m.weight, std=0.001) 93 | if m.bias is not None: 94 | init.constant(m.bias, 0) 95 | 96 | 97 | def resnet18(**kwargs): 98 | return ResNet(18, **kwargs) 99 | 100 | 101 | def resnet34(**kwargs): 102 | return ResNet(34, **kwargs) 103 | 104 | 105 | def resnet50(**kwargs): 106 | return ResNet(50, **kwargs) 107 | 108 | 109 | def resnet101(**kwargs): 110 | return ResNet(101, **kwargs) 111 | 112 | 113 | def resnet152(**kwargs): 114 | return ResNet(152, **kwargs) 115 | -------------------------------------------------------------------------------- /reid/datasets/cuhk03.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..utils.data import Dataset 7 | from ..utils.osutils import mkdir_if_missing 8 | from ..utils.serialization import write_json 9 | 10 | 11 | class CUHK03(Dataset): 12 | url = 'https://docs.google.com/spreadsheet/viewform?usp=drive_web&formkey=dHRkMkFVSUFvbTJIRkRDLWRwZWpONnc6MA#gid=0' 13 | md5 = '728939e58ad9f0ff53e521857dd8fb43' 14 | 15 | def __init__(self, root, split_id=0, num_val=100, download=True): 16 | super(CUHK03, self).__init__(root, split_id=split_id) 17 | 18 | if download: 19 | self.download() 20 | 21 | if not self._check_integrity(): 22 | raise RuntimeError("Dataset not found or corrupted. " + 23 | "You can use download=True to download it.") 24 | 25 | self.load(num_val) 26 | 27 | def download(self): 28 | if self._check_integrity(): 29 | print("Files already downloaded and verified") 30 | return 31 | 32 | import h5py 33 | import hashlib 34 | from scipy.misc import imsave 35 | from zipfile import ZipFile 36 | 37 | raw_dir = osp.join(self.root, 'raw') 38 | mkdir_if_missing(raw_dir) 39 | 40 | # Download the raw zip file 41 | fpath = osp.join(raw_dir, 'cuhk03_release.zip') 42 | if osp.isfile(fpath) and \ 43 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 44 | print("Using downloaded file: " + fpath) 45 | else: 46 | raise RuntimeError("Please download the dataset manually from {} " 47 | "to {}".format(self.url, fpath)) 48 | 49 | # Extract the file 50 | exdir = osp.join(raw_dir, 'cuhk03_release') 51 | if not osp.isdir(exdir): 52 | print("Extracting zip file") 53 | with ZipFile(fpath) as z: 54 | z.extractall(path=raw_dir) 55 | 56 | # Format 57 | images_dir = osp.join(self.root, 'images') 58 | mkdir_if_missing(images_dir) 59 | matdata = h5py.File(osp.join(exdir, 'cuhk-03.mat'), 'r') 60 | 61 | def deref(ref): 62 | return matdata[ref][:].T 63 | 64 | def dump_(refs, pid, cam, fnames): 65 | for ref in refs: 66 | img = deref(ref) 67 | if img.size == 0 or img.ndim < 2: break 68 | fname = '{:08d}_{:02d}_{:04d}.jpg'.format(pid, cam, len(fnames)) 69 | imsave(osp.join(images_dir, fname), img) 70 | fnames.append(fname) 71 | 72 | identities = [] 73 | for labeled, detected in zip( 74 | matdata['labeled'][0], matdata['detected'][0]): 75 | labeled, detected = deref(labeled), deref(detected) 76 | assert labeled.shape == detected.shape 77 | for i in range(labeled.shape[0]): 78 | pid = len(identities) 79 | images = [[], []] 80 | dump_(labeled[i, :5], pid, 0, images[0]) 81 | dump_(detected[i, :5], pid, 0, images[0]) 82 | dump_(labeled[i, 5:], pid, 1, images[1]) 83 | dump_(detected[i, 5:], pid, 1, images[1]) 84 | identities.append(images) 85 | 86 | # Save meta information into a json file 87 | meta = {'name': 'cuhk03', 'shot': 'multiple', 'num_cameras': 2, 88 | 'identities': identities} 89 | write_json(meta, osp.join(self.root, 'meta.json')) 90 | 91 | # Save training and test splits 92 | splits = [] 93 | view_counts = [deref(ref).shape[0] for ref in matdata['labeled'][0]] 94 | vid_offsets = np.r_[0, np.cumsum(view_counts)] 95 | for ref in matdata['testsets'][0]: 96 | test_info = deref(ref).astype(np.int32) 97 | test_pids = sorted( 98 | [int(vid_offsets[i - 1] + j - 1) for i, j in test_info]) 99 | trainval_pids = list(set(range(vid_offsets[-1])) - set(test_pids)) 100 | split = {'trainval': trainval_pids, 101 | 'query': test_pids, 102 | 'gallery': test_pids} 103 | splits.append(split) 104 | write_json(splits, osp.join(self.root, 'splits.json')) 105 | -------------------------------------------------------------------------------- /reid/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import os 4 | import numpy as np 5 | 6 | from ..utils.data import Dataset 7 | from ..utils.osutils import mkdir_if_missing 8 | from ..utils.serialization import read_json 9 | from ..utils.serialization import write_json 10 | 11 | 12 | class msmt17(Dataset): 13 | url = 'https://drive.google.com/file/d/1PduQX1OBuoXDh9JxybYBoDEcKhSx_Q8j/view?usp=sharing' 14 | md5 = 'ea5502ae9dd06c596ad866bd1db0280d' 15 | 16 | def __init__(self, root, split_id=0, num_val=100, download=True): 17 | super(msmt17, self).__init__(root, split_id=split_id) 18 | 19 | if download: 20 | self.download() 21 | 22 | self.load() 23 | 24 | def download(self): 25 | if self._check_integrity(): 26 | print("Files already downloaded and verified") 27 | return 28 | 29 | import re 30 | import hashlib 31 | import tarfile 32 | 33 | raw_dir = osp.join(self.root, 'raw') 34 | mkdir_if_missing(raw_dir) 35 | 36 | # Download the raw zip file 37 | fpath = osp.join(raw_dir, 'MSMT17_V1.tar.gz') 38 | if osp.isfile(fpath) and \ 39 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 40 | print("Using downloaded file: " + fpath) 41 | else: 42 | raise RuntimeError("Please download the dataset manually from {} " 43 | "to {}".format(self.url, fpath)) 44 | 45 | # Extract the file 46 | exdir = osp.join(raw_dir, 'MSMT17_V1') 47 | if not osp.isdir(exdir): 48 | print("Extracting tar file") 49 | with tarfile.open(fpath) as tar: 50 | tar.extractall(raw_dir) 51 | 52 | # MSMT17 files 53 | def register(typeName, pattern=re.compile(r'([-\d]+)_([-\d]+)_([-\d]+)'), test_query=False): 54 | assert typeName.lower() in ['gallery', 'query', 'train', 'val'] 55 | nameMap = { 56 | 'gallery': 'test', 'query': 'test', 57 | 'train': 'train', 'val': 'train' 58 | } 59 | with open(osp.join(exdir, 'list_{}.txt'.format(typeName.lower())), 'r') as f: 60 | fpaths = f.readlines() 61 | fpaths = [name.strip().split(' ')[0] for name in fpaths] 62 | fpaths = sorted([osp.join(exdir, nameMap[typeName.lower()], name) for name in fpaths]) 63 | curData = [] 64 | for fpath in fpaths: 65 | fname = osp.basename(fpath) 66 | pid, _, cam = map(int, pattern.search(fname).groups()) 67 | cam -= 1 68 | curData.append((fpath, pid, cam)) 69 | return curData 70 | 71 | self.train = register('train') 72 | self.val = register('val') 73 | self.trainval = self.train + self.val 74 | self.gallery = register('gallery') 75 | self.query = register('query') 76 | 77 | ######################## 78 | # Added 79 | def load(self, verbose=True): 80 | trainPids = [pid[1] for pid in self.train] 81 | valPids = [pid[1] for pid in self.val] 82 | trainvalPids = [pid[1] for pid in self.trainval] 83 | galleryPids = [pid[1] for pid in self.gallery] 84 | queryPids = [pid[1] for pid in self.query] 85 | self.num_train_ids = len(set(trainPids)) 86 | self.num_val_ids = len(set(valPids)) 87 | self.num_trainval_ids = len(set(trainvalPids)) 88 | self.num_query_ids = len(set(queryPids)) 89 | self.num_gallery_ids = len(set(galleryPids)) 90 | ########## 91 | if verbose: 92 | print(self.__class__.__name__, "dataset loaded") 93 | print(" subset | # ids | # images") 94 | print(" ---------------------------") 95 | print(" train | {:5d} | {:8d}".format(self.num_train_ids, len(self.train))) 96 | print(" val | {:5d} | {:8d}".format(self.num_val_ids, len(self.val))) 97 | print(" trainval | {:5d} | {:8d}".format(self.num_trainval_ids, len(self.trainval))) 98 | print(" query | {:5d} | {:8d}".format(self.num_query_ids, len(self.query))) 99 | print(" gallery | {:5d} | {:8d}".format(self.num_gallery_ids, len(self.gallery))) 100 | ######################## 101 | -------------------------------------------------------------------------------- /reid/utils/data/video_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | import random 9 | 10 | def read_image(img_path): 11 | """Keep reading image until succeed. 12 | This can avoid IOError incurred by heavy IO process.""" 13 | got_img = False 14 | while not got_img: 15 | try: 16 | img = Image.open(img_path).convert('RGB') 17 | got_img = True 18 | except IOError: 19 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 20 | pass 21 | return img 22 | 23 | 24 | class VideoDataset(Dataset): 25 | """Video Person ReID Dataset. 26 | Note batch data has shape (batch, seq_len, channel, height, width). 27 | """ 28 | sample_methods = ['evenly', 'random', 'all'] 29 | 30 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None): 31 | self.dataset = dataset 32 | self.seq_len = seq_len 33 | self.sample = sample 34 | self.transform = transform 35 | 36 | def __len__(self): 37 | return len(self.dataset) 38 | 39 | def __getitem__(self, index): 40 | img_paths, pid, camid = self.dataset[index] 41 | num = len(img_paths) 42 | if self.sample == 'random': 43 | """ 44 | Randomly sample seq_len consecutive frames from num frames, 45 | if num is smaller than seq_len, then replicate items. 46 | This sampling strategy is used in training phase. 47 | """ 48 | frame_indices = range(num) 49 | rand_end = max(0, len(frame_indices) - self.seq_len - 1) 50 | begin_index = random.randint(0, rand_end) 51 | end_index = min(begin_index + self.seq_len, len(frame_indices)) 52 | 53 | indices = frame_indices[begin_index:end_index] 54 | 55 | for index in indices: 56 | if len(indices) >= self.seq_len: 57 | break 58 | indices.append(index) 59 | indices=np.array(indices) 60 | imgs = [] 61 | for index in indices: 62 | index=int(index) 63 | img_path = img_paths[index] 64 | img = read_image(img_path) 65 | if self.transform is not None: 66 | img = self.transform(img) 67 | img = img.unsqueeze(0) 68 | imgs.append(img) 69 | imgs = torch.cat(imgs, dim=0) 70 | #imgs=imgs.permute(1,0,2,3) 71 | return imgs, pid, camid 72 | 73 | elif self.sample == 'dense': 74 | """ 75 | Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1. 76 | This sampling strategy is used in test phase. 77 | """ 78 | cur_index=0 79 | frame_indices = range(num) 80 | indices_list=[] 81 | while num-cur_index > self.seq_len: 82 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len]) 83 | cur_index+=self.seq_len 84 | last_seq=frame_indices[cur_index:] 85 | for index in last_seq: 86 | if len(last_seq) >= self.seq_len: 87 | break 88 | last_seq.append(index) 89 | indices_list.append(last_seq) 90 | imgs_list=[] 91 | for indices in indices_list: 92 | imgs = [] 93 | for index in indices: 94 | index=int(index) 95 | img_path = img_paths[index] 96 | img = read_image(img_path) 97 | if self.transform is not None: 98 | img = self.transform(img) 99 | img = img.unsqueeze(0) 100 | imgs.append(img) 101 | imgs = torch.cat(imgs, dim=0) 102 | #imgs=imgs.permute(1,0,2,3) 103 | imgs_list.append(imgs) 104 | imgs_array = torch.stack(imgs_list) 105 | return imgs_array, pid, camid 106 | 107 | else: 108 | raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods)) 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | from collections import namedtuple 5 | 6 | import torch 7 | 8 | from .evaluation_metrics import cmc, mean_ap 9 | from .feature_extraction import extract_cnn_feature 10 | from .utils.meters import AverageMeter 11 | 12 | 13 | def extract_features(model, data_loader, print_freq=1): 14 | model.eval() 15 | batch_time = AverageMeter() 16 | data_time = AverageMeter() 17 | 18 | features = OrderedDict() 19 | labels = OrderedDict() 20 | 21 | end = time.time() 22 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 23 | data_time.update(time.time() - end) 24 | 25 | outputs = extract_cnn_feature(model, imgs) 26 | for fname, output, pid in zip(fnames, outputs, pids): 27 | features[fname] = output 28 | labels[fname] = pid 29 | 30 | batch_time.update(time.time() - end) 31 | end = time.time() 32 | 33 | if (i + 1) % print_freq == 0: 34 | print('Extract Features: [{}/{}]\t' 35 | 'Time {:.3f} ({:.3f})\t' 36 | 'Data {:.3f} ({:.3f})\t' 37 | .format(i + 1, len(data_loader), 38 | batch_time.val, batch_time.avg, 39 | data_time.val, data_time.avg)) 40 | 41 | return features, labels 42 | 43 | 44 | def pairwise_distance(features, query=None, gallery=None, metric=None): 45 | if query is None and gallery is None: 46 | n = len(features) 47 | x = torch.cat(list(features.values())) 48 | x = x.view(n, -1) 49 | if metric is not None: 50 | x = metric.transform(x) 51 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 52 | dist = dist.expand(n, n) - 2 * torch.mm(x, x.t()) 53 | return dist 54 | 55 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 56 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 57 | m, n = x.size(0), y.size(0) 58 | x = x.view(m, -1) 59 | y = y.view(n, -1) 60 | if metric is not None: 61 | x = metric.transform(x) 62 | y = metric.transform(y) 63 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 64 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 65 | dist.addmm_(1, -2, x, y.t()) 66 | return dist 67 | 68 | 69 | def evaluate_all(distmat, query=None, gallery=None, 70 | query_ids=None, gallery_ids=None, 71 | query_cams=None, gallery_cams=None, 72 | cmc_topk=(1, 5, 10)): 73 | if query is not None and gallery is not None: 74 | query_ids = [pid for _, pid, _ in query] 75 | gallery_ids = [pid for _, pid, _ in gallery] 76 | query_cams = [cam for _, _, cam in query] 77 | gallery_cams = [cam for _, _, cam in gallery] 78 | else: 79 | assert (query_ids is not None and gallery_ids is not None 80 | and query_cams is not None and gallery_cams is not None) 81 | 82 | # Compute mean AP 83 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 84 | print('Mean AP: {:4.1%}'.format(mAP)) 85 | 86 | cmc_configs = { 87 | 'allshots': dict(separate_camera_set=False, 88 | single_gallery_shot=False, 89 | first_match_break=False), 90 | 'cuhk03': dict(separate_camera_set=True, 91 | single_gallery_shot=True, 92 | first_match_break=False), 93 | 'market1501': dict(separate_camera_set=False, 94 | single_gallery_shot=False, 95 | first_match_break=True)} 96 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 97 | query_cams, gallery_cams, **params) 98 | for name, params in cmc_configs.items()} 99 | 100 | print('CMC Scores{:>12}{:>12}{:>12}' 101 | .format('allshots', 'cuhk03', 'market1501')) 102 | rank_score = namedtuple( 103 | 'rank_score', 104 | ['map', 'allshots', 'cuhk03', 'market1501'], 105 | ) 106 | for k in cmc_topk: 107 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}' 108 | .format(k, cmc_scores['allshots'][k - 1], 109 | cmc_scores['cuhk03'][k - 1], 110 | cmc_scores['market1501'][k - 1])) 111 | score = rank_score( 112 | mAP, 113 | cmc_scores['allshots'], cmc_scores['cuhk03'], 114 | cmc_scores['market1501'], 115 | ) 116 | return score 117 | 118 | 119 | class Evaluator(object): 120 | def __init__(self, model, print_freq=1): 121 | super(Evaluator, self).__init__() 122 | self.model = model 123 | self.print_freq = print_freq 124 | 125 | def evaluate(self, data_loader, query, gallery, metric=None): 126 | features, _ = extract_features(self.model, data_loader, print_freq=self.print_freq) 127 | distmat = pairwise_distance(features, query, gallery, metric=metric) 128 | return evaluate_all(distmat, query=query, gallery=gallery) 129 | -------------------------------------------------------------------------------- /reid/models/inception.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.nn import init 7 | 8 | 9 | __all__ = ['InceptionNet', 'inception'] 10 | 11 | 12 | def _make_conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, 13 | bias=False): 14 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 15 | stride=stride, padding=padding, bias=bias) 16 | bn = nn.BatchNorm2d(out_planes) 17 | relu = nn.ReLU(inplace=True) 18 | return nn.Sequential(conv, bn, relu) 19 | 20 | 21 | class Block(nn.Module): 22 | def __init__(self, in_planes, out_planes, pool_method, stride): 23 | super(Block, self).__init__() 24 | self.branches = nn.ModuleList([ 25 | nn.Sequential( 26 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0), 27 | _make_conv(out_planes, out_planes, stride=stride) 28 | ), 29 | nn.Sequential( 30 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0), 31 | _make_conv(out_planes, out_planes), 32 | _make_conv(out_planes, out_planes, stride=stride)) 33 | ]) 34 | 35 | if pool_method == 'Avg': 36 | assert stride == 1 37 | self.branches.append( 38 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0)) 39 | self.branches.append(nn.Sequential( 40 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 41 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0))) 42 | else: 43 | self.branches.append( 44 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1)) 45 | 46 | def forward(self, x): 47 | return torch.cat([b(x) for b in self.branches], 1) 48 | 49 | 50 | class InceptionNet(nn.Module): 51 | def __init__(self, cut_at_pooling=False, num_features=256, norm=False, 52 | dropout=0, num_classes=0): 53 | super(InceptionNet, self).__init__() 54 | self.cut_at_pooling = cut_at_pooling 55 | 56 | self.conv1 = _make_conv(3, 32) 57 | self.conv2 = _make_conv(32, 32) 58 | self.conv3 = _make_conv(32, 32) 59 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) 60 | self.in_planes = 32 61 | self.inception4a = self._make_inception(64, 'Avg', 1) 62 | self.inception4b = self._make_inception(64, 'Max', 2) 63 | self.inception5a = self._make_inception(128, 'Avg', 1) 64 | self.inception5b = self._make_inception(128, 'Max', 2) 65 | self.inception6a = self._make_inception(256, 'Avg', 1) 66 | self.inception6b = self._make_inception(256, 'Max', 2) 67 | 68 | if not self.cut_at_pooling: 69 | self.num_features = num_features 70 | self.norm = norm 71 | self.dropout = dropout 72 | self.has_embedding = num_features > 0 73 | self.num_classes = num_classes 74 | 75 | self.avgpool = nn.AdaptiveAvgPool2d(1) 76 | 77 | if self.has_embedding: 78 | self.feat = nn.Linear(self.in_planes, self.num_features) 79 | self.feat_bn = nn.BatchNorm1d(self.num_features) 80 | else: 81 | # Change the num_features to CNN output channels 82 | self.num_features = self.in_planes 83 | if self.dropout > 0: 84 | self.drop = nn.Dropout(self.dropout) 85 | if self.num_classes > 0: 86 | self.classifier = nn.Linear(self.num_features, self.num_classes) 87 | 88 | self.reset_params() 89 | 90 | def forward(self, x): 91 | x = self.conv1(x) 92 | x = self.conv2(x) 93 | x = self.conv3(x) 94 | x = self.pool3(x) 95 | x = self.inception4a(x) 96 | x = self.inception4b(x) 97 | x = self.inception5a(x) 98 | x = self.inception5b(x) 99 | x = self.inception6a(x) 100 | x = self.inception6b(x) 101 | 102 | if self.cut_at_pooling: 103 | return x 104 | 105 | x = self.avgpool(x) 106 | x = x.view(x.size(0), -1) 107 | 108 | if self.has_embedding: 109 | x = self.feat(x) 110 | x = self.feat_bn(x) 111 | if self.norm: 112 | x = F.normalize(x) 113 | elif self.has_embedding: 114 | x = F.relu(x) 115 | if self.dropout > 0: 116 | x = self.drop(x) 117 | if self.num_classes > 0: 118 | x = self.classifier(x) 119 | return x 120 | 121 | def _make_inception(self, out_planes, pool_method, stride): 122 | block = Block(self.in_planes, out_planes, pool_method, stride) 123 | self.in_planes = (out_planes * 4 if pool_method == 'Avg' else 124 | out_planes * 2 + self.in_planes) 125 | return block 126 | 127 | def reset_params(self): 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | init.kaiming_normal(m.weight, mode='fan_out') 131 | if m.bias is not None: 132 | init.constant(m.bias, 0) 133 | elif isinstance(m, nn.BatchNorm2d): 134 | init.constant(m.weight, 1) 135 | init.constant(m.bias, 0) 136 | elif isinstance(m, nn.Linear): 137 | init.normal(m.weight, std=0.001) 138 | if m.bias is not None: 139 | init.constant(m.bias, 0) 140 | 141 | 142 | def inception(**kwargs): 143 | return InceptionNet(**kwargs) 144 | -------------------------------------------------------------------------------- /reid/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..serialization import read_json 7 | 8 | 9 | def _pluck(identities, indices, relabel=False): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, cam_images in enumerate(pid_images): 14 | for fname in cam_images: 15 | name = osp.splitext(fname)[0] 16 | x, y, _ = map(int, name.split('_')) 17 | assert pid == x and camid == y 18 | if relabel: 19 | ret.append((fname, index, camid)) 20 | else: 21 | ret.append((fname, pid, camid)) 22 | return ret 23 | 24 | def _pluck_gallery(identities, indices, relabel=False): 25 | ret = [] 26 | for index, pid in enumerate(indices): 27 | pid_images = identities[pid] 28 | for camid, cam_images in enumerate(pid_images): 29 | if len(cam_images[:-1])==0: 30 | for fname in cam_images: 31 | name = osp.splitext(fname)[0] 32 | x, y, _ = map(int, name.split('_')) 33 | assert pid == x and camid == y 34 | if relabel: 35 | ret.append((fname, index, camid)) 36 | else: 37 | ret.append((fname, pid, camid)) 38 | else: 39 | for fname in cam_images[:-1]: 40 | name = osp.splitext(fname)[0] 41 | x, y, _ = map(int, name.split('_')) 42 | assert pid == x and camid == y 43 | if relabel: 44 | ret.append((fname, index, camid)) 45 | else: 46 | ret.append((fname, pid, camid)) 47 | return ret 48 | 49 | def _pluck_query(identities, indices, relabel=False): 50 | ret = [] 51 | for index, pid in enumerate(indices): 52 | pid_images = identities[pid] 53 | for camid, cam_images in enumerate(pid_images): 54 | for fname in cam_images[-1:]: 55 | name = osp.splitext(fname)[0] 56 | x, y, _ = map(int, name.split('_')) 57 | assert pid == x and camid == y 58 | if relabel: 59 | ret.append((fname, index, camid)) 60 | else: 61 | ret.append((fname, pid, camid)) 62 | return ret 63 | 64 | 65 | class Dataset(object): 66 | def __init__(self, root, split_id=0): 67 | self.root = root 68 | self.split_id = split_id 69 | self.meta = None 70 | self.split = None 71 | self.train, self.val, self.trainval = [], [], [] 72 | self.query, self.gallery = [], [] 73 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 74 | 75 | @property 76 | def images_dir(self): 77 | return osp.join(self.root, 'images') 78 | 79 | def load(self, num_val=0.3, verbose=True): 80 | splits = read_json(osp.join(self.root, 'splits.json')) 81 | if self.split_id >= len(splits): 82 | raise ValueError("split_id exceeds total splits {}" 83 | .format(len(splits))) 84 | self.split = splits[self.split_id] 85 | 86 | # Randomly split train / val 87 | trainval_pids = np.asarray(self.split['trainval']) 88 | np.random.shuffle(trainval_pids) 89 | num = len(trainval_pids) 90 | if isinstance(num_val, float): 91 | num_val = int(round(num * num_val)) 92 | if num_val >= num or num_val < 0: 93 | raise ValueError("num_val exceeds total identities {}" 94 | .format(num)) 95 | train_pids = sorted(trainval_pids[:-num_val]) 96 | val_pids = sorted(trainval_pids[-num_val:]) 97 | 98 | self.meta = read_json(osp.join(self.root, 'meta.json')) 99 | identities = self.meta['identities'] 100 | self.train = _pluck(identities, train_pids, relabel=True) 101 | self.val = _pluck(identities, val_pids, relabel=True) 102 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 103 | self.query = _pluck_query(identities, self.split['query']) 104 | #self.gallery = _pluck(identities, self.split['gallery']) 105 | self.gallery = _pluck_gallery(identities, self.split['gallery']) 106 | self.num_train_ids = len(train_pids) 107 | self.num_val_ids = len(val_pids) 108 | self.num_trainval_ids = len(trainval_pids) 109 | 110 | if verbose: 111 | print(self.__class__.__name__, "dataset loaded") 112 | print(" subset | # ids | # images") 113 | print(" ---------------------------") 114 | print(" train | {:5d} | {:8d}" 115 | .format(self.num_train_ids, len(self.train))) 116 | print(" val | {:5d} | {:8d}" 117 | .format(self.num_val_ids, len(self.val))) 118 | print(" trainval | {:5d} | {:8d}" 119 | .format(self.num_trainval_ids, len(self.trainval))) 120 | print(" query | {:5d} | {:8d}" 121 | .format(len(self.split['query']), len(self.query))) 122 | print(" gallery | {:5d} | {:8d}" 123 | .format(len(self.split['gallery']), len(self.gallery))) 124 | 125 | def _check_integrity(self): 126 | return osp.isdir(osp.join(self.root, 'images')) and \ 127 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 128 | osp.isfile(osp.join(self.root, 'splits.json')) 129 | -------------------------------------------------------------------------------- /reid/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # !/usr/bin/env python3 4 | # -*- coding: utf-8 -*- 5 | 6 | import numpy as np 7 | from scipy.spatial.distance import cdist 8 | import torch 9 | import tqdm 10 | import torch.nn.functional as F 11 | 12 | 13 | def pairwiseDis(qFeature, gFeature): # 246s 14 | x, y = F.normalize(qFeature), F.normalize(gFeature) 15 | # x, y = qFeature, gFeature 16 | m, n = x.shape[0], y.shape[0] 17 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 18 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 19 | disMat.addmm_(1, -2, x, y.t()) 20 | print('-----* Distance Matrix has been computed*-----') 21 | return disMat.clamp_(min=1e-5) 22 | 23 | 24 | def re_ranking(input_feature_source, input_feature, k1=20, k2=6, lambda_value=0.1): 25 | all_num = input_feature.shape[0] 26 | # feat = input_feature.astype(np.float16) 27 | feat = torch.from_numpy(input_feature) # target 28 | del input_feature 29 | 30 | if lambda_value != 0: 31 | print('Computing source distance...') 32 | srcFeat, tarFeat = input_feature_source, feat 33 | # all_num_source = input_feature_source.shape[0] 34 | # sour_tar_dist = np.power(cdist(input_feature, input_feature_source), 2).astype(np.float32) #608s 35 | sour_tar_dist = pairwiseDis(srcFeat, tarFeat).t().numpy() 36 | sour_tar_dist = 1 - np.exp(-sour_tar_dist) # tar-src 37 | source_dist_vec = np.min(sour_tar_dist, axis=1) 38 | source_dist_vec = source_dist_vec / (np.max(source_dist_vec) + 1e-3) # for trget 39 | source_dist = np.zeros([all_num, all_num]) # tar size 40 | for i in range(all_num): 41 | source_dist[i, :] = source_dist_vec + source_dist_vec[i] 42 | del sour_tar_dist 43 | del source_dist_vec 44 | 45 | print('Computing original distance...') 46 | original_dist = pairwiseDis(feat, feat).cpu().numpy() 47 | print('done...') 48 | # original_dist = np.power(original_dist,2).astype(np.float16) 49 | del feat 50 | # original_dist = np.concatenate(dist,axis=0) 51 | gallery_num = original_dist.shape[0] # gallery_num=all_num 52 | original_dist = np.transpose(original_dist / (np.max(original_dist, axis=0))) 53 | V = np.zeros_like(original_dist).astype(np.float16) 54 | initial_rank = np.argsort(original_dist).astype(np.int32) ## default axis=-1. 55 | 56 | print('Starting re_ranking...') 57 | for i in tqdm.tqdm(range(all_num)): 58 | # k-reciprocal neighbors 59 | forward_k_neigh_index = initial_rank[i, 60 | :k1 + 1] ## k1+1 because self always ranks first. forward_k_neigh_index.shape=[k1+1]. forward_k_neigh_index[0] == i. 61 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, 62 | :k1 + 1] ##backward.shape = [k1+1, k1+1]. For each ele in forward_k_neigh_index, find its rank k1 neighbors 63 | fi = np.where(backward_k_neigh_index == i)[0] 64 | k_reciprocal_index = forward_k_neigh_index[fi] ## get R(p,k) in the paper 65 | k_reciprocal_expansion_index = k_reciprocal_index 66 | for j in range(len(k_reciprocal_index)): 67 | candidate = k_reciprocal_index[j] 68 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 69 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 70 | :int(np.around(k1 / 2)) + 1] 71 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 72 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 73 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 74 | candidate_k_reciprocal_index): 75 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 76 | 77 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 78 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 79 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 80 | # original_dist = original_dist[:query_num,] 81 | if k2 != 1: 82 | V_qe = np.zeros_like(V, dtype=np.float16) 83 | for i in tqdm.tqdm(range(all_num)): 84 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 85 | V = V_qe 86 | del V_qe 87 | del initial_rank 88 | invIndex = [] 89 | for i in tqdm.tqdm(range(gallery_num)): 90 | invIndex.append(np.where(V[:, i] != 0)[0]) # len(invIndex)=all_num 91 | 92 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 93 | 94 | for i in tqdm.tqdm(range(all_num)): 95 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 96 | indNonZero = np.where(V[i, :] != 0)[0] 97 | indImages = [] 98 | indImages = [invIndex[ind] for ind in indNonZero] 99 | for j in range(len(indNonZero)): 100 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 101 | V[indImages[j], indNonZero[j]]) 102 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 103 | 104 | pos_bool = (jaccard_dist < 0) 105 | jaccard_dist[pos_bool] = 0.0 106 | 107 | if lambda_value == 0: 108 | return jaccard_dist 109 | else: 110 | final_dist = jaccard_dist * (1 - lambda_value) + source_dist * lambda_value 111 | return final_dist 112 | -------------------------------------------------------------------------------- /reid/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | 5 | from ..utils.data import Dataset 6 | from ..utils.osutils import mkdir_if_missing 7 | from ..utils.serialization import read_json 8 | from ..utils.serialization import write_json 9 | 10 | 11 | def _pluck(identities, indices, relabel=False): 12 | """Extract im names of given pids. 13 | Args: 14 | identities: containing im names 15 | indices: pids 16 | relabel: whether to transform pids to classification labels 17 | """ 18 | ret = [] 19 | for index, pid in enumerate(indices): 20 | pid_images = identities[pid] 21 | for camid, cam_images in enumerate(pid_images): 22 | for fname in cam_images: 23 | name = osp.splitext(fname)[0] 24 | x, y, _ = map(int, name.split('_')) 25 | assert pid == x and camid == y 26 | if relabel: 27 | ret.append((fname, index, camid)) 28 | else: 29 | ret.append((fname, pid, camid)) 30 | return ret 31 | 32 | 33 | ######################## 34 | 35 | 36 | class Market1501(Dataset): 37 | url = 'https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view' 38 | md5 = '65005ab7d12ec1c44de4eeafe813e68a' 39 | 40 | def __init__(self, root, split_id=0, num_val=100, download=True): 41 | super(Market1501, self).__init__(root, split_id=split_id) 42 | 43 | if download: 44 | self.download() 45 | 46 | if not self._check_integrity(): 47 | raise RuntimeError("Dataset not found or corrupted. " + 48 | "You can use download=True to download it.") 49 | 50 | self.load(num_val) 51 | 52 | def download(self): 53 | if self._check_integrity(): 54 | print("Files already downloaded and verified") 55 | return 56 | 57 | import re 58 | import hashlib 59 | import shutil 60 | from glob import glob 61 | from zipfile import ZipFile 62 | 63 | raw_dir = osp.join(self.root, 'raw') 64 | mkdir_if_missing(raw_dir) 65 | 66 | # Download the raw zip file 67 | fpath = osp.join(raw_dir, 'Market-1501-v15.09.15.zip') 68 | if osp.isfile(fpath) and \ 69 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 70 | print("Using downloaded file: " + fpath) 71 | else: 72 | raise RuntimeError("Please download the dataset manually from {} " 73 | "to {}".format(self.url, fpath)) 74 | 75 | # Extract the file 76 | exdir = osp.join(raw_dir, 'Market-1501-v15.09.15') 77 | if not osp.isdir(exdir): 78 | print("Extracting zip file") 79 | with ZipFile(fpath) as z: 80 | z.extractall(path=raw_dir) 81 | 82 | # Format 83 | images_dir = osp.join(self.root, 'images') 84 | mkdir_if_missing(images_dir) 85 | 86 | # 1501 identities (+1 for background) with 6 camera views each 87 | identities = [[[] for _ in range(6)] for _ in range(1502)] 88 | 89 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 90 | fnames = [] ######### Added. Names of images in new dir. 91 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) 92 | pids = set() 93 | for fpath in fpaths: 94 | fname = osp.basename(fpath) 95 | pid, cam = map(int, pattern.search(fname).groups()) 96 | if pid == -1: continue # junk images are just ignored 97 | assert 0 <= pid <= 1501 # pid == 0 means background 98 | assert 1 <= cam <= 6 99 | cam -= 1 100 | pids.add(pid) 101 | fname = ('{:08d}_{:02d}_{:04d}.jpg' 102 | .format(pid, cam, len(identities[pid][cam]))) 103 | identities[pid][cam].append(fname) 104 | shutil.copy(fpath, osp.join(images_dir, fname)) 105 | fnames.append(fname) ######### Added 106 | return pids, fnames 107 | 108 | trainval_pids, _ = register('bounding_box_train') 109 | gallery_pids, gallery_fnames = register('bounding_box_test') 110 | query_pids, query_fnames = register('query') 111 | assert query_pids <= gallery_pids 112 | assert trainval_pids.isdisjoint(gallery_pids) 113 | 114 | # Save meta information into a json file 115 | meta = {'name': 'Market1501', 'shot': 'multiple', 'num_cameras': 6, 116 | 'identities': identities, 117 | 'query_fnames': query_fnames, ######### Added 118 | 'gallery_fnames': gallery_fnames} ######### Added 119 | write_json(meta, osp.join(self.root, 'meta.json')) 120 | 121 | # Save the only training / test split 122 | splits = [{ 123 | 'trainval': sorted(list(trainval_pids)), 124 | 'query': sorted(list(query_pids)), 125 | 'gallery': sorted(list(gallery_pids))}] 126 | write_json(splits, osp.join(self.root, 'splits.json')) 127 | 128 | ######################## 129 | # Added 130 | def load(self, num_val=0.3, verbose=True): 131 | splits = read_json(osp.join(self.root, 'splits.json')) 132 | if self.split_id >= len(splits): 133 | raise ValueError("split_id exceeds total splits {}" 134 | .format(len(splits))) 135 | self.split = splits[self.split_id] 136 | 137 | # Randomly split train / val 138 | trainval_pids = np.asarray(self.split['trainval']) 139 | np.random.shuffle(trainval_pids) 140 | num = len(trainval_pids) 141 | if isinstance(num_val, float): 142 | num_val = int(round(num * num_val)) 143 | if num_val >= num or num_val < 0: 144 | raise ValueError("num_val exceeds total identities {}" 145 | .format(num)) 146 | train_pids = sorted(trainval_pids[:-num_val]) 147 | val_pids = sorted(trainval_pids[-num_val:]) 148 | 149 | self.meta = read_json(osp.join(self.root, 'meta.json')) 150 | identities = self.meta['identities'] 151 | 152 | self.train = _pluck(identities, train_pids, relabel=True) 153 | self.val = _pluck(identities, val_pids, relabel=True) 154 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 155 | self.num_train_ids = len(train_pids) 156 | self.num_val_ids = len(val_pids) 157 | self.num_trainval_ids = len(trainval_pids) 158 | 159 | ########## 160 | # Added 161 | query_fnames = self.meta['query_fnames'] 162 | gallery_fnames = self.meta['gallery_fnames'] 163 | self.query = [] 164 | for fname in query_fnames: 165 | name = osp.splitext(fname)[0] 166 | pid, cam, _ = map(int, name.split('_')) 167 | self.query.append((fname, pid, cam)) 168 | self.gallery = [] 169 | for fname in gallery_fnames: 170 | name = osp.splitext(fname)[0] 171 | pid, cam, _ = map(int, name.split('_')) 172 | self.gallery.append((fname, pid, cam)) 173 | ########## 174 | 175 | if verbose: 176 | print(self.__class__.__name__, "dataset loaded") 177 | print(" subset | # ids | # images") 178 | print(" ---------------------------") 179 | print(" train | {:5d} | {:8d}" 180 | .format(self.num_train_ids, len(self.train))) 181 | print(" val | {:5d} | {:8d}" 182 | .format(self.num_val_ids, len(self.val))) 183 | print(" trainval | {:5d} | {:8d}" 184 | .format(self.num_trainval_ids, len(self.trainval))) 185 | print(" query | {:5d} | {:8d}" 186 | .format(len(self.split['query']), len(self.query))) 187 | print(" gallery | {:5d} | {:8d}" 188 | .format(len(self.split['gallery']), len(self.gallery))) 189 | ######################## 190 | -------------------------------------------------------------------------------- /reid/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import read_json 7 | from ..utils.serialization import write_json 8 | 9 | 10 | def _pluck(identities, indices, relabel=False): 11 | """Extract im names of given pids. 12 | Args: 13 | identities: containing im names 14 | indices: pids 15 | relabel: whether to transform pids to classification labels 16 | """ 17 | ret = [] 18 | for index, pid in enumerate(indices): 19 | pid_images = identities[pid] 20 | for camid, cam_images in enumerate(pid_images): 21 | for fname in cam_images: 22 | name = osp.splitext(fname)[0] 23 | x, y, _ = map(int, name.split('_')) 24 | assert pid == x and camid == y 25 | if relabel: 26 | ret.append((fname, index, camid)) 27 | else: 28 | ret.append((fname, pid, camid)) 29 | return ret 30 | 31 | 32 | class DukeMTMC(Dataset): 33 | url = 'https://drive.google.com/uc?id=0B0VOCNYh8HeRdnBPa2ZWaVBYSVk' 34 | md5 = '2f93496f9b516d1ee5ef51c1d5e7d601' 35 | 36 | def __init__(self, root, split_id=0, num_val=100, download=True): 37 | super(DukeMTMC, self).__init__(root, split_id=split_id) 38 | 39 | if download: 40 | self.download() 41 | 42 | if not self._check_integrity(): 43 | raise RuntimeError("Dataset not found or corrupted. " + 44 | "You can use download=True to download it.") 45 | 46 | self.load(num_val) 47 | 48 | def download(self): 49 | if self._check_integrity(): 50 | print("Files already downloaded and verified") 51 | return 52 | 53 | import re 54 | import hashlib 55 | import shutil 56 | from glob import glob 57 | from zipfile import ZipFile 58 | 59 | raw_dir = osp.join(self.root, 'raw') 60 | mkdir_if_missing(raw_dir) 61 | 62 | # Download the raw zip file 63 | fpath = osp.join(raw_dir, 'DukeMTMC-reID.zip') 64 | if osp.isfile(fpath) and \ 65 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 66 | print("Using downloaded file: " + fpath) 67 | else: 68 | raise RuntimeError("Please download the dataset manually from {} " 69 | "to {}".format(self.url, fpath)) 70 | 71 | # Extract the file 72 | exdir = osp.join(raw_dir, 'DukeMTMC-reID') 73 | if not osp.isdir(exdir): 74 | print("Extracting zip file") 75 | with ZipFile(fpath) as z: 76 | z.extractall(path=raw_dir) 77 | 78 | # Format 79 | images_dir = osp.join(self.root, 'images') 80 | mkdir_if_missing(images_dir) 81 | 82 | identities = [] 83 | all_pids = {} 84 | 85 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 86 | fnames = [] ###### New Add. Names of images in new dir 87 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) 88 | pids = set() 89 | for fpath in fpaths: 90 | fname = osp.basename(fpath) 91 | pid, cam = map(int, pattern.search(fname).groups()) 92 | assert 1 <= cam <= 8 93 | cam -= 1 94 | if pid not in all_pids: 95 | all_pids[pid] = len(all_pids) 96 | pid = all_pids[pid] 97 | pids.add(pid) 98 | if pid >= len(identities): 99 | assert pid == len(identities) 100 | identities.append([[] for _ in range(8)]) # 8 camera views 101 | fname = ('{:08d}_{:02d}_{:04d}.jpg' 102 | .format(pid, cam, len(identities[pid][cam]))) 103 | identities[pid][cam].append(fname) 104 | shutil.copy(fpath, osp.join(images_dir, fname)) 105 | fnames.append(fname) ######## added 106 | return pids, fnames 107 | 108 | trainval_pids, _ = register('bounding_box_train') 109 | gallery_pids, gallery_fnames = register('bounding_box_test') 110 | query_pids, query_fnames = register('query') 111 | assert query_pids <= gallery_pids 112 | assert trainval_pids.isdisjoint(gallery_pids) 113 | 114 | # Save meta information into a json file 115 | meta = {'name': 'DukeMTMC', 'shot': 'multiple', 'num_cameras': 8, 116 | 'identities': identities, 117 | 'query_fnames': query_fnames, ########## Added 118 | 'gallery_fnames': gallery_fnames} ######### Added 119 | write_json(meta, osp.join(self.root, 'meta.json')) 120 | 121 | # Save the only training / test split 122 | splits = [{ 123 | 'trainval': sorted(list(trainval_pids)), 124 | 'query': sorted(list(query_pids)), 125 | 'gallery': sorted(list(gallery_pids))}] 126 | write_json(splits, osp.join(self.root, 'splits.json')) 127 | 128 | ######################## 129 | # Added 130 | def load(self, num_val=0.3, verbose=True): 131 | import numpy as np 132 | splits = read_json(osp.join(self.root, 'splits.json')) 133 | if self.split_id >= len(splits): 134 | raise ValueError("split_id exceeds total splits {}" 135 | .format(len(splits))) 136 | self.split = splits[self.split_id] 137 | 138 | # Randomly split train / val 139 | trainval_pids = np.asarray(self.split['trainval']) 140 | np.random.shuffle(trainval_pids) 141 | num = len(trainval_pids) 142 | if isinstance(num_val, float): 143 | num_val = int(round(num * num_val)) 144 | if num_val >= num or num_val < 0: 145 | raise ValueError("num_val exceeds total identities {}" 146 | .format(num)) 147 | train_pids = sorted(trainval_pids[:-num_val]) 148 | val_pids = sorted(trainval_pids[-num_val:]) 149 | 150 | self.meta = read_json(osp.join(self.root, 'meta.json')) 151 | identities = self.meta['identities'] 152 | 153 | self.train = _pluck(identities, train_pids, relabel=True) 154 | self.val = _pluck(identities, val_pids, relabel=True) 155 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 156 | self.num_train_ids = len(train_pids) 157 | self.num_val_ids = len(val_pids) 158 | self.num_trainval_ids = len(trainval_pids) 159 | 160 | ########## 161 | # Added 162 | query_fnames = self.meta['query_fnames'] 163 | gallery_fnames = self.meta['gallery_fnames'] 164 | self.query = [] 165 | for fname in query_fnames: 166 | name = osp.splitext(fname)[0] 167 | pid, cam, _ = map(int, name.split('_')) 168 | self.query.append((fname, pid, cam)) 169 | self.gallery = [] 170 | for fname in gallery_fnames: 171 | name = osp.splitext(fname)[0] 172 | pid, cam, _ = map(int, name.split('_')) 173 | self.gallery.append((fname, pid, cam)) 174 | ########## 175 | 176 | if verbose: 177 | print(self.__class__.__name__, "dataset loaded") 178 | print(" subset | # ids | # images") 179 | print(" ---------------------------") 180 | print(" train | {:5d} | {:8d}" 181 | .format(self.num_train_ids, len(self.train))) 182 | print(" val | {:5d} | {:8d}" 183 | .format(self.num_val_ids, len(self.val))) 184 | print(" trainval | {:5d} | {:8d}" 185 | .format(self.num_trainval_ids, len(self.trainval))) 186 | print(" query | {:5d} | {:8d}" 187 | .format(len(self.split['query']), len(self.query))) 188 | print(" gallery | {:5d} | {:8d}" 189 | .format(len(self.split['gallery']), len(self.gallery))) 190 | ######################## 191 | -------------------------------------------------------------------------------- /selftrainingKmeans.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function, absolute_import 5 | import argparse 6 | import time 7 | import os.path as osp 8 | import os 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from torch.nn import init 13 | from torch.backends import cudnn 14 | from torch.utils.data import DataLoader 15 | from reid import datasets 16 | from reid import models 17 | from reid.dist_metric import DistanceMetric 18 | from reid.loss import TripletLoss 19 | from reid.trainers import Trainer 20 | from reid.evaluators import Evaluator, extract_features 21 | from reid.utils.data import transforms as T 22 | import torch.nn.functional as F 23 | from reid.utils.data.preprocessor import Preprocessor 24 | from reid.utils.data.sampler import RandomIdentitySampler 25 | from reid.utils.serialization import load_checkpoint, save_checkpoint 26 | 27 | from sklearn.cluster import KMeans 28 | from reid.rerank import re_ranking 29 | 30 | 31 | def get_data(name, data_dir, height, width, batch_size, 32 | workers): 33 | root = osp.join(data_dir, name) 34 | 35 | dataset = datasets.create(name, root, num_val=0.1) 36 | 37 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 38 | std=[0.229, 0.224, 0.225]) 39 | 40 | # use all training and validation images in target dataset 41 | train_set = dataset.trainval 42 | num_classes = dataset.num_trainval_ids 43 | 44 | transformer = T.Compose([ 45 | T.Resize((height,width)), 46 | T.ToTensor(), 47 | normalizer, 48 | ]) 49 | 50 | extfeat_loader = DataLoader( 51 | Preprocessor(train_set, root=dataset.images_dir, 52 | transform=transformer), 53 | batch_size=batch_size, num_workers=workers, 54 | shuffle=False, pin_memory=True) 55 | 56 | test_loader = DataLoader( 57 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 58 | root=dataset.images_dir, transform=transformer), 59 | batch_size=batch_size//2, num_workers=workers, 60 | shuffle=False, pin_memory=True) 61 | 62 | return dataset, num_classes, extfeat_loader, test_loader 63 | 64 | 65 | def get_source_data(name, data_dir, height, width, batch_size, workers): 66 | root = osp.join(data_dir, name) 67 | 68 | dataset = datasets.create(name, root, num_val=0.1) 69 | 70 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 71 | std=[0.229, 0.224, 0.225]) 72 | 73 | # use all training images on source dataset 74 | train_set = dataset.train 75 | num_classes = dataset.num_train_ids 76 | 77 | transformer = T.Compose([ 78 | T.Resize((height,width)), 79 | T.ToTensor(), 80 | normalizer, 81 | ]) 82 | 83 | extfeat_loader = DataLoader( 84 | Preprocessor(train_set, root=dataset.images_dir, 85 | transform=transformer), 86 | batch_size=batch_size, num_workers=workers, 87 | shuffle=False, pin_memory=True) 88 | 89 | return dataset, extfeat_loader 90 | 91 | 92 | def splitLowconfi(feature, labels, centers, ratio=0.2): 93 | # set bot 20% imsimilar samples to -1 94 | # center VS feature 95 | centerDis = calDis(torch.from_numpy(feature), torch.from_numpy(centers)).numpy() # center VS samples 96 | noiseLoc = [] 97 | for ii, pid in enumerate(set(labels)): 98 | curDis = centerDis[:,ii] 99 | curDis[labels!=pid] = 100 100 | smallLossIdx = curDis.argsort() 101 | smallLossIdx = smallLossIdx[curDis[smallLossIdx]!=100] 102 | # bot 20% removed 103 | partSize = int(ratio*smallLossIdx.shape[0]) 104 | if partSize!=0: 105 | noiseLoc.extend(smallLossIdx[-partSize:]) 106 | labels[noiseLoc] = -1 107 | return labels 108 | 109 | 110 | def calDis(qFeature, gFeature):#246s 111 | x, y = F.normalize(qFeature), F.normalize(gFeature) 112 | # x, y = qFeature, gFeature 113 | m, n = x.shape[0], y.shape[0] 114 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 115 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 116 | disMat.addmm_(1, -2, x, y.t()) 117 | return disMat.clamp_(min=1e-5) 118 | 119 | 120 | def labelUnknown(knownFeat, allLab, unknownFeat): 121 | disMat = calDis(knownFeat, unknownFeat) 122 | labLoc = disMat.argmin(dim=0) 123 | return allLab[labLoc] 124 | 125 | 126 | def labelNoise(feature, labels): 127 | # features and labels with -1 128 | noiseFeat, pureFeat = feature[labels==-1,:], feature[labels!=-1,:] 129 | labels = labels[labels!=-1] 130 | unLab = labelUnknown(pureFeat, labels, noiseFeat) 131 | return unLab.numpy() 132 | 133 | 134 | def main(args): 135 | np.random.seed(args.seed) 136 | torch.manual_seed(args.seed) 137 | cudnn.benchmark = True 138 | 139 | # Create data loaders 140 | assert args.num_instances > 1, "num_instances should be greater than 1" 141 | assert args.batch_size % args.num_instances == 0, \ 142 | 'num_instances should divide batch_size' 143 | if args.height is None or args.width is None: 144 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 145 | (256, 128) 146 | 147 | # get source data 148 | src_dataset, src_extfeat_loader = \ 149 | get_source_data(args.src_dataset, args.data_dir, args.height, 150 | args.width, args.batch_size, args.workers) 151 | # get target data 152 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \ 153 | get_data(args.tgt_dataset, args.data_dir, args.height, 154 | args.width, args.batch_size, args.workers) 155 | 156 | # Create model 157 | # Hacking here to let the classifier be the number of source ids 158 | if args.src_dataset == 'dukemtmc': 159 | model = models.create(args.arch, num_classes=632, pretrained=False) 160 | elif args.src_dataset == 'market1501': 161 | model = models.create(args.arch, num_classes=676, pretrained=False) 162 | else: 163 | raise RuntimeError('Please specify the number of classes (ids) of the network.') 164 | 165 | # Load from checkpoint 166 | if args.resume: 167 | print('Resuming checkpoints from finetuned model on another dataset...\n') 168 | checkpoint = load_checkpoint(args.resume) 169 | model.load_state_dict(checkpoint['state_dict'], strict=False) 170 | else: 171 | raise RuntimeWarning('Not using a pre-trained model.') 172 | model = nn.DataParallel(model).cuda() 173 | 174 | # evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 175 | # if args.evaluate: return 176 | 177 | # Criterion 178 | criterion = [ 179 | TripletLoss(args.margin, args.num_instances, isAvg=True, use_semi=True).cuda(), 180 | TripletLoss(args.margin, args.num_instances, isAvg=True, use_semi=True).cuda(), 181 | ] 182 | 183 | 184 | # Optimizer 185 | optimizer = torch.optim.Adam( 186 | model.parameters(), lr = args.lr 187 | ) 188 | 189 | 190 | # training stage transformer on input images 191 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 192 | train_transformer = T.Compose([ 193 | T.Resize((args.height,args.width)), 194 | T.RandomHorizontalFlip(), 195 | T.ToTensor(), normalizer, 196 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3) 197 | ]) 198 | 199 | 200 | # # Start training 201 | for iter_n in range(args.iteration): 202 | if args.lambda_value == 0: 203 | source_features = 0 204 | else: 205 | # get source datas' feature 206 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq) 207 | # synchronization feature order with src_dataset.train 208 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0) 209 | 210 | # extract training images' features 211 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n+1)) 212 | target_features, tarNames = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq) 213 | # synchronization feature order with dataset.train 214 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0) 215 | target_real_label = np.asarray([tarNames[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval]) 216 | numTarID = len(set(target_real_label)) 217 | # calculate distance and rerank result 218 | print('Calculating feature distances...') 219 | target_features = target_features.numpy() 220 | cluster = KMeans(n_clusters=numTarID, n_jobs=8, n_init=1) 221 | 222 | # select & cluster images as training set of this epochs 223 | print('Clustering and labeling...') 224 | clusterRes = cluster.fit(target_features) 225 | labels, centers = clusterRes.labels_, clusterRes.cluster_centers_ 226 | # labels = splitLowconfi(target_features,labels,centers) 227 | # num_ids = len(set(labels)) 228 | # print('Iteration {} have {} training ids'.format(iter_n+1, num_ids)) 229 | # generate new dataset 230 | new_dataset = [] 231 | for (fname, _, cam), label in zip(tgt_dataset.trainval, labels): 232 | # if label==-1: continue 233 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0 234 | new_dataset.append((fname,label,cam)) 235 | print('Iteration {} have {} training images'.format(iter_n+1, len(new_dataset))) 236 | train_loader = DataLoader( 237 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer), 238 | batch_size=args.batch_size, num_workers=4, 239 | sampler=RandomIdentitySampler(new_dataset, args.num_instances), 240 | pin_memory=True, drop_last=True 241 | ) 242 | 243 | # train model with new generated dataset 244 | trainer = Trainer(model, criterion) 245 | 246 | 247 | evaluator = Evaluator(model, print_freq=args.print_freq) 248 | 249 | # Start training 250 | for epoch in range(args.epochs): 251 | # trainer.train(epoch, remRate=0.2+(0.6/args.iteration)*(1+iter_n)) # to at most 80% 252 | trainer.train(epoch, train_loader, optimizer) 253 | # test only 254 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 255 | #print('co-model:\n') 256 | #rank_score = evaluatorB.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 257 | 258 | # Evaluate 259 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 260 | save_checkpoint({ 261 | 'state_dict': model.module.state_dict(), 262 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0], 263 | }, True, fpath=osp.join(args.logs_dir, 'adapted.pth.tar')) 264 | return (rank_score.map, rank_score.market1501[0]) 265 | 266 | 267 | if __name__ == '__main__': 268 | parser = argparse.ArgumentParser(description="Triplet loss classification") 269 | # data 270 | parser.add_argument('--src_dataset', type=str, default='dukemtmc', 271 | choices=datasets.names()) 272 | parser.add_argument('--tgt_dataset', type=str, default='market1501', 273 | choices=datasets.names()) 274 | parser.add_argument('--batch_size', type=int, default=64) 275 | parser.add_argument('--workers', type=int, default=4) 276 | parser.add_argument('--split', type=int, default=0) 277 | parser.add_argument('--noiseLam', type=float, default=0.5) 278 | parser.add_argument('--height', type=int, 279 | help="input height, default: 256 for resnet*, " 280 | "144 for inception") 281 | parser.add_argument('--width', type=int, 282 | help="input width, default: 128 for resnet*, " 283 | "56 for inception") 284 | parser.add_argument('--combine-trainval', action='store_true', 285 | help="train and val sets together for training, " 286 | "val set alone for validation") 287 | parser.add_argument('--num_instances', type=int, default=4, 288 | help="each minibatch consist of " 289 | "(batch_size // num_instances) identities, and " 290 | "each identity has num_instances instances, " 291 | "default: 4") 292 | # model 293 | parser.add_argument('--arch', type=str, default='resnet50', 294 | choices=models.names()) 295 | # loss 296 | parser.add_argument('--margin', type=float, default=0.5, 297 | help="margin of the triplet loss, default: 0.5") 298 | parser.add_argument('--lambda_value', type=float, default=0.1, 299 | help="balancing parameter, default: 0.1") 300 | parser.add_argument('--rho', type=float, default=1.6e-3, 301 | help="rho percentage, default: 1.6e-3") 302 | # optimizer 303 | parser.add_argument('--lr', type=float, default=6e-5, 304 | help="learning rate of all parameters") 305 | # training configs 306 | parser.add_argument('--resume', type=str, metavar='PATH', 307 | default = '') 308 | parser.add_argument('--evaluate', type=int, default=0, 309 | help="evaluation only") 310 | parser.add_argument('--seed', type=int, default=1) 311 | parser.add_argument('--print_freq', type=int, default=1) 312 | parser.add_argument('--iteration', type=int, default=10) 313 | parser.add_argument('--epochs', type=int, default=30) 314 | # metric learning 315 | parser.add_argument('--dist_metric', type=str, default='euclidean', 316 | choices=['euclidean', 'kissme']) 317 | # misc 318 | parser.add_argument('--data_dir', type=str, metavar='PATH', 319 | default='') 320 | parser.add_argument('--logs_dir', type=str, metavar='PATH', 321 | default='') 322 | 323 | args = parser.parse_args() 324 | mean_ap, rank1 = main(args) 325 | results_file = np.asarray([mean_ap, rank1]) 326 | file_name = time.strftime("%H%M%S", time.localtime()) 327 | file_name = osp.join(args.logs_dir, file_name) 328 | np.save(file_name, results_file) 329 | -------------------------------------------------------------------------------- /selfNoise.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function, absolute_import 5 | import argparse 6 | import time 7 | import os.path as osp 8 | import os 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from torch.nn import init 13 | from torch.backends import cudnn 14 | from torch.utils.data import DataLoader 15 | from reid import datasets 16 | from reid import models 17 | from reid.dist_metric import DistanceMetric 18 | from reid.loss import TripletLoss 19 | from reid.trainers import Trainer 20 | from reid.evaluators import Evaluator, extract_features 21 | from reid.utils.data import transforms as T 22 | import torch.nn.functional as F 23 | from reid.utils.data.preprocessor import Preprocessor 24 | from reid.utils.data.sampler import RandomIdentitySampler 25 | from reid.utils.serialization import load_checkpoint, save_checkpoint 26 | 27 | from sklearn.cluster import DBSCAN 28 | from reid.rerank import re_ranking 29 | 30 | 31 | def calScores(clusters, labels): 32 | """ 33 | compute pair-wise precision pair-wise recall 34 | """ 35 | from scipy.special import comb 36 | if len(clusters) == 0: 37 | return 0, 0 38 | else: 39 | curCluster = [] 40 | for curClus in clusters.values(): 41 | curCluster.append(labels[curClus]) 42 | TPandFP = sum([comb(len(val), 2) for val in curCluster]) 43 | TP = 0 44 | for clusterVal in curCluster: 45 | for setMember in set(clusterVal): 46 | if sum(clusterVal == setMember) < 2: continue 47 | TP += comb(sum(clusterVal == setMember), 2) 48 | FP = TPandFP - TP 49 | # FN and TN 50 | TPandFN = sum([comb(labels.tolist().count(val), 2) for val in set(labels)]) 51 | FN = TPandFN - TP 52 | # cal precision and recall 53 | precision, recall = TP / (TP + FP), TP / (TP + FN) 54 | fScore = 2 * precision * recall / (precision + recall) 55 | return precision, recall, fScore 56 | 57 | 58 | def get_data(name, data_dir, height, width, batch_size, 59 | workers): 60 | root = osp.join(data_dir, name) 61 | 62 | dataset = datasets.create(name, root, num_val=0.1) 63 | 64 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 65 | std=[0.229, 0.224, 0.225]) 66 | 67 | # use all training and validation images in target dataset 68 | train_set = dataset.trainval 69 | num_classes = dataset.num_trainval_ids 70 | 71 | transformer = T.Compose([ 72 | T.Resize((height, width)), 73 | T.ToTensor(), 74 | normalizer, 75 | ]) 76 | 77 | extfeat_loader = DataLoader( 78 | Preprocessor(train_set, root=dataset.images_dir, 79 | transform=transformer), 80 | batch_size=batch_size, num_workers=workers, 81 | shuffle=False, pin_memory=True) 82 | 83 | test_loader = DataLoader( 84 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 85 | root=dataset.images_dir, transform=transformer), 86 | batch_size=batch_size, num_workers=workers, 87 | shuffle=False, pin_memory=True) 88 | 89 | return dataset, num_classes, extfeat_loader, test_loader 90 | 91 | 92 | def get_source_data(name, data_dir, height, width, batch_size, 93 | workers): 94 | root = osp.join(data_dir, name) 95 | 96 | dataset = datasets.create(name, root, num_val=0.1) 97 | 98 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 99 | std=[0.229, 0.224, 0.225]) 100 | 101 | # use all training images on source dataset 102 | train_set = dataset.train 103 | num_classes = dataset.num_train_ids 104 | 105 | transformer = T.Compose([ 106 | T.Resize((height, width)), 107 | T.ToTensor(), 108 | normalizer, 109 | ]) 110 | 111 | extfeat_loader = DataLoader( 112 | Preprocessor(train_set, root=dataset.images_dir, 113 | transform=transformer), 114 | batch_size=batch_size, num_workers=workers, 115 | shuffle=False, pin_memory=True) 116 | 117 | return dataset, extfeat_loader 118 | 119 | 120 | def calDis(qFeature, gFeature): # 246s 121 | x, y = F.normalize(qFeature), F.normalize(gFeature) 122 | # x, y = qFeature, gFeature 123 | m, n = x.shape[0], y.shape[0] 124 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 125 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 126 | disMat.addmm_(1, -2, x, y.t()) 127 | return disMat.clamp_(min=1e-5) 128 | 129 | 130 | def labelUnknown(knownFeat, allLab, unknownFeat): 131 | # allLab--label from known 132 | disMat = calDis(knownFeat, unknownFeat) 133 | labLoc = disMat.argmin(dim=0) 134 | return allLab[labLoc] 135 | 136 | 137 | def labelNoise(feature, labels): 138 | # features and labels with -1 139 | noiseFeat, pureFeat = feature[labels == -1, :], feature[labels != -1, :] 140 | pureLabs = labels[labels != -1] # no outliers 141 | unLab = labelUnknown(pureFeat, pureLabs, noiseFeat) 142 | labels[labels == -1] = unLab 143 | return labels.numpy() 144 | 145 | 146 | def getCenter(features, labels): 147 | allCenter = {} 148 | features = features[labels != -1, :] 149 | labels = labels[labels != -1] 150 | for pid in set(labels): 151 | allCenter[pid] = torch.from_numpy(features[labels == pid, :].mean(axis=0)).unsqueeze(0) 152 | return torch.cat(list(allCenter.values())) 153 | 154 | 155 | def main(args): 156 | np.random.seed(args.seed) 157 | torch.manual_seed(args.seed) 158 | cudnn.benchmark = True 159 | 160 | # Create data loaders 161 | assert args.num_instances > 1, "num_instances should be greater than 1" 162 | assert args.batch_size % args.num_instances == 0, \ 163 | 'num_instances should divide batch_size' 164 | if args.height is None or args.width is None: 165 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 166 | (256, 128) 167 | 168 | # get source data 169 | src_dataset, src_extfeat_loader = \ 170 | get_source_data(args.src_dataset, args.data_dir, args.height, 171 | args.width, args.batch_size, args.workers) 172 | # get target data 173 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \ 174 | get_data(args.tgt_dataset, args.data_dir, args.height, 175 | args.width, args.batch_size, args.workers) 176 | 177 | # Create model 178 | # Hacking here to let the classifier be the number of source ids 179 | if args.src_dataset == 'dukemtmc': 180 | model = models.create(args.arch, num_classes=632, pretrained=False) 181 | elif args.src_dataset == 'market1501': 182 | model = models.create(args.arch, num_classes=676, pretrained=False) 183 | else: 184 | raise RuntimeError('Please specify the number of classes (ids) of the network.') 185 | 186 | # Load from checkpoint 187 | if args.resume: 188 | print('Resuming checkpoints from finetuned model on another dataset...\n') 189 | checkpoint = load_checkpoint(args.resume) 190 | model.load_state_dict(checkpoint['state_dict'], strict=False) 191 | else: 192 | raise RuntimeWarning('Not using a pre-trained model.') 193 | model = nn.DataParallel(model).cuda() 194 | 195 | # Criterion 196 | criterion = [ 197 | TripletLoss(args.margin, args.num_instances, use_semi=False).cuda(), 198 | TripletLoss(args.margin, args.num_instances, use_semi=False).cuda() 199 | ] 200 | optimizer = torch.optim.Adam( 201 | model.parameters(), lr=args.lr 202 | ) 203 | 204 | # training stage transformer on input images 205 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 206 | train_transformer = T.Compose([ 207 | T.Resize((args.height, args.width)), 208 | T.RandomHorizontalFlip(), 209 | T.ToTensor(), normalizer, 210 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3) 211 | ]) 212 | 213 | # # Start training 214 | for iter_n in range(args.iteration): 215 | if args.lambda_value == 0: 216 | source_features = 0 217 | else: 218 | # get source datas' feature 219 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq) 220 | # synchronization feature order with src_dataset.train 221 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0) 222 | 223 | # extract training images' features 224 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n + 1)) 225 | target_features, _ = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq) 226 | # synchronization feature order with dataset.train 227 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0) 228 | # calculate distance and rerank result 229 | print('Calculating feature distances...') 230 | target_features = target_features.numpy() 231 | rerank_dist = re_ranking(source_features, target_features, lambda_value=args.lambda_value) 232 | if iter_n == 0: 233 | # DBSCAN cluster 234 | tri_mat = np.triu(rerank_dist, 1) # tri_mat.dim=2 235 | tri_mat = tri_mat[np.nonzero(tri_mat)] # tri_mat.dim=1 236 | tri_mat = np.sort(tri_mat, axis=None) 237 | top_num = np.round(args.rho * tri_mat.size).astype(int) 238 | eps = tri_mat[:top_num].mean() 239 | print('eps in cluster: {:.3f}'.format(eps)) 240 | cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=8) 241 | # select & cluster images as training set of this epochs 242 | print('Clustering and labeling...') 243 | labels = cluster.fit_predict(rerank_dist) 244 | num_ids = len(set(labels)) - 1 245 | print('Iteration {} have {} training ids'.format(iter_n + 1, num_ids)) 246 | # generate new dataset 247 | new_dataset = [] 248 | # assign label for target ones 249 | newLab = labelNoise(torch.from_numpy(target_features), torch.from_numpy(labels)) 250 | # unknownFeats = target_features[labels==-1,:] 251 | counter = 0 252 | from collections import defaultdict 253 | realIDs, fakeIDs = defaultdict(list), [] 254 | for (fname, realID, cam), label in zip(tgt_dataset.trainval, newLab): 255 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0 256 | new_dataset.append((fname, label, cam)) 257 | realIDs[realID].append(counter) 258 | fakeIDs.append(label) 259 | counter += 1 260 | precision, recall, fscore = calScores(realIDs, np.asarray(fakeIDs)) 261 | print('Iteration {} have {} training images'.format(iter_n + 1, len(new_dataset))) 262 | print(f'precision:{precision * 100}, recall:{100 * recall}, fscore:{fscore}') 263 | train_loader = DataLoader( 264 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer), 265 | batch_size=args.batch_size, num_workers=4, 266 | sampler=RandomIdentitySampler(new_dataset, args.num_instances), 267 | pin_memory=True, drop_last=True 268 | ) 269 | 270 | trainer = Trainer(model, criterion) 271 | 272 | # Start training 273 | for epoch in range(args.epochs): 274 | trainer.train(epoch, train_loader, optimizer) # to at most 80% 275 | # test only 276 | evaluator = Evaluator(model, print_freq=args.print_freq) 277 | # rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 278 | 279 | # Evaluate 280 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 281 | save_checkpoint({ 282 | 'state_dict': model.module.state_dict(), 283 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0], 284 | }, True, fpath=osp.join(args.logs_dir, 'adapted.pth.tar')) 285 | return rank_score.map, rank_score.market1501[0] 286 | 287 | 288 | if __name__ == '__main__': 289 | parser = argparse.ArgumentParser(description="Triplet loss classification") 290 | # data 291 | parser.add_argument('--src_dataset', type=str, default='dukemtmc', 292 | choices=datasets.names()) 293 | parser.add_argument('--tgt_dataset', type=str, default='market1501', 294 | choices=datasets.names()) 295 | parser.add_argument('--batch_size', type=int, default=64) 296 | parser.add_argument('--workers', type=int, default=4) 297 | parser.add_argument('--split', type=int, default=0) 298 | parser.add_argument('--noiseLam', type=float, default=0.5) 299 | parser.add_argument('--height', type=int, 300 | help="input height, default: 256 for resnet*, " 301 | "144 for inception") 302 | parser.add_argument('--width', type=int, 303 | help="input width, default: 128 for resnet*, " 304 | "56 for inception") 305 | parser.add_argument('--combine-trainval', action='store_true', 306 | help="train and val sets together for training, " 307 | "val set alone for validation") 308 | parser.add_argument('--num_instances', type=int, default=4, 309 | help="each minibatch consist of " 310 | "(batch_size // num_instances) identities, and " 311 | "each identity has num_instances instances, " 312 | "default: 4") 313 | # model 314 | parser.add_argument('--arch', type=str, default='resnet50', 315 | choices=models.names()) 316 | # loss 317 | parser.add_argument('--margin', type=float, default=0.5, 318 | help="margin of the triplet loss, default: 0.5") 319 | parser.add_argument('--lambda_value', type=float, default=0.1, 320 | help="balancing parameter, default: 0.1") 321 | parser.add_argument('--rho', type=float, default=1.6e-3, 322 | help="rho percentage, default: 1.6e-3") 323 | # optimizer 324 | parser.add_argument('--lr', type=float, default=6e-5, 325 | help="learning rate of all parameters") 326 | # training configs 327 | parser.add_argument('--resume', type=str, metavar='PATH', 328 | default='') 329 | parser.add_argument('--evaluate', type=int, default=0, 330 | help="evaluation only") 331 | parser.add_argument('--seed', type=int, default=1) 332 | parser.add_argument('--print_freq', type=int, default=1) 333 | parser.add_argument('--iteration', type=int, default=10) 334 | parser.add_argument('--epochs', type=int, default=30) 335 | # metric learning 336 | parser.add_argument('--dist_metric', type=str, default='euclidean', 337 | choices=['euclidean', 'kissme']) 338 | # misc 339 | parser.add_argument('--data_dir', type=str, metavar='PATH', 340 | default='') 341 | parser.add_argument('--logs_dir', type=str, metavar='PATH', 342 | default='') 343 | 344 | args = parser.parse_args() 345 | mean_ap, rank1 = main(args) 346 | -------------------------------------------------------------------------------- /selftrainingKmeansAsy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function, absolute_import 5 | import argparse 6 | import time 7 | import os.path as osp 8 | import os 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from torch.nn import init 13 | from torch.backends import cudnn 14 | from torch.utils.data import DataLoader 15 | from reid import datasets 16 | from reid import models 17 | from reid.dist_metric import DistanceMetric 18 | from reid.loss import TripletLoss 19 | from reid.trainers import Trainer, CoTrainerAsy, CoTeaching, CoTrainerAsySep 20 | from reid.evaluators import Evaluator, extract_features 21 | from reid.utils.data import transforms as T 22 | import torch.nn.functional as F 23 | from reid.utils.data.preprocessor import Preprocessor 24 | from reid.utils.data.sampler import RandomIdentitySampler 25 | from reid.utils.serialization import load_checkpoint, save_checkpoint 26 | 27 | from sklearn.cluster import KMeans 28 | from reid.rerank import re_ranking 29 | 30 | 31 | def get_data(name, data_dir, height, width, batch_size, 32 | workers): 33 | root = osp.join(data_dir, name) 34 | 35 | dataset = datasets.create(name, root, num_val=0.1) 36 | 37 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 38 | std=[0.229, 0.224, 0.225]) 39 | 40 | # use all training and validation images in target dataset 41 | train_set = dataset.trainval 42 | num_classes = dataset.num_trainval_ids 43 | 44 | transformer = T.Compose([ 45 | T.Resize((height,width)), 46 | T.ToTensor(), 47 | normalizer, 48 | ]) 49 | 50 | extfeat_loader = DataLoader( 51 | Preprocessor(train_set, root=dataset.images_dir, 52 | transform=transformer), 53 | batch_size=batch_size, num_workers=workers, 54 | shuffle=False, pin_memory=True) 55 | 56 | test_loader = DataLoader( 57 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 58 | root=dataset.images_dir, transform=transformer), 59 | batch_size=batch_size//2, num_workers=workers, 60 | shuffle=False, pin_memory=True) 61 | 62 | return dataset, num_classes, extfeat_loader, test_loader 63 | 64 | 65 | def get_source_data(name, data_dir, height, width, batch_size, workers): 66 | root = osp.join(data_dir, name) 67 | 68 | dataset = datasets.create(name, root, num_val=0.1) 69 | 70 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 71 | std=[0.229, 0.224, 0.225]) 72 | 73 | # use all training images on source dataset 74 | train_set = dataset.train 75 | num_classes = dataset.num_train_ids 76 | 77 | transformer = T.Compose([ 78 | T.Resize((height,width)), 79 | T.ToTensor(), 80 | normalizer, 81 | ]) 82 | 83 | extfeat_loader = DataLoader( 84 | Preprocessor(train_set, root=dataset.images_dir, 85 | transform=transformer), 86 | batch_size=batch_size, num_workers=workers, 87 | shuffle=False, pin_memory=True) 88 | 89 | return dataset, extfeat_loader 90 | 91 | 92 | def splitLowconfi(feature, labels, centers, ratio=0.2): 93 | # set bot 20% imsimilar samples to -1 94 | # center VS feature 95 | centerDis = calDis(torch.from_numpy(feature), torch.from_numpy(centers)).numpy() # center VS samples 96 | noiseLoc = [] 97 | for ii, pid in enumerate(set(labels)): 98 | curDis = centerDis[:,ii] 99 | curDis[labels!=pid] = 100 100 | smallLossIdx = curDis.argsort() 101 | smallLossIdx = smallLossIdx[curDis[smallLossIdx]!=100] 102 | # bot 20% removed 103 | partSize = int(ratio*smallLossIdx.shape[0]) 104 | if partSize!=0: 105 | noiseLoc.extend(smallLossIdx[-partSize:]) 106 | labels[noiseLoc] = -1 107 | return labels 108 | 109 | 110 | def calDis(qFeature, gFeature):#246s 111 | x, y = F.normalize(qFeature), F.normalize(gFeature) 112 | # x, y = qFeature, gFeature 113 | m, n = x.shape[0], y.shape[0] 114 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 115 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 116 | disMat.addmm_(1, -2, x, y.t()) 117 | return disMat.clamp_(min=1e-5) 118 | 119 | 120 | def labelUnknown(knownFeat, allLab, unknownFeat): 121 | disMat = calDis(knownFeat, unknownFeat) 122 | labLoc = disMat.argmin(dim=0) 123 | return allLab[labLoc] 124 | 125 | 126 | def labelNoise(feature, labels): 127 | # features and labels with -1 128 | noiseFeat, pureFeat = feature[labels==-1,:], feature[labels!=-1,:] 129 | labels = labels[labels!=-1] 130 | unLab = labelUnknown(pureFeat, labels, noiseFeat) 131 | return unLab.numpy() 132 | 133 | 134 | def main(args): 135 | np.random.seed(args.seed) 136 | torch.manual_seed(args.seed) 137 | cudnn.benchmark = True 138 | 139 | # Create data loaders 140 | assert args.num_instances > 1, "num_instances should be greater than 1" 141 | assert args.batch_size % args.num_instances == 0, \ 142 | 'num_instances should divide batch_size' 143 | if args.height is None or args.width is None: 144 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 145 | (256, 128) 146 | 147 | # get source data 148 | src_dataset, src_extfeat_loader = \ 149 | get_source_data(args.src_dataset, args.data_dir, args.height, 150 | args.width, args.batch_size, args.workers) 151 | # get target data 152 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \ 153 | get_data(args.tgt_dataset, args.data_dir, args.height, 154 | args.width, args.batch_size, args.workers) 155 | 156 | # Create model 157 | # Hacking here to let the classifier be the number of source ids 158 | if args.src_dataset == 'dukemtmc': 159 | model = models.create(args.arch, num_classes=632, pretrained=False) 160 | coModel = models.create(args.arch, num_classes=632, pretrained=False) 161 | elif args.src_dataset == 'market1501': 162 | model = models.create(args.arch, num_classes=676, pretrained=False) 163 | coModel = models.create(args.arch, num_classes=676, pretrained=False) 164 | else: 165 | raise RuntimeError('Please specify the number of classes (ids) of the network.') 166 | 167 | # Load from checkpoint 168 | if args.resume: 169 | print('Resuming checkpoints from finetuned model on another dataset...\n') 170 | checkpoint = load_checkpoint(args.resume) 171 | model.load_state_dict(checkpoint['state_dict'], strict=False) 172 | coModel.load_state_dict(checkpoint['state_dict'], strict=False) 173 | else: 174 | raise RuntimeWarning('Not using a pre-trained model.') 175 | model = nn.DataParallel(model).cuda() 176 | coModel = nn.DataParallel(coModel).cuda() 177 | 178 | # evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 179 | # if args.evaluate: return 180 | 181 | # Criterion 182 | criterion = [ 183 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(), 184 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(), 185 | ] 186 | 187 | 188 | # Optimizer 189 | optimizer = torch.optim.Adam( 190 | model.parameters(), lr = args.lr 191 | ) 192 | coOptimizer = torch.optim.Adam( 193 | coModel.parameters(), lr = args.lr 194 | ) 195 | 196 | optims = [optimizer, coOptimizer] 197 | 198 | 199 | # training stage transformer on input images 200 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 201 | train_transformer = T.Compose([ 202 | T.Resize((args.height,args.width)), 203 | T.RandomHorizontalFlip(), 204 | T.ToTensor(), normalizer, 205 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3) 206 | ]) 207 | 208 | 209 | # # Start training 210 | for iter_n in range(args.iteration): 211 | if args.lambda_value == 0: 212 | source_features = 0 213 | else: 214 | # get source datas' feature 215 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq) 216 | # synchronization feature order with src_dataset.train 217 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0) 218 | 219 | # extract training images' features 220 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n+1)) 221 | target_features, tarNames = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq) 222 | # synchronization feature order with dataset.train 223 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0) 224 | target_real_label = np.asarray([tarNames[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval]) 225 | numTarID = len(set(target_real_label)) 226 | # calculate distance and rerank result 227 | print('Calculating feature distances...') 228 | target_features = target_features.numpy() 229 | cluster = KMeans(n_clusters=numTarID, n_jobs=8, n_init=1) 230 | 231 | # select & cluster images as training set of this epochs 232 | print('Clustering and labeling...') 233 | clusterRes = cluster.fit(target_features) 234 | labels, centers = clusterRes.labels_, clusterRes.cluster_centers_ 235 | labels = splitLowconfi(target_features,labels,centers) 236 | # num_ids = len(set(labels)) 237 | # print('Iteration {} have {} training ids'.format(iter_n+1, num_ids)) 238 | # generate new dataset 239 | new_dataset, unknown_dataset = [], [] 240 | # assign label for target ones 241 | unknownLab = labelNoise(torch.from_numpy(target_features), torch.from_numpy(labels)) 242 | # unknownFeats = target_features[labels==-1,:] 243 | unCounter = 0 244 | for (fname, _, cam), label in zip(tgt_dataset.trainval, labels): 245 | if label==-1: 246 | unknown_dataset.append((fname,int(unknownLab[unCounter]),cam)) # unknown data 247 | unCounter += 1 248 | continue 249 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0 250 | new_dataset.append((fname,label,cam)) 251 | print('Iteration {} have {} training images'.format(iter_n+1, len(new_dataset))) 252 | 253 | train_loader = DataLoader( 254 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer), 255 | batch_size=args.batch_size, num_workers=4, 256 | sampler=RandomIdentitySampler(new_dataset, args.num_instances), 257 | pin_memory=True, drop_last=True 258 | ) 259 | # hard samples 260 | unLoader = DataLoader( 261 | Preprocessor(unknown_dataset, root=tgt_dataset.images_dir, transform=train_transformer), 262 | batch_size=args.batch_size, num_workers=4, 263 | sampler=RandomIdentitySampler(unknown_dataset, args.num_instances), 264 | pin_memory=True, drop_last=True 265 | ) 266 | 267 | # train model with new generated dataset 268 | trainer = CoTrainerAsy( 269 | model, coModel, train_loader, unLoader, criterion, optims 270 | ) 271 | # trainer = CoTeaching( 272 | # model, coModel, train_loader, unLoader, criterion, optims 273 | # ) 274 | # trainer = CoTrainerAsySep( 275 | # model, coModel, train_loader, unLoader, criterion, optims 276 | # ) 277 | 278 | evaluator = Evaluator(model, print_freq=args.print_freq) 279 | #evaluatorB = Evaluator(coModel, print_freq=args.print_freq) 280 | # Start training 281 | for epoch in range(args.epochs): 282 | trainer.train(epoch, remRate=0.2+(0.6/args.iteration)*(1+iter_n)) # to at most 80% 283 | # trainer.train(epoch, remRate=0.7+(0.3/args.iteration)*(1+iter_n)) 284 | # test only 285 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 286 | #print('co-model:\n') 287 | #rank_score = evaluatorB.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 288 | 289 | # Evaluate 290 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 291 | save_checkpoint({ 292 | 'state_dict': model.module.state_dict(), 293 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0], 294 | }, True, fpath=osp.join(args.logs_dir, 'adapted.pth.tar')) 295 | return (rank_score.map, rank_score.market1501[0]) 296 | 297 | 298 | if __name__ == '__main__': 299 | parser = argparse.ArgumentParser(description="Triplet loss classification") 300 | # data 301 | parser.add_argument('--src_dataset', type=str, default='dukemtmc', 302 | choices=datasets.names()) 303 | parser.add_argument('--tgt_dataset', type=str, default='market1501', 304 | choices=datasets.names()) 305 | parser.add_argument('--batch_size', type=int, default=64) 306 | parser.add_argument('--workers', type=int, default=4) 307 | parser.add_argument('--split', type=int, default=0) 308 | parser.add_argument('--noiseLam', type=float, default=0.5) 309 | parser.add_argument('--height', type=int, 310 | help="input height, default: 256 for resnet*, " 311 | "144 for inception") 312 | parser.add_argument('--width', type=int, 313 | help="input width, default: 128 for resnet*, " 314 | "56 for inception") 315 | parser.add_argument('--combine-trainval', action='store_true', 316 | help="train and val sets together for training, " 317 | "val set alone for validation") 318 | parser.add_argument('--num_instances', type=int, default=4, 319 | help="each minibatch consist of " 320 | "(batch_size // num_instances) identities, and " 321 | "each identity has num_instances instances, " 322 | "default: 4") 323 | # model 324 | parser.add_argument('--arch', type=str, default='resnet50', 325 | choices=models.names()) 326 | # loss 327 | parser.add_argument('--margin', type=float, default=0.5, 328 | help="margin of the triplet loss, default: 0.5") 329 | parser.add_argument('--lambda_value', type=float, default=0.1, 330 | help="balancing parameter, default: 0.1") 331 | parser.add_argument('--rho', type=float, default=1.6e-3, 332 | help="rho percentage, default: 1.6e-3") 333 | # optimizer 334 | parser.add_argument('--lr', type=float, default=6e-5, 335 | help="learning rate of all parameters") 336 | # training configs 337 | parser.add_argument('--resume', type=str, metavar='PATH', 338 | default = '') 339 | parser.add_argument('--evaluate', type=int, default=0, 340 | help="evaluation only") 341 | parser.add_argument('--seed', type=int, default=1) 342 | parser.add_argument('--print_freq', type=int, default=1) 343 | parser.add_argument('--iteration', type=int, default=10) 344 | parser.add_argument('--epochs', type=int, default=30) 345 | # metric learning 346 | parser.add_argument('--dist_metric', type=str, default='euclidean', 347 | choices=['euclidean', 'kissme']) 348 | # misc 349 | parser.add_argument('--data_dir', type=str, metavar='PATH', 350 | default='') 351 | parser.add_argument('--logs_dir', type=str, metavar='PATH', 352 | default='') 353 | 354 | args = parser.parse_args() 355 | mean_ap, rank1 = main(args) 356 | results_file = np.asarray([mean_ap, rank1]) 357 | file_name = time.strftime("%H%M%S", time.localtime()) 358 | file_name = osp.join(args.logs_dir, file_name) 359 | np.save(file_name, results_file) 360 | -------------------------------------------------------------------------------- /selftrainingCT.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function, absolute_import 5 | import argparse 6 | import time 7 | import os.path as osp 8 | import os 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from torch.nn import init 13 | from torch.backends import cudnn 14 | from torch.utils.data import DataLoader 15 | from reid import datasets 16 | from reid import models 17 | from reid.dist_metric import DistanceMetric 18 | from reid.loss import TripletLoss 19 | from reid.trainers import CoTeaching 20 | from reid.evaluators import Evaluator, extract_features 21 | from reid.utils.data import transforms as T 22 | import torch.nn.functional as F 23 | from reid.utils.data.preprocessor import Preprocessor 24 | from reid.utils.data.sampler import RandomIdentitySampler 25 | from reid.utils.serialization import load_checkpoint, save_checkpoint 26 | 27 | from sklearn.cluster import DBSCAN 28 | from reid.rerank import re_ranking 29 | 30 | 31 | def calScores(clusters, labels): 32 | """ 33 | compute pair-wise precision pair-wise recall 34 | """ 35 | from scipy.special import comb 36 | if len(clusters) == 0: 37 | return 0, 0 38 | else: 39 | curCluster = [] 40 | for curClus in clusters.values(): 41 | curCluster.append(labels[curClus]) 42 | TPandFP = sum([comb(len(val), 2) for val in curCluster]) 43 | TP = 0 44 | for clusterVal in curCluster: 45 | for setMember in set(clusterVal): 46 | if sum(clusterVal == setMember) < 2: continue 47 | TP += comb(sum(clusterVal == setMember), 2) 48 | FP = TPandFP - TP 49 | # FN and TN 50 | TPandFN = sum([comb(labels.tolist().count(val), 2) for val in set(labels)]) 51 | FN = TPandFN - TP 52 | # cal precision and recall 53 | precision, recall = TP / (TP + FP), TP / (TP + FN) 54 | fScore = 2 * precision * recall / (precision + recall) 55 | return precision, recall, fScore 56 | 57 | 58 | def get_data(name, data_dir, height, width, batch_size, 59 | workers): 60 | root = osp.join(data_dir, name) 61 | 62 | dataset = datasets.create(name, root, num_val=0.1) 63 | 64 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 65 | std=[0.229, 0.224, 0.225]) 66 | 67 | # use all training and validation images in target dataset 68 | train_set = dataset.trainval 69 | num_classes = dataset.num_trainval_ids 70 | 71 | transformer = T.Compose([ 72 | T.Resize((height, width)), 73 | T.ToTensor(), 74 | normalizer, 75 | ]) 76 | 77 | extfeat_loader = DataLoader( 78 | Preprocessor(train_set, root=dataset.images_dir, 79 | transform=transformer), 80 | batch_size=batch_size, num_workers=workers, 81 | shuffle=False, pin_memory=True) 82 | 83 | test_loader = DataLoader( 84 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 85 | root=dataset.images_dir, transform=transformer), 86 | batch_size=batch_size, num_workers=workers, 87 | shuffle=False, pin_memory=True) 88 | 89 | return dataset, num_classes, extfeat_loader, test_loader 90 | 91 | 92 | def get_source_data(name, data_dir, height, width, batch_size, 93 | workers): 94 | root = osp.join(data_dir, name) 95 | 96 | dataset = datasets.create(name, root, num_val=0.1) 97 | 98 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 99 | std=[0.229, 0.224, 0.225]) 100 | 101 | # use all training images on source dataset 102 | train_set = dataset.train 103 | num_classes = dataset.num_train_ids 104 | 105 | transformer = T.Compose([ 106 | T.Resize((height, width)), 107 | T.ToTensor(), 108 | normalizer, 109 | ]) 110 | 111 | extfeat_loader = DataLoader( 112 | Preprocessor(train_set, root=dataset.images_dir, 113 | transform=transformer), 114 | batch_size=batch_size, num_workers=workers, 115 | shuffle=False, pin_memory=True) 116 | 117 | return dataset, extfeat_loader 118 | 119 | 120 | def calDis(qFeature, gFeature): # 246s 121 | x, y = F.normalize(qFeature), F.normalize(gFeature) 122 | # x, y = qFeature, gFeature 123 | m, n = x.shape[0], y.shape[0] 124 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 125 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 126 | disMat.addmm_(1, -2, x, y.t()) 127 | return disMat.clamp_(min=1e-5) 128 | 129 | 130 | def labelUnknown(knownFeat, allLab, unknownFeat): 131 | # allLab--label from known 132 | disMat = calDis(knownFeat, unknownFeat) 133 | labLoc = disMat.argmin(dim=0) 134 | return allLab[labLoc] 135 | 136 | 137 | def labelNoise(feature, labels): 138 | # features and labels with -1 139 | noiseFeat, pureFeat = feature[labels == -1, :], feature[labels != -1, :] 140 | pureLabs = labels[labels != -1] # no outliers 141 | unLab = labelUnknown(pureFeat, pureLabs, noiseFeat) 142 | labels[labels == -1] = unLab 143 | return labels.numpy() 144 | 145 | 146 | def getCenter(features, labels): 147 | allCenter = {} 148 | features = features[labels != -1, :] 149 | labels = labels[labels != -1] 150 | for pid in set(labels): 151 | allCenter[pid] = torch.from_numpy(features[labels == pid, :].mean(axis=0)).unsqueeze(0) 152 | return torch.cat(list(allCenter.values())) 153 | 154 | 155 | def main(args): 156 | np.random.seed(args.seed) 157 | torch.manual_seed(args.seed) 158 | cudnn.benchmark = True 159 | 160 | # Create data loaders 161 | assert args.num_instances > 1, "num_instances should be greater than 1" 162 | assert args.batch_size % args.num_instances == 0, \ 163 | 'num_instances should divide batch_size' 164 | if args.height is None or args.width is None: 165 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 166 | (256, 128) 167 | 168 | # get source data 169 | src_dataset, src_extfeat_loader = \ 170 | get_source_data(args.src_dataset, args.data_dir, args.height, 171 | args.width, args.batch_size, args.workers) 172 | # get target data 173 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \ 174 | get_data(args.tgt_dataset, args.data_dir, args.height, 175 | args.width, args.batch_size, args.workers) 176 | 177 | # Create model 178 | # Hacking here to let the classifier be the number of source ids 179 | if args.src_dataset == 'dukemtmc': 180 | model = models.create(args.arch, num_classes=632, pretrained=False) 181 | coModel = models.create(args.arch, num_classes=632, pretrained=False) 182 | elif args.src_dataset == 'market1501': 183 | model = models.create(args.arch, num_classes=676, pretrained=False) 184 | coModel = models.create(args.arch, num_classes=676, pretrained=False) 185 | elif args.src_dataset == 'msmt17': 186 | model = models.create(args.arch, num_classes=1041, pretrained=False) 187 | coModel = models.create(args.arch, num_classes=1041, pretrained=False) 188 | elif args.src_dataset == 'cuhk03': 189 | model = models.create(args.arch, num_classes=1230, pretrained=False) 190 | coModel = models.create(args.arch, num_classes=1230, pretrained=False) 191 | else: 192 | raise RuntimeError('Please specify the number of classes (ids) of the network.') 193 | 194 | # Load from checkpoint 195 | if args.resume: 196 | print('Resuming checkpoints from finetuned model on another dataset...\n') 197 | checkpoint = load_checkpoint(args.resume) 198 | model.load_state_dict(checkpoint['state_dict'], strict=False) 199 | coModel.load_state_dict(checkpoint['state_dict'], strict=False) 200 | else: 201 | raise RuntimeWarning('Not using a pre-trained model.') 202 | model = nn.DataParallel(model).cuda() 203 | coModel = nn.DataParallel(coModel).cuda() 204 | 205 | # Criterion 206 | criterion = [ 207 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(), 208 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda() 209 | ] 210 | optimizer = torch.optim.Adam( 211 | model.parameters(), lr=args.lr 212 | ) 213 | coOptimizer = torch.optim.Adam( 214 | coModel.parameters(), lr=args.lr 215 | ) 216 | 217 | optims = [optimizer, coOptimizer] 218 | 219 | # training stage transformer on input images 220 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 221 | train_transformer = T.Compose([ 222 | T.Resize((args.height, args.width)), 223 | T.RandomHorizontalFlip(), 224 | T.ToTensor(), normalizer, 225 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3) 226 | ]) 227 | 228 | # # Start training 229 | for iter_n in range(args.iteration): 230 | if args.lambda_value == 0: 231 | source_features = 0 232 | else: 233 | # get source datas' feature 234 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq) 235 | # synchronization feature order with src_dataset.train 236 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0) 237 | 238 | # extract training images' features 239 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n + 1)) 240 | target_features, _ = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq) 241 | # synchronization feature order with dataset.train 242 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0) 243 | # calculate distance and rerank result 244 | print('Calculating feature distances...') 245 | target_features = target_features.numpy() 246 | rerank_dist = re_ranking(source_features, target_features, lambda_value=args.lambda_value) 247 | if iter_n == 0: 248 | # DBSCAN cluster 249 | tri_mat = np.triu(rerank_dist, 1) # tri_mat.dim=2 250 | tri_mat = tri_mat[np.nonzero(tri_mat)] # tri_mat.dim=1 251 | tri_mat = np.sort(tri_mat, axis=None) 252 | top_num = np.round(args.rho * tri_mat.size).astype(int) 253 | eps = tri_mat[:top_num].mean() 254 | print('eps in cluster: {:.3f}'.format(eps)) 255 | cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=8) 256 | # select & cluster images as training set of this epochs 257 | print('Clustering and labeling...') 258 | labels = cluster.fit_predict(rerank_dist) 259 | num_ids = len(set(labels)) - 1 260 | print('Iteration {} have {} training ids'.format(iter_n + 1, num_ids)) 261 | # generate new dataset 262 | new_dataset = [] 263 | # assign label for target ones 264 | newLab = labelNoise(torch.from_numpy(target_features), torch.from_numpy(labels)) 265 | # unknownFeats = target_features[labels==-1,:] 266 | counter = 0 267 | from collections import defaultdict 268 | realIDs, fakeIDs = defaultdict(list), [] 269 | for (fname, realID, cam), label in zip(tgt_dataset.trainval, newLab): 270 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0 271 | new_dataset.append((fname, label, cam)) 272 | realIDs[realID].append(counter) 273 | fakeIDs.append(label) 274 | counter += 1 275 | precision, recall, fscore = calScores(realIDs, np.asarray(fakeIDs)) 276 | print('Iteration {} have {} training images'.format(iter_n + 1, len(new_dataset))) 277 | print(f'precision:{precision * 100}, recall:{100 * recall}, fscore:{100 * fscore}') 278 | train_loader = DataLoader( 279 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer), 280 | batch_size=args.batch_size, num_workers=4, 281 | sampler=RandomIdentitySampler(new_dataset, args.num_instances), 282 | pin_memory=True, drop_last=True 283 | ) 284 | trainer = CoTeaching( 285 | model, coModel, train_loader, criterion, optims 286 | ) 287 | 288 | # Start training 289 | for epoch in range(args.epochs): 290 | trainer.train(epoch, remRate=0.2 + (0.8 / args.iteration) * (1 + iter_n)) # to at most 80% 291 | # test only 292 | evaluator = Evaluator(model, print_freq=args.print_freq) 293 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 294 | 295 | # Evaluate 296 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 297 | save_checkpoint({ 298 | 'state_dict': model.module.state_dict(), 299 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0], 300 | }, True, fpath=osp.join(args.logs_dir, 'adapted.pth.tar')) 301 | return rank_score.map, rank_score.market1501[0] 302 | 303 | 304 | if __name__ == '__main__': 305 | parser = argparse.ArgumentParser(description="Triplet loss classification") 306 | # data 307 | parser.add_argument('--src_dataset', type=str, default='dukemtmc', 308 | choices=datasets.names()) 309 | parser.add_argument('--tgt_dataset', type=str, default='market1501', 310 | choices=datasets.names()) 311 | parser.add_argument('--batch_size', type=int, default=64) 312 | parser.add_argument('--workers', type=int, default=4) 313 | parser.add_argument('--split', type=int, default=0) 314 | parser.add_argument('--noiseLam', type=float, default=0.5) 315 | parser.add_argument('--height', type=int, 316 | help="input height, default: 256 for resnet*, " 317 | "144 for inception") 318 | parser.add_argument('--width', type=int, 319 | help="input width, default: 128 for resnet*, " 320 | "56 for inception") 321 | parser.add_argument('--combine-trainval', action='store_true', 322 | help="train and val sets together for training, " 323 | "val set alone for validation") 324 | parser.add_argument('--num_instances', type=int, default=4, 325 | help="each minibatch consist of " 326 | "(batch_size // num_instances) identities, and " 327 | "each identity has num_instances instances, " 328 | "default: 4") 329 | # model 330 | parser.add_argument('--arch', type=str, default='resnet50', 331 | choices=models.names()) 332 | # loss 333 | parser.add_argument('--margin', type=float, default=0.5, 334 | help="margin of the triplet loss, default: 0.5") 335 | parser.add_argument('--lambda_value', type=float, default=0.1, 336 | help="balancing parameter, default: 0.1") 337 | parser.add_argument('--rho', type=float, default=1.6e-3, 338 | help="rho percentage, default: 1.6e-3") 339 | # optimizer 340 | parser.add_argument('--lr', type=float, default=6e-5, 341 | help="learning rate of all parameters") 342 | # training configs 343 | parser.add_argument('--resume', type=str, metavar='PATH', 344 | default='') 345 | parser.add_argument('--evaluate', type=int, default=0, 346 | help="evaluation only") 347 | parser.add_argument('--seed', type=int, default=1) 348 | parser.add_argument('--print_freq', type=int, default=1) 349 | parser.add_argument('--iteration', type=int, default=10) 350 | parser.add_argument('--epochs', type=int, default=30) 351 | # metric learning 352 | parser.add_argument('--dist_metric', type=str, default='euclidean', 353 | choices=['euclidean', 'kissme']) 354 | # misc 355 | parser.add_argument('--data_dir', type=str, metavar='PATH', 356 | default='') 357 | parser.add_argument('--logs_dir', type=str, metavar='PATH', 358 | default='') 359 | 360 | args = parser.parse_args() 361 | mean_ap, rank1 = main(args) 362 | -------------------------------------------------------------------------------- /selftrainingACT.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function, absolute_import 5 | import argparse 6 | import time 7 | import os.path as osp 8 | import os 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from torch.nn import init 13 | from torch.backends import cudnn 14 | from torch.utils.data import DataLoader 15 | from reid import datasets 16 | from reid import models 17 | from reid.loss import TripletLoss 18 | from reid.trainers import CoTrainerAsy 19 | from reid.evaluators import Evaluator, extract_features 20 | from reid.utils.data import transforms as T 21 | import torch.nn.functional as F 22 | from reid.utils.data.preprocessor import Preprocessor 23 | from reid.utils.data.sampler import RandomIdentitySampler 24 | from reid.utils.serialization import load_checkpoint, save_checkpoint 25 | 26 | from sklearn.cluster import DBSCAN 27 | from reid.rerank import re_ranking 28 | 29 | 30 | def calScores(clusters, labels): 31 | """ 32 | compute pair-wise precision pair-wise recall 33 | """ 34 | from scipy.special import comb 35 | if len(clusters) == 0: 36 | return 0, 0 37 | else: 38 | curCluster = [] 39 | for curClus in clusters.values(): 40 | curCluster.append(labels[curClus]) 41 | TPandFP = sum([comb(len(val), 2) for val in curCluster]) 42 | TP = 0 43 | for clusterVal in curCluster: 44 | for setMember in set(clusterVal): 45 | if sum(clusterVal == setMember) < 2: continue 46 | TP += comb(sum(clusterVal == setMember), 2) 47 | FP = TPandFP - TP 48 | # FN and TN 49 | TPandFN = sum([comb(labels.tolist().count(val), 2) for val in set(labels)]) 50 | FN = TPandFN - TP 51 | # cal precision and recall 52 | precision, recall = TP / (TP + FP), TP / (TP + FN) 53 | fScore = 2 * precision * recall / (precision + recall) 54 | return precision, recall, fScore 55 | 56 | 57 | def get_data(name, data_dir, height, width, batch_size, 58 | workers): 59 | root = osp.join(data_dir, name) 60 | 61 | dataset = datasets.create(name, root, num_val=0.1) 62 | 63 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 64 | std=[0.229, 0.224, 0.225]) 65 | 66 | # use all training and validation images in target dataset 67 | train_set = dataset.trainval 68 | num_classes = dataset.num_trainval_ids 69 | 70 | transformer = T.Compose([ 71 | T.Resize((height, width)), 72 | T.ToTensor(), 73 | normalizer, 74 | ]) 75 | 76 | extfeat_loader = DataLoader( 77 | Preprocessor(train_set, root=dataset.images_dir, 78 | transform=transformer), 79 | batch_size=batch_size, num_workers=workers, 80 | shuffle=False, pin_memory=True) 81 | 82 | test_loader = DataLoader( 83 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 84 | root=dataset.images_dir, transform=transformer), 85 | batch_size=batch_size // 2, num_workers=workers, 86 | shuffle=False, pin_memory=True) 87 | 88 | return dataset, num_classes, extfeat_loader, test_loader 89 | 90 | 91 | def saveAll(nameList, rootDir, tarDir): 92 | import os 93 | import shutil 94 | if os.path.exists(tarDir): 95 | shutil.rmtree(tarDir) 96 | os.makedirs(tarDir) 97 | for name in nameList: 98 | shutil.copyfile(os.path.join(rootDir, name), os.path.join(tarDir, name)) 99 | 100 | 101 | def get_source_data(name, data_dir, height, width, batch_size, workers): 102 | root = osp.join(data_dir, name) 103 | 104 | dataset = datasets.create(name, root, num_val=0.1) 105 | 106 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 107 | std=[0.229, 0.224, 0.225]) 108 | 109 | # use all training images on source dataset 110 | train_set = dataset.train 111 | num_classes = dataset.num_train_ids 112 | 113 | transformer = T.Compose([ 114 | T.Resize((height, width)), 115 | T.ToTensor(), 116 | normalizer, 117 | ]) 118 | 119 | extfeat_loader = DataLoader( 120 | Preprocessor(train_set, root=dataset.images_dir, 121 | transform=transformer), 122 | batch_size=batch_size, num_workers=workers, 123 | shuffle=False, pin_memory=True) 124 | 125 | return dataset, extfeat_loader 126 | 127 | 128 | def calDis(qFeature, gFeature): # 246s 129 | x, y = F.normalize(qFeature), F.normalize(gFeature) 130 | # x, y = qFeature, gFeature 131 | m, n = x.shape[0], y.shape[0] 132 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 133 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 134 | disMat.addmm_(1, -2, x, y.t()) 135 | return disMat.clamp_(min=1e-5) 136 | 137 | 138 | def labelUnknown(knownFeat, allLab, unknownFeat): 139 | disMat = calDis(knownFeat, unknownFeat) 140 | labLoc = disMat.argmin(dim=0) 141 | return allLab[labLoc] 142 | 143 | 144 | def labelNoise(feature, labels): 145 | # features and labels with -1 146 | noiseFeat, pureFeat = feature[labels == -1, :], feature[labels != -1, :] 147 | labels = labels[labels != -1] 148 | unLab = labelUnknown(pureFeat, labels, noiseFeat) 149 | return unLab.numpy() 150 | 151 | 152 | def main(args): 153 | np.random.seed(args.seed) 154 | torch.manual_seed(args.seed) 155 | cudnn.benchmark = True 156 | 157 | # Create data loaders 158 | assert args.num_instances > 1, "num_instances should be greater than 1" 159 | assert args.batch_size % args.num_instances == 0, \ 160 | 'num_instances should divide batch_size' 161 | if args.height is None or args.width is None: 162 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 163 | (256, 128) 164 | 165 | # get source data 166 | src_dataset, src_extfeat_loader = \ 167 | get_source_data(args.src_dataset, args.data_dir, args.height, 168 | args.width, args.batch_size, args.workers) 169 | # get target data 170 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \ 171 | get_data(args.tgt_dataset, args.data_dir, args.height, 172 | args.width, args.batch_size, args.workers) 173 | 174 | # Create model 175 | # Hacking here to let the classifier be the number of source ids 176 | if args.src_dataset == 'dukemtmc': 177 | model = models.create(args.arch, num_classes=632, pretrained=False) 178 | coModel = models.create(args.arch, num_classes=632, pretrained=False) 179 | elif args.src_dataset == 'market1501': 180 | model = models.create(args.arch, num_classes=676, pretrained=False) 181 | coModel = models.create(args.arch, num_classes=676, pretrained=False) 182 | elif args.src_dataset == 'msmt17': 183 | model = models.create(args.arch, num_classes=1041, pretrained=False) 184 | coModel = models.create(args.arch, num_classes=1041, pretrained=False) 185 | elif args.src_dataset == 'cuhk03': 186 | model = models.create(args.arch, num_classes=1230, pretrained=False) 187 | coModel = models.create(args.arch, num_classes=1230, pretrained=False) 188 | else: 189 | raise RuntimeError('Please specify the number of classes (ids) of the network.') 190 | 191 | # Load from checkpoint 192 | if args.resume: 193 | print('Resuming checkpoints from finetuned model on another dataset...\n') 194 | checkpoint = load_checkpoint(args.resume) 195 | model.load_state_dict(checkpoint['state_dict'], strict=False) 196 | coModel.load_state_dict(checkpoint['state_dict'], strict=False) 197 | else: 198 | raise RuntimeWarning('Not using a pre-trained model.') 199 | model = nn.DataParallel(model).cuda() 200 | coModel = nn.DataParallel(coModel).cuda() 201 | 202 | evaluator = Evaluator(model, print_freq=args.print_freq) 203 | evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 204 | # if args.evaluate: return 205 | 206 | # Criterion 207 | criterion = [ 208 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(), 209 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(), 210 | ] 211 | 212 | # Optimizer 213 | optimizer = torch.optim.Adam( 214 | model.parameters(), lr=args.lr 215 | ) 216 | coOptimizer = torch.optim.Adam( 217 | coModel.parameters(), lr=args.lr 218 | ) 219 | 220 | optims = [optimizer, coOptimizer] 221 | 222 | # training stage transformer on input images 223 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 224 | train_transformer = T.Compose([ 225 | T.Resize((args.height, args.width)), 226 | T.RandomHorizontalFlip(), 227 | T.ToTensor(), normalizer, 228 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3) 229 | ]) 230 | 231 | # # Start training 232 | for iter_n in range(args.iteration): 233 | if args.lambda_value == 0: 234 | source_features = 0 235 | else: 236 | # get source datas' feature 237 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq) 238 | # synchronization feature order with src_dataset.train 239 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0) 240 | 241 | # extract training images' features 242 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n + 1)) 243 | target_features, _ = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq) 244 | # synchronization feature order with dataset.train 245 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0) 246 | # calculate distance and rerank result 247 | print('Calculating feature distances...') 248 | target_features = target_features.numpy() 249 | rerank_dist = re_ranking(source_features, target_features, lambda_value=args.lambda_value) 250 | if iter_n == 0: 251 | # DBSCAN cluster 252 | tri_mat = np.triu(rerank_dist, 1) # tri_mat.dim=2 253 | tri_mat = tri_mat[np.nonzero(tri_mat)] # tri_mat.dim=1 254 | tri_mat = np.sort(tri_mat, axis=None) 255 | top_num = np.round(args.rho * tri_mat.size).astype(int) 256 | eps = tri_mat[:top_num].mean() 257 | print('eps in cluster: {:.3f}'.format(eps)) 258 | cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=8) 259 | 260 | # select & cluster images as training set of this epochs 261 | print('Clustering and labeling...') 262 | labels = cluster.fit_predict(rerank_dist) 263 | num_ids = len(set(labels)) - 1 264 | print('Iteration {} have {} training ids'.format(iter_n + 1, num_ids)) 265 | # generate new dataset 266 | new_dataset, unknown_dataset = [], [] 267 | # assign label for target ones 268 | unknownLab = labelNoise(torch.from_numpy(target_features), torch.from_numpy(labels)) 269 | # unknownFeats = target_features[labels==-1,:] 270 | unCounter, index = 0, 0 271 | from collections import defaultdict 272 | realIDs, fakeIDs = defaultdict(list), [] 273 | for (fname, realPID, cam), label in zip(tgt_dataset.trainval, labels): 274 | if label == -1: 275 | unknown_dataset.append((fname, int(unknownLab[unCounter]), cam)) # unknown data 276 | fakeIDs.append(int(unknownLab[unCounter])) 277 | realIDs[realPID].append(index) 278 | unCounter += 1 279 | index += 1 280 | continue 281 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0 282 | new_dataset.append((fname, label, cam)) 283 | fakeIDs.append(label) 284 | realIDs[realPID].append(index) 285 | index += 1 286 | print('Iteration {} have {} training images'.format(iter_n + 1, len(new_dataset))) 287 | precision, recall, fscore = calScores(realIDs, np.asarray(fakeIDs)) # fakeIDs does not contain -1 288 | print('precision:{}, recall:{}, fscore: {}'.format(100 * precision, 100 * recall, fscore)) 289 | 290 | train_loader = DataLoader( 291 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer), 292 | batch_size=args.batch_size, num_workers=4, 293 | sampler=RandomIdentitySampler(new_dataset, args.num_instances), 294 | pin_memory=True, drop_last=True 295 | ) 296 | # hard samples 297 | # noiseImgs = [name[1] for name in unknown_dataset] 298 | # saveAll(noiseImgs, tgt_dataset.images_dir, 'noiseImg') 299 | # import ipdb; ipdb.set_trace() 300 | unLoader = DataLoader( 301 | Preprocessor(unknown_dataset, root=tgt_dataset.images_dir, transform=train_transformer), 302 | batch_size=args.batch_size, num_workers=4, 303 | sampler=RandomIdentitySampler(unknown_dataset, args.num_instances), 304 | pin_memory=True, drop_last=True 305 | ) 306 | # train model with new generated dataset 307 | trainer = CoTrainerAsy( 308 | model, coModel, train_loader, unLoader, criterion, optims 309 | ) 310 | 311 | # Start training 312 | for epoch in range(args.epochs): 313 | trainer.train(epoch, remRate=0.2 + (0.8 / args.iteration) * (1 + iter_n)) 314 | 315 | # test only 316 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 317 | # print('co-model:\n') 318 | # rank_score = evaluatorB.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 319 | 320 | # Evaluate 321 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 322 | save_checkpoint({ 323 | 'state_dict': model.module.state_dict(), 324 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0], 325 | }, True, fpath=osp.join(args.logs_dir, 'asyCo.pth')) 326 | return rank_score.map, rank_score.market1501[0] 327 | 328 | 329 | if __name__ == '__main__': 330 | parser = argparse.ArgumentParser(description="Triplet loss classification") 331 | # data 332 | parser.add_argument('--src_dataset', type=str, default='dukemtmc', 333 | choices=datasets.names()) 334 | parser.add_argument('--tgt_dataset', type=str, default='market1501', 335 | choices=datasets.names()) 336 | parser.add_argument('--batch_size', type=int, default=64) 337 | parser.add_argument('--workers', type=int, default=4) 338 | parser.add_argument('--split', type=int, default=0) 339 | parser.add_argument('--noiseLam', type=float, default=0.5) 340 | parser.add_argument('--height', type=int, 341 | help="input height, default: 256 for resnet*, " 342 | "144 for inception") 343 | parser.add_argument('--width', type=int, 344 | help="input width, default: 128 for resnet*, " 345 | "56 for inception") 346 | parser.add_argument('--combine-trainval', action='store_true', 347 | help="train and val sets together for training, " 348 | "val set alone for validation") 349 | parser.add_argument('--num_instances', type=int, default=4, 350 | help="each minibatch consist of " 351 | "(batch_size // num_instances) identities, and " 352 | "each identity has num_instances instances, " 353 | "default: 4") 354 | # model 355 | parser.add_argument('--arch', type=str, default='resnet50', 356 | choices=models.names()) 357 | # loss 358 | parser.add_argument('--margin', type=float, default=0.5, 359 | help="margin of the triplet loss, default: 0.5") 360 | parser.add_argument('--lambda_value', type=float, default=0.1, 361 | help="balancing parameter, default: 0.1") 362 | parser.add_argument('--rho', type=float, default=1.6e-3, 363 | help="rho percentage, default: 1.6e-3") 364 | # optimizer 365 | parser.add_argument('--lr', type=float, default=6e-5, 366 | help="learning rate of all parameters") 367 | # training configs 368 | parser.add_argument('--resume', type=str, metavar='PATH', 369 | default='') 370 | parser.add_argument('--evaluate', type=int, default=0, 371 | help="evaluation only") 372 | parser.add_argument('--seed', type=int, default=1) 373 | parser.add_argument('--print_freq', type=int, default=1) 374 | parser.add_argument('--iteration', type=int, default=10) 375 | parser.add_argument('--epochs', type=int, default=30) 376 | # metric learning 377 | parser.add_argument('--dist_metric', type=str, default='euclidean', 378 | choices=['euclidean', 'kissme']) 379 | # misc 380 | parser.add_argument('--data_dir', type=str, metavar='PATH', 381 | default='') 382 | parser.add_argument('--logs_dir', type=str, metavar='PATH', 383 | default='') 384 | 385 | args = parser.parse_args() 386 | mean_ap, rank1 = main(args) 387 | -------------------------------------------------------------------------------- /selftrainingRCT.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function, absolute_import 5 | import argparse 6 | import time 7 | import os.path as osp 8 | import os 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from torch.nn import init 13 | from torch.backends import cudnn 14 | from torch.utils.data import DataLoader 15 | from reid import datasets 16 | from reid import models 17 | from reid.dist_metric import DistanceMetric 18 | from reid.loss import TripletLoss 19 | from reid.trainers import RCoTeaching 20 | from reid.evaluators import Evaluator, extract_features 21 | from reid.utils.data import transforms as T 22 | import torch.nn.functional as F 23 | from reid.utils.data.preprocessor import Preprocessor 24 | from reid.utils.data.sampler import RandomIdentitySampler 25 | from reid.utils.serialization import load_checkpoint, save_checkpoint 26 | 27 | from sklearn.cluster import DBSCAN 28 | from reid.rerank import re_ranking 29 | 30 | 31 | def calScores(clusters, labels): 32 | """ 33 | compute pair-wise precision pair-wise recall 34 | """ 35 | from scipy.special import comb 36 | if len(clusters) == 0: 37 | return 0, 0 38 | else: 39 | curCluster = [] 40 | for curClus in clusters.values(): 41 | curCluster.append(labels[curClus]) 42 | TPandFP = sum([comb(len(val), 2) for val in curCluster]) 43 | TP = 0 44 | for clusterVal in curCluster: 45 | for setMember in set(clusterVal): 46 | if sum(clusterVal == setMember) < 2: continue 47 | TP += comb(sum(clusterVal == setMember), 2) 48 | FP = TPandFP - TP 49 | # FN and TN 50 | TPandFN = sum([comb(labels.tolist().count(val), 2) for val in set(labels)]) 51 | FN = TPandFN - TP 52 | # cal precision and recall 53 | precision, recall = TP / (TP + FP), TP / (TP + FN) 54 | fScore = 2 * precision * recall / (precision + recall) 55 | return precision, recall, fScore 56 | 57 | 58 | def get_data(name, data_dir, height, width, batch_size, 59 | workers): 60 | root = osp.join(data_dir, name) 61 | 62 | dataset = datasets.create(name, root, num_val=0.1) 63 | 64 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 65 | std=[0.229, 0.224, 0.225]) 66 | 67 | # use all training and validation images in target dataset 68 | train_set = dataset.trainval 69 | num_classes = dataset.num_trainval_ids 70 | 71 | transformer = T.Compose([ 72 | T.Resize((height, width)), 73 | T.ToTensor(), 74 | normalizer, 75 | ]) 76 | 77 | extfeat_loader = DataLoader( 78 | Preprocessor(train_set, root=dataset.images_dir, 79 | transform=transformer), 80 | batch_size=batch_size, num_workers=workers, 81 | shuffle=False, pin_memory=True) 82 | 83 | test_loader = DataLoader( 84 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 85 | root=dataset.images_dir, transform=transformer), 86 | batch_size=batch_size // 2, num_workers=workers, 87 | shuffle=False, pin_memory=True) 88 | 89 | return dataset, num_classes, extfeat_loader, test_loader 90 | 91 | 92 | def saveAll(nameList, rootDir, tarDir): 93 | import os 94 | import shutil 95 | if os.path.exists(tarDir): 96 | shutil.rmtree(tarDir) 97 | os.makedirs(tarDir) 98 | for name in nameList: 99 | shutil.copyfile(os.path.join(rootDir, name), os.path.join(tarDir, name)) 100 | 101 | 102 | def get_source_data(name, data_dir, height, width, batch_size, workers): 103 | root = osp.join(data_dir, name) 104 | 105 | dataset = datasets.create(name, root, num_val=0.1) 106 | 107 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 108 | std=[0.229, 0.224, 0.225]) 109 | 110 | # use all training images on source dataset 111 | train_set = dataset.train 112 | num_classes = dataset.num_train_ids 113 | 114 | transformer = T.Compose([ 115 | T.Resize((height, width)), 116 | T.ToTensor(), 117 | normalizer, 118 | ]) 119 | 120 | extfeat_loader = DataLoader( 121 | Preprocessor(train_set, root=dataset.images_dir, 122 | transform=transformer), 123 | batch_size=batch_size, num_workers=workers, 124 | shuffle=False, pin_memory=True) 125 | 126 | return dataset, extfeat_loader 127 | 128 | 129 | def calDis(qFeature, gFeature): # 246s 130 | x, y = F.normalize(qFeature), F.normalize(gFeature) 131 | # x, y = qFeature, gFeature 132 | m, n = x.shape[0], y.shape[0] 133 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 134 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 135 | disMat.addmm_(1, -2, x, y.t()) 136 | return disMat.clamp_(min=1e-5) 137 | 138 | 139 | def labelUnknown(knownFeat, allLab, unknownFeat): 140 | disMat = calDis(knownFeat, unknownFeat) 141 | labLoc = disMat.argmin(dim=0) 142 | return allLab[labLoc] 143 | 144 | 145 | def labelNoise(feature, labels): 146 | # features and labels with -1 147 | noiseFeat, pureFeat = feature[labels == -1, :], feature[labels != -1, :] 148 | labels = labels[labels != -1] 149 | unLab = labelUnknown(pureFeat, labels, noiseFeat) 150 | return unLab.numpy() 151 | 152 | 153 | def main(args): 154 | np.random.seed(args.seed) 155 | torch.manual_seed(args.seed) 156 | cudnn.benchmark = True 157 | 158 | # Create data loaders 159 | assert args.num_instances > 1, "num_instances should be greater than 1" 160 | assert args.batch_size % args.num_instances == 0, \ 161 | 'num_instances should divide batch_size' 162 | if args.height is None or args.width is None: 163 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 164 | (256, 128) 165 | 166 | # get source data 167 | src_dataset, src_extfeat_loader = \ 168 | get_source_data(args.src_dataset, args.data_dir, args.height, 169 | args.width, args.batch_size, args.workers) 170 | # get target data 171 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \ 172 | get_data(args.tgt_dataset, args.data_dir, args.height, 173 | args.width, args.batch_size, args.workers) 174 | 175 | # Create model 176 | # Hacking here to let the classifier be the number of source ids 177 | if args.src_dataset == 'dukemtmc': 178 | model = models.create(args.arch, num_classes=632, pretrained=False) 179 | coModel = models.create(args.arch, num_classes=632, pretrained=False) 180 | elif args.src_dataset == 'market1501': 181 | model = models.create(args.arch, num_classes=676, pretrained=False) 182 | coModel = models.create(args.arch, num_classes=676, pretrained=False) 183 | elif args.src_dataset == 'msmt17': 184 | model = models.create(args.arch, num_classes=1041, pretrained=False) 185 | coModel = models.create(args.arch, num_classes=1041, pretrained=False) 186 | elif args.src_dataset == 'cuhk03': 187 | model = models.create(args.arch, num_classes=1230, pretrained=False) 188 | coModel = models.create(args.arch, num_classes=1230, pretrained=False) 189 | else: 190 | raise RuntimeError('Please specify the number of classes (ids) of the network.') 191 | 192 | # Load from checkpoint 193 | if args.resume: 194 | print('Resuming checkpoints from finetuned model on another dataset...\n') 195 | checkpoint = load_checkpoint(args.resume) 196 | model.load_state_dict(checkpoint['state_dict'], strict=False) 197 | coModel.load_state_dict(checkpoint['state_dict'], strict=False) 198 | else: 199 | raise RuntimeWarning('Not using a pre-trained model.') 200 | model = nn.DataParallel(model).cuda() 201 | coModel = nn.DataParallel(coModel).cuda() 202 | 203 | evaluator = Evaluator(model, print_freq=args.print_freq) 204 | # evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 205 | # if args.evaluate: return 206 | 207 | # Criterion 208 | criterion = [ 209 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(), 210 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(), 211 | ] 212 | 213 | # Optimizer 214 | optimizer = torch.optim.Adam( 215 | model.parameters(), lr=args.lr 216 | ) 217 | coOptimizer = torch.optim.Adam( 218 | coModel.parameters(), lr=args.lr 219 | ) 220 | 221 | optims = [optimizer, coOptimizer] 222 | 223 | # training stage transformer on input images 224 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 225 | train_transformer = T.Compose([ 226 | T.Resize((args.height, args.width)), 227 | T.RandomHorizontalFlip(), 228 | T.ToTensor(), normalizer, 229 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3) 230 | ]) 231 | 232 | # # Start training 233 | for iter_n in range(args.iteration): 234 | if args.lambda_value == 0: 235 | source_features = 0 236 | else: 237 | # get source datas' feature 238 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq) 239 | # synchronization feature order with src_dataset.train 240 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0) 241 | 242 | # extract training images' features 243 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n + 1)) 244 | target_features, _ = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq) 245 | # synchronization feature order with dataset.train 246 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0) 247 | # calculate distance and rerank result 248 | print('Calculating feature distances...') 249 | target_features = target_features.numpy() 250 | rerank_dist = re_ranking(source_features, target_features, lambda_value=args.lambda_value) 251 | if iter_n == 0: 252 | # DBSCAN cluster 253 | tri_mat = np.triu(rerank_dist, 1) # tri_mat.dim=2 254 | tri_mat = tri_mat[np.nonzero(tri_mat)] # tri_mat.dim=1 255 | tri_mat = np.sort(tri_mat, axis=None) 256 | top_num = np.round(args.rho * tri_mat.size).astype(int) 257 | eps = tri_mat[:top_num].mean() 258 | print('eps in cluster: {:.3f}'.format(eps)) 259 | cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=8) 260 | 261 | # select & cluster images as training set of this epochs 262 | print('Clustering and labeling...') 263 | labels = cluster.fit_predict(rerank_dist) 264 | num_ids = len(set(labels)) - 1 265 | print('Iteration {} have {} training ids'.format(iter_n + 1, num_ids)) 266 | # generate new dataset 267 | new_dataset, unknown_dataset = [], [] 268 | # assign label for target ones 269 | unknownLab = labelNoise(torch.from_numpy(target_features), torch.from_numpy(labels)) 270 | # unknownFeats = target_features[labels==-1,:] 271 | unCounter, index = 0, 0 272 | from collections import defaultdict 273 | realIDs, fakeIDs = defaultdict(list), [] 274 | for (fname, realPID, cam), label in zip(tgt_dataset.trainval, labels): 275 | if label == -1: 276 | unknown_dataset.append((fname, int(unknownLab[unCounter]), cam)) # unknown data 277 | fakeIDs.append(int(unknownLab[unCounter])) 278 | realIDs[realPID].append(index) 279 | unCounter += 1 280 | index += 1 281 | continue 282 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0 283 | new_dataset.append((fname, label, cam)) 284 | fakeIDs.append(label) 285 | realIDs[realPID].append(index) 286 | index += 1 287 | print('Iteration {} have {} training images'.format(iter_n + 1, len(new_dataset))) 288 | precision, recall, fscore = calScores(realIDs, np.asarray(fakeIDs)) # fakeIDs does not contain -1 289 | print('precision:{}, recall:{}, fscore: {}'.format(100 * precision, 100 * recall, fscore)) 290 | 291 | train_loader = DataLoader( 292 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer), 293 | batch_size=args.batch_size, num_workers=4, 294 | sampler=RandomIdentitySampler(new_dataset, args.num_instances), 295 | pin_memory=True, drop_last=True 296 | ) 297 | # hard samples 298 | # noiseImgs = [name[1] for name in unknown_dataset] 299 | # saveAll(noiseImgs, tgt_dataset.images_dir, 'noiseImg') 300 | # import ipdb; ipdb.set_trace() 301 | unLoader = DataLoader( 302 | Preprocessor(unknown_dataset, root=tgt_dataset.images_dir, transform=train_transformer), 303 | batch_size=args.batch_size, num_workers=4, 304 | sampler=RandomIdentitySampler(unknown_dataset, args.num_instances), 305 | pin_memory=True, drop_last=True 306 | ) 307 | # train model with new generated dataset 308 | trainer = RCoTeaching( 309 | model, coModel, train_loader, unLoader, criterion, optims 310 | ) 311 | 312 | # Start training 313 | for epoch in range(args.epochs): 314 | trainer.train(epoch, remRate=0.2 + (0.8 / args.iteration) * (1 + iter_n)) 315 | 316 | # test only 317 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 318 | # print('co-model:\n') 319 | # rank_score = evaluatorB.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 320 | 321 | # Evaluate 322 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery) 323 | save_checkpoint({ 324 | 'state_dict': model.module.state_dict(), 325 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0], 326 | }, True, fpath=osp.join(args.logs_dir, 'RCT.pth')) 327 | return rank_score.map, rank_score.market1501[0] 328 | 329 | 330 | if __name__ == '__main__': 331 | parser = argparse.ArgumentParser(description="Triplet loss classification") 332 | # data 333 | parser.add_argument('--src_dataset', type=str, default='dukemtmc', 334 | choices=datasets.names()) 335 | parser.add_argument('--tgt_dataset', type=str, default='market1501', 336 | choices=datasets.names()) 337 | parser.add_argument('--batch_size', type=int, default=64) 338 | parser.add_argument('--workers', type=int, default=4) 339 | parser.add_argument('--split', type=int, default=0) 340 | parser.add_argument('--noiseLam', type=float, default=0.5) 341 | parser.add_argument('--height', type=int, 342 | help="input height, default: 256 for resnet*, " 343 | "144 for inception") 344 | parser.add_argument('--width', type=int, 345 | help="input width, default: 128 for resnet*, " 346 | "56 for inception") 347 | parser.add_argument('--combine-trainval', action='store_true', 348 | help="train and val sets together for training, " 349 | "val set alone for validation") 350 | parser.add_argument('--num_instances', type=int, default=4, 351 | help="each minibatch consist of " 352 | "(batch_size // num_instances) identities, and " 353 | "each identity has num_instances instances, " 354 | "default: 4") 355 | # model 356 | parser.add_argument('--arch', type=str, default='resnet50', 357 | choices=models.names()) 358 | # loss 359 | parser.add_argument('--margin', type=float, default=0.5, 360 | help="margin of the triplet loss, default: 0.5") 361 | parser.add_argument('--lambda_value', type=float, default=0.1, 362 | help="balancing parameter, default: 0.1") 363 | parser.add_argument('--rho', type=float, default=1.6e-3, 364 | help="rho percentage, default: 1.6e-3") 365 | # optimizer 366 | parser.add_argument('--lr', type=float, default=6e-5, 367 | help="learning rate of all parameters") 368 | # training configs 369 | parser.add_argument('--resume', type=str, metavar='PATH', 370 | default='') 371 | parser.add_argument('--evaluate', type=int, default=0, 372 | help="evaluation only") 373 | parser.add_argument('--seed', type=int, default=1) 374 | parser.add_argument('--print_freq', type=int, default=1) 375 | parser.add_argument('--iteration', type=int, default=10) 376 | parser.add_argument('--epochs', type=int, default=30) 377 | # metric learning 378 | parser.add_argument('--dist_metric', type=str, default='euclidean', 379 | choices=['euclidean', 'kissme']) 380 | # misc 381 | parser.add_argument('--data_dir', type=str, metavar='PATH', 382 | default='') 383 | parser.add_argument('--logs_dir', type=str, metavar='PATH', 384 | default='') 385 | 386 | args = parser.parse_args() 387 | mean_ap, rank1 = main(args) 388 | --------------------------------------------------------------------------------