├── reid ├── .DS_Store ├── loss │ ├── .DS_Store │ ├── __init__.py │ ├── oim.py │ └── triplet.py ├── models │ ├── .DS_Store │ ├── __init__.py │ ├── inception.py │ └── resnet.py ├── utils │ ├── data │ │ ├── __init__.py │ │ ├── preprocessor.py │ │ ├── sampler.py │ │ ├── transforms.py │ │ └── dataset.py │ ├── osutils.py │ ├── meters.py │ ├── __init__.py │ ├── logging.py │ └── serialization.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── feature_extraction │ ├── __init__.py │ ├── cnn.py │ └── database.py ├── __init__.py ├── metric_learning │ ├── euclidean.py │ ├── __init__.py │ └── kissme.py ├── dist_metric.py ├── datasets │ ├── __init__.py │ ├── cuhk01.py │ ├── viper.py │ ├── cuhk03.py │ ├── market1501.py │ └── dukemtmc.py ├── test_dataset.py ├── trainers.py ├── evaluators.py └── extract_feature.py ├── README.md ├── run.sh ├── LICENSE └── mancs_train.py /reid/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/mancs/HEAD/reid/.DS_Store -------------------------------------------------------------------------------- /reid/loss/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/mancs/HEAD/reid/loss/.DS_Store -------------------------------------------------------------------------------- /reid/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/mancs/HEAD/reid/models/.DS_Store -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .dataset import Dataset 4 | from .preprocessor import Preprocessor 5 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cnn import extract_cnn_feature 4 | from .database import FeatureDatabase 5 | 6 | __all__ = [ 7 | 'extract_cnn_feature', 8 | 'FeatureDatabase', 9 | ] 10 | -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .oim import oim, OIM, OIMLoss 4 | from .triplet import TripletLoss,FocalLoss 5 | 6 | __all__ = [ 7 | 'oim', 8 | 'OIM', 9 | 'OIMLoss', 10 | 'TripletLoss', 11 | ] 12 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import feature_extraction 6 | from . import loss 7 | from . import metric_learning 8 | from . import models 9 | from . import utils 10 | from . import dist_metric 11 | from . import evaluators 12 | from . import trainers 13 | 14 | __version__ = '0.2.0' 15 | -------------------------------------------------------------------------------- /reid/metric_learning/euclidean.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from metric_learn.base_metric import BaseMetricLearner 5 | 6 | 7 | class Euclidean(BaseMetricLearner): 8 | def __init__(self): 9 | self.M_ = None 10 | 11 | def metric(self): 12 | return self.M_ 13 | 14 | def fit(self, X): 15 | self.M_ = np.eye(X.shape[1]) 16 | self.X_ = X 17 | 18 | def transform(self, X=None): 19 | if X is None: 20 | return self.X_ 21 | return X 22 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | If you find this project helpful for your research, please cite the following papers: 3 | 4 | ``` 5 | @inproceedings{wang2018mancs, 6 | title={Mancs: A multi-task attentional network with curriculum sampling for person re-identification}, 7 | author={Wang, Cheng and Zhang, Qian and Huang, Chang and Liu, Wenyu and Wang, Xinggang}, 8 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 9 | pages={365--381}, 10 | year={2018} 11 | } 12 | ``` 13 | 14 | IMPORTANT NOTICE: Although this software is licensed under MIT, our intention is to make it free for academic research purposes. If you are going to use it in a product, we suggest you [contact us](https://xinggangw.info) regarding possible patent issues. 15 | -------------------------------------------------------------------------------- /reid/metric_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from metric_learn import (ITML_Supervised, LMNN, LSML_Supervised, 4 | SDML_Supervised, NCA, LFDA, RCA_Supervised) 5 | 6 | from .euclidean import Euclidean 7 | from .kissme import KISSME 8 | 9 | __factory = { 10 | 'euclidean': Euclidean, 11 | 'kissme': KISSME, 12 | 'itml': ITML_Supervised, 13 | 'lmnn': LMNN, 14 | 'lsml': LSML_Supervised, 15 | 'sdml': SDML_Supervised, 16 | 'nca': NCA, 17 | 'lfda': LFDA, 18 | 'rca': RCA_Supervised, 19 | } 20 | 21 | 22 | def get_metric(algorithm, *args, **kwargs): 23 | if algorithm not in __factory: 24 | raise KeyError("Unknown metric:", algorithm) 25 | return __factory[algorithm](*args, **kwargs) 26 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python ./mancs_train.py -d market1501 \ 2 | -a resnet50 \ 3 | -b 256 \ 4 | --num-instances 16 \ 5 | --data-dir ./data \ 6 | --logs-dir ./logs/market1501-resnet50/ \ 7 | --lr 0.0003 \ 8 | --margin 0.5 \ 9 | -j 4 \ 10 | --epochs 150 11 | --combine-trainval \ 12 | --start_save 100 \ 13 | -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import OrderedDict 3 | 4 | from torch.autograd import Variable 5 | 6 | from ..utils import to_torch 7 | 8 | 9 | def extract_cnn_feature(model, inputs, modules=None): 10 | model.eval() 11 | inputs = to_torch(inputs) 12 | inputs = Variable(inputs, volatile=True) 13 | if modules is None: 14 | outputs = model(inputs) 15 | outputs = outputs[0].data.cpu() #outputs contains [x1, x2, x3] 16 | return outputs 17 | # Register forward hook for each module 18 | outputs = OrderedDict() 19 | handles = [] 20 | for m in modules: 21 | outputs[id(m)] = None 22 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 23 | handles.append(m.register_forward_hook(func)) 24 | model(inputs) 25 | for h in handles: 26 | h.remove() 27 | return list(outputs.values()) 28 | -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/dist_metric.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | from .evaluators import extract_features 6 | from .metric_learning import get_metric 7 | 8 | 9 | class DistanceMetric(object): 10 | def __init__(self, algorithm='euclidean', *args, **kwargs): 11 | super(DistanceMetric, self).__init__() 12 | self.algorithm = algorithm 13 | self.metric = get_metric(algorithm, *args, **kwargs) 14 | 15 | def train(self, model, data_loader): 16 | if self.algorithm == 'euclidean': return 17 | features, labels = extract_features(model, data_loader) 18 | features = torch.stack(features.values()).numpy() 19 | labels = torch.Tensor(list(labels.values())).numpy() 20 | self.metric.fit(features, labels) 21 | 22 | def transform(self, X): 23 | if torch.is_tensor(X): 24 | X = X.numpy() 25 | X = self.metric.transform(X) 26 | X = torch.from_numpy(X) 27 | else: 28 | X = self.metric.transform(X) 29 | return X 30 | 31 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | 6 | 7 | class Preprocessor(object): 8 | def __init__(self, dataset, root=None, transform=None): 9 | super(Preprocessor, self).__init__() 10 | self.dataset = dataset 11 | self.root = root 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | def __getitem__(self, indices): 18 | if isinstance(indices, (tuple, list)): 19 | return [self._get_single_item(index) for index in indices] 20 | return self._get_single_item(indices) 21 | 22 | def _get_single_item(self, index): 23 | fname, pid, camid = self.dataset[index] 24 | fpath = fname 25 | if self.root is not None: 26 | fpath = osp.join(self.root, fname) 27 | img = Image.open(fpath).convert('RGB') 28 | if self.transform is not None: 29 | img = self.transform(img) 30 | return img, fname, pid, camid 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Cheng Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | def __init__(self, data_source, num_instances=1): 13 | self.data_source = data_source 14 | self.num_instances = num_instances 15 | self.index_dic = defaultdict(list) 16 | for index, (_, pid, _) in enumerate(data_source): 17 | self.index_dic[pid].append(index) 18 | self.pids = list(self.index_dic.keys()) 19 | self.num_samples = len(self.pids) 20 | 21 | def __len__(self): 22 | return self.num_samples * self.num_instances 23 | 24 | def __iter__(self): 25 | indices = torch.randperm(self.num_samples) 26 | ret = [] 27 | for i in indices: 28 | pid = self.pids[i] 29 | t = self.index_dic[pid] 30 | if len(t) >= self.num_instances: 31 | t = np.random.choice(t, size=self.num_instances, replace=False) 32 | else: 33 | t = np.random.choice(t, size=self.num_instances, replace=True) 34 | ret.extend(t) 35 | return iter(ret) 36 | -------------------------------------------------------------------------------- /reid/feature_extraction/database.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import h5py 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FeatureDatabase(Dataset): 9 | def __init__(self, *args, **kwargs): 10 | super(FeatureDatabase, self).__init__() 11 | self.fid = h5py.File(*args, **kwargs) 12 | 13 | def __enter__(self): 14 | return self 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | self.close() 18 | 19 | def __getitem__(self, keys): 20 | if isinstance(keys, (tuple, list)): 21 | return [self._get_single_item(k) for k in keys] 22 | return self._get_single_item(keys) 23 | 24 | def _get_single_item(self, key): 25 | return np.asarray(self.fid[key]) 26 | 27 | def __setitem__(self, key, value): 28 | if key in self.fid: 29 | if self.fid[key].shape == value.shape and \ 30 | self.fid[key].dtype == value.dtype: 31 | self.fid[key][...] = value 32 | else: 33 | del self.fid[key] 34 | self.fid.create_dataset(key, data=value) 35 | else: 36 | self.fid.create_dataset(key, data=value) 37 | 38 | def __delitem__(self, key): 39 | del self.fid[key] 40 | 41 | def __len__(self): 42 | return len(self.fid) 43 | 44 | def __iter__(self): 45 | return iter(self.fid) 46 | 47 | def flush(self): 48 | self.fid.flush() 49 | 50 | def close(self): 51 | self.fid.close() 52 | -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .cuhk01 import CUHK01 5 | from .cuhk03 import CUHK03 6 | from .dukemtmc import DukeMTMC 7 | from .market1501 import Market1501 8 | from .viper import VIPeR 9 | 10 | 11 | __factory = { 12 | 'viper': VIPeR, 13 | 'cuhk01': CUHK01, 14 | 'cuhk03': CUHK03, 15 | 'market1501': Market1501, 16 | 'dukemtmc': DukeMTMC, 17 | } 18 | 19 | 20 | def names(): 21 | return sorted(__factory.keys()) 22 | 23 | 24 | def create(name, root, *args, **kwargs): 25 | """ 26 | Create a dataset instance. 27 | 28 | Parameters 29 | ---------- 30 | name : str 31 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03', 32 | 'market1501', and 'dukemtmc'. 33 | root : str 34 | The path to the dataset directory. 35 | split_id : int, optional 36 | The index of data split. Default: 0 37 | num_val : int or float, optional 38 | When int, it means the number of validation identities. When float, 39 | it means the proportion of validation to all the trainval. Default: 100 40 | download : bool, optional 41 | If True, will download the dataset. Default: False 42 | """ 43 | if name not in __factory: 44 | raise KeyError("Unknown dataset:", name) 45 | return __factory[name](root, *args, **kwargs) 46 | 47 | 48 | def get_dataset(name, root, *args, **kwargs): 49 | warnings.warn("get_dataset is deprecated. Use create instead.") 50 | return create(name, root, *args, **kwargs) 51 | -------------------------------------------------------------------------------- /reid/metric_learning/kissme.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from metric_learn.base_metric import BaseMetricLearner 5 | 6 | 7 | def validate_cov_matrix(M): 8 | M = (M + M.T) * 0.5 9 | k = 0 10 | I = np.eye(M.shape[0]) 11 | while True: 12 | try: 13 | _ = np.linalg.cholesky(M) 14 | break 15 | except np.linalg.LinAlgError: 16 | # Find the nearest positive definite matrix for M. Modified from 17 | # http://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd 18 | # Might take several minutes 19 | k += 1 20 | w, v = np.linalg.eig(M) 21 | min_eig = v.min() 22 | M += (-min_eig * k * k + np.spacing(min_eig)) * I 23 | return M 24 | 25 | 26 | class KISSME(BaseMetricLearner): 27 | def __init__(self): 28 | self.M_ = None 29 | 30 | def metric(self): 31 | return self.M_ 32 | 33 | def fit(self, X, y=None): 34 | n = X.shape[0] 35 | if y is None: 36 | y = np.arange(n) 37 | X1, X2 = np.meshgrid(np.arange(n), np.arange(n)) 38 | X1, X2 = X1[X1 < X2], X2[X1 < X2] 39 | matches = (y[X1] == y[X2]) 40 | num_matches = matches.sum() 41 | num_non_matches = len(matches) - num_matches 42 | idxa = X1[matches] 43 | idxb = X2[matches] 44 | S = X[idxa] - X[idxb] 45 | C1 = S.transpose().dot(S) / num_matches 46 | p = np.random.choice(num_non_matches, num_matches, replace=False) 47 | idxa = X1[~matches] 48 | idxb = X2[~matches] 49 | idxa = idxa[p] 50 | idxb = idxb[p] 51 | S = X[idxa] - X[idxb] 52 | C0 = S.transpose().dot(S) / num_matches 53 | self.M_ = np.linalg.inv(C1) - np.linalg.inv(C0) 54 | self.M_ = validate_cov_matrix(self.M_) 55 | self.X_ = X 56 | -------------------------------------------------------------------------------- /reid/loss/oim.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, autograd 6 | 7 | 8 | class OIM(autograd.Function): 9 | def __init__(self, lut, momentum=0.5): 10 | super(OIM, self).__init__() 11 | self.lut = lut 12 | self.momentum = momentum 13 | 14 | def forward(self, inputs, targets): 15 | self.save_for_backward(inputs, targets) 16 | outputs = inputs.mm(self.lut.t()) 17 | return outputs 18 | 19 | def backward(self, grad_outputs): 20 | inputs, targets = self.saved_tensors 21 | grad_inputs = None 22 | if self.needs_input_grad[0]: 23 | grad_inputs = grad_outputs.mm(self.lut) 24 | for x, y in zip(inputs, targets): 25 | self.lut[y] = self.momentum * self.lut[y] + (1. - self.momentum) * x 26 | self.lut[y] /= self.lut[y].norm() 27 | return grad_inputs, None 28 | 29 | 30 | def oim(inputs, targets, lut, momentum=0.5): 31 | return OIM(lut, momentum=momentum)(inputs, targets) 32 | 33 | 34 | class OIMLoss(nn.Module): 35 | def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5, 36 | weight=None, size_average=True): 37 | super(OIMLoss, self).__init__() 38 | self.num_features = num_features 39 | self.num_classes = num_classes 40 | self.momentum = momentum 41 | self.scalar = scalar 42 | self.weight = weight 43 | self.size_average = size_average 44 | 45 | self.register_buffer('lut', torch.zeros(num_classes, num_features)) 46 | 47 | def forward(self, inputs, targets): 48 | inputs = oim(inputs, targets, self.lut, momentum=self.momentum) 49 | inputs *= self.scalar 50 | loss = F.cross_entropy(inputs, targets, weight=self.weight, 51 | size_average=self.size_average) 52 | return loss, inputs 53 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .inception import * 4 | from .resnet import * 5 | 6 | 7 | __factory = { 8 | 'inception': inception, 9 | 'resnet18': resnet18, 10 | 'resnet34': resnet34, 11 | 'resnet50': resnet50, 12 | 'resnet101': resnet101, 13 | 'resnet152': resnet152, 14 | } 15 | 16 | 17 | def names(): 18 | return sorted(__factory.keys()) 19 | 20 | 21 | def create(name, *args, **kwargs): 22 | """ 23 | Create a model instance. 24 | 25 | Parameters 26 | ---------- 27 | name : str 28 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 29 | 'resnet50', 'resnet101', and 'resnet152'. 30 | pretrained : bool, optional 31 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 32 | model. Default: True 33 | cut_at_pooling : bool, optional 34 | If True, will cut the model before the last global pooling layer and 35 | ignore the remaining kwargs. Default: False 36 | num_features : int, optional 37 | If positive, will append a Linear layer after the global pooling layer, 38 | with this number of output units, followed by a BatchNorm layer. 39 | Otherwise these layers will not be appended. Default: 256 for 40 | 'inception', 0 for 'resnet*' 41 | norm : bool, optional 42 | If True, will normalize the feature to be unit L2-norm for each sample. 43 | Otherwise will append a ReLU layer after the above Linear layer if 44 | num_features > 0. Default: False 45 | dropout : float, optional 46 | If positive, will append a Dropout layer with this dropout rate. 47 | Default: 0 48 | num_classes : int, optional 49 | If positive, will append a Linear layer at the end as the classifier 50 | with this number of output units. Default: 0 51 | """ 52 | if name not in __factory: 53 | raise KeyError("Unknown model:", name) 54 | return __factory[name](*args, **kwargs) 55 | -------------------------------------------------------------------------------- /reid/test_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 8 | 9 | 10 | def is_image_file(filename): 11 | """Checks if a file is an image. 12 | Args: 13 | filename (string): path to a file 14 | Returns: 15 | bool: True if the filename ends with a known image extension 16 | """ 17 | filename_lower = filename.lower() 18 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 19 | 20 | 21 | def find_classes(dir): 22 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 23 | classes.sort() 24 | class_to_idx = {classes[i]: i for i in range(len(classes))} 25 | return classes, class_to_idx 26 | 27 | 28 | def make_dataset(dir_): 29 | images = [] 30 | dir_ = os.path.expanduser(dir_) 31 | for fname in sorted(os.listdir(dir_)): 32 | if is_image_file(fname): 33 | path = os.path.join(dir_, fname) 34 | images.append(path) 35 | 36 | 37 | return images 38 | 39 | 40 | def pil_loader(path): 41 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 42 | with open(path, 'rb') as f: 43 | img = Image.open(f) 44 | return img.convert('RGB') 45 | 46 | 47 | def accimage_loader(path): 48 | import accimage 49 | try: 50 | return accimage.Image(path) 51 | except IOError: 52 | # Potentially a decoding problem, fall back to PIL.Image 53 | return pil_loader(path) 54 | 55 | 56 | def default_loader(path): 57 | from torchvision import get_image_backend 58 | if get_image_backend() == 'accimage': 59 | return accimage_loader(path) 60 | else: 61 | return pil_loader(path) 62 | 63 | 64 | class TestDataset(data.Dataset): 65 | 66 | 67 | def __init__(self, root, transform=None, loader=default_loader): 68 | imgs = make_dataset(root) 69 | if len(imgs) == 0: 70 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 71 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 72 | 73 | self.root = root 74 | self.imgs = imgs 75 | self.transform = transform 76 | self.loader = loader 77 | 78 | def __getitem__(self, index): 79 | """ 80 | Args: 81 | index (int): Index 82 | Returns: 83 | tuple: (image, target) where target is class_index of the target class. 84 | """ 85 | path = self.imgs[index] 86 | img = self.loader(path) 87 | if self.transform is not None: 88 | img = self.transform(img) 89 | 90 | return img, path.split('/')[-1] 91 | def __len__(self): 92 | return len(self.imgs) -------------------------------------------------------------------------------- /reid/datasets/cuhk01.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..utils.data import Dataset 7 | from ..utils.osutils import mkdir_if_missing 8 | from ..utils.serialization import write_json 9 | 10 | 11 | class CUHK01(Dataset): 12 | url = 'https://docs.google.com/spreadsheet/viewform?formkey=dF9pZ1BFZkNiMG1oZUdtTjZPalR0MGc6MA' 13 | md5 = 'e6d55c0da26d80cda210a2edeb448e98' 14 | 15 | def __init__(self, root, split_id=0, num_val=100, download=True): 16 | super(CUHK01, self).__init__(root, split_id=split_id) 17 | 18 | if download: 19 | self.download() 20 | 21 | if not self._check_integrity(): 22 | raise RuntimeError("Dataset not found or corrupted. " + 23 | "You can use download=True to download it.") 24 | 25 | self.load(num_val) 26 | 27 | def download(self): 28 | if self._check_integrity(): 29 | print("Files already downloaded and verified") 30 | return 31 | 32 | import hashlib 33 | import shutil 34 | from glob import glob 35 | from zipfile import ZipFile 36 | 37 | raw_dir = osp.join(self.root, 'raw') 38 | mkdir_if_missing(raw_dir) 39 | 40 | # Download the raw zip file 41 | fpath = osp.join(raw_dir, 'CUHK01.zip') 42 | if osp.isfile(fpath) and \ 43 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 44 | print("Using downloaded file: " + fpath) 45 | else: 46 | raise RuntimeError("Please download the dataset manually from {} " 47 | "to {}".format(self.url, fpath)) 48 | 49 | # Extract the file 50 | exdir = osp.join(raw_dir, 'campus') 51 | if not osp.isdir(exdir): 52 | print("Extracting zip file") 53 | with ZipFile(fpath) as z: 54 | z.extractall(path=raw_dir) 55 | 56 | # Format 57 | images_dir = osp.join(self.root, 'images') 58 | mkdir_if_missing(images_dir) 59 | 60 | identities = [[[] for _ in range(2)] for _ in range(971)] 61 | 62 | files = sorted(glob(osp.join(exdir, '*.png'))) 63 | for fpath in files: 64 | fname = osp.basename(fpath) 65 | pid, cam = int(fname[:4]), int(fname[4:7]) 66 | assert 1 <= pid <= 971 67 | assert 1 <= cam <= 4 68 | pid, cam = pid - 1, (cam - 1) // 2 69 | fname = ('{:08d}_{:02d}_{:04d}.png' 70 | .format(pid, cam, len(identities[pid][cam]))) 71 | identities[pid][cam].append(fname) 72 | shutil.copy(fpath, osp.join(images_dir, fname)) 73 | 74 | # Save meta information into a json file 75 | meta = {'name': 'cuhk01', 'shot': 'multiple', 'num_cameras': 2, 76 | 'identities': identities} 77 | write_json(meta, osp.join(self.root, 'meta.json')) 78 | 79 | # Randomly create ten training and test split 80 | num = len(identities) 81 | splits = [] 82 | for _ in range(10): 83 | pids = np.random.permutation(num).tolist() 84 | trainval_pids = sorted(pids[:num // 2]) 85 | test_pids = sorted(pids[num // 2:]) 86 | split = {'trainval': trainval_pids, 87 | 'query': test_pids, 88 | 'gallery': test_pids} 89 | splits.append(split) 90 | write_json(splits, osp.join(self.root, 'splits.json')) 91 | -------------------------------------------------------------------------------- /reid/datasets/viper.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..utils.data import Dataset 7 | from ..utils.osutils import mkdir_if_missing 8 | from ..utils.serialization import write_json 9 | 10 | 11 | class VIPeR(Dataset): 12 | url = 'http://users.soe.ucsc.edu/~manduchi/VIPeR.v1.0.zip' 13 | md5 = '1c2d9fc1cc800332567a0da25a1ce68c' 14 | 15 | def __init__(self, root, split_id=0, num_val=100, download=True): 16 | super(VIPeR, self).__init__(root, split_id=split_id) 17 | 18 | if download: 19 | self.download() 20 | 21 | if not self._check_integrity(): 22 | raise RuntimeError("Dataset not found or corrupted. " + 23 | "You can use download=True to download it.") 24 | 25 | self.load(num_val) 26 | 27 | def download(self): 28 | if self._check_integrity(): 29 | print("Files already downloaded and verified") 30 | return 31 | 32 | import hashlib 33 | from glob import glob 34 | from scipy.misc import imsave, imread 35 | from six.moves import urllib 36 | from zipfile import ZipFile 37 | 38 | raw_dir = osp.join(self.root, 'raw') 39 | mkdir_if_missing(raw_dir) 40 | 41 | # Download the raw zip file 42 | fpath = osp.join(raw_dir, 'VIPeR.v1.0.zip') 43 | if osp.isfile(fpath) and \ 44 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 45 | print("Using downloaded file: " + fpath) 46 | else: 47 | print("Downloading {} to {}".format(self.url, fpath)) 48 | urllib.request.urlretrieve(self.url, fpath) 49 | 50 | # Extract the file 51 | exdir = osp.join(raw_dir, 'VIPeR') 52 | if not osp.isdir(exdir): 53 | print("Extracting zip file") 54 | with ZipFile(fpath) as z: 55 | z.extractall(path=raw_dir) 56 | 57 | # Format 58 | images_dir = osp.join(self.root, 'images') 59 | mkdir_if_missing(images_dir) 60 | cameras = [sorted(glob(osp.join(exdir, 'cam_a', '*.bmp'))), 61 | sorted(glob(osp.join(exdir, 'cam_b', '*.bmp')))] 62 | assert len(cameras[0]) == len(cameras[1]) 63 | identities = [] 64 | for pid, (cam1, cam2) in enumerate(zip(*cameras)): 65 | images = [] 66 | # view-0 67 | fname = '{:08d}_{:02d}_{:04d}.jpg'.format(pid, 0, 0) 68 | imsave(osp.join(images_dir, fname), imread(cam1)) 69 | images.append([fname]) 70 | # view-1 71 | fname = '{:08d}_{:02d}_{:04d}.jpg'.format(pid, 1, 0) 72 | imsave(osp.join(images_dir, fname), imread(cam2)) 73 | images.append([fname]) 74 | identities.append(images) 75 | 76 | # Save meta information into a json file 77 | meta = {'name': 'VIPeR', 'shot': 'single', 'num_cameras': 2, 78 | 'identities': identities} 79 | write_json(meta, osp.join(self.root, 'meta.json')) 80 | 81 | # Randomly create ten training and test split 82 | num = len(identities) 83 | splits = [] 84 | for _ in range(10): 85 | pids = np.random.permutation(num).tolist() 86 | trainval_pids = sorted(pids[:num // 2]) 87 | test_pids = sorted(pids[num // 2:]) 88 | split = {'trainval': trainval_pids, 89 | 'query': test_pids, 90 | 'gallery': test_pids} 91 | splits.append(split) 92 | write_json(splits, osp.join(self.root, 'splits.json')) 93 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from PIL import Image, ImageOps, ImageEnhance 4 | import random 5 | import math 6 | from torchvision.transforms import * 7 | 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | class RandomErasing(object): 52 | ''' 53 | Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al. 54 | ------------------------------------------------------------------------------------- 55 | probability: The probability that the operation will be performed. 56 | sl: min erasing area 57 | sh: max erasing area 58 | r1: min aspect ratio 59 | mean: erasing value 60 | ------------------------------------------------------------------------------------- 61 | ''' 62 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.2, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 63 | self.probability = probability 64 | self.mean = mean 65 | self.sl = sl 66 | self.sh = sh 67 | self.r1 = r1 68 | 69 | def __call__(self, img): 70 | 71 | if random.uniform(0, 1) > self.probability: 72 | return img 73 | 74 | for attempt in range(100): 75 | area = img.size()[1] * img.size()[2] 76 | 77 | target_area = random.uniform(self.sl, self.sh) * area 78 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 79 | 80 | h = int(round(math.sqrt(target_area * aspect_ratio))) 81 | w = int(round(math.sqrt(target_area / aspect_ratio))) 82 | 83 | if w < img.size()[2] and h < img.size()[1]: 84 | x1 = random.randint(0, img.size()[1] - h) 85 | y1 = random.randint(0, img.size()[2] - w) 86 | if img.size()[0] == 3: 87 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 88 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 89 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 90 | else: 91 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 92 | return img 93 | 94 | return img 95 | -------------------------------------------------------------------------------- /reid/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | from scipy.stats import norm 8 | 9 | import numpy as np 10 | 11 | class TripletLoss(nn.Module): 12 | def __init__(self, margin=0, num_instances=0): 13 | super(TripletLoss, self).__init__() 14 | self.margin = margin 15 | self.ranking_loss = nn.MarginRankingLoss(margin=self.margin) 16 | self.K = num_instances 17 | 18 | def forward(self, inputs, targets, epoch): 19 | n = inputs.size(0) 20 | P = n/self.K 21 | t0 = 30.0 22 | t1 = 60.0 23 | 24 | # Compute pairwise distance, replace by the official when merged 25 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 26 | dist = dist + dist.t() 27 | dist.addmm_(1, -2, inputs, inputs.t()) 28 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 29 | # For each anchor, find the hardest positive and negative 30 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 31 | dist_ap, dist_an = [], [] 32 | if True: 33 | mean = max(80.0-3.0*epoch, 0.0) 34 | std = 15*0.001**(max((epoch-t0)/(t1-t0), 0.0)) 35 | neg_probs = norm(mean, std).pdf(np.linspace(0,79,80)) 36 | neg_probs = torch.from_numpy(neg_probs).clamp(min=3e-5) 37 | for i in range(P): 38 | for j in range(self.K): 39 | neg_examples = dist[i*self.K+j][mask[i*self.K+j] == 0] 40 | sort_neg_examples = torch.topk(neg_examples, k=80, largest=False)[0] 41 | for pair in range(j+1,self.K): 42 | dist_ap.append(dist[i*self.K+j][i*self.K+pair]) 43 | choosen_neg = sort_neg_examples[torch.multinomial(neg_probs,1).cuda()] 44 | dist_an.append(choosen_neg) 45 | else: 46 | for i in range(n): 47 | dist_ap.append(dist[i][mask[i]].max()) 48 | dist_an.append(dist[i][mask[i] == 0].min()) 49 | dist_ap = torch.cat(dist_ap) 50 | dist_an = torch.cat(dist_an) 51 | # Compute ranking hinge loss 52 | y = dist_an.data.new() 53 | y.resize_as_(dist_an.data) 54 | y.fill_(1) 55 | y = Variable(y) 56 | loss = self.ranking_loss(dist_an, dist_ap, y) 57 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 58 | return loss, prec 59 | 60 | class FocalLoss(nn.Module): 61 | def __init__(self, gamma=2.0, alpha=None, size_average=True): 62 | super(FocalLoss, self).__init__() 63 | self.gamma = gamma 64 | self.alpha = alpha 65 | if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha]) 66 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 67 | self.size_average = size_average 68 | 69 | def forward(self, input, target, epoch): 70 | if input.dim()>2: 71 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 72 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 73 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 74 | target = target.view(-1,1) 75 | logpt = F.log_softmax(input) 76 | logpt = logpt.gather(1,target) 77 | logpt = logpt.view(-1) 78 | pt = Variable(logpt.data.exp()) 79 | if self.alpha is not None: 80 | if self.alpha.type()!=input.data.type(): 81 | self.alpha = self.alpha.type_as(input.data) 82 | at = self.alpha.gather(0,target.data.view(-1)) 83 | logpt = logpt * Variable(at) 84 | 85 | loss = -1*((1-pt)**self.gamma)*logpt 86 | if self.size_average: return loss.mean() 87 | else: return loss.sum() 88 | 89 | -------------------------------------------------------------------------------- /reid/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from .evaluation_metrics import accuracy 8 | from .loss import OIMLoss, TripletLoss 9 | from .utils.meters import AverageMeter 10 | 11 | 12 | class BaseTrainer(object): 13 | def __init__(self, model, criterions): 14 | super(BaseTrainer, self).__init__() 15 | self.model = model 16 | self.criterions = criterions 17 | 18 | def train(self, epoch, data_loader, optimizer, print_freq=1): 19 | self.model.train() 20 | 21 | batch_time = AverageMeter() 22 | data_time = AverageMeter() 23 | losses = AverageMeter() 24 | precisions = AverageMeter() 25 | 26 | end = time.time() 27 | for i, inputs in enumerate(data_loader): 28 | data_time.update(time.time() - end) 29 | 30 | inputs, targets = self._parse_data(inputs) 31 | loss, prec1 = self._forward(inputs, targets, epoch) 32 | 33 | losses.update(loss.data[0], targets.size(0)) 34 | precisions.update(prec1, targets.size(0)) 35 | 36 | optimizer.zero_grad() 37 | loss.backward() 38 | #add gradient clip for lstm 39 | for param in self.model.parameters(): 40 | try: 41 | param.grad.data.clamp(-1., 1.) 42 | except: 43 | continue 44 | 45 | optimizer.step() 46 | 47 | batch_time.update(time.time() - end) 48 | end = time.time() 49 | 50 | if (i + 1) % print_freq == 0: 51 | print('Epoch: [{}][{}/{}]\t' 52 | 'Time {:.3f} ({:.3f})\t' 53 | 'Data {:.3f} ({:.3f})\t' 54 | 'Loss {:.3f} ({:.3f})\t' 55 | 'Prec {:.2%} ({:.2%})\t' 56 | .format(epoch, i + 1, len(data_loader), 57 | batch_time.val, batch_time.avg, 58 | data_time.val, data_time.avg, 59 | losses.val, losses.avg, 60 | precisions.val, precisions.avg)) 61 | 62 | def _parse_data(self, inputs): 63 | raise NotImplementedError 64 | 65 | def _forward(self, inputs, targets): 66 | raise NotImplementedError 67 | 68 | 69 | class Trainer(BaseTrainer): 70 | def _parse_data(self, inputs): 71 | imgs, _, pids, _ = inputs 72 | inputs = [Variable(imgs)] 73 | targets = Variable(pids.cuda()) 74 | return inputs, targets 75 | 76 | def _forward(self, inputs, targets, epoch): 77 | outputs = self.model(*inputs) #outputs=[x1,x2,x3] 78 | #new added by wc 79 | # x1 triplet loss 80 | loss_tri, prec_tri = self.criterions[0](outputs[0], targets, epoch) 81 | # x2 global feature cross entropy loss 82 | loss_global = self.criterions[1](outputs[1], targets,epoch) 83 | #prec_global, = accuracy(outputs[1].data, targets.data) 84 | #prec_global = prec_global[0] 85 | # x3 local lstm feature cross entropy loss 86 | loss_local = self.criterions[2](outputs[2], targets) 87 | 88 | return loss_tri+loss_global+0.2*loss_local, prec_tri 89 | 90 | 91 | #new added by wc 92 | 93 | ''' 94 | if isinstance(self.criterions, torch.nn.CrossEntropyLoss): 95 | loss = self.criterions(outputs[1], targets) 96 | prec, = accuracy(outputs[1].data, targets.data) 97 | prec = prec[0] 98 | elif isinstance(self.criterions, OIMLoss): 99 | loss, outputs = self.criterions(outputs, targets) 100 | prec, = accuracy(outputs.data, targets.data) 101 | prec = prec[0] 102 | elif isinstance(self.criterions, TripletLoss): 103 | loss, prec = self.criterions(outputs, targets) 104 | else: 105 | raise ValueError("Unsupported loss:", self.criterion) 106 | return loss, prec 107 | ''' 108 | -------------------------------------------------------------------------------- /reid/datasets/cuhk03.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..utils.data import Dataset 7 | from ..utils.osutils import mkdir_if_missing 8 | from ..utils.serialization import write_json 9 | 10 | 11 | class CUHK03(Dataset): 12 | url = 'https://docs.google.com/spreadsheet/viewform?usp=drive_web&formkey=dHRkMkFVSUFvbTJIRkRDLWRwZWpONnc6MA#gid=0' 13 | md5 = '728939e58ad9f0ff53e521857dd8fb43' 14 | 15 | def __init__(self, root, split_id=0, num_val=100, download=True): 16 | super(CUHK03, self).__init__(root, split_id=split_id) 17 | 18 | if download: 19 | self.download() 20 | 21 | if not self._check_integrity(): 22 | raise RuntimeError("Dataset not found or corrupted. " + 23 | "You can use download=True to download it.") 24 | 25 | self.load(num_val) 26 | 27 | def download(self): 28 | if self._check_integrity(): 29 | print("Files already downloaded and verified") 30 | return 31 | 32 | import h5py 33 | import hashlib 34 | from scipy.misc import imsave 35 | from zipfile import ZipFile 36 | 37 | raw_dir = osp.join(self.root, 'raw') 38 | mkdir_if_missing(raw_dir) 39 | 40 | # Download the raw zip file 41 | fpath = osp.join(raw_dir, 'cuhk03_release.zip') 42 | if osp.isfile(fpath) and \ 43 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 44 | print("Using downloaded file: " + fpath) 45 | else: 46 | raise RuntimeError("Please download the dataset manually from {} " 47 | "to {}".format(self.url, fpath)) 48 | 49 | # Extract the file 50 | exdir = osp.join(raw_dir, 'cuhk03_release') 51 | if not osp.isdir(exdir): 52 | print("Extracting zip file") 53 | with ZipFile(fpath) as z: 54 | z.extractall(path=raw_dir) 55 | 56 | # Format 57 | images_dir = osp.join(self.root, 'images') 58 | mkdir_if_missing(images_dir) 59 | matdata = h5py.File(osp.join(exdir, 'cuhk-03.mat'), 'r') 60 | 61 | def deref(ref): 62 | return matdata[ref][:].T 63 | 64 | def dump_(refs, pid, cam, fnames): 65 | for ref in refs: 66 | img = deref(ref) 67 | if img.size == 0 or img.ndim < 2: break 68 | fname = '{:08d}_{:02d}_{:04d}.jpg'.format(pid, cam, len(fnames)) 69 | imsave(osp.join(images_dir, fname), img) 70 | fnames.append(fname) 71 | 72 | identities = [] 73 | for labeled, detected in zip( 74 | matdata['labeled'][0], matdata['detected'][0]): 75 | labeled, detected = deref(labeled), deref(detected) 76 | assert labeled.shape == detected.shape 77 | for i in range(labeled.shape[0]): 78 | pid = len(identities) 79 | images = [[], []] 80 | dump_(labeled[i, :5], pid, 0, images[0]) 81 | dump_(detected[i, :5], pid, 0, images[0]) 82 | dump_(labeled[i, 5:], pid, 1, images[1]) 83 | dump_(detected[i, 5:], pid, 1, images[1]) 84 | identities.append(images) 85 | 86 | # Save meta information into a json file 87 | meta = {'name': 'cuhk03', 'shot': 'multiple', 'num_cameras': 2, 88 | 'identities': identities} 89 | write_json(meta, osp.join(self.root, 'meta.json')) 90 | 91 | # Save training and test splits 92 | splits = [] 93 | view_counts = [deref(ref).shape[0] for ref in matdata['labeled'][0]] 94 | vid_offsets = np.r_[0, np.cumsum(view_counts)] 95 | for ref in matdata['testsets'][0]: 96 | test_info = deref(ref).astype(np.int32) 97 | test_pids = sorted( 98 | [int(vid_offsets[i-1] + j - 1) for i, j in test_info]) 99 | trainval_pids = list(set(range(vid_offsets[-1])) - set(test_pids)) 100 | split = {'trainval': trainval_pids, 101 | 'query': test_pids, 102 | 'gallery': test_pids} 103 | splits.append(split) 104 | write_json(splits, osp.join(self.root, 'splits.json')) 105 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | from .evaluation_metrics import cmc, mean_ap 8 | from .feature_extraction import extract_cnn_feature 9 | from .utils.meters import AverageMeter 10 | 11 | 12 | def extract_features(model, data_loader, print_freq=1, metric=None): 13 | model.eval() 14 | batch_time = AverageMeter() 15 | data_time = AverageMeter() 16 | 17 | features = OrderedDict() 18 | labels = OrderedDict() 19 | 20 | end = time.time() 21 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 22 | data_time.update(time.time() - end) 23 | 24 | outputs = extract_cnn_feature(model, imgs) 25 | #bs, ncrops, c, h, w = imgs.size() 26 | #outputs = extract_cnn_feature(model, imgs.view(-1,c,h,w)) 27 | #outputs = outputs.view(bs, ncrops, -1).mean(1) 28 | for fname, output, pid in zip(fnames, outputs, pids): 29 | features[fname] = output 30 | labels[fname] = pid 31 | 32 | batch_time.update(time.time() - end) 33 | end = time.time() 34 | 35 | if (i + 1) % print_freq == 0: 36 | print('Extract Features: [{}/{}]\t' 37 | 'Time {:.3f} ({:.3f})\t' 38 | 'Data {:.3f} ({:.3f})\t' 39 | .format(i + 1, len(data_loader), 40 | batch_time.val, batch_time.avg, 41 | data_time.val, data_time.avg)) 42 | 43 | return features, labels 44 | 45 | 46 | def pairwise_distance(features, query=None, gallery=None, metric=None): 47 | if query is None and gallery is None: 48 | n = len(features) 49 | x = torch.cat(list(features.values())) 50 | x = x.view(n, -1) 51 | if metric is not None: 52 | x = metric.transform(x) 53 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 54 | dist = dist.expand(n, n) - 2 * torch.mm(x, x.t()) 55 | return dist 56 | 57 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 58 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 59 | m, n = x.size(0), y.size(0) 60 | x = x.view(m, -1) 61 | y = y.view(n, -1) 62 | if metric is not None: 63 | x = metric.transform(x) 64 | y = metric.transform(y) 65 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 66 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 67 | dist.addmm_(1, -2, x, y.t()) 68 | return dist 69 | 70 | 71 | def evaluate_all(distmat, query=None, gallery=None, 72 | query_ids=None, gallery_ids=None, 73 | query_cams=None, gallery_cams=None, 74 | cmc_topk=(1, 5, 10)): 75 | if query is not None and gallery is not None: 76 | query_ids = [pid for _, pid, _ in query] 77 | gallery_ids = [pid for _, pid, _ in gallery] 78 | query_cams = [cam for _, _, cam in query] 79 | gallery_cams = [cam for _, _, cam in gallery] 80 | else: 81 | assert (query_ids is not None and gallery_ids is not None 82 | and query_cams is not None and gallery_cams is not None) 83 | 84 | # Compute mean AP 85 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 86 | print('Mean AP: {:4.1%}'.format(mAP)) 87 | 88 | #Compute all kinds of CMC scores 89 | cmc_configs = { 90 | 'allshots': dict(separate_camera_set=False, 91 | single_gallery_shot=False, 92 | first_match_break=False), 93 | 'cuhk03': dict(separate_camera_set=True, 94 | single_gallery_shot=True, 95 | first_match_break=False), 96 | 'market1501': dict(separate_camera_set=False, 97 | single_gallery_shot=False, 98 | first_match_break=True)} 99 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 100 | query_cams, gallery_cams, **params) 101 | for name, params in cmc_configs.items()} 102 | 103 | print('CMC Scores{:>12}{:>12}{:>12}' 104 | .format('allshots', 'cuhk03', 'market1501')) 105 | for k in cmc_topk: 106 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}' 107 | .format(k, cmc_scores['allshots'][k - 1], 108 | cmc_scores['cuhk03'][k - 1], 109 | cmc_scores['market1501'][k - 1])) 110 | 111 | # Use the allshots cmc top-1 score for validation criterion 112 | return cmc_scores['market1501'][0] 113 | #return mAP 114 | 115 | class Evaluator(object): 116 | def __init__(self, model): 117 | super(Evaluator, self).__init__() 118 | self.model = model 119 | 120 | def evaluate(self, data_loader, query, gallery, metric=None): 121 | features, _ = extract_features(self.model, data_loader) 122 | distmat = pairwise_distance(features, query, gallery, metric=metric) 123 | return evaluate_all(distmat, query=query, gallery=gallery) 124 | -------------------------------------------------------------------------------- /reid/models/inception.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.nn import init 7 | 8 | 9 | __all__ = ['InceptionNet', 'inception'] 10 | 11 | 12 | def _make_conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, 13 | bias=False): 14 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 15 | stride=stride, padding=padding, bias=bias) 16 | bn = nn.BatchNorm2d(out_planes) 17 | relu = nn.ReLU(inplace=True) 18 | return nn.Sequential(conv, bn, relu) 19 | 20 | 21 | class Block(nn.Module): 22 | def __init__(self, in_planes, out_planes, pool_method, stride): 23 | super(Block, self).__init__() 24 | self.branches = nn.ModuleList([ 25 | nn.Sequential( 26 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0), 27 | _make_conv(out_planes, out_planes, stride=stride) 28 | ), 29 | nn.Sequential( 30 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0), 31 | _make_conv(out_planes, out_planes), 32 | _make_conv(out_planes, out_planes, stride=stride)) 33 | ]) 34 | 35 | if pool_method == 'Avg': 36 | assert stride == 1 37 | self.branches.append( 38 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0)) 39 | self.branches.append(nn.Sequential( 40 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 41 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0))) 42 | else: 43 | self.branches.append( 44 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1)) 45 | 46 | def forward(self, x): 47 | return torch.cat([b(x) for b in self.branches], 1) 48 | 49 | 50 | class InceptionNet(nn.Module): 51 | def __init__(self, cut_at_pooling=False, num_features=256, norm=False, 52 | dropout=0, num_classes=0): 53 | super(InceptionNet, self).__init__() 54 | self.cut_at_pooling = cut_at_pooling 55 | 56 | self.conv1 = _make_conv(3, 32) 57 | self.conv2 = _make_conv(32, 32) 58 | self.conv3 = _make_conv(32, 32) 59 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) 60 | self.in_planes = 32 61 | self.inception4a = self._make_inception(64, 'Avg', 1) 62 | self.inception4b = self._make_inception(64, 'Max', 2) 63 | self.inception5a = self._make_inception(128, 'Avg', 1) 64 | self.inception5b = self._make_inception(128, 'Max', 2) 65 | self.inception6a = self._make_inception(256, 'Avg', 1) 66 | self.inception6b = self._make_inception(256, 'Max', 2) 67 | 68 | if not self.cut_at_pooling: 69 | self.num_features = num_features 70 | self.norm = norm 71 | self.dropout = dropout 72 | self.has_embedding = num_features > 0 73 | self.num_classes = num_classes 74 | 75 | self.avgpool = nn.AdaptiveAvgPool2d(1) 76 | 77 | if self.has_embedding: 78 | self.feat = nn.Linear(self.in_planes, self.num_features) 79 | self.feat_bn = nn.BatchNorm1d(self.num_features) 80 | else: 81 | # Change the num_features to CNN output channels 82 | self.num_features = self.in_planes 83 | if self.dropout > 0: 84 | self.drop = nn.Dropout(self.dropout) 85 | if self.num_classes > 0: 86 | self.classifier = nn.Linear(self.num_features, self.num_classes) 87 | 88 | self.reset_params() 89 | 90 | def forward(self, x): 91 | x = self.conv1(x) 92 | x = self.conv2(x) 93 | x = self.conv3(x) 94 | x = self.pool3(x) 95 | x = self.inception4a(x) 96 | x = self.inception4b(x) 97 | x = self.inception5a(x) 98 | x = self.inception5b(x) 99 | x = self.inception6a(x) 100 | x = self.inception6b(x) 101 | 102 | if self.cut_at_pooling: 103 | return x 104 | 105 | x = self.avgpool(x) 106 | x = x.view(x.size(0), -1) 107 | 108 | if self.has_embedding: 109 | x = self.feat(x) 110 | x = self.feat_bn(x) 111 | if self.norm: 112 | x = F.normalize(x) 113 | elif self.has_embedding: 114 | x = F.relu(x) 115 | if self.dropout > 0: 116 | x = self.drop(x) 117 | if self.num_classes > 0: 118 | x = self.classifier(x) 119 | return x 120 | 121 | def _make_inception(self, out_planes, pool_method, stride): 122 | block = Block(self.in_planes, out_planes, pool_method, stride) 123 | self.in_planes = (out_planes * 4 if pool_method == 'Avg' else 124 | out_planes * 2 + self.in_planes) 125 | return block 126 | 127 | def reset_params(self): 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | init.kaiming_normal(m.weight, mode='fan_out') 131 | if m.bias is not None: 132 | init.constant(m.bias, 0) 133 | elif isinstance(m, nn.BatchNorm2d): 134 | init.constant(m.weight, 1) 135 | init.constant(m.bias, 0) 136 | elif isinstance(m, nn.Linear): 137 | init.normal(m.weight, std=0.001) 138 | if m.bias is not None: 139 | init.constant(m.bias, 0) 140 | 141 | 142 | def inception(**kwargs): 143 | return InceptionNet(**kwargs) 144 | -------------------------------------------------------------------------------- /reid/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..serialization import read_json 7 | 8 | 9 | def _pluck(identities, indices, relabel=False): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, cam_images in enumerate(pid_images): 14 | for fname in cam_images: 15 | name = osp.splitext(fname)[0] 16 | x, y, _ = map(int, name.split('_')) 17 | assert pid == x and camid == y 18 | if relabel: 19 | ret.append((fname, index, camid)) 20 | else: 21 | ret.append((fname, pid, camid)) 22 | return ret 23 | 24 | def _pluck_gallery(identities, indices, relabel=False): 25 | ret = [] 26 | for index, pid in enumerate(indices): 27 | pid_images = identities[pid] 28 | for camid, cam_images in enumerate(pid_images): 29 | if len(cam_images[:-1])==0: 30 | for fname in cam_images: 31 | name = osp.splitext(fname)[0] 32 | x, y, _ = map(int, name.split('_')) 33 | assert pid == x and camid == y 34 | if relabel: 35 | ret.append((fname, index, camid)) 36 | else: 37 | ret.append((fname, pid, camid)) 38 | else: 39 | for fname in cam_images[:-1]: 40 | name = osp.splitext(fname)[0] 41 | x, y, _ = map(int, name.split('_')) 42 | assert pid == x and camid == y 43 | if relabel: 44 | ret.append((fname, index, camid)) 45 | else: 46 | ret.append((fname, pid, camid)) 47 | return ret 48 | 49 | def _pluck_query(identities, indices, relabel=False): 50 | ret = [] 51 | for index, pid in enumerate(indices): 52 | pid_images = identities[pid] 53 | for camid, cam_images in enumerate(pid_images): 54 | for fname in cam_images[-1:]: 55 | name = osp.splitext(fname)[0] 56 | x, y, _ = map(int, name.split('_')) 57 | assert pid == x and camid == y 58 | if relabel: 59 | ret.append((fname, index, camid)) 60 | else: 61 | ret.append((fname, pid, camid)) 62 | return ret 63 | 64 | 65 | class Dataset(object): 66 | def __init__(self, root, split_id=0): 67 | self.root = root 68 | self.split_id = split_id 69 | self.meta = None 70 | self.split = None 71 | self.train, self.val, self.trainval = [], [], [] 72 | self.query, self.gallery = [], [] 73 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 74 | 75 | @property 76 | def images_dir(self): 77 | return osp.join(self.root, 'images') 78 | 79 | def load(self, num_val=0.3, verbose=True): 80 | splits = read_json(osp.join(self.root, 'splits.json')) 81 | if self.split_id >= len(splits): 82 | raise ValueError("split_id exceeds total splits {}" 83 | .format(len(splits))) 84 | self.split = splits[self.split_id] 85 | 86 | # Randomly split train / val 87 | trainval_pids = np.asarray(self.split['trainval']) 88 | np.random.shuffle(trainval_pids) 89 | num = len(trainval_pids) 90 | if isinstance(num_val, float): 91 | num_val = int(round(num * num_val)) 92 | if num_val >= num or num_val < 0: 93 | raise ValueError("num_val exceeds total identities {}" 94 | .format(num)) 95 | train_pids = sorted(trainval_pids[:-num_val]) 96 | val_pids = sorted(trainval_pids[-num_val:]) 97 | 98 | self.meta = read_json(osp.join(self.root, 'meta.json')) 99 | identities = self.meta['identities'] 100 | self.train = _pluck(identities, train_pids, relabel=True) 101 | self.val = _pluck(identities, val_pids, relabel=True) 102 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 103 | self.query = _pluck_query(identities, self.split['query']) 104 | #self.gallery = _pluck(identities, self.split['gallery']) 105 | self.gallery = _pluck_gallery(identities, self.split['gallery']) 106 | self.num_train_ids = len(train_pids) 107 | self.num_val_ids = len(val_pids) 108 | self.num_trainval_ids = len(trainval_pids) 109 | 110 | if verbose: 111 | print(self.__class__.__name__, "dataset loaded") 112 | print(" subset | # ids | # images") 113 | print(" ---------------------------") 114 | print(" train | {:5d} | {:8d}" 115 | .format(self.num_train_ids, len(self.train))) 116 | print(" val | {:5d} | {:8d}" 117 | .format(self.num_val_ids, len(self.val))) 118 | print(" trainval | {:5d} | {:8d}" 119 | .format(self.num_trainval_ids, len(self.trainval))) 120 | print(" query | {:5d} | {:8d}" 121 | .format(len(self.split['query']), len(self.query))) 122 | print(" gallery | {:5d} | {:8d}" 123 | .format(len(self.split['gallery']), len(self.gallery))) 124 | 125 | def _check_integrity(self): 126 | return osp.isdir(osp.join(self.root, 'images')) and \ 127 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 128 | osp.isfile(osp.join(self.root, 'splits.json')) 129 | -------------------------------------------------------------------------------- /reid/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | 5 | from ..utils.data import Dataset 6 | from ..utils.osutils import mkdir_if_missing 7 | from ..utils.serialization import read_json 8 | from ..utils.serialization import write_json 9 | 10 | 11 | ######################## 12 | # Added 13 | def _pluck(identities, indices, relabel=False): 14 | """Extract im names of given pids. 15 | Args: 16 | identities: containing im names 17 | indices: pids 18 | relabel: whether to transform pids to classification labels 19 | """ 20 | ret = [] 21 | for index, pid in enumerate(indices): 22 | pid_images = identities[pid] 23 | for camid, cam_images in enumerate(pid_images): 24 | for fname in cam_images: 25 | name = osp.splitext(fname)[0] 26 | x, y, _ = map(int, name.split('_')) 27 | assert pid == x and camid == y 28 | if relabel: 29 | ret.append((fname, index, camid)) 30 | else: 31 | ret.append((fname, pid, camid)) 32 | return ret 33 | ######################## 34 | 35 | 36 | class Market1501(Dataset): 37 | url = 'https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view' 38 | md5 = '65005ab7d12ec1c44de4eeafe813e68a' 39 | 40 | def __init__(self, root, split_id=0, num_val=100, download=True): 41 | super(Market1501, self).__init__(root, split_id=split_id) 42 | 43 | if download: 44 | self.download() 45 | 46 | if not self._check_integrity(): 47 | raise RuntimeError("Dataset not found or corrupted. " + 48 | "You can use download=True to download it.") 49 | 50 | self.load(num_val) 51 | 52 | def download(self): 53 | if self._check_integrity(): 54 | print("Files already downloaded and verified") 55 | return 56 | 57 | import re 58 | import hashlib 59 | import shutil 60 | from glob import glob 61 | from zipfile import ZipFile 62 | 63 | raw_dir = osp.join(self.root, 'raw') 64 | mkdir_if_missing(raw_dir) 65 | 66 | # Download the raw zip file 67 | fpath = osp.join(raw_dir, 'Market-1501-v15.09.15.zip') 68 | if osp.isfile(fpath) and \ 69 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 70 | print("Using downloaded file: " + fpath) 71 | else: 72 | raise RuntimeError("Please download the dataset manually from {} " 73 | "to {}".format(self.url, fpath)) 74 | 75 | # Extract the file 76 | exdir = osp.join(raw_dir, 'Market-1501-v15.09.15') 77 | if not osp.isdir(exdir): 78 | print("Extracting zip file") 79 | with ZipFile(fpath) as z: 80 | z.extractall(path=raw_dir) 81 | 82 | # Format 83 | images_dir = osp.join(self.root, 'images') 84 | mkdir_if_missing(images_dir) 85 | 86 | # 1501 identities (+1 for background) with 6 camera views each 87 | identities = [[[] for _ in range(6)] for _ in range(1502)] 88 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 89 | fnames = [] ######### Added. Names of images in new dir. 90 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) 91 | pids = set() 92 | for fpath in fpaths: 93 | fname = osp.basename(fpath) 94 | pid, cam = map(int, pattern.search(fname).groups()) 95 | if pid == -1: continue # junk images are just ignored 96 | assert 0 <= pid <= 1501 # pid == 0 means background 97 | assert 1 <= cam <= 6 98 | cam -= 1 99 | pids.add(pid) 100 | fname = ('{:08d}_{:02d}_{:04d}.jpg' 101 | .format(pid, cam, len(identities[pid][cam]))) 102 | identities[pid][cam].append(fname) 103 | shutil.copy(fpath, osp.join(images_dir, fname)) 104 | fnames.append(fname) ######### Added 105 | return pids, fnames 106 | 107 | trainval_pids, _ = register('bounding_box_train') 108 | gallery_pids, gallery_fnames = register('bounding_box_test') 109 | query_pids, query_fnames = register('query') 110 | assert query_pids <= gallery_pids 111 | assert trainval_pids.isdisjoint(gallery_pids) 112 | 113 | # Save meta information into a json file 114 | meta = {'name': 'Market1501', 'shot': 'multiple', 'num_cameras': 6, 115 | 'identities': identities, 116 | 'query_fnames': query_fnames, ######### Added 117 | 'gallery_fnames': gallery_fnames} ######### Added 118 | write_json(meta, osp.join(self.root, 'meta.json')) 119 | 120 | # Save the only training / test split 121 | splits = [{ 122 | 'trainval': sorted(list(trainval_pids)), 123 | 'query': sorted(list(query_pids)), 124 | 'gallery': sorted(list(gallery_pids))}] 125 | write_json(splits, osp.join(self.root, 'splits.json')) 126 | 127 | ######################## 128 | # Added 129 | def load(self, num_val=0.3, verbose=True): 130 | splits = read_json(osp.join(self.root, 'splits.json')) 131 | if self.split_id >= len(splits): 132 | raise ValueError("split_id exceeds total splits {}" 133 | .format(len(splits))) 134 | self.split = splits[self.split_id] 135 | 136 | # Randomly split train / val 137 | trainval_pids = np.asarray(self.split['trainval']) 138 | np.random.shuffle(trainval_pids) 139 | num = len(trainval_pids) 140 | if isinstance(num_val, float): 141 | num_val = int(round(num * num_val)) 142 | if num_val >= num or num_val < 0: 143 | raise ValueError("num_val exceeds total identities {}" 144 | .format(num)) 145 | train_pids = sorted(trainval_pids[:-num_val]) 146 | val_pids = sorted(trainval_pids[-num_val:]) 147 | 148 | self.meta = read_json(osp.join(self.root, 'meta.json')) 149 | identities = self.meta['identities'] 150 | 151 | self.train = _pluck(identities, train_pids, relabel=True) 152 | self.val = _pluck(identities, val_pids, relabel=True) 153 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 154 | self.num_train_ids = len(train_pids) 155 | self.num_val_ids = len(val_pids) 156 | self.num_trainval_ids = len(trainval_pids) 157 | 158 | ########## 159 | # Added 160 | query_fnames = self.meta['query_fnames'] 161 | gallery_fnames = self.meta['gallery_fnames'] 162 | self.query = [] 163 | for fname in query_fnames: 164 | name = osp.splitext(fname)[0] 165 | pid, cam, _ = map(int, name.split('_')) 166 | self.query.append((fname, pid, cam)) 167 | self.gallery = [] 168 | for fname in gallery_fnames: 169 | name = osp.splitext(fname)[0] 170 | pid, cam, _ = map(int, name.split('_')) 171 | self.gallery.append((fname, pid, cam)) 172 | ########## 173 | 174 | if verbose: 175 | print(self.__class__.__name__, "dataset loaded") 176 | print(" subset | # ids | # images") 177 | print(" ---------------------------") 178 | print(" train | {:5d} | {:8d}" 179 | .format(self.num_train_ids, len(self.train))) 180 | print(" val | {:5d} | {:8d}" 181 | .format(self.num_val_ids, len(self.val))) 182 | print(" trainval | {:5d} | {:8d}" 183 | .format(self.num_trainval_ids, len(self.trainval))) 184 | print(" query | {:5d} | {:8d}" 185 | .format(len(self.split['query']), len(self.query))) 186 | print(" gallery | {:5d} | {:8d}" 187 | .format(len(self.split['gallery']), len(self.gallery))) 188 | ######################## 189 | -------------------------------------------------------------------------------- /reid/extract_feature.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | 5 | import os 6 | import time 7 | import numpy as np 8 | import sys 9 | import torch 10 | from torch import nn 11 | from torch.backends import cudnn 12 | from torch.utils.data import DataLoader 13 | from torchvision.transforms import TenCrop, Resize, Lambda, CenterCrop 14 | 15 | from reid.feature_extraction import extract_cnn_feature 16 | from reid import datasets 17 | from reid import models 18 | from reid.utils.meters import AverageMeter 19 | from reid.dist_metric import DistanceMetric 20 | from reid.loss import TripletLoss 21 | from reid.trainers import Trainer 22 | from reid.evaluators import Evaluator 23 | from reid.utils.data import transforms as T 24 | from reid.utils.data.preprocessor import Preprocessor 25 | from reid.utils.data.sampler import RandomIdentitySampler 26 | from reid.utils.logging import Logger 27 | from reid.utils.serialization import load_checkpoint, save_checkpoint 28 | 29 | from test_dataset import TestDataset 30 | from sklearn.preprocessing import normalize 31 | 32 | 33 | GT_FILES = './data/market1501/raw/Market-1501-v15.09.15/gt_bbox' 34 | 35 | def get_dataloader(batch_size=64, workers=4): 36 | #prepare dataset 37 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 38 | std=[0.229, 0.224, 0.225]) 39 | 40 | 41 | test_transformer = T.Compose([ 42 | Resize((256, 128)), 43 | T.ToTensor(), 44 | normalizer, 45 | ]) 46 | 47 | #test_time augmentation 48 | #test_transformer = T.Compose([ 49 | # Resize((288, 144)), 50 | # TenCrop((256,128)), 51 | # Lambda(lambda crops: torch.stack([normalizer(T.ToTensor()(crop)) for crop in crops])), 52 | # ]) 53 | 54 | 55 | single_query_dir = './data/market1501/raw/Market-1501-v15.09.15/query' 56 | multi_query_dir = './data/market1501/raw/Market-1501-v15.09.15/gt_bbox' 57 | gallery_dir = './data/market1501/raw/Market-1501-v15.09.15/bounding_box_test' 58 | 59 | 60 | single_query_loader = DataLoader(TestDataset(root=single_query_dir, transform=test_transformer), 61 | batch_size=batch_size, 62 | num_workers=workers, 63 | shuffle=False, pin_memory=True) 64 | 65 | multi_query_loader = DataLoader(TestDataset(root=multi_query_dir, transform=test_transformer), 66 | batch_size=batch_size, 67 | num_workers=workers, 68 | shuffle=False, pin_memory=True) 69 | 70 | gallery_loader = DataLoader(TestDataset(root=gallery_dir, transform=test_transformer), 71 | batch_size=batch_size, 72 | num_workers=workers, 73 | shuffle=False, pin_memory=True) 74 | return single_query_loader, multi_query_loader, gallery_loader 75 | 76 | def extract_features(model, data_loader, print_freq=1, save_name='feature.mat'): 77 | 78 | batch_time = AverageMeter() 79 | data_time = AverageMeter() 80 | 81 | ids = [] 82 | cams = [] 83 | features = [] 84 | query_files = [] 85 | end = time.time() 86 | for i, (imgs, fnames) in enumerate(data_loader): 87 | data_time.update(time.time() - end) 88 | 89 | outputs = extract_cnn_feature(model, imgs) 90 | #for test time augmentation 91 | #bs, ncrops, c, h, w = imgs.size() 92 | #outputs = extract_cnn_feature(model, imgs.view(-1,c,h,w)) 93 | #outputs = outputs.view(bs,ncrops,-1).mean(1) 94 | for fname, output in zip(fnames, outputs): 95 | if fname[0]=='-': 96 | ids.append(-1) 97 | cams.append(int(fname[4])) 98 | else: 99 | ids.append(int(fname[:4])) 100 | cams.append(int(fname[6])) 101 | features.append(output.numpy()) 102 | query_files.append(fname) 103 | batch_time.update(time.time() - end) 104 | end = time.time() 105 | 106 | if (i + 1) % print_freq == 0: 107 | print('Extract Features: [{}/{}]\t' 108 | 'Time {:.3f} ({:.3f})\t' 109 | 'Data {:.3f} ({:.3f})\t' 110 | .format(i + 1, len(data_loader), 111 | batch_time.val, batch_time.avg, 112 | data_time.val, data_time.avg)) 113 | 114 | return features, ids, cams, query_files 115 | 116 | def evaluate(): 117 | print ('Get dataloader... ') 118 | single_query_loader, multi_query_loader, gallery_loader = get_dataloader() 119 | 120 | print ('Create and load pre-trained model...') 121 | model = models.create('resnet50',dropout=0.0, num_features=2048, num_classes=751) 122 | checkpoint = load_checkpoint('./logs/deep-person-1-new-augmentation/market1501-resnet50/model_best.pth.tar') 123 | model.load_state_dict(checkpoint['state_dict']) 124 | model = nn.DataParallel(model).cuda() 125 | model.eval() 126 | 127 | print ('Extract single_query&gallery feature...') 128 | single_query_feat, single_query_ids, single_query_cams, query_files = extract_features(model, single_query_loader) 129 | gallery_feat, gallery_ids, gallery_cams, _ = extract_features(model, gallery_loader) 130 | 131 | print ('Get multi_query feature...') 132 | multi_query_dict = dict() 133 | for i, (imgs, fnames) in enumerate(multi_query_loader): 134 | outputs = extract_cnn_feature(model, imgs) 135 | # test time augmentation 136 | #bs, ncrops, c, h, w = imgs.size() 137 | #outputs = extract_cnn_feature(model, imgs.view(-1,c,h,w)) 138 | #outputs = outputs.view(bs, ncrops, -1).mean(1) 139 | for fname, output in zip(fnames, outputs): 140 | if multi_query_dict.get(fname[:7])==None: 141 | multi_query_dict[fname[:7]]=[] 142 | multi_query_dict[fname[:7]].append(output.numpy()) 143 | 144 | query_max_feat = [] 145 | query_avg_feat = [] 146 | for query_file in query_files: 147 | index = query_file[:7] 148 | multi_features = multi_query_dict[index] 149 | multi_features = normalize(multi_features) 150 | query_max_feat.append(np.max(multi_features,axis=0)) 151 | query_avg_feat.append(np.mean(multi_features,axis=0)) 152 | 153 | 154 | assert len(query_max_feat)==len(query_avg_feat)==len(single_query_feat) 155 | 156 | print ('Write to mat file...') 157 | import scipy.io as sio 158 | if not os.path.exists('./matdata'): 159 | os.mkdir('./matdata') 160 | sio.savemat('./matdata/queryID.mat', {'queryID':np.array(single_query_ids)}) 161 | sio.savemat('./matdata/queryCAM.mat', {'queryCAM':np.array(single_query_cams)}) 162 | sio.savemat('./matdata/testID.mat', {'testID':np.array(gallery_ids)}) 163 | sio.savemat('./matdata/testCAM.mat', {'testCAM':np.array(gallery_cams)}) 164 | sio.savemat('./matdata/Hist_query.mat', {'Hist_query':np.array(single_query_feat)}) 165 | sio.savemat('./matdata/Hist_test.mat', {'Hist_test':np.array(gallery_feat)}) 166 | sio.savemat('./matdata/Hist_query_max.mat', {'Hist_max':np.array(query_max_feat)}) 167 | sio.savemat('./matdata/Hist_query_avg.mat', {'Hist_avg':np.array(query_avg_feat)}) 168 | 169 | return 170 | 171 | evaluate() 172 | -------------------------------------------------------------------------------- /reid/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import read_json 7 | from ..utils.serialization import write_json 8 | 9 | def _pluck(identities, indices, relabel=False): 10 | """Extract im names of given pids. 11 | Args: 12 | identities: containing im names 13 | indices: pids 14 | relabel: whether to transform pids to classification labels 15 | """ 16 | ret = [] 17 | for index, pid in enumerate(indices): 18 | pid_images = identities[pid] 19 | for camid, cam_images in enumerate(pid_images): 20 | for fname in cam_images: 21 | name = osp.splitext(fname)[0] 22 | x, y, _ = map(int, name.split('_')) 23 | assert pid == x and camid == y 24 | if relabel: 25 | ret.append((fname, index, camid)) 26 | else: 27 | ret.append((fname, pid, camid)) 28 | return ret 29 | 30 | class DukeMTMC(Dataset): 31 | url = 'https://drive.google.com/uc?id=0B0VOCNYh8HeRdnBPa2ZWaVBYSVk' 32 | md5 = '2f93496f9b516d1ee5ef51c1d5e7d601' 33 | 34 | def __init__(self, root, split_id=0, num_val=100, download=True): 35 | super(DukeMTMC, self).__init__(root, split_id=split_id) 36 | 37 | if download: 38 | self.download() 39 | 40 | if not self._check_integrity(): 41 | raise RuntimeError("Dataset not found or corrupted. " + 42 | "You can use download=True to download it.") 43 | 44 | self.load(num_val) 45 | 46 | def download(self): 47 | if self._check_integrity(): 48 | print("Files already downloaded and verified") 49 | return 50 | 51 | import re 52 | import hashlib 53 | import shutil 54 | from glob import glob 55 | from zipfile import ZipFile 56 | 57 | raw_dir = osp.join(self.root, 'raw') 58 | mkdir_if_missing(raw_dir) 59 | 60 | # Download the raw zip file 61 | fpath = osp.join(raw_dir, 'DukeMTMC-reID.zip') 62 | if osp.isfile(fpath) and \ 63 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 64 | print("Using downloaded file: " + fpath) 65 | else: 66 | raise RuntimeError("Please download the dataset manually from {} " 67 | "to {}".format(self.url, fpath)) 68 | 69 | # Extract the file 70 | exdir = osp.join(raw_dir, 'DukeMTMC-reID') 71 | if not osp.isdir(exdir): 72 | print("Extracting zip file") 73 | with ZipFile(fpath) as z: 74 | z.extractall(path=raw_dir) 75 | 76 | # Format 77 | images_dir = osp.join(self.root, 'images') 78 | mkdir_if_missing(images_dir) 79 | 80 | identities = [] 81 | all_pids = {} 82 | 83 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 84 | fnames = [] ###### New Add. Names of images in new dir 85 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) 86 | pids = set() 87 | for fpath in fpaths: 88 | fname = osp.basename(fpath) 89 | pid, cam = map(int, pattern.search(fname).groups()) 90 | assert 1 <= cam <= 8 91 | cam -= 1 92 | if pid not in all_pids: 93 | all_pids[pid] = len(all_pids) 94 | pid = all_pids[pid] 95 | pids.add(pid) 96 | if pid >= len(identities): 97 | assert pid == len(identities) 98 | identities.append([[] for _ in range(8)]) # 8 camera views 99 | fname = ('{:08d}_{:02d}_{:04d}.jpg' 100 | .format(pid, cam, len(identities[pid][cam]))) 101 | identities[pid][cam].append(fname) 102 | shutil.copy(fpath, osp.join(images_dir, fname)) 103 | fnames.append(fname)######## added 104 | return pids, fnames 105 | 106 | trainval_pids, _ = register('bounding_box_train') 107 | gallery_pids, gallery_fnames = register('bounding_box_test') 108 | query_pids, query_fnames = register('query') 109 | assert query_pids <= gallery_pids 110 | assert trainval_pids.isdisjoint(gallery_pids) 111 | 112 | # Save meta information into a json file 113 | meta = {'name': 'DukeMTMC', 'shot': 'multiple', 'num_cameras': 8, 114 | 'identities': identities, 115 | 'query_fnames': query_fnames,########## Added 116 | 'gallery_fnames': gallery_fnames} ######### Added 117 | write_json(meta, osp.join(self.root, 'meta.json')) 118 | 119 | # Save the only training / test split 120 | splits = [{ 121 | 'trainval': sorted(list(trainval_pids)), 122 | 'query': sorted(list(query_pids)), 123 | 'gallery': sorted(list(gallery_pids))}] 124 | write_json(splits, osp.join(self.root, 'splits.json')) 125 | 126 | ######################## 127 | # Added 128 | def load(self, num_val=0.3, verbose=True): 129 | import numpy as np 130 | splits = read_json(osp.join(self.root, 'splits.json')) 131 | if self.split_id >= len(splits): 132 | raise ValueError("split_id exceeds total splits {}" 133 | .format(len(splits))) 134 | self.split = splits[self.split_id] 135 | 136 | # Randomly split train / val 137 | trainval_pids = np.asarray(self.split['trainval']) 138 | np.random.shuffle(trainval_pids) 139 | num = len(trainval_pids) 140 | if isinstance(num_val, float): 141 | num_val = int(round(num * num_val)) 142 | if num_val >= num or num_val < 0: 143 | raise ValueError("num_val exceeds total identities {}" 144 | .format(num)) 145 | train_pids = sorted(trainval_pids[:-num_val]) 146 | val_pids = sorted(trainval_pids[-num_val:]) 147 | 148 | self.meta = read_json(osp.join(self.root, 'meta.json')) 149 | identities = self.meta['identities'] 150 | 151 | self.train = _pluck(identities, train_pids, relabel=True) 152 | self.val = _pluck(identities, val_pids, relabel=True) 153 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 154 | self.num_train_ids = len(train_pids) 155 | self.num_val_ids = len(val_pids) 156 | self.num_trainval_ids = len(trainval_pids) 157 | 158 | ########## 159 | # Added 160 | query_fnames = self.meta['query_fnames'] 161 | gallery_fnames = self.meta['gallery_fnames'] 162 | self.query = [] 163 | for fname in query_fnames: 164 | name = osp.splitext(fname)[0] 165 | pid, cam, _ = map(int, name.split('_')) 166 | self.query.append((fname, pid, cam)) 167 | self.gallery = [] 168 | for fname in gallery_fnames: 169 | name = osp.splitext(fname)[0] 170 | pid, cam, _ = map(int, name.split('_')) 171 | self.gallery.append((fname, pid, cam)) 172 | ########## 173 | 174 | if verbose: 175 | print(self.__class__.__name__, "dataset loaded") 176 | print(" subset | # ids | # images") 177 | print(" ---------------------------") 178 | print(" train | {:5d} | {:8d}" 179 | .format(self.num_train_ids, len(self.train))) 180 | print(" val | {:5d} | {:8d}" 181 | .format(self.num_val_ids, len(self.val))) 182 | print(" trainval | {:5d} | {:8d}" 183 | .format(self.num_trainval_ids, len(self.trainval))) 184 | print(" query | {:5d} | {:8d}" 185 | .format(len(self.split['query']), len(self.query))) 186 | print(" gallery | {:5d} | {:8d}" 187 | .format(len(self.split['gallery']), len(self.gallery))) 188 | ######################## 189 | -------------------------------------------------------------------------------- /mancs_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import os 5 | import numpy as np 6 | import sys 7 | import torch 8 | from torch import nn 9 | from torch.backends import cudnn 10 | from torch.utils.data import DataLoader 11 | from torchvision.transforms import TenCrop, Lambda, Resize 12 | sys.path.append(osp.join(os.getcwd(), 'open-reid')) 13 | from reid import datasets 14 | from reid import models 15 | from reid.dist_metric import DistanceMetric 16 | from reid.loss import TripletLoss,FocalLoss 17 | from reid.trainers import Trainer 18 | from reid.evaluators import Evaluator 19 | from reid.utils.data import transforms as T 20 | from reid.utils.data.preprocessor import Preprocessor 21 | from reid.utils.data.sampler import RandomIdentitySampler 22 | from reid.utils.logging import Logger 23 | from reid.utils.serialization import load_checkpoint, save_checkpoint 24 | 25 | 26 | def get_data(name, split_id, data_dir, height, width, batch_size, num_instances, 27 | workers, combine_trainval): 28 | root = osp.join(data_dir, name) 29 | 30 | dataset = datasets.create(name, root, num_val=0.1, split_id=split_id) 31 | 32 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 33 | std=[0.229, 0.224, 0.225]) 34 | 35 | train_set = dataset.trainval if combine_trainval else dataset.train 36 | num_classes = (dataset.num_trainval_ids if combine_trainval 37 | else dataset.num_train_ids) 38 | 39 | train_transformer = T.Compose([ 40 | Resize((256,128)), 41 | T.RandomSizedRectCrop(height, width), 42 | T.RandomHorizontalFlip(), 43 | T.ToTensor(), 44 | normalizer, 45 | T.RandomErasing(probability=0.5,sh=0.2,r1=0.3) 46 | ]) 47 | 48 | test_transformer = T.Compose([ 49 | T.RectScale(height, width), 50 | T.ToTensor(), 51 | normalizer, 52 | ]) 53 | 54 | 55 | train_loader = DataLoader( 56 | Preprocessor(train_set, root=dataset.images_dir, 57 | transform=train_transformer), 58 | batch_size=batch_size, num_workers=workers, 59 | sampler=RandomIdentitySampler(train_set, num_instances), 60 | pin_memory=True, drop_last=True) 61 | 62 | val_loader = DataLoader( 63 | Preprocessor(dataset.val, root=dataset.images_dir, 64 | transform=test_transformer), 65 | batch_size=batch_size, num_workers=workers, 66 | shuffle=False, pin_memory=True) 67 | 68 | test_loader = DataLoader( 69 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 70 | root=dataset.images_dir, transform=test_transformer), 71 | batch_size=batch_size, num_workers=workers, 72 | shuffle=False, pin_memory=True) 73 | 74 | return dataset, num_classes, train_loader, val_loader, test_loader 75 | 76 | 77 | def main(args): 78 | np.random.seed(args.seed) 79 | torch.manual_seed(args.seed) 80 | cudnn.benchmark = True 81 | 82 | # Redirect print to both console and log file 83 | if not args.evaluate: 84 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 85 | 86 | # Create data loaders 87 | assert args.num_instances > 1, "num_instances should be greater than 1" 88 | assert args.batch_size % args.num_instances == 0, \ 89 | 'num_instances should divide batch_size' 90 | if args.height is None or args.width is None: 91 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 92 | (256, 128) 93 | dataset, num_classes, train_loader, val_loader, test_loader = \ 94 | get_data(args.dataset, args.split, args.data_dir, args.height, 95 | args.width, args.batch_size, args.num_instances, args.workers, 96 | args.combine_trainval) 97 | # Create model 98 | # Hacking here to let the classifier be the last feature embedding layer 99 | # Net structure: avgpool -> FC(1024) -> FC(args.features) 100 | model = models.create(args.arch, num_classes=num_classes) 101 | 102 | # Load from checkpoint 103 | start_epoch = best_top1 = 0 104 | if args.resume: 105 | checkpoint = load_checkpoint(args.resume) 106 | model.load_state_dict(checkpoint['state_dict']) 107 | #start_epoch = checkpoint['epoch'] 108 | best_top1 = checkpoint['best_top1'] 109 | print("=> Start epoch {} best top1 {:.1%}" 110 | .format(start_epoch, best_top1)) 111 | model = nn.DataParallel(model).cuda() 112 | 113 | # Distance metric 114 | metric = DistanceMetric(algorithm=args.dist_metric) 115 | 116 | # Evaluator 117 | evaluator = Evaluator(model) 118 | if args.evaluate: 119 | metric.train(model, train_loader) 120 | #print("Validation:") 121 | #evaluator.evaluate(val_loader, dataset.val, dataset.val, metric) 122 | print("Test:") 123 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric) 124 | return 125 | 126 | # Criterion 127 | criterion = [] 128 | criterion.append(TripletLoss(margin=args.margin,num_instances=args.num_instances).cuda()) 129 | #criterion.append(nn.CrossEntropyLoss().cuda()) 130 | criterion.append(FocalLoss().cuda()) 131 | criterion.append(nn.CrossEntropyLoss().cuda()) 132 | 133 | #multi lr 134 | base_param_ids = set(map(id, model.module.base.parameters())) 135 | new_params = [p for p in model.parameters() if 136 | id(p) not in base_param_ids] 137 | param_groups = [ 138 | {'params': model.module.base.parameters(), 'lr_mult': 1.0}, 139 | {'params': new_params, 'lr_mult': 3.0}] 140 | # Optimizer 141 | optimizer = torch.optim.Adam(param_groups, lr=args.lr, 142 | weight_decay=args.weight_decay) 143 | #optimizer = torch.optim.SGD(param_groups, lr=args.lr, 144 | #momentum=0.9, weight_decay=args.weight_decay) 145 | 146 | # Trainer 147 | trainer = Trainer(model, criterion) 148 | # Schedule learning rate 149 | def adjust_lr(epoch): 150 | lr = args.lr if epoch <= 100 else \ 151 | args.lr * (0.001 ** ((epoch - 100) / 50.0)) 152 | for g in optimizer.param_groups: 153 | g['lr'] = lr * g.get('lr_mult', 1) 154 | 155 | # Start training 156 | for epoch in range(start_epoch, args.epochs): 157 | adjust_lr(epoch) 158 | trainer.train(epoch, train_loader, optimizer) 159 | if epoch < args.start_save: 160 | continue 161 | #top1 = evaluator.evaluate(val_loader, dataset.val, dataset.val) 162 | top1 = evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric) 163 | 164 | is_best = top1 > best_top1 165 | best_top1 = max(top1, best_top1) 166 | save_checkpoint({ 167 | 'state_dict': model.module.state_dict(), 168 | 'epoch': epoch + 1, 169 | 'best_top1': best_top1, 170 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 171 | 172 | print('\n * Finished epoch {:3d} top1: {:5.1%} best: {:5.1%}{}\n'. 173 | format(epoch, top1, best_top1, ' *' if is_best else '')) 174 | 175 | # Final test 176 | print('Test with best model:') 177 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) 178 | model.module.load_state_dict(checkpoint['state_dict']) 179 | metric.train(model, train_loader) 180 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric) 181 | 182 | 183 | if __name__ == '__main__': 184 | parser = argparse.ArgumentParser(description="Triplet loss classification") 185 | # data 186 | parser.add_argument('-d', '--dataset', type=str, default='cuhk03', 187 | choices=datasets.names()) 188 | parser.add_argument('-b', '--batch-size', type=int, default=256) 189 | parser.add_argument('-j', '--workers', type=int, default=4) 190 | parser.add_argument('--split', type=int, default=0) 191 | parser.add_argument('--height', type=int, 192 | help="input height, default: 256 for resnet*, " 193 | "144 for inception") 194 | parser.add_argument('--width', type=int, 195 | help="input width, default: 128 for resnet*, " 196 | "56 for inception") 197 | parser.add_argument('--combine-trainval', action='store_true', 198 | help="train and val sets together for training, " 199 | "val set alone for validation") 200 | parser.add_argument('--num-instances', type=int, default=4, 201 | help="each minibatch consist of " 202 | "(batch_size // num_instances) identities, and " 203 | "each identity has num_instances instances, " 204 | "default: 4") 205 | # model 206 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 207 | choices=models.names()) 208 | parser.add_argument('--features', type=int, default=128) 209 | parser.add_argument('--dropout', type=float, default=0) 210 | # loss 211 | parser.add_argument('--margin', type=float, default=0.5, 212 | help="margin of the triplet loss, default: 0.5") 213 | # optimizer 214 | parser.add_argument('--lr', type=float, default=0.0002, 215 | help="learning rate of all parameters") 216 | parser.add_argument('--weight-decay', type=float, default=5e-4) 217 | # training configs 218 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 219 | parser.add_argument('--evaluate', action='store_true', 220 | help="evaluation only") 221 | parser.add_argument('--epochs', type=int, default=150) 222 | parser.add_argument('--start_save', type=int, default=0, 223 | help="start saving checkpoints after specific epoch") 224 | parser.add_argument('--seed', type=int, default=1) 225 | parser.add_argument('--print-freq', type=int, default=1) 226 | # metric learning 227 | parser.add_argument('--dist-metric', type=str, default='euclidean', 228 | choices=['euclidean', 'kissme']) 229 | # misc 230 | working_dir = osp.dirname(osp.abspath(__file__)) 231 | parser.add_argument('--data-dir', type=str, metavar='PATH', 232 | default=osp.join(working_dir, 'data')) 233 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 234 | default=osp.join(working_dir, 'logs')) 235 | main(parser.parse_args()) 236 | -------------------------------------------------------------------------------- /reid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | from torch.autograd import Variable 9 | import torchvision 10 | from torch_deform_conv.layers import ConvOffset2D 11 | 12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 13 | 'resnet152'] 14 | 15 | class ResNet(nn.Module): 16 | __factory = { 17 | 18: torchvision.models.resnet18, 18 | 34: torchvision.models.resnet34, 19 | 50: torchvision.models.resnet50, 20 | 101: torchvision.models.resnet101, 21 | 152: torchvision.models.resnet152, 22 | } 23 | 24 | def __init__(self, depth, pretrained=True, num_features=2048, 25 | dropout=0.1, num_classes=0): 26 | super(ResNet, self).__init__() 27 | 28 | self.depth = depth 29 | self.pretrained = pretrained 30 | 31 | 32 | # Construct base (pretrained) resnet 33 | if depth not in ResNet.__factory: 34 | raise KeyError("Unsupported depth:", depth) 35 | self.base = ResNet.__factory[depth](pretrained=pretrained) 36 | #change layer4[-1] relu to prelu 37 | self.base._modules['layer4'][-1].relu = nn.PReLU() 38 | 39 | self.num_features = num_features 40 | self.dropout = dropout 41 | self.num_classes = num_classes 42 | 43 | out_planes = self.base.fc.in_features 44 | # In deep person has_embedding is always False. So self.num_features==out_planes 45 | # for x2 embedding 46 | self.feat = nn.Linear(out_planes, self.num_features, bias=False) 47 | self.feat_bn = nn.BatchNorm1d(self.num_features) 48 | self.prelu = nn.PReLU() 49 | init.normal(self.feat.weight, std=0.001) 50 | init.constant(self.feat_bn.weight, 1) 51 | init.constant(self.feat_bn.bias, 0) 52 | 53 | if self.dropout > 0: 54 | self.drop = nn.Dropout(self.dropout) 55 | #x2 classifier 56 | self.classifier_x2 = nn.Linear(self.num_features, self.num_classes) 57 | init.normal(self.classifier_x2.weight, std=0.001) 58 | init.constant(self.classifier_x2.bias, 0) 59 | # x3 module 60 | self.att1 = SELayer(256,reduction=16) 61 | self.att2 = SELayer(512,reduction=16) 62 | self.att3 = SELayer(1024,reduction=16) 63 | unset_module = [] 64 | self.embed_x3 = nn.Linear(1792, 1792, bias=False) 65 | unset_module.append(self.embed_x3) 66 | self.bn_x3 = nn.BatchNorm2d(1792) 67 | unset_module.append(self.bn_x3) 68 | self.prelu_x3 = nn.PReLU() 69 | self.classifier_x3 = nn.Linear(1792, self.num_classes) 70 | unset_module.append(self.classifier_x3) 71 | if not self.pretrained: 72 | self.reset_params() 73 | 74 | def forward(self, x): 75 | for name, module in self.base._modules.items(): 76 | if name == 'layer2': 77 | att1 = self.att1(x) 78 | x = x + att1*x 79 | att1 = att1*x 80 | if name == 'layer3': 81 | att2 = self.att2(x) 82 | x = x + att2*x 83 | att2 = att2*x 84 | if name == 'layer4': 85 | att3 = self.att3(x) 86 | x = x + att3*x 87 | att3 = att3*x 88 | if name == 'avgpool': 89 | break 90 | x = module(x) 91 | # triplet loss branch 92 | x1 = F.avg_pool2d(x, x.size()[2:]) 93 | x1 = x1.view(x1.size(0), -1) 94 | # global feature branch 95 | x2 = F.avg_pool2d(x, x.size()[2:]) 96 | x2 = x2.view(x2.size(0), -1) 97 | x2 = self.feat(x2) 98 | x2 = self.feat_bn(x2) 99 | x2 = self.prelu(x2) 100 | x2 = self.drop(x2) 101 | x2 = self.classifier_x2(x2) 102 | # x3 module 103 | ####### 104 | att1 = F.adaptive_avg_pool2d(att1,1) 105 | att1 = att1.view(att1.size(0),-1) 106 | att2 = F.adaptive_avg_pool2d(att2,1) 107 | att2 = att2.view(att2.size(0),-1) 108 | att3 = F.adaptive_avg_pool2d(att3,1) 109 | att3 = att3.view(att3.size(0),-1) 110 | att_feat = torch.cat([att1,att2,att3],1) 111 | att_feat = self.embed_x3(att_feat) 112 | att_feat = self.bn_x3(att_feat) 113 | att_feat = self.prelu(att_feat) 114 | att_feat = self.drop(att_feat) 115 | x3 = self.classifier_x3(att_feat) 116 | return x1,x2,x3 117 | 118 | def reset_params(self): 119 | for m in self.modules(): 120 | if isinstance(m, nn.Conv2d): 121 | init.normal(m.weight, std=0.001) 122 | if m.bias is not None: 123 | init.constant(m.bias, 0) 124 | elif isinstance(m, nn.BatchNorm2d): 125 | init.constant(m.weight, 1) 126 | init.constant(m.bias, 0) 127 | elif isinstance(m, nn.Linear): 128 | init.normal(m.weight, std=0.001) 129 | if m.bias is not None: 130 | init.constant(m.bias, 0) 131 | 132 | class SELayer(nn.Module): 133 | def __init__(self, channel, reduction=16): 134 | super(SELayer, self).__init__() 135 | self.conv = nn.Sequential( 136 | nn.Conv2d(channel, reduction, 1), 137 | nn.PReLU(), 138 | nn.Conv2d(reduction, channel, 1), 139 | nn.Sigmoid() 140 | ) 141 | 142 | for m in self.modules(): 143 | if isinstance(m, nn.Linear): 144 | init.normal(m.weight, std=0.001) 145 | if m.bias is not None: 146 | init.constant(m.bias, 0) 147 | elif isinstance(m, nn.BatchNorm2d): 148 | m.weight.data.fill_(1) 149 | m.bias.data.zero_() 150 | elif isinstance(m, nn.Conv2d): 151 | init.normal(m.weight, std=0.001) 152 | if m.bias is not None: 153 | init.constant(m.bias, 0) 154 | def forward(self, x): 155 | y = self.conv(x) 156 | return y 157 | 158 | class Attention_module(nn.Module): 159 | def __init__(self,in_planes,feat_nums,num_classes): 160 | super(Attention_module, self).__init__() 161 | self.conv = nn.Conv2d(in_planes,feat_nums,3,padding=1,groups=1) 162 | self.pool = nn.AdaptiveMaxPool2d((1,1)) 163 | self.classify_1 = nn.Linear(feat_nums,num_classes) 164 | for m in self.modules(): 165 | if isinstance(m, nn.Linear): 166 | init.normal(m.weight, std=0.001) 167 | if m.bias is not None: 168 | init.constant(m.bias, 0) 169 | elif isinstance(m, nn.BatchNorm2d): 170 | m.weight.data.fill_(1) 171 | m.bias.data.zero_() 172 | elif isinstance(m, nn.Conv2d): 173 | init.normal(m.weight, std=0.001) 174 | if m.bias is not None: 175 | init.constant(m.bias, 0) 176 | 177 | def forward(self, x): 178 | feat = x = self.conv(x) 179 | x = self.pool(x) 180 | x = F.relu(x) 181 | x = x.view(x.size(0),-1) 182 | x = self.classify_1(x) 183 | pred = torch.max(x,1)[1] 184 | wt = self.classify_1.weight[pred.data] # B*C 185 | wt = wt.view(wt.size(0), wt.size(1),1,1) 186 | cam = (wt*feat).sum(1,keepdim=True) 187 | cam = fm_norm(cam) 188 | cam_mask = F.sigmoid(cam) 189 | return x,cam_mask 190 | 191 | def fm_norm(inputs,p=2): 192 | ''' 193 | input should have shape of B*C*H*W 194 | ''' 195 | b,c,h,w = inputs.size() 196 | inputs = inputs.view(b, c, -1) 197 | inputs = F.normalize(inputs,p=p,dim=2) 198 | inputs = inputs.view(b,c,h,w) 199 | 200 | return inputs 201 | 202 | def init_params(*modules): 203 | ''' 204 | modules should be list or tuple 205 | ''' 206 | def reset_params(m): 207 | if isinstance(m, nn.Conv2d): 208 | init.normal(m.weight, std=0.001) 209 | if m.bias is not None: 210 | init.constant(m.bias, 0) 211 | elif isinstance(m, nn.Linear): 212 | init.normal(m.weight, std=0.001) 213 | if m.bias is not None: 214 | init.constant(m.bias, 0) 215 | elif isinstance(m, nn.BatchNorm2d): 216 | init.constant(m.weight,1) 217 | init.constant(m.bias, 0) 218 | 219 | for m in modules: 220 | reset_params(m) 221 | 222 | class Bottleneck(nn.Module): 223 | expansion = 4 224 | 225 | def __init__(self, inplanes, planes, stride=1): 226 | 227 | super(Bottleneck, self).__init__() 228 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 229 | self.bn1 = nn.BatchNorm2d(planes) 230 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 231 | padding=1, bias=False) 232 | self.bn2 = nn.BatchNorm2d(planes) 233 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 234 | self.bn3 = nn.BatchNorm2d(planes * 4) 235 | self.relu = nn.ReLU(inplace=True) 236 | self.downsample = nn.Sequential(nn.Conv2d(inplanes,planes*4,1,bias=False), 237 | nn.BatchNorm2d(planes*4)) 238 | #self.downsample = None 239 | self.stride = stride 240 | 241 | for m in self.modules(): 242 | if isinstance(m, nn.Linear): 243 | init.normal(m.weight, std=0.001) 244 | if m.bias is not None: 245 | init.constant(m.bias, 0) 246 | elif isinstance(m, nn.BatchNorm2d): 247 | m.weight.data.fill_(1) 248 | m.bias.data.zero_() 249 | elif isinstance(m, nn.Conv2d): 250 | init.normal(m.weight, std=0.001) 251 | if m.bias is not None: 252 | init.constant(m.bias, 0) 253 | 254 | def forward(self, x): 255 | residual = x 256 | 257 | out = self.conv1(x) 258 | out = self.bn1(out) 259 | out = self.relu(out) 260 | 261 | out = self.conv2(out) 262 | out = self.bn2(out) 263 | out = self.relu(out) 264 | 265 | out = self.conv3(out) 266 | out = self.bn3(out) 267 | 268 | if self.downsample is not None: 269 | residual = self.downsample(x) 270 | 271 | out += residual 272 | out = self.relu(out) 273 | 274 | return out 275 | 276 | def resnet18(**kwargs): 277 | return ResNet(18, **kwargs) 278 | 279 | 280 | def resnet34(**kwargs): 281 | return ResNet(34, **kwargs) 282 | 283 | 284 | def resnet50(**kwargs): 285 | return ResNet(50, **kwargs) 286 | 287 | 288 | def resnet101(**kwargs): 289 | return ResNet(101, **kwargs) 290 | 291 | 292 | def resnet152(**kwargs): 293 | return ResNet(152, **kwargs) 294 | 295 | --------------------------------------------------------------------------------