├── README.md ├── assets ├── perf.png └── pipeline.png ├── configs ├── TMGF_full.yml └── default.py ├── evaluate.py ├── libs ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dukemtmc.py │ ├── market1501.py │ ├── msmt17.py │ ├── msmt17_v2.py │ ├── occ_duke.py │ ├── personx.py │ ├── vehicleid.py │ ├── vehiclex.py │ └── veri.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── evaluators.py ├── models │ ├── __init__.py │ ├── mb.py │ ├── pooling.py │ ├── resnet.py │ ├── resnet_ibn.py │ ├── resnet_ibn_a.py │ ├── resnet_part.py │ ├── vit.py │ └── vit_encoder.py ├── trainers.py └── utils │ ├── __init__.py │ ├── checkpoint_io.py │ ├── clustering.py │ ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── preprocessor.py │ ├── sampler.py │ └── transforms.py │ ├── faiss_rerank.py │ ├── faiss_utils.py │ ├── logging.py │ ├── meters.py │ ├── osutils.py │ ├── prepare_data.py │ ├── prepare_model.py │ ├── prepare_optimizer.py │ ├── prepare_scheduler.py │ ├── rerank.py │ ├── scheduler.py │ └── serialization.py ├── requirements.txt ├── train.py └── train.sh /README.md: -------------------------------------------------------------------------------- 1 | # Transformer-Based Multi-Grained Features for Unsupervised Person Re-Identification (TMGF) 2 | 3 | 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/transformer-based-multi-grained-features-for/unsupervised-person-re-identification-on-12)](https://paperswithcode.com/sota/unsupervised-person-re-identification-on-12?p=transformer-based-multi-grained-features-for) \ 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/transformer-based-multi-grained-features-for/unsupervised-person-re-identification-on-5)](https://paperswithcode.com/sota/unsupervised-person-re-identification-on-5?p=transformer-based-multi-grained-features-for) \ 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/transformer-based-multi-grained-features-for/unsupervised-person-re-identification-on-4)](https://paperswithcode.com/sota/unsupervised-person-re-identification-on-4?p=transformer-based-multi-grained-features-for) 8 | 9 | Official implementation of paper [Transformer-Based Multi-Grained Features for Unsupervised Person Re-Identification](https://openaccess.thecvf.com/content/WACV2023W/RWS/html/Li_Transformer_Based_Multi-Grained_Features_for_Unsupervised_Person_Re-Identification_WACVW_2023_paper.html) (WACV2023 workshop). 10 | 11 | In this work, a dual-branch network based upon ViT is build to generate different granularities of part features from local tokens, learning together with global features for better discriminating capacity. Extensive experiments on three person Re-ID datasets are conducted and show that the proposed method achieves the state-of-the-art performance in unsupervised methods. 12 | 13 | ![pipeline](assets/pipeline.png) 14 | 15 | ## Prerequisites 16 | 17 | ### Pretrained Weight 18 | 19 | TMGF is fine-tuned on the pretrained weight, which can be found from [TransReID-SSL](https://github.com/damo-cv/TransReID-SSL). In our model, we use `ViT-S/16+ICS`. Download [here](https://drive.google.com/file/d/18FL9JaJNlo15-UksalcJRXX-0dgo4Mz4/view?usp=sharing). 20 | 21 | ### Installation 22 | 23 | Clone this repo and extract the files. 24 | 25 | We recommand `conda` to create a virtual Python 3.7 environment and install all requirements in it. Extra packages are listed in `requirements.txt` and can be installed by `pip`: 26 | 27 | ```bash 28 | conda create -n torch1.6 python=3.7 29 | conda activate torch1.6 30 | 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | ### Training 35 | 36 | Download the datasets and put them into the right place. 37 | Check and run the shell script `train.sh`: 38 | 39 | ```bash 40 | CUDA_VISIBLE_DEVICES=0 ./train.sh # run on GPU 0 41 | ``` 42 | 43 | We use `yacs` config system for better parameter management. You may need to modify the dataset root directory `DATASET.ROOT_DIR` and the pretrained weight path `MODEL.PRETRAIN_PATH`. Check [here](https://github.com/rbgirshick/yacs) to learn how to change configs as you like. 44 | 45 | > 2023/8/18: We fix a typo of `nn.DataParallel` misusage in [prepare_model.py](https://github.com/RikoLi/WACV23-workshop-TMGF/blob/main/libs/utils/prepare_model.py), which may cause some bugs in forwarding. 46 | 47 | 48 | ### Evaluation 49 | 50 | You can run evaluation on any datasets with model weight provided. 51 | 52 | ```bash 53 | CUDA_VISIBLE_DEVICES=0 python evaluate.py --weight /path/to/model/weight.pth --conf configs/TMGF_full.yml # run on GPU 0 54 | ``` 55 | 56 | ## Performance 57 | 58 | ![perf](assets/perf.png) 59 | 60 | ## Acknowledgement 61 | 62 | We would like to sincerely thank [TransReID](https://github.com/damo-cv/TransReID), [TransReID-SSL](https://github.com/damo-cv/TransReID-SSL) and [O2CAP](https://github.com/Terminator8758/O2CAP) for their insightful ideas and outstanding works! 63 | 64 | ## Citation 65 | 66 | If you feel our work helpful in your research, please cite it like this: 67 | 68 | ```bibtex 69 | @InProceedings{Li_2023_WACV, 70 | author = {Li, Jiachen and Wang, Menglin and Gong, Xiaojin}, 71 | title = {Transformer Based Multi-Grained Features for Unsupervised Person Re-Identification}, 72 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV) Workshops}, 73 | month = {January}, 74 | year = {2023}, 75 | pages = {42-50} 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /assets/perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RikoLi/WACV23-workshop-TMGF/a4acd2ded1c72f54e3f27e27b8682286228d7f7a/assets/perf.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RikoLi/WACV23-workshop-TMGF/a4acd2ded1c72f54e3f27e27b8682286228d7f7a/assets/pipeline.png -------------------------------------------------------------------------------- /configs/TMGF_full.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | STEM_CONV: true 3 | PRETRAIN_HW_RATIO: 2 4 | PRETRAIN_PATH: "/home/ljc/.cache/torch/checkpoints/vit_small_ics_cfs_lup.pth" 5 | NUM_PARTS: 5 6 | GRANULARITIES: [2, 3] 7 | BRANCH: "all" 8 | GLOBAL_FEATURE_TYPE: "mean" 9 | HAS_HEAD: true 10 | HAS_EARLY_FEATURE: true 11 | ENABLE_EARLY_NORM: false 12 | LOG: 13 | LOG_DIR: "/home/ljc/works/TMGF/logs" 14 | CHECKPOINT: 15 | SAVE_INTERVAL: 10 16 | INPUT: 17 | HEIGHT: 384 18 | WIDTH: 128 19 | PIXEL_MEAN: [0.3525, 0.3106, 0.3140] 20 | PIXEL_STD: [0.2660, 0.2522, 0.2505] 21 | TRAIN: 22 | EPOCHS: 50 23 | ITERS: 200 24 | BATCHSIZE: 32 25 | FP16: true 26 | MEMORY_BANK: 27 | MOMENTUM: 0.2 28 | PROXY_TEMP: 0.07 29 | BG_KNN: 50 30 | BALANCE_W: 0.15 31 | PART_W: 0.1 -------------------------------------------------------------------------------- /configs/default.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _C = CN() 4 | 5 | # Random seed 6 | _C.SEED = 1 7 | 8 | # Default task name 9 | _C.TASK_NAME = 'untitiled_task' 10 | 11 | # Task description 12 | _C.DESC = 'default_desc' 13 | 14 | # Model settings 15 | _C.MODEL = CN() 16 | _C.MODEL.ARCH = 'tmgf' 17 | _C.MODEL.STRIDE_SIZE = [16, 16] 18 | _C.MODEL.SIE_COEF = 3.0 19 | _C.MODEL.SIE_CAMERA = 6 20 | _C.MODEL.SIE_VIEW = 0 21 | _C.MODEL.DROP_PATH = 0.1 22 | _C.MODEL.DROP_OUT = 0.0 23 | _C.MODEL.ATTN_DROP_RATE = 0.0 24 | _C.MODEL.PRETRAIN_HW_RATIO = 1 25 | _C.MODEL.GEM_POOL = False 26 | _C.MODEL.STEM_CONV = False 27 | _C.MODEL.PRETRAIN_PATH = '/home/ljc/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 28 | _C.MODEL.NUM_PARTS = 5 # total number of parts 29 | _C.MODEL.HAS_HEAD = True # whether to use muti-grained projection heads 30 | _C.MODEL.HAS_EARLY_FEATURE = True # whether to obtain (L-1)-th layer output feature 31 | _C.MODEL.ENABLE_EARLY_NORM = False # whether to apply LayerNorm on (L-1)-th layer output feature 32 | _C.MODEL.GLOBAL_FEATURE_TYPE = 'mean' # which global token fusion method to use: mean, b1, b2 33 | _C.MODEL.GRANULARITIES = [2, 3] # number of part splits in each branch, sum up to MODEL.NUM_PART 34 | _C.MODEL.BRANCH = 'all' # which branch to use: all, b1, b2 35 | 36 | 37 | # Dataset settings 38 | _C.DATASET = CN() 39 | _C.DATASET.ROOT_DIR = '/home/ljc/datasets' 40 | _C.DATASET.NAME = 'Market1501' 41 | 42 | # Sampler settings 43 | _C.SAMPLER = CN() 44 | _C.SAMPLER.TYPE = 'proxy_balance' 45 | _C.SAMPLER.NUM_INSTANCES = 4 46 | 47 | # Clustering settings 48 | _C.CLUSTER = CN() 49 | _C.CLUSTER.EPS = 0.5 50 | _C.CLUSTER.MIN_SAMPLES = 4 51 | _C.CLUSTER.K1 = 20 52 | _C.CLUSTER.K2 = 6 53 | 54 | # Input settings 55 | _C.INPUT = CN() 56 | _C.INPUT.HEIGHT = 384 57 | _C.INPUT.WIDTH = 128 58 | _C.INPUT.PIXEL_MEAN = [0.3525, 0.3106, 0.3140] # LUPerson statistics 59 | _C.INPUT.PIXEL_STD = [0.2660, 0.2522, 0.2505] # LUPerson statistics 60 | 61 | # Optimizer settings 62 | _C.OPTIM = CN() 63 | _C.OPTIM.NAME = 'SGD' 64 | _C.OPTIM.BASE_LR = 3.5e-4 65 | _C.OPTIM.WEIGHT_DECAY = 0.0005 66 | _C.OPTIM.WEIGHT_DECAY_BIAS = 0.0005 67 | _C.OPTIM.MOMENTUM = 0.9 68 | _C.OPTIM.BIAS_LR_FACTOR = 1.0 69 | _C.OPTIM.SCHEDULER_TYPE = 'warmup' 70 | _C.OPTIM.WARMUP_EPOCHS = 10 71 | _C.OPTIM.WARMUP_FACTOR = 0.01 72 | _C.OPTIM.GAMMA = 0.1 73 | _C.OPTIM.WARMUP_METHOD = 'linear' 74 | _C.OPTIM.MILESTONES = [20, 40] 75 | 76 | # Logging settings 77 | _C.LOG = CN() 78 | _C.LOG.PRINT_FREQ = 50 79 | _C.LOG.LOG_DIR = '/home/ljc/works/TMGF/logs' 80 | _C.LOG.CHECKPOINT = CN() 81 | _C.LOG.CHECKPOINT.SAVE_DIR = '/home/ljc/works/TMGF/ckpt' 82 | _C.LOG.CHECKPOINT.SAVE_INTERVAL = 100 83 | _C.LOG.CHECKPOINT.LOAD_DIR = '' 84 | _C.LOG.CHECKPOINT.LOAD_EPOCH = 0 85 | _C.LOG.SAVE_BENCHMARK = False 86 | _C.LOG.BENCHMARK_PATH = '/home/ljc/works/TMGF/benchmark.csv' 87 | 88 | # Training settings 89 | _C.TRAIN = CN() 90 | _C.TRAIN.EPOCHS = 50 91 | _C.TRAIN.ITERS = 200 92 | _C.TRAIN.BATCHSIZE = 32 93 | _C.TRAIN.NUM_WORKERS = 8 94 | _C.TRAIN.FP16 = False 95 | 96 | # Test settings 97 | _C.TEST = CN() 98 | _C.TEST.BATCHSIZE = 32 99 | _C.TEST.NUM_WORKERS = 8 100 | _C.TEST.EVAL_STEP = 10 101 | _C.TEST.RE_RANK = False 102 | 103 | # Memory bank settings 104 | _C.MEMORY_BANK = CN() 105 | _C.MEMORY_BANK.MOMENTUM = 0.2 106 | _C.MEMORY_BANK.PROXY_TEMP = 0.07 107 | _C.MEMORY_BANK.BG_KNN = 50 108 | _C.MEMORY_BANK.POS_K = 3 109 | _C.MEMORY_BANK.BALANCE_W = 0.15 110 | _C.MEMORY_BANK.PART_W = 1.0 111 | 112 | def get_cfg_defaults(): 113 | return _C.clone() -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | import numpy as np 5 | import time 6 | from datetime import timedelta 7 | from configs.default import get_cfg_defaults 8 | from libs.utils.prepare_model import create_vit_model 9 | from libs.utils.prepare_data import get_data, get_test_loader 10 | from libs.evaluators import Evaluator 11 | 12 | def evaluate(cfg, weight_path): 13 | model = create_vit_model(cfg) 14 | dataset = get_data(cfg.DATASET.NAME, cfg.DATASET.ROOT_DIR) 15 | test_loader = get_test_loader(cfg, dataset, cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH, cfg.TEST.BATCHSIZE, cfg.TEST.NUM_WORKERS) 16 | evaluator = Evaluator(cfg, model) 17 | 18 | weight = torch.load(weight_path) 19 | model.load_state_dict(weight) 20 | print('=> Model weights loaded.') 21 | 22 | print('=> Start evaluation...') 23 | st = time.monotonic() 24 | evaluator.evaluate_vit(test_loader, dataset.query, dataset.gallery, cmc_flag=True, rerank=cfg.TEST.RE_RANK) 25 | et = time.monotonic() 26 | dt = timedelta(seconds=et-st) 27 | print('=> Evaluation time: {}'.format(dt)) 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--conf', type=str, default='', help='Config file path.') 32 | parser.add_argument('--weight', type=str, default='', help='Model parameter weight (.pth format) path.') 33 | parser.add_argument('opts', help='Modify config options using CMD.', default=None, nargs=argparse.REMAINDER) 34 | args = parser.parse_args() 35 | 36 | # Load config using yacs 37 | cfg = get_cfg_defaults() 38 | if args.conf != '': 39 | cfg.merge_from_file(args.conf) 40 | cfg.merge_from_list(args.opts) 41 | cfg.freeze() 42 | 43 | # Init env. 44 | if cfg.SEED is not None: 45 | random.seed(cfg.SEED) 46 | np.random.seed(cfg.SEED) 47 | torch.manual_seed(cfg.SEED) 48 | torch.backends.cudnn.deterministic = True 49 | torch.backends.cudnn.benchmark = True 50 | 51 | # Run 52 | evaluate(cfg, args.weight) -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RikoLi/WACV23-workshop-TMGF/a4acd2ded1c72f54e3f27e27b8682286228d7f7a/libs/__init__.py -------------------------------------------------------------------------------- /libs/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .dukemtmc import DukeMTMC 5 | from .market1501 import Market1501 6 | #from .msmt17 import MSMT17 7 | from .msmt17_v2 import MSMT17 8 | from .personx import PersonX 9 | from .veri import VeRi 10 | from .vehicleid import VehicleID 11 | from .vehiclex import VehicleX 12 | from .occ_duke import OCCDuke 13 | 14 | 15 | __factory = { 16 | 'Market1501': Market1501, 17 | 'DukeMTMC-reID': DukeMTMC, 18 | 'OCCDuke': OCCDuke, 19 | 'MSMT17': MSMT17, 20 | 'personx': PersonX, 21 | 'VeRi': VeRi, 22 | 'vehicleid': VehicleID, 23 | 'vehiclex': VehicleX 24 | } 25 | 26 | 27 | def names(): 28 | return sorted(__factory.keys()) 29 | 30 | 31 | def create(name, root, *args, **kwargs): 32 | """ 33 | Create a dataset instance. 34 | 35 | Parameters 36 | ---------- 37 | name : str 38 | The dataset name. 39 | root : str 40 | The path to the dataset directory. 41 | split_id : int, optional 42 | The index of data split. Default: 0 43 | num_val : int or float, optional 44 | When int, it means the number of validation identities. When float, 45 | it means the proportion of validation to all the trainval. Default: 100 46 | download : bool, optional 47 | If True, will download the dataset. Default: False 48 | """ 49 | if name not in __factory: 50 | raise KeyError("Unknown dataset:", name) 51 | return __factory[name](root, *args, **kwargs) 52 | 53 | 54 | def get_dataset(name, root, *args, **kwargs): 55 | warnings.warn("get_dataset is deprecated. Use create instead.") 56 | return create(name, root, *args, **kwargs) 57 | -------------------------------------------------------------------------------- /libs/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | 13 | class DukeMTMC(BaseImageDataset): 14 | """ 15 | DukeMTMC-reID 16 | Reference: 17 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 18 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 19 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 20 | 21 | Dataset statistics: 22 | # identities: 1404 (train + query) 23 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 24 | # cameras: 8 25 | """ 26 | dataset_dir = '.' 27 | 28 | def __init__(self, root, verbose=True, **kwargs): 29 | super(DukeMTMC, self).__init__() 30 | self.dataset_dir = root #osp.join(root, self.dataset_dir) 31 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 32 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 33 | self.query_dir = osp.join(self.dataset_dir, 'query') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 35 | 36 | self._download_data() 37 | self._check_before_run() 38 | 39 | train = self._process_dir(self.train_dir, relabel=True) 40 | query = self._process_dir(self.query_dir, relabel=False) 41 | gallery = self._process_dir(self.gallery_dir, relabel=False) 42 | 43 | if verbose: 44 | print("=> DukeMTMC-reID loaded") 45 | self.print_dataset_statistics(train, query, gallery) 46 | 47 | self.train = train 48 | self.query = query 49 | self.gallery = gallery 50 | 51 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 52 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 53 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 54 | 55 | def _download_data(self): 56 | if osp.exists(self.dataset_dir): 57 | print("This dataset has been downloaded.") 58 | return 59 | 60 | print("Creating directory {}".format(self.dataset_dir)) 61 | mkdir_if_missing(self.dataset_dir) 62 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 63 | 64 | print("Downloading DukeMTMC-reID dataset") 65 | urllib.request.urlretrieve(self.dataset_url, fpath) 66 | 67 | print("Extracting files") 68 | zip_ref = zipfile.ZipFile(fpath, 'r') 69 | zip_ref.extractall(self.dataset_dir) 70 | zip_ref.close() 71 | 72 | def _check_before_run(self): 73 | """Check if all files are available before going deeper""" 74 | if not osp.exists(self.dataset_dir): 75 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 76 | if not osp.exists(self.train_dir): 77 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 78 | if not osp.exists(self.query_dir): 79 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 80 | if not osp.exists(self.gallery_dir): 81 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 82 | 83 | def _process_dir(self, dir_path, relabel=False): 84 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 85 | pattern = re.compile(r'([-\d]+)_c(\d)') 86 | 87 | pid_container = set() 88 | for img_path in img_paths: 89 | pid, _ = map(int, pattern.search(img_path).groups()) 90 | pid_container.add(pid) 91 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 92 | 93 | dataset = [] 94 | for img_path in img_paths: 95 | pid, camid = map(int, pattern.search(img_path).groups()) 96 | assert 1 <= camid <= 8 97 | camid -= 1 # index starts from 0 98 | if relabel: pid = pid2label[pid] 99 | dataset.append((img_path, pid, camid)) 100 | 101 | return dataset 102 | -------------------------------------------------------------------------------- /libs/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | class Market1501(BaseImageDataset): 13 | """ 14 | Market1501 15 | Reference: 16 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 17 | URL: http://www.liangzheng.org/Project/project_reid.html 18 | 19 | Dataset statistics: 20 | # identities: 1501 (+1 for background) 21 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 22 | """ 23 | #dataset_dir = 'Market-1501-v15.09.15' 24 | 25 | def __init__(self, root, verbose=True, **kwargs): 26 | super(Market1501, self).__init__() 27 | self.dataset_dir = root # osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 31 | 32 | self._check_before_run() 33 | 34 | train = self._process_dir(self.train_dir, relabel=True) 35 | query = self._process_dir(self.query_dir, relabel=False) 36 | gallery = self._process_dir(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print("=> Market1501 loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _process_dir(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c(\d)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | dataset = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1: continue # junk images are just ignored 76 | #assert 0 <= pid <= 1501 # pid == 0 means background 77 | assert 1 <= camid <= 6 78 | camid -= 1 # index starts from 0 79 | if relabel: pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | -------------------------------------------------------------------------------- /libs/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import tarfile 4 | 5 | import glob 6 | import re 7 | import urllib 8 | import zipfile 9 | 10 | from ..utils.osutils import mkdir_if_missing 11 | from ..utils.serialization import write_json 12 | 13 | 14 | def _pluck_msmt(list_file, subdir, pattern=re.compile(r'([-\d]+)_([-\d]+)_([-\d]+)')): 15 | with open(list_file, 'r') as f: 16 | lines = f.readlines() 17 | ret = [] 18 | pids = [] 19 | for line in lines: 20 | line = line.strip() 21 | fname = line.split(' ')[0] 22 | pid, _, cam = map(int, pattern.search(osp.basename(fname)).groups()) 23 | if pid not in pids: 24 | pids.append(pid) 25 | ret.append((osp.join(subdir,fname), pid, cam)) 26 | return ret, pids 27 | 28 | class Dataset_MSMT(object): 29 | def __init__(self, root): 30 | self.root = root 31 | self.train, self.val, self.trainval = [], [], [] 32 | self.query, self.gallery = [], [] 33 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 34 | 35 | @property 36 | def images_dir(self): 37 | return self.root 38 | #return osp.join(self.root, 'MSMT17_V1') 39 | 40 | def load(self, verbose=True): 41 | #exdir = osp.join(self.root, 'MSMT17_V1') 42 | exdir = self.root 43 | self.train, train_pids = _pluck_msmt(osp.join(exdir, 'list_train.txt'), 'train') 44 | self.val, val_pids = _pluck_msmt(osp.join(exdir, 'list_val.txt'), 'train') 45 | self.train = self.train + self.val 46 | self.query, query_pids = _pluck_msmt(osp.join(exdir, 'list_query.txt'), 'test') 47 | self.gallery, gallery_pids = _pluck_msmt(osp.join(exdir, 'list_gallery.txt'), 'test') 48 | self.num_train_pids = len(list(set(train_pids).union(set(val_pids)))) 49 | 50 | if verbose: 51 | print(self.__class__.__name__, "dataset loaded") 52 | print(" subset | # ids | # images") 53 | print(" ---------------------------") 54 | print(" train | {:5d} | {:8d}" 55 | .format(self.num_train_pids, len(self.train))) 56 | print(" query | {:5d} | {:8d}" 57 | .format(len(query_pids), len(self.query))) 58 | print(" gallery | {:5d} | {:8d}" 59 | .format(len(gallery_pids), len(self.gallery))) 60 | 61 | class MSMT17(Dataset_MSMT): 62 | 63 | def __init__(self, root, split_id=0, download=True): 64 | super(MSMT17, self).__init__(root) 65 | 66 | if download: 67 | self.download() 68 | 69 | self.load() 70 | 71 | def download(self): 72 | 73 | import re 74 | import hashlib 75 | import shutil 76 | from glob import glob 77 | from zipfile import ZipFile 78 | 79 | raw_dir = osp.join(self.root) 80 | mkdir_if_missing(raw_dir) 81 | 82 | # Download the raw zip file 83 | fpath = raw_dir #osp.join(raw_dir, 'MSMT17_V1') 84 | if osp.isdir(fpath): 85 | print("Using downloaded file: " + fpath) 86 | else: 87 | raise RuntimeError("Please download the dataset manually to {}".format(fpath)) 88 | -------------------------------------------------------------------------------- /libs/datasets/msmt17_v2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | class MSMT17(BaseImageDataset): 13 | 14 | def __init__(self, root, verbose=True, **kwargs): 15 | super(MSMT17, self).__init__() 16 | self.dataset_dir = root # osp.join(root, self.dataset_dir) 17 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 18 | self.query_dir = osp.join(self.dataset_dir, 'query') 19 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 20 | 21 | self._check_before_run() 22 | 23 | train = self._process_dir(self.train_dir, relabel=True) 24 | query = self._process_dir(self.query_dir, relabel=False) 25 | gallery = self._process_dir(self.gallery_dir, relabel=False) 26 | 27 | if verbose: 28 | print("=> MSMT17 loaded") 29 | self.print_dataset_statistics(train, query, gallery) 30 | 31 | self.train = train 32 | self.query = query 33 | self.gallery = gallery 34 | 35 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 36 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 37 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 38 | 39 | def _check_before_run(self): 40 | """Check if all files are available before going deeper""" 41 | if not osp.exists(self.dataset_dir): 42 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 43 | if not osp.exists(self.train_dir): 44 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 45 | if not osp.exists(self.query_dir): 46 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 47 | if not osp.exists(self.gallery_dir): 48 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 49 | 50 | def _process_dir(self, dir_path, relabel=False): 51 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 52 | #pattern = re.compile(r'([-\d]+)_c(\d)') # pattern for market and duke 53 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') # pattern for msmt17 54 | pid_container = set() 55 | for img_path in img_paths: 56 | pid, _ = map(int, pattern.search(img_path).groups()) 57 | if pid == -1: continue # junk images are just ignored 58 | pid_container.add(pid) 59 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 60 | 61 | dataset = [] 62 | for img_path in img_paths: 63 | pid, camid = map(int, pattern.search(img_path).groups()) 64 | if pid == -1: continue # junk images are just ignored 65 | #assert 0 <= pid <= 1501 # pid == 0 means background 66 | assert 1 <= camid <= 15 67 | camid -= 1 # index starts from 0 68 | if relabel: pid = pid2label[pid] 69 | dataset.append((img_path, pid, camid)) 70 | 71 | return dataset 72 | -------------------------------------------------------------------------------- /libs/datasets/occ_duke.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | 13 | class OCCDuke(BaseImageDataset): 14 | """ 15 | Occluded-DukeMTMC-reID 16 | """ 17 | dataset_dir = '.' 18 | 19 | def __init__(self, root, verbose=True, **kwargs): 20 | super(OCCDuke, self).__init__() 21 | self.dataset_dir = root #osp.join(root, self.dataset_dir) 22 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 23 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 24 | self.query_dir = osp.join(self.dataset_dir, 'query') 25 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 26 | 27 | self._download_data() 28 | self._check_before_run() 29 | 30 | train = self._process_dir(self.train_dir, relabel=True) 31 | query = self._process_dir(self.query_dir, relabel=False) 32 | gallery = self._process_dir(self.gallery_dir, relabel=False) 33 | 34 | if verbose: 35 | print("=> Occluded DukeMTMC-reID loaded") 36 | self.print_dataset_statistics(train, query, gallery) 37 | 38 | self.train = train 39 | self.query = query 40 | self.gallery = gallery 41 | 42 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 43 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 44 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 45 | 46 | def _download_data(self): 47 | if osp.exists(self.dataset_dir): 48 | print("This dataset has been downloaded.") 49 | return 50 | 51 | print("Creating directory {}".format(self.dataset_dir)) 52 | mkdir_if_missing(self.dataset_dir) 53 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 54 | 55 | print("Downloading DukeMTMC-reID dataset") 56 | urllib.request.urlretrieve(self.dataset_url, fpath) 57 | 58 | print("Extracting files") 59 | zip_ref = zipfile.ZipFile(fpath, 'r') 60 | zip_ref.extractall(self.dataset_dir) 61 | zip_ref.close() 62 | 63 | def _check_before_run(self): 64 | """Check if all files are available before going deeper""" 65 | if not osp.exists(self.dataset_dir): 66 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 67 | if not osp.exists(self.train_dir): 68 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 69 | if not osp.exists(self.query_dir): 70 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 71 | if not osp.exists(self.gallery_dir): 72 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 73 | 74 | def _process_dir(self, dir_path, relabel=False): 75 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 76 | pattern = re.compile(r'([-\d]+)_c(\d)') 77 | 78 | pid_container = set() 79 | for img_path in img_paths: 80 | pid, _ = map(int, pattern.search(img_path).groups()) 81 | pid_container.add(pid) 82 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 83 | 84 | dataset = [] 85 | for img_path in img_paths: 86 | pid, camid = map(int, pattern.search(img_path).groups()) 87 | assert 1 <= camid <= 8 88 | camid -= 1 # index starts from 0 89 | if relabel: pid = pid2label[pid] 90 | dataset.append((img_path, pid, camid)) 91 | 92 | return dataset 93 | -------------------------------------------------------------------------------- /libs/datasets/personx.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | class PersonX(BaseImageDataset): 13 | """ 14 | PersonX 15 | Reference: 16 | Sun et al. Dissecting Person Re-identification from the Viewpoint of Viewpoint. CVPR 2019. 17 | 18 | Dataset statistics: 19 | # identities: 1266 20 | # images: 9840 (train) + 5136 (query) + 30816 (gallery) 21 | """ 22 | dataset_dir = 'PersonX' 23 | 24 | def __init__(self, root, verbose=True, **kwargs): 25 | super(PersonX, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 30 | 31 | self._check_before_run() 32 | 33 | train = self._process_dir(self.train_dir, relabel=True) 34 | query = self._process_dir(self.query_dir, relabel=False) 35 | gallery = self._process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | print("=> PersonX loaded") 39 | self.print_dataset_statistics(train, query, gallery) 40 | 41 | self.train = train 42 | self.query = query 43 | self.gallery = gallery 44 | 45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 48 | 49 | def _check_before_run(self): 50 | """Check if all files are available before going deeper""" 51 | if not osp.exists(self.dataset_dir): 52 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 53 | if not osp.exists(self.train_dir): 54 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 55 | if not osp.exists(self.query_dir): 56 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 57 | if not osp.exists(self.gallery_dir): 58 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 59 | 60 | def _process_dir(self, dir_path, relabel=False): 61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 62 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 63 | cam2label = {3:1, 4:2, 8:3, 10:4, 11:5, 12:6} 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | pid_container.add(pid) 69 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 70 | 71 | dataset = [] 72 | for img_path in img_paths: 73 | pid, camid = map(int, pattern.search(img_path).groups()) 74 | assert (camid in cam2label.keys()) 75 | camid = cam2label[camid] 76 | camid -= 1 # index starts from 0 77 | if relabel: pid = pid2label[pid] 78 | dataset.append((img_path, pid, camid)) 79 | 80 | return dataset 81 | -------------------------------------------------------------------------------- /libs/datasets/vehicleid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | import os.path as osp 7 | 8 | from ..utils.data import BaseImageDataset 9 | from collections import defaultdict 10 | 11 | 12 | class VehicleID(BaseImageDataset): 13 | """ 14 | VehicleID 15 | Reference: 16 | Deep Relative Distance Learning: Tell the Difference Between Similar Vehicles 17 | 18 | Dataset statistics: 19 | # train_list: 13164 vehicles for model training 20 | # test_list_800: 800 vehicles for model testing(small test set in paper 21 | # test_list_1600: 1600 vehicles for model testing(medium test set in paper 22 | # test_list_2400: 2400 vehicles for model testing(large test set in paper 23 | # test_list_3200: 3200 vehicles for model testing 24 | # test_list_6000: 6000 vehicles for model testing 25 | # test_list_13164: 13164 vehicles for model testing 26 | """ 27 | dataset_dir = 'VehicleID' 28 | 29 | def __init__(self, root, verbose=True, test_size=800, **kwargs): 30 | super(VehicleID, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.img_dir = osp.join(self.dataset_dir, 'image') 33 | self.split_dir = osp.join(self.dataset_dir, 'train_test_split') 34 | self.train_list = osp.join(self.split_dir, 'train_list.txt') 35 | self.test_size = test_size 36 | 37 | if self.test_size == 800: 38 | self.test_list = osp.join(self.split_dir, 'test_list_800.txt') 39 | elif self.test_size == 1600: 40 | self.test_list = osp.join(self.split_dir, 'test_list_1600.txt') 41 | elif self.test_size == 2400: 42 | self.test_list = osp.join(self.split_dir, 'test_list_2400.txt') 43 | 44 | print(self.test_list) 45 | 46 | self.check_before_run() 47 | 48 | train, query, gallery = self.process_split(relabel=True) 49 | self.train = train 50 | self.query = query 51 | self.gallery = gallery 52 | 53 | if verbose: 54 | print('=> VehicleID loaded') 55 | self.print_dataset_statistics(train, query, gallery) 56 | 57 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 58 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 59 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 60 | 61 | def check_before_run(self): 62 | """Check if all files are available before going deeper""" 63 | if not osp.exists(self.dataset_dir): 64 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 65 | if not osp.exists(self.split_dir): 66 | raise RuntimeError('"{}" is not available'.format(self.split_dir)) 67 | if not osp.exists(self.train_list): 68 | raise RuntimeError('"{}" is not available'.format(self.train_list)) 69 | if self.test_size not in [800, 1600, 2400]: 70 | raise RuntimeError('"{}" is not available'.format(self.test_size)) 71 | if not osp.exists(self.test_list): 72 | raise RuntimeError('"{}" is not available'.format(self.test_list)) 73 | 74 | def get_pid2label(self, pids): 75 | pid_container = set(pids) 76 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 77 | return pid2label 78 | 79 | def parse_img_pids(self, nl_pairs, pid2label=None): 80 | # il_pair is the pairs of img name and label 81 | output = [] 82 | for info in nl_pairs: 83 | name = info[0] 84 | pid = info[1] 85 | if pid2label is not None: 86 | pid = pid2label[pid] 87 | camid = 0 # don't have camid information use 0 for all 88 | img_path = osp.join(self.img_dir, name+'.jpg') 89 | output.append((img_path, pid, camid)) 90 | return output 91 | 92 | def process_split(self, relabel=False): 93 | # read train paths 94 | train_pid_dict = defaultdict(list) 95 | 96 | # 'train_list.txt' format: 97 | # the first number is the number of image 98 | # the second number is the id of vehicle 99 | with open(self.train_list) as f_train: 100 | train_data = f_train.readlines() 101 | for data in train_data: 102 | name, pid = data.strip().split(' ') 103 | pid = int(pid) 104 | train_pid_dict[pid].append([name, pid]) 105 | train_pids = list(train_pid_dict.keys()) 106 | num_train_pids = len(train_pids) 107 | assert num_train_pids == 13164, 'There should be 13164 vehicles for training,' \ 108 | ' but but got {}, please check the data'\ 109 | .format(num_train_pids) 110 | # print('num of train ids: {}'.format(num_train_pids)) 111 | test_pid_dict = defaultdict(list) 112 | with open(self.test_list) as f_test: 113 | test_data = f_test.readlines() 114 | for data in test_data: 115 | name, pid = data.split(' ') 116 | pid = int(pid) 117 | test_pid_dict[pid].append([name, pid]) 118 | test_pids = list(test_pid_dict.keys()) 119 | num_test_pids = len(test_pids) 120 | assert num_test_pids == self.test_size, 'There should be {} vehicles for testing,' \ 121 | ' but but got {}, please check the data'\ 122 | .format(self.test_size, num_test_pids) 123 | 124 | train_data = [] 125 | query_data = [] 126 | gallery_data = [] 127 | 128 | # for train ids, all images are used in the train set. 129 | for pid in train_pids: 130 | imginfo = train_pid_dict[pid] # imginfo include image name and id 131 | train_data.extend(imginfo) 132 | 133 | # for each test id, random choose one image for gallery 134 | # and the other ones for query. 135 | for pid in test_pids: 136 | imginfo = test_pid_dict[pid] 137 | sample = random.choice(imginfo) 138 | imginfo.remove(sample) 139 | query_data.extend(imginfo) 140 | gallery_data.append(sample) 141 | 142 | if relabel: 143 | train_pid2label = self.get_pid2label(train_pids) 144 | else: 145 | train_pid2label = None 146 | # for key, value in train_pid2label.items(): 147 | # print('{key}:{value}'.format(key=key, value=value)) 148 | 149 | train = self.parse_img_pids(train_data, train_pid2label) 150 | query = self.parse_img_pids(query_data) 151 | gallery = self.parse_img_pids(gallery_data) 152 | return train, query, gallery 153 | -------------------------------------------------------------------------------- /libs/datasets/vehiclex.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import re 7 | import os.path as osp 8 | 9 | from ..utils.data import BaseDataset 10 | 11 | 12 | class VehicleX(BaseDataset): 13 | """ 14 | VeRi 15 | Reference: 16 | PAMTRI: Pose-Aware Multi-Task Learning for Vehicle Re-Identification Using Highly Randomized Synthetic Data. In: ICCV 2019 17 | """ 18 | dataset_dir = 'AIC20_ReID_Simulation' 19 | 20 | def __init__(self, root, verbose=True, **kwargs): 21 | super(VehicleX, self).__init__() 22 | self.dataset_dir = osp.join(root, self.dataset_dir) 23 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 24 | 25 | self.check_before_run() 26 | 27 | train = self.process_dir(self.train_dir, relabel=True) 28 | 29 | if verbose: 30 | print('=> VehicleX loaded') 31 | self.print_dataset_statistics(train) 32 | 33 | self.train = train 34 | 35 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 36 | 37 | def check_before_run(self): 38 | """Check if all files are available before going deeper""" 39 | if not osp.exists(self.dataset_dir): 40 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 41 | if not osp.exists(self.train_dir): 42 | raise RuntimeError('"{}" is not available'.format(self.train_dir)) 43 | 44 | def process_dir(self, dir_path, relabel=False): 45 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 46 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 47 | 48 | pid_container = set() 49 | for img_path in img_paths: 50 | pid, _ = map(int, pattern.search(img_path).groups()) 51 | if pid == -1: 52 | continue # junk images are just ignored 53 | pid_container.add(pid) 54 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 55 | 56 | dataset = [] 57 | for img_path in img_paths: 58 | pid, camid = map(int, pattern.search(img_path).groups()) 59 | if pid == -1: 60 | continue # junk images are just ignored 61 | assert 1 <= pid <= 1362 62 | assert 6 <= camid <= 36 63 | camid -= 6 # index starts from 0 64 | if relabel: 65 | pid = pid2label[pid] 66 | dataset.append((img_path, pid, camid)) 67 | return dataset 68 | 69 | def print_dataset_statistics(self, train): 70 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 71 | 72 | print("Dataset statistics:") 73 | print(" ----------------------------------------") 74 | print(" subset | # ids | # images | # cameras") 75 | print(" ----------------------------------------") 76 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 77 | print(" ----------------------------------------") 78 | -------------------------------------------------------------------------------- /libs/datasets/veri.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import re 7 | import os.path as osp 8 | 9 | from ..utils.data import BaseImageDataset 10 | 11 | 12 | class VeRi(BaseImageDataset): 13 | """ 14 | VeRi 15 | Reference: 16 | Liu, X., Liu, W., Ma, H., Fu, H.: Large-scale vehicle re-identification in urban surveillance videos. In: IEEE % 17 | International Conference on Multimedia and Expo. (2016) accepted. 18 | Dataset statistics: 19 | # identities: 776 vehicles(576 for training and 200 for testing) 20 | # images: 37778 (train) + 11579 (query) 21 | """ 22 | # dataset_dir = 'VeRi' 23 | 24 | def __init__(self, root, verbose=True, **kwargs): 25 | super(VeRi, self).__init__() 26 | self.dataset_dir = root 27 | # self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 31 | 32 | self.check_before_run() 33 | 34 | train = self.process_dir(self.train_dir, relabel=True) 35 | query = self.process_dir(self.query_dir, relabel=False) 36 | gallery = self.process_dir(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print('=> VeRi loaded') 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 49 | 50 | def check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError('"{}" is not available'.format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError('"{}" is not available'.format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir)) 60 | 61 | def process_dir(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: 69 | continue # junk images are just ignored 70 | pid_container.add(pid) 71 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 72 | 73 | dataset = [] 74 | for img_path in img_paths: 75 | pid, camid = map(int, pattern.search(img_path).groups()) 76 | if pid == -1: 77 | continue # junk images are just ignored 78 | assert 0 <= pid <= 776 # pid == 0 means background 79 | assert 1 <= camid <= 20 80 | camid -= 1 # index starts from 0 81 | if relabel: 82 | pid = pid2label[pid] 83 | dataset.append((img_path, pid, camid)) 84 | 85 | return dataset 86 | -------------------------------------------------------------------------------- /libs/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap' 10 | ] 11 | -------------------------------------------------------------------------------- /libs/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from ..utils import to_torch 5 | 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | with torch.no_grad(): 9 | output, target = to_torch(output), to_torch(target) 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | ret = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 20 | ret.append(correct_k.mul_(1. / batch_size)) 21 | return ret 22 | -------------------------------------------------------------------------------- /libs/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import tqdm 5 | import numpy as np 6 | from sklearn.metrics import average_precision_score 7 | 8 | from ..utils import to_numpy 9 | 10 | 11 | def _unique_sample(ids_dict, num): 12 | mask = np.zeros(num, dtype=np.bool) 13 | for _, indices in ids_dict.items(): 14 | i = np.random.choice(indices) 15 | mask[i] = True 16 | return mask 17 | 18 | 19 | def cmc(distmat, query_ids=None, gallery_ids=None, 20 | query_cams=None, gallery_cams=None, topk=100, 21 | separate_camera_set=False, 22 | single_gallery_shot=False, 23 | first_match_break=False): 24 | distmat = to_numpy(distmat) 25 | m, n = distmat.shape 26 | # Fill up default values 27 | if query_ids is None: 28 | query_ids = np.arange(m) 29 | if gallery_ids is None: 30 | gallery_ids = np.arange(n) 31 | if query_cams is None: 32 | query_cams = np.zeros(m).astype(np.int32) 33 | if gallery_cams is None: 34 | gallery_cams = np.ones(n).astype(np.int32) 35 | # Ensure numpy array 36 | query_ids = np.asarray(query_ids) 37 | gallery_ids = np.asarray(gallery_ids) 38 | query_cams = np.asarray(query_cams) 39 | gallery_cams = np.asarray(gallery_cams) 40 | # Sort and find correct matches 41 | indices = np.argsort(distmat, axis=1) 42 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 43 | # Compute CMC for each query 44 | ret = np.zeros(topk) 45 | num_valid_queries = 0 46 | for i in tqdm.tqdm(range(m), desc='CMC Eval.'): 47 | # Filter out the same id and same camera 48 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 49 | (gallery_cams[indices[i]] != query_cams[i])) 50 | if separate_camera_set: 51 | # Filter out samples from same camera 52 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 53 | if not np.any(matches[i, valid]): continue 54 | if single_gallery_shot: 55 | repeat = 10 56 | gids = gallery_ids[indices[i][valid]] 57 | inds = np.where(valid)[0] 58 | ids_dict = defaultdict(list) 59 | for j, x in zip(inds, gids): 60 | ids_dict[x].append(j) 61 | else: 62 | repeat = 1 63 | for _ in range(repeat): 64 | if single_gallery_shot: 65 | # Randomly choose one instance for each id 66 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 67 | index = np.nonzero(matches[i, sampled])[0] 68 | else: 69 | index = np.nonzero(matches[i, valid])[0] 70 | delta = 1. / (len(index) * repeat) 71 | for j, k in enumerate(index): 72 | if k - j >= topk: break 73 | if first_match_break: 74 | ret[k - j] += 1 75 | break 76 | ret[k - j] += delta 77 | num_valid_queries += 1 78 | if num_valid_queries == 0: 79 | raise RuntimeError("No valid query") 80 | return ret.cumsum() / num_valid_queries 81 | 82 | 83 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 84 | query_cams=None, gallery_cams=None): 85 | distmat = to_numpy(distmat) 86 | m, n = distmat.shape 87 | # Fill up default values 88 | if query_ids is None: 89 | query_ids = np.arange(m) 90 | if gallery_ids is None: 91 | gallery_ids = np.arange(n) 92 | if query_cams is None: 93 | query_cams = np.zeros(m).astype(np.int32) 94 | if gallery_cams is None: 95 | gallery_cams = np.ones(n).astype(np.int32) 96 | # Ensure numpy array 97 | query_ids = np.asarray(query_ids) 98 | gallery_ids = np.asarray(gallery_ids) 99 | query_cams = np.asarray(query_cams) 100 | gallery_cams = np.asarray(gallery_cams) 101 | # Sort and find correct matches 102 | indices = np.argsort(distmat, axis=1) 103 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 104 | # Compute AP for each query 105 | aps = [] 106 | for i in tqdm.tqdm(range(m), desc='mAP Eval.'): 107 | # Filter out the same id and same camera 108 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 109 | (gallery_cams[indices[i]] != query_cams[i])) 110 | y_true = matches[i, valid] 111 | y_score = -distmat[i][indices[i]][valid] 112 | if not np.any(y_true): continue 113 | aps.append(average_precision_score(y_true, y_score)) 114 | if len(aps) == 0: 115 | raise RuntimeError("No valid query") 116 | return np.mean(aps) 117 | -------------------------------------------------------------------------------- /libs/evaluators.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import csv 3 | from collections import OrderedDict 4 | import torch 5 | import tqdm 6 | from .evaluation_metrics import cmc, mean_ap 7 | from .utils.rerank import re_ranking 8 | from .utils import to_torch 9 | 10 | def save_benchmark(cfg, mAP, cmc, task_name, time_cost): 11 | if not osp.exists(cfg.LOG.BENCHMARK_PATH): 12 | with open(cfg.LOG.BENCHMARK_PATH, 'w', newline='') as f: 13 | writer = csv.writer(f) 14 | writer.writerow(['Task', 'Training time', 'Dataset', 'mAP', 'Rank-1', 'Rank-5', 'Rank-10']) 15 | with open(cfg.LOG.BENCHMARK_PATH, 'a', newline='') as f: 16 | writer = csv.writer(f) 17 | writer.writerow([task_name, '{}'.format(time_cost), cfg.DATASET.NAME, '{:.1%}'.format(mAP), '{:.1%}'.format(cmc[0]), 18 | '{:.1%}'.format(cmc[4]), '{:.1%}'.format(cmc[9])]) 19 | print('=> Benchmark is updated.') 20 | 21 | 22 | def extract_vit_features(model, data_loader): 23 | model.eval() 24 | features = OrderedDict() 25 | labels = OrderedDict() 26 | 27 | with torch.no_grad(): 28 | for i, (imgs, fnames, pids, cams, _) in enumerate(tqdm.tqdm(data_loader)): 29 | imgs = to_torch(imgs).cuda() 30 | cams = to_torch(cams).cuda() 31 | outputs = model(imgs, cam_label=cams) 32 | if isinstance(outputs, dict): 33 | outputs = outputs['global'] # 只抽取全局特征 34 | outputs = outputs.data.cpu() 35 | 36 | for fname, output, pid in zip(fnames, outputs, pids): 37 | features[fname] = output 38 | labels[fname] = pid 39 | return features, labels 40 | 41 | def extract_multipart_vit_features(model, data_loader, num_parts): 42 | model.eval() 43 | 44 | global_feats = OrderedDict() 45 | labels = OrderedDict() 46 | 47 | part_feats = [OrderedDict() for _ in range(num_parts)] 48 | 49 | with torch.no_grad(): 50 | for i, (imgs, fnames, pids, cams, _) in enumerate(tqdm.tqdm(data_loader)): 51 | 52 | imgs = to_torch(imgs).cuda() 53 | cams = to_torch(cams).cuda() 54 | out_dict = model(imgs, cam_label=cams) 55 | for k, v in out_dict.items(): 56 | if k == 'global': 57 | out_dict[k] = v.data.cpu() 58 | elif k == 'part': 59 | out_dict[k] = [each.data.cpu() for each in v] 60 | 61 | # Global 62 | for fname, output, pid in zip(fnames, out_dict['global'], pids): 63 | global_feats[fname] = output 64 | labels[fname] = pid 65 | 66 | # Part 67 | for part, pf in zip(out_dict['part'], part_feats): 68 | for fname, output, pid in zip(fnames, part, pids): 69 | pf[fname] = output 70 | 71 | return global_feats, part_feats, labels # OD, list, OD 72 | 73 | def pairwise_distance(features, query=None, gallery=None): 74 | if query is None and gallery is None: 75 | n = len(features) 76 | x = torch.cat(list(features.values())) 77 | x = x.view(n, -1) 78 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 79 | dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t()) 80 | return dist_m 81 | 82 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 83 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 84 | m, n = x.size(0), y.size(0) 85 | x = x.view(m, -1) 86 | y = y.view(n, -1) 87 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 88 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 89 | dist_m.addmm_(1, -2, x, y.t()) 90 | return dist_m, x.numpy(), y.numpy() 91 | 92 | 93 | def evaluate_all(distmat, query=None, gallery=None, 94 | query_ids=None, gallery_ids=None, 95 | query_cams=None, gallery_cams=None, 96 | cmc_topk=(1, 5, 10), cmc_flag=False): 97 | if query is not None and gallery is not None: 98 | query_ids = [pid for _, pid, _ in query] 99 | gallery_ids = [pid for _, pid, _ in gallery] 100 | query_cams = [cam for _, _, cam in query] 101 | gallery_cams = [cam for _, _, cam in gallery] 102 | else: 103 | assert (query_ids is not None and gallery_ids is not None 104 | and query_cams is not None and gallery_cams is not None) 105 | 106 | # Compute mean AP 107 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 108 | print('Mean AP: {:4.1%}'.format(mAP)) 109 | 110 | if (not cmc_flag): 111 | return mAP 112 | 113 | cmc_configs = { 114 | 'market1501': dict(separate_camera_set=False, 115 | single_gallery_shot=False, 116 | first_match_break=True),} 117 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 118 | query_cams, gallery_cams, **params) 119 | for name, params in cmc_configs.items()} 120 | 121 | print('CMC Scores:') 122 | for k in cmc_topk: 123 | print(' top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k-1])) 124 | return cmc_scores['market1501'], mAP 125 | 126 | 127 | class Evaluator(object): 128 | def __init__(self, cfg, model): 129 | super(Evaluator, self).__init__() 130 | self.model = model 131 | self.cfg = cfg 132 | 133 | def evaluate_vit(self, data_loader, query, gallery, cmc_flag=False, rerank=False, is_concat=False): 134 | features, _ = extract_vit_features(self.model, data_loader) 135 | distmat, _, _ = pairwise_distance(features, query, gallery) 136 | 137 | results = evaluate_all(distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 138 | 139 | if (not rerank): 140 | return results 141 | 142 | print('Applying person re-ranking ...') 143 | 144 | distmat_qq = pairwise_distance(features, query, query) 145 | distmat_gg = pairwise_distance(features, gallery, gallery) 146 | distmat = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy()) 147 | return evaluate_all(distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) -------------------------------------------------------------------------------- /libs/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .resnet_ibn import * 5 | from .resnet_part import * 6 | from .vit_encoder import * 7 | 8 | 9 | __factory = { 10 | 'resnet18': resnet18, 11 | 'resnet34': resnet34, 12 | 'resnet50': resnet50, 13 | 'resnet101': resnet101, 14 | 'resnet152': resnet152, 15 | 'resnet_ibn50a': resnet_ibn50a, 16 | 'resnet_ibn101a': resnet_ibn101a, 17 | 'tmgf': tmgf 18 | } 19 | 20 | 21 | def names(): 22 | return sorted(__factory.keys()) 23 | 24 | 25 | def create(name, *args, **kwargs): 26 | """ 27 | Create a model instance. 28 | 29 | Parameters 30 | ---------- 31 | name : str 32 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 33 | 'resnet50', 'resnet101', and 'resnet152'. 34 | pretrained : bool, optional 35 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 36 | model. Default: True 37 | cut_at_pooling : bool, optional 38 | If True, will cut the model before the last global pooling layer and 39 | ignore the remaining kwargs. Default: False 40 | num_features : int, optional 41 | If positive, will append a Linear layer after the global pooling layer, 42 | with this number of output units, followed by a BatchNorm layer. 43 | Otherwise these layers will not be appended. Default: 256 for 44 | 'inception', 0 for 'resnet*' 45 | norm : bool, optional 46 | If True, will normalize the feature to be unit L2-norm for each sample. 47 | Otherwise will append a ReLU layer after the above Linear layer if 48 | num_features > 0. Default: False 49 | dropout : float, optional 50 | If positive, will append a Dropout layer with this dropout rate. 51 | Default: 0 52 | num_classes : int, optional 53 | If positive, will append a Linear layer at the end as the classifier 54 | with this number of output units. Default: 0 55 | """ 56 | if name not in __factory: 57 | raise KeyError("Unknown model:", name) 58 | return __factory[name](*args, **kwargs) 59 | -------------------------------------------------------------------------------- /libs/models/mb.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Memory bank loss implementations. 3 | Implementations are inspired by SpCL and O2CAP, thanks for their excellent works! 4 | SpCL: https://github.com/yxgeee/SpCL 5 | O2CAP: https://github.com/Terminator8758/O2CAP 6 | ''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Function 12 | from torch.cuda import amp 13 | 14 | class PartMatMul(Function): 15 | """ 16 | Matrix multiplication with memory bank update. An extra part dim is added. 17 | In forwarding, it only applies a matmul operation between anchors and memory bank. 18 | In backwarding, it update the memory bank with momentum. 19 | """ 20 | 21 | @staticmethod 22 | @amp.custom_fwd 23 | def forward(ctx, inputs, targets, em, alpha): 24 | ctx.em = em 25 | ctx.alpha = alpha 26 | ctx.save_for_backward(inputs, targets) 27 | outputs = inputs.matmul(ctx.em.permute(0,2,1)) # (n_part, b, c) x (n_part, c, n_proxy) -> (n_part, b, n_proxy) 28 | return outputs 29 | 30 | @staticmethod 31 | @amp.custom_bwd 32 | def backward(ctx, grad_outputs): 33 | inputs, targets = ctx.saved_tensors 34 | grad_inputs = None 35 | if ctx.needs_input_grad[0]: 36 | grad_inputs = grad_outputs.matmul(ctx.em) # (n_part, b, c) 37 | 38 | for i, y in enumerate(targets): 39 | x = inputs[:,i,:] # (n_part, c) 40 | ctx.em[:,y,:] = ctx.alpha * ctx.em[:,y,:] + (1.0 - ctx.alpha) * x 41 | ctx.em[:,y,:] /= ctx.em[:,y,:].norm(dim=1).unsqueeze(-1) 42 | 43 | return grad_inputs, None, None, None 44 | 45 | def part_matmul(inputs, targets, em, alpha): 46 | return PartMatMul.apply(inputs, targets, em, alpha) 47 | 48 | class MultiPartMemory(nn.Module): 49 | def __init__(self, cfg): 50 | """ 51 | Multi-part offline/online loss with momentum proxy memory bank. 52 | 53 | Params: 54 | cfg: Config instance. 55 | Returns: 56 | A MultiPartMemory instance. 57 | """ 58 | super().__init__() 59 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 60 | self.temp = cfg.MEMORY_BANK.PROXY_TEMP 61 | self.momentum = cfg.MEMORY_BANK.MOMENTUM 62 | self.neg_k = cfg.MEMORY_BANK.BG_KNN 63 | self.posK = cfg.MEMORY_BANK.POS_K 64 | self.balance_w = cfg.MEMORY_BANK.BALANCE_W 65 | self.part_weight = cfg.MEMORY_BANK.PART_W 66 | self.num_parts = cfg.MODEL.NUM_PARTS 67 | 68 | self.all_proxy_labels = None 69 | self.proxy_memory = None 70 | self.proxy2cluster = None 71 | self.cluster2proxy = None 72 | self.part_proxies = None 73 | self.unique_cams = None 74 | self.cam2proxy = None 75 | 76 | def forward(self, feature_dict, targets, cam=None, epoch=None): 77 | 78 | # proxy labels in a batch 79 | batch_proxy_labels = self.all_proxy_labels[targets].to(self.device) 80 | 81 | # loss computation 82 | all_feats = torch.cat([feature_dict['global'].unsqueeze(0), feature_dict['part']], dim=0) 83 | all_scores = part_matmul(all_feats, batch_proxy_labels, self.proxy_memory, self.momentum) # (n_part, b, n_proxy) 84 | all_scaled_scores = all_scores / self.temp 85 | global_off_loss, part_off_loss = self.offline_loss_part_parallel(all_scaled_scores, batch_proxy_labels) 86 | part_off_loss = part_off_loss.mean() 87 | 88 | all_temp_scores = all_scores.detach().clone() 89 | global_on_loss, part_on_loss = self.online_loss_part_parallel(all_scaled_scores, batch_proxy_labels, all_temp_scores) 90 | part_on_loss = part_on_loss.mean() 91 | 92 | # part loss weight 93 | part_off_loss *= self.part_weight 94 | part_on_loss *= self.part_weight 95 | 96 | loss_dict = { 97 | 'loss': global_off_loss + global_on_loss + part_off_loss + part_on_loss, 98 | 'global_off_loss': global_off_loss, 99 | 'global_on_loss': global_on_loss, 100 | 'part_off_loss': part_off_loss, 101 | 'part_on_loss': part_on_loss 102 | } 103 | return loss_dict 104 | 105 | def offline_loss_part_parallel(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 106 | """ 107 | Compute offline loss for both global and part level. 108 | All parts are handled parallelly to alleviate time consumption. 109 | 110 | Params: 111 | scores: Scaled batch-proxy similarity scores. 112 | labels: Proxy labels in a batch. 113 | Returns: 114 | Offline global and part losses. 115 | """ 116 | 117 | temp_scores = scores.detach().clone() 118 | loss = 0 119 | 120 | if scores.size(0) > 1: 121 | part_loss = 0 122 | else: 123 | part_loss = torch.tensor(0).type_as(scores) 124 | 125 | for i in range(scores.size(1)): 126 | pos_ind = torch.tensor(self.cluster2proxy[self.proxy2cluster[labels[i].item()].item()]).type_as(labels) 127 | temp_scores[:, i, pos_ind] = 10000.0 128 | sel_ind = torch.argsort(temp_scores[:, i, :])[:, -self.neg_k-len(pos_ind):] # (n_part, neg_k+pos_k) 129 | sel_input = scores[:, i, :].gather(dim=1, index=sel_ind) # (n_part, neg_k+pos_k) 130 | sel_target = torch.zeros(sel_input.shape).type_as(sel_input) # (n_part, neg_k+pos_k) 131 | sel_target[:, -len(pos_ind):] = 1.0 / len(pos_ind) 132 | logit = -1.0 * (F.log_softmax(sel_input, dim=1) * sel_target) # (n_part, neg_k+pos_k) 133 | loss += logit[0,:].sum() 134 | 135 | # compute part loss when there exists part feature 136 | if scores.size(0) > 1: 137 | part_loss += logit[1:,:].sum(dim=1) 138 | 139 | loss /= scores.size(1) 140 | part_loss /= scores.size(1) 141 | 142 | return loss, part_loss 143 | 144 | def online_loss_part_parallel(self, scores: torch.Tensor, labels: torch.Tensor, temp_scores: torch.Tensor): 145 | """ 146 | Compute online loss for both global and part level. 147 | All parts and batch samples are handled parallelly to alleviate time consumption. 148 | 149 | Params: 150 | scores: Scaled batch-proxy similarity scores. 151 | labels: Proxy labels in a batch. 152 | temp_scores: Detached scores for positive/negative samples retrieval. 153 | Returns: 154 | Online global and part losses. 155 | """ 156 | # compute online similarity 157 | temp_memory = self.proxy_memory.detach().clone() 158 | proxy_sims = torch.matmul(temp_memory, temp_memory.permute(0,2,1)) # (1+N_part, N_proxy, N_proxy) 159 | sims = self.balance_w * temp_scores + (1 - self.balance_w) * proxy_sims[:, labels, :] # (1+N_part, B, N_proxy) 160 | 161 | # CA-NN: camera-aware nearest neighbors 162 | all_cam_tops = [] 163 | for cc in self.unique_cams: 164 | proxy_inds = self.cam2proxy[int(cc)].long().to(self.device) # 当前相机下的proxy label 165 | max_idx = sims[:, :, proxy_inds].argmax(dim=2) 166 | all_cam_tops.append(proxy_inds[max_idx]) 167 | 168 | # retrieve positive samples 169 | all_cam_tops = torch.stack(all_cam_tops, dim=-1) # (1+N_part, B, N_cam) 170 | top_sims = torch.gather(sims, dim=2, index=all_cam_tops) # (1+N_part, B, N_cam) 171 | sel_inds = torch.argsort(top_sims, dim=2)[:, :, -self.posK:] 172 | pos_inds = torch.gather(all_cam_tops, dim=2, index=sel_inds) 173 | scatter_sims = torch.scatter(sims, dim=2, index=pos_inds, src=10000.0*torch.ones(pos_inds.shape).type_as(sims)) # (1+N_part, B, N_proxy) 174 | top_inds = torch.sort(scatter_sims, dim=2)[1][:, :, -self.neg_k-self.posK:] # (1+N_part, B, N_pn) 175 | sel_inputs = torch.gather(scores, dim=2, index=top_inds) 176 | sel_targets = torch.zeros(sel_inputs.shape).type_as(sel_inputs) 177 | sel_targets[:, :, -self.posK:] = 1.0 / self.posK 178 | 179 | # global loss 180 | loss = -1.0 * (F.log_softmax(sel_inputs[0], dim=1) * sel_targets[0]).sum(dim=1).mean() # scalar 181 | 182 | # part loss 183 | if scores.size(0) > 1: 184 | part_loss = -1.0 * (F.log_softmax(sel_inputs[1:], dim=2) * sel_targets[1:]).sum(dim=2).mean(dim=1) # (N_part, ) 185 | else: 186 | part_loss = torch.tensor(0).type_as(loss) 187 | 188 | return loss, part_loss 189 | -------------------------------------------------------------------------------- /libs/models/pooling.py: -------------------------------------------------------------------------------- 1 | # Credit to https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/layers/pooling.py 2 | from abc import ABC 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | __all__ = [ 9 | "GeneralizedMeanPoolingPFpn", 10 | "GeneralizedMeanPoolingList", 11 | "GeneralizedMeanPoolingP", 12 | "AdaptiveAvgMaxPool2d", 13 | "FastGlobalAvgPool2d", 14 | "avg_pooling", 15 | "max_pooling", 16 | ] 17 | 18 | 19 | class GeneralizedMeanPoolingList(nn.Module, ABC): 20 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of 21 | several input planes. 22 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 23 | - At p = infinity, one gets Max Pooling 24 | - At p = 1, one gets Average Pooling 25 | The output is of size H x W, for any input size. 26 | The number of output features is equal to the number of input planes. 27 | Args: 28 | output_size: the target output size of the image of the form H x W. 29 | Can be a tuple (H, W) or a single H for a square image H x H 30 | H and W can be either a ``int``, or ``None`` which means the size 31 | will be the same as that of the input. 32 | """ 33 | 34 | def __init__(self, output_size=1, eps=1e-6): 35 | super(GeneralizedMeanPoolingList, self).__init__() 36 | self.output_size = output_size 37 | self.eps = eps 38 | 39 | def forward(self, x_list): 40 | outs = [] 41 | for x in x_list: 42 | x = x.clamp(min=self.eps) 43 | out = torch.nn.functional.adaptive_avg_pool2d(x, self.output_size) 44 | outs.append(out) 45 | return torch.stack(outs, -1).mean(-1) 46 | 47 | def __repr__(self): 48 | return ( 49 | self.__class__.__name__ 50 | + "(" 51 | + "output_size=" 52 | + str(self.output_size) 53 | + ")" 54 | ) 55 | 56 | 57 | class GeneralizedMeanPooling(nn.Module, ABC): 58 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of 59 | several input planes. 60 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 61 | - At p = infinity, one gets Max Pooling 62 | - At p = 1, one gets Average Pooling 63 | The output is of size H x W, for any input size. 64 | The number of output features is equal to the number of input planes. 65 | Args: 66 | output_size: the target output size of the image of the form H x W. 67 | Can be a tuple (H, W) or a single H for a square image H x H 68 | H and W can be either a ``int``, or ``None`` which means the size 69 | will be the same as that of the input. 70 | """ 71 | 72 | def __init__(self, norm, output_size=1, eps=1e-6): 73 | super(GeneralizedMeanPooling, self).__init__() 74 | assert norm > 0 75 | self.p = float(norm) 76 | self.output_size = output_size 77 | self.eps = eps 78 | 79 | def forward(self, x): 80 | x = x.clamp(min=self.eps).pow(self.p) 81 | return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow( 82 | 1.0 / self.p 83 | ) 84 | 85 | def __repr__(self): 86 | return ( 87 | self.__class__.__name__ 88 | + "(" 89 | + str(self.p) 90 | + ", " 91 | + "output_size=" 92 | + str(self.output_size) 93 | + ")" 94 | ) 95 | 96 | 97 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling, ABC): 98 | """ Same, but norm is trainable 99 | """ 100 | 101 | def __init__(self, norm=3, output_size=1, eps=1e-6): 102 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 103 | self.p = nn.Parameter(torch.ones(1) * norm) 104 | 105 | 106 | class GeneralizedMeanPoolingFpn(nn.Module, ABC): 107 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of 108 | several input planes. 109 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 110 | - At p = infinity, one gets Max Pooling 111 | - At p = 1, one gets Average Pooling 112 | The output is of size H x W, for any input size. 113 | The number of output features is equal to the number of input planes. 114 | Args: 115 | output_size: the target output size of the image of the form H x W. 116 | Can be a tuple (H, W) or a single H for a square image H x H 117 | H and W can be either a ``int``, or ``None`` which means the size 118 | will be the same as that of the input. 119 | """ 120 | 121 | def __init__(self, norm, output_size=1, eps=1e-6): 122 | super(GeneralizedMeanPoolingFpn, self).__init__() 123 | assert norm > 0 124 | self.p = float(norm) 125 | self.output_size = output_size 126 | self.eps = eps 127 | 128 | def forward(self, x_lists): 129 | outs = [] 130 | for x in x_lists: 131 | x = x.clamp(min=self.eps).pow(self.p) 132 | out = torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow( 133 | 1.0 / self.p 134 | ) 135 | outs.append(out) 136 | return torch.cat(outs, 1) 137 | 138 | def __repr__(self): 139 | return ( 140 | self.__class__.__name__ 141 | + "(" 142 | + str(self.p) 143 | + ", " 144 | + "output_size=" 145 | + str(self.output_size) 146 | + ")" 147 | ) 148 | 149 | 150 | class GeneralizedMeanPoolingPFpn(GeneralizedMeanPoolingFpn, ABC): 151 | """ Same, but norm is trainable 152 | """ 153 | 154 | def __init__(self, norm=3, output_size=1, eps=1e-6): 155 | super(GeneralizedMeanPoolingPFpn, self).__init__(norm, output_size, eps) 156 | self.p = nn.Parameter(torch.ones(1) * norm) 157 | 158 | 159 | class AdaptiveAvgMaxPool2d(nn.Module, ABC): 160 | def __init__(self): 161 | super(AdaptiveAvgMaxPool2d, self).__init__() 162 | self.avgpool = FastGlobalAvgPool2d() 163 | 164 | def forward(self, x): 165 | x_avg = self.avgpool(x, self.output_size) 166 | x_max = F.adaptive_max_pool2d(x, 1) 167 | x = x_max + x_avg 168 | return x 169 | 170 | 171 | class FastGlobalAvgPool2d(nn.Module, ABC): 172 | def __init__(self, flatten=False): 173 | super(FastGlobalAvgPool2d, self).__init__() 174 | self.flatten = flatten 175 | 176 | def forward(self, x): 177 | if self.flatten: 178 | in_size = x.size() 179 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2) 180 | else: 181 | return ( 182 | x.view(x.size(0), x.size(1), -1) 183 | .mean(-1) 184 | .view(x.size(0), x.size(1), 1, 1) 185 | ) 186 | 187 | 188 | def avg_pooling(): 189 | return nn.AdaptiveAvgPool2d(1) 190 | # return FastGlobalAvgPool2d() 191 | 192 | 193 | def max_pooling(): 194 | return nn.AdaptiveMaxPool2d(1) 195 | 196 | 197 | class Flatten(nn.Module): 198 | def forward(self, input): 199 | return input.view(input.size(0), -1) 200 | 201 | 202 | __pooling_factory = { 203 | "avg": avg_pooling, 204 | "max": max_pooling, 205 | "gem": GeneralizedMeanPoolingP, 206 | "gemFpn": GeneralizedMeanPoolingPFpn, 207 | "gemList": GeneralizedMeanPoolingList, 208 | "avg+max": AdaptiveAvgMaxPool2d, 209 | } 210 | 211 | 212 | def pooling_names(): 213 | return sorted(__pooling_factory.keys()) 214 | 215 | 216 | def build_pooling_layer(name): 217 | """ 218 | Create a pooling layer. 219 | Parameters 220 | ---------- 221 | name : str 222 | The backbone name. 223 | """ 224 | if name not in __pooling_factory: 225 | raise KeyError("Unknown pooling layer:", name) 226 | return __pooling_factory[name]() -------------------------------------------------------------------------------- /libs/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | from .pooling import build_pooling_layer 9 | from torch.hub import load_state_dict_from_url 10 | 11 | 12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 13 | 'resnet152'] 14 | 15 | 16 | class ResNet(nn.Module): 17 | __factory = { 18 | 18: torchvision.models.resnet18, 19 | 34: torchvision.models.resnet34, 20 | 50: torchvision.models.resnet50, 21 | 101: torchvision.models.resnet101, 22 | 152: torchvision.models.resnet152, 23 | } 24 | 25 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 26 | num_features=0, norm=False, dropout=0, num_classes=0, 27 | pooling_type='avg', pretrained_weight=None): 28 | super(ResNet, self).__init__() 29 | self.pretrained = pretrained 30 | self.depth = depth 31 | self.cut_at_pooling = cut_at_pooling 32 | # Construct base (pretrained) resnet 33 | if depth not in ResNet.__factory: 34 | raise KeyError("Unsupported depth:", depth) 35 | resnet = ResNet.__factory[depth](pretrained=pretrained) 36 | resnet.layer4[0].conv2.stride = (1,1) 37 | resnet.layer4[0].downsample[0].stride = (1,1) 38 | 39 | # Load other pretrained weight 40 | if pretrained_weight is not None: 41 | print('load pretrained weight from {}'.format(pretrained_weight)) 42 | self.load_pretrained_weight(resnet, pretrained_weight) 43 | 44 | self.base = nn.Sequential( 45 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 46 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 47 | # self.gap = nn.AdaptiveAvgPool2d(1) # Vanilla model 48 | self.gap = build_pooling_layer(pooling_type) # follow cluster-contrast 49 | 50 | if not self.cut_at_pooling: 51 | self.num_features = num_features 52 | self.norm = norm 53 | self.dropout = dropout 54 | self.has_embedding = num_features > 0 55 | self.num_classes = num_classes 56 | 57 | out_planes = resnet.fc.in_features 58 | 59 | # Append new layers 60 | if self.has_embedding: 61 | self.feat = nn.Linear(out_planes, self.num_features) 62 | self.feat_bn = nn.BatchNorm1d(self.num_features) 63 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 64 | init.constant_(self.feat.bias, 0) 65 | else: 66 | # Change the num_features to CNN output channels 67 | self.num_features = out_planes 68 | self.feat_bn = nn.BatchNorm1d(self.num_features) 69 | self.feat_bn.bias.requires_grad_(False) 70 | if self.dropout > 0: 71 | self.drop = nn.Dropout(self.dropout) 72 | if self.num_classes > 0: 73 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 74 | init.normal_(self.classifier.weight, std=0.001) 75 | init.constant_(self.feat_bn.weight, 1) 76 | init.constant_(self.feat_bn.bias, 0) 77 | 78 | if not pretrained: 79 | self.reset_params() 80 | 81 | def forward(self, x): 82 | bs = x.size(0) 83 | x = self.base(x) 84 | 85 | x = self.gap(x) 86 | x = x.view(x.size(0), -1) 87 | 88 | if self.cut_at_pooling: 89 | return x 90 | 91 | if self.has_embedding: 92 | bn_x = self.feat_bn(self.feat(x)) 93 | else: 94 | bn_x = self.feat_bn(x) 95 | 96 | if (self.training is False): 97 | bn_x = F.normalize(bn_x) 98 | return bn_x 99 | 100 | if self.norm: 101 | bn_x = F.normalize(bn_x) 102 | elif self.has_embedding: 103 | bn_x = F.relu(bn_x) 104 | 105 | if self.dropout > 0: 106 | bn_x = self.drop(bn_x) 107 | 108 | if self.num_classes > 0: 109 | prob = self.classifier(bn_x) 110 | else: 111 | return bn_x 112 | 113 | return prob 114 | 115 | def reset_params(self): 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | init.kaiming_normal_(m.weight, mode='fan_out') 119 | if m.bias is not None: 120 | init.constant_(m.bias, 0) 121 | elif isinstance(m, nn.BatchNorm2d): 122 | init.constant_(m.weight, 1) 123 | init.constant_(m.bias, 0) 124 | elif isinstance(m, nn.BatchNorm1d): 125 | init.constant_(m.weight, 1) 126 | init.constant_(m.bias, 0) 127 | elif isinstance(m, nn.Linear): 128 | init.normal_(m.weight, std=0.001) 129 | if m.bias is not None: 130 | init.constant_(m.bias, 0) 131 | 132 | resnet = ResNet.__factory[self.depth](pretrained=self.pretrained) 133 | self.base[0].load_state_dict(resnet.conv1.state_dict()) 134 | self.base[1].load_state_dict(resnet.bn1.state_dict()) 135 | self.base[2].load_state_dict(resnet.maxpool.state_dict()) 136 | self.base[3].load_state_dict(resnet.layer1.state_dict()) 137 | self.base[4].load_state_dict(resnet.layer2.state_dict()) 138 | self.base[5].load_state_dict(resnet.layer3.state_dict()) 139 | self.base[6].load_state_dict(resnet.layer4.state_dict()) 140 | 141 | def load_pretrained_weight(self, model, pretrained_weight): 142 | weight = torch.load(pretrained_weight) 143 | if 'lup_moco_r50.pth' in pretrained_weight: 144 | model.load_state_dict(weight, strict=False) # LUPerson pretrain only provides trainable params 145 | else: 146 | raise NotImplementedError 147 | 148 | 149 | 150 | 151 | 152 | def resnet18(**kwargs): 153 | return ResNet(18, **kwargs) 154 | 155 | 156 | def resnet34(**kwargs): 157 | return ResNet(34, **kwargs) 158 | 159 | 160 | def resnet50(**kwargs): 161 | return ResNet(50, **kwargs) 162 | 163 | 164 | def resnet101(**kwargs): 165 | return ResNet(101, **kwargs) 166 | 167 | 168 | def resnet152(**kwargs): 169 | return ResNet(152, **kwargs) 170 | -------------------------------------------------------------------------------- /libs/models/resnet_ibn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | 9 | from .resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a 10 | 11 | 12 | __all__ = ['ResNetIBN', 'resnet_ibn50a', 'resnet_ibn101a'] 13 | 14 | 15 | class ResNetIBN(nn.Module): 16 | __factory = { 17 | '50a': resnet50_ibn_a, 18 | '101a': resnet101_ibn_a 19 | } 20 | 21 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 22 | num_features=0, norm=False, dropout=0, num_classes=0): 23 | super(ResNetIBN, self).__init__() 24 | 25 | self.depth = depth 26 | self.pretrained = pretrained 27 | self.cut_at_pooling = cut_at_pooling 28 | 29 | resnet = ResNetIBN.__factory[depth](pretrained=pretrained) 30 | resnet.layer4[0].conv2.stride = (1,1) 31 | resnet.layer4[0].downsample[0].stride = (1,1) 32 | self.base = nn.Sequential( 33 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 34 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 35 | self.gap = nn.AdaptiveAvgPool2d(1) 36 | 37 | if not self.cut_at_pooling: 38 | self.num_features = num_features 39 | self.norm = norm 40 | self.dropout = dropout 41 | self.has_embedding = num_features > 0 42 | self.num_classes = num_classes 43 | 44 | out_planes = resnet.fc.in_features 45 | 46 | # Append new layers 47 | if self.has_embedding: 48 | self.feat = nn.Linear(out_planes, self.num_features) 49 | self.feat_bn = nn.BatchNorm1d(self.num_features) 50 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 51 | init.constant_(self.feat.bias, 0) 52 | else: 53 | # Change the num_features to CNN output channels 54 | self.num_features = out_planes 55 | self.feat_bn = nn.BatchNorm1d(self.num_features) 56 | self.feat_bn.bias.requires_grad_(False) 57 | if self.dropout > 0: 58 | self.drop = nn.Dropout(self.dropout) 59 | if self.num_classes > 0: 60 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 61 | init.normal_(self.classifier.weight, std=0.001) 62 | init.constant_(self.feat_bn.weight, 1) 63 | init.constant_(self.feat_bn.bias, 0) 64 | 65 | if not pretrained: 66 | self.reset_params() 67 | 68 | def forward(self, x): 69 | x = self.base(x) 70 | 71 | x = self.gap(x) 72 | x = x.view(x.size(0), -1) 73 | 74 | if self.cut_at_pooling: 75 | return x 76 | 77 | if self.has_embedding: 78 | bn_x = self.feat_bn(self.feat(x)) 79 | else: 80 | bn_x = self.feat_bn(x) 81 | 82 | if self.training is False: 83 | bn_x = F.normalize(bn_x) 84 | return bn_x 85 | 86 | if self.norm: 87 | bn_x = F.normalize(bn_x) 88 | elif self.has_embedding: 89 | bn_x = F.relu(bn_x) 90 | 91 | if self.dropout > 0: 92 | bn_x = self.drop(bn_x) 93 | 94 | if self.num_classes > 0: 95 | prob = self.classifier(bn_x) 96 | else: 97 | return bn_x 98 | 99 | return prob 100 | 101 | def reset_params(self): 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | init.kaiming_normal_(m.weight, mode='fan_out') 105 | if m.bias is not None: 106 | init.constant_(m.bias, 0) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | init.constant_(m.weight, 1) 109 | init.constant_(m.bias, 0) 110 | elif isinstance(m, nn.BatchNorm1d): 111 | init.constant_(m.weight, 1) 112 | init.constant_(m.bias, 0) 113 | elif isinstance(m, nn.Linear): 114 | init.normal_(m.weight, std=0.001) 115 | if m.bias is not None: 116 | init.constant_(m.bias, 0) 117 | 118 | resnet = ResNetIBN.__factory[self.depth](pretrained=self.pretrained) 119 | self.base[0].load_state_dict(resnet.conv1.state_dict()) 120 | self.base[1].load_state_dict(resnet.bn1.state_dict()) 121 | self.base[2].load_state_dict(resnet.relu.state_dict()) 122 | self.base[3].load_state_dict(resnet.maxpool.state_dict()) 123 | self.base[4].load_state_dict(resnet.layer1.state_dict()) 124 | self.base[5].load_state_dict(resnet.layer2.state_dict()) 125 | self.base[6].load_state_dict(resnet.layer3.state_dict()) 126 | self.base[7].load_state_dict(resnet.layer4.state_dict()) 127 | 128 | 129 | def resnet_ibn50a(**kwargs): 130 | return ResNetIBN('50a', **kwargs) 131 | 132 | 133 | def resnet_ibn101a(**kwargs): 134 | return ResNetIBN('101a', **kwargs) 135 | -------------------------------------------------------------------------------- /libs/models/resnet_ibn_a.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['ResNet', 'resnet50_ibn_a', 'resnet101_ibn_a'] 8 | 9 | 10 | model_urls = { 11 | 'ibn_resnet50a': './logs/pretrained/resnet50_ibn_a.pth.tar', 12 | 'ibn_resnet101a': './logs/pretrained/resnet101_ibn_a.pth.tar', 13 | } 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class IBN(nn.Module): 55 | def __init__(self, planes): 56 | super(IBN, self).__init__() 57 | half1 = int(planes/2) 58 | self.half = half1 59 | half2 = planes - half1 60 | self.IN = nn.InstanceNorm2d(half1, affine=True) 61 | self.BN = nn.BatchNorm2d(half2) 62 | 63 | def forward(self, x): 64 | split = torch.split(x, self.half, 1) 65 | out1 = self.IN(split[0].contiguous()) 66 | out2 = self.BN(split[1].contiguous()) 67 | out = torch.cat((out1, out2), 1) 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): 75 | super(Bottleneck, self).__init__() 76 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 77 | if ibn: 78 | self.bn1 = IBN(planes) 79 | else: 80 | self.bn1 = nn.BatchNorm2d(planes) 81 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 82 | padding=1, bias=False) 83 | self.bn2 = nn.BatchNorm2d(planes) 84 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 85 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 86 | self.relu = nn.ReLU(inplace=True) 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x): 91 | residual = x 92 | 93 | out = self.conv1(x) 94 | out = self.bn1(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv2(out) 98 | out = self.bn2(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv3(out) 102 | out = self.bn3(out) 103 | 104 | if self.downsample is not None: 105 | residual = self.downsample(x) 106 | 107 | out += residual 108 | out = self.relu(out) 109 | 110 | return out 111 | 112 | 113 | class ResNet(nn.Module): 114 | 115 | def __init__(self, block, layers, num_classes=1000): 116 | scale = 64 117 | self.inplanes = scale 118 | super(ResNet, self).__init__() 119 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 120 | bias=False) 121 | self.bn1 = nn.BatchNorm2d(scale) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 124 | self.layer1 = self._make_layer(block, scale, layers[0]) 125 | self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2) 126 | self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2) 127 | self.layer4 = self._make_layer(block, scale*8, layers[3], stride=2) 128 | self.avgpool = nn.AvgPool2d(7) 129 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 130 | 131 | for m in self.modules(): 132 | if isinstance(m, nn.Conv2d): 133 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 134 | m.weight.data.normal_(0, math.sqrt(2. / n)) 135 | elif isinstance(m, nn.BatchNorm2d): 136 | m.weight.data.fill_(1) 137 | m.bias.data.zero_() 138 | elif isinstance(m, nn.InstanceNorm2d): 139 | m.weight.data.fill_(1) 140 | m.bias.data.zero_() 141 | 142 | def _make_layer(self, block, planes, blocks, stride=1): 143 | downsample = None 144 | if stride != 1 or self.inplanes != planes * block.expansion: 145 | downsample = nn.Sequential( 146 | nn.Conv2d(self.inplanes, planes * block.expansion, 147 | kernel_size=1, stride=stride, bias=False), 148 | nn.BatchNorm2d(planes * block.expansion), 149 | ) 150 | 151 | layers = [] 152 | ibn = True 153 | if planes == 512: 154 | ibn = False 155 | layers.append(block(self.inplanes, planes, ibn, stride, downsample)) 156 | self.inplanes = planes * block.expansion 157 | for i in range(1, blocks): 158 | layers.append(block(self.inplanes, planes, ibn)) 159 | 160 | return nn.Sequential(*layers) 161 | 162 | def forward(self, x): 163 | x = self.conv1(x) 164 | x = self.bn1(x) 165 | x = self.relu(x) 166 | x = self.maxpool(x) 167 | 168 | x = self.layer1(x) 169 | x = self.layer2(x) 170 | x = self.layer3(x) 171 | x = self.layer4(x) 172 | 173 | x = self.avgpool(x) 174 | x = x.view(x.size(0), -1) 175 | x = self.fc(x) 176 | 177 | return x 178 | 179 | 180 | def resnet50_ibn_a(pretrained=False, **kwargs): 181 | """Constructs a ResNet-50 model. 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 186 | if pretrained: 187 | state_dict = torch.load(model_urls['ibn_resnet50a'], map_location=torch.device('cpu'))['state_dict'] 188 | state_dict = remove_module_key(state_dict) 189 | model.load_state_dict(state_dict) 190 | return model 191 | 192 | 193 | def resnet101_ibn_a(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | state_dict = torch.load(model_urls['ibn_resnet101a'], map_location=torch.device('cpu'))['state_dict'] 201 | state_dict = remove_module_key(state_dict) 202 | model.load_state_dict(state_dict) 203 | return model 204 | 205 | 206 | def remove_module_key(state_dict): 207 | for key in list(state_dict.keys()): 208 | if 'module' in key: 209 | state_dict[key.replace('module.','')] = state_dict.pop(key) 210 | return state_dict 211 | -------------------------------------------------------------------------------- /libs/models/resnet_part.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | from .pooling import build_pooling_layer 9 | 10 | 11 | __all__ = ['ResNetPart', 'resnet18_part', 'resnet34_part', 'resnet50_part', 'resnet101_part', 12 | 'resnet152_part'] 13 | 14 | 15 | class ResNetPart(nn.Module): 16 | __factory = { 17 | 18: torchvision.models.resnet18, 18 | 34: torchvision.models.resnet34, 19 | 50: torchvision.models.resnet50, 20 | 101: torchvision.models.resnet101, 21 | 152: torchvision.models.resnet152, 22 | } 23 | 24 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 25 | num_features=0, norm=False, dropout=0, num_classes=0, 26 | pooling_type='avg', part_pooling_type='avg', num_parts=1): 27 | assert num_classes == 0, 'Disenable parametric classifier!' 28 | super(ResNetPart, self).__init__() 29 | self.pretrained = pretrained 30 | self.depth = depth 31 | self.cut_at_pooling = cut_at_pooling 32 | self.num_parts = num_parts 33 | # Construct base (pretrained) resnet 34 | if depth not in ResNetPart.__factory: 35 | raise KeyError("Unsupported depth:", depth) 36 | resnet = ResNetPart.__factory[depth](pretrained=pretrained) 37 | resnet.layer4[0].conv2.stride = (1,1) 38 | resnet.layer4[0].downsample[0].stride = (1,1) 39 | self.base = nn.Sequential( 40 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 41 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 42 | # self.gap = nn.AdaptiveAvgPool2d(1) # Vanilla model 43 | self.global_pool = build_pooling_layer(pooling_type) # follow cluster-contrast 44 | self.part_pool = build_pooling_layer(part_pooling_type) # 局部特征可用其他池化 45 | 46 | if not self.cut_at_pooling: 47 | self.num_features = num_features 48 | self.norm = norm 49 | self.dropout = dropout 50 | self.has_embedding = num_features > 0 51 | self.num_classes = num_classes 52 | 53 | out_planes = resnet.fc.in_features 54 | 55 | # Append new layers 56 | if self.has_embedding: 57 | # Global feature 58 | self.feat = nn.Linear(out_planes, self.num_features) 59 | self.feat_bn = nn.BatchNorm1d(self.num_features) 60 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 61 | init.constant_(self.feat.bias, 0) 62 | 63 | # Part feature 64 | self.part_feats = nn.ModuleList([nn.Linear(out_planes, self.num_features) for _ in range(self.num_parts)]) 65 | self.part_bns = nn.ModuleList([nn.BatchNorm1d(self.num_features) for _ in range(self.num_parts)]) 66 | map(lambda part_feat: init.kaiming_normal_(part_feat.weight, mode='fan_out'), self.part_feats) 67 | map(lambda part_feat: init.constant_(part_feat.bias, 0), self.part_feats) 68 | else: 69 | # Change the num_features to CNN output channels 70 | self.num_features = out_planes 71 | self.feat_bn = nn.BatchNorm1d(self.num_features) 72 | 73 | # 局部特征每个branch的BN 74 | self.part_bns = nn.ModuleList([nn.BatchNorm1d(self.num_features) for _ in range(self.num_parts)]) 75 | 76 | # 禁用BN的bias更新 77 | self.feat_bn.bias.requires_grad_(False) # NOTE BN的bias不被优化,但是weight会被优化 78 | map(lambda part_bn: part_bn.bias.requires_grad_(False), self.part_bns) # 局部branch的BN也禁用bias的优化 79 | 80 | if self.dropout > 0: 81 | self.drop = nn.Dropout(self.dropout) 82 | if self.num_classes > 0: 83 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 84 | init.normal_(self.classifier.weight, std=0.001) 85 | 86 | # 初始化BN参数 87 | init.constant_(self.feat_bn.weight, 1) 88 | init.constant_(self.feat_bn.bias, 0) 89 | map(lambda part_bn: init.constant_(part_bn.weight, 1), self.part_bns) 90 | map(lambda part_bn: init.constant_(part_bn.bias, 0), self.part_bns) 91 | 92 | if not pretrained: 93 | self.reset_params() 94 | 95 | def forward(self, x): 96 | bs = x.size(0) 97 | x = self.base(x) # (b, c, h, w) 98 | 99 | # Part level features 100 | if self.num_parts > 1: 101 | part_x = x.split(x.size(2)//self.num_parts, dim=2)[:self.num_parts] 102 | else: 103 | part_x = None 104 | 105 | assert part_x is not None, 'Check num_parts!' 106 | 107 | x = self.global_pool(x) 108 | part_x = list(map(self.part_pool, part_x)) 109 | x = x.view(x.size(0), -1) 110 | part_x = list(map(lambda part: part.view(part.size(0), -1), part_x)) 111 | 112 | if self.cut_at_pooling: 113 | return {'global': x, 'part': part_x} 114 | 115 | if self.has_embedding: 116 | bn_x = self.feat_bn(self.feat(x)) 117 | # bn_part_x = list(map(self.feat, part_x)) 118 | # bn_part_x = list(map(self.feat_bn, part_x)) 119 | bn_part_x = [part_head(px) for px, part_head in zip(part_x, self.part_feats)] 120 | bn_part_x = [part_bn(px) for px, part_bn in zip(bn_part_x, self.part_bns)] 121 | else: 122 | bn_x = self.feat_bn(x) 123 | # bn_part_x = list(map(self.feat_bn, part_x)) # BUG 局部特征和全局特征用了同个BN层,存在问题 124 | bn_part_x = [self.part_bns[i](part_x[i]) for i in range(self.num_parts)] 125 | 126 | if (self.training is False): 127 | bn_x = F.normalize(bn_x) 128 | bn_part_x = list(map(F.normalize, bn_part_x)) 129 | return {'global': bn_x, 'part': bn_part_x} 130 | 131 | if self.norm: 132 | bn_x = F.normalize(bn_x) 133 | bn_part_x = list(map(F.normalize, bn_part_x)) 134 | elif self.has_embedding: 135 | bn_x = F.relu(bn_x) 136 | bn_part_x = list(map(F.relu, bn_part_x)) 137 | 138 | if self.dropout > 0: 139 | bn_x = self.drop(bn_x) 140 | bn_part_x = list(map(self.drop, bn_part_x)) 141 | 142 | # if self.num_classes > 0: 143 | # prob = self.classifier(bn_x) 144 | # else: 145 | return {'global': bn_x, 'part': bn_part_x} 146 | 147 | # return prob 148 | 149 | def reset_params(self): 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | init.kaiming_normal_(m.weight, mode='fan_out') 153 | if m.bias is not None: 154 | init.constant_(m.bias, 0) 155 | elif isinstance(m, nn.BatchNorm2d): 156 | init.constant_(m.weight, 1) 157 | init.constant_(m.bias, 0) 158 | elif isinstance(m, nn.BatchNorm1d): 159 | init.constant_(m.weight, 1) 160 | init.constant_(m.bias, 0) 161 | elif isinstance(m, nn.Linear): 162 | init.normal_(m.weight, std=0.001) 163 | if m.bias is not None: 164 | init.constant_(m.bias, 0) 165 | 166 | resnet = ResNetPart.__factory[self.depth](pretrained=self.pretrained) 167 | self.base[0].load_state_dict(resnet.conv1.state_dict()) 168 | self.base[1].load_state_dict(resnet.bn1.state_dict()) 169 | self.base[2].load_state_dict(resnet.maxpool.state_dict()) 170 | self.base[3].load_state_dict(resnet.layer1.state_dict()) 171 | self.base[4].load_state_dict(resnet.layer2.state_dict()) 172 | self.base[5].load_state_dict(resnet.layer3.state_dict()) 173 | self.base[6].load_state_dict(resnet.layer4.state_dict()) 174 | 175 | def resnet18_part(**kwargs): 176 | return ResNetPart(18, **kwargs) 177 | 178 | 179 | def resnet34_part(**kwargs): 180 | return ResNetPart(34, **kwargs) 181 | 182 | 183 | def resnet50_part(**kwargs): 184 | return ResNetPart(50, **kwargs) 185 | 186 | def resnet101_part(**kwargs): 187 | return ResNetPart(101, **kwargs) 188 | 189 | 190 | def resnet152_part(**kwargs): 191 | return ResNetPart(152, **kwargs) 192 | -------------------------------------------------------------------------------- /libs/models/vit.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in 4 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 5 | 6 | The official jax code is released and available at https://github.com/google-research/vision_transformer 7 | 8 | Status/TODO: 9 | * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights. 10 | * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches. 11 | * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code. 12 | * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future. 13 | 14 | Acknowledgments: 15 | * The paper authors for releasing code and weights, thanks! 16 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 17 | for some einops/einsum fun 18 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 19 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 20 | 21 | Hacked together by / Copyright 2020 Ross Wightman 22 | """ 23 | import math 24 | from functools import partial 25 | from itertools import repeat 26 | 27 | import torch 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | 31 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 32 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 33 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8: 34 | from torch._six import container_abcs 35 | else: 36 | import collections.abc as container_abcs 37 | 38 | 39 | # From PyTorch internals 40 | def _ntuple(n): 41 | def parse(x): 42 | if isinstance(x, container_abcs.Iterable): 43 | return x 44 | return tuple(repeat(x, n)) 45 | return parse 46 | 47 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 48 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 49 | to_2tuple = _ntuple(2) 50 | 51 | def drop_path(x, drop_prob: float = 0., training: bool = False): 52 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 53 | 54 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 55 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 56 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 57 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 58 | 'survival rate' as the argument. 59 | 60 | """ 61 | if drop_prob == 0. or not training: 62 | return x 63 | keep_prob = 1 - drop_prob 64 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 65 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 66 | random_tensor.floor_() # binarize 67 | output = x.div(keep_prob) * random_tensor 68 | return output 69 | 70 | class DropPath(nn.Module): 71 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 72 | """ 73 | def __init__(self, drop_prob=None): 74 | super(DropPath, self).__init__() 75 | self.drop_prob = drop_prob 76 | 77 | def forward(self, x): 78 | return drop_path(x, self.drop_prob, self.training) 79 | 80 | 81 | def _cfg(url='', **kwargs): 82 | return { 83 | 'url': url, 84 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 85 | 'crop_pct': .9, 'interpolation': 'bicubic', 86 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 87 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 88 | **kwargs 89 | } 90 | 91 | 92 | default_cfgs = { 93 | # patch models 94 | 'vit_small_patch16_224': _cfg( 95 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', 96 | ), 97 | 'vit_base_patch16_224': _cfg( 98 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', 99 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 100 | ), 101 | 'vit_base_patch16_384': _cfg( 102 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', 103 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 104 | 'vit_base_patch32_384': _cfg( 105 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', 106 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 107 | 'vit_large_patch16_224': _cfg( 108 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', 109 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 110 | 'vit_large_patch16_384': _cfg( 111 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', 112 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 113 | 'vit_large_patch32_384': _cfg( 114 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 115 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 116 | 'vit_huge_patch16_224': _cfg(), 117 | 'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)), 118 | # hybrid models 119 | 'vit_small_resnet26d_224': _cfg(), 120 | 'vit_small_resnet50d_s3_224': _cfg(), 121 | 'vit_base_resnet26d_224': _cfg(), 122 | 'vit_base_resnet50d_224': _cfg(), 123 | } 124 | 125 | 126 | class Mlp(nn.Module): 127 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 128 | super().__init__() 129 | out_features = out_features or in_features 130 | hidden_features = hidden_features or in_features 131 | self.fc1 = nn.Linear(in_features, hidden_features) 132 | self.act = act_layer() 133 | self.fc2 = nn.Linear(hidden_features, out_features) 134 | self.drop = nn.Dropout(drop) 135 | 136 | def forward(self, x): 137 | x = self.fc1(x) 138 | x = self.act(x) 139 | x = self.drop(x) 140 | x = self.fc2(x) 141 | x = self.drop(x) 142 | return x 143 | 144 | 145 | class Attention(nn.Module): 146 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 147 | super().__init__() 148 | self.num_heads = num_heads 149 | head_dim = dim // num_heads 150 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 151 | self.scale = qk_scale or head_dim ** -0.5 152 | 153 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 154 | self.attn_drop = nn.Dropout(attn_drop) 155 | self.proj = nn.Linear(dim, dim) 156 | self.proj_drop = nn.Dropout(proj_drop) 157 | 158 | def forward(self, x): 159 | B, N, C = x.shape 160 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 161 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 162 | 163 | attn = (q @ k.transpose(-2, -1)) * self.scale 164 | attn = attn.softmax(dim=-1) 165 | attn = self.attn_drop(attn) 166 | 167 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 168 | x = self.proj(x) 169 | x = self.proj_drop(x) 170 | return x 171 | 172 | 173 | class Block(nn.Module): 174 | 175 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 176 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 177 | super().__init__() 178 | self.norm1 = norm_layer(dim) 179 | self.attn = Attention( 180 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 181 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 182 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 183 | self.norm2 = norm_layer(dim) 184 | mlp_hidden_dim = int(dim * mlp_ratio) 185 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 186 | 187 | def forward(self, x): 188 | x = x + self.drop_path(self.attn(self.norm1(x))) 189 | x = x + self.drop_path(self.mlp(self.norm2(x))) 190 | return x 191 | 192 | class HybridEmbed(nn.Module): 193 | """ CNN Feature Map Embedding 194 | Extract feature map from CNN, flatten, project to embedding dim. 195 | """ 196 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 197 | super().__init__() 198 | assert isinstance(backbone, nn.Module) 199 | img_size = to_2tuple(img_size) 200 | self.img_size = img_size 201 | self.backbone = backbone 202 | if feature_size is None: 203 | with torch.no_grad(): 204 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 205 | # map for all networks, the feature metadata has reliable channel and stride info, but using 206 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 207 | training = backbone.training 208 | if training: 209 | backbone.eval() 210 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) 211 | if isinstance(o, (list, tuple)): 212 | o = o[-1] # last feature if backbone outputs list/tuple of features 213 | feature_size = o.shape[-2:] 214 | feature_dim = o.shape[1] 215 | backbone.train(training) 216 | else: 217 | feature_size = to_2tuple(feature_size) 218 | if hasattr(self.backbone, 'feature_info'): 219 | feature_dim = self.backbone.feature_info.channels()[-1] 220 | else: 221 | feature_dim = self.backbone.num_features 222 | self.num_patches = feature_size[0] * feature_size[1] 223 | self.proj = nn.Conv2d(feature_dim, embed_dim, 1) 224 | 225 | def forward(self, x): 226 | x = self.backbone(x) 227 | if isinstance(x, (list, tuple)): 228 | x = x[-1] # last feature if backbone outputs list/tuple of features 229 | x = self.proj(x).flatten(2).transpose(1, 2) 230 | return x 231 | 232 | 233 | class PatchEmbed_overlap(nn.Module): 234 | """ Image to Patch Embedding with overlapping patches 235 | """ 236 | def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768): 237 | super().__init__() 238 | img_size = to_2tuple(img_size) 239 | patch_size = to_2tuple(patch_size) 240 | stride_size_tuple = to_2tuple(stride_size) 241 | self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1 242 | self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1 243 | print('using stride: {}, and patch number is num_y{} * num_x{}'.format(stride_size, self.num_y, self.num_x)) 244 | num_patches = self.num_x * self.num_y 245 | self.img_size = img_size 246 | self.patch_size = patch_size 247 | self.num_patches = num_patches 248 | 249 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size) 250 | for m in self.modules(): 251 | if isinstance(m, nn.Conv2d): 252 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 253 | m.weight.data.normal_(0, math.sqrt(2. / n)) 254 | elif isinstance(m, nn.BatchNorm2d): 255 | m.weight.data.fill_(1) 256 | m.bias.data.zero_() 257 | elif isinstance(m, nn.InstanceNorm2d): 258 | m.weight.data.fill_(1) 259 | m.bias.data.zero_() 260 | 261 | def forward(self, x): 262 | B, C, H, W = x.shape 263 | 264 | # FIXME look at relaxing size constraints 265 | assert H == self.img_size[0] and W == self.img_size[1], \ 266 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 267 | x = self.proj(x) 268 | 269 | x = x.flatten(2).transpose(1, 2) # [64, 8, 768] 270 | return x 271 | 272 | class GeneralizedMeanPooling(nn.Module): 273 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. 274 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 275 | - At p = infinity, one gets Max Pooling 276 | - At p = 1, one gets Average Pooling 277 | The output is of size H x W, for any input size. 278 | The number of output features is equal to the number of input planes. 279 | Args: 280 | output_size: the target output size of the image of the form H x W. 281 | Can be a tuple (H, W) or a single H for a square image H x H 282 | H and W can be either a ``int``, or ``None`` which means the size will 283 | be the same as that of the input. 284 | """ 285 | 286 | def __init__(self, norm=3, output_size=1, eps=1e-6): 287 | super(GeneralizedMeanPooling, self).__init__() 288 | assert norm > 0 289 | self.p = float(norm) 290 | self.output_size = output_size 291 | self.eps = eps 292 | 293 | def forward(self, x): 294 | x = x.clamp(min=self.eps).pow(self.p) 295 | return F.adaptive_avg_pool1d(x, self.output_size).pow(1. / self.p) 296 | 297 | class IBN(nn.Module): 298 | def __init__(self, planes): 299 | super(IBN, self).__init__() 300 | half1 = int(planes/2) 301 | self.half = half1 302 | half2 = planes - half1 303 | self.IN = nn.InstanceNorm2d(half1, affine=True) 304 | self.BN = nn.BatchNorm2d(half2) 305 | 306 | def forward(self, x): 307 | split = torch.split(x, self.half, 1) 308 | out1 = self.IN(split[0].contiguous()) 309 | out2 = self.BN(split[1].contiguous()) 310 | out = torch.cat((out1, out2), 1) 311 | return out 312 | 313 | class PatchEmbed(nn.Module): 314 | """ Image to Patch Embedding with overlapping patches 315 | """ 316 | def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, embed_dim=768, stem_conv=False): 317 | super().__init__() 318 | img_size = to_2tuple(img_size) 319 | patch_size = to_2tuple(patch_size) 320 | stride_size_tuple = to_2tuple(stride_size) 321 | self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1 322 | self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1 323 | print('using stride: {}, and patch number is num_y{} * num_x{}'.format(stride_size, self.num_y, self.num_x)) 324 | self.num_patches = self.num_x * self.num_y 325 | self.img_size = img_size 326 | self.patch_size = patch_size 327 | 328 | self.stem_conv = stem_conv 329 | if self.stem_conv: 330 | hidden_dim = 64 331 | stem_stride = 2 332 | stride_size = patch_size = patch_size[0] // stem_stride 333 | self.conv = nn.Sequential( 334 | nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3,bias=False), 335 | IBN(hidden_dim), 336 | nn.ReLU(inplace=True), 337 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,padding=1,bias=False), 338 | IBN(hidden_dim), 339 | nn.ReLU(inplace=True), 340 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,padding=1,bias=False), 341 | nn.BatchNorm2d(hidden_dim), 342 | nn.ReLU(inplace=True), 343 | ) 344 | in_chans = hidden_dim 345 | 346 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size) 347 | 348 | def forward(self, x): 349 | if self.stem_conv: 350 | x = self.conv(x) 351 | x = self.proj(x) 352 | x = x.flatten(2).transpose(1, 2) # [64, 8, 768] 353 | return x 354 | 355 | class TransReID(nn.Module): 356 | """ Transformer-based Object Re-Identification 357 | """ 358 | def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 359 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., camera=0, view=0, 360 | drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), has_early_feature=False, sie_coef=1.0, 361 | gem_pool=False, stem_conv=False, enable_early_norm=False, **kwargs): 362 | super().__init__() 363 | self.enable_early_norm = enable_early_norm 364 | self.num_classes = num_classes 365 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 366 | self.has_early_feature = has_early_feature 367 | self.patch_embed = PatchEmbed( 368 | img_size=img_size, patch_size=patch_size, stride_size=stride_size, in_chans=in_chans, 369 | embed_dim=embed_dim, stem_conv=stem_conv) 370 | 371 | num_patches = self.patch_embed.num_patches 372 | 373 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 374 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 375 | self.cam_num = camera 376 | self.view_num = view 377 | self.sie_coef = sie_coef 378 | self.in_planes = 768 379 | self.gem_pool = gem_pool 380 | if self.gem_pool: 381 | print('using gem pooling') 382 | # Initialize SIE Embedding 383 | if camera > 1 and view > 1: 384 | self.sie_embed = nn.Parameter(torch.zeros(camera * view, 1, embed_dim)) 385 | trunc_normal_(self.sie_embed, std=.02) 386 | print('camera number is : {} and viewpoint number is : {}'.format(camera, view)) 387 | print('using SIE_Lambda is : {}'.format(sie_coef)) 388 | elif camera > 1: 389 | self.sie_embed = nn.Parameter(torch.zeros(camera, 1, embed_dim)) 390 | trunc_normal_(self.sie_embed, std=.02) 391 | print('camera number is : {}'.format(camera)) 392 | print('using SIE_Lambda is : {}'.format(sie_coef)) 393 | elif view > 1: 394 | self.sie_embed = nn.Parameter(torch.zeros(view, 1, embed_dim)) 395 | trunc_normal_(self.sie_embed, std=.02) 396 | print('viewpoint number is : {}'.format(view)) 397 | print('using SIE_Lambda is : {}'.format(sie_coef)) 398 | 399 | print('using drop_out rate is : {}'.format(drop_rate)) 400 | print('using attn_drop_out rate is : {}'.format(attn_drop_rate)) 401 | print('using drop_path rate is : {}'.format(drop_path_rate)) 402 | 403 | self.pos_drop = nn.Dropout(p=drop_rate) 404 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 405 | 406 | self.blocks = nn.ModuleList([ 407 | Block( 408 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 409 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 410 | for i in range(depth)]) 411 | 412 | self.norm = norm_layer(embed_dim) 413 | 414 | # Classifier head 415 | # self.fc = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 416 | trunc_normal_(self.cls_token, std=.02) 417 | trunc_normal_(self.pos_embed, std=.02) 418 | 419 | self.apply(self._init_weights) 420 | self.gem = GeneralizedMeanPooling() 421 | 422 | def _init_weights(self, m): 423 | if isinstance(m, nn.Linear): 424 | trunc_normal_(m.weight, std=.02) 425 | if isinstance(m, nn.Linear) and m.bias is not None: 426 | nn.init.constant_(m.bias, 0) 427 | elif isinstance(m, nn.LayerNorm): 428 | nn.init.constant_(m.bias, 0) 429 | nn.init.constant_(m.weight, 1.0) 430 | 431 | @torch.jit.ignore 432 | def no_weight_decay(self): 433 | return {'pos_embed', 'cls_token'} 434 | 435 | def get_classifier(self): 436 | return self.head 437 | 438 | def reset_classifier(self, num_classes, global_pool=''): 439 | self.num_classes = num_classes 440 | self.fc = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 441 | 442 | def forward_features(self, x, camera_id, view_id): 443 | B = x.shape[0] 444 | x = self.patch_embed(x) 445 | 446 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 447 | x = torch.cat((cls_tokens, x), dim=1) 448 | 449 | if self.cam_num > 0 and self.view_num > 0: 450 | x = x + self.pos_embed + self.sie_coef * self.sie_embed[camera_id * self.view_num + view_id] 451 | elif self.cam_num > 0: 452 | x = x + self.pos_embed + self.sie_coef * self.sie_embed[camera_id] 453 | elif self.view_num > 0: 454 | x = x + self.pos_embed + self.sie_coef * self.sie_embed[view_id] 455 | else: 456 | x = x + self.pos_embed 457 | 458 | x = self.pos_drop(x) 459 | 460 | if self.has_early_feature: 461 | for blk in self.blocks[:-1]: 462 | x = blk(x) 463 | if self.enable_early_norm: 464 | x = self.norm(x) 465 | return x 466 | 467 | else: 468 | for blk in self.blocks: 469 | x = blk(x) 470 | 471 | x = self.norm(x) 472 | 473 | if self.gem_pool: 474 | gf = self.gem(x[:,1:].permute(0,2,1)).squeeze() 475 | return x[:, 0] + gf 476 | 477 | return x 478 | 479 | def forward(self, x, cam_label=None, view_label=None): 480 | x = self.forward_features(x, cam_label, view_label) 481 | return x 482 | 483 | def load_param(self, model_path, hw_ratio): 484 | param_dict = torch.load(model_path, map_location='cpu') 485 | count = 0 486 | if 'model' in param_dict: 487 | param_dict = param_dict['model'] 488 | if 'state_dict' in param_dict: 489 | param_dict = param_dict['state_dict'] 490 | for k, v in param_dict.items(): 491 | if 'head' in k or 'dist' in k or 'pre_logits' in k: 492 | continue 493 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 494 | # For old models that I trained prior to conv based patchification 495 | O, I, H, W = self.patch_embed.proj.weight.shape 496 | v = v.reshape(O, -1, H, W) 497 | elif k == 'pos_embed' and v.shape != self.pos_embed.shape: 498 | # To resize pos embedding when using model at different size from pretrained weights 499 | if 'distilled' in model_path: 500 | print('distill need to choose right cls token in the pth') 501 | v = torch.cat([v[:, 0:1], v[:, 2:]], dim=1) 502 | v = resize_pos_embed(v, self.pos_embed, self.patch_embed.num_y, self.patch_embed.num_x, hw_ratio) 503 | try: 504 | self.state_dict()[k].copy_(v) 505 | count += 1 506 | except: 507 | print('===========================ERROR=========================') 508 | print('shape do not match in k :{}: param_dict{} vs self.state_dict(){}'.format(k, v.shape, self.state_dict()[k].shape)) 509 | print('Load %d / %d layers.'%(count,len(self.state_dict().keys()))) 510 | 511 | def resize_pos_embed(posemb, posemb_new, hight, width, hw_ratio): 512 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 513 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 514 | ntok_new = posemb_new.shape[1] 515 | 516 | posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] 517 | ntok_new -= 1 518 | 519 | gs_old_h = int(math.sqrt(len(posemb_grid)*hw_ratio)) 520 | gs_old_w = gs_old_h // hw_ratio 521 | print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width)) 522 | posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2) 523 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') 524 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) 525 | posemb = torch.cat([posemb_token, posemb_grid], dim=1) 526 | return posemb 527 | 528 | 529 | def vit_base_patch16_224_TransReID(img_size=(256, 128), stride_size=16, drop_path_rate=0.1, camera=0, view=0,has_early_feature=False,sie_xishu=1.5, 530 | enable_early_norm=False, **kwargs): 531 | model = TransReID( 532 | img_size=img_size, patch_size=16, stride_size=stride_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\ 533 | camera=camera, view=view, drop_path_rate=drop_path_rate, sie_coef=sie_xishu, has_early_feature=has_early_feature, enable_early_norm=enable_early_norm, **kwargs) 534 | 535 | return model 536 | 537 | def vit_small_patch16_224_TransReID(img_size=(256, 128), stride_size=16, drop_path_rate=0.1, camera=0, view=0, has_early_feature=False, sie_xishu=1.5, 538 | enable_early_norm=False, **kwargs): 539 | model = TransReID( 540 | img_size=img_size, patch_size=16, stride_size=stride_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, drop_path_rate=drop_path_rate,\ 541 | camera=camera, view=view, sie_coef=sie_xishu, has_early_feature=has_early_feature, enable_early_norm=enable_early_norm, **kwargs) 542 | model.in_planes = 384 543 | 544 | return model 545 | 546 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 547 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 548 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 549 | def norm_cdf(x): 550 | # Computes standard normal cumulative distribution function 551 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 552 | 553 | if (mean < a - 2 * std) or (mean > b + 2 * std): 554 | print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 555 | "The distribution of values may be incorrect.",) 556 | 557 | with torch.no_grad(): 558 | # Values are generated by using a truncated uniform distribution and 559 | # then using the inverse CDF for the normal distribution. 560 | # Get upper and lower cdf values 561 | l = norm_cdf((a - mean) / std) 562 | u = norm_cdf((b - mean) / std) 563 | 564 | # Uniformly fill tensor with values from [l, u], then translate to 565 | # [2l-1, 2u-1]. 566 | tensor.uniform_(2 * l - 1, 2 * u - 1) 567 | 568 | # Use inverse cdf transform for normal distribution to get truncated 569 | # standard normal 570 | tensor.erfinv_() 571 | 572 | # Transform to proper mean, std 573 | tensor.mul_(std * math.sqrt(2.)) 574 | tensor.add_(mean) 575 | 576 | # Clamp to ensure it's in the proper range 577 | tensor.clamp_(min=a, max=b) 578 | return tensor 579 | 580 | 581 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 582 | # type: (Tensor, float, float, float, float) -> Tensor 583 | r"""Fills the input Tensor with values drawn from a truncated 584 | normal distribution. The values are effectively drawn from the 585 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 586 | with values outside :math:`[a, b]` redrawn until they are within 587 | the bounds. The method used for generating the random values works 588 | best when :math:`a \leq \text{mean} \leq b`. 589 | Args: 590 | tensor: an n-dimensional `torch.Tensor` 591 | mean: the mean of the normal distribution 592 | std: the standard deviation of the normal distribution 593 | a: the minimum cutoff value 594 | b: the maximum cutoff value 595 | Examples: 596 | >>> w = torch.empty(3, 5) 597 | >>> nn.init.trunc_normal_(w) 598 | """ 599 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 600 | -------------------------------------------------------------------------------- /libs/models/vit_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | The backbone implementation is inspired by TransReID series. The pretrained weighted is provided by LUPerson. 3 | Thanks for their excellent works! 4 | TransReID: https://github.com/damo-cv/TransReID 5 | TransReID-SSL: https://github.com/damo-cv/TransReID-SSL 6 | LUPerson: https://github.com/DengpanFu/LUPerson 7 | """ 8 | 9 | import sys 10 | import os.path as osp 11 | sys.path.append(osp.abspath(osp.join(__file__, '..'))) 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import copy 17 | from libs.models.vit import vit_small_patch16_224_TransReID 18 | 19 | def shuffle_unit(features, shift, group, begin=1): 20 | 21 | batchsize = features.size(0) 22 | dim = features.size(-1) 23 | # Shift Operation 24 | feature_random = torch.cat([features[:, begin-1+shift:], features[:, begin:begin-1+shift]], dim=1) 25 | x = feature_random 26 | # Patch Shuffle Operation 27 | try: 28 | x = x.view(batchsize, group, -1, dim) 29 | except: 30 | x = torch.cat([x, x[:, -2:-1, :]], dim=1) 31 | x = x.view(batchsize, group, -1, dim) 32 | 33 | x = torch.transpose(x, 1, 2).contiguous() 34 | x = x.view(batchsize, -1, dim) 35 | 36 | return x 37 | 38 | def weights_init_kaiming(m): 39 | classname = m.__class__.__name__ 40 | if classname.find('Linear') != -1: 41 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 42 | nn.init.constant_(m.bias, 0.0) 43 | 44 | elif classname.find('Conv') != -1: 45 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 46 | if m.bias is not None: 47 | nn.init.constant_(m.bias, 0.0) 48 | elif classname.find('BatchNorm') != -1: 49 | if m.affine: 50 | nn.init.constant_(m.weight, 1.0) 51 | nn.init.constant_(m.bias, 0.0) 52 | 53 | def weights_init_classifier(m): 54 | classname = m.__class__.__name__ 55 | if classname.find('Linear') != -1: 56 | nn.init.normal_(m.weight, std=0.001) 57 | if m.bias: 58 | nn.init.constant_(m.bias, 0.0) 59 | 60 | 61 | class TMGF(nn.Module): 62 | """ 63 | Transformer-based Multi-Grained Feature encoder. 64 | """ 65 | 66 | __factory = { 67 | 'tmgf': vit_small_patch16_224_TransReID 68 | } 69 | 70 | def __init__(self, arch, img_size, sie_coef, camera_num, view_num, stride_size, drop_path_rate, drop_rate, attn_drop_rate, 71 | pretrain_path, hw_ratio, gem_pool, stem_conv, num_parts, has_early_feature, has_head, global_feature_type, 72 | granularities, branch, enable_early_norm, **kwargs): 73 | super().__init__() 74 | print(f'using Transformer_type: {arch} as a backbone') 75 | 76 | assert sum(granularities) == num_parts 77 | assert branch in ('all', 'b1', 'b2') 78 | 79 | if camera_num: 80 | camera_num = camera_num 81 | else: 82 | camera_num = 0 83 | if view_num: 84 | view_num = view_num 85 | else: 86 | view_num = 0 87 | 88 | self.base = TMGF.__factory[arch](img_size=img_size, sie_xishu=sie_coef, camera=camera_num, view=view_num, stride_size=stride_size, 89 | drop_path_rate=drop_path_rate, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 90 | gem_pool=gem_pool, stem_conv=stem_conv, has_early_feature=has_early_feature, 91 | enable_early_norm=enable_early_norm, **kwargs) # local_feature = True for no projection head ablation 92 | self.in_planes = self.base.in_planes 93 | self.has_head = has_head 94 | self.global_feature_type = global_feature_type 95 | self.granularities = granularities 96 | self.branch = branch 97 | 98 | 99 | if pretrain_path != '': 100 | if osp.exists(pretrain_path): 101 | self.base.load_param(pretrain_path, hw_ratio) 102 | print('Loading pretrained weights from {} ...'.format(pretrain_path)) 103 | else: 104 | raise FileNotFoundError('Cannot find {}'.format(pretrain_path)) 105 | else: 106 | print('Initialize weights randomly.') 107 | 108 | # Part split settings 109 | self.num_parts = num_parts 110 | self.fmap_h = img_size[0] // stride_size[0] 111 | self.fmap_w = img_size[1] // stride_size[1] 112 | 113 | # Two different granularity branches 114 | if self.has_head: 115 | block = self.base.blocks[-1] 116 | layer_norm = self.base.norm 117 | self.b1 = nn.Sequential( 118 | copy.deepcopy(block), 119 | copy.deepcopy(layer_norm) 120 | ) 121 | self.b2 = nn.Sequential( 122 | copy.deepcopy(block), 123 | copy.deepcopy(layer_norm) 124 | ) 125 | 126 | # Pooling layers 127 | for i, g in enumerate(self.granularities): 128 | setattr(self, 'b{}_pool'.format(i+1), nn.AvgPool2d(kernel_size=(self.fmap_h//g, self.fmap_w), 129 | stride=(self.fmap_h//g,))) 130 | 131 | print('num_parts={}, branch_parts={}'.format(self.num_parts, self.granularities)) 132 | 133 | # Global bottleneck 134 | self.bottleneck = self.make_bnneck(self.in_planes, weights_init_kaiming) 135 | 136 | 137 | # Part bottleneck 138 | self.part_bns = nn.ModuleList([ 139 | self.make_bnneck(self.in_planes, weights_init_kaiming) for i in range(self.num_parts) 140 | ]) 141 | 142 | def forward_single_branch(self, x, branch, label=None, cam_label=None, view_label=None): 143 | """ 144 | Full ViT, no projection head. One part pooling branch. 145 | """ 146 | 147 | x = self.base(x, cam_label=cam_label, view_label=view_label) 148 | B = x.size(0) 149 | x_glb = x[:,0,:] 150 | x_patch = x[:,1:,:] 151 | x_patch = x_patch.permute(0,2,1).reshape((B, self.in_planes, self.fmap_h, self.fmap_w)) 152 | x_part = getattr(self, '{}_pool'.format(branch))(x_patch).squeeze() 153 | 154 | return x_glb, x_part 155 | 156 | def forward_multi_branch(self, x, label=None, cam_label=None, view_label=None): 157 | """ 158 | ViT 1st ~ (L-1)-th layers, duplicated L-th layers as projection heads for two branches. 159 | """ 160 | 161 | x = self.base(x, cam_label=cam_label, view_label=view_label) # output before last layer 162 | B = x.size(0) 163 | 164 | # Split after head 165 | # branch 1 166 | x_b1 = self.b1(x) # (B, L, C) 167 | x_b1_glb = x_b1[:,0,:] # (B, C) 168 | x_b1_patch = x_b1[:,1:,:] # (B, L-1, C) 169 | x_b1_patch = x_b1_patch.permute(0,2,1).reshape((B, self.in_planes, self.fmap_h, self.fmap_w)) 170 | x_b1_patch = self.b1_pool(x_b1_patch).squeeze() # (B, C, P1) 171 | 172 | # branch 2 173 | x_b2 = self.b2(x) 174 | x_b2_glb = x_b2[:,0,:] 175 | x_b2_patch = x_b2[:,1:,:] 176 | x_b2_patch = x_b2_patch.permute(0,2,1).reshape((B, self.in_planes, self.fmap_h, self.fmap_w)) 177 | x_b2_patch = self.b2_pool(x_b2_patch).squeeze() # (B, C, P2) 178 | 179 | # Mean global feature 180 | if self.global_feature_type == 'mean': 181 | x_glb = 0.5 * (x_b1_glb + x_b2_glb) # (B, C) 182 | elif self.global_feature_type == 'b1': 183 | x_glb = x_b1_glb 184 | elif self.global_feature_type == 'b2': 185 | x_glb = x_b2_glb 186 | else: 187 | raise ValueError('Invalid global feature type: {}'.format(self.global_feature_type)) 188 | 189 | # Stack two branch part features 190 | x_part = torch.cat([x_b1_patch, x_b2_patch], dim=2) # (B, C, P), P = P1 + P2 191 | 192 | return x_glb, x_part 193 | 194 | def forward_multi_branch_no_head(self, x, label=None, cam_label=None, view_label=None): 195 | """ 196 | Full ViT, no projection head. Two part pooling branches. 197 | """ 198 | 199 | x = self.base(x, cam_label=cam_label, view_label=view_label) 200 | B = x.size(0) 201 | 202 | # Split without head 203 | # branch 1 204 | x_patch = x[:,1:,:] # (B, L-1, C) 205 | x_patch = x_patch.permute(0,2,1).reshape((B, self.in_planes, self.fmap_h, self.fmap_w)) 206 | x_b1_patch = self.b1_pool(x_patch).squeeze() # (B, C, P1) 207 | 208 | # branch 2 209 | x_b2_patch = self.b2_pool(x_patch).squeeze() # (B, C, P2) 210 | 211 | # global feature 212 | x_glb = x[:,0,:] # (B, C) 213 | 214 | # Stack two branch part features 215 | x_part = torch.cat([x_b1_patch, x_b2_patch], dim=2) # (B, C, P), P = P1 + P2 216 | 217 | return x_glb, x_part 218 | 219 | def forward(self, x, label=None, cam_label=None, view_label=None): 220 | B = x.size(0) 221 | if self.has_head: 222 | x_glb, x_part = self.forward_multi_branch(x, label, cam_label, view_label) 223 | elif self.branch != 'all': 224 | x_glb, x_part = self.forward_single_branch(x, self.branch, label, cam_label, view_label) 225 | else: 226 | x_glb, x_part = self.forward_multi_branch_no_head(x, label, cam_label, view_label) 227 | 228 | # BNNeck + L2 norm 229 | x_glb = self.bottleneck(x_glb) 230 | x_part = torch.stack([self.part_bns[i](x_part[:,:,i]) for i in range(x_part.size(2))], dim=2) 231 | 232 | x_glb = F.normalize(x_glb, dim=1) 233 | x_part = F.normalize(x_part, dim=1) 234 | 235 | assert x_part.size(2) == self.num_parts, 'x_part size: {} != num_parts: {}'.format( 236 | x_part.size(2), self.num_parts) # check part num 237 | 238 | return {'global': x_glb, 'part': x_part.permute(2,0,1)} # x_part as (P, B, C) 239 | 240 | def make_bnneck(self, dims, init_func): 241 | bn = nn.BatchNorm1d(dims) 242 | bn.bias.requires_grad_(False) # disable bias update 243 | bn.apply(init_func) 244 | return bn 245 | 246 | def tmgf(**kwargs): 247 | return TMGF(**kwargs) -------------------------------------------------------------------------------- /libs/trainers.py: -------------------------------------------------------------------------------- 1 | import time 2 | from torch.cuda import amp 3 | from .utils.meters import AverageMeter 4 | 5 | class _BaseTrainer: 6 | """The most basic trainer class.""" 7 | def __init__(self, encoder, memory) -> None: 8 | super().__init__() 9 | self.encoder = encoder 10 | self.memory = memory 11 | 12 | def train(self, epoch, data_loader, optimizer, print_freq=10, train_iters=400): 13 | raise NotImplementedError 14 | 15 | def _parse_data(self, inputs): 16 | imgs, _, _, cams, index_target, _ = inputs # img, fname, pseudo_label, camid, img_index, accum_label 17 | return imgs.cuda(), cams.cuda(), index_target.cuda() 18 | 19 | def _forward(self, inputs): 20 | return self.encoder(inputs) 21 | 22 | class ViTTrainerFp16(_BaseTrainer): 23 | """ 24 | ViT trainer with FP16 forwarding. 25 | """ 26 | def __init__(self, encoder, memory) -> None: 27 | super().__init__(encoder, memory) 28 | 29 | def train(self, epoch, data_loader, optimizer, print_freq=10, train_iters=400, fp16=False): 30 | self.encoder.train() 31 | batch_time = AverageMeter() 32 | data_time = AverageMeter() 33 | losses = AverageMeter() 34 | end = time.time() 35 | 36 | # amp fp16 training 37 | scaler = amp.GradScaler() if fp16 else None 38 | 39 | for i in range(train_iters): 40 | # load data 41 | inputs = data_loader.next() 42 | data_time.update(time.time() - end) 43 | 44 | # process inputs 45 | inputs, cams, index_target = self._parse_data(inputs) 46 | 47 | # loss 48 | with amp.autocast(enabled=fp16): 49 | # forward 50 | f_out = self._forward(inputs, cam_label=cams) # dict: global & part features 51 | 52 | # compute loss with the memory 53 | loss_dict = self.memory(f_out, index_target, cams, epoch) 54 | loss = loss_dict['loss'] 55 | 56 | optimizer.zero_grad() 57 | 58 | if scaler is None: 59 | loss.backward() 60 | optimizer.step() 61 | else: 62 | scaler.scale(loss).backward() 63 | scaler.step(optimizer) 64 | scaler.update() 65 | 66 | 67 | losses.update(loss.item()) 68 | 69 | # print log 70 | batch_time.update(time.time() - end) 71 | end = time.time() 72 | 73 | if (i + 1) % print_freq == 0: 74 | print('Epoch: [{}][{}/{}] ' 75 | 'Time: {:.3f} ({:.3f}), ' 76 | 'Data: {:.3f} ({:.3f}), ' 77 | 'Loss: {:.3f} ({:.3f}), ' 78 | '{}' 79 | .format(epoch, i + 1, len(data_loader), 80 | batch_time.val, batch_time.avg, 81 | data_time.val, data_time.avg, 82 | losses.val, losses.avg, 83 | ', '.join(['{}: {:.3f}'.format(k, v) for k, v in loss_dict.items()]))) 84 | 85 | def _parse_data(self, inputs): 86 | imgs, _, _, cams, index_target, _ = inputs # img, fname, pseudo_label, camid, img_index, accum_label 87 | return imgs.cuda(), cams.cuda(), index_target.cuda() 88 | 89 | def _forward(self, *args, **kwargs): 90 | return self.encoder(*args, **kwargs) 91 | -------------------------------------------------------------------------------- /libs/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /libs/utils/checkpoint_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import torch 4 | 5 | def save_checkpoint(model, optimizer, scheduler, ckpt_save_dir, epoch): 6 | if not osp.exists(ckpt_save_dir): 7 | os.makedirs(ckpt_save_dir, exist_ok=True) 8 | torch.save(model.state_dict(), osp.join(ckpt_save_dir, 'weight_{}.pth'.format(epoch))) 9 | torch.save(optimizer.state_dict(), osp.join(ckpt_save_dir, 'optim_{}.pth'.format(epoch))) 10 | torch.save(scheduler.state_dict(), osp.join(ckpt_save_dir, 'scheduler_{}.pth'.format(epoch))) 11 | 12 | def load_checkpoint(model, optimizer, scheduler, ckpt_load_dir, ckpt_load_ep): 13 | weight = torch.load(osp.join(ckpt_load_dir, 'weight_{}.pth').format(ckpt_load_ep)) 14 | opt_params = torch.load(osp.join(ckpt_load_dir, 'optim_{}.pth'.format(ckpt_load_ep))) 15 | sch_params = torch.load(osp.join(ckpt_load_dir, 'scheduler_{}.pth'.format(ckpt_load_ep))) 16 | 17 | model.load_state_dict(weight) 18 | optimizer.load_state_dict(opt_params) 19 | scheduler.load_state_dict(sch_params) 20 | 21 | return model, optimizer, scheduler -------------------------------------------------------------------------------- /libs/utils/clustering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from sklearn.cluster import DBSCAN 5 | from .faiss_rerank import compute_jaccard_distance 6 | 7 | 8 | def cam_label_split(cluster_labels, all_img_cams): 9 | """ 10 | Split proxies using camera labels. 11 | 12 | Params: 13 | cluster_labels: Pseudo labels from DBSCAN clustering. 14 | all_img_cams: Camera labels of all images. 15 | Returns: 16 | Proxy labels of all images. 17 | """ 18 | proxy_labels = -1 * np.ones(cluster_labels.shape, cluster_labels.dtype) 19 | cnt = 0 20 | for i in range(0, int(cluster_labels.max() + 1)): 21 | inds = np.where(cluster_labels == i)[0] 22 | local_cams = all_img_cams[inds] 23 | for cc in np.unique(local_cams): 24 | pc_inds = np.where(local_cams == cc)[0] 25 | proxy_labels[inds[pc_inds]] = cnt 26 | cnt += 1 27 | return proxy_labels 28 | 29 | def dbscan_clustering(cfg, features): 30 | """ 31 | DBSCAN clustering. Generate pseudo labels. 32 | 33 | Params: 34 | cfg: Config instance. 35 | features: Image features extracted by the model. 36 | Returns: 37 | Pseudo cluster labels of all images. 38 | """ 39 | 40 | rerank_dist = compute_jaccard_distance(features, k1=cfg.CLUSTER.K1, k2=cfg.CLUSTER.K2) 41 | print('=> Global DBSCAN params: eps={:.3f}, min_samples={:.3f}'.format(cfg.CLUSTER.EPS, cfg.CLUSTER.MIN_SAMPLES)) 42 | 43 | dbscan = DBSCAN(eps=cfg.CLUSTER.EPS, min_samples=cfg.CLUSTER.MIN_SAMPLES, metric='precomputed', n_jobs=-1) 44 | cluster_labels = dbscan.fit_predict(rerank_dist) 45 | 46 | return cluster_labels 47 | 48 | def get_centers(features, labels): 49 | """ 50 | Get L2-normalized centers of all pseudo classes. 51 | 52 | Params: 53 | features: Image features extracted by the model. 54 | labels: Pseudo labels of all features. 55 | Returns: 56 | L2-normalized centers of all pseudo classes. 57 | """ 58 | num_ids = len(set(labels)) - (1 if -1 in labels else 0) 59 | centers = np.zeros((num_ids, features.shape[1]), dtype=np.float32) 60 | for i in range(num_ids): 61 | idx = torch.where(torch.from_numpy(labels) == i)[0].numpy() 62 | temp = features[idx,:] 63 | if len(temp.shape) == 1: 64 | temp = temp.reshape(1, -1) 65 | centers[i,:] = temp.mean(0) 66 | centers = torch.from_numpy(centers) 67 | return F.normalize(centers, dim=1) 68 | -------------------------------------------------------------------------------- /libs/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .base_dataset import BaseDataset, BaseImageDataset 4 | from .preprocessor import Preprocessor 5 | 6 | class IterLoader: 7 | def __init__(self, loader, length=None): 8 | self.loader = loader 9 | self.length = length 10 | self.iter = None 11 | 12 | def __len__(self): 13 | if (self.length is not None): 14 | return self.length 15 | return len(self.loader) 16 | 17 | def new_epoch(self): 18 | self.iter = iter(self.loader) 19 | 20 | def next(self): 21 | try: 22 | return next(self.iter) 23 | except: 24 | self.iter = iter(self.loader) 25 | return next(self.iter) 26 | -------------------------------------------------------------------------------- /libs/utils/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def print_dataset_statistics(self): 23 | raise NotImplementedError 24 | 25 | @property 26 | def images_dir(self): 27 | return None 28 | 29 | 30 | class BaseImageDataset(BaseDataset): 31 | """ 32 | Base class of image reid dataset 33 | """ 34 | 35 | def print_dataset_statistics(self, train, query, gallery): 36 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 37 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 38 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 39 | 40 | print("Dataset statistics:") 41 | print(" ----------------------------------------") 42 | print(" subset | # ids | # images | # cameras") 43 | print(" ----------------------------------------") 44 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 45 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 46 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 47 | print(" ----------------------------------------") 48 | -------------------------------------------------------------------------------- /libs/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from torch.utils.data import DataLoader, Dataset 5 | import torchvision.transforms as T 6 | import numpy as np 7 | import random 8 | import math 9 | from PIL import Image 10 | 11 | import torch.utils.data as data 12 | import random 13 | 14 | class ProxySampleSet(Dataset): 15 | def __init__(self, args, fnames) -> None: 16 | super().__init__() 17 | self.fnames = fnames 18 | self.transform = T.Compose([ 19 | T.Resize((args.height, args.width), interpolation=3), 20 | T.RandomHorizontalFlip(p=0.5), 21 | T.Pad(10), 22 | T.RandomCrop((args.height, args.width)), 23 | T.ToTensor(), 24 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 25 | T.RandomErasing(p=0.5, value=[0.485, 0.456, 0.406]) 26 | ]) 27 | 28 | def __getitem__(self, index: int): 29 | fn = self.fnames[index] 30 | img = Image.open(fn).convert('RGB') 31 | img = self.transform(img) 32 | return img 33 | 34 | def __len__(self) -> int: 35 | return len(self.fnames) 36 | 37 | class Preprocessor(Dataset): 38 | def __init__(self, dataset, root=None, transform=None): 39 | super(Preprocessor, self).__init__() 40 | self.dataset = dataset 41 | self.root = root 42 | self.transform = transform 43 | 44 | def __len__(self): 45 | return len(self.dataset) 46 | 47 | def __getitem__(self, indices): 48 | return self._get_single_item(indices) 49 | 50 | def _get_single_item(self, index): 51 | #fname, pid, camid = self.dataset[index] 52 | input_data = self.dataset[index] 53 | fname = input_data[0] 54 | pid = input_data[1] 55 | camid = input_data[2] 56 | fpath = fname 57 | if self.root is not None: 58 | fpath = osp.join(self.root, fname) 59 | 60 | img = Image.open(fpath).convert('RGB') 61 | 62 | if self.transform is not None: 63 | img = self.transform(img) 64 | 65 | return img, fname, pid, camid, index 66 | 67 | 68 | 69 | class CameraAwarePreprocessor(object): 70 | def __init__(self, dataset, root=None, transform=None): 71 | super(CameraAwarePreprocessor, self).__init__() 72 | self.dataset = dataset 73 | self.root = root 74 | self.transform = transform 75 | 76 | def __len__(self): 77 | return len(self.dataset) 78 | 79 | def __getitem__(self, indices): 80 | if isinstance(indices, (tuple, list)): 81 | return [self._get_single_item(index) for index in indices] 82 | return self._get_single_item(indices) 83 | 84 | def _get_single_item(self, index): 85 | fname, pseudo_label, camid, img_index, accum_label = self.dataset[index] 86 | 87 | fpath = fname 88 | if self.root is not None: 89 | fpath = osp.join(self.root, fname) 90 | 91 | img = Image.open(fpath).convert('RGB') 92 | 93 | if self.transform is not None: 94 | img = self.transform(img) 95 | 96 | return img, fname, pseudo_label, camid, img_index, accum_label 97 | 98 | -------------------------------------------------------------------------------- /libs/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import math 4 | 5 | from typing import * 6 | import numpy as np 7 | import copy 8 | import random 9 | import torch 10 | from torch.utils.data.sampler import ( 11 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 12 | WeightedRandomSampler) 13 | 14 | 15 | def No_index(a, b): 16 | assert isinstance(a, list) 17 | return [i for i, j in enumerate(a) if j != b] 18 | 19 | 20 | class RandomIdentitySampler(Sampler): 21 | def __init__(self, data_source, num_instances=4, class_position=1): 22 | self.data_source = data_source 23 | #self.class_position = class_posotion 24 | self.num_instances = num_instances 25 | self.index_dic = defaultdict(list) 26 | #for index, (_, pid, _) in enumerate(data_source): 27 | for index, each_input in enumerate(data_source): 28 | pid = each_input[class_position] 29 | self.index_dic[pid].append(index) 30 | self.pids = list(self.index_dic.keys()) 31 | self.num_samples = len(self.pids) 32 | 33 | def __len__(self): 34 | return self.num_samples * self.num_instances 35 | 36 | def __iter__(self): 37 | indices = torch.randperm(self.num_samples).tolist() 38 | ret = [] 39 | for i in indices: 40 | pid = self.pids[i] 41 | t = self.index_dic[pid] 42 | if len(t) >= self.num_instances: 43 | t = np.random.choice(t, size=self.num_instances, replace=False) 44 | else: 45 | t = np.random.choice(t, size=self.num_instances, replace=True) 46 | ret.extend(t) 47 | return iter(ret) 48 | 49 | 50 | class RandomMultipleGallerySampler(Sampler): 51 | def __init__(self, data_source, class_position, num_instances=4): 52 | self.data_source = data_source 53 | self.index_pid = defaultdict(int) 54 | self.pid_cam = defaultdict(list) 55 | self.pid_index = defaultdict(list) 56 | self.num_instances = num_instances 57 | self.class_position = class_position 58 | 59 | #for index, (_, pid, cam) in enumerate(data_source): 60 | for index, each_input in enumerate(data_source): 61 | pid = each_input[self.class_position] # 1: cluster_label, 4: proxy_label 62 | cam = each_input[2] 63 | if (pid<0): continue 64 | self.index_pid[index] = pid 65 | self.pid_cam[pid].append(cam) 66 | self.pid_index[pid].append(index) 67 | 68 | self.pids = list(self.pid_index.keys()) 69 | self.num_samples = len(self.pids) 70 | 71 | def __len__(self): 72 | return self.num_samples * self.num_instances 73 | 74 | def __iter__(self): 75 | indices = torch.randperm(len(self.pids)).tolist() 76 | ret = [] 77 | 78 | for kid in indices: 79 | i = random.choice(self.pid_index[self.pids[kid]]) 80 | 81 | #_, i_pid, i_cam = self.data_source[i] 82 | i_pid = self.data_source[i][1] 83 | i_cam = self.data_source[i][2] 84 | ret.append(i) 85 | 86 | pid_i = self.index_pid[i] 87 | cams = self.pid_cam[pid_i] 88 | index = self.pid_index[pid_i] 89 | select_cams = No_index(cams, i_cam) 90 | 91 | if select_cams: # as a priority: select images in the same cluster/class, from different cameras (my add) 92 | 93 | if len(select_cams) >= self.num_instances: 94 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 95 | else: 96 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True) 97 | 98 | for kk in cam_indexes: 99 | ret.append(index[kk]) 100 | 101 | else: # otherwise select images in the same camera, or do not select more if it's an outlier (my add) 102 | select_indexes = No_index(index, i) 103 | if (not select_indexes): continue 104 | if len(select_indexes) >= self.num_instances: 105 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 106 | else: 107 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 108 | 109 | for kk in ind_indexes: 110 | ret.append(index[kk]) 111 | 112 | 113 | return iter(ret) 114 | 115 | 116 | 117 | class ClassUniformlySampler(Sampler): 118 | ''' 119 | random sample according to class label 120 | Arguments: 121 | data_source (Dataset): data_loader to sample from 122 | class_position (int): which one is used as class 123 | k (int): sample k images of each class 124 | ''' 125 | def __init__(self, samples, class_position, k=4, has_outlier=False, cam_num=0): 126 | 127 | self.samples = samples 128 | self.class_position = class_position 129 | self.k = k 130 | self.has_outlier = has_outlier 131 | self.cam_num = cam_num 132 | self.class_dict = self._tuple2dict(self.samples) 133 | 134 | def __iter__(self): 135 | self.sample_list = self._generate_list(self.class_dict) 136 | return iter(self.sample_list) 137 | 138 | def __len__(self): 139 | return len(self.sample_list) 140 | 141 | def _tuple2dict(self, inputs): 142 | ''' 143 | :param inputs: list with tuple elemnts, [(image_path1, class_index_1), (image_path_2, class_index_2), ...] 144 | :return: dict, {class_index_i: [samples_index1, samples_index2, ...]} 145 | ''' 146 | id_dict = {} 147 | for index, each_input in enumerate(inputs): 148 | class_index = each_input[self.class_position] # from which index to obtain the label 149 | if class_index not in list(id_dict.keys()): 150 | id_dict[class_index] = [index] 151 | else: 152 | id_dict[class_index].append(index) 153 | return id_dict 154 | 155 | def _generate_list(self, id_dict): 156 | ''' 157 | :param dict: dict, whose values are list 158 | :return: 159 | ''' 160 | sample_list = [] 161 | 162 | dict_copy = id_dict.copy() 163 | keys = list(dict_copy.keys()) 164 | random.shuffle(keys) 165 | outlier_cnt = 0 166 | for key in keys: 167 | value = dict_copy[key] 168 | if self.has_outlier and len(value)<=self.cam_num: 169 | random.shuffle(value) 170 | sample_list.append(value[0]) # sample outlier only one time 171 | outlier_cnt += 1 172 | elif len(value) >= self.k: 173 | random.shuffle(value) 174 | sample_list.extend(value[0: self.k]) 175 | else: 176 | value = value * self.k # copy a person's image list for k-times 177 | random.shuffle(value) 178 | sample_list.extend(value[0: self.k]) 179 | if outlier_cnt > 0: 180 | print('in Sampler: outlier number= {}'.format(outlier_cnt)) 181 | return sample_list 182 | 183 | 184 | 185 | class ClassAndCameraBalancedSampler(Sampler): 186 | def __init__(self, data_source, num_instances=4, class_position=1): 187 | self.data_source = data_source 188 | self.index_pid = defaultdict(int) 189 | self.pid_cam = defaultdict(list) 190 | self.pid_index = defaultdict(list) 191 | self.num_instances = num_instances 192 | 193 | # for index, (_, pid, cam) in enumerate(data_source): 194 | for index, each_input in enumerate(data_source): 195 | pid = each_input[class_position] 196 | cam = each_input[2] 197 | if (pid<0): continue 198 | self.index_pid[index] = pid 199 | self.pid_cam[pid].append(cam) 200 | self.pid_index[pid].append(index) 201 | 202 | self.pids = list(self.pid_index.keys()) 203 | self.num_samples = len(self.pids) 204 | 205 | def __len__(self): 206 | return self.num_samples * self.num_instances 207 | 208 | def __iter__(self): 209 | indices = torch.randperm(len(self.pids)).tolist() 210 | ret = [] 211 | 212 | for ii in indices: 213 | curr_id = self.pids[ii] 214 | indexes = np.array(self.pid_index[curr_id]) 215 | cams = np.array(self.pid_cam[curr_id]) 216 | uniq_cams = np.unique(cams) 217 | if len(uniq_cams) >= self.num_instances: # more cameras than per-class-instances 218 | sel_cams = np.random.choice(uniq_cams, size=self.num_instances, replace=False) 219 | for cc in sel_cams: 220 | ind = np.where(cams==cc)[0] 221 | sel_idx = np.random.choice(indexes[ind], size=1, replace=False) 222 | ret.append(sel_idx[0]) 223 | else: 224 | sel_cams = np.random.choice(uniq_cams, size=self.num_instances, replace=True) 225 | for cc in np.unique(sel_cams): 226 | sample_num = len(np.where(sel_cams == cc)[0]) 227 | ind = np.where(cams == cc)[0] 228 | if len(ind) >= sample_num: 229 | sel_idx = np.random.choice(indexes[ind], size=sample_num, replace=False) 230 | else: 231 | sel_idx = np.random.choice(indexes[ind], size=sample_num, replace=True) 232 | for idx in sel_idx: 233 | ret.append(idx) 234 | return iter(ret) 235 | 236 | class ClusterProxyBalancedSampler(Sampler): 237 | ''' 238 | Cluster-proxy balanced sampler. Samples are equally collected from different proxies in different clusters. 239 | 240 | Steps: 241 | 1. Randomly select a cluster `c_i` from all clusters. Add it into the selected set. 242 | 2. Randomly select a proxy `p_j` in the chosen cluster `c_i`. 243 | 3. Randomly select `k` samples from `p_j` in `c_i`. 244 | 4. Repeat until all `batchsize // num_instances` proxies are sampled. 245 | ''' 246 | def __init__(self, samples, k=4, has_outlier=False, cam_num=0): 247 | 248 | self.samples = samples 249 | self.k = k 250 | self.has_outlier = has_outlier 251 | self.cam_num = cam_num 252 | self.dicts = self._tuple2dict(self.samples) # label -> img_index 253 | 254 | def __iter__(self): 255 | self.sample_list = self._generate_list(self.dicts) 256 | return iter(self.sample_list) 257 | 258 | def __len__(self): 259 | return len(self.sample_list) 260 | 261 | def _tuple2dict(self, inputs): 262 | ''' 263 | :param inputs: list with tuple elemnts, [(image_path1, class_index_1), (image_path_2, class_index_2), ...] 264 | :return: dict, {class_index_i: [samples_index1, samples_index2, ...]} 265 | ''' 266 | cluster2proxy_dict = {} 267 | proxy2id_dict = {} 268 | for index, each_input in enumerate(inputs): 269 | clbl = each_input[1] 270 | plbl = each_input[4] 271 | 272 | # Record cluster proxy mappings 273 | if clbl not in cluster2proxy_dict.keys(): 274 | cluster2proxy_dict[clbl] = [plbl] 275 | else: 276 | cluster2proxy_dict[clbl].append(plbl) 277 | 278 | # Record proxy label mappings 279 | if plbl not in proxy2id_dict.keys(): 280 | proxy2id_dict[plbl] = [index] 281 | else: 282 | proxy2id_dict[plbl].append(index) 283 | return cluster2proxy_dict, proxy2id_dict 284 | 285 | def _generate_list(self, dicts: List[dict]): 286 | ''' 287 | dicts: list of dicts. containing cluster2id and proxy2id. 288 | ''' 289 | sample_list = [] 290 | cluster2proxy_dict, proxy2id_dict = dicts 291 | 292 | # Check each cluster for proxies 293 | cluster2proxy_dict_copy = cluster2proxy_dict.copy() 294 | clusters = list(cluster2proxy_dict_copy.keys()) 295 | random.shuffle(clusters) 296 | for c in clusters: 297 | proxies = cluster2proxy_dict_copy[c] 298 | sel_proxy = random.sample(proxies, k=1)[0] 299 | img_indices = proxy2id_dict[sel_proxy] 300 | if len(img_indices) >= self.k: 301 | random.shuffle(img_indices) 302 | sample_list.extend(img_indices[:self.k]) 303 | else: 304 | img_indices = img_indices * self.k 305 | random.shuffle(img_indices) 306 | sample_list.extend(img_indices[:self.k]) 307 | return sample_list 308 | 309 | class HardProxyBalancedSampler(Sampler): 310 | """ 311 | 对proxy进行PK均衡采样,每个proxy内选择K个距离proxy中心最远的样本。 312 | """ 313 | pass -------------------------------------------------------------------------------- /libs/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | """ Randomly selects a rectangle region in an image and erases its pixels. 54 | 'Random Erasing Data Augmentation' by Zhong et al. 55 | See https://arxiv.org/pdf/1708.04896.pdf 56 | Args: 57 | probability: The probability that the Random Erasing operation will be performed. 58 | sl: Minimum proportion of erased area against input image. 59 | sh: Maximum proportion of erased area against input image. 60 | r1: Minimum aspect ratio of erased area. 61 | mean: Erasing value. 62 | """ 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) >= self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | -------------------------------------------------------------------------------- /libs/utils/faiss_rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 5 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 6 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 7 | """ 8 | 9 | import os, sys 10 | import time 11 | import numpy as np 12 | from scipy.spatial.distance import cdist 13 | import gc 14 | import faiss 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | 19 | from .faiss_utils import search_index_pytorch, search_raw_array_pytorch, \ 20 | index_init_gpu, index_init_cpu 21 | 22 | 23 | def k_reciprocal_neigh(initial_rank, i, k1): 24 | forward_k_neigh_index = initial_rank[i,:k1+1] 25 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 26 | fi = np.where(backward_k_neigh_index==i)[0] 27 | return forward_k_neigh_index[fi] 28 | 29 | 30 | def compute_jaccard_distance(target_features, k1=20, k2=6, print_flag=True, search_option=0, use_float16=False): 31 | end = time.time() 32 | if print_flag: 33 | print('Computing jaccard distance...') 34 | 35 | ngpus = faiss.get_num_gpus() 36 | N = target_features.size(0) 37 | mat_type = np.float16 if use_float16 else np.float32 38 | 39 | if (search_option==0): 40 | # GPU + PyTorch CUDA Tensors (1) 41 | res = faiss.StandardGpuResources() 42 | res.setDefaultNullStreamAllDevices() 43 | _, initial_rank = search_raw_array_pytorch(res, target_features, target_features, k1) 44 | initial_rank = initial_rank.cpu().numpy() 45 | elif (search_option==1): 46 | # GPU + PyTorch CUDA Tensors (2) 47 | res = faiss.StandardGpuResources() 48 | index = faiss.GpuIndexFlatL2(res, target_features.size(-1)) 49 | index.add(target_features.cpu().numpy()) 50 | _, initial_rank = search_index_pytorch(index, target_features, k1) 51 | res.syncDefaultStreamCurrentDevice() 52 | initial_rank = initial_rank.cpu().numpy() 53 | elif (search_option==2): 54 | # GPU 55 | index = index_init_gpu(ngpus, target_features.size(-1)) 56 | index.add(target_features.cpu().numpy()) 57 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 58 | else: 59 | # CPU 60 | index = index_init_cpu(target_features.size(-1)) 61 | index.add(target_features.cpu().numpy()) 62 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 63 | 64 | 65 | nn_k1 = [] 66 | nn_k1_half = [] 67 | for i in range(N): 68 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 69 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1/2)))) 70 | 71 | V = np.zeros((N, N), dtype=mat_type) 72 | for i in range(N): 73 | k_reciprocal_index = nn_k1[i] 74 | k_reciprocal_expansion_index = k_reciprocal_index 75 | for candidate in k_reciprocal_index: 76 | candidate_k_reciprocal_index = nn_k1_half[candidate] 77 | if (len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index)) > 2/3*len(candidate_k_reciprocal_index)): 78 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 79 | 80 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 81 | dist = 2-2*torch.mm(target_features[i].unsqueeze(0).contiguous(), target_features[k_reciprocal_expansion_index].t()) 82 | if use_float16: 83 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type) 84 | else: 85 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy() 86 | 87 | del nn_k1, nn_k1_half 88 | 89 | if k2 != 1: 90 | V_qe = np.zeros_like(V, dtype=mat_type) 91 | for i in range(N): 92 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:], axis=0) 93 | V = V_qe 94 | del V_qe 95 | 96 | del initial_rank 97 | 98 | invIndex = [] 99 | for i in range(N): 100 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num 101 | 102 | jaccard_dist = np.zeros((N, N), dtype=mat_type) 103 | for i in range(N): 104 | temp_min = np.zeros((1, N), dtype=mat_type) 105 | # temp_max = np.zeros((1,N), dtype=mat_type) 106 | indNonZero = np.where(V[i, :] != 0)[0] 107 | indImages = [] 108 | indImages = [invIndex[ind] for ind in indNonZero] 109 | for j in range(len(indNonZero)): 110 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]]+np.minimum(V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]) 111 | # temp_max[0,indImages[j]] = temp_max[0,indImages[j]]+np.maximum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 112 | 113 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 114 | # jaccard_dist[i] = 1-temp_min/(temp_max+1e-6) 115 | 116 | del invIndex, V 117 | 118 | pos_bool = (jaccard_dist < 0) 119 | jaccard_dist[pos_bool] = 0.0 120 | if print_flag: 121 | print("Jaccard distance computing time cost: {}".format(time.time()-end)) 122 | 123 | return jaccard_dist 124 | -------------------------------------------------------------------------------- /libs/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | def swig_ptr_from_FloatTensor(x): 7 | assert x.is_contiguous() 8 | assert x.dtype == torch.float32 9 | return faiss.cast_integer_to_float_ptr( 10 | x.storage().data_ptr() + x.storage_offset() * 4) 11 | 12 | def swig_ptr_from_LongTensor(x): 13 | assert x.is_contiguous() 14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 15 | 16 | return faiss.cast_integer_to_idx_t_ptr( 17 | x.storage().data_ptr() + x.storage_offset() * 8) 18 | 19 | def search_index_pytorch(index, x, k, D=None, I=None): 20 | """call the search function of an index with pytorch tensor I/O (CPU 21 | and GPU supported)""" 22 | assert x.is_contiguous() 23 | n, d = x.size() 24 | assert d == index.d 25 | 26 | if D is None: 27 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 28 | else: 29 | assert D.size() == (n, k) 30 | 31 | if I is None: 32 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 33 | else: 34 | assert I.size() == (n, k) 35 | torch.cuda.synchronize() 36 | xptr = swig_ptr_from_FloatTensor(x) 37 | Iptr = swig_ptr_from_LongTensor(I) 38 | Dptr = swig_ptr_from_FloatTensor(D) 39 | index.search_c(n, xptr, 40 | k, Dptr, Iptr) 41 | torch.cuda.synchronize() 42 | return D, I 43 | 44 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 45 | metric=faiss.METRIC_L2): 46 | assert xb.device == xq.device 47 | 48 | nq, d = xq.size() 49 | if xq.is_contiguous(): 50 | xq_row_major = True 51 | elif xq.t().is_contiguous(): 52 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 53 | xq_row_major = False 54 | else: 55 | raise TypeError('matrix should be row or column-major') 56 | 57 | xq_ptr = swig_ptr_from_FloatTensor(xq) 58 | 59 | nb, d2 = xb.size() 60 | assert d2 == d 61 | if xb.is_contiguous(): 62 | xb_row_major = True 63 | elif xb.t().is_contiguous(): 64 | xb = xb.t() 65 | xb_row_major = False 66 | else: 67 | raise TypeError('matrix should be row or column-major') 68 | xb_ptr = swig_ptr_from_FloatTensor(xb) 69 | 70 | if D is None: 71 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 72 | else: 73 | assert D.shape == (nq, k) 74 | assert D.device == xb.device 75 | 76 | if I is None: 77 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 78 | else: 79 | assert I.shape == (nq, k) 80 | assert I.device == xb.device 81 | 82 | D_ptr = swig_ptr_from_FloatTensor(D) 83 | I_ptr = swig_ptr_from_LongTensor(I) 84 | 85 | faiss.bruteForceKnn(res, metric, 86 | xb_ptr, xb_row_major, nb, 87 | xq_ptr, xq_row_major, nq, 88 | d, k, D_ptr, I_ptr) 89 | 90 | return D, I 91 | 92 | def index_init_gpu(ngpus, feat_dim): 93 | flat_config = [] 94 | for i in range(ngpus): 95 | cfg = faiss.GpuIndexFlatConfig() 96 | cfg.useFloat16 = False 97 | cfg.device = i 98 | flat_config.append(cfg) 99 | 100 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 101 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 102 | index = faiss.IndexShards(feat_dim) 103 | for sub_index in indexes: 104 | index.add_shard(sub_index) 105 | index.reset() 106 | return index 107 | 108 | def index_init_cpu(feat_dim): 109 | return faiss.IndexFlatL2(feat_dim) 110 | -------------------------------------------------------------------------------- /libs/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /libs/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /libs/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /libs/utils/prepare_data.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | sys.path.append(osp.abspath(osp.join(osp.dirname(__file__), '../..'))) 4 | 5 | from torch.utils.data import DataLoader 6 | from libs import datasets 7 | from .data import transforms as T 8 | from .data import IterLoader 9 | from .data.sampler import ClassUniformlySampler, RandomMultipleGallerySampler, ClusterProxyBalancedSampler 10 | from .data.preprocessor import Preprocessor, CameraAwarePreprocessor 11 | 12 | def get_data(name, data_dir): 13 | root = osp.join(data_dir, name) 14 | print('root path= {}'.format(root)) 15 | dataset = datasets.create(name, root) 16 | return dataset 17 | 18 | def get_train_loader(cfg, dataset, height, width, batch_size, workers, 19 | num_instances, iters, trainset=None): 20 | # Preprocessing 21 | normalizer = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, 22 | std=cfg.INPUT.PIXEL_STD) 23 | train_transformer = T.Compose([ 24 | T.Resize((height, width), interpolation=3), 25 | T.RandomHorizontalFlip(p=0.5), 26 | T.Pad(10), 27 | T.RandomCrop((height, width)), 28 | T.ToTensor(), 29 | normalizer, 30 | T.RandomErasing(probability=0.5, mean=cfg.INPUT.PIXEL_MEAN) 31 | ]) 32 | 33 | train_set = sorted(dataset.train) if trainset is None else sorted(trainset) 34 | 35 | # Choose sampler type 36 | # class_position [1: cluster_label, 4: proxy_label] 37 | if cfg.SAMPLER.TYPE == 'proxy_balance': 38 | sampler = ClassUniformlySampler(train_set, class_position=4, k=num_instances) 39 | elif cfg.SAMPLER.TYPE == 'cluster_balance': 40 | sampler = ClassUniformlySampler(train_set, class_position=1, k=num_instances) 41 | elif cfg.SAMPLER.TYPE == 'cam_cluster_balance': 42 | sampler = RandomMultipleGallerySampler(train_set, class_position=1, num_instances=num_instances) 43 | elif cfg.SAMPLER.TYPE == 'cam_proxy_balance': 44 | sampler = RandomMultipleGallerySampler(train_set, class_position=4, num_instances=num_instances) 45 | elif cfg.SAMPLER.TYPE == 'cluster_proxy_balance': 46 | sampler = ClusterProxyBalancedSampler(train_set, k=num_instances) 47 | else: 48 | raise ValueError('Invalid sampler type name!') 49 | 50 | # Create dataloader 51 | train_loader = IterLoader( 52 | DataLoader(CameraAwarePreprocessor(train_set, root=dataset.images_dir, transform=train_transformer), 53 | batch_size=batch_size, num_workers=workers, sampler=sampler, 54 | shuffle=False, pin_memory=True, drop_last=True), length=iters) 55 | return train_loader 56 | 57 | def get_test_loader(cfg, dataset, height, width, batch_size, workers, testset=None): 58 | normalizer = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, 59 | std=cfg.INPUT.PIXEL_STD) 60 | 61 | test_transformer = T.Compose([ 62 | T.Resize((height, width), interpolation=3), 63 | T.ToTensor(), 64 | normalizer 65 | ]) 66 | 67 | if (testset is None): 68 | testset = list(set(dataset.query) | set(dataset.gallery)) 69 | 70 | test_loader = DataLoader( 71 | Preprocessor(testset, root=dataset.images_dir, transform=test_transformer), 72 | batch_size=batch_size, num_workers=workers, 73 | shuffle=False, pin_memory=True) 74 | 75 | return test_loader 76 | -------------------------------------------------------------------------------- /libs/utils/prepare_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .. import models 3 | 4 | def create_vit_model(cfg): 5 | """ 6 | Create ViT model. 7 | 8 | Params: 9 | cfg: Config instance. 10 | Returns: 11 | The TMGF model. 12 | """ 13 | 14 | model = models.create(cfg.MODEL.ARCH, arch=cfg.MODEL.ARCH, 15 | img_size=[cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH], sie_coef=cfg.MODEL.SIE_COEF, 16 | camera_num=cfg.MODEL.SIE_CAMERA, view_num=cfg.MODEL.SIE_VIEW, 17 | stride_size=cfg.MODEL.STRIDE_SIZE, drop_path_rate=cfg.MODEL.DROP_PATH, 18 | drop_rate=cfg.MODEL.DROP_OUT, attn_drop_rate=cfg.MODEL.ATTN_DROP_RATE, 19 | pretrain_path=cfg.MODEL.PRETRAIN_PATH, hw_ratio=cfg.MODEL.PRETRAIN_HW_RATIO, 20 | gem_pool=cfg.MODEL.GEM_POOL, stem_conv=cfg.MODEL.STEM_CONV, num_parts=cfg.MODEL.NUM_PARTS, 21 | has_head=cfg.MODEL.HAS_HEAD, global_feature_type=cfg.MODEL.GLOBAL_FEATURE_TYPE, 22 | granularities=cfg.MODEL.GRANULARITIES, branch=cfg.MODEL.BRANCH, has_early_feature=cfg.MODEL.HAS_EARLY_FEATURE, 23 | enable_early_norm=cfg.MODEL.ENABLE_EARLY_NORM) 24 | model.cuda() 25 | return model 26 | -------------------------------------------------------------------------------- /libs/utils/prepare_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def make_vit_optimizer(cfg, model): 4 | """ 5 | Create ViT optimizer. 6 | 7 | Params: 8 | cfg: Config instance. 9 | model: The model to be optimized. 10 | Returns: 11 | An optimizer. 12 | """ 13 | 14 | params = [] 15 | for key, value in model.named_parameters(): 16 | if not value.requires_grad: 17 | continue 18 | lr = cfg.OPTIM.BASE_LR 19 | weight_decay = cfg.OPTIM.WEIGHT_DECAY 20 | if "bias" in key: 21 | lr = cfg.OPTIM.BASE_LR * cfg.OPTIM.BIAS_LR_FACTOR 22 | weight_decay = cfg.OPTIM.WEIGHT_DECAY_BIAS 23 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 24 | 25 | if cfg.OPTIM.NAME == 'SGD': 26 | optimizer = getattr(torch.optim, cfg.OPTIM.NAME)(params, momentum=cfg.OPTIM.MOMENTUM) 27 | elif cfg.OPTIM.NAME == 'AdamW': 28 | optimizer = torch.optim.AdamW(params, lr=cfg.OPTIM.BASE_LR, weight_decay=cfg.OPTIM.WEIGHT_DECAY) 29 | else: 30 | optimizer = getattr(torch.optim, cfg.OPTIM.NAME)(params) 31 | 32 | return optimizer -------------------------------------------------------------------------------- /libs/utils/prepare_scheduler.py: -------------------------------------------------------------------------------- 1 | from .scheduler import CosineLRScheduler, WarmupMultiStepLR 2 | 3 | 4 | def create_scheduler(cfg, optimizer): 5 | scheduler_type = cfg.OPTIM.SCHEDULER_TYPE 6 | if scheduler_type == 'cosine': 7 | num_epochs = cfg.TRAIN.EPOCHS 8 | # type 1 9 | # lr_min = 0.01 * cfg.SOLVER.BASE_LR 10 | # warmup_lr_init = 0.001 * cfg.SOLVER.BASE_LR 11 | # type 2 12 | lr_min = 0.002 * cfg.OPTIM.BASE_LR 13 | warmup_lr_init = 0.01 * cfg.OPTIM.BASE_LR 14 | # type 3 15 | # lr_min = 0.001 * cfg.SOLVER.BASE_LR 16 | # warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 17 | 18 | warmup_t = cfg.OPTIM.WARMUP_EPOCHS 19 | noise_range = None 20 | 21 | lr_scheduler = CosineLRScheduler( 22 | optimizer, 23 | t_initial=num_epochs, 24 | lr_min=lr_min, 25 | t_mul= 1., 26 | decay_rate=0.1, 27 | warmup_lr_init=warmup_lr_init, 28 | warmup_t=warmup_t, 29 | cycle_limit=1, 30 | t_in_epochs=True, 31 | noise_range_t=noise_range, 32 | noise_pct= 0.67, 33 | noise_std= 1., 34 | noise_seed=42, 35 | ) 36 | elif scheduler_type == 'warmup': 37 | lr_scheduler = WarmupMultiStepLR(optimizer, cfg.OPTIM.MILESTONES, gamma=cfg.OPTIM.GAMMA, 38 | warmup_factor=cfg.OPTIM.WARMUP_FACTOR, 39 | warmup_iters=cfg.OPTIM.WARMUP_EPOCHS) 40 | else: 41 | raise ValueError(f'Invalid scheduler type {scheduler_type}!') 42 | 43 | return lr_scheduler -------------------------------------------------------------------------------- /libs/utils/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Source: https://github.com/zhunzhong07/person-re-ranking 5 | Created on Mon Jun 26 14:46:56 2017 6 | @author: luohao 7 | Modified by Houjing Huang, 2017-12-22. 8 | - This version accepts distance matrix instead of raw features. 9 | - The difference of `/` division between python 2 and 3 is handled. 10 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 11 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 12 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 13 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 14 | API 15 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 16 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 17 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 18 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 19 | Returns: 20 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import print_function 24 | from __future__ import division 25 | 26 | __all__ = ['re_ranking'] 27 | 28 | import numpy as np 29 | 30 | 31 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 32 | 33 | # The following naming, e.g. gallery_num, is different from outer scope. 34 | # Don't care about it. 35 | 36 | original_dist = np.concatenate( 37 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 38 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 39 | axis=0) 40 | original_dist = np.power(original_dist, 2).astype(np.float32) 41 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 42 | V = np.zeros_like(original_dist).astype(np.float32) 43 | initial_rank = np.argsort(original_dist).astype(np.int32) 44 | 45 | query_num = q_g_dist.shape[0] 46 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 47 | all_num = gallery_num 48 | 49 | for i in range(all_num): 50 | # k-reciprocal neighbors 51 | forward_k_neigh_index = initial_rank[i,:k1+1] 52 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 53 | fi = np.where(backward_k_neigh_index==i)[0] 54 | k_reciprocal_index = forward_k_neigh_index[fi] 55 | k_reciprocal_expansion_index = k_reciprocal_index 56 | for j in range(len(k_reciprocal_index)): 57 | candidate = k_reciprocal_index[j] 58 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 59 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 60 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 61 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 62 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 63 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 64 | 65 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 66 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 67 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 68 | original_dist = original_dist[:query_num,] 69 | if k2 != 1: 70 | V_qe = np.zeros_like(V,dtype=np.float32) 71 | for i in range(all_num): 72 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 73 | V = V_qe 74 | del V_qe 75 | del initial_rank 76 | invIndex = [] 77 | for i in range(gallery_num): 78 | invIndex.append(np.where(V[:,i] != 0)[0]) 79 | 80 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 81 | 82 | 83 | for i in range(query_num): 84 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 85 | indNonZero = np.where(V[i,:] != 0)[0] 86 | indImages = [] 87 | indImages = [invIndex[ind] for ind in indNonZero] 88 | for j in range(len(indNonZero)): 89 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 90 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 91 | 92 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 93 | del original_dist 94 | del V 95 | del jaccard_dist 96 | final_dist = final_dist[:query_num,query_num:] 97 | return final_dist 98 | -------------------------------------------------------------------------------- /libs/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import logging 4 | from bisect import bisect_right 5 | from typing import Dict, Any 6 | 7 | class Scheduler: 8 | """ Parameter Scheduler Base Class 9 | A scheduler base class that can be used to schedule any optimizer parameter groups. 10 | 11 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 12 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 13 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 14 | 15 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 16 | 17 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 18 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 19 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 20 | 21 | Based on ideas from: 22 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 23 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 24 | """ 25 | 26 | def __init__(self, 27 | optimizer: torch.optim.Optimizer, 28 | param_group_field: str, 29 | noise_range_t=None, 30 | noise_type='normal', 31 | noise_pct=0.67, 32 | noise_std=1.0, 33 | noise_seed=None, 34 | initialize: bool = True) -> None: 35 | self.optimizer = optimizer 36 | self.param_group_field = param_group_field 37 | self._initial_param_group_field = f"initial_{param_group_field}" 38 | if initialize: 39 | for i, group in enumerate(self.optimizer.param_groups): 40 | if param_group_field not in group: 41 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 42 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 43 | else: 44 | for i, group in enumerate(self.optimizer.param_groups): 45 | if self._initial_param_group_field not in group: 46 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 47 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 48 | self.metric = None # any point to having this for all? 49 | self.noise_range_t = noise_range_t 50 | self.noise_pct = noise_pct 51 | self.noise_type = noise_type 52 | self.noise_std = noise_std 53 | self.noise_seed = noise_seed if noise_seed is not None else 42 54 | self.update_groups(self.base_values) 55 | 56 | def state_dict(self) -> Dict[str, Any]: 57 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 58 | 59 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 60 | self.__dict__.update(state_dict) 61 | 62 | def get_epoch_values(self, epoch: int): 63 | return None 64 | 65 | def get_update_values(self, num_updates: int): 66 | return None 67 | 68 | def step(self, epoch: int, metric: float = None) -> None: 69 | self.metric = metric 70 | values = self.get_epoch_values(epoch) 71 | if values is not None: 72 | values = self._add_noise(values, epoch) 73 | self.update_groups(values) 74 | 75 | def step_update(self, num_updates: int, metric: float = None): 76 | self.metric = metric 77 | values = self.get_update_values(num_updates) 78 | if values is not None: 79 | values = self._add_noise(values, num_updates) 80 | self.update_groups(values) 81 | 82 | def update_groups(self, values): 83 | if not isinstance(values, (list, tuple)): 84 | values = [values] * len(self.optimizer.param_groups) 85 | for param_group, value in zip(self.optimizer.param_groups, values): 86 | param_group[self.param_group_field] = value 87 | 88 | def _add_noise(self, lrs, t): 89 | if self.noise_range_t is not None: 90 | if isinstance(self.noise_range_t, (list, tuple)): 91 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 92 | else: 93 | apply_noise = t >= self.noise_range_t 94 | if apply_noise: 95 | g = torch.Generator() 96 | g.manual_seed(self.noise_seed + t) 97 | if self.noise_type == 'normal': 98 | while True: 99 | # resample if noise out of percent limit, brute force but shouldn't spin much 100 | noise = torch.randn(1, generator=g).item() 101 | if abs(noise) < self.noise_pct: 102 | break 103 | else: 104 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 105 | lrs = [v + v * noise for v in lrs] 106 | return lrs 107 | 108 | _logger = logging.getLogger(__name__) 109 | 110 | 111 | class CosineLRScheduler(Scheduler): 112 | """ 113 | Cosine decay with restarts. 114 | This is described in the paper https://arxiv.org/abs/1608.03983. 115 | 116 | Inspiration from 117 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 118 | """ 119 | 120 | def __init__(self, 121 | optimizer: torch.optim.Optimizer, 122 | t_initial: int, 123 | t_mul: float = 1., 124 | lr_min: float = 0., 125 | decay_rate: float = 1., 126 | warmup_t=0, 127 | warmup_lr_init=0, 128 | warmup_prefix=False, 129 | cycle_limit=0, 130 | t_in_epochs=True, 131 | noise_range_t=None, 132 | noise_pct=0.67, 133 | noise_std=1.0, 134 | noise_seed=42, 135 | initialize=True) -> None: 136 | super().__init__( 137 | optimizer, param_group_field="lr", 138 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 139 | initialize=initialize) 140 | 141 | assert t_initial > 0 142 | assert lr_min >= 0 143 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 144 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 145 | "rate since t_initial = t_mul = eta_mul = 1.") 146 | self.t_initial = t_initial 147 | self.t_mul = t_mul 148 | self.lr_min = lr_min 149 | self.decay_rate = decay_rate 150 | self.cycle_limit = cycle_limit 151 | self.warmup_t = warmup_t 152 | self.warmup_lr_init = warmup_lr_init 153 | self.warmup_prefix = warmup_prefix 154 | self.t_in_epochs = t_in_epochs 155 | if self.warmup_t: 156 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 157 | super().update_groups(self.warmup_lr_init) 158 | else: 159 | self.warmup_steps = [1 for _ in self.base_values] 160 | 161 | def _get_lr(self, t): 162 | if t < self.warmup_t: 163 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 164 | else: 165 | if self.warmup_prefix: 166 | t = t - self.warmup_t 167 | 168 | if self.t_mul != 1: 169 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 170 | t_i = self.t_mul ** i * self.t_initial 171 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 172 | else: 173 | i = t // self.t_initial 174 | t_i = self.t_initial 175 | t_curr = t - (self.t_initial * i) 176 | 177 | gamma = self.decay_rate ** i 178 | lr_min = self.lr_min * gamma 179 | lr_max_values = [v * gamma for v in self.base_values] 180 | 181 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 182 | lrs = [ 183 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 184 | ] 185 | else: 186 | lrs = [self.lr_min for _ in self.base_values] 187 | 188 | return lrs 189 | 190 | def get_epoch_values(self, epoch: int): 191 | if self.t_in_epochs: 192 | return self._get_lr(epoch) 193 | else: 194 | return None 195 | 196 | def get_update_values(self, num_updates: int): 197 | if not self.t_in_epochs: 198 | return self._get_lr(num_updates) 199 | else: 200 | return None 201 | 202 | def get_cycle_length(self, cycles=0): 203 | if not cycles: 204 | cycles = self.cycle_limit 205 | cycles = max(1, cycles) 206 | if self.t_mul == 1.0: 207 | return self.t_initial * cycles 208 | else: 209 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 210 | 211 | 212 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 213 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, warmup_iters=500, 214 | warmup_method="linear", last_epoch=-1): 215 | if not list(milestones) == sorted(milestones): 216 | raise ValueError( 217 | "Milestones should be a list of" " increasing integers. Got {}", 218 | milestones,) 219 | 220 | if warmup_method not in ("constant", "linear"): 221 | raise ValueError( 222 | "Only 'constant' or 'linear' warmup_method accepted" 223 | "got {}".format(warmup_method) 224 | ) 225 | self.milestones = milestones 226 | self.gamma = gamma 227 | self.warmup_factor = warmup_factor 228 | self.warmup_iters = warmup_iters 229 | self.warmup_method = warmup_method 230 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 231 | 232 | def get_lr(self): 233 | warmup_factor = 1 234 | if self.last_epoch < self.warmup_iters: 235 | if self.warmup_method == "constant": 236 | warmup_factor = self.warmup_factor 237 | elif self.warmup_method == "linear": 238 | alpha = float(self.last_epoch) / float(self.warmup_iters) 239 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 240 | return [ 241 | base_lr 242 | * warmup_factor 243 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 244 | for base_lr in self.base_lrs 245 | ] -------------------------------------------------------------------------------- /libs/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | # checkpoint = torch.load(fpath) 34 | checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 35 | print("=> Loaded checkpoint '{}'".format(fpath)) 36 | return checkpoint 37 | else: 38 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 39 | 40 | 41 | def copy_state_dict(state_dict, model, strip=None): 42 | tgt_state = model.state_dict() 43 | copied_names = set() 44 | for name, param in state_dict.items(): 45 | if strip is not None and name.startswith(strip): 46 | name = name[len(strip):] 47 | if name not in tgt_state: 48 | continue 49 | if isinstance(param, Parameter): 50 | param = param.data 51 | if param.size() != tgt_state[name].size(): 52 | print('mismatch:', name, param.size(), tgt_state[name].size()) 53 | continue 54 | tgt_state[name].copy_(param) 55 | copied_names.add(name) 56 | 57 | missing = set(tgt_state.keys()) - copied_names 58 | if len(missing) > 0: 59 | print("missing keys in state_dict:", missing) 60 | 61 | return model 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.12.0 2 | argon2-cffi==21.3.0 3 | argon2-cffi-bindings==21.2.0 4 | attrs==21.4.0 5 | backcall==0.2.0 6 | bleach==4.1.0 7 | cachetools==4.2.2 8 | certifi==2020.12.5 9 | cffi==1.15.0 10 | chardet==4.0.0 11 | colorama==0.4.4 12 | cycler==0.10.0 13 | Cython==0.29.23 14 | debugpy==1.5.1 15 | decorator==5.0.9 16 | defusedxml==0.7.1 17 | easydict==1.9 18 | entrypoints==0.3 19 | faiss-gpu==1.7.2 20 | filelock==3.0.12 21 | fire==0.4.0 22 | future==0.18.2 23 | gdown==3.13.0 24 | google-auth==1.30.0 25 | google-auth-oauthlib==0.4.4 26 | grpcio==1.37.1 27 | hdbscan==0.8.27 28 | idna==2.10 29 | importlib-metadata==4.0.1 30 | importlib-resources==5.4.0 31 | infomap==2.3.0 32 | ipdb==0.13.7 33 | ipykernel==6.7.0 34 | ipython==7.23.1 35 | ipython-genutils==0.2.0 36 | ipywidgets==7.6.5 37 | jedi==0.18.0 38 | Jinja2==3.0.3 39 | joblib==1.0.1 40 | jsonschema==4.4.0 41 | jupyter==1.0.0 42 | jupyter-client==7.1.1 43 | jupyter-console==6.4.0 44 | jupyter-core==4.9.1 45 | jupyterlab-pygments==0.1.2 46 | jupyterlab-widgets==1.0.2 47 | kiwisolver==1.3.1 48 | llvmlite==0.36.0 49 | Markdown==3.3.4 50 | MarkupSafe==2.0.1 51 | matplotlib==3.4.2 52 | matplotlib-inline==0.1.2 53 | mistune==0.8.4 54 | munch==2.5.0 55 | nbclient==0.5.10 56 | nbconvert==6.4.0 57 | nbformat==5.1.3 58 | nest-asyncio==1.5.4 59 | notebook==6.4.7 60 | numba==0.53.1 61 | numpy==1.20.3 62 | oauthlib==3.1.0 63 | opencv-python==3.4.2.17 64 | packaging==21.3 65 | pandas==1.2.4 66 | pandocfilters==1.5.0 67 | parso==0.8.2 68 | pexpect==4.8.0 69 | pickleshare==0.7.5 70 | Pillow==8.2.0 71 | pretrainedmodels==0.7.4 72 | prettytable==2.1.0 73 | prometheus-client==0.12.0 74 | prompt-toolkit==3.0.18 75 | protobuf==3.17.0 76 | ptyprocess==0.7.0 77 | pyasn1==0.4.8 78 | pyasn1-modules==0.2.8 79 | pycparser==2.21 80 | Pygments==2.9.0 81 | pynndescent==0.5.2 82 | pyparsing==2.4.7 83 | PyQt5-Qt5==5.15.2 84 | PyQt5-sip==12.11.0 85 | pyrsistent==0.18.1 86 | PySocks==1.7.1 87 | python-dateutil==2.8.1 88 | pytorch-ignite==0.1.2 89 | pytz==2021.1 90 | PyYAML==5.4.1 91 | pyzmq==22.3.0 92 | qtconsole==5.2.2 93 | QtPy==2.0.0 94 | requests==2.25.1 95 | requests-oauthlib==1.3.0 96 | rsa==4.7.2 97 | scikit-learn==0.24.2 98 | scipy==1.6.3 99 | Send2Trash==1.8.0 100 | six==1.16.0 101 | sklearn==0.0 102 | tabulate==0.8.9 103 | tensorboard==2.5.0 104 | tensorboard-data-server==0.6.1 105 | tensorboard-plugin-wit==1.8.0 106 | tensorboardX==2.2 107 | termcolor==1.1.0 108 | terminado==0.12.1 109 | testpath==0.5.0 110 | threadpoolctl==2.1.0 111 | toml==0.10.2 112 | torch==1.6.0 113 | torchvision==0.7.0 114 | tornado==6.1 115 | tqdm==4.60.0 116 | traitlets==5.1.0 117 | typing-extensions==3.10.0.0 118 | umap-learn==0.5.1 119 | urllib3==1.26.4 120 | wcwidth==0.2.5 121 | webencodings==0.5.1 122 | Werkzeug==2.0.1 123 | widgetsnbextension==3.5.2 124 | wrapt==1.13.3 125 | yacs==0.1.8 126 | zipp==3.4.1 127 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | import warnings 4 | warnings.filterwarnings('ignore') 5 | import argparse 6 | import random 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import time 11 | from configs.default import get_cfg_defaults 12 | from datetime import timedelta 13 | from libs.utils.prepare_data import get_data, get_test_loader, get_train_loader 14 | from libs.utils.prepare_model import create_vit_model 15 | from libs import trainers 16 | from libs.models import mb 17 | from libs.evaluators import Evaluator, extract_multipart_vit_features, save_benchmark 18 | from libs.utils.logging import Logger 19 | from libs.utils.checkpoint_io import load_checkpoint, save_checkpoint 20 | from libs.utils.prepare_optimizer import make_vit_optimizer 21 | from libs.utils.prepare_scheduler import create_scheduler 22 | from libs.utils.clustering import dbscan_clustering, cam_label_split, get_centers 23 | 24 | 25 | def main(cfg): 26 | # Check output dir 27 | assert osp.exists(cfg.LOG.LOG_DIR) 28 | assert osp.exists(cfg.LOG.CHECKPOINT.SAVE_DIR) 29 | 30 | start_time = time.monotonic() 31 | 32 | # Build task folder 33 | task_name = time.strftime('%Y%m%d') + '_' + cfg.TASK_NAME 34 | log_file_name = osp.join(cfg.LOG.LOG_DIR, task_name+'.txt') 35 | ckpt_save_dir = osp.join(cfg.LOG.CHECKPOINT.SAVE_DIR, task_name) 36 | 37 | 38 | # Print settings 39 | sys.stdout = Logger(log_file_name) 40 | print("==========\n{}\n==========".format(cfg)) 41 | print('=> Task name:', task_name) 42 | print('=> Description:', cfg.DESC) 43 | 44 | 45 | # Create datasets 46 | iters = cfg.TRAIN.ITERS if (cfg.TRAIN.ITERS>0) else None 47 | print("=> Load unlabeled dataset") 48 | dataset = get_data(cfg.DATASET.NAME, cfg.DATASET.ROOT_DIR) 49 | 50 | # Create model 51 | model = create_vit_model(cfg) 52 | 53 | # Create memory 54 | memory = mb.MultiPartMemory(cfg).cuda() 55 | 56 | # Get dataloaders 57 | cluster_loader = get_test_loader(cfg, dataset, cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH, cfg.TEST.BATCHSIZE, cfg.TEST.NUM_WORKERS, testset=sorted(dataset.train)) 58 | test_loader = get_test_loader(cfg, dataset, cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH, cfg.TEST.BATCHSIZE, cfg.TEST.NUM_WORKERS) 59 | 60 | # Evaluator 61 | evaluator = Evaluator(cfg, model) 62 | 63 | 64 | # Optimizer & scheduler 65 | optimizer = make_vit_optimizer(cfg, model) 66 | lr_scheduler = create_scheduler(cfg, optimizer) 67 | 68 | # Load checkpoint 69 | if len(cfg.LOG.CHECKPOINT.LOAD_DIR) == 0: 70 | start_ep = 0 # default: start from beginning 71 | print('=> Train from beginning.') 72 | else: 73 | model, optimizer, lr_scheduler = load_checkpoint(model, optimizer, lr_scheduler, cfg.LOG.CHECKPOINT.LOAD_DIR, cfg.LOG.CHECKPOINT.LOAD_EPOCH) 74 | start_ep = cfg.LOG.CHECKPOINT.LOAD_EPOCH 75 | print('=> Continue training from epoch={}, load checkpoint from {}'.format(start_ep, cfg.LOG.CHECKPOINT.LOAD_DIR)) 76 | 77 | # Trainer 78 | trainer = trainers.ViTTrainerFp16(model, memory) 79 | 80 | # Training pipeline 81 | for epoch in range(start_ep, cfg.TRAIN.EPOCHS): 82 | print('=> EPOCH num={}'.format(epoch+1)) 83 | 84 | # Feature extraction 85 | print('=> Extract features...') 86 | features, part_feats, _ = extract_multipart_vit_features(model, cluster_loader, cfg.MODEL.NUM_PARTS) 87 | features = torch.cat([features[f].unsqueeze(0) for f, _, _ in sorted(dataset.train)], 0) 88 | part_feats = [torch.cat([pf[f].unsqueeze(0) for f, _, _ in sorted(dataset.train)], 0) for pf in part_feats] 89 | 90 | # Clustering for pseudo labels 91 | cluster_labels = dbscan_clustering(cfg, features) 92 | 93 | 94 | # Camera proxy generation 95 | print('=> cam-split with global features') 96 | all_img_cams = np.array([c for _, _, c in sorted(dataset.train)]) 97 | proxy_labels = cam_label_split(cluster_labels, all_img_cams) 98 | 99 | num_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) 100 | num_proxies = len(set(proxy_labels)) - (1 if -1 in proxy_labels else 0) 101 | num_outliers = len(np.where(proxy_labels == -1)[0]) 102 | print('=> Global feature clusters: {}\n=> Generated proxies: {}\n=> Outliers: {}'.format( 103 | num_clusters, num_proxies, num_outliers 104 | )) 105 | 106 | 107 | # Add pseudo labels into training set 108 | pseudo_labeled_dataset = [] 109 | for i, ((fname, _, cid), gcl, pl) in enumerate(zip(sorted(dataset.train), cluster_labels, proxy_labels)): 110 | if gcl != -1 and pl != -1: 111 | pseudo_labeled_dataset.append((fname, gcl, cid, i, pl)) 112 | 113 | 114 | # Cluster-proxy mappings 115 | proxy_labels = torch.from_numpy(proxy_labels).long() 116 | cluster_labels = torch.from_numpy(cluster_labels).long() 117 | cluster2proxy = {} # global cluster label -> proxy 118 | proxy2cluster = {} # proxy -> global cluster label 119 | cam2proxy = {} # cam -> proxy 120 | for p in range(0, int(proxy_labels.max() + 1)): 121 | proxy2cluster[p] = torch.unique(cluster_labels[proxy_labels == p]) 122 | for c in range(0, int(cluster_labels.max() + 1)): 123 | cluster2proxy[c] = torch.unique(proxy_labels[cluster_labels == c]) 124 | for cc in range(0, int(all_img_cams.max() + 1)): 125 | cam2proxy[cc] = torch.unique(proxy_labels[all_img_cams == cc]) 126 | cam2proxy[cc] = cam2proxy[cc][cam2proxy[cc] != -1] # remove outliers 127 | 128 | # Set memory attributes 129 | memory.all_proxy_labels = proxy_labels # proxy label of all samples 130 | memory.proxy2cluster = proxy2cluster 131 | memory.cluster2proxy = cluster2proxy 132 | 133 | # Stack into a single memory 134 | proxy_memory = [get_centers(features.numpy(), proxy_labels.numpy()).cuda()] + \ 135 | [get_centers(f.numpy(), proxy_labels.numpy()).cuda() for f in part_feats] 136 | memory.proxy_memory = torch.stack(proxy_memory, dim=0) # (n_part, n_proxy, c) 137 | 138 | 139 | # camera-proxy mapping 140 | memory.unique_cams = torch.unique(torch.from_numpy(all_img_cams)) 141 | memory.cam2proxy = cam2proxy 142 | 143 | # Get a train loader 144 | train_loader = get_train_loader(cfg, dataset, cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH, 145 | cfg.TRAIN.BATCHSIZE, cfg.TRAIN.NUM_WORKERS, 146 | cfg.SAMPLER.NUM_INSTANCES, iters, 147 | trainset=pseudo_labeled_dataset) 148 | 149 | 150 | 151 | # Train one epoch 152 | curr_lr = lr_scheduler._get_lr(epoch+1)[0] if cfg.OPTIM.SCHEDULER_TYPE == 'cosine' else lr_scheduler.get_lr()[0] 153 | print('=> Current Lr: {:.2e}'.format(curr_lr)) 154 | train_loader.new_epoch() 155 | trainer.train(epoch+1, train_loader, optimizer, print_freq=cfg.LOG.PRINT_FREQ, train_iters=len(train_loader), fp16=cfg.TRAIN.FP16) 156 | 157 | # Update scheduler 158 | if cfg.OPTIM.SCHEDULER_TYPE == 'cosine': 159 | lr_scheduler.step(epoch+1) 160 | else: 161 | lr_scheduler.step() 162 | 163 | # Save checkpoint 164 | if (epoch+1) % cfg.LOG.CHECKPOINT.SAVE_INTERVAL == 0: 165 | save_checkpoint(model, optimizer, lr_scheduler, ckpt_save_dir, epoch+1) 166 | print('=> Checkpoint is saved.') 167 | 168 | # Evaluation 169 | if ((epoch+1) % cfg.TEST.EVAL_STEP == 0 or (epoch == cfg.TRAIN.EPOCHS - 1)): 170 | print('=> Epoch {} test: '.format(epoch+1)) 171 | cmc, mAP = evaluator.evaluate_vit(test_loader, dataset.query, dataset.gallery, cmc_flag=True, rerank=cfg.TEST.RE_RANK) 172 | 173 | torch.cuda.empty_cache() 174 | print('=> CUDA cache is released.') 175 | print('') 176 | 177 | 178 | end_time = time.monotonic() 179 | dtime = timedelta(seconds=end_time - start_time) 180 | print('=> Task finished: {}'.format(task_name)) 181 | print('Total running time: {}'.format(dtime)) 182 | 183 | # Save benchmark 184 | if cfg.LOG.SAVE_BENCHMARK: 185 | save_benchmark(cfg, mAP, cmc, task_name, dtime) 186 | 187 | if __name__ == '__main__': 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument('--conf', type=str, default='', help='Config file path.') 190 | parser.add_argument('opts', help='Modify config options using CMD.', default=None, nargs=argparse.REMAINDER) 191 | args = parser.parse_args() 192 | 193 | # Load config using yacs 194 | cfg = get_cfg_defaults() 195 | if args.conf != '': 196 | cfg.merge_from_file(args.conf) 197 | cfg.merge_from_list(args.opts) 198 | cfg.freeze() 199 | 200 | # Init env. 201 | if cfg.SEED is not None: 202 | random.seed(cfg.SEED) 203 | np.random.seed(cfg.SEED) 204 | torch.manual_seed(cfg.SEED) 205 | torch.backends.cudnn.deterministic = True 206 | torch.backends.cudnn.benchmark = True 207 | 208 | # Run 209 | main(cfg) 210 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # TMGF full model 2 | # market 3 | python train.py --conf configs/TMGF_full.yml TASK_NAME TMGF_training_market DATASET.ROOT_DIR ./datasets DATASET.NAME Market1501 MODEL.SIE_CAMERA 6 MEMORY_BANK.POS_K 3 4 | # duke 5 | # python train.py --conf configs/TMGF_full.yml TASK_NAME TMGF_training_duke DATASET.ROOT_DIR ./datasets DATASET.NAME DukeMTMC-reID MODEL.SIE_CAMERA 8 MEMORY_BANK.POS_K 2 6 | # msmt 7 | # python train.py --conf configs/TMGF_full.yml TASK_NAME TMGF_training_msmt DATASET.ROOT_DIR ./datasets DATASET.NAME MSMT17 MODEL.SIE_CAMERA 15 MEMORY_BANK.POS_K 3 --------------------------------------------------------------------------------