├── .gitignore
├── README.md
├── config
├── PAT.yml
├── __init__.py
├── defaults.py
└── vit.yml
├── data
├── __init__.py
├── build_DG_dataloader.py
├── common.py
├── data_utils.py
├── datasets
│ ├── AirportALERT.py
│ ├── DG_cuhk02.py
│ ├── DG_cuhk03_detected.py
│ ├── DG_cuhk03_labeled.py
│ ├── DG_cuhk_sysu.py
│ ├── DG_dukemtmcreid.py
│ ├── DG_grid.py
│ ├── DG_iLIDS.py
│ ├── DG_market1501.py
│ ├── DG_prid.py
│ ├── DG_viper.py
│ ├── __init__.py
│ ├── bases.py
│ ├── caviara.py
│ ├── cuhk03.py
│ ├── dukemtmcreid.py
│ ├── grid.py
│ ├── iLIDS.py
│ ├── lpw.py
│ ├── market1501.py
│ ├── msmt17.py
│ ├── pes3d.py
│ ├── pku.py
│ ├── prai.py
│ ├── prid.py
│ ├── randperson.py
│ ├── sensereid.py
│ ├── shinpuhkan.py
│ ├── sysu_mm.py
│ ├── thermalworld.py
│ ├── vehicleid.py
│ ├── veri.py
│ ├── veri_keypoint.py
│ ├── veriwild.py
│ └── viper.py
├── samplers
│ ├── __init__.py
│ ├── data_sampler.py
│ └── triplet_sampler.py
└── transforms
│ ├── __init__.py
│ ├── autoaugment.py
│ ├── build.py
│ ├── functional.py
│ └── transforms.py
├── enviroments.sh
├── loss
├── __init__.py
├── arcface.py
├── build_loss.py
├── ce_labelSmooth.py
├── center_loss.py
├── make_loss.py
├── metric_learning.py
├── myloss.py
├── smooth.py
├── softmax_loss.py
└── triplet_loss.py
├── model
├── __init__.py
├── backbones
│ ├── IBN.py
│ ├── __init__.py
│ ├── resnet.py
│ ├── resnet_ibn.py
│ └── vit_pytorch.py
└── make_model.py
├── processor
├── __init__.py
├── ori_vit_processor_with_amp.py
└── part_attention_vit_processor.py
├── run.sh
├── solver
├── __init__.py
├── cosine_lr.py
├── lr_scheduler.py
├── make_optimizer.py
├── scheduler.py
└── scheduler_factory.py
├── test.py
├── train.py
├── utils
├── __init__.py
├── comm.py
├── file_io.py
├── iotools.py
├── logger.py
├── meter.py
├── metrics.py
├── registry.py
└── reranking.py
└── visualization
├── config_vis
├── __init__.py
└── vit_b.py
├── good_samples_market_query.json
├── readme.md
├── test.jpg
├── vit_explain.py
└── vit_rollout
├── vit_example.py
├── vit_grad_rollout.py
└── vit_rollout.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | tb_log
3 | .vscode
4 | *.yml
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Part-Aware-Transformer
2 |
3 | ## 🔥 News
4 | We updated the visualization codes.
5 |
6 | See instructions in /visualization/readme.md.
7 |
8 | ## Welcome
9 |
10 | This is the official repo for "Part-Aware Transformer for Generalizable Person Re-identification" [ICCV 2023]
11 |
12 |

