├── .gitignore ├── LICENSE ├── README.md ├── configs ├── baseline.yml └── default │ ├── __init__.py │ ├── dataset.py │ └── strategy.py ├── data ├── __init__.py ├── datasets.py └── samplers.py ├── engine ├── __init__.py ├── engine.py └── metric.py ├── eval.py ├── extract.py ├── layers └── separate_bn.py ├── losses ├── center_loss.py └── triplet_loss.py ├── models ├── baseline.py └── resnet.py ├── train.py └── utils ├── calc_acc.py ├── eval_sysu.py ├── eval_utils.py ├── lr_scheduler.py ├── net_utils.py └── tsne.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | logs/ 4 | checkpoints/ 5 | features/ 6 | vis/ 7 | configs/exp 8 | 9 | *.pth 10 | *.pkl 11 | *.mat 12 | *.txt 13 | *.py[cod] 14 | *.rar 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Chuanchen Luo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Strong Baseline for RGD-Infrared Cross-Modality Person Re-Identification 2 | 3 | ## Dependency 4 | * Python 3.7 5 | * PyTorch 1.10 6 | * Ignite 0.4.7 7 | * Yacs 8 | 9 | ## Utilization 10 | Download [SYSU-MM01](https://github.com/wuancong/SYSU-MM01) dataset and uncompress it. 11 | Change the entry `data_root` in configs/default.py to the path of the dataset. 12 | Put the [rand_perm_cam.mat](https://github.com/wuancong/SYSU-MM01/blob/master/evaluation/data_split/rand_perm_cam.mat) in `exp` directory in dataset root. This file is used to assign gallery items for each trial while testing. 13 | Run 14 | ```shell script 15 | CUDA_VISIBLE_DEVICES=0 python3 train.py configs/baseline.yml --work-dir work_dirs/baseline 16 | ``` 17 | 18 | ## Performance 19 | 20 | We evaluate the performance on [SYSU-MM01](https://github.com/wuancong/SYSU-MM01) under the setting of **one-shot** & **all-search**. 21 | 22 | | model | mAP | rank-1 | rank-5 | rank-10 | rank-20 | 23 | | ----------------- | ------ | ------ | ------- | ------- | ------- | 24 | | baseline | 54.60 | 57.51 | 82.77 | 90.05 | 95.28 | 25 | 26 | 27 | ## Reference 28 | 29 | [L1aoXingyu/reid_baseline](https://github.com/L1aoXingyu/reid_baseline) 30 | -------------------------------------------------------------------------------- /configs/baseline.yml: -------------------------------------------------------------------------------- 1 | work_dir: sysu-adam0.9 2 | 3 | fp16: false 4 | 5 | # dataset 6 | sample_method: identity_uniform 7 | p_size: 8 8 | k_size: 8 9 | 10 | dataset: sysu 11 | 12 | # architecture 13 | last_stride: 1 14 | 15 | # optimizer 16 | lr: 0.01 17 | optimizer: adam 18 | betas: [ 0.5,0.999 ] 19 | num_epoch: 120 20 | lr_step: [ 40,70 ] 21 | 22 | # augmentation 23 | random_flip: true 24 | random_crop: true 25 | random_erase: true 26 | color_jitter: false 27 | padding: 10 28 | 29 | 30 | # log 31 | log_period: 50 32 | eval_interval: 10 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /configs/default/__init__.py: -------------------------------------------------------------------------------- 1 | from configs.default.dataset import dataset_cfg 2 | from configs.default.strategy import strategy_cfg 3 | 4 | __all__ = ["dataset_cfg", "strategy_cfg"] 5 | -------------------------------------------------------------------------------- /configs/default/dataset.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | dataset_cfg = CfgNode() 4 | 5 | # config for dataset 6 | dataset_cfg.sysu = CfgNode() 7 | dataset_cfg.sysu.num_id = 395 8 | dataset_cfg.sysu.num_cam = 6 9 | dataset_cfg.sysu.data_root = "/home/chuanchen_luo/data/SYSU-MM01" 10 | -------------------------------------------------------------------------------- /configs/default/strategy.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | strategy_cfg = CfgNode() 4 | 5 | strategy_cfg.work_dir = "baseline" 6 | 7 | # setting for loader 8 | strategy_cfg.sample_method = "random" 9 | strategy_cfg.batch_size = 128 10 | strategy_cfg.p_size = 16 11 | strategy_cfg.k_size = 8 12 | 13 | # settings for optimizer 14 | strategy_cfg.optimizer = "sgd" 15 | strategy_cfg.lr = 0.1 16 | strategy_cfg.wd = 5e-4 17 | strategy_cfg.momentum = 0.9 18 | strategy_cfg.betas = (0.5, 0.999) 19 | strategy_cfg.lr_step = (40,) 20 | 21 | strategy_cfg.fp16 = False 22 | 23 | strategy_cfg.num_epoch = 60 24 | 25 | # settings for dataset 26 | strategy_cfg.dataset = "sysu" 27 | strategy_cfg.image_size = (256, 128) 28 | 29 | # settings for augmentation 30 | strategy_cfg.random_flip = True 31 | strategy_cfg.random_crop = True 32 | strategy_cfg.random_erase = True 33 | strategy_cfg.color_jitter = False 34 | strategy_cfg.padding = 10 35 | 36 | # settings for base architecture 37 | strategy_cfg.last_stride = 1 38 | 39 | # logging 40 | strategy_cfg.eval_interval = -1 41 | strategy_cfg.log_period = 10 42 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | from torch.utils.data import DataLoader 3 | 4 | from data.datasets import SYSUDataset 5 | from data.samplers import CrossModalityIdentitySampler 6 | from data.samplers import CrossModalityRandomSampler 7 | 8 | 9 | def get_train_loader(root, sample_method, batch_size, p_size, k_size, image_size, random_flip=False, random_crop=False, 10 | random_erase=False, color_jitter=False, padding=0, num_workers=4): 11 | # data pre-processing 12 | t = [T.Resize(image_size)] 13 | 14 | if random_flip: 15 | t.append(T.RandomHorizontalFlip()) 16 | 17 | if color_jitter: 18 | t.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)) 19 | 20 | if random_crop: 21 | t.extend([T.Pad(padding, fill=127), T.RandomCrop(image_size)]) 22 | 23 | t.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 24 | 25 | if random_erase: 26 | t.append(T.RandomErasing()) 27 | 28 | transform = T.Compose(t) 29 | 30 | # dataset 31 | train_dataset = SYSUDataset(root, mode='train', transform=transform) 32 | 33 | # sampler 34 | assert sample_method in ['random', 'identity_uniform'] 35 | if sample_method == 'identity_uniform': 36 | sampler = CrossModalityIdentitySampler(train_dataset, p_size, k_size) 37 | else: 38 | sampler = CrossModalityRandomSampler(train_dataset, batch_size) 39 | 40 | # loader 41 | train_loader = DataLoader(train_dataset, 42 | batch_size=batch_size, 43 | sampler=sampler, 44 | drop_last=True, 45 | pin_memory=True, 46 | num_workers=num_workers) 47 | 48 | return train_loader 49 | 50 | 51 | def get_test_loader(root, batch_size, image_size, num_workers=4): 52 | # transform 53 | transform = T.Compose([ 54 | T.Resize(image_size), 55 | T.ToTensor(), 56 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 57 | ]) 58 | 59 | # dataset 60 | gallery_dataset = SYSUDataset(root, mode='gallery', transform=transform) 61 | query_dataset = SYSUDataset(root, mode='query', transform=transform) 62 | 63 | # dataloader 64 | query_loader = DataLoader(dataset=query_dataset, 65 | batch_size=batch_size, 66 | shuffle=False, 67 | pin_memory=False, 68 | drop_last=False, 69 | num_workers=num_workers) 70 | 71 | gallery_loader = DataLoader(dataset=gallery_dataset, 72 | batch_size=batch_size, 73 | shuffle=False, 74 | pin_memory=False, 75 | drop_last=False, 76 | num_workers=num_workers) 77 | 78 | return gallery_loader, query_loader 79 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SYSUDataset(Dataset): 10 | def __init__(self, root, mode='train', transform=None): 11 | assert os.path.isdir(root) 12 | assert mode in ['train', 'gallery', 'query'] 13 | 14 | if mode == 'train': 15 | train_ids = open(os.path.join(root, 'exp', 'train_id.txt')).readline() 16 | val_ids = open(os.path.join(root, 'exp', 'val_id.txt')).readline() 17 | 18 | train_ids = train_ids.strip('\n').split(',') 19 | val_ids = val_ids.strip('\n').split(',') 20 | selected_ids = train_ids + val_ids 21 | else: 22 | test_ids = open(os.path.join(root, 'exp', 'test_id.txt')).readline() 23 | selected_ids = test_ids.strip('\n').split(',') 24 | 25 | selected_ids = [int(i) for i in selected_ids] 26 | num_ids = len(selected_ids) 27 | 28 | img_paths = glob(os.path.join(root, '**/*.jpg'), recursive=True) 29 | img_paths = [path for path in img_paths if int(path.split('/')[-2]) in selected_ids] 30 | 31 | if mode == 'gallery': 32 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (1, 2, 4, 5)] 33 | elif mode == 'query': 34 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (3, 6)] 35 | 36 | img_paths = sorted(img_paths) 37 | self.img_paths = img_paths 38 | self.cam_ids = [int(path.split('/')[-3][-1]) for path in img_paths] 39 | self.num_ids = num_ids 40 | self.transform = transform 41 | 42 | if mode == 'train': 43 | id_map = dict(zip(selected_ids, range(num_ids))) 44 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths] 45 | else: 46 | self.ids = [int(path.split('/')[-2]) for path in img_paths] 47 | 48 | def __len__(self): 49 | return len(self.img_paths) 50 | 51 | def __getitem__(self, item): 52 | path = self.img_paths[item] 53 | img = Image.open(path) 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | 57 | label = torch.tensor(self.ids[item], dtype=torch.long) 58 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long) 59 | item = torch.tensor(item, dtype=torch.long) 60 | 61 | return img, label, cam, path, item 62 | -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from torch.utils.data import Sampler 4 | from collections import defaultdict 5 | 6 | 7 | class CrossModalityRandomSampler(Sampler): 8 | def __init__(self, dataset, batch_size): 9 | self.dataset = dataset 10 | self.batch_size = batch_size 11 | 12 | self.rgb_list = [] 13 | self.ir_list = [] 14 | for i, cam in enumerate(dataset.cam_ids): 15 | if cam in [3, 6]: 16 | self.ir_list.append(i) 17 | else: 18 | self.rgb_list.append(i) 19 | 20 | def __len__(self): 21 | return max(len(self.rgb_list), len(self.ir_list)) * 2 22 | 23 | def __iter__(self): 24 | sample_list = [] 25 | rgb_list = np.random.permutation(self.rgb_list).tolist() 26 | ir_list = np.random.permutation(self.ir_list).tolist() 27 | 28 | rgb_size = len(self.rgb_list) 29 | ir_size = len(self.ir_list) 30 | if rgb_size >= ir_size: 31 | diff = rgb_size - ir_size 32 | reps = diff // ir_size 33 | pad_size = diff % ir_size 34 | for _ in range(reps): 35 | ir_list.extend(np.random.permutation(self.ir_list).tolist()) 36 | ir_list.extend(np.random.choice(self.ir_list, pad_size, replace=False).tolist()) 37 | else: 38 | diff = ir_size - rgb_size 39 | reps = diff // ir_size 40 | pad_size = diff % ir_size 41 | for _ in range(reps): 42 | rgb_list.extend(np.random.permutation(self.rgb_list).tolist()) 43 | rgb_list.extend(np.random.choice(self.rgb_list, pad_size, replace=False).tolist()) 44 | 45 | assert len(rgb_list) == len(ir_list) 46 | 47 | half_bs = self.batch_size // 2 48 | for start in range(0, len(rgb_list), half_bs): 49 | sample_list.extend(rgb_list[start:start + half_bs]) 50 | sample_list.extend(ir_list[start:start + half_bs]) 51 | 52 | return iter(sample_list) 53 | 54 | 55 | class CrossModalityIdentitySampler(Sampler): 56 | def __init__(self, dataset, p_size, k_size): 57 | self.dataset = dataset 58 | self.p_size = p_size 59 | self.k_size = k_size 60 | self.batch_size = p_size * k_size 61 | 62 | self.id2idx_rgb = defaultdict(list) 63 | self.id2idx_ir = defaultdict(list) 64 | for i, identity in enumerate(dataset.ids): 65 | if dataset.cam_ids[i] in [3, 6]: 66 | self.id2idx_ir[identity].append(i) 67 | else: 68 | self.id2idx_rgb[identity].append(i) 69 | 70 | self.size_rgb = sum([len(a) for a in self.id2idx_rgb.values()]) 71 | self.size_ir = sum([len(a) for a in self.id2idx_ir.values()]) 72 | self.num_iters = max(self.size_ir, self.size_rgb) // self.batch_size + 1 73 | 74 | def __len__(self): 75 | return self.num_iters * self.batch_size 76 | 77 | def __iter__(self): 78 | sample_list = [] 79 | 80 | id_perm = np.arange(self.dataset.num_ids) 81 | # first half batch is visible and second half batch is infrared 82 | for _ in range(self.num_iters): 83 | selected_ids = np.random.choice(id_perm, size=self.k_size, replace=False) 84 | 85 | sample = [] 86 | for identity in selected_ids: 87 | replace = len(self.id2idx_rgb[identity]) < self.k_size 88 | s = np.random.choice(self.id2idx_rgb[identity], size=self.k_size, replace=replace) 89 | sample.extend(s) 90 | 91 | sample_list.extend(sample) 92 | 93 | sample.clear() 94 | for identity in selected_ids: 95 | replace = len(self.id2idx_ir[identity]) < self.k_size 96 | s = np.random.choice(self.id2idx_ir[identity], size=self.k_size, replace=replace) 97 | sample.extend(s) 98 | 99 | sample_list.extend(sample) 100 | 101 | return iter(sample_list) 102 | -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import scipy.io as sio 5 | from runx.logx import logx 6 | from ignite.engine import Events 7 | from ignite.handlers import Timer 8 | 9 | from engine.engine import create_eval_engine 10 | from engine.engine import create_train_engine 11 | from engine.metric import AutoKVMetric 12 | from utils.eval_sysu import eval_sysu 13 | from configs.default.dataset import dataset_cfg 14 | 15 | 16 | def get_trainer(model, optimizer, lr_scheduler=None, enable_amp=False, log_period=10, save_interval=10, 17 | query_loader=None, gallery_loader=None, eval_interval=None): 18 | # Trainer 19 | trainer = create_train_engine(model, optimizer, enable_amp) 20 | 21 | # Evaluator 22 | evaluator = None 23 | if not type(eval_interval) == int: 24 | raise TypeError("The parameter 'validate_interval' must be type INT.") 25 | if eval_interval > 0 and query_loader and gallery_loader: 26 | evaluator = create_eval_engine(model) 27 | 28 | # Metric 29 | timer = Timer(average=True) 30 | kv_metric = AutoKVMetric() 31 | 32 | @trainer.on(Events.EPOCH_STARTED) 33 | def epoch_started_callback(engine): 34 | 35 | kv_metric.reset() 36 | timer.reset() 37 | 38 | @trainer.on(Events.EPOCH_COMPLETED) 39 | def epoch_completed_callback(engine): 40 | epoch = engine.state.epoch 41 | logx.msg('Epoch[{}] completed.'.format(epoch)) 42 | 43 | if lr_scheduler is not None: 44 | lr_scheduler.step() 45 | 46 | if epoch % save_interval == 0: 47 | state_dict = model.state_dict() 48 | save_path = os.path.join(logx.logdir, 'checkpoint_ep{}.pt'.format(epoch)) 49 | torch.save(state_dict, save_path) 50 | logx.msg("Model saved at {}".format(save_path)) 51 | 52 | if evaluator and epoch % eval_interval == 0: 53 | torch.cuda.empty_cache() 54 | 55 | # extract query feature 56 | model.eval_mode = 'infrared' 57 | evaluator.run(query_loader) 58 | 59 | q_feats = torch.cat(evaluator.state.feat_list, dim=0) 60 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 61 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 62 | 63 | # extract gallery feature 64 | model.eval_mode = 'visible' 65 | evaluator.run(gallery_loader) 66 | 67 | g_feats = torch.cat(evaluator.state.feat_list, dim=0) 68 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 69 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 70 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 71 | 72 | perm_path = os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat') 73 | perm = sio.loadmat(perm_path)['rand_perm_cam'] 74 | mAP, r1, r5, r10, r20 = eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm) 75 | logx.msg('mAP = %f , r1 = %f , r5 = %f , r10 = %f , r20 = %f' % (mAP, r1, r5, r10, r20)) 76 | 77 | val_metrics = {'mAP': mAP, 'rank-1': r1, 'rank-5': r5, 'rank-10': r10, 'rank-20': r20} 78 | logx.metric('val', val_metrics, epoch) 79 | 80 | # clear temporary storage 81 | evaluator.state.feat_list.clear() 82 | evaluator.state.id_list.clear() 83 | evaluator.state.cam_list.clear() 84 | evaluator.state.img_path_list.clear() 85 | del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams 86 | 87 | torch.cuda.empty_cache() 88 | 89 | @trainer.on(Events.ITERATION_COMPLETED) 90 | def iteration_complete_callback(engine): 91 | timer.step() 92 | 93 | kv_metric.update(engine.state.output) 94 | 95 | epoch = engine.state.epoch 96 | iteration = engine.state.iteration 97 | iter_in_epoch = iteration - (epoch - 1) * len(engine.state.dataloader) 98 | 99 | if iter_in_epoch % log_period == 0: 100 | batch_size = engine.state.batch[0].size(0) 101 | speed = batch_size / timer.value() 102 | 103 | msg = "Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec" % (epoch, iter_in_epoch, speed) 104 | metric_dict = kv_metric.compute() 105 | for k in sorted(metric_dict.keys()): 106 | msg += "\t%s: %.4f" % (k, metric_dict[k]) 107 | logx.msg(msg) 108 | logx.metric('train', metric_dict, iteration) 109 | 110 | kv_metric.reset() 111 | timer.reset() 112 | 113 | return trainer 114 | -------------------------------------------------------------------------------- /engine/engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.cuda import amp 4 | from ignite.engine import Engine 5 | from ignite.engine import Events 6 | 7 | 8 | def create_train_engine(model, optimizer, enable_amp=False): 9 | device = torch.device("cuda", torch.cuda.current_device()) 10 | scaler = amp.GradScaler(enabled=enable_amp) 11 | 12 | def _process_func(engine, batch): 13 | model.train() 14 | 15 | data, labels, cam_ids, img_paths, img_ids = batch 16 | epoch = engine.state.epoch 17 | 18 | data = data.to(device, non_blocking=True) 19 | labels = labels.to(device, non_blocking=True) 20 | cam_ids = cam_ids.to(device, non_blocking=True) 21 | 22 | optimizer.zero_grad(set_to_none=True) 23 | with amp.autocast(enabled=enable_amp): 24 | loss, metric = model(data, labels=labels, cam_ids=cam_ids, epoch=epoch) 25 | 26 | scaler.scale(loss).backward() 27 | scaler.step(optimizer) 28 | scaler.update() 29 | 30 | return metric 31 | 32 | return Engine(_process_func) 33 | 34 | 35 | def create_eval_engine(model): 36 | device = torch.device("cuda", torch.cuda.current_device()) 37 | 38 | def _process_func(engine, batch): 39 | model.eval() 40 | 41 | data, labels, cam_ids, img_paths = batch[:4] 42 | 43 | data = data.to(device, non_blocking=True) 44 | with torch.no_grad(): 45 | feat = model(data, cam_ids=cam_ids.to(device, non_blocking=False)) 46 | 47 | return feat.data.float().cpu(), labels, cam_ids, np.array(img_paths) 48 | 49 | engine = Engine(_process_func) 50 | 51 | @engine.on(Events.EPOCH_STARTED) 52 | def clear_data(engine): 53 | # feat list 54 | if not hasattr(engine.state, "feat_list"): 55 | setattr(engine.state, "feat_list", []) 56 | else: 57 | engine.state.feat_list.clear() 58 | 59 | # id_list 60 | if not hasattr(engine.state, "id_list"): 61 | setattr(engine.state, "id_list", []) 62 | else: 63 | engine.state.id_list.clear() 64 | 65 | # cam list 66 | if not hasattr(engine.state, "cam_list"): 67 | setattr(engine.state, "cam_list", []) 68 | else: 69 | engine.state.cam_list.clear() 70 | 71 | # cam list 72 | if not hasattr(engine.state, "img_path_list"): 73 | setattr(engine.state, "img_path_list", []) 74 | else: 75 | engine.state.img_path_list.clear() 76 | 77 | @engine.on(Events.ITERATION_COMPLETED) 78 | def store_data(engine): 79 | engine.state.feat_list.append(engine.state.output[0]) 80 | engine.state.id_list.append(engine.state.output[1]) 81 | engine.state.cam_list.append(engine.state.output[2]) 82 | engine.state.img_path_list.append(engine.state.output[3]) 83 | 84 | return engine 85 | -------------------------------------------------------------------------------- /engine/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from collections import defaultdict 4 | from ignite.metrics import Metric 5 | from ignite.exceptions import NotComputableError 6 | from ignite.metrics.metric import reinit__is_reduced 7 | 8 | 9 | class AutoKVMetric(Metric): 10 | def __init__(self): 11 | self.kv_sum_metric = defaultdict(lambda: torch.tensor(0., device="cuda")) 12 | self.kv_sum_inst = defaultdict(lambda: torch.tensor(0., device="cuda")) 13 | self.kv_metric = defaultdict(lambda: 0) 14 | 15 | self.reset() 16 | 17 | @reinit__is_reduced 18 | def update(self, output): 19 | if not isinstance(output, dict): 20 | raise TypeError('The output must be a key-value dict.') 21 | 22 | for k in output.keys(): 23 | self.kv_sum_metric[k].add_(output[k]) 24 | self.kv_sum_inst[k].add_(1) 25 | 26 | @reinit__is_reduced 27 | def reset(self): 28 | self.kv_sum_metric.clear() 29 | self.kv_sum_inst.clear() 30 | self.kv_metric.clear() 31 | 32 | def compute(self): 33 | for k in self.kv_sum_metric.keys(): 34 | if self.kv_sum_inst[k] == 0: 35 | raise NotComputableError('Accuracy must have at least one example before it can be computed') 36 | 37 | metric_value = self.kv_sum_metric[k] / self.kv_sum_inst[k] 38 | 39 | if dist.is_initialized(): 40 | dist.barrier() 41 | dist.all_reduce(metric_value) 42 | dist.barrier() 43 | metric_value /= dist.get_world_size() 44 | 45 | self.kv_metric[k] = metric_value.item() 46 | 47 | return self.kv_metric 48 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | import sys 5 | 6 | import scipy.io as sio 7 | 8 | from utils.eval_sysu import eval_sysu 9 | from configs.default.dataset import dataset_cfg 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("gpu", type=int) 14 | parser.add_argument("model_path", type=str) 15 | parser.add_argument("dataset", type=str) 16 | 17 | args = parser.parse_args() 18 | basename = os.path.basename(args.model_path) 19 | prefix = os.path.splitext(basename)[0] 20 | 21 | # extract feature 22 | cmd = "python{} extract.py {} {} {} ".format(sys.version[0], args.gpu, args.model_path, args.dataset) 23 | subprocess.check_call(cmd.strip().split(" ")) 24 | 25 | # evaluation 26 | dirname = os.path.dirname(args.model_path) 27 | q_mat_path = os.path.join(dirname, 'query-%s.mat' % prefix) 28 | g_mat_path = os.path.join(dirname, 'gallery-%s.mat' % prefix) 29 | 30 | assert os.path.exists(q_mat_path) 31 | assert os.path.exists(g_mat_path) 32 | 33 | mat = sio.loadmat(q_mat_path) 34 | q_feats = mat["feat"] 35 | q_ids = mat["ids"].squeeze() 36 | q_cam_ids = mat["cam_ids"].squeeze() 37 | 38 | mat = sio.loadmat(g_mat_path) 39 | g_feats = mat["feat"] 40 | g_ids = mat["ids"].squeeze() 41 | g_cam_ids = mat["cam_ids"].squeeze() 42 | g_img_paths = mat['img_path'].squeeze() 43 | 44 | data_root = dataset_cfg.get(args.dataset).data_root 45 | perm = sio.loadmat(os.path.join(data_root, 'exp', 'rand_perm_cam.mat')) 46 | perm = perm['rand_perm_cam'] 47 | mAP, r1, r5, r10, r20 = eval_sysu(q_feats, q_ids, q_cam_ids, g_feats, g_ids, g_cam_ids, g_img_paths, perm) 48 | print('mAP = %f , r1 = %f , r5 = %f , r10 = %f , r20 = %f' % (mAP, r1, r5, r10, r20)) 49 | -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import scipy.io as sio 5 | import torch 6 | 7 | from configs.default import dataset_cfg 8 | from data import get_test_loader 9 | from models.baseline import Baseline 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("gpu", type=int) 14 | parser.add_argument("model_path", type=str) # TODO compatible for different models 15 | parser.add_argument("dataset", type=str, default=None) 16 | parser.add_argument("--img-h", type=int, default=256) 17 | 18 | args = parser.parse_args() 19 | model_path = args.model_path 20 | basename = os.path.basename(model_path) 21 | prefix = os.path.splitext(basename)[0] 22 | 23 | dataset = args.dataset 24 | dataset_config = dataset_cfg.get(dataset) 25 | image_size = (args.img_h, 128) 26 | 27 | torch.backends.cudnn.benchmark = True 28 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 29 | 30 | model = Baseline() 31 | state_dict = torch.load(model_path) 32 | 33 | model.load_state_dict(state_dict, strict=False) 34 | model.float() 35 | model.eval() 36 | model.cuda() 37 | 38 | # extract test feature 39 | gallery_loader, query_loader = get_test_loader(root=dataset_config.data_root, 40 | batch_size=512, 41 | image_size=image_size, 42 | num_workers=16) 43 | # extract query features 44 | feats = [] 45 | labels = [] 46 | cam_ids = [] 47 | img_paths = [] 48 | model.eval_mode = 'infrared' 49 | for data, label, cam_id, img_path, _ in query_loader: 50 | with torch.autograd.no_grad(): 51 | feat = model(data.cuda(non_blocking=True), cam_ids=cam_id) 52 | 53 | feats.append(feat.data.cpu().numpy()) 54 | labels.append(label.data.cpu().numpy()) 55 | cam_ids.append(cam_id.data.cpu().numpy()) 56 | img_paths.extend(img_path) 57 | 58 | feats = np.concatenate(feats, axis=0) 59 | labels = np.concatenate(labels, axis=0) 60 | cam_ids = np.concatenate(cam_ids, axis=0) 61 | print(feats.shape) 62 | 63 | dirname = os.path.dirname(args.model_path) 64 | if not os.path.isdir(dirname): 65 | os.makedirs(dirname) 66 | 67 | save_name = "{}/query-{}.mat".format(dirname, prefix) 68 | sio.savemat(save_name, 69 | {"feat": feats, 70 | "ids": labels, 71 | "cam_ids": cam_ids, 72 | "img_path": img_paths}) 73 | 74 | # extract gallery features 75 | feats = [] 76 | labels = [] 77 | cam_ids = [] 78 | img_paths = [] 79 | model.eval_mode = 'visible' 80 | for data, label, cam_id, img_path, _ in gallery_loader: 81 | with torch.autograd.no_grad(): 82 | feat = model(data.cuda(non_blocking=True), cam_ids=cam_id) 83 | 84 | feats.append(feat.data.cpu().numpy()) 85 | labels.append(label.data.cpu().numpy()) 86 | cam_ids.append(cam_id.data.cpu().numpy()) 87 | img_paths.extend(img_path) 88 | 89 | feats = np.concatenate(feats, axis=0) 90 | labels = np.concatenate(labels, axis=0) 91 | cam_ids = np.concatenate(cam_ids, axis=0) 92 | print(feats.shape) 93 | 94 | save_name = "{}/gallery-{}.mat".format(dirname, prefix) 95 | sio.savemat(save_name, 96 | {"feat": feats, 97 | "ids": labels, 98 | "cam_ids": cam_ids, 99 | "img_path": img_paths}) 100 | -------------------------------------------------------------------------------- /layers/separate_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.batchnorm import _BatchNorm 5 | 6 | 7 | def convert_sep_bn_model(module, separate_affine=False): 8 | mod = module 9 | 10 | if isinstance(module, nn.BatchNorm2d): 11 | mod = SeparateBatchNorm(num_features=module.num_features, separate_affine=separate_affine, 12 | momentum=module.momentum, affine=module.affine, eps=module.eps, 13 | track_running_stats=module.track_running_stats) 14 | 15 | mod.running_mean = module.running_mean.clone() 16 | mod.running_var = module.running_var.clone() 17 | mod.running_mean_s = module.running_mean.clone() 18 | mod.running_var_s = module.running_var.clone() 19 | 20 | if module.affine: 21 | mod.weight.data = module.weight.data.clone().detach() 22 | mod.bias.data = module.bias.data.clone().detach() 23 | if separate_affine: 24 | mod.weight_s.data = module.bias.data.clone().detach() 25 | mod.bias_s.data = module.bias.data.clone().detach() 26 | 27 | for name, children in module.named_children(): 28 | mod.add_module(name, convert_sep_bn_model(children, separate_affine)) 29 | 30 | del module 31 | return mod 32 | 33 | 34 | class SeparateBatchNorm(_BatchNorm): 35 | def __init__(self, num_features, separate_affine=False, eps=1e-5, momentum=0.1, affine=True, 36 | track_running_stats=True): 37 | super(SeparateBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats) 38 | 39 | self.register_buffer('running_mean_s', torch.zeros(num_features)) 40 | self.register_buffer('running_var_s', torch.ones(num_features)) 41 | 42 | if affine and separate_affine: 43 | self.weight = nn.Parameter(torch.ones(num_features)) 44 | self.bias = nn.Parameter(torch.zeros(num_features)) 45 | else: 46 | self.weight_s = self.weight 47 | self.bias_s = self.bias 48 | 49 | self._eval_mode = 'visible' 50 | 51 | @property 52 | def eval_mode(self): 53 | return self._eval_mode 54 | 55 | @eval_mode.setter 56 | def eval_mode(self, mode): 57 | if mode not in ('visible', 'infrared'): 58 | raise ValueError('The choice of mode is visible or infrared!') 59 | self._eval_mode = mode 60 | 61 | def forward(self, x): 62 | if self.training: 63 | source_split, target_split = x.tensor_split(2, dim=0) 64 | source_result = F.batch_norm(source_split, self.running_mean_s, self.running_var_s, self.weight_s, 65 | self.bias_s, True, self.momentum, self.eps) 66 | target_result = F.batch_norm(target_split, self.running_mean, self.running_var, self.weight, 67 | self.bias, True, self.momentum, self.eps) 68 | 69 | return torch.cat([source_result, target_result], dim=0) 70 | else: 71 | if self.eval_mode == 'infrared': 72 | result = F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 73 | not self.track_running_stats, self.momentum, self.eps) 74 | else: 75 | result = F.batch_norm(x, self.running_mean_s, self.running_var_s, self.weight_s, self.bias_s, 76 | not self.track_running_stats, self.momentum, self.eps) 77 | 78 | return result 79 | -------------------------------------------------------------------------------- /losses/center_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CenterLoss(nn.Module): 6 | """Center loss. 7 | 8 | Reference: 9 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | feature_dim (int): feature dimension. 14 | """ 15 | 16 | def __init__(self, num_classes, feature_dim, reduction='mean'): 17 | super(CenterLoss, self).__init__() 18 | self.num_classes = num_classes 19 | self.feature_dim = feature_dim 20 | self.reduction = reduction 21 | self.centers = nn.Parameter(torch.Tensor(self.num_classes, self.feature_dim)) 22 | 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | nn.init.normal_(self.centers) 27 | 28 | @torch.no_grad() 29 | def ema_update(self, feats, labels, momentum=0.0): 30 | unique_labels, inverse_indices = torch.unique(labels, return_inverse=True) 31 | 32 | feat_list = [] 33 | for i in range(unique_labels.size(0)): 34 | mask = inverse_indices.eq(i) 35 | feat_list.append(feats[mask].mean(dim=0)) 36 | 37 | feats = torch.stack(feat_list, dim=0).to(dtype=self.centers.dtype) 38 | self.centers.data[unique_labels] *= momentum 39 | self.centers.data[unique_labels] += (1 - momentum) * feats 40 | 41 | def forward(self, x, labels): 42 | """ 43 | Args: 44 | x: feature matrix with shape (batch_size, feat_dim). 45 | labels: ground truth labels with shape (batch_size). 46 | """ 47 | dist_mat = torch.cdist(x, self.centers, p=2) ** 2 / 2 48 | 49 | classes = torch.arange(self.num_classes, device=x.device, dtype=torch.long) 50 | classes = classes.unsqueeze(0).expand(x.size(0), -1) 51 | labels = labels.unsqueeze(1).expand(-1, self.num_classes) 52 | mask = labels.eq(classes) 53 | 54 | dist = dist_mat * mask.float() 55 | loss = dist.sum(dim=1) 56 | 57 | if self.reduction == 'mean': 58 | loss = loss.mean() 59 | elif self.reduction == 'sum': 60 | loss = loss.sum() 61 | 62 | return loss 63 | -------------------------------------------------------------------------------- /losses/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def hard_example_mining(dist_mat, labels_1, labels_2=None): 7 | """For each anchor, find the hardest positive and negative sample. 8 | Args: 9 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 10 | labels_1: pytorch LongTensor, with shape [N] 11 | labels_2: pytorch LongTensor, with shape [N] 12 | Returns: 13 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 14 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 15 | """ 16 | if labels_2 is None: 17 | labels_2 = labels_1 18 | 19 | assert dist_mat.dim() == 2 20 | assert dist_mat.size(0) == labels_1.size(0) 21 | assert dist_mat.size(1) == labels_2.size(0) 22 | m, n = dist_mat.shape 23 | 24 | labels_1 = labels_1.view(-1, 1).expand(-1, n) 25 | labels_2 = labels_2.view(1, -1).expand(m, -1) 26 | pos_mask = labels_1.eq(labels_2).to(dtype=dist_mat.dtype) 27 | neg_mask = 1 - pos_mask 28 | 29 | dist_ap, _ = torch.max(dist_mat - neg_mask * 1e4, dim=1) 30 | dist_an, _ = torch.min(dist_mat + pos_mask * 1e4, dim=1) 31 | 32 | return dist_ap, dist_an 33 | 34 | 35 | class TripletLoss(nn.Module): 36 | def __init__(self, margin=None, normalize=False, reduction='mean'): 37 | super(TripletLoss, self).__init__() 38 | assert reduction in ['sum', 'mean', 'none'], 'reduction = "{}" is not supported.'.format(reduction) 39 | 40 | self.margin = margin 41 | self.normalize = normalize 42 | self.reduction = reduction 43 | if margin is not None: 44 | self.ranking_loss = nn.MarginRankingLoss(margin=margin, reduction='none') 45 | else: 46 | self.ranking_loss = nn.SoftMarginLoss(reduction='none') 47 | 48 | def forward(self, feat_1, labels_1, feat_2=None, labels_2=None): 49 | if feat_2 is None: 50 | feat_2 = feat_1 51 | labels_2 = labels_1 52 | 53 | if self.normalize: 54 | feat_1 = F.normalize(feat_1, p=2, dim=-1) 55 | feat_2 = F.normalize(feat_2, p=2, dim=-1) 56 | 57 | dist_mat = torch.cdist(feat_1, feat_2, p=2) 58 | dist_ap, dist_an = hard_example_mining(dist_mat, labels_1, labels_2) 59 | y = torch.ones_like(dist_ap, dtype=torch.long) 60 | if self.margin is not None: 61 | loss = self.ranking_loss(dist_an, dist_ap, y) 62 | else: 63 | loss = self.ranking_loss(dist_an - dist_ap, y) 64 | 65 | if self.reduction == 'none': 66 | return loss 67 | elif self.reduction == 'sum': 68 | return loss.sum() 69 | else: 70 | return loss.mean() 71 | -------------------------------------------------------------------------------- /models/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | from utils.calc_acc import calc_acc 6 | from models.resnet import resnet50 7 | from layers.separate_bn import SeparateBatchNorm 8 | from losses.triplet_loss import TripletLoss 9 | 10 | 11 | class Baseline(nn.Module): 12 | def __init__(self, num_classes=None, **kwargs): 13 | super(Baseline, self).__init__() 14 | self.backbone = resnet50(pretrained=True, last_stride=1) 15 | 16 | self.bn_neck = nn.BatchNorm1d(2048) 17 | self.bn_neck.bias.requires_grad_(False) 18 | 19 | 20 | self._eval_mode = 'visible' 21 | 22 | if num_classes is not None: 23 | self.classifier = nn.Linear(2048, num_classes, bias=False) 24 | nn.init.normal_(self.classifier.weight, std=0.001) 25 | 26 | # losses 27 | self.id_loss = nn.CrossEntropyLoss(ignore_index=-1) 28 | self.triplet_loss = TripletLoss(margin=0.3) 29 | 30 | @property 31 | def eval_mode(self): 32 | return self._eval_mode 33 | 34 | @eval_mode.setter 35 | def eval_mode(self, mode): 36 | if mode not in ('visible', 'infrared'): 37 | raise ValueError('The choice of mode is visible or infrared!') 38 | self._eval_mode = mode 39 | 40 | def set_sep_bn_mode(m, eval_mode): 41 | if isinstance(m, SeparateBatchNorm): 42 | m.eval_mode = eval_mode 43 | 44 | set_sep_bn_mode = partial(set_sep_bn_mode, eval_mode=mode) 45 | self.backbone.apply(set_sep_bn_mode) 46 | 47 | def get_param_groups(self, lr, weight_decay): 48 | ft_params = self.backbone.parameters() 49 | new_params = [param for name, param in self.named_parameters() if not name.startswith("backbone.")] 50 | param_groups = [{'params': ft_params, 'lr': lr * 0.1, 'weight_decay': weight_decay}, 51 | {'params': new_params, 'lr': lr, 'weight_decay': weight_decay}] 52 | return param_groups 53 | 54 | def forward(self, inputs, labels=None, **kwargs): 55 | global_feat = self.backbone(inputs) 56 | if self.training: 57 | return self.train_forward(global_feat, labels, **kwargs) 58 | return self.test_forward(global_feat) 59 | 60 | def test_forward(self, feats): 61 | return self.bn_neck(feats) 62 | 63 | def train_forward(self, feats, labels, **kwargs): 64 | triplet_loss = self.triplet_loss(feats, labels) 65 | 66 | logits = self.classifier(self.bn_neck(feats)) 67 | cls_loss = self.id_loss(logits, labels) 68 | 69 | loss = triplet_loss + cls_loss 70 | metrics = {'ce': cls_loss.item(), 'acc': calc_acc(logits.data, labels.data), 'tri': triplet_loss.item()} 71 | 72 | return loss, metrics 73 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.utils.model_zoo import load_url 3 | 4 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 5 | 6 | model_urls = { 7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 8 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 11 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth' 12 | } 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=dilation, groups=groups, bias=False, dilation=dilation) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, stride=1): 22 | """1x1 convolution""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 30 | base_width=64, dilation=1, norm_layer=None): 31 | super(BasicBlock, self).__init__() 32 | if norm_layer is None: 33 | norm_layer = nn.BatchNorm2d 34 | if groups != 1 or base_width != 64: 35 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 36 | if dilation > 1: 37 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 38 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 39 | self.conv1 = conv3x3(inplanes, planes, stride) 40 | self.bn1 = norm_layer(planes) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.conv2 = conv3x3(planes, planes) 43 | self.bn2 = norm_layer(planes) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | identity = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | 57 | if self.downsample is not None: 58 | identity = self.downsample(x) 59 | 60 | out += identity 61 | out = self.relu(out) 62 | 63 | return out 64 | 65 | 66 | class Bottleneck(nn.Module): 67 | expansion = 4 68 | 69 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 70 | base_width=64, dilation=1, norm_layer=None): 71 | super(Bottleneck, self).__init__() 72 | if norm_layer is None: 73 | norm_layer = nn.BatchNorm2d 74 | width = int(planes * (base_width / 64.)) * groups 75 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 76 | self.conv1 = conv1x1(inplanes, width) 77 | self.bn1 = norm_layer(width) 78 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 79 | self.bn2 = norm_layer(width) 80 | self.conv3 = conv1x1(width, planes * self.expansion) 81 | self.bn3 = norm_layer(planes * self.expansion) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.downsample = downsample 84 | self.stride = stride 85 | 86 | def forward(self, x): 87 | identity = x 88 | 89 | out = self.conv1(x) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv2(out) 94 | out = self.bn2(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv3(out) 98 | out = self.bn3(out) 99 | 100 | if self.downsample is not None: 101 | identity = self.downsample(x) 102 | 103 | out += identity 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class ResNet(nn.Module): 110 | 111 | def __init__(self, block, layers, zero_init_residual=False, 112 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 113 | norm_layer=None, last_stride=2): 114 | super(ResNet, self).__init__() 115 | if norm_layer is None: 116 | norm_layer = nn.BatchNorm2d 117 | self._norm_layer = norm_layer 118 | 119 | self.inplanes = 64 120 | self.dilation = 1 121 | if replace_stride_with_dilation is None: 122 | # each element in the tuple indicates if we should replace 123 | # the 2x2 stride with a dilated convolution instead 124 | replace_stride_with_dilation = [False, False, False] 125 | if len(replace_stride_with_dilation) != 3: 126 | raise ValueError("replace_stride_with_dilation should be None " 127 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 128 | self.groups = groups 129 | self.base_width = width_per_group 130 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 131 | bias=False) 132 | self.bn1 = norm_layer(self.inplanes) 133 | self.relu = nn.ReLU(inplace=True) 134 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 135 | self.layer1 = self._make_layer(block, 64, layers[0]) 136 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 137 | dilate=replace_stride_with_dilation[0]) 138 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 139 | dilate=replace_stride_with_dilation[1]) 140 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, 141 | dilate=replace_stride_with_dilation[2]) 142 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 143 | 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 147 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 148 | nn.init.constant_(m.weight, 1) 149 | nn.init.constant_(m.bias, 0) 150 | 151 | # Zero-initialize the last BN in each residual branch, 152 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 153 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 154 | if zero_init_residual: 155 | for m in self.modules(): 156 | if isinstance(m, Bottleneck): 157 | nn.init.constant_(m.bn3.weight, 0) 158 | elif isinstance(m, BasicBlock): 159 | nn.init.constant_(m.bn2.weight, 0) 160 | 161 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 162 | norm_layer = self._norm_layer 163 | downsample = None 164 | previous_dilation = self.dilation 165 | if dilate: 166 | self.dilation *= stride 167 | stride = 1 168 | if stride != 1 or self.inplanes != planes * block.expansion: 169 | downsample = nn.Sequential( 170 | conv1x1(self.inplanes, planes * block.expansion, stride), 171 | norm_layer(planes * block.expansion), 172 | ) 173 | 174 | layers = [] 175 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 176 | self.base_width, previous_dilation, norm_layer)) 177 | self.inplanes = planes * block.expansion 178 | for _ in range(1, blocks): 179 | layers.append(block(self.inplanes, planes, groups=self.groups, 180 | base_width=self.base_width, dilation=self.dilation, 181 | norm_layer=norm_layer)) 182 | 183 | return nn.Sequential(*layers) 184 | 185 | def forward(self, x): 186 | x = self.conv1(x) 187 | x = self.bn1(x) 188 | x = self.relu(x) 189 | x = self.maxpool(x) 190 | 191 | x = self.layer1(x) 192 | x = self.layer2(x) 193 | x = self.layer3(x) 194 | x = self.layer4(x) 195 | 196 | x = self.avgpool(x) 197 | x = x.flatten(1) 198 | 199 | return x 200 | 201 | 202 | def remove_fc(state_dict): 203 | """Remove the fc layer parameters from state_dict.""" 204 | for key, value in list(state_dict.items()): 205 | if key.startswith('fc.'): 206 | del state_dict[key] 207 | return state_dict 208 | 209 | 210 | def resnet18(pretrained=False, **kwargs): 211 | """Constructs a ResNet-18 model. 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 216 | if pretrained: 217 | model.load_state_dict(remove_fc(load_url(model_urls['resnet18']))) 218 | return model 219 | 220 | 221 | def resnet34(pretrained=False, **kwargs): 222 | """Constructs a ResNet-34 model. 223 | Args: 224 | pretrained (bool): If True, returns a model pre-trained on ImageNet 225 | """ 226 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 227 | if pretrained: 228 | model.load_state_dict(remove_fc(load_url(model_urls['resnet34']))) 229 | return model 230 | 231 | 232 | def resnet50(pretrained=False, **kwargs): 233 | """Constructs a ResNet-50 model. 234 | Args: 235 | pretrained (bool): If True, returns a model pre-trained on ImageNet 236 | """ 237 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 238 | if pretrained: 239 | model.load_state_dict(remove_fc(load_url(model_urls['resnet50']))) 240 | return model 241 | 242 | 243 | def resnet101(pretrained=False, **kwargs): 244 | """Constructs a ResNet-101 model. 245 | Args: 246 | pretrained (bool): If True, returns a model pre-trained on ImageNet 247 | """ 248 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 249 | if pretrained: 250 | model.load_state_dict(remove_fc(load_url(model_urls['resnet101']))) 251 | return model 252 | 253 | 254 | def resnet152(pretrained=False, **kwargs): 255 | """Constructs a ResNet-152 model. 256 | Args: 257 | pretrained (bool): If True, returns a model pre-trained on ImageNet 258 | """ 259 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 260 | if pretrained: 261 | model.load_state_dict(remove_fc(load_url(model_urls['resnet152']))) 262 | return model 263 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import torch 5 | from torch import optim 6 | import torch.distributed as dist 7 | 8 | from data import get_test_loader 9 | from data import get_train_loader 10 | from engine import get_trainer 11 | from models.baseline import Baseline 12 | from utils.lr_scheduler import WarmupMultiStepLR 13 | 14 | 15 | def train(cfg): 16 | # training data loader 17 | train_loader = get_train_loader(root=cfg.data_root, 18 | sample_method=cfg.sample_method, 19 | batch_size=cfg.batch_size, 20 | p_size=cfg.p_size, 21 | k_size=cfg.k_size, 22 | random_flip=cfg.random_flip, 23 | random_crop=cfg.random_crop, 24 | random_erase=cfg.random_erase, 25 | color_jitter=cfg.color_jitter, 26 | padding=cfg.padding, 27 | image_size=cfg.image_size, 28 | num_workers=8) 29 | 30 | # evaluation data loader 31 | gallery_loader, query_loader = None, None 32 | if cfg.eval_interval > 0: 33 | gallery_loader, query_loader = get_test_loader(root=cfg.data_root, 34 | batch_size=512, 35 | image_size=cfg.image_size, 36 | num_workers=4) 37 | 38 | # model 39 | model = Baseline(num_classes=cfg.num_id) 40 | model.cuda() 41 | 42 | # optimizer 43 | assert cfg.optimizer in ['adam', 'sgd'] 44 | param_groups = model.get_param_groups(lr=cfg.lr, weight_decay=cfg.wd) 45 | if cfg.optimizer == 'adam': 46 | optimizer = optim.Adam(param_groups, lr=cfg.lr, betas=cfg.betas, weight_decay=cfg.wd) 47 | else: 48 | optimizer = optim.SGD(param_groups, lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.wd) 49 | 50 | # convert model for mixed precision training 51 | lr_scheduler = WarmupMultiStepLR(optimizer=optimizer, 52 | milestones=cfg.lr_step, 53 | gamma=0.1, 54 | warmup_epochs=10, 55 | warmup_factor=0.01) 56 | # engine 57 | engine = get_trainer(model=model, 58 | optimizer=optimizer, 59 | lr_scheduler=lr_scheduler, 60 | log_period=cfg.log_period, 61 | eval_interval=cfg.eval_interval, 62 | gallery_loader=gallery_loader, 63 | query_loader=query_loader, 64 | enable_amp=cfg.fp16) 65 | 66 | # training 67 | engine.run(train_loader, max_epochs=cfg.num_epoch) 68 | 69 | 70 | if __name__ == '__main__': 71 | import yaml 72 | import time 73 | import argparse 74 | import random 75 | import numpy as np 76 | from pprint import pformat 77 | from datetime import timedelta 78 | from runx.logx import logx 79 | from configs.default import strategy_cfg 80 | from configs.default import dataset_cfg 81 | 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('cfg', type=str, help='Path to config file') 84 | parser.add_argument('--work-dir', type=str, help='Directory for log and checkpoint') 85 | parser.add_argument('--gpu', type=int, help='GPU device for training') 86 | parser.add_argument('--local-rank', type=int, help='Rank of distributed training') 87 | args = parser.parse_args() 88 | 89 | # Load configuration 90 | customized_cfg = yaml.load(open(args.cfg, "r"), yaml.SafeLoader) 91 | cfg = strategy_cfg 92 | cfg.merge_from_file(args.cfg) 93 | 94 | data_cfg = dataset_cfg.get(cfg.dataset) 95 | for k, v in data_cfg.items(): 96 | cfg[k] = v 97 | 98 | if args.work_dir is not None: 99 | cfg.work_dir = args.work_dir 100 | 101 | cfg.freeze() 102 | 103 | # Set random seed 104 | seed = 0 105 | random.seed(seed) 106 | np.random.seed(seed) 107 | torch.manual_seed(seed) 108 | torch.cuda.manual_seed(seed) 109 | 110 | # Setup logger 111 | logx.initialize(logdir=cfg.work_dir, hparams=cfg, tensorboard=True, 112 | global_rank=args.local_rank if dist.is_initialized() else 0) 113 | logx.msg(pformat(cfg)) 114 | shutil.copytree('models', os.path.join(cfg.work_dir, 'models'), dirs_exist_ok=True) 115 | 116 | # Setup CUDNN and GPU device 117 | torch.backends.cudnn.benchmark = True 118 | if args.gpu is not None: 119 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 120 | 121 | start_time = time.monotonic() 122 | train(cfg) 123 | end_time = time.monotonic() 124 | print('Total running time: ', timedelta(seconds=end_time - start_time)) 125 | -------------------------------------------------------------------------------- /utils/calc_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calc_acc(logits, label, ignore_index=-100, mode="multiclass"): 5 | if mode == "binary": 6 | indices = torch.round(logits).type(label.type()) 7 | elif mode == "multiclass": 8 | indices = torch.max(logits, dim=1)[1] 9 | 10 | if label.size() == logits.size(): 11 | ignore = 1 - torch.round(label.sum(dim=1)) 12 | label = torch.max(label, dim=1)[1] 13 | else: 14 | ignore = torch.eq(label, ignore_index).view(-1) 15 | 16 | correct = torch.eq(indices, label).view(-1) 17 | num_correct = torch.sum(correct) 18 | num_examples = logits.shape[0] - ignore.sum() 19 | 20 | return num_correct.float() / num_examples.float() 21 | -------------------------------------------------------------------------------- /utils/eval_sysu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | from sklearn.preprocessing import normalize 5 | 6 | 7 | def get_gallery_names(perm, cams, ids, trial_id, num_shots=1): 8 | names = [] 9 | for cam in cams: 10 | cam_perm = perm[cam - 1][0].squeeze() 11 | for i in ids: 12 | instance_id = cam_perm[i - 1][trial_id][:num_shots] 13 | names.extend(['cam{}/{:0>4d}/{:0>4d}'.format(cam, i, ins) for ins in instance_id.tolist()]) 14 | 15 | return names 16 | 17 | 18 | def get_unique(array): 19 | _, idx = np.unique(array, return_index=True) 20 | return array[np.sort(idx)] 21 | 22 | 23 | def get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 24 | gallery_unique_count = get_unique(gallery_ids).shape[0] 25 | match_counter = np.zeros((gallery_unique_count,)) 26 | 27 | result = gallery_ids[sorted_indices] 28 | cam_locations_result = gallery_cam_ids[sorted_indices] 29 | 30 | valid_probe_sample_count = 0 31 | 32 | for probe_index in range(sorted_indices.shape[0]): 33 | # remove gallery samples from the same camera of the probe 34 | result_i = result[probe_index, :] 35 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1 36 | 37 | # remove the -1 entries from the label result 38 | result_i = np.array([i for i in result_i if i != -1]) 39 | 40 | # remove duplicated id in "stable" manner 41 | result_i_unique = get_unique(result_i) 42 | 43 | # match for probe i 44 | match_i = np.equal(result_i_unique, query_ids[probe_index]) 45 | 46 | if np.sum(match_i) != 0: # if there is true matching in gallery 47 | valid_probe_sample_count += 1 48 | match_counter += match_i 49 | 50 | rank = match_counter / valid_probe_sample_count 51 | cmc = np.cumsum(rank) 52 | return cmc 53 | 54 | 55 | def get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 56 | result = gallery_ids[sorted_indices] 57 | cam_locations_result = gallery_cam_ids[sorted_indices] 58 | 59 | valid_probe_sample_count = 0 60 | avg_precision_sum = 0 61 | 62 | for probe_index in range(sorted_indices.shape[0]): 63 | # remove gallery samples from the same camera of the probe 64 | result_i = result[probe_index, :] 65 | result_i[cam_locations_result[probe_index, :] == query_cam_ids[probe_index]] = -1 66 | 67 | # remove the -1 entries from the label result 68 | result_i = np.array([i for i in result_i if i != -1]) 69 | 70 | # match for probe i 71 | match_i = result_i == query_ids[probe_index] 72 | true_match_count = np.sum(match_i) 73 | 74 | if true_match_count != 0: # if there is true matching in gallery 75 | valid_probe_sample_count += 1 76 | true_match_rank = np.where(match_i)[0] 77 | 78 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1)) 79 | avg_precision_sum += ap 80 | 81 | mAP = avg_precision_sum / valid_probe_sample_count 82 | return mAP 83 | 84 | 85 | def eval_sysu(query_feats, query_ids, query_cam_ids, gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, 86 | perm, mode='all', num_shots=1, num_trials=10): 87 | assert mode in ['indoor', 'all'] 88 | 89 | gallery_cams = [1, 2] if mode == 'indoor' else [1, 2, 4, 5] 90 | 91 | # cam2 and cam3 are in the same location 92 | query_cam_ids[np.equal(query_cam_ids, 3)] = 2 93 | 94 | gallery_indices = np.in1d(gallery_cam_ids, gallery_cams) 95 | 96 | gallery_feats = normalize(gallery_feats[gallery_indices], axis=1) 97 | gallery_cam_ids = gallery_cam_ids[gallery_indices] 98 | gallery_ids = gallery_ids[gallery_indices] 99 | gallery_img_paths = gallery_img_paths[gallery_indices] 100 | gallery_names = np.array(['/'.join(os.path.splitext(path)[0].split('/')[-3:]) for path in gallery_img_paths]) 101 | 102 | gallery_id_set = np.unique(gallery_ids) 103 | 104 | mAP, r1, r5, r10, r20 = 0, 0, 0, 0, 0 105 | for t in range(num_trials): 106 | names = get_gallery_names(perm, gallery_cams, gallery_id_set, t, num_shots) 107 | flag = np.in1d(gallery_names, names) 108 | 109 | g_feat = gallery_feats[flag] 110 | g_ids = gallery_ids[flag] 111 | g_cam_ids = gallery_cam_ids[flag] 112 | 113 | sorted_indices = np.argsort(-np.dot(query_feats, g_feat.T), axis=1) 114 | 115 | mAP += get_mAP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 116 | 117 | cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 118 | r1 += cmc[0] 119 | r5 += cmc[4] 120 | r10 += cmc[9] 121 | r20 += cmc[19] 122 | 123 | mAP /= num_trials 124 | r1 /= num_trials 125 | r5 /= num_trials 126 | r10 /= num_trials 127 | r20 /= num_trials 128 | 129 | return mAP, r1, r5, r10, r20 130 | -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch.nn.functional as F 6 | 7 | 8 | @numba.jit 9 | def compute_ap(good_index, junk_index, sort_index): 10 | cmc = np.zeros((sort_index.shape[0],)) 11 | n_good = good_index.shape[0] 12 | 13 | old_recall = 0 14 | old_precision = 1.0 15 | ap = 0 16 | intersect_size = 0 17 | j = 0 18 | good_now = 0 19 | n_junk = 0 20 | for i in range(sort_index.shape[0]): 21 | flag = 0 22 | if np.any(good_index == sort_index[i]): 23 | cmc[i - n_junk:] = 1 24 | flag = 1 25 | good_now = good_now + 1 26 | 27 | if np.any(junk_index == sort_index[i]): 28 | n_junk = n_junk + 1 29 | continue 30 | 31 | if flag == 1: 32 | intersect_size = intersect_size + 1 33 | 34 | recall = intersect_size / n_good 35 | precision = intersect_size / (j + 1) 36 | ap = ap + (recall - old_recall) * ((old_precision + precision) / 2) 37 | old_recall = recall 38 | old_precision = precision 39 | j = j + 1 40 | 41 | if good_now == n_good: 42 | break 43 | 44 | return ap, cmc 45 | 46 | 47 | def eval_feature(query_features, gallery_features, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids, device): 48 | if isinstance(gallery_features, np.ndarray): 49 | gallery_features = torch.from_numpy(gallery_features) 50 | 51 | if isinstance(query_features, np.ndarray): 52 | query_features = torch.from_numpy(query_features) 53 | 54 | gallery_features = gallery_features.to(device) 55 | query_features = query_features.to(device) 56 | 57 | num_query = query_ids.shape[0] 58 | num_gallery = gallery_ids.shape[0] 59 | 60 | gallery_features = F.normalize(gallery_features, p=2, dim=1) 61 | query_features = F.normalize(query_features, p=2, dim=1) 62 | 63 | dist_array = -torch.mm(query_features, gallery_features.transpose(0, 1)).cpu().numpy() 64 | 65 | ap = np.zeros((num_query,)) # average precision 66 | cmc = np.zeros((num_query, num_gallery)) 67 | 68 | index = np.arange(num_gallery) 69 | for i in tqdm(range(num_query)): 70 | good_flag = np.logical_and(np.not_equal(gallery_cam_ids, query_cam_ids[i]), np.equal(gallery_ids, query_ids[i])) 71 | junk_flag_1 = np.equal(gallery_ids, -1) 72 | junk_flag_2 = np.logical_and(np.equal(gallery_cam_ids, query_cam_ids[i]), 73 | np.equal(gallery_ids, query_ids[i])) 74 | 75 | good_index = index[good_flag] 76 | junk_index = index[np.logical_or(junk_flag_1, junk_flag_2)] 77 | 78 | dist = dist_array[i] 79 | 80 | sort_index = np.argsort(dist) 81 | 82 | ap[i], cmc[i, :] = compute_ap(good_index, junk_index, sort_index) 83 | 84 | mAP = np.mean(ap) 85 | r1 = np.mean(cmc, axis=0)[0] 86 | r5 = np.mean(np.clip(np.sum(cmc[:, :5], axis=1), 0, 1), axis=0) 87 | r10 = np.mean(np.clip(np.sum(cmc[:, :10], axis=1), 0, 1), axis=0) 88 | 89 | return mAP, r1, r5, r10 90 | 91 | 92 | def eval_rank_list(rank_list, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 93 | num_query = len(query_ids) 94 | num_gallery = len(gallery_ids) 95 | 96 | ap = np.zeros((num_query,)) # average precision 97 | cmc = np.zeros((num_query, num_gallery)) 98 | for i in tqdm(range(num_query)): 99 | index = np.arange(num_gallery) 100 | good_flag = np.logical_and(np.not_equal(gallery_cam_ids, query_cam_ids[i]), np.equal(gallery_ids, query_ids[i])) 101 | junk_flag_1 = np.equal(gallery_ids, -1) 102 | junk_flag_2 = np.logical_and(np.equal(gallery_cam_ids, query_cam_ids[i]), 103 | np.equal(gallery_ids, query_ids[i])) 104 | 105 | good_index = index[good_flag] 106 | junk_index = index[np.logical_or(junk_flag_1, junk_flag_2)] 107 | 108 | sort_index = rank_list[i] 109 | 110 | ap[i], cmc[i, :] = compute_ap(good_index, junk_index, sort_index) 111 | 112 | mAP = np.mean(ap) 113 | r1 = np.mean(cmc, axis=0)[0] 114 | r5 = np.mean(np.clip(np.sum(cmc[:, :5], axis=1), 0, 1), axis=0) 115 | r10 = np.mean(np.clip(np.sum(cmc[:, :10], axis=1), 0, 1), axis=0) 116 | 117 | return mAP, r1, r5, r10 118 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | from torch.optim import lr_scheduler 3 | 4 | 5 | # MultiStepLR with linear warmup 6 | class WarmupMultiStepLR(lr_scheduler._LRScheduler): 7 | def __init__(self, 8 | optimizer, 9 | milestones, 10 | gamma, 11 | warmup_epochs, 12 | warmup_factor, 13 | last_epoch=-1): 14 | if not list(milestones) == sorted(milestones): 15 | raise ValueError("Milestones should be a list of increasing ints. Got {}", milestones) 16 | 17 | self.milestones = milestones 18 | self.gamma = gamma 19 | self.warmup_epochs = warmup_epochs 20 | self.warmup_factor = warmup_factor 21 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 22 | 23 | def get_lr(self): 24 | warmup_factor = 1 25 | if self.last_epoch < self.warmup_epochs: 26 | alpha = self.last_epoch / self.warmup_epochs 27 | warmup_factor = self.warmup_factor * (1 - alpha) + 1 * alpha 28 | 29 | return [ 30 | base_lr 31 | * warmup_factor 32 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 33 | for base_lr in self.base_lrs 34 | ] 35 | -------------------------------------------------------------------------------- /utils/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def weights_init_kaiming(m): 5 | classname = m.__class__.__name__ 6 | if 'Linear' in classname: 7 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 8 | if m.bias is not None: 9 | nn.init.constant_(m.bias, 0.0) 10 | elif 'Conv' in classname: 11 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 12 | if m.bias is not None: 13 | nn.init.constant_(m.bias, 0.0) 14 | elif 'BatchNorm' in classname: 15 | if m.affine: 16 | nn.init.normal_(m.weight, 1.0, 0.01) 17 | nn.init.constant_(m.bias, 0.0) 18 | 19 | 20 | def weights_init_classifier(m): 21 | classname = m.__class__.__name__ 22 | if 'Linear' in classname: 23 | nn.init.normal_(m.weight, std=0.001) 24 | if m.bias: 25 | nn.init.constant_(m.bias, 0.0) 26 | -------------------------------------------------------------------------------- /utils/tsne.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import scipy.io as sio 5 | import matplotlib as mpl 6 | 7 | mpl.use('AGG') 8 | import matplotlib.pyplot as plt 9 | from sklearn.manifold import TSNE 10 | 11 | if __name__ == '__main__': 12 | test_ids = [ 13 | 6, 10, 17, 21, 24, 25, 27, 28, 31, 34, 36, 37, 40, 41, 42, 43, 44, 45, 49, 50, 51, 54, 63, 69, 75, 80, 81, 82, 14 | 83, 84, 85, 86, 87, 88, 89, 90, 93, 102, 104, 105, 106, 108, 112, 116, 117, 122, 125, 129, 130, 134, 138, 139, 15 | 150, 152, 162, 166, 167, 170, 172, 176, 185, 190, 192, 202, 204, 207, 210, 215, 223, 229, 232, 237, 252, 253, 16 | 257, 259, 263, 266, 269, 272, 273, 274, 275, 282, 285, 291, 300, 301, 302, 303, 307, 312, 315, 318, 331, 333 17 | ] 18 | random.seed(0) 19 | tsne = TSNE(n_components=2, init='pca') 20 | selected_ids = random.sample(test_ids, 20) 21 | plt.figure(figsize=(5, 5)) 22 | 23 | # features without dual path 24 | q_mat_path = 'features/sysu/query-sysu-test-nodual-nore-adam-16x8-grey_model_150.mat' 25 | g_mat_path = 'features/sysu/gallery-sysu-test-nodual-nore-adam-16x8-grey_model_150.mat' 26 | 27 | mat = sio.loadmat(q_mat_path) 28 | q_feats = mat["feat"] 29 | q_ids = mat["ids"].squeeze() 30 | flag = np.in1d(q_ids, selected_ids) 31 | q_feats = q_feats[flag] 32 | 33 | mat = sio.loadmat(g_mat_path) 34 | g_feats = mat["feat"] 35 | g_ids = mat["ids"].squeeze() 36 | flag = np.in1d(g_ids, selected_ids) 37 | g_feats = g_feats[flag] 38 | 39 | embed = tsne.fit_transform(np.concatenate([q_feats, g_feats], axis=0)) 40 | c = ['r'] * q_feats.shape[0] + ['b'] * g_feats.shape[0] 41 | # plt.subplot(1, 2, 1) 42 | plt.scatter(embed[:, 0], embed[:, 1], c=c) 43 | 44 | # # features with dual path 45 | # q_mat_path = 'features/sysu/query-sysu-test-dual-nore-separatelayer12-0.05_model_30.mat' 46 | # g_mat_path = 'features/sysu/gallery-sysu-test-dual-nore-separatelayer12-0.05_model_30.mat' 47 | # 48 | # mat = sio.loadmat(q_mat_path) 49 | # q_feats = mat["feat"] 50 | # q_ids = mat["ids"].squeeze() 51 | # flag = np.in1d(q_ids, selected_ids) 52 | # q_feats = q_feats[flag] 53 | # 54 | # mat = sio.loadmat(g_mat_path) 55 | # g_feats = mat["feat"] 56 | # g_ids = mat["ids"].squeeze() 57 | # flag = np.in1d(g_ids, selected_ids) 58 | # g_feats = g_feats[flag] 59 | # 60 | # embed = tsne.fit_transform(np.concatenate([q_feats, g_feats], axis=0)) 61 | # c = ['r'] * q_feats.shape[0] + ['b'] * g_feats.shape[0] 62 | # plt.subplot(1, 2, 2) 63 | # plt.scatter(embed[:, 0], embed[:, 1], c=c) 64 | 65 | plt.tight_layout() 66 | plt.savefig('tsne-adv-layer2-separate-l2.jpg') 67 | --------------------------------------------------------------------------------