├── .gitignore ├── README.md ├── config ├── PAT.yml ├── __init__.py ├── defaults.py └── vit.yml ├── data ├── __init__.py ├── build_DG_dataloader.py ├── common.py ├── data_utils.py ├── datasets │ ├── AirportALERT.py │ ├── DG_cuhk02.py │ ├── DG_cuhk03_detected.py │ ├── DG_cuhk03_labeled.py │ ├── DG_cuhk_sysu.py │ ├── DG_dukemtmcreid.py │ ├── DG_grid.py │ ├── DG_iLIDS.py │ ├── DG_market1501.py │ ├── DG_prid.py │ ├── DG_viper.py │ ├── __init__.py │ ├── bases.py │ ├── caviara.py │ ├── cuhk03.py │ ├── dukemtmcreid.py │ ├── grid.py │ ├── iLIDS.py │ ├── lpw.py │ ├── market1501.py │ ├── msmt17.py │ ├── pes3d.py │ ├── pku.py │ ├── prai.py │ ├── prid.py │ ├── randperson.py │ ├── sensereid.py │ ├── shinpuhkan.py │ ├── sysu_mm.py │ ├── thermalworld.py │ ├── vehicleid.py │ ├── veri.py │ ├── veri_keypoint.py │ ├── veriwild.py │ └── viper.py ├── samplers │ ├── __init__.py │ ├── data_sampler.py │ └── triplet_sampler.py └── transforms │ ├── __init__.py │ ├── autoaugment.py │ ├── build.py │ ├── functional.py │ └── transforms.py ├── enviroments.sh ├── loss ├── __init__.py ├── arcface.py ├── build_loss.py ├── ce_labelSmooth.py ├── center_loss.py ├── make_loss.py ├── metric_learning.py ├── myloss.py ├── smooth.py ├── softmax_loss.py └── triplet_loss.py ├── model ├── __init__.py ├── backbones │ ├── IBN.py │ ├── __init__.py │ ├── resnet.py │ ├── resnet_ibn.py │ └── vit_pytorch.py └── make_model.py ├── processor ├── __init__.py ├── ori_vit_processor_with_amp.py └── part_attention_vit_processor.py ├── run.sh ├── solver ├── __init__.py ├── cosine_lr.py ├── lr_scheduler.py ├── make_optimizer.py ├── scheduler.py └── scheduler_factory.py ├── test.py ├── train.py ├── utils ├── __init__.py ├── comm.py ├── file_io.py ├── iotools.py ├── logger.py ├── meter.py ├── metrics.py ├── registry.py └── reranking.py └── visualization ├── config_vis ├── __init__.py └── vit_b.py ├── good_samples_market_query.json ├── readme.md ├── test.jpg ├── vit_explain.py └── vit_rollout ├── vit_example.py ├── vit_grad_rollout.py └── vit_rollout.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | tb_log 3 | .vscode 4 | *.yml -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Part-Aware-Transformer 2 | 3 | ## 🔥 News 4 | We updated the visualization codes. 5 | 6 | See instructions in /visualization/readme.md. 7 | 8 | ## Welcome 9 | 10 | This is the official repo for "Part-Aware Transformer for Generalizable Person Re-identification" [ICCV 2023] 11 | 12 |
13 | 14 | 15 | 16 | 17 | ## Abstract 18 | Domain generalization person re-identification (DG-ReID) aims to train a model on source domains and generalize well on unseen domains. 19 | Vision Transformer usually yields better generalization ability than common CNN networks under distribution shifts. 20 | However, Transformer-based ReID models inevitably over-fit to domain-specific biases due to the supervised learning strategy on the source domain. 21 | We observe that while the global images of different IDs should have different features, their similar local parts (e.g., black backpack) are not bounded by this constraint. 22 | Motivated by this, we propose a pure Transformer model (termed Part-aware Transformer) for DG-ReID by designing a proxy task, named Cross-ID Similarity Learning (CSL), to mine local visual information shared by different IDs. This proxy task allows the model to learn generic features because it only cares about the visual similarity of the parts regardless of the ID labels, thus alleviating the side effect of domain-specific biases. 23 | Based on the local similarity obtained in CSL, a Part-guided Self-Distillation (PSD) is proposed to further improve the generalization of global features. 24 | Our method achieves state-of-the-art performance under most DG ReID settings. 25 | 26 | ## Framework 27 |
28 | 29 | ## Visualizations 30 |
31 |
32 | 33 | # Instructions 34 | 35 | Here are some instructions to run our code. 36 | Our code is based on [TransReID](https://github.com/damo-cv/TransReID), thanks for their excellent work. 37 | 38 | ## 1. Clone this repo 39 | ``` 40 | git clone https://github.com/liyuke65535/Part-Aware-Transformer.git 41 | ``` 42 | 43 | ## 2. Prepare your environment 44 | ``` 45 | conda create -n pat python==3.10 46 | conda activate pat 47 | bash enviroments.sh 48 | ``` 49 | 50 | ## 3. Prepare pretrained model (ViT-B) and datasets 51 | You can download it from huggingface, rwightman, or else where. 52 | For example, pretrained model is avaliable at [ViT-B](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth). 53 | 54 | As for datasets, follow the instructions in [MetaBIN](https://github.com/bismex/MetaBIN#8-datasets). 55 | 56 | ## 4. Modify the config file 57 | ``` 58 | # modify the model path and dataset paths of the config file 59 | vim ./config/PAT.yml 60 | ``` 61 | 62 | ## 5. Train a model 63 | ``` 64 | bash run.sh 65 | ``` 66 | 67 | ## 6. Evaluation only 68 | ``` 69 | # modify the trained path in config 70 | vim ./config/PAT.yml 71 | 72 | # evaluation 73 | python test.py --config ./config/PAT.yml 74 | ``` 75 | ## Citation 76 | If you find this repo useful for your research, you're welcome to cite our paper. 77 | ``` 78 | @inproceedings{ni2023part, 79 | title={Part-Aware Transformer for Generalizable Person Re-identification}, 80 | author={Ni, Hao and Li, Yuke and Gao, Lianli and Shen, Heng Tao and Song, Jingkuan}, 81 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 82 | pages={11280--11289}, 83 | year={2023} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /config/PAT.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: "../../.cache/torch/hub/checkpoints" # root of pretrain path 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'on' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'part_attention_vit' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256,128] 15 | SIZE_TEST: [256,128] 16 | REA: 17 | ENABLED: False 18 | PIXEL_MEAN: [0.5, 0.5, 0.5] 19 | PIXEL_STD: [0.5, 0.5, 0.5] 20 | LGT: # Local Grayscale Transfomation 21 | DO_LGT: False 22 | PROB: 0.5 23 | 24 | DATASETS: 25 | TRAIN: ('Market1501',) 26 | TEST: ("DukeMTMC",) 27 | ROOT_DIR: ('../../data') # root of datasets 28 | 29 | DATALOADER: 30 | SAMPLER: 'softmax_triplet' 31 | NUM_INSTANCE: 4 32 | NUM_WORKERS: 8 33 | 34 | SOLVER: 35 | OPTIMIZER_NAME: 'SGD' 36 | MAX_EPOCHS: 60 37 | BASE_LR: 0.001 # 0.0004 for msmt 38 | IMS_PER_BATCH: 64 39 | WARMUP_METHOD: 'linear' 40 | LARGE_FC_LR: False 41 | CHECKPOINT_PERIOD: 5 42 | LOG_PERIOD: 60 43 | EVAL_PERIOD: 1 44 | WEIGHT_DECAY: 1e-4 45 | WEIGHT_DECAY_BIAS: 1e-4 46 | BIAS_LR_FACTOR: 2 47 | SEED: 1234 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 128 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: True 56 | 57 | LOG_ROOT: '../../data/exp/' # root of log file 58 | TB_LOG_ROOT: './tb_log/' 59 | LOG_NAME: 'PAT/market/vit_base' 60 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from .defaults import _C as cfg 4 | from .defaults import _C as cfg_test 5 | -------------------------------------------------------------------------------- /config/vit.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: "../../.cache/torch/hub/checkpoints" # root of pretrain path 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'on' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'vit' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256,128] 15 | SIZE_TEST: [256,128] 16 | REA: 17 | ENABLED: False 18 | PIXEL_MEAN: [0.5, 0.5, 0.5] 19 | PIXEL_STD: [0.5, 0.5, 0.5] 20 | LGT: # Local Grayscale Transfomation 21 | DO_LGT: False 22 | PROB: 0.5 23 | 24 | DATASETS: 25 | TRAIN: ('Market1501',) 26 | TEST: ("DukeMTMC",) 27 | ROOT_DIR: ('../../data') # root of datasets 28 | 29 | DATALOADER: 30 | SAMPLER: 'softmax_triplet' 31 | NUM_INSTANCE: 4 32 | NUM_WORKERS: 8 33 | 34 | SOLVER: 35 | OPTIMIZER_NAME: 'SGD' 36 | MAX_EPOCHS: 60 37 | BASE_LR: 0.008 # 0.0004 for msmt 38 | IMS_PER_BATCH: 64 39 | WARMUP_METHOD: 'linear' 40 | LARGE_FC_LR: False 41 | CHECKPOINT_PERIOD: 5 42 | LOG_PERIOD: 60 43 | EVAL_PERIOD: 5 44 | WEIGHT_DECAY: 1e-4 45 | WEIGHT_DECAY_BIAS: 1e-4 46 | BIAS_LR_FACTOR: 2 47 | SEED: 1234 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 128 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: True 56 | 57 | LOG_ROOT: '../../data/exp/' # root of log file 58 | TB_LOG_ROOT: './tb_log/' 59 | LOG_NAME: 'vit/market/vit_base' 60 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_DG_dataloader import build_reid_train_loader, build_reid_test_loader 2 | -------------------------------------------------------------------------------- /data/build_DG_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | import collections.abc as container_abcs 5 | 6 | # from torch._six import container_abcs, string_classes, int_classes 7 | int_classes = int 8 | string_classes = str 9 | from torch.utils.data import DataLoader 10 | from utils import comm 11 | import random 12 | 13 | from . import samplers 14 | from .common import CommDataset 15 | from .datasets import DATASET_REGISTRY 16 | from .transforms import build_transforms 17 | 18 | _root = os.getenv("REID_DATASETS", "../../data") 19 | 20 | 21 | def build_reid_train_loader(cfg): 22 | gettrace = getattr(sys, 'gettrace', None) 23 | if gettrace(): 24 | print('*'*100) 25 | print('Hmm, Big Debugger is watching me') 26 | print('*'*100) 27 | num_workers = 0 28 | else: 29 | num_workers = cfg.DATALOADER.NUM_WORKERS 30 | 31 | train_transforms = build_transforms(cfg, is_train=True, is_fake=False) 32 | train_items = list() 33 | domain_idx = 0 34 | camera_all = list() 35 | 36 | # load datasets 37 | _root = cfg.DATASETS.ROOT_DIR 38 | for d in cfg.DATASETS.TRAIN: 39 | if d == 'CUHK03_NP': 40 | dataset = DATASET_REGISTRY.get('CUHK03')(root=_root, cuhk03_labeled=False) 41 | else: 42 | dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL) 43 | if comm.is_main_process(): 44 | dataset.show_train() 45 | if len(dataset.train[0]) < 4: 46 | for i, x in enumerate(dataset.train): 47 | add_info = {} # dictionary 48 | 49 | if cfg.DATALOADER.CAMERA_TO_DOMAIN: 50 | add_info['domains'] = dataset.train[i][2] 51 | camera_all.append(dataset.train[i][2]) 52 | else: 53 | add_info['domains'] = int(domain_idx) 54 | dataset.train[i] = list(dataset.train[i]) 55 | dataset.train[i].append(add_info) 56 | dataset.train[i] = tuple(dataset.train[i]) 57 | domain_idx += 1 58 | train_items.extend(dataset.train) 59 | 60 | train_set = CommDataset(train_items, train_transforms, relabel=True) 61 | 62 | train_loader = make_sampler( 63 | train_set=train_set, 64 | num_batch=cfg.SOLVER.IMS_PER_BATCH, 65 | num_instance=cfg.DATALOADER.NUM_INSTANCE, 66 | num_workers=num_workers, 67 | mini_batch_size=cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size(), 68 | drop_last=cfg.DATALOADER.DROP_LAST, 69 | flag1=cfg.DATALOADER.NAIVE_WAY, 70 | flag2=cfg.DATALOADER.DELETE_REM, 71 | cfg = cfg) 72 | 73 | return train_loader 74 | 75 | 76 | def build_reid_test_loader(cfg, dataset_name, opt=None, flag_test=True, shuffle=False, only_gallery=False, only_query=False, eval_time=False): 77 | test_transforms = build_transforms(cfg, is_train=False) 78 | _root = cfg.DATASETS.ROOT_DIR 79 | if opt is None: 80 | dataset = DATASET_REGISTRY.get(dataset_name)(root=_root) 81 | if comm.is_main_process(): 82 | if flag_test: 83 | dataset.show_test() 84 | else: 85 | dataset.show_train() 86 | else: 87 | dataset = DATASET_REGISTRY.get(dataset_name)(root=[_root, opt]) 88 | if flag_test: 89 | if only_gallery: 90 | test_items = dataset.gallery 91 | elif only_query: 92 | test_set = CommDataset([random.choice(dataset.query)], test_transforms, relabel=False) 93 | return test_set 94 | else: 95 | test_items = dataset.query + dataset.gallery 96 | if shuffle: # only for visualization 97 | random.shuffle(test_items) 98 | else: 99 | test_items = dataset.train 100 | 101 | test_set = CommDataset(test_items, test_transforms, relabel=False) 102 | 103 | batch_size = cfg.TEST.IMS_PER_BATCH 104 | data_sampler = samplers.InferenceSampler(len(test_set)) 105 | batch_sampler = torch.utils.data.BatchSampler(data_sampler, batch_size, False) 106 | 107 | gettrace = getattr(sys, 'gettrace', None) 108 | if gettrace(): 109 | num_workers = 0 110 | else: 111 | num_workers = cfg.DATALOADER.NUM_WORKERS 112 | 113 | test_loader = DataLoader( 114 | test_set, 115 | batch_sampler=batch_sampler, 116 | num_workers=num_workers, # save some memory 117 | collate_fn=fast_batch_collator) 118 | return test_loader, len(dataset.query) 119 | 120 | 121 | def trivial_batch_collator(batch): 122 | """ 123 | A batch collator that does nothing. 124 | """ 125 | return batch 126 | 127 | 128 | def fast_batch_collator(batched_inputs): 129 | """ 130 | A simple batch collator for most common reid tasks 131 | """ 132 | elem = batched_inputs[0] 133 | if isinstance(elem, torch.Tensor): 134 | out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype) 135 | for i, tensor in enumerate(batched_inputs): 136 | out[i] += tensor 137 | return out 138 | 139 | elif isinstance(elem, container_abcs.Mapping): 140 | return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem} 141 | 142 | elif isinstance(elem, float): 143 | return torch.tensor(batched_inputs, dtype=torch.float64) 144 | elif isinstance(elem, int_classes): 145 | return torch.tensor(batched_inputs) 146 | elif isinstance(elem, string_classes): 147 | return batched_inputs 148 | elif isinstance(elem, list): 149 | out_g = [] 150 | out_pt1 = [] 151 | out_pt2 = [] 152 | out_pt3 = [] 153 | # out = torch.stack(elem, dim=0) 154 | for i, tensor_list in enumerate(batched_inputs): 155 | out_g.append(tensor_list[0]) 156 | out_pt1.append(tensor_list[1]) 157 | out_pt2.append(tensor_list[2]) 158 | out_pt3.append(tensor_list[3]) 159 | out = torch.stack(out_g, dim=0) 160 | out_pt1 = torch.stack(out_pt1, dim=0) 161 | out_pt2 = torch.stack(out_pt2, dim=0) 162 | out_pt3 = torch.stack(out_pt3, dim=0) 163 | return out, out_pt1, out_pt2, out_pt3 164 | 165 | 166 | def make_sampler(train_set, num_batch, num_instance, num_workers, 167 | mini_batch_size, drop_last=True, flag1=True, flag2=True, seed=None, cfg=None): 168 | 169 | if flag1: 170 | data_sampler = samplers.RandomIdentitySampler(train_set.img_items, 171 | mini_batch_size, num_instance) 172 | else: 173 | data_sampler = samplers.DomainSuffleSampler(train_set.img_items, 174 | num_batch, num_instance, flag2, seed, cfg) 175 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, drop_last) 176 | train_loader = torch.utils.data.DataLoader( 177 | train_set, 178 | num_workers=num_workers, 179 | batch_sampler=batch_sampler, 180 | collate_fn=fast_batch_collator, 181 | ) 182 | return train_loader -------------------------------------------------------------------------------- /data/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | from .data_utils import read_image 5 | 6 | 7 | class CommDataset(Dataset): 8 | """Image Person ReID Dataset""" 9 | 10 | def __init__(self, img_items, transform=None, relabel=True): 11 | self.img_items = img_items 12 | self.transform = transform 13 | self.relabel = relabel 14 | 15 | self.pid_dict = {} 16 | if self.relabel: 17 | pids = list() 18 | for i, item in enumerate(img_items): 19 | if item[1] in pids: continue 20 | pids.append(item[1]) 21 | self.pids = pids 22 | self.pid_dict = dict([(p, i) for i, p in enumerate(self.pids)]) 23 | 24 | def __len__(self): 25 | return len(self.img_items) 26 | 27 | def __getitem__(self, index): 28 | if len(self.img_items[index]) > 3: 29 | img_path, pid, camid, others = self.img_items[index] 30 | else: 31 | img_path, pid, camid = self.img_items[index] 32 | others = '' 33 | img = read_image(img_path) 34 | if self.transform is not None: img = self.transform(img) 35 | if self.relabel: pid = self.pid_dict[pid] 36 | return { 37 | "images": img, 38 | "targets": pid, 39 | "camid": camid, 40 | "img_path": img_path, 41 | "others": others 42 | } 43 | 44 | @property 45 | def num_classes(self): 46 | return len(self.pids) 47 | -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageOps 3 | 4 | from utils.file_io import PathManager 5 | 6 | 7 | def read_image(file_name, format=None): 8 | """ 9 | Read an image into the given format. 10 | Will apply rotation and flipping if the image has such exif information. 11 | Args: 12 | file_name (str): image file path 13 | format (str): one of the supported image modes in PIL, or "BGR" 14 | Returns: 15 | image (np.ndarray): an HWC image 16 | """ 17 | with PathManager.open(file_name, "rb") as f: 18 | image = Image.open(f) 19 | 20 | # capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973 21 | try: 22 | image = ImageOps.exif_transpose(image) 23 | except Exception: 24 | pass 25 | 26 | if format is not None: 27 | # PIL only supports RGB, so convert to RGB and flip channels over below 28 | conversion_format = format 29 | if format == "BGR": 30 | conversion_format = "RGB" 31 | image = image.convert(conversion_format) 32 | image = np.asarray(image) 33 | if format == "BGR": 34 | # flip channels if needed 35 | image = image[:, :, ::-1] 36 | # PIL squeezes out the channel dimension for "L", so make it HWC 37 | if format == "L": 38 | image = np.expand_dims(image, -1) 39 | image = Image.fromarray(image) 40 | return image 41 | -------------------------------------------------------------------------------- /data/datasets/AirportALERT.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from . import DATASET_REGISTRY 4 | from .bases import ImageDataset 5 | 6 | __all__ = ['AirportALERT', ] 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class AirportALERT(ImageDataset): 11 | dataset_dir = "AirportALERT" 12 | dataset_name = "airport" 13 | 14 | def __init__(self, root='datasets', **kwargs): 15 | self.root = root 16 | self.train_path = os.path.join(self.root, self.dataset_dir) 17 | self.train_file = os.path.join(self.root, self.dataset_dir, 'filepath.txt') 18 | 19 | required_files = [self.train_file, self.train_path] 20 | self.check_before_run(required_files) 21 | 22 | train = self.process_train(self.train_path, self.train_file) 23 | 24 | super().__init__(train, [], [], **kwargs) 25 | 26 | def process_train(self, dir_path, train_file): 27 | data = [] 28 | with open(train_file, "r") as f: 29 | img_paths = [line.strip('\n') for line in f.readlines()] 30 | 31 | for path in img_paths: 32 | split_path = path.split('\\') 33 | img_path = '/'.join(split_path) 34 | camid = self.dataset_name + "_" + split_path[0] 35 | pid = self.dataset_name + "_" + split_path[1] 36 | img_path = os.path.join(dir_path, img_path) 37 | if 11001 <= int(split_path[1]) <= 401999: 38 | data.append([img_path, pid, camid]) 39 | 40 | return data 41 | -------------------------------------------------------------------------------- /data/datasets/DG_cuhk02.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from . import DATASET_REGISTRY 5 | from .bases import ImageDataset 6 | 7 | __all__ = ['DG_CUHK02', ] 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class DG_CUHK02(ImageDataset): 12 | dataset_dir = "cuhk02" 13 | dataset_name = "cuhk02" 14 | 15 | def __init__(self, root='datasets', **kwargs): 16 | self.root = root 17 | self.train_path = os.path.join(self.root, self.dataset_dir) 18 | 19 | required_files = [self.train_path] 20 | self.check_before_run(required_files) 21 | 22 | train = self.process_train(self.train_path) 23 | 24 | super().__init__(train, [], [], **kwargs) 25 | 26 | def process_train(self, train_path): 27 | 28 | 29 | cam_split = True 30 | 31 | data = [] 32 | file_path = os.listdir(train_path) 33 | for pid_dir in file_path: 34 | img_file = os.path.join(train_path, pid_dir) 35 | cam1_folder = os.path.join(img_file, 'cam1') 36 | cam = '1' 37 | 38 | # if os.path.join(img_file, 'cam1'): 39 | img_paths = glob(os.path.join(cam1_folder, "*.png")) 40 | for img_path in img_paths: 41 | split_path = img_path.split('/')[-1].split('_') 42 | pid = self.dataset_name + "_" + pid_dir + "_" + split_path[0] 43 | camid = int(cam) 44 | # if cam_split: 45 | # camid = self.dataset_name + "_" + pid_dir + "_" + cam 46 | # else: 47 | # camid = self.dataset_name + "_" + cam 48 | data.append([img_path, pid, camid]) 49 | 50 | cam2_folder = os.path.join(img_file, 'cam2') 51 | cam = '2' 52 | 53 | img_paths = glob(os.path.join(cam2_folder, "*.png")) 54 | for img_path in img_paths: 55 | split_path = img_path.split('/')[-1].split('_') 56 | pid = self.dataset_name + "_" + pid_dir + "_" + split_path[0] 57 | camid = int(cam) 58 | # if cam_split: 59 | # camid = self.dataset_name + "_" + pid_dir + "_" + cam 60 | # else: 61 | # camid = self.dataset_name + "_" + cam 62 | data.append([img_path, pid, camid]) 63 | return data 64 | -------------------------------------------------------------------------------- /data/datasets/DG_cuhk_sysu.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from . import DATASET_REGISTRY 5 | from .bases import ImageDataset 6 | 7 | __all__ = ['DG_CUHK_SYSU', ] 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class DG_CUHK_SYSU(ImageDataset): 12 | dataset_dir = "CUHK-SYSU" 13 | dataset_name = "CUHK-SYSU" 14 | 15 | def __init__(self, root='datasets', **kwargs): 16 | self.root = root 17 | self.train_path = os.path.join(self.root, self.dataset_dir, 'cropped_image') 18 | 19 | required_files = [self.train_path] 20 | self.check_before_run(required_files) 21 | 22 | train = self.process_train(self.train_path) 23 | 24 | super().__init__(train, [], [], **kwargs) 25 | 26 | def process_train(self, train_path): 27 | data = [] 28 | img_paths = glob(os.path.join(train_path, "*.png")) 29 | for img_path in img_paths: 30 | split_path = img_path.split('/')[-1].split('_') # p00001_n01_s00001_hard0.png 31 | pid = self.dataset_name + "_" + split_path[0][1:] 32 | camid = int(split_path[2][1:]) 33 | # camid = self.dataset_name + "_" + split_path[2][1:] 34 | data.append([img_path, pid, camid]) 35 | return data 36 | -------------------------------------------------------------------------------- /data/datasets/DG_dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | import re 4 | 5 | from .bases import ImageDataset 6 | from ..datasets import DATASET_REGISTRY 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class DG_DukeMTMC(ImageDataset): 11 | """DukeMTMC-reID. 12 | 13 | Reference: 14 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 15 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 16 | 17 | URL: ``_ 18 | 19 | Dataset statistics: 20 | - identities: 1404 (train + query). 21 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 22 | - cameras: 8. 23 | """ 24 | dataset_dir = 'DukeMTMC-reID' 25 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 26 | dataset_name = "dukemtmc" 27 | 28 | def __init__(self, root='datasets', **kwargs): 29 | # self.root = osp.abspath(osp.expanduser(root)) 30 | self.root = root 31 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 32 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 33 | self.query_dir = osp.join(self.dataset_dir, 'query') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 35 | 36 | required_files = [ 37 | self.dataset_dir, 38 | self.train_dir, 39 | self.query_dir, 40 | self.gallery_dir, 41 | ] 42 | self.check_before_run(required_files) 43 | 44 | train = self.process_dir(self.train_dir) 45 | query = self.process_dir(self.query_dir, is_train=True) 46 | gallery = self.process_dir(self.gallery_dir, is_train=True) 47 | train = train + query + gallery 48 | 49 | super(DG_DukeMTMC, self).__init__(train, [], [], **kwargs) 50 | 51 | def process_dir(self, dir_path, is_train=True): 52 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 53 | pattern = re.compile(r'([-\d]+)_c(\d)') 54 | 55 | data = [] 56 | for img_path in img_paths: 57 | pid, camid = map(int, pattern.search(img_path).groups()) 58 | assert 1 <= camid <= 8 59 | camid -= 1 # index starts from 0 60 | if is_train: 61 | pid = self.dataset_name + "_" + str(pid) 62 | data.append((img_path, pid, camid)) 63 | 64 | return data 65 | -------------------------------------------------------------------------------- /data/datasets/DG_grid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | from glob import glob 4 | 5 | from . import DATASET_REGISTRY 6 | from .bases import ImageDataset 7 | import pdb 8 | 9 | __all__ = ['DG_GRID',] 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class DG_GRID(ImageDataset): 14 | dataset_dir = "GRID" 15 | dataset_name = 'grid' 16 | 17 | def __init__(self, root='datasets', split_id = 0, **kwargs): 18 | 19 | if isinstance(root, list): 20 | split_id = root[1] 21 | self.root = root[0] 22 | else: 23 | self.root = root 24 | split_id = 0 25 | self.dataset_dir = os.path.join(self.root, self.dataset_dir) 26 | 27 | self.probe_path = os.path.join( 28 | self.dataset_dir, 'probe' 29 | ) 30 | self.gallery_path = os.path.join( 31 | self.dataset_dir, 'gallery' 32 | ) 33 | self.split_mat_path = os.path.join( 34 | self.dataset_dir, 'features_and_partitions.mat' 35 | ) 36 | self.split_path = os.path.join(self.dataset_dir, 'splits.json') 37 | 38 | required_files = [ 39 | self.dataset_dir, self.probe_path, self.gallery_path, 40 | self.split_mat_path 41 | ] 42 | self.check_before_run(required_files) 43 | 44 | self.prepare_split() 45 | splits = self.read_json(self.split_path) 46 | if split_id >= len(splits): 47 | raise ValueError( 48 | 'split_id exceeds range, received {}, ' 49 | 'but expected between 0 and {}'.format( 50 | split_id, 51 | len(splits) - 1 52 | ) 53 | ) 54 | split = splits[split_id] 55 | 56 | train = split['train'] 57 | query = split['query'] 58 | gallery = split['gallery'] 59 | 60 | train = [tuple(item) for item in train] 61 | query = [tuple(item) for item in query] 62 | gallery = [tuple(item) for item in gallery] 63 | 64 | super(DG_GRID, self).__init__(train, query, gallery, **kwargs) 65 | 66 | def prepare_split(self): 67 | if not os.path.exists(self.split_path): 68 | print('Creating 10 random splits') 69 | split_mat = loadmat(self.split_mat_path) 70 | trainIdxAll = split_mat['trainIdxAll'][0] # length = 10 71 | probe_img_paths = sorted( 72 | glob(os.path.join(self.probe_path, '*.jpeg')) 73 | ) 74 | gallery_img_paths = sorted( 75 | glob(os.path.join(self.gallery_path, '*.jpeg')) 76 | ) 77 | 78 | splits = [] 79 | for split_idx in range(10): 80 | train_idxs = trainIdxAll[split_idx][0][0][2][0].tolist() 81 | assert len(train_idxs) == 125 82 | idx2label = { 83 | idx: label 84 | for label, idx in enumerate(train_idxs) 85 | } 86 | 87 | train, query, gallery = [], [], [] 88 | 89 | # processing probe folder 90 | for img_path in probe_img_paths: 91 | img_name = os.path.basename(img_path) 92 | img_idx = int(img_name.split('_')[0]) 93 | camid = int( 94 | img_name.split('_')[1] 95 | ) - 1 # index starts from 0 96 | if img_idx in train_idxs: 97 | train.append((img_path, idx2label[img_idx], camid)) 98 | else: 99 | query.append((img_path, img_idx, camid)) 100 | 101 | # process gallery folder 102 | for img_path in gallery_img_paths: 103 | img_name = os.path.basename(img_path) 104 | img_idx = int(img_name.split('_')[0]) 105 | camid = int( 106 | img_name.split('_')[1] 107 | ) - 1 # index starts from 0 108 | if img_idx in train_idxs: 109 | train.append((img_path, idx2label[img_idx], camid)) 110 | else: 111 | gallery.append((img_path, img_idx, camid)) 112 | 113 | split = { 114 | 'train': train, 115 | 'query': query, 116 | 'gallery': gallery, 117 | 'num_train_pids': 125, 118 | 'num_query_pids': 125, 119 | 'num_gallery_pids': 900 120 | } 121 | splits.append(split) 122 | 123 | print('Totally {} splits are created'.format(len(splits))) 124 | self.write_json(splits, self.split_path) 125 | print('Split file saved to {}'.format(self.split_path)) 126 | 127 | 128 | def read_json(self, fpath): 129 | import json 130 | """Reads json file from a path.""" 131 | with open(fpath, 'r') as f: 132 | obj = json.load(f) 133 | return obj 134 | 135 | 136 | def write_json(self, obj, fpath): 137 | import json 138 | """Writes to a json file.""" 139 | self.mkdir_if_missing(os.path.dirname(fpath)) 140 | with open(fpath, 'w') as f: 141 | json.dump(obj, f, indent=4, separators=(',', ': ')) 142 | 143 | 144 | def mkdir_if_missing(self, dirname): 145 | import errno 146 | """Creates dirname if it is missing.""" 147 | if not os.path.exists(dirname): 148 | try: 149 | os.makedirs(dirname) 150 | except OSError as e: 151 | if e.errno != errno.EEXIST: 152 | raise -------------------------------------------------------------------------------- /data/datasets/DG_iLIDS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import copy 4 | import random 5 | from collections import defaultdict 6 | from . import DATASET_REGISTRY 7 | from .bases import ImageDataset 8 | 9 | __all__ = ['DG_iLIDS', ] 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class DG_iLIDS(ImageDataset): 14 | dataset_dir = "QMUL-iLIDS" 15 | dataset_name = "ilids" 16 | 17 | def __init__(self, root='datasets', split_id = 0, **kwargs): 18 | 19 | if isinstance(root, list): 20 | split_id = root[1] 21 | self.root = root[0] 22 | else: 23 | self.root = root 24 | split_id = 0 25 | self.dataset_dir = os.path.join(self.root, self.dataset_dir) 26 | # self.download_dataset(self.dataset_dir, self.dataset_url) 27 | 28 | self.data_dir = os.path.join(self.dataset_dir, 'images') 29 | self.split_path = os.path.join(self.dataset_dir, 'splits.json') 30 | 31 | required_files = [self.dataset_dir, self.data_dir] 32 | self.check_before_run(required_files) 33 | 34 | self.prepare_split() 35 | splits = self.read_json(self.split_path) 36 | if split_id >= len(splits): 37 | raise ValueError( 38 | 'split_id exceeds range, received {}, but ' 39 | 'expected between 0 and {}'.format(split_id, 40 | len(splits) - 1) 41 | ) 42 | split = splits[split_id] 43 | 44 | train, query, gallery = self.process_split(split) 45 | 46 | super(DG_iLIDS, self).__init__(train, query, gallery, **kwargs) 47 | 48 | def prepare_split(self): 49 | if not os.path.exists(self.split_path): 50 | print('Creating splits ...') 51 | 52 | paths = glob.glob(os.path.join(self.data_dir, '*.jpg')) 53 | img_names = [os.path.basename(path) for path in paths] 54 | num_imgs = len(img_names) 55 | assert num_imgs == 476, 'There should be 476 images, but ' \ 56 | 'got {}, please check the data'.format(num_imgs) 57 | 58 | # store image names 59 | # image naming format: 60 | # the first four digits denote the person ID 61 | # the last four digits denote the sequence index 62 | pid_dict = defaultdict(list) 63 | for img_name in img_names: 64 | pid = int(img_name[:4]) 65 | pid_dict[pid].append(img_name) 66 | pids = list(pid_dict.keys()) 67 | num_pids = len(pids) 68 | assert num_pids == 119, 'There should be 119 identities, ' \ 69 | 'but got {}, please check the data'.format(num_pids) 70 | 71 | num_train_pids = int(num_pids * 0.5) 72 | 73 | splits = [] 74 | for _ in range(10): 75 | # randomly choose num_train_pids train IDs and the rest for test IDs 76 | pids_copy = copy.deepcopy(pids) 77 | random.shuffle(pids_copy) 78 | train_pids = pids_copy[:num_train_pids] 79 | test_pids = pids_copy[num_train_pids:] 80 | 81 | train = [] 82 | query = [] 83 | gallery = [] 84 | 85 | # for train IDs, all images are used in the train set. 86 | for pid in train_pids: 87 | img_names = pid_dict[pid] 88 | train.extend(img_names) 89 | 90 | # for each test ID, randomly choose two images, one for 91 | # query and the other one for gallery. 92 | for pid in test_pids: 93 | img_names = pid_dict[pid] 94 | samples = random.sample(img_names, 2) 95 | query.append(samples[0]) 96 | gallery.append(samples[1]) 97 | 98 | split = {'train': train, 'query': query, 'gallery': gallery} 99 | splits.append(split) 100 | 101 | print('Totally {} splits are created'.format(len(splits))) 102 | self.write_json(splits, self.split_path) 103 | print('Split file is saved to {}'.format(self.split_path)) 104 | 105 | def get_pid2label(self, img_names): 106 | pid_container = set() 107 | for img_name in img_names: 108 | pid = int(img_name[:4]) 109 | pid_container.add(pid) 110 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 111 | return pid2label 112 | 113 | def parse_img_names(self, img_names, pid2label=None): 114 | data = [] 115 | 116 | for img_name in img_names: 117 | pid = int(img_name[:4]) 118 | if pid2label is not None: 119 | pid = pid2label[pid] 120 | camid = int(img_name[4:7]) - 1 # 0-based 121 | img_path = os.path.join(self.data_dir, img_name) 122 | data.append((img_path, pid, camid)) 123 | 124 | return data 125 | 126 | def process_split(self, split): 127 | train_pid2label = self.get_pid2label(split['train']) 128 | train = self.parse_img_names(split['train'], train_pid2label) 129 | query = self.parse_img_names(split['query']) 130 | gallery = self.parse_img_names(split['gallery']) 131 | return train, query, gallery 132 | 133 | def read_json(self, fpath): 134 | import json 135 | """Reads json file from a path.""" 136 | with open(fpath, 'r') as f: 137 | obj = json.load(f) 138 | return obj 139 | 140 | def write_json(self, obj, fpath): 141 | import json 142 | """Writes to a json file.""" 143 | self.mkdir_if_missing(os.path.dirname(fpath)) 144 | with open(fpath, 'w') as f: 145 | json.dump(obj, f, indent=4, separators=(',', ': ')) 146 | 147 | def mkdir_if_missing(self, dirname): 148 | import errno 149 | """Creates dirname if it is missing.""" 150 | if not os.path.exists(dirname): 151 | try: 152 | os.makedirs(dirname) 153 | except OSError as e: 154 | if e.errno != errno.EEXIST: 155 | raise -------------------------------------------------------------------------------- /data/datasets/DG_market1501.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | import re 4 | import warnings 5 | 6 | from .bases import ImageDataset 7 | from ..datasets import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class DG_Market1501(ImageDataset): 12 | """Market1501. 13 | 14 | Reference: 15 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 16 | 17 | URL: ``_ 18 | 19 | Dataset statistics: 20 | - identities: 1501 (+1 for background). 21 | - images: 12936 (train) + 3368 (query) + 15913 (gallery). 22 | """ 23 | _junk_pids = [0, -1] 24 | dataset_dir = '' 25 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip' 26 | dataset_name = "market1501" 27 | 28 | def __init__(self, root='datasets', market1501_500k=False, **kwargs): 29 | # self.root = osp.abspath(osp.expanduser(root)) 30 | self.root = root 31 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 32 | 33 | # allow alternative directory structure 34 | self.data_dir = self.dataset_dir 35 | data_dir = osp.join(self.data_dir, 'market1501') 36 | if osp.isdir(data_dir): 37 | self.data_dir = data_dir 38 | else: 39 | warnings.warn('The current data structure is deprecated. Please ' 40 | 'put data folders such as "bounding_box_train" under ' 41 | '"Market-1501-v15.09.15".') 42 | 43 | self.train_dir = osp.join(self.data_dir, 'bounding_box_train') 44 | self.query_dir = osp.join(self.data_dir, 'query') 45 | self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test') 46 | self.extra_gallery_dir = osp.join(self.data_dir, 'images') 47 | self.market1501_500k = market1501_500k 48 | 49 | required_files = [ 50 | self.data_dir, 51 | self.train_dir, 52 | self.query_dir, 53 | self.gallery_dir, 54 | ] 55 | if self.market1501_500k: 56 | required_files.append(self.extra_gallery_dir) 57 | self.check_before_run(required_files) 58 | 59 | train = self.process_dir(self.train_dir) 60 | query = self.process_dir(self.query_dir, is_train=True) 61 | gallery = self.process_dir(self.gallery_dir, is_train=True) 62 | train = train + query + gallery 63 | if self.market1501_500k: 64 | gallery += self.process_dir(self.extra_gallery_dir, is_train=False) 65 | 66 | super(DG_Market1501, self).__init__(train, [], [], **kwargs) 67 | 68 | def process_dir(self, dir_path, is_train=True): 69 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 70 | pattern = re.compile(r'([-\d]+)_c(\d)') 71 | 72 | data = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1 or pid == 0: 76 | continue # junk images are just ignored 77 | assert 0 <= pid <= 1501 # pid == 0 means background 78 | assert 1 <= camid <= 6 79 | camid -= 1 # index starts from 0 80 | if is_train: 81 | pid = self.dataset_name + "_" + str(pid) 82 | data.append((img_path, pid, camid)) 83 | 84 | return data 85 | -------------------------------------------------------------------------------- /data/datasets/DG_prid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | from glob import glob 4 | import random 5 | from . import DATASET_REGISTRY 6 | from .bases import ImageDataset 7 | import pdb 8 | 9 | __all__ = ['DG_PRID', ] 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class DG_PRID(ImageDataset): 14 | dataset_dir = "prid_2011" 15 | dataset_name = 'prid' 16 | _junk_pids = list(range(201, 750)) 17 | 18 | def __init__(self, root='datasets', split_id=0, **kwargs): 19 | 20 | if isinstance(root, list): 21 | split_id = root[1] 22 | self.root = root[0] 23 | else: 24 | self.root = root 25 | split_id = 0 26 | self.dataset_dir = os.path.join(self.root, self.dataset_dir) 27 | # self.download_dataset(self.dataset_dir, self.dataset_url) 28 | 29 | self.cam_a_dir = os.path.join( 30 | self.dataset_dir, 'single_shot', 'cam_a' 31 | ) 32 | self.cam_b_dir = os.path.join( 33 | self.dataset_dir, 'single_shot', 'cam_b' 34 | ) 35 | self.split_path = os.path.join(self.dataset_dir, 'splits_single_shot.json') 36 | 37 | required_files = [ 38 | self.dataset_dir, 39 | self.cam_a_dir, 40 | self.cam_b_dir 41 | ] 42 | self.check_before_run(required_files) 43 | 44 | self.prepare_split() 45 | splits = self.read_json(self.split_path) 46 | if split_id >= len(splits): 47 | raise ValueError( 48 | 'split_id exceeds range, received {}, but expected between 0 and {}' 49 | .format(split_id, 50 | len(splits) - 1) 51 | ) 52 | split = splits[split_id] 53 | 54 | train, query, gallery = self.process_split(split) 55 | 56 | super(DG_PRID, self).__init__(train, query, gallery, **kwargs) 57 | 58 | def prepare_split(self): 59 | if not os.path.exists(self.split_path): 60 | print('Creating splits ...') 61 | 62 | splits = [] 63 | for _ in range(10): 64 | # randomly sample 100 IDs for train and use the rest 100 IDs for test 65 | # (note: there are only 200 IDs appearing in both views) 66 | pids = [i for i in range(1, 201)] 67 | train_pids = random.sample(pids, 100) 68 | train_pids.sort() 69 | test_pids = [i for i in pids if i not in train_pids] 70 | split = {'train': train_pids, 'test': test_pids} 71 | splits.append(split) 72 | 73 | print('Totally {} splits are created'.format(len(splits))) 74 | self.write_json(splits, self.split_path) 75 | print('Split file is saved to {}'.format(self.split_path)) 76 | 77 | def process_split(self, split): 78 | train_pids = split['train'] 79 | test_pids = split['test'] 80 | 81 | train_pid2label = {pid: label for label, pid in enumerate(train_pids)} 82 | 83 | # train 84 | train = [] 85 | for pid in train_pids: 86 | img_name = 'person_' + str(pid).zfill(4) + '.png' 87 | pid = train_pid2label[pid] 88 | img_a_path = os.path.join(self.cam_a_dir, img_name) 89 | train.append((img_a_path, pid, 0)) 90 | img_b_path = os.path.join(self.cam_b_dir, img_name) 91 | train.append((img_b_path, pid, 1)) 92 | 93 | # query and gallery 94 | query, gallery = [], [] 95 | for pid in test_pids: 96 | img_name = 'person_' + str(pid).zfill(4) + '.png' 97 | img_a_path = os.path.join(self.cam_a_dir, img_name) 98 | query.append((img_a_path, pid, 0)) 99 | img_b_path = os.path.join(self.cam_b_dir, img_name) 100 | gallery.append((img_b_path, pid, 1)) 101 | for pid in range(201, 750): 102 | img_name = 'person_' + str(pid).zfill(4) + '.png' 103 | img_b_path = os.path.join(self.cam_b_dir, img_name) 104 | gallery.append((img_b_path, pid, 1)) 105 | 106 | return train, query, gallery 107 | 108 | def read_json(self, fpath): 109 | import json 110 | """Reads json file from a path.""" 111 | with open(fpath, 'r') as f: 112 | obj = json.load(f) 113 | return obj 114 | 115 | def write_json(self, obj, fpath): 116 | import json 117 | """Writes to a json file.""" 118 | self.mkdir_if_missing(os.path.dirname(fpath)) 119 | with open(fpath, 'w') as f: 120 | json.dump(obj, f, indent=4, separators=(',', ': ')) 121 | 122 | def mkdir_if_missing(self, dirname): 123 | import errno 124 | """Creates dirname if it is missing.""" 125 | if not os.path.exists(dirname): 126 | try: 127 | os.makedirs(dirname) 128 | except OSError as e: 129 | if e.errno != errno.EEXIST: 130 | raise -------------------------------------------------------------------------------- /data/datasets/DG_viper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from . import DATASET_REGISTRY 5 | from .bases import ImageDataset 6 | 7 | __all__ = ['DG_viper', ] 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class DG_VIPeR(ImageDataset): 12 | dataset_dir = "viper" 13 | dataset_name = "viper" 14 | 15 | def __init__(self, root='datasets', **kwargs): 16 | if isinstance(root, list): 17 | type = root[1] 18 | self.root = root[0] 19 | else: 20 | self.root = root 21 | type = 'split_1a' 22 | self.train_dir = os.path.join(self.root, self.dataset_dir, type, 'train') 23 | self.query_dir = os.path.join(self.root, self.dataset_dir, type, 'query') 24 | self.gallery_dir = os.path.join(self.root, self.dataset_dir, type, 'gallery') 25 | 26 | required_files = [ 27 | self.train_dir, 28 | self.query_dir, 29 | self.gallery_dir, 30 | ] 31 | self.check_before_run(required_files) 32 | 33 | train = self.process_train(self.train_dir, is_train = True) 34 | query = self.process_train(self.query_dir, is_train = False) 35 | gallery = self.process_train(self.gallery_dir, is_train = False) 36 | 37 | super().__init__(train, query, gallery, **kwargs) 38 | 39 | def process_train(self, path, is_train = True): 40 | data = [] 41 | img_list = glob(os.path.join(path, '*.png')) 42 | for img_path in img_list: 43 | img_name = img_path.split('/')[-1] # p000_c1_d045.png 44 | split_name = img_name.split('_') 45 | pid = int(split_name[0][1:]) 46 | if is_train: 47 | pid = self.dataset_name + "_" + str(pid) 48 | camid = int(split_name[1][1:]) 49 | # dirid = int(split_name[2][1:-4]) 50 | data.append([img_path, pid, camid]) 51 | 52 | return data -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.registry import Registry 2 | 3 | DATASET_REGISTRY = Registry("DATASET") 4 | DATASET_REGISTRY.__doc__ = """ 5 | Registry for datasets 6 | It must returns an instance of :class:`Backbone`. 7 | """ 8 | 9 | 10 | # Person re-id datasets 11 | from .cuhk03 import CUHK03 12 | from .DG_cuhk_sysu import DG_CUHK_SYSU 13 | from .DG_cuhk02 import DG_CUHK02 14 | from .DG_cuhk03_labeled import DG_CUHK03_labeled 15 | from .DG_cuhk03_detected import DG_CUHK03_detected 16 | from .dukemtmcreid import DukeMTMC 17 | from .DG_dukemtmcreid import DG_DukeMTMC 18 | from .market1501 import Market1501 19 | from .DG_market1501 import DG_Market1501 20 | from .msmt17 import MSMT17 21 | from .AirportALERT import AirportALERT 22 | from .iLIDS import iLIDS 23 | from .pku import PKU 24 | from .grid import GRID 25 | from .prai import PRAI 26 | from .prid import PRID 27 | from .DG_prid import DG_PRID 28 | from .DG_grid import DG_GRID 29 | from .sensereid import SenseReID 30 | from .sysu_mm import SYSU_mm 31 | from .thermalworld import Thermalworld 32 | from .pes3d import PeS3D 33 | from .caviara import CAVIARa 34 | from .viper import VIPeR 35 | from .DG_viper import DG_VIPeR 36 | from .DG_iLIDS import DG_iLIDS 37 | from .lpw import LPW 38 | from .shinpuhkan import Shinpuhkan 39 | # Vehicle re-id datasets 40 | from .veri import VeRi 41 | from .veri_keypoint import VeRi_keypoint 42 | from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID 43 | from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild 44 | from .randperson import RandPerson 45 | 46 | 47 | __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")] 48 | -------------------------------------------------------------------------------- /data/datasets/bases.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | 5 | class Dataset(object): 6 | """An abstract class representing a Dataset. 7 | This is the base class for ``ImageDataset`` and ``VideoDataset``. 8 | Args: 9 | train (list): contains tuples of (img_path(s), pid, camid). 10 | query (list): contains tuples of (img_path(s), pid, camid). 11 | gallery (list): contains tuples of (img_path(s), pid, camid). 12 | transform: transform function. 13 | mode (str): 'train', 'query' or 'gallery'. 14 | combineall (bool): combines train, query and gallery in a 15 | dataset for training. 16 | verbose (bool): show information. 17 | """ 18 | _junk_pids = [] # contains useless person IDs, e.g. background, false detections 19 | 20 | def __init__(self, train, query, gallery, transform=None, mode='train', 21 | combineall=False, verbose=True, **kwargs): 22 | self.train = train 23 | self.query = query 24 | self.gallery = gallery 25 | self.query = [tuple(q_tuple)+({'q_or_g': 'query'},) for q_tuple in self.query] 26 | self.gallery = [tuple(g_tuple)+({'q_or_g': 'gallery'},) for g_tuple in self.gallery] 27 | self.transform = transform 28 | self.mode = mode 29 | self.combineall = combineall 30 | self.verbose = verbose 31 | 32 | # if self.train != []: 33 | self.num_train_pids = self.get_num_pids(self.train) 34 | self.num_train_cams = self.get_num_cams(self.train) 35 | 36 | if self.combineall: 37 | self.combine_all() 38 | 39 | if self.mode == 'train': 40 | self.data = self.train 41 | elif self.mode == 'query': 42 | self.data = self.query 43 | elif self.mode == 'gallery': 44 | self.data = self.gallery 45 | else: 46 | raise ValueError('Invalid mode. Got {}, but expected to be ' 47 | 'one of [train | query | gallery]'.format(self.mode)) 48 | 49 | # if self.verbose: 50 | # self.show_summary() 51 | 52 | def __getitem__(self, index): 53 | raise NotImplementedError 54 | 55 | def __len__(self): 56 | return len(self.data) 57 | 58 | def __radd__(self, other): 59 | """Supports sum([dataset1, dataset2, dataset3]).""" 60 | if other == 0: 61 | return self 62 | else: 63 | return self.__add__(other) 64 | 65 | def parse_data(self, data): 66 | """Parses data list and returns the number of person IDs 67 | and the number of camera views. 68 | Args: 69 | data (list): contains tuples of (img_path(s), pid, camid) 70 | """ 71 | pids = set() 72 | cams = set() 73 | if len(data[0]) > 3: 74 | for _, pid, camid, _ in data: 75 | pids.add(pid) 76 | cams.add(camid) 77 | else: 78 | for _, pid, camid in data: 79 | pids.add(pid) 80 | cams.add(camid) 81 | return len(pids), len(cams) 82 | 83 | def get_num_pids(self, data): 84 | """Returns the number of training person identities.""" 85 | return self.parse_data(data)[0] 86 | 87 | def get_num_cams(self, data): 88 | """Returns the number of training cameras.""" 89 | return self.parse_data(data)[1] 90 | 91 | def show_summary(self): 92 | """Shows dataset statistics.""" 93 | pass 94 | 95 | def combine_all(self): 96 | """Combines train, query and gallery in a dataset for training.""" 97 | combined = copy.deepcopy(self.train) 98 | 99 | def _combine_data(data): 100 | for img_path, pid, camid, _ in data: 101 | if pid in self._junk_pids: 102 | continue 103 | pid = self.dataset_name + "_" + str(pid) 104 | combined.append((img_path, pid, camid)) 105 | 106 | _combine_data(self.query) 107 | _combine_data(self.gallery) 108 | 109 | self.train = combined 110 | self.num_train_pids = self.get_num_pids(self.train) 111 | 112 | def check_before_run(self, required_files): 113 | """Checks if required files exist before going deeper. 114 | Args: 115 | required_files (str or list): string file name(s). 116 | """ 117 | if isinstance(required_files, str): 118 | required_files = [required_files] 119 | 120 | for fpath in required_files: 121 | if not os.path.exists(fpath): 122 | raise RuntimeError('"{}" is not found'.format(fpath)) 123 | 124 | def __repr__(self): 125 | num_train_pids, num_train_cams = self.parse_data(self.train) 126 | num_query_pids, num_query_cams = self.parse_data(self.query) 127 | num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery) 128 | 129 | msg = ' ----------------------------------------\n' \ 130 | ' subset | # ids | # items | # cameras\n' \ 131 | ' ----------------------------------------\n' \ 132 | ' train | {:5d} | {:7d} | {:9d}\n' \ 133 | ' query | {:5d} | {:7d} | {:9d}\n' \ 134 | ' gallery | {:5d} | {:7d} | {:9d}\n' \ 135 | ' ----------------------------------------\n' \ 136 | ' items: images/tracklets for image/video dataset\n'.format( 137 | num_train_pids, len(self.train), num_train_cams, 138 | num_query_pids, len(self.query), num_query_cams, 139 | num_gallery_pids, len(self.gallery), num_gallery_cams 140 | ) 141 | 142 | return msg 143 | 144 | 145 | class ImageDataset(Dataset): 146 | """A base class representing ImageDataset. 147 | All other image datasets should subclass it. 148 | ``__getitem__`` returns an image given index. 149 | It will return ``img``, ``pid``, ``camid`` and ``img_path`` 150 | where ``img`` has shape (channel, height, width). As a result, 151 | data in each batch has shape (batch_size, channel, height, width). 152 | """ 153 | 154 | def __init__(self, train, query, gallery, **kwargs): 155 | super(ImageDataset, self).__init__(train, query, gallery, **kwargs) 156 | 157 | def show_train(self): 158 | logger = logging.getLogger('PAT') 159 | num_train_pids, num_train_cams = self.parse_data(self.train) 160 | logger.info('=> Loaded {}'.format(self.__class__.__name__)) 161 | logger.info(' ----------------------------------------') 162 | logger.info(' subset | # ids | # images | # cameras') 163 | logger.info(' ----------------------------------------') 164 | logger.info(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams)) 165 | logger.info(' ----------------------------------------') 166 | 167 | def show_test(self): 168 | logger = logging.getLogger('PAT') 169 | num_query_pids, num_query_cams = self.parse_data(self.query) 170 | num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery) 171 | logger.info('=> Loaded {}'.format(self.__class__.__name__)) 172 | logger.info(' ----------------------------------------') 173 | logger.info(' subset | # ids | # images | # cameras') 174 | logger.info(' ----------------------------------------') 175 | logger.info(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams)) 176 | logger.info(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams)) 177 | logger.info(' ----------------------------------------') 178 | -------------------------------------------------------------------------------- /data/datasets/caviara.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | from glob import glob 4 | 5 | from . import DATASET_REGISTRY 6 | from .bases import ImageDataset 7 | import pdb 8 | import random 9 | import numpy as np 10 | 11 | __all__ = ['CAVIARa',] 12 | 13 | 14 | @DATASET_REGISTRY.register() 15 | class CAVIARa(ImageDataset): 16 | dataset_dir = "CAVIARa" 17 | dataset_name = "caviara" 18 | 19 | def __init__(self, root='datasets', **kwargs): 20 | self.root = root 21 | self.train_path = os.path.join(self.root, self.dataset_dir) 22 | 23 | required_files = [self.train_path] 24 | self.check_before_run(required_files) 25 | 26 | train = self.process_train(self.train_path) 27 | 28 | super().__init__(train, [], [], **kwargs) 29 | 30 | def process_train(self, train_path): 31 | data = [] 32 | 33 | img_list = glob(os.path.join(train_path, "*.jpg")) 34 | for img_path in img_list: 35 | img_name = img_path.split('/')[-1] 36 | pid = self.dataset_name + "_" + img_name[:4] 37 | camid = self.dataset_name + "_cam0" 38 | data.append([img_path, pid, camid]) 39 | 40 | return data 41 | -------------------------------------------------------------------------------- /data/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | import re 4 | 5 | from .bases import ImageDataset 6 | from ..datasets import DATASET_REGISTRY 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class DukeMTMC(ImageDataset): 11 | """DukeMTMC-reID. 12 | 13 | Reference: 14 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 15 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 16 | 17 | URL: ``_ 18 | 19 | Dataset statistics: 20 | - identities: 1404 (train + query). 21 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 22 | - cameras: 8. 23 | """ 24 | dataset_dir = 'DukeMTMC-reID' 25 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 26 | dataset_name = "dukemtmc" 27 | 28 | def __init__(self, root='datasets', **kwargs): 29 | # self.root = osp.abspath(osp.expanduser(root)) 30 | self.root = root 31 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 32 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 33 | self.query_dir = osp.join(self.dataset_dir, 'query') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 35 | 36 | required_files = [ 37 | self.dataset_dir, 38 | self.train_dir, 39 | self.query_dir, 40 | self.gallery_dir, 41 | ] 42 | self.check_before_run(required_files) 43 | 44 | train = self.process_dir(self.train_dir) 45 | query = self.process_dir(self.query_dir, is_train=False) 46 | gallery = self.process_dir(self.gallery_dir, is_train=False) 47 | 48 | super(DukeMTMC, self).__init__(train, query, gallery, **kwargs) 49 | 50 | def process_dir(self, dir_path, is_train=True): 51 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 52 | pattern = re.compile(r'([-\d]+)_c(\d)') 53 | 54 | data = [] 55 | for img_path in img_paths: 56 | pid, camid = map(int, pattern.search(img_path).groups()) 57 | assert 1 <= camid <= 8 58 | camid -= 1 # index starts from 0 59 | if is_train: 60 | pid = self.dataset_name + "_" + str(pid) 61 | data.append((img_path, pid, camid)) 62 | 63 | return data 64 | -------------------------------------------------------------------------------- /data/datasets/grid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | from glob import glob 4 | 5 | from . import DATASET_REGISTRY 6 | from .bases import ImageDataset 7 | import pdb 8 | 9 | __all__ = ['GRID',] 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class GRID(ImageDataset): 14 | dataset_dir = "GRID" 15 | dataset_name = 'grid' 16 | 17 | def __init__(self, root='datasets', split_id = 0, **kwargs): 18 | self.root = root 19 | self.dataset_dir = os.path.join(self.root, self.dataset_dir) 20 | 21 | self.probe_path = os.path.join( 22 | self.dataset_dir, 'probe' 23 | ) 24 | self.gallery_path = os.path.join( 25 | self.dataset_dir, 'gallery' 26 | ) 27 | self.split_mat_path = os.path.join( 28 | self.dataset_dir, 'features_and_partitions.mat' 29 | ) 30 | self.split_path = os.path.join(self.dataset_dir, 'splits.json') 31 | 32 | required_files = [ 33 | self.dataset_dir, self.probe_path, self.gallery_path, 34 | self.split_mat_path 35 | ] 36 | self.check_before_run(required_files) 37 | 38 | self.prepare_split() 39 | splits = self.read_json(self.split_path) 40 | if split_id >= len(splits): 41 | raise ValueError( 42 | 'split_id exceeds range, received {}, ' 43 | 'but expected between 0 and {}'.format( 44 | split_id, 45 | len(splits) - 1 46 | ) 47 | ) 48 | split = splits[split_id] 49 | 50 | train = split['train'] 51 | query = split['query'] 52 | gallery = split['gallery'] 53 | 54 | train = [tuple(item) for item in train] 55 | query = [tuple(item) for item in query] 56 | gallery = [tuple(item) for item in gallery] 57 | 58 | super(GRID, self).__init__(train, query, gallery, **kwargs) 59 | 60 | def prepare_split(self): 61 | if not os.path.exists(self.split_path): 62 | print('Creating 10 random splits') 63 | split_mat = loadmat(self.split_mat_path) 64 | trainIdxAll = split_mat['trainIdxAll'][0] # length = 10 65 | probe_img_paths = sorted( 66 | glob(os.path.join(self.probe_path, '*.jpeg')) 67 | ) 68 | gallery_img_paths = sorted( 69 | glob(os.path.join(self.gallery_path, '*.jpeg')) 70 | ) 71 | 72 | splits = [] 73 | for split_idx in range(10): 74 | train_idxs = trainIdxAll[split_idx][0][0][2][0].tolist() 75 | assert len(train_idxs) == 125 76 | idx2label = { 77 | idx: label 78 | for label, idx in enumerate(train_idxs) 79 | } 80 | 81 | train, query, gallery = [], [], [] 82 | 83 | # processing probe folder 84 | for img_path in probe_img_paths: 85 | img_name = os.path.basename(img_path) 86 | img_idx = int(img_name.split('_')[0]) 87 | camid = int( 88 | img_name.split('_')[1] 89 | ) - 1 # index starts from 0 90 | if img_idx in train_idxs: 91 | train.append((img_path, idx2label[img_idx], camid)) 92 | else: 93 | query.append((img_path, img_idx, camid)) 94 | 95 | # process gallery folder 96 | for img_path in gallery_img_paths: 97 | img_name = os.path.basename(img_path) 98 | img_idx = int(img_name.split('_')[0]) 99 | camid = int( 100 | img_name.split('_')[1] 101 | ) - 1 # index starts from 0 102 | if img_idx in train_idxs: 103 | train.append((img_path, idx2label[img_idx], camid)) 104 | else: 105 | gallery.append((img_path, img_idx, camid)) 106 | 107 | split = { 108 | 'train': train, 109 | 'query': query, 110 | 'gallery': gallery, 111 | 'num_train_pids': 125, 112 | 'num_query_pids': 125, 113 | 'num_gallery_pids': 900 114 | } 115 | splits.append(split) 116 | 117 | print('Totally {} splits are created'.format(len(splits))) 118 | self.write_json(splits, self.split_path) 119 | print('Split file saved to {}'.format(self.split_path)) 120 | 121 | 122 | def read_json(self, fpath): 123 | import json 124 | """Reads json file from a path.""" 125 | with open(fpath, 'r') as f: 126 | obj = json.load(f) 127 | return obj 128 | 129 | 130 | def write_json(self, obj, fpath): 131 | import json 132 | """Writes to a json file.""" 133 | self.mkdir_if_missing(os.path.dirname(fpath)) 134 | with open(fpath, 'w') as f: 135 | json.dump(obj, f, indent=4, separators=(',', ': ')) 136 | 137 | 138 | def mkdir_if_missing(self, dirname): 139 | import errno 140 | """Creates dirname if it is missing.""" 141 | if not os.path.exists(dirname): 142 | try: 143 | os.makedirs(dirname) 144 | except OSError as e: 145 | if e.errno != errno.EEXIST: 146 | raise -------------------------------------------------------------------------------- /data/datasets/iLIDS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import copy 4 | import random 5 | from collections import defaultdict 6 | from . import DATASET_REGISTRY 7 | from .bases import ImageDataset 8 | 9 | __all__ = ['iLIDS', ] 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class iLIDS(ImageDataset): 14 | dataset_dir = "QMUL-iLIDS" 15 | dataset_name = "ilids" 16 | 17 | def __init__(self, root='datasets', split_id = 0, **kwargs): 18 | # self.root = os.path.abspath(os.path.expanduser(root)) 19 | self.root = root 20 | self.dataset_dir = os.path.join(self.root, self.dataset_dir) 21 | # self.download_dataset(self.dataset_dir, self.dataset_url) 22 | 23 | self.data_dir = os.path.join(self.dataset_dir, 'images') 24 | self.split_path = os.path.join(self.dataset_dir, 'splits.json') 25 | 26 | required_files = [self.dataset_dir, self.data_dir] 27 | self.check_before_run(required_files) 28 | 29 | self.prepare_split() 30 | splits = self.read_json(self.split_path) 31 | if split_id >= len(splits): 32 | raise ValueError( 33 | 'split_id exceeds range, received {}, but ' 34 | 'expected between 0 and {}'.format(split_id, 35 | len(splits) - 1) 36 | ) 37 | split = splits[split_id] 38 | 39 | train, query, gallery = self.process_split(split) 40 | 41 | super(iLIDS, self).__init__(train, query, gallery, **kwargs) 42 | 43 | def prepare_split(self): 44 | if not os.path.exists(self.split_path): 45 | print('Creating splits ...') 46 | 47 | paths = glob.glob(os.path.join(self.data_dir, '*.jpg')) 48 | img_names = [os.path.basename(path) for path in paths] 49 | num_imgs = len(img_names) 50 | assert num_imgs == 476, 'There should be 476 images, but ' \ 51 | 'got {}, please check the data'.format(num_imgs) 52 | 53 | # store image names 54 | # image naming format: 55 | # the first four digits denote the person ID 56 | # the last four digits denote the sequence index 57 | pid_dict = defaultdict(list) 58 | for img_name in img_names: 59 | pid = int(img_name[:4]) 60 | pid_dict[pid].append(img_name) 61 | pids = list(pid_dict.keys()) 62 | num_pids = len(pids) 63 | assert num_pids == 119, 'There should be 119 identities, ' \ 64 | 'but got {}, please check the data'.format(num_pids) 65 | 66 | num_train_pids = int(num_pids * 0.5) 67 | 68 | splits = [] 69 | for _ in range(10): 70 | # randomly choose num_train_pids train IDs and the rest for test IDs 71 | pids_copy = copy.deepcopy(pids) 72 | random.shuffle(pids_copy) 73 | train_pids = pids_copy[:num_train_pids] 74 | test_pids = pids_copy[num_train_pids:] 75 | 76 | train = [] 77 | query = [] 78 | gallery = [] 79 | 80 | # for train IDs, all images are used in the train set. 81 | for pid in train_pids: 82 | img_names = pid_dict[pid] 83 | train.extend(img_names) 84 | 85 | # for each test ID, randomly choose two images, one for 86 | # query and the other one for gallery. 87 | for pid in test_pids: 88 | img_names = pid_dict[pid] 89 | samples = random.sample(img_names, 2) 90 | query.append(samples[0]) 91 | gallery.append(samples[1]) 92 | 93 | split = {'train': train, 'query': query, 'gallery': gallery} 94 | splits.append(split) 95 | 96 | print('Totally {} splits are created'.format(len(splits))) 97 | self.write_json(splits, self.split_path) 98 | print('Split file is saved to {}'.format(self.split_path)) 99 | 100 | def get_pid2label(self, img_names): 101 | pid_container = set() 102 | for img_name in img_names: 103 | pid = int(img_name[:4]) 104 | pid_container.add(pid) 105 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 106 | return pid2label 107 | 108 | def parse_img_names(self, img_names, pid2label=None): 109 | data = [] 110 | 111 | for img_name in img_names: 112 | pid = int(img_name[:4]) 113 | if pid2label is not None: 114 | pid = pid2label[pid] 115 | camid = int(img_name[4:7]) - 1 # 0-based 116 | img_path = os.path.join(self.data_dir, img_name) 117 | data.append((img_path, pid, camid)) 118 | 119 | return data 120 | 121 | def process_split(self, split): 122 | train_pid2label = self.get_pid2label(split['train']) 123 | train = self.parse_img_names(split['train'], train_pid2label) 124 | query = self.parse_img_names(split['query']) 125 | gallery = self.parse_img_names(split['gallery']) 126 | return train, query, gallery 127 | 128 | def read_json(self, fpath): 129 | import json 130 | """Reads json file from a path.""" 131 | with open(fpath, 'r') as f: 132 | obj = json.load(f) 133 | return obj 134 | 135 | def write_json(self, obj, fpath): 136 | import json 137 | """Writes to a json file.""" 138 | self.mkdir_if_missing(os.path.dirname(fpath)) 139 | with open(fpath, 'w') as f: 140 | json.dump(obj, f, indent=4, separators=(',', ': ')) 141 | 142 | def mkdir_if_missing(self, dirname): 143 | import errno 144 | """Creates dirname if it is missing.""" 145 | if not os.path.exists(dirname): 146 | try: 147 | os.makedirs(dirname) 148 | except OSError as e: 149 | if e.errno != errno.EEXIST: 150 | raise 151 | -------------------------------------------------------------------------------- /data/datasets/lpw.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from . import DATASET_REGISTRY 5 | from .bases import ImageDataset 6 | 7 | __all__ = ['LPW', ] 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class LPW(ImageDataset): 12 | dataset_dir = "pep_256x128" 13 | dataset_name = "lpw" 14 | 15 | def __init__(self, root='datasets', **kwargs): 16 | self.root = root 17 | self.train_path = os.path.join(self.root, self.dataset_dir) 18 | 19 | required_files = [self.train_path] 20 | self.check_before_run(required_files) 21 | 22 | train = self.process_train(self.train_path) 23 | 24 | super().__init__(train, [], [], **kwargs) 25 | 26 | def process_train(self, train_path): 27 | data = [] 28 | 29 | file_path_list = ['scen1', 'scen2', 'scen3'] 30 | 31 | for scene in file_path_list: 32 | cam_list = os.listdir(os.path.join(train_path, scene)) 33 | for cam in cam_list: 34 | camid = self.dataset_name + "_" + cam 35 | pid_list = os.listdir(os.path.join(train_path, scene, cam)) 36 | for pid_dir in pid_list: 37 | img_paths = glob(os.path.join(train_path, scene, cam, pid_dir, "*.jpg")) 38 | for img_path in img_paths: 39 | pid = self.dataset_name + "_" + scene + "-" + pid_dir 40 | data.append([img_path, pid, camid]) 41 | return data 42 | -------------------------------------------------------------------------------- /data/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | import re 4 | import warnings 5 | 6 | from .bases import ImageDataset 7 | from ..datasets import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class Market1501(ImageDataset): 12 | """Market1501. 13 | 14 | Reference: 15 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 16 | 17 | URL: ``_ 18 | 19 | Dataset statistics: 20 | - identities: 1501 (+1 for background). 21 | - images: 12936 (train) + 3368 (query) + 15913 (gallery). 22 | """ 23 | _junk_pids = [0, -1] 24 | dataset_dir = 'market1501' 25 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip' 26 | dataset_name = "market1501" 27 | 28 | def __init__(self, root='datasets', market1501_500k=False, **kwargs): 29 | # self.root = osp.abspath(osp.expanduser(root)) 30 | self.root = root 31 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 32 | 33 | # allow alternative directory structure 34 | self.data_dir = self.dataset_dir 35 | data_dir = osp.join(self.data_dir, 'Market1501') 36 | if osp.isdir(data_dir): 37 | self.data_dir = data_dir 38 | else: 39 | warnings.warn('The current data structure is deprecated. Please ' 40 | 'put data folders such as "bounding_box_train" under ' 41 | '"Market-1501-v15.09.15".') 42 | 43 | self.train_dir = osp.join(self.data_dir, 'bounding_box_train') 44 | self.query_dir = osp.join(self.data_dir, 'query') 45 | self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test') 46 | self.extra_gallery_dir = osp.join(self.data_dir, 'images') 47 | self.market1501_500k = market1501_500k 48 | 49 | required_files = [ 50 | self.data_dir, 51 | self.train_dir, 52 | self.query_dir, 53 | self.gallery_dir, 54 | ] 55 | if self.market1501_500k: 56 | required_files.append(self.extra_gallery_dir) 57 | self.check_before_run(required_files) 58 | 59 | train = self.process_dir(self.train_dir) 60 | query = self.process_dir(self.query_dir, is_train=False) 61 | gallery = self.process_dir(self.gallery_dir, is_train=False) 62 | if self.market1501_500k: 63 | gallery += self.process_dir(self.extra_gallery_dir, is_train=False) 64 | 65 | super(Market1501, self).__init__(train, query, gallery, **kwargs) 66 | 67 | def process_dir(self, dir_path, is_train=True): 68 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 69 | pattern = re.compile(r'([-\d]+)_c(\d)') 70 | 71 | data = [] 72 | for img_path in img_paths: 73 | pid, camid = map(int, pattern.search(img_path).groups()) 74 | if pid == -1: 75 | continue # junk images are just ignored 76 | assert 0 <= pid <= 1501 # pid == 0 means background 77 | assert 1 <= camid <= 6 78 | camid -= 1 # index starts from 0 79 | if is_train: 80 | pid = self.dataset_name + "_" + str(pid) 81 | data.append((img_path, pid, camid)) 82 | 83 | return data 84 | -------------------------------------------------------------------------------- /data/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path as osp 4 | 5 | from .bases import ImageDataset 6 | from ..datasets import DATASET_REGISTRY 7 | ##### Log ##### 8 | # 22.01.2019 9 | # - add v2 10 | # - v1 and v2 differ in dir names 11 | # - note that faces in v2 are blurred 12 | TRAIN_DIR_KEY = 'train_dir' 13 | TEST_DIR_KEY = 'test_dir' 14 | VERSION_DICT = { 15 | 'MSMT17': { 16 | TRAIN_DIR_KEY: 'train', 17 | TEST_DIR_KEY: 'test', 18 | }, 19 | 'MSMT17_V2': { 20 | TRAIN_DIR_KEY: 'mask_train_v2', 21 | TEST_DIR_KEY: 'mask_test_v2', 22 | } 23 | } 24 | 25 | 26 | @DATASET_REGISTRY.register() 27 | class MSMT17(ImageDataset): 28 | """MSMT17. 29 | Reference: 30 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 31 | URL: ``_ 32 | 33 | Dataset statistics: 34 | - identities: 4101. 35 | - images: 32621 (train) + 11659 (query) + 82161 (gallery). 36 | - cameras: 15. 37 | """ 38 | # dataset_dir = 'MSMT17_V2' 39 | dataset_url = None 40 | dataset_name = 'MSMT17' 41 | 42 | def __init__(self, root='datasets', **kwargs): 43 | self.root = root 44 | self.dataset_dir = self.root 45 | 46 | has_main_dir = False 47 | for main_dir in VERSION_DICT: 48 | if osp.exists(osp.join(self.dataset_dir, main_dir)): 49 | train_dir = VERSION_DICT[main_dir][TRAIN_DIR_KEY] 50 | test_dir = VERSION_DICT[main_dir][TEST_DIR_KEY] 51 | has_main_dir = True 52 | break 53 | assert has_main_dir, 'Dataset folder not found' 54 | 55 | self.train_dir = osp.join(self.dataset_dir, main_dir, train_dir) 56 | self.test_dir = osp.join(self.dataset_dir, main_dir, test_dir) 57 | self.list_train_path = osp.join(self.dataset_dir, main_dir, 'list_train.txt') 58 | self.list_val_path = osp.join(self.dataset_dir, main_dir, 'list_val.txt') 59 | self.list_query_path = osp.join(self.dataset_dir, main_dir, 'list_query.txt') 60 | self.list_gallery_path = osp.join(self.dataset_dir, main_dir, 'list_gallery.txt') 61 | 62 | required_files = [ 63 | self.dataset_dir, 64 | self.train_dir, 65 | self.test_dir 66 | ] 67 | self.check_before_run(required_files) 68 | 69 | train = self.process_dir(self.train_dir, self.list_train_path) 70 | val = self.process_dir(self.train_dir, self.list_val_path) 71 | query = self.process_dir(self.test_dir, self.list_query_path, is_train=False) 72 | gallery = self.process_dir(self.test_dir, self.list_gallery_path, is_train=False) 73 | 74 | num_train_pids = self.get_num_pids(train) 75 | query_tmp = [] 76 | for img_path, pid, camid in query: 77 | query_tmp.append((img_path, pid+num_train_pids, camid)) 78 | del query 79 | query = query_tmp 80 | 81 | gallery_temp = [] 82 | for img_path, pid, camid in gallery: 83 | gallery_temp.append((img_path, pid+num_train_pids, camid)) 84 | del gallery 85 | gallery = gallery_temp 86 | 87 | # Note: to fairly compare with published methods on the conventional ReID setting, 88 | # do not add val images to the training set. 89 | if 'combineall' in kwargs and kwargs['combineall']: 90 | train += val 91 | 92 | super(MSMT17, self).__init__(train, query, gallery, **kwargs) 93 | 94 | def process_dir(self, dir_path, list_path, is_train=True): 95 | with open(list_path, 'r') as txt: 96 | lines = txt.readlines() 97 | 98 | data = [] 99 | 100 | for img_idx, img_info in enumerate(lines): 101 | img_path, pid = img_info.split(' ') 102 | pid = int(pid) # no need to relabel 103 | camid = int(img_path.split('_')[2]) - 1 # index starts from 0 104 | img_path = osp.join(dir_path, img_path) 105 | if is_train: 106 | pid = self.dataset_name + "_" + str(pid) 107 | data.append((img_path, pid, camid)) 108 | 109 | return data -------------------------------------------------------------------------------- /data/datasets/pes3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | from glob import glob 4 | 5 | from . import DATASET_REGISTRY 6 | from .bases import ImageDataset 7 | import pdb 8 | import random 9 | import numpy as np 10 | 11 | __all__ = ['PeS3D',] 12 | 13 | 14 | @DATASET_REGISTRY.register() 15 | class PeS3D(ImageDataset): 16 | dataset_dir = "3DPeS" 17 | dataset_name = "pes3d" 18 | 19 | def __init__(self, root='datasets', **kwargs): 20 | self.root = root 21 | self.train_path = os.path.join(self.root, self.dataset_dir) 22 | 23 | required_files = [self.train_path] 24 | self.check_before_run(required_files) 25 | 26 | train = self.process_train(self.train_path) 27 | 28 | super().__init__(train, [], [], **kwargs) 29 | 30 | def process_train(self, train_path): 31 | data = [] 32 | 33 | pid_list = os.listdir(train_path) 34 | for pid_dir in pid_list: 35 | pid = self.dataset_name + "_" + pid_dir 36 | img_list = glob(os.path.join(train_path, pid_dir, "*.bmp")) 37 | for img_path in img_list: 38 | camid = self.dataset_name + "_cam0" 39 | data.append([img_path, pid, camid]) 40 | return data 41 | -------------------------------------------------------------------------------- /data/datasets/pku.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from . import DATASET_REGISTRY 5 | from .bases import ImageDataset 6 | 7 | __all__ = ['PKU', ] 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class PKU(ImageDataset): 12 | dataset_dir = "PKUv1a_128x48" 13 | dataset_name = 'pku' 14 | 15 | def __init__(self, root='datasets', **kwargs): 16 | self.root = root 17 | self.train_path = os.path.join(self.root, self.dataset_dir) 18 | 19 | required_files = [self.train_path] 20 | self.check_before_run(required_files) 21 | 22 | train = self.process_train(self.train_path) 23 | 24 | super().__init__(train, [], [], **kwargs) 25 | 26 | def process_train(self, train_path): 27 | data = [] 28 | img_paths = glob(os.path.join(train_path, "*.png")) 29 | 30 | for img_path in img_paths: 31 | split_path = img_path.split('/') 32 | img_info = split_path[-1].split('_') 33 | pid = self.dataset_name + "_" + img_info[0] 34 | camid = self.dataset_name + "_" + img_info[1] 35 | data.append([img_path, pid, camid]) 36 | return data 37 | -------------------------------------------------------------------------------- /data/datasets/prai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | from glob import glob 4 | 5 | from . import DATASET_REGISTRY 6 | from .bases import ImageDataset 7 | 8 | __all__ = ['PRAI',] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class PRAI(ImageDataset): 13 | dataset_dir = "PRAI-1581" 14 | dataset_name = 'prai' 15 | 16 | def __init__(self, root='datasets', **kwargs): 17 | self.root = root 18 | self.train_path = os.path.join(self.root, self.dataset_dir, 'images') 19 | 20 | required_files = [self.train_path] 21 | self.check_before_run(required_files) 22 | 23 | train = self.process_train(self.train_path) 24 | 25 | super().__init__(train, [], [], **kwargs) 26 | 27 | def process_train(self, train_path): 28 | data = [] 29 | img_paths = glob(os.path.join(train_path, "*.jpg")) 30 | for img_path in img_paths: 31 | split_path = img_path.split('/') 32 | img_info = split_path[-1].split('_') 33 | pid = self.dataset_name + "_" + img_info[0] 34 | camid = self.dataset_name + "_" + img_info[1] 35 | data.append([img_path, pid, camid]) 36 | return data 37 | 38 | -------------------------------------------------------------------------------- /data/datasets/prid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | from glob import glob 4 | import random 5 | from . import DATASET_REGISTRY 6 | from .bases import ImageDataset 7 | import pdb 8 | 9 | __all__ = ['PRID',] 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class PRID(ImageDataset): 14 | dataset_dir = "prid_2011" 15 | dataset_name = 'prid' 16 | _junk_pids = list(range(201, 750)) 17 | 18 | def __init__(self, root='datasets', split_id=0, **kwargs): 19 | 20 | 21 | self.root = root 22 | self.dataset_dir = os.path.join(self.root, self.dataset_dir) 23 | # self.download_dataset(self.dataset_dir, self.dataset_url) 24 | 25 | self.cam_a_dir = os.path.join( 26 | self.dataset_dir, 'single_shot', 'cam_a' 27 | ) 28 | self.cam_b_dir = os.path.join( 29 | self.dataset_dir, 'single_shot', 'cam_b' 30 | ) 31 | self.split_path = os.path.join(self.dataset_dir, 'splits_single_shot.json') 32 | 33 | required_files = [ 34 | self.dataset_dir, 35 | self.cam_a_dir, 36 | self.cam_b_dir 37 | ] 38 | self.check_before_run(required_files) 39 | 40 | self.prepare_split() 41 | splits = self.read_json(self.split_path) 42 | if split_id >= len(splits): 43 | raise ValueError( 44 | 'split_id exceeds range, received {}, but expected between 0 and {}' 45 | .format(split_id, 46 | len(splits) - 1) 47 | ) 48 | split = splits[split_id] 49 | 50 | train, query, gallery = self.process_split(split) 51 | 52 | super(PRID, self).__init__(train, query, gallery, **kwargs) 53 | 54 | def prepare_split(self): 55 | if not os.path.exists(self.split_path): 56 | print('Creating splits ...') 57 | 58 | splits = [] 59 | for _ in range(10): 60 | # randomly sample 100 IDs for train and use the rest 100 IDs for test 61 | # (note: there are only 200 IDs appearing in both views) 62 | pids = [i for i in range(1, 201)] 63 | train_pids = random.sample(pids, 100) 64 | train_pids.sort() 65 | test_pids = [i for i in pids if i not in train_pids] 66 | split = {'train': train_pids, 'test': test_pids} 67 | splits.append(split) 68 | 69 | print('Totally {} splits are created'.format(len(splits))) 70 | self.write_json(splits, self.split_path) 71 | print('Split file is saved to {}'.format(self.split_path)) 72 | 73 | def process_split(self, split): 74 | train_pids = split['train'] 75 | test_pids = split['test'] 76 | 77 | train_pid2label = {pid: label for label, pid in enumerate(train_pids)} 78 | 79 | # train 80 | train = [] 81 | for pid in train_pids: 82 | img_name = 'person_' + str(pid).zfill(4) + '.png' 83 | pid = train_pid2label[pid] 84 | img_a_path = os.path.join(self.cam_a_dir, img_name) 85 | train.append((img_a_path, pid, 0)) 86 | img_b_path = os.path.join(self.cam_b_dir, img_name) 87 | train.append((img_b_path, pid, 1)) 88 | 89 | # query and gallery 90 | query, gallery = [], [] 91 | for pid in test_pids: 92 | img_name = 'person_' + str(pid).zfill(4) + '.png' 93 | img_a_path = os.path.join(self.cam_a_dir, img_name) 94 | query.append((img_a_path, pid, 0)) 95 | img_b_path = os.path.join(self.cam_b_dir, img_name) 96 | gallery.append((img_b_path, pid, 1)) 97 | for pid in range(201, 750): 98 | img_name = 'person_' + str(pid).zfill(4) + '.png' 99 | img_b_path = os.path.join(self.cam_b_dir, img_name) 100 | gallery.append((img_b_path, pid, 1)) 101 | 102 | return train, query, gallery 103 | 104 | def read_json(self, fpath): 105 | import json 106 | """Reads json file from a path.""" 107 | with open(fpath, 'r') as f: 108 | obj = json.load(f) 109 | return obj 110 | 111 | 112 | def write_json(self, obj, fpath): 113 | import json 114 | """Writes to a json file.""" 115 | self.mkdir_if_missing(os.path.dirname(fpath)) 116 | with open(fpath, 'w') as f: 117 | json.dump(obj, f, indent=4, separators=(',', ': ')) 118 | 119 | def mkdir_if_missing(self, dirname): 120 | import errno 121 | """Creates dirname if it is missing.""" 122 | if not os.path.exists(dirname): 123 | try: 124 | os.makedirs(dirname) 125 | except OSError as e: 126 | if e.errno != errno.EEXIST: 127 | raise -------------------------------------------------------------------------------- /data/datasets/randperson.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | from glob import glob 4 | 5 | from ..datasets import DATASET_REGISTRY 6 | 7 | @DATASET_REGISTRY.register() 8 | class RandPerson(object): 9 | 10 | def __init__(self, root, combineall=True): 11 | 12 | self.images_dir = osp.join(root) 13 | self.img_path = 'your_path/randperson_subset/randperson_subset' 14 | self.train_path = self.img_path 15 | self.gallery_path = '' 16 | self.query_path = '' 17 | self.train = [] 18 | self.gallery = [] 19 | self.query = [] 20 | self.num_train_ids = 0 21 | self.has_time_info = True 22 | # self.show_train() 23 | 24 | def preprocess(self): 25 | fpaths = sorted(glob(osp.join(self.images_dir, self.train_path, '*g'))) 26 | 27 | data = [] 28 | all_pids = {} 29 | camera_offset = [0, 2, 4, 6, 8, 9, 10, 12, 13, 14, 15] 30 | frame_offset = [0, 160000, 340000,490000, 640000, 1070000, 1330000, 1590000, 1890000, 3190000, 3490000] 31 | fps = 24 32 | 33 | for fpath in fpaths: 34 | fname = osp.basename(fpath) # filename: id6_s2_c2_f6.jpg 35 | fields = fname.split('_') 36 | pid = int(fields[0]) 37 | if pid not in all_pids: 38 | all_pids[pid] = len(all_pids) 39 | pid = all_pids[pid] # relabel 40 | camid = camera_offset[int(fields[1][1:])] + int(fields[2][1:]) # make it starting from 0 41 | time = (frame_offset[int(fields[1][1:])] + int(fields[3][1:7])) / fps 42 | data.append((fpath, pid, camid, time)) 43 | # print(fname, pid, camid, time) 44 | return data, int(len(all_pids)) 45 | 46 | def show_train(self): 47 | self.train, self.num_train_ids = self.preprocess() 48 | 49 | print(self.__class__.__name__, "dataset loaded") 50 | print(" subset | # ids | # images") 51 | print(" ---------------------------") 52 | print(" all | {:5d} | {:8d}\n" 53 | .format(self.num_train_ids, len(self.train))) 54 | -------------------------------------------------------------------------------- /data/datasets/sensereid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from . import DATASET_REGISTRY 5 | from .bases import ImageDataset 6 | 7 | __all__ = ['SenseReID', ] 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SenseReID(ImageDataset): 12 | dataset_dir = "SenseReID" 13 | dataset_name = "senseid" 14 | 15 | def __init__(self, root='datasets', **kwargs): 16 | self.root = root 17 | self.train_path = os.path.join(self.root, self.dataset_dir) 18 | 19 | required_files = [self.train_path] 20 | self.check_before_run(required_files) 21 | 22 | train = self.process_train(self.train_path) 23 | 24 | super().__init__(train, [], [], **kwargs) 25 | 26 | def process_train(self, train_path): 27 | data = [] 28 | file_path_list = ['test_gallery', 'test_prob'] 29 | 30 | for file_path in file_path_list: 31 | sub_file = os.path.join(train_path, file_path) 32 | img_name = glob(os.path.join(sub_file, "*.jpg")) 33 | for img_path in img_name: 34 | img_name = img_path.split('/')[-1] 35 | img_info = img_name.split('_') 36 | pid = self.dataset_name + "_" + img_info[0] 37 | camid = self.dataset_name + "_" + img_info[1].split('.')[0] 38 | data.append([img_path, pid, camid]) 39 | return data 40 | -------------------------------------------------------------------------------- /data/datasets/shinpuhkan.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from . import DATASET_REGISTRY 4 | from .bases import ImageDataset 5 | 6 | __all__ = ['Shinpuhkan', ] 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class Shinpuhkan(ImageDataset): 11 | dataset_dir = "shinpuhkan" 12 | dataset_name = 'shinpuhkan' 13 | 14 | def __init__(self, root='datasets', **kwargs): 15 | self.root = root 16 | self.train_path = os.path.join(self.root, self.dataset_dir) 17 | 18 | required_files = [self.train_path] 19 | self.check_before_run(required_files) 20 | 21 | train = self.process_train(self.train_path) 22 | 23 | super().__init__(train, [], [], **kwargs) 24 | 25 | def process_train(self, train_path): 26 | data = [] 27 | 28 | for root, dirs, files in os.walk(train_path): 29 | img_names = list(filter(lambda x: x.endswith(".jpg"), files)) 30 | # fmt: off 31 | if len(img_names) == 0: continue 32 | # fmt: on 33 | for img_name in img_names: 34 | img_path = os.path.join(root, img_name) 35 | split_path = img_name.split('_') 36 | pid = self.dataset_name + "_" + split_path[0] 37 | camid = self.dataset_name + "_" + split_path[2] 38 | data.append((img_path, pid, camid)) 39 | 40 | return data 41 | -------------------------------------------------------------------------------- /data/datasets/sysu_mm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | from glob import glob 4 | 5 | from . import DATASET_REGISTRY 6 | from .bases import ImageDataset 7 | import pdb 8 | 9 | __all__ = ['SYSU_mm', ] 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class SYSU_mm(ImageDataset): 14 | dataset_dir = "SYSU-MM01" 15 | dataset_name = "sysumm01" 16 | 17 | def __init__(self, root='datasets', **kwargs): 18 | self.root = root 19 | self.train_path = os.path.join(self.root, self.dataset_dir) 20 | 21 | required_files = [self.train_path] 22 | self.check_before_run(required_files) 23 | 24 | train = self.process_train(self.train_path) 25 | 26 | super().__init__(train, [], [], **kwargs) 27 | 28 | def process_train(self, train_path): 29 | data = [] 30 | 31 | file_path_list = ['cam1', 'cam2', 'cam4', 'cam5'] 32 | 33 | for file_path in file_path_list: 34 | camid = self.dataset_name + "_" + file_path 35 | pid_list = os.listdir(os.path.join(train_path, file_path)) 36 | for pid_dir in pid_list: 37 | pid = self.dataset_name + "_" + pid_dir 38 | img_list = glob(os.path.join(train_path, file_path, pid_dir, "*.jpg")) 39 | for img_path in img_list: 40 | data.append([img_path, pid, camid]) 41 | return data 42 | 43 | -------------------------------------------------------------------------------- /data/datasets/thermalworld.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | from glob import glob 4 | 5 | from . import DATASET_REGISTRY 6 | from .bases import ImageDataset 7 | import pdb 8 | import random 9 | import numpy as np 10 | 11 | __all__ = ['Thermalworld',] 12 | 13 | 14 | @DATASET_REGISTRY.register() 15 | class Thermalworld(ImageDataset): 16 | dataset_dir = "thermalworld_rgb" 17 | dataset_name = "thermalworld" 18 | 19 | def __init__(self, root='datasets', **kwargs): 20 | self.root = root 21 | self.train_path = os.path.join(self.root, self.dataset_dir) 22 | 23 | required_files = [self.train_path] 24 | self.check_before_run(required_files) 25 | 26 | train = self.process_train(self.train_path) 27 | 28 | super().__init__(train, [], [], **kwargs) 29 | 30 | def process_train(self, train_path): 31 | data = [] 32 | pid_list = os.listdir(train_path) 33 | for pid_dir in pid_list: 34 | pid = self.dataset_name + "_" + pid_dir 35 | img_list = glob(os.path.join(train_path, pid_dir, "*.jpg")) 36 | for img_path in img_list: 37 | camid = self.dataset_name + "_cam0" 38 | data.append([img_path, pid, camid]) 39 | return data 40 | -------------------------------------------------------------------------------- /data/datasets/vehicleid.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | 4 | from .bases import ImageDataset 5 | from ..datasets import DATASET_REGISTRY 6 | 7 | 8 | @DATASET_REGISTRY.register() 9 | class VehicleID(ImageDataset): 10 | """VehicleID. 11 | 12 | Reference: 13 | Liu et al. Deep relative distance learning: Tell the difference between similar vehicles. CVPR 2016. 14 | 15 | URL: ``_ 16 | 17 | Train dataset statistics: 18 | - identities: 13164. 19 | - images: 113346. 20 | """ 21 | dataset_dir = "vehicleid" 22 | dataset_name = "vehicleid" 23 | 24 | def __init__(self, root='datasets', test_list='', **kwargs): 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | self.image_dir = osp.join(self.dataset_dir, 'image') 27 | self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt') 28 | if test_list: 29 | self.test_list = test_list 30 | else: 31 | self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_13164.txt') 32 | 33 | required_files = [ 34 | self.dataset_dir, 35 | self.image_dir, 36 | self.train_list, 37 | self.test_list, 38 | ] 39 | self.check_before_run(required_files) 40 | 41 | train = self.process_dir(self.train_list, is_train=True) 42 | query, gallery = self.process_dir(self.test_list, is_train=False) 43 | 44 | super(VehicleID, self).__init__(train, query, gallery, **kwargs) 45 | 46 | def process_dir(self, list_file, is_train=True): 47 | img_list_lines = open(list_file, 'r').readlines() 48 | 49 | dataset = [] 50 | for idx, line in enumerate(img_list_lines): 51 | line = line.strip() 52 | vid = int(line.split(' ')[1]) 53 | imgid = line.split(' ')[0] 54 | img_path = osp.join(self.image_dir, imgid + '.jpg') 55 | if is_train: 56 | vid = self.dataset_name + "_" + str(vid) 57 | dataset.append((img_path, vid, int(imgid))) 58 | 59 | if is_train: return dataset 60 | else: 61 | random.shuffle(dataset) 62 | vid_container = set() 63 | query = [] 64 | gallery = [] 65 | for sample in dataset: 66 | if sample[1] not in vid_container: 67 | vid_container.add(sample[1]) 68 | gallery.append(sample) 69 | else: 70 | query.append(sample) 71 | 72 | return query, gallery 73 | 74 | 75 | @DATASET_REGISTRY.register() 76 | class SmallVehicleID(VehicleID): 77 | """VehicleID. 78 | Small test dataset statistics: 79 | - identities: 800. 80 | - images: 6493. 81 | """ 82 | 83 | def __init__(self, root='datasets', **kwargs): 84 | # self.dataset_dir = osp.join(root, self.dataset_dir) 85 | self.test_list = osp.join(root, self.dataset_dir, 'train_test_split/test_list_800.txt') 86 | 87 | super(SmallVehicleID, self).__init__(root, self.test_list, **kwargs) 88 | 89 | 90 | @DATASET_REGISTRY.register() 91 | class MediumVehicleID(VehicleID): 92 | """VehicleID. 93 | Medium test dataset statistics: 94 | - identities: 1600. 95 | - images: 13377. 96 | """ 97 | 98 | def __init__(self, root='datasets', **kwargs): 99 | # self.dataset_dir = osp.join(root, self.dataset_dir) 100 | self.test_list = osp.join(root, self.dataset_dir, 'train_test_split/test_list_1600.txt') 101 | 102 | super(MediumVehicleID, self).__init__(root, self.test_list, **kwargs) 103 | 104 | 105 | @DATASET_REGISTRY.register() 106 | class LargeVehicleID(VehicleID): 107 | """VehicleID. 108 | Large test dataset statistics: 109 | - identities: 2400. 110 | - images: 19777. 111 | """ 112 | 113 | def __init__(self, root='datasets', **kwargs): 114 | # self.dataset_dir = osp.join(root, self.dataset_dir) 115 | self.test_list = osp.join(root, self.dataset_dir, 'train_test_split/test_list_2400.txt') 116 | 117 | super(LargeVehicleID, self).__init__(root, self.test_list, **kwargs) 118 | -------------------------------------------------------------------------------- /data/datasets/veri.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | import re 4 | 5 | from .bases import ImageDataset 6 | from ..datasets import DATASET_REGISTRY 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class VeRi(ImageDataset): 11 | """VeRi. 12 | 13 | Reference: 14 | Liu et al. A Deep Learning based Approach for Progressive Vehicle Re-Identification. ECCV 2016. 15 | 16 | URL: ``_ 17 | 18 | Dataset statistics: 19 | - identities: 775. 20 | - images: 37778 (train) + 1678 (query) + 11579 (gallery). 21 | """ 22 | dataset_dir = "veri" 23 | dataset_name = "veri" 24 | 25 | def __init__(self, root='datasets', **kwargs): 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | 28 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 31 | 32 | required_files = [ 33 | self.dataset_dir, 34 | self.train_dir, 35 | self.query_dir, 36 | self.gallery_dir, 37 | ] 38 | self.check_before_run(required_files) 39 | 40 | train = self.process_dir(self.train_dir) 41 | query = self.process_dir(self.query_dir, is_train=False) 42 | gallery = self.process_dir(self.gallery_dir, is_train=False) 43 | 44 | super(VeRi, self).__init__(train, query, gallery, **kwargs) 45 | 46 | def process_dir(self, dir_path, is_train=True): 47 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 48 | pattern = re.compile(r'([\d]+)_c(\d\d\d)') 49 | 50 | data = [] 51 | for img_path in img_paths: 52 | pid, camid = map(int, pattern.search(img_path).groups()) 53 | if pid == -1: continue # junk images are just ignored 54 | assert 1 <= pid <= 776 55 | assert 1 <= camid <= 20 56 | camid -= 1 # index starts from 0 57 | if is_train: 58 | pid = self.dataset_name + "_" + str(pid) 59 | data.append((img_path, pid, camid)) 60 | 61 | return data 62 | -------------------------------------------------------------------------------- /data/datasets/veri_keypoint.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | import re 4 | 5 | from .bases import ImageDataset 6 | from ..datasets import DATASET_REGISTRY 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class VeRi_keypoint(ImageDataset): 11 | """VeRi. 12 | 13 | Reference: 14 | Liu et al. A Deep Learning based Approach for Progressive Vehicle Re-Identification. ECCV 2016. 15 | 16 | URL: ``_ 17 | 18 | Dataset statistics: 19 | - identities: 775. 20 | - images: 37778 (train) + 1678 (query) + 11579 (gallery). 21 | """ 22 | dataset_dir = "veri" 23 | dataset_name = "veri" 24 | 25 | def __init__(self, root='datasets', **kwargs): 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.keypoint_dir = osp.join(root, 'veri_keypoint') 28 | 29 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 30 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 31 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 32 | 33 | required_files = [ 34 | self.dataset_dir, 35 | self.train_dir, 36 | self.query_dir, 37 | self.gallery_dir, 38 | self.keypoint_dir 39 | ] 40 | self.check_before_run(required_files) 41 | 42 | train = self.process_dir(self.train_dir) 43 | train = self.process_keypoint(self.keypoint_dir, train) 44 | query = self.process_dir(self.query_dir, is_train=False) 45 | gallery = self.process_dir(self.gallery_dir, is_train=False) 46 | 47 | super(VeRi_keypoint, self).__init__(train, query, gallery, **kwargs) 48 | 49 | def process_dir(self, dir_path, is_train=True): 50 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 51 | pattern = re.compile(r'([\d]+)_c(\d\d\d)') 52 | 53 | data = [] 54 | for img_path in img_paths: 55 | pid, camid = map(int, pattern.search(img_path).groups()) 56 | if pid == -1: continue # junk images are just ignored 57 | assert 1 <= pid <= 776 58 | assert 1 <= camid <= 20 59 | camid -= 1 # index starts from 0 60 | if is_train: 61 | pid = self.dataset_name + "_" + str(pid) 62 | data.append((img_path, pid, camid)) 63 | 64 | 65 | return data 66 | 67 | 68 | def process_keypoint(self, dir_path, data): 69 | train_name = [] 70 | train_raw = [] 71 | train_keypoint = [] 72 | train_orientation = [] 73 | is_keypoint = False 74 | is_orientation = True 75 | is_aligned = False 76 | with open(osp.join(dir_path, 'keypoint_train_aligned.txt')) as f: 77 | for line in f: 78 | train_raw.append(line) 79 | line_split = line.split(' ') 80 | train_name.append(line_split[0].split('/')[-1]) 81 | 82 | if is_keypoint: 83 | train_keypoint.append(line_split[1:41]) 84 | if is_orientation: 85 | tmp = line_split[-1] 86 | if '\n' in tmp: 87 | tmp = tmp[0] 88 | assert 0 <= int(tmp) <= 7 # orientation should be 0~7 89 | train_orientation.append(int(tmp)) 90 | 91 | if is_aligned: 92 | train_name = sorted(tuple(train_name)) 93 | train_raw = sorted(tuple(train_raw)) 94 | 95 | with open(osp.join(dir_path, 'keypoint_train_aligned.txt'), 'w') as f: 96 | for i, x in enumerate(data): 97 | j = 0 98 | flag_break = False 99 | while (j < len(train_name) and not flag_break): 100 | if train_name[j] in x[0]: 101 | if train_name[j] in train_raw[j]: 102 | f.write(train_raw[j]) 103 | flag_break = True 104 | del train_name[j] 105 | del train_raw[j] 106 | print(i) 107 | else: 108 | assert() 109 | j += 1 110 | 111 | 112 | for i, x in enumerate(data): 113 | j = 0 114 | flag_break = False 115 | while(j < len(train_name) and not flag_break): 116 | if train_name[j] in x[0]: 117 | add_info = {} # dictionary 118 | add_info['domains'] = int(train_orientation[j]) 119 | data[i] = list(data[i]) 120 | data[i].append(add_info) 121 | data[i] = tuple(data[i]) 122 | flag_break = True 123 | del train_name[j] 124 | del train_orientation[j] 125 | # print(i) 126 | j += 1 127 | 128 | cnt = 0 129 | no_title = [] 130 | no_title_local = [] 131 | for line in data: 132 | if len(line) != 4: 133 | assert() 134 | # no_title.append(line[0]) 135 | # tmp1 = line[0].split('/')[-1] 136 | # tmp2 = tmp1.split('_') 137 | # tmp3 = '_'.join(tmp2[2:]) 138 | # for line2 in train_name: 139 | # if tmp3 in line2: 140 | # print(line2) 141 | # no_title_local.append(tmp3) 142 | # cnt += 1 143 | 144 | return data -------------------------------------------------------------------------------- /data/datasets/veriwild.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from .bases import ImageDataset 4 | from ..datasets import DATASET_REGISTRY 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class VeRiWild(ImageDataset): 9 | """VeRi-Wild. 10 | 11 | Reference: 12 | Lou et al. A Large-Scale Dataset for Vehicle Re-Identification in the Wild. CVPR 2019. 13 | 14 | URL: ``_ 15 | 16 | Train dataset statistics: 17 | - identities: 30671. 18 | - images: 277797. 19 | """ 20 | dataset_dir = "VERI-Wild" 21 | dataset_name = "veriwild" 22 | 23 | def __init__(self, root='datasets', query_list='', gallery_list='', **kwargs): 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | 26 | self.image_dir = osp.join(self.dataset_dir, 'images') 27 | self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt') 28 | self.vehicle_info = osp.join(self.dataset_dir, 'train_test_split/vehicle_info.txt') 29 | if query_list and gallery_list: 30 | self.query_list = query_list 31 | self.gallery_list = gallery_list 32 | else: 33 | self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_10000_query.txt') 34 | self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_10000.txt') 35 | 36 | required_files = [ 37 | self.image_dir, 38 | self.train_list, 39 | self.query_list, 40 | self.gallery_list, 41 | self.vehicle_info, 42 | ] 43 | self.check_before_run(required_files) 44 | 45 | self.imgid2vid, self.imgid2camid, self.imgid2imgpath = self.process_vehicle(self.vehicle_info) 46 | 47 | train = self.process_dir(self.train_list) 48 | query = self.process_dir(self.query_list, is_train=False) 49 | gallery = self.process_dir(self.gallery_list, is_train=False) 50 | 51 | super(VeRiWild, self).__init__(train, query, gallery, **kwargs) 52 | 53 | def process_dir(self, img_list, is_train=True): 54 | img_list_lines = open(img_list, 'r').readlines() 55 | 56 | dataset = [] 57 | for idx, line in enumerate(img_list_lines): 58 | line = line.strip() 59 | vid = int(line.split('/')[0]) 60 | imgid = line.split('/')[1] 61 | if is_train: 62 | vid = self.dataset_name + "_" + str(vid) 63 | dataset.append((self.imgid2imgpath[imgid], vid, int(self.imgid2camid[imgid]))) 64 | 65 | assert len(dataset) == len(img_list_lines) 66 | return dataset 67 | 68 | def process_vehicle(self, vehicle_info): 69 | imgid2vid = {} 70 | imgid2camid = {} 71 | imgid2imgpath = {} 72 | vehicle_info_lines = open(vehicle_info, 'r').readlines() 73 | 74 | for idx, line in enumerate(vehicle_info_lines[1:]): 75 | vid = line.strip().split('/')[0] 76 | imgid = line.strip().split(';')[0].split('/')[1] 77 | camid = line.strip().split(';')[1] 78 | # img_path = osp.join(self.image_dir, vid, imgid + '.jpg') 79 | img_path = osp.join(self.image_dir, imgid + '.jpg') 80 | imgid2vid[imgid] = vid 81 | imgid2camid[imgid] = camid 82 | imgid2imgpath[imgid] = img_path 83 | 84 | assert len(imgid2vid) == len(vehicle_info_lines) - 1 85 | return imgid2vid, imgid2camid, imgid2imgpath 86 | 87 | 88 | @DATASET_REGISTRY.register() 89 | class SmallVeRiWild(VeRiWild): 90 | """VeRi-Wild. 91 | Small test dataset statistics: 92 | - identities: 3000. 93 | - images: 41861. 94 | """ 95 | 96 | def __init__(self, root='datasets', **kwargs): 97 | # self.dataset_dir = osp.join(root, self.dataset_dir) 98 | self.query_list = osp.join(root, self.dataset_dir, 'train_test_split/test_3000_query.txt') 99 | self.gallery_list = osp.join(root, self.dataset_dir, 'train_test_split/test_3000.txt') 100 | 101 | super(SmallVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs) 102 | 103 | 104 | @DATASET_REGISTRY.register() 105 | class MediumVeRiWild(VeRiWild): 106 | """VeRi-Wild. 107 | Medium test dataset statistics: 108 | - identities: 5000. 109 | - images: 69389. 110 | """ 111 | 112 | def __init__(self, root='datasets', **kwargs): 113 | # self.dataset_dir = osp.join(root, self.dataset_dir) 114 | self.query_list = osp.join(root, self.dataset_dir, 'train_test_split/test_5000_query.txt') 115 | self.gallery_list = osp.join(root, self.dataset_dir, 'train_test_split/test_5000.txt') 116 | 117 | super(MediumVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs) 118 | 119 | 120 | @DATASET_REGISTRY.register() 121 | class LargeVeRiWild(VeRiWild): 122 | """VeRi-Wild. 123 | Large test dataset statistics: 124 | - identities: 10000. 125 | - images: 138517. 126 | """ 127 | 128 | def __init__(self, root='datasets', **kwargs): 129 | # self.dataset_dir = osp.join(root, self.dataset_dir) 130 | self.query_list = osp.join(root, self.dataset_dir, 'train_test_split/test_10000_query.txt') 131 | self.gallery_list = osp.join(root, self.dataset_dir, 'train_test_split/test_10000.txt') 132 | 133 | super(LargeVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs) 134 | -------------------------------------------------------------------------------- /data/datasets/viper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from . import DATASET_REGISTRY 5 | from .bases import ImageDataset 6 | 7 | __all__ = ['VIPeR', ] 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class VIPeR(ImageDataset): 12 | dataset_dir = "VIPeR" 13 | dataset_name = "viper" 14 | 15 | def __init__(self, root='datasets', **kwargs): 16 | self.root = root 17 | self.train_path = os.path.join(self.root, self.dataset_dir) 18 | 19 | required_files = [self.train_path] 20 | self.check_before_run(required_files) 21 | 22 | train = self.process_train(self.train_path) 23 | 24 | super().__init__(train, [], [], **kwargs) 25 | 26 | def process_train(self, train_path): 27 | data = [] 28 | 29 | file_path_list = ['cam_a', 'cam_b'] 30 | 31 | for file_path in file_path_list: 32 | camid = self.dataset_name + "_" + file_path 33 | img_list = glob(os.path.join(train_path, file_path, "*.bmp")) 34 | for img_path in img_list: 35 | img_name = img_path.split('/')[-1] 36 | pid = self.dataset_name + "_" + img_name.split('_')[0] 37 | data.append([img_path, pid, camid]) 38 | 39 | return data -------------------------------------------------------------------------------- /data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler, DomainSuffleSampler, RandomIdentitySampler 2 | from .data_sampler import TrainingSampler, InferenceSampler 3 | -------------------------------------------------------------------------------- /data/samplers/data_sampler.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Optional 3 | 4 | import numpy as np 5 | from torch.utils.data import Sampler 6 | 7 | from utils import comm 8 | 9 | 10 | class TrainingSampler(Sampler): 11 | """ 12 | In training, we only care about the "infinite stream" of training data. 13 | So this sampler produces an infinite stream of indices and 14 | all workers cooperate to correctly shuffle the indices and sample different indices. 15 | The samplers in each worker effectively produces `indices[worker_id::num_workers]` 16 | where `indices` is an infinite stream of indices consisting of 17 | `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) 18 | or `range(size) + range(size) + ...` (if shuffle is False) 19 | """ 20 | 21 | def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): 22 | """ 23 | Args: 24 | size (int): the total number of data of the underlying dataset to sample from 25 | shuffle (bool): whether to shuffle the indices or not 26 | seed (int): the initial seed of the shuffle. Must be the same 27 | across all workers. If None, will use a random seed shared 28 | among workers (require synchronization among all workers). 29 | """ 30 | self._size = size 31 | assert size > 0 32 | self._shuffle = shuffle 33 | if seed is None: 34 | seed = comm.shared_random_seed() 35 | self._seed = int(seed) 36 | 37 | self._rank = comm.get_rank() 38 | self._world_size = comm.get_world_size() 39 | 40 | def __iter__(self): 41 | start = self._rank 42 | yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) 43 | 44 | def _infinite_indices(self): 45 | np.random.seed(self._seed) 46 | while True: 47 | if self._shuffle: 48 | yield from np.random.permutation(self._size) 49 | else: 50 | yield from np.arange(self._size) 51 | 52 | 53 | class InferenceSampler(Sampler): 54 | """ 55 | Produce indices for inference. 56 | Inference needs to run on the __exact__ set of samples, 57 | therefore when the total number of samples is not divisible by the number of workers, 58 | this sampler produces different number of samples on different workers. 59 | """ 60 | 61 | def __init__(self, size: int): 62 | """ 63 | Args: 64 | size (int): the total number of data of the underlying dataset to sample from 65 | """ 66 | self._size = size 67 | assert size > 0 68 | 69 | begin = 0 70 | end = self._size 71 | self._local_indices = range(begin, end) 72 | 73 | def __iter__(self): 74 | yield from self._local_indices 75 | 76 | def __len__(self): 77 | return len(self._local_indices) 78 | -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_transforms 2 | from .transforms import * 3 | from .autoaugment import * 4 | -------------------------------------------------------------------------------- /data/transforms/build.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as T 3 | 4 | from .transforms import * 5 | from .autoaugment import AutoAugment 6 | from PIL import Image, ImageFilter, ImageOps 7 | 8 | from .transforms import LGT 9 | 10 | class GaussianBlur(object): 11 | """ 12 | Apply Gaussian Blur to the PIL image. 13 | """ 14 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): 15 | self.prob = p 16 | self.radius_min = radius_min 17 | self.radius_max = radius_max 18 | 19 | def __call__(self, img): 20 | do_it = random.random() <= self.prob 21 | if not do_it: 22 | return img 23 | 24 | return img.filter( 25 | ImageFilter.GaussianBlur( 26 | radius=random.uniform(self.radius_min, self.radius_max) 27 | ) 28 | ) 29 | 30 | 31 | class Solarization(object): 32 | """ 33 | Apply Solarization to the PIL image. 34 | """ 35 | def __init__(self, p): 36 | self.p = p 37 | 38 | def __call__(self, img): 39 | if random.random() < self.p: 40 | return ImageOps.solarize(img) 41 | else: 42 | return img 43 | 44 | def build_transforms(cfg, is_train=True, is_fake=False): 45 | res = [] 46 | 47 | if is_train: 48 | size_train = cfg.INPUT.SIZE_TRAIN 49 | 50 | # augmix augmentation 51 | do_augmix = cfg.INPUT.DO_AUGMIX 52 | 53 | # auto augmentation 54 | do_autoaug = cfg.INPUT.DO_AUTOAUG 55 | # total_iter = cfg.SOLVER.MAX_ITER 56 | total_iter = cfg.SOLVER.MAX_EPOCHS 57 | 58 | # horizontal filp 59 | do_flip = cfg.INPUT.DO_FLIP 60 | flip_prob = cfg.INPUT.FLIP_PROB 61 | 62 | # padding 63 | do_pad = cfg.INPUT.DO_PAD 64 | padding = cfg.INPUT.PADDING 65 | padding_mode = cfg.INPUT.PADDING_MODE 66 | 67 | # Local Grayscale Transfomation 68 | do_lgt = cfg.INPUT.LGT.DO_LGT 69 | lgt_prob = cfg.INPUT.LGT.PROB 70 | 71 | # color jitter 72 | do_cj = cfg.INPUT.CJ.ENABLED 73 | cj_prob = cfg.INPUT.CJ.PROB 74 | cj_brightness = cfg.INPUT.CJ.BRIGHTNESS 75 | cj_contrast = cfg.INPUT.CJ.CONTRAST 76 | cj_saturation = cfg.INPUT.CJ.SATURATION 77 | cj_hue = cfg.INPUT.CJ.HUE 78 | 79 | # random erasing 80 | do_rea = cfg.INPUT.REA.ENABLED 81 | rea_prob = cfg.INPUT.REA.PROB 82 | rea_mean = cfg.INPUT.REA.MEAN 83 | # random patch 84 | do_rpt = cfg.INPUT.RPT.ENABLED 85 | rpt_prob = cfg.INPUT.RPT.PROB 86 | 87 | if do_autoaug: 88 | res.append(AutoAugment(total_iter)) 89 | res.append(T.Resize(size_train, interpolation=3)) 90 | if do_flip: 91 | res.append(T.RandomHorizontalFlip(p=flip_prob)) 92 | if do_pad: 93 | res.extend([T.Pad(padding, padding_mode=padding_mode), 94 | T.RandomCrop(size_train)]) 95 | if do_lgt: 96 | res.append(LGT(lgt_prob)) 97 | if do_cj: 98 | res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob)) 99 | if do_augmix: 100 | res.append(AugMix()) 101 | # if do_rea: 102 | # res.append(RandomErasing(probability=rea_prob, mean=rea_mean, sh=1/3)) 103 | if do_rpt: 104 | res.append(RandomPatch(prob_happen=rpt_prob)) 105 | if is_fake: 106 | if cfg.META.DATA.SYNTH_FLAG == 'jitter': 107 | res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=1.0)) 108 | elif cfg.META.DATA.SYNTH_FLAG == 'augmix': 109 | res.append(AugMix()) 110 | elif cfg.META.DATA.SYNTH_FLAG == 'both': 111 | res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob)) 112 | res.append(AugMix()) 113 | res.extend([ 114 | T.ToTensor(), 115 | T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]) 116 | ]) 117 | if do_rea: 118 | from timm.data.random_erasing import RandomErasing as RE 119 | res.append(RE(probability=rea_prob, mode='pixel', max_count=1, device='cpu')) 120 | else: 121 | size_test = cfg.INPUT.SIZE_TEST 122 | res.append(T.Resize(size_test, interpolation=3)) 123 | res.extend([ 124 | T.ToTensor(), 125 | T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]) 126 | ]) 127 | return T.Compose(res) 128 | -------------------------------------------------------------------------------- /data/transforms/functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image, ImageOps, ImageEnhance 4 | 5 | 6 | def to_tensor(pic): 7 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 8 | 9 | See ``ToTensor`` for more details. 10 | 11 | Args: 12 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 13 | 14 | Returns: 15 | Tensor: Converted image. 16 | """ 17 | if isinstance(pic, np.ndarray): 18 | assert len(pic.shape) in (2, 3) 19 | # handle numpy array 20 | if pic.ndim == 2: 21 | pic = pic[:, :, None] 22 | 23 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 24 | # backward compatibility 25 | if isinstance(img, torch.ByteTensor): 26 | return img.float() 27 | else: 28 | return img 29 | 30 | # handle PIL Image 31 | if pic.mode == 'I': 32 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 33 | elif pic.mode == 'I;16': 34 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 35 | elif pic.mode == 'F': 36 | img = torch.from_numpy(np.array(pic, np.float32, copy=False)) 37 | elif pic.mode == '1': 38 | img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) 39 | else: 40 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 41 | # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK 42 | if pic.mode == 'YCbCr': 43 | nchannel = 3 44 | elif pic.mode == 'I;16': 45 | nchannel = 1 46 | else: 47 | nchannel = len(pic.mode) 48 | img = img.view(pic.size[1], pic.size[0], nchannel) 49 | # put it from HWC to CHW format 50 | # yikes, this transpose takes 80% of the loading time/CPU 51 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 52 | if isinstance(img, torch.ByteTensor): 53 | return img.float() 54 | else: 55 | return img 56 | 57 | 58 | def int_parameter(level, maxval): 59 | """Helper function to scale `val` between 0 and maxval . 60 | Args: 61 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 62 | maxval: Maximum value that the operation can have. This will be scaled to 63 | level/PARAMETER_MAX. 64 | Returns: 65 | An int that results from scaling `maxval` according to `level`. 66 | """ 67 | return int(level * maxval / 10) 68 | 69 | 70 | def float_parameter(level, maxval): 71 | """Helper function to scale `val` between 0 and maxval. 72 | Args: 73 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 74 | maxval: Maximum value that the operation can have. This will be scaled to 75 | level/PARAMETER_MAX. 76 | Returns: 77 | A float that results from scaling `maxval` according to `level`. 78 | """ 79 | return float(level) * maxval / 10. 80 | 81 | 82 | def sample_level(n): 83 | return np.random.uniform(low=0.1, high=n) 84 | 85 | 86 | def autocontrast(pil_img, *args): 87 | return ImageOps.autocontrast(pil_img) 88 | 89 | 90 | def equalize(pil_img, *args): 91 | return ImageOps.equalize(pil_img) 92 | 93 | 94 | def posterize(pil_img, level, *args): 95 | level = int_parameter(sample_level(level), 4) 96 | return ImageOps.posterize(pil_img, 4 - level) 97 | 98 | 99 | def rotate(pil_img, level, *args): 100 | degrees = int_parameter(sample_level(level), 30) 101 | if np.random.uniform() > 0.5: 102 | degrees = -degrees 103 | return pil_img.rotate(degrees, resample=Image.BILINEAR) 104 | 105 | 106 | def solarize(pil_img, level, *args): 107 | level = int_parameter(sample_level(level), 256) 108 | return ImageOps.solarize(pil_img, 256 - level) 109 | 110 | 111 | def shear_x(pil_img, level, image_size): 112 | level = float_parameter(sample_level(level), 0.3) 113 | if np.random.uniform() > 0.5: 114 | level = -level 115 | return pil_img.transform(image_size, 116 | Image.AFFINE, (1, level, 0, 0, 1, 0), 117 | resample=Image.BILINEAR) 118 | 119 | 120 | def shear_y(pil_img, level, image_size): 121 | level = float_parameter(sample_level(level), 0.3) 122 | if np.random.uniform() > 0.5: 123 | level = -level 124 | return pil_img.transform(image_size, 125 | Image.AFFINE, (1, 0, 0, level, 1, 0), 126 | resample=Image.BILINEAR) 127 | 128 | 129 | def translate_x(pil_img, level, image_size): 130 | level = int_parameter(sample_level(level), image_size[0] / 3) 131 | if np.random.random() > 0.5: 132 | level = -level 133 | return pil_img.transform(image_size, 134 | Image.AFFINE, (1, 0, level, 0, 1, 0), 135 | resample=Image.BILINEAR) 136 | 137 | 138 | def translate_y(pil_img, level, image_size): 139 | level = int_parameter(sample_level(level), image_size[1] / 3) 140 | if np.random.random() > 0.5: 141 | level = -level 142 | return pil_img.transform(image_size, 143 | Image.AFFINE, (1, 0, 0, 0, 1, level), 144 | resample=Image.BILINEAR) 145 | 146 | 147 | # operation that overlaps with ImageNet-C's test set 148 | def color(pil_img, level, *args): 149 | level = float_parameter(sample_level(level), 1.8) + 0.1 150 | return ImageEnhance.Color(pil_img).enhance(level) 151 | 152 | 153 | # operation that overlaps with ImageNet-C's test set 154 | def contrast(pil_img, level, *args): 155 | level = float_parameter(sample_level(level), 1.8) + 0.1 156 | return ImageEnhance.Contrast(pil_img).enhance(level) 157 | 158 | 159 | # operation that overlaps with ImageNet-C's test set 160 | def brightness(pil_img, level, *args): 161 | level = float_parameter(sample_level(level), 1.8) + 0.1 162 | return ImageEnhance.Brightness(pil_img).enhance(level) 163 | 164 | 165 | # operation that overlaps with ImageNet-C's test set 166 | def sharpness(pil_img, level, *args): 167 | level = float_parameter(sample_level(level), 1.8) + 0.1 168 | return ImageEnhance.Sharpness(pil_img).enhance(level) 169 | 170 | 171 | augmentations_reid = [ 172 | autocontrast, equalize, posterize, shear_x, shear_y, 173 | color, contrast, brightness, sharpness 174 | ] 175 | 176 | augmentations = [ 177 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 178 | translate_x, translate_y 179 | ] 180 | 181 | augmentations_all = [ 182 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 183 | translate_x, translate_y, color, contrast, brightness, sharpness 184 | ] 185 | -------------------------------------------------------------------------------- /enviroments.sh: -------------------------------------------------------------------------------- 1 | pip install torch torchvision torchaudio 2 | pip install einops 3 | pip install timm 4 | pip install scikit-image 5 | pip install opencv-python 6 | pip install tensorboard 7 | pip install yacs 8 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_loss import make_loss 2 | from .arcface import ArcFace 3 | from .smooth import * 4 | from .myloss import * 5 | -------------------------------------------------------------------------------- /loss/arcface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | import math 6 | 7 | 8 | class ArcFace(nn.Module): 9 | def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False): 10 | super(ArcFace, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.s = s 14 | self.m = m 15 | self.cos_m = math.cos(m) 16 | self.sin_m = math.sin(m) 17 | 18 | self.th = math.cos(math.pi - m) 19 | self.mm = math.sin(math.pi - m) * m 20 | 21 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 22 | if bias: 23 | self.bias = Parameter(torch.Tensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 30 | if self.bias is not None: 31 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 32 | bound = 1 / math.sqrt(fan_in) 33 | nn.init.uniform_(self.bias, -bound, bound) 34 | 35 | def forward(self, input, label): 36 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 37 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 38 | phi = cosine * self.cos_m - sine * self.sin_m 39 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 40 | # --------------------------- convert label to one-hot --------------------------- 41 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 42 | one_hot = torch.zeros(cosine.size(), device='cuda') 43 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 44 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 45 | output = (one_hot * phi) + ( 46 | (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 47 | output *= self.s 48 | # print(output) 49 | 50 | return output 51 | 52 | class CircleLoss(nn.Module): 53 | def __init__(self, in_features, num_classes, s=256, m=0.25): 54 | super(CircleLoss, self).__init__() 55 | self.weight = Parameter(torch.Tensor(num_classes, in_features)) 56 | self.s = s 57 | self.m = m 58 | self._num_classes = num_classes 59 | self.reset_parameters() 60 | 61 | 62 | def reset_parameters(self): 63 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 64 | 65 | def __call__(self, bn_feat, targets): 66 | 67 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) 68 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 69 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 70 | delta_p = 1 - self.m 71 | delta_n = self.m 72 | 73 | s_p = self.s * alpha_p * (sim_mat - delta_p) 74 | s_n = self.s * alpha_n * (sim_mat - delta_n) 75 | 76 | targets = F.one_hot(targets, num_classes=self._num_classes) 77 | 78 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 79 | 80 | return pred_class_logits -------------------------------------------------------------------------------- /loss/build_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 3 | from .triplet_loss import TripletLoss 4 | from .center_loss import CenterLoss 5 | from .ce_labelSmooth import CrossEntropyLabelSmooth as CE_LS 6 | 7 | feat_dim_dict = { 8 | 'local_attention_vit': 768, 9 | 'vit': 768, 10 | 'resnet18': 512, 11 | 'resnet34': 512 12 | } 13 | 14 | def build_loss(cfg, num_classes): 15 | name = cfg.MODEL.NAME 16 | sampler = cfg.DATALOADER.SAMPLER 17 | if cfg.MODEL.NAME not in feat_dim_dict.keys(): 18 | feat_dim = 2048 19 | else: 20 | feat_dim = feat_dim_dict[cfg.MODEL.NAME] 21 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 22 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE: 23 | if cfg.MODEL.NO_MARGIN: 24 | triplet = TripletLoss() 25 | print("using soft triplet loss for training") 26 | else: 27 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 28 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN)) 29 | else: 30 | print('expected METRIC_LOSS_TYPE should be triplet' 31 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 32 | 33 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 34 | if name == 'local_attention_vit' and cfg.MODEL.PC_LOSS: 35 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 36 | else: 37 | xent = CE_LS(num_classes=num_classes) 38 | print("label smooth on, numclasses:", num_classes) 39 | 40 | if sampler == 'softmax': # softmax loss only 41 | def loss_func(score, feat, target): 42 | return F.cross_entropy(score, target) 43 | 44 | # softmax & triplet 45 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet' or 'GS': 46 | def loss_func(score, feat, target, domains=None, t_domains=None, all_posvid=None, soft_label=False, soft_weight=0.1, soft_lambda=0.2): 47 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 48 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 49 | if name == 'local_attention_vit' and cfg.MODEL.PC_LOSS: 50 | ID_LOSS = xent(score, target, all_posvid=all_posvid, soft_label=soft_label,soft_weight=soft_weight, soft_lambda=soft_lambda) 51 | else: 52 | ID_LOSS = xent(score, target) 53 | else: 54 | ID_LOSS = F.cross_entropy(score, target) 55 | 56 | TRI_LOSS = triplet(feat, target)[0] 57 | # DOMAIN_LOSS = xent(domains, t_domains) 58 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 59 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 60 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center': 61 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 62 | return xent(score, target) + \ 63 | triplet(feat, target)[0] + \ 64 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 65 | else: 66 | return F.cross_entropy(score, target) + \ 67 | triplet(feat, target)[0] + \ 68 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 69 | else: 70 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center' 71 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 72 | 73 | else: 74 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center' 75 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 76 | return loss_func, center_criterion 77 | 78 | 79 | -------------------------------------------------------------------------------- /loss/ce_labelSmooth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | class CrossEntropyLabelSmooth(nn.Module): 5 | """Cross entropy loss with label smoothing regularizer. 6 | 7 | Reference: 8 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 9 | Equation: y = (1 - epsilon) * y + epsilon / K. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | epsilon (float): weight. 14 | """ 15 | 16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 17 | super(CrossEntropyLabelSmooth, self).__init__() 18 | self.num_classes = num_classes 19 | self.epsilon = epsilon 20 | self.use_gpu = use_gpu 21 | self.logsoftmax = nn.LogSoftmax(dim=1) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (num_classes) 28 | """ 29 | log_probs = self.logsoftmax(inputs) 30 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 31 | if self.use_gpu: targets = targets.cuda() 32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 33 | loss = (- targets * log_probs).mean(0).sum() 34 | return loss 35 | 36 | class LabelSmoothingCrossEntropy(nn.Module): 37 | """ 38 | NLL loss with label smoothing. 39 | """ 40 | def __init__(self, smoothing=0.1): 41 | """ 42 | Constructor for the LabelSmoothing module. 43 | :param smoothing: label smoothing factor 44 | """ 45 | super(LabelSmoothingCrossEntropy, self).__init__() 46 | assert smoothing < 1.0 47 | self.smoothing = smoothing 48 | self.confidence = 1. - smoothing 49 | 50 | def forward(self, x, target): 51 | logprobs = F.log_softmax(x, dim=-1) 52 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 53 | nll_loss = nll_loss.squeeze(1) 54 | smooth_loss = -logprobs.mean(dim=-1) 55 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 56 | return loss.mean() -------------------------------------------------------------------------------- /loss/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | # distmat = torch.addmm(distmat, x, self.centers.t(), beta=1, alpha=-2) 41 | distmat.addmm_(1, -2, x.float(), self.centers.t()) 42 | 43 | classes = torch.arange(self.num_classes).long() 44 | if self.use_gpu: classes = classes.cuda() 45 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 46 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 47 | 48 | dist = [] 49 | for i in range(batch_size): 50 | value = distmat[i][mask[i]] 51 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 52 | dist.append(value) 53 | dist = torch.cat(dist) 54 | loss = dist.mean() 55 | return loss 56 | 57 | 58 | if __name__ == '__main__': 59 | use_gpu = False 60 | center_loss = CenterLoss(use_gpu=use_gpu) 61 | features = torch.rand(16, 2048) 62 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 63 | if use_gpu: 64 | features = torch.rand(16, 2048).cuda() 65 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 66 | 67 | loss = center_loss(features, targets) 68 | print(loss) 69 | -------------------------------------------------------------------------------- /loss/make_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 3 | from .triplet_loss import TripletLoss 4 | from .center_loss import CenterLoss 5 | 6 | 7 | def make_loss(cfg, num_classes): 8 | sampler = cfg.DATALOADER.SAMPLER 9 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE: 10 | if cfg.MODEL.NO_MARGIN: 11 | triplet = TripletLoss() 12 | print("using soft triplet loss for training") 13 | else: 14 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 15 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN)) 16 | else: 17 | print('expected METRIC_LOSS_TYPE should be triplet' 18 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 19 | 20 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 21 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 22 | print("label smooth on, numclasses:", num_classes) 23 | 24 | if sampler == 'softmax': 25 | def loss_func(score, feat, target): 26 | return F.cross_entropy(score, target) 27 | 28 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 29 | def loss_func(score, feat, target, target_cam): 30 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 31 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 32 | if isinstance(score, list): 33 | ID_LOSS = [xent(scor, target) for scor in score[1:]] 34 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 35 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target) 36 | else: 37 | ID_LOSS = xent(score, target) 38 | 39 | if isinstance(feat, list): 40 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 41 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 42 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 43 | else: 44 | TRI_LOSS = triplet(feat, target)[0] 45 | 46 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 47 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 48 | else: 49 | if isinstance(score, list): 50 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score[1:]] 51 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 52 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * F.cross_entropy(score[0], target) 53 | else: 54 | ID_LOSS = F.cross_entropy(score, target) 55 | 56 | if isinstance(feat, list): 57 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 58 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 59 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 60 | else: 61 | TRI_LOSS = triplet(feat, target)[0] 62 | 63 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 64 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 65 | else: 66 | print('expected METRIC_LOSS_TYPE should be triplet' 67 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 68 | 69 | else: 70 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center' 71 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 72 | return loss_func 73 | 74 | 75 | -------------------------------------------------------------------------------- /loss/myloss.py: -------------------------------------------------------------------------------- 1 | from doctest import FAIL_FAST 2 | from importlib.resources import path 3 | 4 | from numpy import tensordot 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class Pedal(nn.Module): 10 | 11 | def __init__(self, scale=10, k=10): 12 | super(Pedal, self).__init__() 13 | self.scale =scale 14 | self.k = k 15 | 16 | 17 | def forward(self, feature, centers, position, PatchMemory = None, vid=None, camid=None): 18 | 19 | loss = 0 20 | 21 | 22 | all_posvid = [] 23 | for p in range(feature.size(0)): 24 | part_feat = feature[p, :, :] 25 | part_centers = centers[p, :, :] 26 | m, n = part_feat.size(0), part_centers.size(0) 27 | dist_map = part_feat.pow(2).sum(dim=1, keepdim=True).expand(m, n) + \ 28 | part_centers.pow(2).sum(dim=1, keepdim=True).expand(n, m).t() 29 | dist_map.addmm_(1, -2, part_feat, part_centers.t()) 30 | 31 | trick = torch.arange(dist_map.size(1)).cuda().expand_as(dist_map) 32 | 33 | neg, index = dist_map[trick!=position.unsqueeze(dim=1).expand_as(dist_map)].view(dist_map.size(0), -1).sort(dim=1) 34 | 35 | pos_camid = torch.tensor(PatchMemory.camid).cuda() 36 | pos_camid = pos_camid[(index[:,:self.k])] 37 | flag = pos_camid != camid.unsqueeze(dim=1).expand_as(pos_camid) 38 | 39 | pos_vid = torch.tensor(PatchMemory.vid).cuda() 40 | pos_vid = pos_vid[(index[:,:self.k])] 41 | all_posvid.append(pos_vid) 42 | 43 | x = ((-1 * self.scale * neg[:, :self.k]).exp().sum(dim=1)).log() 44 | 45 | y = ((-1 * self.scale * neg).exp().sum(dim=1)).log() 46 | 47 | l = (-x + y).sum().div(feature.size(1)) 48 | l = torch.where(torch.isnan(l), torch.full_like(l, 0.), l) 49 | loss += l 50 | loss = loss.div(feature.size(0)) 51 | 52 | return loss, all_posvid 53 | 54 | 55 | 56 | class Ipfl(nn.Module): 57 | def __init__(self, margin=1.0, p=2, eps=1e-6, max_iter=15, nearest=3, num=2, swap=False): 58 | 59 | super(Ipfl, self).__init__() 60 | self.margin = margin 61 | self.p = p 62 | self.eps = eps 63 | self.swap = swap 64 | self.max_iter = max_iter 65 | self.num = num 66 | self.nearest = nearest 67 | 68 | 69 | def forward(self, feature, centers): 70 | 71 | image_label = torch.arange(feature.size(0) // self.num).repeat(self.num, 1).transpose(0, 1).contiguous().view(-1) 72 | center_label = torch.arange(feature.size(0) // self.num) 73 | loss = 0 74 | size = 0 75 | 76 | for i in range(0, feature.size(0), 1): 77 | label = image_label[i] 78 | diff = (feature[i, :].expand_as(centers) - centers).pow(self.p).sum(dim=1) 79 | diff = torch.sqrt(diff) 80 | 81 | same = diff[center_label == label] 82 | sorted, index = diff[center_label != label].sort() 83 | trust_diff_label = [] 84 | trust_diff = [] 85 | 86 | # cycle ranking 87 | max_iter = self.max_iter if self.max_iter < index.size(0) else index.size(0) 88 | for j in range(max_iter): 89 | s = centers[center_label != label, :][index[j]] 90 | l = center_label[center_label != label][index[j]] 91 | 92 | sout = (s.expand_as(centers) - centers).pow(self.p).sum(dim=1) 93 | sout = sout.pow(1. / self.p) 94 | 95 | ssorted, sindex = torch.sort(sout) 96 | near = center_label[sindex[:self.nearest]] 97 | if (label not in near): # view as different identity 98 | trust_diff.append(sorted[j]) 99 | trust_diff_label.append(l) 100 | break 101 | 102 | if len(trust_diff) == 0: 103 | trust_diff.append(torch.tensor([0.]).cuda()) 104 | 105 | min_diff = torch.stack(trust_diff, dim=0).min() 106 | 107 | dist_hinge = torch.clamp(self.margin + same.mean() - min_diff, min=0.0) 108 | 109 | size += 1 110 | loss += dist_hinge 111 | 112 | loss = loss / size 113 | return loss 114 | 115 | 116 | class TripletHard(nn.Module): 117 | def __init__(self, margin=1.0, p=2, eps=1e-5, swap=False, norm=False): 118 | super(TripletHard, self).__init__() 119 | self.margin = margin 120 | self.p = p 121 | self.eps = eps 122 | self.swap = swap 123 | self.norm = norm 124 | self.sigma = 3 125 | 126 | 127 | def forward(self, feature, label): 128 | 129 | if self.norm: 130 | feature = feature.div(feature.norm(dim=1).unsqueeze(1)) 131 | loss = 0 132 | 133 | m, n = feature.size(0), feature.size(0) 134 | dist_map = feature.pow(2).sum(dim=1, keepdim=True).expand(m, n) + \ 135 | feature.pow(2).sum(dim=1, keepdim=True).expand(n, m).t() + self.eps 136 | dist_map.addmm_(1, -2, feature, feature.t()).sqrt_() 137 | 138 | sorted, index = dist_map.sort(dim=1) 139 | 140 | for i in range(feature.size(0)): 141 | 142 | same = sorted[i, :][label[index[i, :]] == label[i]] 143 | diff = sorted[i, :][label[index[i, :]] != label[i]] 144 | dist_hinge = torch.clamp(self.margin + same[1] - diff.min(), min=0.0) 145 | loss += dist_hinge 146 | 147 | loss = loss / (feature.size(0)) 148 | return loss 149 | -------------------------------------------------------------------------------- /loss/smooth.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | 4 | 5 | class PatchMemory(object): 6 | 7 | def __init__(self, momentum=0.1, num=1): 8 | 9 | self.name = [] 10 | self.agent = [] 11 | self.momentum = momentum 12 | self.num = num 13 | 14 | self.camid = [] 15 | self.vid = [] 16 | 17 | 18 | 19 | def get_soft_label(self, path, feat_list, vid=None, camid=None): 20 | 21 | feat = torch.stack(feat_list, dim=0) 22 | 23 | feat = feat[:, ::self.num, :] 24 | 25 | 26 | position = [] 27 | 28 | 29 | # update the agent 30 | for j,p in enumerate(path): 31 | 32 | current_soft_feat = feat[:, j, :].detach() 33 | if current_soft_feat.is_cuda: 34 | current_soft_feat = current_soft_feat.cpu() 35 | key = p 36 | if key not in self.name: 37 | self.name.append(key) 38 | self.camid.append(camid[j]) 39 | self.vid.append(vid[j]) 40 | self.agent.append(current_soft_feat) 41 | ind = self.name.index(key) 42 | position.append(ind) 43 | 44 | else: 45 | ind = self.name.index(key) 46 | tmp = self.agent.pop(ind) 47 | tmp = tmp*(1-self.momentum) + self.momentum*current_soft_feat 48 | self.agent.insert(ind, tmp) 49 | position.append(ind) 50 | 51 | if len(position) != 0: 52 | position = torch.tensor(position).cuda() 53 | 54 | agent = torch.stack(self.agent, dim=1).cuda() 55 | return agent, position 56 | 57 | def _dequeue_and_enqueue(self, keys): 58 | # gather keys before updating queue 59 | keys = concat_all_gather(keys) 60 | 61 | batch_size = keys.shape[0] 62 | 63 | ptr = int(self.queue_ptr) 64 | assert self.K % batch_size == 0 # for simplicity 65 | 66 | # replace the keys at ptr (dequeue and enqueue) 67 | self.queue[:, ptr:ptr + batch_size] = keys.T 68 | ptr = (ptr + batch_size) % self.K # move pointer 69 | 70 | self.queue_ptr[0] = ptr 71 | 72 | 73 | # utils 74 | @torch.no_grad() 75 | def concat_all_gather(tensor): 76 | """ 77 | Performs all_gather operation on the provided tensors. 78 | *** Warning ***: torch.distributed.all_gather has no gradient. 79 | """ 80 | tensors_gather = [torch.ones_like(tensor) 81 | for _ in range(torch.distributed.get_world_size())] 82 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 83 | 84 | output = torch.cat(tensors_gather, dim=0) 85 | return output 86 | 87 | 88 | 89 | class SmoothingForImage(object): 90 | def __init__(self, momentum=0.1, num=1): 91 | 92 | self.map = dict() 93 | self.momentum = momentum 94 | self.num = num 95 | 96 | 97 | def get_soft_label(self, path, feature): 98 | 99 | feature = torch.cat(feature, dim=1) 100 | soft_label = [] 101 | 102 | for j,p in enumerate(path): 103 | 104 | current_soft_feat = feature[j*self.num:(j+1)*self.num, :].detach().mean(dim=0) 105 | if current_soft_feat.is_cuda: 106 | current_soft_feat = current_soft_feat.cpu() 107 | 108 | key = p 109 | if key not in self.map: 110 | self.map.setdefault(key, current_soft_feat) 111 | soft_label.append(self.map[key]) 112 | else: 113 | self.map[key] = self.map[key]*(1-self.momentum) + self.momentum*current_soft_feat 114 | soft_label.append(self.map[key]) 115 | soft_label = torch.stack(soft_label, dim=0).cuda() 116 | return soft_label 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /loss/softmax_loss.py: -------------------------------------------------------------------------------- 1 | from cmath import isnan 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | class CrossEntropyLabelSmooth(nn.Module): 6 | """Cross entropy loss with label smoothing regularizer. 7 | 8 | Reference: 9 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 10 | Equation: y = (1 - epsilon) * y + epsilon / K. 11 | 12 | Args: 13 | num_classes (int): number of classes. 14 | epsilon (float): weight. 15 | """ 16 | 17 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 18 | super(CrossEntropyLabelSmooth, self).__init__() 19 | self.num_classes = num_classes 20 | self.epsilon = epsilon 21 | self.use_gpu = use_gpu 22 | self.logsoftmax = nn.LogSoftmax(dim=1) 23 | 24 | def forward(self, inputs, targets, all_posvid=None, soft_label=False, soft_weight=0.1, soft_lambda=0.2): 25 | """ 26 | Args: 27 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 28 | targets: ground truth labels with shape (num_classes) 29 | """ 30 | all_posvid = torch.cat(all_posvid, dim=1) 31 | soft_targets = [] 32 | for i in range(all_posvid.size(0)): 33 | s_id, s_num = torch.unique(all_posvid[i,:], return_counts=True) 34 | sum_num = s_num.sum() 35 | temp = torch.zeros(inputs.size(1)).cuda().scatter_(0, s_id, (soft_lambda/sum_num)*s_num) 36 | soft_targets.append(temp) 37 | 38 | soft_targets = torch.stack(soft_targets, dim=0) 39 | 40 | 41 | log_probs = self.logsoftmax(inputs) 42 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 43 | if self.use_gpu: targets = targets.cuda() 44 | if soft_label: 45 | soft_targets = (1 - soft_lambda) * targets + soft_targets 46 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 47 | loss = (- targets * log_probs).mean(0).sum()*(1 - soft_weight) + \ 48 | (- soft_targets * log_probs).mean(0).sum()*soft_weight 49 | # if torch.isnan(loss).item(): 50 | # print("====nan!!!====\n{}\n{}".format((- targets * log_probs).mean(0).sum(), (- soft_targets * log_probs).mean(0).sum())) 51 | 52 | else: 53 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 54 | loss = (- targets * log_probs).mean(0).sum() 55 | return loss 56 | 57 | class LabelSmoothingCrossEntropy(nn.Module): 58 | """ 59 | NLL loss with label smoothing. 60 | """ 61 | def __init__(self, smoothing=0.1): 62 | """ 63 | Constructor for the LabelSmoothing module. 64 | :param smoothing: label smoothing factor 65 | """ 66 | super(LabelSmoothingCrossEntropy, self).__init__() 67 | assert smoothing < 1.0 68 | self.smoothing = smoothing 69 | self.confidence = 1. - smoothing 70 | 71 | def forward(self, x, target): 72 | logprobs = F.log_softmax(x, dim=-1) 73 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 74 | nll_loss = nll_loss.squeeze(1) 75 | smooth_loss = -logprobs.mean(dim=-1) 76 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 77 | return loss.mean() -------------------------------------------------------------------------------- /loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import torch 3 | from torch import nn 4 | 5 | 6 | def normalize(x, axis=-1): 7 | """Normalizing to unit length along the specified dimension. 8 | Args: 9 | x: pytorch Variable 10 | Returns: 11 | x: pytorch Variable, same shape as input 12 | """ 13 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 14 | return x 15 | 16 | 17 | def euclidean_dist(x, y): 18 | """ 19 | Args: 20 | x: pytorch Variable, with shape [m, d] 21 | y: pytorch Variable, with shape [n, d] 22 | Returns: 23 | dist: pytorch Variable, with shape [m, n] 24 | """ 25 | m, n = x.size(0), y.size(0) 26 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 27 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 28 | dist = xx + yy 29 | dist = dist - 2 * torch.matmul(x, y.t()) 30 | # dist.addmm_(1, -2, x, y.t()) 31 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 32 | return dist 33 | 34 | 35 | def cosine_dist(x, y): 36 | """ 37 | Args: 38 | x: pytorch Variable, with shape [m, d] 39 | y: pytorch Variable, with shape [n, d] 40 | Returns: 41 | dist: pytorch Variable, with shape [m, n] 42 | """ 43 | m, n = x.size(0), y.size(0) 44 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n) 45 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t() 46 | xy_intersection = torch.mm(x, y.t()) 47 | dist = xy_intersection/(x_norm * y_norm) 48 | dist = (1. - dist) / 2 49 | return dist 50 | 51 | 52 | def hard_example_mining(dist_mat, labels, return_inds=False): 53 | """For each anchor, find the hardest positive and negative sample. 54 | Args: 55 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 56 | labels: pytorch LongTensor, with shape [N] 57 | return_inds: whether to return the indices. Save time if `False`(?) 58 | Returns: 59 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 60 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 61 | p_inds: pytorch LongTensor, with shape [N]; 62 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 63 | n_inds: pytorch LongTensor, with shape [N]; 64 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 65 | NOTE: Only consider the case in which all labels have same num of samples, 66 | thus we can cope with all anchors in parallel. 67 | """ 68 | 69 | assert len(dist_mat.size()) == 2 70 | assert dist_mat.size(0) == dist_mat.size(1) 71 | N = dist_mat.size(0) 72 | 73 | # shape [N, N] 74 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 75 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 76 | 77 | # `dist_ap` means distance(anchor, positive) 78 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 79 | dist_ap, relative_p_inds = torch.max( 80 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 81 | # print(dist_mat[is_pos].shape) 82 | # `dist_an` means distance(anchor, negative) 83 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 84 | dist_an, relative_n_inds = torch.min( 85 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 86 | # shape [N] 87 | dist_ap = dist_ap.squeeze(1) 88 | dist_an = dist_an.squeeze(1) 89 | 90 | if return_inds: 91 | # shape [N, N] 92 | ind = (labels.new().resize_as_(labels) 93 | .copy_(torch.arange(0, N).long()) 94 | .unsqueeze(0).expand(N, N)) 95 | # shape [N, 1] 96 | p_inds = torch.gather( 97 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 98 | n_inds = torch.gather( 99 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 100 | # shape [N] 101 | p_inds = p_inds.squeeze(1) 102 | n_inds = n_inds.squeeze(1) 103 | return dist_ap, dist_an, p_inds, n_inds 104 | 105 | return dist_ap, dist_an 106 | 107 | 108 | class TripletLoss(object): 109 | """ 110 | Triplet loss using HARDER example mining, 111 | modified based on original triplet loss using hard example mining 112 | """ 113 | 114 | def __init__(self, margin=None, hard_factor=0.0): 115 | self.margin = margin 116 | self.hard_factor = hard_factor 117 | if margin is not None: 118 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 119 | else: 120 | self.ranking_loss = nn.SoftMarginLoss() 121 | 122 | def __call__(self, global_feat, labels, normalize_feature=False): 123 | if normalize_feature: 124 | global_feat = normalize(global_feat, axis=-1) 125 | dist_mat = euclidean_dist(global_feat, global_feat) 126 | dist_ap, dist_an = hard_example_mining(dist_mat, labels) 127 | 128 | dist_ap *= (1.0 + self.hard_factor) 129 | dist_an *= (1.0 - self.hard_factor) 130 | 131 | y = dist_an.new().resize_as_(dist_an).fill_(1) 132 | if self.margin is not None: 133 | loss = self.ranking_loss(dist_an, dist_ap, y) 134 | else: 135 | # min_mat = dist_an.new().resize_as_(dist_an).fill_(-85) 136 | # input = max(min_mat, dist_an - dist_ap) 137 | input = dist_an - dist_ap 138 | loss = self.ranking_loss(input, y) 139 | return loss, dist_ap, dist_an 140 | 141 | 142 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /model/backbones/IBN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class IBN(nn.Module): 6 | r"""Instance-Batch Normalization layer from 7 | `"Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net" 8 | ` 9 | 10 | Args: 11 | planes (int): Number of channels for the input tensor 12 | ratio (float): Ratio of instance normalization in the IBN layer 13 | """ 14 | def __init__(self, planes, ratio=0.5): 15 | super(IBN, self).__init__() 16 | self.half = int(planes * ratio) 17 | self.IN = nn.InstanceNorm2d(self.half, affine=True) 18 | self.BN = nn.BatchNorm2d(planes - self.half) 19 | 20 | def forward(self, x): 21 | split = torch.split(x, self.half, 1) 22 | out1 = self.IN(split[0].contiguous()) 23 | out2 = self.BN(split[1].contiguous()) 24 | out = torch.cat((out1, out2), 1) 25 | return out 26 | 27 | 28 | class SELayer(nn.Module): 29 | def __init__(self, channel, reduction=16): 30 | super(SELayer, self).__init__() 31 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 32 | self.fc = nn.Sequential( 33 | nn.Linear(channel, int(channel/reduction), bias=False), 34 | nn.ReLU(inplace=True), 35 | nn.Linear(int(channel/reduction), channel, bias=False), 36 | nn.Sigmoid() 37 | ) 38 | 39 | def forward(self, x): 40 | b, c, _, _ = x.size() 41 | y = self.avg_pool(x).view(b, c) 42 | y = self.fc(y).view(b, c, 1, 1) 43 | return x * y.expand_as(x) -------------------------------------------------------------------------------- /model/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .vit_pytorch import vit_base_patch16_224_TransReID, vit_small_patch16_224_TransReID, deit_small_patch16_224_TransReID -------------------------------------------------------------------------------- /model/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, last_stride=2, block=Bottleneck,layers=[3, 4, 6, 3]): 86 | self.inplanes = 64 87 | super().__init__() 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | self.bn1 = nn.BatchNorm2d(64) 91 | # self.relu = nn.ReLU(inplace=True) # add missed relu 92 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=None, padding=0) 93 | self.layer1 = self._make_layer(block, 64, layers[0]) 94 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 96 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 97 | 98 | def _make_layer(self, block, planes, blocks, stride=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | downsample = nn.Sequential( 102 | nn.Conv2d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm2d(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x, cam_label=None): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | # x = self.relu(x) # add missed relu 119 | x = self.maxpool(x) 120 | x = self.layer1(x) 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | 125 | return x 126 | 127 | def load_param(self, model_path): 128 | param_dict = torch.load(model_path) 129 | for i in param_dict: 130 | if 'fc' in i: 131 | continue 132 | self.state_dict()[i].copy_(param_dict[i]) 133 | 134 | def random_init(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2. / n)) 139 | elif isinstance(m, nn.BatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | 143 | def compute_num_params(self): 144 | total = sum([param.nelement() for param in self.parameters()]) 145 | print("Number of parameter: %.2fM" % (total/1e6)) -------------------------------------------------------------------------------- /processor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyuke65535/Part-Aware-Transformer/104d42e8292f7e5d534689ded15e4afafb453785/processor/__init__.py -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # train PAT 2 | python train.py --config_file "config/PAT.yml" -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import WarmupMultiStepLR 2 | from .make_optimizer import make_optimizer -------------------------------------------------------------------------------- /solver/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import torch 10 | 11 | from .scheduler import Scheduler 12 | 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | class CosineLRScheduler(Scheduler): 18 | """ 19 | Cosine decay with restarts. 20 | This is described in the paper https://arxiv.org/abs/1608.03983. 21 | 22 | Inspiration from 23 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 24 | """ 25 | 26 | def __init__(self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 51 | "rate since t_initial = t_mul = eta_mul = 1.") 52 | self.t_initial = t_initial 53 | self.t_mul = t_mul 54 | self.lr_min = lr_min 55 | self.decay_rate = decay_rate 56 | self.cycle_limit = cycle_limit 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | self.warmup_prefix = warmup_prefix 60 | self.t_in_epochs = t_in_epochs 61 | if self.warmup_t: 62 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 63 | super().update_groups(self.warmup_lr_init) 64 | else: 65 | self.warmup_steps = [1 for _ in self.base_values] 66 | 67 | def _get_lr(self, t): 68 | if t < self.warmup_t: 69 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 70 | else: 71 | if self.warmup_prefix: 72 | t = t - self.warmup_t 73 | 74 | if self.t_mul != 1: 75 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 76 | t_i = self.t_mul ** i * self.t_initial 77 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 78 | else: 79 | i = t // self.t_initial 80 | t_i = self.t_initial 81 | t_curr = t - (self.t_initial * i) 82 | 83 | gamma = self.decay_rate ** i 84 | lr_min = self.lr_min * gamma 85 | lr_max_values = [v * gamma for v in self.base_values] 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | lrs = [ 89 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 90 | ] 91 | else: 92 | lrs = [self.lr_min for _ in self.base_values] 93 | 94 | return lrs 95 | 96 | def get_epoch_values(self, epoch: int): 97 | if self.t_in_epochs: 98 | return self._get_lr(epoch) 99 | else: 100 | return None 101 | 102 | def get_update_values(self, num_updates: int): 103 | if not self.t_in_epochs: 104 | return self._get_lr(num_updates) 105 | else: 106 | return None 107 | 108 | def get_cycle_length(self, cycles=0): 109 | if not cycles: 110 | cycles = self.cycle_limit 111 | cycles = max(1, cycles) 112 | if self.t_mul == 1.0: 113 | return self.t_initial * cycles 114 | else: 115 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 116 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, # steps 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def _get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /solver/make_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def make_optimizer(cfg, model): 5 | params = [] 6 | for key, value in model.named_parameters(): 7 | if not value.requires_grad: 8 | continue 9 | lr = cfg.SOLVER.BASE_LR 10 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 11 | if "bias" in key: 12 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 14 | if cfg.SOLVER.LARGE_FC_LR: 15 | if "classifier" in key or "arcface" in key: 16 | lr = cfg.SOLVER.BASE_LR * 2 17 | print('Using two times learning rate for fc ') 18 | 19 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 20 | 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 23 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 24 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 25 | else: 26 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 27 | 28 | return optimizer 29 | -------------------------------------------------------------------------------- /solver/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /solver/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | 6 | 7 | def create_scheduler(cfg, optimizer): 8 | num_epochs = 120 9 | # type 1 10 | # lr_min = 0.01 * cfg.SOLVER.BASE_LR 11 | # warmup_lr_init = 0.001 * cfg.SOLVER.BASE_LR 12 | # type 2 13 | lr_min = 0.002 * cfg.SOLVER.BASE_LR 14 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 15 | # type 3 16 | # lr_min = 0.001 * cfg.SOLVER.BASE_LR 17 | # warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 18 | 19 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS 20 | noise_range = None 21 | 22 | lr_scheduler = CosineLRScheduler( 23 | optimizer, 24 | t_initial=num_epochs, 25 | lr_min=lr_min, 26 | t_mul= 1., 27 | decay_rate=0.1, 28 | warmup_lr_init=warmup_lr_init, 29 | warmup_t=warmup_t, 30 | cycle_limit=1, 31 | t_in_epochs=True, 32 | noise_range_t=noise_range, 33 | noise_pct= 0.67, 34 | noise_std= 1., 35 | noise_seed=42, 36 | ) 37 | 38 | return lr_scheduler 39 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from config import cfg 3 | import argparse 4 | from data.build_DG_dataloader import build_reid_test_loader 5 | from model import make_model 6 | from processor.part_attention_vit_processor import do_inference as do_inf_pat 7 | from processor.ori_vit_processor_with_amp import do_inference as do_inf 8 | from utils.logger import setup_logger 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser(description="ReID Training") 13 | parser.add_argument( 14 | "--config_file", default="./config/PAT.yml", help="path to config file", type=str 15 | ) 16 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 17 | nargs=argparse.REMAINDER) 18 | 19 | args = parser.parse_args() 20 | 21 | 22 | 23 | if args.config_file != "": 24 | cfg.merge_from_file(args.config_file) 25 | cfg.merge_from_list(args.opts) 26 | cfg.freeze() 27 | 28 | output_dir = os.path.join(cfg.LOG_ROOT, cfg.LOG_NAME) 29 | if output_dir and not os.path.exists(output_dir): 30 | os.makedirs(output_dir) 31 | 32 | logger = setup_logger("PAT", output_dir, if_train=False) 33 | logger.info(args) 34 | 35 | if args.config_file != "": 36 | logger.info("Loaded configuration file {}".format(args.config_file)) 37 | with open(args.config_file, 'r') as cf: 38 | config_str = "\n" + cf.read() 39 | logger.info(config_str) 40 | logger.info("Running with config:\n{}".format(cfg)) 41 | 42 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 43 | 44 | model = make_model(cfg, cfg.MODEL.NAME, 0,0,0) 45 | model.load_param(cfg.TEST.WEIGHT) 46 | 47 | for testname in cfg.DATASETS.TEST: 48 | val_loader, num_query = build_reid_test_loader(cfg, testname) 49 | if cfg.MODEL.NAME == 'part_attention_vit': 50 | do_inf_pat(cfg, model, val_loader, num_query) 51 | else: 52 | do_inf(cfg, model, val_loader, num_query) 53 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from processor.part_attention_vit_processor import part_attention_vit_do_train_with_amp 2 | from processor.ori_vit_processor_with_amp import ori_vit_do_train_with_amp 3 | from utils.logger import setup_logger 4 | from data.build_DG_dataloader import build_reid_train_loader, build_reid_test_loader 5 | from model import make_model 6 | from solver import make_optimizer 7 | from solver.scheduler_factory import create_scheduler 8 | from loss.build_loss import build_loss 9 | import random 10 | import torch 11 | import numpy as np 12 | import os 13 | import argparse 14 | from config import cfg 15 | import loss as Patchloss 16 | 17 | def set_seed(seed): 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = True 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser(description="ReID Training") 28 | parser.add_argument( 29 | "--config_file", default="./config/PAT.yml", help="path to config file", type=str 30 | ) 31 | 32 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 33 | nargs=argparse.REMAINDER) 34 | parser.add_argument("--local_rank", default=0, type=int) 35 | args = parser.parse_args() 36 | 37 | if args.config_file != "": 38 | cfg.merge_from_file(args.config_file) 39 | cfg.merge_from_list(args.opts) 40 | cfg.freeze() 41 | 42 | set_seed(cfg.SOLVER.SEED) 43 | 44 | if cfg.MODEL.DIST_TRAIN: 45 | torch.cuda.set_device(args.local_rank) 46 | 47 | output_dir = os.path.join(cfg.LOG_ROOT, cfg.LOG_NAME) 48 | if output_dir and not os.path.exists(output_dir): 49 | os.makedirs(output_dir) 50 | 51 | logger = setup_logger("PAT", output_dir, if_train=True) 52 | logger.info("Saving model in the path :{}".format(output_dir)) 53 | logger.info(args) 54 | 55 | if args.config_file != "": 56 | logger.info("Loaded configuration file {}".format(args.config_file)) 57 | with open(args.config_file, 'r') as cf: 58 | config_str = "\n" + cf.read() 59 | logger.info(config_str) 60 | logger.info("Running with config:\n{}".format(cfg)) 61 | 62 | if cfg.MODEL.DIST_TRAIN: 63 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 64 | 65 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 66 | 67 | # build DG train loader 68 | train_loader = build_reid_train_loader(cfg) 69 | # build DG validate loader 70 | val_name = cfg.DATASETS.TEST[0] 71 | val_loader, num_query = build_reid_test_loader(cfg, val_name) 72 | num_classes = len(train_loader.dataset.pids) 73 | model_name = cfg.MODEL.NAME 74 | model = make_model(cfg, modelname=model_name, num_class=num_classes, camera_num=None, view_num=None) 75 | if cfg.MODEL.FREEZE_PATCH_EMBED and 'resnet' not in cfg.MODEL.NAME: # trick from moco v3 76 | model.base.patch_embed.proj.weight.requires_grad = False 77 | model.base.patch_embed.proj.bias.requires_grad = False 78 | print("====== freeze patch_embed for stability ======") 79 | 80 | loss_func, center_cri = build_loss(cfg, num_classes=num_classes) 81 | 82 | optimizer = make_optimizer(cfg, model) 83 | scheduler = create_scheduler(cfg, optimizer) 84 | 85 | ################## patch loss #################### 86 | patch_centers = Patchloss.PatchMemory(momentum=0.1, num=1) 87 | pc_criterion = Patchloss.Pedal(scale=cfg.MODEL.PC_SCALE, k=cfg.MODEL.CLUSTER_K).cuda() 88 | if cfg.MODEL.SOFT_LABEL and cfg.MODEL.NAME == 'part_attention_vit': 89 | print("========using soft label========") 90 | ################## patch loss #################### 91 | 92 | do_train_dict = { 93 | 'part_attention_vit': part_attention_vit_do_train_with_amp 94 | } 95 | if model_name not in do_train_dict.keys(): 96 | ori_vit_do_train_with_amp( 97 | cfg, 98 | model, 99 | train_loader, 100 | val_loader, 101 | optimizer, 102 | scheduler, 103 | loss_func, 104 | num_query, args.local_rank, 105 | ) 106 | else : 107 | do_train_dict[model_name]( 108 | cfg, 109 | model, 110 | train_loader, 111 | val_loader, 112 | optimizer, 113 | scheduler, 114 | loss_func, 115 | num_query, args.local_rank, 116 | patch_centers = patch_centers, 117 | pc_criterion = pc_criterion 118 | ) 119 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyuke65535/Part-Aware-Transformer/104d42e8292f7e5d534689ded15e4afafb453785/utils/__init__.py -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import json 3 | import os 4 | 5 | import os.path as osp 6 | 7 | 8 | def mkdir_if_missing(directory): 9 | if not osp.exists(directory): 10 | try: 11 | os.makedirs(directory) 12 | except OSError as e: 13 | if e.errno != errno.EEXIST: 14 | raise 15 | 16 | 17 | def check_isfile(path): 18 | isfile = osp.isfile(path) 19 | if not isfile: 20 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 21 | return isfile 22 | 23 | 24 | def read_json(fpath): 25 | with open(fpath, 'r') as f: 26 | obj = json.load(f) 27 | return obj 28 | 29 | 30 | def write_json(obj, fpath): 31 | mkdir_if_missing(osp.dirname(fpath)) 32 | with open(fpath, 'w') as f: 33 | json.dump(obj, f, indent=4, separators=(',', ': ')) 34 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import os.path as osp 5 | def setup_logger(name, save_dir, if_train): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | 9 | ch = logging.StreamHandler(stream=sys.stdout) 10 | ch.setLevel(logging.DEBUG) 11 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 12 | ch.setFormatter(formatter) 13 | logger.addHandler(ch) 14 | 15 | if save_dir: 16 | if not osp.exists(save_dir): 17 | os.makedirs(save_dir) 18 | if if_train: 19 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w') 20 | else: 21 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='w') 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from re import T 2 | from time import time 3 | import torch 4 | import numpy as np 5 | import os 6 | from utils.reranking import re_ranking 7 | 8 | 9 | def euclidean_distance(qf, gf): 10 | m = qf.shape[0] 11 | n = gf.shape[0] 12 | dist_mat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 13 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 14 | dist_mat.addmm_(qf, gf.t(), beta=1, alpha=-2) 15 | return dist_mat.cpu().numpy() 16 | 17 | def cosine_similarity(qf, gf): 18 | epsilon = 0.00001 19 | dist_mat = qf.mm(gf.t()) 20 | qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True) # mx1 21 | gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True) # nx1 22 | qg_normdot = qf_norm.mm(gf_norm.t()) 23 | 24 | dist_mat = dist_mat.mul(1 / qg_normdot).cpu().numpy() 25 | dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon) 26 | dist_mat = np.arccos(dist_mat) 27 | return dist_mat 28 | 29 | 30 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 31 | """Evaluation with market1501 metric 32 | Key: for each query identity, its gallery images from the same camera view are discarded. 33 | """ 34 | num_q, num_g = distmat.shape 35 | # distmat g 36 | # q 1 3 2 4 37 | # 4 1 2 3 38 | if num_g < max_rank: 39 | max_rank = num_g 40 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 41 | indices = np.argsort(distmat, axis=1) 42 | # 0 2 1 3 43 | # 1 2 3 0 44 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 45 | # compute cmc curve for each query 46 | all_cmc = [] 47 | all_AP = [] 48 | num_valid_q = 0. # number of valid query 49 | for q_idx in range(num_q): 50 | # get query pid and camid 51 | q_pid = q_pids[q_idx] 52 | q_camid = q_camids[q_idx] 53 | 54 | # remove gallery samples that have the same pid and camid with query 55 | order = indices[q_idx] # select one row 56 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 57 | keep = np.invert(remove) 58 | 59 | # compute cmc curve 60 | # binary vector, positions with value 1 are correct matches 61 | orig_cmc = matches[q_idx][keep] 62 | if not np.any(orig_cmc): 63 | # this condition is true when query identity does not appear in gallery 64 | continue 65 | 66 | cmc = orig_cmc.cumsum() 67 | cmc[cmc > 1] = 1 68 | 69 | all_cmc.append(cmc[:max_rank]) 70 | num_valid_q += 1. 71 | 72 | # compute average precision 73 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 74 | num_rel = orig_cmc.sum() 75 | tmp_cmc = orig_cmc.cumsum() 76 | #tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 77 | y = np.arange(1, tmp_cmc.shape[0] + 1) * 1.0 78 | tmp_cmc = tmp_cmc / y 79 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 80 | AP = tmp_cmc.sum() / num_rel 81 | all_AP.append(AP) 82 | 83 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 84 | 85 | all_cmc = np.asarray(all_cmc).astype(np.float32) 86 | all_cmc = all_cmc.sum(0) / num_valid_q 87 | mAP = np.mean(all_AP) 88 | 89 | return all_cmc, mAP 90 | 91 | 92 | class R1_mAP_eval(): 93 | def __init__(self, num_query, max_rank=50, feat_norm=True, reranking=False): 94 | super(R1_mAP_eval, self).__init__() 95 | self.num_query = num_query 96 | self.max_rank = max_rank 97 | self.feat_norm = feat_norm 98 | self.reranking = reranking 99 | 100 | def reset(self): 101 | self.feats = [] 102 | self.pids = [] 103 | self.camids = [] 104 | 105 | def update(self, output): # called once for each batch 106 | feat, pid, camid = output 107 | self.feats.append(feat.cpu()) 108 | self.pids.extend(np.asarray(pid)) 109 | self.camids.extend(np.asarray(camid)) 110 | 111 | def compute(self): # called after each epoch 112 | feats = torch.cat(self.feats, dim=0) 113 | if self.feat_norm: 114 | # print("The test feature is normalized") 115 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) # along channel 116 | # query 117 | qf = feats[:self.num_query] 118 | q_pids = np.asarray(self.pids[:self.num_query]) 119 | q_camids = np.asarray(self.camids[:self.num_query]) 120 | # gallery 121 | gf = feats[self.num_query:] 122 | g_pids = np.asarray(self.pids[self.num_query:]) 123 | 124 | g_camids = np.asarray(self.camids[self.num_query:]) 125 | if self.reranking: 126 | print('=> Enter reranking') 127 | # distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 128 | distmat = re_ranking(qf, gf, k1=50, k2=15, lambda_value=0.3) 129 | 130 | else: 131 | # print('=> Computing DistMat with euclidean_distance') 132 | distmat = euclidean_distance(qf, gf) 133 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 134 | 135 | return cmc, mAP, distmat, self.pids, self.camids, qf, gf 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /utils/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | 4 | class Registry(object): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | To create a registry (e.g. a backbone registry): 9 | .. code-block:: python 10 | BACKBONE_REGISTRY = Registry('BACKBONE') 11 | To register an object: 12 | .. code-block:: python 13 | @BACKBONE_REGISTRY.register() 14 | class MyBackbone(): 15 | ... 16 | Or: 17 | .. code-block:: python 18 | BACKBONE_REGISTRY.register(MyBackbone) 19 | """ 20 | 21 | def __init__(self, name: str) -> None: 22 | """ 23 | Args: 24 | name (str): the name of this registry 25 | """ 26 | self._name: str = name 27 | self._obj_map: Dict[str, object] = {} 28 | 29 | def _do_register(self, name: str, obj: object) -> None: 30 | assert ( 31 | name not in self._obj_map 32 | ), "An object named '{}' was already registered in '{}' registry!".format( 33 | name, self._name 34 | ) 35 | self._obj_map[name] = obj 36 | 37 | def register(self, obj: object = None) -> Optional[object]: 38 | """ 39 | Register the given object under the the name `obj.__name__`. 40 | Can be used as either a decorator or not. See docstring of this class for usage. 41 | """ 42 | if obj is None: 43 | # used as a decorator 44 | def deco(func_or_class: object) -> object: 45 | name = func_or_class.__name__ # pyre-ignore 46 | self._do_register(name, func_or_class) 47 | return func_or_class 48 | 49 | return deco 50 | 51 | # used as a function call 52 | name = obj.__name__ # pyre-ignore 53 | self._do_register(name, obj) 54 | 55 | def get(self, name: str) -> object: 56 | ret = self._obj_map.get(name) 57 | if ret is None: 58 | raise KeyError( 59 | "No object named '{}' found in '{}' registry!".format( 60 | name, self._name 61 | ) 62 | ) 63 | return ret 64 | -------------------------------------------------------------------------------- /utils/reranking.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Fri, 25 May 2018 20:29:09 3 | 4 | 5 | """ 6 | 7 | """ 8 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 9 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 10 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 11 | """ 12 | 13 | """ 14 | API 15 | 16 | probFea: all feature vectors of the query set (torch tensor) 17 | probFea: all feature vectors of the gallery set (torch tensor) 18 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 19 | MemorySave: set to 'True' when using MemorySave mode 20 | Minibatch: avaliable when 'MemorySave' is 'True' 21 | """ 22 | 23 | import numpy as np 24 | import torch 25 | 26 | 27 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 28 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 29 | query_num = probFea.size(0) 30 | all_num = query_num + galFea.size(0) 31 | if only_local: 32 | original_dist = local_distmat 33 | else: 34 | feat = torch.cat([probFea, galFea]) 35 | # print('using GPU to compute original distance') 36 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \ 37 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 38 | distmat.addmm_(1, -2, feat, feat.t()) 39 | original_dist = distmat.cpu().numpy() 40 | del feat 41 | if not local_distmat is None: 42 | original_dist = original_dist + local_distmat 43 | gallery_num = original_dist.shape[0] 44 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 45 | V = np.zeros_like(original_dist).astype(np.float16) 46 | initial_rank = np.argsort(original_dist).astype(np.int32) 47 | 48 | # print('starting re_ranking') 49 | for i in range(all_num): 50 | # k-reciprocal neighbors 51 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 52 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 53 | fi = np.where(backward_k_neigh_index == i)[0] 54 | k_reciprocal_index = forward_k_neigh_index[fi] 55 | k_reciprocal_expansion_index = k_reciprocal_index 56 | for j in range(len(k_reciprocal_index)): 57 | candidate = k_reciprocal_index[j] 58 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 59 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 60 | :int(np.around(k1 / 2)) + 1] 61 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 62 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 63 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 64 | candidate_k_reciprocal_index): 65 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 66 | 67 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 68 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 69 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 70 | original_dist = original_dist[:query_num, ] 71 | if k2 != 1: 72 | V_qe = np.zeros_like(V, dtype=np.float16) 73 | for i in range(all_num): 74 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 75 | V = V_qe 76 | del V_qe 77 | del initial_rank 78 | invIndex = [] 79 | for i in range(gallery_num): 80 | invIndex.append(np.where(V[:, i] != 0)[0]) 81 | 82 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 83 | 84 | for i in range(query_num): 85 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 86 | indNonZero = np.where(V[i, :] != 0)[0] 87 | indImages = [invIndex[ind] for ind in indNonZero] 88 | for j in range(len(indNonZero)): 89 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 90 | V[indImages[j], indNonZero[j]]) 91 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 92 | 93 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 94 | del original_dist 95 | del V 96 | del jaccard_dist 97 | final_dist = final_dist[:query_num, query_num:] 98 | return final_dist 99 | 100 | -------------------------------------------------------------------------------- /visualization/config_vis/__init__.py: -------------------------------------------------------------------------------- 1 | from .vit_b import _C as cfg -------------------------------------------------------------------------------- /visualization/good_samples_market_query.json: -------------------------------------------------------------------------------- 1 | [ 2 | 3 | "1255_c6s3_045442_00.jpg", 4 | "0778_c1s4_018881_00.jpg", 5 | "0600_c1s3_029851_00.jpg", 6 | "0120_c3s1_020126_00.jpg", 7 | "0568_c1s3_019626_00.jpg", 8 | "0964_c5s2_122274_00.jpg", 9 | 10 | "0535_c4s3_005423_00.jpg", 11 | "0505_c4s3_009248_00.jpg", 12 | "1459_c1s6_009666_00.jpg", 13 | "0294_c6s1_066676_00.jpg", 14 | 15 | "1089_c4s6_039016_00.jpg", 16 | "0801_c6s2_088243_00.jpg", 17 | "1183_c6s3_030367_00.jpg", 18 | "0934_c4s4_061216_00.jpg", 19 | "0355_c3s1_081467_00.jpg", 20 | "0618_c6s2_014593_00.jpg", 21 | "0678_c3s2_048787_00.jpg", 22 | "0174_c5s1_053251_00.jpg", 23 | "0911_c3s2_113153_00.jpg", 24 | "1277_c1s5_052541_00.jpg", 25 | "0005_c6s1_004576_00.jpg", 26 | "1146_c2s2_158802_00.jpg", 27 | "0388_c1s2_018716_00.jpg", 28 | "0418_c1s2_027716_00.jpg", 29 | "0538_c2s1_152691_00.jpg", 30 | "0609_c1s3_032151_00.jpg", 31 | "0231_c4s1_047501_00.jpg", 32 | "1195_c6s3_032367_00.jpg", 33 | "0103_c3s1_016876_00.jpg" 34 | ] -------------------------------------------------------------------------------- /visualization/readme.md: -------------------------------------------------------------------------------- 1 | # Attention Rollout 2 | 3 | We updated the visualization codes based on https://github.com/jacobgil/vit-explain. See my examples in visualization/test.jpg. 4 | 5 | ## How to run? 6 | Following the instruction below. 7 | ``` 8 | cd visualization 9 | 10 | python vit_explain.py --save_path xxx --data_path xxx --vit_path xxx --pat_path xxx --pretrain_path xxx 11 | ``` 12 | 13 | For more details, please check visualization/vit_explain.py. 14 | 15 | ## No ideal reults? 16 | 17 | You can modify the options of the line 67 in visualization/vit_explain.py. 18 | 19 | Moreover, learn about attention fusion in visualization/vit_rollout/vit_rollout.py. 20 | -------------------------------------------------------------------------------- /visualization/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyuke65535/Part-Aware-Transformer/104d42e8292f7e5d534689ded15e4afafb453785/visualization/test.jpg -------------------------------------------------------------------------------- /visualization/vit_explain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from PIL import Image 4 | from torchvision import transforms 5 | import numpy as np 6 | import cv2 7 | 8 | from config_vis import cfg 9 | 10 | from vit_rollout.vit_rollout import VITAttentionRollout 11 | import os 12 | import sys 13 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 14 | from model import make_model 15 | 16 | def show_mask_on_image(img, mask): 17 | img = np.float32(img) / 255 18 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 19 | heatmap = np.float32(heatmap) / 255 20 | cam = heatmap + np.float32(img) 21 | cam = cam / np.max(cam) 22 | return np.uint8(255 * cam) 23 | 24 | def main(args): 25 | os.environ['CUDA_VISIBLE_DEVICES'] = '5' 26 | cfg.MODEL.PRETRAIN_PATH = args.pretrain_path 27 | 28 | # load part_attention_vit 29 | model_ours = make_model(cfg, 'part_attention_vit', num_class=1) 30 | model_ours.load_param(args.pat_path) 31 | model_ours.eval() 32 | model_ours.to('cuda') 33 | 34 | # load vanilla vit 35 | model_vit = make_model(cfg, 'vit', num_class=1) 36 | model_vit.load_param(args.vit_path) 37 | model_vit.eval() 38 | model_vit.to('cuda') 39 | 40 | transform = transforms.Compose([ 41 | transforms.Resize((256,128)), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 44 | ]) 45 | 46 | input_tensor = [] 47 | 48 | # Prepare the original person photos 49 | base_dir = args.data_path 50 | img_path = os.listdir(base_dir) 51 | random.shuffle(img_path) 52 | 53 | length = min(30, len(img_path)) # how many photos to visualize 54 | img_list = [] 55 | for pth in img_path[:length]: 56 | img = Image.open(base_dir+pth) 57 | img = img.resize((128,256)) 58 | np_img = np.array(img)[:, :, ::-1] # BGR -> RGB 59 | input_tensor = transform(img).unsqueeze(0) 60 | input_tensor = input_tensor.cuda() 61 | img_list.append(np_img) 62 | 63 | local_flag = False 64 | 65 | # attention rollout 66 | for model in [model_ours]: 67 | attention_rollout = VITAttentionRollout(model, head_fusion='mean', discard_ratio=0.5) # modify head_fusion type and discard_ratio for better outputs 68 | masks = attention_rollout(input_tensor) 69 | 70 | if isinstance(masks, list): 71 | for msk in masks: 72 | msk = cv2.resize(msk, (np_img.shape[1], np_img.shape[0])) 73 | img_list.append(show_mask_on_image(np_img, msk)) 74 | local_flag = True 75 | else: 76 | masks = cv2.resize(masks, (np_img.shape[1], np_img.shape[0])) 77 | out_img = show_mask_on_image(np_img, masks) 78 | img_list.append(out_img) 79 | 80 | 81 | final_img = [] 82 | line_len = 5 if local_flag else 3 83 | 84 | # concate output images in a column 85 | for i in range(0, len(img_list)-1, line_len): 86 | if i==0: 87 | img_line = [img_list[l] for l in range(line_len)] 88 | final_img = np.concatenate(img_line,axis=1) 89 | else: 90 | img_line = [img_list[i+l] for l in range(line_len)] 91 | x = np.concatenate(img_line,axis=1) 92 | final_img = np.concatenate([final_img,x],axis=0) 93 | 94 | cv2.imwrite(args.save_path, final_img) 95 | for i, pth in enumerate(img_path[:30]): 96 | print(i+1, pth) 97 | print(f"save to {args.save_path}") 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument("--save_path", type=str, help="path to save your attention visualized photo. E.g., /home/me/out.jpg") 102 | parser.add_argument("--data_path", type=str, help="path to your dataset. E.g., dataset/market1501/query") 103 | parser.add_argument("--pretrain_path", type=str, help="path to your pretrained vit from imagenet or else. E.g., /home/me/cpt/") 104 | parser.add_argument("--vit_path", type=str, help="path to your trained vanilla vit. E.g., cpt/vit.pth") 105 | parser.add_argument("--pat_path", type=str, help="path to your trained PAT. E.g., cpt/pat.pth") 106 | args = parser.parse_args() 107 | main(args) -------------------------------------------------------------------------------- /visualization/vit_rollout/vit_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import torch 5 | 6 | from pytorch_grad_cam import GradCAM, \ 7 | ScoreCAM, \ 8 | GradCAMPlusPlus, \ 9 | AblationCAM, \ 10 | XGradCAM, \ 11 | EigenCAM, \ 12 | EigenGradCAM, \ 13 | LayerCAM, \ 14 | FullGrad 15 | 16 | from pytorch_grad_cam import GuidedBackpropReLUModel 17 | from pytorch_grad_cam.utils.image import show_cam_on_image, \ 18 | preprocess_image 19 | from pytorch_grad_cam.ablation_layer import AblationLayerVit 20 | 21 | def get_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--use-cuda', action='store_true', default=False, 24 | help='Use NVIDIA GPU acceleration') 25 | parser.add_argument( 26 | '--image-path', 27 | type=str, 28 | default='./examples/both.png', 29 | help='Input image path') 30 | parser.add_argument('--aug_smooth', action='store_true', 31 | help='Apply test time augmentation to smooth the CAM') 32 | parser.add_argument( 33 | '--eigen_smooth', 34 | action='store_true', 35 | help='Reduce noise by taking the first principle componenet' 36 | 'of cam_weights*activations') 37 | 38 | parser.add_argument( 39 | '--method', 40 | type=str, 41 | default='gradcam', 42 | help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam') 43 | 44 | args = parser.parse_args() 45 | args.use_cuda = args.use_cuda and torch.cuda.is_available() 46 | if args.use_cuda: 47 | print('Using GPU for acceleration') 48 | else: 49 | print('Using CPU for computation') 50 | 51 | return args 52 | 53 | 54 | def reshape_transform(tensor, height=14, width=14): 55 | result = tensor[:, 1:, :].reshape(tensor.size(0), 56 | height, width, tensor.size(2)) 57 | 58 | # Bring the channels to the first dimension, 59 | # like in CNNs. 60 | result = result.transpose(2, 3).transpose(1, 2) 61 | return result 62 | 63 | 64 | if __name__ == '__main__': 65 | """ python vit_gradcam.py --image-path 66 | Example usage of using cam-methods on a VIT network. 67 | 68 | """ 69 | 70 | args = get_args() 71 | methods = \ 72 | {"gradcam": GradCAM, 73 | "scorecam": ScoreCAM, 74 | "gradcam++": GradCAMPlusPlus, 75 | "ablationcam": AblationCAM, 76 | "xgradcam": XGradCAM, 77 | "eigencam": EigenCAM, 78 | "eigengradcam": EigenGradCAM, 79 | "layercam": LayerCAM, 80 | "fullgrad": FullGrad} 81 | 82 | if args.method not in list(methods.keys()): 83 | raise Exception(f"method should be one of {list(methods.keys())}") 84 | 85 | model = torch.hub.load('facebookresearch/deit:main', 86 | 'deit_tiny_patch16_224', pretrained=True) 87 | model.eval() 88 | 89 | if args.use_cuda: 90 | model = model.cuda() 91 | 92 | target_layers = [model.blocks[-1].norm1] 93 | 94 | if args.method not in methods: 95 | raise Exception(f"Method {args.method} not implemented") 96 | 97 | if args.method == "ablationcam": 98 | cam = methods[args.method](model=model, 99 | target_layers=target_layers, 100 | use_cuda=args.use_cuda, 101 | reshape_transform=reshape_transform, 102 | ablation_layer=AblationLayerVit()) 103 | else: 104 | cam = methods[args.method](model=model, 105 | target_layers=target_layers, 106 | use_cuda=args.use_cuda, 107 | reshape_transform=reshape_transform) 108 | 109 | 110 | rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1] 111 | rgb_img = cv2.resize(rgb_img, (224, 224)) 112 | rgb_img = np.float32(rgb_img) / 255 113 | input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], 114 | std=[0.5, 0.5, 0.5]) 115 | 116 | # If None, returns the map for the highest scoring category. 117 | # Otherwise, targets the requested category. 118 | targets = None 119 | 120 | # AblationCAM and ScoreCAM have batched implementations. 121 | # You can override the internal batch size for faster computation. 122 | cam.batch_size = 32 123 | 124 | grayscale_cam = cam(input_tensor=input_tensor, 125 | targets=targets , 126 | eigen_smooth=args.eigen_smooth, 127 | aug_smooth=args.aug_smooth) 128 | 129 | # Here grayscale_cam has only one image in the batch 130 | grayscale_cam = grayscale_cam[0, :] 131 | 132 | cam_image = show_cam_on_image(rgb_img, grayscale_cam) 133 | cv2.imwrite(f'{args.method}_cam.jpg', cam_image) -------------------------------------------------------------------------------- /visualization/vit_rollout/vit_grad_rollout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import numpy 4 | import sys 5 | from torchvision import transforms 6 | import numpy as np 7 | import cv2 8 | 9 | def grad_rollout(attentions, gradients, discard_ratio): 10 | result = torch.eye(attentions[0].size(-1)) 11 | with torch.no_grad(): 12 | for attention, grad in zip(attentions, gradients): 13 | weights = grad 14 | attention_heads_fused = (attention*weights).mean(axis=1) 15 | attention_heads_fused[attention_heads_fused < 0] = 0 16 | 17 | # Drop the lowest attentions, but 18 | # don't drop the class token 19 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) 20 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False) 21 | #indices = indices[indices != 0] 22 | flat[0, indices] = 0 23 | 24 | I = torch.eye(attention_heads_fused.size(-1)) 25 | a = (attention_heads_fused + 1.0*I)/2 26 | a = a / a.sum(dim=-1) 27 | result = torch.matmul(a, result) 28 | 29 | # Look at the total attention between the class token, 30 | # and the image patches 31 | mask = result[0, 0 , 1 :] 32 | # In case of 224x224 image, this brings us from 196 to 14 33 | width = int(mask.size(-1)**0.5) 34 | mask = mask.reshape(width, width).numpy() 35 | mask = mask / np.max(mask) 36 | return mask 37 | 38 | class VITAttentionGradRollout: 39 | def __init__(self, model, attention_layer_name='attn_drop', 40 | discard_ratio=0.9): 41 | self.model = model 42 | self.discard_ratio = discard_ratio 43 | for name, module in self.model.named_modules(): 44 | if attention_layer_name in name: 45 | module.register_forward_hook(self.get_attention) 46 | module.register_backward_hook(self.get_attention_gradient) 47 | 48 | self.attentions = [] 49 | self.attention_gradients = [] 50 | 51 | def get_attention(self, module, input, output): 52 | self.attentions.append(output.cpu()) 53 | 54 | def get_attention_gradient(self, module, grad_input, grad_output): 55 | self.attention_gradients.append(grad_input[0].cpu()) 56 | 57 | def __call__(self, input_tensor, category_index): 58 | self.model.zero_grad() 59 | output = self.model(input_tensor) 60 | category_mask = torch.zeros(output.size()) 61 | category_mask[:, category_index] = 1 62 | loss = (output*category_mask).sum() 63 | loss.backward() 64 | 65 | return grad_rollout(self.attentions, self.attention_gradients, 66 | self.discard_ratio) -------------------------------------------------------------------------------- /visualization/vit_rollout/vit_rollout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import numpy 4 | import sys 5 | from torchvision import transforms 6 | import numpy as np 7 | import cv2 8 | 9 | def rollout(attentions, discard_ratio, head_fusion): 10 | result = torch.eye(attentions[0].size(-1)).unsqueeze(0) 11 | with torch.no_grad(): 12 | attention = attentions[3] # alter this 13 | # num_blocks = 6 14 | # for attention in attentions[4:6]: # alter this 15 | if head_fusion == "mean": 16 | attention_heads_fused = attention.mean(axis=1) 17 | elif head_fusion == "max": 18 | attention_heads_fused = attention.max(axis=1)[0] 19 | elif head_fusion == "min": 20 | attention_heads_fused = attention.min(axis=1)[0] 21 | else: 22 | raise "Attention head fusion type Not supported" 23 | 24 | # Drop the lowest attentions, but 25 | # don't drop the class token 26 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) 27 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False) 28 | indices = indices[indices != 0] 29 | flat[0, indices] = 0 30 | 31 | # I = torch.eye(attention_heads_fused.size(-1)) 32 | # a = (attention_heads_fused + 1.0*I)/2 33 | # a = a / a.sum(dim=-1) 34 | # result = torch.matmul(a, result) 35 | 36 | result = attention_heads_fused 37 | 38 | # result /= 2 39 | 40 | # Look at the total attention between the class token, 41 | # and the image patches 42 | if result.size(-1) == 132: 43 | masks = [] 44 | for i in range(4): 45 | mask = result[0, i, 4 :].reshape(16,8).numpy() 46 | mask = mask / np.max(mask) 47 | masks.append(mask) 48 | return masks 49 | # mask = result[0, 0, 4 :] 50 | 51 | else: 52 | mask = result[0, 0, 1 :] 53 | # In case of 224x224 image, this brings us from 196 to 14 54 | 55 | mask = mask.reshape(16,8).numpy() 56 | mask = mask / np.max(mask) 57 | return mask 58 | 59 | class VITAttentionRollout: 60 | def __init__(self, model, attention_layer_name='attn_drop', head_fusion="mean", 61 | discard_ratio=0.9): 62 | self.model = model 63 | self.head_fusion = head_fusion 64 | self.discard_ratio = discard_ratio 65 | for name, module in self.model.named_modules(): 66 | if attention_layer_name in name: 67 | module.register_forward_hook(self.get_attention) 68 | 69 | self.attentions = [] 70 | 71 | def get_attention(self, module, input, output): 72 | self.attentions.append(output.cpu()) 73 | 74 | def __call__(self, input_tensor): 75 | self.attentions = [] 76 | with torch.no_grad(): 77 | output = self.model(input_tensor) 78 | 79 | return rollout(self.attentions, self.discard_ratio, self.head_fusion) --------------------------------------------------------------------------------