├── reid ├── models │ ├── backbones │ │ ├── __init__.py │ │ ├── AIBN.py │ │ ├── TNorm.py │ │ └── resnet.py │ ├── __init__.py │ └── ft_net.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── utils │ ├── data │ │ ├── __init__.py │ │ ├── preprocessor_camstyle.py │ │ ├── preprocessor.py │ │ ├── transforms.py │ │ ├── dataset.py │ │ ├── camera_dataset.py │ │ └── sampler.py │ ├── osutils.py │ ├── meters.py │ ├── __init__.py │ ├── logging.py │ └── serialization.py ├── loss │ ├── __init__.py │ ├── entropy_regularization.py │ └── triplet.py ├── cluster_utils │ ├── __init__.py │ ├── rerank.py │ └── cluster.py ├── __init__.py ├── feature_extraction │ ├── __init__.py │ ├── database.py │ └── cnn.py ├── datasets │ ├── cluster_dataset.py │ ├── __init__.py │ ├── dukemtmc.py │ ├── market1501.py │ └── msmt17.py ├── evaluators.py ├── evaluators_cos.py └── trainers.py ├── img └── fig1.png ├── .vscode └── settings.json ├── script ├── train_duke.sh ├── test_duke.sh ├── test_market.sh ├── train_market.sh ├── test_msmt.sh └── train_msmt.sh ├── .gitignore ├── README.md ├── requirements.txt └── example └── iids_tnorm_self_kd.py /reid/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /img/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SY-Xuan/IIDS/HEAD/img/fig1.png -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.formatting.provider": "autopep8" 3 | } -------------------------------------------------------------------------------- /script/train_duke.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:./ 2 | python ./example/iids_tnorm_self_kd.py --dataset dukemtmc --checkpoint path_to_checkpoint -------------------------------------------------------------------------------- /script/test_duke.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:./ 2 | python ./example/iids_tnorm_self_kd.py --dataset dukemtmc --evaluate --checkpoint path_to_checkpoint -------------------------------------------------------------------------------- /script/test_market.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:./ 2 | python ./example/iids_tnorm_self_kd.py --dataset market1501 --evaluate --checkpoint path_to_checkpoint -------------------------------------------------------------------------------- /script/train_market.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:./ 2 | python ./example/iids_tnorm_self_kd.py --dataset market1501 --checkpoint ~/disk/resnet50-0676ba61.pth -------------------------------------------------------------------------------- /script/test_msmt.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:./ 2 | python ./example/iids_tnorm_self_kd.py --dataset msmt17 --evaluate --use_cpu --checkpoint path_to_checkpoint -------------------------------------------------------------------------------- /script/train_msmt.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:./ 2 | python ./example/iids_tnorm_self_kd.py --dataset msmt17 --epochs_stage1 36 --use_cpu --checkpoint path_to_checkpoint -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .dataset import Dataset 4 | from .camera_dataset import CameraDataset 5 | from .preprocessor import Preprocessor, BothPreprocessor 6 | from .preprocessor_camstyle import PreprocessorCAM 7 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .triplet import TripletLoss, SoftTripletLoss 4 | from .entropy_regularization import SoftLabelLoss, SoftEntropy 5 | 6 | __all__ = [ 7 | 'TripletLoss', 8 | 'SoftLabelLoss', 9 | 'SoftEntropy', 10 | 'SoftTripletLoss' 11 | ] 12 | -------------------------------------------------------------------------------- /reid/cluster_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .cluster import get_intra_cam_cluster_result, \ 2 | get_inter_cam_cluster_result, \ 3 | get_inter_cam_cluster_result_tnorm 4 | 5 | 6 | __ALL__ = ['get_intra_cam_cluster_result', 7 | 'get_inter_cam_cluster_result_tnorm', 8 | 'get_inter_cam_cluster_result',] -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import feature_extraction 6 | from . import loss 7 | from . import models 8 | from . import utils 9 | from . import evaluators 10 | from . import evaluators_cos 11 | from . import trainers 12 | from . import cluster_utils 13 | 14 | __version__ = '0.2.0' 15 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cnn import extract_cnn_feature, extract_cnn_feature_with_tnorm, extract_cnn_feature_specific 4 | from .database import FeatureDatabase 5 | 6 | __all__ = [ 7 | 'extract_cnn_feature', 8 | 'FeatureDatabase', 9 | 'extract_cnn_feature_with_tnorm', 10 | 'extract_cnn_feature_specific' 11 | ] 12 | -------------------------------------------------------------------------------- /reid/datasets/cluster_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | 5 | class Cluster(object): 6 | def __init__(self, root, cluster_result_cam, cam_id): 7 | self.root = root 8 | self.train_set = [] 9 | classes = [] 10 | for fname, pid in cluster_result_cam.items(): 11 | self.train_set.append((fname, pid, cam_id)) 12 | classes.append(pid) 13 | self.classes_num = len(set(classes)) 14 | 15 | @property 16 | def images_dir(self): 17 | return osp.join(self.root, 'images') 18 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/loss/entropy_regularization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SoftEntropy(nn.Module): 7 | def __init__(self): 8 | super(SoftEntropy, self).__init__() 9 | self.logsoftmax = nn.LogSoftmax(dim=1) 10 | 11 | def forward(self, inputs, targets): 12 | log_probs = self.logsoftmax(inputs) 13 | loss = (-F.softmax(targets, dim=1).detach() * log_probs).mean(0).sum() 14 | return loss 15 | 16 | 17 | class SoftLabelLoss(nn.Module): 18 | def __init__(self, alpha=1., T=20): 19 | super(SoftLabelLoss, self).__init__() 20 | self.alpha = alpha 21 | self.T = T 22 | self.kl_div = nn.KLDivLoss(reduction='batchmean') 23 | 24 | def forward(self, p_logit, softlabel): 25 | """ 26 | Args: 27 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 28 | targets: ground truth labels with shape (num_classes) 29 | """ 30 | p_logit = p_logit.view(p_logit.size(0), -1) 31 | log_probs = self.logsoftmax(p_logit / self.T) 32 | 33 | return self.T * self.alpha * self.kl_div(log_probs, softlabel) 34 | -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .dukemtmc import DukeMTMC 5 | from .market1501 import Market1501 6 | from .cluster_dataset import Cluster 7 | from .msmt17 import MSMT17 8 | 9 | __factory = { 10 | 'market1501': Market1501, 11 | 'dukemtmc': DukeMTMC, 12 | 'msmt17': MSMT17, 13 | 'cluster': Cluster, 14 | } 15 | 16 | 17 | def names(): 18 | return sorted(__factory.keys()) 19 | 20 | 21 | def create(name, root, *args, **kwargs): 22 | """ 23 | Create a dataset instance. 24 | 25 | Parameters 26 | ---------- 27 | name : str 28 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03', 29 | 'market1501', and 'dukemtmc'. 30 | root : str 31 | The path to the dataset directory. 32 | split_id : int, optional 33 | The index of data split. Default: 0 34 | num_val : int or float, optional 35 | When int, it means the number of validation identities. When float, 36 | it means the proportion of validation to all the trainval. Default: 100 37 | download : bool, optional 38 | If True, will download the dataset. Default: False 39 | """ 40 | if name not in __factory: 41 | raise KeyError("Unknown dataset:", name) 42 | return __factory[name](root, *args, **kwargs) 43 | 44 | 45 | def get_dataset(name, root, *args, **kwargs): 46 | warnings.warn("get_dataset is deprecated. Use create instead.") 47 | return create(name, root, *args, **kwargs) 48 | -------------------------------------------------------------------------------- /reid/feature_extraction/database.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import h5py 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FeatureDatabase(Dataset): 9 | def __init__(self, *args, **kwargs): 10 | super(FeatureDatabase, self).__init__() 11 | self.fid = h5py.File(*args, **kwargs) 12 | 13 | def __enter__(self): 14 | return self 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | self.close() 18 | 19 | def __getitem__(self, keys): 20 | if isinstance(keys, (tuple, list)): 21 | return [self._get_single_item(k) for k in keys] 22 | return self._get_single_item(keys) 23 | 24 | def _get_single_item(self, key): 25 | return np.asarray(self.fid[key]) 26 | 27 | def __setitem__(self, key, value): 28 | if key in self.fid: 29 | if self.fid[key].shape == value.shape and \ 30 | self.fid[key].dtype == value.dtype: 31 | self.fid[key][...] = value 32 | else: 33 | del self.fid[key] 34 | self.fid.create_dataset(key, data=value) 35 | else: 36 | self.fid.create_dataset(key, data=value) 37 | 38 | def __delitem__(self, key): 39 | del self.fid[key] 40 | 41 | def __len__(self): 42 | return len(self.fid) 43 | 44 | def __iter__(self): 45 | return iter(self.fid) 46 | 47 | def flush(self): 48 | self.fid.flush() 49 | 50 | def close(self): 51 | self.fid.close() 52 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .ft_net import ft_net_inter, \ 4 | ft_net_intra, \ 5 | ft_net_inter_TNorm, \ 6 | ft_net_intra_TNorm, \ 7 | ft_net_both, \ 8 | ft_net_inter_specific, \ 9 | ft_net_intra_specific, \ 10 | ft_net_test \ 11 | 12 | 13 | __factory = { 14 | 'ft_net_inter': ft_net_inter, 15 | 'ft_net_intra': ft_net_intra, 16 | 'ft_net_intra_TNorm': ft_net_intra_TNorm, 17 | 'ft_net_inter_TNorm': ft_net_inter_TNorm, 18 | 'ft_net_both': ft_net_both, 19 | 'ft_net_inter_specific': ft_net_inter_specific, 20 | 'ft_net_intra_specific': ft_net_intra_specific, 21 | 'ft_net_test': ft_net_test, 22 | } 23 | 24 | 25 | def names(): 26 | return sorted(__factory.keys()) 27 | 28 | 29 | def create(name, *args, **kwargs): 30 | """ 31 | Create a model instance. 32 | 33 | Parameters 34 | ---------- 35 | name : str 36 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 37 | 'resnet50', 'resnet101', and 'resnet152'. 38 | pretrained : bool, optional 39 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 40 | model. Default: True 41 | cut_at_pooling : bool, optional 42 | If True, will cut the model before the last global pooling layer and 43 | ignore the remaining kwargs. Default: False 44 | num_features : int, optional 45 | If positive, will append a Linear layer after the global pooling layer, 46 | with this number of output units, followed by a BatchNorm layer. 47 | Otherwise these layers will not be appended. Default: 256 for 48 | 'inception', 0 for 'resnet*' 49 | norm : bool, optional 50 | If True, will normalize the feature to be unit L2-norm for each sample. 51 | Otherwise will append a ReLU layer after the above Linear layer if 52 | num_features > 0. Default: False 53 | dropout : float, optional 54 | If positive, will append a Dropout layer with this dropout rate. 55 | Default: 0 56 | num_classes : int, optional 57 | If positive, will append a Linear layer at the end as the classifier 58 | with this number of output units. Default: 0 59 | """ 60 | if name not in __factory: 61 | raise KeyError("Unknown model:", name) 62 | return __factory[name](*args, **kwargs) 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | logs/ 131 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor_camstyle.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | from PIL import Image 5 | from ..serialization import read_json 6 | import random 7 | 8 | class PreprocessorCAM(object): 9 | def __init__(self, dataset, root=None, transform=None, num_cameras=None, mutual=True): 10 | super(PreprocessorCAM, self).__init__() 11 | self.dataset = dataset 12 | self.root = root 13 | self.transform = transform 14 | self.cam_style_dir = osp.join(osp.join(self.root, ".."), "cam_style") 15 | self.fname2real_name = read_json(osp.join(osp.join(self.root, ".."), "fname2real_name.json")) 16 | self.num_cameras = num_cameras 17 | self.mutual = mutual 18 | 19 | def __len__(self): 20 | return len(self.dataset) 21 | 22 | def __getitem__(self, indices): 23 | # if isinstance(indices, (tuple, list)): 24 | # return [self._get_single_item(index) for index in indices] 25 | if self.mutual: 26 | return self._get_mutual_item(indices) 27 | else: 28 | return self._get_single_item(indices) 29 | 30 | def _get_mutual_item(self, index): 31 | fname, pid, _ = self.dataset[index] 32 | fpath = fname 33 | camid = int(fname.split("_")[1]) 34 | if self.root is not None: 35 | fpath = osp.join(self.root, fname) 36 | convert = np.random.rand() > 0.5 37 | if convert: 38 | while True: 39 | argue_camid = random.randint(1, self.num_cameras) 40 | if argue_camid != (camid + 1): 41 | fpath_mix_up = osp.join(self.cam_style_dir, "{}_fake_{}to{}.jpg".format(self.fname2real_name[fname].split(".")[0], camid + 1, argue_camid)) 42 | img_cam = Image.open(fpath_mix_up).convert('RGB') 43 | break 44 | else: 45 | img_cam = Image.open(fpath).convert('RGB') 46 | 47 | img = Image.open(fpath).convert('RGB') 48 | 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | img_cam = self.transform(img_cam) 52 | 53 | return img_cam, img, fname, pid, camid 54 | 55 | def _get_single_item(self, index): 56 | fname, pid, camid = self.dataset[index] 57 | fpath = fname 58 | if self.root is not None: 59 | fpath = osp.join(self.root, fname) 60 | img = Image.open(fpath).convert('RGB') 61 | if self.transform is not None: 62 | img = self.transform(img) 63 | return img, fname, pid, camid 64 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | 6 | 7 | class Preprocessor(object): 8 | def __init__(self, dataset, root=None, transform=None, mutual=False): 9 | super(Preprocessor, self).__init__() 10 | self.dataset = dataset 11 | self.root = root 12 | self.transform = transform 13 | self.mutual = mutual 14 | 15 | def __len__(self): 16 | return len(self.dataset) 17 | 18 | def __getitem__(self, indices): 19 | # if isinstance(indices, (tuple, list)): 20 | # return [self._get_single_item(index) for index in indices] 21 | if self.mutual: 22 | return self._get_mutual_item(indices) 23 | else: 24 | return self._get_single_item(indices) 25 | 26 | def _get_single_item(self, index): 27 | fname, pid, camid = self.dataset[index] 28 | fpath = fname 29 | if self.root is not None: 30 | fpath = osp.join(self.root, fname) 31 | img = Image.open(fpath).convert('RGB') 32 | if self.transform is not None: 33 | img = self.transform(img) 34 | return img, fname, pid, camid 35 | 36 | def _get_mutual_item(self, index): 37 | fname, pid, camid = self.dataset[index] 38 | fpath = fname 39 | if self.root is not None: 40 | fpath = osp.join(self.root, fname) 41 | 42 | img_1 = Image.open(fpath).convert('RGB') 43 | img_2 = img_1.copy() 44 | 45 | if self.transform is not None: 46 | img_1 = self.transform(img_1) 47 | img_2 = self.transform(img_2) 48 | 49 | return img_1, img_2, fname, pid, camid 50 | 51 | class BothPreprocessor(object): 52 | def __init__(self, dataset, root=None, transform=None): 53 | super(BothPreprocessor, self).__init__() 54 | self.dataset = dataset 55 | self.root = root 56 | self.transform = transform 57 | 58 | def __len__(self): 59 | return len(self.dataset) 60 | 61 | def __getitem__(self, indices): 62 | # if isinstance(indices, (tuple, list)): 63 | # return [self._get_single_item(index) for index in indices] 64 | return self._get_single_item(indices) 65 | 66 | def _get_single_item(self, index): 67 | fname, global_pid, cam_pid, camid = self.dataset[index] 68 | fpath = fname 69 | if self.root is not None: 70 | fpath = osp.join(self.root, fname) 71 | img = Image.open(fpath).convert('RGB') 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | return img, fname, global_pid, cam_pid, camid 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python 3.7.5](https://img.shields.io/badge/python-3.7.5-blue) 2 | ![PyTorch 1.3.1](https://img.shields.io/badge/pytorch-1.3.1-yellow) 3 | ![Cuda 9.2](https://img.shields.io/badge/cuda-9.2-yellowgreen) 4 | 5 | # IIDS 6 | Pytorch implementation of Paper ["Intra-Inter Camera Similarity for Unsupervised Person Re-Identification"](https://ieeexplore.ieee.org/abstract/document/9745321) (TPAMI 2022) 7 | 8 | This is the extended version of IICS on CVPR2021 9 | 10 | ![fig1](./img/fig1.png) 11 | 12 | ## Installation 13 | ### 1. Clone code 14 | ``` 15 | git clone git@github.com:SY-Xuan/IIDS.git 16 | cd ./IIDS 17 | ``` 18 | 19 | ### 2. Install dependency python packages 20 | ``` 21 | conda create --name IIDS --file requirements.txt 22 | ``` 23 | 24 | ### 3. Prepare dataset 25 | Download Market1501, DukeMTMC-ReID, MSMT17 from website and put the zip file under the directory like 26 | ``` 27 | ./data 28 | ├── dukemtmc 29 | │ └── raw 30 | | └──DukeMTMC-reID.zip 31 | ├── market1501 32 | | └── raw 33 | │ └── Market-1501-v15.09.15.zip 34 | |── msmt17 35 | | └── raw 36 | | └── MSMT17_V2.zip 37 | ``` 38 | ## Usage 39 | ### 1. Download trained model 40 | * [Market1501](https://pkueducn-my.sharepoint.com/:f:/g/personal/shiyu_xuan_stu_pku_edu_cn/EuaJrwvGqnpJo2vc851CmnkBZFK2VjU2pbs0YXIfOItsSg?e=rA41MH) 41 | * [DukeMTMC-ReID](https://pkueducn-my.sharepoint.com/:f:/g/personal/shiyu_xuan_stu_pku_edu_cn/EuaJrwvGqnpJo2vc851CmnkBZFK2VjU2pbs0YXIfOItsSg?e=rA41MH) 42 | * [MSMT17](https://pkueducn-my.sharepoint.com/:f:/g/personal/shiyu_xuan_stu_pku_edu_cn/EuaJrwvGqnpJo2vc851CmnkBZFK2VjU2pbs0YXIfOItsSg?e=rA41MH) 43 | 44 | ### 2. Evaluate Model 45 | Change the checkpoint path in the ./script/test_market.sh 46 | ``` 47 | sh ./script/test_market.sh 48 | ``` 49 | 50 | ### 3. Train Model 51 | You need to download ResNet-50 imagenet pretrained model and change the checkpoint path in the ./script/train_market.sh 52 | ``` 53 | sh ./script/train_market.sh 54 | ``` 55 | 56 | ## Results 57 | |Datasets | mAP | Rank@1| Method | 58 | | :--------: | :-----: | :----: | :----: | 59 | |Market1501 | 72.9% | 89.5% | CVPR2021 | 60 | |Market1501 | 78.0% | 91.2% | This Version | 61 | |DukeMTMC-ReID | 64.4% | 80.0% | CVPR2021 | 62 | |DukeMTMC-ReID | 68.7% | 82.1% | This Version | 63 | |MSMT17 | 26.9% | 56.4% | CVPR2021 | 64 | |MSMT17 | 35.1% | 64.4% | This Version | 65 | 66 | ## Citations 67 | If you find this code useful for your research, please cite our paper: 68 | 69 | ``` 70 | @ARTICLE{9745321, 71 | author={Xuan, Shiyu and Zhang, Shiliang}, 72 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 73 | title={Intra-Inter Domain Similarity for Unsupervised Person Re-Identification}, 74 | year={2022}, 75 | volume={}, 76 | number={}, 77 | pages={1-1}, 78 | doi={10.1109/TPAMI.2022.3163451}} 79 | 80 | @inproceedings{xuan2021intra, 81 | title={Intra-inter camera similarity for unsupervised person re-identification}, 82 | author={Xuan, Shiyu and Zhang, Shiliang}, 83 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 84 | pages={11926--11935}, 85 | year={2021} 86 | } 87 | ``` 88 | ## Contact me 89 | If you have any questions about this code or paper, feel free to contact me at 90 | shiyu_xuan@stu.pku.edu.cn. 91 | 92 | ## Acknowledgement 93 | Codes are built upon [open-reid](https://github.com/Cysu/open-reid). -------------------------------------------------------------------------------- /reid/cluster_utils/rerank.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def re_ranking(original_dist, k1=20, k2=6, lambda_value=0.3): 5 | 6 | all_num = original_dist.shape[0] 7 | 8 | euclidean_dist = original_dist 9 | gallery_num = original_dist.shape[0] #gallery_num=all_num 10 | 11 | #original_dist = original_dist - np.min(original_dist) 12 | original_dist = original_dist - np.min(original_dist,axis = 0) 13 | original_dist = np.transpose(original_dist/np.max(original_dist,axis = 0)) 14 | V = np.zeros_like(original_dist).astype(np.float16) 15 | initial_rank = np.argsort(original_dist).astype(np.int32) ## default axis=-1. 16 | 17 | print('Starting re_ranking...') 18 | for i in range(all_num): 19 | # k-reciprocal neighbors 20 | forward_k_neigh_index = initial_rank[i,:k1+1] ## k1+1 because self always ranks first. forward_k_neigh_index.shape=[k1+1]. forward_k_neigh_index[0] == i. 21 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] ##backward.shape = [k1+1, k1+1]. For each ele in forward_k_neigh_index, find its rank k1 neighbors 22 | fi = np.where(backward_k_neigh_index==i)[0] 23 | k_reciprocal_index = forward_k_neigh_index[fi] ## get R(p,k) in the paper 24 | k_reciprocal_expansion_index = k_reciprocal_index 25 | for j in range(len(k_reciprocal_index)): 26 | candidate = k_reciprocal_index[j] 27 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2))+1] 28 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2))+1] 29 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 30 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 31 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2/3*len(candidate_k_reciprocal_index): 32 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 33 | 34 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 35 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 36 | V[i,k_reciprocal_expansion_index] = weight/np.sum(weight) 37 | #original_dist = original_dist[:query_num,] 38 | if k2 != 1: 39 | V_qe = np.zeros_like(V,dtype=np.float16) 40 | for i in range(all_num): 41 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 42 | V = V_qe 43 | del V_qe 44 | del initial_rank 45 | invIndex = [] 46 | for i in range(gallery_num): 47 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num 48 | 49 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float16) 50 | 51 | for i in range(all_num): 52 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float16) 53 | indNonZero = np.where(V[i,:] != 0)[0] 54 | indImages = [] 55 | indImages = [invIndex[ind] for ind in indNonZero] 56 | for j in range(len(indNonZero)): 57 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 58 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 59 | 60 | pos_bool = (jaccard_dist < 0) 61 | jaccard_dist[pos_bool] = 0.0 62 | 63 | #return jaccard_dist 64 | if lambda_value == 0: 65 | return jaccard_dist 66 | else: 67 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 68 | return final_dist 69 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | class RandomErasing(object): 52 | """ Randomly selects a rectangle region in an image and erases its pixels. 53 | 'Random Erasing Data Augmentation' by Zhong et al. 54 | See https://arxiv.org/pdf/1708.04896.pdf 55 | Args: 56 | probability: The probability that the Random Erasing operation will be performed. 57 | sl: Minimum proportion of erased area against input image. 58 | sh: Maximum proportion of erased area against input image. 59 | r1: Minimum aspect ratio of erased area. 60 | mean: Erasing value. 61 | """ 62 | 63 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 64 | self.probability = probability 65 | self.mean = mean 66 | self.sl = sl 67 | self.sh = sh 68 | self.r1 = r1 69 | 70 | def __call__(self, img): 71 | 72 | if random.uniform(0, 1) > self.probability: 73 | return img 74 | 75 | for attempt in range(100): 76 | area = img.size()[1] * img.size()[2] 77 | 78 | target_area = random.uniform(self.sl, self.sh) * area 79 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 80 | 81 | h = int(round(math.sqrt(target_area * aspect_ratio))) 82 | w = int(round(math.sqrt(target_area / aspect_ratio))) 83 | 84 | if w < img.size()[2] and h < img.size()[1]: 85 | x1 = random.randint(0, img.size()[1] - h) 86 | y1 = random.randint(0, img.size()[2] - w) 87 | if img.size()[0] == 3: 88 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 89 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 90 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 91 | else: 92 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 93 | return img 94 | 95 | return img 96 | -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from ..utils import to_torch 6 | import numpy as np 7 | 8 | 9 | def extract_cnn_feature(model, inputs, norm=True): 10 | def fliplr(img): 11 | '''flip horizontal''' 12 | inv_idx = torch.arange(img.size(3) - 1, -1, -1).long() # N x C x H x W 13 | img_flip = img.index_select(3, inv_idx) 14 | return img_flip 15 | 16 | model.eval() 17 | inputs = to_torch(inputs) 18 | 19 | n, c, h, w = inputs.size() 20 | 21 | ff = torch.FloatTensor(n, 2048).zero_().cuda() 22 | 23 | for i in range(2): 24 | if (i == 1): 25 | inputs = fliplr(inputs) 26 | inputs2 = inputs.cuda() 27 | if hasattr(model, "module"): 28 | outputs = model.module.backbone_forward(inputs2) 29 | else: 30 | outputs = model.backbone_forward(inputs2) 31 | outputs = outputs.view(outputs.size(0), outputs.size(1)) 32 | 33 | ff += outputs * 0.5 34 | if norm: 35 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 36 | ff = ff.div(fnorm.expand_as(ff)) 37 | 38 | return ff 39 | 40 | 41 | def extract_cnn_feature_specific(model, inputs, camid, norm=True): 42 | def fliplr(img): 43 | '''flip horizontal''' 44 | inv_idx = torch.arange(img.size(3) - 1, -1, -1).long() # N x C x H x W 45 | img_flip = img.index_select(3, inv_idx) 46 | return img_flip 47 | 48 | model.eval() 49 | inputs = to_torch(inputs) 50 | 51 | n, c, h, w = inputs.size() 52 | 53 | ff = torch.FloatTensor(n, 2048).zero_().cuda() 54 | 55 | for i in range(2): 56 | if (i == 1): 57 | inputs = fliplr(inputs) 58 | inputs2 = inputs.cuda() 59 | if hasattr(model, "module"): 60 | outputs = model.module.backbone_forward(inputs2, camid, True) 61 | else: 62 | outputs = model.backbone_forward(inputs2, camid, True) 63 | outputs = outputs.view(outputs.size(0), outputs.size(1)) 64 | 65 | ff += outputs * 0.5 66 | if norm: 67 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 68 | ff = ff.div(fnorm.expand_as(ff)) 69 | 70 | return ff 71 | 72 | 73 | def extract_cnn_feature_with_tnorm(model, 74 | inputs, 75 | camid, 76 | convert_domain_index, 77 | norm=True): 78 | def fliplr(img): 79 | '''flip horizontal''' 80 | inv_idx = torch.arange(img.size(3) - 1, -1, -1).long() # N x C x H x W 81 | img_flip = img.index_select(3, inv_idx) 82 | return img_flip 83 | 84 | model.eval() 85 | inputs = to_torch(inputs) 86 | 87 | n, c, h, w = inputs.size() 88 | 89 | ff = torch.FloatTensor(n, 2048).zero_().cuda() 90 | domain_index = (camid.view(n), convert_domain_index) 91 | for i in range(2): 92 | if (i == 1): 93 | inputs = fliplr(inputs) 94 | inputs2 = inputs.cuda() 95 | if hasattr(model, "module"): 96 | outputs = model.module.backbone_forward(inputs2, 97 | domain_index, 98 | convert=True) 99 | else: 100 | outputs = model.backbone_forward(inputs2, 101 | domain_index, 102 | convert=True) 103 | outputs = outputs.view(outputs.size(0), outputs.size(1)) 104 | 105 | ff += outputs * 0.5 106 | if norm: 107 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 108 | ff = ff.div(fnorm.expand_as(ff)) 109 | 110 | return ff 111 | -------------------------------------------------------------------------------- /reid/models/backbones/AIBN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AIBNorm2d(nn.Module): 6 | def __init__(self, 7 | num_features, 8 | eps=1e-5, 9 | momentum=0.9, 10 | using_moving_average=True, 11 | only_bn=False, 12 | last_gamma=False, 13 | adaptive_weight=None, 14 | generate_weight=False, 15 | init_weight=0.1): 16 | super(AIBNorm2d, self).__init__() 17 | self.num_features = num_features 18 | self.eps = eps 19 | self.momentum = momentum 20 | self.using_moving_average = using_moving_average 21 | self.only_bn = only_bn 22 | self.last_gamma = last_gamma 23 | self.generate_weight = generate_weight 24 | if generate_weight: 25 | self.weight = nn.Parameter(torch.ones(num_features)) 26 | self.bias = nn.Parameter(torch.zeros(num_features)) 27 | if not only_bn: 28 | if adaptive_weight is not None: 29 | self.adaptive_weight = adaptive_weight 30 | else: 31 | self.adaptive_weight = nn.Parameter( 32 | torch.ones(1) * init_weight) 33 | self.register_buffer('running_mean', torch.zeros(num_features)) 34 | self.register_buffer('running_var', torch.ones(num_features)) 35 | 36 | self.reset_parameters() 37 | 38 | def reset_parameters(self): 39 | 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def _check_input_dim(self, input): 44 | if input.dim() != 4: 45 | raise ValueError('expected 4D input (got {}D input)'.format( 46 | input.dim())) 47 | 48 | def forward(self, x, weight=None, bias=None): 49 | self._check_input_dim(x) 50 | N, C, H, W = x.size() 51 | x = x.view(N, C, -1) 52 | mean_in = x.mean(-1, keepdim=True) 53 | var_in = x.var(-1, keepdim=True) 54 | 55 | temp = var_in + mean_in**2 56 | 57 | if self.training: 58 | mean_bn = mean_in.mean(0, keepdim=True) 59 | var_bn = temp.mean(0, keepdim=True) - mean_bn**2 60 | if self.using_moving_average: 61 | self.running_mean.mul_(self.momentum) 62 | self.running_mean.add_( 63 | (1 - self.momentum) * mean_bn.squeeze().data) 64 | self.running_var.mul_(self.momentum) 65 | self.running_var.add_( 66 | (1 - self.momentum) * var_bn.squeeze().data) 67 | else: 68 | self.running_mean.add_(mean_bn.squeeze().data) 69 | self.running_var.add_(mean_bn.squeeze().data**2 + 70 | var_bn.squeeze().data) 71 | else: 72 | mean_bn = torch.autograd.Variable( 73 | self.running_mean).unsqueeze(0).unsqueeze(2) 74 | var_bn = torch.autograd.Variable( 75 | self.running_var).unsqueeze(0).unsqueeze(2) 76 | 77 | if not self.only_bn: 78 | 79 | adaptive_weight = torch.clamp(self.adaptive_weight, 0, 1) 80 | mean = (1 - adaptive_weight[0] 81 | ) * mean_in + adaptive_weight[0] * mean_bn 82 | var = (1 - 83 | adaptive_weight[0]) * var_in + adaptive_weight[0] * var_bn 84 | 85 | x = (x - mean) / (var + self.eps).sqrt() 86 | x = x.view(N, C, H, W) 87 | else: 88 | x = (x - mean_bn) / (var_bn + self.eps).sqrt() 89 | x = x.view(N, C, H, W) 90 | 91 | if self.generate_weight: 92 | weight = self.weight.view(1, self.num_features, 1, 1) 93 | bias = self.bias.view(1, self.num_features, 1, 1) 94 | else: 95 | weight = weight.view(1, self.num_features, 1, 1) 96 | bias = bias.view(1, self.num_features, 1, 1) 97 | return x * weight + bias 98 | -------------------------------------------------------------------------------- /reid/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import write_json 7 | from tqdm import tqdm 8 | 9 | class DukeMTMC(Dataset): 10 | url = 'https://drive.google.com/uc?id=0B0VOCNYh8HeRdnBPa2ZWaVBYSVk' 11 | md5 = '2f93496f9b516d1ee5ef51c1d5e7d601' 12 | 13 | def __init__(self, root, split_id=0, num_val=100, download=True): 14 | super(DukeMTMC, self).__init__(root, split_id=split_id) 15 | 16 | if download: 17 | self.download() 18 | 19 | if not self._check_integrity(): 20 | raise RuntimeError("Dataset not found or corrupted. " + 21 | "You can use download=True to download it.") 22 | 23 | self.load(num_val) 24 | 25 | def download(self): 26 | if self._check_integrity(): 27 | print("Files already downloaded and verified") 28 | return 29 | 30 | import re 31 | import hashlib 32 | import shutil 33 | from glob import glob 34 | from zipfile import ZipFile 35 | 36 | raw_dir = osp.join(self.root, 'raw') 37 | mkdir_if_missing(raw_dir) 38 | 39 | # Download the raw zip file 40 | fpath = osp.join(raw_dir, 'DukeMTMC-reID.zip') 41 | # if not osp.isfile(fpath): 42 | # raise RuntimeError("Please download the dataset manually from {} " 43 | # "to {}".format(self.url, fpath)) 44 | 45 | # Extract the file 46 | exdir = osp.join(raw_dir, 'DukeMTMC-reID') 47 | if not osp.isdir(exdir): 48 | print("Extracting zip file") 49 | with ZipFile(fpath) as z: 50 | z.extractall(path=raw_dir) 51 | 52 | # Format 53 | images_dir = osp.join(self.root, 'images') 54 | print(self.root) 55 | mkdir_if_missing(images_dir) 56 | 57 | identities = [] 58 | all_pids = {} 59 | 60 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 61 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) 62 | pids = set() 63 | fnames = [] 64 | for fpath in tqdm(fpaths): 65 | fname = osp.basename(fpath) 66 | pid, cam = map(int, pattern.search(fname).groups()) 67 | assert 1 <= cam <= 8 68 | cam -= 1 69 | if pid not in all_pids: 70 | all_pids[pid] = len(all_pids) 71 | pid = all_pids[pid] 72 | pids.add(pid) 73 | if pid >= len(identities): 74 | assert pid == len(identities) 75 | identities.append([[] for _ in range(8)]) # 8 camera views 76 | fname = ('{:08d}_{:02d}_{:04d}.jpg'.format( 77 | pid, cam, len(identities[pid][cam]))) 78 | identities[pid][cam].append(fname) 79 | fnames.append(fname) 80 | shutil.copy(fpath, osp.join(images_dir, fname)) 81 | return pids, fnames 82 | 83 | trainval_pids, _ = register('bounding_box_train') 84 | gallery_pids, gallery_names = register('bounding_box_test') 85 | query_pids, query_names = register('query') 86 | assert query_pids <= gallery_pids 87 | assert trainval_pids.isdisjoint(gallery_pids) 88 | 89 | # Save meta information into a json file 90 | meta = { 91 | 'name': 'DukeMTMC', 92 | 'shot': 'multiple', 93 | 'num_cameras': 8, 94 | 'identities': identities, 95 | 'gallery_names': gallery_names, 96 | 'query_names': query_names, 97 | } 98 | write_json(meta, osp.join(self.root, 'meta.json')) 99 | 100 | # Save the only training / test split 101 | splits = [{ 102 | 'trainval': sorted(list(trainval_pids)), 103 | 'query': sorted(list(query_pids)), 104 | 'gallery': sorted(list(gallery_pids)) 105 | }] 106 | write_json(splits, osp.join(self.root, 'splits.json')) 107 | -------------------------------------------------------------------------------- /reid/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import write_json 7 | 8 | from tqdm import tqdm 9 | 10 | class Market1501(Dataset): 11 | url = 'https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view' 12 | md5 = '65005ab7d12ec1c44de4eeafe813e68a' 13 | 14 | def __init__(self, root, split_id=0, num_val=100, download=True): 15 | super(Market1501, self).__init__(root, split_id=split_id) 16 | 17 | if download: 18 | self.download() 19 | 20 | if not self._check_integrity(): 21 | raise RuntimeError("Dataset not found or corrupted. " + 22 | "You can use download=True to download it.") 23 | 24 | self.load(num_val) 25 | 26 | def download(self): 27 | if self._check_integrity(): 28 | print("Files already downloaded and verified") 29 | return 30 | 31 | import re 32 | import hashlib 33 | import shutil 34 | from glob import glob 35 | from zipfile import ZipFile 36 | 37 | raw_dir = osp.join(self.root, 'raw') 38 | mkdir_if_missing(raw_dir) 39 | 40 | # Download the raw zip file 41 | fpath = osp.join(raw_dir, 'Market-1501-v15.09.15.zip') 42 | # if osp.isfile(fpath) and \ 43 | # hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 44 | # print("Using downloaded file: " + fpath) 45 | # else: 46 | # raise RuntimeError("Please download the dataset manually from {} " 47 | # "to {}".format(self.url, fpath)) 48 | 49 | # Extract the file 50 | exdir = osp.join(raw_dir, 'Market-1501-v15.09.15') 51 | if not osp.isdir(exdir): 52 | print("Extracting zip file") 53 | with ZipFile(fpath) as z: 54 | z.extractall(path=raw_dir) 55 | 56 | # Format 57 | images_dir = osp.join(self.root, 'images') 58 | mkdir_if_missing(images_dir) 59 | 60 | # 1501 identities (+1 for background) with 6 camera views each 61 | identities = [[[] for _ in range(6)] for _ in range(1502)] 62 | 63 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 64 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) 65 | pids = set() 66 | fnames = [] 67 | for fpath in tqdm(fpaths): 68 | fname = osp.basename(fpath) 69 | pid, cam = map(int, pattern.search(fname).groups()) 70 | if pid == -1: continue # junk images are just ignored 71 | assert 0 <= pid <= 1501 # pid == 0 means background 72 | assert 1 <= cam <= 6 73 | cam -= 1 74 | pids.add(pid) 75 | fname = ('{:08d}_{:02d}_{:04d}.jpg'.format( 76 | pid, cam, len(identities[pid][cam]))) 77 | identities[pid][cam].append(fname) 78 | fnames.append(fname) 79 | shutil.copy(fpath, osp.join(images_dir, fname)) 80 | return pids, fnames 81 | 82 | trainval_pids, _ = register('bounding_box_train') 83 | gallery_pids, gallery_names = register('bounding_box_test') 84 | query_pids, query_names = register('query') 85 | assert query_pids <= gallery_pids 86 | assert trainval_pids.isdisjoint(gallery_pids) 87 | 88 | # Save meta information into a json file 89 | meta = { 90 | 'name': 'Market1501', 91 | 'shot': 'multiple', 92 | 'num_cameras': 6, 93 | 'identities': identities, 94 | 'gallery_names': gallery_names, 95 | 'query_names': query_names, 96 | } 97 | write_json(meta, osp.join(self.root, 'meta.json')) 98 | 99 | # Save the only training / test split 100 | splits = [{ 101 | 'trainval': sorted(list(trainval_pids)), 102 | 'query': sorted(list(query_pids)), 103 | 'gallery': sorted(list(gallery_pids)) 104 | }] 105 | write_json(splits, osp.join(self.root, 'splits.json')) 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _pytorch_select=0.2=gpu_0 6 | argon2-cffi=20.1.0=pypi_0 7 | async-generator=1.10=pypi_0 8 | attrs=20.3.0=pypi_0 9 | backcall=0.2.0=pypi_0 10 | blas=1.0=mkl 11 | bleach=3.2.1=pypi_0 12 | ca-certificates=2020.7.22=0 13 | certifi=2020.6.20=py37_0 14 | cffi=1.14.0=py37h2e261b9_0 15 | chardet=3.0.4=pypi_0 16 | click=7.1.2=pypi_0 17 | configparser=5.0.1=pypi_0 18 | cudatoolkit=9.2=0 19 | cudnn=7.6.5=cuda9.2_0 20 | cycler=0.10.0=pypi_0 21 | cython=0.29.21=pypi_0 22 | decorator=4.4.2=pypi_0 23 | defusedxml=0.6.0=pypi_0 24 | docker-pycreds=0.4.0=pypi_0 25 | entrypoints=0.3=pypi_0 26 | flake8=3.8.4=pypi_0 27 | freetype=2.10.2=h5ab3b9f_0 28 | gitdb=4.0.5=pypi_0 29 | gitpython=3.1.9=pypi_0 30 | h5py=2.10.0=pypi_0 31 | idna=2.10=pypi_0 32 | imageio=2.9.0=pypi_0 33 | importlib-metadata=2.0.0=pypi_0 34 | intel-openmp=2020.2=254 35 | ipykernel=5.4.3=pypi_0 36 | ipython=7.19.0=pypi_0 37 | ipython-genutils=0.2.0=pypi_0 38 | ipywidgets=7.6.3=pypi_0 39 | jedi=0.18.0=pypi_0 40 | jinja2=2.11.2=pypi_0 41 | joblib=0.16.0=pypi_0 42 | jpeg=9b=h024ee3a_2 43 | jsonschema=3.2.0=pypi_0 44 | jupyter=1.0.0=pypi_0 45 | jupyter-client=6.1.11=pypi_0 46 | jupyter-console=6.2.0=pypi_0 47 | jupyter-core=4.7.0=pypi_0 48 | jupyterlab-pygments=0.1.2=pypi_0 49 | jupyterlab-widgets=1.0.0=pypi_0 50 | kiwisolver=1.2.0=pypi_0 51 | lcms2=2.11=h396b838_0 52 | ld_impl_linux-64=2.33.1=h53a641e_7 53 | libedit=3.1.20191231=h14c3975_1 54 | libffi=3.2.1=hf484d3e_1007 55 | libgcc-ng=9.1.0=hdf63c60_0 56 | libpng=1.6.37=hbc83047_0 57 | libstdcxx-ng=9.1.0=hdf63c60_0 58 | libtiff=4.1.0=h2733197_1 59 | llvmlite=0.34.0=pypi_0 60 | lz4-c=1.9.2=he6710b0_1 61 | markupsafe=1.1.1=pypi_0 62 | matplotlib=3.3.2=pypi_0 63 | mccabe=0.6.1=pypi_0 64 | metric-learn=0.6.2=pypi_0 65 | mistune=0.8.4=pypi_0 66 | mkl=2020.2=256 67 | mkl-service=2.3.0=py37he904b0f_0 68 | mkl_fft=1.2.0=py37h23d657b_0 69 | mkl_random=1.1.1=py37h0573a6f_0 70 | nbclient=0.5.1=pypi_0 71 | nbconvert=6.0.7=pypi_0 72 | nbformat=5.0.8=pypi_0 73 | ncurses=6.2=he6710b0_1 74 | nest-asyncio=1.4.3=pypi_0 75 | networkx=2.5=pypi_0 76 | ninja=1.10.1=py37hfd86e86_0 77 | notebook=6.1.6=pypi_0 78 | numba=0.51.2=pypi_0 79 | numpy=1.19.1=py37hbc911f0_0 80 | numpy-base=1.19.1=py37hfa32c7d_0 81 | nvidia-ml-py3=7.352.0=pypi_0 82 | olefile=0.46=py37_0 83 | opencv-python=4.4.0.44=pypi_0 84 | openssl=1.1.1h=h7b6447c_0 85 | packaging=20.8=pypi_0 86 | pandas=1.1.4=pypi_0 87 | pandocfilters=1.4.3=pypi_0 88 | parso=0.8.1=pypi_0 89 | pathtools=0.1.2=pypi_0 90 | pexpect=4.8.0=pypi_0 91 | pickleshare=0.7.5=pypi_0 92 | pillow=7.2.0=py37hb39fc2d_0 93 | pip=20.2.2=py37_0 94 | prometheus-client=0.9.0=pypi_0 95 | promise=2.3=pypi_0 96 | prompt-toolkit=3.0.10=pypi_0 97 | protobuf=3.13.0=pypi_0 98 | psutil=5.7.2=pypi_0 99 | ptyprocess=0.7.0=pypi_0 100 | pycocotools=2.0.2=pypi_0 101 | pycodestyle=2.6.0=pypi_0 102 | pycparser=2.20=py_2 103 | pyflakes=2.2.0=pypi_0 104 | pygments=2.7.3=pypi_0 105 | pyparsing=2.4.7=pypi_0 106 | pyrsistent=0.17.3=pypi_0 107 | python=3.7.5=h0371630_0 108 | python-dateutil=2.8.1=pypi_0 109 | pytorch=1.3.1=cuda92py37hb0ba70e_0 110 | pytz=2020.4=pypi_0 111 | pywavelets=1.1.1=pypi_0 112 | pyyaml=5.3.1=pypi_0 113 | pyzmq=20.0.0=pypi_0 114 | qtconsole=5.0.1=pypi_0 115 | qtpy=1.9.0=pypi_0 116 | readline=7.0=h7b6447c_5 117 | requests=2.24.0=pypi_0 118 | scikit-image=0.17.2=pypi_0 119 | scikit-learn=0.23.2=pypi_0 120 | scipy=1.5.2=pypi_0 121 | send2trash=1.5.0=pypi_0 122 | sentry-sdk=0.19.0=pypi_0 123 | setuptools=49.6.0=py37_1 124 | shortuuid=1.0.1=pypi_0 125 | six=1.15.0=py_0 126 | sklearn=0.0=pypi_0 127 | smmap=3.0.4=pypi_0 128 | sqlite=3.33.0=h62c20be_0 129 | subprocess32=3.5.4=pypi_0 130 | terminado=0.9.2=pypi_0 131 | testpath=0.4.4=pypi_0 132 | threadpoolctl=2.1.0=pypi_0 133 | tifffile=2020.11.26=pypi_0 134 | tk=8.6.10=hbc83047_0 135 | torchvision=0.4.2=cuda92py37h1667eeb_0 136 | tornado=6.1=pypi_0 137 | tqdm=4.50.0=pypi_0 138 | traitlets=5.0.5=pypi_0 139 | urllib3=1.25.10=pypi_0 140 | wandb=0.10.5=pypi_0 141 | watchdog=0.10.3=pypi_0 142 | wcwidth=0.2.5=pypi_0 143 | webencodings=0.5.1=pypi_0 144 | wheel=0.35.1=py_0 145 | widgetsnbextension=3.5.1=pypi_0 146 | xz=5.2.5=h7b6447c_0 147 | zipp=3.3.0=pypi_0 148 | zlib=1.2.11=h7b6447c_3 149 | zstd=1.4.5=h9ceee32_0 -------------------------------------------------------------------------------- /reid/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..serialization import read_json 7 | 8 | 9 | def _pluck(identities, indices, relabel=False, validate_names=None): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, cam_images in enumerate(pid_images): 14 | for fname in cam_images: 15 | if validate_names is not None: 16 | if fname not in validate_names: 17 | continue 18 | name = osp.splitext(fname)[0] 19 | x, y, _ = map(int, name.split('_')) 20 | assert pid == x and camid == y 21 | if relabel: 22 | ret.append((fname, index, camid)) 23 | else: 24 | ret.append((fname, pid, camid)) 25 | return ret 26 | 27 | 28 | class Dataset(object): 29 | def __init__(self, root, split_id=0): 30 | self.root = root 31 | self.split_id = split_id 32 | self.meta = None 33 | self.split = None 34 | self.train, self.val, self.trainval = [], [], [] 35 | self.query, self.gallery = [], [] 36 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 37 | 38 | @property 39 | def images_dir(self): 40 | return osp.join(self.root, 'images') 41 | 42 | def load(self, num_val=0.3, verbose=True): 43 | splits = read_json(osp.join(self.root, 'splits.json')) 44 | if self.split_id >= len(splits): 45 | raise ValueError("split_id exceeds total splits {}" 46 | .format(len(splits))) 47 | self.split = splits[self.split_id] 48 | # Randomly split train / val 49 | trainval_pids = np.asarray(self.split['trainval']) 50 | np.random.shuffle(trainval_pids) 51 | num = len(trainval_pids) 52 | if isinstance(num_val, float): 53 | num_val = int(round(num * num_val)) 54 | 55 | if num_val >= num or num_val < 0: 56 | raise ValueError("num_val exceeds total identities {}" 57 | .format(num)) 58 | train_pids = sorted(trainval_pids[:-num_val]) 59 | val_pids = sorted(trainval_pids[-num_val:]) 60 | 61 | self.meta = read_json(osp.join(self.root, 'meta.json')) 62 | self.num_cameras = self.meta['num_cameras'] 63 | identities = self.meta['identities'] 64 | gallery_names = self.meta.get('gallery_names', None) 65 | if gallery_names is not None: 66 | gallery_names = set(gallery_names) 67 | query_names = self.meta.get('query_names', None) 68 | if query_names is not None: 69 | query_names = set(query_names) 70 | self.train = _pluck(identities, train_pids, relabel=True) 71 | self.val = _pluck(identities, val_pids, relabel=True) 72 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 73 | self.query = _pluck(identities, self.split['query'], validate_names=query_names) 74 | self.gallery = _pluck(identities, self.split['gallery'], validate_names=gallery_names) 75 | self.num_train_ids = len(train_pids) 76 | self.num_val_ids = len(val_pids) 77 | self.num_trainval_ids = len(trainval_pids) 78 | 79 | if verbose: 80 | print(self.__class__.__name__, "dataset loaded") 81 | print(" subset | # ids | # images") 82 | print(" ---------------------------") 83 | print(" train | {:5d} | {:8d}" 84 | .format(self.num_train_ids, len(self.train))) 85 | print(" val | {:5d} | {:8d}" 86 | .format(self.num_val_ids, len(self.val))) 87 | print(" trainval | {:5d} | {:8d}" 88 | .format(self.num_trainval_ids, len(self.trainval))) 89 | print(" query | {:5d} | {:8d}" 90 | .format(len(self.split['query']), len(self.query))) 91 | print(" gallery | {:5d} | {:8d}" 92 | .format(len(self.split['gallery']), len(self.gallery))) 93 | 94 | def _check_integrity(self): 95 | return osp.isdir(osp.join(self.root, 'images')) and \ 96 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 97 | osp.isfile(osp.join(self.root, 'splits.json')) 98 | -------------------------------------------------------------------------------- /reid/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import write_json 7 | from tqdm import tqdm 8 | 9 | 10 | class MSMT17(Dataset): 11 | def __init__(self, root, split_id=0, num_val=100, download=True): 12 | super(MSMT17, self).__init__(root, split_id=split_id) 13 | 14 | if download: 15 | self.download() 16 | 17 | if not self._check_integrity(): 18 | raise RuntimeError("Dataset not found or corrupted. " + 19 | "You can use download=True to download it.") 20 | 21 | self.load(num_val) 22 | 23 | def download(self): 24 | if self._check_integrity(): 25 | print("Files already downloaded and verified") 26 | return 27 | print("Generate dataset this step may take some times") 28 | import re 29 | import hashlib 30 | import shutil 31 | from glob import glob 32 | from zipfile import ZipFile 33 | 34 | raw_dir = osp.join(self.root, 'raw') 35 | mkdir_if_missing(raw_dir) 36 | 37 | fpath = osp.join(raw_dir, 'MSMT17_V2.zip') 38 | # if not osp.isfile(fpath): 39 | # raise RuntimeError("Please download the dataset manually") 40 | 41 | # Extract the file 42 | exdir = raw_dir 43 | if not osp.isdir(exdir): 44 | print("Extracting zip file") 45 | with ZipFile(fpath) as z: 46 | z.extractall(path=raw_dir) 47 | 48 | # Format 49 | images_dir = osp.join(self.root, 'images') 50 | print(self.root) 51 | mkdir_if_missing(images_dir) 52 | exdir = raw_dir 53 | # Format 54 | images_dir = osp.join(self.root, 'images') 55 | mkdir_if_missing(images_dir) 56 | 57 | identities = [] 58 | all_pids = {} 59 | 60 | def register(subdir, txt_file): 61 | with open(osp.join(exdir, txt_file), "r") as f: 62 | fpaths = f.readlines() 63 | pids = set() 64 | fnames = [] 65 | for fpath in tqdm(fpaths): 66 | fpath = fpath.replace("\n", '') 67 | filename = fpath.split(" ")[0] 68 | pid = int(fpath.split(" ")[1]) 69 | # deal with the query and gallery 70 | if subdir == "mask_test_v2": 71 | pid += 1500 72 | fname = osp.basename(filename) 73 | cam = int(fname.split("_")[2]) 74 | assert 1 <= cam <= 15 75 | cam -= 1 76 | if pid not in all_pids: 77 | all_pids[pid] = len(all_pids) 78 | pid = all_pids[pid] 79 | pids.add(pid) 80 | if pid >= len(identities): 81 | assert pid == len(identities) 82 | identities.append([[] 83 | for _ in range(15)]) # 15 camera views 84 | fname = ('{:08d}_{:02d}_{:04d}.jpg'.format( 85 | pid, cam, len(identities[pid][cam]))) 86 | identities[pid][cam].append(fname) 87 | fnames.append(fname) 88 | shutil.copy(osp.join(exdir, subdir, filename), 89 | osp.join(images_dir, fname)) 90 | return pids, fnames 91 | 92 | trainval_pids, _ = register('mask_train_v2', "list_train.txt") 93 | gallery_pids, gallery_names = register('mask_test_v2', 94 | "list_gallery.txt") 95 | query_pids, query_names = register('mask_test_v2', "list_query.txt") 96 | assert query_pids <= gallery_pids 97 | 98 | # Save meta information into a json file 99 | meta = { 100 | 'name': 'msmt17', 101 | 'shot': 'multiple', 102 | 'num_cameras': 15, 103 | 'identities': identities, 104 | 'gallery_names': gallery_names, 105 | 'query_names': query_names 106 | } 107 | write_json(meta, osp.join(self.root, 'meta.json')) 108 | 109 | # Save the only training / test split 110 | splits = [{ 111 | 'trainval': sorted(list(trainval_pids)), 112 | 'query': sorted(list(query_pids)), 113 | 'gallery': sorted(list(gallery_pids)) 114 | }] 115 | write_json(splits, osp.join(self.root, 'splits.json')) 116 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | from tqdm import tqdm 9 | 10 | 11 | def _unique_sample(ids_dict, num): 12 | mask = np.zeros(num, dtype=np.bool) 13 | for _, indices in ids_dict.items(): 14 | i = np.random.choice(indices) 15 | mask[i] = True 16 | return mask 17 | 18 | 19 | def cmc(distmat, query_ids=None, gallery_ids=None, 20 | query_cams=None, gallery_cams=None, topk=100, 21 | separate_camera_set=False, 22 | single_gallery_shot=False, 23 | first_match_break=False): 24 | distmat = to_numpy(distmat) 25 | m, n = distmat.shape 26 | # Fill up default values 27 | if query_ids is None: 28 | query_ids = np.arange(m) 29 | if gallery_ids is None: 30 | gallery_ids = np.arange(n) 31 | if query_cams is None: 32 | query_cams = np.zeros(m).astype(np.int32) 33 | if gallery_cams is None: 34 | gallery_cams = np.ones(n).astype(np.int32) 35 | # Ensure numpy array 36 | query_ids = np.asarray(query_ids) 37 | gallery_ids = np.asarray(gallery_ids) 38 | query_cams = np.asarray(query_cams) 39 | gallery_cams = np.asarray(gallery_cams) 40 | # Sort and find correct matches 41 | indices = np.argsort(distmat, axis=1) 42 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 43 | # Compute CMC for each query 44 | ret = np.zeros(topk) 45 | num_valid_queries = 0 46 | for i in tqdm(range(m)): 47 | # Filter out the same id and same camera 48 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 49 | (gallery_cams[indices[i]] != query_cams[i])) 50 | if separate_camera_set: 51 | # Filter out samples from same camera 52 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 53 | if not np.any(matches[i, valid]): continue 54 | if single_gallery_shot: 55 | repeat = 10 56 | gids = gallery_ids[indices[i][valid]] 57 | inds = np.where(valid)[0] 58 | ids_dict = defaultdict(list) 59 | for j, x in zip(inds, gids): 60 | ids_dict[x].append(j) 61 | else: 62 | repeat = 1 63 | for _ in range(repeat): 64 | if single_gallery_shot: 65 | # Randomly choose one instance for each id 66 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 67 | index = np.nonzero(matches[i, sampled])[0] 68 | else: 69 | index = np.nonzero(matches[i, valid])[0] 70 | delta = 1. / (len(index) * repeat) 71 | for j, k in enumerate(index): 72 | if k - j >= topk: break 73 | if first_match_break: 74 | ret[k - j] += 1 75 | break 76 | ret[k - j] += delta 77 | num_valid_queries += 1 78 | if num_valid_queries == 0: 79 | raise RuntimeError("No valid query") 80 | return ret.cumsum() / num_valid_queries 81 | 82 | 83 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 84 | query_cams=None, gallery_cams=None): 85 | distmat = to_numpy(distmat) 86 | m, n = distmat.shape 87 | # Fill up default values 88 | if query_ids is None: 89 | query_ids = np.arange(m) 90 | if gallery_ids is None: 91 | gallery_ids = np.arange(n) 92 | if query_cams is None: 93 | query_cams = np.zeros(m).astype(np.int32) 94 | if gallery_cams is None: 95 | gallery_cams = np.ones(n).astype(np.int32) 96 | # Ensure numpy array 97 | query_ids = np.asarray(query_ids) 98 | gallery_ids = np.asarray(gallery_ids) 99 | query_cams = np.asarray(query_cams) 100 | gallery_cams = np.asarray(gallery_cams) 101 | # Sort and find correct matches 102 | indices = np.argsort(distmat, axis=1) 103 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 104 | # Compute AP for each query 105 | aps = [] 106 | for i in tqdm(range(m)): 107 | # Filter out the same id and same camera 108 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 109 | (gallery_cams[indices[i]] != query_cams[i])) 110 | y_true = matches[i, valid] 111 | y_score = -distmat[i][indices[i]][valid] 112 | if not np.any(y_true): continue 113 | aps.append(average_precision_score(y_true, y_score)) 114 | if len(aps) == 0: 115 | raise RuntimeError("No valid query") 116 | return np.mean(aps) 117 | -------------------------------------------------------------------------------- /reid/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | class TripletLoss(nn.Module): 8 | def __init__(self, margin=0): 9 | super(TripletLoss, self).__init__() 10 | self.margin = margin 11 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 12 | 13 | def forward(self, inputs, targets): 14 | n = inputs.size(0) 15 | # Compute pairwise distance, replace by the official when merged 16 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 17 | dist = dist + dist.t() 18 | dist.addmm_(inputs, inputs.t(), beta=1., alpha=-2.) 19 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 20 | # For each anchor, find the hardest positive and negative 21 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 22 | dist_ap, dist_an = [], [] 23 | for i in range(n): 24 | dist_ap.append(dist[i][mask[i]].max().view(1)) 25 | dist_an.append(dist[i][mask[i] == 0].min().view(1)) 26 | dist_ap = torch.cat(dist_ap, dim=0) 27 | dist_an = torch.cat(dist_an, dim=0) 28 | # Compute ranking hinge loss 29 | y = dist_an.data.new() 30 | y.resize_as_(dist_an.data) 31 | y.fill_(1) 32 | loss = self.ranking_loss(dist_an, dist_ap, y) 33 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 34 | return loss, prec 35 | 36 | 37 | def euclidean_dist(x, y): 38 | m, n = x.size(0), y.size(0) 39 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 40 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 41 | dist = xx + yy 42 | dist.addmm_(1, -2, x, y.t()) 43 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 44 | return dist 45 | 46 | 47 | def cosine_dist(x, y): 48 | bs1, bs2 = x.size(0), y.size(0) 49 | frac_up = torch.matmul(x, y.transpose(0, 1)) 50 | frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \ 51 | (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1) 52 | cosine = frac_up / frac_down 53 | return 1 - cosine 54 | 55 | 56 | def _batch_hard(mat_distance, mat_similarity, indice=False): 57 | sorted_mat_distance, positive_indices = torch.sort( 58 | mat_distance + (-9999999.) * (1 - mat_similarity), 59 | dim=1, 60 | descending=True) 61 | hard_p = sorted_mat_distance[:, 0] 62 | hard_p_indice = positive_indices[:, 0] 63 | sorted_mat_distance, negative_indices = torch.sort( 64 | mat_distance + (9999999.) * (mat_similarity), dim=1, descending=False) 65 | hard_n = sorted_mat_distance[:, 0] 66 | hard_n_indice = negative_indices[:, 0] 67 | if (indice): 68 | return hard_p, hard_n, hard_p_indice, hard_n_indice 69 | return hard_p, hard_n 70 | 71 | 72 | class SoftTripletLoss(nn.Module): 73 | def __init__(self, margin=None, normalize_feature=False): 74 | super(SoftTripletLoss, self).__init__() 75 | self.margin = margin 76 | self.normalize_feature = normalize_feature 77 | 78 | def forward(self, emb1, emb2, label): 79 | if self.normalize_feature: 80 | # equal to cosine similarity 81 | emb1 = F.normalize(emb1) 82 | emb2 = F.normalize(emb2) 83 | 84 | mat_dist = euclidean_dist(emb1, emb1) 85 | assert mat_dist.size(0) == mat_dist.size(1) 86 | N = mat_dist.size(0) 87 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 88 | 89 | dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, 90 | mat_sim, 91 | indice=True) 92 | assert dist_an.size(0) == dist_ap.size(0) 93 | triple_dist = torch.stack((dist_ap, dist_an), dim=1) 94 | triple_dist = F.log_softmax(triple_dist, dim=1) 95 | if (self.margin is not None): 96 | loss = (-self.margin * triple_dist[:, 0] - 97 | (1 - self.margin) * triple_dist[:, 1]).mean() 98 | return loss 99 | 100 | mat_dist_ref = euclidean_dist(emb2, emb2) 101 | dist_ap_ref = torch.gather(mat_dist_ref, 1, 102 | ap_idx.view(N, 1).expand(N, N))[:, 0] 103 | dist_an_ref = torch.gather(mat_dist_ref, 1, 104 | an_idx.view(N, 1).expand(N, N))[:, 0] 105 | triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1) 106 | triple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach() 107 | 108 | loss = (-triple_dist_ref * triple_dist).mean(0).sum() 109 | return loss 110 | -------------------------------------------------------------------------------- /reid/utils/data/camera_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..serialization import read_json 7 | 8 | 9 | def _pluck(identities, indices, relabel=False, validate_names=None, camera_id=None): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, cam_images in enumerate(pid_images): 14 | for fname in cam_images: 15 | if validate_names is not None: 16 | if fname not in validate_names: 17 | continue 18 | name = osp.splitext(fname)[0] 19 | x, y, _ = map(int, name.split('_')) 20 | assert pid == x and camid == y 21 | if relabel: 22 | if camid == camera_id and camera_id is not None: 23 | ret.append((fname, index, camid)) 24 | elif camera_id is None: 25 | ret.append((fname, index, camid)) 26 | else: 27 | if camid == camera_id and camera_id is not None: 28 | ret.append((fname, pid, camid)) 29 | elif camera_id is None: 30 | ret.append((fname, pid, camid)) 31 | return ret 32 | 33 | 34 | class CameraDataset(object): 35 | def __init__(self, root, split_id=0, camera_id=0): 36 | self.root = root 37 | self.split_id = split_id 38 | self.meta = None 39 | self.split = None 40 | self.train, self.val, self.trainval = [], [], [] 41 | self.query, self.gallery = [], [] 42 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 43 | self.camera_id = camera_id 44 | 45 | @property 46 | def images_dir(self): 47 | return osp.join(self.root, 'images') 48 | 49 | def load(self, num_val=0.3, verbose=True): 50 | splits = read_json(osp.join(self.root, 'splits.json')) 51 | if self.split_id >= len(splits): 52 | raise ValueError("split_id exceeds total splits {}" 53 | .format(len(splits))) 54 | self.split = splits[self.split_id] 55 | 56 | # Randomly split train / val 57 | trainval_pids = np.asarray(self.split['trainval']) 58 | np.random.shuffle(trainval_pids) 59 | num = len(trainval_pids) 60 | if isinstance(num_val, float): 61 | num_val = int(round(num * num_val)) 62 | if num_val >= num or num_val < 0: 63 | raise ValueError("num_val exceeds total identities {}" 64 | .format(num)) 65 | train_pids = sorted(trainval_pids[:-num_val]) 66 | val_pids = sorted(trainval_pids[-num_val:]) 67 | 68 | self.meta = read_json(osp.join(self.root, 'meta.json')) 69 | identities = self.meta['identities'] 70 | gallery_names = self.meta.get('gallery_names', None) 71 | if gallery_names is not None: 72 | gallery_names = set(gallery_names) 73 | query_names = self.meta.get('query_names', None) 74 | if query_names is not None: 75 | query_names = set(query_names) 76 | self.train = _pluck(identities, train_pids, relabel=True, camera_id=self.camera_id) 77 | self.val = _pluck(identities, val_pids, relabel=True, camera_id=self.camera_id) 78 | self.trainval = _pluck(identities, trainval_pids, relabel=True, camera_id=self.camera_id) 79 | self.query = _pluck(identities, self.split['query'], validate_names=query_names) 80 | self.gallery = _pluck(identities, self.split['gallery'], validate_names=gallery_names) 81 | self.num_train_ids = len(train_pids) 82 | self.num_val_ids = len(val_pids) 83 | self.num_trainval_ids = len(trainval_pids) 84 | 85 | if verbose: 86 | print(self.__class__.__name__, "dataset loaded") 87 | print(" subset | # ids | # images") 88 | print(" ---------------------------") 89 | print(" train | {:5d} | {:8d}" 90 | .format(self.num_train_ids, len(self.train))) 91 | print(" val | {:5d} | {:8d}" 92 | .format(self.num_val_ids, len(self.val))) 93 | print(" trainval | {:5d} | {:8d}" 94 | .format(self.num_trainval_ids, len(self.trainval))) 95 | print(" query | {:5d} | {:8d}" 96 | .format(len(self.split['query']), len(self.query))) 97 | print(" gallery | {:5d} | {:8d}" 98 | .format(len(self.split['gallery']), len(self.gallery))) 99 | 100 | def _check_integrity(self): 101 | return osp.isdir(osp.join(self.root, 'images')) and \ 102 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 103 | osp.isfile(osp.join(self.root, 'splits.json')) 104 | -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | from .evaluation_metrics import cmc, mean_ap 8 | from .feature_extraction import extract_cnn_feature 9 | from .utils.meters import AverageMeter 10 | 11 | 12 | def extract_features(model, data_loader, print_freq=1, metric=None): 13 | model.eval() 14 | batch_time = AverageMeter() 15 | data_time = AverageMeter() 16 | 17 | features = OrderedDict() 18 | labels = OrderedDict() 19 | 20 | end = time.time() 21 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 22 | data_time.update(time.time() - end) 23 | 24 | outputs = extract_cnn_feature(model, imgs) 25 | for fname, output, pid in zip(fnames, outputs, pids): 26 | features[fname] = output 27 | labels[fname] = pid 28 | 29 | batch_time.update(time.time() - end) 30 | end = time.time() 31 | 32 | if (i + 1) % print_freq == 0: 33 | print('Extract Features: [{}/{}]\t' 34 | 'Time {:.3f} ({:.3f})\t' 35 | 'Data {:.3f} ({:.3f})\t' 36 | .format(i + 1, len(data_loader), 37 | batch_time.val, batch_time.avg, 38 | data_time.val, data_time.avg)) 39 | 40 | return features, labels 41 | 42 | 43 | def pairwise_distance(features, query=None, gallery=None, metric=None): 44 | if query is None and gallery is None: 45 | n = len(features) 46 | x = torch.cat(list(features.values())) 47 | x = x.view(n, -1) 48 | if metric is not None: 49 | x = metric.transform(x) 50 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 51 | dist = dist.expand(n, n) - 2 * torch.mm(x, x.t()) 52 | return dist 53 | 54 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 55 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 56 | m, n = x.size(0), y.size(0) 57 | x = x.view(m, -1) 58 | y = y.view(n, -1) 59 | if metric is not None: 60 | x = metric.transform(x) 61 | y = metric.transform(y) 62 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 63 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 64 | dist.addmm_(1, -2, x, y.t()) 65 | return dist 66 | 67 | 68 | def evaluate_all(distmat, query=None, gallery=None, 69 | query_ids=None, gallery_ids=None, 70 | query_cams=None, gallery_cams=None, 71 | cmc_topk=(1, 5, 10)): 72 | if query is not None and gallery is not None: 73 | query_ids = [pid for _, pid, _ in query] 74 | gallery_ids = [pid for _, pid, _ in gallery] 75 | query_cams = [cam for _, _, cam in query] 76 | gallery_cams = [cam for _, _, cam in gallery] 77 | else: 78 | assert (query_ids is not None and gallery_ids is not None 79 | and query_cams is not None and gallery_cams is not None) 80 | 81 | # Compute mean AP 82 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 83 | print('Mean AP: {:4.1%}'.format(mAP)) 84 | 85 | # Compute all kinds of CMC scores 86 | cmc_configs = { 87 | 'allshots': dict(separate_camera_set=False, 88 | single_gallery_shot=False, 89 | first_match_break=False), 90 | 'cuhk03': dict(separate_camera_set=True, 91 | single_gallery_shot=True, 92 | first_match_break=False), 93 | 'market1501': dict(separate_camera_set=False, 94 | single_gallery_shot=False, 95 | first_match_break=True)} 96 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 97 | query_cams, gallery_cams, **params) 98 | for name, params in cmc_configs.items()} 99 | 100 | print('CMC Scores{:>12}{:>12}{:>12}' 101 | .format('allshots', 'cuhk03', 'market1501')) 102 | for k in cmc_topk: 103 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}' 104 | .format(k, cmc_scores['allshots'][k - 1], 105 | cmc_scores['cuhk03'][k - 1], 106 | cmc_scores['market1501'][k - 1])) 107 | 108 | # Use the allshots cmc top-1 score for validation criterion 109 | return cmc_scores['allshots'][0] 110 | 111 | 112 | class Evaluator(object): 113 | def __init__(self, model): 114 | super(Evaluator, self).__init__() 115 | self.model = model 116 | 117 | def evaluate(self, data_loader, query, gallery, metric=None): 118 | features, _ = extract_features(self.model, data_loader) 119 | distmat = pairwise_distance(features, query, gallery, metric=metric) 120 | return evaluate_all(distmat, query=query, gallery=gallery) 121 | -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data.sampler import ( 9 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 10 | WeightedRandomSampler) 11 | 12 | 13 | class RandomIdentitySampler(Sampler): 14 | """ 15 | Randomly sample N identities, then for each identity, 16 | randomly sample K instances, therefore batch size is N*K. 17 | Args: 18 | - data_source (list): list of (img_path, pid, camid). 19 | - num_instances (int): number of instances per identity in a batch. 20 | - batch_size (int): number of examples in a batch. 21 | """ 22 | 23 | def __init__(self, data_source, batch_size, num_instances): 24 | self.data_source = data_source 25 | self.batch_size = batch_size 26 | self.num_instances = num_instances 27 | self.num_pids_per_batch = self.batch_size // self.num_instances 28 | self.index_dic = defaultdict(list) 29 | for index, (_, pid, _) in enumerate(self.data_source): 30 | self.index_dic[pid].append(index) 31 | self.pids = list(self.index_dic.keys()) 32 | 33 | # estimate number of examples in an epoch 34 | self.length = 0 35 | for pid in self.pids: 36 | idxs = self.index_dic[pid] 37 | num = len(idxs) 38 | if num < self.num_instances: 39 | num = self.num_instances 40 | self.length += num - num % self.num_instances 41 | 42 | def __iter__(self): 43 | batch_idxs_dict = defaultdict(list) 44 | 45 | for pid in self.pids: 46 | idxs = copy.deepcopy(self.index_dic[pid]) 47 | if len(idxs) < self.num_instances: 48 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 49 | random.shuffle(idxs) 50 | batch_idxs = [] 51 | for idx in idxs: 52 | batch_idxs.append(idx) 53 | if len(batch_idxs) == self.num_instances: 54 | batch_idxs_dict[pid].append(batch_idxs) 55 | batch_idxs = [] 56 | 57 | avai_pids = copy.deepcopy(self.pids) 58 | final_idxs = [] 59 | 60 | while len(avai_pids) >= self.num_pids_per_batch: 61 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 62 | for pid in selected_pids: 63 | batch_idxs = batch_idxs_dict[pid].pop(0) 64 | final_idxs.extend(batch_idxs) 65 | if len(batch_idxs_dict[pid]) == 0: 66 | avai_pids.remove(pid) 67 | 68 | self.length = len(final_idxs) 69 | return iter(final_idxs) 70 | 71 | def __len__(self): 72 | return self.length 73 | 74 | 75 | class BothRandomIdentitySampler(Sampler): 76 | """ 77 | Randomly sample N identities, then for each identity, 78 | randomly sample K instances, therefore batch size is N*K. 79 | Args: 80 | - data_source (list): list of (img_path, pid, camid). 81 | - num_instances (int): number of instances per identity in a batch. 82 | - batch_size (int): number of examples in a batch. 83 | """ 84 | 85 | def __init__(self, data_source, batch_size, num_instances): 86 | self.data_source = data_source 87 | self.batch_size = batch_size 88 | self.num_instances = num_instances 89 | self.num_pids_per_batch = self.batch_size // self.num_instances 90 | self.index_dic = defaultdict(list) 91 | for index, (_, pid, _, _) in enumerate(self.data_source): 92 | self.index_dic[pid].append(index) 93 | self.pids = list(self.index_dic.keys()) 94 | 95 | # estimate number of examples in an epoch 96 | self.length = 0 97 | for pid in self.pids: 98 | idxs = self.index_dic[pid] 99 | num = len(idxs) 100 | if num < self.num_instances: 101 | num = self.num_instances 102 | self.length += num - num % self.num_instances 103 | 104 | def __iter__(self): 105 | batch_idxs_dict = defaultdict(list) 106 | 107 | for pid in self.pids: 108 | idxs = copy.deepcopy(self.index_dic[pid]) 109 | if len(idxs) < self.num_instances: 110 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 111 | random.shuffle(idxs) 112 | batch_idxs = [] 113 | for idx in idxs: 114 | batch_idxs.append(idx) 115 | if len(batch_idxs) == self.num_instances: 116 | batch_idxs_dict[pid].append(batch_idxs) 117 | batch_idxs = [] 118 | 119 | avai_pids = copy.deepcopy(self.pids) 120 | final_idxs = [] 121 | 122 | while len(avai_pids) >= self.num_pids_per_batch: 123 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 124 | for pid in selected_pids: 125 | batch_idxs = batch_idxs_dict[pid].pop(0) 126 | final_idxs.extend(batch_idxs) 127 | if len(batch_idxs_dict[pid]) == 0: 128 | avai_pids.remove(pid) 129 | 130 | self.length = len(final_idxs) 131 | return iter(final_idxs) 132 | 133 | def __len__(self): 134 | return self.length 135 | 136 | -------------------------------------------------------------------------------- /reid/models/backbones/TNorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class TNorm(nn.Module): 6 | def __init__(self, num_features, domain_number, using_moving_average=True, eps=1e-5, momentum=0.9): 7 | super(TNorm, self).__init__() 8 | self.num_features = num_features 9 | self.eps = eps 10 | self.momentum = momentum 11 | self.domain_number = domain_number 12 | self.using_moving_average = using_moving_average 13 | self.register_buffer('running_mean', torch.zeros((domain_number, num_features))) 14 | self.register_buffer('running_var', torch.ones((domain_number, num_features))) 15 | self.register_buffer('num_batches_tracked', torch.zeros(domain_number, dtype=torch.long)) 16 | self.reset_parameters() 17 | 18 | def reset_parameters(self): 19 | self.running_mean.zero_() 20 | self.running_var.fill_(1) 21 | self.num_batches_tracked.zero_() 22 | 23 | def _check_input_dim(self, input): 24 | if input.dim() != 4: 25 | raise ValueError('expected 4D input (got {}D input)' 26 | .format(input.dim())) 27 | 28 | def forward(self, x, domain_index=None, convert=False, selected_domain=None): 29 | self._check_input_dim(x) 30 | N, C, H, W = x.size() 31 | x = x.view(N, C, -1) 32 | mean_in = x.mean(-1, keepdim=True) 33 | var_in = x.var(-1, keepdim=True) 34 | 35 | temp = var_in + mean_in ** 2 36 | 37 | if self.training: 38 | 39 | if convert: 40 | assert selected_domain is not None 41 | if domain_index is not None and type(domain_index) != int: 42 | mean_bn = self.running_mean[domain_index, :].view(N, C, 1) 43 | var_bn = self.running_var[domain_index, :].view(N, C, 1) 44 | x_after_in = (x - mean_bn) / (var_bn + self.eps).sqrt() 45 | else: 46 | mean_bn = mean_in.mean(0, keepdim=True) 47 | var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2 48 | sig = (var_bn + self.eps).sqrt() 49 | mean_bn, sig = mean_bn.detach(), sig.detach() 50 | x_after_in = (x - mean_bn) / sig 51 | 52 | convert_mean = self.running_mean[selected_domain, :].view(N, C, 1) # N * C 53 | convert_var = self.running_var[selected_domain, :].view(N, C, 1) # N * C 54 | x = x_after_in * (convert_var + self.eps).sqrt() + convert_mean 55 | 56 | if domain_index is not None and type(domain_index) == int: 57 | if self.using_moving_average: 58 | self.running_mean[domain_index].mul_(self.momentum) 59 | self.running_mean[domain_index].add_((1 - self.momentum) * mean_bn.squeeze().data) 60 | self.running_var[domain_index].mul_(self.momentum) 61 | self.running_var[domain_index].add_((1 - self.momentum) * var_bn.squeeze().data) 62 | else: 63 | self.num_batches_tracked[domain_index] += 1 64 | exponential_average_factor = 1 - 1.0 / self.num_batches_tracked[domain_index] 65 | self.running_mean[domain_index].mul_(exponential_average_factor) 66 | self.running_mean[domain_index].add_( 67 | (1 - exponential_average_factor) * mean_bn.squeeze().data) 68 | self.running_var[domain_index].mul_(exponential_average_factor) 69 | self.running_var[domain_index].add_( 70 | (1 - exponential_average_factor) * var_bn.squeeze().data) 71 | else: 72 | if domain_index is not None: 73 | mean_bn = mean_in.mean(0, keepdim=True) 74 | var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2 75 | if self.using_moving_average: 76 | self.running_mean[domain_index].mul_(self.momentum) 77 | self.running_mean[domain_index].add_((1 - self.momentum) * mean_bn.squeeze().data) 78 | self.running_var[domain_index].mul_(self.momentum) 79 | self.running_var[domain_index].add_((1 - self.momentum) * var_bn.squeeze().data) 80 | else: 81 | self.num_batches_tracked[domain_index] += 1 82 | exponential_average_factor = 1 - 1.0 / self.num_batches_tracked[domain_index] 83 | self.running_mean[domain_index].mul_(exponential_average_factor) 84 | self.running_mean[domain_index].add_( 85 | (1 - exponential_average_factor) * mean_bn.squeeze().data) 86 | self.running_var[domain_index].mul_(exponential_average_factor) 87 | self.running_var[domain_index].add_( 88 | (1 - exponential_average_factor) * var_bn.squeeze().data) 89 | else: 90 | domain_mean_bn = torch.autograd.Variable(self.running_mean) 91 | domain_var_bn = torch.autograd.Variable(self.running_var) 92 | if convert: 93 | assert domain_index is not None 94 | 95 | x_after_in = (x - domain_mean_bn[domain_index[0], :].view(N, C, 1)) / (domain_var_bn[domain_index[0], :].view(N, C, 1) + self.eps).sqrt() 96 | 97 | convert_mean = domain_mean_bn[domain_index[1], :].view(1, C, 1) # N * C 98 | convert_var = domain_var_bn[domain_index[1], :].view(1, C, 1) # N * C 99 | x = x_after_in * (convert_var + self.eps).sqrt() + convert_mean 100 | else: 101 | pass 102 | x = x.view(N, C, H, W) 103 | 104 | return x 105 | -------------------------------------------------------------------------------- /reid/evaluators_cos.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from .evaluation_metrics import cmc, mean_ap 8 | from .feature_extraction import extract_cnn_feature, extract_cnn_feature_specific, extract_cnn_feature_with_tnorm 9 | from .utils.meters import AverageMeter 10 | 11 | import numpy as np 12 | 13 | 14 | def fresh_bn(model, data_loader): 15 | model.train() 16 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 17 | with torch.no_grad(): 18 | outputs = extract_cnn_feature(model, imgs) 19 | print('Fresh BN: [{}/{}]\t'.format(i, len(data_loader))) 20 | 21 | 22 | def extract_features(model, data_loader, print_freq=1, metric=None): 23 | model.eval() 24 | batch_time = AverageMeter() 25 | data_time = AverageMeter() 26 | 27 | features = OrderedDict() 28 | labels = OrderedDict() 29 | 30 | end = time.time() 31 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 32 | data_time.update(time.time() - end) 33 | with torch.no_grad(): 34 | outputs = extract_cnn_feature(model, imgs) 35 | for fname, output, pid in zip(fnames, outputs, pids): 36 | features[fname] = output 37 | labels[fname] = pid 38 | 39 | batch_time.update(time.time() - end) 40 | end = time.time() 41 | 42 | if (i + 1) % print_freq == 0: 43 | print('Extract Features: [{}/{}]\t' 44 | 'Time {:.3f} ({:.3f})\t' 45 | 'Data {:.3f} ({:.3f})\t' 46 | .format(i + 1, len(data_loader), 47 | batch_time.val, batch_time.avg, 48 | data_time.val, data_time.avg)) 49 | 50 | return features, labels 51 | 52 | 53 | def extract_features_tnorm(model, data_loader, print_freq=1, metric=None, camera_number=1): 54 | model.eval() 55 | batch_time = AverageMeter() 56 | data_time = AverageMeter() 57 | 58 | features = OrderedDict() 59 | labels = OrderedDict() 60 | 61 | end = time.time() 62 | for i, (imgs, fnames, pids, camid) in enumerate(data_loader): 63 | data_time.update(time.time() - end) 64 | with torch.no_grad(): 65 | for i in range(camera_number): 66 | t = extract_cnn_feature_with_tnorm(model, 67 | imgs, 68 | camid, 69 | i, 70 | norm=False) 71 | if i == 0: 72 | tmp = t 73 | else: 74 | tmp = tmp + t 75 | outputs = F.normalize(tmp, p=2, dim=1) 76 | for fname, output, pid in zip(fnames, outputs, pids): 77 | features[fname] = output 78 | labels[fname] = pid 79 | 80 | batch_time.update(time.time() - end) 81 | end = time.time() 82 | 83 | if (i + 1) % print_freq == 0: 84 | print('Extract Features: [{}/{}]\t' 85 | 'Time {:.3f} ({:.3f})\t' 86 | 'Data {:.3f} ({:.3f})\t' 87 | .format(i + 1, len(data_loader), 88 | batch_time.val, batch_time.avg, 89 | data_time.val, data_time.avg)) 90 | 91 | return features, labels 92 | 93 | 94 | def extract_features_specific(model, data_loader, print_freq=1, metric=None): 95 | model.eval() 96 | batch_time = AverageMeter() 97 | data_time = AverageMeter() 98 | 99 | features = OrderedDict() 100 | labels = OrderedDict() 101 | 102 | end = time.time() 103 | for i, (imgs, fnames, pids, camid) in enumerate(data_loader): 104 | data_time.update(time.time() - end) 105 | with torch.no_grad(): 106 | domain_index = camid.cuda() 107 | outputs = extract_cnn_feature_specific(model, imgs, domain_index) 108 | for fname, output, pid in zip(fnames, outputs, pids): 109 | features[fname] = output 110 | labels[fname] = pid 111 | 112 | batch_time.update(time.time() - end) 113 | end = time.time() 114 | 115 | if (i + 1) % print_freq == 0: 116 | print('Extract Features: [{}/{}]\t' 117 | 'Time {:.3f} ({:.3f})\t' 118 | 'Data {:.3f} ({:.3f})\t' 119 | .format(i + 1, len(data_loader), 120 | batch_time.val, batch_time.avg, 121 | data_time.val, data_time.avg)) 122 | 123 | return features, labels 124 | 125 | 126 | def pairwise_distance(features, query=None, gallery=None, metric=None, use_cpu=False): 127 | if query is None and gallery is None: 128 | n = len(features) 129 | x = torch.cat(list(features.values())) 130 | x = x.view(n, -1) 131 | if metric is not None: 132 | x = metric.transform(x) 133 | dist = 1 - torch.mm(x, x.t()) 134 | return dist 135 | 136 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 137 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 138 | m, n = x.size(0), y.size(0) 139 | x = x.view(m, -1) 140 | y = y.view(n, -1) 141 | if metric is not None: 142 | x = metric.transform(x) 143 | y = metric.transform(y) 144 | if use_cpu: 145 | dist = 1 - np.matmul(x.cpu().numpy(), y.cpu().numpy().T) 146 | dist = np.array(dist) 147 | else: 148 | dist = 1 - torch.mm(x, y.t()) 149 | return dist 150 | 151 | 152 | def evaluate_all(distmat, query=None, gallery=None, 153 | query_ids=None, gallery_ids=None, 154 | query_cams=None, gallery_cams=None, 155 | cmc_topk=(1, 5, 10), return_mAP=False): 156 | if query is not None and gallery is not None: 157 | query_ids = [pid for _, pid, _ in query] 158 | gallery_ids = [pid for _, pid, _ in gallery] 159 | query_cams = [cam for _, _, cam in query] 160 | gallery_cams = [cam for _, _, cam in gallery] 161 | else: 162 | assert (query_ids is not None and gallery_ids is not None 163 | and query_cams is not None and gallery_cams is not None) 164 | 165 | # Compute mean AP 166 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 167 | print('Mean AP: {:4.1%}'.format(mAP)) 168 | 169 | # Compute all kinds of CMC scores 170 | cmc_configs = { 171 | 'market1501': dict(separate_camera_set=False, 172 | single_gallery_shot=False, 173 | first_match_break=True)} 174 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 175 | query_cams, gallery_cams, **params) 176 | for name, params in cmc_configs.items()} 177 | 178 | print('CMC Scores{:>12}' 179 | .format('market1501')) 180 | for k in cmc_topk: 181 | print(' top-{:<4}{:12.1%}' 182 | .format(k, 183 | cmc_scores['market1501'][k - 1])) 184 | 185 | # Use the allshots cmc top-1 score for validation criterion 186 | if return_mAP: 187 | return cmc_scores['market1501'][0], mAP 188 | else: 189 | return cmc_scores['market1501'][0] 190 | 191 | 192 | class Evaluator(object): 193 | def __init__(self, model, use_cpu=False): 194 | super(Evaluator, self).__init__() 195 | self.model = model 196 | self.use_cpu = use_cpu 197 | 198 | def evaluate(self, data_loader, query, gallery, metric=None, return_mAP=False): 199 | features, _ = extract_features(self.model, data_loader) 200 | distmat = pairwise_distance(features, query, gallery, metric=metric, use_cpu=self.use_cpu) 201 | return evaluate_all(distmat, query=query, gallery=gallery, return_mAP=return_mAP) 202 | 203 | def evaluate_specific(self, data_loader, query, gallery, metric=None, return_mAP=False): 204 | features, _ = extract_features_specific(self.model, data_loader) 205 | distmat = pairwise_distance(features, query, gallery, metric=metric, use_cpu=self.use_cpu) 206 | return evaluate_all(distmat, query=query, gallery=gallery, return_mAP=return_mAP) 207 | 208 | def evaluate_tnorm(self, data_loader, query, gallery, metric=None, return_mAP=False, camera_number=1): 209 | features, _ = extract_features_tnorm(self.model, data_loader, camera_number=camera_number) 210 | distmat = pairwise_distance(features, query, gallery, metric=metric, use_cpu=self.use_cpu) 211 | return evaluate_all(distmat, query=query, gallery=gallery, return_mAP=return_mAP) 212 | 213 | def fresh_bn(self, data_loader): 214 | fresh_bn(self.model, data_loader) 215 | -------------------------------------------------------------------------------- /reid/models/ft_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torchvision import models 5 | from .backbones.resnet import AIBNResNet, TNormResNet 6 | 7 | 8 | def weights_init_kaiming(m): 9 | classname = m.__class__.__name__ 10 | # print(classname) 11 | if classname.find('Conv') != -1: 12 | init.kaiming_normal_( 13 | m.weight.data, a=0, 14 | mode='fan_in') # For old pytorch, you may use kaiming_normal. 15 | elif classname.find('Linear') != -1: 16 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 17 | init.constant_(m.bias.data, 0.0) 18 | elif classname.find('BatchNorm1d') != -1: 19 | init.normal_(m.weight.data, 1.0, 0.02) 20 | init.constant_(m.bias.data, 0.0) 21 | 22 | 23 | def weights_init_classifier(m): 24 | classname = m.__class__.__name__ 25 | if classname.find('Linear') != -1: 26 | init.normal_(m.weight.data, std=0.001) 27 | if m.bias is not None: 28 | init.constant_(m.bias.data, 0.0) 29 | 30 | 31 | 32 | class ft_net_intra_TNorm(nn.Module): 33 | def __init__(self, num_classes, stride=1, init_weight=0.1): 34 | super(ft_net_intra_TNorm, self).__init__() 35 | model_ft = TNormResNet(domain_number=len(num_classes), 36 | last_stride=stride, 37 | layers=[3, 4, 6, 3], 38 | init_weight=init_weight) 39 | 40 | self.model = model_ft 41 | self.classifier = nn.ModuleList([ 42 | nn.Sequential(nn.BatchNorm1d(2048), nn.Linear(2048, 43 | num, 44 | bias=False)) 45 | for num in num_classes 46 | ]) 47 | for classifier_one in self.classifier: 48 | init.normal_(classifier_one[1].weight.data, std=0.001) 49 | init.constant_(classifier_one[0].weight.data, 1.0) 50 | init.constant_(classifier_one[0].bias.data, 0.0) 51 | classifier_one[0].bias.requires_grad_(False) 52 | 53 | def backbone_forward(self, x, domain_index=None, convert=False): 54 | x = self.model(x, domain_index=domain_index, convert=convert) 55 | return x 56 | 57 | def forward(self, x, k=0, convert=False): 58 | x = self.backbone_forward(x, domain_index=k, convert=convert) 59 | x = x.view(x.size(0), x.size(1)) 60 | x = self.classifier[k](x) 61 | return x 62 | 63 | 64 | class ft_net_inter_TNorm(nn.Module): 65 | def __init__(self, num_classes, domain_number, stride=1, init_weight=0.1): 66 | super(ft_net_inter_TNorm, self).__init__() 67 | # domain number only for param initialization has no meaning 68 | model_ft = TNormResNet(domain_number, 69 | last_stride=stride, 70 | layers=[3, 4, 6, 3], 71 | init_weight=init_weight 72 | ) 73 | 74 | self.model = model_ft 75 | self.classifier = nn.Sequential( 76 | nn.BatchNorm1d(2048), nn.Linear(2048, num_classes, bias=False)) 77 | init.normal_(self.classifier[1].weight.data, std=0.001) 78 | init.constant_(self.classifier[0].weight.data, 1.0) 79 | init.constant_(self.classifier[0].bias.data, 0.0) 80 | self.classifier[0].bias.requires_grad_(False) 81 | 82 | def backbone_forward(self, x, domain_index=None, convert=False): 83 | x = self.model(x, domain_index=domain_index, convert=convert) 84 | return x 85 | 86 | def forward(self, x, domain_index=None, convert=False): 87 | x = self.backbone_forward(x, domain_index=domain_index, convert=convert) 88 | x = x.view(x.size(0), x.size(1)) 89 | prob = self.classifier(x) 90 | return prob, x 91 | 92 | 93 | class ft_net_intra(nn.Module): 94 | def __init__(self, num_classes, stride=1, init_weight=0.1, with_permute_adain=False): 95 | super(ft_net_intra, self).__init__() 96 | model_ft = AIBNResNet(last_stride=stride, 97 | layers=[3, 4, 6, 3], 98 | init_weight=init_weight, 99 | with_permute_adain=with_permute_adain) 100 | 101 | self.model = model_ft 102 | self.classifier = nn.ModuleList([ 103 | nn.Sequential(nn.BatchNorm1d(2048), nn.Linear(2048, 104 | num, 105 | bias=False)) 106 | for num in num_classes 107 | ]) 108 | for classifier_one in self.classifier: 109 | init.normal_(classifier_one[1].weight.data, std=0.001) 110 | init.constant_(classifier_one[0].weight.data, 1.0) 111 | init.constant_(classifier_one[0].bias.data, 0.0) 112 | classifier_one[0].bias.requires_grad_(False) 113 | 114 | def backbone_forward(self, x): 115 | x = self.model(x) 116 | return x 117 | 118 | def forward(self, x, k=0): 119 | x = self.backbone_forward(x) 120 | x = x.view(x.size(0), x.size(1)) 121 | x = self.classifier[k](x) 122 | return x 123 | 124 | 125 | class ft_net_inter(nn.Module): 126 | def __init__(self, num_classes, stride=1, init_weight=0.1, with_permute_adain=False): 127 | super(ft_net_inter, self).__init__() 128 | model_ft = AIBNResNet(last_stride=stride, 129 | layers=[3, 4, 6, 3], 130 | init_weight=init_weight, 131 | with_permute_adain=with_permute_adain) 132 | 133 | self.model = model_ft 134 | self.classifier = nn.Sequential( 135 | nn.BatchNorm1d(2048), nn.Linear(2048, num_classes, bias=False)) 136 | init.normal_(self.classifier[1].weight.data, std=0.001) 137 | init.constant_(self.classifier[0].weight.data, 1.0) 138 | init.constant_(self.classifier[0].bias.data, 0.0) 139 | self.classifier[0].bias.requires_grad_(False) 140 | 141 | def backbone_forward(self, x): 142 | x = self.model(x) 143 | return x 144 | 145 | def forward(self, x): 146 | x = self.backbone_forward(x) 147 | x = x.view(x.size(0), x.size(1)) 148 | prob = self.classifier(x) 149 | return prob, x 150 | 151 | 152 | class ft_net_both(nn.Module): 153 | def __init__(self, cam_num_classes, global_num_classes, stride=1): 154 | super(ft_net_both, self).__init__() 155 | model_ft = AIBNResNet(last_stride=stride, layers=[3, 4, 6, 3]) 156 | 157 | self.model = model_ft 158 | self.bn_neck = nn.BatchNorm1d(2048) 159 | self.global_classifier = nn.Linear(2048, 160 | global_num_classes, 161 | bias=False) 162 | self.classifier = nn.ModuleList([ 163 | nn.Linear(2048, cam_num_classes[i], bias=False) 164 | for i in range(len(cam_num_classes)) 165 | ]) 166 | self.intra_loss = nn.CrossEntropyLoss() 167 | init.normal_(self.global_classifier.weight.data, std=0.001) 168 | 169 | init.constant_(self.bn_neck.weight.data, 1.0) 170 | init.constant_(self.bn_neck.bias.data, 0.0) 171 | for cam in self.classifier: 172 | init.normal_(cam.weight.data, std=0.001) 173 | self.bn_neck.bias.requires_grad_(False) 174 | 175 | def backbone_forward(self, x): 176 | x = self.model(x) 177 | return x 178 | 179 | def forward(self, x, targets, domain_targets, camid): 180 | x = self.backbone_forward(x) 181 | x = x.view(x.size(0), x.size(1)) 182 | prob = self.bn_neck(x) 183 | global_prob = self.global_classifier(prob) 184 | unique_camids = torch.unique(camid) 185 | intra_loss = 0 186 | for index, current in enumerate(unique_camids): 187 | current_camid = (camid == current).nonzero().view(-1) 188 | data = torch.index_select(prob, index=current_camid, dim=0) 189 | pids = torch.index_select(domain_targets, 190 | index=current_camid, 191 | dim=0) 192 | intra_out = self.classifier[current](data) 193 | intra_loss = intra_loss + self.intra_loss(intra_out, pids) 194 | intra_loss /= len(unique_camids) 195 | return global_prob, x, intra_loss 196 | 197 | 198 | class ft_net_intra_specific(nn.Module): 199 | def __init__(self, domain_number, num_classes, stride=1): 200 | super(ft_net_intra_specific, self).__init__() 201 | model_ft = CameraAIBNResNet(domain_number=domain_number, 202 | last_stride=stride, 203 | layers=[3, 4, 6, 3]) 204 | 205 | self.model = model_ft 206 | self.bn_neck = CameraBNorm1d(2048, domain_number) 207 | self.classifier = nn.ModuleList( 208 | [nn.Linear(2048, num, bias=False) for num in num_classes]) 209 | for classifier_one in self.classifier: 210 | init.normal_(classifier_one.weight.data, std=0.001) 211 | init.constant_(self.bn_neck.weight.data, 1.0) 212 | init.constant_(self.bn_neck.bias.data, 0.0) 213 | self.bn_neck.bias.requires_grad_(False) 214 | 215 | def backbone_forward(self, x, domain_index, using_running=False): 216 | x = self.model(x, domain_index, using_running) 217 | return x 218 | 219 | def forward(self, x, k=0, using_running=False): 220 | x = self.backbone_forward(x, k, using_running) 221 | x = x.view(x.size(0), x.size(1)) 222 | x = self.bn_neck(x, k, using_running) 223 | x = self.classifier[k](x) 224 | return x 225 | 226 | 227 | class ft_net_inter_specific(nn.Module): 228 | def __init__(self, domain_number, num_classes, stride=1): 229 | super(ft_net_inter_specific, self).__init__() 230 | model_ft = CameraAIBNResNet(domain_number=domain_number, 231 | last_stride=stride, 232 | layers=[3, 4, 6, 3]) 233 | self.model = model_ft 234 | self.bn_neck = CameraBNorm1d(2048, domain_number) 235 | self.classifier = nn.Linear(2048, num_classes, bias=False) 236 | init.normal_(self.classifier.weight.data, std=0.001) 237 | init.constant_(self.bn_neck.weight.data, 1.0) 238 | init.constant_(self.bn_neck.bias.data, 0.0) 239 | self.bn_neck.bias.requires_grad_(False) 240 | 241 | def backbone_forward(self, x, domain_index, using_running=True): 242 | x = self.model(x, domain_index, using_running=using_running) 243 | return x 244 | 245 | def forward(self, x, domain_index, targets, using_running=True): 246 | unique_camids = torch.unique(domain_index) 247 | success = 0 248 | for index, current in enumerate(unique_camids): 249 | current_camid = (domain_index == current).nonzero().view(-1) 250 | if current_camid.size(0) > 1: 251 | data = torch.index_select(x, index=current_camid, dim=0) 252 | pids = torch.index_select(targets, index=current_camid, dim=0) 253 | out = self.backbone_forward(data, current, False) 254 | out = out.view(out.size(0), out.size(1)) 255 | out = self.bn_neck(out, current, using_running=False) 256 | if success == 0: 257 | out_features = out 258 | out_targets = pids 259 | else: 260 | out_features = torch.cat((out_features, out), dim=0) 261 | out_targets = torch.cat((out_targets, pids), dim=0) 262 | success += 1 263 | prob = self.classifier(out_features) 264 | return prob, out_features, out_targets 265 | 266 | 267 | class ft_net_test(nn.Module): 268 | def __init__(self, domain_number, stride=1): 269 | super(ft_net_test, self).__init__() 270 | model_ft = CameraAIBNResNet(domain_number=domain_number, 271 | last_stride=stride, 272 | layers=[3, 4, 6, 3]) 273 | 274 | self.model = model_ft 275 | 276 | def backbone_forward(self, x, domain_index, using_running=False): 277 | x = self.model(x, domain_index, using_running) 278 | return x 279 | 280 | def forward(self, x, k=0, using_running=False): 281 | x = self.backbone_forward(x, k, using_running) 282 | x = x.view(x.size(0), x.size(1)) 283 | return x 284 | -------------------------------------------------------------------------------- /reid/cluster_utils/cluster.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn.functional as F 5 | from reid.feature_extraction import extract_cnn_feature, extract_cnn_feature_with_tnorm 6 | from tqdm import tqdm 7 | from sklearn.cluster import AgglomerativeClustering 8 | import numpy as np 9 | from .rerank import re_ranking 10 | 11 | 12 | def to_numpy(tensor): 13 | if torch.is_tensor(tensor): 14 | return tensor.cpu().numpy() 15 | elif type(tensor).__module__ != 'numpy': 16 | raise ValueError("Cannot convert {} to numpy array".format( 17 | type(tensor))) 18 | return tensor 19 | 20 | 21 | def extract_features_per_cam(model, data_loader, norm=True): 22 | model.eval() 23 | per_cam_features_without = {} 24 | per_cam_features_norm = {} 25 | per_cam_fname = {} 26 | print("Start extract features per camera") 27 | for imgs, fnames, _, camid in tqdm(data_loader): 28 | camid = list(camid) 29 | for cam in camid: 30 | cam = cam.item() 31 | if cam not in per_cam_features_without.keys(): 32 | per_cam_features_without[cam] = [] 33 | per_cam_fname[cam] = [] 34 | per_cam_features_norm[cam] = [] 35 | with torch.no_grad(): 36 | outputs = extract_cnn_feature(model, imgs, norm=False) 37 | 38 | fnorm = torch.norm(outputs, p=2, dim=1, keepdim=True) 39 | norm_outputs = outputs.div(fnorm.expand_as(outputs)) 40 | if norm: 41 | for fname, output, cam in zip(fnames, norm_outputs, camid): 42 | cam = cam.item() 43 | per_cam_features_norm[cam].append(output) 44 | per_cam_fname[cam].append(fname) 45 | else: 46 | for fname, output, output_without, cam in zip( 47 | fnames, norm_outputs, outputs, camid): 48 | cam = cam.item() 49 | per_cam_features_norm[cam].append(output) 50 | per_cam_fname[cam].append(fname) 51 | per_cam_features_without[cam].append(output_without) 52 | if norm: 53 | return per_cam_features_norm, per_cam_fname 54 | else: 55 | return per_cam_features_norm, per_cam_features_without, per_cam_fname 56 | 57 | 58 | def extract_features_cross_cam(model, data_loader, norm=True, bn_neck=False): 59 | model.eval() 60 | cross_cam_features = [] 61 | cross_cam_features_without = [] 62 | cross_cam_fnames = [] 63 | cross_cam_distribute = [] 64 | cams = [] 65 | cam_number = len(model.classifier) 66 | print("Start extract features cross camera") 67 | for imgs, fnames, _, camid in tqdm(data_loader): 68 | with torch.no_grad(): 69 | outputs = extract_cnn_feature(model, imgs, norm=False) 70 | fnorm = torch.norm(outputs, p=2, dim=1, keepdim=True) 71 | norm_outputs = outputs.div(fnorm.expand_as(outputs)) 72 | if bn_neck: 73 | outputs = model.bn_neck(outputs) 74 | for i in range(cam_number): 75 | x = model.classifier[i](outputs) 76 | if i == 0: 77 | distribute = F.softmax(x.data, dim=1) 78 | else: 79 | distribute_tmp = F.softmax(x.data, dim=1) 80 | distribute = torch.cat((distribute, distribute_tmp), dim=1) 81 | if norm: 82 | for fname, output, cam, dis in zip(fnames, norm_outputs, camid, 83 | distribute): 84 | cam = cam.item() 85 | cross_cam_fnames.append(fname) 86 | cross_cam_features.append(output) 87 | cams.append(cam) 88 | cross_cam_distribute.append(dis.cpu().numpy()) 89 | else: 90 | for fname, output, output_without, cam, dis in zip( 91 | fnames, norm_outputs, outputs, camid, distribute): 92 | cam = cam.item() 93 | cross_cam_fnames.append(fname) 94 | cross_cam_features.append(output) 95 | cross_cam_features_without.append(output_without) 96 | cams.append(cam) 97 | cross_cam_distribute.append(dis.cpu().numpy()) 98 | if norm: 99 | return cross_cam_features, cross_cam_fnames, cross_cam_distribute, cams 100 | else: 101 | return cross_cam_features, cross_cam_features_without, cross_cam_fnames, cross_cam_distribute, cams 102 | 103 | 104 | def extract_features_cross_cam_with_tnorm(model, data_loader): 105 | model.eval() 106 | cross_cam_features = [] 107 | cross_cam_fnames = [] 108 | cross_cam_distribute = [] 109 | cams = [] 110 | cam_number = len(model.classifier) 111 | print("Start extract features cross camera") 112 | for imgs, fnames, _, camid in tqdm(data_loader): 113 | 114 | with torch.no_grad(): 115 | for i in range(cam_number): 116 | t = extract_cnn_feature_with_tnorm(model, 117 | imgs, 118 | camid, 119 | i, 120 | norm=False) 121 | if i == 0: 122 | tmp = t 123 | else: 124 | tmp = tmp + t 125 | x = model.classifier[i](t) 126 | if i == 0: 127 | distribute = F.softmax(x.data, dim=1) 128 | else: 129 | distribute_tmp = F.softmax(x.data, dim=1) 130 | distribute = torch.cat((distribute, distribute_tmp), dim=1) 131 | norm_outputs = F.normalize(tmp, p=2, dim=1) 132 | 133 | for fname, output, cam, dis in zip(fnames, norm_outputs, camid, 134 | distribute): 135 | cam = cam.item() 136 | cross_cam_fnames.append(fname) 137 | cross_cam_features.append(output) 138 | cams.append(cam) 139 | cross_cam_distribute.append(dis.cpu().numpy()) 140 | return cross_cam_features, cross_cam_fnames, cross_cam_distribute, cams 141 | 142 | 143 | def jaccard_sim_cross_cam(cross_cam_distribute): 144 | print( 145 | "Start calculate jaccard similarity cross camera, this step may cost a lot of time" 146 | ) 147 | n = cross_cam_distribute.size(0) 148 | jaccard_sim = torch.zeros((n, n)) 149 | for i in range(n): 150 | distribute = cross_cam_distribute[i] 151 | abs_sub = torch.abs(distribute - cross_cam_distribute) 152 | sum_distribute = distribute + cross_cam_distribute 153 | intersection = (sum_distribute - abs_sub).sum(dim=1) / 2 154 | union = (sum_distribute + abs_sub).sum(dim=1) / 2 155 | jaccard_sim[i, :] = intersection / union 156 | return to_numpy(jaccard_sim) 157 | 158 | 159 | def cluster_cross_cam(cross_cam_dist, 160 | cross_cam_fname, 161 | eph, 162 | linkage="average", 163 | cams=None, 164 | mix_rate=0., 165 | jaccard_sim=None, 166 | n_clusters=None): 167 | cluster_results = OrderedDict() 168 | print("Start cluster cross camera according to distance") 169 | if mix_rate > 0: 170 | assert jaccard_sim is not None, "if mix_rate > 0, the jaccard sim is needed" 171 | assert cams is not None, "if mix_rate > 0, the cam is needed" 172 | n = len(cross_cam_fname) 173 | cams = np.array(cams).reshape((n, 1)) 174 | expand_cams = np.tile(cams, n) 175 | mask = np.array(expand_cams != expand_cams.T, dtype=np.float32) 176 | cross_cam_dist -= mask * jaccard_sim * mix_rate 177 | cross_cam_dist = re_ranking(cross_cam_dist) 178 | if n_clusters is None: 179 | tri_mat = np.triu(cross_cam_dist, 1) 180 | tri_mat = tri_mat[np.nonzero(tri_mat)] 181 | tri_mat = np.sort(tri_mat, axis=None) 182 | top_num = np.round(eph * tri_mat.size).astype(int) 183 | eps = tri_mat[top_num] 184 | print(eps) 185 | else: 186 | eps = None 187 | 188 | Ag = AgglomerativeClustering(n_clusters=n_clusters, 189 | affinity="precomputed", 190 | linkage=linkage, 191 | distance_threshold=eps) 192 | labels = Ag.fit_predict(cross_cam_dist) 193 | print(len(set(labels))) 194 | tem = {} 195 | relabel = 0 196 | for fname, label in zip(cross_cam_fname, labels): 197 | if label not in tem.keys(): 198 | tem[label] = [] 199 | tem[label].append(fname) 200 | for label, names in tem.items(): 201 | if len(names) > 1: 202 | for name in names: 203 | cluster_results[name] = relabel 204 | relabel += 1 205 | return cluster_results 206 | 207 | 208 | def distance_cross_cam(features, use_cpu=False): 209 | print("Start calculate pairwise distance cross camera") 210 | n = len(features) 211 | x = torch.cat(features) 212 | x = x.view(n, -1) 213 | if use_cpu: 214 | dist = 1 - np.matmul(x.cpu().numpy(), x.cpu().numpy().T) 215 | else: 216 | dist = 1 - torch.mm(x, x.t()) 217 | 218 | return to_numpy(dist) 219 | 220 | 221 | def distane_per_cam(per_cam_features): 222 | per_cam_dist = {} 223 | print("Start calculate pairwise distance per camera") 224 | for k, features in per_cam_features.items(): 225 | 226 | n = len(features) 227 | x = torch.cat(features) 228 | x = x.view(n, -1) 229 | 230 | per_cam_dist[k] = 1 - torch.mm(x, x.t()) 231 | return per_cam_dist 232 | 233 | 234 | def cluster_per_cam(per_cam_dist, per_cam_fname, eph, linkage="average", n_clusters=None): 235 | cluster_results = {} 236 | print("Start cluster per camera according to distance") 237 | for k, dist in per_cam_dist.items(): 238 | cluster_results[k] = OrderedDict() 239 | 240 | # handle the number of samples is small 241 | dist = dist.cpu().numpy() 242 | n = dist.shape[0] 243 | if n < eph: 244 | eph = n // 2 245 | # double the number of samples 246 | dist = np.tile(dist, (2, 2)) 247 | per_cam_fname[k] = per_cam_fname[k] + per_cam_fname[k] 248 | 249 | dist = re_ranking(dist) 250 | if n_clusters is None: 251 | tri_mat = np.triu(dist, 1) 252 | tri_mat = tri_mat[np.nonzero(tri_mat)] 253 | tri_mat = np.sort(tri_mat, axis=None) 254 | top_num = np.round(eph * tri_mat.size).astype(int) 255 | eps = tri_mat[top_num] 256 | # eps = tri_mat[:top_num].mean() 257 | print(eps) 258 | else: 259 | eps = None 260 | 261 | Ag = AgglomerativeClustering(n_clusters=n_clusters, 262 | affinity="precomputed", 263 | linkage=linkage, 264 | distance_threshold=eps) 265 | # Ag = DBSCAN(eps=eps, min_samples=3, metric='precomputed') 266 | 267 | labels = Ag.fit_predict(dist) 268 | print(len(set(labels))) 269 | tem = {} 270 | relabel = 0 271 | for fname, label in zip(per_cam_fname[k], labels): 272 | if label != -1: 273 | if label not in tem.keys(): 274 | tem[label] = [] 275 | tem[label].append(fname) 276 | for label, names in tem.items(): 277 | if len(names) > 1: 278 | for name in names: 279 | cluster_results[k][name] = relabel 280 | relabel += 1 281 | # for fname, label in zip(per_cam_fname[k], labels): 282 | # if label != -1: 283 | # cluster_results[k][fname] = label 284 | return cluster_results 285 | 286 | 287 | def get_intra_cam_cluster_result(model, data_loader, eph, linkage, n_clusters=None): 288 | per_cam_features, per_cam_fname = extract_features_per_cam( 289 | model, data_loader) 290 | per_cam_dist = distane_per_cam(per_cam_features) 291 | cluster_results = cluster_per_cam(per_cam_dist, per_cam_fname, eph, 292 | linkage, n_clusters=n_clusters) 293 | return cluster_results 294 | 295 | 296 | def get_inter_cam_cluster_result(model, 297 | data_loader, 298 | eph, 299 | linkage, 300 | mix_rate=0., 301 | use_cpu=False, 302 | n_clusters=None): 303 | features, fnames, cross_cam_distribute, cams = extract_features_cross_cam( 304 | model, data_loader) 305 | 306 | cross_cam_distribute = torch.Tensor(np.array(cross_cam_distribute)).cuda() 307 | 308 | cross_cam_dist = distance_cross_cam(features, use_cpu=use_cpu) 309 | if mix_rate > 0: 310 | jaccard_sim = jaccard_sim_cross_cam(cross_cam_distribute) 311 | else: 312 | jaccard_sim = None 313 | 314 | cluster_results = cluster_cross_cam( 315 | cross_cam_dist, 316 | fnames, 317 | eph, 318 | linkage=linkage, 319 | cams=cams, 320 | mix_rate=mix_rate, 321 | jaccard_sim=jaccard_sim, 322 | n_clusters=n_clusters 323 | ) 324 | return cluster_results 325 | 326 | 327 | def get_inter_cam_cluster_result_tnorm(model, 328 | data_loader, 329 | eph, 330 | linkage, 331 | mix_rate=0., 332 | use_cpu=False): 333 | features, fnames, cross_cam_distribute, cams = extract_features_cross_cam_with_tnorm( 334 | model, data_loader) 335 | 336 | cross_cam_distribute = torch.Tensor(np.array(cross_cam_distribute)).cuda() 337 | cross_cam_dist = distance_cross_cam(features, use_cpu=use_cpu) 338 | 339 | if mix_rate > 0: 340 | jaccard_sim = jaccard_sim_cross_cam(cross_cam_distribute) 341 | else: 342 | jaccard_sim = None 343 | 344 | cluster_results = cluster_cross_cam( 345 | cross_cam_dist, 346 | fnames, 347 | eph, 348 | linkage=linkage, 349 | cams=cams, 350 | mix_rate=mix_rate, 351 | jaccard_sim=jaccard_sim, 352 | ) 353 | return cluster_results 354 | -------------------------------------------------------------------------------- /reid/models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | from .AIBN import AIBNorm2d 7 | from .TNorm import TNorm 8 | import numpy as np 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, 13 | out_planes, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=1, 17 | bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | out = self.relu(out) 48 | 49 | return out 50 | 51 | 52 | class Bottleneck(nn.Module): 53 | expansion = 4 54 | 55 | def __init__(self, inplanes, planes, stride=1, downsample=None): 56 | super(Bottleneck, self).__init__() 57 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 58 | self.bn1 = nn.BatchNorm2d(planes) 59 | self.conv2 = nn.Conv2d(planes, 60 | planes, 61 | kernel_size=3, 62 | stride=stride, 63 | padding=1, 64 | bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class AIBNBottleneck(nn.Module): 96 | expansion = 4 97 | 98 | def __init__(self, 99 | inplanes, 100 | planes, 101 | stride=1, 102 | downsample=None, 103 | adaptive_weight=None, 104 | generate_weight=True, 105 | init_weight=0.1, 106 | ): 107 | super(AIBNBottleneck, self).__init__() 108 | if adaptive_weight is None: 109 | self.adaptive_weight = nn.Parameter(torch.ones(1) * init_weight) 110 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 111 | self.bn1 = AIBNorm2d(planes, 112 | adaptive_weight=self.adaptive_weight, 113 | generate_weight=generate_weight) 114 | self.conv2 = nn.Conv2d(planes, 115 | planes, 116 | kernel_size=3, 117 | stride=stride, 118 | padding=1, 119 | bias=False) 120 | self.bn2 = AIBNorm2d(planes, 121 | adaptive_weight=self.adaptive_weight, 122 | generate_weight=generate_weight) 123 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 124 | 125 | self.bn3 = AIBNorm2d(planes * 4, 126 | adaptive_weight=self.adaptive_weight, 127 | generate_weight=generate_weight) 128 | self.relu = nn.ReLU(inplace=True) 129 | self.downsample = downsample 130 | if self.downsample is not None: 131 | self.downsample[1].adaptive_weight = self.adaptive_weight 132 | self.stride = stride 133 | 134 | def forward(self, x): 135 | residual = x 136 | 137 | out = self.conv1(x) 138 | out = self.bn1(out) 139 | out = self.relu(out) 140 | 141 | out = self.conv2(out) 142 | out = self.bn2(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv3(out) 146 | out = self.bn3(out) 147 | 148 | if self.downsample is not None: 149 | residual = self.downsample(x) 150 | 151 | out += residual 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | 157 | class AIBNResNet(nn.Module): 158 | def __init__(self, 159 | last_stride=2, 160 | block=AIBNBottleneck, 161 | layers=[3, 4, 6, 3], 162 | init_weight=0.1, 163 | ): 164 | self.inplanes = 64 165 | super(AIBNResNet, self).__init__() 166 | self.conv1 = nn.Conv2d(3, 167 | 64, 168 | kernel_size=7, 169 | stride=2, 170 | padding=3, 171 | bias=False) 172 | self.bn1 = nn.BatchNorm2d(64) 173 | self.relu = nn.ReLU(inplace=True) # add missed relu 174 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 175 | self.layer1 = self._make_layer_normal(Bottleneck, 64, layers[0]) 176 | self.layer2 = self._make_layer_normal(Bottleneck, 177 | 128, 178 | layers[1], 179 | stride=2) 180 | self.layer3 = self._make_layer(block, 181 | 256, 182 | layers[2], 183 | stride=2, 184 | adaptive_weight=None, 185 | init_weight=init_weight) 186 | self.layer4 = self._make_layer(block, 187 | 512, 188 | layers[3], 189 | stride=last_stride, 190 | adaptive_weight=None, 191 | init_weight=init_weight) 192 | 193 | self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1)) 194 | 195 | def _make_layer(self, 196 | block, 197 | planes, 198 | blocks, 199 | stride=1, 200 | adaptive_weight=None, 201 | init_weight=0.1): 202 | downsample = None 203 | if stride != 1 or self.inplanes != planes * block.expansion: 204 | downsample = nn.Sequential( 205 | nn.Conv2d(self.inplanes, 206 | planes * block.expansion, 207 | kernel_size=1, 208 | stride=stride, 209 | bias=False), 210 | AIBNorm2d(planes * block.expansion, 211 | adaptive_weight=adaptive_weight, 212 | generate_weight=True, 213 | init_weight=init_weight), 214 | ) 215 | 216 | layers = [] 217 | layers.append( 218 | block(self.inplanes, 219 | planes, 220 | stride, 221 | downsample, 222 | adaptive_weight=adaptive_weight, 223 | generate_weight=True, 224 | init_weight=init_weight, 225 | )) 226 | self.inplanes = planes * block.expansion 227 | for i in range(1, blocks): 228 | if i == (blocks - 1): 229 | layers.append( 230 | block(self.inplanes, 231 | planes, 232 | adaptive_weight=adaptive_weight, 233 | generate_weight=True, 234 | init_weight=init_weight, 235 | )) 236 | else: 237 | layers.append( 238 | block(self.inplanes, 239 | planes, 240 | adaptive_weight=adaptive_weight, 241 | generate_weight=True, 242 | init_weight=init_weight, 243 | )) 244 | 245 | return nn.Sequential(*layers) 246 | 247 | def _make_layer_normal(self, block, planes, blocks, stride=1): 248 | downsample = None 249 | if stride != 1 or self.inplanes != planes * block.expansion: 250 | downsample = nn.Sequential( 251 | nn.Conv2d(self.inplanes, 252 | planes * block.expansion, 253 | kernel_size=1, 254 | stride=stride, 255 | bias=False), 256 | nn.BatchNorm2d(planes * block.expansion), 257 | ) 258 | 259 | layers = [] 260 | layers.append(block(self.inplanes, planes, stride, downsample)) 261 | self.inplanes = planes * block.expansion 262 | for i in range(1, blocks): 263 | layers.append(block(self.inplanes, planes)) 264 | 265 | return nn.Sequential(*layers) 266 | 267 | def forward(self, x): 268 | x = self.conv1(x) 269 | x = self.bn1(x) 270 | x = self.relu(x) # add missed relu 271 | x = self.maxpool(x) 272 | 273 | x = self.layer1(x) 274 | x = self.layer2(x) 275 | x = self.layer3(x) 276 | x = self.layer4(x) 277 | 278 | x = self.adaptive_pool(x) 279 | 280 | return x 281 | 282 | def load_param(self, model_path): 283 | param_dict = torch.load(model_path) 284 | for i in param_dict: 285 | if 'fc' in i: 286 | continue 287 | self.state_dict()[i].copy_(param_dict[i]) 288 | 289 | def random_init(self): 290 | for m in self.modules(): 291 | if isinstance(m, nn.Conv2d): 292 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 293 | m.weight.data.normal_(0, math.sqrt(2. / n)) 294 | elif isinstance(m, nn.BatchNorm2d): 295 | m.weight.data.fill_(1) 296 | m.bias.data.zero_() 297 | 298 | 299 | class TNormResNet(nn.Module): 300 | def __init__(self, 301 | domain_number=1, 302 | last_stride=2, 303 | block=AIBNBottleneck, 304 | layers=[3, 4, 6, 3], 305 | init_weight=0.1): 306 | self.inplanes = 64 307 | super(TNormResNet, self).__init__() 308 | self.domain_number = domain_number 309 | self.conv1 = nn.Conv2d(3, 310 | 64, 311 | kernel_size=7, 312 | stride=2, 313 | padding=3, 314 | bias=False) 315 | self.bn1 = nn.BatchNorm2d(64) 316 | self.relu = nn.ReLU(inplace=True) # add missed relu 317 | self.relu_after_tnorm = nn.ReLU() 318 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 319 | self.layer1 = self._make_layer_normal(Bottleneck, 64, layers[0]) 320 | self.tnorm1 = TNorm(256, domain_number) 321 | self.layer2 = self._make_layer_normal(Bottleneck, 322 | 128, 323 | layers[1], 324 | stride=2) 325 | self.tnorm2 = TNorm(512, domain_number) 326 | self.layer3 = self._make_layer(block, 327 | 256, 328 | layers[2], 329 | stride=2, 330 | adaptive_weight=None, 331 | init_weight=init_weight) 332 | 333 | self.tnorm3 = TNorm(1024, domain_number) 334 | self.layer4 = self._make_layer(block, 335 | 512, 336 | layers[3], 337 | stride=last_stride, 338 | adaptive_weight=None, 339 | init_weight=init_weight) 340 | 341 | self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1)) 342 | 343 | def _make_layer(self, 344 | block, 345 | planes, 346 | blocks, 347 | stride=1, 348 | adaptive_weight=None, 349 | init_weight=0.1): 350 | downsample = None 351 | if stride != 1 or self.inplanes != planes * block.expansion: 352 | downsample = nn.Sequential( 353 | nn.Conv2d(self.inplanes, 354 | planes * block.expansion, 355 | kernel_size=1, 356 | stride=stride, 357 | bias=False), 358 | AIBNorm2d( 359 | planes * block.expansion, 360 | adaptive_weight=adaptive_weight, 361 | generate_weight=True, 362 | init_weight=init_weight 363 | ), 364 | ) 365 | 366 | layers = [] 367 | layers.append( 368 | block(self.inplanes, 369 | planes, 370 | stride, 371 | downsample, 372 | adaptive_weight=adaptive_weight, 373 | generate_weight=True, 374 | init_weight=init_weight)) 375 | self.inplanes = planes * block.expansion 376 | for i in range(1, blocks): 377 | if i == (blocks - 1): 378 | layers.append( 379 | block(self.inplanes, 380 | planes, 381 | adaptive_weight=adaptive_weight, 382 | generate_weight=True, 383 | init_weight=init_weight)) 384 | else: 385 | layers.append( 386 | block(self.inplanes, 387 | planes, 388 | adaptive_weight=adaptive_weight, 389 | generate_weight=True, 390 | init_weight=init_weight)) 391 | 392 | return nn.Sequential(*layers) 393 | 394 | def _make_layer_normal(self, block, planes, blocks, stride=1): 395 | downsample = None 396 | if stride != 1 or self.inplanes != planes * block.expansion: 397 | downsample = nn.Sequential( 398 | nn.Conv2d(self.inplanes, 399 | planes * block.expansion, 400 | kernel_size=1, 401 | stride=stride, 402 | bias=False), 403 | nn.BatchNorm2d(planes * block.expansion), 404 | ) 405 | 406 | layers = [] 407 | layers.append(block(self.inplanes, planes, stride, downsample)) 408 | self.inplanes = planes * block.expansion 409 | for i in range(1, blocks): 410 | layers.append(block(self.inplanes, planes)) 411 | 412 | return nn.Sequential(*layers) 413 | 414 | def forward(self, x, domain_index=None, convert=False): 415 | if convert: 416 | selected_domain = np.random.randint(0, 417 | self.domain_number, 418 | size=(x.size(0))) 419 | else: 420 | selected_domain = None 421 | x = self.conv1(x) 422 | x = self.bn1(x) 423 | x = self.relu(x) # add missed relu 424 | x = self.maxpool(x) 425 | x = self.layer1(x) 426 | x = self.tnorm1(x, domain_index, convert, selected_domain) 427 | x = self.layer2(x) 428 | x = self.tnorm2(x, domain_index, convert, selected_domain) 429 | x = self.layer3(x) 430 | x = self.tnorm3(x, domain_index, convert, selected_domain) 431 | x = self.layer4(x) 432 | 433 | x = self.adaptive_pool(x) 434 | 435 | return x 436 | 437 | def load_param(self, model_path): 438 | param_dict = torch.load(model_path) 439 | for k, v in param_dict.items(): 440 | if k in self.state_dict().keys(): 441 | self.state_dict()[k].copy_(v) 442 | 443 | def random_init(self): 444 | for m in self.modules(): 445 | if isinstance(m, nn.Conv2d): 446 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 447 | m.weight.data.normal_(0, math.sqrt(2. / n)) 448 | elif isinstance(m, nn.BatchNorm2d): 449 | m.weight.data.fill_(1) 450 | m.bias.data.zero_() 451 | -------------------------------------------------------------------------------- /reid/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from .evaluation_metrics import accuracy 8 | from .loss import TripletLoss 9 | from .utils.meters import AverageMeter 10 | import numpy as np 11 | 12 | 13 | class BaseTrainer(object): 14 | def __init__(self, model, criterion, warm_up_epoch=-1): 15 | super(BaseTrainer, self).__init__() 16 | self.model = model 17 | self.criterion = criterion 18 | self.warm_up_epoch = warm_up_epoch 19 | 20 | def train(self, epoch, data_loader, optimizer, print_freq=1): 21 | self.model.train() 22 | 23 | batch_time = AverageMeter() 24 | data_time = AverageMeter() 25 | losses = AverageMeter() 26 | precisions = AverageMeter() 27 | 28 | end = time.time() 29 | for i, inputs in enumerate(data_loader): 30 | data_time.update(time.time() - end) 31 | 32 | inputs, targets = self._parse_data(inputs) 33 | loss, prec1 = self._forward(inputs, targets) 34 | 35 | losses.update(loss.item(), targets.size(0)) 36 | precisions.update(prec1, targets.size(0)) 37 | if epoch < self.warm_up_epoch: 38 | loss = loss * 0.1 39 | optimizer.zero_grad() 40 | loss.backward() 41 | optimizer.step() 42 | 43 | batch_time.update(time.time() - end) 44 | end = time.time() 45 | 46 | if (i + 1) % print_freq == 0: 47 | print('Epoch: [{}][{}/{}]\t' 48 | 'Time {:.3f} ({:.3f})\t' 49 | 'Data {:.3f} ({:.3f})\t' 50 | 'Loss {:.3f} ({:.3f})\t' 51 | 'Prec {:.2%} ({:.2%})\t'.format( 52 | epoch, i + 1, len(data_loader), batch_time.val, 53 | batch_time.avg, data_time.val, data_time.avg, 54 | losses.val, losses.avg, precisions.val, 55 | precisions.avg)) 56 | 57 | def _parse_data(self, inputs): 58 | raise NotImplementedError 59 | 60 | def _forward(self, inputs, targets): 61 | raise NotImplementedError 62 | 63 | 64 | class Trainer(BaseTrainer): 65 | def _parse_data(self, inputs): 66 | imgs, _, pids, _ = inputs 67 | inputs = imgs.cuda() 68 | targets = pids.cuda() 69 | return inputs, targets 70 | 71 | def _forward(self, inputs, targets): 72 | outputs = self.model(inputs) 73 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 74 | loss = self.criterion(outputs, targets) 75 | prec, = accuracy(outputs.data, targets.data) 76 | prec = prec[0] 77 | elif isinstance(self.criterion, TripletLoss): 78 | loss, prec = self.criterion(outputs, targets) 79 | else: 80 | raise ValueError("Unsupported loss:", self.criterion) 81 | return loss, prec 82 | 83 | 84 | class IntraCameraTrainer(BaseTrainer): 85 | def _parse_data(self, inputs): 86 | pass 87 | 88 | def _forward(self, inputs, targets, i): 89 | outputs = self.model(inputs, i) 90 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 91 | loss = self.criterion(outputs, targets) 92 | prec, = accuracy(outputs.data, targets.data) 93 | prec = prec[0] 94 | else: 95 | raise ValueError("Unsupported loss:", self.criterion) 96 | return loss, prec 97 | 98 | def train( 99 | self, 100 | cluster_epoch, 101 | epoch, 102 | data_loader, 103 | optimizer, 104 | print_freq=1, 105 | ): 106 | self.model.train() 107 | 108 | batch_time = AverageMeter() 109 | data_time = AverageMeter() 110 | losses = AverageMeter() 111 | precisions = AverageMeter() 112 | 113 | end = time.time() 114 | data_loader_size = min([len(l) for l in data_loader]) 115 | for i, inputs in enumerate(zip(*data_loader)): 116 | data_time.update(time.time() - end) 117 | for domain, domain_input in enumerate(inputs): 118 | imgs, _, pids, camid = domain_input 119 | imgs = imgs.cuda() 120 | targets = pids.cuda() 121 | assert domain == camid[0] 122 | 123 | loss, prec1 = self._forward(imgs, targets, domain) 124 | if domain == 0: 125 | loss_sum = loss 126 | else: 127 | loss_sum = loss_sum + loss 128 | losses.update(loss.item(), targets.size(0)) 129 | precisions.update(prec1, targets.size(0)) 130 | 131 | if cluster_epoch < self.warm_up_epoch: 132 | loss_sum = loss_sum * 0.1 133 | 134 | optimizer.zero_grad() 135 | loss_sum.backward() 136 | optimizer.step() 137 | 138 | batch_time.update(time.time() - end) 139 | end = time.time() 140 | 141 | if (i + 1) % print_freq == 0: 142 | print('Cluster_Epoch: [{}]\t' 143 | 'Epoch: [{}][{}/{}]\t' 144 | 'Time {:.3f} ({:.3f})\t' 145 | 'Data {:.3f} ({:.3f})\t' 146 | 'Loss {:.3f} ({:.3f})\t' 147 | 'Prec {:.2%} ({:.2%})\t'.format( 148 | cluster_epoch, epoch, i + 1, data_loader_size, 149 | batch_time.val, batch_time.avg, data_time.val, 150 | data_time.avg, losses.val, losses.avg, 151 | precisions.val, precisions.avg)) 152 | 153 | 154 | class IntraCameraSelfKDTnormTrainer(object): 155 | def __init__( 156 | self, 157 | model_1, 158 | entropy_criterion, 159 | soft_entropy_criterion, 160 | warm_up_epoch=-1, 161 | multi_task_weight=1.0 162 | ): 163 | super(IntraCameraSelfKDTnormTrainer, self).__init__() 164 | self.model_1 = model_1 165 | self.T = 1. 166 | self.entropy_criterion = entropy_criterion 167 | self.soft_entropy_criterion = soft_entropy_criterion 168 | self.warm_up_epoch = warm_up_epoch 169 | self.multi_task_weight = multi_task_weight 170 | 171 | def _forward(self, inputs1, inputs2, targets, i): 172 | convert = np.random.rand() > 0.5 173 | outputs1 = self.model_1(inputs1, i) 174 | outputs2 = self.model_1(inputs2, i, convert=convert) 175 | 176 | loss_ce1 = self.entropy_criterion(outputs1, targets) 177 | prec1, = accuracy(outputs1.data, targets.data) 178 | prec1 = prec1[0] 179 | 180 | soft_loss1 = self.soft_entropy_criterion(outputs2 / self.T, (outputs1 / self.T).detach()) * self.T * self.T 181 | 182 | return loss_ce1, prec1, soft_loss1 183 | 184 | def train( 185 | self, 186 | cluster_epoch, 187 | epoch, 188 | data_loader, 189 | optimizer, 190 | print_freq=1, 191 | ): 192 | self.model_1.train() 193 | 194 | batch_time = AverageMeter() 195 | data_time = AverageMeter() 196 | losses_ce1 = AverageMeter() 197 | precisions_1 = AverageMeter() 198 | losses_soft1 = AverageMeter() 199 | 200 | end = time.time() 201 | data_loader_size = min([len(l) for l in data_loader]) 202 | 203 | for i, inputs in enumerate(zip(*data_loader)): 204 | data_time.update(time.time() - end) 205 | for domain, domain_input in enumerate(inputs): 206 | imgs1, imgs2, _, pids, _ = domain_input 207 | imgs1 = imgs1.cuda() 208 | imgs2 = imgs2.cuda() 209 | targets = pids.cuda() 210 | 211 | loss1, prec1, soft_loss1 = self._forward( 212 | imgs1, imgs2, targets, domain) 213 | if domain == 0: 214 | loss1_sum = loss1 215 | soft_loss1_sum = soft_loss1 216 | else: 217 | loss1_sum = loss1_sum + loss1 218 | soft_loss1_sum = soft_loss1_sum + soft_loss1 219 | 220 | losses_ce1.update(loss1.item(), targets.size(0)) 221 | precisions_1.update(prec1, targets.size(0)) 222 | losses_soft1.update(soft_loss1.item(), targets.size(0)) 223 | 224 | final_loss = loss1_sum + soft_loss1_sum * self.multi_task_weight 225 | 226 | if cluster_epoch < self.warm_up_epoch: 227 | final_loss = final_loss * 0.1 228 | 229 | optimizer.zero_grad() 230 | final_loss.backward() 231 | optimizer.step() 232 | 233 | batch_time.update(time.time() - end) 234 | end = time.time() 235 | 236 | if (i + 1) % print_freq == 0: 237 | print('Cluster_Epoch: [{}]\t' 238 | 'Epoch: [{}][{}/{}]\t' 239 | 'Time {:.3f} ({:.3f})\t' 240 | 'Data {:.3f} ({:.3f})\t' 241 | 'Loss_ce1 {:.3f} ({:.3f})\t' 242 | 'Prec1 {:.2%} ({:.2%})\t' 243 | 'Loss_soft1 {:.3f} ({:.3f})\t'.format( 244 | cluster_epoch, epoch, i + 1, data_loader_size, 245 | batch_time.val, batch_time.avg, data_time.val, 246 | data_time.avg, losses_ce1.val, losses_ce1.avg, 247 | precisions_1.val, precisions_1.avg, 248 | losses_soft1.val, losses_soft1.avg)) 249 | 250 | 251 | class InterCameraTrainer(BaseTrainer): 252 | def __init__(self, 253 | model, 254 | entropy_criterion, 255 | triple_criterion, 256 | warm_up_epoch=-1, 257 | multi_task_weight=1.): 258 | super(InterCameraTrainer, self).__init__(model, entropy_criterion, 259 | warm_up_epoch) 260 | self.triple_critetion = triple_criterion 261 | self.multi_task_weight = multi_task_weight 262 | 263 | def _forward(self, inputs, targets): 264 | prob, distance = self.model(inputs) 265 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 266 | loss_entropy = self.criterion(prob, targets) 267 | prec_entropy, = accuracy(prob.data, targets.data) 268 | prec_entropy = prec_entropy[0] 269 | 270 | loss_triple, prec_triple = self.triple_critetion(distance, targets) 271 | 272 | return loss_entropy, prec_entropy, loss_triple, prec_triple 273 | 274 | def train(self, 275 | cluster_epoch, 276 | epoch, 277 | data_loader, 278 | optimizer, 279 | print_freq=1): 280 | self.model.train() 281 | 282 | batch_time = AverageMeter() 283 | data_time = AverageMeter() 284 | losses_entropy = AverageMeter() 285 | precisions_entropy = AverageMeter() 286 | losses_triple = AverageMeter() 287 | precisions_triple = AverageMeter() 288 | 289 | end = time.time() 290 | for i, inputs in enumerate(data_loader): 291 | data_time.update(time.time() - end) 292 | imgs, _, pids, _ = inputs 293 | imgs = imgs.cuda() 294 | targets = pids.cuda() 295 | 296 | loss_entropy, prec_entropy, loss_triple, prec_triple = self._forward( 297 | imgs, targets) 298 | 299 | loss = loss_triple * self.multi_task_weight + loss_entropy 300 | 301 | losses_entropy.update(loss_entropy.item(), targets.size(0)) 302 | precisions_entropy.update(prec_entropy, targets.size(0)) 303 | losses_triple.update(loss_triple.item(), targets.size(0)) 304 | precisions_triple.update(prec_triple, targets.size(0)) 305 | 306 | if cluster_epoch < self.warm_up_epoch: 307 | loss = loss * 0.1 308 | optimizer.zero_grad() 309 | loss.backward() 310 | optimizer.step() 311 | 312 | batch_time.update(time.time() - end) 313 | end = time.time() 314 | 315 | if (i + 1) % print_freq == 0: 316 | print('Cluster_Epoch: [{}]\t' 317 | 'Epoch: [{}][{}/{}]\t' 318 | 'Time {:.3f} ({:.3f})\t' 319 | 'Data {:.3f} ({:.3f})\t' 320 | 'Loss_Entropy {:.3f} ({:.3f})\t' 321 | 'Prec_Entropy {:.2%} ({:.2%})\t' 322 | 'Loss_Triple {:.3f} ({:.3f})\t' 323 | 'Prec_Triple {:.2%} ({:.2%})\t'.format( 324 | cluster_epoch, 325 | epoch, 326 | i + 1, 327 | len(data_loader), 328 | batch_time.val, 329 | batch_time.avg, 330 | data_time.val, 331 | data_time.avg, 332 | losses_entropy.val, 333 | losses_entropy.avg, 334 | precisions_entropy.val, 335 | precisions_entropy.avg, 336 | losses_triple.val, 337 | losses_triple.avg, 338 | precisions_triple.val, 339 | precisions_triple.avg, 340 | )) 341 | 342 | 343 | class InterCameraSelfKDTNormTrainer(object): 344 | def __init__( 345 | self, 346 | model_1, 347 | entropy_criterion, 348 | triplet_criterion, 349 | soft_entropy_criterion, 350 | triple_soft_criterion, 351 | warm_up_epoch=-1, 352 | multi_task_weight=1., 353 | ): 354 | super(InterCameraSelfKDTNormTrainer, self).__init__() 355 | self.model_1 = model_1 356 | 357 | self.entropy_criterion = entropy_criterion 358 | self.triplet_criterion = triplet_criterion 359 | self.soft_entropy_criterion = soft_entropy_criterion 360 | self.triple_soft_criterion = triple_soft_criterion 361 | self.warm_up_epoch = warm_up_epoch 362 | self.multi_task_weight = multi_task_weight 363 | self.T = 1. 364 | 365 | def _forward(self, inputs1, inputs2, targets): 366 | convert = np.random.rand() > 0.5 367 | prob1, distance1 = self.model_1(inputs1) 368 | prob2, distance2 = self.model_1(inputs2, convert=convert) 369 | 370 | loss_ce1 = self.entropy_criterion(prob1, targets) 371 | prec1, = accuracy(prob1.data, targets.data) 372 | prec1 = prec1[0] 373 | 374 | soft_loss1 = self.soft_entropy_criterion(prob2 / self.T, (prob1 / self.T).detach()) * self.T * self.T 375 | 376 | loss_triple1, prec_triple1 = self.triplet_criterion(distance1, targets) 377 | 378 | loss_triple1_soft = self.triple_soft_criterion(distance2, distance1.detach(), targets) 379 | 380 | return loss_ce1, prec1, loss_triple1, prec_triple1, \ 381 | soft_loss1, loss_triple1_soft 382 | 383 | def train( 384 | self, 385 | cluster_epoch, 386 | epoch, 387 | data_loader, 388 | optimizer, 389 | print_freq=1, 390 | ): 391 | self.model_1.train() 392 | 393 | batch_time = AverageMeter() 394 | data_time = AverageMeter() 395 | losses_ce1 = AverageMeter() 396 | precisions_1 = AverageMeter() 397 | losses_triple1 = AverageMeter() 398 | precisions_triple1 = AverageMeter() 399 | losses_soft1 = AverageMeter() 400 | losses_triple_soft = AverageMeter() 401 | 402 | end = time.time() 403 | data_loader_size = len(data_loader) 404 | 405 | for i, inputs in enumerate(data_loader): 406 | data_time.update(time.time() - end) 407 | imgs1, imgs2, _, pids, _ = inputs 408 | 409 | imgs1 = imgs1.cuda() 410 | imgs2 = imgs2.cuda() 411 | targets = pids.cuda() 412 | 413 | # TODO: whether use soft triple loss 414 | loss_ce1, prec1, loss_triple1, prec_triple1, \ 415 | loss_soft1, loss_triple1_soft = self._forward( 416 | imgs1, imgs2, targets) 417 | 418 | losses_ce1.update(loss_ce1.item(), targets.size(0)) 419 | precisions_1.update(prec1, targets.size(0)) 420 | losses_triple1.update(loss_triple1.item(), targets.size(0)) 421 | precisions_triple1.update(prec_triple1, targets.size(0)) 422 | losses_soft1.update(loss_soft1, targets.size(0)) 423 | # Note: add soft triple 424 | losses_triple_soft.update(loss_triple1_soft, targets.size(0)) 425 | 426 | final_loss = loss_ce1 + loss_triple1 + loss_soft1 + loss_triple1_soft 427 | 428 | if cluster_epoch < self.warm_up_epoch: 429 | final_loss = final_loss * 0.1 430 | 431 | optimizer.zero_grad() 432 | final_loss.backward() 433 | optimizer.step() 434 | 435 | batch_time.update(time.time() - end) 436 | end = time.time() 437 | 438 | if (i + 1) % print_freq == 0: 439 | print('Cluster_Epoch: [{}]\t' 440 | 'Epoch: [{}][{}/{}]\t' 441 | 'Time {:.3f} ({:.3f})\t' 442 | 'Data {:.3f} ({:.3f})\t' 443 | 'Loss_ce1 {:.3f} ({:.3f})\t' 444 | 'Prec1 {:.2%} ({:.2%})\t' 445 | 'Loss_triple1 {:.3f} ({:.3f})\t' 446 | 'Prec_triple1 {:.2%} ({:.2%})\t' 447 | 'Loss_soft1 {:.3f} ({:.3f})\t' 448 | 'Loss_triple_soft {:.3f} ({:.3f})\t'.format( 449 | cluster_epoch, epoch, i + 1, data_loader_size, 450 | batch_time.val, batch_time.avg, data_time.val, 451 | data_time.avg, losses_ce1.val, losses_ce1.avg, 452 | precisions_1.val, precisions_1.avg, 453 | losses_triple1.val, losses_triple1.avg, 454 | precisions_triple1.val, precisions_triple1.avg, 455 | losses_soft1.val, losses_soft1.avg, 456 | losses_triple_soft.val, losses_triple_soft.avg)) 457 | -------------------------------------------------------------------------------- /example/iids_tnorm_self_kd.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import shutil 5 | 6 | import numpy as np 7 | import sys 8 | import torch 9 | from torch import nn 10 | from torch.backends import cudnn 11 | from torch.utils.data import DataLoader 12 | from torch.optim import lr_scheduler 13 | from reid.loss import TripletLoss, SoftEntropy, SoftTripletLoss 14 | 15 | from reid import datasets 16 | from reid import models 17 | from reid.trainers import IntraCameraSelfKDTnormTrainer 18 | from reid.trainers import InterCameraSelfKDTNormTrainer 19 | from reid.evaluators_cos import Evaluator 20 | from reid.utils.data import transforms as T 21 | from reid.utils.data.preprocessor import Preprocessor 22 | from reid.utils.logging import Logger 23 | from reid.utils.serialization import load_checkpoint, save_checkpoint 24 | from reid.cluster_utils import get_intra_cam_cluster_result 25 | from reid.cluster_utils import get_inter_cam_cluster_result_tnorm 26 | from reid.utils.data.sampler import RandomIdentitySampler 27 | 28 | 29 | def get_data( 30 | name, 31 | split_id, 32 | data_dir, 33 | height, 34 | width, 35 | batch_size, 36 | workers, 37 | ): 38 | root = osp.join(data_dir, name) 39 | 40 | dataset = datasets.create(name, root, split_id=split_id) 41 | 42 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 43 | std=[0.229, 0.224, 0.225]) 44 | 45 | train_set = dataset.trainval 46 | num_classes = dataset.num_trainval_ids 47 | 48 | train_transformer = T.Compose([ 49 | T.Resize((height, width), interpolation=3), 50 | T.ToTensor(), 51 | normalizer, 52 | ]) 53 | 54 | test_transformer = T.Compose([ 55 | T.Resize((height, width), interpolation=3), 56 | T.ToTensor(), 57 | normalizer, 58 | ]) 59 | 60 | train_loader = DataLoader(Preprocessor(train_set, 61 | root=dataset.images_dir, 62 | transform=train_transformer), 63 | batch_size=batch_size, 64 | num_workers=workers, 65 | shuffle=False, 66 | pin_memory=False, 67 | drop_last=False) 68 | 69 | val_loader = DataLoader(Preprocessor(dataset.val, 70 | root=dataset.images_dir, 71 | transform=test_transformer), 72 | batch_size=batch_size, 73 | num_workers=workers, 74 | shuffle=False, 75 | pin_memory=False) 76 | 77 | test_loader = DataLoader(Preprocessor( 78 | list(set(dataset.query) | set(dataset.gallery)), 79 | root=dataset.images_dir, 80 | transform=test_transformer), 81 | batch_size=batch_size, 82 | num_workers=workers, 83 | shuffle=False, 84 | pin_memory=False) 85 | 86 | return dataset, num_classes, train_loader, val_loader, test_loader 87 | 88 | 89 | def make_params(model, lr, weight_decay): 90 | params = [] 91 | for key, value in model.model.named_parameters(): 92 | if not value.requires_grad: 93 | continue 94 | 95 | params += [{ 96 | "params": [value], 97 | "lr": lr * 0.1, 98 | "weight_decay": weight_decay 99 | }] 100 | for key, value in model.classifier.named_parameters(): 101 | if not value.requires_grad: 102 | continue 103 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 104 | return params 105 | 106 | 107 | def get_mix_rate(mix_rate, epoch, num_epoch, power=0.6): 108 | return mix_rate * (1 - epoch / num_epoch)**power 109 | 110 | 111 | def main(args): 112 | np.random.seed(args.seed) 113 | torch.manual_seed(args.seed) 114 | cudnn.benchmark = True 115 | 116 | # Redirect print to both console and log file 117 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 118 | print(args) 119 | shutil.copy(sys.argv[0], osp.join(args.logs_dir, 120 | osp.basename(sys.argv[0]))) 121 | 122 | # Create data loaders 123 | if args.height is None or args.width is None: 124 | args.height, args.width = (256, 128) 125 | dataset, num_classes, train_loader, val_loader, test_loader = \ 126 | get_data(args.dataset, args.split, args.data_dir, args.height, 127 | args.width, args.batch_size * 8, args.workers, 128 | ) 129 | camera_number = {"market1501": 6, "dukemtmc": 8, "msmt17": 15} 130 | # Create model 131 | model = models.create("ft_net_inter_TNorm", 132 | domain_number=camera_number[args.dataset], 133 | num_classes=num_classes, 134 | stride=args.stride, 135 | init_weight=args.init_weight) 136 | 137 | # Load from checkpoint 138 | start_epoch = 0 139 | best_top1 = 0 140 | top1 = 0 141 | is_best = False 142 | if args.checkpoint is not None: 143 | if args.evaluate: 144 | 145 | checkpoint = load_checkpoint(args.checkpoint) 146 | param_dict = model.state_dict() 147 | for k, v in checkpoint['state_dict'].items(): 148 | if 'model' in k and k in param_dict.keys(): 149 | param_dict[k] = v 150 | model.load_state_dict(param_dict) 151 | else: 152 | model.model.load_param(args.checkpoint) 153 | model = model.cuda() 154 | 155 | # Distance metric 156 | metric = None 157 | 158 | # Evaluator 159 | evaluator = Evaluator(model, use_cpu=args.use_cpu) 160 | if args.evaluate: 161 | print("Test:") 162 | evaluator.evaluate_tnorm( 163 | test_loader, 164 | dataset.query, 165 | dataset.gallery, 166 | metric, 167 | return_mAP=True, 168 | camera_number=camera_number[args.dataset], 169 | ) 170 | evaluator.evaluate( 171 | test_loader, 172 | dataset.query, 173 | dataset.gallery, 174 | metric, 175 | ) 176 | return 177 | 178 | train_transformer = [ 179 | T.Resize((args.height, args.width), interpolation=3), 180 | T.RandomHorizontalFlip(), 181 | T.Pad(10), 182 | T.RandomCrop((args.height, args.width)), 183 | T.ToTensor(), 184 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 185 | T.RandomErasing(probability=0.5), 186 | ] 187 | train_transformer = T.Compose(train_transformer) 188 | for cluster_epoch in range(args.cluster_epochs): 189 | # -------------------------Stage 1 intra camera training-------------------------- 190 | # Cluster and generate new dataset and model 191 | cluster_result = get_intra_cam_cluster_result(model, train_loader, 192 | args.eph_stage1, 193 | args.linkage) 194 | cluster_datasets = [ 195 | datasets.create("cluster", osp.join(args.data_dir, args.dataset), 196 | cluster_result[cam_id], cam_id) 197 | for cam_id in sorted(cluster_result.keys()) 198 | ] 199 | 200 | cluster_dataloaders = [ 201 | DataLoader(Preprocessor(dataset.train_set, 202 | root=dataset.images_dir, 203 | transform=train_transformer, 204 | mutual=True), 205 | batch_size=args.batch_size, 206 | num_workers=args.workers, 207 | shuffle=True, 208 | pin_memory=False, 209 | drop_last=True) for dataset in cluster_datasets 210 | ] 211 | param_dict = model.model.state_dict() 212 | model = models.create( 213 | "ft_net_intra_TNorm", 214 | num_classes=[dt.classes_num for dt in cluster_datasets], 215 | stride=args.stride, 216 | init_weight=args.init_weight) 217 | 218 | model_param_dict = model.model.state_dict() 219 | for k, v in model_param_dict.items(): 220 | if k in param_dict.keys(): 221 | model_param_dict[k] = param_dict[k] 222 | model.model.load_state_dict(model_param_dict) 223 | 224 | model = model.cuda() 225 | criterion = nn.CrossEntropyLoss().cuda() 226 | soft_criterion = SoftEntropy().cuda() 227 | # Optimizer 228 | param_groups = make_params(model, args.lr, args.weight_decay) 229 | optimizer = torch.optim.SGD(param_groups, momentum=0.9) 230 | # Trainer 231 | trainer = IntraCameraSelfKDTnormTrainer(model, 232 | criterion, 233 | soft_criterion, 234 | warm_up_epoch=args.warm_up, 235 | multi_task_weight=args.multi_task_weight,) 236 | print("start training") 237 | # Start training 238 | for epoch in range(0, args.epochs_stage1): 239 | trainer.train( 240 | cluster_epoch, 241 | epoch, 242 | cluster_dataloaders, 243 | optimizer, 244 | print_freq=5 245 | ) 246 | #-------------------------------------------Stage 2 inter camera training----------------------------------- 247 | mix_rate = get_mix_rate(args.mix_rate, 248 | cluster_epoch, 249 | args.cluster_epochs, 250 | power=args.decay_factor) 251 | 252 | cluster_result = get_inter_cam_cluster_result_tnorm(model, 253 | train_loader, 254 | args.eph_stage2, 255 | args.linkage, 256 | mix_rate, 257 | use_cpu=args.use_cpu) 258 | 259 | cluster_dataset = datasets.create( 260 | "cluster", osp.join(args.data_dir, args.dataset), cluster_result, 261 | 0) 262 | 263 | cluster_dataloaders = DataLoader( 264 | Preprocessor(cluster_dataset.train_set, 265 | root=cluster_dataset.images_dir, 266 | transform=train_transformer, 267 | mutual=True), 268 | batch_size=args.batch_size_stage2, 269 | num_workers=args.workers, 270 | sampler=RandomIdentitySampler(cluster_dataset.train_set, 271 | args.batch_size_stage2, 272 | args.instances), 273 | pin_memory=False, 274 | drop_last=True) 275 | 276 | param_dict = model.model.state_dict() 277 | model = models.create("ft_net_inter_TNorm", 278 | domain_number=camera_number[args.dataset], 279 | num_classes=cluster_dataset.classes_num, 280 | stride=args.stride, 281 | init_weight=args.init_weight) 282 | model.model.load_state_dict(param_dict) 283 | 284 | model = model.cuda() 285 | # Criterion 286 | criterion_entropy = nn.CrossEntropyLoss().cuda() 287 | criterion_triple = TripletLoss(margin=args.margin).cuda() 288 | criterion_soft = SoftEntropy().cuda() 289 | criterion_triple_soft = SoftTripletLoss(margin=None).cuda() 290 | # Optimizer 291 | param_groups = make_params(model, 292 | args.lr * args.batch_size_stage2 / 32, 293 | args.weight_decay) 294 | 295 | optimizer = torch.optim.SGD(param_groups, momentum=0.9) 296 | # Trainer 297 | trainer = InterCameraSelfKDTNormTrainer( 298 | model, 299 | criterion_entropy, 300 | criterion_triple, 301 | criterion_soft, 302 | criterion_triple_soft, 303 | warm_up_epoch=args.warm_up, 304 | ) 305 | 306 | print("start training") 307 | # Start training 308 | for epoch in range(0, args.epochs_stage2): 309 | trainer.train(cluster_epoch, 310 | epoch, 311 | cluster_dataloaders, 312 | optimizer, 313 | print_freq=args.print_freq) 314 | if (cluster_epoch + 1) % 5 == 0: 315 | evaluator = Evaluator(model, use_cpu=args.use_cpu) 316 | top1, mAP = evaluator.evaluate( 317 | test_loader, dataset.query, dataset.gallery, metric, return_mAP=True) 318 | 319 | is_best = top1 > best_top1 320 | best_top1 = max(top1, best_top1) 321 | 322 | save_checkpoint( 323 | { 324 | 'state_dict': model.state_dict(), 325 | 'epoch': cluster_epoch + 1, 326 | 'best_top1': best_top1, 327 | 'cluster_epoch': cluster_epoch + 1, 328 | }, 329 | is_best, 330 | fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 331 | if cluster_epoch == (args.cluster_epochs - 1): 332 | save_checkpoint( 333 | { 334 | 'state_dict': model.state_dict(), 335 | 'epoch': cluster_epoch + 1, 336 | 'best_top1': best_top1, 337 | 'cluster_epoch': cluster_epoch + 1, 338 | }, 339 | False, 340 | fpath=osp.join(args.logs_dir, 'latest.pth.tar')) 341 | 342 | print('\n * cluster_epoch: {:3d} top1: {:5.1%} best: {:5.1%}{}\n'. 343 | format(cluster_epoch, top1, best_top1, ' *' if is_best else '')) 344 | 345 | # Final test 346 | print('Test with best model:') 347 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) 348 | new_state_dict = {} 349 | for k in checkpoint['state_dict'].keys(): 350 | if 'model' in k: 351 | new_state_dict[k] = checkpoint['state_dict'][k] 352 | model.load_state_dict(new_state_dict, strict=False) 353 | best_rank1, mAP = evaluator.evaluate_tnorm( 354 | test_loader, 355 | dataset.query, 356 | dataset.gallery, 357 | metric, 358 | return_mAP=True, 359 | camera_number=camera_number[args.dataset]) 360 | best_rank1_2, mAP2 = evaluator.evaluate( 361 | test_loader, 362 | dataset.query, 363 | dataset.gallery, 364 | metric, 365 | return_mAP=True, 366 | ) 367 | print("Tnorm: Rank1: {} mAP: {}\t Normal: Rank1: {} mAP: {}".format( 368 | best_rank1, mAP, best_rank1_2, mAP2)) 369 | 370 | 371 | if __name__ == '__main__': 372 | parser = argparse.ArgumentParser(description="Softmax loss classification") 373 | # data 374 | parser.add_argument('--checkpoint', type=str, metavar='PATH') 375 | parser.add_argument('-d', 376 | '--dataset', 377 | type=str, 378 | default='market1501', 379 | choices=datasets.names()) 380 | parser.add_argument('--eph_stage1', type=float, default=0.0025) 381 | parser.add_argument('--eph_stage2', type=float, default=0.0017) 382 | parser.add_argument('--margin', type=float, default=0.3) 383 | parser.add_argument('--init_weight', type=float, default=0.1) 384 | parser.add_argument('--mix_rate', 385 | type=float, 386 | default=0.01, 387 | help="mu in Eq (5)") 388 | parser.add_argument('--decay_factor', type=float, default=0.6) 389 | 390 | parser.add_argument('-b', '--batch-size', type=int, default=8) 391 | parser.add_argument('-b2', '--batch-size-stage2', type=int, default=64) 392 | parser.add_argument('--instances', default=4) 393 | parser.add_argument('-j', '--workers', type=int, default=4) 394 | parser.add_argument('--split', type=int, default=0) 395 | parser.add_argument('--height', 396 | type=int, 397 | help="input height, default: 256 for resnet*, " 398 | "144 for inception") 399 | parser.add_argument('--width', 400 | type=int, 401 | help="input width, default: 128 for resnet*, " 402 | "56 for inception") 403 | # optimizer 404 | parser.add_argument('--lr', 405 | type=float, 406 | default=0.005, 407 | help="learning rate of new parameters, for pretrained " 408 | "parameters it is 10 times smaller than this") 409 | parser.add_argument('--momentum', type=float, default=0.9) 410 | parser.add_argument('--weight-decay', type=float, default=5e-4) 411 | # training configs 412 | parser.add_argument('--evaluate', 413 | action='store_true', 414 | help="evaluation only") 415 | parser.add_argument( 416 | '--use_cpu', 417 | action='store_true', 418 | help='use cpu to calculate dist to prevent from GPU OOM') 419 | parser.add_argument('--epochs_stage1', type=int, default=3) 420 | parser.add_argument('--epochs_stage2', type=int, default=2) 421 | parser.add_argument('--cluster_epochs', type=int, default=40) 422 | parser.add_argument('--warm_up', type=int, default=0) 423 | parser.add_argument('--start_save', 424 | type=int, 425 | default=0, 426 | help="start saving checkpoints after specific epoch") 427 | parser.add_argument('--seed', type=int, default=1) 428 | parser.add_argument('--print-freq', type=int, default=20) 429 | parser.add_argument('--linkage', type=str, default="average") 430 | parser.add_argument('--stride', type=int, default=1) 431 | parser.add_argument('--multi_task_weight', type=float, default=1.) 432 | # misc 433 | working_dir = osp.dirname(osp.abspath(__file__)) 434 | parser.add_argument('--data-dir', 435 | type=str, 436 | metavar='PATH', 437 | default=osp.join(working_dir, '../data')) 438 | parser.add_argument('--logs-dir', 439 | type=str, 440 | metavar='PATH', 441 | default=osp.join(working_dir, '../logs')) 442 | main(parser.parse_args()) 443 | --------------------------------------------------------------------------------