├── ChannelAug.py ├── LICENSE ├── README.md ├── clustercontrast ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── evaluators.cpython-36.pyc │ ├── evaluators.cpython-37.pyc │ ├── trainers.cpython-36.pyc │ └── trainers.cpython-37.pyc ├── datasets │ ├── __init__.py │ ├── dukemtmcreid.py │ ├── market1501.py │ ├── msmt17.py │ ├── personx.py │ ├── regdb_ir.py │ ├── regdb_rgb.py │ ├── sysu_all.py │ ├── sysu_ir.py │ ├── sysu_rgb.py │ └── veri.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── evaluators.py ├── models │ ├── __init__.py │ ├── agw.py │ ├── cm.py │ ├── dsbn.py │ ├── kmeans.py │ ├── pooling.py │ ├── resnet.py │ ├── resnet_agw.py │ ├── resnet_ibn.py │ └── resnet_ibn_a.py ├── trainers.py └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── base_dataset.cpython-36.pyc │ │ ├── base_dataset.cpython-37.pyc │ │ ├── preprocessor.cpython-36.pyc │ │ ├── preprocessor.cpython-37.pyc │ │ ├── sampler.cpython-36.pyc │ │ ├── sampler.cpython-37.pyc │ │ ├── transforms.cpython-36.pyc │ │ └── transforms.cpython-37.pyc │ ├── base_dataset.py │ ├── preprocessor.py │ ├── sampler.py │ └── transforms.py │ ├── faiss_rerank.py │ ├── faiss_utils.py │ ├── infomap_cluster.py │ ├── infomap_utils.py │ ├── logging.py │ ├── meters.py │ ├── osutils.py │ ├── rerank.py │ └── serialization.py ├── examples └── note.md ├── meters.py ├── prepare_regdb.py ├── prepare_sysu.py ├── requirements.txt ├── run_test_regdb.sh ├── run_test_sysu.sh ├── run_train_regdb.sh ├── run_train_sysu.sh ├── test_regdb.py ├── test_sysu.py ├── train_regdb.py ├── train_sysu.py └── training_logs_example ├── RegDB └── regdb_s1_s2.log └── SYSU-MM01 └── sysu_s1_s2.log /ChannelAug.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | 5 | #from PIL import Image 6 | import random 7 | import math 8 | #import numpy as np 9 | #import torch 10 | 11 | class ChannelExchange(object): 12 | """ Adaptive selects a channel or two channels. 13 | Args: 14 | probability: The probability that the Random Erasing operation will be performed. 15 | sl: Minimum proportion of erased area against input image. 16 | sh: Maximum proportion of erased area against input image. 17 | r1: Minimum aspect ratio of erased area. 18 | mean: Erasing value. 19 | """ 20 | 21 | def __init__(self, gray = 2): 22 | self.gray = gray 23 | 24 | def __call__(self, img): 25 | 26 | idx = random.randint(0, self.gray) 27 | 28 | if idx ==0: 29 | # random select R Channel 30 | img[1, :,:] = img[0,:,:] 31 | img[2, :,:] = img[0,:,:] 32 | elif idx ==1: 33 | # random select B Channel 34 | img[0, :,:] = img[1,:,:] 35 | img[2, :,:] = img[1,:,:] 36 | elif idx ==2: 37 | # random select G Channel 38 | img[0, :,:] = img[2,:,:] 39 | img[1, :,:] = img[2,:,:] 40 | else: 41 | tmp_img = 0.2989 * img[0,:,:] + 0.5870 * img[1,:,:] + 0.1140 * img[2,:,:] 42 | img[0,:,:] = tmp_img 43 | img[1,:,:] = tmp_img 44 | img[2,:,:] = tmp_img 45 | return img 46 | 47 | class ChannelAdap(object): 48 | """ Adaptive selects a channel or two channels. 49 | Args: 50 | probability: The probability that the Random Erasing operation will be performed. 51 | sl: Minimum proportion of erased area against input image. 52 | sh: Maximum proportion of erased area against input image. 53 | r1: Minimum aspect ratio of erased area. 54 | mean: Erasing value. 55 | """ 56 | 57 | def __init__(self, probability = 0.5): 58 | self.probability = probability 59 | 60 | 61 | def __call__(self, img): 62 | 63 | # if random.uniform(0, 1) > self.probability: 64 | # return img 65 | 66 | idx = random.randint(0, 3) 67 | 68 | if idx ==0: 69 | # random select R Channel 70 | img[1, :,:] = img[0,:,:] 71 | img[2, :,:] = img[0,:,:] 72 | elif idx ==1: 73 | # random select B Channel 74 | img[0, :,:] = img[1,:,:] 75 | img[2, :,:] = img[1,:,:] 76 | elif idx ==2: 77 | # random select G Channel 78 | img[0, :,:] = img[2,:,:] 79 | img[1, :,:] = img[2,:,:] 80 | else: 81 | img = img 82 | 83 | return img 84 | 85 | 86 | class ChannelAdapGray(object): 87 | """ Adaptive selects a channel or two channels. 88 | Args: 89 | probability: The probability that the Random Erasing operation will be performed. 90 | sl: Minimum proportion of erased area against input image. 91 | sh: Maximum proportion of erased area against input image. 92 | r1: Minimum aspect ratio of erased area. 93 | mean: Erasing value. 94 | """ 95 | 96 | def __init__(self, probability = 0.5): 97 | self.probability = probability 98 | 99 | 100 | def __call__(self, img): 101 | 102 | # if random.uniform(0, 1) > self.probability: 103 | # return img 104 | 105 | idx = random.randint(0, 3) 106 | 107 | if idx ==0: 108 | # random select R Channel 109 | img[1, :,:] = img[0,:,:] 110 | img[2, :,:] = img[0,:,:] 111 | elif idx ==1: 112 | # random select B Channel 113 | img[0, :,:] = img[1,:,:] 114 | img[2, :,:] = img[1,:,:] 115 | elif idx ==2: 116 | # random select G Channel 117 | img[0, :,:] = img[2,:,:] 118 | img[1, :,:] = img[2,:,:] 119 | else: 120 | if random.uniform(0, 1) > self.probability: 121 | # return img 122 | img = img 123 | else: 124 | tmp_img = 0.2989 * img[0,:,:] + 0.5870 * img[1,:,:] + 0.1140 * img[2,:,:] 125 | img[0,:,:] = tmp_img 126 | img[1,:,:] = tmp_img 127 | img[2,:,:] = tmp_img 128 | return img 129 | 130 | 131 | class Gray(object): 132 | """ Adaptive selects a channel or two channels. 133 | Args: 134 | probability: The probability that the Random Erasing operation will be performed. 135 | sl: Minimum proportion of erased area against input image. 136 | sh: Maximum proportion of erased area against input image. 137 | r1: Minimum aspect ratio of erased area. 138 | mean: Erasing value. 139 | """ 140 | 141 | def __init__(self, probability = 0.5): 142 | self.probability = probability 143 | 144 | 145 | def __call__(self, img): 146 | 147 | # if random.uniform(0, 1) > self.probability: 148 | # return img 149 | 150 | tmp_img = 0.2989 * img[0,:,:] + 0.5870 * img[1,:,:] + 0.1140 * img[2,:,:] 151 | img[0,:,:] = tmp_img 152 | img[1,:,:] = tmp_img 153 | img[2,:,:] = tmp_img 154 | return img 155 | 156 | 157 | class ChannelRandomErasing(object): 158 | """ Randomly selects a rectangle region in an image and erases its pixels. 159 | 'Random Erasing Data Augmentation' by Zhong et al. 160 | See https://arxiv.org/pdf/1708.04896.pdf 161 | Args: 162 | probability: The probability that the Random Erasing operation will be performed. 163 | sl: Minimum proportion of erased area against input image. 164 | sh: Maximum proportion of erased area against input image. 165 | r1: Minimum aspect ratio of erased area. 166 | mean: Erasing value. 167 | """ 168 | 169 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 170 | 171 | self.probability = probability 172 | self.mean = mean 173 | self.sl = sl 174 | self.sh = sh 175 | self.r1 = r1 176 | 177 | def __call__(self, img): 178 | 179 | if random.uniform(0, 1) > self.probability: 180 | return img 181 | 182 | for attempt in range(100): 183 | area = img.size()[1] * img.size()[2] 184 | 185 | target_area = random.uniform(self.sl, self.sh) * area 186 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 187 | 188 | h = int(round(math.sqrt(target_area * aspect_ratio))) 189 | w = int(round(math.sqrt(target_area / aspect_ratio))) 190 | 191 | if w < img.size()[2] and h < img.size()[1]: 192 | x1 = random.randint(0, img.size()[1] - h) 193 | y1 = random.randint(0, img.size()[2] - w) 194 | if img.size()[0] == 3: 195 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 196 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 197 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 198 | else: 199 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 200 | return img 201 | 202 | return img -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Alibaba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Visible-Infrared Person Re-Identification via Progressive Graph Matching and Alternate Learning 2 | 3 | 4 | ## Dataset Preprocessing 5 | Convert the dataset format (like Market1501). 6 | ```shell 7 | python prepare_sysu.py # for SYSU-MM01 8 | python prepare_regdb.py # for RegDB 9 | ``` 10 | You need to change the file path in the `prepare_sysu(regdb).py`. 11 | 12 | Note: a pre-processed dataset can be downloaded from [Baidu Netdisk](https://pan.baidu.com/s/1Ovc8SRbWHkMMit26DfEaiA) (Password: ReID) or [Google Drive](https://drive.google.com/drive/folders/1TJG3TRgqi_DUMItJeFU4285IaB10-cXl?usp=sharing). 13 | 14 | ## Training 15 | ```shell 16 | ./train_sysu.sh # for SYSU-MM01 17 | ./train_regdb.sh # for RegDB 18 | ``` 19 | Two training stages are included and you need to specify the training stage by commenting another stage's `main_worker` like this: 20 | ```python 21 | main_worker_stage1(args,log_s1_name) # Stage 1 22 | # main_worker_stage2(args,log_s1_name,log_s2_name) # Stage 2 23 | ``` 24 | Update: We optimized the code to make the training more stable. In the 2nd stage of training, we recommend setting `use_hard` to `True`, referring to [1]. 25 | 26 | ## Test 27 | ```shell 28 | ./test_sysu.sh # for SYSU-MM01 29 | ./test_regdb.sh # for RegDB 30 | ``` 31 | 32 | # Citation 33 | ```bibtex 34 | @InProceedings{Wu_2023_CVPR, 35 | author = {Wu, Zesen and Ye, Mang}, 36 | title = {Unsupervised Visible-Infrared Person Re-Identification via Progressive Graph Matching and Alternate Learning}, 37 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 38 | month = {June}, 39 | year = {2023}, 40 | pages = {9548-9558} 41 | } 42 | ``` 43 | 44 | Our trained models can be downloaded [here](https://drive.google.com/drive/folders/1NIpM5uv9_DUbCafwy7Z28yXPnMXxNtss?usp=sharing). 45 | 46 | [1] Dai, Zuozhuo, et al. "Cluster contrast for unsupervised person re-identification." Proceedings of the Asian conference on computer vision. 2022. 47 | 48 | # Contact 49 | zesenwu@whu.edu.cn 50 | 51 | The code is implemented based on ClusterContrast and ADCA. 52 | -------------------------------------------------------------------------------- /clustercontrast/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import models 6 | from . import utils 7 | from . import evaluators 8 | from . import trainers 9 | 10 | __version__ = '0.1.0' 11 | -------------------------------------------------------------------------------- /clustercontrast/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /clustercontrast/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /clustercontrast/__pycache__/evaluators.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/__pycache__/evaluators.cpython-36.pyc -------------------------------------------------------------------------------- /clustercontrast/__pycache__/evaluators.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/__pycache__/evaluators.cpython-37.pyc -------------------------------------------------------------------------------- /clustercontrast/__pycache__/trainers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/__pycache__/trainers.cpython-36.pyc -------------------------------------------------------------------------------- /clustercontrast/__pycache__/trainers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/__pycache__/trainers.cpython-37.pyc -------------------------------------------------------------------------------- /clustercontrast/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .market1501 import Market1501 5 | from .msmt17 import MSMT17 6 | from .personx import PersonX 7 | from .veri import VeRi 8 | from .dukemtmcreid import DukeMTMCreID 9 | from .sysu_all import sysu_all 10 | from .sysu_ir import sysu_ir 11 | from .sysu_rgb import sysu_rgb 12 | from .regdb_ir import regdb_ir 13 | from .regdb_rgb import regdb_rgb 14 | __factory = { 15 | 'market1501': Market1501, 16 | 'msmt17': MSMT17, 17 | 'personx': PersonX, 18 | 'veri': VeRi, 19 | 'dukemtmcreid': DukeMTMCreID, 20 | 'sysu_all': sysu_all, 21 | 'sysu_ir':sysu_ir, 22 | 'sysu_rgb':sysu_rgb, 23 | 'regdb_ir':regdb_ir, 24 | 'regdb_rgb':regdb_rgb 25 | } 26 | 27 | 28 | def names(): 29 | return sorted(__factory.keys()) 30 | 31 | 32 | def create(name, root,trial=0, *args, **kwargs): 33 | """ 34 | Create a dataset instance. 35 | 36 | Parameters 37 | ---------- 38 | name : str 39 | The dataset name. 40 | root : str 41 | The path to the dataset directory. 42 | split_id : int, optional 43 | The index of data split. Default: 0 44 | num_val : int or float, optional 45 | When int, it means the number of validation identities. When float, 46 | it means the proportion of validation to all the trainval. Default: 100 47 | download : bool, optional 48 | If True, will download the dataset. Default: False 49 | """ 50 | if name not in __factory: 51 | raise KeyError("Unknown dataset:", name) 52 | return __factory[name](root, trial=trial, *args, **kwargs) 53 | 54 | 55 | def get_dataset(name, root, *args, **kwargs): 56 | warnings.warn("get_dataset is deprecated. Use create instead.") 57 | return create(name, root, *args, **kwargs) 58 | -------------------------------------------------------------------------------- /clustercontrast/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | import re 4 | from ..utils.data import BaseImageDataset 5 | 6 | 7 | def process_dir(dir_path, relabel=False): 8 | img_paths = glob.glob(osp.join(dir_path, "*.jpg")) 9 | pattern = re.compile(r"([-\d]+)_c(\d)") 10 | 11 | # get all identities 12 | pid_container = set() 13 | for img_path in img_paths: 14 | pid, _ = map(int, pattern.search(img_path).groups()) 15 | if pid == -1: 16 | continue 17 | pid_container.add(pid) 18 | 19 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 20 | 21 | data = [] 22 | for img_path in img_paths: 23 | pid, camid = map(int, pattern.search(img_path).groups()) 24 | if (pid not in pid_container) or (pid == -1): 25 | continue 26 | 27 | assert 1 <= camid <= 8 28 | camid -= 1 29 | 30 | if relabel: 31 | pid = pid2label[pid] 32 | data.append((img_path, pid, camid)) 33 | 34 | return data 35 | 36 | 37 | class DukeMTMCreID(BaseImageDataset): 38 | 39 | """DukeMTMC-reID. 40 | Reference: 41 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, 42 | Multi-Camera Tracking. ECCVW 2016. 43 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person 44 | Re-identification Baseline in vitro. ICCV 2017. 45 | URL: ``_ 46 | 47 | Dataset statistics: 48 | - identities: 1404 (train + query). 49 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 50 | - cameras: 8. 51 | """ 52 | 53 | dataset_dir = "DukeMTMC-reID" 54 | 55 | def __init__(self, root, verbose=True): 56 | super(DukeMTMCreID, self).__init__() 57 | self.root = osp.abspath(osp.expanduser(root)) 58 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 59 | 60 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 61 | self.query_dir = osp.join(self.dataset_dir, 'query') 62 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 63 | 64 | train = process_dir(dir_path=self.train_dir, relabel=True) 65 | query = process_dir(dir_path=self.query_dir, relabel=False) 66 | gallery = process_dir(dir_path=self.gallery_dir, relabel=False) 67 | 68 | self.train = train 69 | self.query = query 70 | self.gallery = gallery 71 | 72 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 73 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 74 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 75 | 76 | def _check_before_run(self): 77 | """Check if all files are available before going deeper""" 78 | if not osp.exists(self.dataset_dir): 79 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 80 | if not osp.exists(self.train_dir): 81 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 82 | if not osp.exists(self.query_dir): 83 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 84 | if not osp.exists(self.gallery_dir): 85 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 86 | -------------------------------------------------------------------------------- /clustercontrast/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class Market1501(BaseImageDataset): 9 | """ 10 | Market1501 11 | Reference: 12 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 13 | URL: http://www.liangzheng.org/Project/project_reid.html 14 | 15 | Dataset statistics: 16 | # identities: 1501 (+1 for background) 17 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 18 | """ 19 | dataset_dir = 'market1501' 20 | 21 | def __init__(self, root, verbose=True, **kwargs): 22 | super(Market1501, self).__init__() 23 | root='/dat01/yangbin/data/' 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 26 | self.query_dir = osp.join(self.dataset_dir, 'query') 27 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 28 | 29 | self._check_before_run() 30 | 31 | train = self._process_dir(self.train_dir, relabel=True) 32 | query = self._process_dir(self.query_dir, relabel=False) 33 | gallery = self._process_dir(self.gallery_dir, relabel=False) 34 | 35 | if verbose: 36 | print("=> Market1501 loaded") 37 | self.print_dataset_statistics(train, query, gallery) 38 | 39 | self.train = train 40 | self.query = query 41 | self.gallery = gallery 42 | 43 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 44 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 45 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 46 | 47 | def _check_before_run(self): 48 | """Check if all files are available before going deeper""" 49 | if not osp.exists(self.dataset_dir): 50 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 51 | if not osp.exists(self.train_dir): 52 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 53 | if not osp.exists(self.query_dir): 54 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 55 | if not osp.exists(self.gallery_dir): 56 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 57 | 58 | def _process_dir(self, dir_path, relabel=False): 59 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 60 | pattern = re.compile(r'([-\d]+)_c(\d)') 61 | 62 | pid_container = set() 63 | for img_path in img_paths: 64 | pid, _ = map(int, pattern.search(img_path).groups()) 65 | if pid == -1: 66 | continue # junk images are just ignored 67 | pid_container.add(pid) 68 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 69 | 70 | dataset = [] 71 | for img_path in img_paths: 72 | pid, camid = map(int, pattern.search(img_path).groups()) 73 | if pid == -1: 74 | continue # junk images are just ignored 75 | assert 0 <= pid <= 1501 # pid == 0 means background 76 | assert 1 <= camid <= 6 77 | camid -= 1 # index starts from 0 78 | if relabel: 79 | pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | -------------------------------------------------------------------------------- /clustercontrast/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | import glob 5 | import re 6 | from ..utils.data import BaseImageDataset 7 | 8 | 9 | def _process_dir(dir_path, relabel=False): 10 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 11 | pattern = re.compile(r'([-\d]+)_c(\d+)') 12 | 13 | pid_container = set() 14 | for img_path in img_paths: 15 | pid, _ = map(int, pattern.search(img_path).groups()) 16 | if pid == -1: 17 | continue # junk images are just ignored 18 | pid_container.add(pid) 19 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 20 | dataset = [] 21 | for img_path in img_paths: 22 | pid, camid = map(int, pattern.search(img_path).groups()) 23 | if pid == -1: 24 | continue # junk images are just ignored 25 | assert 1 <= camid <= 15 26 | camid -= 1 # index starts from 0 27 | if relabel: 28 | pid = pid2label[pid] 29 | dataset.append((img_path, pid, camid)) 30 | 31 | return dataset 32 | 33 | 34 | class MSMT17(BaseImageDataset): 35 | dataset_dir = 'MSMT17_V1' 36 | 37 | def __init__(self, root, verbose=True, **kwargs): 38 | super(MSMT17, self).__init__() 39 | self.dataset_dir = osp.join(root, self.dataset_dir) 40 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 41 | self.query_dir = osp.join(self.dataset_dir, 'query') 42 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 43 | 44 | self._check_before_run() 45 | 46 | train = _process_dir(self.train_dir, relabel=True) 47 | query = _process_dir(self.query_dir, relabel=False) 48 | gallery = _process_dir(self.gallery_dir, relabel=False) 49 | 50 | if verbose: 51 | print("=> MSMT17_V1 loaded") 52 | self.print_dataset_statistics(train, query, gallery) 53 | 54 | self.train = train 55 | self.query = query 56 | self.gallery = gallery 57 | 58 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 59 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 60 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 61 | 62 | def _check_before_run(self): 63 | """Check if all files are available before going deeper""" 64 | if not osp.exists(self.dataset_dir): 65 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 66 | if not osp.exists(self.train_dir): 67 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 68 | if not osp.exists(self.query_dir): 69 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 70 | if not osp.exists(self.gallery_dir): 71 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 72 | -------------------------------------------------------------------------------- /clustercontrast/datasets/personx.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | 6 | from ..utils.data import BaseImageDataset 7 | 8 | 9 | class PersonX(BaseImageDataset): 10 | """ 11 | PersonX 12 | Reference: 13 | Sun et al. Dissecting Person Re-identification from the Viewpoint of Viewpoint. CVPR 2019. 14 | 15 | Dataset statistics: 16 | # identities: 1266 17 | # images: 9840 (train) + 5136 (query) + 30816 (gallery) 18 | """ 19 | dataset_dir = 'PersonX' 20 | 21 | def __init__(self, root, verbose=True, **kwargs): 22 | super(PersonX, self).__init__() 23 | self.dataset_dir = osp.join(root, self.dataset_dir) 24 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 25 | self.query_dir = osp.join(self.dataset_dir, 'query') 26 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 27 | 28 | self._check_before_run() 29 | 30 | train = self._process_dir(self.train_dir, relabel=True) 31 | query = self._process_dir(self.query_dir, relabel=False) 32 | gallery = self._process_dir(self.gallery_dir, relabel=False) 33 | 34 | if verbose: 35 | print("=> PersonX loaded") 36 | self.print_dataset_statistics(train, query, gallery) 37 | 38 | self.train = train 39 | self.query = query 40 | self.gallery = gallery 41 | 42 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 43 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 44 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 45 | 46 | def _check_before_run(self): 47 | """Check if all files are available before going deeper""" 48 | if not osp.exists(self.dataset_dir): 49 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 50 | if not osp.exists(self.train_dir): 51 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 52 | if not osp.exists(self.query_dir): 53 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 54 | if not osp.exists(self.gallery_dir): 55 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 56 | 57 | def _process_dir(self, dir_path, relabel=False): 58 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 59 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 60 | cam2label = {3: 1, 4: 2, 8: 3, 10: 4, 11: 5, 12: 6} 61 | 62 | pid_container = set() 63 | for img_path in img_paths: 64 | pid, _ = map(int, pattern.search(img_path).groups()) 65 | pid_container.add(pid) 66 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 67 | 68 | dataset = [] 69 | for img_path in img_paths: 70 | pid, camid = map(int, pattern.search(img_path).groups()) 71 | assert (camid in cam2label.keys()) 72 | camid = cam2label[camid] 73 | camid -= 1 # index starts from 0 74 | if relabel: pid = pid2label[pid] 75 | dataset.append((img_path, pid, camid)) 76 | 77 | return dataset 78 | -------------------------------------------------------------------------------- /clustercontrast/datasets/regdb_ir.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class regdb_ir(BaseImageDataset): 9 | """ 10 | regdb_ir 11 | train in market1501 type data 12 | test in orignal regdb data 13 | """ 14 | dataset_dir = 'regdb/ir_modify/' 15 | 16 | def __init__(self, root, trial= 0,verbose=True, **kwargs): 17 | super(regdb_ir, self).__init__() 18 | # print('regdb_ir',trial) 19 | # root='./data/' 20 | dataset_dir = '/data0/ReIDData/RegDB/ir_modify/' 21 | self.dataset_dir = dataset_dir 22 | self.train_dir = osp.join(self.dataset_dir, str(trial)+'/'+'bounding_box_train') 23 | 24 | self.query_dir = osp.join(self.dataset_dir, str(trial)+'/'+'query') 25 | self.gallery_dir = osp.join(self.dataset_dir, str(trial)+'/'+'bounding_box_test') 26 | 27 | 28 | self._check_before_run() 29 | 30 | train = self._process_dir(self.train_dir, relabel=True) 31 | query = self._process_dir(self.query_dir, relabel=False) 32 | gallery = self._process_dir(self.gallery_dir, relabel=False) 33 | 34 | if verbose: 35 | print("=> regdb_ir loaded",trial) 36 | self.print_dataset_statistics(train, query, gallery) 37 | 38 | self.train = train 39 | self.query = query 40 | self.gallery = gallery 41 | 42 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 43 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 44 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 45 | 46 | def _check_before_run(self): 47 | """Check if all files are available before going deeper""" 48 | if not osp.exists(self.dataset_dir): 49 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 50 | if not osp.exists(self.train_dir): 51 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 52 | if not osp.exists(self.query_dir): 53 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 54 | if not osp.exists(self.gallery_dir): 55 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 56 | 57 | def _process_dir(self, dir_path, relabel=False): 58 | img_paths = glob.glob(osp.join(dir_path, '*.bmp')) 59 | pattern = re.compile(r'([-\d]+)_c(\d)') 60 | 61 | pid_container = set() 62 | for img_path in img_paths: 63 | pid, _ = map(int, pattern.search(img_path).groups()) 64 | if pid == -1: 65 | continue # junk images are just ignored 66 | pid_container.add(pid) 67 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 68 | 69 | dataset = [] 70 | for img_path in img_paths: 71 | pid, camid = map(int, pattern.search(img_path).groups()) 72 | if pid == -1: 73 | continue # junk images are just ignored 74 | assert 0 <= pid <= 1501 # pid == 0 means background 75 | assert 1 <= camid <= 6 76 | camid -= 1 # index starts from 0 77 | if relabel: 78 | pid = pid2label[pid] 79 | dataset.append((img_path, pid, camid)) 80 | 81 | return dataset 82 | -------------------------------------------------------------------------------- /clustercontrast/datasets/regdb_rgb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class regdb_rgb(BaseImageDataset): 9 | """ 10 | regdb_rgb 11 | train in market1501 type data 12 | test in orignal regdb data 13 | """ 14 | # dataset_dir = '/data0/ReIDData/RegDB/rgb_modify/' 15 | 16 | def __init__(self, root,trial=0, verbose=True, **kwargs): 17 | super(regdb_rgb, self).__init__() 18 | # root='./data/' 19 | # print('regdb_rgb',trial) 20 | dataset_dir = '/data0/ReIDData/RegDB/rgb_modify/' 21 | self.dataset_dir = dataset_dir 22 | self.train_dir = osp.join(self.dataset_dir, str(trial)+'/'+'bounding_box_train') 23 | 24 | 25 | 26 | self.query_dir = osp.join(self.dataset_dir, str(trial)+'/'+'query')#osp.join(self.dataset_dir, 'query') 27 | self.gallery_dir = osp.join(self.dataset_dir, str(trial)+'/'+'bounding_box_test') 28 | 29 | self._check_before_run() 30 | 31 | train = self._process_dir(self.train_dir, relabel=True) 32 | query = self._process_dir(self.query_dir, relabel=False) 33 | gallery = self._process_dir(self.gallery_dir, relabel=False) 34 | 35 | if verbose: 36 | print("=> regdb_rgb loaded",trial) 37 | self.print_dataset_statistics(train, query, gallery) 38 | 39 | self.train = train 40 | self.query = query 41 | self.gallery = gallery 42 | 43 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 44 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 45 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 46 | 47 | def _check_before_run(self): 48 | """Check if all files are available before going deeper""" 49 | if not osp.exists(self.dataset_dir): 50 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 51 | if not osp.exists(self.train_dir): 52 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 53 | if not osp.exists(self.query_dir): 54 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 55 | if not osp.exists(self.gallery_dir): 56 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 57 | 58 | def _process_dir(self, dir_path, relabel=False): 59 | img_paths = glob.glob(osp.join(dir_path, '*.bmp')) 60 | pattern = re.compile(r'([-\d]+)_c(\d)') 61 | 62 | pid_container = set() 63 | for img_path in img_paths: 64 | pid, _ = map(int, pattern.search(img_path).groups()) 65 | if pid == -1: 66 | continue # junk images are just ignored 67 | pid_container.add(pid) 68 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 69 | 70 | dataset = [] 71 | for img_path in img_paths: 72 | pid, camid = map(int, pattern.search(img_path).groups()) 73 | if pid == -1: 74 | continue # junk images are just ignored 75 | assert 0 <= pid <= 1501 # pid == 0 means background 76 | assert 1 <= camid <= 6 77 | camid -= 1 # index starts from 0 78 | if relabel: 79 | pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | -------------------------------------------------------------------------------- /clustercontrast/datasets/sysu_all.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class sysu_all(BaseImageDataset): 9 | """ 10 | Market1501 11 | Reference: 12 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 13 | URL: http://www.liangzheng.org/Project/project_reid.html 14 | 15 | Dataset statistics: 16 | # identities: 1501 (+1 for background) 17 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 18 | """ 19 | dataset_dir = 'sysu/all_modify/' 20 | 21 | def __init__(self, root, verbose=True, **kwargs): 22 | super(sysu_all, self).__init__() 23 | root='/dat01/yangbin/data/' 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 26 | self.query_dir = osp.join(self.dataset_dir, 'query') 27 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 28 | 29 | self._check_before_run() 30 | 31 | train = self._process_dir(self.train_dir, relabel=True) 32 | query = self._process_dir(self.query_dir, relabel=False) 33 | gallery = self._process_dir(self.gallery_dir, relabel=False) 34 | 35 | if verbose: 36 | print("=> Market1501 loaded") 37 | self.print_dataset_statistics(train, query, gallery) 38 | 39 | self.train = train 40 | self.query = query 41 | self.gallery = gallery 42 | 43 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 44 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 45 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 46 | 47 | def _check_before_run(self): 48 | """Check if all files are available before going deeper""" 49 | if not osp.exists(self.dataset_dir): 50 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 51 | if not osp.exists(self.train_dir): 52 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 53 | if not osp.exists(self.query_dir): 54 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 55 | if not osp.exists(self.gallery_dir): 56 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 57 | 58 | def _process_dir(self, dir_path, relabel=False): 59 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 60 | pattern = re.compile(r'([-\d]+)_c(\d)') 61 | 62 | pid_container = set() 63 | for img_path in img_paths: 64 | pid, _ = map(int, pattern.search(img_path).groups()) 65 | if pid == -1: 66 | continue # junk images are just ignored 67 | pid_container.add(pid) 68 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 69 | 70 | dataset = [] 71 | for img_path in img_paths: 72 | pid, camid = map(int, pattern.search(img_path).groups()) 73 | if pid == -1: 74 | continue # junk images are just ignored 75 | assert 0 <= pid <= 1501 # pid == 0 means background 76 | assert 1 <= camid <= 6 77 | camid -= 1 # index starts from 0 78 | if relabel: 79 | pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | -------------------------------------------------------------------------------- /clustercontrast/datasets/sysu_ir.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class sysu_ir(BaseImageDataset): 9 | """ 10 | sysu_ir 11 | train in market1501 type data 12 | test in orignal sysu data 13 | """ 14 | dataset_dir = 'ir_modify/' 15 | 16 | def __init__(self, root, verbose=True, **kwargs): 17 | super(sysu_ir, self).__init__() 18 | root='/data0/data_wzs/SYSU-MM01-Original/SYSU-MM01' 19 | self.dataset_dir = osp.join(root, self.dataset_dir) 20 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 21 | 22 | self.query_dir = osp.join(self.dataset_dir, 'query')#not use 23 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') #not use 24 | 25 | 26 | self._check_before_run() 27 | 28 | train = self._process_dir(self.train_dir, relabel=True) 29 | query = self._process_dir(self.query_dir, relabel=False) 30 | gallery = self._process_dir(self.gallery_dir, relabel=False) 31 | 32 | if verbose: 33 | print("=> sysu_ir loaded") 34 | self.print_dataset_statistics(train, query, gallery) 35 | 36 | self.train = train 37 | self.query = query 38 | self.gallery = gallery 39 | 40 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 41 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 42 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 43 | 44 | def _check_before_run(self): 45 | """Check if all files are available before going deeper""" 46 | if not osp.exists(self.dataset_dir): 47 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 48 | if not osp.exists(self.train_dir): 49 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 50 | if not osp.exists(self.query_dir): 51 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 52 | if not osp.exists(self.gallery_dir): 53 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 54 | 55 | def _process_dir(self, dir_path, relabel=False): 56 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 57 | pattern = re.compile(r'([-\d]+)_c(\d)') 58 | 59 | pid_container = set() 60 | for img_path in img_paths: 61 | pid, _ = map(int, pattern.search(img_path).groups()) 62 | if pid == -1: 63 | continue # junk images are just ignored 64 | pid_container.add(pid) 65 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 66 | 67 | dataset = [] 68 | for img_path in img_paths: 69 | pid, camid = map(int, pattern.search(img_path).groups()) 70 | if pid == -1: 71 | continue # junk images are just ignored 72 | assert 0 <= pid <= 1501 # pid == 0 means background 73 | assert 1 <= camid <= 6 74 | camid -= 1 # index starts from 0 75 | if relabel: 76 | pid = pid2label[pid] 77 | dataset.append((img_path, pid, camid)) 78 | 79 | return dataset 80 | -------------------------------------------------------------------------------- /clustercontrast/datasets/sysu_rgb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class sysu_rgb(BaseImageDataset): 9 | """ 10 | sysu_rgb 11 | train in market1501 type data 12 | test in orignal sysu data 13 | """ 14 | dataset_dir = 'rgb_modify/' 15 | 16 | def __init__(self, root, verbose=True, **kwargs): 17 | super(sysu_rgb, self).__init__() 18 | root='/data0/data_wzs/SYSU-MM01-Original/SYSU-MM01/' 19 | self.dataset_dir = osp.join(root, self.dataset_dir) 20 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')#for training 21 | 22 | self.query_dir = osp.join(self.dataset_dir, 'query')#not use 23 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')#not use 24 | 25 | self._check_before_run() 26 | 27 | train = self._process_dir(self.train_dir, relabel=True) 28 | query = self._process_dir(self.query_dir, relabel=False) 29 | gallery = self._process_dir(self.gallery_dir, relabel=False) 30 | 31 | if verbose: 32 | print("=> sysu_rgb loaded") 33 | self.print_dataset_statistics(train, query, gallery) 34 | 35 | self.train = train 36 | self.query = query 37 | self.gallery = gallery 38 | 39 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 40 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 41 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 42 | 43 | def _check_before_run(self): 44 | """Check if all files are available before going deeper""" 45 | if not osp.exists(self.dataset_dir): 46 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 47 | if not osp.exists(self.train_dir): 48 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 49 | if not osp.exists(self.query_dir): 50 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 51 | if not osp.exists(self.gallery_dir): 52 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 53 | 54 | def _process_dir(self, dir_path, relabel=False): 55 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 56 | pattern = re.compile(r'([-\d]+)_c(\d)') 57 | 58 | pid_container = set() 59 | for img_path in img_paths: 60 | pid, _ = map(int, pattern.search(img_path).groups()) 61 | if pid == -1: 62 | continue # junk images are just ignored 63 | pid_container.add(pid) 64 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 65 | 66 | dataset = [] 67 | for img_path in img_paths: 68 | pid, camid = map(int, pattern.search(img_path).groups()) 69 | if pid == -1: 70 | continue # junk images are just ignored 71 | assert 0 <= pid <= 1501 # pid == 0 means background 72 | assert 1 <= camid <= 6 73 | camid -= 1 # index starts from 0 74 | if relabel: 75 | pid = pid2label[pid] 76 | dataset.append((img_path, pid, camid)) 77 | 78 | return dataset 79 | -------------------------------------------------------------------------------- /clustercontrast/datasets/veri.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import re 7 | import os.path as osp 8 | 9 | from ..utils.data import BaseImageDataset 10 | 11 | 12 | class VeRi(BaseImageDataset): 13 | """ 14 | VeRi 15 | Reference: 16 | Liu, X., Liu, W., Ma, H., Fu, H.: Large-scale vehicle re-identification in urban surveillance videos. In: IEEE % 17 | International Conference on Multimedia and Expo. (2016) accepted. 18 | Dataset statistics: 19 | # identities: 776 vehicles(576 for training and 200 for testing) 20 | # images: 37778 (train) + 11579 (query) 21 | """ 22 | dataset_dir = 'VeRi' 23 | 24 | def __init__(self, root, verbose=True, **kwargs): 25 | super(VeRi, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 30 | 31 | self.check_before_run() 32 | 33 | train = self.process_dir(self.train_dir, relabel=True) 34 | query = self.process_dir(self.query_dir, relabel=False) 35 | gallery = self.process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | print('=> VeRi loaded') 39 | self.print_dataset_statistics(train, query, gallery) 40 | 41 | self.train = train 42 | self.query = query 43 | self.gallery = gallery 44 | 45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 48 | 49 | def check_before_run(self): 50 | """Check if all files are available before going deeper""" 51 | if not osp.exists(self.dataset_dir): 52 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 53 | if not osp.exists(self.train_dir): 54 | raise RuntimeError('"{}" is not available'.format(self.train_dir)) 55 | if not osp.exists(self.query_dir): 56 | raise RuntimeError('"{}" is not available'.format(self.query_dir)) 57 | if not osp.exists(self.gallery_dir): 58 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir)) 59 | 60 | def process_dir(self, dir_path, relabel=False): 61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 62 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 63 | 64 | pid_container = set() 65 | for img_path in img_paths: 66 | pid, _ = map(int, pattern.search(img_path).groups()) 67 | if pid == -1: 68 | continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | dataset = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1: 76 | continue # junk images are just ignored 77 | assert 0 <= pid <= 776 # pid == 0 means background 78 | assert 1 <= camid <= 20 79 | camid -= 1 # index starts from 0 80 | if relabel: 81 | pid = pid2label[pid] 82 | dataset.append((img_path, pid, camid)) 83 | 84 | return dataset 85 | -------------------------------------------------------------------------------- /clustercontrast/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap' 10 | ] 11 | -------------------------------------------------------------------------------- /clustercontrast/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from ..utils import to_torch 5 | 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | with torch.no_grad(): 9 | output, target = to_torch(output), to_torch(target) 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | ret = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 20 | ret.append(correct_k.mul_(1. / batch_size)) 21 | return ret 22 | -------------------------------------------------------------------------------- /clustercontrast/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False,regdb = False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | if regdb == True: 48 | valid = ((gallery_ids[indices[i]] == query_ids[i]) | 49 | (gallery_cams[indices[i]] == query_cams[i])) 50 | else: 51 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 52 | (gallery_cams[indices[i]] != query_cams[i])) 53 | if separate_camera_set: 54 | # Filter out samples from same camera 55 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 56 | if not np.any(matches[i, valid]): continue 57 | if single_gallery_shot: 58 | repeat = 10 59 | gids = gallery_ids[indices[i][valid]] 60 | inds = np.where(valid)[0] 61 | ids_dict = defaultdict(list) 62 | for j, x in zip(inds, gids): 63 | ids_dict[x].append(j) 64 | else: 65 | repeat = 1 66 | for _ in range(repeat): 67 | if single_gallery_shot: 68 | # Randomly choose one instance for each id 69 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 70 | index = np.nonzero(matches[i, sampled])[0] 71 | else: 72 | index = np.nonzero(matches[i, valid])[0] 73 | delta = 1. / (len(index) * repeat) 74 | for j, k in enumerate(index): 75 | if k - j >= topk: break 76 | if first_match_break: 77 | ret[k - j] += 1 78 | break 79 | ret[k - j] += delta 80 | num_valid_queries += 1 81 | if num_valid_queries == 0: 82 | raise RuntimeError("No valid query") 83 | return ret.cumsum() / num_valid_queries 84 | 85 | 86 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 87 | query_cams=None, gallery_cams=None,regdb = False): 88 | distmat = to_numpy(distmat) 89 | m, n = distmat.shape 90 | # Fill up default values 91 | if query_ids is None: 92 | query_ids = np.arange(m) 93 | if gallery_ids is None: 94 | gallery_ids = np.arange(n) 95 | if query_cams is None: 96 | query_cams = np.zeros(m).astype(np.int32) 97 | if gallery_cams is None: 98 | gallery_cams = np.ones(n).astype(np.int32) 99 | # Ensure numpy array 100 | query_ids = np.asarray(query_ids) 101 | gallery_ids = np.asarray(gallery_ids) 102 | query_cams = np.asarray(query_cams) 103 | gallery_cams = np.asarray(gallery_cams) 104 | # Sort and find correct matches 105 | indices = np.argsort(distmat, axis=1) 106 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 107 | # Compute AP for each query 108 | aps = [] 109 | for i in range(m): 110 | # Filter out the same id and same camera 111 | if regdb == True: 112 | valid = ((gallery_ids[indices[i]] == query_ids[i]) | 113 | (gallery_cams[indices[i]] == query_cams[i])) 114 | else: 115 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 116 | (gallery_cams[indices[i]] != query_cams[i])) 117 | y_true = matches[i, valid] 118 | y_score = -distmat[i][indices[i]][valid] 119 | if not np.any(y_true): continue 120 | aps.append(average_precision_score(y_true, y_score)) 121 | if len(aps) == 0: 122 | raise RuntimeError("No valid query") 123 | return np.mean(aps) 124 | -------------------------------------------------------------------------------- /clustercontrast/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import collections 4 | from collections import OrderedDict 5 | import numpy as np 6 | import torch 7 | import random 8 | import copy 9 | 10 | from .evaluation_metrics import cmc, mean_ap 11 | from .utils.meters import AverageMeter 12 | from .utils.rerank import re_ranking 13 | from .utils import to_torch 14 | 15 | def fliplr(img): 16 | '''flip horizontal''' 17 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W 18 | img_flip = img.index_select(3,inv_idx) 19 | return img_flip 20 | 21 | def extract_cnn_feature(model, inputs,mode): 22 | inputs = to_torch(inputs).cuda() 23 | # inputs1 = inputs 24 | # print(inputs) 25 | outputs = model(inputs,inputs,modal=mode) 26 | outputs = outputs.data.cpu() 27 | return outputs 28 | 29 | 30 | def extract_features(model, data_loader, print_freq=50,flip=True,mode=0): 31 | model.eval() 32 | batch_time = AverageMeter() 33 | data_time = AverageMeter() 34 | 35 | features = OrderedDict() 36 | labels = OrderedDict() 37 | 38 | end = time.time() 39 | with torch.no_grad(): 40 | for i, (imgs, fnames, pids, _, _) in enumerate(data_loader): 41 | data_time.update(time.time() - end) 42 | 43 | outputs = extract_cnn_feature(model, imgs,mode) 44 | flip = fliplr(imgs) 45 | # print(flip) 46 | outputs_flip = extract_cnn_feature(model, flip,mode) 47 | 48 | for fname, output,output_flip,pid in zip(fnames, outputs,outputs_flip, pids): 49 | features[fname] = (output.detach() + output_flip.detach())/2.0 50 | labels[fname] = pid 51 | 52 | batch_time.update(time.time() - end) 53 | end = time.time() 54 | 55 | if (i + 1) % print_freq == 0: 56 | print('Extract Features: [{}/{}]\t' 57 | 'Time {:.3f} ({:.3f})\t' 58 | 'Data {:.3f} ({:.3f})\t' 59 | .format(i + 1, len(data_loader), 60 | batch_time.val, batch_time.avg, 61 | data_time.val, data_time.avg)) 62 | 63 | return features, labels 64 | 65 | 66 | def pairwise_distance(features, query=None, gallery=None): 67 | if query is None and gallery is None: 68 | n = len(features) 69 | x = torch.cat(list(features.values())) 70 | x = x.view(n, -1) 71 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 72 | dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t()) 73 | return dist_m 74 | 75 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 76 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 77 | m, n = x.size(0), y.size(0) 78 | x = x.view(m, -1) 79 | y = y.view(n, -1) 80 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 81 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 82 | dist_m.addmm_(1, -2, x, y.t()) 83 | return dist_m, x.numpy(), y.numpy() 84 | 85 | 86 | def evaluate_all(query_features, gallery_features, distmat, query=None, gallery=None, 87 | query_ids=None, gallery_ids=None, 88 | query_cams=None, gallery_cams=None, 89 | cmc_topk=(1, 5, 10), cmc_flag=False,regdb=False): 90 | if query is not None and gallery is not None: 91 | query_ids = [pid for _, pid, _ in query] 92 | gallery_ids = [pid for _, pid, _ in gallery] 93 | query_cams = [cam for _, _, cam in query] 94 | gallery_cams = [cam for _, _, cam in gallery] 95 | else: 96 | assert (query_ids is not None and gallery_ids is not None 97 | and query_cams is not None and gallery_cams is not None) 98 | 99 | # Compute mean AP 100 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams,regdb=regdb) 101 | print('Mean AP: {:4.1%}'.format(mAP)) 102 | 103 | if (not cmc_flag): 104 | return mAP 105 | 106 | cmc_configs = { 107 | 'market1501': dict(separate_camera_set=False, 108 | single_gallery_shot=False, 109 | first_match_break=True),} 110 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 111 | query_cams, gallery_cams,regdb=regdb, **params) 112 | for name, params in cmc_configs.items()} 113 | 114 | print('CMC Scores:') 115 | for k in cmc_topk: 116 | print(' top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k-1])) 117 | return cmc_scores['market1501'], mAP 118 | 119 | 120 | class Evaluator(object): 121 | def __init__(self, model): 122 | super(Evaluator, self).__init__() 123 | self.model = model 124 | 125 | def evaluate(self, data_loader, query, gallery, cmc_flag=False, rerank=False,modal=0,regdb=False): 126 | features, _ = extract_features(self.model, data_loader,mode=modal) 127 | # print(features,features) 128 | distmat, query_features, gallery_features = pairwise_distance(features, query, gallery) 129 | # print(distmat) 130 | results = evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag,regdb=regdb) 131 | 132 | if (not rerank): 133 | return results 134 | 135 | print('Applying person re-ranking ...') 136 | distmat_qq, _, _ = pairwise_distance(features, query, query) 137 | distmat_gg, _, _ = pairwise_distance(features, gallery, gallery) 138 | distmat = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy()) 139 | return evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 140 | -------------------------------------------------------------------------------- /clustercontrast/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .resnet_ibn import * 5 | from .resnet_ibn import * 6 | from .agw import * 7 | __factory = { 8 | 'resnet18': resnet18, 9 | 'resnet34': resnet34, 10 | 'resnet50': resnet50, 11 | 'resnet101': resnet101, 12 | 'resnet152': resnet152, 13 | 'resnet_ibn50a': resnet_ibn50a, 14 | 'resnet_ibn101a': resnet_ibn101a, 15 | 'agw':agw 16 | } 17 | 18 | 19 | def names(): 20 | return sorted(__factory.keys()) 21 | 22 | 23 | def create(name, *args, **kwargs): 24 | """ 25 | Create a model instance. 26 | 27 | Parameters 28 | ---------- 29 | name : str 30 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 31 | 'resnet50', 'resnet101', and 'resnet152'. 32 | pretrained : bool, optional 33 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 34 | model. Default: True 35 | cut_at_pooling : bool, optional 36 | If True, will cut the model before the last global pooling layer and 37 | ignore the remaining kwargs. Default: False 38 | num_features : int, optional 39 | If positive, will append a Linear layer after the global pooling layer, 40 | with this number of output units, followed by a BatchNorm layer. 41 | Otherwise these layers will not be appended. Default: 256 for 42 | 'inception', 0 for 'resnet*' 43 | norm : bool, optional 44 | If True, will normalize the feature to be unit L2-norm for each sample. 45 | Otherwise will append a ReLU layer after the above Linear layer if 46 | num_features > 0. Default: False 47 | dropout : float, optional 48 | If positive, will append a Dropout layer with this dropout rate. 49 | Default: 0 50 | num_classes : int, optional 51 | If positive, will append a Linear layer at the end as the classifier 52 | with this number of output units. Default: 0 53 | """ 54 | if name not in __factory: 55 | raise KeyError("Unknown model:", name) 56 | return __factory[name](*args, **kwargs) 57 | -------------------------------------------------------------------------------- /clustercontrast/models/agw.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from .resnet_agw import resnet50 as resnet50_agw 5 | 6 | class Normalize(nn.Module): 7 | def __init__(self, power=2): 8 | super(Normalize, self).__init__() 9 | self.power = power 10 | 11 | def forward(self, x): 12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 13 | out = x.div(norm) 14 | return out 15 | 16 | class Non_local(nn.Module): 17 | def __init__(self, in_channels, reduc_ratio=2): 18 | super(Non_local, self).__init__() 19 | 20 | self.in_channels = in_channels 21 | self.inter_channels = reduc_ratio//reduc_ratio 22 | 23 | self.g = nn.Sequential( 24 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 25 | padding=0), 26 | ) 27 | 28 | self.W = nn.Sequential( 29 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 30 | kernel_size=1, stride=1, padding=0), 31 | nn.BatchNorm2d(self.in_channels), 32 | ) 33 | nn.init.constant_(self.W[1].weight, 0.0) 34 | nn.init.constant_(self.W[1].bias, 0.0) 35 | 36 | 37 | 38 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 39 | kernel_size=1, stride=1, padding=0) 40 | 41 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 42 | kernel_size=1, stride=1, padding=0) 43 | 44 | def forward(self, x): 45 | ''' 46 | :param x: (b, c, t, h, w) 47 | :return: 48 | ''' 49 | 50 | batch_size = x.size(0) 51 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 52 | g_x = g_x.permute(0, 2, 1) 53 | 54 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 55 | theta_x = theta_x.permute(0, 2, 1) 56 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 57 | f = torch.matmul(theta_x, phi_x) 58 | N = f.size(-1) 59 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 60 | f_div_C = f / N 61 | 62 | y = torch.matmul(f_div_C, g_x) 63 | y = y.permute(0, 2, 1).contiguous() 64 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 65 | W_y = self.W(y) 66 | z = W_y + x 67 | 68 | return z 69 | 70 | 71 | # ##################################################################### 72 | def weights_init_kaiming(m): 73 | classname = m.__class__.__name__ 74 | # print(classname) 75 | if classname.find('Conv') != -1: 76 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 77 | elif classname.find('Linear') != -1: 78 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 79 | init.zeros_(m.bias.data) 80 | elif classname.find('BatchNorm1d') != -1: 81 | init.normal_(m.weight.data, 1.0, 0.01) 82 | init.zeros_(m.bias.data) 83 | 84 | def weights_init_classifier(m): 85 | classname = m.__class__.__name__ 86 | if classname.find('Linear') != -1: 87 | init.normal_(m.weight.data, 0, 0.001) 88 | if m.bias: 89 | init.zeros_(m.bias.data) 90 | 91 | 92 | 93 | class visible_module(nn.Module): 94 | def __init__(self, arch='resnet50'): 95 | super(visible_module, self).__init__() 96 | 97 | model_v = resnet50_agw(pretrained=True, 98 | last_conv_stride=1, last_conv_dilation=1) 99 | # avg pooling to global pooling 100 | self.visible = model_v 101 | 102 | def forward(self, x): 103 | x = self.visible.conv1(x) 104 | x = self.visible.bn1(x) 105 | x = self.visible.relu(x) 106 | x = self.visible.maxpool(x) 107 | return x 108 | 109 | 110 | class thermal_module(nn.Module): 111 | def __init__(self, arch='resnet50'): 112 | super(thermal_module, self).__init__() 113 | 114 | model_t = resnet50_agw(pretrained=True, 115 | last_conv_stride=1, last_conv_dilation=1) 116 | # avg pooling to global pooling 117 | self.thermal = model_t 118 | 119 | def forward(self, x): 120 | x = self.thermal.conv1(x) 121 | x = self.thermal.bn1(x) 122 | x = self.thermal.relu(x) 123 | x = self.thermal.maxpool(x) 124 | return x 125 | 126 | 127 | class base_resnet(nn.Module): 128 | def __init__(self, arch='resnet50'): 129 | super(base_resnet, self).__init__() 130 | 131 | model_base = resnet50_agw(pretrained=True, 132 | last_conv_stride=1, last_conv_dilation=1) 133 | # avg pooling to global pooling 134 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 135 | self.base = model_base 136 | 137 | def forward(self, x): 138 | x = self.base.layer1(x) 139 | x = self.base.layer2(x) 140 | x = self.base.layer3(x) 141 | x = self.base.layer4(x) 142 | return x 143 | 144 | ##### 145 | class embed_net_ori(nn.Module): 146 | def __init__(self, num_classes=1000, no_local= 'on', gm_pool = 'on', arch='resnet50'): 147 | super(embed_net_ori, self).__init__() 148 | 149 | self.thermal_module = thermal_module(arch=arch) 150 | self.visible_module = visible_module(arch=arch) 151 | self.base_resnet = base_resnet(arch=arch) 152 | self.non_local = no_local 153 | if self.non_local =='on': 154 | layers=[3, 4, 6, 3] 155 | non_layers=[0,2,3,0] 156 | self.NL_1 = nn.ModuleList( 157 | [Non_local(256) for i in range(non_layers[0])]) 158 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 159 | self.NL_2 = nn.ModuleList( 160 | [Non_local(512) for i in range(non_layers[1])]) 161 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 162 | self.NL_3 = nn.ModuleList( 163 | [Non_local(1024) for i in range(non_layers[2])]) 164 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 165 | self.NL_4 = nn.ModuleList( 166 | [Non_local(2048) for i in range(non_layers[3])]) 167 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 168 | 169 | 170 | pool_dim = 2048 171 | self.num_features = pool_dim 172 | self.l2norm = Normalize(2) 173 | self.bottleneck = nn.BatchNorm1d(pool_dim) 174 | self.bottleneck.bias.requires_grad_(False) # no shift 175 | 176 | self.classifier = nn.Linear(pool_dim, num_classes, bias=False) 177 | 178 | self.bottleneck.apply(weights_init_kaiming) 179 | self.classifier.apply(weights_init_classifier) 180 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 181 | self.gm_pool = gm_pool 182 | 183 | def forward(self, x1, x2, modal=0,label_1=None,label_2=None): 184 | # print(x1,x2) 185 | single_size = x1.size(0) 186 | if modal == 0: 187 | x1 = self.visible_module(x1) 188 | x2 = self.thermal_module(x2) 189 | x = torch.cat((x1, x2), 0) 190 | label = torch.cat((label_1, label_2), -1) 191 | elif modal == 1: 192 | x = self.visible_module(x1) 193 | elif modal == 2: 194 | x = self.thermal_module(x2) 195 | 196 | # shared block 197 | if self.non_local == 'on': 198 | NL1_counter = 0 199 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] 200 | for i in range(len(self.base_resnet.base.layer1)): 201 | x = self.base_resnet.base.layer1[i](x) 202 | if i == self.NL_1_idx[NL1_counter]: 203 | _, C, H, W = x.shape 204 | x = self.NL_1[NL1_counter](x) 205 | NL1_counter += 1 206 | # Layer 2 207 | NL2_counter = 0 208 | if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] 209 | for i in range(len(self.base_resnet.base.layer2)): 210 | x = self.base_resnet.base.layer2[i](x) 211 | if i == self.NL_2_idx[NL2_counter]: 212 | _, C, H, W = x.shape 213 | x = self.NL_2[NL2_counter](x) 214 | NL2_counter += 1 215 | # Layer 3 216 | NL3_counter = 0 217 | if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] 218 | for i in range(len(self.base_resnet.base.layer3)): 219 | x = self.base_resnet.base.layer3[i](x) 220 | if i == self.NL_3_idx[NL3_counter]: 221 | _, C, H, W = x.shape 222 | x = self.NL_3[NL3_counter](x) 223 | NL3_counter += 1 224 | # Layer 4 225 | NL4_counter = 0 226 | if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] 227 | for i in range(len(self.base_resnet.base.layer4)): 228 | x = self.base_resnet.base.layer4[i](x) 229 | if i == self.NL_4_idx[NL4_counter]: 230 | _, C, H, W = x.shape 231 | x = self.NL_4[NL4_counter](x) 232 | NL4_counter += 1 233 | else: 234 | x = self.base_resnet(x) 235 | if self.gm_pool == 'on': 236 | b, c, h, w = x.shape 237 | x = x.view(b, c, -1) 238 | p = 3.0 239 | x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p) 240 | else: 241 | x_pool = self.avgpool(x) 242 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 243 | 244 | feat = self.bottleneck(x_pool) 245 | 246 | 247 | if self.training: 248 | return feat,feat[:single_size],feat[single_size:],label_1,label_2,x_pool[:single_size],x_pool[single_size:] 249 | # x_pool#, self.classifier(feat) 250 | else: 251 | # return self.l2norm(x_pool), self.l2norm(feat) 252 | return self.l2norm(feat)#self.l2norm(x_pool)#, 253 | 254 | # if self.training: 255 | # return x_pool, self.classifier(feat) 256 | # else: 257 | # return self.l2norm(x_pool), self.l2norm(feat) 258 | 259 | 260 | 261 | 262 | def agw(pretrained=False,no_local= 'on', **kwargs): 263 | """Constructs a ResNet-50 model. 264 | Args: 265 | pretrained (bool): If True, returns a model pre-trained on ImageNet 266 | """ 267 | model = embed_net_ori(no_local= 'on', gm_pool = 'on') #without no-local -> resnet with non-local->agw 268 | 269 | return model -------------------------------------------------------------------------------- /clustercontrast/models/cm.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | from abc import ABC 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, autograd 7 | 8 | 9 | class CM(autograd.Function): 10 | 11 | @staticmethod 12 | def forward(ctx, inputs, targets, features, momentum): 13 | ctx.features = features 14 | ctx.momentum = momentum 15 | ctx.save_for_backward(inputs, targets) 16 | outputs = inputs.mm(ctx.features.t()) 17 | 18 | return outputs 19 | 20 | @staticmethod 21 | def backward(ctx, grad_outputs): 22 | inputs, targets = ctx.saved_tensors 23 | grad_inputs = None 24 | if ctx.needs_input_grad[0]: 25 | grad_inputs = grad_outputs.mm(ctx.features) 26 | 27 | # momentum update 28 | for x, y in zip(inputs, targets): 29 | ctx.features[y] = ctx.momentum * ctx.features[y] + (1. - ctx.momentum) * x 30 | ctx.features[y] /= ctx.features[y].norm() 31 | 32 | return grad_inputs, None, None, None 33 | 34 | 35 | def cm(inputs, indexes, features, momentum=0.5): 36 | return CM.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 37 | 38 | 39 | class CM_Hard(autograd.Function): 40 | 41 | @staticmethod 42 | def forward(ctx, inputs, targets, features, momentum): 43 | ctx.features = features 44 | ctx.momentum = momentum 45 | ctx.save_for_backward(inputs, targets) 46 | outputs = inputs.mm(ctx.features.t()) 47 | 48 | return outputs 49 | 50 | @staticmethod 51 | def backward(ctx, grad_outputs): 52 | inputs, targets = ctx.saved_tensors 53 | grad_inputs = None 54 | if ctx.needs_input_grad[0]: 55 | grad_inputs = grad_outputs.mm(ctx.features) 56 | 57 | batch_centers = collections.defaultdict(list) 58 | for instance_feature, index in zip(inputs, targets.tolist()): 59 | batch_centers[index].append(instance_feature) 60 | 61 | for index, features in batch_centers.items(): 62 | distances = [] 63 | for feature in features: 64 | distance = feature.unsqueeze(0).mm(ctx.features[index].unsqueeze(0).t())[0][0] 65 | distances.append(distance.cpu().numpy()) 66 | 67 | median = np.argmin(np.array(distances)) 68 | ctx.features[index] = ctx.features[index] * ctx.momentum + (1 - ctx.momentum) * features[median] 69 | ctx.features[index] /= ctx.features[index].norm() 70 | 71 | return grad_inputs, None, None, None 72 | 73 | 74 | def cm_hard(inputs, indexes, features, momentum=0.5): 75 | return CM_Hard.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 76 | 77 | 78 | class ClusterMemory(nn.Module, ABC): 79 | def __init__(self, num_features, num_samples, temp=0.05, momentum=0.2, use_hard=False): 80 | super(ClusterMemory, self).__init__() 81 | self.num_features = num_features 82 | self.num_samples = num_samples 83 | 84 | self.momentum = momentum 85 | self.temp = temp 86 | self.use_hard = use_hard 87 | 88 | self.register_buffer('features', torch.zeros(num_samples, num_features)) 89 | 90 | def forward(self, inputs, targets, momentum=None): 91 | if not momentum: 92 | inputs = F.normalize(inputs, dim=1).cuda() 93 | if self.use_hard: 94 | outputs = cm_hard(inputs, targets, self.features, self.momentum) 95 | else: 96 | outputs = cm(inputs, targets, self.features, self.momentum) 97 | outputs /= self.temp 98 | loss = F.cross_entropy(outputs, targets) 99 | return loss 100 | # dynamic ccl loss 101 | # lamda = 0.5 102 | # weight = torch.pow((1 - F.softmax(outputs,1)), lamda) 103 | # outputs = F.softmax(outputs, dim=1) 104 | # loss = 0.0 105 | # for b in range(targets.shape[0]): 106 | # loss -= torch.pow((1 - outputs[b][targets[b]]), lamda) * torch.log(outputs[b][targets[b]]) 107 | # return loss / targets.shape[0] 108 | # cross specific momentum 109 | else: 110 | inputs = F.normalize(inputs, dim=1).cuda() 111 | if self.use_hard: 112 | outputs = cm_hard(inputs, targets, self.features, momentum) 113 | else: 114 | outputs = cm(inputs, targets, self.features, momentum) 115 | outputs /= self.temp 116 | loss = F.cross_entropy(outputs, targets) 117 | return loss 118 | # dynamic ccl loss 119 | # lamda = 0.5 120 | # weight = torch.pow((1 - F.softmax(outputs,1)), lamda) 121 | # outputs = F.softmax(outputs, dim=1) 122 | # loss = 0.0 123 | # for b in range(targets.shape[0]): 124 | # loss -= torch.pow((1 - outputs[b][targets[b]]), lamda) * torch.log(outputs[b][targets[b]]) 125 | # return loss / targets.shape[0] 126 | -------------------------------------------------------------------------------- /clustercontrast/models/dsbn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Domain-specific BatchNorm 5 | 6 | class DSBN2d(nn.Module): 7 | def __init__(self, planes): 8 | super(DSBN2d, self).__init__() 9 | self.num_features = planes 10 | self.BN_S = nn.BatchNorm2d(planes) 11 | self.BN_T = nn.BatchNorm2d(planes) 12 | 13 | def forward(self, x): 14 | if (not self.training): 15 | return self.BN_T(x) 16 | 17 | bs = x.size(0) 18 | assert (bs%2==0) 19 | split = torch.split(x, int(bs/2), 0) 20 | out1 = self.BN_S(split[0].contiguous()) 21 | out2 = self.BN_T(split[1].contiguous()) 22 | out = torch.cat((out1, out2), 0) 23 | return out 24 | 25 | class DSBN1d(nn.Module): 26 | def __init__(self, planes): 27 | super(DSBN1d, self).__init__() 28 | self.num_features = planes 29 | self.BN_S = nn.BatchNorm1d(planes) 30 | self.BN_T = nn.BatchNorm1d(planes) 31 | 32 | def forward(self, x): 33 | if (not self.training): 34 | return self.BN_T(x) 35 | 36 | bs = x.size(0) 37 | assert (bs%2==0) 38 | split = torch.split(x, int(bs/2), 0) 39 | out1 = self.BN_S(split[0].contiguous()) 40 | out2 = self.BN_T(split[1].contiguous()) 41 | out = torch.cat((out1, out2), 0) 42 | return out 43 | 44 | def convert_dsbn(model): 45 | for _, (child_name, child) in enumerate(model.named_children()): 46 | assert(not next(model.parameters()).is_cuda) 47 | if isinstance(child, nn.BatchNorm2d): 48 | m = DSBN2d(child.num_features) 49 | m.BN_S.load_state_dict(child.state_dict()) 50 | m.BN_T.load_state_dict(child.state_dict()) 51 | setattr(model, child_name, m) 52 | elif isinstance(child, nn.BatchNorm1d): 53 | m = DSBN1d(child.num_features) 54 | m.BN_S.load_state_dict(child.state_dict()) 55 | m.BN_T.load_state_dict(child.state_dict()) 56 | setattr(model, child_name, m) 57 | else: 58 | convert_dsbn(child) 59 | 60 | def convert_bn(model, use_target=True): 61 | for _, (child_name, child) in enumerate(model.named_children()): 62 | assert(not next(model.parameters()).is_cuda) 63 | if isinstance(child, DSBN2d): 64 | m = nn.BatchNorm2d(child.num_features) 65 | if use_target: 66 | m.load_state_dict(child.BN_T.state_dict()) 67 | else: 68 | m.load_state_dict(child.BN_S.state_dict()) 69 | setattr(model, child_name, m) 70 | elif isinstance(child, DSBN1d): 71 | m = nn.BatchNorm1d(child.num_features) 72 | if use_target: 73 | m.load_state_dict(child.BN_T.state_dict()) 74 | else: 75 | m.load_state_dict(child.BN_S.state_dict()) 76 | setattr(model, child_name, m) 77 | else: 78 | convert_bn(child, use_target=use_target) 79 | -------------------------------------------------------------------------------- /clustercontrast/models/kmeans.py: -------------------------------------------------------------------------------- 1 | # Written by Yixiao Ge 2 | 3 | import warnings 4 | 5 | import faiss 6 | import torch 7 | 8 | from ..utils import to_numpy, to_torch 9 | 10 | __all__ = ["label_generator_kmeans"] 11 | 12 | 13 | @torch.no_grad() 14 | def label_generator_kmeans(features, num_classes=500, cuda=True): 15 | 16 | assert num_classes, "num_classes for kmeans is null" 17 | 18 | # k-means cluster by faiss 19 | cluster = faiss.Kmeans( 20 | features.size(-1), num_classes, niter=300, verbose=True, gpu=cuda 21 | ) 22 | 23 | cluster.train(to_numpy(features)) 24 | 25 | _, labels = cluster.index.search(to_numpy(features), 1) 26 | labels = labels.reshape(-1) 27 | 28 | centers = to_torch(cluster.centroids).float() 29 | # labels = to_torch(labels).long() 30 | 31 | # k-means does not have outlier points 32 | assert not (-1 in labels) 33 | 34 | return labels, centers, num_classes, None 35 | -------------------------------------------------------------------------------- /clustercontrast/models/pooling.py: -------------------------------------------------------------------------------- 1 | # Credit to https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/layers/pooling.py 2 | from abc import ABC 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | __all__ = [ 9 | "GeneralizedMeanPoolingPFpn", 10 | "GeneralizedMeanPoolingList", 11 | "GeneralizedMeanPoolingP", 12 | "AdaptiveAvgMaxPool2d", 13 | "FastGlobalAvgPool2d", 14 | "avg_pooling", 15 | "max_pooling", 16 | ] 17 | 18 | 19 | class GeneralizedMeanPoolingList(nn.Module, ABC): 20 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of 21 | several input planes. 22 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 23 | - At p = infinity, one gets Max Pooling 24 | - At p = 1, one gets Average Pooling 25 | The output is of size H x W, for any input size. 26 | The number of output features is equal to the number of input planes. 27 | Args: 28 | output_size: the target output size of the image of the form H x W. 29 | Can be a tuple (H, W) or a single H for a square image H x H 30 | H and W can be either a ``int``, or ``None`` which means the size 31 | will be the same as that of the input. 32 | """ 33 | 34 | def __init__(self, output_size=1, eps=1e-6): 35 | super(GeneralizedMeanPoolingList, self).__init__() 36 | self.output_size = output_size 37 | self.eps = eps 38 | 39 | def forward(self, x_list): 40 | outs = [] 41 | for x in x_list: 42 | x = x.clamp(min=self.eps) 43 | out = torch.nn.functional.adaptive_avg_pool2d(x, self.output_size) 44 | outs.append(out) 45 | return torch.stack(outs, -1).mean(-1) 46 | 47 | def __repr__(self): 48 | return ( 49 | self.__class__.__name__ 50 | + "(" 51 | + "output_size=" 52 | + str(self.output_size) 53 | + ")" 54 | ) 55 | 56 | 57 | class GeneralizedMeanPooling(nn.Module, ABC): 58 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of 59 | several input planes. 60 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 61 | - At p = infinity, one gets Max Pooling 62 | - At p = 1, one gets Average Pooling 63 | The output is of size H x W, for any input size. 64 | The number of output features is equal to the number of input planes. 65 | Args: 66 | output_size: the target output size of the image of the form H x W. 67 | Can be a tuple (H, W) or a single H for a square image H x H 68 | H and W can be either a ``int``, or ``None`` which means the size 69 | will be the same as that of the input. 70 | """ 71 | 72 | def __init__(self, norm, output_size=1, eps=1e-6): 73 | super(GeneralizedMeanPooling, self).__init__() 74 | assert norm > 0 75 | self.p = float(norm) 76 | self.output_size = output_size 77 | self.eps = eps 78 | 79 | def forward(self, x): 80 | x = x.clamp(min=self.eps).pow(self.p) 81 | return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow( 82 | 1.0 / self.p 83 | ) 84 | 85 | def __repr__(self): 86 | return ( 87 | self.__class__.__name__ 88 | + "(" 89 | + str(self.p) 90 | + ", " 91 | + "output_size=" 92 | + str(self.output_size) 93 | + ")" 94 | ) 95 | 96 | 97 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling, ABC): 98 | """ Same, but norm is trainable 99 | """ 100 | 101 | def __init__(self, norm=3, output_size=1, eps=1e-6): 102 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 103 | self.p = nn.Parameter(torch.ones(1) * norm) 104 | 105 | 106 | class GeneralizedMeanPoolingFpn(nn.Module, ABC): 107 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of 108 | several input planes. 109 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 110 | - At p = infinity, one gets Max Pooling 111 | - At p = 1, one gets Average Pooling 112 | The output is of size H x W, for any input size. 113 | The number of output features is equal to the number of input planes. 114 | Args: 115 | output_size: the target output size of the image of the form H x W. 116 | Can be a tuple (H, W) or a single H for a square image H x H 117 | H and W can be either a ``int``, or ``None`` which means the size 118 | will be the same as that of the input. 119 | """ 120 | 121 | def __init__(self, norm, output_size=1, eps=1e-6): 122 | super(GeneralizedMeanPoolingFpn, self).__init__() 123 | assert norm > 0 124 | self.p = float(norm) 125 | self.output_size = output_size 126 | self.eps = eps 127 | 128 | def forward(self, x_lists): 129 | outs = [] 130 | for x in x_lists: 131 | x = x.clamp(min=self.eps).pow(self.p) 132 | out = torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow( 133 | 1.0 / self.p 134 | ) 135 | outs.append(out) 136 | return torch.cat(outs, 1) 137 | 138 | def __repr__(self): 139 | return ( 140 | self.__class__.__name__ 141 | + "(" 142 | + str(self.p) 143 | + ", " 144 | + "output_size=" 145 | + str(self.output_size) 146 | + ")" 147 | ) 148 | 149 | 150 | class GeneralizedMeanPoolingPFpn(GeneralizedMeanPoolingFpn, ABC): 151 | """ Same, but norm is trainable 152 | """ 153 | 154 | def __init__(self, norm=3, output_size=1, eps=1e-6): 155 | super(GeneralizedMeanPoolingPFpn, self).__init__(norm, output_size, eps) 156 | self.p = nn.Parameter(torch.ones(1) * norm) 157 | 158 | 159 | class AdaptiveAvgMaxPool2d(nn.Module, ABC): 160 | def __init__(self): 161 | super(AdaptiveAvgMaxPool2d, self).__init__() 162 | self.avgpool = FastGlobalAvgPool2d() 163 | 164 | def forward(self, x): 165 | x_avg = self.avgpool(x, self.output_size) 166 | x_max = F.adaptive_max_pool2d(x, 1) 167 | x = x_max + x_avg 168 | return x 169 | 170 | 171 | class FastGlobalAvgPool2d(nn.Module, ABC): 172 | def __init__(self, flatten=False): 173 | super(FastGlobalAvgPool2d, self).__init__() 174 | self.flatten = flatten 175 | 176 | def forward(self, x): 177 | if self.flatten: 178 | in_size = x.size() 179 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2) 180 | else: 181 | return ( 182 | x.view(x.size(0), x.size(1), -1) 183 | .mean(-1) 184 | .view(x.size(0), x.size(1), 1, 1) 185 | ) 186 | 187 | 188 | def avg_pooling(): 189 | return nn.AdaptiveAvgPool2d(1) 190 | # return FastGlobalAvgPool2d() 191 | 192 | 193 | def max_pooling(): 194 | return nn.AdaptiveMaxPool2d(1) 195 | 196 | 197 | class Flatten(nn.Module): 198 | def forward(self, input): 199 | return input.view(input.size(0), -1) 200 | 201 | 202 | __pooling_factory = { 203 | "avg": avg_pooling, 204 | "max": max_pooling, 205 | "gem": GeneralizedMeanPoolingP, 206 | "gemFpn": GeneralizedMeanPoolingPFpn, 207 | "gemList": GeneralizedMeanPoolingList, 208 | "avg+max": AdaptiveAvgMaxPool2d, 209 | } 210 | 211 | 212 | def pooling_names(): 213 | return sorted(__pooling_factory.keys()) 214 | 215 | 216 | def build_pooling_layer(name): 217 | """ 218 | Create a pooling layer. 219 | Parameters 220 | ---------- 221 | name : str 222 | The backbone name. 223 | """ 224 | if name not in __pooling_factory: 225 | raise KeyError("Unknown pooling layer:", name) 226 | return __pooling_factory[name]() -------------------------------------------------------------------------------- /clustercontrast/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.nn import init 5 | import torchvision 6 | import torch 7 | from .pooling import build_pooling_layer 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152'] 12 | 13 | 14 | class ResNet(nn.Module): 15 | __factory = { 16 | 18: torchvision.models.resnet18, 17 | 34: torchvision.models.resnet34, 18 | 50: torchvision.models.resnet50, 19 | 101: torchvision.models.resnet101, 20 | 152: torchvision.models.resnet152, 21 | } 22 | 23 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 24 | num_features=0, norm=False, dropout=0, num_classes=0, pooling_type='avg'): 25 | super(ResNet, self).__init__() 26 | self.pretrained = pretrained 27 | self.depth = depth 28 | self.cut_at_pooling = cut_at_pooling 29 | # Construct base (pretrained) resnet 30 | if depth not in ResNet.__factory: 31 | raise KeyError("Unsupported depth:", depth) 32 | # resnet = ResNet.__factory[depth](pretrained=pretrained) 33 | resnet = ResNet.__factory[depth](pretrained=False) 34 | resnet.load_state_dict( torch.load('/dat01/yangbin/cluster-contrast-reid-main/examples/pretrained/resnet50-19c8e357.pth')) 35 | 36 | resnet.layer4[0].conv2.stride = (1, 1) 37 | resnet.layer4[0].downsample[0].stride = (1, 1) 38 | self.base = nn.Sequential( 39 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 40 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 41 | 42 | self.gap = build_pooling_layer(pooling_type) 43 | 44 | if not self.cut_at_pooling: 45 | self.num_features = num_features 46 | self.norm = norm 47 | self.dropout = dropout 48 | self.has_embedding = num_features > 0 49 | self.num_classes = num_classes 50 | 51 | out_planes = resnet.fc.in_features 52 | 53 | # Append new layers 54 | if self.has_embedding: 55 | self.feat = nn.Linear(out_planes, self.num_features) 56 | self.feat_bn = nn.BatchNorm1d(self.num_features) 57 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 58 | init.constant_(self.feat.bias, 0) 59 | else: 60 | # Change the num_features to CNN output channels 61 | self.num_features = out_planes 62 | self.feat_bn = nn.BatchNorm1d(self.num_features) 63 | self.feat_bn.bias.requires_grad_(False) 64 | if self.dropout > 0: 65 | self.drop = nn.Dropout(self.dropout) 66 | if self.num_classes > 0: 67 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 68 | init.normal_(self.classifier.weight, std=0.001) 69 | init.constant_(self.feat_bn.weight, 1) 70 | init.constant_(self.feat_bn.bias, 0) 71 | 72 | if not pretrained: 73 | self.reset_params() 74 | 75 | def forward(self, x): 76 | bs = x.size(0) 77 | x = self.base(x) 78 | 79 | x = self.gap(x) 80 | x = x.view(x.size(0), -1) 81 | 82 | if self.cut_at_pooling: 83 | return x 84 | 85 | if self.has_embedding: 86 | bn_x = self.feat_bn(self.feat(x)) 87 | else: 88 | bn_x = self.feat_bn(x) 89 | 90 | if (self.training is False): 91 | bn_x = F.normalize(bn_x) 92 | return bn_x 93 | 94 | if self.norm: 95 | bn_x = F.normalize(bn_x) 96 | elif self.has_embedding: 97 | bn_x = F.relu(bn_x) 98 | 99 | if self.dropout > 0: 100 | bn_x = self.drop(bn_x) 101 | 102 | if self.num_classes > 0: 103 | prob = self.classifier(bn_x) 104 | else: 105 | return bn_x 106 | 107 | return prob 108 | 109 | def reset_params(self): 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | init.kaiming_normal_(m.weight, mode='fan_out') 113 | if m.bias is not None: 114 | init.constant_(m.bias, 0) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | init.constant_(m.weight, 1) 117 | init.constant_(m.bias, 0) 118 | elif isinstance(m, nn.BatchNorm1d): 119 | init.constant_(m.weight, 1) 120 | init.constant_(m.bias, 0) 121 | elif isinstance(m, nn.Linear): 122 | init.normal_(m.weight, std=0.001) 123 | if m.bias is not None: 124 | init.constant_(m.bias, 0) 125 | 126 | 127 | def resnet18(**kwargs): 128 | return ResNet(18, **kwargs) 129 | 130 | 131 | def resnet34(**kwargs): 132 | return ResNet(34, **kwargs) 133 | 134 | 135 | def resnet50(**kwargs): 136 | return ResNet(50, **kwargs) 137 | 138 | 139 | def resnet101(**kwargs): 140 | return ResNet(101, **kwargs) 141 | 142 | 143 | def resnet152(**kwargs): 144 | return ResNet(152, **kwargs) 145 | -------------------------------------------------------------------------------- /clustercontrast/models/resnet_agw.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': './examples/pretrained/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | # original padding is 1; original dilation is 1 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, bias=False, dilation=dilation) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | # original padding is 1; original dilation is 1 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1): 98 | 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | m.weight.data.normal_(0, math.sqrt(2. / n)) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | m.weight.data.fill_(1) 117 | m.bias.data.zero_() 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | return x 148 | 149 | 150 | def remove_fc(state_dict): 151 | """Remove the fc layer parameters from state_dict.""" 152 | # for key, value in state_dict.items(): 153 | for key, value in list(state_dict.items()): 154 | if key.startswith('fc.'): 155 | del state_dict[key] 156 | return state_dict 157 | 158 | 159 | def resnet18(pretrained=False, **kwargs): 160 | """Constructs a ResNet-18 model. 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | """ 164 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 165 | if pretrained: 166 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 167 | return model 168 | 169 | 170 | def resnet34(pretrained=False, **kwargs): 171 | """Constructs a ResNet-34 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | # model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 189 | model.load_state_dict(remove_fc(torch.load(model_urls['resnet50']))) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict( 201 | remove_fc(model_zoo.load_url(model_urls['resnet101']))) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict( 213 | remove_fc(model_zoo.load_url(model_urls['resnet152']))) 214 | return model -------------------------------------------------------------------------------- /clustercontrast/models/resnet_ibn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | from .pooling import build_pooling_layer 9 | 10 | from .resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a 11 | 12 | 13 | __all__ = ['ResNetIBN', 'resnet_ibn50a', 'resnet_ibn101a'] 14 | 15 | 16 | class ResNetIBN(nn.Module): 17 | __factory = { 18 | '50a': resnet50_ibn_a, 19 | '101a': resnet101_ibn_a 20 | } 21 | 22 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 23 | num_features=0, norm=False, dropout=0, num_classes=0, pooling_type='avg'): 24 | 25 | super(ResNetIBN, self).__init__() 26 | 27 | self.depth = depth 28 | self.pretrained = pretrained 29 | self.cut_at_pooling = cut_at_pooling 30 | 31 | resnet = ResNetIBN.__factory[depth](pretrained=pretrained) 32 | resnet.layer4[0].conv2.stride = (1, 1) 33 | resnet.layer4[0].downsample[0].stride = (1, 1) 34 | 35 | self.base = nn.Sequential( 36 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 37 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 38 | 39 | self.gap = build_pooling_layer(pooling_type) 40 | 41 | if not self.cut_at_pooling: 42 | self.num_features = num_features 43 | self.norm = norm 44 | self.dropout = dropout 45 | self.has_embedding = num_features > 0 46 | self.num_classes = num_classes 47 | 48 | out_planes = resnet.fc.in_features 49 | 50 | # Append new layers 51 | if self.has_embedding: 52 | self.feat = nn.Linear(out_planes, self.num_features) 53 | self.feat_bn = nn.BatchNorm1d(self.num_features) 54 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 55 | init.constant_(self.feat.bias, 0) 56 | else: 57 | # Change the num_features to CNN output channels 58 | self.num_features = out_planes 59 | self.feat_bn = nn.BatchNorm1d(self.num_features) 60 | self.feat_bn.bias.requires_grad_(False) 61 | if self.dropout > 0: 62 | self.drop = nn.Dropout(self.dropout) 63 | if self.num_classes > 0: 64 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 65 | init.normal_(self.classifier.weight, std=0.001) 66 | 67 | init.constant_(self.feat_bn.weight, 1) 68 | init.constant_(self.feat_bn.bias, 0) 69 | 70 | if not pretrained: 71 | self.reset_params() 72 | 73 | def forward(self, x): 74 | x = self.base(x) 75 | 76 | x = self.gap(x) 77 | x = x.view(x.size(0), -1) 78 | 79 | if self.cut_at_pooling: 80 | return x 81 | 82 | if self.has_embedding: 83 | bn_x = self.feat_bn(self.feat(x)) 84 | else: 85 | bn_x = self.feat_bn(x) 86 | 87 | if self.training is False: 88 | bn_x = F.normalize(bn_x) 89 | return bn_x 90 | 91 | if self.norm: 92 | bn_x = F.normalize(bn_x) 93 | elif self.has_embedding: 94 | bn_x = F.relu(bn_x) 95 | 96 | if self.dropout > 0: 97 | bn_x = self.drop(bn_x) 98 | 99 | if self.num_classes > 0: 100 | prob = self.classifier(bn_x) 101 | else: 102 | return bn_x 103 | 104 | return prob 105 | 106 | def reset_params(self): 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | init.kaiming_normal_(m.weight, mode='fan_out') 110 | if m.bias is not None: 111 | init.constant_(m.bias, 0) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | init.constant_(m.weight, 1) 114 | init.constant_(m.bias, 0) 115 | elif isinstance(m, nn.BatchNorm1d): 116 | init.constant_(m.weight, 1) 117 | init.constant_(m.bias, 0) 118 | elif isinstance(m, nn.Linear): 119 | init.normal_(m.weight, std=0.001) 120 | if m.bias is not None: 121 | init.constant_(m.bias, 0) 122 | 123 | 124 | def resnet_ibn50a(**kwargs): 125 | return ResNetIBN('50a', **kwargs) 126 | 127 | 128 | def resnet_ibn101a(**kwargs): 129 | return ResNetIBN('101a', **kwargs) 130 | -------------------------------------------------------------------------------- /clustercontrast/models/resnet_ibn_a.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['ResNet', 'resnet50_ibn_a', 'resnet101_ibn_a'] 8 | 9 | 10 | model_urls = { 11 | 'ibn_resnet50a': './examples/pretrained/resnet50_ibn_a.pth.tar', 12 | 'ibn_resnet101a': './examples/pretrained/resnet101_ibn_a.pth.tar', 13 | } 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class IBN(nn.Module): 55 | def __init__(self, planes): 56 | super(IBN, self).__init__() 57 | half1 = int(planes/2) 58 | self.half = half1 59 | half2 = planes - half1 60 | self.IN = nn.InstanceNorm2d(half1, affine=True) 61 | self.BN = nn.BatchNorm2d(half2) 62 | 63 | def forward(self, x): 64 | split = torch.split(x, self.half, 1) 65 | out1 = self.IN(split[0].contiguous()) 66 | out2 = self.BN(split[1].contiguous()) 67 | out = torch.cat((out1, out2), 1) 68 | return out 69 | 70 | class Bottleneck(nn.Module): 71 | expansion = 4 72 | 73 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): 74 | super(Bottleneck, self).__init__() 75 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 76 | if ibn: 77 | self.bn1 = IBN(planes) 78 | else: 79 | self.bn1 = nn.BatchNorm2d(planes) 80 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 81 | padding=1, bias=False) 82 | self.bn2 = nn.BatchNorm2d(planes) 83 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 84 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | residual = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | residual = self.downsample(x) 105 | 106 | out += residual 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, block, layers, num_classes=1000): 115 | scale = 64 116 | self.inplanes = scale 117 | super(ResNet, self).__init__() 118 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 119 | bias=False) 120 | self.bn1 = nn.BatchNorm2d(scale) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 123 | self.layer1 = self._make_layer(block, scale, layers[0]) 124 | self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2) 125 | self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2) 126 | self.layer4 = self._make_layer(block, scale*8, layers[3], stride=2) 127 | self.avgpool = nn.AvgPool2d(7) 128 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 129 | 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 133 | m.weight.data.normal_(0, math.sqrt(2. / n)) 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | elif isinstance(m, nn.InstanceNorm2d): 138 | m.weight.data.fill_(1) 139 | m.bias.data.zero_() 140 | 141 | def _make_layer(self, block, planes, blocks, stride=1): 142 | downsample = None 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | downsample = nn.Sequential( 145 | nn.Conv2d(self.inplanes, planes * block.expansion, 146 | kernel_size=1, stride=stride, bias=False), 147 | nn.BatchNorm2d(planes * block.expansion), 148 | ) 149 | 150 | layers = [] 151 | ibn = True 152 | if planes == 512: 153 | ibn = False 154 | layers.append(block(self.inplanes, planes, ibn, stride, downsample)) 155 | self.inplanes = planes * block.expansion 156 | for i in range(1, blocks): 157 | layers.append(block(self.inplanes, planes, ibn)) 158 | 159 | return nn.Sequential(*layers) 160 | 161 | def forward(self, x): 162 | x = self.conv1(x) 163 | x = self.bn1(x) 164 | x = self.relu(x) 165 | x = self.maxpool(x) 166 | 167 | x = self.layer1(x) 168 | x = self.layer2(x) 169 | x = self.layer3(x) 170 | x = self.layer4(x) 171 | 172 | x = self.avgpool(x) 173 | x = x.view(x.size(0), -1) 174 | x = self.fc(x) 175 | 176 | return x 177 | 178 | 179 | def resnet50_ibn_a(pretrained=False, **kwargs): 180 | """Constructs a ResNet-50 model. 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 185 | if pretrained: 186 | state_dict = torch.load(model_urls['ibn_resnet50a'], map_location=torch.device('cpu'))['state_dict'] 187 | state_dict = remove_module_key(state_dict) 188 | model.load_state_dict(state_dict) 189 | return model 190 | 191 | 192 | def resnet101_ibn_a(pretrained=False, **kwargs): 193 | """Constructs a ResNet-101 model. 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 198 | if pretrained: 199 | state_dict = torch.load(model_urls['ibn_resnet101a'], map_location=torch.device('cpu'))['state_dict'] 200 | state_dict = remove_module_key(state_dict) 201 | model.load_state_dict(state_dict) 202 | return model 203 | 204 | 205 | def remove_module_key(state_dict): 206 | for key in list(state_dict.keys()): 207 | if 'module' in key: 208 | state_dict[key.replace('module.','')] = state_dict.pop(key) 209 | return state_dict 210 | -------------------------------------------------------------------------------- /clustercontrast/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | from audioop import cross 3 | import time 4 | from .utils.meters import AverageMeter 5 | import torch.nn as nn 6 | import torch 7 | from torch.nn import functional as F 8 | import math 9 | 10 | 11 | 12 | 13 | def pdist_torch(emb1, emb2): 14 | ''' 15 | compute the eucilidean distance matrix between embeddings1 and embeddings2 16 | using gpu 17 | ''' 18 | m, n = emb1.shape[0], emb2.shape[0] 19 | emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n) 20 | emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t() 21 | dist_mtx = emb1_pow + emb2_pow 22 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 23 | # dist_mtx = dist_mtx.clamp(min = 1e-12) 24 | dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt() 25 | return dist_mtx 26 | def softmax_weights(dist, mask): 27 | max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] 28 | diff = dist - max_v 29 | Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero 30 | W = torch.exp(diff) * mask / Z 31 | return W 32 | def normalize(x, axis=-1): 33 | """Normalizing to unit length along the specified dimension. 34 | Args: 35 | x: pytorch Variable 36 | Returns: 37 | x: pytorch Variable, same shape as input 38 | """ 39 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 40 | return x 41 | 42 | class ClusterContrastTrainer(object): 43 | def __init__(self, encoder, memory=None): 44 | super(ClusterContrastTrainer, self).__init__() 45 | self.encoder = encoder 46 | self.memory_ir = memory 47 | self.memory_rgb = memory 48 | # self.tri = TripletLoss_ADP(alpha = 1, gamma = 1, square = 1) 49 | def train(self, epoch, data_loader_ir,data_loader_rgb, optimizer, print_freq=10, train_iters=400, i2r=None, r2i=None): 50 | self.encoder.train() 51 | 52 | batch_time = AverageMeter() 53 | data_time = AverageMeter() 54 | losses = AverageMeter() 55 | criterion_tri = OriTripletLoss(256, 0.3) # (batchsize, margin) 56 | 57 | 58 | end = time.time() 59 | for i in range(train_iters): 60 | # load data 61 | inputs_ir = data_loader_ir.next() 62 | inputs_rgb = data_loader_rgb.next() 63 | data_time.update(time.time() - end) 64 | 65 | # process inputs 66 | inputs_ir, labels_ir, indexes_ir = self._parse_data_ir(inputs_ir) 67 | inputs_rgb,inputs_rgb1, labels_rgb, indexes_rgb = self._parse_data_rgb(inputs_rgb) 68 | # KL any? 69 | 70 | # forward 71 | inputs_rgb = torch.cat((inputs_rgb,inputs_rgb1),0) 72 | labels_rgb = torch.cat((labels_rgb,labels_rgb),-1) 73 | _,f_out_rgb,f_out_ir,labels_rgb,labels_ir,pool_rgb,pool_ir = self._forward(inputs_rgb,inputs_ir,label_1=labels_rgb,label_2=labels_ir,modal=0) 74 | 75 | # intra-modality nce loss 76 | loss_ir = self.memory_ir(f_out_ir, labels_ir) 77 | loss_rgb = self.memory_rgb(f_out_rgb, labels_rgb) 78 | 79 | # cross contrastive learning 80 | if r2i: 81 | rgb2ir_labels = torch.tensor([r2i[key.item()] for key in labels_rgb]).cuda() 82 | ir2rgb_labels = torch.tensor([i2r[key.item()] for key in labels_ir]).cuda() 83 | alternate = True 84 | if alternate: 85 | # accl 86 | if epoch % 2 == 1: 87 | cross_loss = 1 * self.memory_rgb(f_out_ir, ir2rgb_labels.long()) 88 | else: 89 | cross_loss = 1 * self.memory_ir(f_out_rgb, rgb2ir_labels.long()) 90 | else: 91 | cross_loss = self.memory_rgb(f_out_ir, ir2rgb_labels.long()) + self.memory_ir(f_out_rgb, rgb2ir_labels.long()) 92 | # Unidirectional 93 | # cross_loss = self.memory_rgb(f_out_ir, ir2rgb_labels.long()) 94 | # cross_loss = self.memory_ir(f_out_rgb, rgb2ir_labels.long()) 95 | else: 96 | cross_loss = torch.tensor(0.0) 97 | 98 | new_loss_rgb = loss_rgb 99 | new_cross_loss = cross_loss 100 | 101 | loss = loss_ir+new_loss_rgb+0.25*new_cross_loss # total loss 102 | optimizer.zero_grad() 103 | loss.backward() 104 | optimizer.step() 105 | 106 | losses.update(loss.item()) 107 | 108 | # print log 109 | batch_time.update(time.time() - end) 110 | end = time.time() 111 | 112 | if (i + 1) % print_freq == 0: 113 | print('Epoch: [{}][{}/{}]\t' 114 | 'Time {:.3f} ({:.3f})\t' 115 | 'Data {:.3f} ({:.3f})\t' 116 | 'Loss {:.3f} ({:.3f})\t' 117 | 'Loss ir {:.3f}\t' 118 | 'Loss rgb {:.3f}\t' 119 | 'Loss cross {:.3f}\t' 120 | # 'Loss tri rgb {:.3f}\t' 121 | # 'Loss tri ir {:.3f}\t' 122 | .format(epoch, i + 1, len(data_loader_rgb), 123 | batch_time.val, batch_time.avg, 124 | data_time.val, data_time.avg, 125 | losses.val, losses.avg,loss_ir,new_loss_rgb,new_cross_loss 126 | # , loss_tri_rgb 127 | # , loss_tri_ir 128 | )) 129 | 130 | def _parse_data_rgb(self, inputs): 131 | imgs,imgs1, _, pids, _, indexes = inputs 132 | return imgs.cuda(),imgs1.cuda(), pids.cuda(), indexes.cuda() 133 | 134 | def _parse_data_ir(self, inputs): 135 | imgs, _, pids, _, indexes = inputs 136 | return imgs.cuda(), pids.cuda(), indexes.cuda() 137 | 138 | def _forward(self, x1, x2, label_1=None,label_2=None,modal=0): 139 | return self.encoder(x1, x2, modal=modal,label_1=label_1,label_2=label_2) 140 | 141 | 142 | class OriTripletLoss(nn.Module): 143 | """Triplet loss with hard positive/negative mining. 144 | 145 | Reference: 146 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 147 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 148 | 149 | Args: 150 | - margin (float): margin for triplet. 151 | """ 152 | 153 | def __init__(self, batch_size, margin=0.3): 154 | super(OriTripletLoss, self).__init__() 155 | self.margin = margin 156 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 157 | 158 | def forward(self, inputs, targets): 159 | """ 160 | Args: 161 | - inputs: feature matrix with shape (batch_size, feat_dim) 162 | - targets: ground truth labels with shape (num_classes) 163 | """ 164 | n = inputs.size(0) 165 | 166 | # Compute pairwise distance, replace by the official when merged 167 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 168 | dist = dist + dist.t() 169 | dist.addmm_(1, -2, inputs, inputs.t()) 170 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 171 | 172 | # For each anchor, find the hardest positive and negative 173 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 174 | dist_ap, dist_an = [], [] 175 | for i in range(n): 176 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 177 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 178 | dist_ap = torch.cat(dist_ap) 179 | dist_an = torch.cat(dist_an) 180 | 181 | # Compute ranking hinge loss 182 | y = torch.ones_like(dist_an) 183 | loss = self.ranking_loss(dist_an, dist_ap, y) 184 | 185 | # compute accuracy 186 | correct = torch.ge(dist_an, dist_ap).sum().item() 187 | return loss, correct -------------------------------------------------------------------------------- /clustercontrast/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /clustercontrast/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .base_dataset import BaseDataset, BaseImageDataset 4 | from .preprocessor import Preprocessor 5 | 6 | 7 | class IterLoader: 8 | def __init__(self, loader, length=None): 9 | self.loader = loader 10 | self.length = length 11 | self.iter = None 12 | 13 | def __len__(self): 14 | if self.length is not None: 15 | return self.length 16 | 17 | return len(self.loader) 18 | 19 | def new_epoch(self): 20 | self.iter = iter(self.loader) 21 | 22 | def next(self): 23 | try: 24 | return next(self.iter) 25 | except: 26 | self.iter = iter(self.loader) 27 | return next(self.iter) 28 | -------------------------------------------------------------------------------- /clustercontrast/utils/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/utils/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /clustercontrast/utils/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/utils/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /clustercontrast/utils/data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/utils/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /clustercontrast/utils/data/__pycache__/base_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/utils/data/__pycache__/base_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /clustercontrast/utils/data/__pycache__/preprocessor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/utils/data/__pycache__/preprocessor.cpython-36.pyc -------------------------------------------------------------------------------- /clustercontrast/utils/data/__pycache__/preprocessor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/utils/data/__pycache__/preprocessor.cpython-37.pyc -------------------------------------------------------------------------------- /clustercontrast/utils/data/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/utils/data/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /clustercontrast/utils/data/__pycache__/sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/utils/data/__pycache__/sampler.cpython-37.pyc -------------------------------------------------------------------------------- /clustercontrast/utils/data/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/utils/data/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /clustercontrast/utils/data/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zesenwu23/USL-VI-ReID/ed4a7df28755b64d8367cfa733cf681b316c26c2/clustercontrast/utils/data/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /clustercontrast/utils/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def print_dataset_statistics(self): 23 | raise NotImplementedError 24 | 25 | @property 26 | def images_dir(self): 27 | return None 28 | 29 | 30 | class BaseImageDataset(BaseDataset): 31 | """ 32 | Base class of image reid dataset 33 | """ 34 | 35 | def print_dataset_statistics(self, train, query, gallery): 36 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 37 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 38 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 39 | 40 | print("Dataset statistics:") 41 | print(" ----------------------------------------") 42 | print(" subset | # ids | # images | # cameras") 43 | print(" ----------------------------------------") 44 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 45 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 46 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 47 | print(" ----------------------------------------") 48 | -------------------------------------------------------------------------------- /clustercontrast/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import random 7 | import math 8 | from PIL import Image 9 | 10 | 11 | class Preprocessor(Dataset): 12 | def __init__(self, dataset, root=None, transform=None): 13 | super(Preprocessor, self).__init__() 14 | self.dataset = dataset 15 | self.root = root 16 | self.transform = transform 17 | 18 | def __len__(self): 19 | return len(self.dataset) 20 | 21 | def __getitem__(self, indices): 22 | return self._get_single_item(indices) 23 | 24 | def _get_single_item(self, index): 25 | fname, pid, camid = self.dataset[index] 26 | fpath = fname 27 | if self.root is not None: 28 | fpath = osp.join(self.root, fname) 29 | 30 | img = Image.open(fpath).convert('RGB') 31 | 32 | if self.transform is not None: 33 | img = self.transform(img) 34 | 35 | return img, fname, pid, camid, index 36 | 37 | class Preprocessor_color(Dataset): 38 | def __init__(self, dataset, root=None, transform=None,transform1=None): 39 | super(Preprocessor_color, self).__init__() 40 | self.dataset = dataset 41 | self.root = root 42 | self.transform = transform 43 | self.transform1 = transform1 44 | def __len__(self): 45 | return len(self.dataset) 46 | 47 | def __getitem__(self, indices): 48 | return self._get_single_item(indices) 49 | 50 | def _get_single_item(self, index): 51 | fname, pid, camid = self.dataset[index] 52 | fpath = fname 53 | if self.root is not None: 54 | fpath = osp.join(self.root, fname) 55 | 56 | img_ori = Image.open(fpath).convert('RGB') 57 | 58 | if self.transform is not None: 59 | img = self.transform(img_ori) 60 | img1 = self.transform1(img_ori) 61 | return img, img1,fname, pid, camid, index 62 | -------------------------------------------------------------------------------- /clustercontrast/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import math 4 | 5 | import numpy as np 6 | import copy 7 | import random 8 | import torch 9 | from torch.utils.data.sampler import ( 10 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 11 | WeightedRandomSampler) 12 | 13 | 14 | def No_index(a, b): 15 | assert isinstance(a, list) 16 | return [i for i, j in enumerate(a) if j != b] 17 | 18 | 19 | class RandomIdentitySampler(Sampler): 20 | def __init__(self, data_source, num_instances): 21 | self.data_source = data_source 22 | self.num_instances = num_instances 23 | self.index_dic = defaultdict(list) 24 | for index, (_, pid, _) in enumerate(data_source): 25 | self.index_dic[pid].append(index) 26 | self.pids = list(self.index_dic.keys()) 27 | self.num_samples = len(self.pids) 28 | 29 | def __len__(self): 30 | return self.num_samples * self.num_instances 31 | 32 | def __iter__(self): 33 | indices = torch.randperm(self.num_samples).tolist() 34 | ret = [] 35 | for i in indices: 36 | pid = self.pids[i] 37 | t = self.index_dic[pid] 38 | if len(t) >= self.num_instances: 39 | t = np.random.choice(t, size=self.num_instances, replace=False) 40 | else: 41 | t = np.random.choice(t, size=self.num_instances, replace=True) 42 | ret.extend(t) 43 | return iter(ret) 44 | 45 | 46 | class RandomMultipleGallerySampler(Sampler): 47 | def __init__(self, data_source, num_instances=4): 48 | super().__init__(data_source) 49 | self.data_source = data_source 50 | self.index_pid = defaultdict(int) 51 | self.pid_cam = defaultdict(list) 52 | self.pid_index = defaultdict(list) 53 | self.num_instances = num_instances 54 | 55 | for index, (_, pid, cam) in enumerate(data_source): 56 | if pid < 0: 57 | continue 58 | self.index_pid[index] = pid 59 | self.pid_cam[pid].append(cam) 60 | self.pid_index[pid].append(index) 61 | 62 | self.pids = list(self.pid_index.keys()) 63 | self.num_samples = len(self.pids) 64 | 65 | def __len__(self): 66 | return self.num_samples * self.num_instances 67 | 68 | def __iter__(self): 69 | indices = torch.randperm(len(self.pids)).tolist() 70 | ret = [] 71 | 72 | for kid in indices: 73 | i = random.choice(self.pid_index[self.pids[kid]]) 74 | 75 | _, i_pid, i_cam = self.data_source[i] 76 | 77 | ret.append(i) 78 | 79 | pid_i = self.index_pid[i] 80 | cams = self.pid_cam[pid_i] 81 | index = self.pid_index[pid_i] 82 | select_cams = No_index(cams, i_cam) 83 | 84 | if select_cams: 85 | 86 | if len(select_cams) >= self.num_instances: 87 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 88 | else: 89 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True) 90 | 91 | for kk in cam_indexes: 92 | ret.append(index[kk]) 93 | 94 | else: 95 | select_indexes = No_index(index, i) 96 | if not select_indexes: 97 | continue 98 | if len(select_indexes) >= self.num_instances: 99 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 100 | else: 101 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 102 | 103 | for kk in ind_indexes: 104 | ret.append(index[kk]) 105 | 106 | return iter(ret) 107 | 108 | 109 | class RandomMultipleGallerySamplerNoCam(Sampler): 110 | def __init__(self, data_source, num_instances=4): 111 | super().__init__(data_source) 112 | 113 | self.data_source = data_source 114 | self.index_pid = defaultdict(int) 115 | self.pid_index = defaultdict(list) 116 | self.num_instances = num_instances 117 | 118 | for index, (_, pid, cam) in enumerate(data_source): 119 | if pid < 0: 120 | continue 121 | self.index_pid[index] = pid 122 | self.pid_index[pid].append(index) 123 | 124 | self.pids = list(self.pid_index.keys()) 125 | self.num_samples = len(self.pids) 126 | 127 | def __len__(self): 128 | return self.num_samples * self.num_instances 129 | 130 | def __iter__(self): 131 | indices = torch.randperm(len(self.pids)).tolist() 132 | ret = [] 133 | 134 | for kid in indices: 135 | i = random.choice(self.pid_index[self.pids[kid]]) 136 | _, i_pid, i_cam = self.data_source[i] 137 | 138 | ret.append(i) 139 | 140 | pid_i = self.index_pid[i] 141 | index = self.pid_index[pid_i] 142 | 143 | select_indexes = No_index(index, i) 144 | if not select_indexes: 145 | continue 146 | if len(select_indexes) >= self.num_instances: 147 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 148 | else: 149 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 150 | 151 | for kk in ind_indexes: 152 | ret.append(index[kk]) 153 | 154 | return iter(ret) 155 | -------------------------------------------------------------------------------- /clustercontrast/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | """ Randomly selects a rectangle region in an image and erases its pixels. 54 | 'Random Erasing Data Augmentation' by Zhong et al. 55 | See https://arxiv.org/pdf/1708.04896.pdf 56 | Args: 57 | probability: The probability that the Random Erasing operation will be performed. 58 | sl: Minimum proportion of erased area against input image. 59 | sh: Maximum proportion of erased area against input image. 60 | r1: Minimum aspect ratio of erased area. 61 | mean: Erasing value. 62 | """ 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) >= self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | -------------------------------------------------------------------------------- /clustercontrast/utils/faiss_rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 5 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 6 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 7 | """ 8 | 9 | import os, sys 10 | import time 11 | import numpy as np 12 | from scipy.spatial.distance import cdist 13 | import gc 14 | import faiss 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | 19 | from .faiss_utils import search_index_pytorch, search_raw_array_pytorch, \ 20 | index_init_gpu, index_init_cpu 21 | 22 | 23 | def k_reciprocal_neigh(initial_rank, i, k1): 24 | forward_k_neigh_index = initial_rank[i,:k1+1] 25 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 26 | fi = np.where(backward_k_neigh_index==i)[0] 27 | return forward_k_neigh_index[fi] 28 | 29 | 30 | def compute_jaccard_distance(target_features, k1=20, k2=6, print_flag=True, search_option=0, use_float16=False): 31 | end = time.time() 32 | if print_flag: 33 | print('Computing jaccard distance...') 34 | 35 | ngpus = faiss.get_num_gpus() 36 | N = target_features.size(0) 37 | mat_type = np.float16 if use_float16 else np.float32 38 | 39 | if (search_option==0): 40 | # GPU + PyTorch CUDA Tensors (1) 41 | res = faiss.StandardGpuResources() 42 | res.setDefaultNullStreamAllDevices() 43 | _, initial_rank = search_raw_array_pytorch(res, target_features, target_features, k1) 44 | initial_rank = initial_rank.cpu().numpy() 45 | elif (search_option==1): 46 | # GPU + PyTorch CUDA Tensors (2) 47 | res = faiss.StandardGpuResources() 48 | index = faiss.GpuIndexFlatL2(res, target_features.size(-1)) 49 | index.add(target_features.cpu().numpy()) 50 | _, initial_rank = search_index_pytorch(index, target_features, k1) 51 | res.syncDefaultStreamCurrentDevice() 52 | initial_rank = initial_rank.cpu().numpy() 53 | elif (search_option==2): 54 | # GPU 55 | index = index_init_gpu(ngpus, target_features.size(-1)) 56 | index.add(target_features.cpu().numpy()) 57 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 58 | else: 59 | # CPU 60 | index = index_init_cpu(target_features.size(-1)) 61 | index.add(target_features.cpu().numpy()) 62 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 63 | 64 | 65 | nn_k1 = [] 66 | nn_k1_half = [] 67 | for i in range(N): 68 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 69 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1/2)))) 70 | 71 | V = np.zeros((N, N), dtype=mat_type) 72 | for i in range(N): 73 | k_reciprocal_index = nn_k1[i] 74 | k_reciprocal_expansion_index = k_reciprocal_index 75 | for candidate in k_reciprocal_index: 76 | candidate_k_reciprocal_index = nn_k1_half[candidate] 77 | if (len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index)) > 2/3*len(candidate_k_reciprocal_index)): 78 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 79 | 80 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 81 | dist = 2-2*torch.mm(target_features[i].unsqueeze(0).contiguous(), target_features[k_reciprocal_expansion_index].t()) 82 | if use_float16: 83 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type) 84 | else: 85 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy() 86 | 87 | del nn_k1, nn_k1_half 88 | 89 | if k2 != 1: 90 | V_qe = np.zeros_like(V, dtype=mat_type) 91 | for i in range(N): 92 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:], axis=0) 93 | V = V_qe 94 | del V_qe 95 | 96 | del initial_rank 97 | 98 | invIndex = [] 99 | for i in range(N): 100 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num 101 | 102 | jaccard_dist = np.zeros((N, N), dtype=mat_type) 103 | for i in range(N): 104 | temp_min = np.zeros((1, N), dtype=mat_type) 105 | # temp_max = np.zeros((1,N), dtype=mat_type) 106 | indNonZero = np.where(V[i, :] != 0)[0] 107 | indImages = [] 108 | indImages = [invIndex[ind] for ind in indNonZero] 109 | for j in range(len(indNonZero)): 110 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]]+np.minimum(V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]) 111 | # temp_max[0,indImages[j]] = temp_max[0,indImages[j]]+np.maximum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 112 | 113 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 114 | # jaccard_dist[i] = 1-temp_min/(temp_max+1e-6) 115 | 116 | del invIndex, V 117 | 118 | pos_bool = (jaccard_dist < 0) 119 | jaccard_dist[pos_bool] = 0.0 120 | if print_flag: 121 | print("Jaccard distance computing time cost: {}".format(time.time()-end)) 122 | 123 | return jaccard_dist 124 | -------------------------------------------------------------------------------- /clustercontrast/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | def swig_ptr_from_FloatTensor(x): 7 | assert x.is_contiguous() 8 | assert x.dtype == torch.float32 9 | return faiss.cast_integer_to_float_ptr( 10 | x.storage().data_ptr() + x.storage_offset() * 4) 11 | 12 | def swig_ptr_from_LongTensor(x): 13 | assert x.is_contiguous() 14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 15 | 16 | return faiss.cast_integer_to_idx_t_ptr( 17 | x.storage().data_ptr() + x.storage_offset() * 8) 18 | 19 | def search_index_pytorch(index, x, k, D=None, I=None): 20 | """call the search function of an index with pytorch tensor I/O (CPU 21 | and GPU supported)""" 22 | assert x.is_contiguous() 23 | n, d = x.size() 24 | assert d == index.d 25 | 26 | if D is None: 27 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 28 | else: 29 | assert D.size() == (n, k) 30 | 31 | if I is None: 32 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 33 | else: 34 | assert I.size() == (n, k) 35 | torch.cuda.synchronize() 36 | xptr = swig_ptr_from_FloatTensor(x) 37 | Iptr = swig_ptr_from_LongTensor(I) 38 | Dptr = swig_ptr_from_FloatTensor(D) 39 | index.search_c(n, xptr, 40 | k, Dptr, Iptr) 41 | torch.cuda.synchronize() 42 | return D, I 43 | 44 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 45 | metric=faiss.METRIC_L2): 46 | assert xb.device == xq.device 47 | 48 | nq, d = xq.size() 49 | if xq.is_contiguous(): 50 | xq_row_major = True 51 | elif xq.t().is_contiguous(): 52 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 53 | xq_row_major = False 54 | else: 55 | raise TypeError('matrix should be row or column-major') 56 | 57 | xq_ptr = swig_ptr_from_FloatTensor(xq) 58 | 59 | nb, d2 = xb.size() 60 | assert d2 == d 61 | if xb.is_contiguous(): 62 | xb_row_major = True 63 | elif xb.t().is_contiguous(): 64 | xb = xb.t() 65 | xb_row_major = False 66 | else: 67 | raise TypeError('matrix should be row or column-major') 68 | xb_ptr = swig_ptr_from_FloatTensor(xb) 69 | 70 | if D is None: 71 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 72 | else: 73 | assert D.shape == (nq, k) 74 | assert D.device == xb.device 75 | 76 | if I is None: 77 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 78 | else: 79 | assert I.shape == (nq, k) 80 | assert I.device == xb.device 81 | 82 | D_ptr = swig_ptr_from_FloatTensor(D) 83 | I_ptr = swig_ptr_from_LongTensor(I) 84 | 85 | faiss.bruteForceKnn(res, metric, 86 | xb_ptr, xb_row_major, nb, 87 | xq_ptr, xq_row_major, nq, 88 | d, k, D_ptr, I_ptr) 89 | 90 | return D, I 91 | 92 | def index_init_gpu(ngpus, feat_dim): 93 | flat_config = [] 94 | for i in range(ngpus): 95 | cfg = faiss.GpuIndexFlatConfig() 96 | cfg.useFloat16 = False 97 | cfg.device = i 98 | flat_config.append(cfg) 99 | 100 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 101 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 102 | index = faiss.IndexShards(feat_dim) 103 | for sub_index in indexes: 104 | index.add_shard(sub_index) 105 | index.reset() 106 | return index 107 | 108 | def index_init_cpu(feat_dim): 109 | return faiss.IndexFlatL2(feat_dim) 110 | -------------------------------------------------------------------------------- /clustercontrast/utils/infomap_cluster.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | # import infomap 4 | from infomap import infomap 5 | # from infomap import Infomap 6 | import faiss 7 | import math 8 | import multiprocessing as mp 9 | from clustercontrast.utils.infomap_utils import Timer 10 | 11 | 12 | 13 | 14 | def l2norm(vec): 15 | """ 16 | 归一化 17 | :param vec: 18 | :return: 19 | """ 20 | vec /= np.linalg.norm(vec, axis=1).reshape(-1, 1) 21 | return vec 22 | 23 | 24 | def intdict2ndarray(d, default_val=-1): 25 | lenth = max(d.keys()) 26 | print(lenth+1) 27 | arr = np.zeros(lenth+1) + default_val 28 | # arr = np.zeros(len(d)) + default_val 29 | # print(arr.shape) 30 | # print(len(d)) 31 | for k, v in d.items(): 32 | arr[k] = v 33 | return arr 34 | 35 | 36 | def read_meta(fn_meta, start_pos=0, verbose=True): 37 | """ 38 | idx2lb:每一个顶点对应一个类 39 | lb2idxs:每个类对应一个id 40 | """ 41 | lb2idxs = {} 42 | idx2lb = {} 43 | with open(fn_meta) as f: 44 | for idx, x in enumerate(f.readlines()[start_pos:]): 45 | lb = int(x.strip()) 46 | if lb not in lb2idxs: 47 | lb2idxs[lb] = [] 48 | lb2idxs[lb] += [idx] 49 | idx2lb[idx] = lb 50 | 51 | inst_num = len(idx2lb) 52 | cls_num = len(lb2idxs) 53 | if verbose: 54 | print('[{}] #cls: {}, #inst: {}'.format(fn_meta, cls_num, inst_num)) 55 | return lb2idxs, idx2lb 56 | 57 | 58 | class knn_faiss(): 59 | """ 60 | 内积暴力循环 61 | 归一化特征的内积等价于余弦相似度 62 | """ 63 | 64 | def __init__(self, feats, k, knn_method='faiss-cpu', verbose=True): 65 | self.verbose = verbose 66 | 67 | with Timer('[{}] build index {}'.format(knn_method, k), verbose): 68 | feats = feats.astype('float32') 69 | size, dim = feats.shape 70 | if knn_method == 'faiss-gpu': 71 | i = math.ceil(size / 1000000) 72 | if i > 1: 73 | i = (i - 1) * 4 74 | res = faiss.StandardGpuResources() 75 | res.setTempMemory(i * 1024 * 1024 * 1024) 76 | index = faiss.GpuIndexFlatIP(res, dim) 77 | else: 78 | index = faiss.IndexFlatIP(dim) 79 | index.add(feats) 80 | 81 | with Timer('[{}] query topk {}'.format(knn_method, k), verbose): 82 | sims, nbrs = index.search(feats, k=k) 83 | self.knns = [(np.array(nbr, dtype=np.int32), 84 | 1 - np.array(sim, dtype=np.float32)) 85 | for nbr, sim in zip(nbrs, sims)] 86 | 87 | def filter_by_th(self, i): 88 | th_nbrs = [] 89 | th_dists = [] 90 | nbrs, dists = self.knns[i] 91 | for n, dist in zip(nbrs, dists): 92 | if 1 - dist < self.th: 93 | continue 94 | th_nbrs.append(n) 95 | th_dists.append(dist) 96 | th_nbrs = np.array(th_nbrs) 97 | th_dists = np.array(th_dists) 98 | return th_nbrs, th_dists 99 | 100 | def get_knns(self, th=None): 101 | if th is None or th <= 0.: 102 | return self.knns 103 | # TODO: optimize the filtering process by numpy 104 | # nproc = mp.cpu_count() 105 | nproc = 1 106 | with Timer('filter edges by th {} (CPU={})'.format(th, nproc), 107 | self.verbose): 108 | self.th = th 109 | self.th_knns = [] 110 | tot = len(self.knns) 111 | if nproc > 1: 112 | pool = mp.Pool(nproc) 113 | th_knns = list( 114 | tqdm(pool.imap(self.filter_by_th, range(tot)), total=tot)) 115 | pool.close() 116 | else: 117 | th_knns = [self.filter_by_th(i) for i in range(tot)] 118 | return th_knns 119 | 120 | 121 | def knns2ordered_nbrs(knns, sort=True): 122 | if isinstance(knns, list): 123 | knns = np.array(knns) 124 | nbrs = knns[:, 0, :].astype(np.int32) 125 | dists = knns[:, 1, :] 126 | if sort: 127 | # sort dists from low to high 128 | nb_idx = np.argsort(dists, axis=1) 129 | idxs = np.arange(nb_idx.shape[0]).reshape(-1, 1) 130 | dists = dists[idxs, nb_idx] 131 | nbrs = nbrs[idxs, nb_idx] 132 | return dists, nbrs 133 | 134 | 135 | # 构造边 136 | def get_links(single, links, nbrs, dists, min_sim): 137 | for i in tqdm(range(nbrs.shape[0])): 138 | count = 0 139 | for j in range(0, len(nbrs[i])): 140 | # 排除本身节点 141 | if i == nbrs[i][j]: 142 | pass 143 | elif dists[i][j] <= 1 - min_sim: 144 | count += 1 145 | links[(i, nbrs[i][j])] = float(1 - dists[i][j]) 146 | else: 147 | break 148 | # 统计孤立点 149 | if count == 0: 150 | single.append(i) 151 | return single, links 152 | 153 | 154 | def cluster_by_infomap(nbrs, dists, min_sim, cluster_num=2): 155 | """ 156 | 基于infomap的聚类 157 | :param nbrs: 158 | :param dists: 159 | :param pred_label_path: 160 | :return: 161 | """ 162 | single = [] 163 | links = {} 164 | with Timer('get links', verbose=True): 165 | single, links = get_links(single=single, links=links, nbrs=nbrs, dists=dists, min_sim=min_sim) 166 | 167 | infomapWrapper = infomap.Infomap("--two-level --directed") 168 | for (i, j), sim in tqdm(links.items()): 169 | # _ = infomapWrapper.addLink(int(i), int(j), sim) 170 | _ = infomapWrapper.addLink(int(i), int(j)) 171 | # 聚类运算 172 | infomapWrapper.run() 173 | 174 | label2idx = {} 175 | idx2label = {} 176 | 177 | # 聚类结果统计 178 | # for node in infomapWrapper.iterTree(): 179 | for node in infomapWrapper.tree.leafIter(): 180 | # node.physicalId 特征向量的编号 181 | # node.moduleIndex() 聚类的编号 182 | if node.moduleIndex() not in label2idx: 183 | label2idx[node.moduleIndex()] = [] 184 | # label2idx[node.moduleIndex()].append(node.physicalId) 185 | label2idx[node.moduleIndex()].append(node.physIndex) 186 | 187 | node_count = 0 188 | for k, v in label2idx.items(): 189 | if k == 0: 190 | each_index_list = v[2:] 191 | node_count += len(each_index_list) 192 | label2idx[k] = each_index_list 193 | else: 194 | each_index_list = v[1:] 195 | node_count += len(each_index_list) 196 | label2idx[k] = each_index_list 197 | 198 | for each_index in each_index_list: 199 | idx2label[each_index] = k 200 | 201 | keys_len = len(list(label2idx.keys())) 202 | # 孤立点放入到结果中 203 | for single_node in single: 204 | idx2label[single_node] = keys_len 205 | label2idx[keys_len] = [single_node] 206 | keys_len += 1 207 | node_count += 1 208 | 209 | # 孤立点个数 210 | print("孤立点数:{}".format(len(single))) 211 | 212 | idx_len = len(list(idx2label.keys())) 213 | assert idx_len == node_count, 'idx_len not equal node_count!' 214 | 215 | print("总节点数:{}".format(idx_len)) 216 | 217 | old_label_container = set() 218 | for each_label, each_index_list in label2idx.items(): 219 | if len(each_index_list) <= cluster_num: 220 | for each_index in each_index_list: 221 | idx2label[each_index] = -1 222 | else: 223 | old_label_container.add(each_label) 224 | 225 | old2new = {old_label: new_label for new_label, old_label in enumerate(old_label_container)} 226 | 227 | for each_index, each_label in idx2label.items(): 228 | if each_label == -1: 229 | continue 230 | # print(each_index,each_label) 231 | idx2label[each_index] = old2new[each_label] 232 | 233 | pre_labels = intdict2ndarray(idx2label) 234 | 235 | print("总类别数:{}/{}".format(keys_len, len(set(pre_labels)) - (1 if -1 in pre_labels else 0))) 236 | 237 | return pre_labels 238 | 239 | 240 | def get_dist_nbr(features, k=80, knn_method='faiss-cpu'): 241 | index = knn_faiss(feats=features, k=k, knn_method=knn_method) 242 | knns = index.get_knns() 243 | dists, nbrs = knns2ordered_nbrs(knns) 244 | return dists, nbrs 245 | 246 | 247 | 248 | -------------------------------------------------------------------------------- /clustercontrast/utils/infomap_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class TextColors: 5 | HEADER = '\033[35m' 6 | OKBLUE = '\033[34m' 7 | OKGREEN = '\033[32m' 8 | WARNING = '\033[33m' 9 | FATAL = '\033[31m' 10 | ENDC = '\033[0m' 11 | BOLD = '\033[1m' 12 | UNDERLINE = '\033[4m' 13 | 14 | 15 | class Timer(): 16 | def __init__(self, name='task', verbose=True): 17 | self.name = name 18 | self.verbose = verbose 19 | 20 | def __enter__(self): 21 | self.start = time.time() 22 | return self 23 | 24 | def __exit__(self, exc_type, exc_val, exc_tb): 25 | if self.verbose: 26 | print('[Time] {} consumes {:.4f} s'.format( 27 | self.name, 28 | time.time() - self.start)) 29 | return exc_type is None -------------------------------------------------------------------------------- /clustercontrast/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /clustercontrast/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /clustercontrast/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /clustercontrast/utils/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Source: https://github.com/zhunzhong07/person-re-ranking 5 | Created on Mon Jun 26 14:46:56 2017 6 | @author: luohao 7 | Modified by Houjing Huang, 2017-12-22. 8 | - This version accepts distance matrix instead of raw features. 9 | - The difference of `/` division between python 2 and 3 is handled. 10 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 11 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 12 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 13 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 14 | API 15 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 16 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 17 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 18 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 19 | Returns: 20 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import print_function 24 | from __future__ import division 25 | 26 | __all__ = ['re_ranking'] 27 | 28 | import numpy as np 29 | 30 | 31 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 32 | 33 | # The following naming, e.g. gallery_num, is different from outer scope. 34 | # Don't care about it. 35 | 36 | original_dist = np.concatenate( 37 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 38 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 39 | axis=0) 40 | original_dist = np.power(original_dist, 2).astype(np.float32) 41 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 42 | V = np.zeros_like(original_dist).astype(np.float32) 43 | initial_rank = np.argsort(original_dist).astype(np.int32) 44 | 45 | query_num = q_g_dist.shape[0] 46 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 47 | all_num = gallery_num 48 | 49 | for i in range(all_num): 50 | # k-reciprocal neighbors 51 | forward_k_neigh_index = initial_rank[i,:k1+1] 52 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 53 | fi = np.where(backward_k_neigh_index==i)[0] 54 | k_reciprocal_index = forward_k_neigh_index[fi] 55 | k_reciprocal_expansion_index = k_reciprocal_index 56 | for j in range(len(k_reciprocal_index)): 57 | candidate = k_reciprocal_index[j] 58 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 59 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 60 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 61 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 62 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 63 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 64 | 65 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 66 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 67 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 68 | original_dist = original_dist[:query_num,] 69 | if k2 != 1: 70 | V_qe = np.zeros_like(V,dtype=np.float32) 71 | for i in range(all_num): 72 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 73 | V = V_qe 74 | del V_qe 75 | del initial_rank 76 | invIndex = [] 77 | for i in range(gallery_num): 78 | invIndex.append(np.where(V[:,i] != 0)[0]) 79 | 80 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 81 | 82 | 83 | for i in range(query_num): 84 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 85 | indNonZero = np.where(V[i,:] != 0)[0] 86 | indImages = [] 87 | indImages = [invIndex[ind] for ind in indNonZero] 88 | for j in range(len(indNonZero)): 89 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 90 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 91 | 92 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 93 | del original_dist 94 | del V 95 | del jaccard_dist 96 | final_dist = final_dist[:query_num,query_num:] 97 | return final_dist 98 | -------------------------------------------------------------------------------- /clustercontrast/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | # checkpoint = torch.load(fpath) 34 | checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 35 | print("=> Loaded checkpoint '{}'".format(fpath)) 36 | return checkpoint 37 | else: 38 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 39 | 40 | 41 | def copy_state_dict(state_dict, model, strip=None): 42 | tgt_state = model.state_dict() 43 | copied_names = set() 44 | for name, param in state_dict.items(): 45 | if strip is not None and name.startswith(strip): 46 | name = name[len(strip):] 47 | if name not in tgt_state: 48 | continue 49 | if isinstance(param, Parameter): 50 | param = param.data 51 | if param.size() != tgt_state[name].size(): 52 | print('mismatch:', name, param.size(), tgt_state[name].size()) 53 | continue 54 | tgt_state[name].copy_(param) 55 | copied_names.add(name) 56 | 57 | missing = set(tgt_state.keys()) - copied_names 58 | if len(missing) > 0: 59 | print("missing keys in state_dict:", missing) 60 | 61 | return model 62 | -------------------------------------------------------------------------------- /examples/note.md: -------------------------------------------------------------------------------- 1 | The folder examples can be downloaded [here](https://drive.google.com/drive/folders/1NIpM5uv9_DUbCafwy7Z28yXPnMXxNtss?usp=sharing) 2 | -------------------------------------------------------------------------------- /meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /prepare_regdb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | from shutil import copyfile 8 | 9 | # You only need to change this line to your dataset download path 10 | download_path = '/data0/ReIDData/RegDB' 11 | 12 | if not os.path.isdir(download_path): 13 | print('please change the download_path') 14 | 15 | #----------------------------------------- 16 | 17 | #----------------------------------------- 18 | #query 19 | # query_path = download_path + '/query' 20 | mode1 ='/ir_modify/' 21 | save_path_first = download_path + mode1 22 | if not os.path.isdir(save_path_first): 23 | os.mkdir(save_path_first) 24 | for trial in range(1,11): 25 | n=0 26 | save_path = download_path + mode1+str(trial) 27 | if not os.path.isdir(save_path): 28 | os.mkdir(save_path) 29 | query_save_path = download_path + mode1+str(trial)+'/query' 30 | if not os.path.isdir(query_save_path): 31 | os.mkdir(query_save_path) 32 | ####################### 33 | data_path=download_path 34 | 35 | test_file_path = os.path.join(data_path,'idx/test_thermal_'+str(trial)+'.txt') 36 | with open(test_file_path) as f: 37 | data_file_list = open(test_file_path, 'rt').read().splitlines() 38 | # Get full list of image and labels 39 | files_ir = [data_path + '/' + s for s in data_file_list] 40 | # print(files_ir) 41 | # new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 42 | # files_ir.extend(new_files) 43 | # print(files_ir) 44 | exist_id = {} 45 | for file_path in files_ir: 46 | file_list = file_path.split('/') 47 | c_id = 'c1' 48 | ID = file_path.split(' ')[1] 49 | # print(ID) 50 | img_name = file_list[-1].split(' ')[-0] 51 | # print(file_list) 52 | src_path = file_path.split(' ')[0] 53 | dst_path = query_save_path 54 | if not os.path.isdir(dst_path): 55 | os.mkdir(dst_path) 56 | name = ID+"_"+c_id+"_"+img_name 57 | exist_id[ID] = exist_id.get(ID,0)+1 58 | if exist_id[ID]<=4: 59 | # print(src_path,dst_path) 60 | copyfile(src_path, dst_path + '/' + name) 61 | print(dst_path + '/' + name) 62 | mode1 ='/rgb_modify/' 63 | save_path_first = download_path + mode1 64 | if not os.path.isdir(save_path_first): 65 | os.mkdir(save_path_first) 66 | for trial in range(1,11): 67 | n=0 68 | save_path = download_path + mode1+str(trial) 69 | if not os.path.isdir(save_path): 70 | os.mkdir(save_path) 71 | query_save_path = download_path + mode1+str(trial)+'/query' 72 | if not os.path.isdir(query_save_path): 73 | os.mkdir(query_save_path) 74 | ####################### 75 | data_path=download_path 76 | 77 | test_file_path = os.path.join(data_path,'idx/test_visible_'+str(trial)+'.txt') 78 | with open(test_file_path) as f: 79 | data_file_list = open(test_file_path, 'rt').read().splitlines() 80 | # Get full list of image and labels 81 | files_ir = [data_path + '/' + s for s in data_file_list] 82 | # print(files_ir) 83 | # new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 84 | # files_ir.extend(new_files) 85 | # print(files_ir) 86 | exist_id = {} 87 | for file_path in files_ir: 88 | file_list = file_path.split('/') 89 | c_id = 'c1' 90 | ID = file_path.split(' ')[1] 91 | img_name = file_list[-1].split(' ')[-0] 92 | src_path = file_path.split(' ')[0] 93 | dst_path = query_save_path 94 | if not os.path.isdir(dst_path): 95 | os.mkdir(dst_path) 96 | name = ID+"_"+c_id+"_"+img_name 97 | exist_id[ID] = exist_id.get(ID,0)+1 98 | if exist_id[ID]<=4: 99 | # print(name) 100 | copyfile(src_path, dst_path + '/' + name) 101 | print(dst_path + '/' + name) 102 | 103 | mode1 ='/ir_modify/' 104 | if not os.path.isdir(save_path_first): 105 | os.mkdir(save_path_first) 106 | for trial in range(1,11): 107 | save_path = download_path + mode1+str(trial) 108 | if not os.path.isdir(save_path): 109 | os.mkdir(save_path) 110 | query_save_path = download_path + mode1+str(trial)+'/bounding_box_test' 111 | if not os.path.isdir(query_save_path): 112 | os.mkdir(query_save_path) 113 | ####################### 114 | data_path=download_path 115 | 116 | test_file_path = os.path.join(data_path,'idx/test_thermal_'+str(trial)+'.txt') 117 | with open(test_file_path) as f: 118 | data_file_list = open(test_file_path, 'rt').read().splitlines() 119 | files_ir = [data_path + '/' + s for s in data_file_list] 120 | 121 | for file_path in files_ir: 122 | file_list = file_path.split('/') 123 | c_id = 'c1' 124 | ID = file_path.split(' ')[1] 125 | 126 | img_name = file_list[-1].split(' ')[-0] 127 | 128 | src_path = file_path.split(' ')[0] 129 | dst_path = query_save_path 130 | if not os.path.isdir(dst_path): 131 | os.mkdir(dst_path) 132 | name = ID+"_"+c_id+"_"+img_name 133 | print(name) 134 | 135 | copyfile(src_path, dst_path + '/' + name) 136 | print(dst_path + '/' + name) 137 | # ################################# 138 | mode1 ='/rgb_modify/' 139 | save_path_first = download_path + mode1 140 | if not os.path.isdir(save_path_first): 141 | os.mkdir(save_path_first) 142 | for trial in range(1,11): 143 | save_path = download_path + mode1+str(trial) 144 | if not os.path.isdir(save_path): 145 | os.mkdir(save_path) 146 | query_save_path = download_path + mode1+str(trial)+'/bounding_box_test' 147 | if not os.path.isdir(query_save_path): 148 | os.mkdir(query_save_path) 149 | ####################### 150 | data_path=download_path 151 | 152 | test_file_path = os.path.join(data_path,'idx/test_visible_'+str(trial)+'.txt') 153 | with open(test_file_path) as f: 154 | data_file_list = open(test_file_path, 'rt').read().splitlines() 155 | 156 | files_ir = [data_path + '/' + s for s in data_file_list] 157 | 158 | 159 | for file_path in files_ir: 160 | file_list = file_path.split('/') 161 | c_id = 'c1' 162 | ID = file_path.split(' ')[1] 163 | 164 | img_name = file_list[-1].split(' ')[-0] 165 | 166 | src_path = file_path.split(' ')[0] 167 | dst_path = query_save_path 168 | if not os.path.isdir(dst_path): 169 | os.mkdir(dst_path) 170 | name = ID+"_"+c_id+"_"+img_name 171 | # print(src_path,dst_path) 172 | copyfile(src_path, dst_path + '/' + name) 173 | print(dst_path + '/' + name) 174 | 175 | 176 | # ################################# 177 | mode1 ='/ir_modify/' 178 | mode2 ='/rgb_modify/' 179 | save_path_first = download_path + mode1 180 | if not os.path.isdir(save_path_first): 181 | os.mkdir(save_path_first) 182 | for trial in range(1,11): 183 | save_path = download_path + mode1+str(trial) 184 | if not os.path.isdir(save_path): 185 | os.mkdir(save_path) 186 | query_save_path = download_path + mode1+str(trial)+'/bounding_box_train' 187 | if not os.path.isdir(query_save_path): 188 | os.mkdir(query_save_path) 189 | ####################### 190 | data_path=download_path 191 | 192 | test_file_path = os.path.join(data_path,'idx/train_thermal_'+str(trial)+'.txt') 193 | with open(test_file_path) as f: 194 | data_file_list = open(test_file_path, 'rt').read().splitlines() 195 | # Get full list of image and labels 196 | files_ir = [data_path + '/' + s for s in data_file_list] 197 | 198 | 199 | for file_path in files_ir: 200 | file_list = file_path.split('/') 201 | c_id = 'c1' 202 | ID = file_path.split(' ')[1] 203 | img_name = file_list[-1].split(' ')[-0] 204 | src_path = file_path.split(' ')[0] 205 | dst_path = query_save_path 206 | if not os.path.isdir(dst_path): 207 | os.mkdir(dst_path) 208 | name = ID+"_"+c_id+"_"+img_name 209 | copyfile(src_path, dst_path + '/' + name) 210 | print(dst_path + '/' + name) 211 | save_path = download_path + mode2+str(trial) 212 | if not os.path.isdir(save_path): 213 | os.mkdir(save_path) 214 | query_save_path = download_path + mode2+str(trial)+'/bounding_box_train' 215 | test_file_path = os.path.join(data_path,'idx/train_visible_'+str(trial)+'.txt') 216 | 217 | with open(test_file_path) as f: 218 | data_file_list = open(test_file_path, 'rt').read().splitlines() 219 | # Get full list of image and labels 220 | files_ir = [data_path + '/' + s for s in data_file_list] 221 | for file_path in files_ir: 222 | file_list = file_path.split('/') 223 | c_id = 'c1' 224 | ID = file_path.split(' ')[1] 225 | img_name = file_list[-1].split(' ')[-0] 226 | src_path = file_path.split(' ')[0] 227 | dst_path = query_save_path 228 | if not os.path.isdir(dst_path): 229 | os.mkdir(dst_path) 230 | name = ID+"_"+c_id+"_"+img_name 231 | copyfile(src_path, dst_path + '/' + name) 232 | print(dst_path + '/' + name) 233 | -------------------------------------------------------------------------------- /prepare_sysu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | from shutil import copyfile 8 | 9 | # You only need to change this line to your dataset download path 10 | download_path = '/data0/data_wzs/SYSU-MM01-Original/SYSU-MM01' 11 | 12 | if not os.path.isdir(download_path): 13 | print('please change the download_path') 14 | 15 | save_path = download_path + '/ir_modify' 16 | if not os.path.isdir(save_path): 17 | os.mkdir(save_path) 18 | #----------------------------------------- 19 | #query 20 | # query_path = download_path + '/query' 21 | query_save_path = download_path + '/ir_modify/query' 22 | if not os.path.isdir(query_save_path): 23 | os.mkdir(query_save_path) 24 | ####################### 25 | data_path=download_path 26 | ir_cameras = ['cam3','cam6'] 27 | test_file_path = os.path.join(data_path,'exp/test_id.txt') 28 | files_rgb = [] 29 | files_ir = [] 30 | files_test=[] 31 | with open(test_file_path, 'r') as file: 32 | ids = file.read().splitlines() 33 | ids = [int(y) for y in ids[0].split(',')] 34 | ids = ["%04d" % x for x in ids] 35 | for id in sorted(ids): 36 | n=0 37 | for cam in ir_cameras: 38 | img_dir = os.path.join(data_path,cam,id) 39 | if os.path.isdir(img_dir): 40 | for single in os.listdir(img_dir): 41 | if n < 4: 42 | files_ir.append(img_dir+'/'+single) 43 | 44 | else: 45 | files_test.append(img_dir+'/'+single) 46 | n=n+1 47 | # new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 48 | # files_ir.extend(new_files) 49 | # print(files_ir) 50 | 51 | for file_path in files_ir: 52 | file_list = file_path.split('/') 53 | ID = file_list[-2] 54 | c_id = 'c'+file_list[-3][-1] 55 | img_name = file_list[-1] 56 | # print(file_list) 57 | src_path = file_path 58 | dst_path = query_save_path 59 | if not os.path.isdir(dst_path): 60 | os.mkdir(dst_path) 61 | name = ID+"_"+c_id+"_"+img_name 62 | copyfile(src_path, dst_path + '/' + name) 63 | 64 | for file_path in files_test: 65 | file_list = file_path.split('/') 66 | ID = file_list[-2] 67 | c_id = 'c'+file_list[-3][-1] 68 | img_name = file_list[-1] 69 | # print(file_list) 70 | src_path = file_path 71 | dst_path = download_path + '/ir_modify/bounding_box_test' 72 | if not os.path.isdir(dst_path): 73 | os.mkdir(dst_path) 74 | name = ID+"_"+c_id+"_"+img_name 75 | copyfile(src_path, dst_path + '/' + name) 76 | print(dst_path + '/' + name) 77 | ############################ 78 | 79 | 80 | query_save_path = download_path + '/ir_modify/bounding_box_train' 81 | if not os.path.isdir(query_save_path): 82 | os.mkdir(query_save_path) 83 | ####################### 84 | data_path=download_path 85 | ir_cameras = ['cam3','cam6'] 86 | # rgb_cameras = ['cam1','cam2','cam3','cam4','cam5','cam6'] 87 | test_file_path = os.path.join(data_path,'exp/train_id.txt') 88 | file_path_val = os.path.join(data_path,'exp/val_id.txt') 89 | files_rgb = [] 90 | files_ir = [] 91 | with open(test_file_path, 'r') as file: 92 | ids = file.read().splitlines() 93 | ids = [int(y) for y in ids[0].split(',')] 94 | ids_train = ["%04d" % x for x in ids] 95 | 96 | with open(file_path_val, 'r') as file: 97 | ids = file.read().splitlines() 98 | ids = [int(y) for y in ids[0].split(',')] 99 | id_val = ["%04d" % x for x in ids] 100 | # print(id_val) 101 | ids_train.extend(id_val) 102 | for id in sorted(ids_train): 103 | for cam in ir_cameras: 104 | img_dir = os.path.join(data_path,cam,id) 105 | if os.path.isdir(img_dir): 106 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 107 | files_ir.extend(new_files) 108 | # print(files_ir) 109 | ############################ 110 | for file_path in files_ir: 111 | file_list = file_path.split('/') 112 | ID = file_list[-2] 113 | c_id = 'c'+file_list[-3][-1] 114 | img_name = file_list[-1] 115 | # print(file_list) 116 | src_path = file_path 117 | dst_path = query_save_path 118 | if not os.path.isdir(dst_path): 119 | os.mkdir(dst_path) 120 | name = ID+"_"+c_id+"_"+img_name 121 | copyfile(src_path, dst_path + '/' + name) 122 | print(dst_path + '/' + name) 123 | print(len(files_ir)) 124 | 125 | #------------------------------------------------------- 126 | 127 | save_path = download_path + '/rgb_modify' 128 | if not os.path.isdir(save_path): 129 | os.mkdir(save_path) 130 | #----------------------------------------- 131 | #query 132 | # query_path = download_path + '/query' 133 | query_save_path = download_path + '/rgb_modify/query' 134 | if not os.path.isdir(query_save_path): 135 | os.mkdir(query_save_path) 136 | ####################### 137 | data_path=download_path 138 | ir_cameras = ['cam1','cam2','cam4','cam5'] 139 | test_file_path = os.path.join(data_path,'exp/test_id.txt') 140 | files_rgb = [] 141 | files_ir = [] 142 | files_test=[] 143 | with open(test_file_path, 'r') as file: 144 | ids = file.read().splitlines() 145 | ids = [int(y) for y in ids[0].split(',')] 146 | ids = ["%04d" % x for x in ids] 147 | for id in sorted(ids): 148 | n=0 149 | for cam in ir_cameras: 150 | img_dir = os.path.join(data_path,cam,id) 151 | if os.path.isdir(img_dir): 152 | for single in os.listdir(img_dir): 153 | if n < 4: 154 | files_ir.append(img_dir+'/'+single) 155 | 156 | else: 157 | files_test.append(img_dir+'/'+single) 158 | n=n+1 159 | # new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 160 | # files_ir.extend(new_files) 161 | # print(files_ir) 162 | 163 | for file_path in files_ir: 164 | file_list = file_path.split('/') 165 | ID = file_list[-2] 166 | c_id = 'c'+file_list[-3][-1] 167 | img_name = file_list[-1] 168 | # print(file_list) 169 | src_path = file_path 170 | dst_path = query_save_path 171 | if not os.path.isdir(dst_path): 172 | os.mkdir(dst_path) 173 | name = ID+"_"+c_id+"_"+img_name 174 | copyfile(src_path, dst_path + '/' + name) 175 | print(dst_path + '/' + name) 176 | for file_path in files_test: 177 | file_list = file_path.split('/') 178 | ID = file_list[-2] 179 | c_id = 'c'+file_list[-3][-1] 180 | img_name = file_list[-1] 181 | # print(file_list) 182 | src_path = file_path 183 | dst_path = download_path + '/rgb_modify/bounding_box_test' 184 | if not os.path.isdir(dst_path): 185 | os.mkdir(dst_path) 186 | name = ID+"_"+c_id+"_"+img_name 187 | copyfile(src_path, dst_path + '/' + name) 188 | print(dst_path + '/' + name) 189 | ############################ 190 | 191 | 192 | query_save_path = download_path + '/rgb_modify/bounding_box_train' 193 | if not os.path.isdir(query_save_path): 194 | os.mkdir(query_save_path) 195 | ####################### 196 | data_path=download_path 197 | ir_cameras = ['cam1','cam2','cam4','cam5'] 198 | # rgb_cameras = ['cam1','cam2','cam3','cam4','cam5','cam6'] 199 | test_file_path = os.path.join(data_path,'exp/train_id.txt') 200 | file_path_val = os.path.join(data_path,'exp/val_id.txt') 201 | files_rgb = [] 202 | files_ir = [] 203 | with open(test_file_path, 'r') as file: 204 | ids = file.read().splitlines() 205 | ids = [int(y) for y in ids[0].split(',')] 206 | ids_train = ["%04d" % x for x in ids] 207 | 208 | with open(file_path_val, 'r') as file: 209 | ids = file.read().splitlines() 210 | ids = [int(y) for y in ids[0].split(',')] 211 | id_val = ["%04d" % x for x in ids] 212 | # print(id_val) 213 | ids_train.extend(id_val) 214 | for id in sorted(ids_train): 215 | for cam in ir_cameras: 216 | img_dir = os.path.join(data_path,cam,id) 217 | if os.path.isdir(img_dir): 218 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 219 | files_ir.extend(new_files) 220 | # print(files_ir) 221 | print(len(files_ir)) 222 | ############################ 223 | for file_path in files_ir: 224 | file_list = file_path.split('/') 225 | ID = file_list[-2] 226 | c_id = 'c'+file_list[-3][-1] 227 | img_name = file_list[-1] 228 | # print(file_list) 229 | src_path = file_path 230 | dst_path = query_save_path 231 | if not os.path.isdir(dst_path): 232 | os.mkdir(dst_path) 233 | name = ID+"_"+c_id+"_"+img_name 234 | copyfile(src_path, dst_path + '/' + name) 235 | print(dst_path + '/' + name) 236 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | async-timeout==3.0.1 2 | blinker==1.4 3 | brotlipy==0.7.0 4 | certifi==2021.5.30 5 | einops==0.3.0 6 | entrypoints==0.3 7 | faiss-gpu==1.6.4 8 | fire==0.4.0 9 | future==0.18.2 10 | google-auth-oauthlib==0.4.1 11 | h5py==2.8.0 12 | jedi==0.17.0 13 | llvmlite==0.36.0 14 | mkl-fft==1.3.0 15 | mkl-random==1.1.1 16 | mkl-service==2.3.0 17 | numba==0.53.1 18 | olefile==0.46 19 | pandas==1.1.5 20 | protobuf==3.17.2 21 | pyasn1-modules==0.2.8 22 | pynndescent==0.5.7 23 | pytz==2021.3 24 | PyYAML==6.0 25 | requests-oauthlib==1.3.0 26 | scikit-learn==0.22.1 27 | swin-transformer-pytorch==0.4.1 28 | tensorboard-plugin-wit==1.6.0 29 | tensorboardX==2.4 30 | termcolor==1.1.0 31 | torch==1.8.2 32 | torch-tb-profiler==0.3.1 33 | torchaudio==0.8.2 34 | torchvision==0.9.2 35 | traitlets==4.3.3 36 | urllib3==1.26.7 37 | yacs==0.1.8 38 | -------------------------------------------------------------------------------- /run_test_regdb.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 2 | python test_regdb.py \ 3 | -b 256 -a agw -d regdb_rgb \ 4 | --iters 100 \ 5 | --eps 0.6 --num-instances 16 \ 6 | --logs-dir "/data1/wzs/cvpr_upload/origin/" 7 | -------------------------------------------------------------------------------- /run_test_sysu.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 2 | python test_sysu.py \ 3 | -b 256 -a agw -d sysu_all \ 4 | --iters 200 \ 5 | --eps 0.6 \ 6 | --num-instances 16 \ 7 | --logs-dir "/data1/wzs/cvpr23_upload/origin" 8 | -------------------------------------------------------------------------------- /run_train_regdb.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 2 | python train_regdb.py -b 256 -a agw -d regdb_rgb \ 3 | --iters 100 --num-instances 16 \ 4 | --data-dir "/data0/ReIDData/RegDB" \ 5 | --logs-dir "/data1/cvpr23_upload/origin/regdb" \ 6 | --trial 1 7 | 8 | # trial: 1,2,3,4,5,6,7,8,9,10 9 | -------------------------------------------------------------------------------- /run_train_sysu.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 2 | python train_sysu.py -b 256 -a agw -d sysu_all \ 3 | --num-instances 16 \ 4 | --data-dir "/data0/data_wzs/SYSU-MM01-Original/SYSU-MM01" \ 5 | --logs-dir "/data1/wzs/cvpr23_upload/origin" \ -------------------------------------------------------------------------------- /test_regdb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, absolute_import 3 | import argparse 4 | import os.path as osp 5 | import random 6 | import numpy as np 7 | import sys 8 | import collections 9 | import time 10 | from datetime import timedelta 11 | 12 | from sklearn.cluster import DBSCAN 13 | from PIL import Image 14 | import torch 15 | from torch import nn 16 | from torch.backends import cudnn 17 | from torch.utils.data import DataLoader 18 | import torch.nn.functional as F 19 | 20 | from clustercontrast import datasets 21 | from clustercontrast import models 22 | from clustercontrast.models.cm import ClusterMemory 23 | from clustercontrast.utils.data import IterLoader 24 | from clustercontrast.utils.data import transforms as T 25 | from clustercontrast.utils.data.preprocessor import Preprocessor,Preprocessor_color 26 | from clustercontrast.utils.logging import Logger 27 | from clustercontrast.utils.serialization import load_checkpoint, save_checkpoint 28 | from clustercontrast.utils.faiss_rerank import compute_jaccard_distance 29 | from clustercontrast.utils.data.sampler import RandomMultipleGallerySampler, RandomMultipleGallerySamplerNoCam 30 | import os 31 | import torch.utils.data as data 32 | from torch.autograd import Variable 33 | import math 34 | from ChannelAug import ChannelAdap, ChannelAdapGray, ChannelRandomErasing,ChannelExchange,Gray 35 | from collections import Counter 36 | start_epoch = best_mAP = 0 37 | 38 | def get_data(name, data_dir,trial=0): 39 | root = osp.join(data_dir, name) 40 | dataset = datasets.create(name, root,trial=trial) 41 | return dataset 42 | 43 | 44 | 45 | 46 | 47 | 48 | def get_train_loader_ir(args, dataset, height, width, batch_size, workers, 49 | num_instances, iters, trainset=None, no_cam=False,train_transformer=None): 50 | 51 | 52 | 53 | train_set = sorted(dataset.train) if trainset is None else sorted(trainset) 54 | rmgs_flag = num_instances > 0 55 | if rmgs_flag: 56 | if no_cam: 57 | sampler = RandomMultipleGallerySamplerNoCam(train_set, num_instances) 58 | else: 59 | sampler = RandomMultipleGallerySampler(train_set, num_instances) 60 | else: 61 | sampler = None 62 | train_loader = IterLoader( 63 | DataLoader(Preprocessor(train_set, root=dataset.images_dir, transform=train_transformer), 64 | batch_size=batch_size, num_workers=workers, sampler=sampler, 65 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=iters) 66 | 67 | return train_loader 68 | 69 | def get_train_loader_color(args, dataset, height, width, batch_size, workers, 70 | num_instances, iters, trainset=None, no_cam=False,train_transformer=None,train_transformer1=None): 71 | 72 | 73 | 74 | train_set = sorted(dataset.train) if trainset is None else sorted(trainset) 75 | rmgs_flag = num_instances > 0 76 | if rmgs_flag: 77 | if no_cam: 78 | sampler = RandomMultipleGallerySamplerNoCam(train_set, num_instances) 79 | else: 80 | sampler = RandomMultipleGallerySampler(train_set, num_instances) 81 | else: 82 | sampler = None 83 | if train_transformer1 is None: 84 | train_loader = IterLoader( 85 | DataLoader(Preprocessor(train_set, root=dataset.images_dir, transform=train_transformer), 86 | batch_size=batch_size, num_workers=workers, sampler=sampler, 87 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=iters) 88 | else: 89 | train_loader = IterLoader( 90 | DataLoader(Preprocessor_color(train_set, root=dataset.images_dir, transform=train_transformer,transform1=train_transformer1), 91 | batch_size=batch_size, num_workers=workers, sampler=sampler, 92 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=iters) 93 | 94 | return train_loader 95 | 96 | 97 | def get_test_loader(dataset, height, width, batch_size, workers, testset=None,test_transformer=None): 98 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 99 | std=[0.229, 0.224, 0.225]) 100 | if test_transformer is None: 101 | test_transformer = T.Compose([ 102 | T.Resize((height, width), interpolation=3), 103 | T.ToTensor(), 104 | normalizer 105 | ]) 106 | 107 | if testset is None: 108 | testset = list(set(dataset.query) | set(dataset.gallery)) 109 | 110 | test_loader = DataLoader( 111 | Preprocessor(testset, root=dataset.images_dir, transform=test_transformer), 112 | batch_size=batch_size, num_workers=workers, 113 | shuffle=False, pin_memory=True) 114 | 115 | return test_loader 116 | 117 | 118 | def create_model(args): 119 | model = models.create(args.arch, num_features=args.features, norm=True, dropout=args.dropout, 120 | num_classes=0, pooling_type=args.pooling_type) 121 | # use CUDA 122 | model.cuda() 123 | model = nn.DataParallel(model)#,output_device=1) 124 | return model 125 | 126 | 127 | def main(): 128 | args = parser.parse_args() 129 | 130 | if args.seed is not None: 131 | random.seed(args.seed) 132 | np.random.seed(args.seed) 133 | torch.manual_seed(args.seed) 134 | cudnn.deterministic = True 135 | main_worker(args) 136 | 137 | class TestData(data.Dataset): 138 | def __init__(self, test_img_file, test_label, transform=None, img_size = (144,288)): 139 | 140 | test_image = [] 141 | for i in range(len(test_img_file)): 142 | img = Image.open(test_img_file[i]) 143 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 144 | pix_array = np.array(img) 145 | test_image.append(pix_array) 146 | test_image = np.array(test_image) 147 | self.test_image = test_image 148 | self.test_label = test_label 149 | self.transform = transform 150 | 151 | def __getitem__(self, index): 152 | img1, target1 = self.test_image[index], self.test_label[index] 153 | img1 = self.transform(img1) 154 | return img1, target1 155 | 156 | def __len__(self): 157 | return len(self.test_image) 158 | 159 | 160 | def fliplr(img): 161 | '''flip horizontal''' 162 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W 163 | img_flip = img.index_select(3,inv_idx) 164 | return img_flip 165 | def extract_gall_feat(model,gall_loader,ngall): 166 | pool_dim=2048 167 | net = model 168 | net.eval() 169 | print ('Extracting Gallery Feature...') 170 | start = time.time() 171 | ptr = 0 172 | gall_feat_pool = np.zeros((ngall, pool_dim)) 173 | gall_feat_fc = np.zeros((ngall, pool_dim)) 174 | with torch.no_grad(): 175 | for batch_idx, (input, label ) in enumerate(gall_loader): 176 | batch_num = input.size(0) 177 | flip_input = fliplr(input) 178 | input = Variable(input.cuda()) 179 | feat_fc = net( input,input, 2) 180 | flip_input = Variable(flip_input.cuda()) 181 | feat_fc_1 = net( flip_input,flip_input, 2) 182 | feature_fc = (feat_fc.detach() + feat_fc_1.detach())/2 183 | fnorm_fc = torch.norm(feature_fc, p=2, dim=1, keepdim=True) 184 | feature_fc = feature_fc.div(fnorm_fc.expand_as(feature_fc)) 185 | gall_feat_fc[ptr:ptr+batch_num,: ] = feature_fc.cpu().numpy() 186 | ptr = ptr + batch_num 187 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 188 | return gall_feat_fc 189 | 190 | def extract_query_feat(model,query_loader,nquery): 191 | pool_dim=2048 192 | net = model 193 | net.eval() 194 | print ('Extracting Query Feature...') 195 | start = time.time() 196 | ptr = 0 197 | query_feat_pool = np.zeros((nquery, pool_dim)) 198 | query_feat_fc = np.zeros((nquery, pool_dim)) 199 | with torch.no_grad(): 200 | for batch_idx, (input, label ) in enumerate(query_loader): 201 | batch_num = input.size(0) 202 | flip_input = fliplr(input) 203 | input = Variable(input.cuda()) 204 | feat_fc = net( input, input,1) 205 | flip_input = Variable(flip_input.cuda()) 206 | feat_fc_1 = net( flip_input,flip_input, 1) 207 | feature_fc = (feat_fc.detach() + feat_fc_1.detach())/2 208 | fnorm_fc = torch.norm(feature_fc, p=2, dim=1, keepdim=True) 209 | feature_fc = feature_fc.div(fnorm_fc.expand_as(feature_fc)) 210 | query_feat_fc[ptr:ptr+batch_num,: ] = feature_fc.cpu().numpy() 211 | 212 | ptr = ptr + batch_num 213 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 214 | return query_feat_fc 215 | 216 | 217 | 218 | 219 | def pairwise_distance(features_q, features_g): 220 | x = torch.from_numpy(features_q) 221 | y = torch.from_numpy(features_g) 222 | m, n = x.size(0), y.size(0) 223 | x = x.view(m, -1) 224 | y = y.view(n, -1) 225 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 226 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 227 | dist_m.addmm_(1, -2, x, y.t()) 228 | return dist_m.numpy() 229 | 230 | 231 | 232 | 233 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'): 234 | if modal=='visible': 235 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 236 | elif modal=='thermal': 237 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 238 | 239 | with open(input_data_path) as f: 240 | data_file_list = open(input_data_path, 'rt').read().splitlines() 241 | # Get full list of image and labels 242 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 243 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 244 | 245 | return file_image, np.array(file_label) 246 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 247 | num_q, num_g = distmat.shape 248 | if num_g < max_rank: 249 | max_rank = num_g 250 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 251 | indices = np.argsort(distmat, axis=1) 252 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 253 | 254 | # compute cmc curve for each query 255 | all_cmc = [] 256 | all_AP = [] 257 | all_INP = [] 258 | num_valid_q = 0. # number of valid query 259 | 260 | # only two cameras 261 | q_camids = np.ones(num_q).astype(np.int32) 262 | g_camids = 2* np.ones(num_g).astype(np.int32) 263 | 264 | for q_idx in range(num_q): 265 | # get query pid and camid 266 | q_pid = q_pids[q_idx] 267 | q_camid = q_camids[q_idx] 268 | 269 | # remove gallery samples that have the same pid and camid with query 270 | order = indices[q_idx] 271 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 272 | keep = np.invert(remove) 273 | 274 | # compute cmc curve 275 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 276 | if not np.any(raw_cmc): 277 | # this condition is true when query identity does not appear in gallery 278 | continue 279 | 280 | cmc = raw_cmc.cumsum() 281 | 282 | # compute mINP 283 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 284 | pos_idx = np.where(raw_cmc == 1) 285 | pos_max_idx = np.max(pos_idx) 286 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 287 | all_INP.append(inp) 288 | 289 | cmc[cmc > 1] = 1 290 | 291 | all_cmc.append(cmc[:max_rank]) 292 | num_valid_q += 1. 293 | 294 | # compute average precision 295 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 296 | num_rel = raw_cmc.sum() 297 | tmp_cmc = raw_cmc.cumsum() 298 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 299 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 300 | AP = tmp_cmc.sum() / num_rel 301 | all_AP.append(AP) 302 | 303 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 304 | 305 | all_cmc = np.asarray(all_cmc).astype(np.float32) 306 | all_cmc = all_cmc.sum(0) / num_valid_q 307 | mAP = np.mean(all_AP) 308 | mINP = np.mean(all_INP) 309 | return all_cmc, mAP, mINP 310 | 311 | def main_worker(args): 312 | log_name='regdb_s2'#model path 313 | model = create_model(args) 314 | for trial in range(1,11):#(1,11): 315 | args.test_batch=64 316 | args.img_w=args.width 317 | args.img_h=args.height 318 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406], 319 | std=[0.229, 0.224, 0.225]) 320 | transform_test = T.Compose([ 321 | T.ToPILImage(), 322 | T.Resize((args.img_h,args.img_w)), 323 | T.ToTensor(), 324 | normalize, 325 | ]) 326 | logs_dir_root = osp.join(args.logs_dir+'/'+log_name) 327 | # args.logs_dir = osp.join(logs_dir_root,str(trial)) 328 | print('==> Test with the best model:') 329 | 330 | checkpoint = load_checkpoint(osp.join(logs_dir_root+'/'+str(trial), 'model_best.pth.tar')) 331 | 332 | model.load_state_dict(checkpoint['state_dict']) 333 | 334 | mode='visible to thermal' 335 | print(mode) 336 | data_path='/data0/ReIDData/RegDB/' 337 | query_img, query_label = process_test_regdb(data_path, trial=trial, modal='visible') 338 | gall_img, gall_label = process_test_regdb(data_path, trial=trial, modal='thermal') 339 | 340 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 341 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 342 | nquery = len(query_label) 343 | ngall = len(gall_label) 344 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 345 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 346 | query_feat_fc = extract_query_feat(model,query_loader,nquery) 347 | # for trial in range(1): 348 | ngall = len(gall_label) 349 | gall_feat_fc = extract_gall_feat(model,gall_loader,ngall) 350 | # fc feature 351 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 352 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 353 | if trial == 1: 354 | all_cmc = cmc 355 | all_mAP = mAP 356 | all_mINP = mINP 357 | 358 | else: 359 | all_cmc = all_cmc + cmc 360 | all_mAP = all_mAP + mAP 361 | all_mINP = all_mINP + mINP 362 | 363 | 364 | print('Test Trial: {}'.format(trial)) 365 | print( 366 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 367 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 368 | cmc = all_cmc / 10 369 | mAP = all_mAP / 10 370 | mINP = all_mINP / 10 371 | print('All Average:') 372 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 373 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 374 | ################################# 375 | for trial in range(1,11):#(1,11): 376 | args.test_batch=64 377 | args.img_w=args.width 378 | args.img_h=args.height 379 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406], 380 | std=[0.229, 0.224, 0.225]) 381 | transform_test = T.Compose([ 382 | T.ToPILImage(), 383 | T.Resize((args.img_h,args.img_w)), 384 | T.ToTensor(), 385 | normalize, 386 | ]) 387 | logs_dir_root = osp.join(args.logs_dir+'/'+log_name) 388 | # args.logs_dir = osp.join(logs_dir_root,str(trial)) 389 | print('==> Test with the best model:') 390 | model = create_model(args) 391 | checkpoint = load_checkpoint(osp.join(logs_dir_root+'/'+str(trial), 'model_best.pth.tar')) 392 | 393 | model.load_state_dict(checkpoint['state_dict']) 394 | 395 | mode='thermal to visible' 396 | print(mode) 397 | data_path='/data0/ReIDData/RegDB/' 398 | query_img, query_label = process_test_regdb(data_path, trial=trial, modal='thermal') 399 | gall_img, gall_label = process_test_regdb(data_path, trial=trial, modal='visible') 400 | 401 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 402 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 403 | nquery = len(query_label) 404 | ngall = len(gall_label) 405 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 406 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 407 | query_feat_fc = extract_gall_feat(model,query_loader,nquery) 408 | # for trial in range(1): 409 | ngall = len(gall_label) 410 | gall_feat_fc = extract_query_feat(model,gall_loader,ngall) 411 | # fc feature 412 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 413 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 414 | if trial == 1: 415 | all_cmc = cmc 416 | all_mAP = mAP 417 | all_mINP = mINP 418 | 419 | else: 420 | all_cmc = all_cmc + cmc 421 | all_mAP = all_mAP + mAP 422 | all_mINP = all_mINP + mINP 423 | 424 | 425 | print('Test Trial: {}'.format(trial)) 426 | print( 427 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 428 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 429 | cmc = all_cmc / 10 430 | mAP = all_mAP / 10 431 | mINP = all_mINP / 10 432 | print('All Average:') 433 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 434 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 435 | ################################# 436 | 437 | 438 | if __name__ == '__main__': 439 | parser = argparse.ArgumentParser(description="Self-paced contrastive learning on unsupervised re-ID") 440 | # data 441 | parser.add_argument('-d', '--dataset', type=str, default='dukemtmcreid', 442 | choices=datasets.names()) 443 | parser.add_argument('-b', '--batch-size', type=int, default=2) 444 | parser.add_argument('-j', '--workers', type=int, default=8) 445 | parser.add_argument('--height', type=int, default=288, help="input height") 446 | parser.add_argument('--width', type=int, default=144, help="input width") 447 | parser.add_argument('--num-instances', type=int, default=4, 448 | help="each minibatch consist of " 449 | "(batch_size // num_instances) identities, and " 450 | "each identity has num_instances instances, " 451 | "default: 0 (NOT USE)") 452 | # cluster 453 | parser.add_argument('--eps', type=float, default=0.6, 454 | help="max neighbor distance for DBSCAN") 455 | parser.add_argument('--eps-gap', type=float, default=0.02, 456 | help="multi-scale criterion for measuring cluster reliability") 457 | parser.add_argument('--k1', type=int, default=30, 458 | help="hyperparameter for jaccard distance") 459 | parser.add_argument('--k2', type=int, default=6, 460 | help="hyperparameter for jaccard distance") 461 | 462 | # model 463 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 464 | choices=models.names()) 465 | parser.add_argument('--features', type=int, default=0) 466 | parser.add_argument('--dropout', type=float, default=0) 467 | parser.add_argument('--momentum', type=float, default=0.2, 468 | help="update momentum for the hybrid memory") 469 | # optimizer 470 | parser.add_argument('--lr', type=float, default=0.00035, 471 | help="learning rate") 472 | parser.add_argument('--weight-decay', type=float, default=5e-4) 473 | parser.add_argument('--epochs', type=int, default=50) 474 | parser.add_argument('--iters', type=int, default=100) 475 | parser.add_argument('--step-size', type=int, default=20) 476 | # training configs 477 | parser.add_argument('--seed', type=int, default=1) 478 | parser.add_argument('--print-freq', type=int, default=10) 479 | parser.add_argument('--eval-step', type=int, default=1) 480 | parser.add_argument('--temp', type=float, default=0.05, 481 | help="temperature for scaling contrastive loss") 482 | # path 483 | working_dir = osp.dirname(osp.abspath(__file__)) 484 | parser.add_argument('--data-dir', type=str, metavar='PATH', 485 | default=osp.join(working_dir, 'data')) 486 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 487 | default=osp.join(working_dir, 'logs')) 488 | parser.add_argument('--pooling-type', type=str, default='gem') 489 | parser.add_argument('--use-hard', action="store_true") 490 | parser.add_argument('--no-cam', action="store_true") 491 | 492 | main() 493 | --------------------------------------------------------------------------------