├── LICENSE ├── README.md ├── examples ├── data │ └── data.txt ├── test.py └── unsupervised_train.py ├── figs └── figure8.png ├── ice ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dukemtmc.py │ ├── market1501.py │ └── msmt17.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── evaluators.py ├── loss │ ├── __init__.py │ ├── contrastive.py │ ├── crossentropy.py │ └── triplet.py ├── models │ ├── __init__.py │ ├── resnet.py │ ├── resnet_ibn.py │ └── resnet_ibn_a.py ├── trainers.py └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── dataset.py │ ├── preprocessor.py │ ├── sampler.py │ └── transforms.py │ ├── faiss_rerank.py │ ├── faiss_utils.py │ ├── logging.py │ ├── lr_scheduler.py │ ├── meters.py │ ├── osutils.py │ ├── rerank.py │ └── serialization.py ├── setup.cfg └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 chenhao2345 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ICE 2 | This is the official PyTorch implementation of the ICCV 2021 paper 3 | [ICE: Inter-instance Contrastive Encoding for Unsupervised Person 4 | Re-identification](https://arxiv.org/pdf/2103.16364.pdf). 5 | 6 | [[Video](https://drive.google.com/file/d/1E__ru9u_oRcb44-WIH_GjBTv1-_5rcO2/view?usp=sharing)] [[Poster](https://drive.google.com/file/d/1HEkgtUCSOixIndH1ClhRZfAQGTIFfY-n/view?usp=sharing)] 7 | 8 | ![teaser](figs/figure8.png) 9 | 10 | ## Installation 11 | 12 | ```shell 13 | git clone https://github.com/chenhao2345/ICE 14 | cd ICE 15 | python setup.py develop 16 | ``` 17 | 18 | ## Prepare Datasets 19 | 20 | Download the raw datasets [DukeMTMC-reID](https://arxiv.org/abs/1609.01775), [Market-1501](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Zheng_Scalable_Person_Re-Identification_ICCV_2015_paper.pdf), [MSMT17](https://arxiv.org/abs/1711.08565), 21 | and then unzip them under the directory like 22 | ``` 23 | ICE/examples/data 24 | ├── dukemtmc-reid 25 | │ └── DukeMTMC-reID 26 | ├── market1501 27 | └── msmt17 28 | └── MSMT17_V1(or MSMT17_V2) 29 | ``` 30 | 31 | ## Training 32 | We used **4 GPUs** to train our model. 33 | 34 | Train [Market-1501](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Zheng_Scalable_Person_Re-Identification_ICCV_2015_paper.pdf): 35 | ``` 36 | python examples/unsupervised_train.py --dataset-target market1501 37 | ``` 38 | Train [DukeMTMC-reID](https://arxiv.org/abs/1609.01775): 39 | ``` 40 | python examples/unsupervised_train.py --dataset-target dukemtmc-reid 41 | ``` 42 | Train [MSMT17](https://arxiv.org/abs/1711.08565): 43 | ``` 44 | python examples/unsupervised_train.py --dataset-target msmt17 45 | ``` 46 | ## Citation 47 | If you find this project useful, please kindly star our project and cite our paper. 48 | ```bibtex 49 | @InProceedings{Chen_2021_ICCV, 50 | author = {Chen, Hao and Lagadec, Benoit and Bremond, Fran\c{c}ois}, 51 | title = {ICE: Inter-Instance Contrastive Encoding for Unsupervised Person Re-Identification}, 52 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 53 | month = {October}, 54 | year = {2021}, 55 | pages = {14960-14969} 56 | } 57 | ``` 58 | 59 | -------------------------------------------------------------------------------- /examples/data/data.txt: -------------------------------------------------------------------------------- 1 | Please put your datasets in this folder. 2 | 3 | ICE/examples/data 4 | ├── dukemtmc-reid 5 | │ └── DukeMTMC-reID 6 | ├── market1501 7 | └── msmt17 8 | └── MSMT17_V1(or MSMT17_V2) -------------------------------------------------------------------------------- /examples/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import argparse 4 | import os.path as osp 5 | import random 6 | import numpy as np 7 | import sys 8 | import time 9 | 10 | from sklearn.cluster import DBSCAN 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from torch.backends import cudnn 15 | from torch.utils.data import DataLoader 16 | 17 | from ice.utils.logging import Logger 18 | from ice import datasets 19 | from ice import models 20 | from ice.trainers import ImageTrainer 21 | from ice.evaluators import Evaluator, extract_features 22 | from ice.utils.data import IterLoader 23 | from ice.utils.data import transforms as T 24 | from ice.utils.data.sampler import MoreCameraSampler 25 | from ice.utils.data.preprocessor import Preprocessor_mutual, Preprocessor 26 | from ice.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict 27 | from ice.utils.faiss_rerank import compute_jaccard_distance 28 | from ice.utils.lr_scheduler import WarmupMultiStepLR 29 | 30 | start_epoch = best_mAP = 0 31 | 32 | 33 | def get_data(name, data_dir): 34 | root = osp.join(data_dir, name) 35 | dataset = datasets.create(name, root) 36 | return dataset 37 | 38 | 39 | def get_train_loader(dataset, height, width, batch_size, workers, 40 | num_instances, iters, trainset=None, mutual=False, index=False): 41 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225]) 43 | train_transformer = T.Compose([ 44 | T.Resize((height, width), interpolation=3), 45 | T.Pad(10), 46 | T.RandomCrop((height, width)), 47 | T.RandomHorizontalFlip(p=0.5), 48 | T.RandomApply([T.GaussianBlur([.1, 2.])], p=0.5), 49 | T.ToTensor(), 50 | normalizer, 51 | T.RandomErasing(probability=0.6, mean=[0.485, 0.456, 0.406]), 52 | ]) 53 | 54 | weak_transformer = T.Compose([ 55 | T.Resize((height, width), interpolation=3), 56 | T.ToTensor(), 57 | normalizer, 58 | ]) 59 | 60 | train_set = sorted(dataset.train) if trainset is None else sorted(trainset) 61 | rmgs_flag = num_instances > 0 62 | if rmgs_flag: 63 | sampler = MoreCameraSampler(train_set, num_instances) 64 | else: 65 | sampler = None 66 | train_loader = IterLoader( 67 | DataLoader(Preprocessor_mutual(train_set, root=dataset.images_dir, transform=train_transformer, mutual=mutual, transform_weak=weak_transformer), 68 | batch_size=batch_size, num_workers=workers, sampler=sampler, 69 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=iters) 70 | 71 | return train_loader 72 | 73 | def get_test_loader(dataset, height, width, batch_size, workers, testset=None): 74 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 75 | std=[0.229, 0.224, 0.225]) 76 | 77 | test_transformer = T.Compose([ 78 | T.Resize((height, width), interpolation=3), 79 | T.ToTensor(), 80 | normalizer 81 | ]) 82 | 83 | if (testset is None): 84 | testset = list(set(dataset.query) | set(dataset.gallery)) 85 | 86 | test_loader = DataLoader( 87 | Preprocessor(testset, root=dataset.images_dir, transform=test_transformer), 88 | batch_size=batch_size, num_workers=workers, 89 | shuffle=False, pin_memory=True) 90 | 91 | return test_loader 92 | 93 | def create_model(args): 94 | model_1_ema = models.create(args.arch, num_features=args.features, dropout=args.dropout, num_classes=0) 95 | model_1_ema.cuda() 96 | model_1_ema = nn.DataParallel(model_1_ema) 97 | return model_1_ema 98 | 99 | 100 | def main(): 101 | args = parser.parse_args() 102 | 103 | if args.seed is not None: 104 | random.seed(args.seed) 105 | np.random.seed(args.seed) 106 | torch.manual_seed(args.seed) 107 | cudnn.deterministic = True 108 | 109 | main_worker(args) 110 | 111 | 112 | def main_worker(args): 113 | global start_epoch, best_mAP 114 | 115 | cudnn.benchmark = True 116 | 117 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 118 | print("==========\nArgs:{}\n==========".format(args)) 119 | 120 | # Create data loaders 121 | dataset_target = get_data(args.dataset_target, args.data_dir) 122 | test_loader_target = get_test_loader(dataset_target, args.height, args.width, 256, args.workers) 123 | 124 | # Create model 125 | model_1_ema = create_model(args) 126 | 127 | # Evaluator 128 | evaluator_1_ema = Evaluator(model_1_ema) 129 | checkpoint = load_checkpoint(osp.join(args.logs_dir, args.dataset_target+'_unsupervised', 'model_best.pth.tar')) 130 | model_1_ema.load_state_dict(checkpoint['state_dict']) 131 | evaluator_1_ema.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery, cmc_flag=True) 132 | 133 | if __name__ == '__main__': 134 | parser = argparse.ArgumentParser(description="Moco Training") 135 | # data 136 | parser.add_argument('-dt', '--dataset-target', type=str, default='market1501', 137 | choices=datasets.names()) 138 | parser.add_argument('-b', '--batch-size', type=int, default=32) 139 | parser.add_argument('-j', '--workers', type=int, default=8) 140 | parser.add_argument('--height', type=int, default=256, 141 | help="input height") 142 | parser.add_argument('--width', type=int, default=128, 143 | help="input width") 144 | parser.add_argument('--num-instances', type=int, default=4, 145 | help="each minibatch consist of " 146 | "(batch_size // num_instances) identities, and " 147 | "each identity has num_instances instances, " 148 | "default: 0 (NOT USE)") 149 | # model 150 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 151 | choices=models.names()) 152 | parser.add_argument('--features', type=int, default=0) 153 | parser.add_argument('--dropout', type=float, default=0) 154 | # optimizer 155 | parser.add_argument('--lr', type=float, default=0.00035, 156 | help="learning rate of new parameters") 157 | parser.add_argument('--alpha', type=float, default=0.999) 158 | parser.add_argument('--weight-decay', type=float, default=5e-4) 159 | parser.add_argument('--epochs', type=int, default=40) 160 | parser.add_argument('--iters', type=int, default=400) 161 | # training configs 162 | parser.add_argument('--seed', type=int, default=1) 163 | parser.add_argument('--print-freq', type=int, default=100) 164 | parser.add_argument('--eval-step', type=int, default=1) 165 | parser.add_argument('--tau-c', type=float, default=0.5) 166 | parser.add_argument('--tau-v', type=float, default=0.1) 167 | parser.add_argument('--scale-kl', type=float, default=0.4) 168 | parser.add_argument('--warmup-step', type=int, default=10) 169 | parser.add_argument('--milestones', nargs='+', type=int, default=[], 170 | help='milestones for the learning rate decay') 171 | # path 172 | working_dir = osp.dirname(osp.abspath(__file__)) 173 | parser.add_argument('--data-dir', type=str, metavar='PATH', 174 | default=osp.join(working_dir, 'data')) 175 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 176 | default=osp.join(working_dir, 'logs')) 177 | # cluster 178 | parser.add_argument('--eps', type=float, default=0.55, help="dbscan threshold") 179 | parser.add_argument('--k1', type=int, default=30, 180 | help="k1, default: 30") 181 | parser.add_argument('--min-samples', type=int, default=4, 182 | help="min sample, default: 4") 183 | end = time.time() 184 | main() 185 | print('Time used: {}'.format(time.time()-end)) 186 | -------------------------------------------------------------------------------- /examples/unsupervised_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import argparse 4 | import os.path as osp 5 | import random 6 | import numpy as np 7 | import sys 8 | import time 9 | 10 | from sklearn.cluster import DBSCAN 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from torch.backends import cudnn 15 | from torch.utils.data import DataLoader 16 | 17 | from ice.utils.logging import Logger 18 | from ice import datasets 19 | from ice import models 20 | from ice.trainers import ImageTrainer 21 | from ice.evaluators import Evaluator, extract_features 22 | from ice.utils.data import IterLoader 23 | from ice.utils.data import transforms as T 24 | from ice.utils.data.sampler import MoreCameraSampler 25 | from ice.utils.data.preprocessor import Preprocessor_mutual, Preprocessor 26 | from ice.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict 27 | from ice.utils.faiss_rerank import compute_jaccard_distance 28 | from ice.utils.lr_scheduler import WarmupMultiStepLR 29 | 30 | start_epoch = best_mAP = 0 31 | 32 | 33 | def get_data(name, data_dir): 34 | root = osp.join(data_dir, name) 35 | dataset = datasets.create(name, root) 36 | return dataset 37 | 38 | 39 | def get_train_loader(dataset, height, width, batch_size, workers, 40 | num_instances, iters, trainset=None, mutual=False, index=False): 41 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225]) 43 | train_transformer = T.Compose([ 44 | T.Resize((height, width), interpolation=3), 45 | T.Pad(10), 46 | T.RandomCrop((height, width)), 47 | T.RandomHorizontalFlip(p=0.5), 48 | T.RandomApply([T.GaussianBlur([.1, 2.])], p=0.5), 49 | T.ToTensor(), 50 | normalizer, 51 | T.RandomErasing(probability=0.6, mean=[0.485, 0.456, 0.406]), 52 | ]) 53 | 54 | weak_transformer = T.Compose([ 55 | T.Resize((height, width), interpolation=3), 56 | T.ToTensor(), 57 | normalizer, 58 | ]) 59 | 60 | train_set = sorted(dataset.train) if trainset is None else sorted(trainset) 61 | rmgs_flag = num_instances > 0 62 | if rmgs_flag: 63 | sampler = MoreCameraSampler(train_set, num_instances) 64 | else: 65 | sampler = None 66 | train_loader = IterLoader( 67 | DataLoader(Preprocessor_mutual(train_set, root=dataset.images_dir, transform=train_transformer, mutual=mutual, transform_weak=weak_transformer), 68 | batch_size=batch_size, num_workers=workers, sampler=sampler, 69 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=iters) 70 | 71 | return train_loader 72 | 73 | def get_test_loader(dataset, height, width, batch_size, workers, testset=None): 74 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 75 | std=[0.229, 0.224, 0.225]) 76 | 77 | test_transformer = T.Compose([ 78 | T.Resize((height, width), interpolation=3), 79 | T.ToTensor(), 80 | normalizer 81 | ]) 82 | 83 | if (testset is None): 84 | testset = list(set(dataset.query) | set(dataset.gallery)) 85 | 86 | test_loader = DataLoader( 87 | Preprocessor(testset, root=dataset.images_dir, transform=test_transformer), 88 | batch_size=batch_size, num_workers=workers, 89 | shuffle=False, pin_memory=True) 90 | 91 | return test_loader 92 | 93 | def create_model(args): 94 | model_1 = models.create(args.arch, num_features=args.features, dropout=args.dropout, num_classes=0) 95 | 96 | model_1_ema = models.create(args.arch, num_features=args.features, dropout=args.dropout, num_classes=0) 97 | 98 | model_1.cuda() 99 | model_1_ema.cuda() 100 | model_1 = nn.DataParallel(model_1) 101 | model_1_ema = nn.DataParallel(model_1_ema) 102 | 103 | if args.init != '': 104 | initial_weights = load_checkpoint(args.init) 105 | copy_state_dict(initial_weights['state_dict'], model_1) 106 | copy_state_dict(initial_weights['state_dict'], model_1_ema) 107 | 108 | return model_1, model_1_ema 109 | 110 | 111 | def main(): 112 | args = parser.parse_args() 113 | 114 | if args.seed is not None: 115 | random.seed(args.seed) 116 | np.random.seed(args.seed) 117 | torch.manual_seed(args.seed) 118 | cudnn.deterministic = True 119 | 120 | main_worker(args) 121 | 122 | 123 | def main_worker(args): 124 | global start_epoch, best_mAP 125 | 126 | cudnn.benchmark = True 127 | 128 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 129 | print("==========\nArgs:{}\n==========".format(args)) 130 | 131 | # Create data loaders 132 | iters = args.iters if (args.iters>0) else None 133 | dataset_target = get_data(args.dataset_target, args.data_dir) 134 | test_loader_target = get_test_loader(dataset_target, args.height, args.width, 256, args.workers) 135 | 136 | # Create model 137 | model_1, model_1_ema = create_model(args) 138 | 139 | # Optimizer 140 | params = [] 141 | for key, value in model_1.named_parameters(): 142 | if not value.requires_grad: 143 | continue 144 | params += [{"params": [value], "lr": args.lr, "weight_decay": args.weight_decay}] 145 | optimizer = torch.optim.Adam(params) 146 | 147 | lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=1, warmup_factor=0.1, 148 | warmup_iters=args.warmup_step) 149 | 150 | # Evaluator 151 | evaluator_1_ema = Evaluator(model_1_ema) 152 | 153 | for epoch in range(args.epochs): 154 | 155 | cluster_loader = get_test_loader(dataset_target, args.height, args.width, 256, args.workers, testset=dataset_target.train) 156 | dict_f1, _ = extract_features(model_1_ema, cluster_loader, print_freq=50) 157 | cf = torch.stack(list(dict_f1.values())) 158 | 159 | rerank_dist = compute_jaccard_distance(cf, k1=args.k1, k2=6) 160 | eps = args.eps 161 | 162 | print('eps in cluster: {:.3f}'.format(eps)) 163 | print('Clustering and labeling...') 164 | cluster = DBSCAN(eps=eps, min_samples=args.min_samples, metric='precomputed', n_jobs=-1) 165 | labels = cluster.fit_predict(rerank_dist) 166 | num_ids = len(set(labels)) - (1 if -1 in labels else 0) 167 | 168 | centers = [] 169 | for id in range(num_ids): 170 | centers.append(torch.mean(cf[labels == id], dim=0)) 171 | centers = torch.stack(centers, dim=0) 172 | 173 | # change pseudo labels 174 | pseudo_labeled_dataset = [] 175 | pseudo_outlier_dataset = [] 176 | labels_true = [] 177 | 178 | cams = [] 179 | 180 | for i, ((fname, pid, cid), label) in enumerate(zip(dataset_target.train, labels)): 181 | labels_true.append(pid) 182 | cams.append(cid) 183 | if label == -1: 184 | pseudo_outlier_dataset.append((fname, label.item(), cid)) 185 | else: 186 | pseudo_labeled_dataset.append((fname, label.item(), cid)) 187 | cams = np.asarray(cams) 188 | 189 | intra_id_features = [] 190 | intra_id_labels = [] 191 | for cc in np.unique(cams): 192 | percam_ind = np.where(cams == cc)[0] 193 | percam_feature = cf[percam_ind].numpy() 194 | percam_label = labels[percam_ind] 195 | percam_class_num = len(np.unique(percam_label[percam_label >= 0])) 196 | percam_id_feature = np.zeros((percam_class_num, percam_feature.shape[1]), dtype=np.float32) 197 | cnt = 0 198 | for lbl in np.unique(percam_label): 199 | if lbl >= 0: 200 | ind = np.where(percam_label == lbl)[0] 201 | id_feat = np.mean(percam_feature[ind], axis=0) 202 | percam_id_feature[cnt, :] = id_feat 203 | intra_id_labels.append(lbl) 204 | cnt += 1 205 | percam_id_feature = percam_id_feature / np.linalg.norm(percam_id_feature, axis=1, keepdims=True) 206 | intra_id_features.append(torch.from_numpy(percam_id_feature)) 207 | 208 | print('Epoch {} has {} labeled samples of {} ids and {} unlabeled samples'. 209 | format(epoch, len(pseudo_labeled_dataset), num_ids, len(pseudo_outlier_dataset))) 210 | print('Learning Rate:', optimizer.param_groups[0]['lr']) 211 | train_loader_target = get_train_loader(dataset_target, args.height, args.width, 212 | args.batch_size, args.workers, args.num_instances, iters, trainset=pseudo_labeled_dataset, mutual=True) 213 | 214 | # Trainer 215 | trainer = ImageTrainer(model_1, model_1_ema, num_cluster=num_ids, alpha=args.alpha, 216 | num_instance=args.num_instances, tau_c=args.tau_c, tau_v=args.tau_v, 217 | scale_kl=args.scale_kl) 218 | 219 | train_loader_target.new_epoch() 220 | 221 | trainer.train(epoch, train_loader_target, optimizer, 222 | print_freq=args.print_freq, train_iters=len(train_loader_target), centers=centers, 223 | intra_id_labels=intra_id_labels, intra_id_features=intra_id_features, cams=cams, all_pseudo_label=labels) 224 | 225 | lr_scheduler.step() 226 | 227 | if (epoch+1)%args.eval_step==0: 228 | cmc, mAP_1 = evaluator_1_ema.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery, cmc_flag=True) 229 | is_best = mAP_1 > best_mAP 230 | best_mAP = max(mAP_1, best_mAP) 231 | save_checkpoint({ 232 | 'state_dict': model_1_ema.state_dict(), 233 | 'epoch': epoch + 1, 234 | 'best_mAP': best_mAP, 235 | }, is_best, fpath=osp.join(args.logs_dir, args.dataset_target+'_unsupervised','checkpoint.pth.tar')) 236 | 237 | checkpoint = load_checkpoint(osp.join(args.logs_dir, args.dataset_target+'_unsupervised', 'model_best.pth.tar')) 238 | model_1_ema.load_state_dict(checkpoint['state_dict']) 239 | evaluator_1_ema.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery, cmc_flag=True) 240 | 241 | if __name__ == '__main__': 242 | parser = argparse.ArgumentParser(description="Moco Training") 243 | # data 244 | parser.add_argument('-dt', '--dataset-target', type=str, default='market1501', 245 | choices=datasets.names()) 246 | parser.add_argument('-b', '--batch-size', type=int, default=32) 247 | parser.add_argument('-j', '--workers', type=int, default=8) 248 | parser.add_argument('--height', type=int, default=256, 249 | help="input height") 250 | parser.add_argument('--width', type=int, default=128, 251 | help="input width") 252 | parser.add_argument('--num-instances', type=int, default=4, 253 | help="each minibatch consist of " 254 | "(batch_size // num_instances) identities, and " 255 | "each identity has num_instances instances, " 256 | "default: 0 (NOT USE)") 257 | # model 258 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 259 | choices=models.names()) 260 | parser.add_argument('--features', type=int, default=0) 261 | parser.add_argument('--dropout', type=float, default=0) 262 | # optimizer 263 | parser.add_argument('--lr', type=float, default=0.00035, 264 | help="learning rate of new parameters") 265 | parser.add_argument('--alpha', type=float, default=0.999) 266 | parser.add_argument('--weight-decay', type=float, default=5e-4) 267 | parser.add_argument('--epochs', type=int, default=40) 268 | parser.add_argument('--iters', type=int, default=400) 269 | # training configs 270 | parser.add_argument('--seed', type=int, default=1) 271 | parser.add_argument('--print-freq', type=int, default=100) 272 | parser.add_argument('--eval-step', type=int, default=1) 273 | parser.add_argument('--tau-c', type=float, default=0.5) 274 | parser.add_argument('--tau-v', type=float, default=0.1) 275 | parser.add_argument('--scale-kl', type=float, default=0.4) 276 | parser.add_argument('--warmup-step', type=int, default=10) 277 | parser.add_argument('--milestones', nargs='+', type=int, default=[], 278 | help='milestones for the learning rate decay') 279 | # path 280 | working_dir = osp.dirname(osp.abspath(__file__)) 281 | parser.add_argument('--data-dir', type=str, metavar='PATH', 282 | default=osp.join(working_dir, 'data')) 283 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 284 | default=osp.join(working_dir, 'logs')) 285 | # cluster 286 | parser.add_argument('--eps', type=float, default=0.55, help="dbscan threshold") 287 | parser.add_argument('--k1', type=int, default=30, 288 | help="k1, default: 30") 289 | parser.add_argument('--min-samples', type=int, default=4, 290 | help="min sample, default: 4") 291 | # init 292 | parser.add_argument('--init', type=str, 293 | default='', 294 | metavar='PATH') 295 | end = time.time() 296 | main() 297 | print('Time used: {}'.format(time.time()-end)) 298 | -------------------------------------------------------------------------------- /figs/figure8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenhao2345/ICE/a206eb9a97ad431ab9d9cf38cdcf5ab6fdc6ad1c/figs/figure8.png -------------------------------------------------------------------------------- /ice/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import models 6 | from . import utils 7 | from . import evaluators 8 | from . import trainers 9 | 10 | __version__ = '1.0.0' 11 | -------------------------------------------------------------------------------- /ice/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .dukemtmc import DukeMTMC 5 | from .market1501 import Market1501 6 | from .msmt17 import MSMT17 7 | 8 | __factory = { 9 | 'market1501': Market1501, 10 | 'dukemtmc-reid': DukeMTMC, 11 | 'msmt17': MSMT17, 12 | } 13 | 14 | 15 | def names(): 16 | return sorted(__factory.keys()) 17 | 18 | 19 | def create(name, root, *args, **kwargs): 20 | """ 21 | Create a dataset instance. 22 | 23 | Parameters 24 | ---------- 25 | name : str 26 | The dataset name. 27 | root : str 28 | The path to the dataset directory. 29 | split_id : int, optional 30 | The index of data split. Default: 0 31 | num_val : int or float, optional 32 | When int, it means the number of validation identities. When float, 33 | it means the proportion of validation to all the trainval. Default: 100 34 | download : bool, optional 35 | If True, will download the dataset. Default: False 36 | """ 37 | if name not in __factory: 38 | raise KeyError("Unknown dataset:", name) 39 | return __factory[name](root, *args, **kwargs) 40 | 41 | 42 | def get_dataset(name, root, *args, **kwargs): 43 | warnings.warn("get_dataset is deprecated. Use create instead.") 44 | return create(name, root, *args, **kwargs) 45 | -------------------------------------------------------------------------------- /ice/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | import os 8 | import shutil 9 | from ..utils.data import BaseImageDataset 10 | from ..utils.osutils import mkdir_if_missing 11 | from ..utils.serialization import write_json 12 | 13 | 14 | class DukeMTMC(BaseImageDataset): 15 | """ 16 | DukeMTMC-reID 17 | Reference: 18 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 21 | 22 | Dataset statistics: 23 | # identities: 1404 (train + query) 24 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 25 | # cameras: 8 26 | """ 27 | dataset_dir = '.' 28 | 29 | def __init__(self, root, verbose=True, **kwargs): 30 | super(DukeMTMC, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 33 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 34 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 35 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 36 | 37 | self._download_data() 38 | self._check_before_run() 39 | 40 | train = self._process_dir(self.train_dir, relabel=True) 41 | query = self._process_dir(self.query_dir, relabel=False) 42 | gallery = self._process_dir(self.gallery_dir, relabel=False) 43 | 44 | if verbose: 45 | print("=> DukeMTMC-reID loaded") 46 | self.print_dataset_statistics(train, query, gallery) 47 | 48 | self.train = train 49 | self.query = query 50 | self.gallery = gallery 51 | 52 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 53 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 54 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 55 | 56 | def _download_data(self): 57 | if osp.exists(self.dataset_dir): 58 | print("This dataset has been downloaded.") 59 | return 60 | 61 | print("Creating directory {}".format(self.dataset_dir)) 62 | mkdir_if_missing(self.dataset_dir) 63 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 64 | 65 | print("Downloading DukeMTMC-reID dataset") 66 | urllib.request.urlretrieve(self.dataset_url, fpath) 67 | 68 | print("Extracting files") 69 | zip_ref = zipfile.ZipFile(fpath, 'r') 70 | zip_ref.extractall(self.dataset_dir) 71 | zip_ref.close() 72 | 73 | def _check_before_run(self): 74 | """Check if all files are available before going deeper""" 75 | if not osp.exists(self.dataset_dir): 76 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 77 | if not osp.exists(self.train_dir): 78 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 79 | if not osp.exists(self.query_dir): 80 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 81 | if not osp.exists(self.gallery_dir): 82 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 83 | 84 | def _process_dir(self, dir_path, relabel=False): 85 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 86 | pattern = re.compile(r'([-\d]+)_c(\d)') 87 | 88 | pid_container = set() 89 | for img_path in img_paths: 90 | pid, _ = map(int, pattern.search(img_path).groups()) 91 | pid_container.add(pid) 92 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 93 | 94 | dataset = [] 95 | for img_path in img_paths: 96 | pid, camid = map(int, pattern.search(img_path).groups()) 97 | assert 1 <= camid <= 8 98 | camid -= 1 # index starts from 0 99 | if relabel: pid = pid2label[pid] 100 | dataset.append((img_path, pid, camid)) 101 | 102 | return dataset 103 | -------------------------------------------------------------------------------- /ice/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import shutil 6 | import os 7 | from ..utils.data import BaseImageDataset 8 | from ..utils.osutils import mkdir_if_missing 9 | from ..utils.serialization import write_json 10 | 11 | class Market1501(BaseImageDataset): 12 | """ 13 | Market1501 14 | Reference: 15 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 16 | URL: http://www.liangzheng.org/Project/project_reid.html 17 | 18 | Dataset statistics: 19 | # identities: 1501 (+1 for background) 20 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 21 | """ 22 | dataset_dir = '' 23 | 24 | def __init__(self, root, verbose=True, **kwargs): 25 | super(Market1501, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 30 | 31 | self._check_before_run() 32 | 33 | train = self._process_dir(self.train_dir, relabel=True) 34 | query = self._process_dir(self.query_dir, relabel=False) 35 | gallery = self._process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | print("=> Market1501 loaded") 39 | self.print_dataset_statistics(train, query, gallery) 40 | 41 | self.train = train 42 | self.query = query 43 | self.gallery = gallery 44 | 45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 48 | 49 | def _check_before_run(self): 50 | """Check if all files are available before going deeper""" 51 | if not osp.exists(self.dataset_dir): 52 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 53 | if not osp.exists(self.train_dir): 54 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 55 | if not osp.exists(self.query_dir): 56 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 57 | if not osp.exists(self.gallery_dir): 58 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 59 | 60 | def _process_dir(self, dir_path, relabel=False): 61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 62 | pattern = re.compile(r'([-\d]+)_c(\d)') 63 | 64 | pid_container = set() 65 | for img_path in img_paths: 66 | pid, _ = map(int, pattern.search(img_path).groups()) 67 | if pid == -1: continue # junk images are just ignored 68 | pid_container.add(pid) 69 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 70 | 71 | dataset = [] 72 | for img_path in img_paths: 73 | pid, camid = map(int, pattern.search(img_path).groups()) 74 | if pid == -1: continue # junk images are just ignored 75 | assert 0 <= pid <= 1501 # pid == 0 means background 76 | assert 1 <= camid <= 6 77 | camid -= 1 # index starts from 0 78 | if relabel: pid = pid2label[pid] 79 | dataset.append((img_path, pid, camid)) 80 | 81 | return dataset 82 | -------------------------------------------------------------------------------- /ice/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import tarfile 4 | 5 | import glob 6 | import re 7 | import urllib 8 | import zipfile 9 | 10 | from ..utils.osutils import mkdir_if_missing 11 | from ..utils.serialization import write_json 12 | 13 | 14 | def _pluck_msmt(list_file, subdir, pattern=re.compile(r'([-\d]+)_([-\d]+)_([-\d]+)')): 15 | with open(list_file, 'r') as f: 16 | lines = f.readlines() 17 | ret = [] 18 | pids = [] 19 | for line in lines: 20 | line = line.strip() 21 | fname = line.split(' ')[0] 22 | pid, _, cam = map(int, pattern.search(osp.basename(fname)).groups()) 23 | cam = cam - 1 # start from 0 24 | if pid not in pids: 25 | pids.append(pid) 26 | ret.append((osp.join(subdir,fname), pid, cam)) 27 | return ret, pids 28 | 29 | class Dataset_MSMT(object): 30 | def __init__(self, root): 31 | self.root = root 32 | self.train, self.val, self.trainval = [], [], [] 33 | self.query, self.gallery = [], [] 34 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 35 | 36 | @property 37 | def images_dir(self): 38 | return osp.join(self.root, 'MSMT17_V2') 39 | 40 | def load(self, verbose=True): 41 | exdir = osp.join(self.root, 'MSMT17_V2') 42 | self.train, train_pids = _pluck_msmt(osp.join(exdir, 'list_train.txt'), 'mask_train_v2') 43 | self.val, val_pids = _pluck_msmt(osp.join(exdir, 'list_val.txt'), 'mask_train_v2') 44 | self.train = self.train + self.val 45 | self.query, query_pids = _pluck_msmt(osp.join(exdir, 'list_query.txt'), 'mask_test_v2') 46 | self.gallery, gallery_pids = _pluck_msmt(osp.join(exdir, 'list_gallery.txt'), 'mask_test_v2') 47 | self.num_train_pids = len(list(set(train_pids).union(set(val_pids)))) 48 | 49 | if verbose: 50 | print(self.__class__.__name__, "dataset loaded") 51 | print(" subset | # ids | # images") 52 | print(" ---------------------------") 53 | print(" train | {:5d} | {:8d}" 54 | .format(self.num_train_pids, len(self.train))) 55 | print(" query | {:5d} | {:8d}" 56 | .format(len(query_pids), len(self.query))) 57 | print(" gallery | {:5d} | {:8d}" 58 | .format(len(gallery_pids), len(self.gallery))) 59 | 60 | class MSMT17(Dataset_MSMT): 61 | 62 | def __init__(self, root, split_id=0, download=True): 63 | super(MSMT17, self).__init__(root) 64 | 65 | if download: 66 | self.download() 67 | 68 | self.load() 69 | 70 | def download(self): 71 | 72 | import re 73 | import hashlib 74 | import shutil 75 | from glob import glob 76 | from zipfile import ZipFile 77 | 78 | raw_dir = osp.join(self.root) 79 | mkdir_if_missing(raw_dir) 80 | 81 | # Download the raw zip file 82 | fpath = osp.join(raw_dir, 'MSMT17_V2') 83 | if osp.isdir(fpath): 84 | print("Using downloaded file: " + fpath) 85 | else: 86 | raise RuntimeError("Please download the dataset manually to {}".format(fpath)) 87 | -------------------------------------------------------------------------------- /ice/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap' 10 | ] 11 | -------------------------------------------------------------------------------- /ice/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from ..utils import to_torch 5 | 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | with torch.no_grad(): 9 | output, target = to_torch(output), to_torch(target) 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | ret = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 20 | ret.append(correct_k.mul_(1. / batch_size)) 21 | return ret 22 | -------------------------------------------------------------------------------- /ice/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /ice/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import collections 4 | from collections import OrderedDict 5 | import numpy as np 6 | import torch 7 | import random 8 | import copy 9 | 10 | from .evaluation_metrics import cmc, mean_ap 11 | from .utils.meters import AverageMeter 12 | from .utils.rerank import re_ranking 13 | from .utils import to_torch 14 | 15 | def extract_cnn_feature(model, inputs): 16 | inputs = to_torch(inputs).cuda() 17 | outputs = model(inputs) 18 | outputs = outputs.data.cpu() 19 | return outputs 20 | 21 | def extract_features(model, data_loader, print_freq=50): 22 | model.eval() 23 | batch_time = AverageMeter() 24 | data_time = AverageMeter() 25 | 26 | features = OrderedDict() 27 | labels = OrderedDict() 28 | 29 | end = time.time() 30 | with torch.no_grad(): 31 | for i, (imgs, fnames, pids, _, _) in enumerate(data_loader): 32 | data_time.update(time.time() - end) 33 | 34 | outputs = extract_cnn_feature(model, imgs) 35 | for fname, output, pid in zip(fnames, outputs, pids): 36 | features[fname] = output 37 | labels[fname] = pid 38 | 39 | batch_time.update(time.time() - end) 40 | end = time.time() 41 | 42 | if (i + 1) % print_freq == 0: 43 | print('Extract Features: [{}/{}]\t' 44 | 'Time {:.3f} ({:.3f})\t' 45 | 'Data {:.3f} ({:.3f})\t' 46 | .format(i + 1, len(data_loader), 47 | batch_time.val, batch_time.avg, 48 | data_time.val, data_time.avg)) 49 | 50 | return features, labels 51 | 52 | 53 | def pairwise_distance(features, query=None, gallery=None): 54 | if query is None and gallery is None: 55 | n = len(features) 56 | x = torch.cat(list(features.values())) 57 | x = x.view(n, -1) 58 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 59 | dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t()) 60 | return dist_m 61 | 62 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 63 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 64 | m, n = x.size(0), y.size(0) 65 | x = x.view(m, -1) 66 | y = y.view(n, -1) 67 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 68 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 69 | dist_m.addmm_(x, y.t(), beta=1, alpha=-2) 70 | return dist_m, x.numpy(), y.numpy() 71 | 72 | def evaluate_all(query_features, gallery_features, distmat, query=None, gallery=None, 73 | query_ids=None, gallery_ids=None, 74 | query_cams=None, gallery_cams=None, 75 | cmc_topk=(1, 5, 10), cmc_flag=False): 76 | if query is not None and gallery is not None: 77 | query_ids = [pid for _, pid, _ in query] 78 | gallery_ids = [pid for _, pid, _ in gallery] 79 | query_cams = [cam for _, _, cam in query] 80 | gallery_cams = [cam for _, _, cam in gallery] 81 | else: 82 | assert (query_ids is not None and gallery_ids is not None 83 | and query_cams is not None and gallery_cams is not None) 84 | 85 | # Compute mean AP 86 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 87 | print('Mean AP: {:4.1%}'.format(mAP)) 88 | 89 | if (not cmc_flag): 90 | return mAP 91 | 92 | cmc_configs = { 93 | 'market1501': dict(separate_camera_set=False, 94 | single_gallery_shot=False, 95 | first_match_break=True),} 96 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 97 | query_cams, gallery_cams, **params) 98 | for name, params in cmc_configs.items()} 99 | 100 | print('CMC Scores:') 101 | for k in cmc_topk: 102 | print(' top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k-1])) 103 | return cmc_scores['market1501'], mAP 104 | 105 | 106 | class Evaluator(object): 107 | def __init__(self, model): 108 | super(Evaluator, self).__init__() 109 | self.model = model 110 | 111 | def evaluate(self, data_loader, query, gallery, cmc_flag=False, rerank=False, only_distmat=False): 112 | features, _ = extract_features(self.model, data_loader) 113 | distmat, query_features, gallery_features = pairwise_distance(features, query, gallery) 114 | results = evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 115 | if only_distmat: 116 | return distmat 117 | if (not rerank): 118 | return results 119 | 120 | print('Applying person re-ranking ...') 121 | distmat_qq = pairwise_distance(features, query, query) 122 | distmat_gg = pairwise_distance(features, gallery, gallery) 123 | distmat = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy()) 124 | return evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 125 | -------------------------------------------------------------------------------- /ice/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .contrastive import ViewContrastiveLoss 4 | from .crossentropy import CrossEntropyLabelSmooth, SoftEntropy 5 | from .triplet import TripletLoss, SoftTripletLoss 6 | 7 | __all__ = [ 8 | 'CrossEntropyLabelSmooth', 9 | 'SoftEntropy', 10 | 'TripletLoss', 11 | 'SoftTripletLoss', 12 | 'ViewContrastiveLoss'] 13 | -------------------------------------------------------------------------------- /ice/loss/contrastive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | class ViewContrastiveLoss(nn.Module): 7 | def __init__(self, num_instance=4, T=1.0): 8 | super(ViewContrastiveLoss, self).__init__() 9 | self.criterion = nn.CrossEntropyLoss() 10 | self.num_instance = num_instance 11 | self.T = T 12 | 13 | def forward(self, q, k, label): 14 | batchSize = q.shape[0] 15 | N = q.size(0) 16 | mat_sim = torch.matmul(q, k.transpose(0, 1)) 17 | mat_eq = label.expand(N, N).eq(label.expand(N, N).t()).float() 18 | # batch hard 19 | hard_p, hard_n, hard_p_indice, hard_n_indice = self.batch_hard(mat_sim, mat_eq, True) 20 | l_pos = hard_p.view(batchSize, 1) 21 | mat_ne = label.expand(N, N).ne(label.expand(N, N).t()) 22 | # positives = torch.masked_select(mat_sim, mat_eq).view(-1, 1) 23 | negatives = torch.masked_select(mat_sim, mat_ne).view(batchSize, -1) 24 | out = torch.cat((l_pos, negatives), dim=1) / self.T 25 | # out = torch.cat((l_pos, l_neg, negatives), dim=1) / self.T 26 | targets = torch.zeros([batchSize]).cuda().long() 27 | triple_dist = F.log_softmax(out, dim=1) 28 | triple_dist_ref = torch.zeros_like(triple_dist).scatter_(1, targets.unsqueeze(1), 1) 29 | # triple_dist_ref = torch.zeros_like(triple_dist).scatter_(1, targets.unsqueeze(1), 1)*l + torch.zeros_like(triple_dist).scatter_(1, targets.unsqueeze(1)+1, 1) * (1-l) 30 | loss = (- triple_dist_ref * triple_dist).mean(0).sum() 31 | return loss 32 | 33 | def batch_hard(self, mat_sim, mat_eq, indice=False): 34 | sorted_mat_sim, positive_indices = torch.sort(mat_sim + (9999999.) * (1 - mat_eq), dim=1, 35 | descending=False) 36 | hard_p = sorted_mat_sim[:, 0] 37 | hard_p_indice = positive_indices[:, 0] 38 | sorted_mat_distance, negative_indices = torch.sort(mat_sim + (-9999999.) * (mat_eq), dim=1, 39 | descending=True) 40 | hard_n = sorted_mat_distance[:, 0] 41 | hard_n_indice = negative_indices[:, 0] 42 | if (indice): 43 | return hard_p, hard_n, hard_p_indice, hard_n_indice 44 | return hard_p, hard_n 45 | 46 | -------------------------------------------------------------------------------- /ice/loss/crossentropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import * 5 | 6 | 7 | class CrossEntropyLabelSmooth(nn.Module): 8 | """Cross entropy loss with label smoothing regularizer. 9 | 10 | Reference: 11 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 12 | Equation: y = (1 - epsilon) * y + epsilon / K. 13 | 14 | Args: 15 | num_classes (int): number of classes. 16 | epsilon (float): weight. 17 | """ 18 | 19 | def __init__(self, num_classes, epsilon=0.1): 20 | super(CrossEntropyLabelSmooth, self).__init__() 21 | self.num_classes = num_classes 22 | self.epsilon = epsilon 23 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 24 | 25 | def forward(self, inputs, targets): 26 | """ 27 | Args: 28 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 29 | targets: ground truth labels with shape (num_classes) 30 | """ 31 | log_probs = self.logsoftmax(inputs) 32 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 33 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 34 | loss = (- targets * log_probs).mean(0).sum() 35 | return loss 36 | 37 | class SoftEntropy(nn.Module): 38 | def __init__(self): 39 | super(SoftEntropy, self).__init__() 40 | self.logsoftmax = nn.LogSoftmax(dim=1) 41 | 42 | def forward(self, inputs, targets): 43 | log_probs = self.logsoftmax(inputs) 44 | loss = (- F.softmax(targets, dim=1).detach() * log_probs).mean(0).sum() 45 | return loss 46 | -------------------------------------------------------------------------------- /ice/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def euclidean_dist(x, y): 9 | m, n = x.size(0), y.size(0) 10 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 11 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 12 | dist = xx + yy 13 | dist.addmm_(x, y.t(), beta=1, alpha=-2) 14 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 15 | return dist 16 | 17 | def cosine_dist(x, y): 18 | bs1, bs2 = x.size(0), y.size(0) 19 | frac_up = torch.matmul(x, y.transpose(0, 1)) 20 | frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \ 21 | (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1) 22 | cosine = frac_up / frac_down 23 | return 1-cosine 24 | 25 | def _batch_hard(mat_distance, mat_similarity, indice=False): 26 | sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1, descending=True) 27 | hard_p = sorted_mat_distance[:, 0] 28 | hard_p_indice = positive_indices[:, 0] 29 | sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1, descending=False) 30 | hard_n = sorted_mat_distance[:, 0] 31 | hard_n_indice = negative_indices[:, 0] 32 | if(indice): 33 | return hard_p, hard_n, hard_p_indice, hard_n_indice 34 | return hard_p, hard_n 35 | 36 | class TripletLoss(nn.Module): 37 | ''' 38 | Compute Triplet loss augmented with Batch Hard 39 | Details can be seen in 'In defense of the Triplet Loss for Person Re-Identification' 40 | ''' 41 | 42 | def __init__(self, margin, normalize_feature=False): 43 | super(TripletLoss, self).__init__() 44 | self.margin = margin 45 | self.normalize_feature = normalize_feature 46 | self.margin_loss = nn.MarginRankingLoss(margin=margin).cuda() 47 | 48 | def forward(self, emb, label): 49 | if self.normalize_feature: 50 | # equal to cosine similarity 51 | emb = F.normalize(emb) 52 | mat_dist = euclidean_dist(emb, emb) 53 | # mat_dist = cosine_dist(emb, emb) 54 | assert mat_dist.size(0) == mat_dist.size(1) 55 | N = mat_dist.size(0) 56 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 57 | 58 | dist_ap, dist_an = _batch_hard(mat_dist, mat_sim) 59 | assert dist_an.size(0)==dist_ap.size(0) 60 | y = torch.ones_like(dist_ap) 61 | loss = self.margin_loss(dist_an, dist_ap, y) 62 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 63 | return loss, prec 64 | 65 | class SoftTripletLoss(nn.Module): 66 | 67 | def __init__(self, margin=None, normalize_feature=False): 68 | super(SoftTripletLoss, self).__init__() 69 | self.margin = margin 70 | self.normalize_feature = normalize_feature 71 | 72 | def forward(self, emb1, emb2, label): 73 | if self.normalize_feature: 74 | # equal to cosine similarity 75 | emb1 = F.normalize(emb1) 76 | emb2 = F.normalize(emb2) 77 | 78 | mat_dist = euclidean_dist(emb1, emb1) 79 | assert mat_dist.size(0) == mat_dist.size(1) 80 | N = mat_dist.size(0) 81 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 82 | 83 | dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True) 84 | assert dist_an.size(0)==dist_ap.size(0) 85 | triple_dist = torch.stack((dist_ap, dist_an), dim=1) 86 | triple_dist = F.log_softmax(triple_dist, dim=1) 87 | if (self.margin is not None): 88 | loss = (- self.margin * triple_dist[:,0] - (1 - self.margin) * triple_dist[:,1]).mean() 89 | return loss 90 | 91 | mat_dist_ref = euclidean_dist(emb2, emb2) 92 | dist_ap_ref = torch.gather(mat_dist_ref, 1, ap_idx.view(N,1).expand(N,N))[:,0] 93 | dist_an_ref = torch.gather(mat_dist_ref, 1, an_idx.view(N,1).expand(N,N))[:,0] 94 | triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1) 95 | triple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach() 96 | 97 | loss = (- triple_dist_ref * triple_dist).mean(0).sum() 98 | return loss 99 | 100 | -------------------------------------------------------------------------------- /ice/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .resnet_ibn import * 5 | 6 | __factory = { 7 | 'resnet18': resnet18, 8 | 'resnet34': resnet34, 9 | 'resnet50': resnet50, 10 | 'resnet101': resnet101, 11 | 'resnet152': resnet152, 12 | 'resnet_ibn50a': resnet_ibn50a, 13 | 'resnet_ibn101a': resnet_ibn101a 14 | } 15 | 16 | 17 | def names(): 18 | return sorted(__factory.keys()) 19 | 20 | 21 | def create(name, *args, **kwargs): 22 | """ 23 | Create a model instance. 24 | 25 | Parameters 26 | ---------- 27 | name : str 28 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 29 | 'resnet50', 'resnet101', and 'resnet152'. 30 | pretrained : bool, optional 31 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 32 | model. Default: True 33 | cut_at_pooling : bool, optional 34 | If True, will cut the model before the last global pooling layer and 35 | ignore the remaining kwargs. Default: False 36 | num_features : int, optional 37 | If positive, will append a Linear layer after the global pooling layer, 38 | with this number of output units, followed by a BatchNorm layer. 39 | Otherwise these layers will not be appended. Default: 256 for 40 | 'inception', 0 for 'resnet*' 41 | norm : bool, optional 42 | If True, will normalize the feature to be unit L2-norm for each sample. 43 | Otherwise will append a ReLU layer after the above Linear layer if 44 | num_features > 0. Default: False 45 | dropout : float, optional 46 | If positive, will append a Dropout layer with this dropout rate. 47 | Default: 0 48 | num_classes : int, optional 49 | If positive, will append a Linear layer at the end as the classifier 50 | with this number of output units. Default: 0 51 | """ 52 | if name not in __factory: 53 | raise KeyError("Unknown model:", name) 54 | return __factory[name](*args, **kwargs) 55 | -------------------------------------------------------------------------------- /ice/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152'] 12 | 13 | 14 | class ResNet(nn.Module): 15 | __factory = { 16 | 18: torchvision.models.resnet18, 17 | 34: torchvision.models.resnet34, 18 | 50: torchvision.models.resnet50, 19 | 101: torchvision.models.resnet101, 20 | 152: torchvision.models.resnet152, 21 | } 22 | 23 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 24 | num_features=0, norm=False, dropout=0, num_classes=0): 25 | super(ResNet, self).__init__() 26 | self.pretrained = pretrained 27 | self.depth = depth 28 | self.cut_at_pooling = cut_at_pooling 29 | # Construct base (pretrained) resnet 30 | if depth not in ResNet.__factory: 31 | raise KeyError("Unsupported depth:", depth) 32 | resnet = ResNet.__factory[depth](pretrained=pretrained) 33 | resnet.layer4[0].conv2.stride = (1,1) 34 | resnet.layer4[0].downsample[0].stride = (1,1) 35 | self.base = nn.Sequential( 36 | resnet.conv1, resnet.bn1, 37 | resnet.relu, 38 | resnet.maxpool, 39 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 40 | self.gap = nn.AdaptiveAvgPool2d(1) 41 | 42 | if not self.cut_at_pooling: 43 | self.num_features = num_features 44 | self.norm = norm 45 | self.dropout = dropout 46 | self.has_embedding = num_features > 0 47 | self.num_classes = num_classes 48 | 49 | out_planes = resnet.fc.in_features 50 | 51 | # Append new layers 52 | if self.has_embedding: 53 | self.feat = nn.Linear(out_planes, self.num_features) 54 | self.feat_bn = nn.BatchNorm1d(self.num_features) 55 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 56 | init.constant_(self.feat.bias, 0) 57 | else: 58 | # Change the num_features to CNN output channels 59 | self.num_features = out_planes 60 | self.feat_bn = nn.BatchNorm1d(self.num_features) 61 | 62 | self.feat_bn.bias.requires_grad_(False) 63 | if self.dropout > 0: 64 | self.drop = nn.Dropout(self.dropout) 65 | if self.num_classes > 0: 66 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 67 | init.normal_(self.classifier.weight, std=0.001) 68 | init.constant_(self.feat_bn.weight, 1) 69 | init.constant_(self.feat_bn.bias, 0) 70 | 71 | if not pretrained: 72 | self.reset_params() 73 | 74 | def forward(self, x): 75 | bs = x.size(0) 76 | x = self.base(x) 77 | 78 | x = self.gap(x) 79 | x = x.view(x.size(0), -1) 80 | 81 | if self.cut_at_pooling: 82 | return x 83 | 84 | if self.has_embedding: 85 | bn_x = self.feat_bn(self.feat(x)) 86 | else: 87 | bn_x = self.feat_bn(x) 88 | 89 | if self.training is False: 90 | bn_x = F.normalize(bn_x) 91 | return bn_x 92 | 93 | if self.norm: 94 | bn_x = F.normalize(bn_x) 95 | elif self.has_embedding: 96 | bn_x = F.relu(bn_x) 97 | 98 | if self.dropout > 0: 99 | bn_x = self.drop(bn_x) 100 | 101 | if self.num_classes > 0: 102 | prob = self.classifier(bn_x) 103 | else: 104 | return bn_x 105 | 106 | return x, prob 107 | 108 | def reset_params(self): 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | init.kaiming_normal_(m.weight, mode='fan_out') 112 | if m.bias is not None: 113 | init.constant_(m.bias, 0) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | init.constant_(m.weight, 1) 116 | init.constant_(m.bias, 0) 117 | elif isinstance(m, nn.BatchNorm1d): 118 | init.constant_(m.weight, 1) 119 | init.constant_(m.bias, 0) 120 | elif isinstance(m, nn.Linear): 121 | init.normal_(m.weight, std=0.001) 122 | if m.bias is not None: 123 | init.constant_(m.bias, 0) 124 | 125 | resnet = ResNet.__factory[self.depth](pretrained=self.pretrained) 126 | self.base[0].load_state_dict(resnet.conv1.state_dict()) 127 | self.base[1].load_state_dict(resnet.bn1.state_dict()) 128 | self.base[4].load_state_dict(resnet.layer1.state_dict()) 129 | self.base[5].load_state_dict(resnet.layer2.state_dict()) 130 | self.base[6].load_state_dict(resnet.layer3.state_dict()) 131 | self.base[7].load_state_dict(resnet.layer4.state_dict()) 132 | 133 | 134 | def resnet18(**kwargs): 135 | return ResNet(18, **kwargs) 136 | 137 | 138 | def resnet34(**kwargs): 139 | return ResNet(34, **kwargs) 140 | 141 | 142 | def resnet50(**kwargs): 143 | return ResNet(50, **kwargs) 144 | 145 | 146 | def resnet101(**kwargs): 147 | return ResNet(101, **kwargs) 148 | 149 | 150 | def resnet152(**kwargs): 151 | return ResNet(152, **kwargs) 152 | -------------------------------------------------------------------------------- /ice/models/resnet_ibn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | 9 | from .resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a 10 | 11 | 12 | __all__ = ['ResNetIBN', 'resnet_ibn50a', 'resnet_ibn101a'] 13 | 14 | 15 | class ResNetIBN(nn.Module): 16 | __factory = { 17 | '50a': resnet50_ibn_a, 18 | '101a': resnet101_ibn_a 19 | } 20 | 21 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 22 | num_features=0, norm=False, dropout=0, num_classes=0): 23 | super(ResNetIBN, self).__init__() 24 | 25 | self.depth = depth 26 | self.pretrained = pretrained 27 | self.cut_at_pooling = cut_at_pooling 28 | 29 | resnet = ResNetIBN.__factory[depth](pretrained=pretrained) 30 | resnet.layer4[0].conv2.stride = (1,1) 31 | resnet.layer4[0].downsample[0].stride = (1,1) 32 | self.base = nn.Sequential( 33 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 34 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 35 | self.gap = nn.AdaptiveAvgPool2d(1) 36 | 37 | if not self.cut_at_pooling: 38 | self.num_features = num_features 39 | self.norm = norm 40 | self.dropout = dropout 41 | self.has_embedding = num_features > 0 42 | self.num_classes = num_classes 43 | 44 | out_planes = resnet.fc.in_features 45 | 46 | # Append new layers 47 | if self.has_embedding: 48 | self.feat = nn.Linear(out_planes, self.num_features) 49 | self.feat_bn = nn.BatchNorm1d(self.num_features) 50 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 51 | init.constant_(self.feat.bias, 0) 52 | else: 53 | # Change the num_features to CNN output channels 54 | self.num_features = out_planes 55 | self.feat_bn = nn.BatchNorm1d(self.num_features) 56 | self.feat_bn.bias.requires_grad_(False) 57 | if self.dropout > 0: 58 | self.drop = nn.Dropout(self.dropout) 59 | if self.num_classes > 0: 60 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 61 | init.normal_(self.classifier.weight, std=0.001) 62 | init.constant_(self.feat_bn.weight, 1) 63 | init.constant_(self.feat_bn.bias, 0) 64 | 65 | if not pretrained: 66 | self.reset_params() 67 | 68 | def forward(self, x): 69 | x = self.base(x) 70 | 71 | x = self.gap(x) 72 | x = x.view(x.size(0), -1) 73 | 74 | if self.cut_at_pooling: 75 | return x 76 | 77 | if self.has_embedding: 78 | bn_x = self.feat_bn(self.feat(x)) 79 | else: 80 | bn_x = self.feat_bn(x) 81 | 82 | if self.training is False: 83 | bn_x = F.normalize(bn_x) 84 | return bn_x 85 | 86 | if self.norm: 87 | bn_x = F.normalize(bn_x) 88 | elif self.has_embedding: 89 | bn_x = F.relu(bn_x) 90 | 91 | if self.dropout > 0: 92 | bn_x = self.drop(bn_x) 93 | 94 | if self.num_classes > 0: 95 | prob = self.classifier(bn_x) 96 | else: 97 | return bn_x 98 | 99 | return prob 100 | 101 | def reset_params(self): 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | init.kaiming_normal_(m.weight, mode='fan_out') 105 | if m.bias is not None: 106 | init.constant_(m.bias, 0) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | init.constant_(m.weight, 1) 109 | init.constant_(m.bias, 0) 110 | elif isinstance(m, nn.BatchNorm1d): 111 | init.constant_(m.weight, 1) 112 | init.constant_(m.bias, 0) 113 | elif isinstance(m, nn.Linear): 114 | init.normal_(m.weight, std=0.001) 115 | if m.bias is not None: 116 | init.constant_(m.bias, 0) 117 | 118 | resnet = ResNetIBN.__factory[self.depth](pretrained=self.pretrained) 119 | self.base[0].load_state_dict(resnet.conv1.state_dict()) 120 | self.base[1].load_state_dict(resnet.bn1.state_dict()) 121 | self.base[2].load_state_dict(resnet.relu.state_dict()) 122 | self.base[3].load_state_dict(resnet.maxpool.state_dict()) 123 | self.base[4].load_state_dict(resnet.layer1.state_dict()) 124 | self.base[5].load_state_dict(resnet.layer2.state_dict()) 125 | self.base[6].load_state_dict(resnet.layer3.state_dict()) 126 | self.base[7].load_state_dict(resnet.layer4.state_dict()) 127 | 128 | 129 | def resnet_ibn50a(**kwargs): 130 | return ResNetIBN('50a', **kwargs) 131 | 132 | 133 | def resnet_ibn101a(**kwargs): 134 | return ResNetIBN('101a', **kwargs) 135 | -------------------------------------------------------------------------------- /ice/models/resnet_ibn_a.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['ResNet', 'resnet50_ibn_a', 'resnet101_ibn_a'] 8 | 9 | 10 | model_urls = { 11 | 'ibn_resnet50a': '/home/hchen/Projects/Baseline/logs/pretrained/resnet50_ibn_a.pth.tar', 12 | 'ibn_resnet101a': './logs/pretrained/resnet101_ibn_a.pth.tar', 13 | } 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class IBN(nn.Module): 55 | def __init__(self, planes): 56 | super(IBN, self).__init__() 57 | half1 = int(planes/2) 58 | self.half = half1 59 | half2 = planes - half1 60 | self.IN = nn.InstanceNorm2d(half1, affine=True) 61 | self.BN = nn.BatchNorm2d(half2) 62 | 63 | def forward(self, x): 64 | split = torch.split(x, self.half, 1) 65 | out1 = self.IN(split[0].contiguous()) 66 | out2 = self.BN(split[1].contiguous()) 67 | out = torch.cat((out1, out2), 1) 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): 75 | super(Bottleneck, self).__init__() 76 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 77 | if ibn: 78 | self.bn1 = IBN(planes) 79 | else: 80 | self.bn1 = nn.BatchNorm2d(planes) 81 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 82 | padding=1, bias=False) 83 | self.bn2 = nn.BatchNorm2d(planes) 84 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 85 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 86 | self.relu = nn.ReLU(inplace=True) 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x): 91 | residual = x 92 | 93 | out = self.conv1(x) 94 | out = self.bn1(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv2(out) 98 | out = self.bn2(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv3(out) 102 | out = self.bn3(out) 103 | 104 | if self.downsample is not None: 105 | residual = self.downsample(x) 106 | 107 | out += residual 108 | out = self.relu(out) 109 | 110 | return out 111 | 112 | 113 | class ResNet(nn.Module): 114 | 115 | def __init__(self, block, layers, num_classes=1000): 116 | scale = 64 117 | self.inplanes = scale 118 | super(ResNet, self).__init__() 119 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 120 | bias=False) 121 | self.bn1 = nn.BatchNorm2d(scale) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 124 | self.layer1 = self._make_layer(block, scale, layers[0]) 125 | self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2) 126 | self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2) 127 | self.layer4 = self._make_layer(block, scale*8, layers[3], stride=2) 128 | self.avgpool = nn.AvgPool2d(7) 129 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 130 | 131 | for m in self.modules(): 132 | if isinstance(m, nn.Conv2d): 133 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 134 | m.weight.data.normal_(0, math.sqrt(2. / n)) 135 | elif isinstance(m, nn.BatchNorm2d): 136 | m.weight.data.fill_(1) 137 | m.bias.data.zero_() 138 | elif isinstance(m, nn.InstanceNorm2d): 139 | m.weight.data.fill_(1) 140 | m.bias.data.zero_() 141 | 142 | def _make_layer(self, block, planes, blocks, stride=1): 143 | downsample = None 144 | if stride != 1 or self.inplanes != planes * block.expansion: 145 | downsample = nn.Sequential( 146 | nn.Conv2d(self.inplanes, planes * block.expansion, 147 | kernel_size=1, stride=stride, bias=False), 148 | nn.BatchNorm2d(planes * block.expansion), 149 | ) 150 | 151 | layers = [] 152 | ibn = True 153 | if planes == 512: 154 | ibn = False 155 | layers.append(block(self.inplanes, planes, ibn, stride, downsample)) 156 | self.inplanes = planes * block.expansion 157 | for i in range(1, blocks): 158 | layers.append(block(self.inplanes, planes, ibn)) 159 | 160 | return nn.Sequential(*layers) 161 | 162 | def forward(self, x): 163 | x = self.conv1(x) 164 | x = self.bn1(x) 165 | x = self.relu(x) 166 | x = self.maxpool(x) 167 | 168 | x = self.layer1(x) 169 | x = self.layer2(x) 170 | x = self.layer3(x) 171 | x = self.layer4(x) 172 | 173 | x = self.avgpool(x) 174 | x = x.view(x.size(0), -1) 175 | x = self.fc(x) 176 | 177 | return x 178 | 179 | 180 | def resnet50_ibn_a(pretrained=False, **kwargs): 181 | """Constructs a ResNet-50 model. 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 186 | if pretrained: 187 | state_dict = torch.load(model_urls['ibn_resnet50a'], map_location=torch.device('cpu'))['state_dict'] 188 | state_dict = remove_module_key(state_dict) 189 | model.load_state_dict(state_dict) 190 | return model 191 | 192 | 193 | def resnet101_ibn_a(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | state_dict = torch.load(model_urls['ibn_resnet101a'], map_location=torch.device('cpu'))['state_dict'] 201 | state_dict = remove_module_key(state_dict) 202 | model.load_state_dict(state_dict) 203 | return model 204 | 205 | 206 | def remove_module_key(state_dict): 207 | for key in list(state_dict.keys()): 208 | if 'module' in key: 209 | state_dict[key.replace('module.','')] = state_dict.pop(key) 210 | return state_dict 211 | -------------------------------------------------------------------------------- /ice/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import numpy as np 4 | import collections 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | from ice.loss import CrossEntropyLabelSmooth, ViewContrastiveLoss 10 | from .utils.meters import AverageMeter 11 | from .evaluation_metrics import accuracy 12 | 13 | 14 | class ImageTrainer(object): 15 | def __init__(self, model_1, model_1_ema, num_cluster=500, alpha=0.999, num_instance=4, tau_c=0.5, tau_v=0.09, 16 | scale_kl=2.0): 17 | super(ImageTrainer, self).__init__() 18 | self.model_1 = model_1 19 | self.model_1_ema = model_1_ema 20 | self.alpha = alpha 21 | 22 | self.tau_c = tau_c 23 | self.tau_v = tau_v 24 | self.scale_kl = scale_kl 25 | 26 | self.ccloss = CrossEntropyLabelSmooth(num_cluster) 27 | self.vcloss = ViewContrastiveLoss(num_instance=num_instance, T=tau_v) 28 | self.kl = nn.KLDivLoss(reduction='batchmean') 29 | self.crosscam_epoch = 0 30 | self.beta = 0.07 31 | self.bg_knn = 50 32 | 33 | self.mse = nn.MSELoss(reduction='sum') 34 | 35 | def train(self, epoch, data_loader_target, 36 | optimizer, print_freq=1, train_iters=200, centers=None, intra_id_labels=None, intra_id_features=None, 37 | cams=None, all_pseudo_label=None): 38 | self.model_1.train() 39 | self.model_1_ema.train() 40 | centers = centers.cuda() 41 | # outliers = outliers.cuda() 42 | 43 | batch_time = AverageMeter() 44 | data_time = AverageMeter() 45 | 46 | losses_ccl = AverageMeter() 47 | losses_vcl = AverageMeter() 48 | losses_cam = AverageMeter() 49 | losses_kl = AverageMeter() 50 | precisions = AverageMeter() 51 | 52 | self.all_img_cams = torch.tensor(cams).cuda() 53 | self.unique_cams = torch.unique(self.all_img_cams) 54 | # print(self.unique_cams) 55 | 56 | self.all_pseudo_label = torch.tensor(all_pseudo_label).cuda() 57 | self.init_intra_id_feat = intra_id_features 58 | # print(len(self.init_intra_id_feat)) 59 | 60 | # initialize proxy memory 61 | self.percam_memory = [] 62 | self.memory_class_mapper = [] 63 | self.concate_intra_class = [] 64 | for cc in self.unique_cams: 65 | percam_ind = torch.nonzero(self.all_img_cams == cc).squeeze(-1) 66 | uniq_class = torch.unique(self.all_pseudo_label[percam_ind]) 67 | uniq_class = uniq_class[uniq_class >= 0] 68 | self.concate_intra_class.append(uniq_class) 69 | cls_mapper = {int(uniq_class[j]): j for j in range(len(uniq_class))} 70 | self.memory_class_mapper.append(cls_mapper) # from pseudo label to index under each camera 71 | 72 | if len(self.init_intra_id_feat) > 0: 73 | # print('initializing ID memory from updated embedding features...') 74 | proto_memory = self.init_intra_id_feat[cc] 75 | proto_memory = proto_memory.cuda() 76 | self.percam_memory.append(proto_memory.detach()) 77 | self.concate_intra_class = torch.cat(self.concate_intra_class) 78 | 79 | if epoch >= self.crosscam_epoch: 80 | percam_tempV = [] 81 | for ii in self.unique_cams: 82 | percam_tempV.append(self.percam_memory[ii].detach().clone()) 83 | percam_tempV = torch.cat(percam_tempV, dim=0).cuda() 84 | 85 | end = time.time() 86 | for i in range(train_iters): 87 | target_inputs = data_loader_target.next() 88 | data_time.update(time.time() - end) 89 | # process inputs 90 | inputs_1, inputs_weak, targets, inputs_2, cids = self._parse_data(target_inputs) 91 | b, c, h, w = inputs_1.size() 92 | 93 | # ids for ShuffleBN 94 | shuffle_ids, reverse_ids = self.get_shuffle_ids(b) 95 | 96 | f_out_t1 = self.model_1(inputs_1) 97 | p_out_t1 = torch.matmul(f_out_t1, centers.transpose(1, 0)) / self.tau_c 98 | 99 | f_out_t2 = self.model_1(inputs_2) 100 | 101 | loss_cam = torch.tensor([0.]).cuda() 102 | for cc in torch.unique(cids): 103 | # print(cc) 104 | inds = torch.nonzero(cids == cc).squeeze(-1) 105 | percam_targets = targets[inds] 106 | # print(percam_targets) 107 | percam_feat = f_out_t1[inds] 108 | 109 | # # intra-camera loss 110 | # mapped_targets = [self.memory_class_mapper[cc][int(k)] for k in percam_targets] 111 | # mapped_targets = torch.tensor(mapped_targets).to(torch.device('cuda')) 112 | # # percam_inputs = ExemplarMemory.apply(percam_feat, mapped_targets, self.percam_memory[cc], self.alpha) 113 | # percam_inputs = torch.matmul(F.normalize(percam_feat), F.normalize(self.percam_memory[cc].t())) 114 | # percam_inputs /= self.beta # similarity score before softmax 115 | # loss_cam += F.cross_entropy(percam_inputs, mapped_targets) 116 | 117 | # cross-camera loss 118 | if epoch >= self.crosscam_epoch: 119 | associate_loss = 0 120 | # target_inputs = percam_feat.mm(percam_tempV.t().clone()) 121 | target_inputs = torch.matmul(F.normalize(percam_feat), F.normalize(percam_tempV.t().clone())) 122 | temp_sims = target_inputs.detach().clone() 123 | target_inputs /= self.beta 124 | 125 | for k in range(len(percam_feat)): 126 | ori_asso_ind = torch.nonzero(self.concate_intra_class == percam_targets[k]).squeeze(-1) 127 | temp_sims[k, ori_asso_ind] = -10000.0 # mask out positive 128 | sel_ind = torch.sort(temp_sims[k])[1][-self.bg_knn:] 129 | concated_input = torch.cat((target_inputs[k, ori_asso_ind], target_inputs[k, sel_ind]), dim=0) 130 | concated_target = torch.zeros((len(concated_input)), dtype=concated_input.dtype).to( 131 | torch.device('cuda')) 132 | concated_target[0:len(ori_asso_ind)] = 1.0 / len(ori_asso_ind) 133 | associate_loss += -1 * ( 134 | F.log_softmax(concated_input.unsqueeze(0), dim=1) * concated_target.unsqueeze( 135 | 0)).sum() 136 | loss_cam += 0.5 * associate_loss / len(percam_feat) 137 | 138 | with torch.no_grad(): 139 | inputs_1 = inputs_1[shuffle_ids] 140 | f_out_t1_ema = self.model_1_ema(inputs_1) 141 | f_out_t1_ema = f_out_t1_ema[reverse_ids] 142 | 143 | inputs_2 = inputs_2[shuffle_ids] 144 | f_out_t2_ema = self.model_1_ema(inputs_2) 145 | f_out_t2_ema = f_out_t2_ema[reverse_ids] 146 | 147 | inputs_weak = inputs_weak[shuffle_ids] 148 | f_out_weak_ema = self.model_1_ema(inputs_weak) 149 | f_out_weak_ema = f_out_weak_ema[reverse_ids] 150 | 151 | loss_ccl = self.ccloss(p_out_t1, targets) 152 | loss_vcl = self.vcloss(F.normalize(f_out_t1), F.normalize(f_out_t2_ema), targets) 153 | 154 | loss_kl = self.kl(F.softmax( 155 | torch.matmul(F.normalize(f_out_t1), F.normalize(f_out_t2_ema).transpose(1, 0)) / self.scale_kl, 156 | dim=1).log(), 157 | F.softmax(torch.matmul(F.normalize(f_out_weak_ema), 158 | F.normalize(f_out_weak_ema).transpose(1, 0)) / self.scale_kl, 159 | dim=1)) * 10 160 | 161 | loss = loss_ccl + loss_vcl + loss_cam + loss_kl 162 | 163 | optimizer.zero_grad() 164 | loss.backward() 165 | optimizer.step() 166 | 167 | self._update_ema_variables(self.model_1, self.model_1_ema, self.alpha, epoch * len(data_loader_target) + i) 168 | 169 | prec_1, = accuracy(p_out_t1.data, targets.data) 170 | 171 | losses_ccl.update(loss_ccl.item()) 172 | losses_cam.update(loss_cam.item()) 173 | losses_vcl.update(loss_vcl.item()) 174 | losses_kl.update(loss_kl.item()) 175 | precisions.update(prec_1[0]) 176 | 177 | # print log # 178 | batch_time.update(time.time() - end) 179 | end = time.time() 180 | 181 | if (i + 1) % print_freq == 0: 182 | print('Epoch: [{}][{}/{}]\t' 183 | 'Time {:.3f} ({:.3f})\t' 184 | 'Data {:.3f} ({:.3f})\t' 185 | 'Loss_ccl {:.3f}\t' 186 | 'Loss_hard_instance {:.3f}\t' 187 | 'Loss_cam {:.3f}\t' 188 | 'Loss_kl {:.3f}\t' 189 | 'Prec {:.2%}\t' 190 | .format(epoch, i + 1, len(data_loader_target), 191 | batch_time.val, batch_time.avg, 192 | data_time.val, data_time.avg, 193 | losses_ccl.avg, 194 | losses_vcl.avg, 195 | losses_cam.avg, 196 | losses_kl.avg, 197 | precisions.avg)) 198 | 199 | def _update_ema_variables(self, model, ema_model, alpha, global_step): 200 | # alpha = min(1 - 1 / (global_step + 1), alpha) 201 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 202 | ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha) 203 | 204 | def _parse_data(self, inputs): 205 | imgs_1, imgs_2, img_mutual, pids, cids = inputs 206 | inputs_1 = imgs_1.cuda() 207 | inputs_2 = imgs_2.cuda() 208 | inputs_mutual = img_mutual.cuda() 209 | targets = pids.cuda() 210 | cids = cids.cuda() 211 | return inputs_1, inputs_2, targets, inputs_mutual, cids 212 | 213 | def get_shuffle_ids(self, bsz): 214 | """generate shuffle ids for ShuffleBN""" 215 | forward_inds = torch.randperm(bsz).long().cuda() 216 | backward_inds = torch.zeros(bsz).long().cuda() 217 | value = torch.arange(bsz).long().cuda() 218 | backward_inds.index_copy_(0, forward_inds, value) 219 | return forward_inds, backward_inds 220 | -------------------------------------------------------------------------------- /ice/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /ice/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .base_dataset import BaseDataset, BaseImageDataset, BaseVideoDataset 4 | from .preprocessor import Preprocessor 5 | 6 | class IterLoader: 7 | def __init__(self, loader, length=None): 8 | self.loader = loader 9 | self.length = length 10 | self.iter = None 11 | 12 | def __len__(self): 13 | if (self.length is not None): 14 | return self.length 15 | return len(self.loader) 16 | 17 | def new_epoch(self): 18 | self.iter = iter(self.loader) 19 | 20 | def next(self): 21 | try: 22 | return next(self.iter) 23 | except: 24 | self.iter = iter(self.loader) 25 | return next(self.iter) 26 | -------------------------------------------------------------------------------- /ice/utils/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def get_videodata_info(self, data): 23 | pids, cams = [], [] 24 | for _, pid, camid, _ in data: 25 | pids += [pid] 26 | cams += [camid] 27 | pids = set(pids) 28 | cams = set(cams) 29 | num_pids = len(pids) 30 | num_cams = len(cams) 31 | num_imgs = len(data) 32 | return num_pids, num_imgs, num_cams 33 | 34 | def print_dataset_statistics(self): 35 | raise NotImplementedError 36 | 37 | @property 38 | def images_dir(self): 39 | return None 40 | 41 | 42 | class BaseImageDataset(BaseDataset): 43 | """ 44 | Base class of image reid dataset 45 | """ 46 | 47 | def print_dataset_statistics(self, train, query, gallery): 48 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 49 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 50 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 51 | 52 | print("Dataset statistics:") 53 | print(" ----------------------------------------") 54 | print(" subset | # ids | # images | # cameras") 55 | print(" ----------------------------------------") 56 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 57 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 58 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 59 | print(" ----------------------------------------") 60 | 61 | 62 | class BaseVideoDataset(BaseDataset): 63 | """ 64 | Base class of video reid dataset 65 | """ 66 | 67 | def print_dataset_statistics(self, train, query, gallery): 68 | num_train_pids, num_train_imgs, num_train_cams = self.get_videodata_info(train) 69 | num_query_pids, num_query_imgs, num_query_cams = self.get_videodata_info(query) 70 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_videodata_info(gallery) 71 | 72 | print("Dataset statistics:") 73 | print(" ----------------------------------------") 74 | print(" subset | # ids | # tracklets | # cameras") 75 | print(" ----------------------------------------") 76 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 77 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 78 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 79 | print(" ----------------------------------------") 80 | 81 | -------------------------------------------------------------------------------- /ice/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..serialization import read_json 7 | 8 | 9 | def _pluck(identities, indices, relabel=False): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, cam_images in enumerate(pid_images): 14 | for fname in cam_images: 15 | name = osp.splitext(fname)[0] 16 | x, y, _ = map(int, name.split('_')) 17 | assert pid == x and camid == y 18 | if relabel: 19 | ret.append((fname, index, camid)) 20 | else: 21 | ret.append((fname, pid, camid)) 22 | return ret 23 | 24 | def _pluck_gallery(identities, indices, relabel=False): 25 | ret = [] 26 | for index, pid in enumerate(indices): 27 | pid_images = identities[pid] 28 | for camid, cam_images in enumerate(pid_images): 29 | if len(cam_images[:-1])==0: 30 | for fname in cam_images: 31 | name = osp.splitext(fname)[0] 32 | x, y, _ = map(int, name.split('_')) 33 | assert pid == x and camid == y 34 | if relabel: 35 | ret.append((fname, index, camid)) 36 | else: 37 | ret.append((fname, pid, camid)) 38 | else: 39 | for fname in cam_images[:-1]: 40 | name = osp.splitext(fname)[0] 41 | x, y, _ = map(int, name.split('_')) 42 | assert pid == x and camid == y 43 | if relabel: 44 | ret.append((fname, index, camid)) 45 | else: 46 | ret.append((fname, pid, camid)) 47 | return ret 48 | 49 | def _pluck_query(identities, indices, relabel=False): 50 | ret = [] 51 | for index, pid in enumerate(indices): 52 | pid_images = identities[pid] 53 | for camid, cam_images in enumerate(pid_images): 54 | for fname in cam_images[-1:]: 55 | name = osp.splitext(fname)[0] 56 | x, y, _ = map(int, name.split('_')) 57 | assert pid == x and camid == y 58 | if relabel: 59 | ret.append((fname, index, camid)) 60 | else: 61 | ret.append((fname, pid, camid)) 62 | return ret 63 | 64 | 65 | class Dataset(object): 66 | def __init__(self, root, split_id=0): 67 | self.root = root 68 | self.split_id = split_id 69 | self.meta = None 70 | self.split = None 71 | self.train, self.val, self.trainval = [], [], [] 72 | self.query, self.gallery = [], [] 73 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 74 | 75 | @property 76 | def images_dir(self): 77 | return osp.join(self.root, 'images') 78 | 79 | def load(self, num_val=0.3, verbose=True): 80 | splits = read_json(osp.join(self.root, 'splits.json')) 81 | if self.split_id >= len(splits): 82 | raise ValueError("split_id exceeds total splits {}" 83 | .format(len(splits))) 84 | self.split = splits[self.split_id] 85 | 86 | # Randomly split train / val 87 | trainval_pids = np.asarray(self.split['trainval']) 88 | np.random.shuffle(trainval_pids) 89 | num = len(trainval_pids) 90 | if isinstance(num_val, float): 91 | num_val = int(round(num * num_val)) 92 | if num_val >= num or num_val < 0: 93 | raise ValueError("num_val exceeds total identities {}" 94 | .format(num)) 95 | train_pids = sorted(trainval_pids[:-num_val]) 96 | val_pids = sorted(trainval_pids[-num_val:]) 97 | 98 | self.meta = read_json(osp.join(self.root, 'meta.json')) 99 | identities = self.meta['identities'] 100 | self.train = _pluck(identities, train_pids, relabel=True) 101 | self.val = _pluck(identities, val_pids, relabel=True) 102 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 103 | self.query = _pluck_query(identities, self.split['query']) 104 | #self.gallery = _pluck(identities, self.split['gallery']) 105 | self.gallery = _pluck_gallery(identities, self.split['gallery']) 106 | self.num_train_ids = len(train_pids) 107 | self.num_val_ids = len(val_pids) 108 | self.num_trainval_ids = len(trainval_pids) 109 | 110 | if verbose: 111 | print(self.__class__.__name__, "dataset loaded") 112 | print(" subset | # ids | # images") 113 | print(" ---------------------------") 114 | print(" train | {:5d} | {:8d}" 115 | .format(self.num_train_ids, len(self.train))) 116 | print(" val | {:5d} | {:8d}" 117 | .format(self.num_val_ids, len(self.val))) 118 | print(" trainval | {:5d} | {:8d}" 119 | .format(self.num_trainval_ids, len(self.trainval))) 120 | print(" query | {:5d} | {:8d}" 121 | .format(len(self.split['query']), len(self.query))) 122 | print(" gallery | {:5d} | {:8d}" 123 | .format(len(self.split['gallery']), len(self.gallery))) 124 | 125 | def _check_integrity(self): 126 | return osp.isdir(osp.join(self.root, 'images')) and \ 127 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 128 | osp.isfile(osp.join(self.root, 'splits.json')) 129 | -------------------------------------------------------------------------------- /ice/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import torch 7 | import random 8 | import math 9 | from PIL import Image 10 | 11 | class Preprocessor(Dataset): 12 | def __init__(self, dataset, root=None, transform=None, mutual=False): 13 | super(Preprocessor, self).__init__() 14 | self.dataset = dataset 15 | self.root = root 16 | self.transform = transform 17 | self.mutual = mutual 18 | 19 | def __len__(self): 20 | return len(self.dataset) 21 | 22 | def __getitem__(self, indices): 23 | if self.mutual: 24 | return self._get_mutual_item(indices) 25 | else: 26 | return self._get_single_item(indices) 27 | 28 | def _get_single_item(self, index): 29 | fname, pid, camid = self.dataset[index] 30 | fpath = fname 31 | if self.root is not None: 32 | fpath = osp.join(self.root, fname) 33 | 34 | img = Image.open(fpath).convert('RGB') 35 | 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | 39 | return img, fname, pid, camid, index 40 | 41 | def _get_mutual_item(self, index): 42 | fname, pid, camid = self.dataset[index] 43 | fpath = fname 44 | if self.root is not None: 45 | fpath = osp.join(self.root, fname) 46 | 47 | img = Image.open(fpath).convert('RGB') 48 | img2 = img.copy() 49 | 50 | if self.transform is not None: 51 | img1 = self.transform(img) 52 | img2 = self.transform(img2) 53 | else: 54 | raise NotImplementedError 55 | 56 | return img1, img2, fname, pid, camid, index 57 | 58 | 59 | class Preprocessor_mutual(Dataset): 60 | def __init__(self, dataset, root=None, transform=None, mutual=False, transform_weak=None): 61 | super(Preprocessor_mutual, self).__init__() 62 | self.dataset = dataset 63 | self.root = root 64 | self.transform = transform 65 | self.transform_weak = transform_weak 66 | self.mutual = mutual 67 | 68 | # self.use_gan=use_gan 69 | # self.num_cam = num_cam 70 | 71 | def __len__(self): 72 | return len(self.dataset) 73 | 74 | def __getitem__(self, indices): 75 | if self.mutual: 76 | return self._get_mutual_item(indices) 77 | else: 78 | return self._get_single_item(indices) 79 | 80 | def _get_single_item(self, index): 81 | fname, pid, camid = self.dataset[index] 82 | fpath = fname 83 | if self.root is not None: 84 | fpath = osp.join(self.root, fname) 85 | 86 | img = Image.open(fpath).convert('RGB') 87 | 88 | if self.transform is not None: 89 | img = self.transform(img) 90 | 91 | return img, fname, pid, camid 92 | 93 | def _get_mutual_item(self, index): 94 | fname, pid, camid = self.dataset[index] 95 | fpath = fname 96 | if self.root is not None: 97 | fpath = osp.join(self.root, fname) 98 | 99 | img = Image.open(fpath).convert('RGB') 100 | img_mutual = img.copy() 101 | img2 = img.copy() 102 | 103 | if self.transform is not None: 104 | img1 = self.transform(img) 105 | img_mutual = self.transform(img_mutual) 106 | if self.transform_weak is not None: 107 | img2 = self.transform_weak(img2) 108 | else: 109 | img2 = self.transform(img2) 110 | else: 111 | raise NotImplementedError 112 | 113 | return img1, img2, img_mutual, pid, camid -------------------------------------------------------------------------------- /ice/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import math 4 | 5 | import numpy as np 6 | import copy 7 | import random 8 | import torch 9 | from torch.utils.data.sampler import ( 10 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 11 | WeightedRandomSampler) 12 | 13 | 14 | def No_index(a, b): 15 | assert isinstance(a, list) 16 | return [i for i, j in enumerate(a) if j != b] 17 | 18 | 19 | class RandomIdentitySampler(Sampler): 20 | def __init__(self, data_source, num_instances, video=False): 21 | self.data_source = data_source 22 | self.num_instances = num_instances 23 | self.index_dic = defaultdict(list) 24 | self.video = video 25 | if self.video: 26 | for index, (_, pid, _, _) in enumerate(data_source): 27 | self.index_dic[pid].append(index) 28 | else: 29 | for index, (_, pid, _) in enumerate(data_source): 30 | self.index_dic[pid].append(index) 31 | self.pids = list(self.index_dic.keys()) 32 | self.num_samples = len(self.pids) 33 | 34 | def __len__(self): 35 | return self.num_samples * self.num_instances 36 | 37 | def __iter__(self): 38 | indices = torch.randperm(self.num_samples).tolist() 39 | ret = [] 40 | for i in indices: 41 | pid = self.pids[i] 42 | t = self.index_dic[pid] 43 | if len(t) >= self.num_instances: 44 | t = np.random.choice(t, size=self.num_instances, replace=False) 45 | else: 46 | t = np.random.choice(t, size=self.num_instances, replace=True) 47 | ret.extend(t) 48 | return iter(ret) 49 | 50 | 51 | class RandomMultipleGallerySampler(Sampler): 52 | def __init__(self, data_source, num_instances=4, video=False): 53 | self.data_source = data_source 54 | self.index_pid = defaultdict(int) 55 | self.pid_cam = defaultdict(list) 56 | self.pid_index = defaultdict(list) 57 | self.num_instances = num_instances 58 | self.video=video 59 | 60 | if self.video: 61 | for index, (_, pid, cam, _) in enumerate(data_source): 62 | if (pid < 0): continue 63 | self.index_pid[index] = pid 64 | self.pid_cam[pid].append(cam) 65 | self.pid_index[pid].append(index) 66 | else: 67 | for index, (_, pid, cam) in enumerate(data_source): 68 | if (pid < 0): continue 69 | self.index_pid[index] = pid 70 | self.pid_cam[pid].append(cam) 71 | self.pid_index[pid].append(index) 72 | 73 | self.pids = list(self.pid_index.keys()) 74 | self.num_samples = len(self.pids) 75 | 76 | def __len__(self): 77 | return self.num_samples * self.num_instances 78 | 79 | def __iter__(self): 80 | indices = torch.randperm(len(self.pids)).tolist() 81 | ret = [] 82 | 83 | for kid in indices: 84 | i = random.choice(self.pid_index[self.pids[kid]]) 85 | 86 | if self.video: 87 | _, i_pid, i_cam, _ = self.data_source[i] 88 | else: 89 | _, i_pid, i_cam = self.data_source[i] 90 | 91 | 92 | 93 | pid_i = self.index_pid[i] 94 | cams = self.pid_cam[pid_i] 95 | index = self.pid_index[pid_i] 96 | select_cams = No_index(cams, i_cam) 97 | 98 | if select_cams: 99 | ret.append(i) 100 | if len(select_cams) >= self.num_instances: 101 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 102 | else: 103 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True) 104 | 105 | for kk in cam_indexes: 106 | ret.append(index[kk]) 107 | 108 | else: 109 | if len(index) >= self.num_instances: 110 | t = np.random.choice(index, size=self.num_instances, replace=False) 111 | else: 112 | t = np.random.choice(index, size=self.num_instances, replace=True) 113 | ret.extend(t) 114 | 115 | 116 | return iter(ret) 117 | 118 | 119 | class RandomIdentityCameraSampler(Sampler): 120 | def __init__(self, data_source, num_instances=4, video=False): 121 | self.data_source = data_source 122 | self.index_pid = defaultdict(int) 123 | self.pid_cam = defaultdict(list) 124 | self.pid_index = defaultdict(list) 125 | self.num_instances = num_instances 126 | self.video=video 127 | 128 | if self.video: 129 | for index, (_, pid, cam, _) in enumerate(data_source): 130 | if (pid < 0): continue 131 | self.index_pid[index] = pid 132 | self.pid_cam[pid].append(cam) 133 | self.pid_index[pid].append(index) 134 | else: 135 | for index, (_, pid, cam) in enumerate(data_source): 136 | if (pid < 0): continue 137 | self.index_pid[index] = pid 138 | self.pid_cam[pid].append(cam) 139 | self.pid_index[pid].append(index) 140 | 141 | self.pids = list(self.pid_index.keys()) 142 | self.num_samples = len(self.pids) 143 | 144 | def __len__(self): 145 | return self.num_samples * self.num_instances 146 | 147 | def __iter__(self): 148 | indices = torch.randperm(len(self.pids)).tolist() 149 | ret = [] 150 | 151 | for kid in indices: 152 | i = random.choice(self.pid_index[self.pids[kid]]) 153 | 154 | if self.video: 155 | _, i_pid, i_cam, _ = self.data_source[i] 156 | else: 157 | _, i_pid, i_cam = self.data_source[i] 158 | cams = self.pid_cam[i_pid] 159 | index = self.pid_index[i_pid] 160 | select_cams = No_index(cams, i_cam) 161 | 162 | if select_cams: 163 | ret.append(i) 164 | if len(select_cams) >= self.num_instances: 165 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 166 | else: 167 | cam_indexes = np.random.permutation(select_cams) 168 | select_index = No_index(index,i) 169 | select_index = np.setdiff1d(select_index, cam_indexes) 170 | if len(select_index)==0: 171 | ind_indexes = np.random.choice(select_cams, size=self.num_instances-1-len(cam_indexes), replace=True) 172 | elif len(select_index) >= self.num_instances-1-len(cam_indexes): 173 | ind_indexes = np.random.choice(select_index, size=self.num_instances-1-len(cam_indexes), replace=False) 174 | else: 175 | ind_indexes = np.random.choice(select_index, size=self.num_instances-1-len(cam_indexes), replace=True) 176 | cam_indexes = np.concatenate((cam_indexes, ind_indexes)) 177 | for kk in cam_indexes: 178 | ret.append(index[kk]) 179 | else: 180 | if len(index) >= self.num_instances: 181 | t = np.random.choice(index, size=self.num_instances, replace=False) 182 | else: 183 | t = np.random.choice(index, size=self.num_instances, replace=True) 184 | ret.extend(t) 185 | 186 | return iter(ret) 187 | 188 | 189 | class MoreCameraSampler(Sampler): 190 | def __init__(self, data_source, num_instances=4, video=False): 191 | self.data_source = data_source 192 | self.index_pid = defaultdict(int) 193 | self.pid_cam = defaultdict(list) 194 | self.pid_index = defaultdict(list) 195 | self.num_instances = num_instances 196 | self.video = video 197 | 198 | if self.video: 199 | for index, (_, pid, cam, _) in enumerate(data_source): 200 | if (pid < 0): continue 201 | self.index_pid[index] = pid 202 | self.pid_cam[pid].append(cam) 203 | self.pid_index[pid].append(index) 204 | else: 205 | for index, (_, pid, cam) in enumerate(data_source): 206 | if (pid < 0): continue 207 | self.index_pid[index] = pid 208 | self.pid_cam[pid].append(cam) 209 | self.pid_index[pid].append(index) 210 | 211 | self.pids = list(self.pid_index.keys()) 212 | self.num_samples = len(self.pids) 213 | 214 | def __len__(self): 215 | return self.num_samples * self.num_instances 216 | 217 | def __iter__(self): 218 | indices = torch.randperm(len(self.pids)).tolist() 219 | ret = [] 220 | 221 | for kid in indices: 222 | i = random.choice(self.pid_index[self.pids[kid]]) 223 | 224 | if self.video: 225 | _, i_pid, i_cam, _ = self.data_source[i] 226 | else: 227 | _, i_pid, i_cam = self.data_source[i] 228 | 229 | cams = self.pid_cam[i_pid] 230 | index = self.pid_index[i_pid] 231 | 232 | unique_cams = set(cams) 233 | cams = np.array(cams) 234 | index = np.array(index) 235 | select_indexes = [] 236 | for cam in unique_cams: 237 | select_indexes.append(np.random.choice(index[cams==cam], size=1, replace=False)) 238 | select_indexes = np.concatenate(select_indexes) 239 | if len(select_indexes)< self.num_instances: 240 | diff_indexes = np.setdiff1d(index, select_indexes) 241 | if len(diff_indexes) == 0: 242 | select_indexes = np.random.choice(select_indexes, size=self.num_instances, replace=True) 243 | elif len(diff_indexes) >= (self.num_instances-len(select_indexes)): 244 | diff_indexes = np.random.choice(diff_indexes, size=(self.num_instances-len(select_indexes)), replace=False) 245 | else: 246 | diff_indexes = np.random.choice(diff_indexes, size=(self.num_instances-len(select_indexes)), replace=True) 247 | select_indexes = np.concatenate([select_indexes, diff_indexes]) 248 | else: 249 | select_indexes = np.random.choice(select_indexes, size=self.num_instances, replace=False) 250 | ret.extend(select_indexes) 251 | return iter(ret) -------------------------------------------------------------------------------- /ice/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torchvision.transforms import * 5 | from PIL import Image, ImageFilter 6 | import random 7 | import math 8 | import numpy as np 9 | 10 | 11 | class GaussianBlur(object): 12 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 13 | 14 | def __init__(self, sigma=[.1, 2.]): 15 | self.sigma = sigma 16 | 17 | def __call__(self, x): 18 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 19 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 20 | return x 21 | 22 | 23 | class AddGaussianNoise(object): 24 | def __init__(self, mean=0., std=1.): 25 | self.std = std 26 | self.mean = mean 27 | 28 | def __call__(self, tensor): 29 | return tensor + torch.randn(tensor.size()) * self.std + self.mean 30 | 31 | # def __repr__(self): 32 | # return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 33 | 34 | 35 | class RectScale(object): 36 | def __init__(self, height, width, interpolation=Image.BILINEAR): 37 | self.height = height 38 | self.width = width 39 | self.interpolation = interpolation 40 | 41 | def __call__(self, img): 42 | w, h = img.size 43 | if h == self.height and w == self.width: 44 | return img 45 | return img.resize((self.width, self.height), self.interpolation) 46 | 47 | 48 | class RandomSizedRectCrop(object): 49 | def __init__(self, height, width, interpolation=Image.BILINEAR): 50 | self.height = height 51 | self.width = width 52 | self.interpolation = interpolation 53 | 54 | def __call__(self, img): 55 | for attempt in range(10): 56 | area = img.size[0] * img.size[1] 57 | target_area = random.uniform(0.64, 1.0) * area 58 | aspect_ratio = random.uniform(2, 3) 59 | 60 | h = int(round(math.sqrt(target_area * aspect_ratio))) 61 | w = int(round(math.sqrt(target_area / aspect_ratio))) 62 | 63 | if w <= img.size[0] and h <= img.size[1]: 64 | x1 = random.randint(0, img.size[0] - w) 65 | y1 = random.randint(0, img.size[1] - h) 66 | 67 | img = img.crop((x1, y1, x1 + w, y1 + h)) 68 | assert(img.size == (w, h)) 69 | 70 | return img.resize((self.width, self.height), self.interpolation) 71 | 72 | # Fallback 73 | scale = RectScale(self.height, self.width, 74 | interpolation=self.interpolation) 75 | return scale(img) 76 | 77 | 78 | class RandomErasing(object): 79 | """ Randomly selects a rectangle region in an image and erases its pixels. 80 | 'Random Erasing Data Augmentation' by Zhong et al. 81 | See https://arxiv.org/pdf/1708.04896.pdf 82 | Args: 83 | probability: The probability that the Random Erasing operation will be performed. 84 | sl: Minimum proportion of erased area against input image. 85 | sh: Maximum proportion of erased area against input image. 86 | r1: Minimum aspect ratio of erased area. 87 | mean: Erasing value. 88 | """ 89 | 90 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 91 | self.probability = probability 92 | self.mean = mean 93 | self.sl = sl 94 | self.sh = sh 95 | self.r1 = r1 96 | 97 | def __call__(self, img): 98 | 99 | if random.uniform(0, 1) >= self.probability: 100 | return img 101 | 102 | for attempt in range(100): 103 | area = img.size()[1] * img.size()[2] 104 | 105 | target_area = random.uniform(self.sl, self.sh) * area 106 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 107 | 108 | h = int(round(math.sqrt(target_area * aspect_ratio))) 109 | w = int(round(math.sqrt(target_area / aspect_ratio))) 110 | 111 | if w < img.size()[2] and h < img.size()[1]: 112 | x1 = random.randint(0, img.size()[1] - h) 113 | y1 = random.randint(0, img.size()[2] - w) 114 | if img.size()[0] == 3: 115 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]#+np.random.normal(0, 0.2) 116 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]#+np.random.normal(0, 0.2) 117 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]#+np.random.normal(0, 0.2) 118 | else: 119 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 120 | return img 121 | 122 | return img 123 | -------------------------------------------------------------------------------- /ice/utils/faiss_rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 5 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 6 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 7 | """ 8 | 9 | import os, sys 10 | import time 11 | import numpy as np 12 | from scipy.spatial.distance import cdist 13 | import gc 14 | import faiss 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | 19 | from .faiss_utils import search_index_pytorch, search_raw_array_pytorch, \ 20 | index_init_gpu, index_init_cpu 21 | 22 | def k_reciprocal_neigh(initial_rank, i, k1): 23 | forward_k_neigh_index = initial_rank[i,:k1+1] 24 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 25 | fi = np.where(backward_k_neigh_index==i)[0] 26 | return forward_k_neigh_index[fi] 27 | 28 | def compute_jaccard_distance(target_features, k1=20, k2=6, print_flag=True, search_option=0, use_float16=False): 29 | end = time.time() 30 | if print_flag: 31 | print('Computing jaccard distance...') 32 | 33 | ngpus = faiss.get_num_gpus() 34 | N = target_features.size(0) 35 | mat_type = np.float16 if use_float16 else np.float32 36 | 37 | if (search_option==0): 38 | # GPU + PyTorch CUDA Tensors (1) 39 | res = faiss.StandardGpuResources() 40 | res.setDefaultNullStreamAllDevices() 41 | _, initial_rank = search_raw_array_pytorch(res, target_features, target_features, k1) 42 | initial_rank = initial_rank.cpu().numpy() 43 | elif (search_option==1): 44 | # GPU + PyTorch CUDA Tensors (2) 45 | res = faiss.StandardGpuResources() 46 | index = faiss.GpuIndexFlatL2(res, target_features.size(-1)) 47 | index.add(target_features.cpu().numpy()) 48 | _, initial_rank = search_index_pytorch(index, target_features, k1) 49 | res.syncDefaultStreamCurrentDevice() 50 | initial_rank = initial_rank.cpu().numpy() 51 | elif (search_option==2): 52 | # GPU 53 | index = index_init_gpu(ngpus, target_features.size(-1)) 54 | index.add(target_features.cpu().numpy()) 55 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 56 | else: 57 | # CPU 58 | index = index_init_cpu(target_features.size(-1)) 59 | index.add(target_features.cpu().numpy()) 60 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 61 | 62 | 63 | nn_k1 = [] 64 | nn_k1_half = [] 65 | for i in range(N): 66 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 67 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1/2)))) 68 | 69 | V = np.zeros((N, N), dtype=mat_type) 70 | for i in range(N): 71 | k_reciprocal_index = nn_k1[i] 72 | k_reciprocal_expansion_index = k_reciprocal_index 73 | for candidate in k_reciprocal_index: 74 | candidate_k_reciprocal_index = nn_k1_half[candidate] 75 | if (len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index)) > 2/3*len(candidate_k_reciprocal_index)): 76 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 77 | 78 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 79 | dist = 2-2*torch.mm(target_features[i].unsqueeze(0).contiguous(), target_features[k_reciprocal_expansion_index].t()) 80 | if use_float16: 81 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type) 82 | else: 83 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy() 84 | 85 | del nn_k1, nn_k1_half 86 | 87 | if k2 != 1: 88 | V_qe = np.zeros_like(V, dtype=mat_type) 89 | for i in range(N): 90 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:], axis=0) 91 | V = V_qe 92 | del V_qe 93 | 94 | del initial_rank 95 | 96 | invIndex = [] 97 | for i in range(N): 98 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num 99 | 100 | jaccard_dist = np.zeros((N, N), dtype=mat_type) 101 | for i in range(N): 102 | temp_min = np.zeros((1,N), dtype=mat_type) 103 | # temp_max = np.zeros((1,N), dtype=mat_type) 104 | indNonZero = np.where(V[i,:] != 0)[0] 105 | indImages = [] 106 | indImages = [invIndex[ind] for ind in indNonZero] 107 | for j in range(len(indNonZero)): 108 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 109 | # temp_max[0,indImages[j]] = temp_max[0,indImages[j]]+np.maximum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 110 | 111 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 112 | # jaccard_dist[i] = 1-temp_min/(temp_max+1e-6) 113 | 114 | del invIndex, V 115 | 116 | pos_bool = (jaccard_dist < 0) 117 | jaccard_dist[pos_bool] = 0.0 118 | if print_flag: 119 | print ("Jaccard distance computing time cost: {}".format(time.time()-end)) 120 | 121 | return jaccard_dist 122 | -------------------------------------------------------------------------------- /ice/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | def swig_ptr_from_FloatTensor(x): 7 | assert x.is_contiguous() 8 | assert x.dtype == torch.float32 9 | return faiss.cast_integer_to_float_ptr( 10 | x.storage().data_ptr() + x.storage_offset() * 4) 11 | 12 | def swig_ptr_from_LongTensor(x): 13 | assert x.is_contiguous() 14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 15 | return faiss.cast_integer_to_long_ptr( 16 | x.storage().data_ptr() + x.storage_offset() * 8) 17 | 18 | def search_index_pytorch(index, x, k, D=None, I=None): 19 | """call the search function of an index with pytorch tensor I/O (CPU 20 | and GPU supported)""" 21 | assert x.is_contiguous() 22 | n, d = x.size() 23 | assert d == index.d 24 | 25 | if D is None: 26 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 27 | else: 28 | assert D.size() == (n, k) 29 | 30 | if I is None: 31 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 32 | else: 33 | assert I.size() == (n, k) 34 | torch.cuda.synchronize() 35 | xptr = swig_ptr_from_FloatTensor(x) 36 | Iptr = swig_ptr_from_LongTensor(I) 37 | Dptr = swig_ptr_from_FloatTensor(D) 38 | index.search_c(n, xptr, 39 | k, Dptr, Iptr) 40 | torch.cuda.synchronize() 41 | return D, I 42 | 43 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 44 | metric=faiss.METRIC_L2): 45 | assert xb.device == xq.device 46 | 47 | nq, d = xq.size() 48 | if xq.is_contiguous(): 49 | xq_row_major = True 50 | elif xq.t().is_contiguous(): 51 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 52 | xq_row_major = False 53 | else: 54 | raise TypeError('matrix should be row or column-major') 55 | 56 | xq_ptr = swig_ptr_from_FloatTensor(xq) 57 | 58 | nb, d2 = xb.size() 59 | assert d2 == d 60 | if xb.is_contiguous(): 61 | xb_row_major = True 62 | elif xb.t().is_contiguous(): 63 | xb = xb.t() 64 | xb_row_major = False 65 | else: 66 | raise TypeError('matrix should be row or column-major') 67 | xb_ptr = swig_ptr_from_FloatTensor(xb) 68 | 69 | if D is None: 70 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 71 | else: 72 | assert D.shape == (nq, k) 73 | assert D.device == xb.device 74 | 75 | if I is None: 76 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 77 | else: 78 | assert I.shape == (nq, k) 79 | assert I.device == xb.device 80 | 81 | D_ptr = swig_ptr_from_FloatTensor(D) 82 | I_ptr = swig_ptr_from_LongTensor(I) 83 | 84 | faiss.bruteForceKnn(res, metric, 85 | xb_ptr, xb_row_major, nb, 86 | xq_ptr, xq_row_major, nq, 87 | d, k, D_ptr, I_ptr) 88 | 89 | return D, I 90 | 91 | def index_init_gpu(ngpus, feat_dim): 92 | flat_config = [] 93 | for i in range(ngpus): 94 | cfg = faiss.GpuIndexFlatConfig() 95 | cfg.useFloat16 = False 96 | cfg.device = i 97 | flat_config.append(cfg) 98 | 99 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 100 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 101 | index = faiss.IndexShards(feat_dim) 102 | for sub_index in indexes: 103 | index.add_shard(sub_index) 104 | index.reset() 105 | return index 106 | 107 | def index_init_cpu(feat_dim): 108 | return faiss.IndexFlatL2(feat_dim) 109 | -------------------------------------------------------------------------------- /ice/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /ice/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | from torch.optim.lr_scheduler import * 9 | 10 | 11 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 12 | # separating MultiStepLR with WarmupLR 13 | # but the current LRScheduler design doesn't allow it 14 | 15 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 16 | def __init__( 17 | self, 18 | optimizer, 19 | milestones, 20 | gamma=0.1, 21 | warmup_factor=1.0 / 3, 22 | warmup_iters=500, 23 | warmup_method="linear", 24 | last_epoch=-1, 25 | ): 26 | if not list(milestones) == sorted(milestones): 27 | raise ValueError( 28 | "Milestones should be a list of" " increasing integers. Got {}", 29 | milestones, 30 | ) 31 | 32 | if warmup_method not in ("constant", "linear"): 33 | raise ValueError( 34 | "Only 'constant' or 'linear' warmup_method accepted" 35 | "got {}".format(warmup_method) 36 | ) 37 | self.milestones = milestones 38 | self.gamma = gamma 39 | self.warmup_factor = warmup_factor 40 | self.warmup_iters = warmup_iters 41 | self.warmup_method = warmup_method 42 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 43 | 44 | def get_lr(self): 45 | warmup_factor = 1 46 | if self.last_epoch < self.warmup_iters: 47 | if self.warmup_method == "constant": 48 | warmup_factor = self.warmup_factor 49 | elif self.warmup_method == "linear": 50 | alpha = float(self.last_epoch) / float(self.warmup_iters) 51 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 52 | return [ 53 | base_lr 54 | * warmup_factor 55 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 56 | for base_lr in self.base_lrs 57 | ] 58 | -------------------------------------------------------------------------------- /ice/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /ice/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /ice/utils/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Source: https://github.com/zhunzhong07/person-re-ranking 5 | Created on Mon Jun 26 14:46:56 2017 6 | @author: luohao 7 | Modified by Yixiao Ge, 2020-3-14. 8 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 9 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 10 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 11 | API 12 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 13 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 14 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 15 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 16 | Returns: 17 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 18 | """ 19 | from __future__ import absolute_import 20 | from __future__ import print_function 21 | from __future__ import division 22 | 23 | __all__ = ['re_ranking'] 24 | 25 | import numpy as np 26 | import time 27 | 28 | import torch 29 | import torch.nn.functional as F 30 | 31 | 32 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 33 | # The following naming, e.g. gallery_num, is different from outer scope. 34 | # Don't care about it. 35 | 36 | original_dist = np.concatenate( 37 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 38 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 39 | axis=0) 40 | original_dist = np.power(original_dist, 2).astype(np.float32) 41 | original_dist = np.transpose(1. * original_dist / np.max(original_dist, axis=0)) 42 | V = np.zeros_like(original_dist).astype(np.float32) 43 | initial_rank = np.argsort(original_dist).astype(np.int32) 44 | 45 | query_num = q_g_dist.shape[0] 46 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 47 | all_num = gallery_num 48 | 49 | for i in range(all_num): 50 | # k-reciprocal neighbors 51 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 52 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 53 | fi = np.where(backward_k_neigh_index == i)[0] 54 | k_reciprocal_index = forward_k_neigh_index[fi] 55 | k_reciprocal_expansion_index = k_reciprocal_index 56 | for j in range(len(k_reciprocal_index)): 57 | candidate = k_reciprocal_index[j] 58 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2.)) + 1] 59 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 60 | :int(np.around(k1 / 2.)) + 1] 61 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 62 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 63 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2. / 3 * len( 64 | candidate_k_reciprocal_index): 65 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 66 | 67 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 68 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 69 | V[i, k_reciprocal_expansion_index] = 1. * weight / np.sum(weight) 70 | original_dist = original_dist[:query_num, ] 71 | if k2 != 1: 72 | V_qe = np.zeros_like(V, dtype=np.float32) 73 | for i in range(all_num): 74 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 75 | V = V_qe 76 | del V_qe 77 | del initial_rank 78 | invIndex = [] 79 | for i in range(gallery_num): 80 | invIndex.append(np.where(V[:, i] != 0)[0]) 81 | 82 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float32) 83 | 84 | for i in range(query_num): 85 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float32) 86 | indNonZero = np.where(V[i, :] != 0)[0] 87 | indImages = [] 88 | indImages = [invIndex[ind] for ind in indNonZero] 89 | for j in range(len(indNonZero)): 90 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 91 | V[indImages[j], indNonZero[j]]) 92 | jaccard_dist[i] = 1 - temp_min / (2. - temp_min) 93 | 94 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 95 | del original_dist 96 | del V 97 | del jaccard_dist 98 | final_dist = final_dist[:query_num, query_num:] 99 | return final_dist 100 | 101 | 102 | def k_reciprocal_neigh(initial_rank, i, k1): 103 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 104 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 105 | fi = torch.nonzero(backward_k_neigh_index == i)[:, 0] 106 | return forward_k_neigh_index[fi] 107 | 108 | 109 | def compute_jaccard_dist(target_features, k1=20, k2=6, print_flag=True, 110 | lambda_value=0, source_features=None, use_gpu=False): 111 | end = time.time() 112 | N = target_features.size(0) 113 | if (use_gpu): 114 | # accelerate matrix distance computing 115 | target_features = target_features.cuda() 116 | if (source_features is not None): 117 | source_features = source_features.cuda() 118 | 119 | if ((lambda_value > 0) and (source_features is not None)): 120 | M = source_features.size(0) 121 | sour_tar_dist = torch.pow(target_features, 2).sum(dim=1, keepdim=True).expand(N, M) + \ 122 | torch.pow(source_features, 2).sum(dim=1, keepdim=True).expand(M, N).t() 123 | sour_tar_dist.addmm_(1, -2, target_features, source_features.t()) 124 | sour_tar_dist = 1 - torch.exp(-sour_tar_dist) 125 | sour_tar_dist = sour_tar_dist.cpu() 126 | source_dist_vec = sour_tar_dist.min(1)[0] 127 | del sour_tar_dist 128 | source_dist_vec /= source_dist_vec.max() 129 | source_dist = torch.zeros(N, N) 130 | for i in range(N): 131 | source_dist[i, :] = source_dist_vec + source_dist_vec[i] 132 | del source_dist_vec 133 | 134 | if print_flag: 135 | print('Computing original distance...') 136 | 137 | original_dist = torch.pow(target_features, 2).sum(dim=1, keepdim=True) * 2 138 | original_dist = original_dist.expand(N, N) - 2 * torch.mm(target_features, target_features.t()) 139 | original_dist /= original_dist.max(0)[0] 140 | original_dist = original_dist.t() 141 | initial_rank = torch.argsort(original_dist, dim=-1) 142 | 143 | original_dist = original_dist.cpu() 144 | initial_rank = initial_rank.cpu() 145 | all_num = gallery_num = original_dist.size(0) 146 | 147 | del target_features 148 | if (source_features is not None): 149 | del source_features 150 | 151 | if print_flag: 152 | print('Computing Jaccard distance...') 153 | 154 | nn_k1 = [] 155 | nn_k1_half = [] 156 | for i in range(all_num): 157 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 158 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1 / 2)))) 159 | 160 | V = torch.zeros(all_num, all_num) 161 | for i in range(all_num): 162 | k_reciprocal_index = nn_k1[i] 163 | k_reciprocal_expansion_index = k_reciprocal_index 164 | for candidate in k_reciprocal_index: 165 | candidate_k_reciprocal_index = nn_k1_half[candidate] 166 | if (len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 167 | candidate_k_reciprocal_index)): 168 | k_reciprocal_expansion_index = torch.cat((k_reciprocal_expansion_index, candidate_k_reciprocal_index)) 169 | 170 | k_reciprocal_expansion_index = torch.unique(k_reciprocal_expansion_index) ## element-wise unique 171 | weight = torch.exp(-original_dist[i, k_reciprocal_expansion_index]) 172 | V[i, k_reciprocal_expansion_index] = weight / torch.sum(weight) 173 | 174 | if k2 != 1: 175 | k2_rank = initial_rank[:, :k2].clone().view(-1) 176 | V_qe = V[k2_rank] 177 | V_qe = V_qe.view(initial_rank.size(0), k2, -1).sum(1) 178 | V_qe /= k2 179 | V = V_qe 180 | del V_qe 181 | del initial_rank 182 | 183 | invIndex = [] 184 | for i in range(gallery_num): 185 | invIndex.append(torch.nonzero(V[:, i])[:, 0]) # len(invIndex)=all_num 186 | 187 | jaccard_dist = torch.zeros_like(original_dist) 188 | for i in range(all_num): 189 | temp_min = torch.zeros(1, gallery_num) 190 | indNonZero = torch.nonzero(V[i, :])[:, 0] 191 | indImages = [] 192 | indImages = [invIndex[ind] for ind in indNonZero] 193 | for j in range(len(indNonZero)): 194 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + torch.min(V[i, indNonZero[j]], 195 | V[indImages[j], indNonZero[j]]) 196 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 197 | del invIndex 198 | 199 | del V 200 | 201 | pos_bool = (jaccard_dist < 0) 202 | jaccard_dist[pos_bool] = 0.0 203 | if print_flag: 204 | print("Time cost: {}".format(time.time() - end)) 205 | 206 | if (lambda_value > 0): 207 | return jaccard_dist * (1 - lambda_value) + source_dist * lambda_value 208 | else: 209 | return jaccard_dist -------------------------------------------------------------------------------- /ice/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | # checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 35 | print("=> Loaded checkpoint '{}'".format(fpath)) 36 | return checkpoint 37 | else: 38 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 39 | 40 | 41 | def copy_state_dict(state_dict, model, strip=None): 42 | tgt_state = model.state_dict() 43 | copied_names = set() 44 | for name, param in state_dict.items(): 45 | if strip is not None and name.startswith(strip): 46 | name = name[len(strip):] 47 | if name not in tgt_state: 48 | continue 49 | if isinstance(param, Parameter): 50 | param = param.data 51 | if param.size() != tgt_state[name].size(): 52 | print('mismatch:', name, param.size(), tgt_state[name].size()) 53 | continue 54 | tgt_state[name].copy_(param) 55 | copied_names.add(name) 56 | 57 | missing = set(tgt_state.keys()) - copied_names 58 | if len(missing) > 0: 59 | print("missing keys in state_dict:", missing) 60 | 61 | return model 62 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup(name='ICE', 5 | version='1.0.0', 6 | description='ICE: Inter-instance Contrastive Encoding for Unsupervised Person Re-identification', 7 | author='Hao Chen', 8 | author_email='hao.chen@inria.fr', 9 | url='https://github.com/chenhao2345/ICE', 10 | install_requires=[ 11 | 'numpy', 'torch==1.7.0', 'torchvision==0.8.0', 12 | 'six', 'h5py', 'Pillow', 'scipy', 13 | 'scikit-learn', 'metric-learn', 'faiss-gpu'], 14 | packages=find_packages(), 15 | keywords=[ 16 | 'Contrastive Learning', 17 | 'Person Re-identification' 18 | ]) 19 | 20 | --------------------------------------------------------------------------------