├── figures
├── ACT.jpg
└── MSMT.jpg
├── reid
├── utils
│ ├── data
│ │ ├── __init__.py
│ │ ├── preprocessor.py
│ │ ├── sampler.py
│ │ ├── transforms.py
│ │ ├── video_loader.py
│ │ └── dataset.py
│ ├── osutils.py
│ ├── meters.py
│ ├── __init__.py
│ ├── logging.py
│ └── serialization.py
├── evaluation_metrics
│ ├── __init__.py
│ ├── classification.py
│ └── ranking.py
├── feature_extraction
│ ├── __init__.py
│ ├── cnn.py
│ └── database.py
├── __init__.py
├── loss
│ ├── __init__.py
│ ├── gan_loss.py
│ ├── var_loss.py
│ ├── neighbour_loss.py
│ ├── center_loss.py
│ ├── matchLoss.py
│ ├── oim.py
│ ├── triplet.py
│ └── virtual_ce.py
├── models
│ ├── module.py
│ ├── functional.py
│ ├── classifier.py
│ ├── __init__.py
│ ├── resnet.py
│ └── inception.py
├── metric_learning
│ ├── euclidean.py
│ ├── __init__.py
│ └── kissme.py
├── dataloader.py
├── dist_metric.py
├── datasets
│ ├── __init__.py
│ ├── utils.py
│ ├── cuhk03.py
│ ├── msmt17.py
│ ├── market1501.py
│ └── dukemtmc.py
├── evaluators.py
└── rerank.py
├── requirements.txt
├── data
└── readme.md
├── readme.md
├── selftrainingKmeans.py
├── selfNoise.py
├── selftrainingKmeansAsy.py
├── selftrainingCT.py
├── selftrainingACT.py
└── selftrainingRCT.py
/figures/ACT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FlyingRoastDuck/ACT_AAAI20/HEAD/figures/ACT.jpg
--------------------------------------------------------------------------------
/figures/MSMT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FlyingRoastDuck/ACT_AAAI20/HEAD/figures/MSMT.jpg
--------------------------------------------------------------------------------
/reid/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .dataset import Dataset
4 | from .preprocessor import Preprocessor
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.16.4
2 | matplotlib==3.1.1
3 | torch==1.3.1
4 | metric_learn==0.4.0
5 | tqdm==4.32.2
6 | torchvision==0.2.0
7 | scipy==1.1.0
8 | h5py==2.9.0
9 | Pillow==6.2.1
10 | six==1.13.0
11 | scikit_learn==0.21.3
12 |
--------------------------------------------------------------------------------
/reid/evaluation_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .classification import accuracy
4 | from .ranking import cmc, mean_ap
5 |
6 | __all__ = [
7 | 'accuracy',
8 | 'cmc',
9 | 'mean_ap',
10 | ]
11 |
--------------------------------------------------------------------------------
/reid/feature_extraction/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .cnn import extract_cnn_feature
4 | from .database import FeatureDatabase
5 |
6 | __all__ = [
7 | 'extract_cnn_feature',
8 | 'FeatureDatabase',
9 | ]
10 |
--------------------------------------------------------------------------------
/reid/utils/osutils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import errno
4 |
5 |
6 | def mkdir_if_missing(dir_path):
7 | try:
8 | os.makedirs(dir_path)
9 | except OSError as e:
10 | if e.errno != errno.EEXIST:
11 | raise
12 |
--------------------------------------------------------------------------------
/data/readme.md:
--------------------------------------------------------------------------------
1 | # put all datasets under this folder like this:
2 |
3 | ```
4 | data
5 | ----cuhk03
6 | --------raw
7 | ------------cuhk03_release.zip
8 | ----market1501
9 | --------raw
10 | ------------Market-1501-v15.09.15.zip
11 | ----dukemtmc
12 | --------raw
13 | ------------DukeMTMC-reID.zip
14 | ```
--------------------------------------------------------------------------------
/reid/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from . import datasets
4 | from . import evaluation_metrics
5 | from . import feature_extraction
6 | from . import loss
7 | from . import metric_learning
8 | from . import models
9 | from . import utils
10 | from . import dist_metric
11 | from . import evaluators
12 | from . import trainers
13 | from . import rerank
14 | from . import dataloader
15 |
16 | __version__ = '0.2.0'
17 |
--------------------------------------------------------------------------------
/reid/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .oim import oim, OIM, OIMLoss
4 | from .triplet import TripletLoss
5 | from .center_loss import CenterLoss
6 | from .matchLoss import lossMMD
7 | from .neighbour_loss import NeiLoss
8 | from .virtual_ce import VirtualCE, VirtualKCE
9 |
10 | __all__ = [
11 | 'oim', 'OIM', 'OIMLoss','NeiLoss',
12 | 'TripletLoss', 'CenterLoss', 'lossMMD', 'VirtualCE', 'VirtualKCE'
13 | ]
14 |
--------------------------------------------------------------------------------
/reid/models/module.py:
--------------------------------------------------------------------------------
1 | from .functional import revgrad
2 | from torch.nn import Module
3 |
4 |
5 | class RevGrad(Module):
6 | def __init__(self, *args, **kwargs):
7 | """
8 | A gradient reversal layer.
9 |
10 | This layer has no parameters, and simply reverses the gradient
11 | in the backward pass.
12 | """
13 |
14 | super().__init__(*args, **kwargs)
15 |
16 | def forward(self, input_):
17 | return revgrad(input_)
18 |
--------------------------------------------------------------------------------
/reid/models/functional.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Function
2 |
3 |
4 | class RevGrad(Function):
5 | @staticmethod
6 | def forward(ctx, input_):
7 | ctx.save_for_backward(input_)
8 | output = input_
9 | return output
10 |
11 | @staticmethod
12 | def backward(ctx, grad_output): # pragma: no cover
13 | grad_input = None
14 | if ctx.needs_input_grad[0]:
15 | grad_input = -grad_output
16 | return grad_input
17 |
18 |
19 | revgrad = RevGrad.apply
20 |
--------------------------------------------------------------------------------
/reid/metric_learning/euclidean.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import numpy as np
4 | from metric_learn.base_metric import BaseMetricLearner
5 |
6 |
7 | class Euclidean(BaseMetricLearner):
8 | def __init__(self):
9 | self.M_ = None
10 |
11 | def metric(self):
12 | return self.M_
13 |
14 | def fit(self, X):
15 | self.M_ = np.eye(X.shape[1])
16 | self.X_ = X
17 |
18 | def transform(self, X=None):
19 | if X is None:
20 | return self.X_
21 | return X
22 |
--------------------------------------------------------------------------------
/reid/utils/meters.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 |
7 | def __init__(self):
8 | self.val = 0
9 | self.avg = 0
10 | self.sum = 0
11 | self.count = 0
12 |
13 | def reset(self):
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 |
--------------------------------------------------------------------------------
/reid/evaluation_metrics/classification.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from ..utils import to_torch
4 |
5 |
6 | def accuracy(output, target, topk=(1,)):
7 | output, target = to_torch(output), to_torch(target)
8 | maxk = max(topk)
9 | batch_size = target.size(0)
10 |
11 | _, pred = output.topk(maxk, 1, True, True)
12 | pred = pred.t()
13 | correct = pred.eq(target.view(1, -1).expand_as(pred))
14 |
15 | ret = []
16 | for k in topk:
17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
18 | ret.append(correct_k.mul_(1. / batch_size))
19 | return ret
20 |
--------------------------------------------------------------------------------
/reid/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 |
5 |
6 | def to_numpy(tensor):
7 | if torch.is_tensor(tensor):
8 | return tensor.cpu().numpy()
9 | elif type(tensor).__module__ != 'numpy':
10 | raise ValueError("Cannot convert {} to numpy array"
11 | .format(type(tensor)))
12 | return tensor
13 |
14 |
15 | def to_torch(ndarray):
16 | if type(ndarray).__module__ == 'numpy':
17 | return torch.from_numpy(ndarray).cuda()
18 | elif not torch.is_tensor(ndarray):
19 | raise ValueError("Cannot convert {} to torch tensor"
20 | .format(type(ndarray)))
21 | return ndarray.cuda()
22 |
--------------------------------------------------------------------------------
/reid/loss/gan_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class GanLoss(nn.Module):
6 | def __init__(self, modelD, modelG, inDim=2048, outDim=2, lossD=nn.CrossEntropyLoss(), lossG=nn.CrossEntropyLoss()):
7 | super(GanLoss, self).__init__()
8 | self.inDim = inDim
9 | self.outDim = outDim
10 | self.modelD = modelD
11 | self.modelG = modelG
12 | self.lossG = lossG
13 | self.lossD = lossD
14 |
15 | def forward(self, x, labels, domainLab):
16 | dScore, gScore = self.modelD(x), self.modelG(x)
17 | lossDomain = self.lossD(dScore, domainLab)
18 | lossCls = self.lossG(gScore, labels)
19 | return lossDomain, lossCls
20 |
--------------------------------------------------------------------------------
/reid/metric_learning/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from metric_learn import (ITML_Supervised, LMNN, LSML_Supervised,
4 | SDML_Supervised, NCA, LFDA, RCA_Supervised)
5 |
6 | from .euclidean import Euclidean
7 | from .kissme import KISSME
8 |
9 | __factory = {
10 | 'euclidean': Euclidean,
11 | 'kissme': KISSME,
12 | 'itml': ITML_Supervised,
13 | 'lmnn': LMNN,
14 | 'lsml': LSML_Supervised,
15 | 'sdml': SDML_Supervised,
16 | 'nca': NCA,
17 | 'lfda': LFDA,
18 | 'rca': RCA_Supervised,
19 | }
20 |
21 |
22 | def get_metric(algorithm, *args, **kwargs):
23 | if algorithm not in __factory:
24 | raise KeyError("Unknown metric:", algorithm)
25 | return __factory[algorithm](*args, **kwargs)
26 |
--------------------------------------------------------------------------------
/reid/feature_extraction/cnn.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import OrderedDict
3 |
4 | from ..utils import to_torch
5 | from torch.autograd import Variable
6 |
7 |
8 | def extract_cnn_feature(model, inputs):
9 | model.eval()
10 | inputs = Variable(to_torch(inputs))
11 |
12 | outputs = model(inputs)[0]
13 | outputs = outputs.data.cpu()
14 | return outputs # pool5
15 |
16 | # Register forward hook for each module
17 | outputs = OrderedDict()
18 | handles = []
19 | for m in modules:
20 | outputs[id(m)] = None
21 |
22 | def func(m, i, o): outputs[id(m)] = o.data.cpu()
23 |
24 | handles.append(m.register_forward_hook(func))
25 | model(inputs)
26 | for h in handles:
27 | h.remove()
28 | return list(outputs.values())
29 |
--------------------------------------------------------------------------------
/reid/models/classifier.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import math
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from torch.nn import init
8 | from torch.autograd import Variable
9 | import torchvision
10 | # from torch_deform_conv.layers import ConvOffset2D
11 | from reid.utils.serialization import load_checkpoint, save_checkpoint
12 |
13 |
14 | class classifier(nn.Module):
15 | def __init__(self, inDim, outDim):
16 | super(classifier,self).__init__()
17 | self.cls = nn.Linear(inDim, outDim)
18 | # self.cls = nn.Sequential(
19 | # nn.Linear(inDim, outDim),
20 | # # nn.ReLU()
21 | # )
22 | init.normal(self.cls.weight, std=0.001)
23 | init.constant(self.cls.bias, 0)
24 |
25 | def forward(self, x):
26 | return self.cls(x)
--------------------------------------------------------------------------------
/reid/dataloader.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Variable
2 | import numpy as np
3 | import torch
4 |
5 |
6 | class GraphLoader(object):
7 | def __init__(self, trainList, labels, model, loss=None):
8 | # self.hw = [384,128]
9 | self.graphs = trainList
10 | self.ID = labels
11 | self.model = model
12 | self.loss = loss
13 |
14 | def __getitem__(self, idx):
15 | curGraph = self.graphs[idx]
16 | # featSize = curGraph.size(0)
17 | # useFeat = curGraph[np.random.choice(featSize, size=(int(0.8*featSize),), replace=False), :]
18 | gEmb, scores = self.model(curGraph)
19 | loss = self.loss(scores, torch.LongTensor([self.ID[idx]]).cuda()) if self.loss is not None else 0
20 | return gEmb.squeeze(), loss # return embedding and loss
21 |
22 | def __len__(self):
23 | return len(self.graphs)
24 |
--------------------------------------------------------------------------------
/reid/utils/logging.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import sys
4 |
5 | from .osutils import mkdir_if_missing
6 |
7 |
8 | class Logger(object):
9 | def __init__(self, fpath=None):
10 | self.console = sys.stdout
11 | self.file = None
12 | if fpath is not None:
13 | mkdir_if_missing(os.path.dirname(fpath))
14 | self.file = open(fpath, 'w')
15 |
16 | def __del__(self):
17 | self.close()
18 |
19 | def __enter__(self):
20 | pass
21 |
22 | def __exit__(self, *args):
23 | self.close()
24 |
25 | def write(self, msg):
26 | self.console.write(msg)
27 | if self.file is not None:
28 | self.file.write(msg)
29 |
30 | def flush(self):
31 | self.console.flush()
32 | if self.file is not None:
33 | self.file.flush()
34 | os.fsync(self.file.fileno())
35 |
36 | def close(self):
37 | self.console.close()
38 | if self.file is not None:
39 | self.file.close()
40 |
--------------------------------------------------------------------------------
/reid/dist_metric.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 |
5 | from .evaluators import extract_features
6 | from .metric_learning import get_metric
7 |
8 |
9 | class DistanceMetric(object):
10 | def __init__(self, algorithm='euclidean', *args, **kwargs):
11 | super(DistanceMetric, self).__init__()
12 | self.algorithm = algorithm
13 | self.metric = get_metric(algorithm, *args, **kwargs)
14 |
15 | def train(self, model, data_loader):
16 | if self.algorithm == 'euclidean': return
17 | features, labels = extract_features(model, data_loader)
18 | features = torch.stack(features.values()).numpy()
19 | labels = torch.Tensor(list(labels.values())).numpy()
20 | self.metric.fit(features, labels)
21 |
22 | def transform(self, X):
23 | if torch.is_tensor(X):
24 | X = X.numpy()
25 | X = self.metric.transform(X)
26 | X = torch.from_numpy(X)
27 | else:
28 | X = self.metric.transform(X)
29 | return X
30 |
31 |
--------------------------------------------------------------------------------
/reid/utils/data/preprocessor.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os.path as osp
3 |
4 | from PIL import Image
5 |
6 |
7 | class Preprocessor(object):
8 | def __init__(self, dataset, root=None, transform=None):
9 | super(Preprocessor, self).__init__()
10 | self.dataset = dataset
11 | self.root = root
12 | self.transform = transform
13 |
14 | def __len__(self):
15 | return len(self.dataset)
16 |
17 | def __getitem__(self, indices):
18 | if isinstance(indices, (tuple, list)):
19 | return [self._get_single_item(index) for index in indices]
20 | return self._get_single_item(indices)
21 |
22 | def _get_single_item(self, index):
23 | fname, pid, camid = self.dataset[index]
24 | fpath = fname
25 | if self.root is not None:
26 | fpath = osp.join(self.root, fname)
27 | if not osp.exists(fpath):
28 | fpath = fname
29 |
30 | img = Image.open(fpath).convert('RGB')
31 | if self.transform is not None:
32 | img = self.transform(img)
33 | return img, fname, pid, camid
34 |
--------------------------------------------------------------------------------
/reid/loss/var_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from collections import defaultdict
4 |
5 |
6 | class VarLoss(nn.Module):
7 | def __init__(self, feat_dim=768):
8 | super(VarLoss, self).__init__()
9 | self.feat_dim = feat_dim
10 | self.simiFunc = nn.Softmax(dim=0)
11 |
12 | def __calDis(self, x, y): # 246s
13 | # x, y = F.normalize(qFeature), F.normalize(gFeature)
14 | # x, y = qFeature, gFeature
15 | m, n = x.shape[0], y.shape[0]
16 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
17 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
18 | disMat.addmm_(1, -2, x, y.t())
19 | return disMat
20 |
21 | def forward(self, x, labels):
22 | labelMap = defaultdict(list)
23 | # per-ID features
24 | labVal = [int(val) for val in labels.cpu()]
25 | for pid in set(labVal):
26 | labelMap[pid].append(x[labels == pid, :])
27 | # cal loss
28 | loss = 0
29 | for keyNum in labelMap.keys():
30 | meanVec = labelMap[keyNum][0].mean(dim=0, keepdim=True)
31 | dist = self.__calDis(meanVec, labelMap[keyNum][0])
32 | import ipdb;
33 | ipdb.set_trace()
34 | loss += dist.mean()
35 | return loss
36 |
--------------------------------------------------------------------------------
/reid/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data.sampler import (
7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler,
8 | WeightedRandomSampler)
9 |
10 |
11 | class RandomIdentitySampler(Sampler):
12 | def __init__(self, data_source, num_instances=1):
13 | self.data_source = data_source
14 | self.num_instances = num_instances
15 | self.index_dic = defaultdict(list)
16 | for index, (_, pid, _) in enumerate(data_source):
17 | self.index_dic[pid].append(index)
18 | self.pids = list(self.index_dic.keys())
19 | self.num_samples = len(self.pids)
20 |
21 | def __len__(self):
22 | return self.num_samples * self.num_instances
23 |
24 | def __iter__(self):
25 | indices = torch.randperm(self.num_samples)
26 | ret = []
27 | for i in indices:
28 | pid = self.pids[i]
29 | t = self.index_dic[pid]
30 | if len(t) >= self.num_instances:
31 | t = np.random.choice(t, size=self.num_instances, replace=False)
32 | else:
33 | t = np.random.choice(t, size=self.num_instances, replace=True)
34 | ret.extend(t)
35 | return iter(ret)
36 |
--------------------------------------------------------------------------------
/reid/loss/neighbour_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from collections import defaultdict
4 |
5 | class NeiLoss(nn.Module):
6 | def __init__(self, feat_dim=768):
7 | super(NeiLoss, self).__init__()
8 | self.feat_dim = feat_dim
9 | self.simiFunc = nn.Softmax(dim=0)
10 |
11 | def __calDis(self, x, y):#246s
12 | # x, y = F.normalize(qFeature), F.normalize(gFeature)
13 | # x, y = qFeature, gFeature
14 | m, n = x.shape[0], y.shape[0]
15 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
16 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
17 | disMat.addmm_(1, -2, x, y.t())
18 | return disMat
19 |
20 | def forward(self, x, labels):
21 | bSize = x.shape[0]
22 | labelMap = defaultdict(list)
23 | distmat = self.__calDis(x,x)
24 | # per-ID features
25 | labVal = [int(val) for val in labels.cpu()]
26 | for pid in set(labVal):
27 | labelMap[pid].append(labels==pid)
28 | # cal loss
29 | loss = 0
30 | for keyNum in labelMap.keys():
31 | mask = labelMap[keyNum]
32 | curProb = distmat[labels==keyNum][0]
33 | loss += -torch.log(self.simiFunc(curProb)[mask].sum()).sum()
34 | return loss/len(labelMap.keys())
35 |
--------------------------------------------------------------------------------
/reid/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import warnings
3 |
4 | from .cuhk03 import CUHK03
5 | from .dukemtmc import DukeMTMC
6 | from .market1501 import Market1501
7 | from .msmt17 import msmt17
8 |
9 | __factory = {
10 | 'cuhk03': CUHK03,
11 | 'market1501': Market1501,
12 | 'dukemtmc': DukeMTMC,
13 | 'msmt17': msmt17,
14 | }
15 |
16 |
17 | def names():
18 | return sorted(__factory.keys())
19 |
20 |
21 | def create(name, root, *args, **kwargs):
22 | """
23 | Create a dataset instance.
24 |
25 | Parameters
26 | ----------
27 | name : str
28 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03',
29 | 'market1501', and 'dukemtmc'.
30 | root : str
31 | The path to the dataset directory.
32 | split_id : int, optional
33 | The index of data split. Default: 0
34 | num_val : int or float, optional
35 | When int, it means the number of validation identities. When float,
36 | it means the proportion of validation to all the trainval. Default: 100
37 | download : bool, optional
38 | If True, will download the dataset. Default: False
39 | """
40 | if name not in __factory:
41 | raise KeyError("Unknown dataset:", name)
42 | return __factory[name](root, *args, **kwargs)
43 |
44 |
45 | def get_dataset(name, root, *args, **kwargs):
46 | warnings.warn("get_dataset is deprecated. Use create instead.")
47 | return create(name, root, *args, **kwargs)
48 |
--------------------------------------------------------------------------------
/reid/feature_extraction/database.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import h5py
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class FeatureDatabase(Dataset):
9 | def __init__(self, *args, **kwargs):
10 | super(FeatureDatabase, self).__init__()
11 | self.fid = h5py.File(*args, **kwargs)
12 |
13 | def __enter__(self):
14 | return self
15 |
16 | def __exit__(self, exc_type, exc_val, exc_tb):
17 | self.close()
18 |
19 | def __getitem__(self, keys):
20 | if isinstance(keys, (tuple, list)):
21 | return [self._get_single_item(k) for k in keys]
22 | return self._get_single_item(keys)
23 |
24 | def _get_single_item(self, key):
25 | return np.asarray(self.fid[key])
26 |
27 | def __setitem__(self, key, value):
28 | if key in self.fid:
29 | if self.fid[key].shape == value.shape and \
30 | self.fid[key].dtype == value.dtype:
31 | self.fid[key][...] = value
32 | else:
33 | del self.fid[key]
34 | self.fid.create_dataset(key, data=value)
35 | else:
36 | self.fid.create_dataset(key, data=value)
37 |
38 | def __delitem__(self, key):
39 | del self.fid[key]
40 |
41 | def __len__(self):
42 | return len(self.fid)
43 |
44 | def __iter__(self):
45 | return iter(self.fid)
46 |
47 | def flush(self):
48 | self.fid.flush()
49 |
50 | def close(self):
51 | self.fid.close()
52 |
--------------------------------------------------------------------------------
/reid/loss/center_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class CenterLoss(nn.Module):
5 | """Center loss.
6 |
7 | Reference:
8 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
9 |
10 | Args:
11 | num_classes (int): number of classes.
12 | feat_dim (int): feature dimension.
13 | """
14 | def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
15 | super(CenterLoss, self).__init__()
16 | self.num_classes = num_classes
17 | self.feat_dim = feat_dim
18 | self.use_gpu = use_gpu
19 |
20 | if self.use_gpu:
21 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
22 | else:
23 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
24 |
25 | def forward(self, x, labels):
26 | """
27 | Args:
28 | x: feature matrix with shape (batch_size, feat_dim).
29 | labels: ground truth labels with shape (batch_size).
30 | """
31 | batch_size = x.size(0)
32 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
33 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
34 | distmat.addmm_(1, -2, x, self.centers.t())
35 |
36 | classes = torch.arange(self.num_classes).long()
37 | if self.use_gpu: classes = classes.cuda()
38 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
39 | mask = labels.eq(classes.expand(batch_size, self.num_classes))
40 |
41 | dist = distmat * mask.float()
42 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
43 |
44 | return loss
45 |
--------------------------------------------------------------------------------
/reid/loss/matchLoss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def pairwiseDis(x, y):#246s
5 | # x, y = F.normalize(qFeature), F.normalize(gFeature)
6 | # x, y = qFeature, gFeature
7 | m, n = x.shape[0], y.shape[0]
8 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
9 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
10 | disMat.addmm_(1, -2, x, y.t())
11 | return disMat
12 |
13 |
14 | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
15 | n_samples = source.shape[0]+target.shape[0]
16 | L2_distance = pairwiseDis(source, target)
17 | if fix_sigma:
18 | bandwidth = fix_sigma
19 | else:
20 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
21 | bandwidth /= kernel_mul ** (kernel_num // 2)
22 | bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
23 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
24 | return sum(kernel_val)/len(kernel_val)
25 |
26 | def lossMMD(srcFeat, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
27 | srcBatch = srcFeat.shape[0]
28 | newFeat = torch.cat([srcFeat,target])
29 | kernels = guassian_kernel(newFeat, newFeat, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
30 | XX = kernels[:srcBatch, :srcBatch]
31 | YY = kernels[srcBatch:, srcBatch:]
32 | XY = kernels[:srcBatch, srcBatch:]
33 | YX = kernels[srcBatch:, :srcBatch]
34 | return XX.mean() + YY.mean() - XY.mean() -YX.mean()
35 |
36 |
37 | # if __name__ == "__main__":
38 | # import numpy as np
39 | # srcFeat, tarFeat = np.load('E://gcn//srcGFeat.npy'), np.load('E://gcn//tarGFeat.npy')
40 | # disMat = lossMMD(torch.from_numpy(srcFeat), torch.from_numpy(tarFeat))
--------------------------------------------------------------------------------
/reid/metric_learning/kissme.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import numpy as np
4 | from metric_learn.base_metric import BaseMetricLearner
5 |
6 |
7 | def validate_cov_matrix(M):
8 | M = (M + M.T) * 0.5
9 | k = 0
10 | I = np.eye(M.shape[0])
11 | while True:
12 | try:
13 | _ = np.linalg.cholesky(M)
14 | break
15 | except np.linalg.LinAlgError:
16 | # Find the nearest positive definite matrix for M. Modified from
17 | # http://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd
18 | # Might take several minutes
19 | k += 1
20 | w, v = np.linalg.eig(M)
21 | min_eig = v.min()
22 | M += (-min_eig * k * k + np.spacing(min_eig)) * I
23 | return M
24 |
25 |
26 | class KISSME(BaseMetricLearner):
27 | def __init__(self):
28 | self.M_ = None
29 |
30 | def metric(self):
31 | return self.M_
32 |
33 | def fit(self, X, y=None):
34 | n = X.shape[0]
35 | if y is None:
36 | y = np.arange(n)
37 | X1, X2 = np.meshgrid(np.arange(n), np.arange(n))
38 | X1, X2 = X1[X1 < X2], X2[X1 < X2]
39 | matches = (y[X1] == y[X2])
40 | num_matches = matches.sum()
41 | num_non_matches = len(matches) - num_matches
42 | idxa = X1[matches]
43 | idxb = X2[matches]
44 | S = X[idxa] - X[idxb]
45 | C1 = S.transpose().dot(S) / num_matches
46 | p = np.random.choice(num_non_matches, num_matches, replace=False)
47 | idxa = X1[~matches]
48 | idxb = X2[~matches]
49 | idxa = idxa[p]
50 | idxb = idxb[p]
51 | S = X[idxa] - X[idxb]
52 | C0 = S.transpose().dot(S) / num_matches
53 | self.M_ = np.linalg.inv(C1) - np.linalg.inv(C0)
54 | self.M_ = validate_cov_matrix(self.M_)
55 | self.X_ = X
56 |
--------------------------------------------------------------------------------
/reid/loss/oim.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, autograd
6 |
7 |
8 | class OIM(autograd.Function):
9 | def __init__(self, lut, momentum=0.5):
10 | super(OIM, self).__init__()
11 | self.lut = lut
12 | self.momentum = momentum
13 |
14 | def forward(self, inputs, targets):
15 | self.save_for_backward(inputs, targets)
16 | outputs = inputs.mm(self.lut.t())
17 | return outputs
18 |
19 | def backward(self, grad_outputs):
20 | inputs, targets = self.saved_tensors
21 | grad_inputs = None
22 | if self.needs_input_grad[0]:
23 | grad_inputs = grad_outputs.mm(self.lut)
24 | for x, y in zip(inputs, targets):
25 | self.lut[y] = self.momentum * self.lut[y] + (1. - self.momentum) * x
26 | self.lut[y] /= self.lut[y].norm()
27 | return grad_inputs, None
28 |
29 |
30 | def oim(inputs, targets, lut, momentum=0.5):
31 | return OIM(lut, momentum=momentum)(inputs, targets)
32 |
33 |
34 | class OIMLoss(nn.Module):
35 | def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5,
36 | weight=None, size_average=True):
37 | super(OIMLoss, self).__init__()
38 | self.num_features = num_features
39 | self.num_classes = num_classes
40 | self.momentum = momentum
41 | self.scalar = scalar
42 | self.weight = weight
43 | self.size_average = size_average
44 |
45 | self.register_buffer('lut', torch.zeros(num_classes, num_features))
46 |
47 | def forward(self, inputs, targets):
48 | inputs = oim(inputs, targets, self.lut, momentum=self.momentum)
49 | inputs *= self.scalar
50 | loss = F.cross_entropy(inputs, targets, weight=self.weight,
51 | size_average=self.size_average)
52 | return loss, inputs
53 |
--------------------------------------------------------------------------------
/reid/utils/serialization.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import json
3 | import os.path as osp
4 | import shutil
5 |
6 | import torch
7 | from torch.nn import Parameter
8 |
9 | from .osutils import mkdir_if_missing
10 |
11 |
12 | def read_json(fpath):
13 | with open(fpath, 'r') as f:
14 | obj = json.load(f)
15 | return obj
16 |
17 |
18 | def write_json(obj, fpath):
19 | mkdir_if_missing(osp.dirname(fpath))
20 | with open(fpath, 'w') as f:
21 | json.dump(obj, f, indent=4, separators=(',', ': '))
22 |
23 |
24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
25 | mkdir_if_missing(osp.dirname(fpath))
26 | torch.save(state, fpath)
27 | if is_best:
28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))
29 |
30 |
31 | def load_checkpoint(fpath):
32 | if osp.isfile(fpath):
33 | checkpoint = torch.load(fpath)
34 | print("=> Loaded checkpoint '{}'".format(fpath))
35 | return checkpoint
36 | else:
37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath))
38 |
39 |
40 | def copy_state_dict(state_dict, model, strip=None):
41 | tgt_state = model.state_dict()
42 | copied_names = set()
43 | for name, param in state_dict.items():
44 | if strip is not None and name.startswith(strip):
45 | name = name[len(strip):]
46 | if name not in tgt_state:
47 | continue
48 | if isinstance(param, Parameter):
49 | param = param.data
50 | if param.size() != tgt_state[name].size():
51 | print('mismatch:', name, param.size(), tgt_state[name].size())
52 | continue
53 | tgt_state[name].copy_(param)
54 | copied_names.add(name)
55 |
56 | missing = set(tgt_state.keys()) - copied_names
57 | if len(missing) > 0:
58 | print("missing keys in state_dict:", missing)
59 |
60 | return model
61 |
--------------------------------------------------------------------------------
/reid/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .inception import *
4 | from .resnet import *
5 | from .classifier import classifier
6 |
7 | __factory = {
8 | 'inception': inception,
9 | 'resnet18': resnet18,
10 | 'resnet34': resnet34,
11 | 'resnet50': resnet50,
12 | 'resnet101': resnet101,
13 | 'resnet152': resnet152,
14 | 'classifier': classifier
15 | }
16 |
17 |
18 | def names():
19 | return sorted(__factory.keys())
20 |
21 |
22 | def create(name, *args, **kwargs):
23 | """
24 | Create a model instance.
25 |
26 | Parameters
27 | ----------
28 | name : str
29 | Model name. Can be one of 'inception', 'resnet18', 'resnet34',
30 | 'resnet50', 'resnet101', and 'resnet152'.
31 | pretrained : bool, optional
32 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained
33 | model. Default: True
34 | cut_at_pooling : bool, optional
35 | If True, will cut the model before the last global pooling layer and
36 | ignore the remaining kwargs. Default: False
37 | num_features : int, optional
38 | If positive, will append a Linear layer after the global pooling layer,
39 | with this number of output units, followed by a BatchNorm layer.
40 | Otherwise these layers will not be appended. Default: 256 for
41 | 'inception', 0 for 'resnet*'
42 | norm : bool, optional
43 | If True, will normalize the feature to be unit L2-norm for each sample.
44 | Otherwise will append a ReLU layer after the above Linear layer if
45 | num_features > 0. Default: False
46 | dropout : float, optional
47 | If positive, will append a Dropout layer with this dropout rate.
48 | Default: 0
49 | num_classes : int, optional
50 | If positive, will append a Linear layer at the end as the classifier
51 | with this number of output units. Default: 0
52 | """
53 | if name not in __factory:
54 | raise KeyError("Unknown model:", name)
55 | return __factory[name](*args, **kwargs)
56 |
--------------------------------------------------------------------------------
/reid/loss/triplet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Variable
6 |
7 |
8 | class TripletLoss(nn.Module):
9 | def __init__(self, margin=0, num_instances=0, use_semi=True, isAvg=True):
10 | super(TripletLoss, self).__init__()
11 | self.margin = margin
12 | self.ranking_loss = nn.MarginRankingLoss(margin=self.margin, reduce=isAvg)
13 | self.K = num_instances
14 | self.use_semi = use_semi
15 |
16 | def forward(self, inputs, targets, epoch):
17 | n = inputs.size(0)
18 | P = n // self.K
19 |
20 | # Compute pairwise distance, replace by the official when merged
21 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
22 | dist = dist + dist.t()
23 | dist.addmm_(1, -2, inputs, inputs.t())
24 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
25 | # For each anchor, find the hardest positive and negative
26 | mask = targets.expand(n, n).eq(targets.expand(n, n).t())
27 | dist_ap, dist_an = [], []
28 | if self.use_semi:
29 | for i in range(P):
30 | for j in range(self.K):
31 | neg_examples = dist[i * self.K + j][mask[i * self.K + j] == 0]
32 | for pair in range(j + 1, self.K):
33 | ap = dist[i * self.K + j][i * self.K + pair]
34 | dist_ap.append(ap.unsqueeze(0))
35 | dist_an.append(neg_examples.min().unsqueeze(0))
36 | else:
37 | for i in range(n):
38 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
39 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
40 | dist_ap = torch.cat(dist_ap)
41 | dist_an = torch.cat(dist_an)
42 | # Compute ranking hinge loss
43 | y = dist_an.data.new()
44 | y.resize_as_(dist_an.data)
45 | y.fill_(1)
46 | y = Variable(y)
47 | loss = self.ranking_loss(dist_an, dist_ap, y)
48 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0)
49 | return loss, prec
50 |
--------------------------------------------------------------------------------
/reid/loss/virtual_ce.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Variable
6 | from torch.nn import functional as F
7 | from scipy.stats import norm
8 |
9 | import numpy as np
10 |
11 |
12 | class VirtualCE(nn.Module):
13 | def __init__(self, beta=0.1):
14 | super(VirtualCE, self).__init__()
15 | self.beta = beta
16 |
17 | def forward(self, inputs, targets):
18 | # norm first
19 | n = inputs.shape[0]
20 | inputs = F.normalize(inputs, p=2)
21 | allPids = targets.cpu().numpy().tolist()
22 | # All Centers
23 | centerHash = {
24 | pid: F.normalize(inputs[targets == pid, :].mean(dim=0, keepdim=True), p=2).detach() for pid in set(allPids)
25 | }
26 | allCenters = torch.autograd.Variable(torch.cat(list(centerHash.values()))).cuda()
27 | centerPID = torch.from_numpy(np.asarray(list(centerHash.keys())))
28 | # sampler vs center
29 | samplerCenter = torch.autograd.Variable(torch.cat([allCenters[centerPID == pid, :] for pid in allPids])).cuda()
30 | # inputs--(128*1024), allCenters--(32*1024)
31 | vce = torch.diag(torch.exp(samplerCenter.mm(inputs.t()) / self.beta)) # 1*128
32 | centerScore = torch.exp(allCenters.mm(inputs.t()) / self.beta).sum(dim=0) # 32(center number)*128->1*128
33 | return -torch.log(vce.div(centerScore)).mean()
34 |
35 |
36 | class VirtualKCE(nn.Module):
37 | def __init__(self, beta=0.1):
38 | super(VirtualKCE, self).__init__()
39 | self.beta = beta
40 |
41 | def forward(self, inputs, targets):
42 | # norm first
43 | n = inputs.shape[0]
44 | inputs = F.normalize(inputs, p=2)
45 | allPids = targets.cpu().numpy().tolist()
46 | # All Centers
47 | centerHash = {
48 | pid: F.normalize(inputs[targets == pid, :].mean(dim=0, keepdim=True), p=2).detach() for pid in set(allPids)
49 | }
50 | allCenters = torch.autograd.Variable(torch.cat(list(centerHash.values()))).cuda()
51 | centerPID = torch.from_numpy(np.asarray(list(centerHash.keys())))
52 | samplerCenter = torch.autograd.Variable(torch.cat([allCenters[centerPID == pid, :] for pid in allPids])).cuda()
53 | # inputs--(128*1024), allCenters--(32*1024)
54 | vce = torch.diag(torch.exp(samplerCenter.mm(inputs.t()) / self.beta)) # 1*128
55 | centerScore = torch.exp(allCenters.mm(inputs.t()) / self.beta).sum(dim=0) # 32*128->1*128
56 | kNegScore = torch.diag(inputs.mm(inputs.t()))
57 | return -torch.log(vce.div(kNegScore + centerScore)).mean()
58 |
--------------------------------------------------------------------------------
/reid/datasets/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import sys
4 | import errno
5 | import shutil
6 | import json
7 | import os.path as osp
8 |
9 | import torch
10 |
11 |
12 | def mkdir_if_missing(directory):
13 | if not osp.exists(directory):
14 | try:
15 | os.makedirs(directory)
16 | except OSError as e:
17 | if e.errno != errno.EEXIST:
18 | raise
19 |
20 |
21 | class AverageMeter(object):
22 | """Computes and stores the average and current value.
23 |
24 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
25 | """
26 |
27 | def __init__(self):
28 | self.reset()
29 |
30 | def reset(self):
31 | self.val = 0
32 | self.avg = 0
33 | self.sum = 0
34 | self.count = 0
35 |
36 | def update(self, val, n=1):
37 | self.val = val
38 | self.sum += val * n
39 | self.count += n
40 | self.avg = self.sum / self.count
41 |
42 |
43 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
44 | mkdir_if_missing(osp.dirname(fpath))
45 | torch.save(state, fpath)
46 | if is_best:
47 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))
48 |
49 |
50 | class Logger(object):
51 | """
52 | Write console output to external text file.
53 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
54 | """
55 |
56 | def __init__(self, fpath=None):
57 | self.console = sys.stdout
58 | self.file = None
59 | if fpath is not None:
60 | mkdir_if_missing(os.path.dirname(fpath))
61 | self.file = open(fpath, 'w')
62 |
63 | def __del__(self):
64 | self.close()
65 |
66 | def __enter__(self):
67 | pass
68 |
69 | def __exit__(self, *args):
70 | self.close()
71 |
72 | def write(self, msg):
73 | self.console.write(msg)
74 | if self.file is not None:
75 | self.file.write(msg)
76 |
77 | def flush(self):
78 | self.console.flush()
79 | if self.file is not None:
80 | self.file.flush()
81 | os.fsync(self.file.fileno())
82 |
83 | def close(self):
84 | self.console.close()
85 | if self.file is not None:
86 | self.file.close()
87 |
88 |
89 | def read_json(fpath):
90 | with open(fpath, 'r') as f:
91 | obj = json.load(f)
92 | return obj
93 |
94 |
95 | def write_json(obj, fpath):
96 | mkdir_if_missing(osp.dirname(fpath))
97 | with open(fpath, 'w') as f:
98 | json.dump(obj, f, indent=4, separators=(',', ': '))
99 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Asymmetric Co-Teaching for Unsupervised Cross Domain Person Re-Identification (AAAI 2020)
2 |
3 | Code for AAAI 2020 paper [Asymmetric Co-Teaching for Unsupervised Cross Domain Person Re-Identification](https://arxiv.org/abs/1912.01349).
4 |
5 | 
6 |
7 | ## Requirements
8 | * python 3.7
9 | * Server with 4 GPUs
10 | * Market1501, DukeMTMC-reID and other datasets.
11 | * Other necessary packages listed in [requirements.txt](requirements.txt)
12 |
13 | ## Adaptation with ACT
14 | 1. Download all necessary datasets and move them to 'data' by following instructions in 'data/readme.md'
15 |
16 | 2. If you want to train from the pre-adapted model for fast reproduction,
17 | please download all models in Resources and run the following command:
18 |
19 | ```
20 | python selftrainingACT.py --src_dataset {src_dataset_name} --tgt_dataset {tgt_dataset_name} --resume {model's path} --data_dir ./data --logs_dir {path to save model}
21 | ```
22 |
23 | avaliable choices to fill "src_dataset_name" and "tgt_dataset_name" are:
24 | market1501 (for Market1501), dukemtmc (for DukeMTMC-reID), cuhk03 (for CUHK03).
25 |
26 |
27 | 3. If you want to train from scratch, please train source model and adapted model by using code in
28 | [Adaptive-ReID](https://github.com/LcDog/DomainAdaptiveReID) and follow #2.
29 |
30 | ## Adaptation with other co-teaching-like structures.
31 | To reproduce the results in Tab. 2 of our paper, please run selftrainingRCT.py and selftrainingCT.py in similar way.
32 |
33 | ## Adaptation with other clustering methods.
34 | To reproduce Tab.3 of our paper, run selftrainingKmeans.py (co-teaching version) and selftrainingKmeansAsy.py (ACT version).
35 |
36 | If you find this code useful in your research, please consider citing:
37 | ```
38 | @inproceedings{yang2020asymmetric,
39 | title={Asymmetric Co-Teaching for Unsupervised Cross-Domain Person Re-Identification.},
40 | author={Yang, Fengxiang and Li, Ke and Zhong, Zhun and Luo, Zhiming and Sun, Xing and Cheng, Hao and Guo, Xiaowei and Huang, Feiyue and Ji, Rongrong and Li, Shaozi},
41 | booktitle={AAAI},
42 | pages={12597--12604},
43 | year={2020}
44 | }
45 | ```
46 |
47 | ## Acknowledgments
48 | Our code is based on [open-reid](https://github.com/Cysu/open-reid) and [Adaptive-ReID](https://arxiv.org/abs/1807.11334),
49 | if you use our code, please also cite their paper.
50 | ```
51 | @article{song2018unsupervised,
52 | title={Unsupervised domain adaptive re-identification: Theory and practice},
53 | author={Song, Liangchen and Wang, Cheng and Zhang, Lefei and Du, Bo and Zhang, Qian and Huang, Chang and Wang, Xinggang},
54 | journal={arXiv preprint arXiv:1807.11334},
55 | year={2018}
56 | }
57 | ```
58 |
59 |
60 |
61 |
Resouces:
62 |
63 | 1. Pretrained Models:
64 |
65 | all pre-adapted models are named by the following formula:
66 | ```
67 | ada{src}2{tgt}.pth
68 | ```
69 | where "src" and "tgt" are the initial letter of source and target dataset's name, i.e., M for Market1501, D for Duke and C for CUHK03.
70 |
71 | [Baidu NetDisk](https://pan.baidu.com/s/1uPjKpkdZjqSJdk3XxR1-Yg), Password: 9aba
72 |
73 | [Google Drive](https://drive.google.com/file/d/1W1BcmHjmzxR3TVj2rFpnV703Huat3AeA/view?usp=sharing)
74 |
75 | 2. Results on MSMT17(MS), Duke(D) and Market(M).
76 |
77 | 
78 |
--------------------------------------------------------------------------------
/reid/utils/data/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from PIL import Image, ImageOps, ImageEnhance
4 | import random
5 | import math
6 | from torchvision.transforms import *
7 |
8 |
9 | class RectScale(object):
10 | def __init__(self, height, width, interpolation=Image.BILINEAR):
11 | self.height = height
12 | self.width = width
13 | self.interpolation = interpolation
14 |
15 | def __call__(self, img):
16 | w, h = img.size
17 | if h == self.height and w == self.width:
18 | return img
19 | return img.resize((self.width, self.height), self.interpolation)
20 |
21 |
22 | class RandomSizedRectCrop(object):
23 | def __init__(self, height, width, interpolation=Image.BILINEAR):
24 | self.height = height
25 | self.width = width
26 | self.interpolation = interpolation
27 |
28 | def __call__(self, img):
29 | for attempt in range(10):
30 | area = img.size[0] * img.size[1]
31 | target_area = random.uniform(0.64, 1.0) * area
32 | aspect_ratio = random.uniform(2, 3)
33 |
34 | h = int(round(math.sqrt(target_area * aspect_ratio)))
35 | w = int(round(math.sqrt(target_area / aspect_ratio)))
36 |
37 | if w <= img.size[0] and h <= img.size[1]:
38 | x1 = random.randint(0, img.size[0] - w)
39 | y1 = random.randint(0, img.size[1] - h)
40 |
41 | img = img.crop((x1, y1, x1 + w, y1 + h))
42 | assert(img.size == (w, h))
43 |
44 | return img.resize((self.width, self.height), self.interpolation)
45 |
46 | # Fallback
47 | scale = RectScale(self.height, self.width,
48 | interpolation=self.interpolation)
49 | return scale(img)
50 |
51 | class RandomErasing(object):
52 | '''
53 | Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al.
54 | -------------------------------------------------------------------------------------
55 | probability: The probability that the operation will be performed.
56 | sl: min erasing area
57 | sh: max erasing area
58 | r1: min aspect ratio
59 | mean: erasing value
60 | -------------------------------------------------------------------------------------
61 | '''
62 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.2, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
63 | self.probability = probability
64 | self.mean = mean
65 | self.sl = sl
66 | self.sh = sh
67 | self.r1 = r1
68 |
69 | def __call__(self, img):
70 |
71 | if random.uniform(0, 1) > self.probability:
72 | return img
73 |
74 | for attempt in range(100):
75 | area = img.size()[1] * img.size()[2]
76 |
77 | target_area = random.uniform(self.sl, self.sh) * area
78 | aspect_ratio = random.uniform(self.r1, 1/self.r1)
79 |
80 | h = int(round(math.sqrt(target_area * aspect_ratio)))
81 | w = int(round(math.sqrt(target_area / aspect_ratio)))
82 |
83 | if w < img.size()[2] and h < img.size()[1]:
84 | x1 = random.randint(0, img.size()[1] - h)
85 | y1 = random.randint(0, img.size()[2] - w)
86 | if img.size()[0] == 3:
87 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
88 | img[1, x1:x1+h, y1:y1+w] = self.mean[1]
89 | img[2, x1:x1+h, y1:y1+w] = self.mean[2]
90 | else:
91 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
92 | return img
93 |
94 | return img
95 |
--------------------------------------------------------------------------------
/reid/models/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import math
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from torch.nn import init
8 | from torch.autograd import Variable
9 | import torchvision
10 | # from torch_deform_conv.layers import ConvOffset2D
11 | from reid.utils.serialization import load_checkpoint, save_checkpoint
12 |
13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
14 | 'resnet152']
15 |
16 |
17 | class ResNet(nn.Module):
18 | __factory = {
19 | 18: torchvision.models.resnet18,
20 | 34: torchvision.models.resnet34,
21 | 50: torchvision.models.resnet50,
22 | 101: torchvision.models.resnet101,
23 | 152: torchvision.models.resnet152,
24 | }
25 |
26 | def __init__(self, depth, checkpoint=None, pretrained=True, num_features=2048,
27 | dropout=0.1, num_classes=0, numCams=0):
28 | super(ResNet, self).__init__()
29 |
30 | self.depth = depth
31 | self.checkpoint = checkpoint
32 | self.pretrained = pretrained
33 | self.num_features = num_features
34 | self.dropout = dropout
35 | self.num_classes = num_classes
36 |
37 | if self.dropout > 0:
38 | self.drop = nn.Dropout(self.dropout)
39 | # Construct base (pretrained) resnet
40 | if depth not in ResNet.__factory:
41 | raise KeyError("Unsupported depth:", depth)
42 | self.base = ResNet.__factory[depth](pretrained=pretrained)
43 | out_planes = self.base.fc.in_features
44 |
45 | # resume from pre-iteration training
46 | if self.checkpoint:
47 | state_dict = load_checkpoint(checkpoint)
48 | self.load_state_dict(state_dict['state_dict'], strict=False)
49 |
50 | self.feat = nn.Linear(out_planes, self.num_features, bias=False)
51 | self.feat_bn = nn.BatchNorm1d(self.num_features)
52 | self.relu = nn.ReLU(inplace=True)
53 | init.normal(self.feat.weight, std=0.001)
54 | init.constant(self.feat_bn.weight, 1)
55 | init.constant(self.feat_bn.bias, 0)
56 |
57 | # x2 classifier
58 | self.classifier_x2 = nn.Linear(self.num_features, self.num_classes)
59 | init.normal(self.classifier_x2.weight, std=0.001)
60 | init.constant(self.classifier_x2.bias, 0)
61 |
62 | if not self.pretrained:
63 | self.reset_params()
64 |
65 | def forward(self, x):
66 | for name, module in self.base._modules.items():
67 | if name == 'avgpool': break
68 | x = module(x)
69 | # x with (bSize, 1024, 8, 4)
70 | x1 = F.avg_pool2d(x, x.size()[2:])
71 | x1 = x1.view(x1.size(0), -1)
72 | # get classification
73 | x2 = F.avg_pool2d(x, x.size()[2:])
74 | x2 = x2.view(x2.size(0), -1)
75 | x2 = self.feat(x2)
76 | x2 = self.feat_bn(x2)
77 | x2 = self.relu(x2)
78 | if self.num_classes != 0:
79 | return x1, x2, self.classifier_x2(x2) # pool5, fc2048, classifier
80 | return x1, x2
81 |
82 | def reset_params(self):
83 | for m in self.modules():
84 | if isinstance(m, nn.Conv2d):
85 | init.normal(m.weight, std=0.001)
86 | if m.bias is not None:
87 | init.constant(m.bias, 0)
88 | elif isinstance(m, nn.BatchNorm2d):
89 | init.constant(m.weight, 1)
90 | init.constant(m.bias, 0)
91 | elif isinstance(m, nn.Linear):
92 | init.normal(m.weight, std=0.001)
93 | if m.bias is not None:
94 | init.constant(m.bias, 0)
95 |
96 |
97 | def resnet18(**kwargs):
98 | return ResNet(18, **kwargs)
99 |
100 |
101 | def resnet34(**kwargs):
102 | return ResNet(34, **kwargs)
103 |
104 |
105 | def resnet50(**kwargs):
106 | return ResNet(50, **kwargs)
107 |
108 |
109 | def resnet101(**kwargs):
110 | return ResNet(101, **kwargs)
111 |
112 |
113 | def resnet152(**kwargs):
114 | return ResNet(152, **kwargs)
115 |
--------------------------------------------------------------------------------
/reid/datasets/cuhk03.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 |
4 | import numpy as np
5 |
6 | from ..utils.data import Dataset
7 | from ..utils.osutils import mkdir_if_missing
8 | from ..utils.serialization import write_json
9 |
10 |
11 | class CUHK03(Dataset):
12 | url = 'https://docs.google.com/spreadsheet/viewform?usp=drive_web&formkey=dHRkMkFVSUFvbTJIRkRDLWRwZWpONnc6MA#gid=0'
13 | md5 = '728939e58ad9f0ff53e521857dd8fb43'
14 |
15 | def __init__(self, root, split_id=0, num_val=100, download=True):
16 | super(CUHK03, self).__init__(root, split_id=split_id)
17 |
18 | if download:
19 | self.download()
20 |
21 | if not self._check_integrity():
22 | raise RuntimeError("Dataset not found or corrupted. " +
23 | "You can use download=True to download it.")
24 |
25 | self.load(num_val)
26 |
27 | def download(self):
28 | if self._check_integrity():
29 | print("Files already downloaded and verified")
30 | return
31 |
32 | import h5py
33 | import hashlib
34 | from scipy.misc import imsave
35 | from zipfile import ZipFile
36 |
37 | raw_dir = osp.join(self.root, 'raw')
38 | mkdir_if_missing(raw_dir)
39 |
40 | # Download the raw zip file
41 | fpath = osp.join(raw_dir, 'cuhk03_release.zip')
42 | if osp.isfile(fpath) and \
43 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5:
44 | print("Using downloaded file: " + fpath)
45 | else:
46 | raise RuntimeError("Please download the dataset manually from {} "
47 | "to {}".format(self.url, fpath))
48 |
49 | # Extract the file
50 | exdir = osp.join(raw_dir, 'cuhk03_release')
51 | if not osp.isdir(exdir):
52 | print("Extracting zip file")
53 | with ZipFile(fpath) as z:
54 | z.extractall(path=raw_dir)
55 |
56 | # Format
57 | images_dir = osp.join(self.root, 'images')
58 | mkdir_if_missing(images_dir)
59 | matdata = h5py.File(osp.join(exdir, 'cuhk-03.mat'), 'r')
60 |
61 | def deref(ref):
62 | return matdata[ref][:].T
63 |
64 | def dump_(refs, pid, cam, fnames):
65 | for ref in refs:
66 | img = deref(ref)
67 | if img.size == 0 or img.ndim < 2: break
68 | fname = '{:08d}_{:02d}_{:04d}.jpg'.format(pid, cam, len(fnames))
69 | imsave(osp.join(images_dir, fname), img)
70 | fnames.append(fname)
71 |
72 | identities = []
73 | for labeled, detected in zip(
74 | matdata['labeled'][0], matdata['detected'][0]):
75 | labeled, detected = deref(labeled), deref(detected)
76 | assert labeled.shape == detected.shape
77 | for i in range(labeled.shape[0]):
78 | pid = len(identities)
79 | images = [[], []]
80 | dump_(labeled[i, :5], pid, 0, images[0])
81 | dump_(detected[i, :5], pid, 0, images[0])
82 | dump_(labeled[i, 5:], pid, 1, images[1])
83 | dump_(detected[i, 5:], pid, 1, images[1])
84 | identities.append(images)
85 |
86 | # Save meta information into a json file
87 | meta = {'name': 'cuhk03', 'shot': 'multiple', 'num_cameras': 2,
88 | 'identities': identities}
89 | write_json(meta, osp.join(self.root, 'meta.json'))
90 |
91 | # Save training and test splits
92 | splits = []
93 | view_counts = [deref(ref).shape[0] for ref in matdata['labeled'][0]]
94 | vid_offsets = np.r_[0, np.cumsum(view_counts)]
95 | for ref in matdata['testsets'][0]:
96 | test_info = deref(ref).astype(np.int32)
97 | test_pids = sorted(
98 | [int(vid_offsets[i - 1] + j - 1) for i, j in test_info])
99 | trainval_pids = list(set(range(vid_offsets[-1])) - set(test_pids))
100 | split = {'trainval': trainval_pids,
101 | 'query': test_pids,
102 | 'gallery': test_pids}
103 | splits.append(split)
104 | write_json(splits, osp.join(self.root, 'splits.json'))
105 |
--------------------------------------------------------------------------------
/reid/datasets/msmt17.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 | import os
4 | import numpy as np
5 |
6 | from ..utils.data import Dataset
7 | from ..utils.osutils import mkdir_if_missing
8 | from ..utils.serialization import read_json
9 | from ..utils.serialization import write_json
10 |
11 |
12 | class msmt17(Dataset):
13 | url = 'https://drive.google.com/file/d/1PduQX1OBuoXDh9JxybYBoDEcKhSx_Q8j/view?usp=sharing'
14 | md5 = 'ea5502ae9dd06c596ad866bd1db0280d'
15 |
16 | def __init__(self, root, split_id=0, num_val=100, download=True):
17 | super(msmt17, self).__init__(root, split_id=split_id)
18 |
19 | if download:
20 | self.download()
21 |
22 | self.load()
23 |
24 | def download(self):
25 | if self._check_integrity():
26 | print("Files already downloaded and verified")
27 | return
28 |
29 | import re
30 | import hashlib
31 | import tarfile
32 |
33 | raw_dir = osp.join(self.root, 'raw')
34 | mkdir_if_missing(raw_dir)
35 |
36 | # Download the raw zip file
37 | fpath = osp.join(raw_dir, 'MSMT17_V1.tar.gz')
38 | if osp.isfile(fpath) and \
39 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5:
40 | print("Using downloaded file: " + fpath)
41 | else:
42 | raise RuntimeError("Please download the dataset manually from {} "
43 | "to {}".format(self.url, fpath))
44 |
45 | # Extract the file
46 | exdir = osp.join(raw_dir, 'MSMT17_V1')
47 | if not osp.isdir(exdir):
48 | print("Extracting tar file")
49 | with tarfile.open(fpath) as tar:
50 | tar.extractall(raw_dir)
51 |
52 | # MSMT17 files
53 | def register(typeName, pattern=re.compile(r'([-\d]+)_([-\d]+)_([-\d]+)'), test_query=False):
54 | assert typeName.lower() in ['gallery', 'query', 'train', 'val']
55 | nameMap = {
56 | 'gallery': 'test', 'query': 'test',
57 | 'train': 'train', 'val': 'train'
58 | }
59 | with open(osp.join(exdir, 'list_{}.txt'.format(typeName.lower())), 'r') as f:
60 | fpaths = f.readlines()
61 | fpaths = [name.strip().split(' ')[0] for name in fpaths]
62 | fpaths = sorted([osp.join(exdir, nameMap[typeName.lower()], name) for name in fpaths])
63 | curData = []
64 | for fpath in fpaths:
65 | fname = osp.basename(fpath)
66 | pid, _, cam = map(int, pattern.search(fname).groups())
67 | cam -= 1
68 | curData.append((fpath, pid, cam))
69 | return curData
70 |
71 | self.train = register('train')
72 | self.val = register('val')
73 | self.trainval = self.train + self.val
74 | self.gallery = register('gallery')
75 | self.query = register('query')
76 |
77 | ########################
78 | # Added
79 | def load(self, verbose=True):
80 | trainPids = [pid[1] for pid in self.train]
81 | valPids = [pid[1] for pid in self.val]
82 | trainvalPids = [pid[1] for pid in self.trainval]
83 | galleryPids = [pid[1] for pid in self.gallery]
84 | queryPids = [pid[1] for pid in self.query]
85 | self.num_train_ids = len(set(trainPids))
86 | self.num_val_ids = len(set(valPids))
87 | self.num_trainval_ids = len(set(trainvalPids))
88 | self.num_query_ids = len(set(queryPids))
89 | self.num_gallery_ids = len(set(galleryPids))
90 | ##########
91 | if verbose:
92 | print(self.__class__.__name__, "dataset loaded")
93 | print(" subset | # ids | # images")
94 | print(" ---------------------------")
95 | print(" train | {:5d} | {:8d}".format(self.num_train_ids, len(self.train)))
96 | print(" val | {:5d} | {:8d}".format(self.num_val_ids, len(self.val)))
97 | print(" trainval | {:5d} | {:8d}".format(self.num_trainval_ids, len(self.trainval)))
98 | print(" query | {:5d} | {:8d}".format(self.num_query_ids, len(self.query)))
99 | print(" gallery | {:5d} | {:8d}".format(self.num_gallery_ids, len(self.gallery)))
100 | ########################
101 |
--------------------------------------------------------------------------------
/reid/utils/data/video_loader.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os
3 | from PIL import Image
4 | import numpy as np
5 |
6 | import torch
7 | from torch.utils.data import Dataset
8 | import random
9 |
10 | def read_image(img_path):
11 | """Keep reading image until succeed.
12 | This can avoid IOError incurred by heavy IO process."""
13 | got_img = False
14 | while not got_img:
15 | try:
16 | img = Image.open(img_path).convert('RGB')
17 | got_img = True
18 | except IOError:
19 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
20 | pass
21 | return img
22 |
23 |
24 | class VideoDataset(Dataset):
25 | """Video Person ReID Dataset.
26 | Note batch data has shape (batch, seq_len, channel, height, width).
27 | """
28 | sample_methods = ['evenly', 'random', 'all']
29 |
30 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None):
31 | self.dataset = dataset
32 | self.seq_len = seq_len
33 | self.sample = sample
34 | self.transform = transform
35 |
36 | def __len__(self):
37 | return len(self.dataset)
38 |
39 | def __getitem__(self, index):
40 | img_paths, pid, camid = self.dataset[index]
41 | num = len(img_paths)
42 | if self.sample == 'random':
43 | """
44 | Randomly sample seq_len consecutive frames from num frames,
45 | if num is smaller than seq_len, then replicate items.
46 | This sampling strategy is used in training phase.
47 | """
48 | frame_indices = range(num)
49 | rand_end = max(0, len(frame_indices) - self.seq_len - 1)
50 | begin_index = random.randint(0, rand_end)
51 | end_index = min(begin_index + self.seq_len, len(frame_indices))
52 |
53 | indices = frame_indices[begin_index:end_index]
54 |
55 | for index in indices:
56 | if len(indices) >= self.seq_len:
57 | break
58 | indices.append(index)
59 | indices=np.array(indices)
60 | imgs = []
61 | for index in indices:
62 | index=int(index)
63 | img_path = img_paths[index]
64 | img = read_image(img_path)
65 | if self.transform is not None:
66 | img = self.transform(img)
67 | img = img.unsqueeze(0)
68 | imgs.append(img)
69 | imgs = torch.cat(imgs, dim=0)
70 | #imgs=imgs.permute(1,0,2,3)
71 | return imgs, pid, camid
72 |
73 | elif self.sample == 'dense':
74 | """
75 | Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1.
76 | This sampling strategy is used in test phase.
77 | """
78 | cur_index=0
79 | frame_indices = range(num)
80 | indices_list=[]
81 | while num-cur_index > self.seq_len:
82 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len])
83 | cur_index+=self.seq_len
84 | last_seq=frame_indices[cur_index:]
85 | for index in last_seq:
86 | if len(last_seq) >= self.seq_len:
87 | break
88 | last_seq.append(index)
89 | indices_list.append(last_seq)
90 | imgs_list=[]
91 | for indices in indices_list:
92 | imgs = []
93 | for index in indices:
94 | index=int(index)
95 | img_path = img_paths[index]
96 | img = read_image(img_path)
97 | if self.transform is not None:
98 | img = self.transform(img)
99 | img = img.unsqueeze(0)
100 | imgs.append(img)
101 | imgs = torch.cat(imgs, dim=0)
102 | #imgs=imgs.permute(1,0,2,3)
103 | imgs_list.append(imgs)
104 | imgs_array = torch.stack(imgs_list)
105 | return imgs_array, pid, camid
106 |
107 | else:
108 | raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods))
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/reid/evaluation_metrics/ranking.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | from sklearn.metrics import average_precision_score
6 |
7 | from ..utils import to_numpy
8 |
9 |
10 | def _unique_sample(ids_dict, num):
11 | mask = np.zeros(num, dtype=np.bool)
12 | for _, indices in ids_dict.items():
13 | i = np.random.choice(indices)
14 | mask[i] = True
15 | return mask
16 |
17 |
18 | def cmc(distmat, query_ids=None, gallery_ids=None,
19 | query_cams=None, gallery_cams=None, topk=100,
20 | separate_camera_set=False,
21 | single_gallery_shot=False,
22 | first_match_break=False):
23 | distmat = to_numpy(distmat)
24 | m, n = distmat.shape
25 | # Fill up default values
26 | if query_ids is None:
27 | query_ids = np.arange(m)
28 | if gallery_ids is None:
29 | gallery_ids = np.arange(n)
30 | if query_cams is None:
31 | query_cams = np.zeros(m).astype(np.int32)
32 | if gallery_cams is None:
33 | gallery_cams = np.ones(n).astype(np.int32)
34 | # Ensure numpy array
35 | query_ids = np.asarray(query_ids)
36 | gallery_ids = np.asarray(gallery_ids)
37 | query_cams = np.asarray(query_cams)
38 | gallery_cams = np.asarray(gallery_cams)
39 | # Sort and find correct matches
40 | indices = np.argsort(distmat, axis=1)
41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
42 | # Compute CMC for each query
43 | ret = np.zeros(topk)
44 | num_valid_queries = 0
45 | for i in range(m):
46 | # Filter out the same id and same camera
47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) |
48 | (gallery_cams[indices[i]] != query_cams[i]))
49 | if separate_camera_set:
50 | # Filter out samples from same camera
51 | valid &= (gallery_cams[indices[i]] != query_cams[i])
52 | if not np.any(matches[i, valid]): continue
53 | if single_gallery_shot:
54 | repeat = 10
55 | gids = gallery_ids[indices[i][valid]]
56 | inds = np.where(valid)[0]
57 | ids_dict = defaultdict(list)
58 | for j, x in zip(inds, gids):
59 | ids_dict[x].append(j)
60 | else:
61 | repeat = 1
62 | for _ in range(repeat):
63 | if single_gallery_shot:
64 | # Randomly choose one instance for each id
65 | sampled = (valid & _unique_sample(ids_dict, len(valid)))
66 | index = np.nonzero(matches[i, sampled])[0]
67 | else:
68 | index = np.nonzero(matches[i, valid])[0]
69 | delta = 1. / (len(index) * repeat)
70 | for j, k in enumerate(index):
71 | if k - j >= topk: break
72 | if first_match_break:
73 | ret[k - j] += 1
74 | break
75 | ret[k - j] += delta
76 | num_valid_queries += 1
77 | if num_valid_queries == 0:
78 | raise RuntimeError("No valid query")
79 | return ret.cumsum() / num_valid_queries
80 |
81 |
82 | def mean_ap(distmat, query_ids=None, gallery_ids=None,
83 | query_cams=None, gallery_cams=None):
84 | distmat = to_numpy(distmat)
85 | m, n = distmat.shape
86 | # Fill up default values
87 | if query_ids is None:
88 | query_ids = np.arange(m)
89 | if gallery_ids is None:
90 | gallery_ids = np.arange(n)
91 | if query_cams is None:
92 | query_cams = np.zeros(m).astype(np.int32)
93 | if gallery_cams is None:
94 | gallery_cams = np.ones(n).astype(np.int32)
95 | # Ensure numpy array
96 | query_ids = np.asarray(query_ids)
97 | gallery_ids = np.asarray(gallery_ids)
98 | query_cams = np.asarray(query_cams)
99 | gallery_cams = np.asarray(gallery_cams)
100 | # Sort and find correct matches
101 | indices = np.argsort(distmat, axis=1)
102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
103 | # Compute AP for each query
104 | aps = []
105 | for i in range(m):
106 | # Filter out the same id and same camera
107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) |
108 | (gallery_cams[indices[i]] != query_cams[i]))
109 | y_true = matches[i, valid]
110 | y_score = -distmat[i][indices[i]][valid]
111 | if not np.any(y_true): continue
112 | aps.append(average_precision_score(y_true, y_score))
113 | if len(aps) == 0:
114 | raise RuntimeError("No valid query")
115 | return np.mean(aps)
116 |
--------------------------------------------------------------------------------
/reid/evaluators.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import time
3 | from collections import OrderedDict
4 | from collections import namedtuple
5 |
6 | import torch
7 |
8 | from .evaluation_metrics import cmc, mean_ap
9 | from .feature_extraction import extract_cnn_feature
10 | from .utils.meters import AverageMeter
11 |
12 |
13 | def extract_features(model, data_loader, print_freq=1):
14 | model.eval()
15 | batch_time = AverageMeter()
16 | data_time = AverageMeter()
17 |
18 | features = OrderedDict()
19 | labels = OrderedDict()
20 |
21 | end = time.time()
22 | for i, (imgs, fnames, pids, _) in enumerate(data_loader):
23 | data_time.update(time.time() - end)
24 |
25 | outputs = extract_cnn_feature(model, imgs)
26 | for fname, output, pid in zip(fnames, outputs, pids):
27 | features[fname] = output
28 | labels[fname] = pid
29 |
30 | batch_time.update(time.time() - end)
31 | end = time.time()
32 |
33 | if (i + 1) % print_freq == 0:
34 | print('Extract Features: [{}/{}]\t'
35 | 'Time {:.3f} ({:.3f})\t'
36 | 'Data {:.3f} ({:.3f})\t'
37 | .format(i + 1, len(data_loader),
38 | batch_time.val, batch_time.avg,
39 | data_time.val, data_time.avg))
40 |
41 | return features, labels
42 |
43 |
44 | def pairwise_distance(features, query=None, gallery=None, metric=None):
45 | if query is None and gallery is None:
46 | n = len(features)
47 | x = torch.cat(list(features.values()))
48 | x = x.view(n, -1)
49 | if metric is not None:
50 | x = metric.transform(x)
51 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2
52 | dist = dist.expand(n, n) - 2 * torch.mm(x, x.t())
53 | return dist
54 |
55 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0)
56 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0)
57 | m, n = x.size(0), y.size(0)
58 | x = x.view(m, -1)
59 | y = y.view(n, -1)
60 | if metric is not None:
61 | x = metric.transform(x)
62 | y = metric.transform(y)
63 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
64 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
65 | dist.addmm_(1, -2, x, y.t())
66 | return dist
67 |
68 |
69 | def evaluate_all(distmat, query=None, gallery=None,
70 | query_ids=None, gallery_ids=None,
71 | query_cams=None, gallery_cams=None,
72 | cmc_topk=(1, 5, 10)):
73 | if query is not None and gallery is not None:
74 | query_ids = [pid for _, pid, _ in query]
75 | gallery_ids = [pid for _, pid, _ in gallery]
76 | query_cams = [cam for _, _, cam in query]
77 | gallery_cams = [cam for _, _, cam in gallery]
78 | else:
79 | assert (query_ids is not None and gallery_ids is not None
80 | and query_cams is not None and gallery_cams is not None)
81 |
82 | # Compute mean AP
83 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams)
84 | print('Mean AP: {:4.1%}'.format(mAP))
85 |
86 | cmc_configs = {
87 | 'allshots': dict(separate_camera_set=False,
88 | single_gallery_shot=False,
89 | first_match_break=False),
90 | 'cuhk03': dict(separate_camera_set=True,
91 | single_gallery_shot=True,
92 | first_match_break=False),
93 | 'market1501': dict(separate_camera_set=False,
94 | single_gallery_shot=False,
95 | first_match_break=True)}
96 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids,
97 | query_cams, gallery_cams, **params)
98 | for name, params in cmc_configs.items()}
99 |
100 | print('CMC Scores{:>12}{:>12}{:>12}'
101 | .format('allshots', 'cuhk03', 'market1501'))
102 | rank_score = namedtuple(
103 | 'rank_score',
104 | ['map', 'allshots', 'cuhk03', 'market1501'],
105 | )
106 | for k in cmc_topk:
107 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}'
108 | .format(k, cmc_scores['allshots'][k - 1],
109 | cmc_scores['cuhk03'][k - 1],
110 | cmc_scores['market1501'][k - 1]))
111 | score = rank_score(
112 | mAP,
113 | cmc_scores['allshots'], cmc_scores['cuhk03'],
114 | cmc_scores['market1501'],
115 | )
116 | return score
117 |
118 |
119 | class Evaluator(object):
120 | def __init__(self, model, print_freq=1):
121 | super(Evaluator, self).__init__()
122 | self.model = model
123 | self.print_freq = print_freq
124 |
125 | def evaluate(self, data_loader, query, gallery, metric=None):
126 | features, _ = extract_features(self.model, data_loader, print_freq=self.print_freq)
127 | distmat = pairwise_distance(features, query, gallery, metric=metric)
128 | return evaluate_all(distmat, query=query, gallery=gallery)
129 |
--------------------------------------------------------------------------------
/reid/models/inception.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.nn import init
7 |
8 |
9 | __all__ = ['InceptionNet', 'inception']
10 |
11 |
12 | def _make_conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1,
13 | bias=False):
14 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
15 | stride=stride, padding=padding, bias=bias)
16 | bn = nn.BatchNorm2d(out_planes)
17 | relu = nn.ReLU(inplace=True)
18 | return nn.Sequential(conv, bn, relu)
19 |
20 |
21 | class Block(nn.Module):
22 | def __init__(self, in_planes, out_planes, pool_method, stride):
23 | super(Block, self).__init__()
24 | self.branches = nn.ModuleList([
25 | nn.Sequential(
26 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0),
27 | _make_conv(out_planes, out_planes, stride=stride)
28 | ),
29 | nn.Sequential(
30 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0),
31 | _make_conv(out_planes, out_planes),
32 | _make_conv(out_planes, out_planes, stride=stride))
33 | ])
34 |
35 | if pool_method == 'Avg':
36 | assert stride == 1
37 | self.branches.append(
38 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0))
39 | self.branches.append(nn.Sequential(
40 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
41 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0)))
42 | else:
43 | self.branches.append(
44 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1))
45 |
46 | def forward(self, x):
47 | return torch.cat([b(x) for b in self.branches], 1)
48 |
49 |
50 | class InceptionNet(nn.Module):
51 | def __init__(self, cut_at_pooling=False, num_features=256, norm=False,
52 | dropout=0, num_classes=0):
53 | super(InceptionNet, self).__init__()
54 | self.cut_at_pooling = cut_at_pooling
55 |
56 | self.conv1 = _make_conv(3, 32)
57 | self.conv2 = _make_conv(32, 32)
58 | self.conv3 = _make_conv(32, 32)
59 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
60 | self.in_planes = 32
61 | self.inception4a = self._make_inception(64, 'Avg', 1)
62 | self.inception4b = self._make_inception(64, 'Max', 2)
63 | self.inception5a = self._make_inception(128, 'Avg', 1)
64 | self.inception5b = self._make_inception(128, 'Max', 2)
65 | self.inception6a = self._make_inception(256, 'Avg', 1)
66 | self.inception6b = self._make_inception(256, 'Max', 2)
67 |
68 | if not self.cut_at_pooling:
69 | self.num_features = num_features
70 | self.norm = norm
71 | self.dropout = dropout
72 | self.has_embedding = num_features > 0
73 | self.num_classes = num_classes
74 |
75 | self.avgpool = nn.AdaptiveAvgPool2d(1)
76 |
77 | if self.has_embedding:
78 | self.feat = nn.Linear(self.in_planes, self.num_features)
79 | self.feat_bn = nn.BatchNorm1d(self.num_features)
80 | else:
81 | # Change the num_features to CNN output channels
82 | self.num_features = self.in_planes
83 | if self.dropout > 0:
84 | self.drop = nn.Dropout(self.dropout)
85 | if self.num_classes > 0:
86 | self.classifier = nn.Linear(self.num_features, self.num_classes)
87 |
88 | self.reset_params()
89 |
90 | def forward(self, x):
91 | x = self.conv1(x)
92 | x = self.conv2(x)
93 | x = self.conv3(x)
94 | x = self.pool3(x)
95 | x = self.inception4a(x)
96 | x = self.inception4b(x)
97 | x = self.inception5a(x)
98 | x = self.inception5b(x)
99 | x = self.inception6a(x)
100 | x = self.inception6b(x)
101 |
102 | if self.cut_at_pooling:
103 | return x
104 |
105 | x = self.avgpool(x)
106 | x = x.view(x.size(0), -1)
107 |
108 | if self.has_embedding:
109 | x = self.feat(x)
110 | x = self.feat_bn(x)
111 | if self.norm:
112 | x = F.normalize(x)
113 | elif self.has_embedding:
114 | x = F.relu(x)
115 | if self.dropout > 0:
116 | x = self.drop(x)
117 | if self.num_classes > 0:
118 | x = self.classifier(x)
119 | return x
120 |
121 | def _make_inception(self, out_planes, pool_method, stride):
122 | block = Block(self.in_planes, out_planes, pool_method, stride)
123 | self.in_planes = (out_planes * 4 if pool_method == 'Avg' else
124 | out_planes * 2 + self.in_planes)
125 | return block
126 |
127 | def reset_params(self):
128 | for m in self.modules():
129 | if isinstance(m, nn.Conv2d):
130 | init.kaiming_normal(m.weight, mode='fan_out')
131 | if m.bias is not None:
132 | init.constant(m.bias, 0)
133 | elif isinstance(m, nn.BatchNorm2d):
134 | init.constant(m.weight, 1)
135 | init.constant(m.bias, 0)
136 | elif isinstance(m, nn.Linear):
137 | init.normal(m.weight, std=0.001)
138 | if m.bias is not None:
139 | init.constant(m.bias, 0)
140 |
141 |
142 | def inception(**kwargs):
143 | return InceptionNet(**kwargs)
144 |
--------------------------------------------------------------------------------
/reid/utils/data/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os.path as osp
3 |
4 | import numpy as np
5 |
6 | from ..serialization import read_json
7 |
8 |
9 | def _pluck(identities, indices, relabel=False):
10 | ret = []
11 | for index, pid in enumerate(indices):
12 | pid_images = identities[pid]
13 | for camid, cam_images in enumerate(pid_images):
14 | for fname in cam_images:
15 | name = osp.splitext(fname)[0]
16 | x, y, _ = map(int, name.split('_'))
17 | assert pid == x and camid == y
18 | if relabel:
19 | ret.append((fname, index, camid))
20 | else:
21 | ret.append((fname, pid, camid))
22 | return ret
23 |
24 | def _pluck_gallery(identities, indices, relabel=False):
25 | ret = []
26 | for index, pid in enumerate(indices):
27 | pid_images = identities[pid]
28 | for camid, cam_images in enumerate(pid_images):
29 | if len(cam_images[:-1])==0:
30 | for fname in cam_images:
31 | name = osp.splitext(fname)[0]
32 | x, y, _ = map(int, name.split('_'))
33 | assert pid == x and camid == y
34 | if relabel:
35 | ret.append((fname, index, camid))
36 | else:
37 | ret.append((fname, pid, camid))
38 | else:
39 | for fname in cam_images[:-1]:
40 | name = osp.splitext(fname)[0]
41 | x, y, _ = map(int, name.split('_'))
42 | assert pid == x and camid == y
43 | if relabel:
44 | ret.append((fname, index, camid))
45 | else:
46 | ret.append((fname, pid, camid))
47 | return ret
48 |
49 | def _pluck_query(identities, indices, relabel=False):
50 | ret = []
51 | for index, pid in enumerate(indices):
52 | pid_images = identities[pid]
53 | for camid, cam_images in enumerate(pid_images):
54 | for fname in cam_images[-1:]:
55 | name = osp.splitext(fname)[0]
56 | x, y, _ = map(int, name.split('_'))
57 | assert pid == x and camid == y
58 | if relabel:
59 | ret.append((fname, index, camid))
60 | else:
61 | ret.append((fname, pid, camid))
62 | return ret
63 |
64 |
65 | class Dataset(object):
66 | def __init__(self, root, split_id=0):
67 | self.root = root
68 | self.split_id = split_id
69 | self.meta = None
70 | self.split = None
71 | self.train, self.val, self.trainval = [], [], []
72 | self.query, self.gallery = [], []
73 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0
74 |
75 | @property
76 | def images_dir(self):
77 | return osp.join(self.root, 'images')
78 |
79 | def load(self, num_val=0.3, verbose=True):
80 | splits = read_json(osp.join(self.root, 'splits.json'))
81 | if self.split_id >= len(splits):
82 | raise ValueError("split_id exceeds total splits {}"
83 | .format(len(splits)))
84 | self.split = splits[self.split_id]
85 |
86 | # Randomly split train / val
87 | trainval_pids = np.asarray(self.split['trainval'])
88 | np.random.shuffle(trainval_pids)
89 | num = len(trainval_pids)
90 | if isinstance(num_val, float):
91 | num_val = int(round(num * num_val))
92 | if num_val >= num or num_val < 0:
93 | raise ValueError("num_val exceeds total identities {}"
94 | .format(num))
95 | train_pids = sorted(trainval_pids[:-num_val])
96 | val_pids = sorted(trainval_pids[-num_val:])
97 |
98 | self.meta = read_json(osp.join(self.root, 'meta.json'))
99 | identities = self.meta['identities']
100 | self.train = _pluck(identities, train_pids, relabel=True)
101 | self.val = _pluck(identities, val_pids, relabel=True)
102 | self.trainval = _pluck(identities, trainval_pids, relabel=True)
103 | self.query = _pluck_query(identities, self.split['query'])
104 | #self.gallery = _pluck(identities, self.split['gallery'])
105 | self.gallery = _pluck_gallery(identities, self.split['gallery'])
106 | self.num_train_ids = len(train_pids)
107 | self.num_val_ids = len(val_pids)
108 | self.num_trainval_ids = len(trainval_pids)
109 |
110 | if verbose:
111 | print(self.__class__.__name__, "dataset loaded")
112 | print(" subset | # ids | # images")
113 | print(" ---------------------------")
114 | print(" train | {:5d} | {:8d}"
115 | .format(self.num_train_ids, len(self.train)))
116 | print(" val | {:5d} | {:8d}"
117 | .format(self.num_val_ids, len(self.val)))
118 | print(" trainval | {:5d} | {:8d}"
119 | .format(self.num_trainval_ids, len(self.trainval)))
120 | print(" query | {:5d} | {:8d}"
121 | .format(len(self.split['query']), len(self.query)))
122 | print(" gallery | {:5d} | {:8d}"
123 | .format(len(self.split['gallery']), len(self.gallery)))
124 |
125 | def _check_integrity(self):
126 | return osp.isdir(osp.join(self.root, 'images')) and \
127 | osp.isfile(osp.join(self.root, 'meta.json')) and \
128 | osp.isfile(osp.join(self.root, 'splits.json'))
129 |
--------------------------------------------------------------------------------
/reid/rerank.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # !/usr/bin/env python3
4 | # -*- coding: utf-8 -*-
5 |
6 | import numpy as np
7 | from scipy.spatial.distance import cdist
8 | import torch
9 | import tqdm
10 | import torch.nn.functional as F
11 |
12 |
13 | def pairwiseDis(qFeature, gFeature): # 246s
14 | x, y = F.normalize(qFeature), F.normalize(gFeature)
15 | # x, y = qFeature, gFeature
16 | m, n = x.shape[0], y.shape[0]
17 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
18 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
19 | disMat.addmm_(1, -2, x, y.t())
20 | print('-----* Distance Matrix has been computed*-----')
21 | return disMat.clamp_(min=1e-5)
22 |
23 |
24 | def re_ranking(input_feature_source, input_feature, k1=20, k2=6, lambda_value=0.1):
25 | all_num = input_feature.shape[0]
26 | # feat = input_feature.astype(np.float16)
27 | feat = torch.from_numpy(input_feature) # target
28 | del input_feature
29 |
30 | if lambda_value != 0:
31 | print('Computing source distance...')
32 | srcFeat, tarFeat = input_feature_source, feat
33 | # all_num_source = input_feature_source.shape[0]
34 | # sour_tar_dist = np.power(cdist(input_feature, input_feature_source), 2).astype(np.float32) #608s
35 | sour_tar_dist = pairwiseDis(srcFeat, tarFeat).t().numpy()
36 | sour_tar_dist = 1 - np.exp(-sour_tar_dist) # tar-src
37 | source_dist_vec = np.min(sour_tar_dist, axis=1)
38 | source_dist_vec = source_dist_vec / (np.max(source_dist_vec) + 1e-3) # for trget
39 | source_dist = np.zeros([all_num, all_num]) # tar size
40 | for i in range(all_num):
41 | source_dist[i, :] = source_dist_vec + source_dist_vec[i]
42 | del sour_tar_dist
43 | del source_dist_vec
44 |
45 | print('Computing original distance...')
46 | original_dist = pairwiseDis(feat, feat).cpu().numpy()
47 | print('done...')
48 | # original_dist = np.power(original_dist,2).astype(np.float16)
49 | del feat
50 | # original_dist = np.concatenate(dist,axis=0)
51 | gallery_num = original_dist.shape[0] # gallery_num=all_num
52 | original_dist = np.transpose(original_dist / (np.max(original_dist, axis=0)))
53 | V = np.zeros_like(original_dist).astype(np.float16)
54 | initial_rank = np.argsort(original_dist).astype(np.int32) ## default axis=-1.
55 |
56 | print('Starting re_ranking...')
57 | for i in tqdm.tqdm(range(all_num)):
58 | # k-reciprocal neighbors
59 | forward_k_neigh_index = initial_rank[i,
60 | :k1 + 1] ## k1+1 because self always ranks first. forward_k_neigh_index.shape=[k1+1]. forward_k_neigh_index[0] == i.
61 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,
62 | :k1 + 1] ##backward.shape = [k1+1, k1+1]. For each ele in forward_k_neigh_index, find its rank k1 neighbors
63 | fi = np.where(backward_k_neigh_index == i)[0]
64 | k_reciprocal_index = forward_k_neigh_index[fi] ## get R(p,k) in the paper
65 | k_reciprocal_expansion_index = k_reciprocal_index
66 | for j in range(len(k_reciprocal_index)):
67 | candidate = k_reciprocal_index[j]
68 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
69 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
70 | :int(np.around(k1 / 2)) + 1]
71 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
72 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
73 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
74 | candidate_k_reciprocal_index):
75 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
76 |
77 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique
78 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
79 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
80 | # original_dist = original_dist[:query_num,]
81 | if k2 != 1:
82 | V_qe = np.zeros_like(V, dtype=np.float16)
83 | for i in tqdm.tqdm(range(all_num)):
84 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
85 | V = V_qe
86 | del V_qe
87 | del initial_rank
88 | invIndex = []
89 | for i in tqdm.tqdm(range(gallery_num)):
90 | invIndex.append(np.where(V[:, i] != 0)[0]) # len(invIndex)=all_num
91 |
92 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
93 |
94 | for i in tqdm.tqdm(range(all_num)):
95 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
96 | indNonZero = np.where(V[i, :] != 0)[0]
97 | indImages = []
98 | indImages = [invIndex[ind] for ind in indNonZero]
99 | for j in range(len(indNonZero)):
100 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
101 | V[indImages[j], indNonZero[j]])
102 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
103 |
104 | pos_bool = (jaccard_dist < 0)
105 | jaccard_dist[pos_bool] = 0.0
106 |
107 | if lambda_value == 0:
108 | return jaccard_dist
109 | else:
110 | final_dist = jaccard_dist * (1 - lambda_value) + source_dist * lambda_value
111 | return final_dist
112 |
--------------------------------------------------------------------------------
/reid/datasets/market1501.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 | import numpy as np
4 |
5 | from ..utils.data import Dataset
6 | from ..utils.osutils import mkdir_if_missing
7 | from ..utils.serialization import read_json
8 | from ..utils.serialization import write_json
9 |
10 |
11 | def _pluck(identities, indices, relabel=False):
12 | """Extract im names of given pids.
13 | Args:
14 | identities: containing im names
15 | indices: pids
16 | relabel: whether to transform pids to classification labels
17 | """
18 | ret = []
19 | for index, pid in enumerate(indices):
20 | pid_images = identities[pid]
21 | for camid, cam_images in enumerate(pid_images):
22 | for fname in cam_images:
23 | name = osp.splitext(fname)[0]
24 | x, y, _ = map(int, name.split('_'))
25 | assert pid == x and camid == y
26 | if relabel:
27 | ret.append((fname, index, camid))
28 | else:
29 | ret.append((fname, pid, camid))
30 | return ret
31 |
32 |
33 | ########################
34 |
35 |
36 | class Market1501(Dataset):
37 | url = 'https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view'
38 | md5 = '65005ab7d12ec1c44de4eeafe813e68a'
39 |
40 | def __init__(self, root, split_id=0, num_val=100, download=True):
41 | super(Market1501, self).__init__(root, split_id=split_id)
42 |
43 | if download:
44 | self.download()
45 |
46 | if not self._check_integrity():
47 | raise RuntimeError("Dataset not found or corrupted. " +
48 | "You can use download=True to download it.")
49 |
50 | self.load(num_val)
51 |
52 | def download(self):
53 | if self._check_integrity():
54 | print("Files already downloaded and verified")
55 | return
56 |
57 | import re
58 | import hashlib
59 | import shutil
60 | from glob import glob
61 | from zipfile import ZipFile
62 |
63 | raw_dir = osp.join(self.root, 'raw')
64 | mkdir_if_missing(raw_dir)
65 |
66 | # Download the raw zip file
67 | fpath = osp.join(raw_dir, 'Market-1501-v15.09.15.zip')
68 | if osp.isfile(fpath) and \
69 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5:
70 | print("Using downloaded file: " + fpath)
71 | else:
72 | raise RuntimeError("Please download the dataset manually from {} "
73 | "to {}".format(self.url, fpath))
74 |
75 | # Extract the file
76 | exdir = osp.join(raw_dir, 'Market-1501-v15.09.15')
77 | if not osp.isdir(exdir):
78 | print("Extracting zip file")
79 | with ZipFile(fpath) as z:
80 | z.extractall(path=raw_dir)
81 |
82 | # Format
83 | images_dir = osp.join(self.root, 'images')
84 | mkdir_if_missing(images_dir)
85 |
86 | # 1501 identities (+1 for background) with 6 camera views each
87 | identities = [[[] for _ in range(6)] for _ in range(1502)]
88 |
89 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')):
90 | fnames = [] ######### Added. Names of images in new dir.
91 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg')))
92 | pids = set()
93 | for fpath in fpaths:
94 | fname = osp.basename(fpath)
95 | pid, cam = map(int, pattern.search(fname).groups())
96 | if pid == -1: continue # junk images are just ignored
97 | assert 0 <= pid <= 1501 # pid == 0 means background
98 | assert 1 <= cam <= 6
99 | cam -= 1
100 | pids.add(pid)
101 | fname = ('{:08d}_{:02d}_{:04d}.jpg'
102 | .format(pid, cam, len(identities[pid][cam])))
103 | identities[pid][cam].append(fname)
104 | shutil.copy(fpath, osp.join(images_dir, fname))
105 | fnames.append(fname) ######### Added
106 | return pids, fnames
107 |
108 | trainval_pids, _ = register('bounding_box_train')
109 | gallery_pids, gallery_fnames = register('bounding_box_test')
110 | query_pids, query_fnames = register('query')
111 | assert query_pids <= gallery_pids
112 | assert trainval_pids.isdisjoint(gallery_pids)
113 |
114 | # Save meta information into a json file
115 | meta = {'name': 'Market1501', 'shot': 'multiple', 'num_cameras': 6,
116 | 'identities': identities,
117 | 'query_fnames': query_fnames, ######### Added
118 | 'gallery_fnames': gallery_fnames} ######### Added
119 | write_json(meta, osp.join(self.root, 'meta.json'))
120 |
121 | # Save the only training / test split
122 | splits = [{
123 | 'trainval': sorted(list(trainval_pids)),
124 | 'query': sorted(list(query_pids)),
125 | 'gallery': sorted(list(gallery_pids))}]
126 | write_json(splits, osp.join(self.root, 'splits.json'))
127 |
128 | ########################
129 | # Added
130 | def load(self, num_val=0.3, verbose=True):
131 | splits = read_json(osp.join(self.root, 'splits.json'))
132 | if self.split_id >= len(splits):
133 | raise ValueError("split_id exceeds total splits {}"
134 | .format(len(splits)))
135 | self.split = splits[self.split_id]
136 |
137 | # Randomly split train / val
138 | trainval_pids = np.asarray(self.split['trainval'])
139 | np.random.shuffle(trainval_pids)
140 | num = len(trainval_pids)
141 | if isinstance(num_val, float):
142 | num_val = int(round(num * num_val))
143 | if num_val >= num or num_val < 0:
144 | raise ValueError("num_val exceeds total identities {}"
145 | .format(num))
146 | train_pids = sorted(trainval_pids[:-num_val])
147 | val_pids = sorted(trainval_pids[-num_val:])
148 |
149 | self.meta = read_json(osp.join(self.root, 'meta.json'))
150 | identities = self.meta['identities']
151 |
152 | self.train = _pluck(identities, train_pids, relabel=True)
153 | self.val = _pluck(identities, val_pids, relabel=True)
154 | self.trainval = _pluck(identities, trainval_pids, relabel=True)
155 | self.num_train_ids = len(train_pids)
156 | self.num_val_ids = len(val_pids)
157 | self.num_trainval_ids = len(trainval_pids)
158 |
159 | ##########
160 | # Added
161 | query_fnames = self.meta['query_fnames']
162 | gallery_fnames = self.meta['gallery_fnames']
163 | self.query = []
164 | for fname in query_fnames:
165 | name = osp.splitext(fname)[0]
166 | pid, cam, _ = map(int, name.split('_'))
167 | self.query.append((fname, pid, cam))
168 | self.gallery = []
169 | for fname in gallery_fnames:
170 | name = osp.splitext(fname)[0]
171 | pid, cam, _ = map(int, name.split('_'))
172 | self.gallery.append((fname, pid, cam))
173 | ##########
174 |
175 | if verbose:
176 | print(self.__class__.__name__, "dataset loaded")
177 | print(" subset | # ids | # images")
178 | print(" ---------------------------")
179 | print(" train | {:5d} | {:8d}"
180 | .format(self.num_train_ids, len(self.train)))
181 | print(" val | {:5d} | {:8d}"
182 | .format(self.num_val_ids, len(self.val)))
183 | print(" trainval | {:5d} | {:8d}"
184 | .format(self.num_trainval_ids, len(self.trainval)))
185 | print(" query | {:5d} | {:8d}"
186 | .format(len(self.split['query']), len(self.query)))
187 | print(" gallery | {:5d} | {:8d}"
188 | .format(len(self.split['gallery']), len(self.gallery)))
189 | ########################
190 |
--------------------------------------------------------------------------------
/reid/datasets/dukemtmc.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 |
4 | from ..utils.data import Dataset
5 | from ..utils.osutils import mkdir_if_missing
6 | from ..utils.serialization import read_json
7 | from ..utils.serialization import write_json
8 |
9 |
10 | def _pluck(identities, indices, relabel=False):
11 | """Extract im names of given pids.
12 | Args:
13 | identities: containing im names
14 | indices: pids
15 | relabel: whether to transform pids to classification labels
16 | """
17 | ret = []
18 | for index, pid in enumerate(indices):
19 | pid_images = identities[pid]
20 | for camid, cam_images in enumerate(pid_images):
21 | for fname in cam_images:
22 | name = osp.splitext(fname)[0]
23 | x, y, _ = map(int, name.split('_'))
24 | assert pid == x and camid == y
25 | if relabel:
26 | ret.append((fname, index, camid))
27 | else:
28 | ret.append((fname, pid, camid))
29 | return ret
30 |
31 |
32 | class DukeMTMC(Dataset):
33 | url = 'https://drive.google.com/uc?id=0B0VOCNYh8HeRdnBPa2ZWaVBYSVk'
34 | md5 = '2f93496f9b516d1ee5ef51c1d5e7d601'
35 |
36 | def __init__(self, root, split_id=0, num_val=100, download=True):
37 | super(DukeMTMC, self).__init__(root, split_id=split_id)
38 |
39 | if download:
40 | self.download()
41 |
42 | if not self._check_integrity():
43 | raise RuntimeError("Dataset not found or corrupted. " +
44 | "You can use download=True to download it.")
45 |
46 | self.load(num_val)
47 |
48 | def download(self):
49 | if self._check_integrity():
50 | print("Files already downloaded and verified")
51 | return
52 |
53 | import re
54 | import hashlib
55 | import shutil
56 | from glob import glob
57 | from zipfile import ZipFile
58 |
59 | raw_dir = osp.join(self.root, 'raw')
60 | mkdir_if_missing(raw_dir)
61 |
62 | # Download the raw zip file
63 | fpath = osp.join(raw_dir, 'DukeMTMC-reID.zip')
64 | if osp.isfile(fpath) and \
65 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5:
66 | print("Using downloaded file: " + fpath)
67 | else:
68 | raise RuntimeError("Please download the dataset manually from {} "
69 | "to {}".format(self.url, fpath))
70 |
71 | # Extract the file
72 | exdir = osp.join(raw_dir, 'DukeMTMC-reID')
73 | if not osp.isdir(exdir):
74 | print("Extracting zip file")
75 | with ZipFile(fpath) as z:
76 | z.extractall(path=raw_dir)
77 |
78 | # Format
79 | images_dir = osp.join(self.root, 'images')
80 | mkdir_if_missing(images_dir)
81 |
82 | identities = []
83 | all_pids = {}
84 |
85 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')):
86 | fnames = [] ###### New Add. Names of images in new dir
87 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg')))
88 | pids = set()
89 | for fpath in fpaths:
90 | fname = osp.basename(fpath)
91 | pid, cam = map(int, pattern.search(fname).groups())
92 | assert 1 <= cam <= 8
93 | cam -= 1
94 | if pid not in all_pids:
95 | all_pids[pid] = len(all_pids)
96 | pid = all_pids[pid]
97 | pids.add(pid)
98 | if pid >= len(identities):
99 | assert pid == len(identities)
100 | identities.append([[] for _ in range(8)]) # 8 camera views
101 | fname = ('{:08d}_{:02d}_{:04d}.jpg'
102 | .format(pid, cam, len(identities[pid][cam])))
103 | identities[pid][cam].append(fname)
104 | shutil.copy(fpath, osp.join(images_dir, fname))
105 | fnames.append(fname) ######## added
106 | return pids, fnames
107 |
108 | trainval_pids, _ = register('bounding_box_train')
109 | gallery_pids, gallery_fnames = register('bounding_box_test')
110 | query_pids, query_fnames = register('query')
111 | assert query_pids <= gallery_pids
112 | assert trainval_pids.isdisjoint(gallery_pids)
113 |
114 | # Save meta information into a json file
115 | meta = {'name': 'DukeMTMC', 'shot': 'multiple', 'num_cameras': 8,
116 | 'identities': identities,
117 | 'query_fnames': query_fnames, ########## Added
118 | 'gallery_fnames': gallery_fnames} ######### Added
119 | write_json(meta, osp.join(self.root, 'meta.json'))
120 |
121 | # Save the only training / test split
122 | splits = [{
123 | 'trainval': sorted(list(trainval_pids)),
124 | 'query': sorted(list(query_pids)),
125 | 'gallery': sorted(list(gallery_pids))}]
126 | write_json(splits, osp.join(self.root, 'splits.json'))
127 |
128 | ########################
129 | # Added
130 | def load(self, num_val=0.3, verbose=True):
131 | import numpy as np
132 | splits = read_json(osp.join(self.root, 'splits.json'))
133 | if self.split_id >= len(splits):
134 | raise ValueError("split_id exceeds total splits {}"
135 | .format(len(splits)))
136 | self.split = splits[self.split_id]
137 |
138 | # Randomly split train / val
139 | trainval_pids = np.asarray(self.split['trainval'])
140 | np.random.shuffle(trainval_pids)
141 | num = len(trainval_pids)
142 | if isinstance(num_val, float):
143 | num_val = int(round(num * num_val))
144 | if num_val >= num or num_val < 0:
145 | raise ValueError("num_val exceeds total identities {}"
146 | .format(num))
147 | train_pids = sorted(trainval_pids[:-num_val])
148 | val_pids = sorted(trainval_pids[-num_val:])
149 |
150 | self.meta = read_json(osp.join(self.root, 'meta.json'))
151 | identities = self.meta['identities']
152 |
153 | self.train = _pluck(identities, train_pids, relabel=True)
154 | self.val = _pluck(identities, val_pids, relabel=True)
155 | self.trainval = _pluck(identities, trainval_pids, relabel=True)
156 | self.num_train_ids = len(train_pids)
157 | self.num_val_ids = len(val_pids)
158 | self.num_trainval_ids = len(trainval_pids)
159 |
160 | ##########
161 | # Added
162 | query_fnames = self.meta['query_fnames']
163 | gallery_fnames = self.meta['gallery_fnames']
164 | self.query = []
165 | for fname in query_fnames:
166 | name = osp.splitext(fname)[0]
167 | pid, cam, _ = map(int, name.split('_'))
168 | self.query.append((fname, pid, cam))
169 | self.gallery = []
170 | for fname in gallery_fnames:
171 | name = osp.splitext(fname)[0]
172 | pid, cam, _ = map(int, name.split('_'))
173 | self.gallery.append((fname, pid, cam))
174 | ##########
175 |
176 | if verbose:
177 | print(self.__class__.__name__, "dataset loaded")
178 | print(" subset | # ids | # images")
179 | print(" ---------------------------")
180 | print(" train | {:5d} | {:8d}"
181 | .format(self.num_train_ids, len(self.train)))
182 | print(" val | {:5d} | {:8d}"
183 | .format(self.num_val_ids, len(self.val)))
184 | print(" trainval | {:5d} | {:8d}"
185 | .format(self.num_trainval_ids, len(self.trainval)))
186 | print(" query | {:5d} | {:8d}"
187 | .format(len(self.split['query']), len(self.query)))
188 | print(" gallery | {:5d} | {:8d}"
189 | .format(len(self.split['gallery']), len(self.gallery)))
190 | ########################
191 |
--------------------------------------------------------------------------------
/selftrainingKmeans.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import print_function, absolute_import
5 | import argparse
6 | import time
7 | import os.path as osp
8 | import os
9 | import numpy as np
10 | import torch
11 | from torch import nn
12 | from torch.nn import init
13 | from torch.backends import cudnn
14 | from torch.utils.data import DataLoader
15 | from reid import datasets
16 | from reid import models
17 | from reid.dist_metric import DistanceMetric
18 | from reid.loss import TripletLoss
19 | from reid.trainers import Trainer
20 | from reid.evaluators import Evaluator, extract_features
21 | from reid.utils.data import transforms as T
22 | import torch.nn.functional as F
23 | from reid.utils.data.preprocessor import Preprocessor
24 | from reid.utils.data.sampler import RandomIdentitySampler
25 | from reid.utils.serialization import load_checkpoint, save_checkpoint
26 |
27 | from sklearn.cluster import KMeans
28 | from reid.rerank import re_ranking
29 |
30 |
31 | def get_data(name, data_dir, height, width, batch_size,
32 | workers):
33 | root = osp.join(data_dir, name)
34 |
35 | dataset = datasets.create(name, root, num_val=0.1)
36 |
37 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
38 | std=[0.229, 0.224, 0.225])
39 |
40 | # use all training and validation images in target dataset
41 | train_set = dataset.trainval
42 | num_classes = dataset.num_trainval_ids
43 |
44 | transformer = T.Compose([
45 | T.Resize((height,width)),
46 | T.ToTensor(),
47 | normalizer,
48 | ])
49 |
50 | extfeat_loader = DataLoader(
51 | Preprocessor(train_set, root=dataset.images_dir,
52 | transform=transformer),
53 | batch_size=batch_size, num_workers=workers,
54 | shuffle=False, pin_memory=True)
55 |
56 | test_loader = DataLoader(
57 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)),
58 | root=dataset.images_dir, transform=transformer),
59 | batch_size=batch_size//2, num_workers=workers,
60 | shuffle=False, pin_memory=True)
61 |
62 | return dataset, num_classes, extfeat_loader, test_loader
63 |
64 |
65 | def get_source_data(name, data_dir, height, width, batch_size, workers):
66 | root = osp.join(data_dir, name)
67 |
68 | dataset = datasets.create(name, root, num_val=0.1)
69 |
70 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
71 | std=[0.229, 0.224, 0.225])
72 |
73 | # use all training images on source dataset
74 | train_set = dataset.train
75 | num_classes = dataset.num_train_ids
76 |
77 | transformer = T.Compose([
78 | T.Resize((height,width)),
79 | T.ToTensor(),
80 | normalizer,
81 | ])
82 |
83 | extfeat_loader = DataLoader(
84 | Preprocessor(train_set, root=dataset.images_dir,
85 | transform=transformer),
86 | batch_size=batch_size, num_workers=workers,
87 | shuffle=False, pin_memory=True)
88 |
89 | return dataset, extfeat_loader
90 |
91 |
92 | def splitLowconfi(feature, labels, centers, ratio=0.2):
93 | # set bot 20% imsimilar samples to -1
94 | # center VS feature
95 | centerDis = calDis(torch.from_numpy(feature), torch.from_numpy(centers)).numpy() # center VS samples
96 | noiseLoc = []
97 | for ii, pid in enumerate(set(labels)):
98 | curDis = centerDis[:,ii]
99 | curDis[labels!=pid] = 100
100 | smallLossIdx = curDis.argsort()
101 | smallLossIdx = smallLossIdx[curDis[smallLossIdx]!=100]
102 | # bot 20% removed
103 | partSize = int(ratio*smallLossIdx.shape[0])
104 | if partSize!=0:
105 | noiseLoc.extend(smallLossIdx[-partSize:])
106 | labels[noiseLoc] = -1
107 | return labels
108 |
109 |
110 | def calDis(qFeature, gFeature):#246s
111 | x, y = F.normalize(qFeature), F.normalize(gFeature)
112 | # x, y = qFeature, gFeature
113 | m, n = x.shape[0], y.shape[0]
114 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
115 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
116 | disMat.addmm_(1, -2, x, y.t())
117 | return disMat.clamp_(min=1e-5)
118 |
119 |
120 | def labelUnknown(knownFeat, allLab, unknownFeat):
121 | disMat = calDis(knownFeat, unknownFeat)
122 | labLoc = disMat.argmin(dim=0)
123 | return allLab[labLoc]
124 |
125 |
126 | def labelNoise(feature, labels):
127 | # features and labels with -1
128 | noiseFeat, pureFeat = feature[labels==-1,:], feature[labels!=-1,:]
129 | labels = labels[labels!=-1]
130 | unLab = labelUnknown(pureFeat, labels, noiseFeat)
131 | return unLab.numpy()
132 |
133 |
134 | def main(args):
135 | np.random.seed(args.seed)
136 | torch.manual_seed(args.seed)
137 | cudnn.benchmark = True
138 |
139 | # Create data loaders
140 | assert args.num_instances > 1, "num_instances should be greater than 1"
141 | assert args.batch_size % args.num_instances == 0, \
142 | 'num_instances should divide batch_size'
143 | if args.height is None or args.width is None:
144 | args.height, args.width = (144, 56) if args.arch == 'inception' else \
145 | (256, 128)
146 |
147 | # get source data
148 | src_dataset, src_extfeat_loader = \
149 | get_source_data(args.src_dataset, args.data_dir, args.height,
150 | args.width, args.batch_size, args.workers)
151 | # get target data
152 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \
153 | get_data(args.tgt_dataset, args.data_dir, args.height,
154 | args.width, args.batch_size, args.workers)
155 |
156 | # Create model
157 | # Hacking here to let the classifier be the number of source ids
158 | if args.src_dataset == 'dukemtmc':
159 | model = models.create(args.arch, num_classes=632, pretrained=False)
160 | elif args.src_dataset == 'market1501':
161 | model = models.create(args.arch, num_classes=676, pretrained=False)
162 | else:
163 | raise RuntimeError('Please specify the number of classes (ids) of the network.')
164 |
165 | # Load from checkpoint
166 | if args.resume:
167 | print('Resuming checkpoints from finetuned model on another dataset...\n')
168 | checkpoint = load_checkpoint(args.resume)
169 | model.load_state_dict(checkpoint['state_dict'], strict=False)
170 | else:
171 | raise RuntimeWarning('Not using a pre-trained model.')
172 | model = nn.DataParallel(model).cuda()
173 |
174 | # evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
175 | # if args.evaluate: return
176 |
177 | # Criterion
178 | criterion = [
179 | TripletLoss(args.margin, args.num_instances, isAvg=True, use_semi=True).cuda(),
180 | TripletLoss(args.margin, args.num_instances, isAvg=True, use_semi=True).cuda(),
181 | ]
182 |
183 |
184 | # Optimizer
185 | optimizer = torch.optim.Adam(
186 | model.parameters(), lr = args.lr
187 | )
188 |
189 |
190 | # training stage transformer on input images
191 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
192 | train_transformer = T.Compose([
193 | T.Resize((args.height,args.width)),
194 | T.RandomHorizontalFlip(),
195 | T.ToTensor(), normalizer,
196 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3)
197 | ])
198 |
199 |
200 | # # Start training
201 | for iter_n in range(args.iteration):
202 | if args.lambda_value == 0:
203 | source_features = 0
204 | else:
205 | # get source datas' feature
206 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq)
207 | # synchronization feature order with src_dataset.train
208 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0)
209 |
210 | # extract training images' features
211 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n+1))
212 | target_features, tarNames = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq)
213 | # synchronization feature order with dataset.train
214 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0)
215 | target_real_label = np.asarray([tarNames[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval])
216 | numTarID = len(set(target_real_label))
217 | # calculate distance and rerank result
218 | print('Calculating feature distances...')
219 | target_features = target_features.numpy()
220 | cluster = KMeans(n_clusters=numTarID, n_jobs=8, n_init=1)
221 |
222 | # select & cluster images as training set of this epochs
223 | print('Clustering and labeling...')
224 | clusterRes = cluster.fit(target_features)
225 | labels, centers = clusterRes.labels_, clusterRes.cluster_centers_
226 | # labels = splitLowconfi(target_features,labels,centers)
227 | # num_ids = len(set(labels))
228 | # print('Iteration {} have {} training ids'.format(iter_n+1, num_ids))
229 | # generate new dataset
230 | new_dataset = []
231 | for (fname, _, cam), label in zip(tgt_dataset.trainval, labels):
232 | # if label==-1: continue
233 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0
234 | new_dataset.append((fname,label,cam))
235 | print('Iteration {} have {} training images'.format(iter_n+1, len(new_dataset)))
236 | train_loader = DataLoader(
237 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer),
238 | batch_size=args.batch_size, num_workers=4,
239 | sampler=RandomIdentitySampler(new_dataset, args.num_instances),
240 | pin_memory=True, drop_last=True
241 | )
242 |
243 | # train model with new generated dataset
244 | trainer = Trainer(model, criterion)
245 |
246 |
247 | evaluator = Evaluator(model, print_freq=args.print_freq)
248 |
249 | # Start training
250 | for epoch in range(args.epochs):
251 | # trainer.train(epoch, remRate=0.2+(0.6/args.iteration)*(1+iter_n)) # to at most 80%
252 | trainer.train(epoch, train_loader, optimizer)
253 | # test only
254 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
255 | #print('co-model:\n')
256 | #rank_score = evaluatorB.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
257 |
258 | # Evaluate
259 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
260 | save_checkpoint({
261 | 'state_dict': model.module.state_dict(),
262 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0],
263 | }, True, fpath=osp.join(args.logs_dir, 'adapted.pth.tar'))
264 | return (rank_score.map, rank_score.market1501[0])
265 |
266 |
267 | if __name__ == '__main__':
268 | parser = argparse.ArgumentParser(description="Triplet loss classification")
269 | # data
270 | parser.add_argument('--src_dataset', type=str, default='dukemtmc',
271 | choices=datasets.names())
272 | parser.add_argument('--tgt_dataset', type=str, default='market1501',
273 | choices=datasets.names())
274 | parser.add_argument('--batch_size', type=int, default=64)
275 | parser.add_argument('--workers', type=int, default=4)
276 | parser.add_argument('--split', type=int, default=0)
277 | parser.add_argument('--noiseLam', type=float, default=0.5)
278 | parser.add_argument('--height', type=int,
279 | help="input height, default: 256 for resnet*, "
280 | "144 for inception")
281 | parser.add_argument('--width', type=int,
282 | help="input width, default: 128 for resnet*, "
283 | "56 for inception")
284 | parser.add_argument('--combine-trainval', action='store_true',
285 | help="train and val sets together for training, "
286 | "val set alone for validation")
287 | parser.add_argument('--num_instances', type=int, default=4,
288 | help="each minibatch consist of "
289 | "(batch_size // num_instances) identities, and "
290 | "each identity has num_instances instances, "
291 | "default: 4")
292 | # model
293 | parser.add_argument('--arch', type=str, default='resnet50',
294 | choices=models.names())
295 | # loss
296 | parser.add_argument('--margin', type=float, default=0.5,
297 | help="margin of the triplet loss, default: 0.5")
298 | parser.add_argument('--lambda_value', type=float, default=0.1,
299 | help="balancing parameter, default: 0.1")
300 | parser.add_argument('--rho', type=float, default=1.6e-3,
301 | help="rho percentage, default: 1.6e-3")
302 | # optimizer
303 | parser.add_argument('--lr', type=float, default=6e-5,
304 | help="learning rate of all parameters")
305 | # training configs
306 | parser.add_argument('--resume', type=str, metavar='PATH',
307 | default = '')
308 | parser.add_argument('--evaluate', type=int, default=0,
309 | help="evaluation only")
310 | parser.add_argument('--seed', type=int, default=1)
311 | parser.add_argument('--print_freq', type=int, default=1)
312 | parser.add_argument('--iteration', type=int, default=10)
313 | parser.add_argument('--epochs', type=int, default=30)
314 | # metric learning
315 | parser.add_argument('--dist_metric', type=str, default='euclidean',
316 | choices=['euclidean', 'kissme'])
317 | # misc
318 | parser.add_argument('--data_dir', type=str, metavar='PATH',
319 | default='')
320 | parser.add_argument('--logs_dir', type=str, metavar='PATH',
321 | default='')
322 |
323 | args = parser.parse_args()
324 | mean_ap, rank1 = main(args)
325 | results_file = np.asarray([mean_ap, rank1])
326 | file_name = time.strftime("%H%M%S", time.localtime())
327 | file_name = osp.join(args.logs_dir, file_name)
328 | np.save(file_name, results_file)
329 |
--------------------------------------------------------------------------------
/selfNoise.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import print_function, absolute_import
5 | import argparse
6 | import time
7 | import os.path as osp
8 | import os
9 | import numpy as np
10 | import torch
11 | from torch import nn
12 | from torch.nn import init
13 | from torch.backends import cudnn
14 | from torch.utils.data import DataLoader
15 | from reid import datasets
16 | from reid import models
17 | from reid.dist_metric import DistanceMetric
18 | from reid.loss import TripletLoss
19 | from reid.trainers import Trainer
20 | from reid.evaluators import Evaluator, extract_features
21 | from reid.utils.data import transforms as T
22 | import torch.nn.functional as F
23 | from reid.utils.data.preprocessor import Preprocessor
24 | from reid.utils.data.sampler import RandomIdentitySampler
25 | from reid.utils.serialization import load_checkpoint, save_checkpoint
26 |
27 | from sklearn.cluster import DBSCAN
28 | from reid.rerank import re_ranking
29 |
30 |
31 | def calScores(clusters, labels):
32 | """
33 | compute pair-wise precision pair-wise recall
34 | """
35 | from scipy.special import comb
36 | if len(clusters) == 0:
37 | return 0, 0
38 | else:
39 | curCluster = []
40 | for curClus in clusters.values():
41 | curCluster.append(labels[curClus])
42 | TPandFP = sum([comb(len(val), 2) for val in curCluster])
43 | TP = 0
44 | for clusterVal in curCluster:
45 | for setMember in set(clusterVal):
46 | if sum(clusterVal == setMember) < 2: continue
47 | TP += comb(sum(clusterVal == setMember), 2)
48 | FP = TPandFP - TP
49 | # FN and TN
50 | TPandFN = sum([comb(labels.tolist().count(val), 2) for val in set(labels)])
51 | FN = TPandFN - TP
52 | # cal precision and recall
53 | precision, recall = TP / (TP + FP), TP / (TP + FN)
54 | fScore = 2 * precision * recall / (precision + recall)
55 | return precision, recall, fScore
56 |
57 |
58 | def get_data(name, data_dir, height, width, batch_size,
59 | workers):
60 | root = osp.join(data_dir, name)
61 |
62 | dataset = datasets.create(name, root, num_val=0.1)
63 |
64 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
65 | std=[0.229, 0.224, 0.225])
66 |
67 | # use all training and validation images in target dataset
68 | train_set = dataset.trainval
69 | num_classes = dataset.num_trainval_ids
70 |
71 | transformer = T.Compose([
72 | T.Resize((height, width)),
73 | T.ToTensor(),
74 | normalizer,
75 | ])
76 |
77 | extfeat_loader = DataLoader(
78 | Preprocessor(train_set, root=dataset.images_dir,
79 | transform=transformer),
80 | batch_size=batch_size, num_workers=workers,
81 | shuffle=False, pin_memory=True)
82 |
83 | test_loader = DataLoader(
84 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)),
85 | root=dataset.images_dir, transform=transformer),
86 | batch_size=batch_size, num_workers=workers,
87 | shuffle=False, pin_memory=True)
88 |
89 | return dataset, num_classes, extfeat_loader, test_loader
90 |
91 |
92 | def get_source_data(name, data_dir, height, width, batch_size,
93 | workers):
94 | root = osp.join(data_dir, name)
95 |
96 | dataset = datasets.create(name, root, num_val=0.1)
97 |
98 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
99 | std=[0.229, 0.224, 0.225])
100 |
101 | # use all training images on source dataset
102 | train_set = dataset.train
103 | num_classes = dataset.num_train_ids
104 |
105 | transformer = T.Compose([
106 | T.Resize((height, width)),
107 | T.ToTensor(),
108 | normalizer,
109 | ])
110 |
111 | extfeat_loader = DataLoader(
112 | Preprocessor(train_set, root=dataset.images_dir,
113 | transform=transformer),
114 | batch_size=batch_size, num_workers=workers,
115 | shuffle=False, pin_memory=True)
116 |
117 | return dataset, extfeat_loader
118 |
119 |
120 | def calDis(qFeature, gFeature): # 246s
121 | x, y = F.normalize(qFeature), F.normalize(gFeature)
122 | # x, y = qFeature, gFeature
123 | m, n = x.shape[0], y.shape[0]
124 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
125 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
126 | disMat.addmm_(1, -2, x, y.t())
127 | return disMat.clamp_(min=1e-5)
128 |
129 |
130 | def labelUnknown(knownFeat, allLab, unknownFeat):
131 | # allLab--label from known
132 | disMat = calDis(knownFeat, unknownFeat)
133 | labLoc = disMat.argmin(dim=0)
134 | return allLab[labLoc]
135 |
136 |
137 | def labelNoise(feature, labels):
138 | # features and labels with -1
139 | noiseFeat, pureFeat = feature[labels == -1, :], feature[labels != -1, :]
140 | pureLabs = labels[labels != -1] # no outliers
141 | unLab = labelUnknown(pureFeat, pureLabs, noiseFeat)
142 | labels[labels == -1] = unLab
143 | return labels.numpy()
144 |
145 |
146 | def getCenter(features, labels):
147 | allCenter = {}
148 | features = features[labels != -1, :]
149 | labels = labels[labels != -1]
150 | for pid in set(labels):
151 | allCenter[pid] = torch.from_numpy(features[labels == pid, :].mean(axis=0)).unsqueeze(0)
152 | return torch.cat(list(allCenter.values()))
153 |
154 |
155 | def main(args):
156 | np.random.seed(args.seed)
157 | torch.manual_seed(args.seed)
158 | cudnn.benchmark = True
159 |
160 | # Create data loaders
161 | assert args.num_instances > 1, "num_instances should be greater than 1"
162 | assert args.batch_size % args.num_instances == 0, \
163 | 'num_instances should divide batch_size'
164 | if args.height is None or args.width is None:
165 | args.height, args.width = (144, 56) if args.arch == 'inception' else \
166 | (256, 128)
167 |
168 | # get source data
169 | src_dataset, src_extfeat_loader = \
170 | get_source_data(args.src_dataset, args.data_dir, args.height,
171 | args.width, args.batch_size, args.workers)
172 | # get target data
173 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \
174 | get_data(args.tgt_dataset, args.data_dir, args.height,
175 | args.width, args.batch_size, args.workers)
176 |
177 | # Create model
178 | # Hacking here to let the classifier be the number of source ids
179 | if args.src_dataset == 'dukemtmc':
180 | model = models.create(args.arch, num_classes=632, pretrained=False)
181 | elif args.src_dataset == 'market1501':
182 | model = models.create(args.arch, num_classes=676, pretrained=False)
183 | else:
184 | raise RuntimeError('Please specify the number of classes (ids) of the network.')
185 |
186 | # Load from checkpoint
187 | if args.resume:
188 | print('Resuming checkpoints from finetuned model on another dataset...\n')
189 | checkpoint = load_checkpoint(args.resume)
190 | model.load_state_dict(checkpoint['state_dict'], strict=False)
191 | else:
192 | raise RuntimeWarning('Not using a pre-trained model.')
193 | model = nn.DataParallel(model).cuda()
194 |
195 | # Criterion
196 | criterion = [
197 | TripletLoss(args.margin, args.num_instances, use_semi=False).cuda(),
198 | TripletLoss(args.margin, args.num_instances, use_semi=False).cuda()
199 | ]
200 | optimizer = torch.optim.Adam(
201 | model.parameters(), lr=args.lr
202 | )
203 |
204 | # training stage transformer on input images
205 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
206 | train_transformer = T.Compose([
207 | T.Resize((args.height, args.width)),
208 | T.RandomHorizontalFlip(),
209 | T.ToTensor(), normalizer,
210 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3)
211 | ])
212 |
213 | # # Start training
214 | for iter_n in range(args.iteration):
215 | if args.lambda_value == 0:
216 | source_features = 0
217 | else:
218 | # get source datas' feature
219 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq)
220 | # synchronization feature order with src_dataset.train
221 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0)
222 |
223 | # extract training images' features
224 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n + 1))
225 | target_features, _ = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq)
226 | # synchronization feature order with dataset.train
227 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0)
228 | # calculate distance and rerank result
229 | print('Calculating feature distances...')
230 | target_features = target_features.numpy()
231 | rerank_dist = re_ranking(source_features, target_features, lambda_value=args.lambda_value)
232 | if iter_n == 0:
233 | # DBSCAN cluster
234 | tri_mat = np.triu(rerank_dist, 1) # tri_mat.dim=2
235 | tri_mat = tri_mat[np.nonzero(tri_mat)] # tri_mat.dim=1
236 | tri_mat = np.sort(tri_mat, axis=None)
237 | top_num = np.round(args.rho * tri_mat.size).astype(int)
238 | eps = tri_mat[:top_num].mean()
239 | print('eps in cluster: {:.3f}'.format(eps))
240 | cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=8)
241 | # select & cluster images as training set of this epochs
242 | print('Clustering and labeling...')
243 | labels = cluster.fit_predict(rerank_dist)
244 | num_ids = len(set(labels)) - 1
245 | print('Iteration {} have {} training ids'.format(iter_n + 1, num_ids))
246 | # generate new dataset
247 | new_dataset = []
248 | # assign label for target ones
249 | newLab = labelNoise(torch.from_numpy(target_features), torch.from_numpy(labels))
250 | # unknownFeats = target_features[labels==-1,:]
251 | counter = 0
252 | from collections import defaultdict
253 | realIDs, fakeIDs = defaultdict(list), []
254 | for (fname, realID, cam), label in zip(tgt_dataset.trainval, newLab):
255 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0
256 | new_dataset.append((fname, label, cam))
257 | realIDs[realID].append(counter)
258 | fakeIDs.append(label)
259 | counter += 1
260 | precision, recall, fscore = calScores(realIDs, np.asarray(fakeIDs))
261 | print('Iteration {} have {} training images'.format(iter_n + 1, len(new_dataset)))
262 | print(f'precision:{precision * 100}, recall:{100 * recall}, fscore:{fscore}')
263 | train_loader = DataLoader(
264 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer),
265 | batch_size=args.batch_size, num_workers=4,
266 | sampler=RandomIdentitySampler(new_dataset, args.num_instances),
267 | pin_memory=True, drop_last=True
268 | )
269 |
270 | trainer = Trainer(model, criterion)
271 |
272 | # Start training
273 | for epoch in range(args.epochs):
274 | trainer.train(epoch, train_loader, optimizer) # to at most 80%
275 | # test only
276 | evaluator = Evaluator(model, print_freq=args.print_freq)
277 | # rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
278 |
279 | # Evaluate
280 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
281 | save_checkpoint({
282 | 'state_dict': model.module.state_dict(),
283 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0],
284 | }, True, fpath=osp.join(args.logs_dir, 'adapted.pth.tar'))
285 | return rank_score.map, rank_score.market1501[0]
286 |
287 |
288 | if __name__ == '__main__':
289 | parser = argparse.ArgumentParser(description="Triplet loss classification")
290 | # data
291 | parser.add_argument('--src_dataset', type=str, default='dukemtmc',
292 | choices=datasets.names())
293 | parser.add_argument('--tgt_dataset', type=str, default='market1501',
294 | choices=datasets.names())
295 | parser.add_argument('--batch_size', type=int, default=64)
296 | parser.add_argument('--workers', type=int, default=4)
297 | parser.add_argument('--split', type=int, default=0)
298 | parser.add_argument('--noiseLam', type=float, default=0.5)
299 | parser.add_argument('--height', type=int,
300 | help="input height, default: 256 for resnet*, "
301 | "144 for inception")
302 | parser.add_argument('--width', type=int,
303 | help="input width, default: 128 for resnet*, "
304 | "56 for inception")
305 | parser.add_argument('--combine-trainval', action='store_true',
306 | help="train and val sets together for training, "
307 | "val set alone for validation")
308 | parser.add_argument('--num_instances', type=int, default=4,
309 | help="each minibatch consist of "
310 | "(batch_size // num_instances) identities, and "
311 | "each identity has num_instances instances, "
312 | "default: 4")
313 | # model
314 | parser.add_argument('--arch', type=str, default='resnet50',
315 | choices=models.names())
316 | # loss
317 | parser.add_argument('--margin', type=float, default=0.5,
318 | help="margin of the triplet loss, default: 0.5")
319 | parser.add_argument('--lambda_value', type=float, default=0.1,
320 | help="balancing parameter, default: 0.1")
321 | parser.add_argument('--rho', type=float, default=1.6e-3,
322 | help="rho percentage, default: 1.6e-3")
323 | # optimizer
324 | parser.add_argument('--lr', type=float, default=6e-5,
325 | help="learning rate of all parameters")
326 | # training configs
327 | parser.add_argument('--resume', type=str, metavar='PATH',
328 | default='')
329 | parser.add_argument('--evaluate', type=int, default=0,
330 | help="evaluation only")
331 | parser.add_argument('--seed', type=int, default=1)
332 | parser.add_argument('--print_freq', type=int, default=1)
333 | parser.add_argument('--iteration', type=int, default=10)
334 | parser.add_argument('--epochs', type=int, default=30)
335 | # metric learning
336 | parser.add_argument('--dist_metric', type=str, default='euclidean',
337 | choices=['euclidean', 'kissme'])
338 | # misc
339 | parser.add_argument('--data_dir', type=str, metavar='PATH',
340 | default='')
341 | parser.add_argument('--logs_dir', type=str, metavar='PATH',
342 | default='')
343 |
344 | args = parser.parse_args()
345 | mean_ap, rank1 = main(args)
346 |
--------------------------------------------------------------------------------
/selftrainingKmeansAsy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import print_function, absolute_import
5 | import argparse
6 | import time
7 | import os.path as osp
8 | import os
9 | import numpy as np
10 | import torch
11 | from torch import nn
12 | from torch.nn import init
13 | from torch.backends import cudnn
14 | from torch.utils.data import DataLoader
15 | from reid import datasets
16 | from reid import models
17 | from reid.dist_metric import DistanceMetric
18 | from reid.loss import TripletLoss
19 | from reid.trainers import Trainer, CoTrainerAsy, CoTeaching, CoTrainerAsySep
20 | from reid.evaluators import Evaluator, extract_features
21 | from reid.utils.data import transforms as T
22 | import torch.nn.functional as F
23 | from reid.utils.data.preprocessor import Preprocessor
24 | from reid.utils.data.sampler import RandomIdentitySampler
25 | from reid.utils.serialization import load_checkpoint, save_checkpoint
26 |
27 | from sklearn.cluster import KMeans
28 | from reid.rerank import re_ranking
29 |
30 |
31 | def get_data(name, data_dir, height, width, batch_size,
32 | workers):
33 | root = osp.join(data_dir, name)
34 |
35 | dataset = datasets.create(name, root, num_val=0.1)
36 |
37 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
38 | std=[0.229, 0.224, 0.225])
39 |
40 | # use all training and validation images in target dataset
41 | train_set = dataset.trainval
42 | num_classes = dataset.num_trainval_ids
43 |
44 | transformer = T.Compose([
45 | T.Resize((height,width)),
46 | T.ToTensor(),
47 | normalizer,
48 | ])
49 |
50 | extfeat_loader = DataLoader(
51 | Preprocessor(train_set, root=dataset.images_dir,
52 | transform=transformer),
53 | batch_size=batch_size, num_workers=workers,
54 | shuffle=False, pin_memory=True)
55 |
56 | test_loader = DataLoader(
57 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)),
58 | root=dataset.images_dir, transform=transformer),
59 | batch_size=batch_size//2, num_workers=workers,
60 | shuffle=False, pin_memory=True)
61 |
62 | return dataset, num_classes, extfeat_loader, test_loader
63 |
64 |
65 | def get_source_data(name, data_dir, height, width, batch_size, workers):
66 | root = osp.join(data_dir, name)
67 |
68 | dataset = datasets.create(name, root, num_val=0.1)
69 |
70 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
71 | std=[0.229, 0.224, 0.225])
72 |
73 | # use all training images on source dataset
74 | train_set = dataset.train
75 | num_classes = dataset.num_train_ids
76 |
77 | transformer = T.Compose([
78 | T.Resize((height,width)),
79 | T.ToTensor(),
80 | normalizer,
81 | ])
82 |
83 | extfeat_loader = DataLoader(
84 | Preprocessor(train_set, root=dataset.images_dir,
85 | transform=transformer),
86 | batch_size=batch_size, num_workers=workers,
87 | shuffle=False, pin_memory=True)
88 |
89 | return dataset, extfeat_loader
90 |
91 |
92 | def splitLowconfi(feature, labels, centers, ratio=0.2):
93 | # set bot 20% imsimilar samples to -1
94 | # center VS feature
95 | centerDis = calDis(torch.from_numpy(feature), torch.from_numpy(centers)).numpy() # center VS samples
96 | noiseLoc = []
97 | for ii, pid in enumerate(set(labels)):
98 | curDis = centerDis[:,ii]
99 | curDis[labels!=pid] = 100
100 | smallLossIdx = curDis.argsort()
101 | smallLossIdx = smallLossIdx[curDis[smallLossIdx]!=100]
102 | # bot 20% removed
103 | partSize = int(ratio*smallLossIdx.shape[0])
104 | if partSize!=0:
105 | noiseLoc.extend(smallLossIdx[-partSize:])
106 | labels[noiseLoc] = -1
107 | return labels
108 |
109 |
110 | def calDis(qFeature, gFeature):#246s
111 | x, y = F.normalize(qFeature), F.normalize(gFeature)
112 | # x, y = qFeature, gFeature
113 | m, n = x.shape[0], y.shape[0]
114 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
115 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
116 | disMat.addmm_(1, -2, x, y.t())
117 | return disMat.clamp_(min=1e-5)
118 |
119 |
120 | def labelUnknown(knownFeat, allLab, unknownFeat):
121 | disMat = calDis(knownFeat, unknownFeat)
122 | labLoc = disMat.argmin(dim=0)
123 | return allLab[labLoc]
124 |
125 |
126 | def labelNoise(feature, labels):
127 | # features and labels with -1
128 | noiseFeat, pureFeat = feature[labels==-1,:], feature[labels!=-1,:]
129 | labels = labels[labels!=-1]
130 | unLab = labelUnknown(pureFeat, labels, noiseFeat)
131 | return unLab.numpy()
132 |
133 |
134 | def main(args):
135 | np.random.seed(args.seed)
136 | torch.manual_seed(args.seed)
137 | cudnn.benchmark = True
138 |
139 | # Create data loaders
140 | assert args.num_instances > 1, "num_instances should be greater than 1"
141 | assert args.batch_size % args.num_instances == 0, \
142 | 'num_instances should divide batch_size'
143 | if args.height is None or args.width is None:
144 | args.height, args.width = (144, 56) if args.arch == 'inception' else \
145 | (256, 128)
146 |
147 | # get source data
148 | src_dataset, src_extfeat_loader = \
149 | get_source_data(args.src_dataset, args.data_dir, args.height,
150 | args.width, args.batch_size, args.workers)
151 | # get target data
152 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \
153 | get_data(args.tgt_dataset, args.data_dir, args.height,
154 | args.width, args.batch_size, args.workers)
155 |
156 | # Create model
157 | # Hacking here to let the classifier be the number of source ids
158 | if args.src_dataset == 'dukemtmc':
159 | model = models.create(args.arch, num_classes=632, pretrained=False)
160 | coModel = models.create(args.arch, num_classes=632, pretrained=False)
161 | elif args.src_dataset == 'market1501':
162 | model = models.create(args.arch, num_classes=676, pretrained=False)
163 | coModel = models.create(args.arch, num_classes=676, pretrained=False)
164 | else:
165 | raise RuntimeError('Please specify the number of classes (ids) of the network.')
166 |
167 | # Load from checkpoint
168 | if args.resume:
169 | print('Resuming checkpoints from finetuned model on another dataset...\n')
170 | checkpoint = load_checkpoint(args.resume)
171 | model.load_state_dict(checkpoint['state_dict'], strict=False)
172 | coModel.load_state_dict(checkpoint['state_dict'], strict=False)
173 | else:
174 | raise RuntimeWarning('Not using a pre-trained model.')
175 | model = nn.DataParallel(model).cuda()
176 | coModel = nn.DataParallel(coModel).cuda()
177 |
178 | # evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
179 | # if args.evaluate: return
180 |
181 | # Criterion
182 | criterion = [
183 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(),
184 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(),
185 | ]
186 |
187 |
188 | # Optimizer
189 | optimizer = torch.optim.Adam(
190 | model.parameters(), lr = args.lr
191 | )
192 | coOptimizer = torch.optim.Adam(
193 | coModel.parameters(), lr = args.lr
194 | )
195 |
196 | optims = [optimizer, coOptimizer]
197 |
198 |
199 | # training stage transformer on input images
200 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
201 | train_transformer = T.Compose([
202 | T.Resize((args.height,args.width)),
203 | T.RandomHorizontalFlip(),
204 | T.ToTensor(), normalizer,
205 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3)
206 | ])
207 |
208 |
209 | # # Start training
210 | for iter_n in range(args.iteration):
211 | if args.lambda_value == 0:
212 | source_features = 0
213 | else:
214 | # get source datas' feature
215 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq)
216 | # synchronization feature order with src_dataset.train
217 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0)
218 |
219 | # extract training images' features
220 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n+1))
221 | target_features, tarNames = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq)
222 | # synchronization feature order with dataset.train
223 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0)
224 | target_real_label = np.asarray([tarNames[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval])
225 | numTarID = len(set(target_real_label))
226 | # calculate distance and rerank result
227 | print('Calculating feature distances...')
228 | target_features = target_features.numpy()
229 | cluster = KMeans(n_clusters=numTarID, n_jobs=8, n_init=1)
230 |
231 | # select & cluster images as training set of this epochs
232 | print('Clustering and labeling...')
233 | clusterRes = cluster.fit(target_features)
234 | labels, centers = clusterRes.labels_, clusterRes.cluster_centers_
235 | labels = splitLowconfi(target_features,labels,centers)
236 | # num_ids = len(set(labels))
237 | # print('Iteration {} have {} training ids'.format(iter_n+1, num_ids))
238 | # generate new dataset
239 | new_dataset, unknown_dataset = [], []
240 | # assign label for target ones
241 | unknownLab = labelNoise(torch.from_numpy(target_features), torch.from_numpy(labels))
242 | # unknownFeats = target_features[labels==-1,:]
243 | unCounter = 0
244 | for (fname, _, cam), label in zip(tgt_dataset.trainval, labels):
245 | if label==-1:
246 | unknown_dataset.append((fname,int(unknownLab[unCounter]),cam)) # unknown data
247 | unCounter += 1
248 | continue
249 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0
250 | new_dataset.append((fname,label,cam))
251 | print('Iteration {} have {} training images'.format(iter_n+1, len(new_dataset)))
252 |
253 | train_loader = DataLoader(
254 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer),
255 | batch_size=args.batch_size, num_workers=4,
256 | sampler=RandomIdentitySampler(new_dataset, args.num_instances),
257 | pin_memory=True, drop_last=True
258 | )
259 | # hard samples
260 | unLoader = DataLoader(
261 | Preprocessor(unknown_dataset, root=tgt_dataset.images_dir, transform=train_transformer),
262 | batch_size=args.batch_size, num_workers=4,
263 | sampler=RandomIdentitySampler(unknown_dataset, args.num_instances),
264 | pin_memory=True, drop_last=True
265 | )
266 |
267 | # train model with new generated dataset
268 | trainer = CoTrainerAsy(
269 | model, coModel, train_loader, unLoader, criterion, optims
270 | )
271 | # trainer = CoTeaching(
272 | # model, coModel, train_loader, unLoader, criterion, optims
273 | # )
274 | # trainer = CoTrainerAsySep(
275 | # model, coModel, train_loader, unLoader, criterion, optims
276 | # )
277 |
278 | evaluator = Evaluator(model, print_freq=args.print_freq)
279 | #evaluatorB = Evaluator(coModel, print_freq=args.print_freq)
280 | # Start training
281 | for epoch in range(args.epochs):
282 | trainer.train(epoch, remRate=0.2+(0.6/args.iteration)*(1+iter_n)) # to at most 80%
283 | # trainer.train(epoch, remRate=0.7+(0.3/args.iteration)*(1+iter_n))
284 | # test only
285 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
286 | #print('co-model:\n')
287 | #rank_score = evaluatorB.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
288 |
289 | # Evaluate
290 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
291 | save_checkpoint({
292 | 'state_dict': model.module.state_dict(),
293 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0],
294 | }, True, fpath=osp.join(args.logs_dir, 'adapted.pth.tar'))
295 | return (rank_score.map, rank_score.market1501[0])
296 |
297 |
298 | if __name__ == '__main__':
299 | parser = argparse.ArgumentParser(description="Triplet loss classification")
300 | # data
301 | parser.add_argument('--src_dataset', type=str, default='dukemtmc',
302 | choices=datasets.names())
303 | parser.add_argument('--tgt_dataset', type=str, default='market1501',
304 | choices=datasets.names())
305 | parser.add_argument('--batch_size', type=int, default=64)
306 | parser.add_argument('--workers', type=int, default=4)
307 | parser.add_argument('--split', type=int, default=0)
308 | parser.add_argument('--noiseLam', type=float, default=0.5)
309 | parser.add_argument('--height', type=int,
310 | help="input height, default: 256 for resnet*, "
311 | "144 for inception")
312 | parser.add_argument('--width', type=int,
313 | help="input width, default: 128 for resnet*, "
314 | "56 for inception")
315 | parser.add_argument('--combine-trainval', action='store_true',
316 | help="train and val sets together for training, "
317 | "val set alone for validation")
318 | parser.add_argument('--num_instances', type=int, default=4,
319 | help="each minibatch consist of "
320 | "(batch_size // num_instances) identities, and "
321 | "each identity has num_instances instances, "
322 | "default: 4")
323 | # model
324 | parser.add_argument('--arch', type=str, default='resnet50',
325 | choices=models.names())
326 | # loss
327 | parser.add_argument('--margin', type=float, default=0.5,
328 | help="margin of the triplet loss, default: 0.5")
329 | parser.add_argument('--lambda_value', type=float, default=0.1,
330 | help="balancing parameter, default: 0.1")
331 | parser.add_argument('--rho', type=float, default=1.6e-3,
332 | help="rho percentage, default: 1.6e-3")
333 | # optimizer
334 | parser.add_argument('--lr', type=float, default=6e-5,
335 | help="learning rate of all parameters")
336 | # training configs
337 | parser.add_argument('--resume', type=str, metavar='PATH',
338 | default = '')
339 | parser.add_argument('--evaluate', type=int, default=0,
340 | help="evaluation only")
341 | parser.add_argument('--seed', type=int, default=1)
342 | parser.add_argument('--print_freq', type=int, default=1)
343 | parser.add_argument('--iteration', type=int, default=10)
344 | parser.add_argument('--epochs', type=int, default=30)
345 | # metric learning
346 | parser.add_argument('--dist_metric', type=str, default='euclidean',
347 | choices=['euclidean', 'kissme'])
348 | # misc
349 | parser.add_argument('--data_dir', type=str, metavar='PATH',
350 | default='')
351 | parser.add_argument('--logs_dir', type=str, metavar='PATH',
352 | default='')
353 |
354 | args = parser.parse_args()
355 | mean_ap, rank1 = main(args)
356 | results_file = np.asarray([mean_ap, rank1])
357 | file_name = time.strftime("%H%M%S", time.localtime())
358 | file_name = osp.join(args.logs_dir, file_name)
359 | np.save(file_name, results_file)
360 |
--------------------------------------------------------------------------------
/selftrainingCT.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import print_function, absolute_import
5 | import argparse
6 | import time
7 | import os.path as osp
8 | import os
9 | import numpy as np
10 | import torch
11 | from torch import nn
12 | from torch.nn import init
13 | from torch.backends import cudnn
14 | from torch.utils.data import DataLoader
15 | from reid import datasets
16 | from reid import models
17 | from reid.dist_metric import DistanceMetric
18 | from reid.loss import TripletLoss
19 | from reid.trainers import CoTeaching
20 | from reid.evaluators import Evaluator, extract_features
21 | from reid.utils.data import transforms as T
22 | import torch.nn.functional as F
23 | from reid.utils.data.preprocessor import Preprocessor
24 | from reid.utils.data.sampler import RandomIdentitySampler
25 | from reid.utils.serialization import load_checkpoint, save_checkpoint
26 |
27 | from sklearn.cluster import DBSCAN
28 | from reid.rerank import re_ranking
29 |
30 |
31 | def calScores(clusters, labels):
32 | """
33 | compute pair-wise precision pair-wise recall
34 | """
35 | from scipy.special import comb
36 | if len(clusters) == 0:
37 | return 0, 0
38 | else:
39 | curCluster = []
40 | for curClus in clusters.values():
41 | curCluster.append(labels[curClus])
42 | TPandFP = sum([comb(len(val), 2) for val in curCluster])
43 | TP = 0
44 | for clusterVal in curCluster:
45 | for setMember in set(clusterVal):
46 | if sum(clusterVal == setMember) < 2: continue
47 | TP += comb(sum(clusterVal == setMember), 2)
48 | FP = TPandFP - TP
49 | # FN and TN
50 | TPandFN = sum([comb(labels.tolist().count(val), 2) for val in set(labels)])
51 | FN = TPandFN - TP
52 | # cal precision and recall
53 | precision, recall = TP / (TP + FP), TP / (TP + FN)
54 | fScore = 2 * precision * recall / (precision + recall)
55 | return precision, recall, fScore
56 |
57 |
58 | def get_data(name, data_dir, height, width, batch_size,
59 | workers):
60 | root = osp.join(data_dir, name)
61 |
62 | dataset = datasets.create(name, root, num_val=0.1)
63 |
64 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
65 | std=[0.229, 0.224, 0.225])
66 |
67 | # use all training and validation images in target dataset
68 | train_set = dataset.trainval
69 | num_classes = dataset.num_trainval_ids
70 |
71 | transformer = T.Compose([
72 | T.Resize((height, width)),
73 | T.ToTensor(),
74 | normalizer,
75 | ])
76 |
77 | extfeat_loader = DataLoader(
78 | Preprocessor(train_set, root=dataset.images_dir,
79 | transform=transformer),
80 | batch_size=batch_size, num_workers=workers,
81 | shuffle=False, pin_memory=True)
82 |
83 | test_loader = DataLoader(
84 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)),
85 | root=dataset.images_dir, transform=transformer),
86 | batch_size=batch_size, num_workers=workers,
87 | shuffle=False, pin_memory=True)
88 |
89 | return dataset, num_classes, extfeat_loader, test_loader
90 |
91 |
92 | def get_source_data(name, data_dir, height, width, batch_size,
93 | workers):
94 | root = osp.join(data_dir, name)
95 |
96 | dataset = datasets.create(name, root, num_val=0.1)
97 |
98 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
99 | std=[0.229, 0.224, 0.225])
100 |
101 | # use all training images on source dataset
102 | train_set = dataset.train
103 | num_classes = dataset.num_train_ids
104 |
105 | transformer = T.Compose([
106 | T.Resize((height, width)),
107 | T.ToTensor(),
108 | normalizer,
109 | ])
110 |
111 | extfeat_loader = DataLoader(
112 | Preprocessor(train_set, root=dataset.images_dir,
113 | transform=transformer),
114 | batch_size=batch_size, num_workers=workers,
115 | shuffle=False, pin_memory=True)
116 |
117 | return dataset, extfeat_loader
118 |
119 |
120 | def calDis(qFeature, gFeature): # 246s
121 | x, y = F.normalize(qFeature), F.normalize(gFeature)
122 | # x, y = qFeature, gFeature
123 | m, n = x.shape[0], y.shape[0]
124 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
125 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
126 | disMat.addmm_(1, -2, x, y.t())
127 | return disMat.clamp_(min=1e-5)
128 |
129 |
130 | def labelUnknown(knownFeat, allLab, unknownFeat):
131 | # allLab--label from known
132 | disMat = calDis(knownFeat, unknownFeat)
133 | labLoc = disMat.argmin(dim=0)
134 | return allLab[labLoc]
135 |
136 |
137 | def labelNoise(feature, labels):
138 | # features and labels with -1
139 | noiseFeat, pureFeat = feature[labels == -1, :], feature[labels != -1, :]
140 | pureLabs = labels[labels != -1] # no outliers
141 | unLab = labelUnknown(pureFeat, pureLabs, noiseFeat)
142 | labels[labels == -1] = unLab
143 | return labels.numpy()
144 |
145 |
146 | def getCenter(features, labels):
147 | allCenter = {}
148 | features = features[labels != -1, :]
149 | labels = labels[labels != -1]
150 | for pid in set(labels):
151 | allCenter[pid] = torch.from_numpy(features[labels == pid, :].mean(axis=0)).unsqueeze(0)
152 | return torch.cat(list(allCenter.values()))
153 |
154 |
155 | def main(args):
156 | np.random.seed(args.seed)
157 | torch.manual_seed(args.seed)
158 | cudnn.benchmark = True
159 |
160 | # Create data loaders
161 | assert args.num_instances > 1, "num_instances should be greater than 1"
162 | assert args.batch_size % args.num_instances == 0, \
163 | 'num_instances should divide batch_size'
164 | if args.height is None or args.width is None:
165 | args.height, args.width = (144, 56) if args.arch == 'inception' else \
166 | (256, 128)
167 |
168 | # get source data
169 | src_dataset, src_extfeat_loader = \
170 | get_source_data(args.src_dataset, args.data_dir, args.height,
171 | args.width, args.batch_size, args.workers)
172 | # get target data
173 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \
174 | get_data(args.tgt_dataset, args.data_dir, args.height,
175 | args.width, args.batch_size, args.workers)
176 |
177 | # Create model
178 | # Hacking here to let the classifier be the number of source ids
179 | if args.src_dataset == 'dukemtmc':
180 | model = models.create(args.arch, num_classes=632, pretrained=False)
181 | coModel = models.create(args.arch, num_classes=632, pretrained=False)
182 | elif args.src_dataset == 'market1501':
183 | model = models.create(args.arch, num_classes=676, pretrained=False)
184 | coModel = models.create(args.arch, num_classes=676, pretrained=False)
185 | elif args.src_dataset == 'msmt17':
186 | model = models.create(args.arch, num_classes=1041, pretrained=False)
187 | coModel = models.create(args.arch, num_classes=1041, pretrained=False)
188 | elif args.src_dataset == 'cuhk03':
189 | model = models.create(args.arch, num_classes=1230, pretrained=False)
190 | coModel = models.create(args.arch, num_classes=1230, pretrained=False)
191 | else:
192 | raise RuntimeError('Please specify the number of classes (ids) of the network.')
193 |
194 | # Load from checkpoint
195 | if args.resume:
196 | print('Resuming checkpoints from finetuned model on another dataset...\n')
197 | checkpoint = load_checkpoint(args.resume)
198 | model.load_state_dict(checkpoint['state_dict'], strict=False)
199 | coModel.load_state_dict(checkpoint['state_dict'], strict=False)
200 | else:
201 | raise RuntimeWarning('Not using a pre-trained model.')
202 | model = nn.DataParallel(model).cuda()
203 | coModel = nn.DataParallel(coModel).cuda()
204 |
205 | # Criterion
206 | criterion = [
207 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(),
208 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda()
209 | ]
210 | optimizer = torch.optim.Adam(
211 | model.parameters(), lr=args.lr
212 | )
213 | coOptimizer = torch.optim.Adam(
214 | coModel.parameters(), lr=args.lr
215 | )
216 |
217 | optims = [optimizer, coOptimizer]
218 |
219 | # training stage transformer on input images
220 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
221 | train_transformer = T.Compose([
222 | T.Resize((args.height, args.width)),
223 | T.RandomHorizontalFlip(),
224 | T.ToTensor(), normalizer,
225 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3)
226 | ])
227 |
228 | # # Start training
229 | for iter_n in range(args.iteration):
230 | if args.lambda_value == 0:
231 | source_features = 0
232 | else:
233 | # get source datas' feature
234 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq)
235 | # synchronization feature order with src_dataset.train
236 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0)
237 |
238 | # extract training images' features
239 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n + 1))
240 | target_features, _ = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq)
241 | # synchronization feature order with dataset.train
242 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0)
243 | # calculate distance and rerank result
244 | print('Calculating feature distances...')
245 | target_features = target_features.numpy()
246 | rerank_dist = re_ranking(source_features, target_features, lambda_value=args.lambda_value)
247 | if iter_n == 0:
248 | # DBSCAN cluster
249 | tri_mat = np.triu(rerank_dist, 1) # tri_mat.dim=2
250 | tri_mat = tri_mat[np.nonzero(tri_mat)] # tri_mat.dim=1
251 | tri_mat = np.sort(tri_mat, axis=None)
252 | top_num = np.round(args.rho * tri_mat.size).astype(int)
253 | eps = tri_mat[:top_num].mean()
254 | print('eps in cluster: {:.3f}'.format(eps))
255 | cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=8)
256 | # select & cluster images as training set of this epochs
257 | print('Clustering and labeling...')
258 | labels = cluster.fit_predict(rerank_dist)
259 | num_ids = len(set(labels)) - 1
260 | print('Iteration {} have {} training ids'.format(iter_n + 1, num_ids))
261 | # generate new dataset
262 | new_dataset = []
263 | # assign label for target ones
264 | newLab = labelNoise(torch.from_numpy(target_features), torch.from_numpy(labels))
265 | # unknownFeats = target_features[labels==-1,:]
266 | counter = 0
267 | from collections import defaultdict
268 | realIDs, fakeIDs = defaultdict(list), []
269 | for (fname, realID, cam), label in zip(tgt_dataset.trainval, newLab):
270 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0
271 | new_dataset.append((fname, label, cam))
272 | realIDs[realID].append(counter)
273 | fakeIDs.append(label)
274 | counter += 1
275 | precision, recall, fscore = calScores(realIDs, np.asarray(fakeIDs))
276 | print('Iteration {} have {} training images'.format(iter_n + 1, len(new_dataset)))
277 | print(f'precision:{precision * 100}, recall:{100 * recall}, fscore:{100 * fscore}')
278 | train_loader = DataLoader(
279 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer),
280 | batch_size=args.batch_size, num_workers=4,
281 | sampler=RandomIdentitySampler(new_dataset, args.num_instances),
282 | pin_memory=True, drop_last=True
283 | )
284 | trainer = CoTeaching(
285 | model, coModel, train_loader, criterion, optims
286 | )
287 |
288 | # Start training
289 | for epoch in range(args.epochs):
290 | trainer.train(epoch, remRate=0.2 + (0.8 / args.iteration) * (1 + iter_n)) # to at most 80%
291 | # test only
292 | evaluator = Evaluator(model, print_freq=args.print_freq)
293 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
294 |
295 | # Evaluate
296 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
297 | save_checkpoint({
298 | 'state_dict': model.module.state_dict(),
299 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0],
300 | }, True, fpath=osp.join(args.logs_dir, 'adapted.pth.tar'))
301 | return rank_score.map, rank_score.market1501[0]
302 |
303 |
304 | if __name__ == '__main__':
305 | parser = argparse.ArgumentParser(description="Triplet loss classification")
306 | # data
307 | parser.add_argument('--src_dataset', type=str, default='dukemtmc',
308 | choices=datasets.names())
309 | parser.add_argument('--tgt_dataset', type=str, default='market1501',
310 | choices=datasets.names())
311 | parser.add_argument('--batch_size', type=int, default=64)
312 | parser.add_argument('--workers', type=int, default=4)
313 | parser.add_argument('--split', type=int, default=0)
314 | parser.add_argument('--noiseLam', type=float, default=0.5)
315 | parser.add_argument('--height', type=int,
316 | help="input height, default: 256 for resnet*, "
317 | "144 for inception")
318 | parser.add_argument('--width', type=int,
319 | help="input width, default: 128 for resnet*, "
320 | "56 for inception")
321 | parser.add_argument('--combine-trainval', action='store_true',
322 | help="train and val sets together for training, "
323 | "val set alone for validation")
324 | parser.add_argument('--num_instances', type=int, default=4,
325 | help="each minibatch consist of "
326 | "(batch_size // num_instances) identities, and "
327 | "each identity has num_instances instances, "
328 | "default: 4")
329 | # model
330 | parser.add_argument('--arch', type=str, default='resnet50',
331 | choices=models.names())
332 | # loss
333 | parser.add_argument('--margin', type=float, default=0.5,
334 | help="margin of the triplet loss, default: 0.5")
335 | parser.add_argument('--lambda_value', type=float, default=0.1,
336 | help="balancing parameter, default: 0.1")
337 | parser.add_argument('--rho', type=float, default=1.6e-3,
338 | help="rho percentage, default: 1.6e-3")
339 | # optimizer
340 | parser.add_argument('--lr', type=float, default=6e-5,
341 | help="learning rate of all parameters")
342 | # training configs
343 | parser.add_argument('--resume', type=str, metavar='PATH',
344 | default='')
345 | parser.add_argument('--evaluate', type=int, default=0,
346 | help="evaluation only")
347 | parser.add_argument('--seed', type=int, default=1)
348 | parser.add_argument('--print_freq', type=int, default=1)
349 | parser.add_argument('--iteration', type=int, default=10)
350 | parser.add_argument('--epochs', type=int, default=30)
351 | # metric learning
352 | parser.add_argument('--dist_metric', type=str, default='euclidean',
353 | choices=['euclidean', 'kissme'])
354 | # misc
355 | parser.add_argument('--data_dir', type=str, metavar='PATH',
356 | default='')
357 | parser.add_argument('--logs_dir', type=str, metavar='PATH',
358 | default='')
359 |
360 | args = parser.parse_args()
361 | mean_ap, rank1 = main(args)
362 |
--------------------------------------------------------------------------------
/selftrainingACT.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import print_function, absolute_import
5 | import argparse
6 | import time
7 | import os.path as osp
8 | import os
9 | import numpy as np
10 | import torch
11 | from torch import nn
12 | from torch.nn import init
13 | from torch.backends import cudnn
14 | from torch.utils.data import DataLoader
15 | from reid import datasets
16 | from reid import models
17 | from reid.loss import TripletLoss
18 | from reid.trainers import CoTrainerAsy
19 | from reid.evaluators import Evaluator, extract_features
20 | from reid.utils.data import transforms as T
21 | import torch.nn.functional as F
22 | from reid.utils.data.preprocessor import Preprocessor
23 | from reid.utils.data.sampler import RandomIdentitySampler
24 | from reid.utils.serialization import load_checkpoint, save_checkpoint
25 |
26 | from sklearn.cluster import DBSCAN
27 | from reid.rerank import re_ranking
28 |
29 |
30 | def calScores(clusters, labels):
31 | """
32 | compute pair-wise precision pair-wise recall
33 | """
34 | from scipy.special import comb
35 | if len(clusters) == 0:
36 | return 0, 0
37 | else:
38 | curCluster = []
39 | for curClus in clusters.values():
40 | curCluster.append(labels[curClus])
41 | TPandFP = sum([comb(len(val), 2) for val in curCluster])
42 | TP = 0
43 | for clusterVal in curCluster:
44 | for setMember in set(clusterVal):
45 | if sum(clusterVal == setMember) < 2: continue
46 | TP += comb(sum(clusterVal == setMember), 2)
47 | FP = TPandFP - TP
48 | # FN and TN
49 | TPandFN = sum([comb(labels.tolist().count(val), 2) for val in set(labels)])
50 | FN = TPandFN - TP
51 | # cal precision and recall
52 | precision, recall = TP / (TP + FP), TP / (TP + FN)
53 | fScore = 2 * precision * recall / (precision + recall)
54 | return precision, recall, fScore
55 |
56 |
57 | def get_data(name, data_dir, height, width, batch_size,
58 | workers):
59 | root = osp.join(data_dir, name)
60 |
61 | dataset = datasets.create(name, root, num_val=0.1)
62 |
63 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
64 | std=[0.229, 0.224, 0.225])
65 |
66 | # use all training and validation images in target dataset
67 | train_set = dataset.trainval
68 | num_classes = dataset.num_trainval_ids
69 |
70 | transformer = T.Compose([
71 | T.Resize((height, width)),
72 | T.ToTensor(),
73 | normalizer,
74 | ])
75 |
76 | extfeat_loader = DataLoader(
77 | Preprocessor(train_set, root=dataset.images_dir,
78 | transform=transformer),
79 | batch_size=batch_size, num_workers=workers,
80 | shuffle=False, pin_memory=True)
81 |
82 | test_loader = DataLoader(
83 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)),
84 | root=dataset.images_dir, transform=transformer),
85 | batch_size=batch_size // 2, num_workers=workers,
86 | shuffle=False, pin_memory=True)
87 |
88 | return dataset, num_classes, extfeat_loader, test_loader
89 |
90 |
91 | def saveAll(nameList, rootDir, tarDir):
92 | import os
93 | import shutil
94 | if os.path.exists(tarDir):
95 | shutil.rmtree(tarDir)
96 | os.makedirs(tarDir)
97 | for name in nameList:
98 | shutil.copyfile(os.path.join(rootDir, name), os.path.join(tarDir, name))
99 |
100 |
101 | def get_source_data(name, data_dir, height, width, batch_size, workers):
102 | root = osp.join(data_dir, name)
103 |
104 | dataset = datasets.create(name, root, num_val=0.1)
105 |
106 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
107 | std=[0.229, 0.224, 0.225])
108 |
109 | # use all training images on source dataset
110 | train_set = dataset.train
111 | num_classes = dataset.num_train_ids
112 |
113 | transformer = T.Compose([
114 | T.Resize((height, width)),
115 | T.ToTensor(),
116 | normalizer,
117 | ])
118 |
119 | extfeat_loader = DataLoader(
120 | Preprocessor(train_set, root=dataset.images_dir,
121 | transform=transformer),
122 | batch_size=batch_size, num_workers=workers,
123 | shuffle=False, pin_memory=True)
124 |
125 | return dataset, extfeat_loader
126 |
127 |
128 | def calDis(qFeature, gFeature): # 246s
129 | x, y = F.normalize(qFeature), F.normalize(gFeature)
130 | # x, y = qFeature, gFeature
131 | m, n = x.shape[0], y.shape[0]
132 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
133 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
134 | disMat.addmm_(1, -2, x, y.t())
135 | return disMat.clamp_(min=1e-5)
136 |
137 |
138 | def labelUnknown(knownFeat, allLab, unknownFeat):
139 | disMat = calDis(knownFeat, unknownFeat)
140 | labLoc = disMat.argmin(dim=0)
141 | return allLab[labLoc]
142 |
143 |
144 | def labelNoise(feature, labels):
145 | # features and labels with -1
146 | noiseFeat, pureFeat = feature[labels == -1, :], feature[labels != -1, :]
147 | labels = labels[labels != -1]
148 | unLab = labelUnknown(pureFeat, labels, noiseFeat)
149 | return unLab.numpy()
150 |
151 |
152 | def main(args):
153 | np.random.seed(args.seed)
154 | torch.manual_seed(args.seed)
155 | cudnn.benchmark = True
156 |
157 | # Create data loaders
158 | assert args.num_instances > 1, "num_instances should be greater than 1"
159 | assert args.batch_size % args.num_instances == 0, \
160 | 'num_instances should divide batch_size'
161 | if args.height is None or args.width is None:
162 | args.height, args.width = (144, 56) if args.arch == 'inception' else \
163 | (256, 128)
164 |
165 | # get source data
166 | src_dataset, src_extfeat_loader = \
167 | get_source_data(args.src_dataset, args.data_dir, args.height,
168 | args.width, args.batch_size, args.workers)
169 | # get target data
170 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \
171 | get_data(args.tgt_dataset, args.data_dir, args.height,
172 | args.width, args.batch_size, args.workers)
173 |
174 | # Create model
175 | # Hacking here to let the classifier be the number of source ids
176 | if args.src_dataset == 'dukemtmc':
177 | model = models.create(args.arch, num_classes=632, pretrained=False)
178 | coModel = models.create(args.arch, num_classes=632, pretrained=False)
179 | elif args.src_dataset == 'market1501':
180 | model = models.create(args.arch, num_classes=676, pretrained=False)
181 | coModel = models.create(args.arch, num_classes=676, pretrained=False)
182 | elif args.src_dataset == 'msmt17':
183 | model = models.create(args.arch, num_classes=1041, pretrained=False)
184 | coModel = models.create(args.arch, num_classes=1041, pretrained=False)
185 | elif args.src_dataset == 'cuhk03':
186 | model = models.create(args.arch, num_classes=1230, pretrained=False)
187 | coModel = models.create(args.arch, num_classes=1230, pretrained=False)
188 | else:
189 | raise RuntimeError('Please specify the number of classes (ids) of the network.')
190 |
191 | # Load from checkpoint
192 | if args.resume:
193 | print('Resuming checkpoints from finetuned model on another dataset...\n')
194 | checkpoint = load_checkpoint(args.resume)
195 | model.load_state_dict(checkpoint['state_dict'], strict=False)
196 | coModel.load_state_dict(checkpoint['state_dict'], strict=False)
197 | else:
198 | raise RuntimeWarning('Not using a pre-trained model.')
199 | model = nn.DataParallel(model).cuda()
200 | coModel = nn.DataParallel(coModel).cuda()
201 |
202 | evaluator = Evaluator(model, print_freq=args.print_freq)
203 | evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
204 | # if args.evaluate: return
205 |
206 | # Criterion
207 | criterion = [
208 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(),
209 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(),
210 | ]
211 |
212 | # Optimizer
213 | optimizer = torch.optim.Adam(
214 | model.parameters(), lr=args.lr
215 | )
216 | coOptimizer = torch.optim.Adam(
217 | coModel.parameters(), lr=args.lr
218 | )
219 |
220 | optims = [optimizer, coOptimizer]
221 |
222 | # training stage transformer on input images
223 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
224 | train_transformer = T.Compose([
225 | T.Resize((args.height, args.width)),
226 | T.RandomHorizontalFlip(),
227 | T.ToTensor(), normalizer,
228 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3)
229 | ])
230 |
231 | # # Start training
232 | for iter_n in range(args.iteration):
233 | if args.lambda_value == 0:
234 | source_features = 0
235 | else:
236 | # get source datas' feature
237 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq)
238 | # synchronization feature order with src_dataset.train
239 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0)
240 |
241 | # extract training images' features
242 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n + 1))
243 | target_features, _ = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq)
244 | # synchronization feature order with dataset.train
245 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0)
246 | # calculate distance and rerank result
247 | print('Calculating feature distances...')
248 | target_features = target_features.numpy()
249 | rerank_dist = re_ranking(source_features, target_features, lambda_value=args.lambda_value)
250 | if iter_n == 0:
251 | # DBSCAN cluster
252 | tri_mat = np.triu(rerank_dist, 1) # tri_mat.dim=2
253 | tri_mat = tri_mat[np.nonzero(tri_mat)] # tri_mat.dim=1
254 | tri_mat = np.sort(tri_mat, axis=None)
255 | top_num = np.round(args.rho * tri_mat.size).astype(int)
256 | eps = tri_mat[:top_num].mean()
257 | print('eps in cluster: {:.3f}'.format(eps))
258 | cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=8)
259 |
260 | # select & cluster images as training set of this epochs
261 | print('Clustering and labeling...')
262 | labels = cluster.fit_predict(rerank_dist)
263 | num_ids = len(set(labels)) - 1
264 | print('Iteration {} have {} training ids'.format(iter_n + 1, num_ids))
265 | # generate new dataset
266 | new_dataset, unknown_dataset = [], []
267 | # assign label for target ones
268 | unknownLab = labelNoise(torch.from_numpy(target_features), torch.from_numpy(labels))
269 | # unknownFeats = target_features[labels==-1,:]
270 | unCounter, index = 0, 0
271 | from collections import defaultdict
272 | realIDs, fakeIDs = defaultdict(list), []
273 | for (fname, realPID, cam), label in zip(tgt_dataset.trainval, labels):
274 | if label == -1:
275 | unknown_dataset.append((fname, int(unknownLab[unCounter]), cam)) # unknown data
276 | fakeIDs.append(int(unknownLab[unCounter]))
277 | realIDs[realPID].append(index)
278 | unCounter += 1
279 | index += 1
280 | continue
281 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0
282 | new_dataset.append((fname, label, cam))
283 | fakeIDs.append(label)
284 | realIDs[realPID].append(index)
285 | index += 1
286 | print('Iteration {} have {} training images'.format(iter_n + 1, len(new_dataset)))
287 | precision, recall, fscore = calScores(realIDs, np.asarray(fakeIDs)) # fakeIDs does not contain -1
288 | print('precision:{}, recall:{}, fscore: {}'.format(100 * precision, 100 * recall, fscore))
289 |
290 | train_loader = DataLoader(
291 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer),
292 | batch_size=args.batch_size, num_workers=4,
293 | sampler=RandomIdentitySampler(new_dataset, args.num_instances),
294 | pin_memory=True, drop_last=True
295 | )
296 | # hard samples
297 | # noiseImgs = [name[1] for name in unknown_dataset]
298 | # saveAll(noiseImgs, tgt_dataset.images_dir, 'noiseImg')
299 | # import ipdb; ipdb.set_trace()
300 | unLoader = DataLoader(
301 | Preprocessor(unknown_dataset, root=tgt_dataset.images_dir, transform=train_transformer),
302 | batch_size=args.batch_size, num_workers=4,
303 | sampler=RandomIdentitySampler(unknown_dataset, args.num_instances),
304 | pin_memory=True, drop_last=True
305 | )
306 | # train model with new generated dataset
307 | trainer = CoTrainerAsy(
308 | model, coModel, train_loader, unLoader, criterion, optims
309 | )
310 |
311 | # Start training
312 | for epoch in range(args.epochs):
313 | trainer.train(epoch, remRate=0.2 + (0.8 / args.iteration) * (1 + iter_n))
314 |
315 | # test only
316 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
317 | # print('co-model:\n')
318 | # rank_score = evaluatorB.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
319 |
320 | # Evaluate
321 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
322 | save_checkpoint({
323 | 'state_dict': model.module.state_dict(),
324 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0],
325 | }, True, fpath=osp.join(args.logs_dir, 'asyCo.pth'))
326 | return rank_score.map, rank_score.market1501[0]
327 |
328 |
329 | if __name__ == '__main__':
330 | parser = argparse.ArgumentParser(description="Triplet loss classification")
331 | # data
332 | parser.add_argument('--src_dataset', type=str, default='dukemtmc',
333 | choices=datasets.names())
334 | parser.add_argument('--tgt_dataset', type=str, default='market1501',
335 | choices=datasets.names())
336 | parser.add_argument('--batch_size', type=int, default=64)
337 | parser.add_argument('--workers', type=int, default=4)
338 | parser.add_argument('--split', type=int, default=0)
339 | parser.add_argument('--noiseLam', type=float, default=0.5)
340 | parser.add_argument('--height', type=int,
341 | help="input height, default: 256 for resnet*, "
342 | "144 for inception")
343 | parser.add_argument('--width', type=int,
344 | help="input width, default: 128 for resnet*, "
345 | "56 for inception")
346 | parser.add_argument('--combine-trainval', action='store_true',
347 | help="train and val sets together for training, "
348 | "val set alone for validation")
349 | parser.add_argument('--num_instances', type=int, default=4,
350 | help="each minibatch consist of "
351 | "(batch_size // num_instances) identities, and "
352 | "each identity has num_instances instances, "
353 | "default: 4")
354 | # model
355 | parser.add_argument('--arch', type=str, default='resnet50',
356 | choices=models.names())
357 | # loss
358 | parser.add_argument('--margin', type=float, default=0.5,
359 | help="margin of the triplet loss, default: 0.5")
360 | parser.add_argument('--lambda_value', type=float, default=0.1,
361 | help="balancing parameter, default: 0.1")
362 | parser.add_argument('--rho', type=float, default=1.6e-3,
363 | help="rho percentage, default: 1.6e-3")
364 | # optimizer
365 | parser.add_argument('--lr', type=float, default=6e-5,
366 | help="learning rate of all parameters")
367 | # training configs
368 | parser.add_argument('--resume', type=str, metavar='PATH',
369 | default='')
370 | parser.add_argument('--evaluate', type=int, default=0,
371 | help="evaluation only")
372 | parser.add_argument('--seed', type=int, default=1)
373 | parser.add_argument('--print_freq', type=int, default=1)
374 | parser.add_argument('--iteration', type=int, default=10)
375 | parser.add_argument('--epochs', type=int, default=30)
376 | # metric learning
377 | parser.add_argument('--dist_metric', type=str, default='euclidean',
378 | choices=['euclidean', 'kissme'])
379 | # misc
380 | parser.add_argument('--data_dir', type=str, metavar='PATH',
381 | default='')
382 | parser.add_argument('--logs_dir', type=str, metavar='PATH',
383 | default='')
384 |
385 | args = parser.parse_args()
386 | mean_ap, rank1 = main(args)
387 |
--------------------------------------------------------------------------------
/selftrainingRCT.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import print_function, absolute_import
5 | import argparse
6 | import time
7 | import os.path as osp
8 | import os
9 | import numpy as np
10 | import torch
11 | from torch import nn
12 | from torch.nn import init
13 | from torch.backends import cudnn
14 | from torch.utils.data import DataLoader
15 | from reid import datasets
16 | from reid import models
17 | from reid.dist_metric import DistanceMetric
18 | from reid.loss import TripletLoss
19 | from reid.trainers import RCoTeaching
20 | from reid.evaluators import Evaluator, extract_features
21 | from reid.utils.data import transforms as T
22 | import torch.nn.functional as F
23 | from reid.utils.data.preprocessor import Preprocessor
24 | from reid.utils.data.sampler import RandomIdentitySampler
25 | from reid.utils.serialization import load_checkpoint, save_checkpoint
26 |
27 | from sklearn.cluster import DBSCAN
28 | from reid.rerank import re_ranking
29 |
30 |
31 | def calScores(clusters, labels):
32 | """
33 | compute pair-wise precision pair-wise recall
34 | """
35 | from scipy.special import comb
36 | if len(clusters) == 0:
37 | return 0, 0
38 | else:
39 | curCluster = []
40 | for curClus in clusters.values():
41 | curCluster.append(labels[curClus])
42 | TPandFP = sum([comb(len(val), 2) for val in curCluster])
43 | TP = 0
44 | for clusterVal in curCluster:
45 | for setMember in set(clusterVal):
46 | if sum(clusterVal == setMember) < 2: continue
47 | TP += comb(sum(clusterVal == setMember), 2)
48 | FP = TPandFP - TP
49 | # FN and TN
50 | TPandFN = sum([comb(labels.tolist().count(val), 2) for val in set(labels)])
51 | FN = TPandFN - TP
52 | # cal precision and recall
53 | precision, recall = TP / (TP + FP), TP / (TP + FN)
54 | fScore = 2 * precision * recall / (precision + recall)
55 | return precision, recall, fScore
56 |
57 |
58 | def get_data(name, data_dir, height, width, batch_size,
59 | workers):
60 | root = osp.join(data_dir, name)
61 |
62 | dataset = datasets.create(name, root, num_val=0.1)
63 |
64 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
65 | std=[0.229, 0.224, 0.225])
66 |
67 | # use all training and validation images in target dataset
68 | train_set = dataset.trainval
69 | num_classes = dataset.num_trainval_ids
70 |
71 | transformer = T.Compose([
72 | T.Resize((height, width)),
73 | T.ToTensor(),
74 | normalizer,
75 | ])
76 |
77 | extfeat_loader = DataLoader(
78 | Preprocessor(train_set, root=dataset.images_dir,
79 | transform=transformer),
80 | batch_size=batch_size, num_workers=workers,
81 | shuffle=False, pin_memory=True)
82 |
83 | test_loader = DataLoader(
84 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)),
85 | root=dataset.images_dir, transform=transformer),
86 | batch_size=batch_size // 2, num_workers=workers,
87 | shuffle=False, pin_memory=True)
88 |
89 | return dataset, num_classes, extfeat_loader, test_loader
90 |
91 |
92 | def saveAll(nameList, rootDir, tarDir):
93 | import os
94 | import shutil
95 | if os.path.exists(tarDir):
96 | shutil.rmtree(tarDir)
97 | os.makedirs(tarDir)
98 | for name in nameList:
99 | shutil.copyfile(os.path.join(rootDir, name), os.path.join(tarDir, name))
100 |
101 |
102 | def get_source_data(name, data_dir, height, width, batch_size, workers):
103 | root = osp.join(data_dir, name)
104 |
105 | dataset = datasets.create(name, root, num_val=0.1)
106 |
107 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
108 | std=[0.229, 0.224, 0.225])
109 |
110 | # use all training images on source dataset
111 | train_set = dataset.train
112 | num_classes = dataset.num_train_ids
113 |
114 | transformer = T.Compose([
115 | T.Resize((height, width)),
116 | T.ToTensor(),
117 | normalizer,
118 | ])
119 |
120 | extfeat_loader = DataLoader(
121 | Preprocessor(train_set, root=dataset.images_dir,
122 | transform=transformer),
123 | batch_size=batch_size, num_workers=workers,
124 | shuffle=False, pin_memory=True)
125 |
126 | return dataset, extfeat_loader
127 |
128 |
129 | def calDis(qFeature, gFeature): # 246s
130 | x, y = F.normalize(qFeature), F.normalize(gFeature)
131 | # x, y = qFeature, gFeature
132 | m, n = x.shape[0], y.shape[0]
133 | disMat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
134 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
135 | disMat.addmm_(1, -2, x, y.t())
136 | return disMat.clamp_(min=1e-5)
137 |
138 |
139 | def labelUnknown(knownFeat, allLab, unknownFeat):
140 | disMat = calDis(knownFeat, unknownFeat)
141 | labLoc = disMat.argmin(dim=0)
142 | return allLab[labLoc]
143 |
144 |
145 | def labelNoise(feature, labels):
146 | # features and labels with -1
147 | noiseFeat, pureFeat = feature[labels == -1, :], feature[labels != -1, :]
148 | labels = labels[labels != -1]
149 | unLab = labelUnknown(pureFeat, labels, noiseFeat)
150 | return unLab.numpy()
151 |
152 |
153 | def main(args):
154 | np.random.seed(args.seed)
155 | torch.manual_seed(args.seed)
156 | cudnn.benchmark = True
157 |
158 | # Create data loaders
159 | assert args.num_instances > 1, "num_instances should be greater than 1"
160 | assert args.batch_size % args.num_instances == 0, \
161 | 'num_instances should divide batch_size'
162 | if args.height is None or args.width is None:
163 | args.height, args.width = (144, 56) if args.arch == 'inception' else \
164 | (256, 128)
165 |
166 | # get source data
167 | src_dataset, src_extfeat_loader = \
168 | get_source_data(args.src_dataset, args.data_dir, args.height,
169 | args.width, args.batch_size, args.workers)
170 | # get target data
171 | tgt_dataset, num_classes, tgt_extfeat_loader, test_loader = \
172 | get_data(args.tgt_dataset, args.data_dir, args.height,
173 | args.width, args.batch_size, args.workers)
174 |
175 | # Create model
176 | # Hacking here to let the classifier be the number of source ids
177 | if args.src_dataset == 'dukemtmc':
178 | model = models.create(args.arch, num_classes=632, pretrained=False)
179 | coModel = models.create(args.arch, num_classes=632, pretrained=False)
180 | elif args.src_dataset == 'market1501':
181 | model = models.create(args.arch, num_classes=676, pretrained=False)
182 | coModel = models.create(args.arch, num_classes=676, pretrained=False)
183 | elif args.src_dataset == 'msmt17':
184 | model = models.create(args.arch, num_classes=1041, pretrained=False)
185 | coModel = models.create(args.arch, num_classes=1041, pretrained=False)
186 | elif args.src_dataset == 'cuhk03':
187 | model = models.create(args.arch, num_classes=1230, pretrained=False)
188 | coModel = models.create(args.arch, num_classes=1230, pretrained=False)
189 | else:
190 | raise RuntimeError('Please specify the number of classes (ids) of the network.')
191 |
192 | # Load from checkpoint
193 | if args.resume:
194 | print('Resuming checkpoints from finetuned model on another dataset...\n')
195 | checkpoint = load_checkpoint(args.resume)
196 | model.load_state_dict(checkpoint['state_dict'], strict=False)
197 | coModel.load_state_dict(checkpoint['state_dict'], strict=False)
198 | else:
199 | raise RuntimeWarning('Not using a pre-trained model.')
200 | model = nn.DataParallel(model).cuda()
201 | coModel = nn.DataParallel(coModel).cuda()
202 |
203 | evaluator = Evaluator(model, print_freq=args.print_freq)
204 | # evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
205 | # if args.evaluate: return
206 |
207 | # Criterion
208 | criterion = [
209 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(),
210 | TripletLoss(args.margin, args.num_instances, isAvg=False, use_semi=False).cuda(),
211 | ]
212 |
213 | # Optimizer
214 | optimizer = torch.optim.Adam(
215 | model.parameters(), lr=args.lr
216 | )
217 | coOptimizer = torch.optim.Adam(
218 | coModel.parameters(), lr=args.lr
219 | )
220 |
221 | optims = [optimizer, coOptimizer]
222 |
223 | # training stage transformer on input images
224 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
225 | train_transformer = T.Compose([
226 | T.Resize((args.height, args.width)),
227 | T.RandomHorizontalFlip(),
228 | T.ToTensor(), normalizer,
229 | T.RandomErasing(probability=0.5, sh=0.2, r1=0.3)
230 | ])
231 |
232 | # # Start training
233 | for iter_n in range(args.iteration):
234 | if args.lambda_value == 0:
235 | source_features = 0
236 | else:
237 | # get source datas' feature
238 | source_features, _ = extract_features(model, src_extfeat_loader, print_freq=args.print_freq)
239 | # synchronization feature order with src_dataset.train
240 | source_features = torch.cat([source_features[f].unsqueeze(0) for f, _, _ in src_dataset.train], 0)
241 |
242 | # extract training images' features
243 | print('Iteration {}: Extracting Target Dataset Features...'.format(iter_n + 1))
244 | target_features, _ = extract_features(model, tgt_extfeat_loader, print_freq=args.print_freq)
245 | # synchronization feature order with dataset.train
246 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in tgt_dataset.trainval], 0)
247 | # calculate distance and rerank result
248 | print('Calculating feature distances...')
249 | target_features = target_features.numpy()
250 | rerank_dist = re_ranking(source_features, target_features, lambda_value=args.lambda_value)
251 | if iter_n == 0:
252 | # DBSCAN cluster
253 | tri_mat = np.triu(rerank_dist, 1) # tri_mat.dim=2
254 | tri_mat = tri_mat[np.nonzero(tri_mat)] # tri_mat.dim=1
255 | tri_mat = np.sort(tri_mat, axis=None)
256 | top_num = np.round(args.rho * tri_mat.size).astype(int)
257 | eps = tri_mat[:top_num].mean()
258 | print('eps in cluster: {:.3f}'.format(eps))
259 | cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=8)
260 |
261 | # select & cluster images as training set of this epochs
262 | print('Clustering and labeling...')
263 | labels = cluster.fit_predict(rerank_dist)
264 | num_ids = len(set(labels)) - 1
265 | print('Iteration {} have {} training ids'.format(iter_n + 1, num_ids))
266 | # generate new dataset
267 | new_dataset, unknown_dataset = [], []
268 | # assign label for target ones
269 | unknownLab = labelNoise(torch.from_numpy(target_features), torch.from_numpy(labels))
270 | # unknownFeats = target_features[labels==-1,:]
271 | unCounter, index = 0, 0
272 | from collections import defaultdict
273 | realIDs, fakeIDs = defaultdict(list), []
274 | for (fname, realPID, cam), label in zip(tgt_dataset.trainval, labels):
275 | if label == -1:
276 | unknown_dataset.append((fname, int(unknownLab[unCounter]), cam)) # unknown data
277 | fakeIDs.append(int(unknownLab[unCounter]))
278 | realIDs[realPID].append(index)
279 | unCounter += 1
280 | index += 1
281 | continue
282 | # dont need to change codes in trainer.py _parsing_input function and sampler function after add 0
283 | new_dataset.append((fname, label, cam))
284 | fakeIDs.append(label)
285 | realIDs[realPID].append(index)
286 | index += 1
287 | print('Iteration {} have {} training images'.format(iter_n + 1, len(new_dataset)))
288 | precision, recall, fscore = calScores(realIDs, np.asarray(fakeIDs)) # fakeIDs does not contain -1
289 | print('precision:{}, recall:{}, fscore: {}'.format(100 * precision, 100 * recall, fscore))
290 |
291 | train_loader = DataLoader(
292 | Preprocessor(new_dataset, root=tgt_dataset.images_dir, transform=train_transformer),
293 | batch_size=args.batch_size, num_workers=4,
294 | sampler=RandomIdentitySampler(new_dataset, args.num_instances),
295 | pin_memory=True, drop_last=True
296 | )
297 | # hard samples
298 | # noiseImgs = [name[1] for name in unknown_dataset]
299 | # saveAll(noiseImgs, tgt_dataset.images_dir, 'noiseImg')
300 | # import ipdb; ipdb.set_trace()
301 | unLoader = DataLoader(
302 | Preprocessor(unknown_dataset, root=tgt_dataset.images_dir, transform=train_transformer),
303 | batch_size=args.batch_size, num_workers=4,
304 | sampler=RandomIdentitySampler(unknown_dataset, args.num_instances),
305 | pin_memory=True, drop_last=True
306 | )
307 | # train model with new generated dataset
308 | trainer = RCoTeaching(
309 | model, coModel, train_loader, unLoader, criterion, optims
310 | )
311 |
312 | # Start training
313 | for epoch in range(args.epochs):
314 | trainer.train(epoch, remRate=0.2 + (0.8 / args.iteration) * (1 + iter_n))
315 |
316 | # test only
317 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
318 | # print('co-model:\n')
319 | # rank_score = evaluatorB.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
320 |
321 | # Evaluate
322 | rank_score = evaluator.evaluate(test_loader, tgt_dataset.query, tgt_dataset.gallery)
323 | save_checkpoint({
324 | 'state_dict': model.module.state_dict(),
325 | 'epoch': epoch + 1, 'best_top1': rank_score.market1501[0],
326 | }, True, fpath=osp.join(args.logs_dir, 'RCT.pth'))
327 | return rank_score.map, rank_score.market1501[0]
328 |
329 |
330 | if __name__ == '__main__':
331 | parser = argparse.ArgumentParser(description="Triplet loss classification")
332 | # data
333 | parser.add_argument('--src_dataset', type=str, default='dukemtmc',
334 | choices=datasets.names())
335 | parser.add_argument('--tgt_dataset', type=str, default='market1501',
336 | choices=datasets.names())
337 | parser.add_argument('--batch_size', type=int, default=64)
338 | parser.add_argument('--workers', type=int, default=4)
339 | parser.add_argument('--split', type=int, default=0)
340 | parser.add_argument('--noiseLam', type=float, default=0.5)
341 | parser.add_argument('--height', type=int,
342 | help="input height, default: 256 for resnet*, "
343 | "144 for inception")
344 | parser.add_argument('--width', type=int,
345 | help="input width, default: 128 for resnet*, "
346 | "56 for inception")
347 | parser.add_argument('--combine-trainval', action='store_true',
348 | help="train and val sets together for training, "
349 | "val set alone for validation")
350 | parser.add_argument('--num_instances', type=int, default=4,
351 | help="each minibatch consist of "
352 | "(batch_size // num_instances) identities, and "
353 | "each identity has num_instances instances, "
354 | "default: 4")
355 | # model
356 | parser.add_argument('--arch', type=str, default='resnet50',
357 | choices=models.names())
358 | # loss
359 | parser.add_argument('--margin', type=float, default=0.5,
360 | help="margin of the triplet loss, default: 0.5")
361 | parser.add_argument('--lambda_value', type=float, default=0.1,
362 | help="balancing parameter, default: 0.1")
363 | parser.add_argument('--rho', type=float, default=1.6e-3,
364 | help="rho percentage, default: 1.6e-3")
365 | # optimizer
366 | parser.add_argument('--lr', type=float, default=6e-5,
367 | help="learning rate of all parameters")
368 | # training configs
369 | parser.add_argument('--resume', type=str, metavar='PATH',
370 | default='')
371 | parser.add_argument('--evaluate', type=int, default=0,
372 | help="evaluation only")
373 | parser.add_argument('--seed', type=int, default=1)
374 | parser.add_argument('--print_freq', type=int, default=1)
375 | parser.add_argument('--iteration', type=int, default=10)
376 | parser.add_argument('--epochs', type=int, default=30)
377 | # metric learning
378 | parser.add_argument('--dist_metric', type=str, default='euclidean',
379 | choices=['euclidean', 'kissme'])
380 | # misc
381 | parser.add_argument('--data_dir', type=str, metavar='PATH',
382 | default='')
383 | parser.add_argument('--logs_dir', type=str, metavar='PATH',
384 | default='')
385 |
386 | args = parser.parse_args()
387 | mean_ap, rank1 = main(args)
388 |
--------------------------------------------------------------------------------