├── LICENSE ├── README.md ├── examples ├── test.py ├── train_baseline.py └── train_idm.py ├── idm ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dukemtmc.py │ ├── market1501.py │ ├── msmt17.py │ ├── personx.py │ └── unreal.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ ├── rank.py │ ├── rank_cylib │ │ ├── Makefile │ │ ├── rank_cy.c │ │ ├── rank_cy.pyx │ │ ├── setup.py │ │ └── test_cython.py │ └── ranking.py ├── evaluators.py ├── loss │ ├── __init__.py │ ├── crossentropy.py │ ├── idm_loss.py │ ├── triplet.py │ └── triplet_xbm.py ├── models │ ├── __init__.py │ ├── dsbn.py │ ├── idm_dsbn.py │ ├── idm_module.py │ ├── resnet.py │ ├── resnet_ibn.py │ ├── resnet_ibn_a.py │ ├── resnet_ibn_idm.py │ ├── resnet_idm.py │ └── xbm.py ├── trainers.py └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── preprocessor.py │ ├── sampler.py │ └── transforms.py │ ├── faiss_rerank.py │ ├── faiss_utils.py │ ├── logging.py │ ├── lr_scheduler.py │ ├── meters.py │ ├── osutils.py │ ├── rerank.py │ └── serialization.py └── scripts ├── run_idm.sh ├── run_idm_xbm.sh ├── run_naive_baseline.sh ├── run_strong_baseline.sh ├── run_test_baseline.sh └── run_test_idm.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 SikaStar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python >=3.7](https://img.shields.io/badge/Python->=3.7-blue.svg) 2 | ![PyTorch >=1.1](https://img.shields.io/badge/PyTorch->=1.1-yellow.svg) 3 | 4 | ## News 5 | IDM has been extended to [IDM++](https://arxiv.org/abs/2203.01682). IDM++ is a strong cross-domain person re-ID method, which achieves new state of the art under both the unsupervised domain adaptation (UDA) and domain generalization (DG) re-ID scenarios. The code will be updated. 6 | 7 | ## Citation 8 | If you find our work is useful for your research, please kindly cite our paper 9 | ``` 10 | @inproceedings{dai2021idm, 11 | title={IDM: An Intermediate Domain Module for Domain Adaptive Person Re-ID}, 12 | author={Dai, Yongxing and Liu, Jun and Sun, Yifan and Tong, Zekun and Zhang, Chi and Duan, Ling-Yu}, 13 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 14 | year={2021} 15 | } 16 | 17 | @article{dai2022bridging, 18 | title={Bridging the Source-to-target Gap for Cross-domain Person Re-Identification with Intermediate Domains}, 19 | author={Dai, Yongxing and Sun, Yifan and Liu, Jun and Tong, Zekun and Yang, Yi and Duan, Ling-Yu}, 20 | journal={arXiv preprint arXiv:2203.01682}, 21 | year={2022} 22 | } 23 | ``` 24 | If you have any questions, please leave an issue or contact me: yongxingdai@pku.edu.cn 25 | 26 | 27 | # Intermediate Domain Module (IDM) 28 | 29 | This repository is the official implementation for [IDM: An Intermediate Domain Module for Domain Adaptive Person Re-ID](http://arxiv.org/abs/2108.02413), which is accepted by [ICCV 2021 (Oral)](http://iccv2021.thecvf.com/node/44). 30 | 31 | `IDM` achieves state-of-the-art performances on the **unsupervised domain adaptation** task for person re-ID. 32 | 33 | ## Requirements 34 | 35 | ### Installation 36 | 37 | ```shell 38 | git clone https://github.com/SikaStar/IDM.git 39 | cd IDM/idm/evaluation_metrics/rank_cylib && make all 40 | ``` 41 | 42 | ### Prepare Datasets 43 | 44 | ```shell 45 | cd examples && mkdir data 46 | ``` 47 | Download the person re-ID datasets [Market-1501](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Zheng_Scalable_Person_Re-Identification_ICCV_2015_paper.pdf), [DukeMTMC-ReID](https://arxiv.org/abs/1701.07717), [MSMT17](https://arxiv.org/abs/1711.08565), [PersonX](https://github.com/sxzrt/Instructions-of-the-PersonX-dataset#data-for-visda2020-chanllenge), and 48 | [UnrealPerson](https://github.com/FlyHighest/UnrealPerson). 49 | Then unzip them under the directory like 50 | ``` 51 | IDM/examples/data 52 | ├── dukemtmc 53 | │   └── DukeMTMC-reID 54 | ├── market1501 55 | │   └── Market-1501-v15.09.15 56 | ├── msmt17 57 | │   └── MSMT17_V1 58 | ├── personx 59 | │   └── PersonX 60 | └── unreal 61 | ├── list_unreal_train.txt 62 | └── unreal_vX.Y 63 | ``` 64 | 65 | ### Prepare ImageNet Pre-trained Models for IBN-Net 66 | When training with the backbone of [IBN-ResNet](https://arxiv.org/abs/1807.09441), you need to download the ImageNet-pretrained model from this [link](https://drive.google.com/drive/folders/1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S) and save it under the path of `logs/pretrained/`. 67 | ```shell 68 | mkdir logs && cd logs 69 | mkdir pretrained 70 | ``` 71 | The file tree should be 72 | ``` 73 | IDM/logs 74 | └── pretrained 75 |    └── resnet50_ibn_a.pth.tar 76 | ``` 77 | ImageNet-pretrained models for **ResNet-50** will be automatically downloaded in the python script. 78 | 79 | 80 | ## Training 81 | 82 | We utilize 4 GTX-2080TI GPUs for training. **Note that** 83 | 84 | + The source and target domains are trained jointly. 85 | + For baseline methods, use `-a resnet50` for the backbone of ResNet-50, and `-a resnet_ibn50a` for the backbone of IBN-ResNet. 86 | + For IDM, use `-a resnet50_idm` to insert IDM into the backbone of ResNet-50, and `-a resnet_ibn50a_idm` to insert IDM into the backbone of IBN-ResNet. 87 | + For strong baseline, use `--use-xbm` to implement [XBM](https://arxiv.org/abs/1912.06798) (a variant of Memory Bank). 88 | 89 | 90 | ### Baseline Methods 91 | To train the baseline methods in the paper, run commands like: 92 | ```shell 93 | # Naive Baseline 94 | CUDA_VISIBLE_DEVICES=0,1,2,3 sh scripts/run_naive_baseline.sh ${source} ${target} ${arch} 95 | 96 | # Strong Baseline 97 | CUDA_VISIBLE_DEVICES=0,1,2,3 sh scripts/run_strong_baseline.sh ${source} ${target} ${arch} 98 | ``` 99 | 100 | **Some examples:** 101 | ```shell 102 | ### market1501 -> dukemtmc ### 103 | 104 | # ResNet-50 105 | CUDA_VISIBLE_DEVICES=0,1,2,3 sh scripts/run_strong_baseline.sh market1501 dukemtmc resnet50 106 | 107 | # IBN-ResNet-50 108 | CUDA_VISIBLE_DEVICES=0,1,2,3 sh scripts/run_strong_baseline.sh market1501 dukemtmc resnet_ibn50a 109 | ``` 110 | 111 | ### Training with IDM 112 | 113 | To train the models with our IDM, run commands like: 114 | ```shell 115 | # Naive Baseline + IDM 116 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 117 | sh scripts/run_idm.sh ${source} ${target} ${arch} ${stage} ${mu1} ${mu2} ${mu3} 118 | 119 | # Strong Baseline + IDM 120 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 121 | sh scripts/run_idm_xbm.sh ${source} ${target} ${arch} ${stage} ${mu1} ${mu2} ${mu3} 122 | ``` 123 | 124 | + Defaults: `--stage 0 --mu1 0.7 --mu2 0.1 --mu3 1.0` 125 | 126 | **Some examples:** 127 | ```shell 128 | ### market1501 -> dukemtmc ### 129 | 130 | # ResNet-50 + IDM 131 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 132 | sh scripts/run_idm_xbm.sh market1501 dukemtmc resnet50_idm 0 0.7 0.1 1.0 133 | 134 | # IBN-ResNet-50 + IDM 135 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 136 | sh scripts/run_idm_xbm.sh market1501 dukemtmc resnet_ibn50a_idm 0 0.7 0.1 1.0 137 | ``` 138 | 139 | ## Evaluation 140 | 141 | We utilize 1 GTX-2080TI GPU for testing. **Note that** 142 | 143 | + use `--dsbn` for domain adaptive models, and add `--test-source` if you want to test on the source domain; 144 | + use `-a resnet50` for the backbone of ResNet-50, and `-a resnet_ibn50a` for the backbone of IBN-ResNet. 145 | + use `-a resnet50_idm` for ResNet-50 + IDM, and `-a resnet_ibn50a_idm` for IBN-ResNet + IDM. 146 | 147 | To evaluate the **baseline model** on the **target-domain** dataset, run: 148 | ```shell 149 | CUDA_VISIBLE_DEVICES=0 \ 150 | python3 examples/test.py --dsbn -d ${dataset} -a ${arch} --resume ${resume} 151 | ``` 152 | 153 | To evaluate the **baseline model** on the **source-domain** dataset, run: 154 | ```shell 155 | CUDA_VISIBLE_DEVICES=0 \ 156 | python3 examples/test.py --dsbn --test-source -d ${dataset} -a ${arch} --resume ${resume} 157 | ``` 158 | 159 | To evaluate the **IDM model** on the **target-domain** dataset, run: 160 | ```shell 161 | CUDA_VISIBLE_DEVICES=0 \ 162 | python3 examples/test.py --dsbn-idm -d ${dataset} -a ${arch} --resume ${resume} --stage ${stage} 163 | ``` 164 | 165 | To evaluate the **IDM model** on the **source-domain** dataset, run: 166 | ```shell 167 | CUDA_VISIBLE_DEVICES=0 \ 168 | python3 examples/test.py --dsbn-idm --test-source -d ${dataset} -a ${arch} --resume ${resume} --stage ${stage} 169 | ``` 170 | 171 | 172 | **Some examples:** 173 | ```shell 174 | ### market1501 -> dukemtmc ### 175 | 176 | # evaluate the target domain "dukemtmc" on the strong baseline model 177 | CUDA_VISIBLE_DEVICES=0 \ 178 | python3 examples/test.py --dsbn -d dukemtmc -a resnet50 \ 179 | --resume logs/resnet50_strong_baseline/market1501-TO-dukemtmc/model_best.pth.tar 180 | 181 | # evaluate the source domain "market1501" on the strong baseline model 182 | CUDA_VISIBLE_DEVICES=0 \ 183 | python3 examples/test.py --dsbn --test-source -d market1501 -a resnet50 \ 184 | --resume logs/resnet50_strong_baseline/market1501-TO-dukemtmc/model_best.pth.tar 185 | 186 | # evaluate the target domain "dukemtmc" on the IDM model (after stage-0) 187 | python3 examples/test.py --dsbn-idm -d dukemtmc -a resnet50_idm \ 188 | --resume logs/resnet50_idm_xbm/market1501-TO-dukemtmc/model_best.pth.tar --stage 0 189 | 190 | # evaluate the target domain "dukemtmc" on the IDM model (after stage-0) 191 | python3 examples/test.py --dsbn-idm --test-source -d market1501 -a resnet50_idm \ 192 | --resume logs/resnet50_idm_xbm/market1501-TO-dukemtmc/model_best.pth.tar --stage 0 193 | 194 | ``` 195 | 196 | ## Acknowledgement 197 | Our code is based on [MMT](https://github.com/yxgeee/MMT) and [SpCL](https://github.com/yxgeee/SpCL). Thanks for [Yixiao's](https://geyixiao.com/) wonderful works. 198 | 199 | 200 | 201 | -------------------------------------------------------------------------------- /examples/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import random 5 | import numpy as np 6 | import sys 7 | 8 | import torch 9 | from torch import nn 10 | from torch.backends import cudnn 11 | from torch.utils.data import DataLoader 12 | 13 | sys.path.append(".") 14 | from idm import datasets 15 | from idm import models 16 | from idm.models.dsbn import convert_dsbn, convert_bn 17 | from idm.models.idm_dsbn import convert_dsbn_idm, convert_bn_idm 18 | from idm.evaluators import Evaluator 19 | from idm.utils.data import transforms as T 20 | from idm.utils.data.preprocessor import Preprocessor 21 | from idm.utils.logging import Logger 22 | from idm.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict 23 | 24 | 25 | def get_data(name, data_dir, height, width, batch_size, workers): 26 | root = osp.join(data_dir, name) 27 | 28 | dataset = datasets.create(name, root) 29 | 30 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225]) 32 | 33 | test_transformer = T.Compose([ 34 | T.Resize((height, width), interpolation=3), 35 | T.ToTensor(), 36 | normalizer 37 | ]) 38 | 39 | test_loader = DataLoader( 40 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 41 | root=dataset.images_dir, transform=test_transformer), 42 | batch_size=batch_size, num_workers=workers, 43 | shuffle=False, pin_memory=True) 44 | 45 | return dataset, test_loader 46 | 47 | 48 | def filter_layers(stage): 49 | layer_names = ['conv', 'layer1', 'layer2', 'layer3', 'layer4', 'feat_bn'] 50 | ori_bn_names = [] 51 | idm_bn_names = [] 52 | for i in range(len(layer_names)): 53 | if i < stage+1: 54 | ori_bn_names.append(layer_names[i]) 55 | else: 56 | idm_bn_names.append(layer_names[i]) 57 | return idm_bn_names 58 | 59 | 60 | def main(): 61 | args = parser.parse_args() 62 | 63 | if args.seed is not None: 64 | random.seed(args.seed) 65 | np.random.seed(args.seed) 66 | torch.manual_seed(args.seed) 67 | cudnn.deterministic = True 68 | 69 | main_worker(args) 70 | 71 | 72 | def main_worker(args): 73 | cudnn.benchmark = True 74 | 75 | log_dir = osp.dirname(args.resume) 76 | sys.stdout = Logger(osp.join(log_dir, 'log_test.txt')) 77 | print("==========\nArgs:{}\n==========".format(args)) 78 | 79 | # Create data loaders 80 | dataset, test_loader = get_data(args.dataset, args.data_dir, args.height, 81 | args.width, args.batch_size, args.workers) 82 | 83 | # Create model 84 | model = models.create(args.arch, pretrained=False, num_features=args.features, dropout=args.dropout, num_classes=0) 85 | if args.dsbn: 86 | print("==> Load the model with domain-specific BNs") 87 | convert_dsbn(model) 88 | 89 | if args.dsbn_idm: 90 | print("==> Load the model with domain-specific BNs (IDM)") 91 | idm_bn_names = filter_layers(args.stage) 92 | convert_dsbn_idm(model, idm_bn_names, idm=False) 93 | 94 | # Load from checkpoint 95 | checkpoint = load_checkpoint(args.resume) 96 | copy_state_dict(checkpoint['state_dict'], model, strip='module.') 97 | 98 | if args.dsbn: 99 | print("==> Test with {}-domain BNs".format("source" if args.test_source else "target")) 100 | convert_bn(model, use_target=(not args.test_source)) 101 | 102 | if args.dsbn_idm: 103 | print("==> Test with {}-domain BNs".format("source" if args.test_source else "target")) 104 | convert_bn_idm(model, use_target=(not args.test_source)) 105 | 106 | model.cuda() 107 | model = nn.DataParallel(model) 108 | 109 | # Evaluator 110 | model.eval() 111 | evaluator = Evaluator(model) 112 | print("Test on {}:".format(args.dataset)) 113 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=True, rerank=args.rerank) 114 | return 115 | 116 | if __name__ == '__main__': 117 | parser = argparse.ArgumentParser(description="Testing the model") 118 | # data 119 | parser.add_argument('-d', '--dataset', type=str, required=True, 120 | choices=datasets.names()) 121 | parser.add_argument('-b', '--batch-size', type=int, default=256) 122 | parser.add_argument('-j', '--workers', type=int, default=4) 123 | parser.add_argument('--height', type=int, default=256, help="input height") 124 | parser.add_argument('--width', type=int, default=128, help="input width") 125 | # model 126 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 127 | choices=models.names()) 128 | parser.add_argument('--features', type=int, default=0) 129 | parser.add_argument('--dropout', type=float, default=0) 130 | parser.add_argument('--resume', type=str, required=True, metavar='PATH') 131 | 132 | # idm parameters 133 | parser.add_argument('--stage', type=int, default=0, 134 | help="insert IDM module after stage 0/1/2/3/4") 135 | # testing configs 136 | parser.add_argument('--rerank', action='store_true', 137 | help="evaluation only") 138 | parser.add_argument('--dsbn', action='store_true', 139 | help="test on the model with domain-specific BN") 140 | parser.add_argument('--dsbn-idm', action='store_true', 141 | help="test on the model with domain-specific BN (IDM)") 142 | parser.add_argument('--test-source', action='store_true', 143 | help="test on the source domain") 144 | parser.add_argument('--seed', type=int, default=1) 145 | # path 146 | working_dir = osp.dirname(osp.abspath(__file__)) 147 | parser.add_argument('--data-dir', type=str, metavar='PATH', 148 | default=osp.join(working_dir, 'data')) 149 | main() 150 | -------------------------------------------------------------------------------- /examples/train_baseline.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import random 5 | import numpy as np 6 | import sys 7 | import collections 8 | import copy 9 | import time 10 | from datetime import timedelta 11 | 12 | from sklearn.cluster import DBSCAN, KMeans 13 | from sklearn.preprocessing import normalize 14 | 15 | 16 | import torch 17 | from torch import nn 18 | from torch.backends import cudnn 19 | from torch.utils.data import DataLoader 20 | import torch.nn.functional as F 21 | 22 | sys.path.append(".") 23 | from idm import datasets 24 | from idm import models 25 | from idm.models.dsbn import convert_dsbn, convert_bn 26 | from idm.models.idm_dsbn import convert_dsbn_idm, convert_bn_idm 27 | from idm.models.xbm import XBM 28 | from idm.trainers import Baseline_Trainer, IDM_Trainer 29 | from idm.evaluators import Evaluator, extract_features 30 | from idm.utils.data import IterLoader 31 | from idm.utils.data import transforms as T 32 | from idm.utils.data.sampler import RandomMultipleGallerySampler 33 | from idm.utils.data.preprocessor import Preprocessor 34 | from idm.utils.logging import Logger 35 | from idm.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict 36 | from idm.utils.rerank import compute_jaccard_distance 37 | 38 | 39 | start_epoch = best_mAP = 0 40 | 41 | def get_data(name, data_dir): 42 | root = osp.join(data_dir, name) 43 | dataset = datasets.create(name, root) 44 | return dataset 45 | 46 | def get_train_loader(args, dataset, height, width, batch_size, workers, 47 | num_instances, iters, trainset=None): 48 | 49 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 50 | std=[0.229, 0.224, 0.225]) 51 | train_transformer = T.Compose([ 52 | T.Resize((height, width), interpolation=3), 53 | T.RandomHorizontalFlip(p=0.5), 54 | T.Pad(10), 55 | T.RandomCrop((height, width)), 56 | T.ToTensor(), 57 | normalizer, 58 | T.RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406]) 59 | ]) 60 | 61 | train_set = sorted(dataset.train) if trainset is None else sorted(trainset) 62 | rmgs_flag = num_instances > 0 63 | if rmgs_flag: 64 | sampler = RandomMultipleGallerySampler(train_set, num_instances) 65 | else: 66 | sampler = None 67 | train_loader = IterLoader( 68 | DataLoader(Preprocessor(train_set, root=dataset.images_dir, transform=train_transformer), 69 | batch_size=batch_size, num_workers=workers, sampler=sampler, 70 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=iters) 71 | 72 | return train_loader 73 | 74 | 75 | def get_test_loader(dataset, height, width, batch_size, workers, testset=None): 76 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 77 | std=[0.229, 0.224, 0.225]) 78 | 79 | test_transformer = T.Compose([ 80 | T.Resize((height, width), interpolation=3), 81 | T.ToTensor(), 82 | normalizer 83 | ]) 84 | 85 | if (testset is None): 86 | testset = list(set(dataset.query) | set(dataset.gallery)) 87 | 88 | test_loader = DataLoader( 89 | Preprocessor(testset, root=dataset.images_dir, transform=test_transformer), 90 | batch_size=batch_size, num_workers=workers, 91 | shuffle=False, pin_memory=True) 92 | 93 | return test_loader 94 | 95 | 96 | def create_model(args): 97 | model = models.create(args.arch, num_features=args.features, norm=False, dropout=args.dropout, 98 | num_classes=args.nclass) 99 | 100 | convert_dsbn(model) 101 | 102 | # use CUDA 103 | model.cuda() 104 | model = nn.DataParallel(model) 105 | return model 106 | 107 | 108 | def main(): 109 | args = parser.parse_args() 110 | 111 | if args.seed is not None: 112 | random.seed(args.seed) 113 | np.random.seed(args.seed) 114 | torch.manual_seed(args.seed) 115 | cudnn.deterministic = True 116 | 117 | main_worker(args) 118 | 119 | 120 | def main_worker(args): 121 | global start_epoch, best_mAP 122 | start_time = time.monotonic() 123 | 124 | cudnn.benchmark = True 125 | 126 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 127 | print("==========\nArgs:{}\n==========".format(args)) 128 | 129 | # Create datasets 130 | iters = args.iters if (args.iters>0) else None 131 | print("==> Load source-domain dataset") 132 | dataset_source = get_data(args.dataset_source, args.data_dir) 133 | print("==> Load target-domain dataset") 134 | dataset_target = get_data(args.dataset_target, args.data_dir) 135 | test_loader_target = get_test_loader(dataset_target, args.height, args.width, args.batch_size, args.workers) 136 | train_loader_source = get_train_loader(args, dataset_source, args.height, args.width, 137 | args.batch_size, args.workers, args.num_instances, iters) 138 | 139 | source_classes = dataset_source.num_train_pids 140 | 141 | args.nclass = source_classes+len(dataset_target.train) 142 | args.s_class = source_classes 143 | args.t_class = len(dataset_target.train) 144 | 145 | # Create model 146 | model = create_model(args) 147 | print(model) 148 | 149 | # Create XBM 150 | 151 | datasetSize = len(dataset_source.train)+len(dataset_target.train) 152 | 153 | args.memorySize = int(args.ratio*datasetSize) 154 | xbm = XBM(args.memorySize, args.featureSize) 155 | print('XBM memory size = ', args.memorySize) 156 | # Initialize source-domain class centroids 157 | sour_cluster_loader = get_test_loader(dataset_source, args.height, args.width, 158 | args.batch_size, args.workers, testset=sorted(dataset_source.train)) 159 | source_features, _ = extract_features(model, sour_cluster_loader, print_freq=50) 160 | sour_fea_dict = collections.defaultdict(list) 161 | for f, pid, _ in sorted(dataset_source.train): 162 | sour_fea_dict[pid].append(source_features[f].unsqueeze(0)) 163 | source_centers = [torch.cat(sour_fea_dict[pid],0).mean(0) for pid in sorted(sour_fea_dict.keys())] 164 | source_centers = torch.stack(source_centers,0) 165 | source_centers = F.normalize(source_centers, dim=1) 166 | model.module.classifier.weight.data[0:source_classes].copy_(source_centers.cuda()) 167 | 168 | del source_centers, sour_cluster_loader, sour_fea_dict 169 | 170 | # Evaluator 171 | evaluator = Evaluator(model) 172 | 173 | # Optimizer 174 | params = [{"params": [value]} for _, value in model.named_parameters() if value.requires_grad] 175 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 176 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1) 177 | 178 | # Trainer 179 | trainer = Baseline_Trainer(model, xbm, args.nclass, margin=args.margin) 180 | 181 | for epoch in range(args.epochs): 182 | 183 | tgt_cluster_loader = get_test_loader(dataset_target, args.height, args.width, 184 | args.batch_size, args.workers, testset=sorted(dataset_target.train)) 185 | target_features, _ = extract_features(model, tgt_cluster_loader, print_freq=50) 186 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in 187 | sorted(dataset_target.train)], 0) 188 | 189 | del tgt_cluster_loader 190 | print('==> Create pseudo labels for unlabeled target domain with DBSCAN clustering') 191 | 192 | rerank_dist = compute_jaccard_distance(target_features, k1=args.k1, k2=args.k2, use_gpu=False).numpy() 193 | print('Clustering and labeling...') 194 | eps = args.eps 195 | cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=-1) 196 | labels = cluster.fit_predict(rerank_dist) 197 | num_ids = len(set(labels)) - (1 if -1 in labels else 0) 198 | args.t_class = num_ids 199 | 200 | print('\n Clustered into {} classes \n'.format(args.t_class)) 201 | 202 | 203 | # generate new dataset and calculate cluster centers 204 | new_dataset = [] 205 | cluster_centers = collections.defaultdict(list) 206 | for i, ((fname, _, cid), label) in enumerate(zip(sorted(dataset_target.train), labels)): 207 | if label == -1: continue 208 | new_dataset.append((fname, source_classes+label, cid)) 209 | cluster_centers[label].append(target_features[i]) 210 | 211 | cluster_centers = [torch.stack(cluster_centers[idx]).mean(0) for idx in sorted(cluster_centers.keys())] 212 | cluster_centers = torch.stack(cluster_centers) 213 | model.module.classifier.weight.data[args.s_class:args.s_class+args.t_class].copy_(F.normalize(cluster_centers, dim=1).float().cuda()) 214 | 215 | del cluster_centers, target_features 216 | 217 | train_loader_target = get_train_loader(args, dataset_target, args.height, args.width, 218 | args.batch_size, args.workers, args.num_instances, iters, 219 | trainset=new_dataset) 220 | 221 | train_loader_source.new_epoch() 222 | train_loader_target.new_epoch() 223 | trainer.train(epoch, train_loader_source, train_loader_target, args.s_class, args.t_class, optimizer, 224 | print_freq=args.print_freq, train_iters=args.iters, use_xbm=args.use_xbm) 225 | 226 | if ((epoch+1)%args.eval_step==0 or (epoch==args.epochs-1)): 227 | 228 | print('Test on target: ', args.dataset_target) 229 | _, mAP = evaluator.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery, cmc_flag=True) 230 | is_best = (mAP>best_mAP) 231 | best_mAP = max(mAP, best_mAP) 232 | save_checkpoint({ 233 | 'state_dict': model.state_dict(), 234 | 'epoch': epoch + 1, 235 | 'best_mAP': best_mAP, 236 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 237 | 238 | print('\n * Finished epoch {:3d} model mAP: {:5.1%} best: {:5.1%}{}\n'. 239 | format(epoch, mAP, best_mAP, ' *' if is_best else '')) 240 | 241 | lr_scheduler.step() 242 | 243 | 244 | print ('==> Test with the best model on the target domain:') 245 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) 246 | model.load_state_dict(checkpoint['state_dict']) 247 | evaluator.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery, cmc_flag=True) 248 | 249 | end_time = time.monotonic() 250 | print('Total running time: ', timedelta(seconds=end_time - start_time)) 251 | 252 | if __name__ == '__main__': 253 | parser = argparse.ArgumentParser(description="Self-paced contrastive learning on UDA re-ID") 254 | # data 255 | parser.add_argument('-ds', '--dataset-source', type=str, default='dukemtmc') 256 | parser.add_argument('-dt', '--dataset-target', type=str, default='market1501') 257 | parser.add_argument('-b', '--batch-size', type=int, default=64) 258 | parser.add_argument('-j', '--workers', type=int, default=4) 259 | parser.add_argument('--height', type=int, default=256, help="input height") 260 | parser.add_argument('--width', type=int, default=128, help="input width") 261 | parser.add_argument('--num-instances', type=int, default=4, 262 | help="each minibatch consist of " 263 | "(batch_size // num_instances) identities, and " 264 | "each identity has num_instances instances, " 265 | "default: 0 (NOT USE)") 266 | # cluster 267 | parser.add_argument('--eps', type=float, default=0.6, 268 | help="max neighbor distance for DBSCAN") 269 | parser.add_argument('--k1', type=int, default=30, 270 | help="hyperparameter for jaccard distance") 271 | parser.add_argument('--k2', type=int, default=6, 272 | help="hyperparameter for jaccard distance") 273 | parser.add_argument('--nclass', type=int, default=1000, 274 | help="number of classes (source+target)") 275 | parser.add_argument('--s-class', type=int, default=1000, 276 | help="number of classes (source)") 277 | parser.add_argument('--t-class', type=int, default=1000, 278 | help="number of classes (target)") 279 | # loss 280 | parser.add_argument('--margin', type=float, default=0.3, 281 | help="margin for triplet loss") 282 | parser.add_argument('--mu1', type=float, default=0.5, 283 | help="weight for loss_bridge_pred") 284 | parser.add_argument('--mu2', type=float, default=0.1, 285 | help="weight for loss_bridge_feat") 286 | parser.add_argument('--mu3', type=float, default=1, 287 | help="weight for loss_div") 288 | 289 | # model 290 | parser.add_argument('-a', '--arch', type=str, default='resnet50_idm', 291 | choices=models.names()) 292 | parser.add_argument('--features', type=int, default=0) 293 | parser.add_argument('--dropout', type=float, default=0) 294 | 295 | # xbm parameters 296 | parser.add_argument('--memorySize', type=int, default=8192, 297 | help='meomory bank size') 298 | parser.add_argument('--ratio', type=float, default=1, 299 | help='memorySize=ratio*data_size') 300 | parser.add_argument('--featureSize', type=int, default=2048) 301 | parser.add_argument('--use-xbm', action='store_true',help="if True: strong baseline; if False: naive baseline") 302 | 303 | # optimizer 304 | parser.add_argument('--lr', type=float, default=0.00035, 305 | help="learning rate") 306 | parser.add_argument('--weight-decay', type=float, default=5e-4) 307 | parser.add_argument('--epochs', type=int, default=50) 308 | parser.add_argument('--iters', type=int, default=400) 309 | parser.add_argument('--step-size', type=int, default=30) 310 | # training configs 311 | parser.add_argument('--seed', type=int, default=1) 312 | parser.add_argument('--print-freq', type=int, default=50) 313 | parser.add_argument('--eval-step', type=int, default=10) 314 | 315 | # path 316 | working_dir = osp.dirname(osp.abspath(__file__)) 317 | parser.add_argument('--data-dir', type=str, metavar='PATH', 318 | default=osp.join(working_dir, 'data')) 319 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 320 | default=osp.join(working_dir, 'logs')) 321 | main() 322 | 323 | -------------------------------------------------------------------------------- /examples/train_idm.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import random 5 | import numpy as np 6 | import sys 7 | import collections 8 | import time 9 | from datetime import timedelta 10 | from sklearn.cluster import DBSCAN 11 | import torch 12 | from torch import nn 13 | from torch.backends import cudnn 14 | from torch.utils.data import DataLoader 15 | import torch.nn.functional as F 16 | 17 | sys.path.append(".") 18 | from idm import datasets 19 | from idm import models 20 | from idm.models.idm_dsbn import convert_dsbn_idm, convert_bn_idm 21 | from idm.models.xbm import XBM 22 | from idm.trainers import Baseline_Trainer, IDM_Trainer 23 | from idm.evaluators import Evaluator, extract_features 24 | from idm.utils.data import IterLoader 25 | from idm.utils.data import transforms as T 26 | from idm.utils.data.sampler import RandomMultipleGallerySampler 27 | from idm.utils.data.preprocessor import Preprocessor 28 | from idm.utils.logging import Logger 29 | from idm.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict 30 | from idm.utils.rerank import compute_jaccard_distance 31 | 32 | 33 | start_epoch = best_mAP = 0 34 | 35 | def get_data(name, data_dir): 36 | root = osp.join(data_dir, name) 37 | dataset = datasets.create(name, root) 38 | return dataset 39 | 40 | def get_train_loader(args, dataset, height, width, batch_size, workers, 41 | num_instances, iters, trainset=None): 42 | 43 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 44 | std=[0.229, 0.224, 0.225]) 45 | train_transformer = T.Compose([ 46 | T.Resize((height, width), interpolation=3), 47 | T.RandomHorizontalFlip(p=0.5), 48 | T.Pad(10), 49 | T.RandomCrop((height, width)), 50 | T.ToTensor(), 51 | normalizer, 52 | T.RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406]) 53 | ]) 54 | 55 | train_set = sorted(dataset.train) if trainset is None else sorted(trainset) 56 | rmgs_flag = num_instances > 0 57 | if rmgs_flag: 58 | sampler = RandomMultipleGallerySampler(train_set, num_instances) 59 | else: 60 | sampler = None 61 | train_loader = IterLoader( 62 | DataLoader(Preprocessor(train_set, root=dataset.images_dir, transform=train_transformer), 63 | batch_size=batch_size, num_workers=workers, sampler=sampler, 64 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=iters) 65 | 66 | return train_loader 67 | 68 | def get_test_loader(dataset, height, width, batch_size, workers, testset=None): 69 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 70 | std=[0.229, 0.224, 0.225]) 71 | 72 | test_transformer = T.Compose([ 73 | T.Resize((height, width), interpolation=3), 74 | T.ToTensor(), 75 | normalizer 76 | ]) 77 | 78 | if (testset is None): 79 | testset = list(set(dataset.query) | set(dataset.gallery)) 80 | 81 | test_loader = DataLoader( 82 | Preprocessor(testset, root=dataset.images_dir, transform=test_transformer), 83 | batch_size=batch_size, num_workers=workers, 84 | shuffle=False, pin_memory=True) 85 | 86 | return test_loader 87 | 88 | def filter_layers(stage): 89 | layer_names = ['conv', 'layer1', 'layer2', 'layer3', 'layer4', 'feat_bn'] 90 | ori_bn_names = [] 91 | idm_bn_names = [] 92 | for i in range(len(layer_names)): 93 | if i < stage+1: 94 | ori_bn_names.append(layer_names[i]) 95 | else: 96 | idm_bn_names.append(layer_names[i]) 97 | return idm_bn_names 98 | 99 | def create_model(args): 100 | model = models.create(args.arch, num_features=args.features, norm=False, dropout=args.dropout, 101 | num_classes=args.nclass) 102 | 103 | idm_bn_names = filter_layers(args.stage) 104 | convert_dsbn_idm(model, idm_bn_names, idm=False) 105 | 106 | # use CUDA 107 | model.cuda() 108 | model = nn.DataParallel(model) 109 | return model 110 | 111 | 112 | def main(): 113 | args = parser.parse_args() 114 | 115 | if args.seed is not None: 116 | random.seed(args.seed) 117 | np.random.seed(args.seed) 118 | torch.manual_seed(args.seed) 119 | cudnn.deterministic = True 120 | 121 | main_worker(args) 122 | 123 | 124 | def main_worker(args): 125 | global start_epoch, best_mAP 126 | start_time = time.monotonic() 127 | 128 | cudnn.benchmark = True 129 | 130 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 131 | print("==========\nArgs:{}\n==========".format(args)) 132 | 133 | # Create datasets 134 | iters = args.iters if (args.iters>0) else None 135 | print("==> Load source-domain dataset") 136 | dataset_source = get_data(args.dataset_source, args.data_dir) 137 | print("==> Load target-domain dataset") 138 | dataset_target = get_data(args.dataset_target, args.data_dir) 139 | test_loader_target = get_test_loader(dataset_target, args.height, args.width, args.batch_size, args.workers) 140 | train_loader_source = get_train_loader(args, dataset_source, args.height, args.width, 141 | args.batch_size, args.workers, args.num_instances, iters) 142 | 143 | source_classes = dataset_source.num_train_pids 144 | args.nclass = source_classes+len(dataset_target.train) 145 | args.s_class = source_classes 146 | args.t_class = len(dataset_target.train) 147 | 148 | # Create model 149 | model = create_model(args) 150 | print(model) 151 | 152 | # Create XBM 153 | 154 | datasetSize = len(dataset_source.train)+len(dataset_target.train) 155 | 156 | args.memorySize = int(args.ratio*datasetSize) 157 | xbm = XBM(args.memorySize, args.featureSize) 158 | print('XBM memory size = ', args.memorySize) 159 | # Initialize source-domain class centroids 160 | sour_cluster_loader = get_test_loader(dataset_source, args.height, args.width, 161 | args.batch_size, args.workers, testset=sorted(dataset_source.train)) 162 | source_features, _ = extract_features(model, sour_cluster_loader, print_freq=50) 163 | sour_fea_dict = collections.defaultdict(list) 164 | for f, pid, _ in sorted(dataset_source.train): 165 | sour_fea_dict[pid].append(source_features[f].unsqueeze(0)) 166 | source_centers = [torch.cat(sour_fea_dict[pid],0).mean(0) for pid in sorted(sour_fea_dict.keys())] 167 | source_centers = torch.stack(source_centers,0) 168 | source_centers = F.normalize(source_centers, dim=1) 169 | model.module.classifier.weight.data[0:source_classes].copy_(source_centers.cuda()) 170 | 171 | del source_centers, sour_cluster_loader, sour_fea_dict 172 | 173 | # Evaluator 174 | evaluator = Evaluator(model) 175 | 176 | # Optimizer 177 | params = [{"params": [value]} for _, value in model.named_parameters() if value.requires_grad] 178 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 179 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1) 180 | 181 | # Trainer 182 | trainer = IDM_Trainer(model, xbm, args.nclass, margin=args.margin, mu1=args.mu1, mu2=args.mu2, mu3=args.mu3) 183 | 184 | for epoch in range(args.epochs): 185 | with torch.no_grad(): 186 | tgt_cluster_loader = get_test_loader(dataset_target, args.height, args.width, 187 | args.batch_size, args.workers, testset=sorted(dataset_target.train)) 188 | time.sleep(0.5) 189 | target_features, _ = extract_features(model, tgt_cluster_loader, print_freq=50) 190 | target_features = torch.cat([target_features[f].unsqueeze(0) for f, _, _ in 191 | sorted(dataset_target.train)], 0) 192 | 193 | del tgt_cluster_loader 194 | print('==> Create pseudo labels for unlabeled target domain with DBSCAN clustering') 195 | 196 | rerank_dist = compute_jaccard_distance(target_features, k1=args.k1, k2=args.k2, use_gpu=False).numpy() 197 | print('Clustering and labeling...') 198 | eps = args.eps 199 | cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=-1) 200 | labels = cluster.fit_predict(rerank_dist) 201 | del rerank_dist 202 | num_ids = len(set(labels)) - (1 if -1 in labels else 0) 203 | args.t_class = num_ids 204 | 205 | print('\n Clustered into {} classes \n'.format(args.t_class)) 206 | 207 | 208 | # generate new dataset and calculate cluster centers 209 | new_dataset = [] 210 | cluster_centers = collections.defaultdict(list) 211 | for i, ((fname, _, cid), label) in enumerate(zip(sorted(dataset_target.train), labels)): 212 | if label == -1: continue 213 | new_dataset.append((fname, source_classes+label, cid)) 214 | cluster_centers[label].append(target_features[i]) 215 | 216 | # dataset_target.train = new_dataset 217 | 218 | cluster_centers = [torch.stack(cluster_centers[idx]).mean(0) for idx in sorted(cluster_centers.keys())] 219 | cluster_centers = torch.stack(cluster_centers) 220 | model.module.classifier.weight.data[args.s_class:args.s_class+args.t_class].copy_(F.normalize(cluster_centers, dim=1).float().cuda()) 221 | 222 | del cluster_centers, target_features 223 | 224 | train_loader_target = get_train_loader(args, dataset_target, args.height, args.width, 225 | args.batch_size, args.workers, args.num_instances, iters, 226 | trainset=new_dataset) 227 | 228 | time.sleep(0.5) 229 | train_loader_source.new_epoch() 230 | time.sleep(0.5) 231 | train_loader_target.new_epoch() 232 | time.sleep(0.5) 233 | trainer.train(epoch, train_loader_source, train_loader_target, args.s_class, args.t_class, optimizer, 234 | print_freq=args.print_freq, train_iters=args.iters, use_xbm=args.use_xbm, stage=args.stage) 235 | 236 | if ((epoch+1)%args.eval_step==0 or (epoch==args.epochs-1)): 237 | 238 | print('Test on target: ', args.dataset_target) 239 | _, mAP = evaluator.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery, cmc_flag=True) 240 | is_best = (mAP>best_mAP) 241 | best_mAP = max(mAP, best_mAP) 242 | save_checkpoint({ 243 | 'state_dict': model.state_dict(), 244 | 'epoch': epoch + 1, 245 | 'best_mAP': best_mAP, 246 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 247 | 248 | print('\n * Finished epoch {:3d} model mAP: {:5.1%} best: {:5.1%}{}\n'. 249 | format(epoch, mAP, best_mAP, ' *' if is_best else '')) 250 | 251 | lr_scheduler.step() 252 | 253 | print ('==> Test with the best model on the target domain:') 254 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) 255 | model.load_state_dict(checkpoint['state_dict']) 256 | evaluator.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery, cmc_flag=True) 257 | 258 | end_time = time.monotonic() 259 | print('Total running time: ', timedelta(seconds=end_time - start_time)) 260 | 261 | if __name__ == '__main__': 262 | parser = argparse.ArgumentParser(description="Self-paced contrastive learning on UDA re-ID") 263 | # data 264 | parser.add_argument('-ds', '--dataset-source', type=str, default='dukemtmc') 265 | parser.add_argument('-dt', '--dataset-target', type=str, default='market1501') 266 | parser.add_argument('-b', '--batch-size', type=int, default=64) 267 | parser.add_argument('-j', '--workers', type=int, default=4) 268 | parser.add_argument('--height', type=int, default=256, help="input height") 269 | parser.add_argument('--width', type=int, default=128, help="input width") 270 | parser.add_argument('--num-instances', type=int, default=4, 271 | help="each minibatch consist of " 272 | "(batch_size // num_instances) identities, and " 273 | "each identity has num_instances instances, " 274 | "default: 0 (NOT USE)") 275 | # cluster 276 | parser.add_argument('--eps', type=float, default=0.6, 277 | help="max neighbor distance for DBSCAN") 278 | parser.add_argument('--k1', type=int, default=30, 279 | help="hyperparameter for jaccard distance") 280 | parser.add_argument('--k2', type=int, default=6, 281 | help="hyperparameter for jaccard distance") 282 | parser.add_argument('--nclass', type=int, default=1000, 283 | help="number of classes (source+target)") 284 | parser.add_argument('--s-class', type=int, default=1000, 285 | help="number of classes (source)") 286 | parser.add_argument('--t-class', type=int, default=1000, 287 | help="number of classes (target)") 288 | # loss 289 | parser.add_argument('--margin', type=float, default=0.3, 290 | help="margin for triplet loss") 291 | parser.add_argument('--mu1', type=float, default=0.7, 292 | help="weight for loss_bridge_pred") 293 | parser.add_argument('--mu2', type=float, default=0.1, 294 | help="weight for loss_bridge_feat") 295 | parser.add_argument('--mu3', type=float, default=1, 296 | help="weight for loss_div") 297 | 298 | # model 299 | parser.add_argument('-a', '--arch', type=str, default='resnet50_idm', 300 | choices=models.names()) 301 | parser.add_argument('--features', type=int, default=0) 302 | parser.add_argument('--dropout', type=float, default=0) 303 | 304 | # xbm parameters 305 | parser.add_argument('--memorySize', type=int, default=8192, 306 | help='meomory bank size') 307 | parser.add_argument('--ratio', type=float, default=1, 308 | help='memorySize=ratio*data_size') 309 | parser.add_argument('--featureSize', type=int, default=2048) 310 | parser.add_argument('--use-xbm', action='store_true', 311 | help="if True: strong baseline; if False: naive baseline") 312 | 313 | # idm parameters 314 | parser.add_argument('--stage', type=int, default=0, 315 | help="insert IDM module after stage 0/1/2/3/4") 316 | # optimizer 317 | parser.add_argument('--lr', type=float, default=0.00035, 318 | help="learning rate") 319 | parser.add_argument('--weight-decay', type=float, default=5e-4) 320 | parser.add_argument('--epochs', type=int, default=50) 321 | parser.add_argument('--iters', type=int, default=400) 322 | parser.add_argument('--step-size', type=int, default=20) 323 | # training configs 324 | parser.add_argument('--seed', type=int, default=1) 325 | parser.add_argument('--print-freq', type=int, default=50) 326 | parser.add_argument('--eval-step', type=int, default=1) 327 | 328 | # path 329 | working_dir = osp.dirname(osp.abspath(__file__)) 330 | parser.add_argument('--data-dir', type=str, metavar='PATH', 331 | default=osp.join(working_dir, 'data')) 332 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 333 | default=osp.join(working_dir, 'logs')) 334 | main() 335 | -------------------------------------------------------------------------------- /idm/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import models 6 | from . import utils 7 | from . import evaluators 8 | from . import trainers 9 | 10 | -------------------------------------------------------------------------------- /idm/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .dukemtmc import DukeMTMC 5 | from .market1501 import Market1501 6 | from .msmt17 import MSMT17 7 | from .personx import PersonX 8 | from .unreal import UnrealPerson 9 | 10 | 11 | __factory = { 12 | 'market1501': Market1501, 13 | 'dukemtmc': DukeMTMC, 14 | 'msmt17': MSMT17, 15 | 'personx': PersonX, 16 | 'unreal': UnrealPerson 17 | } 18 | 19 | 20 | def names(): 21 | return sorted(__factory.keys()) 22 | 23 | 24 | def create(name, root, *args, **kwargs): 25 | """ 26 | Create a dataset instance. 27 | 28 | Parameters 29 | ---------- 30 | name : str 31 | The dataset name. 32 | root : str 33 | The path to the dataset directory. 34 | split_id : int, optional 35 | The index of data split. Default: 0 36 | num_val : int or float, optional 37 | When int, it means the number of validation identities. When float, 38 | it means the proportion of validation to all the trainval. Default: 100 39 | download : bool, optional 40 | If True, will download the dataset. Default: False 41 | """ 42 | if name not in __factory: 43 | raise KeyError("Unknown dataset:", name) 44 | return __factory[name](root, *args, **kwargs) 45 | 46 | 47 | def get_dataset(name, root, *args, **kwargs): 48 | warnings.warn("get_dataset is deprecated. Use create instead.") 49 | return create(name, root, *args, **kwargs) 50 | -------------------------------------------------------------------------------- /idm/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | 13 | class DukeMTMC(BaseImageDataset): 14 | """ 15 | DukeMTMC-reID 16 | Reference: 17 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 18 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 19 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 20 | 21 | Dataset statistics: 22 | # identities: 1404 (train + query) 23 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 24 | # cameras: 8 25 | """ 26 | dataset_dir = '.' 27 | 28 | def __init__(self, root, verbose=True, **kwargs): 29 | super(DukeMTMC, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 32 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 33 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 35 | 36 | self._download_data() 37 | self._check_before_run() 38 | 39 | train = self._process_dir(self.train_dir, relabel=True) 40 | query = self._process_dir(self.query_dir, relabel=False) 41 | gallery = self._process_dir(self.gallery_dir, relabel=False) 42 | 43 | if verbose: 44 | print("=> DukeMTMC-reID loaded") 45 | self.print_dataset_statistics(train, query, gallery) 46 | 47 | self.train = train 48 | self.query = query 49 | self.gallery = gallery 50 | 51 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 52 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 53 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 54 | 55 | def _download_data(self): 56 | if osp.exists(self.dataset_dir): 57 | print("This dataset has been downloaded.") 58 | return 59 | 60 | print("Creating directory {}".format(self.dataset_dir)) 61 | mkdir_if_missing(self.dataset_dir) 62 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 63 | 64 | print("Downloading DukeMTMC-reID dataset") 65 | urllib.request.urlretrieve(self.dataset_url, fpath) 66 | 67 | print("Extracting files") 68 | zip_ref = zipfile.ZipFile(fpath, 'r') 69 | zip_ref.extractall(self.dataset_dir) 70 | zip_ref.close() 71 | 72 | def _check_before_run(self): 73 | """Check if all files are available before going deeper""" 74 | if not osp.exists(self.dataset_dir): 75 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 76 | if not osp.exists(self.train_dir): 77 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 78 | if not osp.exists(self.query_dir): 79 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 80 | if not osp.exists(self.gallery_dir): 81 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 82 | 83 | def _process_dir(self, dir_path, relabel=False): 84 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 85 | pattern = re.compile(r'([-\d]+)_c(\d)') 86 | 87 | pid_container = set() 88 | for img_path in img_paths: 89 | pid, _ = map(int, pattern.search(img_path).groups()) 90 | pid_container.add(pid) 91 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 92 | 93 | dataset = [] 94 | for img_path in img_paths: 95 | pid, camid = map(int, pattern.search(img_path).groups()) 96 | assert 1 <= camid <= 8 97 | camid -= 1 # index starts from 0 98 | if relabel: pid = pid2label[pid] 99 | dataset.append((img_path, pid, camid)) 100 | 101 | return dataset 102 | -------------------------------------------------------------------------------- /idm/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | class Market1501(BaseImageDataset): 13 | """ 14 | Market1501 15 | Reference: 16 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 17 | URL: http://www.liangzheng.org/Project/project_reid.html 18 | 19 | Dataset statistics: 20 | # identities: 1501 (+1 for background) 21 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 22 | """ 23 | dataset_dir = 'Market-1501-v15.09.15' 24 | 25 | def __init__(self, root, verbose=True, **kwargs): 26 | super(Market1501, self).__init__() 27 | self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 31 | 32 | self._check_before_run() 33 | 34 | train = self._process_dir(self.train_dir, relabel=True) 35 | query = self._process_dir(self.query_dir, relabel=False) 36 | gallery = self._process_dir(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print("=> Market1501 loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _process_dir(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c(\d)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | dataset = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1: continue # junk images are just ignored 76 | assert 0 <= pid <= 1501 # pid == 0 means background 77 | assert 1 <= camid <= 6 78 | camid -= 1 # index starts from 0 79 | if relabel: pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | -------------------------------------------------------------------------------- /idm/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import tarfile 4 | 5 | import glob 6 | import re 7 | import urllib 8 | import zipfile 9 | 10 | from ..utils.osutils import mkdir_if_missing 11 | from ..utils.serialization import write_json 12 | 13 | 14 | def _pluck_msmt(list_file, subdir, pattern=re.compile(r'([-\d]+)_([-\d]+)_([-\d]+)')): 15 | with open(list_file, 'r') as f: 16 | lines = f.readlines() 17 | ret = [] 18 | pids = [] 19 | for line in lines: 20 | line = line.strip() 21 | fname = line.split(' ')[0] 22 | pid, _, cam = map(int, pattern.search(osp.basename(fname)).groups()) 23 | if pid not in pids: 24 | pids.append(pid) 25 | ret.append((osp.join(subdir,fname), pid, cam)) 26 | return ret, pids 27 | 28 | class Dataset_MSMT(object): 29 | def __init__(self, root): 30 | self.root = root 31 | self.train, self.val, self.trainval = [], [], [] 32 | self.query, self.gallery = [], [] 33 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 34 | 35 | @property 36 | def images_dir(self): 37 | return osp.join(self.root, 'MSMT17_V1') 38 | 39 | def load(self, verbose=True): 40 | exdir = osp.join(self.root, 'MSMT17_V1') 41 | self.train, train_pids = _pluck_msmt(osp.join(exdir, 'list_train.txt'), 'train') 42 | self.val, val_pids = _pluck_msmt(osp.join(exdir, 'list_val.txt'), 'train') 43 | self.train = self.train + self.val 44 | self.query, query_pids = _pluck_msmt(osp.join(exdir, 'list_query.txt'), 'test') 45 | self.gallery, gallery_pids = _pluck_msmt(osp.join(exdir, 'list_gallery.txt'), 'test') 46 | self.num_train_pids = len(list(set(train_pids).union(set(val_pids)))) 47 | 48 | if verbose: 49 | print(self.__class__.__name__, "dataset loaded") 50 | print(" subset | # ids | # images") 51 | print(" ---------------------------") 52 | print(" train | {:5d} | {:8d}" 53 | .format(self.num_train_pids, len(self.train))) 54 | print(" query | {:5d} | {:8d}" 55 | .format(len(query_pids), len(self.query))) 56 | print(" gallery | {:5d} | {:8d}" 57 | .format(len(gallery_pids), len(self.gallery))) 58 | 59 | class MSMT17(Dataset_MSMT): 60 | 61 | def __init__(self, root, split_id=0, download=True): 62 | super(MSMT17, self).__init__(root) 63 | 64 | if download: 65 | self.download() 66 | 67 | self.load() 68 | 69 | def download(self): 70 | 71 | import re 72 | import hashlib 73 | import shutil 74 | from glob import glob 75 | from zipfile import ZipFile 76 | 77 | raw_dir = osp.join(self.root) 78 | mkdir_if_missing(raw_dir) 79 | 80 | # Download the raw zip file 81 | fpath = osp.join(raw_dir, 'MSMT17_V1') 82 | if osp.isdir(fpath): 83 | print("Using downloaded file: " + fpath) 84 | else: 85 | raise RuntimeError("Please download the dataset manually to {}".format(fpath)) 86 | -------------------------------------------------------------------------------- /idm/datasets/personx.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | class PersonX(BaseImageDataset): 13 | """ 14 | PersonX 15 | Reference: 16 | Sun et al. Dissecting Person Re-identification from the Viewpoint of Viewpoint. CVPR 2019. 17 | 18 | Dataset statistics: 19 | # identities: 1266 20 | # images: 9840 (train) + 5136 (query) + 30816 (gallery) 21 | """ 22 | dataset_dir = 'PersonX' 23 | 24 | def __init__(self, root, verbose=True, **kwargs): 25 | super(PersonX, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 30 | 31 | self._check_before_run() 32 | 33 | train = self._process_dir(self.train_dir, relabel=True) 34 | query = self._process_dir(self.query_dir, relabel=False) 35 | gallery = self._process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | print("=> PersonX loaded") 39 | self.print_dataset_statistics(train, query, gallery) 40 | 41 | self.train = train 42 | self.query = query 43 | self.gallery = gallery 44 | 45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 48 | 49 | def _check_before_run(self): 50 | """Check if all files are available before going deeper""" 51 | if not osp.exists(self.dataset_dir): 52 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 53 | if not osp.exists(self.train_dir): 54 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 55 | if not osp.exists(self.query_dir): 56 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 57 | if not osp.exists(self.gallery_dir): 58 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 59 | 60 | def _process_dir(self, dir_path, relabel=False): 61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 62 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 63 | cam2label = {3:1, 4:2, 8:3, 10:4, 11:5, 12:6} 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | pid_container.add(pid) 69 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 70 | 71 | dataset = [] 72 | for img_path in img_paths: 73 | pid, camid = map(int, pattern.search(img_path).groups()) 74 | assert (camid in cam2label.keys()) 75 | camid = cam2label[camid] 76 | camid -= 1 # index starts from 0 77 | if relabel: pid = pid2label[pid] 78 | dataset.append((img_path, pid, camid)) 79 | 80 | return dataset 81 | -------------------------------------------------------------------------------- /idm/datasets/unreal.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | 13 | class UnrealPerson(BaseImageDataset): 14 | """ 15 | UnrealPerson 16 | Reference: 17 | Zhang et al. UnrealPerson: An Adaptive Pipeline towards Costless Person Re-identification. CVPR 2021. 18 | URL: https://github.com/FlyHighest/UnrealPerson 19 | "list_unreal_train.txt" is from https://github.com/FlyHighest/UnrealPerson/tree/main/JVTC/list_unreal 20 | 21 | Dataset statistics: 22 | # identities: 3000 23 | # cameras: 34 24 | # images: 120,000 25 | """ 26 | dataset_dir = '' 27 | 28 | def __init__(self, root, verbose=True, **kwargs): 29 | super(UnrealPerson, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.train_list = osp.join(self.dataset_dir, 'list_unreal_train.txt') 32 | 33 | self._check_before_run() 34 | 35 | train = self._process_dir(self.train_list) 36 | self.train = train 37 | self.query = [] 38 | self.gallery = [] 39 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 40 | 41 | if verbose: 42 | print("=> UnrealPerson loaded") 43 | print(" subset | # ids | # cams | # images") 44 | print(" ---------------------------") 45 | print(" train | {:5d} | {:5d} | {:8d}" 46 | .format(self.num_train_pids, self.num_train_cams, self.num_train_imgs)) 47 | 48 | def _check_before_run(self): 49 | """Check if all files are available before going deeper""" 50 | if not osp.exists(self.dataset_dir): 51 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 52 | if not osp.exists(self.train_list): 53 | raise RuntimeError("'{}' is not available".format(self.train_list)) 54 | 55 | def _process_dir(self, list_file): 56 | with open(list_file, 'r') as f: 57 | lines = f.readlines() 58 | dataset = [] 59 | pid_container = set() 60 | for line in lines: 61 | line = line.strip() 62 | pid = line.split(' ')[1] 63 | pid_container.add(pid) 64 | 65 | pid2label = {pid: label for label, pid in enumerate(sorted(pid_container))} 66 | 67 | for line in lines: 68 | line = line.strip() 69 | fname, pid, camid = line.split(' ')[0], line.split(' ')[1], int(line.split(' ')[2]) 70 | img_path = osp.join(self.dataset_dir, fname) 71 | dataset.append((img_path, pid2label[pid], camid)) 72 | 73 | return dataset 74 | -------------------------------------------------------------------------------- /idm/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap' 10 | ] 11 | -------------------------------------------------------------------------------- /idm/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from ..utils import to_torch 5 | 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | with torch.no_grad(): 9 | output, target = to_torch(output), to_torch(target) 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | ret = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 20 | ret.append(correct_k.mul_(1. / batch_size)) 21 | return ret 22 | -------------------------------------------------------------------------------- /idm/evaluation_metrics/rank.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import numpy as np 3 | import warnings 4 | from collections import defaultdict 5 | 6 | try: 7 | from .rank_cylib.rank_cy import evaluate_cy 8 | IS_CYTHON_AVAI = True 9 | except ImportError: 10 | IS_CYTHON_AVAI = False 11 | warnings.warn( 12 | 'Cython evaluation (very fast so highly recommended) is ' 13 | 'unavailable, now use python evaluation.' 14 | ) 15 | 16 | 17 | def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 18 | """Evaluation with cuhk03 metric 19 | Key: one image for each gallery identity is randomly sampled for each query identity. 20 | Random sampling is performed num_repeats times. 21 | """ 22 | num_repeats = 10 23 | num_q, num_g = distmat.shape 24 | 25 | if num_g < max_rank: 26 | max_rank = num_g 27 | print( 28 | 'Note: number of gallery samples is quite small, got {}'. 29 | format(num_g) 30 | ) 31 | 32 | indices = np.argsort(distmat, axis=1) 33 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 34 | 35 | # compute cmc curve for each query 36 | all_cmc = [] 37 | all_AP = [] 38 | num_valid_q = 0. # number of valid query 39 | 40 | for q_idx in range(num_q): 41 | # get query pid and camid 42 | q_pid = q_pids[q_idx] 43 | q_camid = q_camids[q_idx] 44 | 45 | # remove gallery samples that have the same pid and camid with query 46 | order = indices[q_idx] 47 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 48 | keep = np.invert(remove) 49 | 50 | # compute cmc curve 51 | raw_cmc = matches[q_idx][ 52 | keep] # binary vector, positions with value 1 are correct matches 53 | if not np.any(raw_cmc): 54 | # this condition is true when query identity does not appear in gallery 55 | continue 56 | 57 | kept_g_pids = g_pids[order][keep] 58 | g_pids_dict = defaultdict(list) 59 | for idx, pid in enumerate(kept_g_pids): 60 | g_pids_dict[pid].append(idx) 61 | 62 | cmc = 0. 63 | for repeat_idx in range(num_repeats): 64 | mask = np.zeros(len(raw_cmc), dtype=np.bool) 65 | for _, idxs in g_pids_dict.items(): 66 | # randomly sample one image for each gallery person 67 | rnd_idx = np.random.choice(idxs) 68 | mask[rnd_idx] = True 69 | masked_raw_cmc = raw_cmc[mask] 70 | _cmc = masked_raw_cmc.cumsum() 71 | _cmc[_cmc > 1] = 1 72 | cmc += _cmc[:max_rank].astype(np.float32) 73 | 74 | cmc /= num_repeats 75 | all_cmc.append(cmc) 76 | # compute AP 77 | num_rel = raw_cmc.sum() 78 | tmp_cmc = raw_cmc.cumsum() 79 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 80 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 81 | AP = tmp_cmc.sum() / num_rel 82 | all_AP.append(AP) 83 | num_valid_q += 1. 84 | 85 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 86 | 87 | all_cmc = np.asarray(all_cmc).astype(np.float32) 88 | all_cmc = all_cmc.sum(0) / num_valid_q 89 | mAP = np.mean(all_AP) 90 | 91 | return all_cmc, mAP 92 | 93 | 94 | def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 95 | """Evaluation with market1501 metric 96 | Key: for each query identity, its gallery images from the same camera view are discarded. 97 | """ 98 | num_q, num_g = distmat.shape 99 | 100 | if num_g < max_rank: 101 | max_rank = num_g 102 | print( 103 | 'Note: number of gallery samples is quite small, got {}'. 104 | format(num_g) 105 | ) 106 | 107 | indices = np.argsort(distmat, axis=1) 108 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 109 | 110 | # compute cmc curve for each query 111 | all_cmc = [] 112 | all_AP = [] 113 | num_valid_q = 0. # number of valid query 114 | 115 | for q_idx in range(num_q): 116 | # get query pid and camid 117 | q_pid = q_pids[q_idx] 118 | q_camid = q_camids[q_idx] 119 | 120 | # remove gallery samples that have the same pid and camid with query 121 | order = indices[q_idx] 122 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 123 | keep = np.invert(remove) 124 | 125 | # compute cmc curve 126 | raw_cmc = matches[q_idx][ 127 | keep] # binary vector, positions with value 1 are correct matches 128 | if not np.any(raw_cmc): 129 | # this condition is true when query identity does not appear in gallery 130 | continue 131 | 132 | cmc = raw_cmc.cumsum() 133 | cmc[cmc > 1] = 1 134 | 135 | all_cmc.append(cmc[:max_rank]) 136 | num_valid_q += 1. 137 | 138 | # compute average precision 139 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 140 | num_rel = raw_cmc.sum() 141 | tmp_cmc = raw_cmc.cumsum() 142 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 143 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 144 | AP = tmp_cmc.sum() / num_rel 145 | all_AP.append(AP) 146 | 147 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 148 | 149 | all_cmc = np.asarray(all_cmc).astype(np.float32) 150 | all_cmc = all_cmc.sum(0) / num_valid_q 151 | mAP = np.mean(all_AP) 152 | 153 | return all_cmc, mAP 154 | 155 | 156 | def evaluate_py( 157 | distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03 158 | ): 159 | if use_metric_cuhk03: 160 | return eval_cuhk03( 161 | distmat, q_pids, g_pids, q_camids, g_camids, max_rank 162 | ) 163 | else: 164 | return eval_market1501( 165 | distmat, q_pids, g_pids, q_camids, g_camids, max_rank 166 | ) 167 | 168 | 169 | def evaluate_rank( 170 | distmat, 171 | q_pids, 172 | g_pids, 173 | q_camids, 174 | g_camids, 175 | max_rank=50, 176 | use_metric_cuhk03=False, 177 | use_cython=True 178 | ): 179 | """Evaluates CMC rank. 180 | Args: 181 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 182 | q_pids (numpy.ndarray): 1-D array containing person identities 183 | of each query instance. 184 | g_pids (numpy.ndarray): 1-D array containing person identities 185 | of each gallery instance. 186 | q_camids (numpy.ndarray): 1-D array containing camera views under 187 | which each query instance is captured. 188 | g_camids (numpy.ndarray): 1-D array containing camera views under 189 | which each gallery instance is captured. 190 | max_rank (int, optional): maximum CMC rank to be computed. Default is 50. 191 | use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03. 192 | Default is False. This should be enabled when using cuhk03 classic split. 193 | use_cython (bool, optional): use cython code for evaluation. Default is True. 194 | This is highly recommended as the cython code can speed up the cmc computation 195 | by more than 10x. This requires Cython to be installed. 196 | """ 197 | if use_cython and IS_CYTHON_AVAI: 198 | return evaluate_cy( 199 | distmat, q_pids, g_pids, q_camids, g_camids, max_rank, 200 | use_metric_cuhk03 201 | ) 202 | else: 203 | return evaluate_py( 204 | distmat, q_pids, g_pids, q_camids, g_camids, max_rank, 205 | use_metric_cuhk03 206 | ) 207 | -------------------------------------------------------------------------------- /idm/evaluation_metrics/rank_cylib/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python3 setup.py build_ext --inplace 3 | rm -rf build 4 | clean: 5 | rm -rf build 6 | rm -f rank_cy.c *.so -------------------------------------------------------------------------------- /idm/evaluation_metrics/rank_cylib/rank_cy.pyx: -------------------------------------------------------------------------------- 1 | # cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True 2 | 3 | from __future__ import print_function 4 | import numpy as np 5 | from libc.stdint cimport int64_t, uint64_t 6 | 7 | import cython 8 | 9 | cimport numpy as np 10 | 11 | import random 12 | from collections import defaultdict 13 | 14 | """ 15 | Compiler directives: 16 | https://github.com/cython/cython/wiki/enhancements-compilerdirectives 17 | 18 | Cython tutorial: 19 | https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html 20 | 21 | Credit to https://github.com/luzai 22 | """ 23 | 24 | 25 | # Main interface 26 | cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False): 27 | distmat = np.asarray(distmat, dtype=np.float32) 28 | q_pids = np.asarray(q_pids, dtype=np.int64) 29 | g_pids = np.asarray(g_pids, dtype=np.int64) 30 | q_camids = np.asarray(q_camids, dtype=np.int64) 31 | g_camids = np.asarray(g_camids, dtype=np.int64) 32 | if use_metric_cuhk03: 33 | return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 34 | return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 35 | 36 | 37 | cpdef eval_cuhk03_cy(float[:,:] distmat, int64_t[:] q_pids, int64_t[:]g_pids, 38 | int64_t[:]q_camids, int64_t[:]g_camids, int64_t max_rank): 39 | 40 | cdef int64_t num_q = distmat.shape[0] 41 | cdef int64_t num_g = distmat.shape[1] 42 | 43 | if num_g < max_rank: 44 | max_rank = num_g 45 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 46 | 47 | cdef: 48 | int64_t num_repeats = 10 49 | int64_t[:,:] indices = np.argsort(distmat, axis=1) 50 | int64_t[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 51 | 52 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 53 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 54 | float num_valid_q = 0. # number of valid query 55 | 56 | int64_t q_idx, q_pid, q_camid, g_idx 57 | int64_t[:] order = np.zeros(num_g, dtype=np.int64) 58 | int64_t keep 59 | 60 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 61 | float[:] masked_raw_cmc = np.zeros(num_g, dtype=np.float32) 62 | float[:] cmc, masked_cmc 63 | int64_t num_g_real, num_g_real_masked, rank_idx, rnd_idx 64 | uint64_t meet_condition 65 | float AP 66 | int64_t[:] kept_g_pids, mask 67 | 68 | float num_rel 69 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 70 | float tmp_cmc_sum 71 | 72 | for q_idx in range(num_q): 73 | # get query pid and camid 74 | q_pid = q_pids[q_idx] 75 | q_camid = q_camids[q_idx] 76 | 77 | # remove gallery samples that have the same pid and camid with query 78 | for g_idx in range(num_g): 79 | order[g_idx] = indices[q_idx, g_idx] 80 | num_g_real = 0 81 | meet_condition = 0 82 | kept_g_pids = np.zeros(num_g, dtype=np.int64) 83 | 84 | for g_idx in range(num_g): 85 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 86 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 87 | kept_g_pids[num_g_real] = g_pids[order[g_idx]] 88 | num_g_real += 1 89 | if matches[q_idx][g_idx] > 1e-31: 90 | meet_condition = 1 91 | 92 | if not meet_condition: 93 | # this condition is true when query identity does not appear in gallery 94 | continue 95 | 96 | # cuhk03-specific setting 97 | g_pids_dict = defaultdict(list) # overhead! 98 | for g_idx in range(num_g_real): 99 | g_pids_dict[kept_g_pids[g_idx]].append(g_idx) 100 | 101 | cmc = np.zeros(max_rank, dtype=np.float32) 102 | for _ in range(num_repeats): 103 | mask = np.zeros(num_g_real, dtype=np.int64) 104 | 105 | for _, idxs in g_pids_dict.items(): 106 | # randomly sample one image for each gallery person 107 | rnd_idx = np.random.choice(idxs) 108 | #rnd_idx = idxs[0] # use deterministic for debugging 109 | mask[rnd_idx] = 1 110 | 111 | num_g_real_masked = 0 112 | for g_idx in range(num_g_real): 113 | if mask[g_idx] == 1: 114 | masked_raw_cmc[num_g_real_masked] = raw_cmc[g_idx] 115 | num_g_real_masked += 1 116 | 117 | masked_cmc = np.zeros(num_g, dtype=np.float32) 118 | function_cumsum(masked_raw_cmc, masked_cmc, num_g_real_masked) 119 | for g_idx in range(num_g_real_masked): 120 | if masked_cmc[g_idx] > 1: 121 | masked_cmc[g_idx] = 1 122 | 123 | for rank_idx in range(max_rank): 124 | cmc[rank_idx] += masked_cmc[rank_idx] / num_repeats 125 | 126 | for rank_idx in range(max_rank): 127 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 128 | # compute average precision 129 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 130 | function_cumsum(raw_cmc, tmp_cmc, num_g_real) 131 | num_rel = 0 132 | tmp_cmc_sum = 0 133 | for g_idx in range(num_g_real): 134 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx] 135 | num_rel += raw_cmc[g_idx] 136 | all_AP[q_idx] = tmp_cmc_sum / num_rel 137 | num_valid_q += 1. 138 | 139 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 140 | 141 | # compute averaged cmc 142 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 143 | for rank_idx in range(max_rank): 144 | for q_idx in range(num_q): 145 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 146 | avg_cmc[rank_idx] /= num_valid_q 147 | 148 | cdef float mAP = 0 149 | for q_idx in range(num_q): 150 | mAP += all_AP[q_idx] 151 | mAP /= num_valid_q 152 | 153 | return np.asarray(avg_cmc).astype(np.float32), mAP 154 | 155 | 156 | cpdef eval_market1501_cy(float[:,:] distmat, int64_t[:] q_pids, int64_t[:]g_pids, 157 | int64_t[:]q_camids, int64_t[:]g_camids, int64_t max_rank): 158 | 159 | cdef int64_t num_q = distmat.shape[0] 160 | cdef int64_t num_g = distmat.shape[1] 161 | 162 | if num_g < max_rank: 163 | max_rank = num_g 164 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 165 | 166 | cdef: 167 | int64_t[:,:] indices = np.argsort(distmat, axis=1) 168 | int64_t[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 169 | 170 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 171 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 172 | float num_valid_q = 0. # number of valid query 173 | 174 | int64_t q_idx, q_pid, q_camid, g_idx 175 | int64_t[:] order = np.zeros(num_g, dtype=np.int64) 176 | int64_t keep 177 | 178 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 179 | float[:] cmc = np.zeros(num_g, dtype=np.float32) 180 | int64_t num_g_real, rank_idx 181 | uint64_t meet_condition 182 | 183 | float num_rel 184 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 185 | float tmp_cmc_sum 186 | 187 | for q_idx in range(num_q): 188 | # get query pid and camid 189 | q_pid = q_pids[q_idx] 190 | q_camid = q_camids[q_idx] 191 | 192 | # remove gallery samples that have the same pid and camid with query 193 | for g_idx in range(num_g): 194 | order[g_idx] = indices[q_idx, g_idx] 195 | num_g_real = 0 196 | meet_condition = 0 197 | 198 | for g_idx in range(num_g): 199 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 200 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 201 | num_g_real += 1 202 | if matches[q_idx][g_idx] > 1e-31: 203 | meet_condition = 1 204 | 205 | if not meet_condition: 206 | # this condition is true when query identity does not appear in gallery 207 | continue 208 | 209 | # compute cmc 210 | function_cumsum(raw_cmc, cmc, num_g_real) 211 | for g_idx in range(num_g_real): 212 | if cmc[g_idx] > 1: 213 | cmc[g_idx] = 1 214 | 215 | for rank_idx in range(max_rank): 216 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 217 | num_valid_q += 1. 218 | 219 | # compute average precision 220 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 221 | function_cumsum(raw_cmc, tmp_cmc, num_g_real) 222 | num_rel = 0 223 | tmp_cmc_sum = 0 224 | for g_idx in range(num_g_real): 225 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx] 226 | num_rel += raw_cmc[g_idx] 227 | all_AP[q_idx] = tmp_cmc_sum / num_rel 228 | 229 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 230 | 231 | # compute averaged cmc 232 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 233 | for rank_idx in range(max_rank): 234 | for q_idx in range(num_q): 235 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 236 | avg_cmc[rank_idx] /= num_valid_q 237 | 238 | cdef float mAP = 0 239 | for q_idx in range(num_q): 240 | mAP += all_AP[q_idx] 241 | mAP /= num_valid_q 242 | 243 | return np.asarray(avg_cmc).astype(np.float32), mAP 244 | 245 | 246 | # Compute the cumulative sum 247 | cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, int64_t n): 248 | cdef int64_t i 249 | dst[0] = src[0] 250 | for i in range(1, n): 251 | dst[i] = src[i] + dst[i - 1] -------------------------------------------------------------------------------- /idm/evaluation_metrics/rank_cylib/setup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from distutils.core import setup 3 | from distutils.extension import Extension 4 | from Cython.Build import cythonize 5 | 6 | 7 | def numpy_include(): 8 | try: 9 | numpy_include = np.get_include() 10 | except AttributeError: 11 | numpy_include = np.get_numpy_include() 12 | return numpy_include 13 | 14 | 15 | ext_modules = [ 16 | Extension( 17 | 'rank_cy', 18 | ['rank_cy.pyx'], 19 | include_dirs=[numpy_include()], 20 | ) 21 | ] 22 | 23 | setup( 24 | name='Cython-based reid evaluation code', 25 | ext_modules=cythonize(ext_modules) 26 | ) 27 | -------------------------------------------------------------------------------- /idm/evaluation_metrics/rank_cylib/test_cython.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import numpy as np 4 | import timeit 5 | import os.path as osp 6 | 7 | from torchreid import metrics 8 | 9 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 10 | """ 11 | Test the speed of cython-based evaluation code. The speed improvements 12 | can be much bigger when using the real reid data, which contains a larger 13 | amount of query and gallery images. 14 | 15 | Note: you might encounter the following error: 16 | 'AssertionError: Error: all query identities do not appear in gallery'. 17 | This is normal because the inputs are random numbers. Just try again. 18 | """ 19 | 20 | print('*** Compare running time ***') 21 | 22 | setup = ''' 23 | import sys 24 | import os.path as osp 25 | import numpy as np 26 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 27 | from torchreid import metrics 28 | num_q = 30 29 | num_g = 300 30 | max_rank = 5 31 | distmat = np.random.rand(num_q, num_g) * 20 32 | q_pids = np.random.randint(0, num_q, size=num_q) 33 | g_pids = np.random.randint(0, num_g, size=num_g) 34 | q_camids = np.random.randint(0, 5, size=num_q) 35 | g_camids = np.random.randint(0, 5, size=num_g) 36 | ''' 37 | 38 | print('=> Using market1501\'s metric') 39 | pytime = timeit.timeit( 40 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)', 41 | setup=setup, 42 | number=20 43 | ) 44 | cytime = timeit.timeit( 45 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)', 46 | setup=setup, 47 | number=20 48 | ) 49 | print('Python time: {} s'.format(pytime)) 50 | print('Cython time: {} s'.format(cytime)) 51 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 52 | 53 | print('=> Using cuhk03\'s metric') 54 | pytime = timeit.timeit( 55 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)', 56 | setup=setup, 57 | number=20 58 | ) 59 | cytime = timeit.timeit( 60 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)', 61 | setup=setup, 62 | number=20 63 | ) 64 | print('Python time: {} s'.format(pytime)) 65 | print('Cython time: {} s'.format(cytime)) 66 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 67 | """ 68 | print("=> Check precision") 69 | 70 | num_q = 30 71 | num_g = 300 72 | max_rank = 5 73 | distmat = np.random.rand(num_q, num_g) * 20 74 | q_pids = np.random.randint(0, num_q, size=num_q) 75 | g_pids = np.random.randint(0, num_g, size=num_g) 76 | q_camids = np.random.randint(0, 5, size=num_q) 77 | g_camids = np.random.randint(0, 5, size=num_g) 78 | 79 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False) 80 | print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 81 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True) 82 | print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 83 | """ 84 | -------------------------------------------------------------------------------- /idm/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /idm/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | import torch 5 | from .evaluation_metrics import cmc, mean_ap 6 | from .utils.meters import AverageMeter 7 | from .utils.rerank import re_ranking 8 | from .evaluation_metrics.rank import evaluate_rank 9 | from .utils import to_torch 10 | 11 | 12 | def extract_cnn_feature(model, inputs): 13 | inputs = to_torch(inputs).cuda() 14 | outputs = model(inputs) 15 | outputs = outputs.data.cpu() 16 | return outputs 17 | 18 | 19 | def extract_features(model, data_loader, print_freq=50): 20 | model.eval() 21 | batch_time = AverageMeter() 22 | data_time = AverageMeter() 23 | 24 | features = OrderedDict() 25 | labels = OrderedDict() 26 | 27 | end = time.time() 28 | with torch.no_grad(): 29 | for i, (imgs, fnames, pids, _, _) in enumerate(data_loader): 30 | data_time.update(time.time() - end) 31 | 32 | outputs = extract_cnn_feature(model, imgs) 33 | for fname, output, pid in zip(fnames, outputs, pids): 34 | features[fname] = output 35 | labels[fname] = pid 36 | 37 | batch_time.update(time.time() - end) 38 | end = time.time() 39 | 40 | if (i + 1) % print_freq == 0: 41 | print('Extract Features: [{}/{}]\t' 42 | 'Time {:.3f} ({:.3f})\t' 43 | 'Data {:.3f} ({:.3f})\t' 44 | .format(i + 1, len(data_loader), 45 | batch_time.val, batch_time.avg, 46 | data_time.val, data_time.avg)) 47 | 48 | return features, labels 49 | 50 | 51 | def pairwise_distance(features, query=None, gallery=None): 52 | if query is None and gallery is None: 53 | n = len(features) 54 | x = torch.cat(list(features.values())) 55 | x = x.view(n, -1) 56 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 57 | dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t()) 58 | return dist_m 59 | 60 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 61 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 62 | m, n = x.size(0), y.size(0) 63 | x = x.view(m, -1) 64 | y = y.view(n, -1) 65 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 66 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 67 | dist_m.addmm_(1, -2, x, y.t()) 68 | return dist_m, x.numpy(), y.numpy() 69 | 70 | # def evaluate_all(query_features, gallery_features, distmat, query=None, gallery=None, 71 | # query_ids=None, gallery_ids=None, 72 | # query_cams=None, gallery_cams=None, 73 | # cmc_topk=(1, 5, 10), cmc_flag=False): 74 | # if query is not None and gallery is not None: 75 | # query_ids = [pid for _, pid, _ in query] 76 | # gallery_ids = [pid for _, pid, _ in gallery] 77 | # query_cams = [cam for _, _, cam in query] 78 | # gallery_cams = [cam for _, _, cam in gallery] 79 | # else: 80 | # assert (query_ids is not None and gallery_ids is not None 81 | # and query_cams is not None and gallery_cams is not None) 82 | # 83 | # # Compute mean AP 84 | # mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 85 | # print('Mean AP: {:4.1%}'.format(mAP)) 86 | # 87 | # if (not cmc_flag): 88 | # return mAP 89 | # 90 | # cmc_configs = { 91 | # 'market1501': dict(separate_camera_set=False, 92 | # single_gallery_shot=False, 93 | # first_match_break=True),} 94 | # cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 95 | # query_cams, gallery_cams, **params) 96 | # for name, params in cmc_configs.items()} 97 | # 98 | # print('CMC Scores:') 99 | # for k in cmc_topk: 100 | # print(' top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k-1])) 101 | # return cmc_scores['market1501'], mAP 102 | 103 | def evaluate_all(query_features, gallery_features, distmat, query=None, gallery=None, 104 | query_ids=None, gallery_ids=None, 105 | query_cams=None, gallery_cams=None, 106 | cmc_topk=(1, 5, 10), cmc_flag=False): 107 | if query is not None and gallery is not None: 108 | query_ids = [pid for _, pid, _ in query] 109 | gallery_ids = [pid for _, pid, _ in gallery] 110 | query_cams = [cam for _, _, cam in query] 111 | gallery_cams = [cam for _, _, cam in gallery] 112 | else: 113 | assert (query_ids is not None and gallery_ids is not None 114 | and query_cams is not None and gallery_cams is not None) 115 | 116 | 117 | results = evaluate_rank(distmat, query_ids, gallery_ids, query_cams, gallery_cams, 118 | max_rank=50, use_metric_cuhk03=False, use_cython=True) 119 | all_cmc, mAP = results 120 | 121 | print('Mean AP: {:4.1%}'.format(mAP)) 122 | 123 | if (not cmc_flag): 124 | return mAP 125 | 126 | print('CMC Scores:') 127 | for k in cmc_topk: 128 | print(' top-{:<4}{:12.1%}'.format(k, all_cmc[k-1])) 129 | return all_cmc[0], mAP 130 | 131 | 132 | class Evaluator(object): 133 | def __init__(self, model): 134 | super(Evaluator, self).__init__() 135 | self.model = model 136 | 137 | def evaluate(self, data_loader, query, gallery, cmc_flag=False, rerank=False): 138 | features, _ = extract_features(self.model, data_loader) 139 | distmat, query_features, gallery_features = pairwise_distance(features, query, gallery) 140 | results = evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 141 | 142 | if (not rerank): 143 | return results 144 | 145 | print('Applying person re-ranking ...') 146 | distmat_qq, _, _ = pairwise_distance(features, query, query) 147 | distmat_gg, _, _ = pairwise_distance(features, gallery, gallery) 148 | distmat = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy()) 149 | return evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 150 | -------------------------------------------------------------------------------- /idm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .triplet import TripletLoss 4 | from .triplet_xbm import TripletLossXBM 5 | from .crossentropy import CrossEntropyLabelSmooth 6 | from .idm_loss import DivLoss, BridgeFeatLoss, BridgeProbLoss 7 | 8 | __all__ = [ 9 | 'DivLoss', 10 | 'BridgeFeatLoss', 11 | 'BridgeProbLoss', 12 | 'TripletLoss', 13 | 'TripletLossXBM', 14 | 'CrossEntropyLabelSmooth', 15 | ] 16 | -------------------------------------------------------------------------------- /idm/loss/crossentropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CrossEntropyLabelSmooth(nn.Module): 6 | """Cross entropy loss with label smoothing regularizer. 7 | 8 | Reference: 9 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 10 | Equation: y = (1 - epsilon) * y + epsilon / K. 11 | 12 | Args: 13 | num_classes (int): number of classes. 14 | epsilon (float): weight. 15 | """ 16 | 17 | def __init__(self, num_classes, epsilon=0.1): 18 | super(CrossEntropyLabelSmooth, self).__init__() 19 | self.num_classes = num_classes 20 | self.epsilon = epsilon 21 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (num_classes) 28 | """ 29 | 30 | log_probs = self.logsoftmax(inputs) 31 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 33 | loss = (- targets * log_probs).mean(0).sum() 34 | return loss 35 | -------------------------------------------------------------------------------- /idm/loss/idm_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class DivLoss(nn.Module): 8 | def __init__(self, ): 9 | super(DivLoss, self).__init__() 10 | 11 | def forward(self, scores): 12 | mu = scores.mean(0) 13 | std = ((scores-mu)**2).mean(0,keepdim=True).clamp(min=1e-12).sqrt() 14 | loss_std = -std.sum() 15 | return loss_std 16 | 17 | 18 | class BridgeFeatLoss(nn.Module): 19 | def __init__(self): 20 | super(BridgeFeatLoss, self).__init__() 21 | 22 | def forward(self, feats_s, feats_t, feats_mixed, lam): 23 | 24 | dist_mixed2s = ((feats_mixed-feats_s)**2).sum(1, keepdim=True) 25 | dist_mixed2t = ((feats_mixed-feats_t)**2).sum(1, keepdim=True) 26 | 27 | dist_mixed2s = dist_mixed2s.clamp(min=1e-12).sqrt() 28 | dist_mixed2t = dist_mixed2t.clamp(min=1e-12).sqrt() 29 | 30 | dist_mixed = torch.cat((dist_mixed2s, dist_mixed2t), 1) 31 | lam_dist_mixed = (lam*dist_mixed).sum(1, keepdim=True) 32 | loss = lam_dist_mixed.mean() 33 | 34 | return loss 35 | 36 | 37 | class BridgeProbLoss(nn.Module): 38 | 39 | def __init__(self, num_classes, epsilon=0.1): 40 | super(BridgeProbLoss, self).__init__() 41 | self.num_classes = num_classes 42 | self.epsilon = epsilon 43 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 44 | self.device_num = torch.cuda.device_count() 45 | 46 | def forward(self, inputs, targets, lam): 47 | 48 | inputs = inputs.view(self.device_num, -1, inputs.size(-1)) 49 | inputs_s, inputs_t, inputs_mixed = inputs.split(inputs.size(1) // 3, dim=1) 50 | inputs_ori = torch.cat((inputs_s, inputs_t), 1).view(-1, inputs.size(-1)) 51 | inputs_mixed = inputs_mixed.contiguous().view(-1, inputs.size(-1)) 52 | log_probs_ori = self.logsoftmax(inputs_ori) 53 | log_probs_mixed = self.logsoftmax(inputs_mixed) 54 | 55 | targets = torch.zeros_like(log_probs_ori).scatter_(1, targets.unsqueeze(1), 1) 56 | targets = targets.view(self.device_num, -1, targets.size(-1)) 57 | targets_s, targets_t = targets.split(targets.size(1) // 2, dim=1) 58 | targets_s = targets_s.contiguous() 59 | targets_t = targets_t.contiguous() 60 | 61 | targets_s = targets_s.contiguous().view(-1, targets.size(-1)) 62 | targets_t = targets_t.contiguous().view(-1, targets.size(-1)) 63 | 64 | targets = targets.view(-1, targets.size(-1)) 65 | soft_targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 66 | 67 | lam = lam.view(-1, 1) 68 | soft_targets_mixed = lam*targets_s+(1.-lam)*targets_t 69 | soft_targets_mixed = (1 - self.epsilon) * soft_targets_mixed + self.epsilon / self.num_classes 70 | loss_ori = (- soft_targets*log_probs_ori).mean(0).sum() 71 | loss_bridge_prob = (- soft_targets_mixed*log_probs_mixed).mean(0).sum() 72 | 73 | return loss_ori, loss_bridge_prob 74 | -------------------------------------------------------------------------------- /idm/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def euclidean_dist(x, y): 9 | m, n = x.size(0), y.size(0) 10 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 11 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 12 | dist = xx + yy 13 | dist.addmm_(1, -2, x, y.t()) 14 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 15 | return dist 16 | 17 | def _batch_hard(mat_distance, mat_similarity, indice=False): 18 | sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1, descending=True) 19 | hard_p = sorted_mat_distance[:, 0] 20 | hard_p_indice = positive_indices[:, 0] 21 | sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1, descending=False) 22 | hard_n = sorted_mat_distance[:, 0] 23 | hard_n_indice = negative_indices[:, 0] 24 | if(indice): 25 | return hard_p, hard_n, hard_p_indice, hard_n_indice 26 | return hard_p, hard_n 27 | 28 | class TripletLoss(nn.Module): 29 | ''' 30 | Compute Triplet loss augmented with Batch Hard 31 | Details can be seen in 'In defense of the Triplet Loss for Person Re-Identification' 32 | ''' 33 | 34 | def __init__(self, margin, normalize_feature=False): 35 | super(TripletLoss, self).__init__() 36 | self.margin = margin 37 | self.normalize_feature = normalize_feature 38 | self.margin_loss = nn.MarginRankingLoss(margin=margin).cuda() 39 | 40 | def forward(self, emb, label): 41 | if self.normalize_feature: 42 | emb = F.normalize(emb) 43 | mat_dist = euclidean_dist(emb, emb) 44 | 45 | assert mat_dist.size(0) == mat_dist.size(1) 46 | N = mat_dist.size(0) 47 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 48 | 49 | dist_ap, dist_an = _batch_hard(mat_dist, mat_sim) 50 | assert dist_an.size(0)==dist_ap.size(0) 51 | y = torch.ones_like(dist_ap) 52 | loss = self.margin_loss(dist_an, dist_ap, y) 53 | 54 | return loss 55 | -------------------------------------------------------------------------------- /idm/loss/triplet_xbm.py: -------------------------------------------------------------------------------- 1 | # Reference: Wang et al. Cross-Batch Memory for Embedding Learning, in CVPR 2020. 2 | # https://github.com/msight-tech/research-xbm 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def euclidean_dist(x, y): 10 | m, n = x.size(0), y.size(0) 11 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 12 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 13 | dist = xx + yy 14 | dist.addmm_(1, -2, x, y.t()) 15 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 16 | return dist 17 | 18 | class TripletLossXBM(nn.Module): 19 | def __init__(self, margin=0.3, norm=False): 20 | super(TripletLossXBM, self).__init__() 21 | self.margin = margin 22 | self.norm = norm 23 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 24 | 25 | def forward(self, inputs_col, targets_col, inputs_row, targets_row): 26 | 27 | n = inputs_col.size(0) 28 | if self.norm: 29 | inputs_col = F.normalize(inputs_col) 30 | inputs_row = F.normalize(inputs_row) 31 | 32 | dist = euclidean_dist(inputs_col, inputs_row) 33 | 34 | # split the positive and negative pairs 35 | pos_mask = targets_col.expand( 36 | targets_row.shape[0], n 37 | ).t() == targets_row.expand(n, targets_row.shape[0]) 38 | neg_mask = ~pos_mask 39 | # For each anchor, find the hardest positive and negative 40 | dist_ap, dist_an = [], [] 41 | 42 | for i in range(n): 43 | dist_ap.append(dist[i][pos_mask[i]].max().unsqueeze(0)) 44 | dist_an.append(dist[i][neg_mask[i]].min().unsqueeze(0)) 45 | 46 | dist_ap = torch.cat(dist_ap) 47 | dist_an = torch.cat(dist_an) 48 | 49 | # Compute ranking hinge loss 50 | y = torch.ones_like(dist_an) 51 | loss = self.ranking_loss(dist_an, dist_ap, y) 52 | 53 | return loss 54 | -------------------------------------------------------------------------------- /idm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .idm_module import * 5 | from .resnet_idm import * 6 | from .resnet_ibn import * 7 | from .resnet_ibn_idm import * 8 | 9 | 10 | __factory = { 11 | 'resnet18': resnet18, 12 | 'resnet34': resnet34, 13 | 'resnet50': resnet50, 14 | 'resnet101': resnet101, 15 | 'resnet152': resnet152, 16 | 'resnet_ibn50a': resnet_ibn50a, 17 | 'resnet_ibn101a': resnet_ibn101a, 18 | 'resnet50_idm': resnet50_idm, 19 | 'resnet_ibn50a_idm': resnet_ibn50a_idm 20 | } 21 | 22 | 23 | def names(): 24 | return sorted(__factory.keys()) 25 | 26 | 27 | def create(name, *args, **kwargs): 28 | """ 29 | Create a model instance. 30 | 31 | Parameters 32 | ---------- 33 | name : str 34 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 35 | 'resnet50', 'resnet101', and 'resnet152'. 36 | pretrained : bool, optional 37 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 38 | model. Default: True 39 | cut_at_pooling : bool, optional 40 | If True, will cut the model before the last global pooling layer and 41 | ignore the remaining kwargs. Default: False 42 | num_features : int, optional 43 | If positive, will append a Linear layer after the global pooling layer, 44 | with this number of output units, followed by a BatchNorm layer. 45 | Otherwise these layers will not be appended. Default: 256 for 46 | 'inception', 0 for 'resnet*' 47 | norm : bool, optional 48 | If True, will normalize the feature to be unit L2-norm for each sample. 49 | Otherwise will append a ReLU layer after the above Linear layer if 50 | num_features > 0. Default: False 51 | dropout : float, optional 52 | If positive, will append a Dropout layer with this dropout rate. 53 | Default: 0 54 | num_classes : int, optional 55 | If positive, will append a Linear layer at the end as the classifier 56 | with this number of output units. Default: 0 57 | """ 58 | if name not in __factory: 59 | raise KeyError("Unknown model:", name) 60 | return __factory[name](*args, **kwargs) 61 | -------------------------------------------------------------------------------- /idm/models/dsbn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Domain-specific BatchNorm 5 | 6 | class DSBN2d(nn.Module): 7 | def __init__(self, planes): 8 | super(DSBN2d, self).__init__() 9 | self.num_features = planes 10 | self.BN_S = nn.BatchNorm2d(planes) 11 | self.BN_T = nn.BatchNorm2d(planes) 12 | 13 | def forward(self, x): 14 | if (not self.training): 15 | return self.BN_T(x) 16 | 17 | bs = x.size(0) 18 | assert (bs%2==0) 19 | split = torch.split(x, int(bs/2), 0) 20 | out1 = self.BN_S(split[0].contiguous()) 21 | out2 = self.BN_T(split[1].contiguous()) 22 | out = torch.cat((out1, out2), 0) 23 | return out 24 | 25 | class DSBN1d(nn.Module): 26 | def __init__(self, planes): 27 | super(DSBN1d, self).__init__() 28 | self.num_features = planes 29 | self.BN_S = nn.BatchNorm1d(planes) 30 | self.BN_T = nn.BatchNorm1d(planes) 31 | 32 | def forward(self, x): 33 | if (not self.training): 34 | return self.BN_T(x) 35 | 36 | bs = x.size(0) 37 | assert (bs%2==0) 38 | split = torch.split(x, int(bs/2), 0) 39 | out1 = self.BN_S(split[0].contiguous()) 40 | out2 = self.BN_T(split[1].contiguous()) 41 | out = torch.cat((out1, out2), 0) 42 | return out 43 | 44 | def convert_dsbn(model): 45 | for _, (child_name, child) in enumerate(model.named_children()): 46 | assert(not next(model.parameters()).is_cuda) 47 | if isinstance(child, nn.BatchNorm2d): 48 | m = DSBN2d(child.num_features) 49 | m.BN_S.load_state_dict(child.state_dict()) 50 | m.BN_T.load_state_dict(child.state_dict()) 51 | setattr(model, child_name, m) 52 | elif isinstance(child, nn.BatchNorm1d) and child_name!='d_bn1': 53 | m = DSBN1d(child.num_features) 54 | m.BN_S.load_state_dict(child.state_dict()) 55 | m.BN_T.load_state_dict(child.state_dict()) 56 | setattr(model, child_name, m) 57 | else: 58 | convert_dsbn(child) 59 | 60 | def convert_bn(model, use_target=True): 61 | for _, (child_name, child) in enumerate(model.named_children()): 62 | assert(not next(model.parameters()).is_cuda) 63 | if isinstance(child, DSBN2d): 64 | m = nn.BatchNorm2d(child.num_features) 65 | if use_target: 66 | m.load_state_dict(child.BN_T.state_dict()) 67 | else: 68 | m.load_state_dict(child.BN_S.state_dict()) 69 | setattr(model, child_name, m) 70 | elif isinstance(child, DSBN1d): 71 | m = nn.BatchNorm1d(child.num_features) 72 | if use_target: 73 | m.load_state_dict(child.BN_T.state_dict()) 74 | else: 75 | m.load_state_dict(child.BN_S.state_dict()) 76 | setattr(model, child_name, m) 77 | else: 78 | convert_bn(child, use_target=use_target) 79 | -------------------------------------------------------------------------------- /idm/models/idm_dsbn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Domain-specific BatchNorm 6 | class DSBN2d(nn.Module): 7 | def __init__(self, planes): 8 | super(DSBN2d, self).__init__() 9 | self.num_features = planes 10 | self.BN_S = nn.BatchNorm2d(planes) 11 | self.BN_T = nn.BatchNorm2d(planes) 12 | 13 | def forward(self, x): 14 | if (not self.training): 15 | return self.BN_T(x) 16 | 17 | bs = x.size(0) 18 | assert (bs%2==0) 19 | split = torch.split(x, int(bs/2), 0) 20 | out1 = self.BN_S(split[0].contiguous()) 21 | out2 = self.BN_T(split[1].contiguous()) 22 | out = torch.cat((out1, out2), 0) 23 | return out 24 | 25 | 26 | class DSBN1d(nn.Module): 27 | def __init__(self, planes): 28 | super(DSBN1d, self).__init__() 29 | self.num_features = planes 30 | self.BN_S = nn.BatchNorm1d(planes) 31 | self.BN_T = nn.BatchNorm1d(planes) 32 | 33 | def forward(self, x): 34 | if (not self.training): 35 | return self.BN_T(x) 36 | 37 | bs = x.size(0) 38 | assert (bs%2==0) 39 | split = torch.split(x, int(bs/2), 0) 40 | out1 = self.BN_S(split[0].contiguous()) 41 | out2 = self.BN_T(split[1].contiguous()) 42 | out = torch.cat((out1, out2), 0) 43 | return out 44 | 45 | 46 | class DSBN2d_idm(nn.Module): 47 | def __init__(self, planes): 48 | super(DSBN2d_idm, self).__init__() 49 | self.num_features = planes 50 | self.BN_S = nn.BatchNorm2d(planes) 51 | self.BN_T = nn.BatchNorm2d(planes) 52 | self.BN_mix = nn.BatchNorm2d(planes) 53 | 54 | def forward(self, x): 55 | if (not self.training): 56 | return self.BN_T(x) 57 | 58 | bs = x.size(0) 59 | assert (bs%3==0) 60 | split = torch.split(x, int(bs/3), 0) 61 | out1 = self.BN_S(split[0].contiguous()) 62 | out2 = self.BN_T(split[1].contiguous()) 63 | out3 = self.BN_mix(split[2].contiguous()) 64 | out = torch.cat((out1, out2, out3), 0) 65 | return out 66 | 67 | 68 | class DSBN1d_idm(nn.Module): 69 | def __init__(self, planes): 70 | super(DSBN1d_idm, self).__init__() 71 | self.num_features = planes 72 | self.BN_S = nn.BatchNorm1d(planes) 73 | self.BN_T = nn.BatchNorm1d(planes) 74 | self.BN_mix = nn.BatchNorm1d(planes) 75 | 76 | def forward(self, x): 77 | if (not self.training): 78 | return self.BN_T(x) 79 | 80 | bs = x.size(0) 81 | assert (bs%3==0) 82 | split = torch.split(x, int(bs/3), 0) 83 | out1 = self.BN_S(split[0].contiguous()) 84 | out2 = self.BN_T(split[1].contiguous()) 85 | out3 = self.BN_mix(split[2].contiguous()) 86 | out = torch.cat((out1, out2, out3), 0) 87 | return out 88 | 89 | 90 | 91 | def convert_dsbn_idm(model, mixup_bn_names, idm=False): 92 | 93 | for _, (child_name, child) in enumerate(model.named_children()): 94 | # print(child_name) 95 | idm_flag = idm 96 | assert(not next(model.parameters()).is_cuda) 97 | for name in mixup_bn_names: 98 | if name in child_name: 99 | idm_flag = True 100 | if isinstance(child, nn.BatchNorm2d) and not idm_flag: 101 | m = DSBN2d(child.num_features) 102 | m.BN_S.load_state_dict(child.state_dict()) 103 | m.BN_T.load_state_dict(child.state_dict()) 104 | setattr(model, child_name, m) 105 | elif isinstance(child, nn.BatchNorm2d) and idm_flag: 106 | m = DSBN2d_idm(child.num_features) 107 | m.BN_S.load_state_dict(child.state_dict()) 108 | m.BN_T.load_state_dict(child.state_dict()) 109 | m.BN_mix.load_state_dict(child.state_dict()) 110 | setattr(model, child_name, m) 111 | elif isinstance(child, nn.BatchNorm1d) and not idm_flag: 112 | m = DSBN1d(child.num_features) 113 | m.BN_S.load_state_dict(child.state_dict()) 114 | m.BN_T.load_state_dict(child.state_dict()) 115 | setattr(model, child_name, m) 116 | elif isinstance(child, nn.BatchNorm1d) and idm_flag: 117 | m = DSBN1d_idm(child.num_features) 118 | m.BN_S.load_state_dict(child.state_dict()) 119 | m.BN_T.load_state_dict(child.state_dict()) 120 | m.BN_mix.load_state_dict(child.state_dict()) 121 | setattr(model, child_name, m) 122 | else: 123 | convert_dsbn_idm(child, mixup_bn_names, idm=idm_flag) 124 | 125 | 126 | def convert_bn_idm(model, use_target=True): 127 | for _, (child_name, child) in enumerate(model.named_children()): 128 | assert(not next(model.parameters()).is_cuda) 129 | if isinstance(child, DSBN2d): 130 | m = nn.BatchNorm2d(child.num_features) 131 | if use_target: 132 | m.load_state_dict(child.BN_T.state_dict()) 133 | else: 134 | m.load_state_dict(child.BN_S.state_dict()) 135 | setattr(model, child_name, m) 136 | elif isinstance(child, DSBN2d_idm): 137 | m = nn.BatchNorm2d(child.num_features) 138 | if use_target: 139 | m.load_state_dict(child.BN_T.state_dict()) 140 | else: 141 | m.load_state_dict(child.BN_S.state_dict()) 142 | setattr(model, child_name, m) 143 | elif isinstance(child, DSBN1d): 144 | m = nn.BatchNorm1d(child.num_features) 145 | if use_target: 146 | m.load_state_dict(child.BN_T.state_dict()) 147 | else: 148 | m.load_state_dict(child.BN_S.state_dict()) 149 | setattr(model, child_name, m) 150 | elif isinstance(child, DSBN1d_idm): 151 | m = nn.BatchNorm1d(child.num_features) 152 | if use_target: 153 | m.load_state_dict(child.BN_T.state_dict()) 154 | else: 155 | m.load_state_dict(child.BN_S.state_dict()) 156 | setattr(model, child_name, m) 157 | else: 158 | convert_bn_idm(child, use_target=use_target) 159 | 160 | -------------------------------------------------------------------------------- /idm/models/idm_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | import torch 5 | 6 | 7 | class IDM(nn.Module): 8 | def __init__(self, channel=64): 9 | super(IDM, self).__init__() 10 | self.channel = channel 11 | self.adaptiveFC1 = nn.Linear(2*channel, channel) 12 | self.adaptiveFC2 = nn.Linear(channel, int(channel/2)) 13 | self.adaptiveFC3 = nn.Linear(int(channel/2), 2) 14 | self.softmax = nn.Softmax(dim=1) 15 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 16 | self.max_pool = nn.AdaptiveMaxPool2d(1) 17 | 18 | def forward(self, x): 19 | 20 | if (not self.training): 21 | return x 22 | 23 | bs = x.size(0) 24 | assert (bs%2==0) 25 | split = torch.split(x, int(bs/2), 0) 26 | x_s = split[0].contiguous() # [B, C, H, W] 27 | x_t = split[1].contiguous() 28 | 29 | x_embd_s = torch.cat((self.avg_pool(x_s.detach()).squeeze(), self.max_pool(x_s.detach()).squeeze()), 1) # [B, 2*C] 30 | x_embd_t = torch.cat((self.avg_pool(x_t.detach()).squeeze(), self.max_pool(x_t.detach()).squeeze()), 1) 31 | 32 | x_embd_s, x_embd_t = self.adaptiveFC1(x_embd_s), self.adaptiveFC1(x_embd_t) # [B, C] 33 | x_embd = x_embd_s+x_embd_t 34 | x_embd = self.adaptiveFC2(x_embd) 35 | lam = self.adaptiveFC3(x_embd) 36 | lam = self.softmax(lam) # [B, 2] 37 | x_inter = lam[:, 0].reshape(-1,1,1,1)*x_s + lam[:, 1].reshape(-1,1,1,1)*x_t 38 | out = torch.cat((x_s, x_t, x_inter), 0) 39 | return out, lam 40 | -------------------------------------------------------------------------------- /idm/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | from collections import OrderedDict 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152'] 12 | 13 | 14 | class ResNet(nn.Module): 15 | __factory = { 16 | 18: torchvision.models.resnet18, 17 | 34: torchvision.models.resnet34, 18 | 50: torchvision.models.resnet50, 19 | 101: torchvision.models.resnet101, 20 | 152: torchvision.models.resnet152, 21 | } 22 | 23 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 24 | num_features=0, norm=False, dropout=0, num_classes=0): 25 | super(ResNet, self).__init__() 26 | self.pretrained = pretrained 27 | self.depth = depth 28 | self.cut_at_pooling = cut_at_pooling 29 | # Construct base (pretrained) resnet 30 | if depth not in ResNet.__factory: 31 | raise KeyError("Unsupported depth:", depth) 32 | resnet = ResNet.__factory[depth](pretrained=pretrained) 33 | resnet.layer4[0].conv2.stride = (1,1) 34 | resnet.layer4[0].downsample[0].stride = (1,1) 35 | # self.base = nn.Sequential( 36 | # resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 37 | # resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 38 | 39 | self.conv = nn.Sequential(OrderedDict([ 40 | ('conv1', resnet.conv1), 41 | ('bn1', resnet.bn1), 42 | ('relu', resnet.relu), 43 | ('maxpool', resnet.maxpool)])) 44 | 45 | self.layer1 = resnet.layer1 46 | self.layer2 = resnet.layer2 47 | self.layer3 = resnet.layer3 48 | self.layer4 = resnet.layer4 49 | 50 | 51 | self.gap = nn.AdaptiveAvgPool2d(1) 52 | 53 | if not self.cut_at_pooling: 54 | self.num_features = num_features 55 | self.norm = norm 56 | self.dropout = dropout 57 | self.has_embedding = num_features > 0 58 | self.num_classes = num_classes 59 | 60 | out_planes = resnet.fc.in_features 61 | 62 | # Append new layers 63 | if self.has_embedding: 64 | self.feat = nn.Linear(out_planes, self.num_features) 65 | self.feat_bn = nn.BatchNorm1d(self.num_features) 66 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 67 | init.constant_(self.feat.bias, 0) 68 | else: 69 | # Change the num_features to CNN output channels 70 | self.num_features = out_planes 71 | self.feat_bn = nn.BatchNorm1d(self.num_features) 72 | self.feat_bn.bias.requires_grad_(False) 73 | if self.dropout > 0: 74 | self.drop = nn.Dropout(self.dropout) 75 | if self.num_classes > 0: 76 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 77 | init.normal_(self.classifier.weight, std=0.001) 78 | init.constant_(self.feat_bn.weight, 1) 79 | init.constant_(self.feat_bn.bias, 0) 80 | 81 | if not pretrained: 82 | self.reset_params() 83 | 84 | def forward(self, x, output_prob=False): 85 | bs = x.size(0) 86 | # x = self.base(x) 87 | x = self.conv(x) 88 | x = self.layer1(x) 89 | x = self.layer2(x) 90 | x = self.layer3(x) 91 | x = self.layer4(x) 92 | 93 | x = self.gap(x) 94 | x = x.view(x.size(0), -1) 95 | 96 | if self.cut_at_pooling: 97 | return x 98 | 99 | if self.has_embedding: 100 | bn_x = self.feat_bn(self.feat(x)) 101 | else: 102 | bn_x = self.feat_bn(x) 103 | 104 | 105 | if (self.training is False and output_prob is False): 106 | bn_x = F.normalize(bn_x) 107 | return bn_x 108 | 109 | if self.norm: 110 | norm_bn_x = F.normalize(bn_x) 111 | elif self.has_embedding: 112 | bn_x = F.relu(bn_x) 113 | 114 | if self.dropout > 0: 115 | bn_x = self.drop(bn_x) 116 | 117 | if self.num_classes > 0: 118 | prob = self.classifier(bn_x) 119 | else: 120 | return bn_x 121 | 122 | if self.norm: 123 | return prob, x, norm_bn_x 124 | else: 125 | return prob, x 126 | 127 | def reset_params(self): 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | init.kaiming_normal_(m.weight, mode='fan_out') 131 | if m.bias is not None: 132 | init.constant_(m.bias, 0) 133 | elif isinstance(m, nn.BatchNorm2d): 134 | init.constant_(m.weight, 1) 135 | init.constant_(m.bias, 0) 136 | elif isinstance(m, nn.BatchNorm1d): 137 | init.constant_(m.weight, 1) 138 | init.constant_(m.bias, 0) 139 | elif isinstance(m, nn.Linear): 140 | init.normal_(m.weight, std=0.001) 141 | if m.bias is not None: 142 | init.constant_(m.bias, 0) 143 | 144 | 145 | def resnet18(**kwargs): 146 | return ResNet(18, **kwargs) 147 | 148 | 149 | def resnet34(**kwargs): 150 | return ResNet(34, **kwargs) 151 | 152 | 153 | def resnet50(**kwargs): 154 | return ResNet(50, **kwargs) 155 | 156 | 157 | def resnet101(**kwargs): 158 | return ResNet(101, **kwargs) 159 | 160 | 161 | def resnet152(**kwargs): 162 | return ResNet(152, **kwargs) 163 | -------------------------------------------------------------------------------- /idm/models/resnet_ibn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | from .resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a 7 | 8 | 9 | __all__ = ['ResNetIBN', 'resnet_ibn50a', 'resnet_ibn101a'] 10 | 11 | 12 | class ResNetIBN(nn.Module): 13 | __factory = { 14 | '50a': resnet50_ibn_a, 15 | '101a': resnet101_ibn_a 16 | } 17 | 18 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 19 | num_features=0, norm=False, dropout=0, num_classes=0): 20 | super(ResNetIBN, self).__init__() 21 | 22 | self.depth = depth 23 | self.pretrained = pretrained 24 | self.cut_at_pooling = cut_at_pooling 25 | 26 | resnet = ResNetIBN.__factory[depth](pretrained=pretrained) 27 | resnet.layer4[0].conv2.stride = (1,1) 28 | resnet.layer4[0].downsample[0].stride = (1,1) 29 | self.base = nn.Sequential( 30 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 31 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 32 | self.gap = nn.AdaptiveAvgPool2d(1) 33 | 34 | if not self.cut_at_pooling: 35 | self.num_features = num_features 36 | self.norm = norm 37 | self.dropout = dropout 38 | self.has_embedding = num_features > 0 39 | self.num_classes = num_classes 40 | 41 | out_planes = resnet.fc.in_features 42 | 43 | # Append new layers 44 | if self.has_embedding: 45 | self.feat = nn.Linear(out_planes, self.num_features) 46 | self.feat_bn = nn.BatchNorm1d(self.num_features) 47 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 48 | init.constant_(self.feat.bias, 0) 49 | else: 50 | # Change the num_features to CNN output channels 51 | self.num_features = out_planes 52 | self.feat_bn = nn.BatchNorm1d(self.num_features) 53 | self.feat_bn.bias.requires_grad_(False) 54 | if self.dropout > 0: 55 | self.drop = nn.Dropout(self.dropout) 56 | if self.num_classes > 0: 57 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 58 | init.normal_(self.classifier.weight, std=0.001) 59 | init.constant_(self.feat_bn.weight, 1) 60 | init.constant_(self.feat_bn.bias, 0) 61 | 62 | if not pretrained: 63 | self.reset_params() 64 | 65 | def forward(self, x): 66 | x = self.base(x) 67 | 68 | x = self.gap(x) 69 | x = x.view(x.size(0), -1) 70 | 71 | if self.cut_at_pooling: 72 | return x 73 | 74 | if self.has_embedding: 75 | bn_x = self.feat_bn(self.feat(x)) 76 | else: 77 | bn_x = self.feat_bn(x) 78 | 79 | if self.training is False: 80 | bn_x = F.normalize(bn_x) 81 | return bn_x 82 | 83 | if self.norm: 84 | bn_x = F.normalize(bn_x) 85 | elif self.has_embedding: 86 | bn_x = F.relu(bn_x) 87 | 88 | if self.dropout > 0: 89 | bn_x = self.drop(bn_x) 90 | 91 | if self.num_classes > 0: 92 | prob = self.classifier(bn_x) 93 | else: 94 | return bn_x 95 | 96 | return prob, x 97 | 98 | def reset_params(self): 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv2d): 101 | init.kaiming_normal_(m.weight, mode='fan_out') 102 | if m.bias is not None: 103 | init.constant_(m.bias, 0) 104 | elif isinstance(m, nn.BatchNorm2d): 105 | init.constant_(m.weight, 1) 106 | init.constant_(m.bias, 0) 107 | elif isinstance(m, nn.BatchNorm1d): 108 | init.constant_(m.weight, 1) 109 | init.constant_(m.bias, 0) 110 | elif isinstance(m, nn.Linear): 111 | init.normal_(m.weight, std=0.001) 112 | if m.bias is not None: 113 | init.constant_(m.bias, 0) 114 | 115 | 116 | def resnet_ibn50a(**kwargs): 117 | return ResNetIBN('50a', **kwargs) 118 | 119 | 120 | def resnet_ibn101a(**kwargs): 121 | return ResNetIBN('101a', **kwargs) 122 | -------------------------------------------------------------------------------- /idm/models/resnet_ibn_a.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | __all__ = ['ResNet', 'resnet50_ibn_a', 'resnet101_ibn_a'] 7 | 8 | 9 | model_urls = { 10 | 'ibn_resnet50a': './logs/pretrained/resnet50_ibn_a.pth.tar', 11 | 'ibn_resnet101a': './logs/pretrained/resnet101_ibn_a.pth.tar', 12 | } 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | "3x3 convolution with padding" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, inplanes, planes, stride=1, downsample=None): 25 | super(BasicBlock, self).__init__() 26 | self.conv1 = conv3x3(inplanes, planes, stride) 27 | self.bn1 = nn.BatchNorm2d(planes) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv2 = conv3x3(planes, planes) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | out += residual 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | class IBN(nn.Module): 54 | def __init__(self, planes): 55 | super(IBN, self).__init__() 56 | half1 = int(planes/2) 57 | self.half = half1 58 | half2 = planes - half1 59 | self.IN = nn.InstanceNorm2d(half1, affine=True) 60 | self.BN = nn.BatchNorm2d(half2) 61 | 62 | def forward(self, x): 63 | split = torch.split(x, self.half, 1) 64 | out1 = self.IN(split[0].contiguous()) 65 | out2 = self.BN(split[1].contiguous()) 66 | out = torch.cat((out1, out2), 1) 67 | return out 68 | 69 | 70 | class Bottleneck(nn.Module): 71 | expansion = 4 72 | 73 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): 74 | super(Bottleneck, self).__init__() 75 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 76 | if ibn: 77 | self.bn1 = IBN(planes) 78 | else: 79 | self.bn1 = nn.BatchNorm2d(planes) 80 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 81 | padding=1, bias=False) 82 | self.bn2 = nn.BatchNorm2d(planes) 83 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 84 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | residual = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | residual = self.downsample(x) 105 | 106 | out += residual 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, block, layers, num_classes=1000): 115 | scale = 64 116 | self.inplanes = scale 117 | super(ResNet, self).__init__() 118 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 119 | bias=False) 120 | self.bn1 = nn.BatchNorm2d(scale) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 123 | self.layer1 = self._make_layer(block, scale, layers[0]) 124 | self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2) 125 | self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2) 126 | self.layer4 = self._make_layer(block, scale*8, layers[3], stride=2) 127 | self.avgpool = nn.AvgPool2d(7) 128 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 129 | 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 133 | m.weight.data.normal_(0, math.sqrt(2. / n)) 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | elif isinstance(m, nn.InstanceNorm2d): 138 | m.weight.data.fill_(1) 139 | m.bias.data.zero_() 140 | 141 | def _make_layer(self, block, planes, blocks, stride=1): 142 | downsample = None 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | downsample = nn.Sequential( 145 | nn.Conv2d(self.inplanes, planes * block.expansion, 146 | kernel_size=1, stride=stride, bias=False), 147 | nn.BatchNorm2d(planes * block.expansion), 148 | ) 149 | 150 | layers = [] 151 | ibn = True 152 | if planes == 512: 153 | ibn = False 154 | layers.append(block(self.inplanes, planes, ibn, stride, downsample)) 155 | self.inplanes = planes * block.expansion 156 | for i in range(1, blocks): 157 | layers.append(block(self.inplanes, planes, ibn)) 158 | 159 | return nn.Sequential(*layers) 160 | 161 | def forward(self, x): 162 | x = self.conv1(x) 163 | x = self.bn1(x) 164 | x = self.relu(x) 165 | x = self.maxpool(x) 166 | 167 | x = self.layer1(x) 168 | x = self.layer2(x) 169 | x = self.layer3(x) 170 | x = self.layer4(x) 171 | 172 | x = self.avgpool(x) 173 | x = x.view(x.size(0), -1) 174 | x = self.fc(x) 175 | 176 | return x 177 | 178 | 179 | def resnet50_ibn_a(pretrained=False, **kwargs): 180 | """Constructs a ResNet-50 model. 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 185 | if pretrained: 186 | state_dict = torch.load(model_urls['ibn_resnet50a'], map_location=torch.device('cpu'))['state_dict'] 187 | state_dict = remove_module_key(state_dict) 188 | model.load_state_dict(state_dict) 189 | return model 190 | 191 | 192 | def resnet101_ibn_a(pretrained=False, **kwargs): 193 | """Constructs a ResNet-101 model. 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 198 | if pretrained: 199 | state_dict = torch.load(model_urls['ibn_resnet101a'], map_location=torch.device('cpu'))['state_dict'] 200 | state_dict = remove_module_key(state_dict) 201 | model.load_state_dict(state_dict) 202 | return model 203 | 204 | 205 | def remove_module_key(state_dict): 206 | for key in list(state_dict.keys()): 207 | if 'module' in key: 208 | state_dict[key.replace('module.','')] = state_dict.pop(key) 209 | return state_dict 210 | -------------------------------------------------------------------------------- /idm/models/resnet_ibn_idm.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | from collections import OrderedDict 9 | 10 | from .idm_module import IDM 11 | from .resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a 12 | 13 | __all__ = ['ResNetIBN', 'resnet_ibn50a_idm', 'resnet_ibn101a_idm'] 14 | 15 | 16 | class ResNetIBN(nn.Module): 17 | __factory = { 18 | '50a': resnet50_ibn_a, 19 | '101a': resnet101_ibn_a 20 | } 21 | 22 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 23 | num_features=0, norm=False, dropout=0, num_classes=0): 24 | super(ResNetIBN, self).__init__() 25 | 26 | self.depth = depth 27 | self.pretrained = pretrained 28 | self.cut_at_pooling = cut_at_pooling 29 | 30 | resnet = ResNetIBN.__factory[depth](pretrained=pretrained) 31 | resnet.layer4[0].conv2.stride = (1, 1) 32 | resnet.layer4[0].downsample[0].stride = (1, 1) 33 | 34 | self.conv = nn.Sequential(OrderedDict([ 35 | ('conv1', resnet.conv1), 36 | ('bn1', resnet.bn1), 37 | ('relu', resnet.relu), 38 | ('maxpool', resnet.maxpool)])) 39 | 40 | self.layer1 = resnet.layer1 41 | self.layer2 = resnet.layer2 42 | self.layer3 = resnet.layer3 43 | self.layer4 = resnet.layer4 44 | 45 | self.gap = nn.AdaptiveAvgPool2d(1) 46 | 47 | self.idm1 = IDM(channel=64) 48 | self.idm2 = IDM(channel=256) 49 | self.idm3 = IDM(channel=512) 50 | self.idm4 = IDM(channel=1024) 51 | self.idm5 = IDM(channel=2048) 52 | 53 | if not self.cut_at_pooling: 54 | self.num_features = num_features 55 | self.norm = norm 56 | self.dropout = dropout 57 | self.has_embedding = num_features > 0 58 | self.num_classes = num_classes 59 | 60 | out_planes = resnet.fc.in_features 61 | 62 | # Append new layers 63 | if self.has_embedding: 64 | self.feat = nn.Linear(out_planes, self.num_features) 65 | self.feat_bn = nn.BatchNorm1d(self.num_features) 66 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 67 | init.constant_(self.feat.bias, 0) 68 | else: 69 | # Change the num_features to CNN output channels 70 | self.num_features = out_planes 71 | self.feat_bn = nn.BatchNorm1d(self.num_features) 72 | self.feat_bn.bias.requires_grad_(False) 73 | if self.dropout > 0: 74 | self.drop = nn.Dropout(self.dropout) 75 | if self.num_classes > 0: 76 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 77 | init.normal_(self.classifier.weight, std=0.001) 78 | init.constant_(self.feat_bn.weight, 1) 79 | init.constant_(self.feat_bn.bias, 0) 80 | 81 | if not pretrained: 82 | self.reset_params() 83 | 84 | def forward(self, x, stage=0): 85 | x = self.conv(x) 86 | if stage == 0 and self.training: 87 | x, attention_lam = self.idm1(x) 88 | x = self.layer1(x) 89 | if stage == 1 and self.training: 90 | x, attention_lam = self.idm2(x) 91 | x = self.layer2(x) 92 | if stage == 2 and self.training: 93 | x, attention_lam = self.idm3(x) 94 | x = self.layer3(x) 95 | if stage == 3 and self.training: 96 | x, attention_lam = self.idm4(x) 97 | x = self.layer4(x) 98 | if stage == 4 and self.training: 99 | x, attention_lam = self.idm5(x) 100 | 101 | x = self.gap(x) 102 | x = x.view(x.size(0), -1) 103 | 104 | if self.cut_at_pooling: 105 | return x 106 | 107 | if self.has_embedding: 108 | bn_x = self.feat_bn(self.feat(x)) 109 | else: 110 | bn_x = self.feat_bn(x) 111 | 112 | if self.training is False: 113 | bn_x = F.normalize(bn_x) 114 | return bn_x 115 | 116 | if self.norm: 117 | bn_x = F.normalize(bn_x) 118 | elif self.has_embedding: 119 | bn_x = F.relu(bn_x) 120 | 121 | if self.dropout > 0: 122 | bn_x = self.drop(bn_x) 123 | 124 | if self.num_classes > 0: 125 | prob = self.classifier(bn_x) 126 | else: 127 | return bn_x 128 | 129 | return prob, x, attention_lam 130 | 131 | def reset_params(self): 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | init.kaiming_normal_(m.weight, mode='fan_out') 135 | if m.bias is not None: 136 | init.constant_(m.bias, 0) 137 | elif isinstance(m, nn.BatchNorm2d): 138 | init.constant_(m.weight, 1) 139 | init.constant_(m.bias, 0) 140 | elif isinstance(m, nn.BatchNorm1d): 141 | init.constant_(m.weight, 1) 142 | init.constant_(m.bias, 0) 143 | elif isinstance(m, nn.Linear): 144 | init.normal_(m.weight, std=0.001) 145 | if m.bias is not None: 146 | init.constant_(m.bias, 0) 147 | 148 | 149 | def resnet_ibn50a_idm(**kwargs): 150 | return ResNetIBN('50a', **kwargs) 151 | 152 | 153 | def resnet_ibn101a_idm(**kwargs): 154 | return ResNetIBN('101a', **kwargs) 155 | 156 | -------------------------------------------------------------------------------- /idm/models/resnet_idm.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | from collections import OrderedDict 8 | from .idm_module import IDM 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18_idm', 'resnet34_idm', 'resnet50_idm', 12 | 'resnet101_idm', 'resnet152_idm'] 13 | 14 | 15 | class ResNet(nn.Module): 16 | __factory = { 17 | 18: torchvision.models.resnet18, 18 | 34: torchvision.models.resnet34, 19 | 50: torchvision.models.resnet50, 20 | 101: torchvision.models.resnet101, 21 | 152: torchvision.models.resnet152, 22 | } 23 | 24 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 25 | num_features=0, norm=False, dropout=0, num_classes=0): 26 | super(ResNet, self).__init__() 27 | self.pretrained = pretrained 28 | self.depth = depth 29 | self.cut_at_pooling = cut_at_pooling 30 | # Construct base (pretrained) resnet 31 | if depth not in ResNet.__factory: 32 | raise KeyError("Unsupported depth:", depth) 33 | resnet = ResNet.__factory[depth](pretrained=pretrained) 34 | resnet.layer4[0].conv2.stride = (1,1) 35 | resnet.layer4[0].downsample[0].stride = (1,1) 36 | 37 | self.conv = nn.Sequential(OrderedDict([ 38 | ('conv1', resnet.conv1), 39 | ('bn1', resnet.bn1), 40 | ('relu', resnet.relu), 41 | ('maxpool', resnet.maxpool)])) 42 | 43 | self.layer1 = resnet.layer1 44 | self.layer2 = resnet.layer2 45 | self.layer3 = resnet.layer3 46 | self.layer4 = resnet.layer4 47 | self.gap = nn.AdaptiveAvgPool2d(1) 48 | 49 | self.idm1 = IDM(channel=64) 50 | self.idm2 = IDM(channel=256) 51 | self.idm3 = IDM(channel=512) 52 | self.idm4 = IDM(channel=1024) 53 | self.idm5 = IDM(channel=2048) 54 | 55 | if not self.cut_at_pooling: 56 | self.num_features = num_features 57 | self.norm = norm 58 | self.dropout = dropout 59 | self.has_embedding = num_features > 0 60 | self.num_classes = num_classes 61 | 62 | out_planes = resnet.fc.in_features 63 | 64 | # Append new layers 65 | if self.has_embedding: 66 | self.feat = nn.Linear(out_planes, self.num_features) 67 | self.feat_bn = nn.BatchNorm1d(self.num_features) 68 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 69 | init.constant_(self.feat.bias, 0) 70 | else: 71 | # Change the num_features to CNN output channels 72 | self.num_features = out_planes 73 | self.feat_bn = nn.BatchNorm1d(self.num_features) 74 | self.feat_bn.bias.requires_grad_(False) 75 | if self.dropout > 0: 76 | self.drop = nn.Dropout(self.dropout) 77 | if self.num_classes > 0: 78 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 79 | init.normal_(self.classifier.weight, std=0.001) 80 | init.constant_(self.feat_bn.weight, 1) 81 | init.constant_(self.feat_bn.bias, 0) 82 | 83 | if not pretrained: 84 | self.reset_params() 85 | 86 | def forward(self, x, output_prob=False, stage=0): 87 | bs = x.size(0) 88 | # x = self.base(x) 89 | x = self.conv(x) 90 | if stage==0 and self.training: 91 | x, attention_lam = self.idm1(x) 92 | x = self.layer1(x) 93 | if stage==1 and self.training: 94 | x, attention_lam = self.idm2(x) 95 | x = self.layer2(x) 96 | if stage==2 and self.training: 97 | x, attention_lam = self.idm3(x) 98 | x = self.layer3(x) 99 | if stage==3 and self.training: 100 | x, attention_lam = self.idm4(x) 101 | x = self.layer4(x) 102 | if stage==4 and self.training: 103 | x, attention_lam = self.idm5(x) 104 | 105 | x = self.gap(x) 106 | x = x.view(x.size(0), -1) 107 | 108 | if self.cut_at_pooling: 109 | return x 110 | 111 | if self.has_embedding: 112 | bn_x = self.feat_bn(self.feat(x)) 113 | else: 114 | bn_x = self.feat_bn(x) 115 | 116 | if (self.training is False and output_prob is False): 117 | bn_x = F.normalize(bn_x) 118 | return bn_x 119 | 120 | if self.norm: 121 | norm_bn_x = F.normalize(bn_x) 122 | elif self.has_embedding: 123 | bn_x = F.relu(bn_x) 124 | 125 | if self.dropout > 0: 126 | bn_x = self.drop(bn_x) 127 | 128 | if self.num_classes > 0: 129 | prob = self.classifier(bn_x) 130 | else: 131 | return bn_x 132 | 133 | if self.norm: 134 | return prob, x, norm_bn_x 135 | 136 | else: 137 | return prob, x, attention_lam 138 | 139 | def reset_params(self): 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | init.kaiming_normal_(m.weight, mode='fan_out') 143 | if m.bias is not None: 144 | init.constant_(m.bias, 0) 145 | elif isinstance(m, nn.BatchNorm2d): 146 | init.constant_(m.weight, 1) 147 | init.constant_(m.bias, 0) 148 | elif isinstance(m, nn.BatchNorm1d): 149 | init.constant_(m.weight, 1) 150 | init.constant_(m.bias, 0) 151 | elif isinstance(m, nn.Linear): 152 | init.normal_(m.weight, std=0.001) 153 | if m.bias is not None: 154 | init.constant_(m.bias, 0) 155 | 156 | 157 | def resnet18_idm(**kwargs): 158 | return ResNet(18, **kwargs) 159 | 160 | 161 | def resnet34_idm(**kwargs): 162 | return ResNet(34, **kwargs) 163 | 164 | 165 | def resnet50_idm(**kwargs): 166 | return ResNet(50, **kwargs) 167 | 168 | 169 | def resnet101_idm(**kwargs): 170 | return ResNet(101, **kwargs) 171 | 172 | 173 | def resnet152_idm(**kwargs): 174 | return ResNet(152, **kwargs) 175 | 176 | -------------------------------------------------------------------------------- /idm/models/xbm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | 10 | 11 | class XBM: 12 | def __init__(self, memory_size, feature_size): 13 | self.K = memory_size 14 | self.D = feature_size 15 | self.feats = torch.zeros(self.K, self.D).cuda() 16 | self.targets = torch.zeros(self.K, dtype=torch.long).cuda() 17 | self.ptr = 0 18 | 19 | @property 20 | def is_full(self): 21 | return self.targets[-1].item() != 0 22 | 23 | def get(self): 24 | if self.is_full: 25 | return self.feats, self.targets 26 | else: 27 | return self.feats[:self.ptr], self.targets[:self.ptr] 28 | 29 | def enqueue_dequeue(self, feats, targets): 30 | q_size = len(targets) 31 | if self.ptr + q_size > self.K: 32 | self.feats[-q_size:] = feats 33 | self.targets[-q_size:] = targets 34 | self.ptr = 0 35 | else: 36 | self.feats[self.ptr: self.ptr + q_size] = feats 37 | self.targets[self.ptr: self.ptr + q_size] = targets 38 | self.ptr += q_size 39 | 40 | def clean_target_domain(self, source_classes, target_classes): 41 | empty_feats = torch.zeros(self.K, self.D).cuda() 42 | empty_targets = torch.zeros(self.K, dtype=torch.long).cuda() 43 | j = 0 44 | for i in range(self.K): 45 | if self.targets[i]>=source_classes and self.targets[i]0: 54 | self.ptr = j 55 | 56 | class MultiLabelXBM: 57 | def __init__(self, memory_size, feature_size): 58 | self.K = memory_size 59 | self.D = feature_size 60 | self.feats = torch.zeros(self.K, self.D).cuda() 61 | self.targets_s = torch.zeros(self.K, dtype=torch.long).cuda() 62 | self.targets_t = torch.zeros(self.K, dtype=torch.long).cuda() 63 | self.ptr = 0 64 | 65 | @property 66 | def is_full(self): 67 | return self.targets_s[-1].item() != 0 68 | 69 | def get(self): 70 | if self.is_full: 71 | return self.feats, self.targets_s, self.targets_t 72 | else: 73 | return self.feats[:self.ptr], self.targets_s[:self.ptr], self.targets_t[:self.ptr] 74 | 75 | def enqueue_dequeue(self, feats, source_targets, target_targets): 76 | q_size = len(source_targets) 77 | if self.ptr + q_size > self.K: 78 | self.feats[-q_size:] = feats 79 | self.targets_s[-q_size:] = source_targets 80 | self.targets_t[-q_size:] = target_targets 81 | self.ptr = 0 82 | else: 83 | self.feats[self.ptr: self.ptr + q_size] = feats 84 | self.targets_s[self.ptr: self.ptr + q_size] = source_targets 85 | self.targets_t[self.ptr: self.ptr + q_size] = target_targets 86 | self.ptr += q_size 87 | 88 | def clean_target_domain(self, source_classes, target_classes): 89 | empty_feats = torch.zeros(self.K, self.D).cuda() 90 | empty_targets = torch.zeros(self.K, dtype=torch.long).cuda() 91 | j = 0 92 | for i in range(self.K): 93 | if self.targets[i]>=source_classes and self.targets[i]0: 102 | self.ptr = j 103 | 104 | 105 | 106 | class LogitsXBM: 107 | def __init__(self, memory_size, logits_size): 108 | self.K = memory_size 109 | self.D = logits_size 110 | self.feats = torch.zeros(self.K, self.D).cuda() 111 | self.targets = torch.zeros(self.K, dtype=torch.long).cuda() 112 | self.ptr = 0 113 | 114 | @property 115 | def is_full(self): 116 | return self.targets[-1].item() != 0 117 | 118 | def get(self): 119 | if self.is_full: 120 | return self.feats, self.targets 121 | else: 122 | return self.feats[:self.ptr], self.targets[:self.ptr] 123 | 124 | def enqueue_dequeue(self, feats, targets): 125 | q_size = len(targets) 126 | if self.ptr + q_size > self.K: 127 | self.feats[-q_size:] = feats 128 | self.targets[-q_size:] = targets 129 | self.ptr = 0 130 | else: 131 | self.feats[self.ptr: self.ptr + q_size] = feats 132 | self.targets[self.ptr: self.ptr + q_size] = targets 133 | self.ptr += q_size 134 | 135 | def clean_target_domain(self, source_classes, target_classes): 136 | empty_feats = torch.zeros(self.K, self.D).cuda() 137 | empty_targets = torch.zeros(self.K, dtype=torch.long).cuda() 138 | j = 0 139 | for i in range(self.K): 140 | if self.targets[i]>=source_classes and self.targets[i]0: 149 | self.ptr = j -------------------------------------------------------------------------------- /idm/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import torch 4 | from .utils.meters import AverageMeter 5 | from .evaluation_metrics import accuracy 6 | from .loss import TripletLoss, CrossEntropyLabelSmooth, TripletLossXBM, DivLoss, BridgeFeatLoss, BridgeProbLoss 7 | 8 | 9 | 10 | class Baseline_Trainer(object): 11 | def __init__(self, model, xbm, num_classes, margin=None): 12 | super(Baseline_Trainer, self).__init__() 13 | self.model = model 14 | self.xbm = xbm 15 | self.num_classes = num_classes 16 | self.criterion_ce = CrossEntropyLabelSmooth(num_classes).cuda() 17 | self.criterion_tri = TripletLoss(margin=margin).cuda() 18 | self.criterion_tri_xbm = TripletLossXBM(margin=margin) 19 | 20 | def train(self, epoch, data_loader_source, data_loader_target, source_classes, target_classes, 21 | optimizer, print_freq=50, train_iters=400, use_xbm=False): 22 | self.criterion_ce = CrossEntropyLabelSmooth(source_classes + target_classes).cuda() 23 | 24 | self.model.train() 25 | 26 | batch_time = AverageMeter() 27 | data_time = AverageMeter() 28 | losses = AverageMeter() 29 | losses_ce = AverageMeter() 30 | losses_tri = AverageMeter() 31 | losses_xbm = AverageMeter() 32 | precisions_s = AverageMeter() 33 | precisions_t = AverageMeter() 34 | 35 | end = time.time() 36 | for i in range(train_iters): 37 | # load data 38 | source_inputs = data_loader_source.next() 39 | target_inputs = data_loader_target.next() 40 | data_time.update(time.time() - end) 41 | 42 | # process inputs 43 | s_inputs, s_targets, _ = self._parse_data(source_inputs) 44 | t_inputs, t_targets, t_indexes = self._parse_data(target_inputs) 45 | 46 | # arrange batch for domain-specific BN 47 | device_num = torch.cuda.device_count() 48 | B, C, H, W = s_inputs.size() 49 | 50 | def reshape(inputs): 51 | return inputs.view(device_num, -1, C, H, W) 52 | 53 | s_inputs, t_inputs = reshape(s_inputs), reshape(t_inputs) 54 | inputs = torch.cat((s_inputs, t_inputs), 1).view(-1, C, H, W) 55 | 56 | targets = torch.cat((s_targets.view(device_num, -1), t_targets.view(device_num, -1)), 1) 57 | targets = targets.view(-1) 58 | # forward 59 | prob, feats = self._forward(inputs) 60 | prob = prob[:, 0:source_classes + target_classes] 61 | 62 | # split feats 63 | ori_feats = feats.view(device_num, -1, feats.size(-1)) 64 | feats_s, feats_t = ori_feats.split(ori_feats.size(1) // 2, dim=1) 65 | ori_feats = torch.cat((feats_s, feats_t), 1).view(-1, ori_feats.size(-1)) 66 | 67 | # classification+triplet 68 | loss_ce = self.criterion_ce(prob, targets) 69 | loss_tri = self.criterion_tri(ori_feats, targets) 70 | 71 | # enqueue and dequeue for xbm 72 | if use_xbm: 73 | self.xbm.enqueue_dequeue(ori_feats.detach(), targets.detach()) 74 | xbm_feats, xbm_targets = self.xbm.get() 75 | loss_xbm = self.criterion_tri_xbm(ori_feats, targets, xbm_feats, xbm_targets) 76 | losses_xbm.update(loss_xbm.item()) 77 | loss = loss_ce + loss_tri + loss_xbm 78 | else: 79 | loss = loss_ce + loss_tri 80 | 81 | optimizer.zero_grad() 82 | loss.backward() 83 | optimizer.step() 84 | 85 | ori_prob = prob.view(device_num, -1, prob.size(-1)) 86 | prob_s, prob_t = ori_prob.split(ori_prob.size(1) // 2, dim=1) 87 | prob_s, prob_t = prob_s.contiguous(), prob_t.contiguous() 88 | prec_s, = accuracy(prob_s.view(-1, prob_s.size(-1)).data, s_targets.data) 89 | prec_t, = accuracy(prob_t.view(-1, prob_s.size(-1)).data, t_targets.data) 90 | 91 | losses.update(loss.item()) 92 | losses_ce.update(loss_ce.item()) 93 | losses_tri.update(loss_tri.item()) 94 | precisions_s.update(prec_s[0]) 95 | precisions_t.update(prec_t[0]) 96 | 97 | # print log 98 | batch_time.update(time.time() - end) 99 | end = time.time() 100 | 101 | if (i + 1) % print_freq == 0: 102 | 103 | if use_xbm: 104 | print('Epoch: [{}][{}/{}]\t' 105 | 'Time {:.3f} ({:.3f}) ' 106 | 'Data {:.3f} ({:.3f}) ' 107 | 'Loss {:.3f} ({:.3f}) ' 108 | 'Loss_ce {:.3f} ({:.3f}) ' 109 | 'Loss_tri {:.3f} ({:.3f}) ' 110 | 'Loss_xbm {:.3f} ({:.3f}) ' 111 | 'Prec_s {:.2%} ({:.2%}) ' 112 | 'Prec_t {:.2%} ({:.2%}) ' 113 | .format(epoch, i + 1, len(data_loader_target), 114 | batch_time.val, batch_time.avg, 115 | data_time.val, data_time.avg, 116 | losses.val, losses.avg, 117 | losses_ce.val, losses_ce.avg, 118 | losses_tri.val, losses_tri.avg, 119 | losses_xbm.val, losses_xbm.avg, 120 | precisions_s.val, precisions_s.avg, 121 | precisions_t.val, precisions_t.avg 122 | )) 123 | else: 124 | print('Epoch: [{}][{}/{}]\t' 125 | 'Time {:.3f} ({:.3f}) ' 126 | 'Data {:.3f} ({:.3f}) ' 127 | 'Loss {:.3f} ({:.3f}) ' 128 | 'Loss_ce {:.3f} ({:.3f}) ' 129 | 'Loss_tri {:.3f} ({:.3f}) ' 130 | 'Prec_s {:.2%} ({:.2%}) ' 131 | 'Prec_t {:.2%} ({:.2%}) ' 132 | .format(epoch, i + 1, len(data_loader_target), 133 | batch_time.val, batch_time.avg, 134 | data_time.val, data_time.avg, 135 | losses.val, losses.avg, 136 | losses_ce.val, losses_ce.avg, 137 | losses_tri.val, losses_tri.avg, 138 | precisions_s.val, precisions_s.avg, 139 | precisions_t.val, precisions_t.avg 140 | )) 141 | 142 | def _parse_data(self, inputs): 143 | imgs, _, pids, _, indexes = inputs 144 | return imgs.cuda(), pids.cuda(), indexes.cuda() 145 | 146 | def _forward(self, inputs): 147 | return self.model(inputs) 148 | 149 | 150 | class IDM_Trainer(object): 151 | def __init__(self, model, xbm, num_classes, margin=None, mu1=1.0, mu2=1.0, mu3=1.0): 152 | super(IDM_Trainer, self).__init__() 153 | self.model = model 154 | self.xbm = xbm 155 | self.mu1 = mu1 156 | self.mu2 = mu2 157 | self.mu3 = mu3 158 | self.num_classes = num_classes 159 | self.criterion_ce = BridgeProbLoss(num_classes).cuda() 160 | self.criterion_tri = TripletLoss(margin=margin).cuda() 161 | self.criterion_tri_xbm = TripletLossXBM(margin=margin) 162 | self.criterion_bridge_feat = BridgeFeatLoss() 163 | self.criterion_diverse = DivLoss() 164 | 165 | def train(self, epoch, data_loader_source, data_loader_target, source_classes, target_classes, 166 | optimizer, print_freq=50, train_iters=400, use_xbm=False, stage=0): 167 | 168 | self.criterion_ce = BridgeProbLoss(source_classes + target_classes).cuda() 169 | 170 | self.model.train() 171 | 172 | batch_time = AverageMeter() 173 | data_time = AverageMeter() 174 | losses = AverageMeter() 175 | losses_ce = AverageMeter() 176 | losses_tri = AverageMeter() 177 | losses_xbm = AverageMeter() 178 | losses_bridge_prob = AverageMeter() 179 | losses_bridge_feat = AverageMeter() 180 | losses_diverse = AverageMeter() 181 | 182 | precisions_s = AverageMeter() 183 | precisions_t = AverageMeter() 184 | 185 | end = time.time() 186 | for i in range(train_iters): 187 | # load data 188 | source_inputs = data_loader_source.next() 189 | target_inputs = data_loader_target.next() 190 | data_time.update(time.time() - end) 191 | 192 | # process inputs 193 | s_inputs, s_targets, _ = self._parse_data(source_inputs) 194 | t_inputs, t_targets, t_indexes = self._parse_data(target_inputs) 195 | 196 | # arrange batch for domain-specific BN 197 | device_num = torch.cuda.device_count() 198 | B, C, H, W = s_inputs.size() 199 | 200 | def reshape(inputs): 201 | return inputs.view(device_num, -1, C, H, W) 202 | 203 | s_inputs, t_inputs = reshape(s_inputs), reshape(t_inputs) 204 | inputs = torch.cat((s_inputs, t_inputs), 1).view(-1, C, H, W) 205 | 206 | targets = torch.cat((s_targets.view(device_num, -1), t_targets.view(device_num, -1)), 1) 207 | targets = targets.view(-1) 208 | # forward 209 | prob, feats, attention_lam= self._forward(inputs, stage) # attention_lam: [B, 2] 210 | prob = prob[:, 0:source_classes + target_classes] 211 | 212 | # split feats 213 | ori_feats = feats.view(device_num, -1, feats.size(-1)) 214 | feats_s, feats_t, feats_mixed = ori_feats.split(ori_feats.size(1) // 3, dim=1) 215 | ori_feats = torch.cat((feats_s, feats_t), 1).view(-1, ori_feats.size(-1)) 216 | 217 | # classification+triplet 218 | loss_ce, loss_bridge_prob = self.criterion_ce(prob, targets, attention_lam[:,0].detach()) 219 | loss_tri = self.criterion_tri(ori_feats, targets) 220 | loss_diverse = self.criterion_diverse(attention_lam) 221 | 222 | feats_s = feats_s.contiguous().view(-1, feats.size(-1)) 223 | feats_t = feats_t.contiguous().view(-1, feats.size(-1)) 224 | feats_mixed = feats_mixed.contiguous().view(-1, feats.size(-1)) 225 | 226 | loss_bridge_feat = self.criterion_bridge_feat(feats_s, feats_t, feats_mixed, attention_lam) 227 | 228 | 229 | # enqueue and dequeue for xbm 230 | if use_xbm: 231 | self.xbm.enqueue_dequeue(ori_feats.detach(), targets.detach()) 232 | xbm_feats, xbm_targets = self.xbm.get() 233 | loss_xbm = self.criterion_tri_xbm(ori_feats, targets, xbm_feats, xbm_targets) 234 | losses_xbm.update(loss_xbm.item()) 235 | loss = (1.-self.mu1) * loss_ce + loss_tri + loss_xbm + \ 236 | self.mu1 * loss_bridge_prob + self.mu2 * loss_bridge_feat + self.mu3 * loss_diverse 237 | else: 238 | loss = (1.-self.mu1) * loss_ce + loss_tri + \ 239 | self.mu1 * loss_bridge_prob + self.mu2 * loss_bridge_feat + self.mu3 * loss_diverse 240 | 241 | optimizer.zero_grad() 242 | loss.backward() 243 | optimizer.step() 244 | 245 | ori_prob = prob.view(device_num, -1, prob.size(-1)) 246 | prob_s, prob_t, _ = ori_prob.split(ori_prob.size(1) // 3, dim=1) 247 | 248 | prob_s, prob_t = prob_s.contiguous(), prob_t.contiguous() 249 | prec_s, = accuracy(prob_s.view(-1, prob_s.size(-1)).data, s_targets.data) 250 | prec_t, = accuracy(prob_t.view(-1, prob_s.size(-1)).data, t_targets.data) 251 | 252 | losses.update(loss.item()) 253 | losses_ce.update(loss_ce.item()) 254 | losses_tri.update(loss_tri.item()) 255 | losses_bridge_prob.update(loss_bridge_prob.item()) 256 | losses_bridge_feat.update(loss_bridge_feat.item()) 257 | losses_diverse.update(loss_diverse.item()) 258 | 259 | precisions_s.update(prec_s[0]) 260 | precisions_t.update(prec_t[0]) 261 | 262 | # print log 263 | batch_time.update(time.time() - end) 264 | end = time.time() 265 | 266 | if (i + 1) % print_freq == 0: 267 | 268 | if use_xbm: 269 | print('Epoch: [{}][{}/{}]\t' 270 | 'Time {:.3f} ({:.3f}) ' 271 | 'Data {:.3f} ({:.3f}) ' 272 | 'Loss {:.3f} ({:.3f}) ' 273 | 'Loss_ce {:.3f} ({:.3f}) ' 274 | 'Loss_tri {:.3f} ({:.3f}) ' 275 | 'Loss_xbm {:.3f} ({:.3f}) ' 276 | 'Loss_bridge_prob {:.3f} ({:.3f}) ' 277 | 'Loss_bridge_feat {:.3f} ({:.3f}) ' 278 | 'Loss_diverse {:.3f} ({:.3f}) ' 279 | 'Prec_s {:.2%} ({:.2%}) ' 280 | 'Prec_t {:.2%} ({:.2%}) ' 281 | .format(epoch, i + 1, len(data_loader_target), 282 | batch_time.val, batch_time.avg, 283 | data_time.val, data_time.avg, 284 | losses.val, losses.avg, 285 | losses_ce.val, losses_ce.avg, 286 | losses_tri.val, losses_tri.avg, 287 | losses_xbm.val, losses_xbm.avg, 288 | losses_bridge_prob.val, losses_bridge_prob.avg, 289 | losses_bridge_feat.val, losses_bridge_feat.avg, 290 | losses_diverse.val, losses_diverse.avg, 291 | precisions_s.val, precisions_s.avg, 292 | precisions_t.val, precisions_t.avg 293 | )) 294 | else: 295 | print('Epoch: [{}][{}/{}]\t' 296 | 'Time {:.3f} ({:.3f}) ' 297 | 'Data {:.3f} ({:.3f}) ' 298 | 'Loss {:.3f} ({:.3f}) ' 299 | 'Loss_ce {:.3f} ({:.3f}) ' 300 | 'Loss_tri {:.3f} ({:.3f}) ' 301 | 'Loss_bridge_prob {:.3f} ({:.3f}) ' 302 | 'Loss_bridge_feat {:.3f} ({:.3f}) ' 303 | 'Loss_diverse {:.3f} ({:.3f}) ' 304 | 'Prec_s {:.2%} ({:.2%}) ' 305 | 'Prec_t {:.2%} ({:.2%}) ' 306 | .format(epoch, i + 1, len(data_loader_target), 307 | batch_time.val, batch_time.avg, 308 | data_time.val, data_time.avg, 309 | losses.val, losses.avg, 310 | losses_ce.val, losses_ce.avg, 311 | losses_tri.val, losses_tri.avg, 312 | losses_bridge_prob.val, losses_bridge_prob.avg, 313 | losses_bridge_feat.val, losses_bridge_feat.avg, 314 | losses_diverse.val, losses_diverse.avg, 315 | precisions_s.val, precisions_s.avg, 316 | precisions_t.val, precisions_t.avg 317 | )) 318 | 319 | def _parse_data(self, inputs): 320 | imgs, _, pids, _, indexes = inputs 321 | return imgs.cuda(), pids.cuda(), indexes.cuda() 322 | 323 | def _forward(self, inputs, stage): 324 | return self.model(inputs, stage=stage) 325 | -------------------------------------------------------------------------------- /idm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /idm/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .base_dataset import BaseDataset, BaseImageDataset 4 | from .preprocessor import Preprocessor 5 | 6 | class IterLoader: 7 | def __init__(self, loader, length=None): 8 | self.loader = loader 9 | self.length = length 10 | self.iter = None 11 | 12 | def __len__(self): 13 | if (self.length is not None): 14 | return self.length 15 | return len(self.loader) 16 | 17 | def new_epoch(self): 18 | self.iter = iter(self.loader) 19 | 20 | def next(self): 21 | try: 22 | return next(self.iter) 23 | except: 24 | self.iter = iter(self.loader) 25 | return next(self.iter) 26 | -------------------------------------------------------------------------------- /idm/utils/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def print_dataset_statistics(self): 23 | raise NotImplementedError 24 | 25 | @property 26 | def images_dir(self): 27 | return None 28 | 29 | 30 | class BaseImageDataset(BaseDataset): 31 | """ 32 | Base class of image reid dataset 33 | """ 34 | 35 | def print_dataset_statistics(self, train, query, gallery): 36 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 37 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 38 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 39 | 40 | print("Dataset statistics:") 41 | print(" ----------------------------------------") 42 | print(" subset | # ids | # images | # cameras") 43 | print(" ----------------------------------------") 44 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 45 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 46 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 47 | print(" ----------------------------------------") 48 | -------------------------------------------------------------------------------- /idm/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import random 7 | import math 8 | from PIL import Image 9 | import torch 10 | import pdb 11 | 12 | 13 | class Preprocessor(Dataset): 14 | def __init__(self, dataset, root=None, transform=None): 15 | super(Preprocessor, self).__init__() 16 | self.dataset = dataset 17 | self.root = root 18 | self.transform = transform 19 | 20 | def __len__(self): 21 | return len(self.dataset) 22 | 23 | def __getitem__(self, indices): 24 | return self._get_single_item(indices) 25 | 26 | def _get_single_item(self, index): 27 | fname, pid, camid = self.dataset[index] 28 | fpath = fname 29 | if self.root is not None: 30 | fpath = osp.join(self.root, fname) 31 | 32 | img = Image.open(fpath).convert('RGB') 33 | 34 | if self.transform is not None: 35 | img = self.transform(img) 36 | 37 | return img, fname, pid, camid, index 38 | 39 | 40 | -------------------------------------------------------------------------------- /idm/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import math 4 | 5 | import numpy as np 6 | import copy 7 | import random 8 | import torch 9 | from torch.utils.data.sampler import ( 10 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 11 | WeightedRandomSampler) 12 | 13 | 14 | def No_index(a, b): 15 | assert isinstance(a, list) 16 | return [i for i, j in enumerate(a) if j != b] 17 | 18 | 19 | class RandomIdentitySampler(Sampler): 20 | def __init__(self, data_source, num_instances): 21 | self.data_source = data_source 22 | self.num_instances = num_instances 23 | self.index_dic = defaultdict(list) 24 | for index, (_, pid, _) in enumerate(data_source): 25 | self.index_dic[pid].append(index) 26 | self.pids = list(self.index_dic.keys()) 27 | self.num_samples = len(self.pids) 28 | 29 | def __len__(self): 30 | return self.num_samples * self.num_instances 31 | 32 | def __iter__(self): 33 | indices = torch.randperm(self.num_samples).tolist() 34 | ret = [] 35 | for i in indices: 36 | pid = self.pids[i] 37 | t = self.index_dic[pid] 38 | if len(t) >= self.num_instances: 39 | t = np.random.choice(t, size=self.num_instances, replace=False) 40 | else: 41 | t = np.random.choice(t, size=self.num_instances, replace=True) 42 | ret.extend(t) 43 | return iter(ret) 44 | 45 | 46 | class RandomMultipleGallerySampler(Sampler): 47 | def __init__(self, data_source, num_instances=4): 48 | self.data_source = data_source 49 | self.index_pid = defaultdict(int) 50 | self.pid_cam = defaultdict(list) 51 | self.pid_index = defaultdict(list) 52 | self.num_instances = num_instances 53 | 54 | for index, (_, pid, cam) in enumerate(data_source): 55 | if (pid<0): continue 56 | self.index_pid[index] = pid 57 | self.pid_cam[pid].append(cam) 58 | self.pid_index[pid].append(index) 59 | 60 | self.pids = list(self.pid_index.keys()) 61 | self.num_samples = len(self.pids) 62 | 63 | def __len__(self): 64 | return self.num_samples * self.num_instances 65 | 66 | def __iter__(self): 67 | indices = torch.randperm(len(self.pids)).tolist() 68 | ret = [] 69 | 70 | for kid in indices: 71 | i = random.choice(self.pid_index[self.pids[kid]]) 72 | 73 | _, i_pid, i_cam = self.data_source[i] 74 | 75 | ret.append(i) 76 | 77 | pid_i = self.index_pid[i] 78 | cams = self.pid_cam[pid_i] 79 | index = self.pid_index[pid_i] 80 | select_cams = No_index(cams, i_cam) 81 | 82 | if select_cams: 83 | 84 | if len(select_cams) >= self.num_instances: 85 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 86 | else: 87 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True) 88 | 89 | for kk in cam_indexes: 90 | ret.append(index[kk]) 91 | 92 | else: 93 | select_indexes = No_index(index, i) 94 | if (not select_indexes): continue 95 | if len(select_indexes) >= self.num_instances: 96 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 97 | else: 98 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 99 | 100 | for kk in ind_indexes: 101 | ret.append(index[kk]) 102 | 103 | 104 | return iter(ret) 105 | -------------------------------------------------------------------------------- /idm/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | """ Randomly selects a rectangle region in an image and erases its pixels. 54 | 'Random Erasing Data Augmentation' by Zhong et al. 55 | See https://arxiv.org/pdf/1708.04896.pdf 56 | Args: 57 | probability: The probability that the Random Erasing operation will be performed. 58 | sl: Minimum proportion of erased area against input image. 59 | sh: Maximum proportion of erased area against input image. 60 | r1: Minimum aspect ratio of erased area. 61 | mean: Erasing value. 62 | """ 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) >= self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | -------------------------------------------------------------------------------- /idm/utils/faiss_rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 5 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 6 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 7 | """ 8 | 9 | import os, sys 10 | import time 11 | import numpy as np 12 | from scipy.spatial.distance import cdist 13 | import gc 14 | import faiss 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | 19 | from .faiss_utils import search_index_pytorch, search_raw_array_pytorch, \ 20 | index_init_gpu, index_init_cpu 21 | 22 | def k_reciprocal_neigh(initial_rank, i, k1): 23 | forward_k_neigh_index = initial_rank[i,:k1+1] 24 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 25 | fi = np.where(backward_k_neigh_index==i)[0] 26 | return forward_k_neigh_index[fi] 27 | 28 | def compute_jaccard_distance(target_features, k1=20, k2=6, print_flag=True, search_option=0, use_float16=False): 29 | end = time.time() 30 | if print_flag: 31 | print('Computing jaccard distance...') 32 | 33 | ngpus = faiss.get_num_gpus() 34 | N = target_features.size(0) 35 | mat_type = np.float16 if use_float16 else np.float32 36 | 37 | if (search_option==0): 38 | # GPU + PyTorch CUDA Tensors (1) 39 | res = faiss.StandardGpuResources() 40 | res.setDefaultNullStreamAllDevices() 41 | _, initial_rank = search_raw_array_pytorch(res, target_features, target_features, k1) 42 | initial_rank = initial_rank.cpu().numpy() 43 | elif (search_option==1): 44 | # GPU + PyTorch CUDA Tensors (2) 45 | res = faiss.StandardGpuResources() 46 | index = faiss.GpuIndexFlatL2(res, target_features.size(-1)) 47 | index.add(target_features.cpu().numpy()) 48 | _, initial_rank = search_index_pytorch(index, target_features, k1) 49 | res.syncDefaultStreamCurrentDevice() 50 | initial_rank = initial_rank.cpu().numpy() 51 | elif (search_option==2): 52 | # GPU 53 | index = index_init_gpu(ngpus, target_features.size(-1)) 54 | index.add(target_features.cpu().numpy()) 55 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 56 | else: 57 | # CPU 58 | index = index_init_cpu(target_features.size(-1)) 59 | index.add(target_features.cpu().numpy()) 60 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 61 | 62 | 63 | nn_k1 = [] 64 | nn_k1_half = [] 65 | for i in range(N): 66 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 67 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1/2)))) 68 | 69 | V = np.zeros((N, N), dtype=mat_type) 70 | for i in range(N): 71 | k_reciprocal_index = nn_k1[i] 72 | k_reciprocal_expansion_index = k_reciprocal_index 73 | for candidate in k_reciprocal_index: 74 | candidate_k_reciprocal_index = nn_k1_half[candidate] 75 | if (len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index)) > 2/3*len(candidate_k_reciprocal_index)): 76 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 77 | 78 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 79 | dist = 2-2*torch.mm(target_features[i].unsqueeze(0).contiguous(), target_features[k_reciprocal_expansion_index].t()) 80 | if use_float16: 81 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type) 82 | else: 83 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy() 84 | 85 | del nn_k1, nn_k1_half 86 | 87 | if k2 != 1: 88 | V_qe = np.zeros_like(V, dtype=mat_type) 89 | for i in range(N): 90 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:], axis=0) 91 | V = V_qe 92 | del V_qe 93 | 94 | del initial_rank 95 | 96 | invIndex = [] 97 | for i in range(N): 98 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num 99 | 100 | jaccard_dist = np.zeros((N, N), dtype=mat_type) 101 | for i in range(N): 102 | temp_min = np.zeros((1,N), dtype=mat_type) 103 | # temp_max = np.zeros((1,N), dtype=mat_type) 104 | indNonZero = np.where(V[i,:] != 0)[0] 105 | indImages = [] 106 | indImages = [invIndex[ind] for ind in indNonZero] 107 | for j in range(len(indNonZero)): 108 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 109 | # temp_max[0,indImages[j]] = temp_max[0,indImages[j]]+np.maximum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 110 | 111 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 112 | # jaccard_dist[i] = 1-temp_min/(temp_max+1e-6) 113 | 114 | del invIndex, V 115 | 116 | pos_bool = (jaccard_dist < 0) 117 | jaccard_dist[pos_bool] = 0.0 118 | if print_flag: 119 | print ("Jaccard distance computing time cost: {}".format(time.time()-end)) 120 | 121 | return jaccard_dist 122 | -------------------------------------------------------------------------------- /idm/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | def swig_ptr_from_FloatTensor(x): 7 | assert x.is_contiguous() 8 | assert x.dtype == torch.float32 9 | return faiss.cast_integer_to_float_ptr( 10 | x.storage().data_ptr() + x.storage_offset() * 4) 11 | 12 | def swig_ptr_from_LongTensor(x): 13 | assert x.is_contiguous() 14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 15 | return faiss.cast_integer_to_long_ptr( 16 | x.storage().data_ptr() + x.storage_offset() * 8) 17 | 18 | def search_index_pytorch(index, x, k, D=None, I=None): 19 | """call the search function of an index with pytorch tensor I/O (CPU 20 | and GPU supported)""" 21 | assert x.is_contiguous() 22 | n, d = x.size() 23 | assert d == index.d 24 | 25 | if D is None: 26 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 27 | else: 28 | assert D.size() == (n, k) 29 | 30 | if I is None: 31 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 32 | else: 33 | assert I.size() == (n, k) 34 | torch.cuda.synchronize() 35 | xptr = swig_ptr_from_FloatTensor(x) 36 | Iptr = swig_ptr_from_LongTensor(I) 37 | Dptr = swig_ptr_from_FloatTensor(D) 38 | index.search_c(n, xptr, 39 | k, Dptr, Iptr) 40 | torch.cuda.synchronize() 41 | return D, I 42 | 43 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 44 | metric=faiss.METRIC_L2): 45 | assert xb.device == xq.device 46 | 47 | nq, d = xq.size() 48 | if xq.is_contiguous(): 49 | xq_row_major = True 50 | elif xq.t().is_contiguous(): 51 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 52 | xq_row_major = False 53 | else: 54 | raise TypeError('matrix should be row or column-major') 55 | 56 | xq_ptr = swig_ptr_from_FloatTensor(xq) 57 | 58 | nb, d2 = xb.size() 59 | assert d2 == d 60 | if xb.is_contiguous(): 61 | xb_row_major = True 62 | elif xb.t().is_contiguous(): 63 | xb = xb.t() 64 | xb_row_major = False 65 | else: 66 | raise TypeError('matrix should be row or column-major') 67 | xb_ptr = swig_ptr_from_FloatTensor(xb) 68 | 69 | if D is None: 70 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 71 | else: 72 | assert D.shape == (nq, k) 73 | assert D.device == xb.device 74 | 75 | if I is None: 76 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 77 | else: 78 | assert I.shape == (nq, k) 79 | assert I.device == xb.device 80 | 81 | D_ptr = swig_ptr_from_FloatTensor(D) 82 | I_ptr = swig_ptr_from_LongTensor(I) 83 | 84 | faiss.bruteForceKnn(res, metric, 85 | xb_ptr, xb_row_major, nb, 86 | xq_ptr, xq_row_major, nq, 87 | d, k, D_ptr, I_ptr) 88 | 89 | return D, I 90 | 91 | def index_init_gpu(ngpus, feat_dim): 92 | flat_config = [] 93 | for i in range(ngpus): 94 | cfg = faiss.GpuIndexFlatConfig() 95 | cfg.useFloat16 = False 96 | cfg.device = i 97 | flat_config.append(cfg) 98 | 99 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 100 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 101 | index = faiss.IndexShards(feat_dim) 102 | for sub_index in indexes: 103 | index.add_shard(sub_index) 104 | index.reset() 105 | return index 106 | 107 | def index_init_cpu(feat_dim): 108 | return faiss.IndexFlatL2(feat_dim) 109 | -------------------------------------------------------------------------------- /idm/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /idm/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | from torch.optim.lr_scheduler import * 9 | 10 | 11 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 12 | # separating MultiStepLR with WarmupLR 13 | # but the current LRScheduler design doesn't allow it 14 | 15 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 16 | def __init__( 17 | self, 18 | optimizer, 19 | milestones, 20 | gamma=0.1, 21 | warmup_factor=1.0 / 3, 22 | warmup_iters=500, 23 | warmup_method="linear", 24 | last_epoch=-1, 25 | ): 26 | if not list(milestones) == sorted(milestones): 27 | raise ValueError( 28 | "Milestones should be a list of" " increasing integers. Got {}", 29 | milestones, 30 | ) 31 | 32 | if warmup_method not in ("constant", "linear"): 33 | raise ValueError( 34 | "Only 'constant' or 'linear' warmup_method accepted" 35 | "got {}".format(warmup_method) 36 | ) 37 | self.milestones = milestones 38 | self.gamma = gamma 39 | self.warmup_factor = warmup_factor 40 | self.warmup_iters = warmup_iters 41 | self.warmup_method = warmup_method 42 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 43 | 44 | def get_lr(self): 45 | warmup_factor = 1 46 | if self.last_epoch < self.warmup_iters: 47 | if self.warmup_method == "constant": 48 | warmup_factor = self.warmup_factor 49 | elif self.warmup_method == "linear": 50 | alpha = float(self.last_epoch) / float(self.warmup_iters) 51 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 52 | return [ 53 | base_lr 54 | * warmup_factor 55 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 56 | for base_lr in self.base_lrs 57 | ] -------------------------------------------------------------------------------- /idm/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /idm/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /idm/utils/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Source: https://github.com/zhunzhong07/person-re-ranking 5 | Created on Mon Jun 26 14:46:56 2017 6 | @author: luohao 7 | Modified by Houjing Huang, 2017-12-22. 8 | - This version accepts distance matrix instead of raw features. 9 | - The difference of `/` division between python 2 and 3 is handled. 10 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 11 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 12 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 13 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 14 | API 15 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 16 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 17 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 18 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 19 | Returns: 20 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import print_function 24 | from __future__ import division 25 | 26 | __all__ = ['re_ranking', 'compute_jaccard_distance'] 27 | 28 | import numpy as np 29 | import time 30 | 31 | import torch 32 | import torch.nn.functional as F 33 | 34 | 35 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 36 | 37 | # The following naming, e.g. gallery_num, is different from outer scope. 38 | # Don't care about it. 39 | 40 | original_dist = np.concatenate( 41 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 42 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 43 | axis=0) 44 | original_dist = np.power(original_dist, 2).astype(np.float32) 45 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 46 | V = np.zeros_like(original_dist).astype(np.float32) 47 | initial_rank = np.argsort(original_dist).astype(np.int32) 48 | 49 | query_num = q_g_dist.shape[0] 50 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 51 | all_num = gallery_num 52 | 53 | for i in range(all_num): 54 | # k-reciprocal neighbors 55 | forward_k_neigh_index = initial_rank[i,:k1+1] 56 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 57 | fi = np.where(backward_k_neigh_index==i)[0] 58 | k_reciprocal_index = forward_k_neigh_index[fi] 59 | k_reciprocal_expansion_index = k_reciprocal_index 60 | for j in range(len(k_reciprocal_index)): 61 | candidate = k_reciprocal_index[j] 62 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 63 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 64 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 65 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 66 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 71 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 72 | original_dist = original_dist[:query_num,] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V,dtype=np.float32) 75 | for i in range(all_num): 76 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:,i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 85 | 86 | 87 | for i in range(query_num): 88 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 89 | indNonZero = np.where(V[i,:] != 0)[0] 90 | indImages = [] 91 | indImages = [invIndex[ind] for ind in indNonZero] 92 | for j in range(len(indNonZero)): 93 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 94 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 95 | 96 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 97 | del original_dist 98 | del V 99 | del jaccard_dist 100 | final_dist = final_dist[:query_num,query_num:] 101 | return final_dist 102 | 103 | 104 | def k_reciprocal_neigh(initial_rank, i, k1): 105 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 106 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 107 | fi = torch.nonzero(backward_k_neigh_index == i)[:, 0] 108 | return forward_k_neigh_index[fi] 109 | 110 | 111 | def compute_jaccard_distance(target_features, k1=20, k2=6, print_flag=True, 112 | lambda_value=0, source_features=None, use_gpu=False): 113 | end = time.time() 114 | N = target_features.size(0) 115 | if (use_gpu): 116 | # accelerate matrix distance computing 117 | target_features = target_features.cuda() 118 | if (source_features is not None): 119 | source_features = source_features.cuda() 120 | 121 | if ((lambda_value > 0) and (source_features is not None)): 122 | M = source_features.size(0) 123 | sour_tar_dist = torch.pow(target_features, 2).sum(dim=1, keepdim=True).expand(N, M) + \ 124 | torch.pow(source_features, 2).sum(dim=1, keepdim=True).expand(M, N).t() 125 | sour_tar_dist.addmm_(1, -2, target_features, source_features.t()) 126 | sour_tar_dist = 1 - torch.exp(-sour_tar_dist) 127 | sour_tar_dist = sour_tar_dist.cpu() 128 | source_dist_vec = sour_tar_dist.min(1)[0] 129 | del sour_tar_dist 130 | source_dist_vec /= source_dist_vec.max() 131 | source_dist = torch.zeros(N, N) 132 | for i in range(N): 133 | source_dist[i, :] = source_dist_vec + source_dist_vec[i] 134 | del source_dist_vec 135 | 136 | if print_flag: 137 | print('Computing original distance...') 138 | 139 | original_dist = torch.pow(target_features, 2).sum(dim=1, keepdim=True) * 2 140 | original_dist = original_dist.expand(N, N) - 2 * torch.mm(target_features, target_features.t()) 141 | original_dist /= original_dist.max(0)[0] 142 | original_dist = original_dist.t() 143 | initial_rank = torch.argsort(original_dist, dim=-1) 144 | 145 | original_dist = original_dist.cpu() 146 | initial_rank = initial_rank.cpu() 147 | all_num = gallery_num = original_dist.size(0) 148 | 149 | del target_features 150 | if (source_features is not None): 151 | del source_features 152 | 153 | if print_flag: 154 | print('Computing Jaccard distance...') 155 | 156 | nn_k1 = [] 157 | nn_k1_half = [] 158 | for i in range(all_num): 159 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 160 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1 / 2)))) 161 | 162 | V = torch.zeros(all_num, all_num) 163 | for i in range(all_num): 164 | k_reciprocal_index = nn_k1[i] 165 | k_reciprocal_expansion_index = k_reciprocal_index 166 | for candidate in k_reciprocal_index: 167 | candidate_k_reciprocal_index = nn_k1_half[candidate] 168 | if (len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 169 | candidate_k_reciprocal_index)): 170 | k_reciprocal_expansion_index = torch.cat((k_reciprocal_expansion_index, candidate_k_reciprocal_index)) 171 | 172 | k_reciprocal_expansion_index = torch.unique(k_reciprocal_expansion_index) ## element-wise unique 173 | weight = torch.exp(-original_dist[i, k_reciprocal_expansion_index]) 174 | V[i, k_reciprocal_expansion_index] = weight / torch.sum(weight) 175 | 176 | if k2 != 1: 177 | k2_rank = initial_rank[:, :k2].clone().view(-1) 178 | V_qe = V[k2_rank] 179 | V_qe = V_qe.view(initial_rank.size(0), k2, -1).sum(1) 180 | V_qe /= k2 181 | V = V_qe 182 | del V_qe 183 | del initial_rank 184 | 185 | invIndex = [] 186 | for i in range(gallery_num): 187 | invIndex.append(torch.nonzero(V[:, i])[:, 0]) # len(invIndex)=all_num 188 | 189 | jaccard_dist = torch.zeros_like(original_dist) 190 | for i in range(all_num): 191 | temp_min = torch.zeros(1, gallery_num) 192 | indNonZero = torch.nonzero(V[i, :])[:, 0] 193 | indImages = [] 194 | indImages = [invIndex[ind] for ind in indNonZero] 195 | for j in range(len(indNonZero)): 196 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + torch.min(V[i, indNonZero[j]], 197 | V[indImages[j], indNonZero[j]]) 198 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 199 | del invIndex 200 | 201 | del V 202 | 203 | pos_bool = (jaccard_dist < 0) 204 | jaccard_dist[pos_bool] = 0.0 205 | if print_flag: 206 | print("Time cost: {}".format(time.time() - end)) 207 | 208 | if (lambda_value > 0): 209 | return jaccard_dist * (1 - lambda_value) + source_dist * lambda_value 210 | else: 211 | return jaccard_dist -------------------------------------------------------------------------------- /idm/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | # torch.save(state, fpath, _use_new_zipfile_serialization=False) # for torch >= 1.6 28 | if is_best: 29 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 30 | 31 | 32 | def load_checkpoint(fpath): 33 | if osp.isfile(fpath): 34 | # checkpoint = torch.load(fpath) 35 | checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 36 | print("=> Loaded checkpoint '{}'".format(fpath)) 37 | return checkpoint 38 | else: 39 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 40 | 41 | 42 | def copy_state_dict(state_dict, model, strip=None): 43 | tgt_state = model.state_dict() 44 | copied_names = set() 45 | for name, param in state_dict.items(): 46 | if strip is not None and name.startswith(strip): 47 | name = name[len(strip):] 48 | if name not in tgt_state: 49 | continue 50 | if isinstance(param, Parameter): 51 | param = param.data 52 | if param.size() != tgt_state[name].size(): 53 | print('mismatch:', name, param.size(), tgt_state[name].size()) 54 | continue 55 | tgt_state[name].copy_(param) 56 | copied_names.add(name) 57 | 58 | missing = set(tgt_state.keys()) - copied_names 59 | if len(missing) > 0: 60 | print("missing keys in state_dict:", missing) 61 | 62 | return model 63 | -------------------------------------------------------------------------------- /scripts/run_idm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source=$1 4 | target=$2 5 | arch=$3 6 | stage=$4 7 | mu1=$5 8 | mu2=$6 9 | mu3=$7 10 | 11 | 12 | if [ $# -ne 7 ] 13 | then 14 | echo "Arguments error: " 15 | exit 1 16 | fi 17 | 18 | 19 | python3 examples/train_idm.py -ds ${source} -dt ${target} -a ${arch} \ 20 | --logs-dir logs/${arch}/${source}-TO-${target} \ 21 | --stage ${stage} --mu1 ${mu1} --mu2 ${mu2} --mu3 ${mu3} 22 | 23 | 24 | -------------------------------------------------------------------------------- /scripts/run_idm_xbm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source=$1 4 | target=$2 5 | arch=$3 6 | stage=$4 7 | mu1=$5 8 | mu2=$6 9 | mu3=$7 10 | 11 | 12 | if [ $# -ne 7 ] 13 | then 14 | echo "Arguments error: " 15 | exit 1 16 | fi 17 | 18 | 19 | python3 examples/train_idm.py -ds ${source} -dt ${target} -a ${arch} \ 20 | --logs-dir logs/${arch}_xbm/${source}-TO-${target} \ 21 | --use-xbm --stage ${stage} --mu1 ${mu1} --mu2 ${mu2} --mu3 ${mu3} 22 | 23 | 24 | -------------------------------------------------------------------------------- /scripts/run_naive_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source=$1 4 | target=$2 5 | arch=$3 6 | 7 | if [ $# -ne 3 ] 8 | then 9 | echo "Arguments error: " 10 | exit 1 11 | fi 12 | 13 | 14 | python3 examples/train_baseline.py -ds ${source} -dt ${target} -a ${arch} \ 15 | --logs-dir logs/${arch}_naive_baseline/${source}-TO-${target} 16 | 17 | -------------------------------------------------------------------------------- /scripts/run_strong_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source=$1 4 | target=$2 5 | arch=$3 6 | 7 | if [ $# -ne 3 ] 8 | then 9 | echo "Arguments error: " 10 | exit 1 11 | fi 12 | 13 | 14 | python3 examples/train_baseline.py -ds ${source} -dt ${target} -a ${arch} \ 15 | --logs-dir logs/${arch}_strong_baseline/${source}-TO-${target} --use-xbm 16 | 17 | -------------------------------------------------------------------------------- /scripts/run_test_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | dataset=$1 4 | arch=$2 5 | resume=$3 6 | 7 | 8 | if [ $# -ne 3 ] 9 | then 10 | echo "Arguments error: " 11 | exit 1 12 | fi 13 | 14 | python3 examples/test.py -d ${dataset} -a ${arch} --resume ${resume} --dsbn 15 | 16 | 17 | -------------------------------------------------------------------------------- /scripts/run_test_idm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | dataset=$1 4 | arch=$2 5 | stage=$3 6 | resume=$4 7 | 8 | if [ $# -ne 4 ] 9 | then 10 | echo "Arguments error: " 11 | exit 1 12 | fi 13 | 14 | python3 examples/test.py -d ${dataset} -a ${arch} --stage ${stage} --resume ${resume} --dsbn-idm 15 | 16 | 17 | --------------------------------------------------------------------------------