├── .gitignore ├── LICENSE ├── README.md ├── cluster-contrast-reid ├── LICENSE ├── README.md ├── clustercontrast │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── market1501.py │ │ └── msmt17.py │ ├── evaluation_metrics │ │ ├── __init__.py │ │ ├── classification.py │ │ └── ranking.py │ ├── evaluators.py │ ├── models │ │ ├── __init__.py │ │ ├── cm.py │ │ ├── dsbn.py │ │ ├── kmeans.py │ │ ├── pooling.py │ │ ├── resnet.py │ │ ├── resnet_ibn.py │ │ ├── resnet_ibn_a.py │ │ └── vision_transformer.py │ ├── trainers.py │ └── utils │ │ ├── __init__.py │ │ ├── data │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── preprocessor.py │ │ ├── sampler.py │ │ └── transforms.py │ │ ├── faiss_rerank.py │ │ ├── faiss_utils.py │ │ ├── logging.py │ │ ├── meters.py │ │ ├── osutils.py │ │ ├── rerank.py │ │ └── serialization.py ├── examples │ ├── cluster_contrast_train_usl.py │ └── test.py ├── market_uda.sh ├── market_usl.sh ├── msmt_uda.sh ├── msmt_usl.sh └── setup.py ├── dino ├── LICENSE ├── README.md ├── convert_model_dino.py ├── hubconf.py ├── main_dino.py ├── ours_vit.py ├── run_with_submitit.py ├── script │ ├── vit-b_ics_cfs_lup.sh │ ├── vit-s_cfs_lup.sh │ ├── vit-s_full_lup.sh │ ├── vit-s_ics_cfs_lup.sh │ ├── vit-s_ics_full_lup.sh │ └── vit-s_imagenet.sh ├── utils.py ├── vision_transformer.py └── visualize_attention.py ├── requirements.txt └── transreid_pytorch ├── LICENSE ├── README.md ├── config ├── __init__.py └── defaults.py ├── configs ├── market │ ├── debug.yml │ ├── vit_base_baseline.yml │ ├── vit_base_ics_384.yml │ ├── vit_small.yml │ ├── vit_small_ics.yml │ └── vit_small_ics_ddp.yml └── msmt17 │ ├── vit_base_baseline.yml │ ├── vit_base_ics_384.yml │ ├── vit_small.yml │ ├── vit_small_ics.yml │ └── vit_small_ics_ddp.yml ├── datasets ├── __init__.py ├── bases.py ├── make_dataloader.py ├── market1501.py ├── mm.py ├── msmt17.py ├── preprocessing.py ├── sampler.py ├── sampler_ddp.py └── transforms.py ├── loss ├── __init__.py ├── arcface.py ├── center_loss.py ├── make_loss.py ├── metric_learning.py ├── softmax_loss.py └── triplet_loss.py ├── model ├── __init__.py ├── backbones │ ├── __init__.py │ ├── resnet.py │ ├── resnet_ibn_a.py │ ├── swin_transformer.py │ ├── transformer_layers.py │ └── vit_pytorch.py └── make_model.py ├── processor ├── __init__.py └── processor.py ├── run.sh ├── setup.py ├── solver ├── __init__.py ├── cosine_lr.py ├── lr_scheduler.py ├── make_optimizer.py ├── scheduler.py └── scheduler_factory.py ├── test.py ├── train.py └── utils ├── __init__.py ├── faiss_rerank.py ├── faiss_utils.py ├── iotools.py ├── logger.py ├── meter.py ├── metrics.py └── reranking.py /.gitignore: -------------------------------------------------------------------------------- 1 | # temporary files which can be created if a process still has a handle open of a deleted file 2 | .fuse_hidden* 3 | 4 | # KDE directory preferences 5 | .directory 6 | 7 | # Linux trash folder which might appear on any partition or disk 8 | .Trash-* 9 | 10 | # .nfs files are created when an open file is removed but is still being accessed 11 | .nfs* 12 | 13 | 14 | *.DS_Store 15 | .AppleDouble 16 | .LSOverride 17 | 18 | # Icon must end with two \r 19 | Icon 20 | 21 | 22 | # Thumbnails 23 | ._* 24 | 25 | # Files that might appear in the root of a volume 26 | .DocumentRevisions-V100 27 | .fseventsd 28 | .Spotlight-V100 29 | .TemporaryItems 30 | .Trashes 31 | .VolumeIcon.icns 32 | .com.apple.timemachine.donotpresent 33 | 34 | # Directories potentially created on remote AFP share 35 | .AppleDB 36 | .AppleDesktop 37 | Network Trash Folder 38 | Temporary Items 39 | .apdisk 40 | 41 | 42 | # swap 43 | [._]*.s[a-v][a-z] 44 | [._]*.sw[a-p] 45 | [._]s[a-v][a-z] 46 | [._]sw[a-p] 47 | # session 48 | Session.vim 49 | # temporary 50 | .netrwhist 51 | *~ 52 | # auto-generated tag files 53 | tags 54 | 55 | 56 | # cache files for sublime text 57 | *.tmlanguage.cache 58 | *.tmPreferences.cache 59 | *.stTheme.cache 60 | 61 | # workspace files are user-specific 62 | *.sublime-workspace 63 | 64 | # project files should be checked into the repository, unless a significant 65 | # proportion of contributors will probably not be using SublimeText 66 | # *.sublime-project 67 | 68 | # sftp configuration file 69 | sftp-config.json 70 | 71 | # Package control specific files 72 | Package Control.last-run 73 | Package Control.ca-list 74 | Package Control.ca-bundle 75 | Package Control.system-ca-bundle 76 | Package Control.cache/ 77 | Package Control.ca-certs/ 78 | Package Control.merged-ca-bundle 79 | Package Control.user-ca-bundle 80 | oscrypto-ca-bundle.crt 81 | bh_unicode_properties.cache 82 | 83 | # Sublime-github package stores a github token in this file 84 | # https://packagecontrol.io/packages/sublime-github 85 | GitHub.sublime-settings 86 | 87 | 88 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 89 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 90 | 91 | # User-specific stuff: 92 | .idea 93 | .idea/**/workspace.xml 94 | .idea/**/tasks.xml 95 | 96 | # Sensitive or high-churn files: 97 | .idea/**/dataSources/ 98 | .idea/**/dataSources.ids 99 | .idea/**/dataSources.xml 100 | .idea/**/dataSources.local.xml 101 | .idea/**/sqlDataSources.xml 102 | .idea/**/dynamic.xml 103 | .idea/**/uiDesigner.xml 104 | 105 | # Gradle: 106 | .idea/**/gradle.xml 107 | .idea/**/libraries 108 | 109 | # Mongo Explorer plugin: 110 | .idea/**/mongoSettings.xml 111 | 112 | ## File-based project format: 113 | *.iws 114 | 115 | ## Plugin-specific files: 116 | 117 | # IntelliJ 118 | /out/ 119 | 120 | # mpeltonen/sbt-idea plugin 121 | .idea_modules/ 122 | 123 | # JIRA plugin 124 | atlassian-ide-plugin.xml 125 | 126 | # Crashlytics plugin (for Android Studio and IntelliJ) 127 | com_crashlytics_export_strings.xml 128 | crashlytics.properties 129 | crashlytics-build.properties 130 | fabric.properties 131 | 132 | 133 | # Byte-compiled / optimized / DLL files 134 | __pycache__/ 135 | *.py[cod] 136 | *$py.class 137 | 138 | # C extensions 139 | *.so 140 | 141 | # Distribution / packaging 142 | .Python 143 | env/ 144 | build/ 145 | develop-eggs/ 146 | dist/ 147 | downloads/ 148 | eggs/ 149 | .eggs/ 150 | lib/ 151 | lib64/ 152 | parts/ 153 | sdist/ 154 | var/ 155 | *.egg-info/ 156 | .installed.cfg 157 | *.egg 158 | 159 | # PyInstaller 160 | # Usually these files are written by a python script from a template 161 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 162 | *.manifest 163 | *.spec 164 | 165 | # Installer logs 166 | pip-log.txt 167 | pip-delete-this-directory.txt 168 | 169 | # Unit test / coverage reports 170 | htmlcov/ 171 | .tox/ 172 | .coverage 173 | .coverage.* 174 | .cache 175 | nosetests.xml 176 | coverage.xml 177 | *,cover 178 | .hypothesis/ 179 | 180 | # Translations 181 | *.mo 182 | *.pot 183 | 184 | # Django stuff: 185 | *.log 186 | local_settings.py 187 | 188 | # Flask stuff: 189 | instance/ 190 | .webassets-cache 191 | 192 | # Scrapy stuff: 193 | .scrapy 194 | 195 | # Sphinx documentation 196 | docs/_build/ 197 | 198 | # PyBuilder 199 | target/ 200 | 201 | # IPython Notebook 202 | .ipynb_checkpoints 203 | 204 | # pyenv 205 | .python-version 206 | 207 | # celery beat schedule file 208 | celerybeat-schedule 209 | 210 | # dotenv 211 | .env 212 | 213 | # virtualenv 214 | venv/ 215 | ENV/ 216 | 217 | # Spyder project settings 218 | .spyderproject 219 | 220 | # Rope project settings 221 | .ropeproject 222 | 223 | 224 | test*.yml 225 | vit_tiaocan*.yml 226 | *.pyc 227 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Alibaba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cluster-contrast-reid/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Alibaba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cluster-contrast-reid/README.md: -------------------------------------------------------------------------------- 1 | # Cluster Contrast for Unsupervised Person Re-Identification 2 | We modify the code from [cluster-contrast-reid](https://github.com/alibaba/cluster-contrast-reid). You can refer to the original repo for more details. 3 | 4 | ### Installation 5 | 6 | ```shell 7 | python setup.py develop 8 | ``` 9 | 10 | ### Prepare Pre-trained Models 11 | Please download the pre-trained models and put them into your custom file folder. 12 | 13 | ### Training 14 | 15 | You can use 2 or 4 GPUs for training. For more parameter configuration, please check **`market_usl.sh`**, **`market_uda.sh`**, **`msmt_usl.sh`** and **`msmt_uda.sh`**. 16 | 17 | - If you want to train the ViT-S with ICS, please add **`--conv-stem`**. 18 | - Please set **`-pp`** as the file path of the pre-trained model. For UDA ReID, the pre-trained model should be fine-tuned on the source dataset at first. 19 | - We observe high performance can be achieved on MSMT with 2 GPUs training. 20 | 21 | ## Citation 22 | 23 | If you find this code useful for your research, please cite the paper 24 | ``` 25 | @article{dai2021cluster, 26 | title={Cluster Contrast for Unsupervised Person Re-Identification}, 27 | author={Dai, Zuozhuo and Wang, Guangyuan and Zhu, Siyu and Yuan, Weihao and Tan, Ping}, 28 | journal={arXiv preprint arXiv:2103.11568}, 29 | year={2021} 30 | } 31 | ``` 32 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import models 6 | from . import utils 7 | from . import evaluators 8 | from . import trainers 9 | 10 | __version__ = '0.1.0' 11 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .market1501 import Market1501 5 | from .msmt17 import MSMT17 6 | 7 | 8 | __factory = { 9 | 'market1501': Market1501, 10 | 'msmt17': MSMT17, 11 | } 12 | 13 | 14 | def names(): 15 | return sorted(__factory.keys()) 16 | 17 | 18 | def create(name, root, *args, **kwargs): 19 | """ 20 | Create a dataset instance. 21 | 22 | Parameters 23 | ---------- 24 | name : str 25 | The dataset name. 26 | root : str 27 | The path to the dataset directory. 28 | split_id : int, optional 29 | The index of data split. Default: 0 30 | num_val : int or float, optional 31 | When int, it means the number of validation identities. When float, 32 | it means the proportion of validation to all the trainval. Default: 100 33 | download : bool, optional 34 | If True, will download the dataset. Default: False 35 | """ 36 | if name not in __factory: 37 | raise KeyError("Unknown dataset:", name) 38 | return __factory[name](root, *args, **kwargs) 39 | 40 | 41 | def get_dataset(name, root, *args, **kwargs): 42 | warnings.warn("get_dataset is deprecated. Use create instead.") 43 | return create(name, root, *args, **kwargs) 44 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class Market1501(BaseImageDataset): 9 | """ 10 | Market1501 11 | Reference: 12 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 13 | URL: http://www.liangzheng.org/Project/project_reid.html 14 | 15 | Dataset statistics: 16 | # identities: 1501 (+1 for background) 17 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 18 | """ 19 | dataset_dir = 'market1501' 20 | 21 | def __init__(self, root, verbose=True, **kwargs): 22 | super(Market1501, self).__init__() 23 | self.dataset_dir = osp.join(root, self.dataset_dir) 24 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 25 | self.query_dir = osp.join(self.dataset_dir, 'query') 26 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 27 | 28 | self._check_before_run() 29 | 30 | train = self._process_dir(self.train_dir, relabel=True) 31 | query = self._process_dir(self.query_dir, relabel=False) 32 | gallery = self._process_dir(self.gallery_dir, relabel=False) 33 | 34 | if verbose: 35 | print("=> Market1501 loaded") 36 | self.print_dataset_statistics(train, query, gallery) 37 | 38 | self.train = train 39 | self.query = query 40 | self.gallery = gallery 41 | 42 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 43 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 44 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 45 | 46 | def _check_before_run(self): 47 | """Check if all files are available before going deeper""" 48 | if not osp.exists(self.dataset_dir): 49 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 50 | if not osp.exists(self.train_dir): 51 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 52 | if not osp.exists(self.query_dir): 53 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 54 | if not osp.exists(self.gallery_dir): 55 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 56 | 57 | def _process_dir(self, dir_path, relabel=False): 58 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 59 | pattern = re.compile(r'([-\d]+)_c(\d)') 60 | 61 | pid_container = set() 62 | for img_path in img_paths: 63 | pid, _ = map(int, pattern.search(img_path).groups()) 64 | if pid == -1: 65 | continue # junk images are just ignored 66 | pid_container.add(pid) 67 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 68 | 69 | dataset = [] 70 | for img_path in img_paths: 71 | pid, camid = map(int, pattern.search(img_path).groups()) 72 | if pid == -1: 73 | continue # junk images are just ignored 74 | assert 0 <= pid <= 1501 # pid == 0 means background 75 | assert 1 <= camid <= 6 76 | camid -= 1 # index starts from 0 77 | if relabel: 78 | pid = pid2label[pid] 79 | dataset.append((img_path, pid, camid)) 80 | 81 | return dataset 82 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | import glob 5 | import re 6 | from ..utils.data import BaseImageDataset 7 | 8 | 9 | class MSMT17(BaseImageDataset): 10 | """ 11 | MSMT17 12 | Reference: 13 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 14 | URL: http://www.pkuvmc.com/publications/msmt17.html 15 | Dataset statistics: 16 | # identities: 4101 17 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 18 | # cameras: 15 19 | """ 20 | dataset_dir = 'MSMT17' 21 | 22 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 23 | super(MSMT17, self).__init__() 24 | self.pid_begin = pid_begin 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | self.train_dir = osp.join(self.dataset_dir, 'train') 27 | self.test_dir = osp.join(self.dataset_dir, 'test') 28 | self.list_train_path = osp.join(self.dataset_dir, 'list_train.txt') 29 | self.list_val_path = osp.join(self.dataset_dir, 'list_val.txt') 30 | self.list_query_path = osp.join(self.dataset_dir, 'list_query.txt') 31 | self.list_gallery_path = osp.join(self.dataset_dir, 'list_gallery.txt') 32 | 33 | self._check_before_run() 34 | train = self._process_dir(self.train_dir, self.list_train_path) 35 | val = self._process_dir(self.train_dir, self.list_val_path) 36 | train += val 37 | query = self._process_dir(self.test_dir, self.list_query_path) 38 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) 39 | if verbose: 40 | print("=> MSMT17 loaded") 41 | self.print_dataset_statistics(train, query, gallery) 42 | 43 | self.train = train 44 | self.query = query 45 | self.gallery = gallery 46 | 47 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 48 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 49 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.test_dir): 57 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 58 | 59 | def _process_dir(self, dir_path, list_path): 60 | with open(list_path, 'r') as txt: 61 | lines = txt.readlines() 62 | dataset = [] 63 | pid_container = set() 64 | cam_container = set() 65 | for img_idx, img_info in enumerate(lines): 66 | img_path, pid = img_info.split(' ') 67 | pid = int(pid) # no need to relabel 68 | camid = int(img_path.split('_')[2]) 69 | img_path = osp.join(dir_path, img_path) 70 | dataset.append((img_path, self.pid_begin +pid, camid-1)) 71 | pid_container.add(pid) 72 | cam_container.add(camid) 73 | # check if pid starts from 0 and increments with 1 74 | for idx, pid in enumerate(pid_container): 75 | assert idx == pid, "See code comment for explanation" 76 | return dataset 77 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap' 10 | ] 11 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from ..utils import to_torch 5 | 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | with torch.no_grad(): 9 | output, target = to_torch(output), to_torch(target) 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | ret = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 20 | ret.append(correct_k.mul_(1. / batch_size)) 21 | return ret 22 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import collections 4 | from collections import OrderedDict 5 | import numpy as np 6 | import torch 7 | import random 8 | import copy 9 | 10 | from .evaluation_metrics import cmc, mean_ap 11 | from .utils.meters import AverageMeter 12 | from .utils.rerank import re_ranking 13 | from .utils import to_torch 14 | 15 | 16 | # def extract_cnn_feature(model, inputs): 17 | # inputs = to_torch(inputs).cuda() 18 | # outputs = model(inputs) 19 | # return outputs 20 | 21 | 22 | def extract_features(model, data_loader, print_freq=50, cluster_features=True): 23 | model.eval() 24 | batch_time = AverageMeter() 25 | data_time = AverageMeter() 26 | 27 | features = OrderedDict() 28 | labels = OrderedDict() 29 | 30 | end = time.time() 31 | with torch.no_grad(): 32 | for i, (imgs, fnames, pids, _, _) in enumerate(data_loader): 33 | data_time.update(time.time() - end) 34 | imgs = to_torch(imgs).cuda() 35 | outputs = model(imgs) 36 | # if cluster_features: 37 | # inv_idx = torch.arange(imgs.size(3) - 1, -1, -1).long().cuda() 38 | # imgs_flip = imgs.index_select(3, inv_idx) 39 | # outputs_flip = model(imgs_flip) 40 | # outputs = (outputs + outputs_flip)/2.0 41 | outputs = outputs.data.cpu() 42 | for fname, output, pid in zip(fnames, outputs, pids): 43 | features[fname] = output 44 | labels[fname] = pid 45 | 46 | batch_time.update(time.time() - end) 47 | end = time.time() 48 | 49 | if (i + 1) % print_freq == 0: 50 | print('Extract Features: [{}/{}]\t' 51 | 'Time {:.3f} ({:.3f})\t' 52 | 'Data {:.3f} ({:.3f})\t' 53 | .format(i + 1, len(data_loader), 54 | batch_time.val, batch_time.avg, 55 | data_time.val, data_time.avg)) 56 | 57 | return features, labels 58 | 59 | 60 | def pairwise_distance(features, query=None, gallery=None): 61 | if query is None and gallery is None: 62 | n = len(features) 63 | x = torch.cat(list(features.values())) 64 | x = x.view(n, -1) 65 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 66 | dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t()) 67 | return dist_m 68 | 69 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 70 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 71 | m, n = x.size(0), y.size(0) 72 | x = x.view(m, -1) 73 | y = y.view(n, -1) 74 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 75 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 76 | dist_m.addmm_(1, -2, x, y.t()) 77 | return dist_m, x.numpy(), y.numpy() 78 | 79 | 80 | def evaluate_all(query_features, gallery_features, distmat, query=None, gallery=None, 81 | query_ids=None, gallery_ids=None, 82 | query_cams=None, gallery_cams=None, 83 | cmc_topk=(1, 5, 10), cmc_flag=False): 84 | if query is not None and gallery is not None: 85 | query_ids = [pid for _, pid, _ in query] 86 | gallery_ids = [pid for _, pid, _ in gallery] 87 | query_cams = [cam for _, _, cam in query] 88 | gallery_cams = [cam for _, _, cam in gallery] 89 | else: 90 | assert (query_ids is not None and gallery_ids is not None 91 | and query_cams is not None and gallery_cams is not None) 92 | 93 | # Compute mean AP 94 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 95 | print('Mean AP: {:4.1%}'.format(mAP)) 96 | 97 | if (not cmc_flag): 98 | return mAP 99 | 100 | cmc_configs = { 101 | 'market1501': dict(separate_camera_set=False, 102 | single_gallery_shot=False, 103 | first_match_break=True),} 104 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 105 | query_cams, gallery_cams, **params) 106 | for name, params in cmc_configs.items()} 107 | 108 | print('CMC Scores:') 109 | for k in cmc_topk: 110 | print(' top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k-1])) 111 | return cmc_scores['market1501'], mAP 112 | 113 | 114 | class Evaluator(object): 115 | def __init__(self, model): 116 | super(Evaluator, self).__init__() 117 | self.model = model 118 | 119 | def evaluate(self, data_loader, query, gallery, cmc_flag=False, rerank=False): 120 | features, _ = extract_features(self.model, data_loader, cluster_features=False) 121 | distmat, query_features, gallery_features = pairwise_distance(features, query, gallery) 122 | results = evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 123 | 124 | if (not rerank): 125 | return results 126 | 127 | print('Applying person re-ranking ...') 128 | distmat_qq, _, _ = pairwise_distance(features, query, query) 129 | distmat_gg, _, _ = pairwise_distance(features, gallery, gallery) 130 | distmat = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy()) 131 | return evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 132 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .resnet_ibn import * 5 | from .vision_transformer import * 6 | __factory = { 7 | 'resnet18': resnet18, 8 | 'resnet34': resnet34, 9 | 'resnet50': resnet50, 10 | 'resnet101': resnet101, 11 | 'resnet152': resnet152, 12 | 'resnet_ibn50a': resnet_ibn50a, 13 | 'resnet_ibn101a': resnet_ibn101a, 14 | 'vit_small': vit_small, 15 | 'vit_base': vit_base, 16 | } 17 | 18 | 19 | def names(): 20 | return sorted(__factory.keys()) 21 | 22 | 23 | def create(name, *args, **kwargs): 24 | """ 25 | Create a model instance. 26 | 27 | Parameters 28 | ---------- 29 | name : str 30 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 31 | 'resnet50', 'resnet101', and 'resnet152'. 32 | pretrained : bool, optional 33 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 34 | model. Default: True 35 | cut_at_pooling : bool, optional 36 | If True, will cut the model before the last global pooling layer and 37 | ignore the remaining kwargs. Default: False 38 | num_features : int, optional 39 | If positive, will append a Linear layer after the global pooling layer, 40 | with this number of output units, followed by a BatchNorm layer. 41 | Otherwise these layers will not be appended. Default: 256 for 42 | 'inception', 0 for 'resnet*' 43 | norm : bool, optional 44 | If True, will normalize the feature to be unit L2-norm for each sample. 45 | Otherwise will append a ReLU layer after the above Linear layer if 46 | num_features > 0. Default: False 47 | dropout : float, optional 48 | If positive, will append a Dropout layer with this dropout rate. 49 | Default: 0 50 | num_classes : int, optional 51 | If positive, will append a Linear layer at the end as the classifier 52 | with this number of output units. Default: 0 53 | """ 54 | if name not in __factory: 55 | raise KeyError("Unknown model:", name) 56 | return __factory[name](*args, **kwargs) 57 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/models/cm.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | from abc import ABC 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, autograd 7 | 8 | 9 | class CM(autograd.Function): 10 | 11 | @staticmethod 12 | def forward(ctx, inputs, targets, features, momentum): 13 | ctx.features = features 14 | ctx.momentum = momentum 15 | ctx.save_for_backward(inputs, targets) 16 | outputs = inputs.mm(ctx.features.t()) 17 | 18 | return outputs 19 | 20 | @staticmethod 21 | def backward(ctx, grad_outputs): 22 | inputs, targets = ctx.saved_tensors 23 | grad_inputs = None 24 | if ctx.needs_input_grad[0]: 25 | grad_inputs = grad_outputs.mm(ctx.features) 26 | 27 | # momentum update 28 | for x, y in zip(inputs, targets): 29 | ctx.features[y] = ctx.momentum * ctx.features[y] + (1. - ctx.momentum) * x 30 | ctx.features[y] /= ctx.features[y].norm() 31 | 32 | return grad_inputs, None, None, None 33 | 34 | 35 | def cm(inputs, indexes, features, momentum=0.5): 36 | return CM.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 37 | 38 | 39 | class CM_Hard(autograd.Function): 40 | 41 | @staticmethod 42 | def forward(ctx, inputs, targets, features, momentum): 43 | ctx.features = features 44 | ctx.momentum = momentum 45 | ctx.save_for_backward(inputs, targets) 46 | outputs = inputs.mm(ctx.features.t()) 47 | 48 | return outputs 49 | 50 | @staticmethod 51 | def backward(ctx, grad_outputs): 52 | inputs, targets = ctx.saved_tensors 53 | grad_inputs = None 54 | if ctx.needs_input_grad[0]: 55 | grad_inputs = grad_outputs.mm(ctx.features) 56 | 57 | batch_centers = collections.defaultdict(list) 58 | for instance_feature, index in zip(inputs, targets.tolist()): 59 | batch_centers[index].append(instance_feature) 60 | 61 | for index, features in batch_centers.items(): 62 | distances = [] 63 | for feature in features: 64 | distance = feature.unsqueeze(0).mm(ctx.features[index].unsqueeze(0).t())[0][0] 65 | distances.append(distance.cpu().numpy()) 66 | 67 | median = np.argmin(np.array(distances)) 68 | ctx.features[index] = ctx.features[index] * ctx.momentum + (1 - ctx.momentum) * features[median] 69 | ctx.features[index] /= ctx.features[index].norm() 70 | 71 | return grad_inputs, None, None, None 72 | 73 | 74 | def cm_hard(inputs, indexes, features, momentum=0.5): 75 | return CM_Hard.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 76 | 77 | 78 | class ClusterMemory(nn.Module, ABC): 79 | def __init__(self, num_features, num_samples, temp=0.05, momentum=0.2, use_hard=False): 80 | super(ClusterMemory, self).__init__() 81 | self.num_features = num_features 82 | self.num_samples = num_samples 83 | 84 | self.momentum = momentum 85 | self.temp = temp 86 | self.use_hard = use_hard 87 | 88 | self.register_buffer('features', torch.zeros(num_samples, num_features)) 89 | 90 | def forward(self, inputs, targets): 91 | 92 | inputs = F.normalize(inputs, dim=1).cuda() 93 | if self.use_hard: 94 | outputs = cm_hard(inputs, targets, self.features, self.momentum) 95 | else: 96 | outputs = cm(inputs, targets, self.features, self.momentum) 97 | 98 | outputs /= self.temp 99 | loss = F.cross_entropy(outputs, targets) 100 | return loss 101 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/models/dsbn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Domain-specific BatchNorm 5 | 6 | class DSBN2d(nn.Module): 7 | def __init__(self, planes): 8 | super(DSBN2d, self).__init__() 9 | self.num_features = planes 10 | self.BN_S = nn.BatchNorm2d(planes) 11 | self.BN_T = nn.BatchNorm2d(planes) 12 | 13 | def forward(self, x): 14 | if (not self.training): 15 | return self.BN_T(x) 16 | 17 | bs = x.size(0) 18 | assert (bs%2==0) 19 | split = torch.split(x, int(bs/2), 0) 20 | out1 = self.BN_S(split[0].contiguous()) 21 | out2 = self.BN_T(split[1].contiguous()) 22 | out = torch.cat((out1, out2), 0) 23 | return out 24 | 25 | class DSBN1d(nn.Module): 26 | def __init__(self, planes): 27 | super(DSBN1d, self).__init__() 28 | self.num_features = planes 29 | self.BN_S = nn.BatchNorm1d(planes) 30 | self.BN_T = nn.BatchNorm1d(planes) 31 | 32 | def forward(self, x): 33 | if (not self.training): 34 | return self.BN_T(x) 35 | 36 | bs = x.size(0) 37 | assert (bs%2==0) 38 | split = torch.split(x, int(bs/2), 0) 39 | out1 = self.BN_S(split[0].contiguous()) 40 | out2 = self.BN_T(split[1].contiguous()) 41 | out = torch.cat((out1, out2), 0) 42 | return out 43 | 44 | def convert_dsbn(model): 45 | for _, (child_name, child) in enumerate(model.named_children()): 46 | assert(not next(model.parameters()).is_cuda) 47 | if isinstance(child, nn.BatchNorm2d): 48 | m = DSBN2d(child.num_features) 49 | m.BN_S.load_state_dict(child.state_dict()) 50 | m.BN_T.load_state_dict(child.state_dict()) 51 | setattr(model, child_name, m) 52 | elif isinstance(child, nn.BatchNorm1d): 53 | m = DSBN1d(child.num_features) 54 | m.BN_S.load_state_dict(child.state_dict()) 55 | m.BN_T.load_state_dict(child.state_dict()) 56 | setattr(model, child_name, m) 57 | else: 58 | convert_dsbn(child) 59 | 60 | def convert_bn(model, use_target=True): 61 | for _, (child_name, child) in enumerate(model.named_children()): 62 | assert(not next(model.parameters()).is_cuda) 63 | if isinstance(child, DSBN2d): 64 | m = nn.BatchNorm2d(child.num_features) 65 | if use_target: 66 | m.load_state_dict(child.BN_T.state_dict()) 67 | else: 68 | m.load_state_dict(child.BN_S.state_dict()) 69 | setattr(model, child_name, m) 70 | elif isinstance(child, DSBN1d): 71 | m = nn.BatchNorm1d(child.num_features) 72 | if use_target: 73 | m.load_state_dict(child.BN_T.state_dict()) 74 | else: 75 | m.load_state_dict(child.BN_S.state_dict()) 76 | setattr(model, child_name, m) 77 | else: 78 | convert_bn(child, use_target=use_target) 79 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/models/kmeans.py: -------------------------------------------------------------------------------- 1 | # Written by Yixiao Ge 2 | 3 | import warnings 4 | 5 | import faiss 6 | import torch 7 | 8 | from ..utils import to_numpy, to_torch 9 | 10 | __all__ = ["label_generator_kmeans"] 11 | 12 | 13 | @torch.no_grad() 14 | def label_generator_kmeans(features, num_classes=500, cuda=True): 15 | 16 | assert num_classes, "num_classes for kmeans is null" 17 | 18 | # k-means cluster by faiss 19 | cluster = faiss.Kmeans( 20 | features.size(-1), num_classes, niter=300, verbose=True, gpu=cuda 21 | ) 22 | 23 | cluster.train(to_numpy(features)) 24 | 25 | _, labels = cluster.index.search(to_numpy(features), 1) 26 | labels = labels.reshape(-1) 27 | 28 | centers = to_torch(cluster.centroids).float() 29 | # labels = to_torch(labels).long() 30 | 31 | # k-means does not have outlier points 32 | assert not (-1 in labels) 33 | 34 | return labels, centers, num_classes, None 35 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.nn import init 5 | import torchvision 6 | import torch 7 | from .pooling import build_pooling_layer 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152'] 12 | 13 | 14 | class ResNet(nn.Module): 15 | __factory = { 16 | 18: torchvision.models.resnet18, 17 | 34: torchvision.models.resnet34, 18 | 50: torchvision.models.resnet50, 19 | 101: torchvision.models.resnet101, 20 | 152: torchvision.models.resnet152, 21 | } 22 | 23 | def __init__(self, depth, pretrained_path='', cut_at_pooling=False, 24 | num_features=0, norm=False, dropout=0, num_classes=0, pooling_type='avg'): 25 | print('pooling_type: {}'.format(pooling_type)) 26 | super(ResNet, self).__init__() 27 | # self.pretrained = pretrained 28 | self.depth = depth 29 | self.cut_at_pooling = cut_at_pooling 30 | # Construct base (pretrained) resnet 31 | if depth not in ResNet.__factory: 32 | raise KeyError("Unsupported depth:", depth) 33 | resnet = ResNet.__factory[depth](pretrained=False) 34 | resnet.layer4[0].conv2.stride = (1, 1) 35 | resnet.layer4[0].downsample[0].stride = (1, 1) 36 | resnet = self.load_param(resnet,model_path=pretrained_path) 37 | self.base = nn.Sequential( 38 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 39 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 40 | 41 | 42 | self.gap = build_pooling_layer(pooling_type) 43 | 44 | if not self.cut_at_pooling: 45 | self.num_features = num_features 46 | self.norm = norm 47 | self.dropout = dropout 48 | self.has_embedding = num_features > 0 49 | self.num_classes = num_classes 50 | 51 | out_planes = resnet.fc.in_features 52 | 53 | # Append new layers 54 | if self.has_embedding: 55 | self.feat = nn.Linear(out_planes, self.num_features) 56 | self.feat_bn = nn.BatchNorm1d(self.num_features) 57 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 58 | init.constant_(self.feat.bias, 0) 59 | else: 60 | # Change the num_features to CNN output channels 61 | self.num_features = out_planes 62 | self.feat_bn = nn.BatchNorm1d(self.num_features) 63 | self.feat_bn.bias.requires_grad_(False) 64 | if self.dropout > 0: 65 | self.drop = nn.Dropout(self.dropout) 66 | if self.num_classes > 0: 67 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 68 | init.normal_(self.classifier.weight, std=0.001) 69 | init.constant_(self.feat_bn.weight, 1) 70 | init.constant_(self.feat_bn.bias, 0) 71 | 72 | # if not pretrained: 73 | # self.reset_params() 74 | 75 | def forward(self, x): 76 | bs = x.size(0) 77 | x = self.base(x) 78 | 79 | x = self.gap(x) 80 | x = x.view(x.size(0), -1) 81 | 82 | if self.cut_at_pooling: 83 | return x 84 | 85 | if self.has_embedding: 86 | bn_x = self.feat_bn(self.feat(x)) 87 | else: 88 | bn_x = self.feat_bn(x) 89 | 90 | if (self.training is False): 91 | bn_x = F.normalize(bn_x) 92 | return bn_x 93 | 94 | if self.norm: 95 | bn_x = F.normalize(bn_x) 96 | elif self.has_embedding: 97 | bn_x = F.relu(bn_x) 98 | 99 | if self.dropout > 0: 100 | bn_x = self.drop(bn_x) 101 | 102 | if self.num_classes > 0: 103 | prob = self.classifier(bn_x) 104 | else: 105 | return bn_x 106 | 107 | return prob 108 | 109 | # def reset_params(self): 110 | # for m in self.modules(): 111 | # if isinstance(m, nn.Conv2d): 112 | # init.kaiming_normal_(m.weight, mode='fan_out') 113 | # if m.bias is not None: 114 | # init.constant_(m.bias, 0) 115 | # elif isinstance(m, nn.BatchNorm2d): 116 | # init.constant_(m.weight, 1) 117 | # init.constant_(m.bias, 0) 118 | # elif isinstance(m, nn.BatchNorm1d): 119 | # init.constant_(m.weight, 1) 120 | # init.constant_(m.bias, 0) 121 | # elif isinstance(m, nn.Linear): 122 | # init.normal_(m.weight, std=0.001) 123 | # if m.bias is not None: 124 | # init.constant_(m.bias, 0) 125 | 126 | def load_param(self, model, model_path): 127 | param_dict = torch.load(model_path,map_location='cpu') 128 | for i in param_dict: 129 | if 'fc' in i or 'classifier' in i or 'bottleneck' in i: 130 | continue 131 | if i.startswith('base'): 132 | j = i.replace('base.','') 133 | else: 134 | j = i 135 | model.state_dict()[j].copy_(param_dict[i]) 136 | return model 137 | 138 | def resnet18(**kwargs): 139 | return ResNet(18, **kwargs) 140 | 141 | 142 | def resnet34(**kwargs): 143 | return ResNet(34, **kwargs) 144 | 145 | 146 | def resnet50(**kwargs): 147 | return ResNet(50, **kwargs) 148 | 149 | 150 | def resnet101(**kwargs): 151 | return ResNet(101, **kwargs) 152 | 153 | 154 | def resnet152(**kwargs): 155 | return ResNet(152, **kwargs) 156 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/models/resnet_ibn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | from .pooling import build_pooling_layer 9 | 10 | from .resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a 11 | 12 | 13 | __all__ = ['ResNetIBN', 'resnet_ibn50a', 'resnet_ibn101a'] 14 | 15 | 16 | class ResNetIBN(nn.Module): 17 | __factory = { 18 | '50a': resnet50_ibn_a, 19 | '101a': resnet101_ibn_a 20 | } 21 | 22 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 23 | num_features=0, norm=False, dropout=0, num_classes=0, pooling_type='avg', pretrained_path=''): 24 | 25 | print('pooling_type: {}'.format(pooling_type)) 26 | super(ResNetIBN, self).__init__() 27 | 28 | self.depth = depth 29 | self.pretrained = pretrained 30 | self.cut_at_pooling = cut_at_pooling 31 | 32 | resnet = ResNetIBN.__factory[depth](pretrained=pretrained) 33 | resnet.layer4[0].conv2.stride = (1, 1) 34 | resnet.layer4[0].downsample[0].stride = (1, 1) 35 | resnet = self.load_param(resnet,model_path=pretrained_path) 36 | 37 | self.base = nn.Sequential( 38 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 39 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 40 | 41 | self.gap = build_pooling_layer(pooling_type) 42 | 43 | if not self.cut_at_pooling: 44 | self.num_features = num_features 45 | self.norm = norm 46 | self.dropout = dropout 47 | self.has_embedding = num_features > 0 48 | self.num_classes = num_classes 49 | 50 | out_planes = resnet.fc.in_features 51 | 52 | # Append new layers 53 | if self.has_embedding: 54 | self.feat = nn.Linear(out_planes, self.num_features) 55 | self.feat_bn = nn.BatchNorm1d(self.num_features) 56 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 57 | init.constant_(self.feat.bias, 0) 58 | else: 59 | # Change the num_features to CNN output channels 60 | self.num_features = out_planes 61 | self.feat_bn = nn.BatchNorm1d(self.num_features) 62 | self.feat_bn.bias.requires_grad_(False) 63 | if self.dropout > 0: 64 | self.drop = nn.Dropout(self.dropout) 65 | if self.num_classes > 0: 66 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 67 | init.normal_(self.classifier.weight, std=0.001) 68 | 69 | init.constant_(self.feat_bn.weight, 1) 70 | init.constant_(self.feat_bn.bias, 0) 71 | 72 | # if not pretrained: 73 | # self.reset_params() 74 | 75 | def forward(self, x): 76 | x = self.base(x) 77 | 78 | x = self.gap(x) 79 | x = x.view(x.size(0), -1) 80 | 81 | if self.cut_at_pooling: 82 | return x 83 | 84 | if self.has_embedding: 85 | bn_x = self.feat_bn(self.feat(x)) 86 | else: 87 | bn_x = self.feat_bn(x) 88 | 89 | if self.training is False: 90 | bn_x = F.normalize(bn_x) 91 | return bn_x 92 | 93 | if self.norm: 94 | bn_x = F.normalize(bn_x) 95 | elif self.has_embedding: 96 | bn_x = F.relu(bn_x) 97 | 98 | if self.dropout > 0: 99 | bn_x = self.drop(bn_x) 100 | 101 | if self.num_classes > 0: 102 | prob = self.classifier(bn_x) 103 | else: 104 | return bn_x 105 | 106 | return prob 107 | 108 | # def reset_params(self): 109 | # for m in self.modules(): 110 | # if isinstance(m, nn.Conv2d): 111 | # init.kaiming_normal_(m.weight, mode='fan_out') 112 | # if m.bias is not None: 113 | # init.constant_(m.bias, 0) 114 | # elif isinstance(m, nn.BatchNorm2d): 115 | # init.constant_(m.weight, 1) 116 | # init.constant_(m.bias, 0) 117 | # elif isinstance(m, nn.BatchNorm1d): 118 | # init.constant_(m.weight, 1) 119 | # init.constant_(m.bias, 0) 120 | # elif isinstance(m, nn.Linear): 121 | # init.normal_(m.weight, std=0.001) 122 | # if m.bias is not None: 123 | # init.constant_(m.bias, 0) 124 | def load_param(self, model, model_path): 125 | param_dict = torch.load(model_path,map_location='cpu') 126 | for i in param_dict: 127 | if 'fc' in i or 'classifier' in i or 'bottleneck' in i: 128 | continue 129 | if i.startswith('base'): 130 | j = i.replace('base.','') 131 | else: 132 | j = i 133 | model.state_dict()[j].copy_(param_dict[i]) 134 | return model 135 | 136 | 137 | def resnet_ibn50a(**kwargs): 138 | return ResNetIBN('50a', **kwargs) 139 | 140 | 141 | def resnet_ibn101a(**kwargs): 142 | return ResNetIBN('101a', **kwargs) 143 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from .utils.meters import AverageMeter 4 | 5 | 6 | class ClusterContrastTrainer(object): 7 | def __init__(self, encoder, memory=None): 8 | super(ClusterContrastTrainer, self).__init__() 9 | self.encoder = encoder 10 | self.memory = memory 11 | 12 | def train(self, epoch, data_loader, optimizer, print_freq=10, train_iters=400): 13 | self.encoder.train() 14 | 15 | batch_time = AverageMeter() 16 | data_time = AverageMeter() 17 | 18 | losses = AverageMeter() 19 | 20 | end = time.time() 21 | for i in range(train_iters): 22 | # load data 23 | inputs = data_loader.next() 24 | data_time.update(time.time() - end) 25 | 26 | # process inputs 27 | inputs, labels, indexes = self._parse_data(inputs) 28 | 29 | # forward 30 | f_out = self._forward(inputs) 31 | # print("f_out shape: {}".format(f_out.shape)) 32 | # compute loss with the hybrid memory 33 | # loss = self.memory(f_out, indexes) 34 | loss = self.memory(f_out, labels) 35 | 36 | optimizer.zero_grad() 37 | loss.backward() 38 | optimizer.step() 39 | 40 | losses.update(loss.item()) 41 | 42 | # print log 43 | batch_time.update(time.time() - end) 44 | end = time.time() 45 | 46 | if (i + 1) % print_freq == 0: 47 | print('Epoch: [{}][{}/{}]\t' 48 | 'Time {:.3f} ({:.3f})\t' 49 | 'Data {:.3f} ({:.3f})\t' 50 | 'Loss {:.3f} ({:.3f})' 51 | .format(epoch, i + 1, len(data_loader), 52 | batch_time.val, batch_time.avg, 53 | data_time.val, data_time.avg, 54 | losses.val, losses.avg)) 55 | 56 | def _parse_data(self, inputs): 57 | imgs, _, pids, _, indexes = inputs 58 | return imgs.cuda(), pids.cuda(), indexes.cuda() 59 | 60 | def _forward(self, inputs): 61 | return self.encoder(inputs) 62 | 63 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .base_dataset import BaseDataset, BaseImageDataset 4 | from .preprocessor import Preprocessor 5 | 6 | 7 | class IterLoader: 8 | def __init__(self, loader, length=None): 9 | self.loader = loader 10 | self.length = length 11 | self.iter = None 12 | 13 | def __len__(self): 14 | if self.length is not None: 15 | return self.length 16 | 17 | return len(self.loader) 18 | 19 | def new_epoch(self): 20 | self.iter = iter(self.loader) 21 | 22 | def next(self): 23 | try: 24 | return next(self.iter) 25 | except: 26 | self.iter = iter(self.loader) 27 | return next(self.iter) 28 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def print_dataset_statistics(self): 23 | raise NotImplementedError 24 | 25 | @property 26 | def images_dir(self): 27 | return None 28 | 29 | 30 | class BaseImageDataset(BaseDataset): 31 | """ 32 | Base class of image reid dataset 33 | """ 34 | 35 | def print_dataset_statistics(self, train, query, gallery): 36 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 37 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 38 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 39 | 40 | print("Dataset statistics:") 41 | print(" ----------------------------------------") 42 | print(" subset | # ids | # images | # cameras") 43 | print(" ----------------------------------------") 44 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 45 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 46 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 47 | print(" ----------------------------------------") 48 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import random 7 | import math 8 | from PIL import Image 9 | 10 | 11 | class Preprocessor(Dataset): 12 | def __init__(self, dataset, root=None, transform=None): 13 | super(Preprocessor, self).__init__() 14 | self.dataset = dataset 15 | self.root = root 16 | self.transform = transform 17 | 18 | def __len__(self): 19 | return len(self.dataset) 20 | 21 | def __getitem__(self, indices): 22 | return self._get_single_item(indices) 23 | 24 | def _get_single_item(self, index): 25 | fname, pid, camid = self.dataset[index] 26 | fpath = fname 27 | if self.root is not None: 28 | fpath = osp.join(self.root, fname) 29 | 30 | img = Image.open(fpath).convert('RGB') 31 | 32 | if self.transform is not None: 33 | img = self.transform(img) 34 | 35 | return img, fname, pid, camid, index 36 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import math 4 | 5 | import numpy as np 6 | import copy 7 | import random 8 | import torch 9 | from torch.utils.data.sampler import ( 10 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 11 | WeightedRandomSampler) 12 | 13 | 14 | def No_index(a, b): 15 | assert isinstance(a, list) 16 | return [i for i, j in enumerate(a) if j != b] 17 | 18 | 19 | class RandomIdentitySampler(Sampler): 20 | def __init__(self, data_source, num_instances): 21 | self.data_source = data_source 22 | self.num_instances = num_instances 23 | self.index_dic = defaultdict(list) 24 | for index, (_, pid, _) in enumerate(data_source): 25 | self.index_dic[pid].append(index) 26 | self.pids = list(self.index_dic.keys()) 27 | self.num_samples = len(self.pids) 28 | 29 | def __len__(self): 30 | return self.num_samples * self.num_instances 31 | 32 | def __iter__(self): 33 | indices = torch.randperm(self.num_samples).tolist() 34 | ret = [] 35 | for i in indices: 36 | pid = self.pids[i] 37 | t = self.index_dic[pid] 38 | if len(t) >= self.num_instances: 39 | t = np.random.choice(t, size=self.num_instances, replace=False) 40 | else: 41 | t = np.random.choice(t, size=self.num_instances, replace=True) 42 | ret.extend(t) 43 | return iter(ret) 44 | 45 | 46 | class RandomMultipleGallerySampler(Sampler): 47 | def __init__(self, data_source, num_instances=4): 48 | super().__init__(data_source) 49 | self.data_source = data_source 50 | self.index_pid = defaultdict(int) 51 | self.pid_cam = defaultdict(list) 52 | self.pid_index = defaultdict(list) 53 | self.num_instances = num_instances 54 | 55 | for index, (_, pid, cam) in enumerate(data_source): 56 | if pid < 0: 57 | continue 58 | self.index_pid[index] = pid 59 | self.pid_cam[pid].append(cam) 60 | self.pid_index[pid].append(index) 61 | 62 | self.pids = list(self.pid_index.keys()) 63 | self.num_samples = len(self.pids) 64 | 65 | def __len__(self): 66 | return self.num_samples * self.num_instances 67 | 68 | def __iter__(self): 69 | indices = torch.randperm(len(self.pids)).tolist() 70 | ret = [] 71 | 72 | for kid in indices: 73 | i = random.choice(self.pid_index[self.pids[kid]]) 74 | 75 | _, i_pid, i_cam = self.data_source[i] 76 | 77 | ret.append(i) 78 | 79 | pid_i = self.index_pid[i] 80 | cams = self.pid_cam[pid_i] 81 | index = self.pid_index[pid_i] 82 | select_cams = No_index(cams, i_cam) 83 | 84 | if select_cams: 85 | 86 | if len(select_cams) >= self.num_instances: 87 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 88 | else: 89 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True) 90 | 91 | for kk in cam_indexes: 92 | ret.append(index[kk]) 93 | 94 | else: 95 | select_indexes = No_index(index, i) 96 | if not select_indexes: 97 | continue 98 | if len(select_indexes) >= self.num_instances: 99 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 100 | else: 101 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 102 | 103 | for kk in ind_indexes: 104 | ret.append(index[kk]) 105 | 106 | return iter(ret) 107 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | """ Randomly selects a rectangle region in an image and erases its pixels. 54 | 'Random Erasing Data Augmentation' by Zhong et al. 55 | See https://arxiv.org/pdf/1708.04896.pdf 56 | Args: 57 | probability: The probability that the Random Erasing operation will be performed. 58 | sl: Minimum proportion of erased area against input image. 59 | sh: Maximum proportion of erased area against input image. 60 | r1: Minimum aspect ratio of erased area. 61 | mean: Erasing value. 62 | """ 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) >= self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/faiss_rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 5 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 6 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 7 | """ 8 | 9 | import os, sys 10 | import time 11 | import numpy as np 12 | from scipy.spatial.distance import cdist 13 | import gc 14 | import faiss 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | 19 | from .faiss_utils import search_index_pytorch, search_raw_array_pytorch, \ 20 | index_init_gpu, index_init_cpu 21 | 22 | 23 | def k_reciprocal_neigh(initial_rank, i, k1): 24 | forward_k_neigh_index = initial_rank[i,:k1+1] 25 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 26 | fi = np.where(backward_k_neigh_index==i)[0] 27 | return forward_k_neigh_index[fi] 28 | 29 | 30 | def compute_jaccard_distance(target_features, k1=20, k2=6, print_flag=True, search_option=0, use_float16=False): 31 | end = time.time() 32 | if print_flag: 33 | print('Computing jaccard distance...') 34 | 35 | ngpus = faiss.get_num_gpus() 36 | N = target_features.size(0) 37 | mat_type = np.float16 if use_float16 else np.float32 38 | search_option = 2 39 | 40 | if (search_option==0): 41 | # GPU + PyTorch CUDA Tensors (1) 42 | res = faiss.StandardGpuResources() 43 | res.setDefaultNullStreamAllDevices() 44 | _, initial_rank = search_raw_array_pytorch(res, target_features, target_features, k1) 45 | initial_rank = initial_rank.cpu().numpy() 46 | elif (search_option==1): 47 | # GPU + PyTorch CUDA Tensors (2) 48 | res = faiss.StandardGpuResources() 49 | index = faiss.GpuIndexFlatL2(res, target_features.size(-1)) 50 | index.add(target_features.cpu().numpy()) 51 | _, initial_rank = search_index_pytorch(index, target_features, k1) 52 | res.syncDefaultStreamCurrentDevice() 53 | initial_rank = initial_rank.cpu().numpy() 54 | elif (search_option==2): 55 | # GPU 56 | index = index_init_gpu(ngpus, target_features.size(-1)) 57 | index.add(target_features.cpu().numpy()) 58 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 59 | else: 60 | # CPU 61 | index = index_init_cpu(target_features.size(-1)) 62 | index.add(target_features.cpu().numpy()) 63 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 64 | 65 | 66 | nn_k1 = [] 67 | nn_k1_half = [] 68 | for i in range(N): 69 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 70 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1/2)))) 71 | 72 | V = np.zeros((N, N), dtype=mat_type) 73 | for i in range(N): 74 | k_reciprocal_index = nn_k1[i] 75 | k_reciprocal_expansion_index = k_reciprocal_index 76 | for candidate in k_reciprocal_index: 77 | candidate_k_reciprocal_index = nn_k1_half[candidate] 78 | if (len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index)) > 2/3*len(candidate_k_reciprocal_index)): 79 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 80 | 81 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 82 | dist = 2-2*torch.mm(target_features[i].unsqueeze(0).contiguous(), target_features[k_reciprocal_expansion_index].t()) 83 | if use_float16: 84 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type) 85 | else: 86 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy() 87 | 88 | del nn_k1, nn_k1_half 89 | 90 | if k2 != 1: 91 | V_qe = np.zeros_like(V, dtype=mat_type) 92 | for i in range(N): 93 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:], axis=0) 94 | V = V_qe 95 | del V_qe 96 | 97 | del initial_rank 98 | 99 | invIndex = [] 100 | for i in range(N): 101 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num 102 | 103 | jaccard_dist = np.zeros((N, N), dtype=mat_type) 104 | for i in range(N): 105 | temp_min = np.zeros((1, N), dtype=mat_type) 106 | # temp_max = np.zeros((1,N), dtype=mat_type) 107 | indNonZero = np.where(V[i, :] != 0)[0] 108 | indImages = [] 109 | indImages = [invIndex[ind] for ind in indNonZero] 110 | for j in range(len(indNonZero)): 111 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]]+np.minimum(V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]) 112 | # temp_max[0,indImages[j]] = temp_max[0,indImages[j]]+np.maximum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 113 | 114 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 115 | # jaccard_dist[i] = 1-temp_min/(temp_max+1e-6) 116 | 117 | del invIndex, V 118 | 119 | pos_bool = (jaccard_dist < 0) 120 | jaccard_dist[pos_bool] = 0.0 121 | if print_flag: 122 | print("Jaccard distance computing time cost: {}".format(time.time()-end)) 123 | 124 | return jaccard_dist 125 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | def swig_ptr_from_FloatTensor(x): 7 | assert x.is_contiguous() 8 | assert x.dtype == torch.float32 9 | return faiss.cast_integer_to_float_ptr( 10 | x.storage().data_ptr() + x.storage_offset() * 4) 11 | 12 | def swig_ptr_from_LongTensor(x): 13 | assert x.is_contiguous() 14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 15 | 16 | return faiss.cast_integer_to_idx_t_ptr( 17 | x.storage().data_ptr() + x.storage_offset() * 8) 18 | 19 | def search_index_pytorch(index, x, k, D=None, I=None): 20 | """call the search function of an index with pytorch tensor I/O (CPU 21 | and GPU supported)""" 22 | assert x.is_contiguous() 23 | n, d = x.size() 24 | assert d == index.d 25 | 26 | if D is None: 27 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 28 | else: 29 | assert D.size() == (n, k) 30 | 31 | if I is None: 32 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 33 | else: 34 | assert I.size() == (n, k) 35 | torch.cuda.synchronize() 36 | xptr = swig_ptr_from_FloatTensor(x) 37 | Iptr = swig_ptr_from_LongTensor(I) 38 | Dptr = swig_ptr_from_FloatTensor(D) 39 | index.search_c(n, xptr, 40 | k, Dptr, Iptr) 41 | torch.cuda.synchronize() 42 | return D, I 43 | 44 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 45 | metric=faiss.METRIC_L2): 46 | assert xb.device == xq.device 47 | 48 | nq, d = xq.size() 49 | if xq.is_contiguous(): 50 | xq_row_major = True 51 | elif xq.t().is_contiguous(): 52 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 53 | xq_row_major = False 54 | else: 55 | raise TypeError('matrix should be row or column-major') 56 | 57 | xq_ptr = swig_ptr_from_FloatTensor(xq) 58 | 59 | nb, d2 = xb.size() 60 | assert d2 == d 61 | if xb.is_contiguous(): 62 | xb_row_major = True 63 | elif xb.t().is_contiguous(): 64 | xb = xb.t() 65 | xb_row_major = False 66 | else: 67 | raise TypeError('matrix should be row or column-major') 68 | xb_ptr = swig_ptr_from_FloatTensor(xb) 69 | 70 | if D is None: 71 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 72 | else: 73 | assert D.shape == (nq, k) 74 | assert D.device == xb.device 75 | 76 | if I is None: 77 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 78 | else: 79 | assert I.shape == (nq, k) 80 | assert I.device == xb.device 81 | 82 | D_ptr = swig_ptr_from_FloatTensor(D) 83 | I_ptr = swig_ptr_from_LongTensor(I) 84 | 85 | faiss.bruteForceKnn(res, metric, 86 | xb_ptr, xb_row_major, nb, 87 | xq_ptr, xq_row_major, nq, 88 | d, k, D_ptr, I_ptr) 89 | 90 | return D, I 91 | 92 | def index_init_gpu(ngpus, feat_dim): 93 | flat_config = [] 94 | for i in range(ngpus): 95 | cfg = faiss.GpuIndexFlatConfig() 96 | cfg.useFloat16 = False 97 | cfg.device = i 98 | flat_config.append(cfg) 99 | 100 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 101 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 102 | index = faiss.IndexShards(feat_dim) 103 | for sub_index in indexes: 104 | index.add_shard(sub_index) 105 | index.reset() 106 | return index 107 | 108 | def index_init_cpu(feat_dim): 109 | return faiss.IndexFlatL2(feat_dim) 110 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Source: https://github.com/zhunzhong07/person-re-ranking 5 | Created on Mon Jun 26 14:46:56 2017 6 | @author: luohao 7 | Modified by Houjing Huang, 2017-12-22. 8 | - This version accepts distance matrix instead of raw features. 9 | - The difference of `/` division between python 2 and 3 is handled. 10 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 11 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 12 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 13 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 14 | API 15 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 16 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 17 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 18 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 19 | Returns: 20 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import print_function 24 | from __future__ import division 25 | 26 | __all__ = ['re_ranking'] 27 | 28 | import numpy as np 29 | 30 | 31 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 32 | 33 | # The following naming, e.g. gallery_num, is different from outer scope. 34 | # Don't care about it. 35 | 36 | original_dist = np.concatenate( 37 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 38 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 39 | axis=0) 40 | original_dist = np.power(original_dist, 2).astype(np.float32) 41 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 42 | V = np.zeros_like(original_dist).astype(np.float32) 43 | initial_rank = np.argsort(original_dist).astype(np.int32) 44 | 45 | query_num = q_g_dist.shape[0] 46 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 47 | all_num = gallery_num 48 | 49 | for i in range(all_num): 50 | # k-reciprocal neighbors 51 | forward_k_neigh_index = initial_rank[i,:k1+1] 52 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 53 | fi = np.where(backward_k_neigh_index==i)[0] 54 | k_reciprocal_index = forward_k_neigh_index[fi] 55 | k_reciprocal_expansion_index = k_reciprocal_index 56 | for j in range(len(k_reciprocal_index)): 57 | candidate = k_reciprocal_index[j] 58 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 59 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 60 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 61 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 62 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 63 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 64 | 65 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 66 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 67 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 68 | original_dist = original_dist[:query_num,] 69 | if k2 != 1: 70 | V_qe = np.zeros_like(V,dtype=np.float32) 71 | for i in range(all_num): 72 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 73 | V = V_qe 74 | del V_qe 75 | del initial_rank 76 | invIndex = [] 77 | for i in range(gallery_num): 78 | invIndex.append(np.where(V[:,i] != 0)[0]) 79 | 80 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 81 | 82 | 83 | for i in range(query_num): 84 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 85 | indNonZero = np.where(V[i,:] != 0)[0] 86 | indImages = [] 87 | indImages = [invIndex[ind] for ind in indNonZero] 88 | for j in range(len(indNonZero)): 89 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 90 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 91 | 92 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 93 | del original_dist 94 | del V 95 | del jaccard_dist 96 | final_dist = final_dist[:query_num,query_num:] 97 | return final_dist 98 | -------------------------------------------------------------------------------- /cluster-contrast-reid/clustercontrast/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | # checkpoint = torch.load(fpath) 34 | checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 35 | print("=> Loaded checkpoint '{}'".format(fpath)) 36 | return checkpoint 37 | else: 38 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 39 | 40 | 41 | def copy_state_dict(state_dict, model, strip=None): 42 | tgt_state = model.state_dict() 43 | copied_names = set() 44 | for name, param in state_dict.items(): 45 | if strip is not None and name.startswith(strip): 46 | name = name[len(strip):] 47 | if name not in tgt_state: 48 | continue 49 | if isinstance(param, Parameter): 50 | param = param.data 51 | if param.size() != tgt_state[name].size(): 52 | print('mismatch:', name, param.size(), tgt_state[name].size()) 53 | continue 54 | tgt_state[name].copy_(param) 55 | copied_names.add(name) 56 | 57 | missing = set(tgt_state.keys()) - copied_names 58 | if len(missing) > 0: 59 | print("missing keys in state_dict:", missing) 60 | 61 | return model 62 | -------------------------------------------------------------------------------- /cluster-contrast-reid/examples/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import random 5 | import numpy as np 6 | import sys 7 | 8 | import torch 9 | from torch import nn 10 | from torch.backends import cudnn 11 | from torch.utils.data import DataLoader 12 | 13 | from clustercontrast import datasets 14 | from clustercontrast import models 15 | from clustercontrast.models.dsbn import convert_dsbn, convert_bn 16 | from clustercontrast.evaluators import Evaluator 17 | from clustercontrast.utils.data import transforms as T 18 | from clustercontrast.utils.data.preprocessor import Preprocessor 19 | from clustercontrast.utils.logging import Logger 20 | from clustercontrast.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict 21 | 22 | 23 | def get_data(name, data_dir, height, width, batch_size, workers): 24 | root = osp.join(data_dir, name) 25 | 26 | dataset = datasets.create(name, root) 27 | 28 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 29 | std=[0.229, 0.224, 0.225]) 30 | 31 | test_transformer = T.Compose([ 32 | T.Resize((height, width), interpolation=3), 33 | T.ToTensor(), 34 | normalizer 35 | ]) 36 | 37 | test_loader = DataLoader( 38 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 39 | root=dataset.images_dir, transform=test_transformer), 40 | batch_size=batch_size, num_workers=workers, 41 | shuffle=False, pin_memory=True) 42 | return dataset, test_loader 43 | 44 | 45 | def main(): 46 | args = parser.parse_args() 47 | 48 | if args.seed is not None: 49 | random.seed(args.seed) 50 | np.random.seed(args.seed) 51 | torch.manual_seed(args.seed) 52 | cudnn.deterministic = True 53 | 54 | main_worker(args) 55 | 56 | 57 | def main_worker(args): 58 | cudnn.benchmark = True 59 | 60 | log_dir = osp.dirname(args.resume) 61 | sys.stdout = Logger(osp.join(log_dir, 'log_test.txt')) 62 | print("==========\nArgs:{}\n==========".format(args)) 63 | 64 | # Create data loaders 65 | dataset, test_loader = get_data(args.dataset, args.data_dir, args.height, 66 | args.width, args.batch_size, args.workers) 67 | 68 | # Create model 69 | model = models.create(args.arch, pretrained=False, num_features=args.features, dropout=args.dropout, 70 | num_classes=0, pooling_type=args.pooling_type) 71 | if args.dsbn: 72 | print("==> Load the model with domain-specific BNs") 73 | convert_dsbn(model) 74 | 75 | # Load from checkpoint 76 | checkpoint = load_checkpoint(args.resume) 77 | copy_state_dict(checkpoint['state_dict'], model, strip='module.') 78 | 79 | if args.dsbn: 80 | print("==> Test with {}-domain BNs".format("source" if args.test_source else "target")) 81 | convert_bn(model, use_target=(not args.test_source)) 82 | 83 | model.cuda() 84 | model = nn.DataParallel(model) 85 | 86 | # Evaluator 87 | model.eval() 88 | evaluator = Evaluator(model) 89 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=True, rerank=args.rerank) 90 | return 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser(description="Testing the model") 95 | # data 96 | parser.add_argument('-d', '--dataset', type=str, default='market1501') 97 | parser.add_argument('-b', '--batch-size', type=int, default=256) 98 | parser.add_argument('-j', '--workers', type=int, default=4) 99 | parser.add_argument('--height', type=int, default=256, help="input height") 100 | parser.add_argument('--width', type=int, default=128, help="input width") 101 | # model 102 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 103 | choices=models.names()) 104 | parser.add_argument('--features', type=int, default=0) 105 | parser.add_argument('--dropout', type=float, default=0) 106 | 107 | parser.add_argument('--resume', type=str, 108 | default="/media/yixuan/DATA/cluster-contrast/market-res50/logs/model_best.pth.tar", 109 | metavar='PATH') 110 | # testing configs 111 | parser.add_argument('--rerank', action='store_true', 112 | help="evaluation only") 113 | parser.add_argument('--dsbn', action='store_true', 114 | help="test on the model with domain-specific BN") 115 | parser.add_argument('--test-source', action='store_true', 116 | help="test on the source domain") 117 | parser.add_argument('--seed', type=int, default=1) 118 | # path 119 | working_dir = osp.dirname(osp.abspath(__file__)) 120 | parser.add_argument('--data-dir', type=str, metavar='PATH', 121 | default='/media/yixuan/Project/guangyuan/workpalces/SpCL/examples/data') 122 | parser.add_argument('--pooling-type', type=str, default='gem') 123 | parser.add_argument('--embedding_features_path', type=str, 124 | default='/media/yixuan/Project/guangyuan/workpalces/SpCL/embedding_features/mark1501_res50_ibn/') 125 | main() 126 | -------------------------------------------------------------------------------- /cluster-contrast-reid/market_uda.sh: -------------------------------------------------------------------------------- 1 | # ViT-S+ICS 2 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/cluster_contrast_train_usl.py -b 256 -a vit_small -d market1501 --iters 200 --eps 0.6 --self-norm --use-hard --hw-ratio 2 --num-instances 8 --conv-stem -pp ../../log/transreid/msmt17/vit_small_ics_cfs_lup/transformer_120.pth --logs-dir ../../log/cluster_contrast_reid/msmt2market/vit_small_ics_cfs_lup 3 | 4 | # VIT-S 5 | CUDA_VISIBLE_DEVICES=4,5,6,7 python examples/cluster_contrast_train_usl.py -b 256 -a vit_small -d market1501 --iters 200 --eps 0.6 --self-norm --use-hard --hw-ratio 2 --num-instances 8 -pp ../../log/transreid/msmt17/vit_small_cfs_lup/transformer_120.pth --logs-dir ../../log/cluster_contrast_reid/msmt2market/vit_small_cfs_lup 6 | -------------------------------------------------------------------------------- /cluster-contrast-reid/market_usl.sh: -------------------------------------------------------------------------------- 1 | # ViT-S+ICS 2 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/cluster_contrast_train_usl.py -b 256 -a vit_small -d market1501 --iters 200 --eps 0.6 --self-norm --use-hard --hw-ratio 2 --num-instances 8 --conv-stem -pp ../../model/vit_small_ics_cfs_lup.pth --logs-dir ../../log/cluster_contrast_reid/market/vit_small_ics_cfs_lup 3 | 4 | # VIT-S 5 | CUDA_VISIBLE_DEVICES=4,5,6,7 python examples/cluster_contrast_train_usl.py -b 256 -a vit_small -d market1501 --iters 200 --eps 0.6 --self-norm --use-hard --hw-ratio 2 --num-instances 8 -pp ../../model/vit_small_cfs_lup.pth --logs-dir ../../log/cluster_contrast_reid/market/vit_small_cfs_lup 6 | -------------------------------------------------------------------------------- /cluster-contrast-reid/msmt_uda.sh: -------------------------------------------------------------------------------- 1 | # ViT-S+ICS 2 | CUDA_VISIBLE_DEVICES=2,3 python examples/cluster_contrast_train_usl.py -b 256 -a vit_small -d msmt17 --iters 200 --eps 0.7 --self-norm --use-hard --hw-ratio 2 --num-instances 8 --conv-stem -pp ../../log/transreid/market/vit_small_ics_cfs_lup/transformer_120.pth --logs-dir ../../log/cluster_contrast_reid/market2msmt/vit_small_ics_cfs_lup 3 | 4 | # VIT-S 5 | # CUDA_VISIBLE_DEVICES=0,1 python examples/cluster_contrast_train_usl.py -b 256 -a vit_small -d msmt17 --iters 200 --eps 0.7 --self-norm --use-hard --hw-ratio 2 --num-instances 8 -pp ../../log/transreid/market/vit_small_cfs_lup/transformer_120.pth --logs-dir ../../log/cluster_contrast_reid/market2msmt/vit_small_cfs_lup 6 | -------------------------------------------------------------------------------- /cluster-contrast-reid/msmt_usl.sh: -------------------------------------------------------------------------------- 1 | # ViT-S+ICS 2 | CUDA_VISIBLE_DEVICES=2,3 python examples/cluster_contrast_train_usl.py -b 256 -a vit_small -d msmt17 --iters 200 --eps 0.7 --self-norm --use-hard --hw-ratio 2 --num-instances 8 --conv-stem -pp ../../model/vit_small_ics_cfs_lup.pth --logs-dir ../../log/cluster_contrast_reid/msmt17/vit_small_ics_cfs_lup --eval-step 50 3 | 4 | # VIT-S 5 | # CUDA_VISIBLE_DEVICES=0,1 python examples/cluster_contrast_train_usl.py -b 256 -a vit_small -d msmt17 --iters 200 --eps 0.7 --self-norm --use-hard --hw-ratio 2 --num-instances 8 -pp ../../model/vit_small_cfs_lup.pth --logs-dir ../../log/cluster_contrast_reid/msmt17/vit_small_cfs_lup --eval-step 50 6 | # CUDA_VISIBLE_DEVICES=0,1 python examples/cluster_contrast_train_usl.py -b 256 -a vit_small -d msmt17 --iters 200 --eps 0.7 --self-norm --use-hard --hw-ratio 2 --num-instances 8 -pp /mnt1/michuan.lh/log/dino/lup_filter/deit_small_251w_forget/checkpoint.pth --logs-dir ../../log/cluster_contrast_reid/msmt17/vit_small_cfs_lup/checkpoint --eval-step 50 7 | -------------------------------------------------------------------------------- /cluster-contrast-reid/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup(name='ClusterContrast', 5 | version='1.0.0', 6 | description='Cluster Contrast for Unsupervised Person Re-Identification', 7 | author='GuangYuan wang', 8 | author_email='yixuan.wgy@alibaba-inc.com', 9 | # url='', 10 | install_requires=[ 11 | 'numpy', 'torch', 'torchvision', 12 | 'six', 'h5py', 'Pillow', 'scipy', 13 | 'scikit-learn', 'metric-learn', 'faiss_gpu'], 14 | packages=find_packages(), 15 | keywords=[ 16 | 'Unsupervised Learning', 17 | 'Contrastive Learning', 18 | 'Object Re-identification' 19 | ]) 20 | -------------------------------------------------------------------------------- /dino/README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Vision Transformers with DINO 2 | We modify the code from [DINO](https://github.com/facebookresearch/dino). You can refer to the original repo for more details. 3 | 4 | ## Training 5 | Please set `--data_path`, `filter_path`, `output_dir` and `--keep_num` in the shell files. 6 | You can set `--keep_num` to 2090122 (50%) or 2508146 (60%) for the conditional training. 7 | 60% pre-training data achieves better performance, while 50% pre-training data makes a trade-off between the performance and the computational cost. 8 | 9 | - Training ViT-S 10 | ```bash 11 | python -W ignore -m torch.distributed.launch --nproc_per_node=8 main_dino.py \ 12 | --arch vit_small \ 13 | --data_path /home/michuan.lh/datasets/LUP \ 14 | --output_dir ./log/dino/lup/vit_small_full_lup \ 15 | --height 256 --width 128 \ 16 | --crop_height 128 --crop_width 64 \ 17 | --epochs 100 \ 18 | ``` 19 | 20 | - Training ViT-S+ICS 21 | ```bash 22 | python -W ignore -m torch.distributed.launch --nproc_per_node=8 main_dino.py \ 23 | --arch ours_vit_small \ 24 | --data_path /home/michuan.lh/datasets/LUP \ 25 | --output_dir ./log/dino/lup/vit_small_ics_full_lup \ 26 | --height 256 --width 128 \ 27 | --crop_height 128 --crop_width 64 \ 28 | --epochs 100 \ 29 | 30 | ``` 31 | 32 | - Training ViT-S+ICS+CFS 33 | ```bash 34 | python -W ignore -m torch.distributed.launch --nproc_per_node=8 main_dino.py \ 35 | --arch ours_vit_small \ 36 | --data_path /home/michuan.lh/datasets/LUP \ 37 | --filter_path /mnt1/michuan.lh/workspace/transreid_pytorch/save/cfs_list.pkl \ 38 | --keep_num 2508146 \ 39 | --output_dir /mnt1/michuan.lh/log/dino/lup_filter/open_source/vit_small_ics_cfs_lup \ 40 | --height 256 --width 128 \ 41 | --crop_height 128 --crop_width 64 \ 42 | --epochs 100 \ 43 | ``` 44 | 45 | ## License 46 | This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file. 47 | 48 | ## Citation 49 | If you find this repository useful, please consider giving a star :star: and citation :t-rex:: 50 | ``` 51 | @article{caron2021emerging, 52 | title={Emerging Properties in Self-Supervised Vision Transformers}, 53 | author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand}, 54 | journal={arXiv preprint arXiv:2104.14294}, 55 | year={2021} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /dino/convert_model_dino.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import pickle as pkl 5 | import sys 6 | import torch 7 | 8 | if __name__ == "__main__": 9 | input = sys.argv[1] 10 | # input = '/mnt1/michuan.lh/log/moco/sysu_200ep/ckpt_0020.pth' 11 | obj = torch.load(input, map_location="cpu") 12 | # obj = obj["state_dict"] 13 | obj = obj["teacher"] 14 | 15 | 16 | newmodel = {} 17 | for k, v in obj.items(): 18 | if k.startswith('module'): 19 | k = k.replace("module.", "") 20 | if not k.startswith("backbone."): 21 | continue 22 | old_k = k 23 | k = k.replace("backbone.", "") 24 | print(old_k, "->", k) 25 | newmodel[k] = v 26 | torch.save(newmodel,sys.argv[2]) 27 | -------------------------------------------------------------------------------- /dino/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from torchvision.models.resnet import resnet50 16 | 17 | import vision_transformer as vits 18 | 19 | dependencies = ["torch", "torchvision"] 20 | 21 | 22 | def dino_vits16(pretrained=True, **kwargs): 23 | """ 24 | ViT-Small/16x16 pre-trained with DINO. 25 | Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification. 26 | """ 27 | model = vits.__dict__["vit_small"](patch_size=16, num_classes=0, **kwargs) 28 | if pretrained: 29 | state_dict = torch.hub.load_state_dict_from_url( 30 | url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", 31 | map_location="cpu", 32 | ) 33 | model.load_state_dict(state_dict, strict=True) 34 | return model 35 | 36 | 37 | def dino_vits8(pretrained=True, **kwargs): 38 | """ 39 | ViT-Small/8x8 pre-trained with DINO. 40 | Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification. 41 | """ 42 | model = vits.__dict__["vit_small"](patch_size=8, num_classes=0, **kwargs) 43 | if pretrained: 44 | state_dict = torch.hub.load_state_dict_from_url( 45 | url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", 46 | map_location="cpu", 47 | ) 48 | model.load_state_dict(state_dict, strict=True) 49 | return model 50 | 51 | 52 | def dino_vitb16(pretrained=True, **kwargs): 53 | """ 54 | ViT-Base/16x16 pre-trained with DINO. 55 | Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification. 56 | """ 57 | model = vits.__dict__["vit_base"](patch_size=16, num_classes=0, **kwargs) 58 | if pretrained: 59 | state_dict = torch.hub.load_state_dict_from_url( 60 | url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth", 61 | map_location="cpu", 62 | ) 63 | model.load_state_dict(state_dict, strict=True) 64 | return model 65 | 66 | 67 | def dino_vitb8(pretrained=True, **kwargs): 68 | """ 69 | ViT-Base/8x8 pre-trained with DINO. 70 | Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification. 71 | """ 72 | model = vits.__dict__["vit_base"](patch_size=8, num_classes=0, **kwargs) 73 | if pretrained: 74 | state_dict = torch.hub.load_state_dict_from_url( 75 | url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", 76 | map_location="cpu", 77 | ) 78 | model.load_state_dict(state_dict, strict=True) 79 | return model 80 | 81 | 82 | def dino_resnet50(pretrained=True, **kwargs): 83 | """ 84 | ResNet-50 pre-trained with DINO. 85 | Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`). 86 | """ 87 | model = resnet50(pretrained=False, **kwargs) 88 | model.fc = torch.nn.Identity() 89 | if pretrained: 90 | state_dict = torch.hub.load_state_dict_from_url( 91 | url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth", 92 | map_location="cpu", 93 | ) 94 | model.load_state_dict(state_dict, strict=False) 95 | return model 96 | -------------------------------------------------------------------------------- /dino/run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A script to run multinode training with submitit. 16 | Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py 17 | """ 18 | import argparse 19 | import os 20 | import uuid 21 | from pathlib import Path 22 | 23 | import main_dino 24 | import submitit 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser("Submitit for DINO", parents=[main_dino.get_args_parser()]) 29 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 30 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 31 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") 32 | 33 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 34 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 35 | parser.add_argument('--comment', default="", type=str, 36 | help='Comment to pass to scheduler, e.g. priority message') 37 | return parser.parse_args() 38 | 39 | 40 | def get_shared_folder() -> Path: 41 | user = os.getenv("USER") 42 | if Path("/checkpoint/").is_dir(): 43 | p = Path(f"/checkpoint/{user}/experiments") 44 | p.mkdir(exist_ok=True) 45 | return p 46 | raise RuntimeError("No shared folder available") 47 | 48 | 49 | def get_init_file(): 50 | # Init file must not exist, but it's parent dir must exist. 51 | os.makedirs(str(get_shared_folder()), exist_ok=True) 52 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 53 | if init_file.exists(): 54 | os.remove(str(init_file)) 55 | return init_file 56 | 57 | 58 | class Trainer(object): 59 | def __init__(self, args): 60 | self.args = args 61 | 62 | def __call__(self): 63 | import main_dino 64 | 65 | self._setup_gpu_args() 66 | main_dino.train_dino(self.args) 67 | 68 | def checkpoint(self): 69 | import os 70 | import submitit 71 | 72 | self.args.dist_url = get_init_file().as_uri() 73 | print("Requeuing ", self.args) 74 | empty_trainer = type(self)(self.args) 75 | return submitit.helpers.DelayedSubmission(empty_trainer) 76 | 77 | def _setup_gpu_args(self): 78 | import submitit 79 | from pathlib import Path 80 | 81 | job_env = submitit.JobEnvironment() 82 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 83 | self.args.gpu = job_env.local_rank 84 | self.args.rank = job_env.global_rank 85 | self.args.world_size = job_env.num_tasks 86 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 87 | 88 | 89 | def main(): 90 | args = parse_args() 91 | if args.output_dir == "": 92 | args.output_dir = get_shared_folder() / "%j" 93 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 94 | executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) 95 | 96 | num_gpus_per_node = args.ngpus 97 | nodes = args.nodes 98 | timeout_min = args.timeout 99 | 100 | partition = args.partition 101 | kwargs = {} 102 | if args.use_volta32: 103 | kwargs['slurm_constraint'] = 'volta32gb' 104 | if args.comment: 105 | kwargs['slurm_comment'] = args.comment 106 | 107 | executor.update_parameters( 108 | mem_gb=40 * num_gpus_per_node, 109 | gpus_per_node=num_gpus_per_node, 110 | tasks_per_node=num_gpus_per_node, # one task per GPU 111 | cpus_per_task=10, 112 | nodes=nodes, 113 | timeout_min=timeout_min, # max is 60 * 72 114 | # Below are cluster dependent parameters 115 | slurm_partition=partition, 116 | slurm_signal_delay_s=120, 117 | **kwargs 118 | ) 119 | 120 | executor.update_parameters(name="dino") 121 | 122 | args.dist_url = get_init_file().as_uri() 123 | 124 | trainer = Trainer(args) 125 | job = executor.submit(trainer) 126 | 127 | print(f"Submitted job_id: {job.job_id}") 128 | print(f"Logs and checkpoints will be saved at: {args.output_dir}") 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /dino/script/vit-b_ics_cfs_lup.sh: -------------------------------------------------------------------------------- 1 | python -W ignore -m torch.distributed.launch --nproc_per_node=8 main_dino.py \ 2 | --arch ours_vit_base \ 3 | --data_path /home/michuan.lh/datasets/LUP \ 4 | --output_dir ../../log/dino/vit-b_ics_cfs_lup \ 5 | --filter_path ../../cfs_list.pkl \ 6 | --keep_num 2508146 \ 7 | --height 256 --width 128 \ 8 | --crop_height 128 --crop_width 64 \ 9 | --epochs 100 \ 10 | -------------------------------------------------------------------------------- /dino/script/vit-s_cfs_lup.sh: -------------------------------------------------------------------------------- 1 | python -W ignore -m torch.distributed.launch --nproc_per_node=8 main_dino.py \ 2 | --arch vit_small \ 3 | --data_path /home/michuan.lh/datasets/LUP \ 4 | --filter_path ../../cfs_list.pkl \ 5 | --keep_num 2508146 \ 6 | --output_dir ./log/dino/vit_small_cfs_lup \ 7 | --height 256 --width 128 \ 8 | --crop_height 128 --crop_width 64 \ 9 | --epochs 100 \ 10 | -------------------------------------------------------------------------------- /dino/script/vit-s_full_lup.sh: -------------------------------------------------------------------------------- 1 | python -W ignore -m torch.distributed.launch --nproc_per_node=8 main_dino.py \ 2 | --arch vit_small \ 3 | --data_path /home/michuan.lh/datasets/LUP \ 4 | --output_dir ./log/dino/lup/vit_small_full_lup \ 5 | --height 256 --width 128 \ 6 | --crop_height 128 --crop_width 64 \ 7 | --epochs 100 \ 8 | -------------------------------------------------------------------------------- /dino/script/vit-s_ics_cfs_lup.sh: -------------------------------------------------------------------------------- 1 | python -W ignore -m torch.distributed.launch --nproc_per_node=8 main_dino.py \ 2 | --arch ours_vit_small \ 3 | --data_path /home/michuan.lh/datasets/LUP \ 4 | --filter_path ../../cfs_list.pkl \ 5 | --keep_num 2508146 \ 6 | --output_dir ./log/dino//vit_small_ics_cfs_lup \ 7 | --height 256 --width 128 \ 8 | --crop_height 128 --crop_width 64 \ 9 | --epochs 100 \ 10 | -------------------------------------------------------------------------------- /dino/script/vit-s_ics_full_lup.sh: -------------------------------------------------------------------------------- 1 | python -W ignore -m torch.distributed.launch --nproc_per_node=8 main_dino.py \ 2 | --arch ours_vit_small \ 3 | --data_path /home/michuan.lh/datasets/LUP \ 4 | --output_dir ./log/dino/lup/vit_small_ics_full_lup \ 5 | --height 256 --width 128 \ 6 | --crop_height 128 --crop_width 64 \ 7 | --epochs 100 \ 8 | -------------------------------------------------------------------------------- /dino/script/vit-s_imagenet.sh: -------------------------------------------------------------------------------- 1 | python -W ignore -m torch.distributed.launch --nproc_per_node=8 main_dino.py \ 2 | --arch vit_small \ 3 | --data_path /home/michuan.lh/datasets/imagenet/train \ 4 | --output_dir ./log/dino/vit-s_imagenet \ 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | faiss-gpu==1.7.1.post2 2 | h5py==2.10.0 3 | lmdb==1.1.1 4 | metric-learn==0.6.2 5 | numpy==1.19.5 6 | opencv-python==4.5.1.48 7 | opencv-python-headless==4.5.1.48 8 | pandas==1.1.5 9 | pytorch-ignite==0.1.2 10 | PyYAML==5.4.1 11 | scikit-image==0.16.2 12 | scikit-learn==0.23.1 13 | scipy==1.2.0 14 | six==1.15.0 15 | sklearn==0.0 16 | timm==0.3.4 17 | torch==1.7.1 18 | torchtoolbox==0.1.5 19 | torchvision==0.8.2 20 | wrapt==1.12.1 21 | yacs==0.1.7 22 | -------------------------------------------------------------------------------- /transreid_pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 heshuting555 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /transreid_pytorch/README.md: -------------------------------------------------------------------------------- 1 | # TransReID: Transformer-based Object Re-Identification [[pdf]](https://openaccess.thecvf.com/content/ICCV2021/papers/He_TransReID_Transformer-Based_Object_Re-Identification_ICCV_2021_paper.pdf) 2 | We modify the code from [TransReID](https://github.com/damo-cv/TransReID). You can refer to the original repo for more details. 3 | 4 | ## Requirements 5 | 6 | ### Installation 7 | 8 | ```bash 9 | (we use /torch 1.7.1 /torchvision 0.8.2 /timm 0.3.4 /cuda 10.1 / 16G or 32G V100 for training and evaluation. 10 | Note that we use torch.cuda.amp to accelerate speed of training which requires pytorch >=1.6) 11 | ``` 12 | ### Prepare Pre-trained Models 13 | Please download the pre-trained models and put them into your custom file folder. 14 | 15 | ## Training 16 | 17 | We utilize 1 GPU for training. Please modify the `MODEL.PRETRAIN_PATH` and `OUTPUT_DIR` in the config file. 18 | 19 | ```bash 20 | python train.py --config_file configs/market/vit_small_ics.yml 21 | ``` 22 | 23 | You also can speed up training with 4-GPUs training. But the performance may be reduced by 0.1~0.2% mAP. 24 | 25 | ```bash 26 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 66666 train.py --config_file configs/market/vit_small_ics_ddp.yml 27 | ``` 28 | 29 | ## Evaluation 30 | 31 | ```bash 32 | python test.py --config_file 'choose which config to test' MODEL.DEVICE_ID "('your device id')" TEST.WEIGHT "('your path of trained checkpoints')" 33 | ``` 34 | 35 | **Some examples:** 36 | 37 | ```bash 38 | # Market 39 | python test.py --config_file configs/market/vit_small_ics.yml MODEL.DEVICE_ID "('0')" TEST.WEIGHT 'XXXX/transformer_120.pth' 40 | ``` 41 | 42 | ## Citation 43 | 44 | If you find this code useful for your research, please cite our paper 45 | 46 | ``` 47 | @InProceedings{He_2021_ICCV, 48 | author = {He, Shuting and Luo, Hao and Wang, Pichao and Wang, Fan and Li, Hao and Jiang, Wei}, 49 | title = {TransReID: Transformer-Based Object Re-Identification}, 50 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 51 | month = {October}, 52 | year = {2021}, 53 | pages = {15013-15022} 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /transreid_pytorch/config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | from .defaults import _C as cfg_test 9 | -------------------------------------------------------------------------------- /transreid_pytorch/configs/market/debug.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '/mnt1/michuan.lh/log/dino/lup_filter/open_source/vit_small_ics_cfs_lup/checkpoint.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'deit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | STEM_CONV: True # False for vanilla ViT-S 13 | DIST_TRAIN: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('market1501') 26 | ROOT_DIR: ('/home/michuan.lh/datasets') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.0032 37 | WARMUP_EPOCHS: 20 38 | IMS_PER_BATCH: 256 39 | WARMUP_METHOD: 'cosine' 40 | LARGE_FC_LR: False 41 | CHECKPOINT_PERIOD: 120 42 | LOG_PERIOD: 20 43 | EVAL_PERIOD: 120 44 | WEIGHT_DECAY: 1e-4 45 | WEIGHT_DECAY_BIAS: 1e-4 46 | BIAS_LR_FACTOR: 2 47 | 48 | TEST: 49 | EVAL: True 50 | IMS_PER_BATCH: 256 51 | RE_RANKING: False 52 | WEIGHT: '' 53 | NECK_FEAT: 'before' 54 | FEAT_NORM: 'yes' 55 | 56 | OUTPUT_DIR: '../../log/transreid/debug' 57 | -------------------------------------------------------------------------------- /transreid_pytorch/configs/market/vit_base_baseline.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/michuan.lh/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('market1501') 24 | ROOT_DIR: ('/home/michuan.lh/datasets') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'cosine' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 20 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../../log/transreid/market/vit_base_384' 54 | -------------------------------------------------------------------------------- /transreid_pytorch/configs/market/vit_base_ics_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '../../model/vit_base_ics_cfs_lup.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | STEM_CONV: True # False for vanilla ViT-S 13 | # DIST_TRAIN: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [384, 128] 17 | SIZE_TEST: [384, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('market1501') 26 | ROOT_DIR: ('/home/michuan.lh/datasets') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.0004 37 | WARMUP_EPOCHS: 20 38 | IMS_PER_BATCH: 64 39 | WARMUP_METHOD: 'cosine' 40 | LARGE_FC_LR: False 41 | CHECKPOINT_PERIOD: 120 42 | LOG_PERIOD: 20 43 | EVAL_PERIOD: 120 44 | WEIGHT_DECAY: 1e-4 45 | WEIGHT_DECAY_BIAS: 1e-4 46 | BIAS_LR_FACTOR: 2 47 | 48 | TEST: 49 | EVAL: True 50 | IMS_PER_BATCH: 256 51 | RE_RANKING: False 52 | WEIGHT: '' 53 | NECK_FEAT: 'before' 54 | FEAT_NORM: 'yes' 55 | 56 | OUTPUT_DIR: '../../log/transreid/market/vit_base_ics_cfs_lup_384' 57 | -------------------------------------------------------------------------------- /transreid_pytorch/configs/market/vit_small.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '../../model/vit_small_cfs_lup.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | # STEM_CONV: True # False for vanilla ViT-S 13 | # DIST_TRAIN: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('market1501') 26 | ROOT_DIR: ('/home/michuan.lh/datasets') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.0004 37 | WARMUP_EPOCHS: 20 38 | IMS_PER_BATCH: 64 39 | WARMUP_METHOD: 'cosine' 40 | LARGE_FC_LR: False 41 | CHECKPOINT_PERIOD: 120 42 | LOG_PERIOD: 20 43 | EVAL_PERIOD: 120 44 | WEIGHT_DECAY: 1e-4 45 | WEIGHT_DECAY_BIAS: 1e-4 46 | BIAS_LR_FACTOR: 2 47 | 48 | TEST: 49 | EVAL: True 50 | IMS_PER_BATCH: 256 51 | RE_RANKING: False 52 | WEIGHT: '' 53 | NECK_FEAT: 'before' 54 | FEAT_NORM: 'yes' 55 | 56 | OUTPUT_DIR: '../../log/transreid/market/vit_small_cfs_lup' 57 | -------------------------------------------------------------------------------- /transreid_pytorch/configs/market/vit_small_ics.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '../../model/vit_small_ics_cfs_lup.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | STEM_CONV: True # False for vanilla ViT-S 13 | # DIST_TRAIN: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('market1501') 26 | ROOT_DIR: ('/home/michuan.lh/datasets') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.0004 37 | WARMUP_EPOCHS: 20 38 | IMS_PER_BATCH: 64 39 | WARMUP_METHOD: 'cosine' 40 | LARGE_FC_LR: False 41 | CHECKPOINT_PERIOD: 120 42 | LOG_PERIOD: 20 43 | EVAL_PERIOD: 120 44 | WEIGHT_DECAY: 1e-4 45 | WEIGHT_DECAY_BIAS: 1e-4 46 | BIAS_LR_FACTOR: 2 47 | 48 | TEST: 49 | EVAL: True 50 | IMS_PER_BATCH: 256 51 | RE_RANKING: False 52 | WEIGHT: '' 53 | NECK_FEAT: 'before' 54 | FEAT_NORM: 'yes' 55 | 56 | OUTPUT_DIR: '../../log/transreid/market/vit_small_ics_cfs_lup' 57 | -------------------------------------------------------------------------------- /transreid_pytorch/configs/market/vit_small_ics_ddp.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '../../model/vit_small_ics_cfs_lup.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'deit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | STEM_CONV: True # False for vanilla ViT-S 13 | DIST_TRAIN: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('market1501') 26 | ROOT_DIR: ('/home/michuan.lh/datasets') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.0032 37 | WARMUP_EPOCHS: 20 38 | IMS_PER_BATCH: 256 39 | WARMUP_METHOD: 'cosine' 40 | LARGE_FC_LR: False 41 | CHECKPOINT_PERIOD: 120 42 | LOG_PERIOD: 20 43 | EVAL_PERIOD: 120 44 | WEIGHT_DECAY: 1e-4 45 | WEIGHT_DECAY_BIAS: 1e-4 46 | BIAS_LR_FACTOR: 2 47 | 48 | TEST: 49 | EVAL: True 50 | IMS_PER_BATCH: 256 51 | RE_RANKING: False 52 | WEIGHT: '' 53 | NECK_FEAT: 'before' 54 | FEAT_NORM: 'yes' 55 | 56 | OUTPUT_DIR: '../../log/transreid/market/vit_small_ics_cfs_lup_ddp' 57 | -------------------------------------------------------------------------------- /transreid_pytorch/configs/msmt17/vit_base_baseline.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/michuan.lh/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('msmt17') 24 | ROOT_DIR: ('/home/michuan.lh/datasets') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'cosine' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 20 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../../log/transreid/msmt17/vit_base_384' 54 | -------------------------------------------------------------------------------- /transreid_pytorch/configs/msmt17/vit_base_ics_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '../../model/vit_base_ics_cfs_lup.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | STEM_CONV: True # False for vanilla ViT-S 13 | # DIST_TRAIN: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [384, 128] 17 | SIZE_TEST: [384, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('msmt17') 26 | ROOT_DIR: ('/home/michuan.lh/datasets') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.0004 37 | WARMUP_EPOCHS: 20 38 | IMS_PER_BATCH: 64 39 | WARMUP_METHOD: 'cosine' 40 | LARGE_FC_LR: False 41 | CHECKPOINT_PERIOD: 120 42 | LOG_PERIOD: 20 43 | EVAL_PERIOD: 120 44 | WEIGHT_DECAY: 1e-4 45 | WEIGHT_DECAY_BIAS: 1e-4 46 | BIAS_LR_FACTOR: 2 47 | 48 | TEST: 49 | EVAL: True 50 | IMS_PER_BATCH: 256 51 | RE_RANKING: False 52 | WEIGHT: '' 53 | NECK_FEAT: 'before' 54 | FEAT_NORM: 'yes' 55 | 56 | OUTPUT_DIR: '../../log/transreid/msmt17/vit_base_ics_cfs_lup_384' 57 | -------------------------------------------------------------------------------- /transreid_pytorch/configs/msmt17/vit_small.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '../../model/vit_small_cfs_lup.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'vit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | # STEM_CONV: True # False for vanilla ViT-S 13 | # DIST_TRAIN: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('msmt17') 26 | ROOT_DIR: ('/home/michuan.lh/datasets') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.0004 37 | WARMUP_EPOCHS: 20 38 | IMS_PER_BATCH: 64 39 | WARMUP_METHOD: 'cosine' 40 | LARGE_FC_LR: False 41 | CHECKPOINT_PERIOD: 120 42 | LOG_PERIOD: 20 43 | EVAL_PERIOD: 120 44 | WEIGHT_DECAY: 1e-4 45 | WEIGHT_DECAY_BIAS: 1e-4 46 | BIAS_LR_FACTOR: 2 47 | 48 | TEST: 49 | EVAL: True 50 | IMS_PER_BATCH: 256 51 | RE_RANKING: False 52 | WEIGHT: '' 53 | NECK_FEAT: 'before' 54 | FEAT_NORM: 'yes' 55 | 56 | OUTPUT_DIR: '../../log/transreid/msmt17/vit_small_cfs_lup' 57 | -------------------------------------------------------------------------------- /transreid_pytorch/configs/msmt17/vit_small_ics.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '../../model/vit_small_ics_cfs_lup.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | STEM_CONV: True # False for vanilla ViT-S 13 | # DIST_TRAIN: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('msmt17') 26 | ROOT_DIR: ('/home/michuan.lh/datasets') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.0004 37 | WARMUP_EPOCHS: 20 38 | IMS_PER_BATCH: 64 39 | WARMUP_METHOD: 'cosine' 40 | LARGE_FC_LR: False 41 | CHECKPOINT_PERIOD: 120 42 | LOG_PERIOD: 20 43 | EVAL_PERIOD: 120 44 | WEIGHT_DECAY: 1e-4 45 | WEIGHT_DECAY_BIAS: 1e-4 46 | BIAS_LR_FACTOR: 2 47 | 48 | TEST: 49 | EVAL: True 50 | IMS_PER_BATCH: 256 51 | RE_RANKING: False 52 | WEIGHT: '' 53 | NECK_FEAT: 'before' 54 | FEAT_NORM: 'yes' 55 | 56 | OUTPUT_DIR: '../../log/transreid/msmt17/vit_small_ics_cfs_lup' 57 | -------------------------------------------------------------------------------- /transreid_pytorch/configs/msmt17/vit_small_ics_ddp.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '../../model/vit_small_ics_cfs_lup.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'deit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | STEM_CONV: True # False for vanilla ViT-S 13 | DIST_TRAIN: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('msmt17') 26 | ROOT_DIR: ('/home/michuan.lh/datasets') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.0032 37 | WARMUP_EPOCHS: 20 38 | IMS_PER_BATCH: 256 39 | WARMUP_METHOD: 'cosine' 40 | LARGE_FC_LR: False 41 | CHECKPOINT_PERIOD: 120 42 | LOG_PERIOD: 20 43 | EVAL_PERIOD: 120 44 | WEIGHT_DECAY: 1e-4 45 | WEIGHT_DECAY_BIAS: 1e-4 46 | BIAS_LR_FACTOR: 2 47 | 48 | TEST: 49 | EVAL: True 50 | IMS_PER_BATCH: 256 51 | RE_RANKING: False 52 | WEIGHT: '' 53 | NECK_FEAT: 'before' 54 | FEAT_NORM: 'yes' 55 | 56 | OUTPUT_DIR: '../../log/transreid/msmt17/vit_small_ics_cfs_lup_ddp' 57 | -------------------------------------------------------------------------------- /transreid_pytorch/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_dataloader import make_dataloader -------------------------------------------------------------------------------- /transreid_pytorch/datasets/bases.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFile 2 | 3 | from torch.utils.data import Dataset 4 | import os.path as osp 5 | import random 6 | import torch 7 | import logging 8 | ImageFile.LOAD_TRUNCATED_IMAGES = True 9 | 10 | 11 | def read_image(img_path): 12 | """Keep reading image until succeed. 13 | This can avoid IOError incurred by heavy IO process.""" 14 | got_img = False 15 | if not osp.exists(img_path): 16 | raise IOError("{} does not exist".format(img_path)) 17 | while not got_img: 18 | try: 19 | img = Image.open(img_path).convert('RGB') 20 | got_img = True 21 | except IOError: 22 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 23 | pass 24 | return img 25 | 26 | 27 | class BaseDataset(object): 28 | """ 29 | Base class of reid dataset 30 | """ 31 | 32 | def get_imagedata_info(self, data): 33 | pids, cams, tracks = [], [], [] 34 | 35 | for _, pid, camid, trackid in data: 36 | pids += [pid] 37 | cams += [camid] 38 | tracks += [trackid] 39 | pids = set(pids) 40 | cams = set(cams) 41 | tracks = set(tracks) 42 | num_pids = len(pids) 43 | num_cams = len(cams) 44 | num_imgs = len(data) 45 | num_views = len(tracks) 46 | return num_pids, num_imgs, num_cams, num_views 47 | 48 | def print_dataset_statistics(self): 49 | raise NotImplementedError 50 | 51 | 52 | class BaseImageDataset(BaseDataset): 53 | """ 54 | Base class of image reid dataset 55 | """ 56 | 57 | def print_dataset_statistics(self, train, query, gallery): 58 | num_train_pids, num_train_imgs, num_train_cams, num_train_views = self.get_imagedata_info(train) 59 | num_query_pids, num_query_imgs, num_query_cams, num_train_views = self.get_imagedata_info(query) 60 | num_gallery_pids, num_gallery_imgs, num_gallery_cams, num_train_views = self.get_imagedata_info(gallery) 61 | logger = logging.getLogger("transreid.check") 62 | logger.info("Dataset statistics:") 63 | logger.info(" ----------------------------------------") 64 | logger.info(" subset | # ids | # images | # cameras") 65 | logger.info(" ----------------------------------------") 66 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 67 | logger.info(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 68 | logger.info(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 69 | logger.info(" ----------------------------------------") 70 | 71 | class ImageDataset(Dataset): 72 | def __init__(self, dataset, transform=None): 73 | self.dataset = dataset 74 | self.transform = transform 75 | 76 | def __len__(self): 77 | return len(self.dataset) 78 | 79 | def __getitem__(self, index): 80 | img_path, pid, camid, trackid = self.dataset[index] 81 | img = read_image(img_path) 82 | 83 | if self.transform is not None: 84 | img = self.transform(img) 85 | 86 | return img, pid, camid, trackid, img_path 87 | # return img, pid, camid, trackid,img_path.split('/')[-1] 88 | -------------------------------------------------------------------------------- /transreid_pytorch/datasets/make_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | from torch.utils.data import DataLoader 4 | 5 | from .bases import ImageDataset 6 | from timm.data.random_erasing import RandomErasing 7 | from .sampler import RandomIdentitySampler, RandomIdentitySampler_IdUniform 8 | from .market1501 import Market1501 9 | from .msmt17 import MSMT17 10 | from .sampler_ddp import RandomIdentitySampler_DDP 11 | import torch.distributed as dist 12 | from .mm import MM 13 | __factory = { 14 | 'market1501': Market1501, 15 | 'msmt17': MSMT17, 16 | 'mm': MM, 17 | } 18 | 19 | def train_collate_fn(batch): 20 | """ 21 | # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果 22 | """ 23 | imgs, pids, camids, viewids , _ = zip(*batch) 24 | pids = torch.tensor(pids, dtype=torch.int64) 25 | viewids = torch.tensor(viewids, dtype=torch.int64) 26 | camids = torch.tensor(camids, dtype=torch.int64) 27 | return torch.stack(imgs, dim=0), pids, camids, viewids, 28 | 29 | def val_collate_fn(batch): 30 | imgs, pids, camids, viewids, img_paths = zip(*batch) 31 | viewids = torch.tensor(viewids, dtype=torch.int64) 32 | camids_batch = torch.tensor(camids, dtype=torch.int64) 33 | return torch.stack(imgs, dim=0), pids, camids, camids_batch, viewids, img_paths 34 | 35 | def make_dataloader(cfg): 36 | train_transforms = T.Compose([ 37 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 38 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 39 | T.Pad(cfg.INPUT.PADDING), 40 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 41 | T.ToTensor(), 42 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 43 | RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'), 44 | ]) 45 | 46 | val_transforms = T.Compose([ 47 | T.Resize(cfg.INPUT.SIZE_TEST), 48 | T.ToTensor(), 49 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 50 | ]) 51 | 52 | num_workers = cfg.DATALOADER.NUM_WORKERS 53 | 54 | if cfg.DATASETS.NAMES == 'ourapi': 55 | dataset = OURAPI(root_train=cfg.DATASETS.ROOT_TRAIN_DIR, root_val=cfg.DATASETS.ROOT_VAL_DIR, config=cfg) 56 | else: 57 | dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR) 58 | 59 | train_set = ImageDataset(dataset.train, train_transforms) 60 | train_set_normal = ImageDataset(dataset.train, val_transforms) 61 | num_classes = dataset.num_train_pids 62 | cam_num = dataset.num_train_cams 63 | view_num = dataset.num_train_vids 64 | 65 | if cfg.DATALOADER.SAMPLER in ['softmax_triplet', 'img_triplet']: 66 | print('using img_triplet sampler') 67 | if cfg.MODEL.DIST_TRAIN: 68 | print('DIST_TRAIN START') 69 | mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size() 70 | data_sampler = RandomIdentitySampler_DDP(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) 71 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True) 72 | train_loader = torch.utils.data.DataLoader( 73 | train_set, 74 | num_workers=num_workers, 75 | batch_sampler=batch_sampler, 76 | collate_fn=train_collate_fn, 77 | pin_memory=True, 78 | ) 79 | else: 80 | train_loader = DataLoader( 81 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 82 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 83 | num_workers=num_workers, collate_fn=train_collate_fn 84 | ) 85 | elif cfg.DATALOADER.SAMPLER == 'softmax': 86 | print('using softmax sampler') 87 | train_loader = DataLoader( 88 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 89 | collate_fn=train_collate_fn 90 | ) 91 | elif cfg.DATALOADER.SAMPLER in ['id_triplet', 'id']: 92 | print('using ID sampler') 93 | train_loader = DataLoader( 94 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 95 | sampler=RandomIdentitySampler_IdUniform(dataset.train, cfg.DATALOADER.NUM_INSTANCE), 96 | num_workers=num_workers, collate_fn=train_collate_fn, drop_last = True, 97 | ) 98 | else: 99 | print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER)) 100 | 101 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 102 | 103 | val_loader = DataLoader( 104 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 105 | collate_fn=val_collate_fn 106 | ) 107 | train_loader_normal = DataLoader( 108 | train_set_normal, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 109 | collate_fn=val_collate_fn 110 | ) 111 | return train_loader, train_loader_normal, val_loader, len(dataset.query), num_classes, cam_num, view_num 112 | -------------------------------------------------------------------------------- /transreid_pytorch/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | from collections import defaultdict 14 | import pickle 15 | class Market1501(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'market1501' 27 | 28 | def __init__(self, root='', verbose=True, pid_begin = 0, **kwargs): 29 | super(Market1501, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 32 | self.query_dir = osp.join(self.dataset_dir, 'query') 33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 34 | 35 | self._check_before_run() 36 | self.pid_begin = pid_begin 37 | train = self._process_dir(self.train_dir, relabel=True) 38 | query = self._process_dir(self.query_dir, relabel=False) 39 | gallery = self._process_dir(self.gallery_dir, relabel=False) 40 | 41 | if verbose: 42 | print("=> Market1501 loaded") 43 | self.print_dataset_statistics(train, query, gallery) 44 | 45 | self.train = train 46 | self.query = query 47 | self.gallery = gallery 48 | 49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 52 | 53 | def _check_before_run(self): 54 | """Check if all files are available before going deeper""" 55 | if not osp.exists(self.dataset_dir): 56 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 57 | if not osp.exists(self.train_dir): 58 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 59 | if not osp.exists(self.query_dir): 60 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 61 | if not osp.exists(self.gallery_dir): 62 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 63 | 64 | def _process_dir(self, dir_path, relabel=False): 65 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 66 | pattern = re.compile(r'([-\d]+)_c(\d)') 67 | 68 | pid_container = set() 69 | for img_path in sorted(img_paths): 70 | pid, _ = map(int, pattern.search(img_path).groups()) 71 | if pid == -1: continue # junk images are just ignored 72 | pid_container.add(pid) 73 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 74 | dataset = [] 75 | for img_path in sorted(img_paths): 76 | pid, camid = map(int, pattern.search(img_path).groups()) 77 | if pid == -1: continue # junk images are just ignored 78 | assert 0 <= pid <= 1501 # pid == 0 means background 79 | assert 1 <= camid <= 6 80 | camid -= 1 # index starts from 0 81 | if relabel: pid = pid2label[pid] 82 | 83 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 84 | return dataset 85 | -------------------------------------------------------------------------------- /transreid_pytorch/datasets/mm.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import os 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | from collections import defaultdict 14 | class MM(BaseImageDataset): 15 | def __init__(self, root='', verbose=True, pid_begin = 0, **kwargs): 16 | super(MM, self).__init__() 17 | self.dataset_dir = osp.join(root, 'market1501') 18 | self.query_dir = osp.join(self.dataset_dir, 'query') 19 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 20 | 21 | market_dir = '/home/michuan.lh/datasets/market1501/bounding_box_train' 22 | msmt_dir = '/home/michuan.lh/datasets/MSMT17/train' 23 | train = self.process_msmt(msmt_dir) 24 | train.extend(self.process_label(market_dir,b_pid=1041,b_camid=15)) 25 | 26 | query = self._process_dir(self.query_dir, relabel=False) 27 | gallery = self._process_dir(self.gallery_dir, relabel=False) 28 | 29 | if verbose: 30 | print("=> MM loaded") 31 | self.print_dataset_statistics(train, query, gallery) 32 | 33 | self.train = train 34 | self.query = query 35 | self.gallery = gallery 36 | 37 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 38 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 39 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 40 | 41 | 42 | def _process_dir(self, dir_path, relabel=False): 43 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 44 | pattern = re.compile(r'([-\d]+)_c(\d)') 45 | 46 | pid_container = set() 47 | for img_path in sorted(img_paths): 48 | pid, _ = map(int, pattern.search(img_path).groups()) 49 | if pid == -1: continue # junk images are just ignored 50 | pid_container.add(pid) 51 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 52 | dataset = [] 53 | for img_path in sorted(img_paths): 54 | pid, camid = map(int, pattern.search(img_path).groups()) 55 | if pid == -1: continue # junk images are just ignored 56 | assert 0 <= pid <= 1501 # pid == 0 means background 57 | assert 1 <= camid <= 6 58 | camid -= 1 # index starts from 0 59 | if relabel: pid = pid2label[pid] 60 | 61 | dataset.append((img_path, pid, camid, 1)) 62 | return dataset 63 | 64 | def process_label(self, root_dir, b_pid=0, b_camid=0): 65 | img_paths = os.listdir(root_dir) 66 | pattern = re.compile(r'([-\d]+)_c(\d)') 67 | pid_container = set() 68 | camid_container = set() 69 | EXTs = ('.jpg', '.png', '.jpeg', '.bmp', '.ppm') 70 | for img_path in img_paths: 71 | if os.path.splitext(img_path)[-1] not in EXTs: continue 72 | pid, camid = map(int, pattern.search(img_path).groups()) 73 | if pid == -1: continue # junk images are just ignored 74 | pid_container.add(pid) 75 | camid_container.add(camid) 76 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 77 | dataset = [] 78 | for img_path in sorted(img_paths): 79 | if os.path.splitext(img_path)[-1] not in EXTs: continue 80 | pid, camid = map(int, pattern.search(img_path).groups()) 81 | camid -= 1 # index starts from 0 82 | if pid == -1: continue # junk images are just ignored 83 | pid = pid2label[pid] 84 | dataset.append((os.path.join(root_dir,img_path), b_pid + pid, b_camid+camid, 1)) 85 | return dataset 86 | 87 | def process_msmt(self, msmt_dir): 88 | list_path = os.path.join(msmt_dir,'../list_train.txt') 89 | with open(list_path, 'r') as txt: 90 | lines = txt.readlines() 91 | dataset = [] 92 | pid_container = set() 93 | for img_idx, img_info in enumerate(lines): 94 | img_path, pid = img_info.split(' ') 95 | pid = int(pid) # no need to relabel 96 | camid = int(img_path.split('_')[2]) 97 | img_path = os.path.join(msmt_dir, img_path) 98 | dataset.append((img_path, pid, camid-1, 1)) 99 | return dataset 100 | 101 | 102 | -------------------------------------------------------------------------------- /transreid_pytorch/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import re 4 | 5 | import os.path as osp 6 | 7 | from .bases import BaseImageDataset 8 | 9 | 10 | class MSMT17(BaseImageDataset): 11 | """ 12 | MSMT17 13 | 14 | Reference: 15 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 16 | 17 | URL: http://www.pkuvmc.com/publications/msmt17.html 18 | 19 | Dataset statistics: 20 | # identities: 4101 21 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 22 | # cameras: 15 23 | """ 24 | dataset_dir = 'MSMT17' 25 | 26 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 27 | super(MSMT17, self).__init__() 28 | self.pid_begin = pid_begin 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | self.train_dir = osp.join(self.dataset_dir, 'train') 31 | self.test_dir = osp.join(self.dataset_dir, 'test') 32 | self.list_train_path = osp.join(self.dataset_dir, 'list_train.txt') 33 | self.list_val_path = osp.join(self.dataset_dir, 'list_val.txt') 34 | self.list_query_path = osp.join(self.dataset_dir, 'list_query.txt') 35 | self.list_gallery_path = osp.join(self.dataset_dir, 'list_gallery.txt') 36 | 37 | self._check_before_run() 38 | train = self._process_dir(self.train_dir, self.list_train_path) 39 | val = self._process_dir(self.train_dir, self.list_val_path) 40 | train += val 41 | query = self._process_dir(self.test_dir, self.list_query_path) 42 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) 43 | if verbose: 44 | print("=> MSMT17 loaded") 45 | self.print_dataset_statistics(train, query, gallery) 46 | 47 | self.train = train 48 | self.query = query 49 | self.gallery = gallery 50 | 51 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 52 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 53 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 54 | def _check_before_run(self): 55 | """Check if all files are available before going deeper""" 56 | if not osp.exists(self.dataset_dir): 57 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 58 | if not osp.exists(self.train_dir): 59 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 60 | if not osp.exists(self.test_dir): 61 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 62 | 63 | def _process_dir(self, dir_path, list_path): 64 | with open(list_path, 'r') as txt: 65 | lines = txt.readlines() 66 | dataset = [] 67 | pid_container = set() 68 | cam_container = set() 69 | for img_idx, img_info in enumerate(lines): 70 | img_path, pid = img_info.split(' ') 71 | pid = int(pid) # no need to relabel 72 | camid = int(img_path.split('_')[2]) 73 | img_path = osp.join(dir_path, img_path) 74 | dataset.append((img_path, self.pid_begin +pid, camid-1, 1)) 75 | pid_container.add(pid) 76 | cam_container.add(camid) 77 | print(cam_container, 'cam_container') 78 | # check if pid starts from 0 and increments with 1 79 | for idx, pid in enumerate(pid_container): 80 | assert idx == pid, "See code comment for explanation" 81 | return dataset 82 | -------------------------------------------------------------------------------- /transreid_pytorch/datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | 5 | class RandomErasing(object): 6 | """ Randomly selects a rectangle region in an image and erases its pixels. 7 | 'Random Erasing Data Augmentation' by Zhong et al. 8 | See https://arxiv.org/pdf/1708.04896.pdf 9 | Args: 10 | probability: The probability that the Random Erasing operation will be performed. 11 | sl: Minimum proportion of erased area against input image. 12 | sh: Maximum proportion of erased area against input image. 13 | r1: Minimum aspect ratio of erased area. 14 | mean: Erasing value. 15 | """ 16 | 17 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 18 | self.probability = probability 19 | self.mean = mean 20 | self.sl = sl 21 | self.sh = sh 22 | self.r1 = r1 23 | 24 | def __call__(self, img): 25 | 26 | if random.uniform(0, 1) >= self.probability: 27 | return img 28 | 29 | for attempt in range(100): 30 | area = img.size()[1] * img.size()[2] 31 | 32 | target_area = random.uniform(self.sl, self.sh) * area 33 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 34 | 35 | h = int(round(math.sqrt(target_area * aspect_ratio))) 36 | w = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if w < img.size()[2] and h < img.size()[1]: 39 | x1 = random.randint(0, img.size()[1] - h) 40 | y1 = random.randint(0, img.size()[2] - w) 41 | if img.size()[0] == 3: 42 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 43 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 44 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 45 | else: 46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 47 | return img 48 | 49 | return img 50 | 51 | -------------------------------------------------------------------------------- /transreid_pytorch/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import torch 5 | import random 6 | import numpy as np 7 | 8 | class RandomIdentitySampler(Sampler): 9 | """ 10 | Randomly sample N identities, then for each identity, 11 | randomly sample K instances, therefore batch size is N*K. 12 | Args: 13 | - data_source (list): list of (img_path, pid, camid). 14 | - num_instances (int): number of instances per identity in a batch. 15 | - batch_size (int): number of examples in a batch. 16 | """ 17 | 18 | def __init__(self, data_source, batch_size, num_instances): 19 | self.data_source = data_source 20 | self.batch_size = batch_size 21 | self.num_instances = num_instances 22 | self.num_pids_per_batch = self.batch_size // self.num_instances 23 | self.index_dic = defaultdict(list) #dict with list value 24 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 25 | for index, (_, pid, _, _) in enumerate(self.data_source): 26 | self.index_dic[pid].append(index) 27 | self.pids = list(self.index_dic.keys()) 28 | 29 | # estimate number of examples in an epoch 30 | self.length = 0 31 | for pid in self.pids: 32 | idxs = self.index_dic[pid] 33 | num = len(idxs) 34 | if num < self.num_instances: 35 | num = self.num_instances 36 | self.length += num - num % self.num_instances 37 | 38 | def __iter__(self): 39 | batch_idxs_dict = defaultdict(list) 40 | 41 | for pid in self.pids: 42 | idxs = copy.deepcopy(self.index_dic[pid]) 43 | if len(idxs) < self.num_instances: 44 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 45 | random.shuffle(idxs) 46 | batch_idxs = [] 47 | for idx in idxs: 48 | batch_idxs.append(idx) 49 | if len(batch_idxs) == self.num_instances: 50 | batch_idxs_dict[pid].append(batch_idxs) 51 | batch_idxs = [] 52 | 53 | avai_pids = copy.deepcopy(self.pids) 54 | final_idxs = [] 55 | 56 | while len(avai_pids) >= self.num_pids_per_batch: 57 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 58 | for pid in selected_pids: 59 | batch_idxs = batch_idxs_dict[pid].pop(0) 60 | final_idxs.extend(batch_idxs) 61 | if len(batch_idxs_dict[pid]) == 0: 62 | avai_pids.remove(pid) 63 | 64 | return iter(final_idxs) 65 | 66 | def __len__(self): 67 | return self.length 68 | 69 | # New add by gu 70 | class RandomIdentitySampler_IdUniform(Sampler): 71 | """ 72 | Randomly sample N identities, then for each identity, 73 | randomly sample K instances, therefore batch size is N*K. 74 | 75 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 76 | 77 | Args: 78 | data_source (Dataset): dataset to sample from. 79 | num_instances (int): number of instances per identity. 80 | """ 81 | def __init__(self, data_source, num_instances): 82 | self.data_source = data_source 83 | self.num_instances = num_instances 84 | self.index_dic = defaultdict(list) 85 | for index, item in enumerate(data_source): 86 | pid = item[1] 87 | self.index_dic[pid].append(index) 88 | self.pids = list(self.index_dic.keys()) 89 | self.num_identities = len(self.pids) 90 | 91 | def __iter__(self): 92 | indices = torch.randperm(self.num_identities) 93 | ret = [] 94 | for i in indices: 95 | pid = self.pids[i] 96 | t = self.index_dic[pid] 97 | replace = False if len(t) >= self.num_instances else True 98 | t = np.random.choice(t, size=self.num_instances, replace=replace) 99 | ret.extend(t) 100 | return iter(ret) 101 | 102 | def __len__(self): 103 | return self.num_identities * self.num_instances 104 | -------------------------------------------------------------------------------- /transreid_pytorch/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import math 8 | import random 9 | import cv2 10 | from collections import deque 11 | from PIL import Image, ImageFilter 12 | import torch 13 | 14 | class RandomErasing(object): 15 | """ Randomly selects a rectangle region in an image and erases its pixels. 16 | 'Random Erasing Data Augmentation' by Zhong et al. 17 | See https://arxiv.org/pdf/1708.04896.pdf 18 | Args: 19 | probability: The probability that the Random Erasing operation will be performed. 20 | sl: Minimum proportion of erased area against input image. 21 | sh: Maximum proportion of erased area against input image. 22 | r1: Minimum aspect ratio of erased area. 23 | mean: Erasing value. 24 | """ 25 | 26 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 27 | self.probability = probability 28 | self.mean = mean 29 | self.sl = sl 30 | self.sh = sh 31 | self.r1 = r1 32 | 33 | def __call__(self, img): 34 | 35 | if random.uniform(0, 1) >= self.probability: 36 | return img 37 | 38 | for attempt in range(100): 39 | area = img.size()[1] * img.size()[2] 40 | 41 | target_area = random.uniform(self.sl, self.sh) * area 42 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 43 | 44 | h = int(round(math.sqrt(target_area * aspect_ratio))) 45 | w = int(round(math.sqrt(target_area / aspect_ratio))) 46 | 47 | if w < img.size()[2] and h < img.size()[1]: 48 | x1 = random.randint(0, img.size()[1] - h) 49 | y1 = random.randint(0, img.size()[2] - w) 50 | if img.size()[0] == 3: 51 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 52 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 53 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 54 | else: 55 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 56 | return img 57 | 58 | return img 59 | 60 | class RandomPatch(object): 61 | """Random patch data augmentation. 62 | There is a patch pool that stores randomly extracted pathces from person images. 63 | 64 | For each input image, RandomPatch 65 | 1) extracts a random patch and stores the patch in the patch pool; 66 | 2) randomly selects a patch from the patch pool and pastes it on the 67 | input (at random position) to simulate occlusion. 68 | Reference: 69 | - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019. 70 | - Zhou et al. Learning Generalisable Omni-Scale Representations 71 | for Person Re-Identification. arXiv preprint, 2019. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | prob_happen=0.5, 77 | pool_capacity=50000, 78 | min_sample_size=100, 79 | patch_min_area=0.01, 80 | patch_max_area=0.5, 81 | patch_min_ratio=0.1, 82 | prob_rotate=0.5, 83 | prob_flip_leftright=0.5, 84 | ): 85 | self.prob_happen = prob_happen 86 | 87 | self.patch_min_area = patch_min_area 88 | self.patch_max_area = patch_max_area 89 | self.patch_min_ratio = patch_min_ratio 90 | 91 | self.prob_rotate = prob_rotate 92 | self.prob_flip_leftright = prob_flip_leftright 93 | 94 | self.patchpool = deque(maxlen=pool_capacity) 95 | self.min_sample_size = min_sample_size 96 | 97 | def generate_wh(self, W, H): 98 | area = W * H 99 | for attempt in range(100): 100 | target_area = random.uniform( 101 | self.patch_min_area, self.patch_max_area 102 | ) * area 103 | aspect_ratio = random.uniform( 104 | self.patch_min_ratio, 1. / self.patch_min_ratio 105 | ) 106 | h = int(round(math.sqrt(target_area * aspect_ratio))) 107 | w = int(round(math.sqrt(target_area / aspect_ratio))) 108 | if w < W and h < H: 109 | return w, h 110 | return None, None 111 | 112 | def transform_patch(self, patch): 113 | if random.uniform(0, 1) > self.prob_flip_leftright: 114 | patch = patch.transpose(Image.FLIP_LEFT_RIGHT) 115 | if random.uniform(0, 1) > self.prob_rotate: 116 | patch = patch.rotate(random.randint(-10, 10)) 117 | return patch 118 | 119 | def __call__(self, img): 120 | W, H = img.size # original image size 121 | 122 | # collect new patch 123 | w, h = self.generate_wh(W, H) 124 | if w is not None and h is not None: 125 | x1 = random.randint(0, W - w) 126 | y1 = random.randint(0, H - h) 127 | new_patch = img.crop((x1, y1, x1 + w, y1 + h)) 128 | self.patchpool.append(new_patch) 129 | 130 | if len(self.patchpool) < self.min_sample_size: 131 | return img 132 | 133 | if random.uniform(0, 1) > self.prob_happen: 134 | return img 135 | 136 | # paste a randomly selected patch on a random position 137 | patch = random.sample(self.patchpool, 1)[0] 138 | patchW, patchH = patch.size 139 | x1 = random.randint(0, W - patchW) 140 | y1 = random.randint(0, H - patchH) 141 | patch = self.transform_patch(patch) 142 | img.paste(patch, (x1, y1)) 143 | 144 | return img 145 | 146 | -------------------------------------------------------------------------------- /transreid_pytorch/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_loss import make_loss 2 | from .arcface import ArcFace -------------------------------------------------------------------------------- /transreid_pytorch/loss/arcface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | import math 6 | 7 | 8 | class ArcFace(nn.Module): 9 | def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False): 10 | super(ArcFace, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.s = s 14 | self.m = m 15 | self.cos_m = math.cos(m) 16 | self.sin_m = math.sin(m) 17 | 18 | self.th = math.cos(math.pi - m) 19 | self.mm = math.sin(math.pi - m) * m 20 | 21 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 22 | if bias: 23 | self.bias = Parameter(torch.Tensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 30 | if self.bias is not None: 31 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 32 | bound = 1 / math.sqrt(fan_in) 33 | nn.init.uniform_(self.bias, -bound, bound) 34 | 35 | def forward(self, input, label): 36 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 37 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 38 | phi = cosine * self.cos_m - sine * self.sin_m 39 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 40 | # --------------------------- convert label to one-hot --------------------------- 41 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 42 | one_hot = torch.zeros(cosine.size(), device='cuda') 43 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 44 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 45 | output = (one_hot * phi) + ( 46 | (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 47 | output *= self.s 48 | # print(output) 49 | 50 | return output 51 | 52 | class CircleLoss(nn.Module): 53 | def __init__(self, in_features, num_classes, s=256, m=0.25): 54 | super(CircleLoss, self).__init__() 55 | self.weight = Parameter(torch.Tensor(num_classes, in_features)) 56 | self.s = s 57 | self.m = m 58 | self._num_classes = num_classes 59 | self.reset_parameters() 60 | 61 | 62 | def reset_parameters(self): 63 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 64 | 65 | def __call__(self, bn_feat, targets): 66 | 67 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) 68 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 69 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 70 | delta_p = 1 - self.m 71 | delta_n = self.m 72 | 73 | s_p = self.s * alpha_p * (sim_mat - delta_p) 74 | s_n = self.s * alpha_n * (sim_mat - delta_n) 75 | 76 | targets = F.one_hot(targets, num_classes=self._num_classes) 77 | 78 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 79 | 80 | return pred_class_logits -------------------------------------------------------------------------------- /transreid_pytorch/loss/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = [] 48 | for i in range(batch_size): 49 | value = distmat[i][mask[i]] 50 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 51 | dist.append(value) 52 | dist = torch.cat(dist) 53 | loss = dist.mean() 54 | return loss 55 | 56 | 57 | if __name__ == '__main__': 58 | use_gpu = False 59 | center_loss = CenterLoss(use_gpu=use_gpu) 60 | features = torch.rand(16, 2048) 61 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 62 | if use_gpu: 63 | features = torch.rand(16, 2048).cuda() 64 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 65 | 66 | loss = center_loss(features, targets) 67 | print(loss) 68 | -------------------------------------------------------------------------------- /transreid_pytorch/loss/make_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 9 | from .triplet_loss import TripletLoss 10 | from .center_loss import CenterLoss 11 | 12 | 13 | def make_loss(cfg, num_classes): # modified by gu 14 | sampler = cfg.DATALOADER.SAMPLER 15 | feat_dim = 2048 16 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 17 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE: 18 | if cfg.MODEL.NO_MARGIN: 19 | triplet = TripletLoss() 20 | print("using soft triplet loss for training") 21 | else: 22 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 23 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN)) 24 | else: 25 | print('expected METRIC_LOSS_TYPE should be triplet' 26 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 27 | 28 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 29 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 30 | print("label smooth on, numclasses:", num_classes) 31 | 32 | if sampler in ['softmax', 'id']: 33 | def loss_func(score, feat, target,target_cam): 34 | return F.cross_entropy(score, target) 35 | 36 | # elif cfg.DATALOADER.SAMPLER in ['softmax_triplet', 'id_triplet', 'img_triplet']: 37 | elif 'triplet' in sampler: 38 | def loss_func(score, feat, target, target_cam): 39 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 40 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 41 | if isinstance(score, list): 42 | ID_LOSS = [xent(scor, target) for scor in score[1:]] 43 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 44 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target) 45 | else: 46 | ID_LOSS = xent(score, target) 47 | 48 | if isinstance(feat, list): 49 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 50 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 51 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 52 | else: 53 | TRI_LOSS = triplet(feat, target, normalize_feature=cfg.SOLVER.TRP_L2)[0] 54 | 55 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 56 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 57 | else: 58 | if isinstance(score, list): 59 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score[1:]] 60 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 61 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * F.cross_entropy(score[0], target) 62 | else: 63 | ID_LOSS = F.cross_entropy(score, target) 64 | 65 | if isinstance(feat, list): 66 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 67 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 68 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 69 | else: 70 | TRI_LOSS = triplet(feat, target, normalize_feature=cfg.SOLVER.TRP_L2)[0] 71 | 72 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 73 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 74 | else: 75 | print('expected METRIC_LOSS_TYPE should be triplet' 76 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 77 | 78 | else: 79 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center' 80 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 81 | return loss_func, center_criterion 82 | 83 | 84 | -------------------------------------------------------------------------------- /transreid_pytorch/loss/softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | class CrossEntropyLabelSmooth(nn.Module): 5 | """Cross entropy loss with label smoothing regularizer. 6 | 7 | Reference: 8 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 9 | Equation: y = (1 - epsilon) * y + epsilon / K. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | epsilon (float): weight. 14 | """ 15 | 16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 17 | super(CrossEntropyLabelSmooth, self).__init__() 18 | self.num_classes = num_classes 19 | self.epsilon = epsilon 20 | self.use_gpu = use_gpu 21 | self.logsoftmax = nn.LogSoftmax(dim=1) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (num_classes) 28 | """ 29 | log_probs = self.logsoftmax(inputs) 30 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 31 | if self.use_gpu: targets = targets.cuda() 32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 33 | loss = (- targets * log_probs).mean(0).sum() 34 | return loss 35 | 36 | class LabelSmoothingCrossEntropy(nn.Module): 37 | """ 38 | NLL loss with label smoothing. 39 | """ 40 | def __init__(self, smoothing=0.1): 41 | """ 42 | Constructor for the LabelSmoothing module. 43 | :param smoothing: label smoothing factor 44 | """ 45 | super(LabelSmoothingCrossEntropy, self).__init__() 46 | assert smoothing < 1.0 47 | self.smoothing = smoothing 48 | self.confidence = 1. - smoothing 49 | 50 | def forward(self, x, target): 51 | logprobs = F.log_softmax(x, dim=-1) 52 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 53 | nll_loss = nll_loss.squeeze(1) 54 | smooth_loss = -logprobs.mean(dim=-1) 55 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 56 | return loss.mean() -------------------------------------------------------------------------------- /transreid_pytorch/loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def normalize(x, axis=-1): 6 | """Normalizing to unit length along the specified dimension. 7 | Args: 8 | x: pytorch Variable 9 | Returns: 10 | x: pytorch Variable, same shape as input 11 | """ 12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 13 | return x 14 | 15 | 16 | def euclidean_dist(x, y): 17 | """ 18 | Args: 19 | x: pytorch Variable, with shape [m, d] 20 | y: pytorch Variable, with shape [n, d] 21 | Returns: 22 | dist: pytorch Variable, with shape [m, n] 23 | """ 24 | m, n = x.size(0), y.size(0) 25 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 26 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 27 | dist = xx + yy 28 | dist = dist - 2 * torch.matmul(x, y.t()) 29 | # dist.addmm_(1, -2, x, y.t()) 30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 31 | return dist 32 | 33 | 34 | def cosine_dist(x, y): 35 | """ 36 | Args: 37 | x: pytorch Variable, with shape [m, d] 38 | y: pytorch Variable, with shape [n, d] 39 | Returns: 40 | dist: pytorch Variable, with shape [m, n] 41 | """ 42 | m, n = x.size(0), y.size(0) 43 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n) 44 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t() 45 | xy_intersection = torch.mm(x, y.t()) 46 | dist = xy_intersection/(x_norm * y_norm) 47 | dist = (1. - dist) / 2 48 | return dist 49 | 50 | 51 | def hard_example_mining(dist_mat, labels, return_inds=False): 52 | """For each anchor, find the hardest positive and negative sample. 53 | Args: 54 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 55 | labels: pytorch LongTensor, with shape [N] 56 | return_inds: whether to return the indices. Save time if `False`(?) 57 | Returns: 58 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 59 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 60 | p_inds: pytorch LongTensor, with shape [N]; 61 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 62 | n_inds: pytorch LongTensor, with shape [N]; 63 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 64 | NOTE: Only consider the case in which all labels have same num of samples, 65 | thus we can cope with all anchors in parallel. 66 | """ 67 | 68 | assert len(dist_mat.size()) == 2 69 | assert dist_mat.size(0) == dist_mat.size(1) 70 | N = dist_mat.size(0) 71 | 72 | # shape [N, N] 73 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 74 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 75 | 76 | # `dist_ap` means distance(anchor, positive) 77 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 78 | dist_ap, relative_p_inds = torch.max( 79 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 80 | # print(dist_mat[is_pos].shape) 81 | # `dist_an` means distance(anchor, negative) 82 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 83 | dist_an, relative_n_inds = torch.min( 84 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 85 | # shape [N] 86 | dist_ap = dist_ap.squeeze(1) 87 | dist_an = dist_an.squeeze(1) 88 | 89 | if return_inds: 90 | # shape [N, N] 91 | ind = (labels.new().resize_as_(labels) 92 | .copy_(torch.arange(0, N).long()) 93 | .unsqueeze(0).expand(N, N)) 94 | # shape [N, 1] 95 | p_inds = torch.gather( 96 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 97 | n_inds = torch.gather( 98 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 99 | # shape [N] 100 | p_inds = p_inds.squeeze(1) 101 | n_inds = n_inds.squeeze(1) 102 | return dist_ap, dist_an, p_inds, n_inds 103 | 104 | return dist_ap, dist_an 105 | 106 | 107 | class TripletLoss(object): 108 | """ 109 | Triplet loss using HARDER example mining, 110 | modified based on original triplet loss using hard example mining 111 | """ 112 | 113 | def __init__(self, margin=None, hard_factor=0.0): 114 | self.margin = margin 115 | self.hard_factor = hard_factor 116 | if margin is not None: 117 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 118 | else: 119 | self.ranking_loss = nn.SoftMarginLoss() 120 | 121 | def __call__(self, global_feat, labels, normalize_feature=False): 122 | if normalize_feature: 123 | global_feat = normalize(global_feat, axis=-1) 124 | dist_mat = euclidean_dist(global_feat, global_feat) 125 | dist_ap, dist_an = hard_example_mining(dist_mat, labels) 126 | 127 | # dist_ap *= (1.0 + self.hard_factor) 128 | # dist_an *= (1.0 - self.hard_factor) 129 | 130 | y = dist_an.new().resize_as_(dist_an).fill_(1) 131 | if self.margin is not None: 132 | loss = self.ranking_loss(dist_an, dist_ap, y) 133 | else: 134 | loss = self.ranking_loss(dist_an - dist_ap, y) 135 | return loss, dist_ap, dist_an 136 | 137 | 138 | -------------------------------------------------------------------------------- /transreid_pytorch/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /transreid_pytorch/model/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/TransReID-SSL/fc39e88240aa7cb7b28dd2097e7f161ae2be3ad8/transreid_pytorch/model/backbones/__init__.py -------------------------------------------------------------------------------- /transreid_pytorch/model/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, last_stride=2, block=Bottleneck,layers=[3, 4, 6, 3]): 86 | self.inplanes = 64 87 | super().__init__() 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | self.bn1 = nn.BatchNorm2d(64) 91 | self.relu = nn.ReLU(inplace=True) # add missed relu 92 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=None, padding=0) 93 | self.layer1 = self._make_layer(block, 64, layers[0]) 94 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 96 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 97 | 98 | def _make_layer(self, block, planes, blocks, stride=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | downsample = nn.Sequential( 102 | nn.Conv2d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm2d(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x, cam_label=None): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | x = self.relu(x) # add missed relu 119 | x = self.maxpool(x) 120 | x = self.layer1(x) 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | 125 | return x 126 | 127 | def load_param(self, model_path): 128 | param_dict = torch.load(model_path) 129 | if 'model' in param_dict: 130 | param_dict = param_dict['model'] 131 | for i in param_dict: 132 | if 'fc' in i: 133 | continue 134 | self.state_dict()[i].copy_(param_dict[i]) 135 | 136 | def random_init(self): 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 140 | m.weight.data.normal_(0, math.sqrt(2. / n)) 141 | elif isinstance(m, nn.BatchNorm2d): 142 | m.weight.data.fill_(1) 143 | m.bias.data.zero_() 144 | -------------------------------------------------------------------------------- /transreid_pytorch/processor/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import do_train, do_inference -------------------------------------------------------------------------------- /transreid_pytorch/run.sh: -------------------------------------------------------------------------------- 1 | #Single GPU 2 | python train.py --config_file configs/market/vit_small_ics.yml 3 | 4 | # Multiple GPUs 5 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 66666 train.py --config_file configs/market/vit_small_ics_ddp.yml 6 | 7 | -------------------------------------------------------------------------------- /transreid_pytorch/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup(name='TranReID', 5 | version='1.0.0', 6 | description='TransReID: Transformer-based Object Re-Identification', 7 | author='xxx', 8 | author_email='xxx', 9 | url='xxx', 10 | install_requires=[ 11 | 'numpy', 'torch==1.6.0', 'torchvision==0.7.0', 12 | 'h5py', 'opencv-python', 'yacs', 'timm==0.3.2' 13 | ], 14 | packages=find_packages(), 15 | keywords=[ 16 | 'Pure Transformer', 17 | 'Object Re-identification' 18 | ]) 19 | -------------------------------------------------------------------------------- /transreid_pytorch/solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import WarmupMultiStepLR 2 | from .make_optimizer import make_optimizer -------------------------------------------------------------------------------- /transreid_pytorch/solver/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import torch 10 | 11 | from .scheduler import Scheduler 12 | 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | class CosineLRScheduler(Scheduler): 18 | """ 19 | Cosine decay with restarts. 20 | This is described in the paper https://arxiv.org/abs/1608.03983. 21 | 22 | Inspiration from 23 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 24 | """ 25 | 26 | def __init__(self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 51 | "rate since t_initial = t_mul = eta_mul = 1.") 52 | self.t_initial = t_initial 53 | self.t_mul = t_mul 54 | self.lr_min = lr_min 55 | self.decay_rate = decay_rate 56 | self.cycle_limit = cycle_limit 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | self.warmup_prefix = warmup_prefix 60 | self.t_in_epochs = t_in_epochs 61 | if self.warmup_t: 62 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 63 | super().update_groups(self.warmup_lr_init) 64 | else: 65 | self.warmup_steps = [1 for _ in self.base_values] 66 | 67 | def _get_lr(self, t): 68 | if t < self.warmup_t: 69 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 70 | else: 71 | if self.warmup_prefix: 72 | t = t - self.warmup_t 73 | 74 | if self.t_mul != 1: 75 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 76 | t_i = self.t_mul ** i * self.t_initial 77 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 78 | else: 79 | i = t // self.t_initial 80 | t_i = self.t_initial 81 | t_curr = t - (self.t_initial * i) 82 | 83 | gamma = self.decay_rate ** i 84 | lr_min = self.lr_min * gamma 85 | lr_max_values = [v * gamma for v in self.base_values] 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | lrs = [ 89 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 90 | ] 91 | else: 92 | lrs = [self.lr_min for _ in self.base_values] 93 | 94 | return lrs 95 | 96 | def get_epoch_values(self, epoch: int): 97 | if self.t_in_epochs: 98 | return self._get_lr(epoch) 99 | else: 100 | return None 101 | 102 | def get_update_values(self, num_updates: int): 103 | if not self.t_in_epochs: 104 | return self._get_lr(num_updates) 105 | else: 106 | return None 107 | 108 | def get_cycle_length(self, cycles=0): 109 | if not cycles: 110 | cycles = self.cycle_limit 111 | cycles = max(1, cycles) 112 | if self.t_mul == 1.0: 113 | return self.t_initial * cycles 114 | else: 115 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 116 | -------------------------------------------------------------------------------- /transreid_pytorch/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, # steps 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /transreid_pytorch/solver/make_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def make_optimizer(cfg, model, center_criterion): 5 | params = [] 6 | for key, value in model.named_parameters(): 7 | if not value.requires_grad: 8 | continue 9 | lr = cfg.SOLVER.BASE_LR 10 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 11 | if "bias" in key: 12 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 14 | if cfg.SOLVER.LARGE_FC_LR: 15 | if "classifier" in key or "arcface" in key: 16 | lr = cfg.SOLVER.BASE_LR * 2 17 | print('Using two times learning rate for fc ') 18 | 19 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 20 | 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 23 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 24 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 25 | else: 26 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 27 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 28 | 29 | return optimizer, optimizer_center 30 | -------------------------------------------------------------------------------- /transreid_pytorch/solver/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /transreid_pytorch/solver/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | 6 | 7 | def create_scheduler(cfg, optimizer): 8 | num_epochs = cfg.SOLVER.MAX_EPOCHS 9 | # type 1 10 | # lr_min = 0.01 * cfg.SOLVER.BASE_LR 11 | # warmup_lr_init = 0.001 * cfg.SOLVER.BASE_LR 12 | # type 2 13 | lr_min = 0.002 * cfg.SOLVER.BASE_LR 14 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 15 | # type 3 16 | # lr_min = 0.001 * cfg.SOLVER.BASE_LR 17 | # warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 18 | 19 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS 20 | noise_range = None 21 | 22 | lr_scheduler = CosineLRScheduler( 23 | optimizer, 24 | t_initial=num_epochs, 25 | lr_min=lr_min, 26 | t_mul= 1., 27 | decay_rate=0.1, 28 | warmup_lr_init=warmup_lr_init, 29 | warmup_t=warmup_t, 30 | cycle_limit=1, 31 | t_in_epochs=True, 32 | noise_range_t=noise_range, 33 | noise_pct= 0.67, 34 | noise_std= 1., 35 | noise_seed=42, 36 | ) 37 | 38 | return lr_scheduler 39 | -------------------------------------------------------------------------------- /transreid_pytorch/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from config import cfg 3 | import argparse 4 | from datasets import make_dataloader 5 | from model import make_model 6 | from processor import do_inference 7 | from utils.logger import setup_logger 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 12 | parser.add_argument( 13 | "--config_file", default="", help="path to config file", type=str 14 | ) 15 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 16 | nargs=argparse.REMAINDER) 17 | 18 | args = parser.parse_args() 19 | 20 | 21 | 22 | if args.config_file != "": 23 | cfg.merge_from_file(args.config_file) 24 | cfg.merge_from_list(args.opts) 25 | cfg.freeze() 26 | 27 | output_dir = cfg.OUTPUT_DIR 28 | if output_dir and not os.path.exists(output_dir): 29 | os.makedirs(output_dir) 30 | 31 | logger = setup_logger("transreid", output_dir, if_train=False) 32 | logger.info(args) 33 | 34 | if args.config_file != "": 35 | logger.info("Loaded configuration file {}".format(args.config_file)) 36 | with open(args.config_file, 'r') as cf: 37 | config_str = "\n" + cf.read() 38 | logger.info(config_str) 39 | logger.info("Running with config:\n{}".format(cfg)) 40 | 41 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 42 | 43 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 44 | 45 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num) 46 | model.load_param(cfg.TEST.WEIGHT) 47 | 48 | if cfg.DATASETS.NAMES == 'VehicleID': 49 | for trial in range(10): 50 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 51 | rank_1, rank5 = do_inference(cfg, 52 | model, 53 | val_loader, 54 | num_query) 55 | if trial == 0: 56 | all_rank_1 = rank_1 57 | all_rank_5 = rank5 58 | else: 59 | all_rank_1 = all_rank_1 + rank_1 60 | all_rank_5 = all_rank_5 + rank5 61 | 62 | logger.info("rank_1:{}, rank_5 {} : trial : {}".format(rank_1, rank5, trial)) 63 | logger.info("sum_rank_1:{:.1%}, sum_rank_5 {:.1%}".format(all_rank_1.sum()/10.0, all_rank_5.sum()/10.0)) 64 | else: 65 | do_inference(cfg, 66 | model, 67 | val_loader, 68 | num_query) 69 | 70 | -------------------------------------------------------------------------------- /transreid_pytorch/train.py: -------------------------------------------------------------------------------- 1 | from utils.logger import setup_logger 2 | from datasets import make_dataloader 3 | from model import make_model 4 | from solver import make_optimizer, WarmupMultiStepLR 5 | from solver.scheduler_factory import create_scheduler 6 | from loss import make_loss 7 | from processor import do_train 8 | import random 9 | import torch 10 | import numpy as np 11 | import os 12 | import argparse 13 | from config import cfg 14 | import torch.distributed as dist 15 | 16 | def set_seed(seed): 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | np.random.seed(seed) 21 | random.seed(seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = True 24 | 25 | if __name__ == '__main__': 26 | 27 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 28 | parser.add_argument( 29 | "--config_file", default="", help="path to config file", type=str 30 | ) 31 | 32 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 33 | nargs=argparse.REMAINDER) 34 | parser.add_argument("--local_rank", default=0, type=int) 35 | args = parser.parse_args() 36 | 37 | if args.config_file != "": 38 | cfg.merge_from_file(args.config_file) 39 | cfg.merge_from_list(args.opts) 40 | 41 | cfg.freeze() 42 | set_seed(cfg.SOLVER.SEED) 43 | 44 | if cfg.MODEL.DIST_TRAIN: 45 | torch.cuda.set_device(args.local_rank) 46 | 47 | output_dir = cfg.OUTPUT_DIR 48 | try: 49 | os.makedirs(output_dir) 50 | except: 51 | pass 52 | 53 | logger = setup_logger("transreid", output_dir, if_train=True) 54 | logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR)) 55 | # logger.info(args) 56 | 57 | if args.config_file != "": 58 | logger.info("Loaded configuration file {}".format(args.config_file)) 59 | with open(args.config_file, 'r') as cf: 60 | config_str = "\n" + cf.read() 61 | # logger.info(config_str) 62 | 63 | if cfg.MODEL.DIST_TRAIN: 64 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 65 | logger.info("Running with config:\n{}".format(cfg)) 66 | 67 | 68 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 69 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 70 | 71 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num) 72 | loss_func, center_criterion = make_loss(cfg, num_classes=num_classes) 73 | optimizer, optimizer_center = make_optimizer(cfg, model, center_criterion) 74 | 75 | if cfg.SOLVER.WARMUP_METHOD == 'cosine': 76 | logger.info('===========using cosine learning rate=======') 77 | scheduler = create_scheduler(cfg, optimizer) 78 | else: 79 | logger.info('===========using normal learning rate=======') 80 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, 81 | cfg.SOLVER.WARMUP_FACTOR, 82 | cfg.SOLVER.WARMUP_EPOCHS, cfg.SOLVER.WARMUP_METHOD) 83 | 84 | do_train( 85 | cfg, 86 | model, 87 | center_criterion, 88 | train_loader, 89 | val_loader, 90 | optimizer, 91 | optimizer_center, 92 | scheduler, 93 | loss_func, 94 | num_query, args.local_rank 95 | ) 96 | # print(cfg.OUTPUT_DIR) 97 | # print(cfg.MODEL.PRETRAIN_PATH) 98 | -------------------------------------------------------------------------------- /transreid_pytorch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/TransReID-SSL/fc39e88240aa7cb7b28dd2097e7f161ae2be3ad8/transreid_pytorch/utils/__init__.py -------------------------------------------------------------------------------- /transreid_pytorch/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | def swig_ptr_from_FloatTensor(x): 7 | assert x.is_contiguous() 8 | assert x.dtype == torch.float32 9 | return faiss.cast_integer_to_float_ptr( 10 | x.storage().data_ptr() + x.storage_offset() * 4) 11 | 12 | def swig_ptr_from_LongTensor(x): 13 | assert x.is_contiguous() 14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 15 | return faiss.cast_integer_to_long_ptr( 16 | x.storage().data_ptr() + x.storage_offset() * 8) 17 | 18 | def search_index_pytorch(index, x, k, D=None, I=None): 19 | """call the search function of an index with pytorch tensor I/O (CPU 20 | and GPU supported)""" 21 | assert x.is_contiguous() 22 | n, d = x.size() 23 | assert d == index.d 24 | 25 | if D is None: 26 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 27 | else: 28 | assert D.size() == (n, k) 29 | 30 | if I is None: 31 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 32 | else: 33 | assert I.size() == (n, k) 34 | torch.cuda.synchronize() 35 | xptr = swig_ptr_from_FloatTensor(x) 36 | Iptr = swig_ptr_from_LongTensor(I) 37 | Dptr = swig_ptr_from_FloatTensor(D) 38 | index.search_c(n, xptr, 39 | k, Dptr, Iptr) 40 | torch.cuda.synchronize() 41 | return D, I 42 | 43 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 44 | metric=faiss.METRIC_L2): 45 | assert xb.device == xq.device 46 | 47 | nq, d = xq.size() 48 | if xq.is_contiguous(): 49 | xq_row_major = True 50 | elif xq.t().is_contiguous(): 51 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 52 | xq_row_major = False 53 | else: 54 | raise TypeError('matrix should be row or column-major') 55 | 56 | xq_ptr = swig_ptr_from_FloatTensor(xq) 57 | 58 | nb, d2 = xb.size() 59 | assert d2 == d 60 | if xb.is_contiguous(): 61 | xb_row_major = True 62 | elif xb.t().is_contiguous(): 63 | xb = xb.t() 64 | xb_row_major = False 65 | else: 66 | raise TypeError('matrix should be row or column-major') 67 | xb_ptr = swig_ptr_from_FloatTensor(xb) 68 | 69 | if D is None: 70 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 71 | else: 72 | assert D.shape == (nq, k) 73 | assert D.device == xb.device 74 | 75 | if I is None: 76 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 77 | else: 78 | assert I.shape == (nq, k) 79 | assert I.device == xb.device 80 | 81 | D_ptr = swig_ptr_from_FloatTensor(D) 82 | I_ptr = swig_ptr_from_LongTensor(I) 83 | 84 | faiss.bruteForceKnn(res, metric, 85 | xb_ptr, xb_row_major, nb, 86 | xq_ptr, xq_row_major, nq, 87 | d, k, D_ptr, I_ptr) 88 | 89 | return D, I 90 | 91 | def index_init_gpu(ngpus, feat_dim): 92 | flat_config = [] 93 | for i in range(ngpus): 94 | cfg = faiss.GpuIndexFlatConfig() 95 | cfg.useFloat16 = False 96 | cfg.device = i 97 | flat_config.append(cfg) 98 | 99 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 100 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 101 | index = faiss.IndexShards(feat_dim) 102 | for sub_index in indexes: 103 | index.add_shard(sub_index) 104 | index.reset() 105 | return index 106 | 107 | def index_init_cpu(feat_dim): 108 | return faiss.IndexFlatL2(feat_dim) 109 | -------------------------------------------------------------------------------- /transreid_pytorch/utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import errno 8 | import json 9 | import os 10 | 11 | import os.path as osp 12 | 13 | 14 | def mkdir_if_missing(directory): 15 | if not osp.exists(directory): 16 | try: 17 | os.makedirs(directory) 18 | except OSError as e: 19 | if e.errno != errno.EEXIST: 20 | raise 21 | 22 | 23 | def check_isfile(path): 24 | isfile = osp.isfile(path) 25 | if not isfile: 26 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 27 | return isfile 28 | 29 | 30 | def read_json(fpath): 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | mkdir_if_missing(osp.dirname(fpath)) 38 | with open(fpath, 'w') as f: 39 | json.dump(obj, f, indent=4, separators=(',', ': ')) 40 | -------------------------------------------------------------------------------- /transreid_pytorch/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import os.path as osp 5 | def setup_logger(name, save_dir, if_train): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | 9 | ch = logging.StreamHandler(stream=sys.stdout) 10 | ch.setLevel(logging.DEBUG) 11 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s",datefmt = '%Y-%m-%d %H:%M:%S') 12 | ch.setFormatter(formatter) 13 | logger.addHandler(ch) 14 | 15 | if save_dir: 16 | if not osp.exists(save_dir): 17 | os.makedirs(save_dir) 18 | if if_train: 19 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w') 20 | else: 21 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='w') 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger 27 | -------------------------------------------------------------------------------- /transreid_pytorch/utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /transreid_pytorch/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from utils.reranking import re_ranking 5 | 6 | 7 | def euclidean_distance(qf, gf): 8 | m = qf.shape[0] 9 | n = gf.shape[0] 10 | dist_mat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 11 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 12 | dist_mat.addmm_(1, -2, qf, gf.t()) 13 | return dist_mat.cpu().numpy() 14 | 15 | def cosine_similarity(qf, gf): 16 | epsilon = 0.00001 17 | dist_mat = qf.mm(gf.t()) 18 | qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True) # mx1 19 | gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True) # nx1 20 | qg_normdot = qf_norm.mm(gf_norm.t()) 21 | 22 | dist_mat = dist_mat.mul(1 / qg_normdot).cpu().numpy() 23 | dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon) 24 | dist_mat = np.arccos(dist_mat) 25 | return dist_mat 26 | 27 | 28 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 29 | """Evaluation with market1501 metric 30 | Key: for each query identity, its gallery images from the same camera view are discarded. 31 | """ 32 | num_q, num_g = distmat.shape 33 | # distmat g 34 | # q 1 3 2 4 35 | # 4 1 2 3 36 | if num_g < max_rank: 37 | max_rank = num_g 38 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 39 | indices = np.argsort(distmat, axis=1) 40 | # 0 2 1 3 41 | # 1 2 3 0 42 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 43 | # compute cmc curve for each query 44 | all_cmc = [] 45 | all_AP = [] 46 | num_valid_q = 0. # number of valid query 47 | for q_idx in range(num_q): 48 | # get query pid and camid 49 | q_pid = q_pids[q_idx] 50 | q_camid = q_camids[q_idx] 51 | 52 | # remove gallery samples that have the same pid and camid with query 53 | order = indices[q_idx] # select one row 54 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 55 | keep = np.invert(remove) 56 | 57 | # compute cmc curve 58 | # binary vector, positions with value 1 are correct matches 59 | orig_cmc = matches[q_idx][keep] 60 | if not np.any(orig_cmc): 61 | # this condition is true when query identity does not appear in gallery 62 | continue 63 | 64 | cmc = orig_cmc.cumsum() 65 | cmc[cmc > 1] = 1 66 | 67 | all_cmc.append(cmc[:max_rank]) 68 | num_valid_q += 1. 69 | 70 | # compute average precision 71 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 72 | num_rel = orig_cmc.sum() 73 | tmp_cmc = orig_cmc.cumsum() 74 | y = np.arange(1, tmp_cmc.shape[0] + 1) * 1.0 75 | tmp_cmc = tmp_cmc / y 76 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 77 | AP = tmp_cmc.sum() / num_rel 78 | all_AP.append(AP) 79 | 80 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 81 | 82 | all_cmc = np.asarray(all_cmc).astype(np.float32) 83 | all_cmc = all_cmc.sum(0) / num_valid_q 84 | mAP = np.mean(all_AP) 85 | 86 | return all_cmc, mAP 87 | 88 | 89 | class R1_mAP_eval(): 90 | def __init__(self, num_query, max_rank=50, feat_norm=True, reranking=False): 91 | super(R1_mAP_eval, self).__init__() 92 | self.num_query = num_query 93 | self.max_rank = max_rank 94 | self.feat_norm = feat_norm 95 | self.reranking = reranking 96 | 97 | def reset(self): 98 | self.feats = [] 99 | self.pids = [] 100 | self.camids = [] 101 | 102 | def update(self, output): # called once for each batch 103 | feat, pid, camid = output 104 | self.feats.append(feat.cpu()) 105 | self.pids.extend(np.asarray(pid)) 106 | self.camids.extend(np.asarray(camid)) 107 | 108 | def compute(self): # called after each epoch 109 | feats = torch.cat(self.feats, dim=0) 110 | if self.feat_norm: 111 | print("The test feature is normalized") 112 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) # along channel 113 | # query 114 | qf = feats[:self.num_query] 115 | q_pids = np.asarray(self.pids[:self.num_query]) 116 | q_camids = np.asarray(self.camids[:self.num_query]) 117 | # gallery 118 | gf = feats[self.num_query:] 119 | g_pids = np.asarray(self.pids[self.num_query:]) 120 | 121 | g_camids = np.asarray(self.camids[self.num_query:]) 122 | if self.reranking: 123 | print('=> Enter reranking') 124 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 125 | 126 | else: 127 | print('=> Computing DistMat with euclidean_distance') 128 | distmat = euclidean_distance(qf, gf) 129 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 130 | 131 | return cmc, mAP, distmat, self.pids, self.camids, qf, gf 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /transreid_pytorch/utils/reranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea, galFea]) 37 | # print('using GPU to compute original distance') 38 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1, -2, feat, feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | # print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | --------------------------------------------------------------------------------