13 |
14 |
15 |
16 |
17 | ## Abstract
18 | Domain generalization person re-identification (DG-ReID) aims to train a model on source domains and generalize well on unseen domains.
19 | Vision Transformer usually yields better generalization ability than common CNN networks under distribution shifts.
20 | However, Transformer-based ReID models inevitably over-fit to domain-specific biases due to the supervised learning strategy on the source domain.
21 | We observe that while the global images of different IDs should have different features, their similar local parts (e.g., black backpack) are not bounded by this constraint.
22 | Motivated by this, we propose a pure Transformer model (termed Part-aware Transformer) for DG-ReID by designing a proxy task, named Cross-ID Similarity Learning (CSL), to mine local visual information shared by different IDs. This proxy task allows the model to learn generic features because it only cares about the visual similarity of the parts regardless of the ID labels, thus alleviating the side effect of domain-specific biases.
23 | Based on the local similarity obtained in CSL, a Part-guided Self-Distillation (PSD) is proposed to further improve the generalization of global features.
24 | Our method achieves state-of-the-art performance under most DG ReID settings.
25 |
26 | ## Framework
27 | 
28 |
29 | ## Visualizations
30 | 
31 | 
32 |
33 | # Instructions
34 |
35 | Here are some instructions to run our code.
36 | Our code is based on [TransReID](https://github.com/damo-cv/TransReID), thanks for their excellent work.
37 |
38 | ## 1. Clone this repo
39 | ```
40 | git clone https://github.com/liyuke65535/Part-Aware-Transformer.git
41 | ```
42 |
43 | ## 2. Prepare your environment
44 | ```
45 | conda create -n pat python==3.10
46 | conda activate pat
47 | bash enviroments.sh
48 | ```
49 |
50 | ## 3. Prepare pretrained model (ViT-B) and datasets
51 | You can download it from huggingface, rwightman, or else where.
52 | For example, pretrained model is avaliable at [ViT-B](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth).
53 |
54 | As for datasets, follow the instructions in [MetaBIN](https://github.com/bismex/MetaBIN#8-datasets).
55 |
56 | ## 4. Modify the config file
57 | ```
58 | # modify the model path and dataset paths of the config file
59 | vim ./config/PAT.yml
60 | ```
61 |
62 | ## 5. Train a model
63 | ```
64 | bash run.sh
65 | ```
66 |
67 | ## 6. Evaluation only
68 | ```
69 | # modify the trained path in config
70 | vim ./config/PAT.yml
71 |
72 | # evaluation
73 | python test.py --config ./config/PAT.yml
74 | ```
75 | ## Citation
76 | If you find this repo useful for your research, you're welcome to cite our paper.
77 | ```
78 | @inproceedings{ni2023part,
79 | title={Part-Aware Transformer for Generalizable Person Re-identification},
80 | author={Ni, Hao and Li, Yuke and Gao, Lianli and Shen, Heng Tao and Song, Jingkuan},
81 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
82 | pages={11280--11289},
83 | year={2023}
84 | }
85 | ```
86 |
--------------------------------------------------------------------------------
/config/PAT.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: "../../.cache/torch/hub/checkpoints" # root of pretrain path
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'on'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'part_attention_vit'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('0')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256,128]
15 | SIZE_TEST: [256,128]
16 | REA:
17 | ENABLED: False
18 | PIXEL_MEAN: [0.5, 0.5, 0.5]
19 | PIXEL_STD: [0.5, 0.5, 0.5]
20 | LGT: # Local Grayscale Transfomation
21 | DO_LGT: False
22 | PROB: 0.5
23 |
24 | DATASETS:
25 | TRAIN: ('Market1501',)
26 | TEST: ("DukeMTMC",)
27 | ROOT_DIR: ('../../data') # root of datasets
28 |
29 | DATALOADER:
30 | SAMPLER: 'softmax_triplet'
31 | NUM_INSTANCE: 4
32 | NUM_WORKERS: 8
33 |
34 | SOLVER:
35 | OPTIMIZER_NAME: 'SGD'
36 | MAX_EPOCHS: 60
37 | BASE_LR: 0.001 # 0.0004 for msmt
38 | IMS_PER_BATCH: 64
39 | WARMUP_METHOD: 'linear'
40 | LARGE_FC_LR: False
41 | CHECKPOINT_PERIOD: 5
42 | LOG_PERIOD: 60
43 | EVAL_PERIOD: 1
44 | WEIGHT_DECAY: 1e-4
45 | WEIGHT_DECAY_BIAS: 1e-4
46 | BIAS_LR_FACTOR: 2
47 | SEED: 1234
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 128
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: True
56 |
57 | LOG_ROOT: '../../data/exp/' # root of log file
58 | TB_LOG_ROOT: './tb_log/'
59 | LOG_NAME: 'PAT/market/vit_base'
60 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 |
3 | from .defaults import _C as cfg
4 | from .defaults import _C as cfg_test
5 |
--------------------------------------------------------------------------------
/config/vit.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: "../../.cache/torch/hub/checkpoints" # root of pretrain path
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'on'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'vit'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('0')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256,128]
15 | SIZE_TEST: [256,128]
16 | REA:
17 | ENABLED: False
18 | PIXEL_MEAN: [0.5, 0.5, 0.5]
19 | PIXEL_STD: [0.5, 0.5, 0.5]
20 | LGT: # Local Grayscale Transfomation
21 | DO_LGT: False
22 | PROB: 0.5
23 |
24 | DATASETS:
25 | TRAIN: ('Market1501',)
26 | TEST: ("DukeMTMC",)
27 | ROOT_DIR: ('../../data') # root of datasets
28 |
29 | DATALOADER:
30 | SAMPLER: 'softmax_triplet'
31 | NUM_INSTANCE: 4
32 | NUM_WORKERS: 8
33 |
34 | SOLVER:
35 | OPTIMIZER_NAME: 'SGD'
36 | MAX_EPOCHS: 60
37 | BASE_LR: 0.008 # 0.0004 for msmt
38 | IMS_PER_BATCH: 64
39 | WARMUP_METHOD: 'linear'
40 | LARGE_FC_LR: False
41 | CHECKPOINT_PERIOD: 5
42 | LOG_PERIOD: 60
43 | EVAL_PERIOD: 5
44 | WEIGHT_DECAY: 1e-4
45 | WEIGHT_DECAY_BIAS: 1e-4
46 | BIAS_LR_FACTOR: 2
47 | SEED: 1234
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 128
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: True
56 |
57 | LOG_ROOT: '../../data/exp/' # root of log file
58 | TB_LOG_ROOT: './tb_log/'
59 | LOG_NAME: 'vit/market/vit_base'
60 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .build_DG_dataloader import build_reid_train_loader, build_reid_test_loader
2 |
--------------------------------------------------------------------------------
/data/build_DG_dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import sys
4 | import collections.abc as container_abcs
5 |
6 | # from torch._six import container_abcs, string_classes, int_classes
7 | int_classes = int
8 | string_classes = str
9 | from torch.utils.data import DataLoader
10 | from utils import comm
11 | import random
12 |
13 | from . import samplers
14 | from .common import CommDataset
15 | from .datasets import DATASET_REGISTRY
16 | from .transforms import build_transforms
17 |
18 | _root = os.getenv("REID_DATASETS", "../../data")
19 |
20 |
21 | def build_reid_train_loader(cfg):
22 | gettrace = getattr(sys, 'gettrace', None)
23 | if gettrace():
24 | print('*'*100)
25 | print('Hmm, Big Debugger is watching me')
26 | print('*'*100)
27 | num_workers = 0
28 | else:
29 | num_workers = cfg.DATALOADER.NUM_WORKERS
30 |
31 | train_transforms = build_transforms(cfg, is_train=True, is_fake=False)
32 | train_items = list()
33 | domain_idx = 0
34 | camera_all = list()
35 |
36 | # load datasets
37 | _root = cfg.DATASETS.ROOT_DIR
38 | for d in cfg.DATASETS.TRAIN:
39 | if d == 'CUHK03_NP':
40 | dataset = DATASET_REGISTRY.get('CUHK03')(root=_root, cuhk03_labeled=False)
41 | else:
42 | dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
43 | if comm.is_main_process():
44 | dataset.show_train()
45 | if len(dataset.train[0]) < 4:
46 | for i, x in enumerate(dataset.train):
47 | add_info = {} # dictionary
48 |
49 | if cfg.DATALOADER.CAMERA_TO_DOMAIN:
50 | add_info['domains'] = dataset.train[i][2]
51 | camera_all.append(dataset.train[i][2])
52 | else:
53 | add_info['domains'] = int(domain_idx)
54 | dataset.train[i] = list(dataset.train[i])
55 | dataset.train[i].append(add_info)
56 | dataset.train[i] = tuple(dataset.train[i])
57 | domain_idx += 1
58 | train_items.extend(dataset.train)
59 |
60 | train_set = CommDataset(train_items, train_transforms, relabel=True)
61 |
62 | train_loader = make_sampler(
63 | train_set=train_set,
64 | num_batch=cfg.SOLVER.IMS_PER_BATCH,
65 | num_instance=cfg.DATALOADER.NUM_INSTANCE,
66 | num_workers=num_workers,
67 | mini_batch_size=cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size(),
68 | drop_last=cfg.DATALOADER.DROP_LAST,
69 | flag1=cfg.DATALOADER.NAIVE_WAY,
70 | flag2=cfg.DATALOADER.DELETE_REM,
71 | cfg = cfg)
72 |
73 | return train_loader
74 |
75 |
76 | def build_reid_test_loader(cfg, dataset_name, opt=None, flag_test=True, shuffle=False, only_gallery=False, only_query=False, eval_time=False):
77 | test_transforms = build_transforms(cfg, is_train=False)
78 | _root = cfg.DATASETS.ROOT_DIR
79 | if opt is None:
80 | dataset = DATASET_REGISTRY.get(dataset_name)(root=_root)
81 | if comm.is_main_process():
82 | if flag_test:
83 | dataset.show_test()
84 | else:
85 | dataset.show_train()
86 | else:
87 | dataset = DATASET_REGISTRY.get(dataset_name)(root=[_root, opt])
88 | if flag_test:
89 | if only_gallery:
90 | test_items = dataset.gallery
91 | elif only_query:
92 | test_set = CommDataset([random.choice(dataset.query)], test_transforms, relabel=False)
93 | return test_set
94 | else:
95 | test_items = dataset.query + dataset.gallery
96 | if shuffle: # only for visualization
97 | random.shuffle(test_items)
98 | else:
99 | test_items = dataset.train
100 |
101 | test_set = CommDataset(test_items, test_transforms, relabel=False)
102 |
103 | batch_size = cfg.TEST.IMS_PER_BATCH
104 | data_sampler = samplers.InferenceSampler(len(test_set))
105 | batch_sampler = torch.utils.data.BatchSampler(data_sampler, batch_size, False)
106 |
107 | gettrace = getattr(sys, 'gettrace', None)
108 | if gettrace():
109 | num_workers = 0
110 | else:
111 | num_workers = cfg.DATALOADER.NUM_WORKERS
112 |
113 | test_loader = DataLoader(
114 | test_set,
115 | batch_sampler=batch_sampler,
116 | num_workers=num_workers, # save some memory
117 | collate_fn=fast_batch_collator)
118 | return test_loader, len(dataset.query)
119 |
120 |
121 | def trivial_batch_collator(batch):
122 | """
123 | A batch collator that does nothing.
124 | """
125 | return batch
126 |
127 |
128 | def fast_batch_collator(batched_inputs):
129 | """
130 | A simple batch collator for most common reid tasks
131 | """
132 | elem = batched_inputs[0]
133 | if isinstance(elem, torch.Tensor):
134 | out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype)
135 | for i, tensor in enumerate(batched_inputs):
136 | out[i] += tensor
137 | return out
138 |
139 | elif isinstance(elem, container_abcs.Mapping):
140 | return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}
141 |
142 | elif isinstance(elem, float):
143 | return torch.tensor(batched_inputs, dtype=torch.float64)
144 | elif isinstance(elem, int_classes):
145 | return torch.tensor(batched_inputs)
146 | elif isinstance(elem, string_classes):
147 | return batched_inputs
148 | elif isinstance(elem, list):
149 | out_g = []
150 | out_pt1 = []
151 | out_pt2 = []
152 | out_pt3 = []
153 | # out = torch.stack(elem, dim=0)
154 | for i, tensor_list in enumerate(batched_inputs):
155 | out_g.append(tensor_list[0])
156 | out_pt1.append(tensor_list[1])
157 | out_pt2.append(tensor_list[2])
158 | out_pt3.append(tensor_list[3])
159 | out = torch.stack(out_g, dim=0)
160 | out_pt1 = torch.stack(out_pt1, dim=0)
161 | out_pt2 = torch.stack(out_pt2, dim=0)
162 | out_pt3 = torch.stack(out_pt3, dim=0)
163 | return out, out_pt1, out_pt2, out_pt3
164 |
165 |
166 | def make_sampler(train_set, num_batch, num_instance, num_workers,
167 | mini_batch_size, drop_last=True, flag1=True, flag2=True, seed=None, cfg=None):
168 |
169 | if flag1:
170 | data_sampler = samplers.RandomIdentitySampler(train_set.img_items,
171 | mini_batch_size, num_instance)
172 | else:
173 | data_sampler = samplers.DomainSuffleSampler(train_set.img_items,
174 | num_batch, num_instance, flag2, seed, cfg)
175 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, drop_last)
176 | train_loader = torch.utils.data.DataLoader(
177 | train_set,
178 | num_workers=num_workers,
179 | batch_sampler=batch_sampler,
180 | collate_fn=fast_batch_collator,
181 | )
182 | return train_loader
--------------------------------------------------------------------------------
/data/common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 |
4 | from .data_utils import read_image
5 |
6 |
7 | class CommDataset(Dataset):
8 | """Image Person ReID Dataset"""
9 |
10 | def __init__(self, img_items, transform=None, relabel=True):
11 | self.img_items = img_items
12 | self.transform = transform
13 | self.relabel = relabel
14 |
15 | self.pid_dict = {}
16 | if self.relabel:
17 | pids = list()
18 | for i, item in enumerate(img_items):
19 | if item[1] in pids: continue
20 | pids.append(item[1])
21 | self.pids = pids
22 | self.pid_dict = dict([(p, i) for i, p in enumerate(self.pids)])
23 |
24 | def __len__(self):
25 | return len(self.img_items)
26 |
27 | def __getitem__(self, index):
28 | if len(self.img_items[index]) > 3:
29 | img_path, pid, camid, others = self.img_items[index]
30 | else:
31 | img_path, pid, camid = self.img_items[index]
32 | others = ''
33 | img = read_image(img_path)
34 | if self.transform is not None: img = self.transform(img)
35 | if self.relabel: pid = self.pid_dict[pid]
36 | return {
37 | "images": img,
38 | "targets": pid,
39 | "camid": camid,
40 | "img_path": img_path,
41 | "others": others
42 | }
43 |
44 | @property
45 | def num_classes(self):
46 | return len(self.pids)
47 |
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image, ImageOps
3 |
4 | from utils.file_io import PathManager
5 |
6 |
7 | def read_image(file_name, format=None):
8 | """
9 | Read an image into the given format.
10 | Will apply rotation and flipping if the image has such exif information.
11 | Args:
12 | file_name (str): image file path
13 | format (str): one of the supported image modes in PIL, or "BGR"
14 | Returns:
15 | image (np.ndarray): an HWC image
16 | """
17 | with PathManager.open(file_name, "rb") as f:
18 | image = Image.open(f)
19 |
20 | # capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973
21 | try:
22 | image = ImageOps.exif_transpose(image)
23 | except Exception:
24 | pass
25 |
26 | if format is not None:
27 | # PIL only supports RGB, so convert to RGB and flip channels over below
28 | conversion_format = format
29 | if format == "BGR":
30 | conversion_format = "RGB"
31 | image = image.convert(conversion_format)
32 | image = np.asarray(image)
33 | if format == "BGR":
34 | # flip channels if needed
35 | image = image[:, :, ::-1]
36 | # PIL squeezes out the channel dimension for "L", so make it HWC
37 | if format == "L":
38 | image = np.expand_dims(image, -1)
39 | image = Image.fromarray(image)
40 | return image
41 |
--------------------------------------------------------------------------------
/data/datasets/AirportALERT.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from . import DATASET_REGISTRY
4 | from .bases import ImageDataset
5 |
6 | __all__ = ['AirportALERT', ]
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class AirportALERT(ImageDataset):
11 | dataset_dir = "AirportALERT"
12 | dataset_name = "airport"
13 |
14 | def __init__(self, root='datasets', **kwargs):
15 | self.root = root
16 | self.train_path = os.path.join(self.root, self.dataset_dir)
17 | self.train_file = os.path.join(self.root, self.dataset_dir, 'filepath.txt')
18 |
19 | required_files = [self.train_file, self.train_path]
20 | self.check_before_run(required_files)
21 |
22 | train = self.process_train(self.train_path, self.train_file)
23 |
24 | super().__init__(train, [], [], **kwargs)
25 |
26 | def process_train(self, dir_path, train_file):
27 | data = []
28 | with open(train_file, "r") as f:
29 | img_paths = [line.strip('\n') for line in f.readlines()]
30 |
31 | for path in img_paths:
32 | split_path = path.split('\\')
33 | img_path = '/'.join(split_path)
34 | camid = self.dataset_name + "_" + split_path[0]
35 | pid = self.dataset_name + "_" + split_path[1]
36 | img_path = os.path.join(dir_path, img_path)
37 | if 11001 <= int(split_path[1]) <= 401999:
38 | data.append([img_path, pid, camid])
39 |
40 | return data
41 |
--------------------------------------------------------------------------------
/data/datasets/DG_cuhk02.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 | from . import DATASET_REGISTRY
5 | from .bases import ImageDataset
6 |
7 | __all__ = ['DG_CUHK02', ]
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class DG_CUHK02(ImageDataset):
12 | dataset_dir = "cuhk02"
13 | dataset_name = "cuhk02"
14 |
15 | def __init__(self, root='datasets', **kwargs):
16 | self.root = root
17 | self.train_path = os.path.join(self.root, self.dataset_dir)
18 |
19 | required_files = [self.train_path]
20 | self.check_before_run(required_files)
21 |
22 | train = self.process_train(self.train_path)
23 |
24 | super().__init__(train, [], [], **kwargs)
25 |
26 | def process_train(self, train_path):
27 |
28 |
29 | cam_split = True
30 |
31 | data = []
32 | file_path = os.listdir(train_path)
33 | for pid_dir in file_path:
34 | img_file = os.path.join(train_path, pid_dir)
35 | cam1_folder = os.path.join(img_file, 'cam1')
36 | cam = '1'
37 |
38 | # if os.path.join(img_file, 'cam1'):
39 | img_paths = glob(os.path.join(cam1_folder, "*.png"))
40 | for img_path in img_paths:
41 | split_path = img_path.split('/')[-1].split('_')
42 | pid = self.dataset_name + "_" + pid_dir + "_" + split_path[0]
43 | camid = int(cam)
44 | # if cam_split:
45 | # camid = self.dataset_name + "_" + pid_dir + "_" + cam
46 | # else:
47 | # camid = self.dataset_name + "_" + cam
48 | data.append([img_path, pid, camid])
49 |
50 | cam2_folder = os.path.join(img_file, 'cam2')
51 | cam = '2'
52 |
53 | img_paths = glob(os.path.join(cam2_folder, "*.png"))
54 | for img_path in img_paths:
55 | split_path = img_path.split('/')[-1].split('_')
56 | pid = self.dataset_name + "_" + pid_dir + "_" + split_path[0]
57 | camid = int(cam)
58 | # if cam_split:
59 | # camid = self.dataset_name + "_" + pid_dir + "_" + cam
60 | # else:
61 | # camid = self.dataset_name + "_" + cam
62 | data.append([img_path, pid, camid])
63 | return data
64 |
--------------------------------------------------------------------------------
/data/datasets/DG_cuhk_sysu.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 | from . import DATASET_REGISTRY
5 | from .bases import ImageDataset
6 |
7 | __all__ = ['DG_CUHK_SYSU', ]
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class DG_CUHK_SYSU(ImageDataset):
12 | dataset_dir = "CUHK-SYSU"
13 | dataset_name = "CUHK-SYSU"
14 |
15 | def __init__(self, root='datasets', **kwargs):
16 | self.root = root
17 | self.train_path = os.path.join(self.root, self.dataset_dir, 'cropped_image')
18 |
19 | required_files = [self.train_path]
20 | self.check_before_run(required_files)
21 |
22 | train = self.process_train(self.train_path)
23 |
24 | super().__init__(train, [], [], **kwargs)
25 |
26 | def process_train(self, train_path):
27 | data = []
28 | img_paths = glob(os.path.join(train_path, "*.png"))
29 | for img_path in img_paths:
30 | split_path = img_path.split('/')[-1].split('_') # p00001_n01_s00001_hard0.png
31 | pid = self.dataset_name + "_" + split_path[0][1:]
32 | camid = int(split_path[2][1:])
33 | # camid = self.dataset_name + "_" + split_path[2][1:]
34 | data.append([img_path, pid, camid])
35 | return data
36 |
--------------------------------------------------------------------------------
/data/datasets/DG_dukemtmcreid.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os.path as osp
3 | import re
4 |
5 | from .bases import ImageDataset
6 | from ..datasets import DATASET_REGISTRY
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class DG_DukeMTMC(ImageDataset):
11 | """DukeMTMC-reID.
12 |
13 | Reference:
14 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
15 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
16 |
17 | URL: ``_
18 |
19 | Dataset statistics:
20 | - identities: 1404 (train + query).
21 | - images:16522 (train) + 2228 (query) + 17661 (gallery).
22 | - cameras: 8.
23 | """
24 | dataset_dir = 'DukeMTMC-reID'
25 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
26 | dataset_name = "dukemtmc"
27 |
28 | def __init__(self, root='datasets', **kwargs):
29 | # self.root = osp.abspath(osp.expanduser(root))
30 | self.root = root
31 | self.dataset_dir = osp.join(self.root, self.dataset_dir)
32 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
33 | self.query_dir = osp.join(self.dataset_dir, 'query')
34 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
35 |
36 | required_files = [
37 | self.dataset_dir,
38 | self.train_dir,
39 | self.query_dir,
40 | self.gallery_dir,
41 | ]
42 | self.check_before_run(required_files)
43 |
44 | train = self.process_dir(self.train_dir)
45 | query = self.process_dir(self.query_dir, is_train=True)
46 | gallery = self.process_dir(self.gallery_dir, is_train=True)
47 | train = train + query + gallery
48 |
49 | super(DG_DukeMTMC, self).__init__(train, [], [], **kwargs)
50 |
51 | def process_dir(self, dir_path, is_train=True):
52 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
53 | pattern = re.compile(r'([-\d]+)_c(\d)')
54 |
55 | data = []
56 | for img_path in img_paths:
57 | pid, camid = map(int, pattern.search(img_path).groups())
58 | assert 1 <= camid <= 8
59 | camid -= 1 # index starts from 0
60 | if is_train:
61 | pid = self.dataset_name + "_" + str(pid)
62 | data.append((img_path, pid, camid))
63 |
64 | return data
65 |
--------------------------------------------------------------------------------
/data/datasets/DG_grid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from scipy.io import loadmat
3 | from glob import glob
4 |
5 | from . import DATASET_REGISTRY
6 | from .bases import ImageDataset
7 | import pdb
8 |
9 | __all__ = ['DG_GRID',]
10 |
11 |
12 | @DATASET_REGISTRY.register()
13 | class DG_GRID(ImageDataset):
14 | dataset_dir = "GRID"
15 | dataset_name = 'grid'
16 |
17 | def __init__(self, root='datasets', split_id = 0, **kwargs):
18 |
19 | if isinstance(root, list):
20 | split_id = root[1]
21 | self.root = root[0]
22 | else:
23 | self.root = root
24 | split_id = 0
25 | self.dataset_dir = os.path.join(self.root, self.dataset_dir)
26 |
27 | self.probe_path = os.path.join(
28 | self.dataset_dir, 'probe'
29 | )
30 | self.gallery_path = os.path.join(
31 | self.dataset_dir, 'gallery'
32 | )
33 | self.split_mat_path = os.path.join(
34 | self.dataset_dir, 'features_and_partitions.mat'
35 | )
36 | self.split_path = os.path.join(self.dataset_dir, 'splits.json')
37 |
38 | required_files = [
39 | self.dataset_dir, self.probe_path, self.gallery_path,
40 | self.split_mat_path
41 | ]
42 | self.check_before_run(required_files)
43 |
44 | self.prepare_split()
45 | splits = self.read_json(self.split_path)
46 | if split_id >= len(splits):
47 | raise ValueError(
48 | 'split_id exceeds range, received {}, '
49 | 'but expected between 0 and {}'.format(
50 | split_id,
51 | len(splits) - 1
52 | )
53 | )
54 | split = splits[split_id]
55 |
56 | train = split['train']
57 | query = split['query']
58 | gallery = split['gallery']
59 |
60 | train = [tuple(item) for item in train]
61 | query = [tuple(item) for item in query]
62 | gallery = [tuple(item) for item in gallery]
63 |
64 | super(DG_GRID, self).__init__(train, query, gallery, **kwargs)
65 |
66 | def prepare_split(self):
67 | if not os.path.exists(self.split_path):
68 | print('Creating 10 random splits')
69 | split_mat = loadmat(self.split_mat_path)
70 | trainIdxAll = split_mat['trainIdxAll'][0] # length = 10
71 | probe_img_paths = sorted(
72 | glob(os.path.join(self.probe_path, '*.jpeg'))
73 | )
74 | gallery_img_paths = sorted(
75 | glob(os.path.join(self.gallery_path, '*.jpeg'))
76 | )
77 |
78 | splits = []
79 | for split_idx in range(10):
80 | train_idxs = trainIdxAll[split_idx][0][0][2][0].tolist()
81 | assert len(train_idxs) == 125
82 | idx2label = {
83 | idx: label
84 | for label, idx in enumerate(train_idxs)
85 | }
86 |
87 | train, query, gallery = [], [], []
88 |
89 | # processing probe folder
90 | for img_path in probe_img_paths:
91 | img_name = os.path.basename(img_path)
92 | img_idx = int(img_name.split('_')[0])
93 | camid = int(
94 | img_name.split('_')[1]
95 | ) - 1 # index starts from 0
96 | if img_idx in train_idxs:
97 | train.append((img_path, idx2label[img_idx], camid))
98 | else:
99 | query.append((img_path, img_idx, camid))
100 |
101 | # process gallery folder
102 | for img_path in gallery_img_paths:
103 | img_name = os.path.basename(img_path)
104 | img_idx = int(img_name.split('_')[0])
105 | camid = int(
106 | img_name.split('_')[1]
107 | ) - 1 # index starts from 0
108 | if img_idx in train_idxs:
109 | train.append((img_path, idx2label[img_idx], camid))
110 | else:
111 | gallery.append((img_path, img_idx, camid))
112 |
113 | split = {
114 | 'train': train,
115 | 'query': query,
116 | 'gallery': gallery,
117 | 'num_train_pids': 125,
118 | 'num_query_pids': 125,
119 | 'num_gallery_pids': 900
120 | }
121 | splits.append(split)
122 |
123 | print('Totally {} splits are created'.format(len(splits)))
124 | self.write_json(splits, self.split_path)
125 | print('Split file saved to {}'.format(self.split_path))
126 |
127 |
128 | def read_json(self, fpath):
129 | import json
130 | """Reads json file from a path."""
131 | with open(fpath, 'r') as f:
132 | obj = json.load(f)
133 | return obj
134 |
135 |
136 | def write_json(self, obj, fpath):
137 | import json
138 | """Writes to a json file."""
139 | self.mkdir_if_missing(os.path.dirname(fpath))
140 | with open(fpath, 'w') as f:
141 | json.dump(obj, f, indent=4, separators=(',', ': '))
142 |
143 |
144 | def mkdir_if_missing(self, dirname):
145 | import errno
146 | """Creates dirname if it is missing."""
147 | if not os.path.exists(dirname):
148 | try:
149 | os.makedirs(dirname)
150 | except OSError as e:
151 | if e.errno != errno.EEXIST:
152 | raise
--------------------------------------------------------------------------------
/data/datasets/DG_iLIDS.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import copy
4 | import random
5 | from collections import defaultdict
6 | from . import DATASET_REGISTRY
7 | from .bases import ImageDataset
8 |
9 | __all__ = ['DG_iLIDS', ]
10 |
11 |
12 | @DATASET_REGISTRY.register()
13 | class DG_iLIDS(ImageDataset):
14 | dataset_dir = "QMUL-iLIDS"
15 | dataset_name = "ilids"
16 |
17 | def __init__(self, root='datasets', split_id = 0, **kwargs):
18 |
19 | if isinstance(root, list):
20 | split_id = root[1]
21 | self.root = root[0]
22 | else:
23 | self.root = root
24 | split_id = 0
25 | self.dataset_dir = os.path.join(self.root, self.dataset_dir)
26 | # self.download_dataset(self.dataset_dir, self.dataset_url)
27 |
28 | self.data_dir = os.path.join(self.dataset_dir, 'images')
29 | self.split_path = os.path.join(self.dataset_dir, 'splits.json')
30 |
31 | required_files = [self.dataset_dir, self.data_dir]
32 | self.check_before_run(required_files)
33 |
34 | self.prepare_split()
35 | splits = self.read_json(self.split_path)
36 | if split_id >= len(splits):
37 | raise ValueError(
38 | 'split_id exceeds range, received {}, but '
39 | 'expected between 0 and {}'.format(split_id,
40 | len(splits) - 1)
41 | )
42 | split = splits[split_id]
43 |
44 | train, query, gallery = self.process_split(split)
45 |
46 | super(DG_iLIDS, self).__init__(train, query, gallery, **kwargs)
47 |
48 | def prepare_split(self):
49 | if not os.path.exists(self.split_path):
50 | print('Creating splits ...')
51 |
52 | paths = glob.glob(os.path.join(self.data_dir, '*.jpg'))
53 | img_names = [os.path.basename(path) for path in paths]
54 | num_imgs = len(img_names)
55 | assert num_imgs == 476, 'There should be 476 images, but ' \
56 | 'got {}, please check the data'.format(num_imgs)
57 |
58 | # store image names
59 | # image naming format:
60 | # the first four digits denote the person ID
61 | # the last four digits denote the sequence index
62 | pid_dict = defaultdict(list)
63 | for img_name in img_names:
64 | pid = int(img_name[:4])
65 | pid_dict[pid].append(img_name)
66 | pids = list(pid_dict.keys())
67 | num_pids = len(pids)
68 | assert num_pids == 119, 'There should be 119 identities, ' \
69 | 'but got {}, please check the data'.format(num_pids)
70 |
71 | num_train_pids = int(num_pids * 0.5)
72 |
73 | splits = []
74 | for _ in range(10):
75 | # randomly choose num_train_pids train IDs and the rest for test IDs
76 | pids_copy = copy.deepcopy(pids)
77 | random.shuffle(pids_copy)
78 | train_pids = pids_copy[:num_train_pids]
79 | test_pids = pids_copy[num_train_pids:]
80 |
81 | train = []
82 | query = []
83 | gallery = []
84 |
85 | # for train IDs, all images are used in the train set.
86 | for pid in train_pids:
87 | img_names = pid_dict[pid]
88 | train.extend(img_names)
89 |
90 | # for each test ID, randomly choose two images, one for
91 | # query and the other one for gallery.
92 | for pid in test_pids:
93 | img_names = pid_dict[pid]
94 | samples = random.sample(img_names, 2)
95 | query.append(samples[0])
96 | gallery.append(samples[1])
97 |
98 | split = {'train': train, 'query': query, 'gallery': gallery}
99 | splits.append(split)
100 |
101 | print('Totally {} splits are created'.format(len(splits)))
102 | self.write_json(splits, self.split_path)
103 | print('Split file is saved to {}'.format(self.split_path))
104 |
105 | def get_pid2label(self, img_names):
106 | pid_container = set()
107 | for img_name in img_names:
108 | pid = int(img_name[:4])
109 | pid_container.add(pid)
110 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
111 | return pid2label
112 |
113 | def parse_img_names(self, img_names, pid2label=None):
114 | data = []
115 |
116 | for img_name in img_names:
117 | pid = int(img_name[:4])
118 | if pid2label is not None:
119 | pid = pid2label[pid]
120 | camid = int(img_name[4:7]) - 1 # 0-based
121 | img_path = os.path.join(self.data_dir, img_name)
122 | data.append((img_path, pid, camid))
123 |
124 | return data
125 |
126 | def process_split(self, split):
127 | train_pid2label = self.get_pid2label(split['train'])
128 | train = self.parse_img_names(split['train'], train_pid2label)
129 | query = self.parse_img_names(split['query'])
130 | gallery = self.parse_img_names(split['gallery'])
131 | return train, query, gallery
132 |
133 | def read_json(self, fpath):
134 | import json
135 | """Reads json file from a path."""
136 | with open(fpath, 'r') as f:
137 | obj = json.load(f)
138 | return obj
139 |
140 | def write_json(self, obj, fpath):
141 | import json
142 | """Writes to a json file."""
143 | self.mkdir_if_missing(os.path.dirname(fpath))
144 | with open(fpath, 'w') as f:
145 | json.dump(obj, f, indent=4, separators=(',', ': '))
146 |
147 | def mkdir_if_missing(self, dirname):
148 | import errno
149 | """Creates dirname if it is missing."""
150 | if not os.path.exists(dirname):
151 | try:
152 | os.makedirs(dirname)
153 | except OSError as e:
154 | if e.errno != errno.EEXIST:
155 | raise
--------------------------------------------------------------------------------
/data/datasets/DG_market1501.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os.path as osp
3 | import re
4 | import warnings
5 |
6 | from .bases import ImageDataset
7 | from ..datasets import DATASET_REGISTRY
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class DG_Market1501(ImageDataset):
12 | """Market1501.
13 |
14 | Reference:
15 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
16 |
17 | URL: ``_
18 |
19 | Dataset statistics:
20 | - identities: 1501 (+1 for background).
21 | - images: 12936 (train) + 3368 (query) + 15913 (gallery).
22 | """
23 | _junk_pids = [0, -1]
24 | dataset_dir = ''
25 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip'
26 | dataset_name = "market1501"
27 |
28 | def __init__(self, root='datasets', market1501_500k=False, **kwargs):
29 | # self.root = osp.abspath(osp.expanduser(root))
30 | self.root = root
31 | self.dataset_dir = osp.join(self.root, self.dataset_dir)
32 |
33 | # allow alternative directory structure
34 | self.data_dir = self.dataset_dir
35 | data_dir = osp.join(self.data_dir, 'market1501')
36 | if osp.isdir(data_dir):
37 | self.data_dir = data_dir
38 | else:
39 | warnings.warn('The current data structure is deprecated. Please '
40 | 'put data folders such as "bounding_box_train" under '
41 | '"Market-1501-v15.09.15".')
42 |
43 | self.train_dir = osp.join(self.data_dir, 'bounding_box_train')
44 | self.query_dir = osp.join(self.data_dir, 'query')
45 | self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test')
46 | self.extra_gallery_dir = osp.join(self.data_dir, 'images')
47 | self.market1501_500k = market1501_500k
48 |
49 | required_files = [
50 | self.data_dir,
51 | self.train_dir,
52 | self.query_dir,
53 | self.gallery_dir,
54 | ]
55 | if self.market1501_500k:
56 | required_files.append(self.extra_gallery_dir)
57 | self.check_before_run(required_files)
58 |
59 | train = self.process_dir(self.train_dir)
60 | query = self.process_dir(self.query_dir, is_train=True)
61 | gallery = self.process_dir(self.gallery_dir, is_train=True)
62 | train = train + query + gallery
63 | if self.market1501_500k:
64 | gallery += self.process_dir(self.extra_gallery_dir, is_train=False)
65 |
66 | super(DG_Market1501, self).__init__(train, [], [], **kwargs)
67 |
68 | def process_dir(self, dir_path, is_train=True):
69 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
70 | pattern = re.compile(r'([-\d]+)_c(\d)')
71 |
72 | data = []
73 | for img_path in img_paths:
74 | pid, camid = map(int, pattern.search(img_path).groups())
75 | if pid == -1 or pid == 0:
76 | continue # junk images are just ignored
77 | assert 0 <= pid <= 1501 # pid == 0 means background
78 | assert 1 <= camid <= 6
79 | camid -= 1 # index starts from 0
80 | if is_train:
81 | pid = self.dataset_name + "_" + str(pid)
82 | data.append((img_path, pid, camid))
83 |
84 | return data
85 |
--------------------------------------------------------------------------------
/data/datasets/DG_prid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from scipy.io import loadmat
3 | from glob import glob
4 | import random
5 | from . import DATASET_REGISTRY
6 | from .bases import ImageDataset
7 | import pdb
8 |
9 | __all__ = ['DG_PRID', ]
10 |
11 |
12 | @DATASET_REGISTRY.register()
13 | class DG_PRID(ImageDataset):
14 | dataset_dir = "prid_2011"
15 | dataset_name = 'prid'
16 | _junk_pids = list(range(201, 750))
17 |
18 | def __init__(self, root='datasets', split_id=0, **kwargs):
19 |
20 | if isinstance(root, list):
21 | split_id = root[1]
22 | self.root = root[0]
23 | else:
24 | self.root = root
25 | split_id = 0
26 | self.dataset_dir = os.path.join(self.root, self.dataset_dir)
27 | # self.download_dataset(self.dataset_dir, self.dataset_url)
28 |
29 | self.cam_a_dir = os.path.join(
30 | self.dataset_dir, 'single_shot', 'cam_a'
31 | )
32 | self.cam_b_dir = os.path.join(
33 | self.dataset_dir, 'single_shot', 'cam_b'
34 | )
35 | self.split_path = os.path.join(self.dataset_dir, 'splits_single_shot.json')
36 |
37 | required_files = [
38 | self.dataset_dir,
39 | self.cam_a_dir,
40 | self.cam_b_dir
41 | ]
42 | self.check_before_run(required_files)
43 |
44 | self.prepare_split()
45 | splits = self.read_json(self.split_path)
46 | if split_id >= len(splits):
47 | raise ValueError(
48 | 'split_id exceeds range, received {}, but expected between 0 and {}'
49 | .format(split_id,
50 | len(splits) - 1)
51 | )
52 | split = splits[split_id]
53 |
54 | train, query, gallery = self.process_split(split)
55 |
56 | super(DG_PRID, self).__init__(train, query, gallery, **kwargs)
57 |
58 | def prepare_split(self):
59 | if not os.path.exists(self.split_path):
60 | print('Creating splits ...')
61 |
62 | splits = []
63 | for _ in range(10):
64 | # randomly sample 100 IDs for train and use the rest 100 IDs for test
65 | # (note: there are only 200 IDs appearing in both views)
66 | pids = [i for i in range(1, 201)]
67 | train_pids = random.sample(pids, 100)
68 | train_pids.sort()
69 | test_pids = [i for i in pids if i not in train_pids]
70 | split = {'train': train_pids, 'test': test_pids}
71 | splits.append(split)
72 |
73 | print('Totally {} splits are created'.format(len(splits)))
74 | self.write_json(splits, self.split_path)
75 | print('Split file is saved to {}'.format(self.split_path))
76 |
77 | def process_split(self, split):
78 | train_pids = split['train']
79 | test_pids = split['test']
80 |
81 | train_pid2label = {pid: label for label, pid in enumerate(train_pids)}
82 |
83 | # train
84 | train = []
85 | for pid in train_pids:
86 | img_name = 'person_' + str(pid).zfill(4) + '.png'
87 | pid = train_pid2label[pid]
88 | img_a_path = os.path.join(self.cam_a_dir, img_name)
89 | train.append((img_a_path, pid, 0))
90 | img_b_path = os.path.join(self.cam_b_dir, img_name)
91 | train.append((img_b_path, pid, 1))
92 |
93 | # query and gallery
94 | query, gallery = [], []
95 | for pid in test_pids:
96 | img_name = 'person_' + str(pid).zfill(4) + '.png'
97 | img_a_path = os.path.join(self.cam_a_dir, img_name)
98 | query.append((img_a_path, pid, 0))
99 | img_b_path = os.path.join(self.cam_b_dir, img_name)
100 | gallery.append((img_b_path, pid, 1))
101 | for pid in range(201, 750):
102 | img_name = 'person_' + str(pid).zfill(4) + '.png'
103 | img_b_path = os.path.join(self.cam_b_dir, img_name)
104 | gallery.append((img_b_path, pid, 1))
105 |
106 | return train, query, gallery
107 |
108 | def read_json(self, fpath):
109 | import json
110 | """Reads json file from a path."""
111 | with open(fpath, 'r') as f:
112 | obj = json.load(f)
113 | return obj
114 |
115 | def write_json(self, obj, fpath):
116 | import json
117 | """Writes to a json file."""
118 | self.mkdir_if_missing(os.path.dirname(fpath))
119 | with open(fpath, 'w') as f:
120 | json.dump(obj, f, indent=4, separators=(',', ': '))
121 |
122 | def mkdir_if_missing(self, dirname):
123 | import errno
124 | """Creates dirname if it is missing."""
125 | if not os.path.exists(dirname):
126 | try:
127 | os.makedirs(dirname)
128 | except OSError as e:
129 | if e.errno != errno.EEXIST:
130 | raise
--------------------------------------------------------------------------------
/data/datasets/DG_viper.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 | from . import DATASET_REGISTRY
5 | from .bases import ImageDataset
6 |
7 | __all__ = ['DG_viper', ]
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class DG_VIPeR(ImageDataset):
12 | dataset_dir = "viper"
13 | dataset_name = "viper"
14 |
15 | def __init__(self, root='datasets', **kwargs):
16 | if isinstance(root, list):
17 | type = root[1]
18 | self.root = root[0]
19 | else:
20 | self.root = root
21 | type = 'split_1a'
22 | self.train_dir = os.path.join(self.root, self.dataset_dir, type, 'train')
23 | self.query_dir = os.path.join(self.root, self.dataset_dir, type, 'query')
24 | self.gallery_dir = os.path.join(self.root, self.dataset_dir, type, 'gallery')
25 |
26 | required_files = [
27 | self.train_dir,
28 | self.query_dir,
29 | self.gallery_dir,
30 | ]
31 | self.check_before_run(required_files)
32 |
33 | train = self.process_train(self.train_dir, is_train = True)
34 | query = self.process_train(self.query_dir, is_train = False)
35 | gallery = self.process_train(self.gallery_dir, is_train = False)
36 |
37 | super().__init__(train, query, gallery, **kwargs)
38 |
39 | def process_train(self, path, is_train = True):
40 | data = []
41 | img_list = glob(os.path.join(path, '*.png'))
42 | for img_path in img_list:
43 | img_name = img_path.split('/')[-1] # p000_c1_d045.png
44 | split_name = img_name.split('_')
45 | pid = int(split_name[0][1:])
46 | if is_train:
47 | pid = self.dataset_name + "_" + str(pid)
48 | camid = int(split_name[1][1:])
49 | # dirid = int(split_name[2][1:-4])
50 | data.append([img_path, pid, camid])
51 |
52 | return data
--------------------------------------------------------------------------------
/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from utils.registry import Registry
2 |
3 | DATASET_REGISTRY = Registry("DATASET")
4 | DATASET_REGISTRY.__doc__ = """
5 | Registry for datasets
6 | It must returns an instance of :class:`Backbone`.
7 | """
8 |
9 |
10 | # Person re-id datasets
11 | from .cuhk03 import CUHK03
12 | from .DG_cuhk_sysu import DG_CUHK_SYSU
13 | from .DG_cuhk02 import DG_CUHK02
14 | from .DG_cuhk03_labeled import DG_CUHK03_labeled
15 | from .DG_cuhk03_detected import DG_CUHK03_detected
16 | from .dukemtmcreid import DukeMTMC
17 | from .DG_dukemtmcreid import DG_DukeMTMC
18 | from .market1501 import Market1501
19 | from .DG_market1501 import DG_Market1501
20 | from .msmt17 import MSMT17
21 | from .AirportALERT import AirportALERT
22 | from .iLIDS import iLIDS
23 | from .pku import PKU
24 | from .grid import GRID
25 | from .prai import PRAI
26 | from .prid import PRID
27 | from .DG_prid import DG_PRID
28 | from .DG_grid import DG_GRID
29 | from .sensereid import SenseReID
30 | from .sysu_mm import SYSU_mm
31 | from .thermalworld import Thermalworld
32 | from .pes3d import PeS3D
33 | from .caviara import CAVIARa
34 | from .viper import VIPeR
35 | from .DG_viper import DG_VIPeR
36 | from .DG_iLIDS import DG_iLIDS
37 | from .lpw import LPW
38 | from .shinpuhkan import Shinpuhkan
39 | # Vehicle re-id datasets
40 | from .veri import VeRi
41 | from .veri_keypoint import VeRi_keypoint
42 | from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID
43 | from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild
44 | from .randperson import RandPerson
45 |
46 |
47 | __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
48 |
--------------------------------------------------------------------------------
/data/datasets/bases.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 |
5 | class Dataset(object):
6 | """An abstract class representing a Dataset.
7 | This is the base class for ``ImageDataset`` and ``VideoDataset``.
8 | Args:
9 | train (list): contains tuples of (img_path(s), pid, camid).
10 | query (list): contains tuples of (img_path(s), pid, camid).
11 | gallery (list): contains tuples of (img_path(s), pid, camid).
12 | transform: transform function.
13 | mode (str): 'train', 'query' or 'gallery'.
14 | combineall (bool): combines train, query and gallery in a
15 | dataset for training.
16 | verbose (bool): show information.
17 | """
18 | _junk_pids = [] # contains useless person IDs, e.g. background, false detections
19 |
20 | def __init__(self, train, query, gallery, transform=None, mode='train',
21 | combineall=False, verbose=True, **kwargs):
22 | self.train = train
23 | self.query = query
24 | self.gallery = gallery
25 | self.query = [tuple(q_tuple)+({'q_or_g': 'query'},) for q_tuple in self.query]
26 | self.gallery = [tuple(g_tuple)+({'q_or_g': 'gallery'},) for g_tuple in self.gallery]
27 | self.transform = transform
28 | self.mode = mode
29 | self.combineall = combineall
30 | self.verbose = verbose
31 |
32 | # if self.train != []:
33 | self.num_train_pids = self.get_num_pids(self.train)
34 | self.num_train_cams = self.get_num_cams(self.train)
35 |
36 | if self.combineall:
37 | self.combine_all()
38 |
39 | if self.mode == 'train':
40 | self.data = self.train
41 | elif self.mode == 'query':
42 | self.data = self.query
43 | elif self.mode == 'gallery':
44 | self.data = self.gallery
45 | else:
46 | raise ValueError('Invalid mode. Got {}, but expected to be '
47 | 'one of [train | query | gallery]'.format(self.mode))
48 |
49 | # if self.verbose:
50 | # self.show_summary()
51 |
52 | def __getitem__(self, index):
53 | raise NotImplementedError
54 |
55 | def __len__(self):
56 | return len(self.data)
57 |
58 | def __radd__(self, other):
59 | """Supports sum([dataset1, dataset2, dataset3])."""
60 | if other == 0:
61 | return self
62 | else:
63 | return self.__add__(other)
64 |
65 | def parse_data(self, data):
66 | """Parses data list and returns the number of person IDs
67 | and the number of camera views.
68 | Args:
69 | data (list): contains tuples of (img_path(s), pid, camid)
70 | """
71 | pids = set()
72 | cams = set()
73 | if len(data[0]) > 3:
74 | for _, pid, camid, _ in data:
75 | pids.add(pid)
76 | cams.add(camid)
77 | else:
78 | for _, pid, camid in data:
79 | pids.add(pid)
80 | cams.add(camid)
81 | return len(pids), len(cams)
82 |
83 | def get_num_pids(self, data):
84 | """Returns the number of training person identities."""
85 | return self.parse_data(data)[0]
86 |
87 | def get_num_cams(self, data):
88 | """Returns the number of training cameras."""
89 | return self.parse_data(data)[1]
90 |
91 | def show_summary(self):
92 | """Shows dataset statistics."""
93 | pass
94 |
95 | def combine_all(self):
96 | """Combines train, query and gallery in a dataset for training."""
97 | combined = copy.deepcopy(self.train)
98 |
99 | def _combine_data(data):
100 | for img_path, pid, camid, _ in data:
101 | if pid in self._junk_pids:
102 | continue
103 | pid = self.dataset_name + "_" + str(pid)
104 | combined.append((img_path, pid, camid))
105 |
106 | _combine_data(self.query)
107 | _combine_data(self.gallery)
108 |
109 | self.train = combined
110 | self.num_train_pids = self.get_num_pids(self.train)
111 |
112 | def check_before_run(self, required_files):
113 | """Checks if required files exist before going deeper.
114 | Args:
115 | required_files (str or list): string file name(s).
116 | """
117 | if isinstance(required_files, str):
118 | required_files = [required_files]
119 |
120 | for fpath in required_files:
121 | if not os.path.exists(fpath):
122 | raise RuntimeError('"{}" is not found'.format(fpath))
123 |
124 | def __repr__(self):
125 | num_train_pids, num_train_cams = self.parse_data(self.train)
126 | num_query_pids, num_query_cams = self.parse_data(self.query)
127 | num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
128 |
129 | msg = ' ----------------------------------------\n' \
130 | ' subset | # ids | # items | # cameras\n' \
131 | ' ----------------------------------------\n' \
132 | ' train | {:5d} | {:7d} | {:9d}\n' \
133 | ' query | {:5d} | {:7d} | {:9d}\n' \
134 | ' gallery | {:5d} | {:7d} | {:9d}\n' \
135 | ' ----------------------------------------\n' \
136 | ' items: images/tracklets for image/video dataset\n'.format(
137 | num_train_pids, len(self.train), num_train_cams,
138 | num_query_pids, len(self.query), num_query_cams,
139 | num_gallery_pids, len(self.gallery), num_gallery_cams
140 | )
141 |
142 | return msg
143 |
144 |
145 | class ImageDataset(Dataset):
146 | """A base class representing ImageDataset.
147 | All other image datasets should subclass it.
148 | ``__getitem__`` returns an image given index.
149 | It will return ``img``, ``pid``, ``camid`` and ``img_path``
150 | where ``img`` has shape (channel, height, width). As a result,
151 | data in each batch has shape (batch_size, channel, height, width).
152 | """
153 |
154 | def __init__(self, train, query, gallery, **kwargs):
155 | super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
156 |
157 | def show_train(self):
158 | logger = logging.getLogger('PAT')
159 | num_train_pids, num_train_cams = self.parse_data(self.train)
160 | logger.info('=> Loaded {}'.format(self.__class__.__name__))
161 | logger.info(' ----------------------------------------')
162 | logger.info(' subset | # ids | # images | # cameras')
163 | logger.info(' ----------------------------------------')
164 | logger.info(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
165 | logger.info(' ----------------------------------------')
166 |
167 | def show_test(self):
168 | logger = logging.getLogger('PAT')
169 | num_query_pids, num_query_cams = self.parse_data(self.query)
170 | num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
171 | logger.info('=> Loaded {}'.format(self.__class__.__name__))
172 | logger.info(' ----------------------------------------')
173 | logger.info(' subset | # ids | # images | # cameras')
174 | logger.info(' ----------------------------------------')
175 | logger.info(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
176 | logger.info(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
177 | logger.info(' ----------------------------------------')
178 |
--------------------------------------------------------------------------------
/data/datasets/caviara.py:
--------------------------------------------------------------------------------
1 | import os
2 | from scipy.io import loadmat
3 | from glob import glob
4 |
5 | from . import DATASET_REGISTRY
6 | from .bases import ImageDataset
7 | import pdb
8 | import random
9 | import numpy as np
10 |
11 | __all__ = ['CAVIARa',]
12 |
13 |
14 | @DATASET_REGISTRY.register()
15 | class CAVIARa(ImageDataset):
16 | dataset_dir = "CAVIARa"
17 | dataset_name = "caviara"
18 |
19 | def __init__(self, root='datasets', **kwargs):
20 | self.root = root
21 | self.train_path = os.path.join(self.root, self.dataset_dir)
22 |
23 | required_files = [self.train_path]
24 | self.check_before_run(required_files)
25 |
26 | train = self.process_train(self.train_path)
27 |
28 | super().__init__(train, [], [], **kwargs)
29 |
30 | def process_train(self, train_path):
31 | data = []
32 |
33 | img_list = glob(os.path.join(train_path, "*.jpg"))
34 | for img_path in img_list:
35 | img_name = img_path.split('/')[-1]
36 | pid = self.dataset_name + "_" + img_name[:4]
37 | camid = self.dataset_name + "_cam0"
38 | data.append([img_path, pid, camid])
39 |
40 | return data
41 |
--------------------------------------------------------------------------------
/data/datasets/dukemtmcreid.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os.path as osp
3 | import re
4 |
5 | from .bases import ImageDataset
6 | from ..datasets import DATASET_REGISTRY
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class DukeMTMC(ImageDataset):
11 | """DukeMTMC-reID.
12 |
13 | Reference:
14 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
15 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
16 |
17 | URL: ``_
18 |
19 | Dataset statistics:
20 | - identities: 1404 (train + query).
21 | - images:16522 (train) + 2228 (query) + 17661 (gallery).
22 | - cameras: 8.
23 | """
24 | dataset_dir = 'DukeMTMC-reID'
25 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
26 | dataset_name = "dukemtmc"
27 |
28 | def __init__(self, root='datasets', **kwargs):
29 | # self.root = osp.abspath(osp.expanduser(root))
30 | self.root = root
31 | self.dataset_dir = osp.join(self.root, self.dataset_dir)
32 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
33 | self.query_dir = osp.join(self.dataset_dir, 'query')
34 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
35 |
36 | required_files = [
37 | self.dataset_dir,
38 | self.train_dir,
39 | self.query_dir,
40 | self.gallery_dir,
41 | ]
42 | self.check_before_run(required_files)
43 |
44 | train = self.process_dir(self.train_dir)
45 | query = self.process_dir(self.query_dir, is_train=False)
46 | gallery = self.process_dir(self.gallery_dir, is_train=False)
47 |
48 | super(DukeMTMC, self).__init__(train, query, gallery, **kwargs)
49 |
50 | def process_dir(self, dir_path, is_train=True):
51 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
52 | pattern = re.compile(r'([-\d]+)_c(\d)')
53 |
54 | data = []
55 | for img_path in img_paths:
56 | pid, camid = map(int, pattern.search(img_path).groups())
57 | assert 1 <= camid <= 8
58 | camid -= 1 # index starts from 0
59 | if is_train:
60 | pid = self.dataset_name + "_" + str(pid)
61 | data.append((img_path, pid, camid))
62 |
63 | return data
64 |
--------------------------------------------------------------------------------
/data/datasets/grid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from scipy.io import loadmat
3 | from glob import glob
4 |
5 | from . import DATASET_REGISTRY
6 | from .bases import ImageDataset
7 | import pdb
8 |
9 | __all__ = ['GRID',]
10 |
11 |
12 | @DATASET_REGISTRY.register()
13 | class GRID(ImageDataset):
14 | dataset_dir = "GRID"
15 | dataset_name = 'grid'
16 |
17 | def __init__(self, root='datasets', split_id = 0, **kwargs):
18 | self.root = root
19 | self.dataset_dir = os.path.join(self.root, self.dataset_dir)
20 |
21 | self.probe_path = os.path.join(
22 | self.dataset_dir, 'probe'
23 | )
24 | self.gallery_path = os.path.join(
25 | self.dataset_dir, 'gallery'
26 | )
27 | self.split_mat_path = os.path.join(
28 | self.dataset_dir, 'features_and_partitions.mat'
29 | )
30 | self.split_path = os.path.join(self.dataset_dir, 'splits.json')
31 |
32 | required_files = [
33 | self.dataset_dir, self.probe_path, self.gallery_path,
34 | self.split_mat_path
35 | ]
36 | self.check_before_run(required_files)
37 |
38 | self.prepare_split()
39 | splits = self.read_json(self.split_path)
40 | if split_id >= len(splits):
41 | raise ValueError(
42 | 'split_id exceeds range, received {}, '
43 | 'but expected between 0 and {}'.format(
44 | split_id,
45 | len(splits) - 1
46 | )
47 | )
48 | split = splits[split_id]
49 |
50 | train = split['train']
51 | query = split['query']
52 | gallery = split['gallery']
53 |
54 | train = [tuple(item) for item in train]
55 | query = [tuple(item) for item in query]
56 | gallery = [tuple(item) for item in gallery]
57 |
58 | super(GRID, self).__init__(train, query, gallery, **kwargs)
59 |
60 | def prepare_split(self):
61 | if not os.path.exists(self.split_path):
62 | print('Creating 10 random splits')
63 | split_mat = loadmat(self.split_mat_path)
64 | trainIdxAll = split_mat['trainIdxAll'][0] # length = 10
65 | probe_img_paths = sorted(
66 | glob(os.path.join(self.probe_path, '*.jpeg'))
67 | )
68 | gallery_img_paths = sorted(
69 | glob(os.path.join(self.gallery_path, '*.jpeg'))
70 | )
71 |
72 | splits = []
73 | for split_idx in range(10):
74 | train_idxs = trainIdxAll[split_idx][0][0][2][0].tolist()
75 | assert len(train_idxs) == 125
76 | idx2label = {
77 | idx: label
78 | for label, idx in enumerate(train_idxs)
79 | }
80 |
81 | train, query, gallery = [], [], []
82 |
83 | # processing probe folder
84 | for img_path in probe_img_paths:
85 | img_name = os.path.basename(img_path)
86 | img_idx = int(img_name.split('_')[0])
87 | camid = int(
88 | img_name.split('_')[1]
89 | ) - 1 # index starts from 0
90 | if img_idx in train_idxs:
91 | train.append((img_path, idx2label[img_idx], camid))
92 | else:
93 | query.append((img_path, img_idx, camid))
94 |
95 | # process gallery folder
96 | for img_path in gallery_img_paths:
97 | img_name = os.path.basename(img_path)
98 | img_idx = int(img_name.split('_')[0])
99 | camid = int(
100 | img_name.split('_')[1]
101 | ) - 1 # index starts from 0
102 | if img_idx in train_idxs:
103 | train.append((img_path, idx2label[img_idx], camid))
104 | else:
105 | gallery.append((img_path, img_idx, camid))
106 |
107 | split = {
108 | 'train': train,
109 | 'query': query,
110 | 'gallery': gallery,
111 | 'num_train_pids': 125,
112 | 'num_query_pids': 125,
113 | 'num_gallery_pids': 900
114 | }
115 | splits.append(split)
116 |
117 | print('Totally {} splits are created'.format(len(splits)))
118 | self.write_json(splits, self.split_path)
119 | print('Split file saved to {}'.format(self.split_path))
120 |
121 |
122 | def read_json(self, fpath):
123 | import json
124 | """Reads json file from a path."""
125 | with open(fpath, 'r') as f:
126 | obj = json.load(f)
127 | return obj
128 |
129 |
130 | def write_json(self, obj, fpath):
131 | import json
132 | """Writes to a json file."""
133 | self.mkdir_if_missing(os.path.dirname(fpath))
134 | with open(fpath, 'w') as f:
135 | json.dump(obj, f, indent=4, separators=(',', ': '))
136 |
137 |
138 | def mkdir_if_missing(self, dirname):
139 | import errno
140 | """Creates dirname if it is missing."""
141 | if not os.path.exists(dirname):
142 | try:
143 | os.makedirs(dirname)
144 | except OSError as e:
145 | if e.errno != errno.EEXIST:
146 | raise
--------------------------------------------------------------------------------
/data/datasets/iLIDS.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import copy
4 | import random
5 | from collections import defaultdict
6 | from . import DATASET_REGISTRY
7 | from .bases import ImageDataset
8 |
9 | __all__ = ['iLIDS', ]
10 |
11 |
12 | @DATASET_REGISTRY.register()
13 | class iLIDS(ImageDataset):
14 | dataset_dir = "QMUL-iLIDS"
15 | dataset_name = "ilids"
16 |
17 | def __init__(self, root='datasets', split_id = 0, **kwargs):
18 | # self.root = os.path.abspath(os.path.expanduser(root))
19 | self.root = root
20 | self.dataset_dir = os.path.join(self.root, self.dataset_dir)
21 | # self.download_dataset(self.dataset_dir, self.dataset_url)
22 |
23 | self.data_dir = os.path.join(self.dataset_dir, 'images')
24 | self.split_path = os.path.join(self.dataset_dir, 'splits.json')
25 |
26 | required_files = [self.dataset_dir, self.data_dir]
27 | self.check_before_run(required_files)
28 |
29 | self.prepare_split()
30 | splits = self.read_json(self.split_path)
31 | if split_id >= len(splits):
32 | raise ValueError(
33 | 'split_id exceeds range, received {}, but '
34 | 'expected between 0 and {}'.format(split_id,
35 | len(splits) - 1)
36 | )
37 | split = splits[split_id]
38 |
39 | train, query, gallery = self.process_split(split)
40 |
41 | super(iLIDS, self).__init__(train, query, gallery, **kwargs)
42 |
43 | def prepare_split(self):
44 | if not os.path.exists(self.split_path):
45 | print('Creating splits ...')
46 |
47 | paths = glob.glob(os.path.join(self.data_dir, '*.jpg'))
48 | img_names = [os.path.basename(path) for path in paths]
49 | num_imgs = len(img_names)
50 | assert num_imgs == 476, 'There should be 476 images, but ' \
51 | 'got {}, please check the data'.format(num_imgs)
52 |
53 | # store image names
54 | # image naming format:
55 | # the first four digits denote the person ID
56 | # the last four digits denote the sequence index
57 | pid_dict = defaultdict(list)
58 | for img_name in img_names:
59 | pid = int(img_name[:4])
60 | pid_dict[pid].append(img_name)
61 | pids = list(pid_dict.keys())
62 | num_pids = len(pids)
63 | assert num_pids == 119, 'There should be 119 identities, ' \
64 | 'but got {}, please check the data'.format(num_pids)
65 |
66 | num_train_pids = int(num_pids * 0.5)
67 |
68 | splits = []
69 | for _ in range(10):
70 | # randomly choose num_train_pids train IDs and the rest for test IDs
71 | pids_copy = copy.deepcopy(pids)
72 | random.shuffle(pids_copy)
73 | train_pids = pids_copy[:num_train_pids]
74 | test_pids = pids_copy[num_train_pids:]
75 |
76 | train = []
77 | query = []
78 | gallery = []
79 |
80 | # for train IDs, all images are used in the train set.
81 | for pid in train_pids:
82 | img_names = pid_dict[pid]
83 | train.extend(img_names)
84 |
85 | # for each test ID, randomly choose two images, one for
86 | # query and the other one for gallery.
87 | for pid in test_pids:
88 | img_names = pid_dict[pid]
89 | samples = random.sample(img_names, 2)
90 | query.append(samples[0])
91 | gallery.append(samples[1])
92 |
93 | split = {'train': train, 'query': query, 'gallery': gallery}
94 | splits.append(split)
95 |
96 | print('Totally {} splits are created'.format(len(splits)))
97 | self.write_json(splits, self.split_path)
98 | print('Split file is saved to {}'.format(self.split_path))
99 |
100 | def get_pid2label(self, img_names):
101 | pid_container = set()
102 | for img_name in img_names:
103 | pid = int(img_name[:4])
104 | pid_container.add(pid)
105 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
106 | return pid2label
107 |
108 | def parse_img_names(self, img_names, pid2label=None):
109 | data = []
110 |
111 | for img_name in img_names:
112 | pid = int(img_name[:4])
113 | if pid2label is not None:
114 | pid = pid2label[pid]
115 | camid = int(img_name[4:7]) - 1 # 0-based
116 | img_path = os.path.join(self.data_dir, img_name)
117 | data.append((img_path, pid, camid))
118 |
119 | return data
120 |
121 | def process_split(self, split):
122 | train_pid2label = self.get_pid2label(split['train'])
123 | train = self.parse_img_names(split['train'], train_pid2label)
124 | query = self.parse_img_names(split['query'])
125 | gallery = self.parse_img_names(split['gallery'])
126 | return train, query, gallery
127 |
128 | def read_json(self, fpath):
129 | import json
130 | """Reads json file from a path."""
131 | with open(fpath, 'r') as f:
132 | obj = json.load(f)
133 | return obj
134 |
135 | def write_json(self, obj, fpath):
136 | import json
137 | """Writes to a json file."""
138 | self.mkdir_if_missing(os.path.dirname(fpath))
139 | with open(fpath, 'w') as f:
140 | json.dump(obj, f, indent=4, separators=(',', ': '))
141 |
142 | def mkdir_if_missing(self, dirname):
143 | import errno
144 | """Creates dirname if it is missing."""
145 | if not os.path.exists(dirname):
146 | try:
147 | os.makedirs(dirname)
148 | except OSError as e:
149 | if e.errno != errno.EEXIST:
150 | raise
151 |
--------------------------------------------------------------------------------
/data/datasets/lpw.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 | from . import DATASET_REGISTRY
5 | from .bases import ImageDataset
6 |
7 | __all__ = ['LPW', ]
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class LPW(ImageDataset):
12 | dataset_dir = "pep_256x128"
13 | dataset_name = "lpw"
14 |
15 | def __init__(self, root='datasets', **kwargs):
16 | self.root = root
17 | self.train_path = os.path.join(self.root, self.dataset_dir)
18 |
19 | required_files = [self.train_path]
20 | self.check_before_run(required_files)
21 |
22 | train = self.process_train(self.train_path)
23 |
24 | super().__init__(train, [], [], **kwargs)
25 |
26 | def process_train(self, train_path):
27 | data = []
28 |
29 | file_path_list = ['scen1', 'scen2', 'scen3']
30 |
31 | for scene in file_path_list:
32 | cam_list = os.listdir(os.path.join(train_path, scene))
33 | for cam in cam_list:
34 | camid = self.dataset_name + "_" + cam
35 | pid_list = os.listdir(os.path.join(train_path, scene, cam))
36 | for pid_dir in pid_list:
37 | img_paths = glob(os.path.join(train_path, scene, cam, pid_dir, "*.jpg"))
38 | for img_path in img_paths:
39 | pid = self.dataset_name + "_" + scene + "-" + pid_dir
40 | data.append([img_path, pid, camid])
41 | return data
42 |
--------------------------------------------------------------------------------
/data/datasets/market1501.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os.path as osp
3 | import re
4 | import warnings
5 |
6 | from .bases import ImageDataset
7 | from ..datasets import DATASET_REGISTRY
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class Market1501(ImageDataset):
12 | """Market1501.
13 |
14 | Reference:
15 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
16 |
17 | URL: ``_
18 |
19 | Dataset statistics:
20 | - identities: 1501 (+1 for background).
21 | - images: 12936 (train) + 3368 (query) + 15913 (gallery).
22 | """
23 | _junk_pids = [0, -1]
24 | dataset_dir = 'market1501'
25 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip'
26 | dataset_name = "market1501"
27 |
28 | def __init__(self, root='datasets', market1501_500k=False, **kwargs):
29 | # self.root = osp.abspath(osp.expanduser(root))
30 | self.root = root
31 | self.dataset_dir = osp.join(self.root, self.dataset_dir)
32 |
33 | # allow alternative directory structure
34 | self.data_dir = self.dataset_dir
35 | data_dir = osp.join(self.data_dir, 'Market1501')
36 | if osp.isdir(data_dir):
37 | self.data_dir = data_dir
38 | else:
39 | warnings.warn('The current data structure is deprecated. Please '
40 | 'put data folders such as "bounding_box_train" under '
41 | '"Market-1501-v15.09.15".')
42 |
43 | self.train_dir = osp.join(self.data_dir, 'bounding_box_train')
44 | self.query_dir = osp.join(self.data_dir, 'query')
45 | self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test')
46 | self.extra_gallery_dir = osp.join(self.data_dir, 'images')
47 | self.market1501_500k = market1501_500k
48 |
49 | required_files = [
50 | self.data_dir,
51 | self.train_dir,
52 | self.query_dir,
53 | self.gallery_dir,
54 | ]
55 | if self.market1501_500k:
56 | required_files.append(self.extra_gallery_dir)
57 | self.check_before_run(required_files)
58 |
59 | train = self.process_dir(self.train_dir)
60 | query = self.process_dir(self.query_dir, is_train=False)
61 | gallery = self.process_dir(self.gallery_dir, is_train=False)
62 | if self.market1501_500k:
63 | gallery += self.process_dir(self.extra_gallery_dir, is_train=False)
64 |
65 | super(Market1501, self).__init__(train, query, gallery, **kwargs)
66 |
67 | def process_dir(self, dir_path, is_train=True):
68 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
69 | pattern = re.compile(r'([-\d]+)_c(\d)')
70 |
71 | data = []
72 | for img_path in img_paths:
73 | pid, camid = map(int, pattern.search(img_path).groups())
74 | if pid == -1:
75 | continue # junk images are just ignored
76 | assert 0 <= pid <= 1501 # pid == 0 means background
77 | assert 1 <= camid <= 6
78 | camid -= 1 # index starts from 0
79 | if is_train:
80 | pid = self.dataset_name + "_" + str(pid)
81 | data.append((img_path, pid, camid))
82 |
83 | return data
84 |
--------------------------------------------------------------------------------
/data/datasets/msmt17.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import os.path as osp
4 |
5 | from .bases import ImageDataset
6 | from ..datasets import DATASET_REGISTRY
7 | ##### Log #####
8 | # 22.01.2019
9 | # - add v2
10 | # - v1 and v2 differ in dir names
11 | # - note that faces in v2 are blurred
12 | TRAIN_DIR_KEY = 'train_dir'
13 | TEST_DIR_KEY = 'test_dir'
14 | VERSION_DICT = {
15 | 'MSMT17': {
16 | TRAIN_DIR_KEY: 'train',
17 | TEST_DIR_KEY: 'test',
18 | },
19 | 'MSMT17_V2': {
20 | TRAIN_DIR_KEY: 'mask_train_v2',
21 | TEST_DIR_KEY: 'mask_test_v2',
22 | }
23 | }
24 |
25 |
26 | @DATASET_REGISTRY.register()
27 | class MSMT17(ImageDataset):
28 | """MSMT17.
29 | Reference:
30 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018.
31 | URL: ``_
32 |
33 | Dataset statistics:
34 | - identities: 4101.
35 | - images: 32621 (train) + 11659 (query) + 82161 (gallery).
36 | - cameras: 15.
37 | """
38 | # dataset_dir = 'MSMT17_V2'
39 | dataset_url = None
40 | dataset_name = 'MSMT17'
41 |
42 | def __init__(self, root='datasets', **kwargs):
43 | self.root = root
44 | self.dataset_dir = self.root
45 |
46 | has_main_dir = False
47 | for main_dir in VERSION_DICT:
48 | if osp.exists(osp.join(self.dataset_dir, main_dir)):
49 | train_dir = VERSION_DICT[main_dir][TRAIN_DIR_KEY]
50 | test_dir = VERSION_DICT[main_dir][TEST_DIR_KEY]
51 | has_main_dir = True
52 | break
53 | assert has_main_dir, 'Dataset folder not found'
54 |
55 | self.train_dir = osp.join(self.dataset_dir, main_dir, train_dir)
56 | self.test_dir = osp.join(self.dataset_dir, main_dir, test_dir)
57 | self.list_train_path = osp.join(self.dataset_dir, main_dir, 'list_train.txt')
58 | self.list_val_path = osp.join(self.dataset_dir, main_dir, 'list_val.txt')
59 | self.list_query_path = osp.join(self.dataset_dir, main_dir, 'list_query.txt')
60 | self.list_gallery_path = osp.join(self.dataset_dir, main_dir, 'list_gallery.txt')
61 |
62 | required_files = [
63 | self.dataset_dir,
64 | self.train_dir,
65 | self.test_dir
66 | ]
67 | self.check_before_run(required_files)
68 |
69 | train = self.process_dir(self.train_dir, self.list_train_path)
70 | val = self.process_dir(self.train_dir, self.list_val_path)
71 | query = self.process_dir(self.test_dir, self.list_query_path, is_train=False)
72 | gallery = self.process_dir(self.test_dir, self.list_gallery_path, is_train=False)
73 |
74 | num_train_pids = self.get_num_pids(train)
75 | query_tmp = []
76 | for img_path, pid, camid in query:
77 | query_tmp.append((img_path, pid+num_train_pids, camid))
78 | del query
79 | query = query_tmp
80 |
81 | gallery_temp = []
82 | for img_path, pid, camid in gallery:
83 | gallery_temp.append((img_path, pid+num_train_pids, camid))
84 | del gallery
85 | gallery = gallery_temp
86 |
87 | # Note: to fairly compare with published methods on the conventional ReID setting,
88 | # do not add val images to the training set.
89 | if 'combineall' in kwargs and kwargs['combineall']:
90 | train += val
91 |
92 | super(MSMT17, self).__init__(train, query, gallery, **kwargs)
93 |
94 | def process_dir(self, dir_path, list_path, is_train=True):
95 | with open(list_path, 'r') as txt:
96 | lines = txt.readlines()
97 |
98 | data = []
99 |
100 | for img_idx, img_info in enumerate(lines):
101 | img_path, pid = img_info.split(' ')
102 | pid = int(pid) # no need to relabel
103 | camid = int(img_path.split('_')[2]) - 1 # index starts from 0
104 | img_path = osp.join(dir_path, img_path)
105 | if is_train:
106 | pid = self.dataset_name + "_" + str(pid)
107 | data.append((img_path, pid, camid))
108 |
109 | return data
--------------------------------------------------------------------------------
/data/datasets/pes3d.py:
--------------------------------------------------------------------------------
1 | import os
2 | from scipy.io import loadmat
3 | from glob import glob
4 |
5 | from . import DATASET_REGISTRY
6 | from .bases import ImageDataset
7 | import pdb
8 | import random
9 | import numpy as np
10 |
11 | __all__ = ['PeS3D',]
12 |
13 |
14 | @DATASET_REGISTRY.register()
15 | class PeS3D(ImageDataset):
16 | dataset_dir = "3DPeS"
17 | dataset_name = "pes3d"
18 |
19 | def __init__(self, root='datasets', **kwargs):
20 | self.root = root
21 | self.train_path = os.path.join(self.root, self.dataset_dir)
22 |
23 | required_files = [self.train_path]
24 | self.check_before_run(required_files)
25 |
26 | train = self.process_train(self.train_path)
27 |
28 | super().__init__(train, [], [], **kwargs)
29 |
30 | def process_train(self, train_path):
31 | data = []
32 |
33 | pid_list = os.listdir(train_path)
34 | for pid_dir in pid_list:
35 | pid = self.dataset_name + "_" + pid_dir
36 | img_list = glob(os.path.join(train_path, pid_dir, "*.bmp"))
37 | for img_path in img_list:
38 | camid = self.dataset_name + "_cam0"
39 | data.append([img_path, pid, camid])
40 | return data
41 |
--------------------------------------------------------------------------------
/data/datasets/pku.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 | from . import DATASET_REGISTRY
5 | from .bases import ImageDataset
6 |
7 | __all__ = ['PKU', ]
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class PKU(ImageDataset):
12 | dataset_dir = "PKUv1a_128x48"
13 | dataset_name = 'pku'
14 |
15 | def __init__(self, root='datasets', **kwargs):
16 | self.root = root
17 | self.train_path = os.path.join(self.root, self.dataset_dir)
18 |
19 | required_files = [self.train_path]
20 | self.check_before_run(required_files)
21 |
22 | train = self.process_train(self.train_path)
23 |
24 | super().__init__(train, [], [], **kwargs)
25 |
26 | def process_train(self, train_path):
27 | data = []
28 | img_paths = glob(os.path.join(train_path, "*.png"))
29 |
30 | for img_path in img_paths:
31 | split_path = img_path.split('/')
32 | img_info = split_path[-1].split('_')
33 | pid = self.dataset_name + "_" + img_info[0]
34 | camid = self.dataset_name + "_" + img_info[1]
35 | data.append([img_path, pid, camid])
36 | return data
37 |
--------------------------------------------------------------------------------
/data/datasets/prai.py:
--------------------------------------------------------------------------------
1 | import os
2 | from scipy.io import loadmat
3 | from glob import glob
4 |
5 | from . import DATASET_REGISTRY
6 | from .bases import ImageDataset
7 |
8 | __all__ = ['PRAI',]
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class PRAI(ImageDataset):
13 | dataset_dir = "PRAI-1581"
14 | dataset_name = 'prai'
15 |
16 | def __init__(self, root='datasets', **kwargs):
17 | self.root = root
18 | self.train_path = os.path.join(self.root, self.dataset_dir, 'images')
19 |
20 | required_files = [self.train_path]
21 | self.check_before_run(required_files)
22 |
23 | train = self.process_train(self.train_path)
24 |
25 | super().__init__(train, [], [], **kwargs)
26 |
27 | def process_train(self, train_path):
28 | data = []
29 | img_paths = glob(os.path.join(train_path, "*.jpg"))
30 | for img_path in img_paths:
31 | split_path = img_path.split('/')
32 | img_info = split_path[-1].split('_')
33 | pid = self.dataset_name + "_" + img_info[0]
34 | camid = self.dataset_name + "_" + img_info[1]
35 | data.append([img_path, pid, camid])
36 | return data
37 |
38 |
--------------------------------------------------------------------------------
/data/datasets/prid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from scipy.io import loadmat
3 | from glob import glob
4 | import random
5 | from . import DATASET_REGISTRY
6 | from .bases import ImageDataset
7 | import pdb
8 |
9 | __all__ = ['PRID',]
10 |
11 |
12 | @DATASET_REGISTRY.register()
13 | class PRID(ImageDataset):
14 | dataset_dir = "prid_2011"
15 | dataset_name = 'prid'
16 | _junk_pids = list(range(201, 750))
17 |
18 | def __init__(self, root='datasets', split_id=0, **kwargs):
19 |
20 |
21 | self.root = root
22 | self.dataset_dir = os.path.join(self.root, self.dataset_dir)
23 | # self.download_dataset(self.dataset_dir, self.dataset_url)
24 |
25 | self.cam_a_dir = os.path.join(
26 | self.dataset_dir, 'single_shot', 'cam_a'
27 | )
28 | self.cam_b_dir = os.path.join(
29 | self.dataset_dir, 'single_shot', 'cam_b'
30 | )
31 | self.split_path = os.path.join(self.dataset_dir, 'splits_single_shot.json')
32 |
33 | required_files = [
34 | self.dataset_dir,
35 | self.cam_a_dir,
36 | self.cam_b_dir
37 | ]
38 | self.check_before_run(required_files)
39 |
40 | self.prepare_split()
41 | splits = self.read_json(self.split_path)
42 | if split_id >= len(splits):
43 | raise ValueError(
44 | 'split_id exceeds range, received {}, but expected between 0 and {}'
45 | .format(split_id,
46 | len(splits) - 1)
47 | )
48 | split = splits[split_id]
49 |
50 | train, query, gallery = self.process_split(split)
51 |
52 | super(PRID, self).__init__(train, query, gallery, **kwargs)
53 |
54 | def prepare_split(self):
55 | if not os.path.exists(self.split_path):
56 | print('Creating splits ...')
57 |
58 | splits = []
59 | for _ in range(10):
60 | # randomly sample 100 IDs for train and use the rest 100 IDs for test
61 | # (note: there are only 200 IDs appearing in both views)
62 | pids = [i for i in range(1, 201)]
63 | train_pids = random.sample(pids, 100)
64 | train_pids.sort()
65 | test_pids = [i for i in pids if i not in train_pids]
66 | split = {'train': train_pids, 'test': test_pids}
67 | splits.append(split)
68 |
69 | print('Totally {} splits are created'.format(len(splits)))
70 | self.write_json(splits, self.split_path)
71 | print('Split file is saved to {}'.format(self.split_path))
72 |
73 | def process_split(self, split):
74 | train_pids = split['train']
75 | test_pids = split['test']
76 |
77 | train_pid2label = {pid: label for label, pid in enumerate(train_pids)}
78 |
79 | # train
80 | train = []
81 | for pid in train_pids:
82 | img_name = 'person_' + str(pid).zfill(4) + '.png'
83 | pid = train_pid2label[pid]
84 | img_a_path = os.path.join(self.cam_a_dir, img_name)
85 | train.append((img_a_path, pid, 0))
86 | img_b_path = os.path.join(self.cam_b_dir, img_name)
87 | train.append((img_b_path, pid, 1))
88 |
89 | # query and gallery
90 | query, gallery = [], []
91 | for pid in test_pids:
92 | img_name = 'person_' + str(pid).zfill(4) + '.png'
93 | img_a_path = os.path.join(self.cam_a_dir, img_name)
94 | query.append((img_a_path, pid, 0))
95 | img_b_path = os.path.join(self.cam_b_dir, img_name)
96 | gallery.append((img_b_path, pid, 1))
97 | for pid in range(201, 750):
98 | img_name = 'person_' + str(pid).zfill(4) + '.png'
99 | img_b_path = os.path.join(self.cam_b_dir, img_name)
100 | gallery.append((img_b_path, pid, 1))
101 |
102 | return train, query, gallery
103 |
104 | def read_json(self, fpath):
105 | import json
106 | """Reads json file from a path."""
107 | with open(fpath, 'r') as f:
108 | obj = json.load(f)
109 | return obj
110 |
111 |
112 | def write_json(self, obj, fpath):
113 | import json
114 | """Writes to a json file."""
115 | self.mkdir_if_missing(os.path.dirname(fpath))
116 | with open(fpath, 'w') as f:
117 | json.dump(obj, f, indent=4, separators=(',', ': '))
118 |
119 | def mkdir_if_missing(self, dirname):
120 | import errno
121 | """Creates dirname if it is missing."""
122 | if not os.path.exists(dirname):
123 | try:
124 | os.makedirs(dirname)
125 | except OSError as e:
126 | if e.errno != errno.EEXIST:
127 | raise
--------------------------------------------------------------------------------
/data/datasets/randperson.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 | from glob import glob
4 |
5 | from ..datasets import DATASET_REGISTRY
6 |
7 | @DATASET_REGISTRY.register()
8 | class RandPerson(object):
9 |
10 | def __init__(self, root, combineall=True):
11 |
12 | self.images_dir = osp.join(root)
13 | self.img_path = 'your_path/randperson_subset/randperson_subset'
14 | self.train_path = self.img_path
15 | self.gallery_path = ''
16 | self.query_path = ''
17 | self.train = []
18 | self.gallery = []
19 | self.query = []
20 | self.num_train_ids = 0
21 | self.has_time_info = True
22 | # self.show_train()
23 |
24 | def preprocess(self):
25 | fpaths = sorted(glob(osp.join(self.images_dir, self.train_path, '*g')))
26 |
27 | data = []
28 | all_pids = {}
29 | camera_offset = [0, 2, 4, 6, 8, 9, 10, 12, 13, 14, 15]
30 | frame_offset = [0, 160000, 340000,490000, 640000, 1070000, 1330000, 1590000, 1890000, 3190000, 3490000]
31 | fps = 24
32 |
33 | for fpath in fpaths:
34 | fname = osp.basename(fpath) # filename: id6_s2_c2_f6.jpg
35 | fields = fname.split('_')
36 | pid = int(fields[0])
37 | if pid not in all_pids:
38 | all_pids[pid] = len(all_pids)
39 | pid = all_pids[pid] # relabel
40 | camid = camera_offset[int(fields[1][1:])] + int(fields[2][1:]) # make it starting from 0
41 | time = (frame_offset[int(fields[1][1:])] + int(fields[3][1:7])) / fps
42 | data.append((fpath, pid, camid, time))
43 | # print(fname, pid, camid, time)
44 | return data, int(len(all_pids))
45 |
46 | def show_train(self):
47 | self.train, self.num_train_ids = self.preprocess()
48 |
49 | print(self.__class__.__name__, "dataset loaded")
50 | print(" subset | # ids | # images")
51 | print(" ---------------------------")
52 | print(" all | {:5d} | {:8d}\n"
53 | .format(self.num_train_ids, len(self.train)))
54 |
--------------------------------------------------------------------------------
/data/datasets/sensereid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 | from . import DATASET_REGISTRY
5 | from .bases import ImageDataset
6 |
7 | __all__ = ['SenseReID', ]
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class SenseReID(ImageDataset):
12 | dataset_dir = "SenseReID"
13 | dataset_name = "senseid"
14 |
15 | def __init__(self, root='datasets', **kwargs):
16 | self.root = root
17 | self.train_path = os.path.join(self.root, self.dataset_dir)
18 |
19 | required_files = [self.train_path]
20 | self.check_before_run(required_files)
21 |
22 | train = self.process_train(self.train_path)
23 |
24 | super().__init__(train, [], [], **kwargs)
25 |
26 | def process_train(self, train_path):
27 | data = []
28 | file_path_list = ['test_gallery', 'test_prob']
29 |
30 | for file_path in file_path_list:
31 | sub_file = os.path.join(train_path, file_path)
32 | img_name = glob(os.path.join(sub_file, "*.jpg"))
33 | for img_path in img_name:
34 | img_name = img_path.split('/')[-1]
35 | img_info = img_name.split('_')
36 | pid = self.dataset_name + "_" + img_info[0]
37 | camid = self.dataset_name + "_" + img_info[1].split('.')[0]
38 | data.append([img_path, pid, camid])
39 | return data
40 |
--------------------------------------------------------------------------------
/data/datasets/shinpuhkan.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from . import DATASET_REGISTRY
4 | from .bases import ImageDataset
5 |
6 | __all__ = ['Shinpuhkan', ]
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class Shinpuhkan(ImageDataset):
11 | dataset_dir = "shinpuhkan"
12 | dataset_name = 'shinpuhkan'
13 |
14 | def __init__(self, root='datasets', **kwargs):
15 | self.root = root
16 | self.train_path = os.path.join(self.root, self.dataset_dir)
17 |
18 | required_files = [self.train_path]
19 | self.check_before_run(required_files)
20 |
21 | train = self.process_train(self.train_path)
22 |
23 | super().__init__(train, [], [], **kwargs)
24 |
25 | def process_train(self, train_path):
26 | data = []
27 |
28 | for root, dirs, files in os.walk(train_path):
29 | img_names = list(filter(lambda x: x.endswith(".jpg"), files))
30 | # fmt: off
31 | if len(img_names) == 0: continue
32 | # fmt: on
33 | for img_name in img_names:
34 | img_path = os.path.join(root, img_name)
35 | split_path = img_name.split('_')
36 | pid = self.dataset_name + "_" + split_path[0]
37 | camid = self.dataset_name + "_" + split_path[2]
38 | data.append((img_path, pid, camid))
39 |
40 | return data
41 |
--------------------------------------------------------------------------------
/data/datasets/sysu_mm.py:
--------------------------------------------------------------------------------
1 | import os
2 | from scipy.io import loadmat
3 | from glob import glob
4 |
5 | from . import DATASET_REGISTRY
6 | from .bases import ImageDataset
7 | import pdb
8 |
9 | __all__ = ['SYSU_mm', ]
10 |
11 |
12 | @DATASET_REGISTRY.register()
13 | class SYSU_mm(ImageDataset):
14 | dataset_dir = "SYSU-MM01"
15 | dataset_name = "sysumm01"
16 |
17 | def __init__(self, root='datasets', **kwargs):
18 | self.root = root
19 | self.train_path = os.path.join(self.root, self.dataset_dir)
20 |
21 | required_files = [self.train_path]
22 | self.check_before_run(required_files)
23 |
24 | train = self.process_train(self.train_path)
25 |
26 | super().__init__(train, [], [], **kwargs)
27 |
28 | def process_train(self, train_path):
29 | data = []
30 |
31 | file_path_list = ['cam1', 'cam2', 'cam4', 'cam5']
32 |
33 | for file_path in file_path_list:
34 | camid = self.dataset_name + "_" + file_path
35 | pid_list = os.listdir(os.path.join(train_path, file_path))
36 | for pid_dir in pid_list:
37 | pid = self.dataset_name + "_" + pid_dir
38 | img_list = glob(os.path.join(train_path, file_path, pid_dir, "*.jpg"))
39 | for img_path in img_list:
40 | data.append([img_path, pid, camid])
41 | return data
42 |
43 |
--------------------------------------------------------------------------------
/data/datasets/thermalworld.py:
--------------------------------------------------------------------------------
1 | import os
2 | from scipy.io import loadmat
3 | from glob import glob
4 |
5 | from . import DATASET_REGISTRY
6 | from .bases import ImageDataset
7 | import pdb
8 | import random
9 | import numpy as np
10 |
11 | __all__ = ['Thermalworld',]
12 |
13 |
14 | @DATASET_REGISTRY.register()
15 | class Thermalworld(ImageDataset):
16 | dataset_dir = "thermalworld_rgb"
17 | dataset_name = "thermalworld"
18 |
19 | def __init__(self, root='datasets', **kwargs):
20 | self.root = root
21 | self.train_path = os.path.join(self.root, self.dataset_dir)
22 |
23 | required_files = [self.train_path]
24 | self.check_before_run(required_files)
25 |
26 | train = self.process_train(self.train_path)
27 |
28 | super().__init__(train, [], [], **kwargs)
29 |
30 | def process_train(self, train_path):
31 | data = []
32 | pid_list = os.listdir(train_path)
33 | for pid_dir in pid_list:
34 | pid = self.dataset_name + "_" + pid_dir
35 | img_list = glob(os.path.join(train_path, pid_dir, "*.jpg"))
36 | for img_path in img_list:
37 | camid = self.dataset_name + "_cam0"
38 | data.append([img_path, pid, camid])
39 | return data
40 |
--------------------------------------------------------------------------------
/data/datasets/vehicleid.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import random
3 |
4 | from .bases import ImageDataset
5 | from ..datasets import DATASET_REGISTRY
6 |
7 |
8 | @DATASET_REGISTRY.register()
9 | class VehicleID(ImageDataset):
10 | """VehicleID.
11 |
12 | Reference:
13 | Liu et al. Deep relative distance learning: Tell the difference between similar vehicles. CVPR 2016.
14 |
15 | URL: ``_
16 |
17 | Train dataset statistics:
18 | - identities: 13164.
19 | - images: 113346.
20 | """
21 | dataset_dir = "vehicleid"
22 | dataset_name = "vehicleid"
23 |
24 | def __init__(self, root='datasets', test_list='', **kwargs):
25 | self.dataset_dir = osp.join(root, self.dataset_dir)
26 | self.image_dir = osp.join(self.dataset_dir, 'image')
27 | self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt')
28 | if test_list:
29 | self.test_list = test_list
30 | else:
31 | self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_13164.txt')
32 |
33 | required_files = [
34 | self.dataset_dir,
35 | self.image_dir,
36 | self.train_list,
37 | self.test_list,
38 | ]
39 | self.check_before_run(required_files)
40 |
41 | train = self.process_dir(self.train_list, is_train=True)
42 | query, gallery = self.process_dir(self.test_list, is_train=False)
43 |
44 | super(VehicleID, self).__init__(train, query, gallery, **kwargs)
45 |
46 | def process_dir(self, list_file, is_train=True):
47 | img_list_lines = open(list_file, 'r').readlines()
48 |
49 | dataset = []
50 | for idx, line in enumerate(img_list_lines):
51 | line = line.strip()
52 | vid = int(line.split(' ')[1])
53 | imgid = line.split(' ')[0]
54 | img_path = osp.join(self.image_dir, imgid + '.jpg')
55 | if is_train:
56 | vid = self.dataset_name + "_" + str(vid)
57 | dataset.append((img_path, vid, int(imgid)))
58 |
59 | if is_train: return dataset
60 | else:
61 | random.shuffle(dataset)
62 | vid_container = set()
63 | query = []
64 | gallery = []
65 | for sample in dataset:
66 | if sample[1] not in vid_container:
67 | vid_container.add(sample[1])
68 | gallery.append(sample)
69 | else:
70 | query.append(sample)
71 |
72 | return query, gallery
73 |
74 |
75 | @DATASET_REGISTRY.register()
76 | class SmallVehicleID(VehicleID):
77 | """VehicleID.
78 | Small test dataset statistics:
79 | - identities: 800.
80 | - images: 6493.
81 | """
82 |
83 | def __init__(self, root='datasets', **kwargs):
84 | # self.dataset_dir = osp.join(root, self.dataset_dir)
85 | self.test_list = osp.join(root, self.dataset_dir, 'train_test_split/test_list_800.txt')
86 |
87 | super(SmallVehicleID, self).__init__(root, self.test_list, **kwargs)
88 |
89 |
90 | @DATASET_REGISTRY.register()
91 | class MediumVehicleID(VehicleID):
92 | """VehicleID.
93 | Medium test dataset statistics:
94 | - identities: 1600.
95 | - images: 13377.
96 | """
97 |
98 | def __init__(self, root='datasets', **kwargs):
99 | # self.dataset_dir = osp.join(root, self.dataset_dir)
100 | self.test_list = osp.join(root, self.dataset_dir, 'train_test_split/test_list_1600.txt')
101 |
102 | super(MediumVehicleID, self).__init__(root, self.test_list, **kwargs)
103 |
104 |
105 | @DATASET_REGISTRY.register()
106 | class LargeVehicleID(VehicleID):
107 | """VehicleID.
108 | Large test dataset statistics:
109 | - identities: 2400.
110 | - images: 19777.
111 | """
112 |
113 | def __init__(self, root='datasets', **kwargs):
114 | # self.dataset_dir = osp.join(root, self.dataset_dir)
115 | self.test_list = osp.join(root, self.dataset_dir, 'train_test_split/test_list_2400.txt')
116 |
117 | super(LargeVehicleID, self).__init__(root, self.test_list, **kwargs)
118 |
--------------------------------------------------------------------------------
/data/datasets/veri.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os.path as osp
3 | import re
4 |
5 | from .bases import ImageDataset
6 | from ..datasets import DATASET_REGISTRY
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class VeRi(ImageDataset):
11 | """VeRi.
12 |
13 | Reference:
14 | Liu et al. A Deep Learning based Approach for Progressive Vehicle Re-Identification. ECCV 2016.
15 |
16 | URL: ``_
17 |
18 | Dataset statistics:
19 | - identities: 775.
20 | - images: 37778 (train) + 1678 (query) + 11579 (gallery).
21 | """
22 | dataset_dir = "veri"
23 | dataset_name = "veri"
24 |
25 | def __init__(self, root='datasets', **kwargs):
26 | self.dataset_dir = osp.join(root, self.dataset_dir)
27 |
28 | self.train_dir = osp.join(self.dataset_dir, 'image_train')
29 | self.query_dir = osp.join(self.dataset_dir, 'image_query')
30 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
31 |
32 | required_files = [
33 | self.dataset_dir,
34 | self.train_dir,
35 | self.query_dir,
36 | self.gallery_dir,
37 | ]
38 | self.check_before_run(required_files)
39 |
40 | train = self.process_dir(self.train_dir)
41 | query = self.process_dir(self.query_dir, is_train=False)
42 | gallery = self.process_dir(self.gallery_dir, is_train=False)
43 |
44 | super(VeRi, self).__init__(train, query, gallery, **kwargs)
45 |
46 | def process_dir(self, dir_path, is_train=True):
47 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
48 | pattern = re.compile(r'([\d]+)_c(\d\d\d)')
49 |
50 | data = []
51 | for img_path in img_paths:
52 | pid, camid = map(int, pattern.search(img_path).groups())
53 | if pid == -1: continue # junk images are just ignored
54 | assert 1 <= pid <= 776
55 | assert 1 <= camid <= 20
56 | camid -= 1 # index starts from 0
57 | if is_train:
58 | pid = self.dataset_name + "_" + str(pid)
59 | data.append((img_path, pid, camid))
60 |
61 | return data
62 |
--------------------------------------------------------------------------------
/data/datasets/veri_keypoint.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os.path as osp
3 | import re
4 |
5 | from .bases import ImageDataset
6 | from ..datasets import DATASET_REGISTRY
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class VeRi_keypoint(ImageDataset):
11 | """VeRi.
12 |
13 | Reference:
14 | Liu et al. A Deep Learning based Approach for Progressive Vehicle Re-Identification. ECCV 2016.
15 |
16 | URL: ``_
17 |
18 | Dataset statistics:
19 | - identities: 775.
20 | - images: 37778 (train) + 1678 (query) + 11579 (gallery).
21 | """
22 | dataset_dir = "veri"
23 | dataset_name = "veri"
24 |
25 | def __init__(self, root='datasets', **kwargs):
26 | self.dataset_dir = osp.join(root, self.dataset_dir)
27 | self.keypoint_dir = osp.join(root, 'veri_keypoint')
28 |
29 | self.train_dir = osp.join(self.dataset_dir, 'image_train')
30 | self.query_dir = osp.join(self.dataset_dir, 'image_query')
31 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
32 |
33 | required_files = [
34 | self.dataset_dir,
35 | self.train_dir,
36 | self.query_dir,
37 | self.gallery_dir,
38 | self.keypoint_dir
39 | ]
40 | self.check_before_run(required_files)
41 |
42 | train = self.process_dir(self.train_dir)
43 | train = self.process_keypoint(self.keypoint_dir, train)
44 | query = self.process_dir(self.query_dir, is_train=False)
45 | gallery = self.process_dir(self.gallery_dir, is_train=False)
46 |
47 | super(VeRi_keypoint, self).__init__(train, query, gallery, **kwargs)
48 |
49 | def process_dir(self, dir_path, is_train=True):
50 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
51 | pattern = re.compile(r'([\d]+)_c(\d\d\d)')
52 |
53 | data = []
54 | for img_path in img_paths:
55 | pid, camid = map(int, pattern.search(img_path).groups())
56 | if pid == -1: continue # junk images are just ignored
57 | assert 1 <= pid <= 776
58 | assert 1 <= camid <= 20
59 | camid -= 1 # index starts from 0
60 | if is_train:
61 | pid = self.dataset_name + "_" + str(pid)
62 | data.append((img_path, pid, camid))
63 |
64 |
65 | return data
66 |
67 |
68 | def process_keypoint(self, dir_path, data):
69 | train_name = []
70 | train_raw = []
71 | train_keypoint = []
72 | train_orientation = []
73 | is_keypoint = False
74 | is_orientation = True
75 | is_aligned = False
76 | with open(osp.join(dir_path, 'keypoint_train_aligned.txt')) as f:
77 | for line in f:
78 | train_raw.append(line)
79 | line_split = line.split(' ')
80 | train_name.append(line_split[0].split('/')[-1])
81 |
82 | if is_keypoint:
83 | train_keypoint.append(line_split[1:41])
84 | if is_orientation:
85 | tmp = line_split[-1]
86 | if '\n' in tmp:
87 | tmp = tmp[0]
88 | assert 0 <= int(tmp) <= 7 # orientation should be 0~7
89 | train_orientation.append(int(tmp))
90 |
91 | if is_aligned:
92 | train_name = sorted(tuple(train_name))
93 | train_raw = sorted(tuple(train_raw))
94 |
95 | with open(osp.join(dir_path, 'keypoint_train_aligned.txt'), 'w') as f:
96 | for i, x in enumerate(data):
97 | j = 0
98 | flag_break = False
99 | while (j < len(train_name) and not flag_break):
100 | if train_name[j] in x[0]:
101 | if train_name[j] in train_raw[j]:
102 | f.write(train_raw[j])
103 | flag_break = True
104 | del train_name[j]
105 | del train_raw[j]
106 | print(i)
107 | else:
108 | assert()
109 | j += 1
110 |
111 |
112 | for i, x in enumerate(data):
113 | j = 0
114 | flag_break = False
115 | while(j < len(train_name) and not flag_break):
116 | if train_name[j] in x[0]:
117 | add_info = {} # dictionary
118 | add_info['domains'] = int(train_orientation[j])
119 | data[i] = list(data[i])
120 | data[i].append(add_info)
121 | data[i] = tuple(data[i])
122 | flag_break = True
123 | del train_name[j]
124 | del train_orientation[j]
125 | # print(i)
126 | j += 1
127 |
128 | cnt = 0
129 | no_title = []
130 | no_title_local = []
131 | for line in data:
132 | if len(line) != 4:
133 | assert()
134 | # no_title.append(line[0])
135 | # tmp1 = line[0].split('/')[-1]
136 | # tmp2 = tmp1.split('_')
137 | # tmp3 = '_'.join(tmp2[2:])
138 | # for line2 in train_name:
139 | # if tmp3 in line2:
140 | # print(line2)
141 | # no_title_local.append(tmp3)
142 | # cnt += 1
143 |
144 | return data
--------------------------------------------------------------------------------
/data/datasets/veriwild.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | from .bases import ImageDataset
4 | from ..datasets import DATASET_REGISTRY
5 |
6 |
7 | @DATASET_REGISTRY.register()
8 | class VeRiWild(ImageDataset):
9 | """VeRi-Wild.
10 |
11 | Reference:
12 | Lou et al. A Large-Scale Dataset for Vehicle Re-Identification in the Wild. CVPR 2019.
13 |
14 | URL: ``_
15 |
16 | Train dataset statistics:
17 | - identities: 30671.
18 | - images: 277797.
19 | """
20 | dataset_dir = "VERI-Wild"
21 | dataset_name = "veriwild"
22 |
23 | def __init__(self, root='datasets', query_list='', gallery_list='', **kwargs):
24 | self.dataset_dir = osp.join(root, self.dataset_dir)
25 |
26 | self.image_dir = osp.join(self.dataset_dir, 'images')
27 | self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt')
28 | self.vehicle_info = osp.join(self.dataset_dir, 'train_test_split/vehicle_info.txt')
29 | if query_list and gallery_list:
30 | self.query_list = query_list
31 | self.gallery_list = gallery_list
32 | else:
33 | self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_10000_query.txt')
34 | self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_10000.txt')
35 |
36 | required_files = [
37 | self.image_dir,
38 | self.train_list,
39 | self.query_list,
40 | self.gallery_list,
41 | self.vehicle_info,
42 | ]
43 | self.check_before_run(required_files)
44 |
45 | self.imgid2vid, self.imgid2camid, self.imgid2imgpath = self.process_vehicle(self.vehicle_info)
46 |
47 | train = self.process_dir(self.train_list)
48 | query = self.process_dir(self.query_list, is_train=False)
49 | gallery = self.process_dir(self.gallery_list, is_train=False)
50 |
51 | super(VeRiWild, self).__init__(train, query, gallery, **kwargs)
52 |
53 | def process_dir(self, img_list, is_train=True):
54 | img_list_lines = open(img_list, 'r').readlines()
55 |
56 | dataset = []
57 | for idx, line in enumerate(img_list_lines):
58 | line = line.strip()
59 | vid = int(line.split('/')[0])
60 | imgid = line.split('/')[1]
61 | if is_train:
62 | vid = self.dataset_name + "_" + str(vid)
63 | dataset.append((self.imgid2imgpath[imgid], vid, int(self.imgid2camid[imgid])))
64 |
65 | assert len(dataset) == len(img_list_lines)
66 | return dataset
67 |
68 | def process_vehicle(self, vehicle_info):
69 | imgid2vid = {}
70 | imgid2camid = {}
71 | imgid2imgpath = {}
72 | vehicle_info_lines = open(vehicle_info, 'r').readlines()
73 |
74 | for idx, line in enumerate(vehicle_info_lines[1:]):
75 | vid = line.strip().split('/')[0]
76 | imgid = line.strip().split(';')[0].split('/')[1]
77 | camid = line.strip().split(';')[1]
78 | # img_path = osp.join(self.image_dir, vid, imgid + '.jpg')
79 | img_path = osp.join(self.image_dir, imgid + '.jpg')
80 | imgid2vid[imgid] = vid
81 | imgid2camid[imgid] = camid
82 | imgid2imgpath[imgid] = img_path
83 |
84 | assert len(imgid2vid) == len(vehicle_info_lines) - 1
85 | return imgid2vid, imgid2camid, imgid2imgpath
86 |
87 |
88 | @DATASET_REGISTRY.register()
89 | class SmallVeRiWild(VeRiWild):
90 | """VeRi-Wild.
91 | Small test dataset statistics:
92 | - identities: 3000.
93 | - images: 41861.
94 | """
95 |
96 | def __init__(self, root='datasets', **kwargs):
97 | # self.dataset_dir = osp.join(root, self.dataset_dir)
98 | self.query_list = osp.join(root, self.dataset_dir, 'train_test_split/test_3000_query.txt')
99 | self.gallery_list = osp.join(root, self.dataset_dir, 'train_test_split/test_3000.txt')
100 |
101 | super(SmallVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
102 |
103 |
104 | @DATASET_REGISTRY.register()
105 | class MediumVeRiWild(VeRiWild):
106 | """VeRi-Wild.
107 | Medium test dataset statistics:
108 | - identities: 5000.
109 | - images: 69389.
110 | """
111 |
112 | def __init__(self, root='datasets', **kwargs):
113 | # self.dataset_dir = osp.join(root, self.dataset_dir)
114 | self.query_list = osp.join(root, self.dataset_dir, 'train_test_split/test_5000_query.txt')
115 | self.gallery_list = osp.join(root, self.dataset_dir, 'train_test_split/test_5000.txt')
116 |
117 | super(MediumVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
118 |
119 |
120 | @DATASET_REGISTRY.register()
121 | class LargeVeRiWild(VeRiWild):
122 | """VeRi-Wild.
123 | Large test dataset statistics:
124 | - identities: 10000.
125 | - images: 138517.
126 | """
127 |
128 | def __init__(self, root='datasets', **kwargs):
129 | # self.dataset_dir = osp.join(root, self.dataset_dir)
130 | self.query_list = osp.join(root, self.dataset_dir, 'train_test_split/test_10000_query.txt')
131 | self.gallery_list = osp.join(root, self.dataset_dir, 'train_test_split/test_10000.txt')
132 |
133 | super(LargeVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
134 |
--------------------------------------------------------------------------------
/data/datasets/viper.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 | from . import DATASET_REGISTRY
5 | from .bases import ImageDataset
6 |
7 | __all__ = ['VIPeR', ]
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class VIPeR(ImageDataset):
12 | dataset_dir = "VIPeR"
13 | dataset_name = "viper"
14 |
15 | def __init__(self, root='datasets', **kwargs):
16 | self.root = root
17 | self.train_path = os.path.join(self.root, self.dataset_dir)
18 |
19 | required_files = [self.train_path]
20 | self.check_before_run(required_files)
21 |
22 | train = self.process_train(self.train_path)
23 |
24 | super().__init__(train, [], [], **kwargs)
25 |
26 | def process_train(self, train_path):
27 | data = []
28 |
29 | file_path_list = ['cam_a', 'cam_b']
30 |
31 | for file_path in file_path_list:
32 | camid = self.dataset_name + "_" + file_path
33 | img_list = glob(os.path.join(train_path, file_path, "*.bmp"))
34 | for img_path in img_list:
35 | img_name = img_path.split('/')[-1]
36 | pid = self.dataset_name + "_" + img_name.split('_')[0]
37 | data.append([img_path, pid, camid])
38 |
39 | return data
--------------------------------------------------------------------------------
/data/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler, DomainSuffleSampler, RandomIdentitySampler
2 | from .data_sampler import TrainingSampler, InferenceSampler
3 |
--------------------------------------------------------------------------------
/data/samplers/data_sampler.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from typing import Optional
3 |
4 | import numpy as np
5 | from torch.utils.data import Sampler
6 |
7 | from utils import comm
8 |
9 |
10 | class TrainingSampler(Sampler):
11 | """
12 | In training, we only care about the "infinite stream" of training data.
13 | So this sampler produces an infinite stream of indices and
14 | all workers cooperate to correctly shuffle the indices and sample different indices.
15 | The samplers in each worker effectively produces `indices[worker_id::num_workers]`
16 | where `indices` is an infinite stream of indices consisting of
17 | `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
18 | or `range(size) + range(size) + ...` (if shuffle is False)
19 | """
20 |
21 | def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
22 | """
23 | Args:
24 | size (int): the total number of data of the underlying dataset to sample from
25 | shuffle (bool): whether to shuffle the indices or not
26 | seed (int): the initial seed of the shuffle. Must be the same
27 | across all workers. If None, will use a random seed shared
28 | among workers (require synchronization among all workers).
29 | """
30 | self._size = size
31 | assert size > 0
32 | self._shuffle = shuffle
33 | if seed is None:
34 | seed = comm.shared_random_seed()
35 | self._seed = int(seed)
36 |
37 | self._rank = comm.get_rank()
38 | self._world_size = comm.get_world_size()
39 |
40 | def __iter__(self):
41 | start = self._rank
42 | yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
43 |
44 | def _infinite_indices(self):
45 | np.random.seed(self._seed)
46 | while True:
47 | if self._shuffle:
48 | yield from np.random.permutation(self._size)
49 | else:
50 | yield from np.arange(self._size)
51 |
52 |
53 | class InferenceSampler(Sampler):
54 | """
55 | Produce indices for inference.
56 | Inference needs to run on the __exact__ set of samples,
57 | therefore when the total number of samples is not divisible by the number of workers,
58 | this sampler produces different number of samples on different workers.
59 | """
60 |
61 | def __init__(self, size: int):
62 | """
63 | Args:
64 | size (int): the total number of data of the underlying dataset to sample from
65 | """
66 | self._size = size
67 | assert size > 0
68 |
69 | begin = 0
70 | end = self._size
71 | self._local_indices = range(begin, end)
72 |
73 | def __iter__(self):
74 | yield from self._local_indices
75 |
76 | def __len__(self):
77 | return len(self._local_indices)
78 |
--------------------------------------------------------------------------------
/data/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_transforms
2 | from .transforms import *
3 | from .autoaugment import *
4 |
--------------------------------------------------------------------------------
/data/transforms/build.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms as T
3 |
4 | from .transforms import *
5 | from .autoaugment import AutoAugment
6 | from PIL import Image, ImageFilter, ImageOps
7 |
8 | from .transforms import LGT
9 |
10 | class GaussianBlur(object):
11 | """
12 | Apply Gaussian Blur to the PIL image.
13 | """
14 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
15 | self.prob = p
16 | self.radius_min = radius_min
17 | self.radius_max = radius_max
18 |
19 | def __call__(self, img):
20 | do_it = random.random() <= self.prob
21 | if not do_it:
22 | return img
23 |
24 | return img.filter(
25 | ImageFilter.GaussianBlur(
26 | radius=random.uniform(self.radius_min, self.radius_max)
27 | )
28 | )
29 |
30 |
31 | class Solarization(object):
32 | """
33 | Apply Solarization to the PIL image.
34 | """
35 | def __init__(self, p):
36 | self.p = p
37 |
38 | def __call__(self, img):
39 | if random.random() < self.p:
40 | return ImageOps.solarize(img)
41 | else:
42 | return img
43 |
44 | def build_transforms(cfg, is_train=True, is_fake=False):
45 | res = []
46 |
47 | if is_train:
48 | size_train = cfg.INPUT.SIZE_TRAIN
49 |
50 | # augmix augmentation
51 | do_augmix = cfg.INPUT.DO_AUGMIX
52 |
53 | # auto augmentation
54 | do_autoaug = cfg.INPUT.DO_AUTOAUG
55 | # total_iter = cfg.SOLVER.MAX_ITER
56 | total_iter = cfg.SOLVER.MAX_EPOCHS
57 |
58 | # horizontal filp
59 | do_flip = cfg.INPUT.DO_FLIP
60 | flip_prob = cfg.INPUT.FLIP_PROB
61 |
62 | # padding
63 | do_pad = cfg.INPUT.DO_PAD
64 | padding = cfg.INPUT.PADDING
65 | padding_mode = cfg.INPUT.PADDING_MODE
66 |
67 | # Local Grayscale Transfomation
68 | do_lgt = cfg.INPUT.LGT.DO_LGT
69 | lgt_prob = cfg.INPUT.LGT.PROB
70 |
71 | # color jitter
72 | do_cj = cfg.INPUT.CJ.ENABLED
73 | cj_prob = cfg.INPUT.CJ.PROB
74 | cj_brightness = cfg.INPUT.CJ.BRIGHTNESS
75 | cj_contrast = cfg.INPUT.CJ.CONTRAST
76 | cj_saturation = cfg.INPUT.CJ.SATURATION
77 | cj_hue = cfg.INPUT.CJ.HUE
78 |
79 | # random erasing
80 | do_rea = cfg.INPUT.REA.ENABLED
81 | rea_prob = cfg.INPUT.REA.PROB
82 | rea_mean = cfg.INPUT.REA.MEAN
83 | # random patch
84 | do_rpt = cfg.INPUT.RPT.ENABLED
85 | rpt_prob = cfg.INPUT.RPT.PROB
86 |
87 | if do_autoaug:
88 | res.append(AutoAugment(total_iter))
89 | res.append(T.Resize(size_train, interpolation=3))
90 | if do_flip:
91 | res.append(T.RandomHorizontalFlip(p=flip_prob))
92 | if do_pad:
93 | res.extend([T.Pad(padding, padding_mode=padding_mode),
94 | T.RandomCrop(size_train)])
95 | if do_lgt:
96 | res.append(LGT(lgt_prob))
97 | if do_cj:
98 | res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob))
99 | if do_augmix:
100 | res.append(AugMix())
101 | # if do_rea:
102 | # res.append(RandomErasing(probability=rea_prob, mean=rea_mean, sh=1/3))
103 | if do_rpt:
104 | res.append(RandomPatch(prob_happen=rpt_prob))
105 | if is_fake:
106 | if cfg.META.DATA.SYNTH_FLAG == 'jitter':
107 | res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=1.0))
108 | elif cfg.META.DATA.SYNTH_FLAG == 'augmix':
109 | res.append(AugMix())
110 | elif cfg.META.DATA.SYNTH_FLAG == 'both':
111 | res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob))
112 | res.append(AugMix())
113 | res.extend([
114 | T.ToTensor(),
115 | T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
116 | ])
117 | if do_rea:
118 | from timm.data.random_erasing import RandomErasing as RE
119 | res.append(RE(probability=rea_prob, mode='pixel', max_count=1, device='cpu'))
120 | else:
121 | size_test = cfg.INPUT.SIZE_TEST
122 | res.append(T.Resize(size_test, interpolation=3))
123 | res.extend([
124 | T.ToTensor(),
125 | T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
126 | ])
127 | return T.Compose(res)
128 |
--------------------------------------------------------------------------------
/data/transforms/functional.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image, ImageOps, ImageEnhance
4 |
5 |
6 | def to_tensor(pic):
7 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
8 |
9 | See ``ToTensor`` for more details.
10 |
11 | Args:
12 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
13 |
14 | Returns:
15 | Tensor: Converted image.
16 | """
17 | if isinstance(pic, np.ndarray):
18 | assert len(pic.shape) in (2, 3)
19 | # handle numpy array
20 | if pic.ndim == 2:
21 | pic = pic[:, :, None]
22 |
23 | img = torch.from_numpy(pic.transpose((2, 0, 1)))
24 | # backward compatibility
25 | if isinstance(img, torch.ByteTensor):
26 | return img.float()
27 | else:
28 | return img
29 |
30 | # handle PIL Image
31 | if pic.mode == 'I':
32 | img = torch.from_numpy(np.array(pic, np.int32, copy=False))
33 | elif pic.mode == 'I;16':
34 | img = torch.from_numpy(np.array(pic, np.int16, copy=False))
35 | elif pic.mode == 'F':
36 | img = torch.from_numpy(np.array(pic, np.float32, copy=False))
37 | elif pic.mode == '1':
38 | img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
39 | else:
40 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
41 | # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
42 | if pic.mode == 'YCbCr':
43 | nchannel = 3
44 | elif pic.mode == 'I;16':
45 | nchannel = 1
46 | else:
47 | nchannel = len(pic.mode)
48 | img = img.view(pic.size[1], pic.size[0], nchannel)
49 | # put it from HWC to CHW format
50 | # yikes, this transpose takes 80% of the loading time/CPU
51 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
52 | if isinstance(img, torch.ByteTensor):
53 | return img.float()
54 | else:
55 | return img
56 |
57 |
58 | def int_parameter(level, maxval):
59 | """Helper function to scale `val` between 0 and maxval .
60 | Args:
61 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
62 | maxval: Maximum value that the operation can have. This will be scaled to
63 | level/PARAMETER_MAX.
64 | Returns:
65 | An int that results from scaling `maxval` according to `level`.
66 | """
67 | return int(level * maxval / 10)
68 |
69 |
70 | def float_parameter(level, maxval):
71 | """Helper function to scale `val` between 0 and maxval.
72 | Args:
73 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
74 | maxval: Maximum value that the operation can have. This will be scaled to
75 | level/PARAMETER_MAX.
76 | Returns:
77 | A float that results from scaling `maxval` according to `level`.
78 | """
79 | return float(level) * maxval / 10.
80 |
81 |
82 | def sample_level(n):
83 | return np.random.uniform(low=0.1, high=n)
84 |
85 |
86 | def autocontrast(pil_img, *args):
87 | return ImageOps.autocontrast(pil_img)
88 |
89 |
90 | def equalize(pil_img, *args):
91 | return ImageOps.equalize(pil_img)
92 |
93 |
94 | def posterize(pil_img, level, *args):
95 | level = int_parameter(sample_level(level), 4)
96 | return ImageOps.posterize(pil_img, 4 - level)
97 |
98 |
99 | def rotate(pil_img, level, *args):
100 | degrees = int_parameter(sample_level(level), 30)
101 | if np.random.uniform() > 0.5:
102 | degrees = -degrees
103 | return pil_img.rotate(degrees, resample=Image.BILINEAR)
104 |
105 |
106 | def solarize(pil_img, level, *args):
107 | level = int_parameter(sample_level(level), 256)
108 | return ImageOps.solarize(pil_img, 256 - level)
109 |
110 |
111 | def shear_x(pil_img, level, image_size):
112 | level = float_parameter(sample_level(level), 0.3)
113 | if np.random.uniform() > 0.5:
114 | level = -level
115 | return pil_img.transform(image_size,
116 | Image.AFFINE, (1, level, 0, 0, 1, 0),
117 | resample=Image.BILINEAR)
118 |
119 |
120 | def shear_y(pil_img, level, image_size):
121 | level = float_parameter(sample_level(level), 0.3)
122 | if np.random.uniform() > 0.5:
123 | level = -level
124 | return pil_img.transform(image_size,
125 | Image.AFFINE, (1, 0, 0, level, 1, 0),
126 | resample=Image.BILINEAR)
127 |
128 |
129 | def translate_x(pil_img, level, image_size):
130 | level = int_parameter(sample_level(level), image_size[0] / 3)
131 | if np.random.random() > 0.5:
132 | level = -level
133 | return pil_img.transform(image_size,
134 | Image.AFFINE, (1, 0, level, 0, 1, 0),
135 | resample=Image.BILINEAR)
136 |
137 |
138 | def translate_y(pil_img, level, image_size):
139 | level = int_parameter(sample_level(level), image_size[1] / 3)
140 | if np.random.random() > 0.5:
141 | level = -level
142 | return pil_img.transform(image_size,
143 | Image.AFFINE, (1, 0, 0, 0, 1, level),
144 | resample=Image.BILINEAR)
145 |
146 |
147 | # operation that overlaps with ImageNet-C's test set
148 | def color(pil_img, level, *args):
149 | level = float_parameter(sample_level(level), 1.8) + 0.1
150 | return ImageEnhance.Color(pil_img).enhance(level)
151 |
152 |
153 | # operation that overlaps with ImageNet-C's test set
154 | def contrast(pil_img, level, *args):
155 | level = float_parameter(sample_level(level), 1.8) + 0.1
156 | return ImageEnhance.Contrast(pil_img).enhance(level)
157 |
158 |
159 | # operation that overlaps with ImageNet-C's test set
160 | def brightness(pil_img, level, *args):
161 | level = float_parameter(sample_level(level), 1.8) + 0.1
162 | return ImageEnhance.Brightness(pil_img).enhance(level)
163 |
164 |
165 | # operation that overlaps with ImageNet-C's test set
166 | def sharpness(pil_img, level, *args):
167 | level = float_parameter(sample_level(level), 1.8) + 0.1
168 | return ImageEnhance.Sharpness(pil_img).enhance(level)
169 |
170 |
171 | augmentations_reid = [
172 | autocontrast, equalize, posterize, shear_x, shear_y,
173 | color, contrast, brightness, sharpness
174 | ]
175 |
176 | augmentations = [
177 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
178 | translate_x, translate_y
179 | ]
180 |
181 | augmentations_all = [
182 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
183 | translate_x, translate_y, color, contrast, brightness, sharpness
184 | ]
185 |
--------------------------------------------------------------------------------
/enviroments.sh:
--------------------------------------------------------------------------------
1 | pip install torch torchvision torchaudio
2 | pip install einops
3 | pip install timm
4 | pip install scikit-image
5 | pip install opencv-python
6 | pip install tensorboard
7 | pip install yacs
8 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .make_loss import make_loss
2 | from .arcface import ArcFace
3 | from .smooth import *
4 | from .myloss import *
5 |
--------------------------------------------------------------------------------
/loss/arcface.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import Parameter
5 | import math
6 |
7 |
8 | class ArcFace(nn.Module):
9 | def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False):
10 | super(ArcFace, self).__init__()
11 | self.in_features = in_features
12 | self.out_features = out_features
13 | self.s = s
14 | self.m = m
15 | self.cos_m = math.cos(m)
16 | self.sin_m = math.sin(m)
17 |
18 | self.th = math.cos(math.pi - m)
19 | self.mm = math.sin(math.pi - m) * m
20 |
21 | self.weight = Parameter(torch.Tensor(out_features, in_features))
22 | if bias:
23 | self.bias = Parameter(torch.Tensor(out_features))
24 | else:
25 | self.register_parameter('bias', None)
26 | self.reset_parameters()
27 |
28 | def reset_parameters(self):
29 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
30 | if self.bias is not None:
31 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
32 | bound = 1 / math.sqrt(fan_in)
33 | nn.init.uniform_(self.bias, -bound, bound)
34 |
35 | def forward(self, input, label):
36 | cosine = F.linear(F.normalize(input), F.normalize(self.weight))
37 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
38 | phi = cosine * self.cos_m - sine * self.sin_m
39 | phi = torch.where(cosine > self.th, phi, cosine - self.mm)
40 | # --------------------------- convert label to one-hot ---------------------------
41 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
42 | one_hot = torch.zeros(cosine.size(), device='cuda')
43 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
44 | # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
45 | output = (one_hot * phi) + (
46 | (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
47 | output *= self.s
48 | # print(output)
49 |
50 | return output
51 |
52 | class CircleLoss(nn.Module):
53 | def __init__(self, in_features, num_classes, s=256, m=0.25):
54 | super(CircleLoss, self).__init__()
55 | self.weight = Parameter(torch.Tensor(num_classes, in_features))
56 | self.s = s
57 | self.m = m
58 | self._num_classes = num_classes
59 | self.reset_parameters()
60 |
61 |
62 | def reset_parameters(self):
63 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
64 |
65 | def __call__(self, bn_feat, targets):
66 |
67 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
68 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.)
69 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.)
70 | delta_p = 1 - self.m
71 | delta_n = self.m
72 |
73 | s_p = self.s * alpha_p * (sim_mat - delta_p)
74 | s_n = self.s * alpha_n * (sim_mat - delta_n)
75 |
76 | targets = F.one_hot(targets, num_classes=self._num_classes)
77 |
78 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n
79 |
80 | return pred_class_logits
--------------------------------------------------------------------------------
/loss/build_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy
3 | from .triplet_loss import TripletLoss
4 | from .center_loss import CenterLoss
5 | from .ce_labelSmooth import CrossEntropyLabelSmooth as CE_LS
6 |
7 | feat_dim_dict = {
8 | 'local_attention_vit': 768,
9 | 'vit': 768,
10 | 'resnet18': 512,
11 | 'resnet34': 512
12 | }
13 |
14 | def build_loss(cfg, num_classes):
15 | name = cfg.MODEL.NAME
16 | sampler = cfg.DATALOADER.SAMPLER
17 | if cfg.MODEL.NAME not in feat_dim_dict.keys():
18 | feat_dim = 2048
19 | else:
20 | feat_dim = feat_dim_dict[cfg.MODEL.NAME]
21 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
22 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE:
23 | if cfg.MODEL.NO_MARGIN:
24 | triplet = TripletLoss()
25 | print("using soft triplet loss for training")
26 | else:
27 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
28 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN))
29 | else:
30 | print('expected METRIC_LOSS_TYPE should be triplet'
31 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
32 |
33 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
34 | if name == 'local_attention_vit' and cfg.MODEL.PC_LOSS:
35 | xent = CrossEntropyLabelSmooth(num_classes=num_classes)
36 | else:
37 | xent = CE_LS(num_classes=num_classes)
38 | print("label smooth on, numclasses:", num_classes)
39 |
40 | if sampler == 'softmax': # softmax loss only
41 | def loss_func(score, feat, target):
42 | return F.cross_entropy(score, target)
43 |
44 | # softmax & triplet
45 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet' or 'GS':
46 | def loss_func(score, feat, target, domains=None, t_domains=None, all_posvid=None, soft_label=False, soft_weight=0.1, soft_lambda=0.2):
47 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
48 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
49 | if name == 'local_attention_vit' and cfg.MODEL.PC_LOSS:
50 | ID_LOSS = xent(score, target, all_posvid=all_posvid, soft_label=soft_label,soft_weight=soft_weight, soft_lambda=soft_lambda)
51 | else:
52 | ID_LOSS = xent(score, target)
53 | else:
54 | ID_LOSS = F.cross_entropy(score, target)
55 |
56 | TRI_LOSS = triplet(feat, target)[0]
57 | # DOMAIN_LOSS = xent(domains, t_domains)
58 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \
59 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
60 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
61 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
62 | return xent(score, target) + \
63 | triplet(feat, target)[0] + \
64 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
65 | else:
66 | return F.cross_entropy(score, target) + \
67 | triplet(feat, target)[0] + \
68 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
69 | else:
70 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
71 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
72 |
73 | else:
74 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center'
75 | 'but got {}'.format(cfg.DATALOADER.SAMPLER))
76 | return loss_func, center_criterion
77 |
78 |
79 |
--------------------------------------------------------------------------------
/loss/ce_labelSmooth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | class CrossEntropyLabelSmooth(nn.Module):
5 | """Cross entropy loss with label smoothing regularizer.
6 |
7 | Reference:
8 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
9 | Equation: y = (1 - epsilon) * y + epsilon / K.
10 |
11 | Args:
12 | num_classes (int): number of classes.
13 | epsilon (float): weight.
14 | """
15 |
16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
17 | super(CrossEntropyLabelSmooth, self).__init__()
18 | self.num_classes = num_classes
19 | self.epsilon = epsilon
20 | self.use_gpu = use_gpu
21 | self.logsoftmax = nn.LogSoftmax(dim=1)
22 |
23 | def forward(self, inputs, targets):
24 | """
25 | Args:
26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
27 | targets: ground truth labels with shape (num_classes)
28 | """
29 | log_probs = self.logsoftmax(inputs)
30 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
31 | if self.use_gpu: targets = targets.cuda()
32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
33 | loss = (- targets * log_probs).mean(0).sum()
34 | return loss
35 |
36 | class LabelSmoothingCrossEntropy(nn.Module):
37 | """
38 | NLL loss with label smoothing.
39 | """
40 | def __init__(self, smoothing=0.1):
41 | """
42 | Constructor for the LabelSmoothing module.
43 | :param smoothing: label smoothing factor
44 | """
45 | super(LabelSmoothingCrossEntropy, self).__init__()
46 | assert smoothing < 1.0
47 | self.smoothing = smoothing
48 | self.confidence = 1. - smoothing
49 |
50 | def forward(self, x, target):
51 | logprobs = F.log_softmax(x, dim=-1)
52 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
53 | nll_loss = nll_loss.squeeze(1)
54 | smooth_loss = -logprobs.mean(dim=-1)
55 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
56 | return loss.mean()
--------------------------------------------------------------------------------
/loss/center_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class CenterLoss(nn.Module):
8 | """Center loss.
9 |
10 | Reference:
11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
12 |
13 | Args:
14 | num_classes (int): number of classes.
15 | feat_dim (int): feature dimension.
16 | """
17 |
18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
19 | super(CenterLoss, self).__init__()
20 | self.num_classes = num_classes
21 | self.feat_dim = feat_dim
22 | self.use_gpu = use_gpu
23 |
24 | if self.use_gpu:
25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
26 | else:
27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
28 |
29 | def forward(self, x, labels):
30 | """
31 | Args:
32 | x: feature matrix with shape (batch_size, feat_dim).
33 | labels: ground truth labels with shape (num_classes).
34 | """
35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
36 |
37 | batch_size = x.size(0)
38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
40 | # distmat = torch.addmm(distmat, x, self.centers.t(), beta=1, alpha=-2)
41 | distmat.addmm_(1, -2, x.float(), self.centers.t())
42 |
43 | classes = torch.arange(self.num_classes).long()
44 | if self.use_gpu: classes = classes.cuda()
45 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
46 | mask = labels.eq(classes.expand(batch_size, self.num_classes))
47 |
48 | dist = []
49 | for i in range(batch_size):
50 | value = distmat[i][mask[i]]
51 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability
52 | dist.append(value)
53 | dist = torch.cat(dist)
54 | loss = dist.mean()
55 | return loss
56 |
57 |
58 | if __name__ == '__main__':
59 | use_gpu = False
60 | center_loss = CenterLoss(use_gpu=use_gpu)
61 | features = torch.rand(16, 2048)
62 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long()
63 | if use_gpu:
64 | features = torch.rand(16, 2048).cuda()
65 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda()
66 |
67 | loss = center_loss(features, targets)
68 | print(loss)
69 |
--------------------------------------------------------------------------------
/loss/make_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy
3 | from .triplet_loss import TripletLoss
4 | from .center_loss import CenterLoss
5 |
6 |
7 | def make_loss(cfg, num_classes):
8 | sampler = cfg.DATALOADER.SAMPLER
9 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE:
10 | if cfg.MODEL.NO_MARGIN:
11 | triplet = TripletLoss()
12 | print("using soft triplet loss for training")
13 | else:
14 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
15 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN))
16 | else:
17 | print('expected METRIC_LOSS_TYPE should be triplet'
18 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
19 |
20 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
21 | xent = CrossEntropyLabelSmooth(num_classes=num_classes)
22 | print("label smooth on, numclasses:", num_classes)
23 |
24 | if sampler == 'softmax':
25 | def loss_func(score, feat, target):
26 | return F.cross_entropy(score, target)
27 |
28 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
29 | def loss_func(score, feat, target, target_cam):
30 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
31 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
32 | if isinstance(score, list):
33 | ID_LOSS = [xent(scor, target) for scor in score[1:]]
34 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS)
35 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target)
36 | else:
37 | ID_LOSS = xent(score, target)
38 |
39 | if isinstance(feat, list):
40 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]]
41 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS)
42 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0]
43 | else:
44 | TRI_LOSS = triplet(feat, target)[0]
45 |
46 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \
47 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
48 | else:
49 | if isinstance(score, list):
50 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score[1:]]
51 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS)
52 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * F.cross_entropy(score[0], target)
53 | else:
54 | ID_LOSS = F.cross_entropy(score, target)
55 |
56 | if isinstance(feat, list):
57 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]]
58 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS)
59 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0]
60 | else:
61 | TRI_LOSS = triplet(feat, target)[0]
62 |
63 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \
64 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
65 | else:
66 | print('expected METRIC_LOSS_TYPE should be triplet'
67 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
68 |
69 | else:
70 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center'
71 | 'but got {}'.format(cfg.DATALOADER.SAMPLER))
72 | return loss_func
73 |
74 |
75 |
--------------------------------------------------------------------------------
/loss/myloss.py:
--------------------------------------------------------------------------------
1 | from doctest import FAIL_FAST
2 | from importlib.resources import path
3 |
4 | from numpy import tensordot
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class Pedal(nn.Module):
10 |
11 | def __init__(self, scale=10, k=10):
12 | super(Pedal, self).__init__()
13 | self.scale =scale
14 | self.k = k
15 |
16 |
17 | def forward(self, feature, centers, position, PatchMemory = None, vid=None, camid=None):
18 |
19 | loss = 0
20 |
21 |
22 | all_posvid = []
23 | for p in range(feature.size(0)):
24 | part_feat = feature[p, :, :]
25 | part_centers = centers[p, :, :]
26 | m, n = part_feat.size(0), part_centers.size(0)
27 | dist_map = part_feat.pow(2).sum(dim=1, keepdim=True).expand(m, n) + \
28 | part_centers.pow(2).sum(dim=1, keepdim=True).expand(n, m).t()
29 | dist_map.addmm_(1, -2, part_feat, part_centers.t())
30 |
31 | trick = torch.arange(dist_map.size(1)).cuda().expand_as(dist_map)
32 |
33 | neg, index = dist_map[trick!=position.unsqueeze(dim=1).expand_as(dist_map)].view(dist_map.size(0), -1).sort(dim=1)
34 |
35 | pos_camid = torch.tensor(PatchMemory.camid).cuda()
36 | pos_camid = pos_camid[(index[:,:self.k])]
37 | flag = pos_camid != camid.unsqueeze(dim=1).expand_as(pos_camid)
38 |
39 | pos_vid = torch.tensor(PatchMemory.vid).cuda()
40 | pos_vid = pos_vid[(index[:,:self.k])]
41 | all_posvid.append(pos_vid)
42 |
43 | x = ((-1 * self.scale * neg[:, :self.k]).exp().sum(dim=1)).log()
44 |
45 | y = ((-1 * self.scale * neg).exp().sum(dim=1)).log()
46 |
47 | l = (-x + y).sum().div(feature.size(1))
48 | l = torch.where(torch.isnan(l), torch.full_like(l, 0.), l)
49 | loss += l
50 | loss = loss.div(feature.size(0))
51 |
52 | return loss, all_posvid
53 |
54 |
55 |
56 | class Ipfl(nn.Module):
57 | def __init__(self, margin=1.0, p=2, eps=1e-6, max_iter=15, nearest=3, num=2, swap=False):
58 |
59 | super(Ipfl, self).__init__()
60 | self.margin = margin
61 | self.p = p
62 | self.eps = eps
63 | self.swap = swap
64 | self.max_iter = max_iter
65 | self.num = num
66 | self.nearest = nearest
67 |
68 |
69 | def forward(self, feature, centers):
70 |
71 | image_label = torch.arange(feature.size(0) // self.num).repeat(self.num, 1).transpose(0, 1).contiguous().view(-1)
72 | center_label = torch.arange(feature.size(0) // self.num)
73 | loss = 0
74 | size = 0
75 |
76 | for i in range(0, feature.size(0), 1):
77 | label = image_label[i]
78 | diff = (feature[i, :].expand_as(centers) - centers).pow(self.p).sum(dim=1)
79 | diff = torch.sqrt(diff)
80 |
81 | same = diff[center_label == label]
82 | sorted, index = diff[center_label != label].sort()
83 | trust_diff_label = []
84 | trust_diff = []
85 |
86 | # cycle ranking
87 | max_iter = self.max_iter if self.max_iter < index.size(0) else index.size(0)
88 | for j in range(max_iter):
89 | s = centers[center_label != label, :][index[j]]
90 | l = center_label[center_label != label][index[j]]
91 |
92 | sout = (s.expand_as(centers) - centers).pow(self.p).sum(dim=1)
93 | sout = sout.pow(1. / self.p)
94 |
95 | ssorted, sindex = torch.sort(sout)
96 | near = center_label[sindex[:self.nearest]]
97 | if (label not in near): # view as different identity
98 | trust_diff.append(sorted[j])
99 | trust_diff_label.append(l)
100 | break
101 |
102 | if len(trust_diff) == 0:
103 | trust_diff.append(torch.tensor([0.]).cuda())
104 |
105 | min_diff = torch.stack(trust_diff, dim=0).min()
106 |
107 | dist_hinge = torch.clamp(self.margin + same.mean() - min_diff, min=0.0)
108 |
109 | size += 1
110 | loss += dist_hinge
111 |
112 | loss = loss / size
113 | return loss
114 |
115 |
116 | class TripletHard(nn.Module):
117 | def __init__(self, margin=1.0, p=2, eps=1e-5, swap=False, norm=False):
118 | super(TripletHard, self).__init__()
119 | self.margin = margin
120 | self.p = p
121 | self.eps = eps
122 | self.swap = swap
123 | self.norm = norm
124 | self.sigma = 3
125 |
126 |
127 | def forward(self, feature, label):
128 |
129 | if self.norm:
130 | feature = feature.div(feature.norm(dim=1).unsqueeze(1))
131 | loss = 0
132 |
133 | m, n = feature.size(0), feature.size(0)
134 | dist_map = feature.pow(2).sum(dim=1, keepdim=True).expand(m, n) + \
135 | feature.pow(2).sum(dim=1, keepdim=True).expand(n, m).t() + self.eps
136 | dist_map.addmm_(1, -2, feature, feature.t()).sqrt_()
137 |
138 | sorted, index = dist_map.sort(dim=1)
139 |
140 | for i in range(feature.size(0)):
141 |
142 | same = sorted[i, :][label[index[i, :]] == label[i]]
143 | diff = sorted[i, :][label[index[i, :]] != label[i]]
144 | dist_hinge = torch.clamp(self.margin + same[1] - diff.min(), min=0.0)
145 | loss += dist_hinge
146 |
147 | loss = loss / (feature.size(0))
148 | return loss
149 |
--------------------------------------------------------------------------------
/loss/smooth.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch
3 |
4 |
5 | class PatchMemory(object):
6 |
7 | def __init__(self, momentum=0.1, num=1):
8 |
9 | self.name = []
10 | self.agent = []
11 | self.momentum = momentum
12 | self.num = num
13 |
14 | self.camid = []
15 | self.vid = []
16 |
17 |
18 |
19 | def get_soft_label(self, path, feat_list, vid=None, camid=None):
20 |
21 | feat = torch.stack(feat_list, dim=0)
22 |
23 | feat = feat[:, ::self.num, :]
24 |
25 |
26 | position = []
27 |
28 |
29 | # update the agent
30 | for j,p in enumerate(path):
31 |
32 | current_soft_feat = feat[:, j, :].detach()
33 | if current_soft_feat.is_cuda:
34 | current_soft_feat = current_soft_feat.cpu()
35 | key = p
36 | if key not in self.name:
37 | self.name.append(key)
38 | self.camid.append(camid[j])
39 | self.vid.append(vid[j])
40 | self.agent.append(current_soft_feat)
41 | ind = self.name.index(key)
42 | position.append(ind)
43 |
44 | else:
45 | ind = self.name.index(key)
46 | tmp = self.agent.pop(ind)
47 | tmp = tmp*(1-self.momentum) + self.momentum*current_soft_feat
48 | self.agent.insert(ind, tmp)
49 | position.append(ind)
50 |
51 | if len(position) != 0:
52 | position = torch.tensor(position).cuda()
53 |
54 | agent = torch.stack(self.agent, dim=1).cuda()
55 | return agent, position
56 |
57 | def _dequeue_and_enqueue(self, keys):
58 | # gather keys before updating queue
59 | keys = concat_all_gather(keys)
60 |
61 | batch_size = keys.shape[0]
62 |
63 | ptr = int(self.queue_ptr)
64 | assert self.K % batch_size == 0 # for simplicity
65 |
66 | # replace the keys at ptr (dequeue and enqueue)
67 | self.queue[:, ptr:ptr + batch_size] = keys.T
68 | ptr = (ptr + batch_size) % self.K # move pointer
69 |
70 | self.queue_ptr[0] = ptr
71 |
72 |
73 | # utils
74 | @torch.no_grad()
75 | def concat_all_gather(tensor):
76 | """
77 | Performs all_gather operation on the provided tensors.
78 | *** Warning ***: torch.distributed.all_gather has no gradient.
79 | """
80 | tensors_gather = [torch.ones_like(tensor)
81 | for _ in range(torch.distributed.get_world_size())]
82 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
83 |
84 | output = torch.cat(tensors_gather, dim=0)
85 | return output
86 |
87 |
88 |
89 | class SmoothingForImage(object):
90 | def __init__(self, momentum=0.1, num=1):
91 |
92 | self.map = dict()
93 | self.momentum = momentum
94 | self.num = num
95 |
96 |
97 | def get_soft_label(self, path, feature):
98 |
99 | feature = torch.cat(feature, dim=1)
100 | soft_label = []
101 |
102 | for j,p in enumerate(path):
103 |
104 | current_soft_feat = feature[j*self.num:(j+1)*self.num, :].detach().mean(dim=0)
105 | if current_soft_feat.is_cuda:
106 | current_soft_feat = current_soft_feat.cpu()
107 |
108 | key = p
109 | if key not in self.map:
110 | self.map.setdefault(key, current_soft_feat)
111 | soft_label.append(self.map[key])
112 | else:
113 | self.map[key] = self.map[key]*(1-self.momentum) + self.momentum*current_soft_feat
114 | soft_label.append(self.map[key])
115 | soft_label = torch.stack(soft_label, dim=0).cuda()
116 | return soft_label
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/loss/softmax_loss.py:
--------------------------------------------------------------------------------
1 | from cmath import isnan
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | class CrossEntropyLabelSmooth(nn.Module):
6 | """Cross entropy loss with label smoothing regularizer.
7 |
8 | Reference:
9 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
10 | Equation: y = (1 - epsilon) * y + epsilon / K.
11 |
12 | Args:
13 | num_classes (int): number of classes.
14 | epsilon (float): weight.
15 | """
16 |
17 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
18 | super(CrossEntropyLabelSmooth, self).__init__()
19 | self.num_classes = num_classes
20 | self.epsilon = epsilon
21 | self.use_gpu = use_gpu
22 | self.logsoftmax = nn.LogSoftmax(dim=1)
23 |
24 | def forward(self, inputs, targets, all_posvid=None, soft_label=False, soft_weight=0.1, soft_lambda=0.2):
25 | """
26 | Args:
27 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
28 | targets: ground truth labels with shape (num_classes)
29 | """
30 | all_posvid = torch.cat(all_posvid, dim=1)
31 | soft_targets = []
32 | for i in range(all_posvid.size(0)):
33 | s_id, s_num = torch.unique(all_posvid[i,:], return_counts=True)
34 | sum_num = s_num.sum()
35 | temp = torch.zeros(inputs.size(1)).cuda().scatter_(0, s_id, (soft_lambda/sum_num)*s_num)
36 | soft_targets.append(temp)
37 |
38 | soft_targets = torch.stack(soft_targets, dim=0)
39 |
40 |
41 | log_probs = self.logsoftmax(inputs)
42 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
43 | if self.use_gpu: targets = targets.cuda()
44 | if soft_label:
45 | soft_targets = (1 - soft_lambda) * targets + soft_targets
46 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
47 | loss = (- targets * log_probs).mean(0).sum()*(1 - soft_weight) + \
48 | (- soft_targets * log_probs).mean(0).sum()*soft_weight
49 | # if torch.isnan(loss).item():
50 | # print("====nan!!!====\n{}\n{}".format((- targets * log_probs).mean(0).sum(), (- soft_targets * log_probs).mean(0).sum()))
51 |
52 | else:
53 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
54 | loss = (- targets * log_probs).mean(0).sum()
55 | return loss
56 |
57 | class LabelSmoothingCrossEntropy(nn.Module):
58 | """
59 | NLL loss with label smoothing.
60 | """
61 | def __init__(self, smoothing=0.1):
62 | """
63 | Constructor for the LabelSmoothing module.
64 | :param smoothing: label smoothing factor
65 | """
66 | super(LabelSmoothingCrossEntropy, self).__init__()
67 | assert smoothing < 1.0
68 | self.smoothing = smoothing
69 | self.confidence = 1. - smoothing
70 |
71 | def forward(self, x, target):
72 | logprobs = F.log_softmax(x, dim=-1)
73 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
74 | nll_loss = nll_loss.squeeze(1)
75 | smooth_loss = -logprobs.mean(dim=-1)
76 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
77 | return loss.mean()
--------------------------------------------------------------------------------
/loss/triplet_loss.py:
--------------------------------------------------------------------------------
1 | from cProfile import label
2 | import torch
3 | from torch import nn
4 |
5 |
6 | def normalize(x, axis=-1):
7 | """Normalizing to unit length along the specified dimension.
8 | Args:
9 | x: pytorch Variable
10 | Returns:
11 | x: pytorch Variable, same shape as input
12 | """
13 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
14 | return x
15 |
16 |
17 | def euclidean_dist(x, y):
18 | """
19 | Args:
20 | x: pytorch Variable, with shape [m, d]
21 | y: pytorch Variable, with shape [n, d]
22 | Returns:
23 | dist: pytorch Variable, with shape [m, n]
24 | """
25 | m, n = x.size(0), y.size(0)
26 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
27 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
28 | dist = xx + yy
29 | dist = dist - 2 * torch.matmul(x, y.t())
30 | # dist.addmm_(1, -2, x, y.t())
31 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
32 | return dist
33 |
34 |
35 | def cosine_dist(x, y):
36 | """
37 | Args:
38 | x: pytorch Variable, with shape [m, d]
39 | y: pytorch Variable, with shape [n, d]
40 | Returns:
41 | dist: pytorch Variable, with shape [m, n]
42 | """
43 | m, n = x.size(0), y.size(0)
44 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n)
45 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t()
46 | xy_intersection = torch.mm(x, y.t())
47 | dist = xy_intersection/(x_norm * y_norm)
48 | dist = (1. - dist) / 2
49 | return dist
50 |
51 |
52 | def hard_example_mining(dist_mat, labels, return_inds=False):
53 | """For each anchor, find the hardest positive and negative sample.
54 | Args:
55 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
56 | labels: pytorch LongTensor, with shape [N]
57 | return_inds: whether to return the indices. Save time if `False`(?)
58 | Returns:
59 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
60 | dist_an: pytorch Variable, distance(anchor, negative); shape [N]
61 | p_inds: pytorch LongTensor, with shape [N];
62 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
63 | n_inds: pytorch LongTensor, with shape [N];
64 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
65 | NOTE: Only consider the case in which all labels have same num of samples,
66 | thus we can cope with all anchors in parallel.
67 | """
68 |
69 | assert len(dist_mat.size()) == 2
70 | assert dist_mat.size(0) == dist_mat.size(1)
71 | N = dist_mat.size(0)
72 |
73 | # shape [N, N]
74 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
75 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
76 |
77 | # `dist_ap` means distance(anchor, positive)
78 | # both `dist_ap` and `relative_p_inds` with shape [N, 1]
79 | dist_ap, relative_p_inds = torch.max(
80 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
81 | # print(dist_mat[is_pos].shape)
82 | # `dist_an` means distance(anchor, negative)
83 | # both `dist_an` and `relative_n_inds` with shape [N, 1]
84 | dist_an, relative_n_inds = torch.min(
85 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
86 | # shape [N]
87 | dist_ap = dist_ap.squeeze(1)
88 | dist_an = dist_an.squeeze(1)
89 |
90 | if return_inds:
91 | # shape [N, N]
92 | ind = (labels.new().resize_as_(labels)
93 | .copy_(torch.arange(0, N).long())
94 | .unsqueeze(0).expand(N, N))
95 | # shape [N, 1]
96 | p_inds = torch.gather(
97 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
98 | n_inds = torch.gather(
99 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
100 | # shape [N]
101 | p_inds = p_inds.squeeze(1)
102 | n_inds = n_inds.squeeze(1)
103 | return dist_ap, dist_an, p_inds, n_inds
104 |
105 | return dist_ap, dist_an
106 |
107 |
108 | class TripletLoss(object):
109 | """
110 | Triplet loss using HARDER example mining,
111 | modified based on original triplet loss using hard example mining
112 | """
113 |
114 | def __init__(self, margin=None, hard_factor=0.0):
115 | self.margin = margin
116 | self.hard_factor = hard_factor
117 | if margin is not None:
118 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
119 | else:
120 | self.ranking_loss = nn.SoftMarginLoss()
121 |
122 | def __call__(self, global_feat, labels, normalize_feature=False):
123 | if normalize_feature:
124 | global_feat = normalize(global_feat, axis=-1)
125 | dist_mat = euclidean_dist(global_feat, global_feat)
126 | dist_ap, dist_an = hard_example_mining(dist_mat, labels)
127 |
128 | dist_ap *= (1.0 + self.hard_factor)
129 | dist_an *= (1.0 - self.hard_factor)
130 |
131 | y = dist_an.new().resize_as_(dist_an).fill_(1)
132 | if self.margin is not None:
133 | loss = self.ranking_loss(dist_an, dist_ap, y)
134 | else:
135 | # min_mat = dist_an.new().resize_as_(dist_an).fill_(-85)
136 | # input = max(min_mat, dist_an - dist_ap)
137 | input = dist_an - dist_ap
138 | loss = self.ranking_loss(input, y)
139 | return loss, dist_ap, dist_an
140 |
141 |
142 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .make_model import make_model
--------------------------------------------------------------------------------
/model/backbones/IBN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class IBN(nn.Module):
6 | r"""Instance-Batch Normalization layer from
7 | `"Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net"
8 | `
9 |
10 | Args:
11 | planes (int): Number of channels for the input tensor
12 | ratio (float): Ratio of instance normalization in the IBN layer
13 | """
14 | def __init__(self, planes, ratio=0.5):
15 | super(IBN, self).__init__()
16 | self.half = int(planes * ratio)
17 | self.IN = nn.InstanceNorm2d(self.half, affine=True)
18 | self.BN = nn.BatchNorm2d(planes - self.half)
19 |
20 | def forward(self, x):
21 | split = torch.split(x, self.half, 1)
22 | out1 = self.IN(split[0].contiguous())
23 | out2 = self.BN(split[1].contiguous())
24 | out = torch.cat((out1, out2), 1)
25 | return out
26 |
27 |
28 | class SELayer(nn.Module):
29 | def __init__(self, channel, reduction=16):
30 | super(SELayer, self).__init__()
31 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
32 | self.fc = nn.Sequential(
33 | nn.Linear(channel, int(channel/reduction), bias=False),
34 | nn.ReLU(inplace=True),
35 | nn.Linear(int(channel/reduction), channel, bias=False),
36 | nn.Sigmoid()
37 | )
38 |
39 | def forward(self, x):
40 | b, c, _, _ = x.size()
41 | y = self.avg_pool(x).view(b, c)
42 | y = self.fc(y).view(b, c, 1, 1)
43 | return x * y.expand_as(x)
--------------------------------------------------------------------------------
/model/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .vit_pytorch import vit_base_patch16_224_TransReID, vit_small_patch16_224_TransReID, deit_small_patch16_224_TransReID
--------------------------------------------------------------------------------
/model/backbones/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | """3x3 convolution with padding"""
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 |
12 |
13 | class BasicBlock(nn.Module):
14 | expansion = 1
15 |
16 | def __init__(self, inplanes, planes, stride=1, downsample=None):
17 | super(BasicBlock, self).__init__()
18 | self.conv1 = conv3x3(inplanes, planes, stride)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 | self.relu = nn.ReLU(inplace=True)
21 | self.conv2 = conv3x3(planes, planes)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.downsample = downsample
24 | self.stride = stride
25 |
26 | def forward(self, x):
27 | residual = x
28 |
29 | out = self.conv1(x)
30 | out = self.bn1(out)
31 | out = self.relu(out)
32 |
33 | out = self.conv2(out)
34 | out = self.bn2(out)
35 |
36 | if self.downsample is not None:
37 | residual = self.downsample(x)
38 |
39 | out += residual
40 | out = self.relu(out)
41 |
42 | return out
43 |
44 |
45 | class Bottleneck(nn.Module):
46 | expansion = 4
47 |
48 | def __init__(self, inplanes, planes, stride=1, downsample=None):
49 | super(Bottleneck, self).__init__()
50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
51 | self.bn1 = nn.BatchNorm2d(planes)
52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
53 | padding=1, bias=False)
54 | self.bn2 = nn.BatchNorm2d(planes)
55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
56 | self.bn3 = nn.BatchNorm2d(planes * 4)
57 | self.relu = nn.ReLU(inplace=True)
58 | self.downsample = downsample
59 | self.stride = stride
60 |
61 | def forward(self, x):
62 | residual = x
63 |
64 | out = self.conv1(x)
65 | out = self.bn1(out)
66 | out = self.relu(out)
67 |
68 | out = self.conv2(out)
69 | out = self.bn2(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv3(out)
73 | out = self.bn3(out)
74 |
75 | if self.downsample is not None:
76 | residual = self.downsample(x)
77 |
78 | out += residual
79 | out = self.relu(out)
80 |
81 | return out
82 |
83 |
84 | class ResNet(nn.Module):
85 | def __init__(self, last_stride=2, block=Bottleneck,layers=[3, 4, 6, 3]):
86 | self.inplanes = 64
87 | super().__init__()
88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
89 | bias=False)
90 | self.bn1 = nn.BatchNorm2d(64)
91 | # self.relu = nn.ReLU(inplace=True) # add missed relu
92 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=None, padding=0)
93 | self.layer1 = self._make_layer(block, 64, layers[0])
94 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
96 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride)
97 |
98 | def _make_layer(self, block, planes, blocks, stride=1):
99 | downsample = None
100 | if stride != 1 or self.inplanes != planes * block.expansion:
101 | downsample = nn.Sequential(
102 | nn.Conv2d(self.inplanes, planes * block.expansion,
103 | kernel_size=1, stride=stride, bias=False),
104 | nn.BatchNorm2d(planes * block.expansion),
105 | )
106 |
107 | layers = []
108 | layers.append(block(self.inplanes, planes, stride, downsample))
109 | self.inplanes = planes * block.expansion
110 | for i in range(1, blocks):
111 | layers.append(block(self.inplanes, planes))
112 |
113 | return nn.Sequential(*layers)
114 |
115 | def forward(self, x, cam_label=None):
116 | x = self.conv1(x)
117 | x = self.bn1(x)
118 | # x = self.relu(x) # add missed relu
119 | x = self.maxpool(x)
120 | x = self.layer1(x)
121 | x = self.layer2(x)
122 | x = self.layer3(x)
123 | x = self.layer4(x)
124 |
125 | return x
126 |
127 | def load_param(self, model_path):
128 | param_dict = torch.load(model_path)
129 | for i in param_dict:
130 | if 'fc' in i:
131 | continue
132 | self.state_dict()[i].copy_(param_dict[i])
133 |
134 | def random_init(self):
135 | for m in self.modules():
136 | if isinstance(m, nn.Conv2d):
137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
138 | m.weight.data.normal_(0, math.sqrt(2. / n))
139 | elif isinstance(m, nn.BatchNorm2d):
140 | m.weight.data.fill_(1)
141 | m.bias.data.zero_()
142 |
143 | def compute_num_params(self):
144 | total = sum([param.nelement() for param in self.parameters()])
145 | print("Number of parameter: %.2fM" % (total/1e6))
--------------------------------------------------------------------------------
/processor/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyuke65535/Part-Aware-Transformer/104d42e8292f7e5d534689ded15e4afafb453785/processor/__init__.py
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | # train PAT
2 | python train.py --config_file "config/PAT.yml"
--------------------------------------------------------------------------------
/solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .lr_scheduler import WarmupMultiStepLR
2 | from .make_optimizer import make_optimizer
--------------------------------------------------------------------------------
/solver/cosine_lr.py:
--------------------------------------------------------------------------------
1 | """ Cosine Scheduler
2 |
3 | Cosine LR schedule with warmup, cycle/restarts, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import logging
8 | import math
9 | import torch
10 |
11 | from .scheduler import Scheduler
12 |
13 |
14 | _logger = logging.getLogger(__name__)
15 |
16 |
17 | class CosineLRScheduler(Scheduler):
18 | """
19 | Cosine decay with restarts.
20 | This is described in the paper https://arxiv.org/abs/1608.03983.
21 |
22 | Inspiration from
23 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
24 | """
25 |
26 | def __init__(self,
27 | optimizer: torch.optim.Optimizer,
28 | t_initial: int,
29 | t_mul: float = 1.,
30 | lr_min: float = 0.,
31 | decay_rate: float = 1.,
32 | warmup_t=0,
33 | warmup_lr_init=0,
34 | warmup_prefix=False,
35 | cycle_limit=0,
36 | t_in_epochs=True,
37 | noise_range_t=None,
38 | noise_pct=0.67,
39 | noise_std=1.0,
40 | noise_seed=42,
41 | initialize=True) -> None:
42 | super().__init__(
43 | optimizer, param_group_field="lr",
44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
45 | initialize=initialize)
46 |
47 | assert t_initial > 0
48 | assert lr_min >= 0
49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1:
50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning "
51 | "rate since t_initial = t_mul = eta_mul = 1.")
52 | self.t_initial = t_initial
53 | self.t_mul = t_mul
54 | self.lr_min = lr_min
55 | self.decay_rate = decay_rate
56 | self.cycle_limit = cycle_limit
57 | self.warmup_t = warmup_t
58 | self.warmup_lr_init = warmup_lr_init
59 | self.warmup_prefix = warmup_prefix
60 | self.t_in_epochs = t_in_epochs
61 | if self.warmup_t:
62 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
63 | super().update_groups(self.warmup_lr_init)
64 | else:
65 | self.warmup_steps = [1 for _ in self.base_values]
66 |
67 | def _get_lr(self, t):
68 | if t < self.warmup_t:
69 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
70 | else:
71 | if self.warmup_prefix:
72 | t = t - self.warmup_t
73 |
74 | if self.t_mul != 1:
75 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
76 | t_i = self.t_mul ** i * self.t_initial
77 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
78 | else:
79 | i = t // self.t_initial
80 | t_i = self.t_initial
81 | t_curr = t - (self.t_initial * i)
82 |
83 | gamma = self.decay_rate ** i
84 | lr_min = self.lr_min * gamma
85 | lr_max_values = [v * gamma for v in self.base_values]
86 |
87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
88 | lrs = [
89 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
90 | ]
91 | else:
92 | lrs = [self.lr_min for _ in self.base_values]
93 |
94 | return lrs
95 |
96 | def get_epoch_values(self, epoch: int):
97 | if self.t_in_epochs:
98 | return self._get_lr(epoch)
99 | else:
100 | return None
101 |
102 | def get_update_values(self, num_updates: int):
103 | if not self.t_in_epochs:
104 | return self._get_lr(num_updates)
105 | else:
106 | return None
107 |
108 | def get_cycle_length(self, cycles=0):
109 | if not cycles:
110 | cycles = self.cycle_limit
111 | cycles = max(1, cycles)
112 | if self.t_mul == 1.0:
113 | return self.t_initial * cycles
114 | else:
115 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
116 |
--------------------------------------------------------------------------------
/solver/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | from bisect import bisect_right
7 | import torch
8 |
9 |
10 | # FIXME ideally this would be achieved with a CombinedLRScheduler,
11 | # separating MultiStepLR with WarmupLR
12 | # but the current LRScheduler design doesn't allow it
13 |
14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
15 | def __init__(
16 | self,
17 | optimizer,
18 | milestones, # steps
19 | gamma=0.1,
20 | warmup_factor=1.0 / 3,
21 | warmup_iters=500,
22 | warmup_method="linear",
23 | last_epoch=-1,
24 | ):
25 | if not list(milestones) == sorted(milestones):
26 | raise ValueError(
27 | "Milestones should be a list of" " increasing integers. Got {}",
28 | milestones,
29 | )
30 |
31 | if warmup_method not in ("constant", "linear"):
32 | raise ValueError(
33 | "Only 'constant' or 'linear' warmup_method accepted"
34 | "got {}".format(warmup_method)
35 | )
36 | self.milestones = milestones
37 | self.gamma = gamma
38 | self.warmup_factor = warmup_factor
39 | self.warmup_iters = warmup_iters
40 | self.warmup_method = warmup_method
41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
42 |
43 | def _get_lr(self):
44 | warmup_factor = 1
45 | if self.last_epoch < self.warmup_iters:
46 | if self.warmup_method == "constant":
47 | warmup_factor = self.warmup_factor
48 | elif self.warmup_method == "linear":
49 | alpha = self.last_epoch / self.warmup_iters
50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
51 | return [
52 | base_lr
53 | * warmup_factor
54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
55 | for base_lr in self.base_lrs
56 | ]
57 |
--------------------------------------------------------------------------------
/solver/make_optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def make_optimizer(cfg, model):
5 | params = []
6 | for key, value in model.named_parameters():
7 | if not value.requires_grad:
8 | continue
9 | lr = cfg.SOLVER.BASE_LR
10 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
11 | if "bias" in key:
12 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
14 | if cfg.SOLVER.LARGE_FC_LR:
15 | if "classifier" in key or "arcface" in key:
16 | lr = cfg.SOLVER.BASE_LR * 2
17 | print('Using two times learning rate for fc ')
18 |
19 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
20 |
21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
23 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW':
24 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
25 | else:
26 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
27 |
28 | return optimizer
29 |
--------------------------------------------------------------------------------
/solver/scheduler.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 |
3 | import torch
4 |
5 |
6 | class Scheduler:
7 | """ Parameter Scheduler Base Class
8 | A scheduler base class that can be used to schedule any optimizer parameter groups.
9 |
10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called
11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
13 |
14 | The schedulers built on this should try to remain as stateless as possible (for simplicity).
15 |
16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training
18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call.
19 |
20 | Based on ideas from:
21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
23 | """
24 |
25 | def __init__(self,
26 | optimizer: torch.optim.Optimizer,
27 | param_group_field: str,
28 | noise_range_t=None,
29 | noise_type='normal',
30 | noise_pct=0.67,
31 | noise_std=1.0,
32 | noise_seed=None,
33 | initialize: bool = True) -> None:
34 | self.optimizer = optimizer
35 | self.param_group_field = param_group_field
36 | self._initial_param_group_field = f"initial_{param_group_field}"
37 | if initialize:
38 | for i, group in enumerate(self.optimizer.param_groups):
39 | if param_group_field not in group:
40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
41 | group.setdefault(self._initial_param_group_field, group[param_group_field])
42 | else:
43 | for i, group in enumerate(self.optimizer.param_groups):
44 | if self._initial_param_group_field not in group:
45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
47 | self.metric = None # any point to having this for all?
48 | self.noise_range_t = noise_range_t
49 | self.noise_pct = noise_pct
50 | self.noise_type = noise_type
51 | self.noise_std = noise_std
52 | self.noise_seed = noise_seed if noise_seed is not None else 42
53 | self.update_groups(self.base_values)
54 |
55 | def state_dict(self) -> Dict[str, Any]:
56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
57 |
58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
59 | self.__dict__.update(state_dict)
60 |
61 | def get_epoch_values(self, epoch: int):
62 | return None
63 |
64 | def get_update_values(self, num_updates: int):
65 | return None
66 |
67 | def step(self, epoch: int, metric: float = None) -> None:
68 | self.metric = metric
69 | values = self.get_epoch_values(epoch)
70 | if values is not None:
71 | values = self._add_noise(values, epoch)
72 | self.update_groups(values)
73 |
74 | def step_update(self, num_updates: int, metric: float = None):
75 | self.metric = metric
76 | values = self.get_update_values(num_updates)
77 | if values is not None:
78 | values = self._add_noise(values, num_updates)
79 | self.update_groups(values)
80 |
81 | def update_groups(self, values):
82 | if not isinstance(values, (list, tuple)):
83 | values = [values] * len(self.optimizer.param_groups)
84 | for param_group, value in zip(self.optimizer.param_groups, values):
85 | param_group[self.param_group_field] = value
86 |
87 | def _add_noise(self, lrs, t):
88 | if self.noise_range_t is not None:
89 | if isinstance(self.noise_range_t, (list, tuple)):
90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
91 | else:
92 | apply_noise = t >= self.noise_range_t
93 | if apply_noise:
94 | g = torch.Generator()
95 | g.manual_seed(self.noise_seed + t)
96 | if self.noise_type == 'normal':
97 | while True:
98 | # resample if noise out of percent limit, brute force but shouldn't spin much
99 | noise = torch.randn(1, generator=g).item()
100 | if abs(noise) < self.noise_pct:
101 | break
102 | else:
103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
104 | lrs = [v + v * noise for v in lrs]
105 | return lrs
106 |
--------------------------------------------------------------------------------
/solver/scheduler_factory.py:
--------------------------------------------------------------------------------
1 | """ Scheduler Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from .cosine_lr import CosineLRScheduler
5 |
6 |
7 | def create_scheduler(cfg, optimizer):
8 | num_epochs = 120
9 | # type 1
10 | # lr_min = 0.01 * cfg.SOLVER.BASE_LR
11 | # warmup_lr_init = 0.001 * cfg.SOLVER.BASE_LR
12 | # type 2
13 | lr_min = 0.002 * cfg.SOLVER.BASE_LR
14 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR
15 | # type 3
16 | # lr_min = 0.001 * cfg.SOLVER.BASE_LR
17 | # warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR
18 |
19 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS
20 | noise_range = None
21 |
22 | lr_scheduler = CosineLRScheduler(
23 | optimizer,
24 | t_initial=num_epochs,
25 | lr_min=lr_min,
26 | t_mul= 1.,
27 | decay_rate=0.1,
28 | warmup_lr_init=warmup_lr_init,
29 | warmup_t=warmup_t,
30 | cycle_limit=1,
31 | t_in_epochs=True,
32 | noise_range_t=noise_range,
33 | noise_pct= 0.67,
34 | noise_std= 1.,
35 | noise_seed=42,
36 | )
37 |
38 | return lr_scheduler
39 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from config import cfg
3 | import argparse
4 | from data.build_DG_dataloader import build_reid_test_loader
5 | from model import make_model
6 | from processor.part_attention_vit_processor import do_inference as do_inf_pat
7 | from processor.ori_vit_processor_with_amp import do_inference as do_inf
8 | from utils.logger import setup_logger
9 |
10 |
11 | if __name__ == "__main__":
12 | parser = argparse.ArgumentParser(description="ReID Training")
13 | parser.add_argument(
14 | "--config_file", default="./config/PAT.yml", help="path to config file", type=str
15 | )
16 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
17 | nargs=argparse.REMAINDER)
18 |
19 | args = parser.parse_args()
20 |
21 |
22 |
23 | if args.config_file != "":
24 | cfg.merge_from_file(args.config_file)
25 | cfg.merge_from_list(args.opts)
26 | cfg.freeze()
27 |
28 | output_dir = os.path.join(cfg.LOG_ROOT, cfg.LOG_NAME)
29 | if output_dir and not os.path.exists(output_dir):
30 | os.makedirs(output_dir)
31 |
32 | logger = setup_logger("PAT", output_dir, if_train=False)
33 | logger.info(args)
34 |
35 | if args.config_file != "":
36 | logger.info("Loaded configuration file {}".format(args.config_file))
37 | with open(args.config_file, 'r') as cf:
38 | config_str = "\n" + cf.read()
39 | logger.info(config_str)
40 | logger.info("Running with config:\n{}".format(cfg))
41 |
42 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID
43 |
44 | model = make_model(cfg, cfg.MODEL.NAME, 0,0,0)
45 | model.load_param(cfg.TEST.WEIGHT)
46 |
47 | for testname in cfg.DATASETS.TEST:
48 | val_loader, num_query = build_reid_test_loader(cfg, testname)
49 | if cfg.MODEL.NAME == 'part_attention_vit':
50 | do_inf_pat(cfg, model, val_loader, num_query)
51 | else:
52 | do_inf(cfg, model, val_loader, num_query)
53 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from processor.part_attention_vit_processor import part_attention_vit_do_train_with_amp
2 | from processor.ori_vit_processor_with_amp import ori_vit_do_train_with_amp
3 | from utils.logger import setup_logger
4 | from data.build_DG_dataloader import build_reid_train_loader, build_reid_test_loader
5 | from model import make_model
6 | from solver import make_optimizer
7 | from solver.scheduler_factory import create_scheduler
8 | from loss.build_loss import build_loss
9 | import random
10 | import torch
11 | import numpy as np
12 | import os
13 | import argparse
14 | from config import cfg
15 | import loss as Patchloss
16 |
17 | def set_seed(seed):
18 | torch.manual_seed(seed)
19 | torch.cuda.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | np.random.seed(seed)
22 | random.seed(seed)
23 | torch.backends.cudnn.deterministic = True
24 | torch.backends.cudnn.benchmark = True
25 |
26 | if __name__ == '__main__':
27 | parser = argparse.ArgumentParser(description="ReID Training")
28 | parser.add_argument(
29 | "--config_file", default="./config/PAT.yml", help="path to config file", type=str
30 | )
31 |
32 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
33 | nargs=argparse.REMAINDER)
34 | parser.add_argument("--local_rank", default=0, type=int)
35 | args = parser.parse_args()
36 |
37 | if args.config_file != "":
38 | cfg.merge_from_file(args.config_file)
39 | cfg.merge_from_list(args.opts)
40 | cfg.freeze()
41 |
42 | set_seed(cfg.SOLVER.SEED)
43 |
44 | if cfg.MODEL.DIST_TRAIN:
45 | torch.cuda.set_device(args.local_rank)
46 |
47 | output_dir = os.path.join(cfg.LOG_ROOT, cfg.LOG_NAME)
48 | if output_dir and not os.path.exists(output_dir):
49 | os.makedirs(output_dir)
50 |
51 | logger = setup_logger("PAT", output_dir, if_train=True)
52 | logger.info("Saving model in the path :{}".format(output_dir))
53 | logger.info(args)
54 |
55 | if args.config_file != "":
56 | logger.info("Loaded configuration file {}".format(args.config_file))
57 | with open(args.config_file, 'r') as cf:
58 | config_str = "\n" + cf.read()
59 | logger.info(config_str)
60 | logger.info("Running with config:\n{}".format(cfg))
61 |
62 | if cfg.MODEL.DIST_TRAIN:
63 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
64 |
65 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID
66 |
67 | # build DG train loader
68 | train_loader = build_reid_train_loader(cfg)
69 | # build DG validate loader
70 | val_name = cfg.DATASETS.TEST[0]
71 | val_loader, num_query = build_reid_test_loader(cfg, val_name)
72 | num_classes = len(train_loader.dataset.pids)
73 | model_name = cfg.MODEL.NAME
74 | model = make_model(cfg, modelname=model_name, num_class=num_classes, camera_num=None, view_num=None)
75 | if cfg.MODEL.FREEZE_PATCH_EMBED and 'resnet' not in cfg.MODEL.NAME: # trick from moco v3
76 | model.base.patch_embed.proj.weight.requires_grad = False
77 | model.base.patch_embed.proj.bias.requires_grad = False
78 | print("====== freeze patch_embed for stability ======")
79 |
80 | loss_func, center_cri = build_loss(cfg, num_classes=num_classes)
81 |
82 | optimizer = make_optimizer(cfg, model)
83 | scheduler = create_scheduler(cfg, optimizer)
84 |
85 | ################## patch loss ####################
86 | patch_centers = Patchloss.PatchMemory(momentum=0.1, num=1)
87 | pc_criterion = Patchloss.Pedal(scale=cfg.MODEL.PC_SCALE, k=cfg.MODEL.CLUSTER_K).cuda()
88 | if cfg.MODEL.SOFT_LABEL and cfg.MODEL.NAME == 'part_attention_vit':
89 | print("========using soft label========")
90 | ################## patch loss ####################
91 |
92 | do_train_dict = {
93 | 'part_attention_vit': part_attention_vit_do_train_with_amp
94 | }
95 | if model_name not in do_train_dict.keys():
96 | ori_vit_do_train_with_amp(
97 | cfg,
98 | model,
99 | train_loader,
100 | val_loader,
101 | optimizer,
102 | scheduler,
103 | loss_func,
104 | num_query, args.local_rank,
105 | )
106 | else :
107 | do_train_dict[model_name](
108 | cfg,
109 | model,
110 | train_loader,
111 | val_loader,
112 | optimizer,
113 | scheduler,
114 | loss_func,
115 | num_query, args.local_rank,
116 | patch_centers = patch_centers,
117 | pc_criterion = pc_criterion
118 | )
119 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyuke65535/Part-Aware-Transformer/104d42e8292f7e5d534689ded15e4afafb453785/utils/__init__.py
--------------------------------------------------------------------------------
/utils/iotools.py:
--------------------------------------------------------------------------------
1 | import errno
2 | import json
3 | import os
4 |
5 | import os.path as osp
6 |
7 |
8 | def mkdir_if_missing(directory):
9 | if not osp.exists(directory):
10 | try:
11 | os.makedirs(directory)
12 | except OSError as e:
13 | if e.errno != errno.EEXIST:
14 | raise
15 |
16 |
17 | def check_isfile(path):
18 | isfile = osp.isfile(path)
19 | if not isfile:
20 | print("=> Warning: no file found at '{}' (ignored)".format(path))
21 | return isfile
22 |
23 |
24 | def read_json(fpath):
25 | with open(fpath, 'r') as f:
26 | obj = json.load(f)
27 | return obj
28 |
29 |
30 | def write_json(obj, fpath):
31 | mkdir_if_missing(osp.dirname(fpath))
32 | with open(fpath, 'w') as f:
33 | json.dump(obj, f, indent=4, separators=(',', ': '))
34 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import os.path as osp
5 | def setup_logger(name, save_dir, if_train):
6 | logger = logging.getLogger(name)
7 | logger.setLevel(logging.DEBUG)
8 |
9 | ch = logging.StreamHandler(stream=sys.stdout)
10 | ch.setLevel(logging.DEBUG)
11 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
12 | ch.setFormatter(formatter)
13 | logger.addHandler(ch)
14 |
15 | if save_dir:
16 | if not osp.exists(save_dir):
17 | os.makedirs(save_dir)
18 | if if_train:
19 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w')
20 | else:
21 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='w')
22 | fh.setLevel(logging.DEBUG)
23 | fh.setFormatter(formatter)
24 | logger.addHandler(fh)
25 |
26 | return logger
--------------------------------------------------------------------------------
/utils/meter.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | """Computes and stores the average and current value"""
3 |
4 | def __init__(self):
5 | self.val = 0
6 | self.avg = 0
7 | self.sum = 0
8 | self.count = 0
9 |
10 | def reset(self):
11 | self.val = 0
12 | self.avg = 0
13 | self.sum = 0
14 | self.count = 0
15 |
16 | def update(self, val, n=1):
17 | self.val = val
18 | self.sum += val * n
19 | self.count += n
20 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from re import T
2 | from time import time
3 | import torch
4 | import numpy as np
5 | import os
6 | from utils.reranking import re_ranking
7 |
8 |
9 | def euclidean_distance(qf, gf):
10 | m = qf.shape[0]
11 | n = gf.shape[0]
12 | dist_mat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
13 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
14 | dist_mat.addmm_(qf, gf.t(), beta=1, alpha=-2)
15 | return dist_mat.cpu().numpy()
16 |
17 | def cosine_similarity(qf, gf):
18 | epsilon = 0.00001
19 | dist_mat = qf.mm(gf.t())
20 | qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True) # mx1
21 | gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True) # nx1
22 | qg_normdot = qf_norm.mm(gf_norm.t())
23 |
24 | dist_mat = dist_mat.mul(1 / qg_normdot).cpu().numpy()
25 | dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon)
26 | dist_mat = np.arccos(dist_mat)
27 | return dist_mat
28 |
29 |
30 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
31 | """Evaluation with market1501 metric
32 | Key: for each query identity, its gallery images from the same camera view are discarded.
33 | """
34 | num_q, num_g = distmat.shape
35 | # distmat g
36 | # q 1 3 2 4
37 | # 4 1 2 3
38 | if num_g < max_rank:
39 | max_rank = num_g
40 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
41 | indices = np.argsort(distmat, axis=1)
42 | # 0 2 1 3
43 | # 1 2 3 0
44 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
45 | # compute cmc curve for each query
46 | all_cmc = []
47 | all_AP = []
48 | num_valid_q = 0. # number of valid query
49 | for q_idx in range(num_q):
50 | # get query pid and camid
51 | q_pid = q_pids[q_idx]
52 | q_camid = q_camids[q_idx]
53 |
54 | # remove gallery samples that have the same pid and camid with query
55 | order = indices[q_idx] # select one row
56 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
57 | keep = np.invert(remove)
58 |
59 | # compute cmc curve
60 | # binary vector, positions with value 1 are correct matches
61 | orig_cmc = matches[q_idx][keep]
62 | if not np.any(orig_cmc):
63 | # this condition is true when query identity does not appear in gallery
64 | continue
65 |
66 | cmc = orig_cmc.cumsum()
67 | cmc[cmc > 1] = 1
68 |
69 | all_cmc.append(cmc[:max_rank])
70 | num_valid_q += 1.
71 |
72 | # compute average precision
73 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
74 | num_rel = orig_cmc.sum()
75 | tmp_cmc = orig_cmc.cumsum()
76 | #tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
77 | y = np.arange(1, tmp_cmc.shape[0] + 1) * 1.0
78 | tmp_cmc = tmp_cmc / y
79 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
80 | AP = tmp_cmc.sum() / num_rel
81 | all_AP.append(AP)
82 |
83 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
84 |
85 | all_cmc = np.asarray(all_cmc).astype(np.float32)
86 | all_cmc = all_cmc.sum(0) / num_valid_q
87 | mAP = np.mean(all_AP)
88 |
89 | return all_cmc, mAP
90 |
91 |
92 | class R1_mAP_eval():
93 | def __init__(self, num_query, max_rank=50, feat_norm=True, reranking=False):
94 | super(R1_mAP_eval, self).__init__()
95 | self.num_query = num_query
96 | self.max_rank = max_rank
97 | self.feat_norm = feat_norm
98 | self.reranking = reranking
99 |
100 | def reset(self):
101 | self.feats = []
102 | self.pids = []
103 | self.camids = []
104 |
105 | def update(self, output): # called once for each batch
106 | feat, pid, camid = output
107 | self.feats.append(feat.cpu())
108 | self.pids.extend(np.asarray(pid))
109 | self.camids.extend(np.asarray(camid))
110 |
111 | def compute(self): # called after each epoch
112 | feats = torch.cat(self.feats, dim=0)
113 | if self.feat_norm:
114 | # print("The test feature is normalized")
115 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) # along channel
116 | # query
117 | qf = feats[:self.num_query]
118 | q_pids = np.asarray(self.pids[:self.num_query])
119 | q_camids = np.asarray(self.camids[:self.num_query])
120 | # gallery
121 | gf = feats[self.num_query:]
122 | g_pids = np.asarray(self.pids[self.num_query:])
123 |
124 | g_camids = np.asarray(self.camids[self.num_query:])
125 | if self.reranking:
126 | print('=> Enter reranking')
127 | # distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3)
128 | distmat = re_ranking(qf, gf, k1=50, k2=15, lambda_value=0.3)
129 |
130 | else:
131 | # print('=> Computing DistMat with euclidean_distance')
132 | distmat = euclidean_distance(qf, gf)
133 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
134 |
135 | return cmc, mAP, distmat, self.pids, self.camids, qf, gf
136 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/utils/registry.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 |
4 | class Registry(object):
5 | """
6 | The registry that provides name -> object mapping, to support third-party
7 | users' custom modules.
8 | To create a registry (e.g. a backbone registry):
9 | .. code-block:: python
10 | BACKBONE_REGISTRY = Registry('BACKBONE')
11 | To register an object:
12 | .. code-block:: python
13 | @BACKBONE_REGISTRY.register()
14 | class MyBackbone():
15 | ...
16 | Or:
17 | .. code-block:: python
18 | BACKBONE_REGISTRY.register(MyBackbone)
19 | """
20 |
21 | def __init__(self, name: str) -> None:
22 | """
23 | Args:
24 | name (str): the name of this registry
25 | """
26 | self._name: str = name
27 | self._obj_map: Dict[str, object] = {}
28 |
29 | def _do_register(self, name: str, obj: object) -> None:
30 | assert (
31 | name not in self._obj_map
32 | ), "An object named '{}' was already registered in '{}' registry!".format(
33 | name, self._name
34 | )
35 | self._obj_map[name] = obj
36 |
37 | def register(self, obj: object = None) -> Optional[object]:
38 | """
39 | Register the given object under the the name `obj.__name__`.
40 | Can be used as either a decorator or not. See docstring of this class for usage.
41 | """
42 | if obj is None:
43 | # used as a decorator
44 | def deco(func_or_class: object) -> object:
45 | name = func_or_class.__name__ # pyre-ignore
46 | self._do_register(name, func_or_class)
47 | return func_or_class
48 |
49 | return deco
50 |
51 | # used as a function call
52 | name = obj.__name__ # pyre-ignore
53 | self._do_register(name, obj)
54 |
55 | def get(self, name: str) -> object:
56 | ret = self._obj_map.get(name)
57 | if ret is None:
58 | raise KeyError(
59 | "No object named '{}' found in '{}' registry!".format(
60 | name, self._name
61 | )
62 | )
63 | return ret
64 |
--------------------------------------------------------------------------------
/utils/reranking.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Fri, 25 May 2018 20:29:09
3 |
4 |
5 | """
6 |
7 | """
8 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
9 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
10 | Matlab version: https://github.com/zhunzhong07/person-re-ranking
11 | """
12 |
13 | """
14 | API
15 |
16 | probFea: all feature vectors of the query set (torch tensor)
17 | probFea: all feature vectors of the gallery set (torch tensor)
18 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3)
19 | MemorySave: set to 'True' when using MemorySave mode
20 | Minibatch: avaliable when 'MemorySave' is 'True'
21 | """
22 |
23 | import numpy as np
24 | import torch
25 |
26 |
27 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False):
28 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor
29 | query_num = probFea.size(0)
30 | all_num = query_num + galFea.size(0)
31 | if only_local:
32 | original_dist = local_distmat
33 | else:
34 | feat = torch.cat([probFea, galFea])
35 | # print('using GPU to compute original distance')
36 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \
37 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t()
38 | distmat.addmm_(1, -2, feat, feat.t())
39 | original_dist = distmat.cpu().numpy()
40 | del feat
41 | if not local_distmat is None:
42 | original_dist = original_dist + local_distmat
43 | gallery_num = original_dist.shape[0]
44 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
45 | V = np.zeros_like(original_dist).astype(np.float16)
46 | initial_rank = np.argsort(original_dist).astype(np.int32)
47 |
48 | # print('starting re_ranking')
49 | for i in range(all_num):
50 | # k-reciprocal neighbors
51 | forward_k_neigh_index = initial_rank[i, :k1 + 1]
52 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
53 | fi = np.where(backward_k_neigh_index == i)[0]
54 | k_reciprocal_index = forward_k_neigh_index[fi]
55 | k_reciprocal_expansion_index = k_reciprocal_index
56 | for j in range(len(k_reciprocal_index)):
57 | candidate = k_reciprocal_index[j]
58 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
59 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
60 | :int(np.around(k1 / 2)) + 1]
61 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
62 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
63 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
64 | candidate_k_reciprocal_index):
65 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
66 |
67 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
68 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
69 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
70 | original_dist = original_dist[:query_num, ]
71 | if k2 != 1:
72 | V_qe = np.zeros_like(V, dtype=np.float16)
73 | for i in range(all_num):
74 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
75 | V = V_qe
76 | del V_qe
77 | del initial_rank
78 | invIndex = []
79 | for i in range(gallery_num):
80 | invIndex.append(np.where(V[:, i] != 0)[0])
81 |
82 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
83 |
84 | for i in range(query_num):
85 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
86 | indNonZero = np.where(V[i, :] != 0)[0]
87 | indImages = [invIndex[ind] for ind in indNonZero]
88 | for j in range(len(indNonZero)):
89 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
90 | V[indImages[j], indNonZero[j]])
91 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
92 |
93 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
94 | del original_dist
95 | del V
96 | del jaccard_dist
97 | final_dist = final_dist[:query_num, query_num:]
98 | return final_dist
99 |
100 |
--------------------------------------------------------------------------------
/visualization/config_vis/__init__.py:
--------------------------------------------------------------------------------
1 | from .vit_b import _C as cfg
--------------------------------------------------------------------------------
/visualization/good_samples_market_query.json:
--------------------------------------------------------------------------------
1 | [
2 |
3 | "1255_c6s3_045442_00.jpg",
4 | "0778_c1s4_018881_00.jpg",
5 | "0600_c1s3_029851_00.jpg",
6 | "0120_c3s1_020126_00.jpg",
7 | "0568_c1s3_019626_00.jpg",
8 | "0964_c5s2_122274_00.jpg",
9 |
10 | "0535_c4s3_005423_00.jpg",
11 | "0505_c4s3_009248_00.jpg",
12 | "1459_c1s6_009666_00.jpg",
13 | "0294_c6s1_066676_00.jpg",
14 |
15 | "1089_c4s6_039016_00.jpg",
16 | "0801_c6s2_088243_00.jpg",
17 | "1183_c6s3_030367_00.jpg",
18 | "0934_c4s4_061216_00.jpg",
19 | "0355_c3s1_081467_00.jpg",
20 | "0618_c6s2_014593_00.jpg",
21 | "0678_c3s2_048787_00.jpg",
22 | "0174_c5s1_053251_00.jpg",
23 | "0911_c3s2_113153_00.jpg",
24 | "1277_c1s5_052541_00.jpg",
25 | "0005_c6s1_004576_00.jpg",
26 | "1146_c2s2_158802_00.jpg",
27 | "0388_c1s2_018716_00.jpg",
28 | "0418_c1s2_027716_00.jpg",
29 | "0538_c2s1_152691_00.jpg",
30 | "0609_c1s3_032151_00.jpg",
31 | "0231_c4s1_047501_00.jpg",
32 | "1195_c6s3_032367_00.jpg",
33 | "0103_c3s1_016876_00.jpg"
34 | ]
--------------------------------------------------------------------------------
/visualization/readme.md:
--------------------------------------------------------------------------------
1 | # Attention Rollout
2 |
3 | We updated the visualization codes based on https://github.com/jacobgil/vit-explain. See my examples in visualization/test.jpg.
4 |
5 | ## How to run?
6 | Following the instruction below.
7 | ```
8 | cd visualization
9 |
10 | python vit_explain.py --save_path xxx --data_path xxx --vit_path xxx --pat_path xxx --pretrain_path xxx
11 | ```
12 |
13 | For more details, please check visualization/vit_explain.py.
14 |
15 | ## No ideal reults?
16 |
17 | You can modify the options of the line 67 in visualization/vit_explain.py.
18 |
19 | Moreover, learn about attention fusion in visualization/vit_rollout/vit_rollout.py.
20 |
--------------------------------------------------------------------------------
/visualization/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyuke65535/Part-Aware-Transformer/104d42e8292f7e5d534689ded15e4afafb453785/visualization/test.jpg
--------------------------------------------------------------------------------
/visualization/vit_explain.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | from PIL import Image
4 | from torchvision import transforms
5 | import numpy as np
6 | import cv2
7 |
8 | from config_vis import cfg
9 |
10 | from vit_rollout.vit_rollout import VITAttentionRollout
11 | import os
12 | import sys
13 | sys.path.append(os.path.dirname(os.path.dirname(__file__)))
14 | from model import make_model
15 |
16 | def show_mask_on_image(img, mask):
17 | img = np.float32(img) / 255
18 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
19 | heatmap = np.float32(heatmap) / 255
20 | cam = heatmap + np.float32(img)
21 | cam = cam / np.max(cam)
22 | return np.uint8(255 * cam)
23 |
24 | def main(args):
25 | os.environ['CUDA_VISIBLE_DEVICES'] = '5'
26 | cfg.MODEL.PRETRAIN_PATH = args.pretrain_path
27 |
28 | # load part_attention_vit
29 | model_ours = make_model(cfg, 'part_attention_vit', num_class=1)
30 | model_ours.load_param(args.pat_path)
31 | model_ours.eval()
32 | model_ours.to('cuda')
33 |
34 | # load vanilla vit
35 | model_vit = make_model(cfg, 'vit', num_class=1)
36 | model_vit.load_param(args.vit_path)
37 | model_vit.eval()
38 | model_vit.to('cuda')
39 |
40 | transform = transforms.Compose([
41 | transforms.Resize((256,128)),
42 | transforms.ToTensor(),
43 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
44 | ])
45 |
46 | input_tensor = []
47 |
48 | # Prepare the original person photos
49 | base_dir = args.data_path
50 | img_path = os.listdir(base_dir)
51 | random.shuffle(img_path)
52 |
53 | length = min(30, len(img_path)) # how many photos to visualize
54 | img_list = []
55 | for pth in img_path[:length]:
56 | img = Image.open(base_dir+pth)
57 | img = img.resize((128,256))
58 | np_img = np.array(img)[:, :, ::-1] # BGR -> RGB
59 | input_tensor = transform(img).unsqueeze(0)
60 | input_tensor = input_tensor.cuda()
61 | img_list.append(np_img)
62 |
63 | local_flag = False
64 |
65 | # attention rollout
66 | for model in [model_ours]:
67 | attention_rollout = VITAttentionRollout(model, head_fusion='mean', discard_ratio=0.5) # modify head_fusion type and discard_ratio for better outputs
68 | masks = attention_rollout(input_tensor)
69 |
70 | if isinstance(masks, list):
71 | for msk in masks:
72 | msk = cv2.resize(msk, (np_img.shape[1], np_img.shape[0]))
73 | img_list.append(show_mask_on_image(np_img, msk))
74 | local_flag = True
75 | else:
76 | masks = cv2.resize(masks, (np_img.shape[1], np_img.shape[0]))
77 | out_img = show_mask_on_image(np_img, masks)
78 | img_list.append(out_img)
79 |
80 |
81 | final_img = []
82 | line_len = 5 if local_flag else 3
83 |
84 | # concate output images in a column
85 | for i in range(0, len(img_list)-1, line_len):
86 | if i==0:
87 | img_line = [img_list[l] for l in range(line_len)]
88 | final_img = np.concatenate(img_line,axis=1)
89 | else:
90 | img_line = [img_list[i+l] for l in range(line_len)]
91 | x = np.concatenate(img_line,axis=1)
92 | final_img = np.concatenate([final_img,x],axis=0)
93 |
94 | cv2.imwrite(args.save_path, final_img)
95 | for i, pth in enumerate(img_path[:30]):
96 | print(i+1, pth)
97 | print(f"save to {args.save_path}")
98 |
99 | if __name__ == '__main__':
100 | parser = argparse.ArgumentParser()
101 | parser.add_argument("--save_path", type=str, help="path to save your attention visualized photo. E.g., /home/me/out.jpg")
102 | parser.add_argument("--data_path", type=str, help="path to your dataset. E.g., dataset/market1501/query")
103 | parser.add_argument("--pretrain_path", type=str, help="path to your pretrained vit from imagenet or else. E.g., /home/me/cpt/")
104 | parser.add_argument("--vit_path", type=str, help="path to your trained vanilla vit. E.g., cpt/vit.pth")
105 | parser.add_argument("--pat_path", type=str, help="path to your trained PAT. E.g., cpt/pat.pth")
106 | args = parser.parse_args()
107 | main(args)
--------------------------------------------------------------------------------
/visualization/vit_rollout/vit_example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | import torch
5 |
6 | from pytorch_grad_cam import GradCAM, \
7 | ScoreCAM, \
8 | GradCAMPlusPlus, \
9 | AblationCAM, \
10 | XGradCAM, \
11 | EigenCAM, \
12 | EigenGradCAM, \
13 | LayerCAM, \
14 | FullGrad
15 |
16 | from pytorch_grad_cam import GuidedBackpropReLUModel
17 | from pytorch_grad_cam.utils.image import show_cam_on_image, \
18 | preprocess_image
19 | from pytorch_grad_cam.ablation_layer import AblationLayerVit
20 |
21 | def get_args():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--use-cuda', action='store_true', default=False,
24 | help='Use NVIDIA GPU acceleration')
25 | parser.add_argument(
26 | '--image-path',
27 | type=str,
28 | default='./examples/both.png',
29 | help='Input image path')
30 | parser.add_argument('--aug_smooth', action='store_true',
31 | help='Apply test time augmentation to smooth the CAM')
32 | parser.add_argument(
33 | '--eigen_smooth',
34 | action='store_true',
35 | help='Reduce noise by taking the first principle componenet'
36 | 'of cam_weights*activations')
37 |
38 | parser.add_argument(
39 | '--method',
40 | type=str,
41 | default='gradcam',
42 | help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam')
43 |
44 | args = parser.parse_args()
45 | args.use_cuda = args.use_cuda and torch.cuda.is_available()
46 | if args.use_cuda:
47 | print('Using GPU for acceleration')
48 | else:
49 | print('Using CPU for computation')
50 |
51 | return args
52 |
53 |
54 | def reshape_transform(tensor, height=14, width=14):
55 | result = tensor[:, 1:, :].reshape(tensor.size(0),
56 | height, width, tensor.size(2))
57 |
58 | # Bring the channels to the first dimension,
59 | # like in CNNs.
60 | result = result.transpose(2, 3).transpose(1, 2)
61 | return result
62 |
63 |
64 | if __name__ == '__main__':
65 | """ python vit_gradcam.py --image-path
66 | Example usage of using cam-methods on a VIT network.
67 |
68 | """
69 |
70 | args = get_args()
71 | methods = \
72 | {"gradcam": GradCAM,
73 | "scorecam": ScoreCAM,
74 | "gradcam++": GradCAMPlusPlus,
75 | "ablationcam": AblationCAM,
76 | "xgradcam": XGradCAM,
77 | "eigencam": EigenCAM,
78 | "eigengradcam": EigenGradCAM,
79 | "layercam": LayerCAM,
80 | "fullgrad": FullGrad}
81 |
82 | if args.method not in list(methods.keys()):
83 | raise Exception(f"method should be one of {list(methods.keys())}")
84 |
85 | model = torch.hub.load('facebookresearch/deit:main',
86 | 'deit_tiny_patch16_224', pretrained=True)
87 | model.eval()
88 |
89 | if args.use_cuda:
90 | model = model.cuda()
91 |
92 | target_layers = [model.blocks[-1].norm1]
93 |
94 | if args.method not in methods:
95 | raise Exception(f"Method {args.method} not implemented")
96 |
97 | if args.method == "ablationcam":
98 | cam = methods[args.method](model=model,
99 | target_layers=target_layers,
100 | use_cuda=args.use_cuda,
101 | reshape_transform=reshape_transform,
102 | ablation_layer=AblationLayerVit())
103 | else:
104 | cam = methods[args.method](model=model,
105 | target_layers=target_layers,
106 | use_cuda=args.use_cuda,
107 | reshape_transform=reshape_transform)
108 |
109 |
110 | rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
111 | rgb_img = cv2.resize(rgb_img, (224, 224))
112 | rgb_img = np.float32(rgb_img) / 255
113 | input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5],
114 | std=[0.5, 0.5, 0.5])
115 |
116 | # If None, returns the map for the highest scoring category.
117 | # Otherwise, targets the requested category.
118 | targets = None
119 |
120 | # AblationCAM and ScoreCAM have batched implementations.
121 | # You can override the internal batch size for faster computation.
122 | cam.batch_size = 32
123 |
124 | grayscale_cam = cam(input_tensor=input_tensor,
125 | targets=targets ,
126 | eigen_smooth=args.eigen_smooth,
127 | aug_smooth=args.aug_smooth)
128 |
129 | # Here grayscale_cam has only one image in the batch
130 | grayscale_cam = grayscale_cam[0, :]
131 |
132 | cam_image = show_cam_on_image(rgb_img, grayscale_cam)
133 | cv2.imwrite(f'{args.method}_cam.jpg', cam_image)
--------------------------------------------------------------------------------
/visualization/vit_rollout/vit_grad_rollout.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import numpy
4 | import sys
5 | from torchvision import transforms
6 | import numpy as np
7 | import cv2
8 |
9 | def grad_rollout(attentions, gradients, discard_ratio):
10 | result = torch.eye(attentions[0].size(-1))
11 | with torch.no_grad():
12 | for attention, grad in zip(attentions, gradients):
13 | weights = grad
14 | attention_heads_fused = (attention*weights).mean(axis=1)
15 | attention_heads_fused[attention_heads_fused < 0] = 0
16 |
17 | # Drop the lowest attentions, but
18 | # don't drop the class token
19 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
20 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
21 | #indices = indices[indices != 0]
22 | flat[0, indices] = 0
23 |
24 | I = torch.eye(attention_heads_fused.size(-1))
25 | a = (attention_heads_fused + 1.0*I)/2
26 | a = a / a.sum(dim=-1)
27 | result = torch.matmul(a, result)
28 |
29 | # Look at the total attention between the class token,
30 | # and the image patches
31 | mask = result[0, 0 , 1 :]
32 | # In case of 224x224 image, this brings us from 196 to 14
33 | width = int(mask.size(-1)**0.5)
34 | mask = mask.reshape(width, width).numpy()
35 | mask = mask / np.max(mask)
36 | return mask
37 |
38 | class VITAttentionGradRollout:
39 | def __init__(self, model, attention_layer_name='attn_drop',
40 | discard_ratio=0.9):
41 | self.model = model
42 | self.discard_ratio = discard_ratio
43 | for name, module in self.model.named_modules():
44 | if attention_layer_name in name:
45 | module.register_forward_hook(self.get_attention)
46 | module.register_backward_hook(self.get_attention_gradient)
47 |
48 | self.attentions = []
49 | self.attention_gradients = []
50 |
51 | def get_attention(self, module, input, output):
52 | self.attentions.append(output.cpu())
53 |
54 | def get_attention_gradient(self, module, grad_input, grad_output):
55 | self.attention_gradients.append(grad_input[0].cpu())
56 |
57 | def __call__(self, input_tensor, category_index):
58 | self.model.zero_grad()
59 | output = self.model(input_tensor)
60 | category_mask = torch.zeros(output.size())
61 | category_mask[:, category_index] = 1
62 | loss = (output*category_mask).sum()
63 | loss.backward()
64 |
65 | return grad_rollout(self.attentions, self.attention_gradients,
66 | self.discard_ratio)
--------------------------------------------------------------------------------
/visualization/vit_rollout/vit_rollout.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import numpy
4 | import sys
5 | from torchvision import transforms
6 | import numpy as np
7 | import cv2
8 |
9 | def rollout(attentions, discard_ratio, head_fusion):
10 | result = torch.eye(attentions[0].size(-1)).unsqueeze(0)
11 | with torch.no_grad():
12 | attention = attentions[3] # alter this
13 | # num_blocks = 6
14 | # for attention in attentions[4:6]: # alter this
15 | if head_fusion == "mean":
16 | attention_heads_fused = attention.mean(axis=1)
17 | elif head_fusion == "max":
18 | attention_heads_fused = attention.max(axis=1)[0]
19 | elif head_fusion == "min":
20 | attention_heads_fused = attention.min(axis=1)[0]
21 | else:
22 | raise "Attention head fusion type Not supported"
23 |
24 | # Drop the lowest attentions, but
25 | # don't drop the class token
26 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
27 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
28 | indices = indices[indices != 0]
29 | flat[0, indices] = 0
30 |
31 | # I = torch.eye(attention_heads_fused.size(-1))
32 | # a = (attention_heads_fused + 1.0*I)/2
33 | # a = a / a.sum(dim=-1)
34 | # result = torch.matmul(a, result)
35 |
36 | result = attention_heads_fused
37 |
38 | # result /= 2
39 |
40 | # Look at the total attention between the class token,
41 | # and the image patches
42 | if result.size(-1) == 132:
43 | masks = []
44 | for i in range(4):
45 | mask = result[0, i, 4 :].reshape(16,8).numpy()
46 | mask = mask / np.max(mask)
47 | masks.append(mask)
48 | return masks
49 | # mask = result[0, 0, 4 :]
50 |
51 | else:
52 | mask = result[0, 0, 1 :]
53 | # In case of 224x224 image, this brings us from 196 to 14
54 |
55 | mask = mask.reshape(16,8).numpy()
56 | mask = mask / np.max(mask)
57 | return mask
58 |
59 | class VITAttentionRollout:
60 | def __init__(self, model, attention_layer_name='attn_drop', head_fusion="mean",
61 | discard_ratio=0.9):
62 | self.model = model
63 | self.head_fusion = head_fusion
64 | self.discard_ratio = discard_ratio
65 | for name, module in self.model.named_modules():
66 | if attention_layer_name in name:
67 | module.register_forward_hook(self.get_attention)
68 |
69 | self.attentions = []
70 |
71 | def get_attention(self, module, input, output):
72 | self.attentions.append(output.cpu())
73 |
74 | def __call__(self, input_tensor):
75 | self.attentions = []
76 | with torch.no_grad():
77 | output = self.model(input_tensor)
78 |
79 | return rollout(self.attentions, self.discard_ratio, self.head_fusion)
--------------------------------------------------------------------------------