├── LICENSE ├── README.md ├── configs ├── default │ ├── __init__.py │ ├── dataset.py │ └── strategy.py ├── duke2market.yml ├── market2duke.yml └── single_domain.yml ├── data ├── __init__.py ├── dataset.py └── sampler.py ├── engine ├── __init__.py ├── engine.py └── metric.py ├── eval.py ├── extract.py ├── layers ├── __init__.py ├── loss │ ├── __init__.py │ ├── am_softmax.py │ ├── center_loss.py │ ├── nca_loss.py │ ├── nn_loss.py │ └── triplet_loss.py └── module │ ├── __init__.py │ ├── block_grad.py │ ├── exemplar_linear.py │ └── reverse_grad.py ├── models ├── model.py └── resnet.py ├── train.py ├── train.sh └── utils ├── calc_acc.py ├── curve.py ├── dist_utils.py ├── eval_cmc.py ├── eval_model.py ├── fig.py ├── fp16_utils.py ├── misc.py ├── mod_utils.py ├── rank_vis.py └── tsne.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Chuanchen Luo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generalizing Person Re-identification 2 | Implementation of ECCV2020 paper [*Generalizing Person Re-Identification by Camera-Aware Instance Learning and Cross-Domain Mixup*](https://www.ecva.net/papers/eccv_2020/papers_ECCV/html/2329_ECCV_2020_paper.php). 3 | 4 | ## Dependencies 5 | * python 3.6 6 | * pytorch 1.3 7 | * [apex](https://github.com/NVIDIA/apex) 8 | * [ignite](https://github.com/pytorch/ignite) 0.2.0 9 | 10 | ## Preparation 11 | Download and extract Market-1501, DukeMTMC-reID, CUHK03 and MSMT17. 12 | Replace the root paths of corresponding datasets in the config file `configs/default/dataset.py`. 13 | 14 | 15 | ## Train 16 | ```shell script 17 | bash train.sh GPU_ID_0,GPU_ID_1 PATH_TO_YOUR_YAML_FILE 18 | ``` 19 | Our code is validated under 2-GPUs setting. `GPU_ID_0` and `GPU_ID_1` are the indices of the selected GPUs. `PATH_TO_YOUR_YAML_FILE` is the path to your config yaml file. We also offer the template of config file `configs/duke2market.yml`, `configs/market2duke.yml`, `configs/single_domain.yml`. You can optionally adjust the hyper-parameters in the config yaml file. All of our experiments are conducted under the mix-precision training to reduce the burden of GPU memory, *i.e*, we set the flag `fp16=true`. 20 | 21 | During the training, the checkpoint files and logs will be saved in `./checkpoints` and `./logs` directories, respectively. 22 | 23 | ## Test 24 | In our code, the model is evaluated on the target domain at intervals automatically. 25 | You can also evaluate the trained model manually by running: 26 | ```shell script 27 | python3 eval.py GPU_ID PATH_TO_CHECKPOINT_FILE [--dataset DATASET] 28 | ``` 29 | 30 | `PATH_TO_CHECKPOINT_FILE` is the path to the checkpoint file of the trained model. `DATASET` is the name of the target dataset. Its value can be `{market,duke,cuhk,msmt}`. As an intermediate product, the feature matrix of target images is stored in the `./features` directory. 31 | 32 | ## Citation 33 | 34 | @inproceedings{luo2020generalizing, 35 | title={Generalizing Person Re-Identification by Camera-Aware Invariance Learning and Cross-Domain Mixup}, 36 | author={Luo, Chuanchen and Song, Chunfeng and Zhang, Zhaoxiang}, 37 | booktitle={European Conference on Computer Vision}, 38 | year={2020} 39 | } 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /configs/default/__init__.py: -------------------------------------------------------------------------------- 1 | from configs.default.dataset import dataset_cfg 2 | from configs.default.strategy import strategy_cfg 3 | 4 | __all__ = ["dataset_cfg", "strategy_cfg"] 5 | -------------------------------------------------------------------------------- /configs/default/dataset.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | dataset_cfg = CfgNode() 4 | 5 | # config for dataset 6 | dataset_cfg.market = CfgNode() 7 | dataset_cfg.market.num_id = 751 8 | dataset_cfg.market.num_cam = 6 9 | dataset_cfg.market.root = "/home/chuanchen_luo/data/Market-1501-v15.09.15" 10 | dataset_cfg.market.train = "bounding_box_train" 11 | dataset_cfg.market.query = "query" 12 | dataset_cfg.market.gallery = "bounding_box_test" 13 | 14 | dataset_cfg.duke = CfgNode() 15 | dataset_cfg.duke.num_id = 702 16 | dataset_cfg.duke.num_cam = 8 17 | dataset_cfg.duke.root = "/home/chuanchen_luo/data/DukeMTMC-reID" 18 | dataset_cfg.duke.train = "bounding_box_train" 19 | dataset_cfg.duke.query = "query" 20 | dataset_cfg.duke.gallery = "bounding_box_test" 21 | 22 | dataset_cfg.cuhk = CfgNode() 23 | dataset_cfg.cuhk.num_id = 767 24 | dataset_cfg.cuhk.num_cam = 2 25 | dataset_cfg.cuhk.root = "/home/chuanchen_luo/data/cuhk03-np/labeled" 26 | dataset_cfg.cuhk.train = "bounding_box_train" 27 | dataset_cfg.cuhk.query = "query" 28 | dataset_cfg.cuhk.gallery = "bounding_box_test" 29 | 30 | dataset_cfg.msmt = CfgNode() 31 | dataset_cfg.msmt.num_id = 1041 32 | dataset_cfg.msmt.num_cam = 15 33 | dataset_cfg.msmt.root = "/home/chuanchen_luo/data/MSMT17_V1" 34 | dataset_cfg.msmt.train = "train" 35 | dataset_cfg.msmt.query = "list_query.txt" 36 | dataset_cfg.msmt.gallery = "list_gallery.txt" 37 | -------------------------------------------------------------------------------- /configs/default/strategy.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | strategy_cfg = CfgNode() 4 | 5 | strategy_cfg.prefix = "baseline" 6 | 7 | # setting for loader 8 | strategy_cfg.batch_size = 128 9 | 10 | # settings for optimizer 11 | strategy_cfg.optimizer = "sgd" 12 | strategy_cfg.new_params_lr = 0.1 13 | strategy_cfg.ft_lr = 0.01 14 | strategy_cfg.wd = 5e-4 15 | strategy_cfg.lr_step = [40] 16 | 17 | strategy_cfg.fp16 = False 18 | 19 | strategy_cfg.num_epoch = 60 20 | 21 | # settings for dataset 22 | strategy_cfg.joint_training = False 23 | strategy_cfg.source_dataset = "market" 24 | strategy_cfg.target_dataset = "duke" 25 | strategy_cfg.image_size = (256, 128) 26 | 27 | # settings for augmentation 28 | strategy_cfg.random_flip = True 29 | strategy_cfg.random_crop = True 30 | strategy_cfg.random_erase = True 31 | strategy_cfg.color_jitter = False 32 | strategy_cfg.padding = 10 33 | 34 | # settings for base architecture 35 | strategy_cfg.drop_last_stride = False 36 | 37 | 38 | # settings for neighborhood consistency 39 | strategy_cfg.neighbor_mode = 1 40 | strategy_cfg.neighbor_eps = 0.8 41 | strategy_cfg.scale = 20 42 | 43 | # settings for mix-up 44 | strategy_cfg.mix = False 45 | strategy_cfg.alpha = 0.5 46 | 47 | # logging 48 | strategy_cfg.eval_interval = -1 49 | strategy_cfg.log_period = 50 50 | -------------------------------------------------------------------------------- /configs/duke2market.yml: -------------------------------------------------------------------------------- 1 | prefix: duke2market-all-eps0.8-mix0.6 2 | 3 | fp16: true 4 | 5 | 6 | # cross-domain mix-up 7 | mix: true 8 | alpha: 0.6 9 | 10 | 11 | # neighborhood constraint 12 | neighbor_mode: 1 # 0: camera-agnostic neighborhood; 1: camera-aware neighborhood 13 | neighbor_eps: 0.8 14 | 15 | scale: 10 16 | 17 | 18 | #dataset 19 | batch_size: 256 20 | joint_training: true 21 | 22 | source_dataset: duke 23 | target_dataset: market 24 | 25 | 26 | # architecture 27 | drop_last_stride: true 28 | 29 | # optimizer 30 | new_params_lr: 0.05 31 | ft_lr: 0.01 32 | optimizer: sgd 33 | num_epoch: 70 34 | lr_step: [60] 35 | 36 | # augmentation 37 | random_flip: true 38 | random_crop: true 39 | random_erase: true 40 | color_jitter: false 41 | padding: 10 42 | 43 | 44 | # log 45 | eval_interval: 10 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /configs/market2duke.yml: -------------------------------------------------------------------------------- 1 | prefix: market2duke-all-eps0.8-mix0.6 2 | 3 | fp16: true 4 | 5 | 6 | # cross-domain mix-up 7 | mix: true 8 | alpha: 0.6 9 | 10 | 11 | # neighborhood constraint 12 | neighbor_mode: 1 # 0: camera-agnostic neighborhood; 1: camera-aware neighborhood 13 | neighbor_eps: 0.8 14 | 15 | scale: 10 16 | 17 | 18 | #dataset 19 | batch_size: 256 20 | joint_training: true 21 | 22 | source_dataset: market 23 | target_dataset: duke 24 | 25 | 26 | # architecture 27 | drop_last_stride: true 28 | 29 | # optimizer 30 | new_params_lr: 0.05 31 | ft_lr: 0.01 32 | optimizer: sgd 33 | num_epoch: 70 34 | lr_step: [60] 35 | 36 | # augmentation 37 | random_flip: true 38 | random_crop: true 39 | random_erase: true 40 | color_jitter: false 41 | padding: 10 42 | 43 | 44 | # log 45 | eval_interval: 10 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /configs/single_domain.yml: -------------------------------------------------------------------------------- 1 | prefix: baseline-single-domain 2 | 3 | fp16: true 4 | 5 | #dataset 6 | batch_size: 128 7 | joint_training: false 8 | 9 | source_dataset: duke 10 | target_dataset: duke 11 | 12 | # architecture 13 | drop_last_stride: true 14 | 15 | # optimizer 16 | new_params_lr: 0.05 17 | ft_lr: 0.01 18 | optimizer: sgd 19 | num_epoch: 70 20 | lr_step: [60] 21 | 22 | 23 | 24 | # augmentation 25 | random_flip: true 26 | random_crop: true 27 | random_erase: true 28 | color_jitter: false 29 | padding: 10 30 | 31 | 32 | # log 33 | eval_interval: 10 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torchvision.transforms as T 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import DistributedSampler 8 | from torch.utils.data import RandomSampler 9 | 10 | from data.dataset import CrossDataset 11 | from data.dataset import ImageFolder 12 | from data.dataset import ImageListFile 13 | from data.sampler import CrossDatasetDistributedSampler 14 | from data.sampler import CrossDatasetRandomSampler 15 | from data.sampler import RandomIdentitySampler 16 | 17 | 18 | def collate_fn(batch): # img, label, cam_id, img_path, img_id 19 | samples = list(zip(*batch)) 20 | 21 | data = [torch.stack(x, 0) for i, x in enumerate(samples) if i != 3] 22 | data.insert(3, samples[3]) 23 | return data 24 | 25 | 26 | def get_train_loader(root, batch_size, image_size, random_flip=False, random_crop=False, random_erase=False, 27 | color_jitter=False, padding=0, num_workers=4): 28 | # data pre-processing 29 | t = [T.Resize(image_size)] 30 | 31 | if random_flip: 32 | t.append(T.RandomHorizontalFlip()) 33 | 34 | if color_jitter: 35 | t.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)) 36 | 37 | if random_crop: 38 | t.extend([T.Pad(padding, fill=127), T.RandomCrop(image_size)]) 39 | 40 | t.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 41 | 42 | if random_erase: 43 | t.append(T.RandomErasing(scale=(0.02, 0.25))) 44 | 45 | transform = T.Compose(t) 46 | 47 | # dataset 48 | train_dataset = ImageFolder(root, transform=transform, recursive=True, label_organize=True) 49 | 50 | if dist.is_initialized(): 51 | rand_sampler = DistributedSampler(train_dataset) 52 | else: 53 | rand_sampler = RandomSampler(train_dataset) 54 | 55 | # loader 56 | train_loader = DataLoader(train_dataset, batch_size, drop_last=True, pin_memory=True, 57 | collate_fn=collate_fn, num_workers=num_workers, sampler=rand_sampler) 58 | 59 | return train_loader 60 | 61 | 62 | def get_cross_domain_train_loader(source_root, target_root, batch_size, image_size, random_flip=False, 63 | random_crop=False, random_erase=False, color_jitter=False, padding=0, num_workers=4): 64 | if isinstance(random_crop, bool): 65 | random_crop = (random_crop, random_crop) 66 | if isinstance(random_flip, bool): 67 | random_flip = (random_flip, random_flip) 68 | if isinstance(random_erase, bool): 69 | random_erase = (random_erase, random_erase) 70 | if isinstance(color_jitter, bool): 71 | color_jitter = (color_jitter, color_jitter) 72 | 73 | # data pre-processing 74 | source_transform = [T.Resize(image_size)] 75 | target_transform = [T.Resize(image_size)] 76 | 77 | if random_flip[0]: 78 | source_transform.append(T.RandomHorizontalFlip()) 79 | if random_flip[1]: 80 | target_transform.append(T.RandomHorizontalFlip()) 81 | 82 | if color_jitter[0]: 83 | source_transform.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)) 84 | if color_jitter[1]: 85 | target_transform.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)) 86 | 87 | if random_crop[0]: 88 | source_transform.extend([T.Pad(padding, fill=127), T.RandomCrop(image_size)]) 89 | if random_crop[1]: 90 | target_transform.extend([T.Pad(padding, fill=127), T.RandomCrop(image_size)]) 91 | 92 | source_transform.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 93 | target_transform.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 94 | 95 | if random_erase[0]: 96 | source_transform.append(T.RandomErasing(scale=(0.02, 0.25))) 97 | if random_erase[1]: 98 | target_transform.append(T.RandomErasing(scale=(0.02, 0.25))) 99 | 100 | source_transform = T.Compose(source_transform) 101 | target_transform = T.Compose(target_transform) 102 | 103 | # dataset 104 | source_dataset = ImageFolder(source_root, transform=source_transform, recursive=True, label_organize=True) 105 | target_dataset = ImageFolder(target_root, transform=target_transform, recursive=True, label_organize=True) 106 | 107 | concat_dataset = CrossDataset(source_dataset, target_dataset) 108 | 109 | # sampler 110 | if dist.is_initialized(): 111 | cross_sampler = CrossDatasetDistributedSampler(source_dataset, target_dataset, batch_size) 112 | else: 113 | cross_sampler = CrossDatasetRandomSampler(source_dataset, target_dataset, batch_size) 114 | 115 | # data loader 116 | train_loader = DataLoader(concat_dataset, batch_size, sampler=cross_sampler, drop_last=True, pin_memory=True, 117 | collate_fn=collate_fn, num_workers=num_workers) 118 | 119 | return train_loader 120 | 121 | 122 | def get_test_loader(root, batch_size, image_size, num_workers=4): 123 | # transform 124 | transform = T.Compose([ 125 | T.Resize(image_size), 126 | T.ToTensor(), 127 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 128 | ]) 129 | 130 | # dataset 131 | if root.lower().find("msmt") != -1: 132 | prefix = os.path.join(os.path.dirname(root), "test") 133 | test_dataset = ImageListFile(root, prefix=prefix, transform=transform) 134 | else: 135 | test_dataset = ImageFolder(root, transform=transform) 136 | 137 | # dataloader 138 | test_loader = DataLoader(dataset=test_dataset, 139 | batch_size=batch_size, 140 | shuffle=False, 141 | pin_memory=True, 142 | drop_last=False, 143 | collate_fn=collate_fn, 144 | num_workers=num_workers) 145 | 146 | return test_loader 147 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | 9 | ''' 10 | Specific dataset classes for person re-identification dataset. 11 | ''' 12 | 13 | 14 | class ImageFolder(Dataset): 15 | def __init__(self, root, transform=None, recursive=False, label_organize=False): 16 | if recursive: 17 | image_list = glob(os.path.join(root, "**", "*.jpg"), recursive=recursive) + \ 18 | glob(os.path.join(root, "**", "*.png"), recursive=recursive) 19 | else: 20 | image_list = glob(os.path.join(root, "*.jpg")) + glob(os.path.join(root, "*.png")) 21 | 22 | self.image_list = list(filter(lambda x: int(os.path.basename(x).split("_")[0]) != -1, image_list)) 23 | self.image_list.sort() 24 | 25 | ids = [] 26 | cam_ids = [] 27 | for img_path in self.image_list: 28 | splits = os.path.basename(img_path).split("_") 29 | ids.append(int(splits[0])) 30 | 31 | if root.lower().find("msmt") != -1: 32 | cam_id = int(splits[2]) 33 | else: 34 | cam_id = int(splits[1][1]) 35 | 36 | cam_ids.append(cam_id - 1) 37 | 38 | if label_organize: 39 | # organize identity label 40 | unique_ids = set(ids) 41 | label_map = dict(zip(unique_ids, range(len(unique_ids)))) 42 | 43 | ids = map(lambda x: label_map[x], ids) 44 | ids = list(ids) 45 | 46 | self.ids = ids 47 | self.cam_ids = cam_ids 48 | self.num_id = len(set(ids)) 49 | 50 | self.transform = transform 51 | 52 | def __len__(self): 53 | return len(self.image_list) 54 | 55 | def __getitem__(self, item): 56 | img_path = self.image_list[item] 57 | 58 | img = Image.open(img_path) 59 | 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | 63 | label = torch.tensor(self.ids[item], dtype=torch.long) 64 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long) 65 | item = torch.tensor(item, dtype=torch.long) 66 | 67 | return img, label, cam, img_path, item 68 | 69 | 70 | class ImageListFile(Dataset): 71 | def __init__(self, path, prefix=None, transform=None, label_organize=False): 72 | if not os.path.isfile(path): 73 | raise ValueError("The file %s does not exist." % path) 74 | 75 | image_list = list(np.loadtxt(path, delimiter=" ", dtype=np.str)[:, 0]) 76 | 77 | if prefix is not None: 78 | image_list = map(lambda x: os.path.join(prefix, x), image_list) 79 | 80 | self.image_list = list(filter(lambda x: int(os.path.basename(x).split("_")[0]) != -1, image_list)) 81 | self.image_list.sort() 82 | 83 | ids = [] 84 | cam_ids = [] 85 | for img_path in self.image_list: 86 | splits = os.path.basename(img_path).split("_") 87 | ids.append(int(splits[0])) 88 | 89 | if path.lower().find("msmt") != -1: 90 | cam_id = int(splits[2]) 91 | else: 92 | cam_id = int(splits[1][1]) 93 | 94 | cam_ids.append(cam_id - 1) 95 | 96 | if label_organize: 97 | # organize identity label 98 | unique_ids = set(ids) 99 | label_map = dict(zip(unique_ids, range(len(unique_ids)))) 100 | 101 | ids = map(lambda x: label_map[x], ids) 102 | ids = list(ids) 103 | 104 | self.cam_ids = cam_ids 105 | self.ids = ids 106 | self.num_id = len(set(ids)) 107 | 108 | self.transform = transform 109 | 110 | def __len__(self): 111 | return len(self.image_list) 112 | 113 | def __getitem__(self, item): 114 | img_path = self.image_list[item] 115 | img = Image.open(img_path) 116 | 117 | if self.transform is not None: 118 | img = self.transform(img) 119 | 120 | label = torch.tensor(self.ids[item], dtype=torch.long) 121 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long) 122 | item = torch.tensor(item, dtype=torch.long) 123 | 124 | return img, label, cam, img_path, item 125 | 126 | 127 | class CrossDataset(Dataset): 128 | def __init__(self, source_dataset, target_dataset): 129 | super(CrossDataset, self).__init__() 130 | 131 | self.source_dataset = source_dataset 132 | self.target_dataset = target_dataset 133 | 134 | self.source_size = len(self.source_dataset) 135 | self.target_size = len(target_dataset) 136 | 137 | self.num_source_cams = len(set(source_dataset.cam_ids)) 138 | self.num_target_cams = len(set(target_dataset.cam_ids)) 139 | 140 | def __len__(self): 141 | return self.source_size + self.target_size 142 | 143 | def __getitem__(self, idx): 144 | # from source dataset 145 | if idx < self.source_size: 146 | sample = self.source_dataset[idx] 147 | sample[2].add_(self.num_target_cams) 148 | 149 | return sample 150 | # from target dataset 151 | else: 152 | idx = idx - self.source_size 153 | sample = list(self.target_dataset[idx]) 154 | sample[1].fill_(-1) # set target label to -1 155 | 156 | return sample 157 | -------------------------------------------------------------------------------- /data/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import defaultdict 3 | 4 | import torch 5 | import numpy as np 6 | import torch.distributed as dist 7 | from torch.utils.data import Sampler 8 | 9 | 10 | class RandomIdentitySampler(Sampler): 11 | def __init__(self, dataset, batch_size, num_instance): 12 | assert batch_size % num_instance == 0 13 | 14 | self.dataset = dataset 15 | self.batch_size = batch_size 16 | self.p_size = batch_size // num_instance 17 | self.k_size = num_instance 18 | 19 | self.id2idx = defaultdict(list) 20 | for i, identity in enumerate(dataset.ids): 21 | self.id2idx[identity].append(i) 22 | 23 | def __len__(self): 24 | return self.dataset.num_id * self.k_size 25 | 26 | def __iter__(self): 27 | sample_list = [] 28 | 29 | id_perm = np.random.permutation(self.dataset.num_id) 30 | for start in range(0, self.dataset.num_id, self.p_size): 31 | selected_ids = id_perm[start:start + self.p_size] 32 | 33 | sample = [] 34 | for identity in selected_ids: 35 | if len(self.id2idx[identity]) < self.k_size: 36 | s = np.random.choice(self.id2idx[identity], size=self.k_size, replace=True) 37 | else: 38 | s = np.random.choice(self.id2idx[identity], size=self.k_size, replace=False) 39 | 40 | sample.extend(s) 41 | 42 | sample_list.extend(sample) 43 | 44 | return iter(sample_list) 45 | 46 | 47 | class CrossDatasetRandomSampler(Sampler): 48 | def __init__(self, source_dataset, target_dataset, batch_size): 49 | self.source_dataset = source_dataset 50 | self.target_dataset = target_dataset 51 | 52 | self.batch_size = batch_size 53 | 54 | self.source_size = len(source_dataset) 55 | self.target_size = len(target_dataset) 56 | 57 | def __len__(self): 58 | if self.source_size <= self.target_size: 59 | return self.target_size * 2 60 | else: 61 | return self.source_size * 2 62 | 63 | def __iter__(self): 64 | perm = [] 65 | half_bs = self.batch_size // 2 66 | 67 | if self.source_size <= self.target_size: 68 | multiplier = self.target_size // self.source_size 69 | reminder = self.target_size % self.source_size 70 | 71 | source_perm = [np.random.permutation(self.source_size) for _ in range(multiplier)] 72 | source_perm = np.concatenate(source_perm, axis=0).tolist() 73 | source_perm += np.random.randint(low=0, high=self.source_size, size=(reminder,), dtype=np.int64).tolist() 74 | 75 | target_perm = np.random.permutation(self.target_size) + self.source_size 76 | target_perm = target_perm.tolist() 77 | 78 | for i in range(math.ceil(self.target_size / half_bs)): 79 | perm.extend(source_perm[i * half_bs:(i + 1) * half_bs]) 80 | perm.extend(target_perm[i * half_bs:(i + 1) * half_bs]) 81 | 82 | else: 83 | multiplier = self.source_size // self.target_size 84 | reminder = self.source_size % self.target_size 85 | 86 | target_perm = [np.random.permutation(self.target_size) for _ in range(multiplier)] 87 | target_perm = np.concatenate(target_perm, axis=0) 88 | pad_perm = np.random.randint(low=0, high=self.target_size, size=(reminder,), dtype=np.int64) 89 | target_perm = np.concatenate([target_perm, pad_perm], axis=0) 90 | target_perm += self.source_size 91 | target_perm = target_perm.tolist() 92 | 93 | source_perm = np.random.permutation(self.source_size) 94 | source_perm = source_perm.tolist() 95 | 96 | for i in range(math.ceil(self.source_size / half_bs)): 97 | perm.extend(source_perm[i * half_bs:(i + 1) * half_bs]) 98 | perm.extend(target_perm[i * half_bs:(i + 1) * half_bs]) 99 | 100 | return iter(perm) 101 | 102 | 103 | class CrossDatasetDistributedSampler(Sampler): 104 | def __init__(self, source_dataset, target_dataset, batch_size, num_replicas=None, rank=None): 105 | if num_replicas is None: 106 | if not dist.is_available(): 107 | raise RuntimeError("Requires distributed package to be available") 108 | num_replicas = dist.get_world_size() 109 | batch_size *= num_replicas 110 | 111 | if rank is None: 112 | if not dist.is_available(): 113 | raise RuntimeError("Requires distributed package to be available") 114 | rank = dist.get_rank() 115 | 116 | self.source_dataset = source_dataset 117 | self.target_dataset = target_dataset 118 | self.batch_size = batch_size 119 | 120 | self.num_replicas = num_replicas 121 | self.rank = rank 122 | self.epoch = None 123 | 124 | self.source_size = len(source_dataset) 125 | self.target_size = len(target_dataset) 126 | 127 | def __len__(self): 128 | if self.source_size <= self.target_size: 129 | return int(np.round(self.target_size * 2 / self.num_replicas)) 130 | else: 131 | return int(np.round(self.source_size * 2 / self.num_replicas)) 132 | 133 | def __iter__(self): 134 | g = torch.Generator() 135 | g.manual_seed(self.epoch) 136 | 137 | perm = [] 138 | half_bs = self.batch_size // 2 139 | 140 | if self.source_size <= self.target_size: 141 | multiplier = self.target_size // self.source_size 142 | reminder = self.target_size % self.source_size 143 | pad = self.target_size % self.num_replicas 144 | 145 | source_perm = [torch.randperm(self.source_size, generator=g) for _ in range(multiplier)] 146 | source_perm = torch.cat(source_perm, dim=0).tolist() 147 | source_perm += torch.randint(low=0, high=self.source_size, size=(reminder + pad,), generator=g).tolist() 148 | 149 | target_perm = torch.randperm(self.target_size, generator=g) 150 | pad_perm = torch.randint(low=0, high=self.target_size, size=(pad,), generator=g) 151 | target_perm = torch.cat([target_perm, pad_perm]) 152 | target_perm += self.source_size 153 | target_perm = target_perm.tolist() 154 | 155 | for i in range(math.ceil(self.target_size / half_bs)): 156 | perm.extend(source_perm[i * half_bs:(i + 1) * half_bs]) 157 | perm.extend(target_perm[i * half_bs:(i + 1) * half_bs]) 158 | 159 | else: 160 | multiplier = self.source_size // self.target_size 161 | reminder = self.source_size % self.target_size 162 | pad = self.source_size % self.num_replicas 163 | 164 | target_perm = [torch.randperm(self.target_size, generator=g) for _ in range(multiplier)] 165 | target_perm = torch.cat(target_perm, dim=0) 166 | pad_perm = torch.randint(low=0, high=self.target_size, size=(reminder + pad,), generator=g) 167 | target_perm = torch.cat([target_perm, pad_perm], dim=0) 168 | target_perm += self.source_size 169 | target_perm = target_perm.tolist() 170 | 171 | source_perm = torch.randperm(self.source_size, generator=g) 172 | source_perm = source_perm.tolist() 173 | source_perm += torch.randint(low=0, high=self.source_size, size=(pad,), generator=g).tolist() 174 | 175 | for i in range(math.ceil(self.source_size / half_bs)): 176 | perm.extend(source_perm[i * half_bs:(i + 1) * half_bs]) 177 | perm.extend(target_perm[i * half_bs:(i + 1) * half_bs]) 178 | 179 | # subsample 180 | perm = perm[self.rank::self.num_replicas] 181 | 182 | return iter(perm) 183 | 184 | def set_epoch(self, epoch): 185 | self.epoch = epoch 186 | -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | from ignite.engine import Events 7 | from ignite.handlers import ModelCheckpoint 8 | from ignite.handlers import Timer 9 | from torch.nn.functional import normalize 10 | 11 | from engine.engine import create_eval_engine 12 | from engine.engine import create_train_engine 13 | from engine.metric import AutoKVMetric 14 | from utils.eval_cmc import eval_rank_list 15 | 16 | 17 | def get_trainer(model, optimizer, lr_scheduler=None, logger=None, writer=None, non_blocking=False, log_period=10, 18 | save_interval=10, save_dir="checkpoints", prefix="model", query_loader=None, gallery_loader=None, 19 | eval_interval=None): 20 | if logger is None: 21 | logger = logging.getLogger() 22 | logger.setLevel(logging.WARN) 23 | 24 | # trainer 25 | trainer = create_train_engine(model, optimizer, non_blocking) 26 | 27 | # checkpoint handler 28 | if not dist.is_initialized() or dist.get_rank() == 0: 29 | handler = ModelCheckpoint(save_dir, prefix, save_interval=save_interval, n_saved=4, create_dir=True, 30 | save_as_state_dict=True, require_empty=False) 31 | trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, 32 | {"model": model.module if dist.is_initialized() else model}) 33 | 34 | # metric 35 | timer = Timer(average=True) 36 | 37 | kv_metric = AutoKVMetric() 38 | 39 | # evaluator 40 | evaluator = None 41 | if not type(eval_interval) == int: 42 | raise TypeError("The parameter 'validate_interval' must be type INT.") 43 | if eval_interval > 0 and query_loader and gallery_loader: 44 | evaluator = create_eval_engine(model, non_blocking) 45 | 46 | @trainer.on(Events.EPOCH_STARTED) 47 | def epoch_started_callback(engine): 48 | epoch = engine.state.epoch 49 | 50 | if dist.is_initialized(): 51 | engine.state.dataloader.sampler.set_epoch(epoch) 52 | 53 | engine.state.dataloader.batch_sampler.drop_last = True if epoch > 1 else False 54 | 55 | kv_metric.reset() 56 | timer.reset() 57 | 58 | @trainer.on(Events.EPOCH_COMPLETED) 59 | def epoch_completed_callback(engine): 60 | epoch = engine.state.epoch 61 | 62 | if lr_scheduler is not None: 63 | lr_scheduler.step() 64 | 65 | if epoch % save_interval == 0: 66 | logger.info("Model saved at {}/{}_model_{}.pth".format(save_dir, prefix, epoch)) 67 | 68 | if evaluator and epoch % eval_interval == 0 and (not dist.is_initialized() or dist.get_rank() == 0): 69 | torch.cuda.empty_cache() 70 | 71 | # extract query feature 72 | evaluator.run(query_loader) 73 | 74 | q_feats = torch.cat(evaluator.state.feat_list, dim=0) 75 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 76 | q_cam = torch.cat(evaluator.state.cam_list, dim=0).numpy() 77 | 78 | # extract gallery feature 79 | evaluator.run(gallery_loader) 80 | 81 | g_feats = torch.cat(evaluator.state.feat_list, dim=0) 82 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 83 | g_cam = torch.cat(evaluator.state.cam_list, dim=0).numpy() 84 | 85 | distance = -torch.mm(normalize(q_feats), normalize(g_feats).transpose(0, 1)).numpy() 86 | rank_list = np.argsort(distance, axis=1) 87 | 88 | mAP, r1, r5, _ = eval_rank_list(rank_list, q_ids, q_cam, g_ids, g_cam) 89 | 90 | if writer is not None: 91 | writer.add_scalar('eval/mAP', mAP, epoch) 92 | writer.add_scalar('eval/r1', r1, epoch) 93 | writer.add_scalar('eval/r5', r5, epoch) 94 | 95 | torch.cuda.empty_cache() 96 | 97 | if dist.is_initialized(): 98 | dist.barrier() 99 | 100 | @trainer.on(Events.ITERATION_COMPLETED) 101 | def iteration_complete_callback(engine): 102 | timer.step() 103 | 104 | kv_metric.update(engine.state.output) 105 | 106 | epoch = engine.state.epoch 107 | iteration = engine.state.iteration 108 | iter_in_epoch = iteration - (epoch - 1) * len(engine.state.dataloader) 109 | 110 | if iter_in_epoch % log_period == 0: 111 | batch_size = engine.state.batch[0].size(0) 112 | speed = batch_size / timer.value() 113 | 114 | if dist.is_initialized(): 115 | speed *= dist.get_world_size() 116 | 117 | msg = "Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec" % (epoch, iter_in_epoch, speed) 118 | 119 | metric_dict = kv_metric.compute() 120 | 121 | # log output information 122 | for k in sorted(metric_dict.keys()): 123 | msg += "\t%s: %.4f" % (k, metric_dict[k]) 124 | 125 | if writer is not None: 126 | writer.add_scalar('metric/{}'.format(k), metric_dict[k], iteration) 127 | 128 | logger.info(msg) 129 | 130 | kv_metric.reset() 131 | timer.reset() 132 | 133 | return trainer 134 | -------------------------------------------------------------------------------- /engine/engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from apex import amp 3 | from ignite.engine import Engine 4 | from ignite.engine import Events 5 | from torch.autograd import no_grad 6 | 7 | 8 | def create_train_engine(model, optimizer, non_blocking=False): 9 | device = torch.device("cuda", torch.cuda.current_device()) 10 | 11 | def _process_func(engine, batch): 12 | model.train() 13 | 14 | data, labels, cam_ids, img_paths, img_ids = batch 15 | epoch = engine.state.epoch 16 | 17 | data = data.to(device, non_blocking=non_blocking) 18 | labels = labels.to(device, non_blocking=non_blocking) 19 | cam_ids = cam_ids.to(device, non_blocking=non_blocking) 20 | img_ids = img_ids.to(device, non_blocking=non_blocking) 21 | 22 | optimizer.zero_grad() 23 | 24 | loss, metric = model(data, labels, 25 | cam_ids=cam_ids, 26 | img_ids=img_ids, 27 | epoch=epoch) 28 | 29 | with amp.scale_loss(loss, optimizer) as scaled_loss: 30 | scaled_loss.backward() 31 | 32 | optimizer.step() 33 | 34 | return metric 35 | 36 | return Engine(_process_func) 37 | 38 | 39 | def create_eval_engine(model, non_blocking=False): 40 | device = torch.device("cuda", torch.cuda.current_device()) 41 | 42 | def _process_func(engine, batch): 43 | model.eval() 44 | 45 | data, label, cam = batch[:3] 46 | 47 | data = data.to(device, non_blocking=non_blocking) 48 | 49 | with no_grad(): 50 | feat = model(data, cam=cam.clone().to(device, non_blocking=non_blocking)) 51 | 52 | return feat.data.float().cpu(), label, cam 53 | 54 | engine = Engine(_process_func) 55 | 56 | @engine.on(Events.EPOCH_STARTED) 57 | def clear_data(engine): 58 | # feat list 59 | if not hasattr(engine.state, "feat_list"): 60 | setattr(engine.state, "feat_list", []) 61 | else: 62 | engine.state.feat_list.clear() 63 | 64 | # id_list 65 | if not hasattr(engine.state, "id_list"): 66 | setattr(engine.state, "id_list", []) 67 | else: 68 | engine.state.id_list.clear() 69 | 70 | # cam list 71 | if not hasattr(engine.state, "cam_list"): 72 | setattr(engine.state, "cam_list", []) 73 | else: 74 | engine.state.cam_list.clear() 75 | 76 | @engine.on(Events.ITERATION_COMPLETED) 77 | def store_data(engine): 78 | engine.state.feat_list.append(engine.state.output[0]) 79 | engine.state.id_list.append(engine.state.output[1]) 80 | engine.state.cam_list.append(engine.state.output[2]) 81 | 82 | return engine 83 | -------------------------------------------------------------------------------- /engine/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from ignite.exceptions import NotComputableError 6 | from ignite.metrics import Metric, Accuracy 7 | 8 | 9 | class ScalarMetric(Metric): 10 | 11 | def update(self, value): 12 | self.sum_metric += value 13 | self.sum_inst += 1 14 | 15 | def reset(self): 16 | self.sum_inst = 0 17 | self.sum_metric = 0 18 | 19 | def compute(self): 20 | if self.sum_inst == 0: 21 | raise NotComputableError('Accuracy must have at least one example before it can be computed') 22 | return self.sum_metric / self.sum_inst 23 | 24 | 25 | class IgnoreAccuracy(Accuracy): 26 | def __init__(self, ignore_index=-1): 27 | super(IgnoreAccuracy, self).__init__() 28 | 29 | self.ignore_index = ignore_index 30 | 31 | def reset(self): 32 | self._num_correct = 0 33 | self._num_examples = 0 34 | 35 | def update(self, output): 36 | 37 | y_pred, y = self._check_shape(output) 38 | self._check_type((y_pred, y)) 39 | 40 | if self._type == "binary": 41 | indices = torch.round(y_pred).type(y.type()) 42 | elif self._type == "multiclass": 43 | indices = torch.max(y_pred, dim=1)[1] 44 | 45 | correct = torch.eq(indices, y).view(-1) 46 | ignore = torch.eq(y, self.ignore_index).view(-1) 47 | self._num_correct += torch.sum(correct).item() 48 | self._num_examples += correct.shape[0] - ignore.sum().item() 49 | 50 | def compute(self): 51 | if self._num_examples == 0: 52 | raise NotComputableError('Accuracy must have at least one example before it can be computed') 53 | return self._num_correct / self._num_examples 54 | 55 | 56 | class AutoKVMetric(Metric): 57 | def __init__(self): 58 | self.kv_sum_metric = defaultdict(lambda: torch.tensor(0., device="cuda")) 59 | self.kv_sum_inst = defaultdict(lambda: torch.tensor(0., device="cuda")) 60 | 61 | self.kv_metric = defaultdict(lambda: 0) 62 | 63 | super(AutoKVMetric, self).__init__() 64 | 65 | def update(self, output): 66 | if not isinstance(output, dict): 67 | raise TypeError('The output must be a key-value dict.') 68 | 69 | for k in output.keys(): 70 | self.kv_sum_metric[k].add_(output[k]) 71 | self.kv_sum_inst[k].add_(1) 72 | 73 | def reset(self): 74 | self.kv_sum_metric.clear() 75 | self.kv_sum_inst.clear() 76 | self.kv_metric.clear() 77 | 78 | def compute(self): 79 | for k in self.kv_sum_metric.keys(): 80 | if self.kv_sum_inst[k] == 0: 81 | raise NotComputableError('Accuracy must have at least one example before it can be computed') 82 | 83 | metric_value = self.kv_sum_metric[k] / self.kv_sum_inst[k] 84 | 85 | if dist.is_initialized(): 86 | dist.barrier() 87 | dist.all_reduce(metric_value) 88 | dist.barrier() 89 | metric_value /= dist.get_world_size() 90 | 91 | self.kv_metric[k] = metric_value.item() 92 | 93 | return self.kv_metric 94 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import subprocess 5 | import sys 6 | 7 | import scipy.io as sio 8 | import torch 9 | 10 | from utils.eval_cmc import eval_feature,eval_rank_list 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("gpu", type=int) 15 | parser.add_argument("model_path", type=str) 16 | parser.add_argument("--dataset", type=str, default=None) 17 | 18 | args = parser.parse_args() 19 | dataset, fname = args.model_path.split("/")[1], args.model_path.split("/")[-1] 20 | 21 | if args.dataset is not None: 22 | dataset = args.dataset 23 | 24 | prefix = os.path.splitext(fname)[0] 25 | 26 | logging.basicConfig(level=logging.INFO, format="%(message)s") 27 | 28 | # extract feature 29 | cmd = "python{} extract.py {} {} ".format(sys.version[0], args.gpu, args.model_path) 30 | if args.dataset is not None: 31 | cmd += "--dataset {}".format(args.dataset) 32 | 33 | subprocess.check_call(cmd.strip().split(" ")) 34 | 35 | # evaluation 36 | query_features_path = 'features/%s/query-%s.mat' % (dataset, prefix) 37 | gallery_features_path = "features/%s/gallery-%s.mat" % (dataset, prefix) 38 | 39 | assert os.path.exists(query_features_path) and os.path.exists(gallery_features_path) 40 | 41 | query_mat = sio.loadmat(query_features_path) 42 | gallery_mat = sio.loadmat(gallery_features_path) 43 | 44 | query_features = query_mat["feat"] 45 | query_ids = query_mat["ids"].squeeze() 46 | query_cam_ids = query_mat["cam_ids"].squeeze() 47 | 48 | gallery_features = gallery_mat["feat"] 49 | gallery_ids = gallery_mat["ids"].squeeze() 50 | gallery_cam_ids = gallery_mat["cam_ids"].squeeze() 51 | 52 | eval_feature(query_features, gallery_features, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids, 53 | torch.device("cuda", args.gpu)) 54 | 55 | 56 | -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import scipy.io as sio 5 | import torch 6 | 7 | from configs.default import dataset_cfg 8 | from data import get_test_loader 9 | from models.model import Model 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("gpu", type=int) 14 | parser.add_argument("model_path", type=str) # TODO compatible for different models 15 | parser.add_argument("--img-h", type=int, default=256) 16 | parser.add_argument("--dataset", type=str, default=None) 17 | 18 | args = parser.parse_args() 19 | model_path = args.model_path 20 | fname = model_path.split("/")[-1] 21 | 22 | if args.dataset is not None: 23 | dataset = args.dataset 24 | else: 25 | dataset = model_path.split("/")[1] 26 | 27 | prefix = os.path.splitext(fname)[0] 28 | 29 | dataset_config = dataset_cfg.get(dataset) 30 | image_size = (args.img_h, 128) 31 | 32 | torch.backends.cudnn.benchmark = True 33 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 34 | 35 | model = Model(eval=True, drop_last_stride=True) 36 | 37 | state_dict = torch.load(model_path) 38 | 39 | model.load_state_dict(state_dict, strict=False) 40 | model.float() 41 | model.eval() 42 | model.cuda() 43 | 44 | # extract query feature 45 | query = get_test_loader(root=os.path.join(dataset_config.root, dataset_config.query), 46 | batch_size=512, 47 | image_size=image_size, 48 | num_workers=16) 49 | 50 | query_feat = [] 51 | query_label = [] 52 | query_cam_id = [] 53 | query_img_path = [] 54 | for data, label, cam_id, img_path, _ in query: 55 | with torch.autograd.no_grad(): 56 | feat = model(data.cuda(non_blocking=True)) 57 | 58 | query_feat.append(feat.data.cpu().numpy()) 59 | query_label.append(label.data.cpu().numpy()) 60 | query_cam_id.append(cam_id.data.cpu().numpy()) 61 | query_img_path.extend(img_path) 62 | 63 | query_feat = np.concatenate(query_feat, axis=0) 64 | query_label = np.concatenate(query_label, axis=0) 65 | query_cam_id = np.concatenate(query_cam_id, axis=0) 66 | print(query_feat.shape) 67 | 68 | dir_name = "features/{}".format(dataset, prefix) 69 | if not os.path.isdir(dir_name): 70 | os.makedirs(dir_name) 71 | 72 | save_name = "{}/query-{}.mat".format(dir_name, prefix) 73 | sio.savemat(save_name, 74 | {"feat": query_feat, 75 | "ids": query_label, 76 | "cam_ids": query_cam_id, 77 | "img_path": query_img_path}) 78 | 79 | # extract gallery feature 80 | gallery = get_test_loader(root=os.path.join(dataset_config.root, dataset_config.gallery), 81 | batch_size=512, 82 | image_size=image_size, 83 | num_workers=16) 84 | 85 | gallery_feat = [] 86 | gallery_label = [] 87 | gallery_cam_id = [] 88 | gallery_img_path = [] 89 | for data, label, cam_id, img_path, _ in gallery: 90 | with torch.autograd.no_grad(): 91 | feat = model(data.cuda(non_blocking=True)) 92 | 93 | gallery_feat.append(feat.data.cpu().numpy()) 94 | gallery_label.append(label) 95 | gallery_cam_id.append(cam_id) 96 | gallery_img_path.extend(img_path) 97 | 98 | gallery_feat = np.concatenate(gallery_feat, axis=0) 99 | gallery_label = np.concatenate(gallery_label, axis=0) 100 | gallery_cam_id = np.concatenate(gallery_cam_id, axis=0) 101 | print(gallery_feat.shape) 102 | 103 | save_name = "{}/gallery-{}.mat".format(dir_name, prefix) 104 | sio.savemat(save_name, 105 | {"feat": gallery_feat, 106 | "ids": gallery_label, 107 | "cam_ids": gallery_cam_id, 108 | "img_path": gallery_img_path}) 109 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from layers.loss.am_softmax import AMSoftmaxLoss 2 | from layers.loss.center_loss import CenterLoss 3 | from layers.loss.nca_loss import NCALoss 4 | from layers.loss.nn_loss import NNLoss 5 | from layers.loss.triplet_loss import TripletLoss 6 | from layers.module.exemplar_linear import ExemplarLinear 7 | from layers.module.reverse_grad import ReverseGrad 8 | from layers.module.block_grad import BlockGrad 9 | 10 | __all__ = ['AMSoftmaxLoss', 'TripletLoss', 'NCALoss', 'ReverseGrad', 'BlockGrad', 'CenterLoss', 11 | 'NNLoss', 'ExemplarLinear'] 12 | -------------------------------------------------------------------------------- /layers/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuckyDC/generalizing-reid/c77cb7a8317408e632a1e435d9ab533170bdb37d/layers/loss/__init__.py -------------------------------------------------------------------------------- /layers/loss/am_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AMSoftmaxLoss(nn.Module): 8 | def __init__(self, scale, margin, weight=None, ignore_index=-100, reduction='mean'): 9 | super(AMSoftmaxLoss, self).__init__() 10 | self.weight = weight 11 | self.ignore_index = ignore_index 12 | self.reduction = reduction 13 | self.scale = scale 14 | self.margin = margin 15 | 16 | def forward(self, x, y): 17 | y_onehot = torch.zeros_like(x, device=x.device) 18 | y_onehot.scatter_(1, y.data.view(-1, 1), self.margin) 19 | 20 | out = self.scale * (x - y_onehot) 21 | loss = F.cross_entropy(out, y, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) 22 | 23 | return loss 24 | -------------------------------------------------------------------------------- /layers/loss/center_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CenterLoss(nn.Module): 6 | """Center loss. 7 | 8 | Reference: 9 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | feat_dim (int): feature dimension. 14 | """ 15 | 16 | def __init__(self, num_classes=751, feat_dim=2048): 17 | super(CenterLoss, self).__init__() 18 | self.num_classes = num_classes 19 | self.feat_dim = feat_dim 20 | 21 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 22 | 23 | def forward(self, x, labels): 24 | """ 25 | Args: 26 | x: feature matrix with shape (batch_size, feat_dim). 27 | labels: ground truth labels with shape (num_classes). 28 | """ 29 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 30 | 31 | batch_size = x.size(0) 32 | 33 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 34 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 35 | distmat.addmm_(1, -2, x, self.centers.t()) 36 | 37 | classes = torch.arange(self.num_classes).long() 38 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 39 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 40 | 41 | dist = [] 42 | for i in range(batch_size): 43 | value = distmat[i][mask[i]] 44 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 45 | dist.append(value) 46 | 47 | dist = torch.cat(dist) 48 | loss = dist.mean() 49 | 50 | return loss 51 | 52 | 53 | if __name__ == '__main__': 54 | center_loss = CenterLoss() 55 | features = torch.rand(16, 2048) 56 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 57 | 58 | loss = center_loss(features, targets) 59 | print(loss) 60 | -------------------------------------------------------------------------------- /layers/loss/nca_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class NCALoss(nn.Module): 6 | def __init__(self, dim, reduction='mean'): 7 | super(NCALoss, self).__init__() 8 | 9 | self.dim = dim 10 | self.reduction = reduction 11 | 12 | def forward(self, inputs, k): 13 | values, indices = torch.topk(inputs, k=k, dim=self.dim) 14 | 15 | top_sum = torch.sum(values, dim=self.dim) 16 | loss = - torch.log(top_sum) 17 | 18 | if self.reduction == 'mean': 19 | loss = loss.mean() 20 | elif self.reduction == 'sum': 21 | loss = loss.sum() 22 | 23 | return loss 24 | -------------------------------------------------------------------------------- /layers/loss/nn_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class NNLoss(nn.Module): 6 | def __init__(self, dim, reduction='mean'): 7 | super(NNLoss, self).__init__() 8 | self.dim = dim 9 | self.reduction = reduction 10 | 11 | def forward(self, inputs, k): 12 | top_values, _ = torch.topk(inputs, k=k, dim=self.dim) 13 | loss = -torch.log(top_values).sum(dim=1) / k 14 | 15 | if self.reduction == 'mean': 16 | loss = loss.mean() 17 | elif self.reduction == 'sum': 18 | loss = loss.sum() 19 | return loss 20 | -------------------------------------------------------------------------------- /layers/loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | from torch import nn 8 | 9 | 10 | def normalize(x, axis=-1): 11 | """Normalizing to unit length along the specified dimension. 12 | Args: 13 | x: pytorch Variable 14 | Returns: 15 | x: pytorch Variable, same shape as input 16 | """ 17 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 18 | return x 19 | 20 | 21 | def euclidean_dist(x, y): 22 | """ 23 | Args: 24 | x: pytorch Variable, with shape [m, d] 25 | y: pytorch Variable, with shape [n, d] 26 | Returns: 27 | dist: pytorch Variable, with shape [m, n] 28 | """ 29 | m, n = x.size(0), y.size(0) 30 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 31 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 32 | dist = xx + yy 33 | dist.addmm_(1, -2, x, y.t()) 34 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 35 | return dist 36 | 37 | 38 | def hard_example_mining(dist_mat, labels, return_inds=False): 39 | """For each anchor, find the hardest positive and negative sample. 40 | Args: 41 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 42 | labels: pytorch LongTensor, with shape [N] 43 | return_inds: whether to return the indices. Save time if `False`(?) 44 | Returns: 45 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 46 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 47 | p_inds: pytorch LongTensor, with shape [N]; 48 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 49 | n_inds: pytorch LongTensor, with shape [N]; 50 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 51 | NOTE: Only consider the case in which all labels have same num of samples, 52 | thus we can cope with all anchors in parallel. 53 | """ 54 | 55 | assert len(dist_mat.size()) == 2 56 | assert dist_mat.size(0) == dist_mat.size(1) 57 | N = dist_mat.size(0) 58 | 59 | # shape [N, N] 60 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 61 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 62 | 63 | # `dist_ap` means distance(anchor, positive) 64 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 65 | dist_ap, relative_p_inds = torch.max( 66 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 67 | # `dist_an` means distance(anchor, negative) 68 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 69 | dist_an, relative_n_inds = torch.min( 70 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 71 | # shape [N] 72 | dist_ap = dist_ap.squeeze(1) 73 | dist_an = dist_an.squeeze(1) 74 | 75 | if return_inds: 76 | # shape [N, N] 77 | ind = (labels.new().resize_as_(labels) 78 | .copy_(torch.arange(0, N).long()) 79 | .unsqueeze(0).expand(N, N)) 80 | # shape [N, 1] 81 | p_inds = torch.gather( 82 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 83 | n_inds = torch.gather( 84 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 85 | # shape [N] 86 | p_inds = p_inds.squeeze(1) 87 | n_inds = n_inds.squeeze(1) 88 | return dist_ap, dist_an, p_inds, n_inds 89 | 90 | return dist_ap, dist_an 91 | 92 | 93 | class TripletLoss(object): 94 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 95 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 96 | Loss for Person Re-Identification'.""" 97 | 98 | def __init__(self, margin=None): 99 | self.margin = margin 100 | if margin is not None: 101 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 102 | else: 103 | self.ranking_loss = nn.SoftMarginLoss() 104 | 105 | def __call__(self, global_feat, labels, normalize_feature=False): 106 | if normalize_feature: 107 | global_feat = normalize(global_feat, axis=-1) 108 | dist_mat = euclidean_dist(global_feat, global_feat) 109 | dist_ap, dist_an = hard_example_mining( 110 | dist_mat, labels) 111 | y = dist_an.new().resize_as_(dist_an).fill_(1) 112 | if self.margin is not None: 113 | loss = self.ranking_loss(dist_an, dist_ap, y) 114 | else: 115 | loss = self.ranking_loss(dist_an - dist_ap, y) 116 | return loss, dist_ap, dist_an 117 | -------------------------------------------------------------------------------- /layers/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuckyDC/generalizing-reid/c77cb7a8317408e632a1e435d9ab533170bdb37d/layers/module/__init__.py -------------------------------------------------------------------------------- /layers/module/block_grad.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | 4 | 5 | class BlockGradFunction(Function): 6 | @staticmethod 7 | def forward(ctx, data): 8 | return data 9 | 10 | @staticmethod 11 | def backward(ctx, grad_outputs): 12 | grad = None 13 | 14 | if ctx.needs_input_grad[0]: 15 | grad = grad_outputs.new_zeros(grad_outputs.size()) 16 | 17 | return grad 18 | 19 | 20 | class BlockGrad(nn.Module): 21 | def __init__(self): 22 | super(BlockGrad, self).__init__() 23 | 24 | def forward(self, data): 25 | return BlockGradFunction.apply(data) 26 | -------------------------------------------------------------------------------- /layers/module/exemplar_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.autograd import Function 5 | from torch.autograd.function import once_differentiable 6 | from torch.nn import init 7 | 8 | 9 | class ExemplarLinearFunc(Function): 10 | 11 | @staticmethod 12 | def forward(ctx, input, memory, target, momentum=0.1): 13 | ctx.save_for_backward(memory, input, target) 14 | ctx.momentum = momentum 15 | 16 | return torch.mm(input, memory.t()) 17 | 18 | @staticmethod 19 | @once_differentiable 20 | def backward(ctx, grad_output): 21 | memory, input, target = ctx.saved_tensors 22 | 23 | grad_input = None 24 | if ctx.needs_input_grad[0]: 25 | grad_input = grad_output.mm(memory) 26 | 27 | momentum = ctx.momentum 28 | memory[target] *= momentum 29 | memory[target] += (1 - momentum) * input 30 | memory[target] /= torch.norm(memory[target], p=2, dim=1, keepdim=True) 31 | 32 | return grad_input, None, None, None 33 | 34 | 35 | class ExemplarLinear(nn.Module): 36 | def __init__(self, num_instances, num_features, momentum=0.1): 37 | super(ExemplarLinear, self).__init__() 38 | 39 | self.num_instances = num_instances 40 | self.num_features = num_features 41 | self.momentum = momentum 42 | 43 | self.register_buffer('memory', torch.Tensor(num_instances, num_features)) 44 | 45 | self.reset_buffers() 46 | 47 | def set_momentum(self, value): 48 | self.momentum = value 49 | 50 | def reset_buffers(self): 51 | init.normal_(self.memory, std=0.001) 52 | 53 | def forward(self, x, targets): 54 | return ExemplarLinearFunc.apply(x, self.memory, targets, self.momentum) 55 | -------------------------------------------------------------------------------- /layers/module/reverse_grad.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | 4 | 5 | class ReverseGradFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, data, alpha=1.0): 9 | ctx.alpha = alpha 10 | return data 11 | 12 | @staticmethod 13 | def backward(ctx, grad_outputs): 14 | grad = None 15 | 16 | if ctx.needs_input_grad[0]: 17 | grad = -ctx.alpha * grad_outputs 18 | 19 | return grad, None 20 | 21 | 22 | class ReverseGrad(nn.Module): 23 | def __init__(self): 24 | super(ReverseGrad, self).__init__() 25 | 26 | def forward(self, x, alpha=1.0): 27 | return ReverseGradFunction.apply(x, alpha) 28 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from layers import * 8 | from models.resnet import resnet50 9 | from utils.calc_acc import calc_acc 10 | 11 | 12 | class Model(nn.Module): 13 | def __init__(self, num_classes=None, drop_last_stride=False, joint_training=False, mix=False, neighbor_mode=1, 14 | **kwargs): 15 | super(Model, self).__init__() 16 | 17 | self.drop_last_stride = drop_last_stride 18 | self.joint_training = joint_training 19 | self.mix = mix 20 | self.neighbor_mode = neighbor_mode 21 | 22 | self.backbone = resnet50(pretrained=True, drop_last_stride=drop_last_stride) 23 | self.bn_neck = nn.BatchNorm1d(2048) 24 | self.bn_neck.bias.requires_grad_(False) 25 | 26 | if kwargs.get('eval'): 27 | return 28 | 29 | self.scale = kwargs.get('scale') 30 | 31 | # ----------- tasks for source domain -------------- 32 | if num_classes is not None: 33 | self.classifier = nn.Linear(2048, num_classes, bias=False) 34 | self.id_loss = nn.CrossEntropyLoss(ignore_index=-1) 35 | 36 | # ----------- tasks for target domain -------------- 37 | if self.joint_training: 38 | cam_ids = kwargs.get('cam_ids') 39 | num_instances = kwargs.get('num_instances', None) 40 | self.neighbor_eps = kwargs.get('neighbor_eps') 41 | 42 | # identities captured by each camera 43 | uid2cam = zip(range(num_instances), cam_ids) 44 | self.cam2uid = defaultdict(list) 45 | for uid, cam in uid2cam: 46 | self.cam2uid[cam].append(uid) 47 | 48 | # components for neighborhood consistency 49 | self.exemplar_linear = ExemplarLinear(num_instances, 2048) 50 | self.nn_loss = NNLoss(dim=1) 51 | 52 | alpha = kwargs.get('alpha') 53 | self.beta_dist = torch.distributions.beta.Beta(alpha, alpha) 54 | 55 | self.lambd_st = None 56 | 57 | @staticmethod 58 | def mix_source_target(inputs, beta_dist): 59 | half_batch_size = inputs.size(0) // 2 60 | source_input = inputs[:half_batch_size] 61 | target_input = inputs[half_batch_size:] 62 | 63 | lambd = beta_dist.sample().item() 64 | mixed_input = lambd * source_input + (1 - lambd) * target_input 65 | return mixed_input, lambd 66 | 67 | def forward(self, inputs, labels=None, **kwargs): 68 | if not self.training: 69 | global_feat = self.backbone(inputs) 70 | global_feat = self.bn_neck(global_feat) 71 | return global_feat 72 | else: 73 | batch_size = inputs.size(0) 74 | epoch = kwargs.get('epoch') 75 | 76 | if self.joint_training and self.mix and epoch > 1: 77 | mixed_st, self.lambda_st = self.mix_source_target(inputs, self.beta_dist) 78 | inputs = torch.cat([mixed_st, inputs[batch_size // 2:]], dim=0) 79 | 80 | return self.train_forward(inputs, labels, batch_size, **kwargs) 81 | 82 | def train_forward(self, inputs, labels, batch_size, **kwargs): 83 | epoch = kwargs.get('epoch') 84 | 85 | inputs = self.backbone(inputs) 86 | 87 | if not self.joint_training: # single domain 88 | inputs = self.bn_neck(inputs) 89 | return self.source_train_forward(inputs, labels) 90 | else: # cross domain 91 | half_batch_size = batch_size // 2 92 | label_s = labels[:half_batch_size] 93 | input_t = inputs[-half_batch_size:] 94 | 95 | # source task or mixed task 96 | input_s = inputs[:half_batch_size] 97 | feat_s = F.batch_norm(input_s, None, None, self.bn_neck.weight, self.bn_neck.bias, True) 98 | if not self.mix or epoch <= 1: 99 | loss, metric = self.source_train_forward(feat_s, label_s) 100 | else: 101 | loss, metric = self.mixed_st_forward(feat_s, label_s, **kwargs) 102 | 103 | # target task 104 | feat_t = self.bn_neck(input_t) 105 | target_loss, target_metric = self.target_train_forward(feat_t, **kwargs) 106 | 107 | # summarize loss and metric 108 | loss += target_loss 109 | metric.update(target_metric) 110 | 111 | return loss, metric 112 | 113 | # Tasks for source domain 114 | def source_train_forward(self, inputs, labels): 115 | metric_dict = {} 116 | 117 | cls_score = self.classifier(inputs) 118 | loss = self.id_loss(cls_score.float(), labels) 119 | 120 | metric_dict.update({'id_ce': loss.data, 121 | 'id_acc': calc_acc(cls_score.data, labels.data, ignore_index=-1)}) 122 | 123 | return loss, metric_dict 124 | 125 | # Tasks for target domain 126 | def target_train_forward(self, inputs, **kwargs): 127 | metric_dict = {} 128 | 129 | target_batch_size = inputs.size(0) 130 | 131 | epoch = kwargs.get('epoch') 132 | img_ids = kwargs.get('img_ids')[-target_batch_size:] 133 | cam_ids = kwargs.get('cam_ids')[-target_batch_size:] 134 | 135 | # inputs = self.dropout(inputs) 136 | feat = F.normalize(inputs) 137 | 138 | # Set updating momentum of the exemplar memory. 139 | # Note the momentum must be 0 at the first iteration. 140 | mom = 0.6 141 | self.exemplar_linear.set_momentum(mom if epoch > 1 else 0) 142 | sim = self.exemplar_linear(feat, img_ids).float() 143 | 144 | # ----------------------Neighborhood Constraint------------------------- # 145 | 146 | # Camera-agnostic neighborhood loss 147 | if self.neighbor_mode == 0: 148 | loss = self.cam_agnostic_eps_nn_loss(sim, img_ids) 149 | metric_dict.update({'neighbor': loss.data}) 150 | 151 | weight = 0.1 if epoch > 10 else 0 152 | loss = weight * loss 153 | 154 | # Camera-aware neighborhood loss (intra_loss and inter_loss) 155 | elif self.neighbor_mode == 1: 156 | intra_loss, inter_loss = self.cam_aware_eps_nn_loss(sim, cam_ids, img_ids=img_ids, epoch=epoch) 157 | metric_dict.update({'intra': intra_loss.data, 'inter': inter_loss.data}) 158 | 159 | intra_weight = 1.0 if epoch > 10 else 0 160 | inter_weight = 0.5 if epoch > 30 else 0 161 | 162 | loss = intra_weight * intra_loss + inter_weight * inter_loss 163 | 164 | return loss, metric_dict 165 | 166 | def mixed_st_forward(self, inputs, labels, **kwargs): 167 | img_ids = kwargs.get('img_ids')[-inputs.size(0):] 168 | agent = self.exemplar_linear.memory[img_ids] 169 | 170 | cls_score = F.linear(inputs, self.classifier.weight) 171 | 172 | sim_agent = inputs.mul(agent).sum(dim=1, keepdim=True) 173 | sim_agent = sim_agent.mul(self.classifier.weight.data[labels].norm(dim=1, keepdim=True)) 174 | cls_score = torch.cat([cls_score, sim_agent], dim=1).float() 175 | 176 | virtual_label = labels.clone().fill_(cls_score.size(1) - 1) 177 | loss = self.lambda_st * self.id_loss(cls_score, labels) 178 | loss += (1 - self.lambda_st) * self.id_loss(cls_score, virtual_label) 179 | 180 | metric = {'mix_st': loss.data} 181 | 182 | return loss, metric 183 | 184 | def cam_aware_eps_nn_loss(self, sim, cam_ids, **kwargs): 185 | img_ids = kwargs.get('img_ids') 186 | 187 | sim_exp = torch.exp(sim * self.scale) 188 | 189 | # calculate mask for intra-camera matching and inter-camera matching 190 | mask_instance, mask_intra, mask_inter = self.compute_mask(sim.size(), img_ids, cam_ids, sim.device) 191 | 192 | # intra-camera neighborhood loss 193 | sim_intra = (sim.data + 1) * mask_intra * (1 - mask_instance) - 1 194 | nearest_intra = sim_intra.max(dim=1, keepdim=True)[0] 195 | neighbor_mask_intra = torch.gt(sim_intra, nearest_intra * self.neighbor_eps) 196 | num_neighbor_intra = neighbor_mask_intra.sum(dim=1) 197 | 198 | sim_exp_intra = sim_exp * mask_intra 199 | score_intra = sim_exp_intra / sim_exp_intra.sum(dim=1, keepdim=True) 200 | score_intra = score_intra.clamp_min(1e-5) 201 | intra_loss = -score_intra.log().mul(neighbor_mask_intra).sum(dim=1).div(num_neighbor_intra).mean() 202 | intra_loss -= score_intra.masked_select(mask_instance.bool()).log().mean() 203 | 204 | # inter-camera neighborhood loss 205 | sim_inter = (sim.data + 1) * mask_inter - 1 206 | nearest_inter = sim_inter.max(dim=1, keepdim=True)[0] 207 | neighbor_mask_inter = torch.gt(sim_inter, nearest_inter * self.neighbor_eps) 208 | num_neighbor_inter = neighbor_mask_inter.sum(dim=1) 209 | 210 | sim_exp_inter = mask_inter * sim_exp 211 | score_inter = sim_exp_inter / sim_exp_inter.sum(dim=1, keepdim=True) 212 | score_inter = score_inter.clamp_min(1e-5) 213 | inter_loss = -score_inter.log().mul(neighbor_mask_inter).sum(dim=1).div(num_neighbor_inter).mean() 214 | 215 | return intra_loss, inter_loss 216 | 217 | def cam_agnostic_eps_nn_loss(self, sim, img_ids): 218 | mask_instance = torch.zeros_like(sim) 219 | mask_instance[torch.arange(sim.size(0)), img_ids] = 1 220 | 221 | sim_neighbor = (sim.data + 1) * (1 - mask_instance) - 1 222 | nearest = sim_neighbor.max(dim=1, keepdim=True)[0] 223 | neighbor_mask = torch.gt(sim_neighbor, nearest * self.neighbor_eps) 224 | num_neighbor = neighbor_mask.sum(dim=1) 225 | 226 | score = F.log_softmax(sim * self.scale, dim=1) 227 | loss = -score.mul(neighbor_mask).sum(dim=1).div(num_neighbor).mean() 228 | loss -= score.masked_select(mask_instance.bool()).mean() 229 | 230 | return loss 231 | 232 | def compute_mask(self, size, img_ids, cam_ids, device): 233 | mask_inter = torch.ones(size, device=device) 234 | for i, cam in enumerate(cam_ids.tolist()): 235 | intra_cam_ids = self.cam2uid[cam] 236 | mask_inter[i, intra_cam_ids] = 0 237 | 238 | mask_intra = 1 - mask_inter 239 | mask_instance = torch.zeros(size, device=device) 240 | mask_instance[torch.arange(size[0]), img_ids] = 1 241 | 242 | return mask_instance, mask_intra, mask_inter 243 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models.utils import load_state_dict_from_url 3 | 4 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 5 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 6 | 7 | model_urls = { 8 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 9 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 10 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 11 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 12 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 13 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 14 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, groups=groups, bias=False, dilation=dilation) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 33 | base_width=64, dilation=1, norm_layer=None): 34 | super(BasicBlock, self).__init__() 35 | if norm_layer is None: 36 | norm_layer = nn.BatchNorm2d 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | if dilation > 1: 40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 41 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 42 | self.conv1 = conv3x3(inplanes, planes, stride) 43 | self.bn1 = norm_layer(planes) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.conv2 = conv3x3(planes, planes) 46 | self.bn2 = norm_layer(planes) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | 60 | if self.downsample is not None: 61 | identity = self.downsample(x) 62 | 63 | out += identity 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | 69 | class Bottleneck(nn.Module): 70 | expansion = 4 71 | 72 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 73 | base_width=64, dilation=1, norm_layer=None): 74 | super(Bottleneck, self).__init__() 75 | if norm_layer is None: 76 | norm_layer = nn.BatchNorm2d 77 | width = int(planes * (base_width / 64.)) * groups 78 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 79 | self.conv1 = conv1x1(inplanes, width) 80 | self.bn1 = norm_layer(width) 81 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 82 | self.bn2 = norm_layer(width) 83 | self.conv3 = conv1x1(width, planes * self.expansion) 84 | self.bn3 = norm_layer(planes * self.expansion) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | identity = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | identity = self.downsample(x) 105 | 106 | out += identity 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, block, layers, zero_init_residual=False, 115 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 116 | norm_layer=None, drop_last_stride=False): 117 | super(ResNet, self).__init__() 118 | if norm_layer is None: 119 | norm_layer = nn.BatchNorm2d 120 | self._norm_layer = norm_layer 121 | 122 | self.inplanes = 64 123 | self.dilation = 1 124 | if replace_stride_with_dilation is None: 125 | # each element in the tuple indicates if we should replace 126 | # the 2x2 stride with a dilated convolution instead 127 | replace_stride_with_dilation = [False, False, False] 128 | if len(replace_stride_with_dilation) != 3: 129 | raise ValueError("replace_stride_with_dilation should be None " 130 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 131 | self.groups = groups 132 | self.base_width = width_per_group 133 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 134 | bias=False) 135 | self.bn1 = norm_layer(self.inplanes) 136 | self.relu = nn.ReLU(inplace=True) 137 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 138 | self.layer1 = self._make_layer(block, 64, layers[0]) 139 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 140 | dilate=replace_stride_with_dilation[0]) 141 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 142 | dilate=replace_stride_with_dilation[1]) 143 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1 if drop_last_stride else 2, 144 | dilate=replace_stride_with_dilation[2]) 145 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 146 | 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 150 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 151 | nn.init.constant_(m.weight, 1) 152 | nn.init.constant_(m.bias, 0) 153 | 154 | # Zero-initialize the last BN in each residual branch, 155 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 156 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 157 | if zero_init_residual: 158 | for m in self.modules(): 159 | if isinstance(m, Bottleneck): 160 | nn.init.constant_(m.bn3.weight, 0) 161 | elif isinstance(m, BasicBlock): 162 | nn.init.constant_(m.bn2.weight, 0) 163 | 164 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 165 | norm_layer = self._norm_layer 166 | downsample = None 167 | previous_dilation = self.dilation 168 | if dilate: 169 | self.dilation *= stride 170 | stride = 1 171 | if stride != 1 or self.inplanes != planes * block.expansion: 172 | downsample = nn.Sequential( 173 | conv1x1(self.inplanes, planes * block.expansion, stride), 174 | norm_layer(planes * block.expansion), 175 | ) 176 | 177 | layers = [] 178 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 179 | self.base_width, previous_dilation, norm_layer)) 180 | self.inplanes = planes * block.expansion 181 | for _ in range(1, blocks): 182 | layers.append(block(self.inplanes, planes, groups=self.groups, 183 | base_width=self.base_width, dilation=self.dilation, 184 | norm_layer=norm_layer)) 185 | 186 | return nn.Sequential(*layers) 187 | 188 | def forward(self, x): 189 | x = self.conv1(x) 190 | x = self.bn1(x) 191 | x = self.relu(x) 192 | x = self.maxpool(x) 193 | 194 | x = self.layer1(x) 195 | x = self.layer2(x) 196 | x = self.layer3(x) 197 | x = self.layer4(x) 198 | 199 | x = self.avgpool(x) 200 | x = x.reshape(x.size(0), -1) 201 | 202 | return x 203 | 204 | 205 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 206 | model = ResNet(block, layers, **kwargs) 207 | if pretrained: 208 | state_dict = load_state_dict_from_url(model_urls[arch], 209 | progress=progress) 210 | model.load_state_dict(state_dict, strict=False) 211 | return model 212 | 213 | 214 | def resnet18(pretrained=False, progress=True, **kwargs): 215 | """Constructs a ResNet-18 model. 216 | 217 | Args: 218 | pretrained (bool): If True, returns a model pre-trained on ImageNet 219 | progress (bool): If True, displays a progress bar of the download to stderr 220 | """ 221 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 222 | **kwargs) 223 | 224 | 225 | def resnet34(pretrained=False, progress=True, **kwargs): 226 | """Constructs a ResNet-34 model. 227 | 228 | Args: 229 | pretrained (bool): If True, returns a model pre-trained on ImageNet 230 | progress (bool): If True, displays a progress bar of the download to stderr 231 | """ 232 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 233 | **kwargs) 234 | 235 | 236 | def resnet50(pretrained=False, progress=True, **kwargs): 237 | """Constructs a ResNet-50 model. 238 | 239 | Args: 240 | pretrained (bool): If True, returns a model pre-trained on ImageNet 241 | progress (bool): If True, displays a progress bar of the download to stderr 242 | """ 243 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 244 | **kwargs) 245 | 246 | 247 | def resnet101(pretrained=False, progress=True, **kwargs): 248 | """Constructs a ResNet-101 model. 249 | 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet152(pretrained=False, progress=True, **kwargs): 259 | """Constructs a ResNet-152 model. 260 | 261 | Args: 262 | pretrained (bool): If True, returns a model pre-trained on ImageNet 263 | progress (bool): If True, displays a progress bar of the download to stderr 264 | """ 265 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 266 | **kwargs) 267 | 268 | 269 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 270 | """Constructs a ResNeXt-50 32x4d model. 271 | 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | kwargs['groups'] = 32 277 | kwargs['width_per_group'] = 4 278 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 279 | pretrained, progress, **kwargs) 280 | 281 | 282 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 283 | """Constructs a ResNeXt-101 32x8d model. 284 | 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | progress (bool): If True, displays a progress bar of the download to stderr 288 | """ 289 | kwargs['groups'] = 32 290 | kwargs['width_per_group'] = 8 291 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 292 | pretrained, progress, **kwargs) 293 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pprint 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import yaml 8 | from apex import amp 9 | from apex import parallel 10 | from tensorboardX import SummaryWriter 11 | from torch import optim 12 | 13 | from data import get_cross_domain_train_loader 14 | from data import get_test_loader 15 | from data import get_train_loader 16 | from engine import get_trainer 17 | from models.model import Model 18 | 19 | 20 | def train(cfg): 21 | num_gpus = torch.cuda.device_count() 22 | if num_gpus > 1: 23 | torch.distributed.init_process_group(backend="nccl", world_size=num_gpus) 24 | 25 | # set logger 26 | log_dir = os.path.join("logs/", cfg.source_dataset, cfg.prefix) 27 | if not os.path.isdir(log_dir): 28 | os.makedirs(log_dir, exist_ok=True) 29 | 30 | logging.basicConfig(format="%(asctime)s %(message)s", 31 | filename=log_dir + "/" + "log.txt", 32 | filemode="a") 33 | 34 | logger = logging.getLogger() 35 | logger.setLevel(logging.INFO) 36 | stream_handler = logging.StreamHandler() 37 | stream_handler.setLevel(logging.INFO) 38 | logger.addHandler(stream_handler) 39 | 40 | # writer = SummaryWriter(log_dir, purge_step=0) 41 | 42 | if dist.is_initialized() and dist.get_rank() != 0: 43 | 44 | logger = writer = None 45 | else: 46 | logger.info(pprint.pformat(cfg)) 47 | 48 | # training data loader 49 | if not cfg.joint_training: # single domain 50 | train_loader = get_train_loader(root=os.path.join(cfg.source.root, cfg.source.train), 51 | batch_size=cfg.batch_size, 52 | image_size=cfg.image_size, 53 | random_flip=cfg.random_flip, 54 | random_crop=cfg.random_crop, 55 | random_erase=cfg.random_erase, 56 | color_jitter=cfg.color_jitter, 57 | padding=cfg.padding, 58 | num_workers=4) 59 | else: # cross domain 60 | source_root = os.path.join(cfg.source.root, cfg.source.train) 61 | target_root = os.path.join(cfg.target.root, cfg.target.train) 62 | 63 | train_loader = get_cross_domain_train_loader(source_root=source_root, 64 | target_root=target_root, 65 | batch_size=cfg.batch_size, 66 | random_flip=cfg.random_flip, 67 | random_crop=cfg.random_crop, 68 | random_erase=cfg.random_erase, 69 | color_jitter=cfg.color_jitter, 70 | padding=cfg.padding, 71 | image_size=cfg.image_size, 72 | num_workers=8) 73 | 74 | # evaluation data loader 75 | query_loader = None 76 | gallery_loader = None 77 | if cfg.eval_interval > 0: 78 | query_loader = get_test_loader(root=os.path.join(cfg.target.root, cfg.target.query), 79 | batch_size=512, 80 | image_size=cfg.image_size, 81 | num_workers=4) 82 | 83 | gallery_loader = get_test_loader(root=os.path.join(cfg.target.root, cfg.target.gallery), 84 | batch_size=512, 85 | image_size=cfg.image_size, 86 | num_workers=4) 87 | 88 | # model 89 | num_classes = cfg.source.num_id 90 | num_cam = cfg.source.num_cam + cfg.target.num_cam 91 | cam_ids = train_loader.dataset.target_dataset.cam_ids if cfg.joint_training else train_loader.dataset.cam_ids 92 | num_instances = len(train_loader.dataset.target_dataset) if cfg.joint_training else None 93 | 94 | model = Model(num_classes=num_classes, 95 | drop_last_stride=cfg.drop_last_stride, 96 | joint_training=cfg.joint_training, 97 | num_instances=num_instances, 98 | cam_ids=cam_ids, 99 | num_cam=num_cam, 100 | neighbor_mode=cfg.neighbor_mode, 101 | neighbor_eps=cfg.neighbor_eps, 102 | scale=cfg.scale, 103 | mix=cfg.mix, 104 | alpha=cfg.alpha) 105 | 106 | model.cuda() 107 | 108 | # optimizer 109 | ft_params = model.backbone.parameters() 110 | new_params = [param for name, param in model.named_parameters() if not name.startswith("backbone.")] 111 | param_groups = [{'params': ft_params, 'lr': cfg.ft_lr}, 112 | {'params': new_params, 'lr': cfg.new_params_lr}] 113 | 114 | optimizer = optim.SGD(param_groups, momentum=0.9, weight_decay=cfg.wd) 115 | 116 | # convert model for mixed precision distributed training 117 | 118 | model, optimizer = amp.initialize(model, optimizer, enabled=cfg.fp16, opt_level="O2") 119 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, 120 | milestones=cfg.lr_step, 121 | gamma=0.1) 122 | 123 | if dist.is_initialized(): 124 | model = parallel.DistributedDataParallel(model, delay_allreduce=True) 125 | 126 | # engine 127 | checkpoint_dir = os.path.join("checkpoints", cfg.source_dataset, cfg.prefix) 128 | engine = get_trainer(model=model, 129 | optimizer=optimizer, 130 | lr_scheduler=lr_scheduler, 131 | logger=logger, 132 | # writer=writer, 133 | non_blocking=True, 134 | log_period=cfg.log_period, 135 | save_interval=10, 136 | save_dir=checkpoint_dir, 137 | prefix=cfg.prefix, 138 | eval_interval=cfg.eval_interval, 139 | query_loader=query_loader, 140 | gallery_loader=gallery_loader) 141 | 142 | # training 143 | engine.run(train_loader, max_epochs=cfg.num_epoch) 144 | 145 | if dist.is_initialized(): 146 | dist.destroy_process_group() 147 | 148 | 149 | if __name__ == '__main__': 150 | import argparse 151 | import random 152 | import numpy as np 153 | from yacs.config import CfgNode 154 | from configs.default import strategy_cfg 155 | from configs.default import dataset_cfg 156 | 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("--cfg", type=str, default="configs/market2duke.yml") 159 | parser.add_argument("--local_rank", type=int, default=None) 160 | args = parser.parse_args() 161 | 162 | # enable cudnn backend 163 | if args.local_rank is not None: 164 | torch.cuda.set_device(args.local_rank) 165 | torch.backends.cudnn.benchmark = True 166 | 167 | # load configuration 168 | customized_cfg = yaml.load(open(args.cfg, "r"), Loader=yaml.SafeLoader) 169 | 170 | cfg = strategy_cfg 171 | cfg.merge_from_file(args.cfg) 172 | 173 | source_cfg = dataset_cfg.get(cfg.source_dataset) 174 | target_cfg = dataset_cfg.get(cfg.target_dataset) 175 | 176 | cfg.source = CfgNode() 177 | for k, v in source_cfg.items(): 178 | cfg.source[k] = v 179 | 180 | cfg.target = CfgNode() 181 | for k, v in target_cfg.items(): 182 | cfg.target[k] = v 183 | 184 | # split data batch into all devices evenly 185 | cfg.batch_size = cfg.batch_size // torch.cuda.device_count() 186 | 187 | cfg.freeze() 188 | 189 | train(cfg) 190 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # generate a port number randomly 4 | MASTER_PORT=$RANDOM 5 | let MASTER_PORT+=10000 6 | 7 | # check whether the port has been occupied 8 | ret=$(netstat -tlpn 2>/dev/null | grep "\b${MASTER_PORT}\b" | awk '{print $0}') 9 | 10 | while [[ "$ret" != "" ]] 11 | do 12 | MASTER_PORT=$RANDOM 13 | let MASTER_PORT+=10000 14 | ret=$(netstat -tlpn 2>/dev/null | grep "\b${MASTER_PORT}\b" | awk '{print $0}') 15 | done 16 | 17 | # get the number of process 18 | NUM_PROC=0 19 | for _ in ${1//,/ } 20 | do 21 | let NUM_PROC++ 22 | done 23 | 24 | CORES=`lscpu | grep Core | awk '{print $4}'` 25 | SOCKETS=`lscpu | grep Socket | awk '{print $2}'` 26 | TOTAL_CORES=`expr $CORES \* $SOCKETS` 27 | 28 | # launch the distributed program 29 | OMP_NUM_THREADS=$TOTAL_CORES \ 30 | KMP_AFFINITY=granularity=fine,compact,1,0 \ 31 | KMP_BLOCKTIME=1 \ 32 | CUDA_VISIBLE_DEVICES=$1 \ 33 | python3 -m torch.distributed.launch --nproc_per_node=${NUM_PROC} --master_port ${MASTER_PORT} train.py --cfg $2 -------------------------------------------------------------------------------- /utils/calc_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calc_acc(logits, label, ignore_index=-100, mode="multiclass"): 5 | if mode == "binary": 6 | indices = torch.round(logits).type(label.type()) 7 | elif mode == "multiclass": 8 | indices = torch.max(logits, dim=1)[1] 9 | 10 | if label.size() == logits.size(): 11 | ignore = 1 - torch.round(label.sum(dim=1)) 12 | label = torch.max(label, dim=1)[1] 13 | else: 14 | ignore = torch.eq(label, ignore_index).view(-1) 15 | 16 | correct = torch.eq(indices, label).view(-1) 17 | num_correct = torch.sum(correct) 18 | num_examples = logits.shape[0] - ignore.sum() 19 | 20 | return num_correct.float() / num_examples.float() 21 | -------------------------------------------------------------------------------- /utils/curve.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | x=[0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95] 5 | 6 | d2m_map = [0.233131, 0.317582, 0.486466, 0.656447, 0.715028, 0.673071, 0.585358, 0.527190] 7 | d2m_r1 = [0.472981, 0.570962, 0.724762, 0.843230, 0.880641, 0.880344, 0.855701, 0.834620] 8 | m2d_map = [0.182909, 0.317147, 0.448963, 0.618837, 0.652238, 0.650040, 0.596468, 0.471318] 9 | m2d_r1 = [0.315978, 0.494614, 0.640036, 0.770646, 0.794883, 0.798923, 0.765709, 0.717684] 10 | 11 | d2m_map = np.array(d2m_map) * 100 12 | d2m_r1 = np.array(d2m_r1) * 100 13 | m2d_map = np.array(m2d_map) * 100 14 | m2d_r1 = np.array(m2d_r1) * 100 15 | 16 | fig = plt.figure() 17 | 18 | ax_1 = plt.subplot(121) 19 | fig.add_subplot(ax_1) 20 | 21 | ax_1.set_xticks(x) 22 | l1 = ax_1.plot(x,d2m_r1, marker="o", linewidth=2, label=r"Duke$\rightarrow$ Market")[0] 23 | l2 = ax_1.plot(x,m2d_r1, marker="o", linewidth=2, label=r"Market$\rightarrow$ Duke")[0] 24 | #ax_1.set_ylim(70,95) 25 | ax_1.set_xlabel(r'$\alpha$') 26 | ax_1.set_ylabel('Rank-1 accuracy (%)') 27 | ax_1.grid() 28 | 29 | 30 | ax_2 = plt.subplot(122) 31 | fig.add_subplot(ax_2) 32 | 33 | ax_2.set_xticks(x) 34 | ax_2.plot(x,d2m_map, marker="o", linewidth=2,label=r"Duke$\rightarrow$ Market") 35 | ax_2.plot(x,m2d_map, marker="o", linewidth=2,label=r"Market$\rightarrow$ Duke") 36 | #ax_2.set_ylim(50,70) 37 | ax_2.set_xlabel(r'$\alpha$') 38 | ax_2.set_ylabel('mAP (%)') 39 | ax_2.grid() 40 | 41 | fig.legend([l1, l2], labels=[r"Duke$\rightarrow$ Market", r"Market$\rightarrow$ Duke"], ncol=2, loc=9, bbox_to_anchor=(0, 1.02, 1, 0)) 42 | 43 | plt.show() 44 | 45 | 46 | -------------------------------------------------------------------------------- /utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | 4 | def reduce_tensor(tensor): 5 | tensor = tensor.clone() 6 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 7 | tensor /= dist.get_world_size() 8 | 9 | return tensor 10 | -------------------------------------------------------------------------------- /utils/eval_cmc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numba 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | 10 | @numba.jit 11 | def compute_ap(good_index, junk_index, sort_index): 12 | cmc = np.zeros((sort_index.shape[0],)) 13 | n_good = good_index.shape[0] 14 | 15 | old_recall = 0 16 | old_precision = 1.0 17 | ap = 0 18 | intersect_size = 0 19 | j = 0 20 | good_now = 0 21 | n_junk = 0 22 | for i in range(sort_index.shape[0]): 23 | flag = 0 24 | if np.any(good_index == sort_index[i]): 25 | cmc[i - n_junk:] = 1 26 | flag = 1 27 | good_now = good_now + 1 28 | 29 | if np.any(junk_index == sort_index[i]): 30 | n_junk = n_junk + 1 31 | continue 32 | 33 | if flag == 1: 34 | intersect_size = intersect_size + 1 35 | 36 | recall = intersect_size / n_good 37 | precision = intersect_size / (j + 1) 38 | ap = ap + (recall - old_recall) * ((old_precision + precision) / 2) 39 | old_recall = recall 40 | old_precision = precision 41 | j = j + 1 42 | 43 | if good_now == n_good: 44 | break 45 | 46 | return ap, cmc 47 | 48 | 49 | def eval_feature(query_features, gallery_features, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids, device): 50 | if isinstance(gallery_features, np.ndarray): 51 | gallery_features = torch.from_numpy(gallery_features) 52 | 53 | if isinstance(query_features, np.ndarray): 54 | query_features = torch.from_numpy(query_features) 55 | 56 | gallery_features = gallery_features.to(device) 57 | query_features = query_features.to(device) 58 | 59 | num_query = query_ids.shape[0] 60 | num_gallery = gallery_ids.shape[0] 61 | 62 | gallery_features = F.normalize(gallery_features, p=2, dim=1) 63 | query_features = F.normalize(query_features, p=2, dim=1) 64 | 65 | dist_array = -torch.mm(query_features, gallery_features.transpose(0, 1)).cpu().numpy() 66 | 67 | ap = np.zeros((num_query,)) # average precision 68 | cmc = np.zeros((num_query, num_gallery)) 69 | 70 | index = np.arange(num_gallery) 71 | for i in tqdm(range(num_query)): 72 | good_flag = np.logical_and(np.not_equal(gallery_cam_ids, query_cam_ids[i]), np.equal(gallery_ids, query_ids[i])) 73 | junk_flag_1 = np.equal(gallery_ids, -1) 74 | junk_flag_2 = np.logical_and(np.equal(gallery_cam_ids, query_cam_ids[i]), 75 | np.equal(gallery_ids, query_ids[i])) 76 | 77 | good_index = index[good_flag] 78 | junk_index = index[np.logical_or(junk_flag_1, junk_flag_2)] 79 | 80 | dist = dist_array[i] 81 | 82 | sort_index = np.argsort(dist) 83 | 84 | ap[i], cmc[i, :] = compute_ap(good_index, junk_index, sort_index) 85 | 86 | mAP = np.mean(ap) 87 | r1 = np.mean(cmc, axis=0)[0] 88 | r5 = np.mean(np.clip(np.sum(cmc[:, :5], axis=1), 0, 1), axis=0) 89 | r10 = np.mean(np.clip(np.sum(cmc[:, :10], axis=1), 0, 1), axis=0) 90 | 91 | logging.info('mAP = %f , r1 precision = %f , r5 precision = %f , r10 precision = %f' % (mAP, r1, r5, r10)) 92 | 93 | return mAP, r1, r5, r10 94 | 95 | 96 | def eval_rank_list(rank_list, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 97 | num_query = len(query_ids) 98 | num_gallery = len(gallery_ids) 99 | 100 | ap = np.zeros((num_query,)) # average precision 101 | cmc = np.zeros((num_query, num_gallery)) 102 | for i in tqdm(range(num_query)): 103 | index = np.arange(num_gallery) 104 | good_flag = np.logical_and(np.not_equal(gallery_cam_ids, query_cam_ids[i]), np.equal(gallery_ids, query_ids[i])) 105 | junk_flag_1 = np.equal(gallery_ids, -1) 106 | junk_flag_2 = np.logical_and(np.equal(gallery_cam_ids, query_cam_ids[i]), 107 | np.equal(gallery_ids, query_ids[i])) 108 | 109 | good_index = index[good_flag] 110 | junk_index = index[np.logical_or(junk_flag_1, junk_flag_2)] 111 | 112 | sort_index = rank_list[i] 113 | 114 | ap[i], cmc[i, :] = compute_ap(good_index, junk_index, sort_index) 115 | 116 | mAP = np.mean(ap) 117 | r1 = np.mean(cmc, axis=0)[0] 118 | r5 = np.mean(np.clip(np.sum(cmc[:, :5], axis=1), 0, 1), axis=0) 119 | r10 = np.mean(np.clip(np.sum(cmc[:, :10], axis=1), 0, 1), axis=0) 120 | 121 | logging.info('mAP = %f , r1 precision = %f , r5 precision = %f , r10 precision = %f' % (mAP, r1, r5, r10)) 122 | 123 | return mAP, r1, r5, r10 124 | -------------------------------------------------------------------------------- /utils/eval_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from data import get_test_loader 6 | from utils.eval_cmc import eval_feature 7 | 8 | 9 | def eval_model(model, dataset_config, image_size, device): 10 | # extract query feature 11 | query = get_test_loader(root=os.path.join(dataset_config.root, dataset_config.query), 12 | batch_size=512, 13 | image_size=image_size, 14 | num_workers=16) 15 | 16 | query_feat = [] 17 | query_label = [] 18 | query_cam_id = [] 19 | for data, label, cam_id, _ in query: 20 | feat = model(data.cuda(non_blocking=True)) 21 | 22 | query_feat.append(feat.data.cpu().numpy()) 23 | query_label.append(label.data.cpu().numpy()) 24 | query_cam_id.append(cam_id.data.cpu().numpy()) 25 | 26 | query_feat = np.concatenate(query_feat, axis=0) 27 | query_label = np.concatenate(query_label, axis=0) 28 | query_cam_id = np.concatenate(query_cam_id, axis=0) 29 | 30 | # extract gallery feature 31 | gallery = get_test_loader(root=os.path.join(dataset_config.root, dataset_config.gallery), 32 | batch_size=512, 33 | image_size=image_size, 34 | num_workers=16) 35 | 36 | gallery_feat = [] 37 | gallery_label = [] 38 | gallery_cam_id = [] 39 | for data, label, cam_id, _ in gallery: 40 | feat = model(data.cuda(non_blocking=True)) 41 | 42 | gallery_feat.append(feat.data.cpu().numpy()) 43 | gallery_label.append(label) 44 | gallery_cam_id.append(cam_id) 45 | 46 | gallery_feat = np.concatenate(gallery_feat, axis=0) 47 | gallery_label = np.concatenate(gallery_label, axis=0) 48 | gallery_cam_id = np.concatenate(gallery_cam_id, axis=0) 49 | 50 | mAP, r1, r5, r10 = eval_feature(query_feat, gallery_feat, query_label, query_cam_id, gallery_label, gallery_cam_id, 51 | device) 52 | print('mAP = %f , r1 precision = %f , r5 precision = %f , r10 precision = %f' % (mAP, r1, r5, r10)) 53 | -------------------------------------------------------------------------------- /utils/fig.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from mpl_toolkits.axisartist.axislines import SubplotZero 3 | import numpy as np 4 | 5 | 6 | def autolabel(rects): 7 | for rect in rects: 8 | height = rect.get_height() 9 | plt.text(rect.get_x()+rect.get_width()/2.-0.15, 1.03*height, '%s' % float(height),fontsize=8) 10 | 11 | 12 | x=np.arange(6) 13 | y1=[87.1,87.4,88.1,87.4,87.1,86.3] 14 | y2=[60.5,63.3,64.6,63.9,63.1,62.3] 15 | 16 | y3 = [84.6,87.8,88.1,87.1,86.7,85.8] 17 | y4 = [57.2,61.6,64.6,64.3,64.0,63.1] 18 | 19 | 20 | bar_width=0.32 21 | 22 | 23 | fig = plt.figure() 24 | 25 | ax_1 = SubplotZero(fig, 212) 26 | fig.add_subplot(ax_1) 27 | 28 | a = ax_1.bar(x,y1,bar_width,label='Rank-1') 29 | b = ax_1.bar(x+bar_width+0.02,y2,bar_width,label='mAP') 30 | autolabel(a) 31 | autolabel(b) 32 | ax_1.set_ylim(30,90) 33 | ax_1.set_xticks(x+bar_width/2+0.01) 34 | ax_1.set_xticklabels(['5','10','15','20','25','30']) 35 | ax_1.set_xlabel(r"$\alpha$") 36 | ax_1.axis['top'].set_visible(False) 37 | ax_1.axis['right'].set_visible(False) 38 | ax_1.axis['left'].set_axisline_style("->") 39 | ax_1.axis['bottom'].set_axisline_style("->") 40 | 41 | 42 | ax_2 = SubplotZero(fig, 211) 43 | fig.add_subplot(ax_2) 44 | 45 | c = ax_2.bar(x,y3,bar_width,label='Rank-1') 46 | d = ax_2.bar(x+bar_width+0.01,y4,bar_width,label='mAP') 47 | autolabel(c) 48 | autolabel(d) 49 | ax_2.set_ylim(30,90) 50 | ax_2.set_xticks(x+bar_width/2+0.01) 51 | ax_2.set_xlabel(r"$\alpha$") 52 | ax_2.set_xticklabels(['1','3','5','7','9','11']) 53 | ax_2.axis['top'].set_visible(False) 54 | ax_2.axis['right'].set_visible(False) 55 | ax_2.axis['left'].set_axisline_style("->") 56 | ax_2.axis['bottom'].set_axisline_style("->") 57 | 58 | ax_2.legend(loc=8, ncol=2,bbox_to_anchor=(0, 1.1, 1, 0)) 59 | 60 | plt.tight_layout() 61 | plt.show() -------------------------------------------------------------------------------- /utils/fp16_utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.batchnorm import _BatchNorm 2 | 3 | 4 | def network_to_half(module): 5 | return norm_convert_float(module.half()) 6 | 7 | 8 | def norm_convert_float(module): 9 | ''' 10 | BatchNorm layers need parameters in single precision. 11 | Find all layers and convert them back to float. This can't 12 | be done with built in .apply as that function will apply 13 | fn to all modules, parameters, and buffers. Thus we wouldn't 14 | be able to guard the float conversion based on the module type. 15 | ''' 16 | if isinstance(module, _BatchNorm): 17 | module.float() 18 | for child in module.children(): 19 | norm_convert_float(child) 20 | return module 21 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def one_hot(indices, depth, dtype=torch.float): 5 | y_onehot = torch.zeros(size=(indices.size(0), depth), dtype=dtype, device=indices.device) 6 | y_onehot.scatter_(1, indices.unsqueeze(1), 1) 7 | 8 | return y_onehot 9 | -------------------------------------------------------------------------------- /utils/mod_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from torch import nn 4 | 5 | 6 | def clone_without_grad(module): 7 | assert isinstance(module, nn.Module) 8 | 9 | mod = copy.deepcopy(module) 10 | 11 | for name, param in module.named_parameters(): 12 | if '.' not in name: 13 | setattr(mod, name, nn.Parameter(param.data, requires_grad=False)) 14 | else: 15 | splits = name.split('.') 16 | attr = mod 17 | 18 | for s in splits[:-1]: 19 | attr = getattr(attr, s) 20 | 21 | setattr(attr, splits[-1], nn.Parameter(param.data, requires_grad=False)) 22 | 23 | return mod 24 | -------------------------------------------------------------------------------- /utils/rank_vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import cv2 3 | import matplotlib as mpl 4 | import numpy as np 5 | from sklearn.preprocessing import normalize 6 | from tqdm import tqdm 7 | 8 | mpl.use('agg') 9 | 10 | 11 | def rank_vis(query_mat, gallery_mat, n_match, save_dir): 12 | if not os.path.isdir(save_dir): 13 | os.makedirs(save_dir) 14 | 15 | query_feat = query_mat['feat'] 16 | query_ids = query_mat['ids'].squeeze() 17 | query_cam_ids = query_mat['cam_ids'].squeeze() 18 | query_img_paths = query_mat['img_path'].tolist() 19 | 20 | gallery_feat = gallery_mat['feat'] 21 | gallery_ids = gallery_mat['ids'].squeeze() 22 | gallery_cam_ids = gallery_mat['cam_ids'].squeeze() 23 | gallery_img_paths = gallery_mat['img_path'].tolist() 24 | 25 | dist = -np.dot(normalize(query_feat), normalize(gallery_feat).T) 26 | 27 | rank_list = np.argsort(dist, axis=1) 28 | 29 | num_query = len(query_ids) 30 | num_gallery = len(gallery_ids) 31 | 32 | for i in tqdm(range(num_query)): 33 | index = np.arange(num_gallery) 34 | 35 | intra_flag = np.equal(gallery_cam_ids, query_cam_ids[i]) 36 | inter_flag = np.not_equal(gallery_cam_ids, query_cam_ids[i]) 37 | junk_flag = np.logical_or(np.equal(gallery_ids, -1), 38 | np.logical_and(intra_flag, np.equal(gallery_ids, query_ids[i]))) 39 | 40 | junk_index = index[junk_flag] 41 | # intra_index = index[intra_flag] 42 | # inter_index = index[inter_flag] 43 | 44 | sort_index = rank_list[i] 45 | flag = np.logical_not(np.isin(sort_index, junk_index)) 46 | # intra_flag = np.isin(sort_index, intra_index) 47 | # inter_flag = np.isin(sort_index, inter_index) 48 | 49 | sort_index = sort_index[flag] 50 | # intra_sort_index = sort_index[np.logical_and(flag, intra_flag)] 51 | # inter_sort_index = sort_index[np.logical_and(flag, inter_flag)] 52 | 53 | plt.figure(figsize=(4 * n_match, 10)) 54 | 55 | q_img_path = query_img_paths[i].strip() 56 | q_img = cv2.imread(q_img_path) 57 | q_img = cv2.cvtColor(q_img, cv2.COLOR_BGR2RGB) 58 | q_img = cv2.resize(q_img, (128, 256)) 59 | 60 | # --------------------------------- # 61 | plt.subplot(1, n_match + 1, 1) 62 | 63 | plt.imshow(q_img) 64 | # plt.text(30, -8, 'cam %d' % query_cam_ids[i], size=45) 65 | plt.xticks([]) 66 | plt.yticks([]) 67 | 68 | # # --------------------------------- # 69 | # plt.subplot(3, n_match + 1, 1 + n_match + 1) 70 | 71 | # plt.imshow(q_img) 72 | # plt.text(30, -8, 'cam %d' % query_cam_ids[i], size=45) 73 | # plt.xticks([]) 74 | # plt.yticks([]) 75 | 76 | # # --------------------------------- # 77 | # plt.subplot(3, n_match + 1, 1 + (n_match + 1) * 2) 78 | 79 | # plt.imshow(q_img) 80 | # plt.text(30, -8, 'cam %d' % query_cam_ids[i], size=45) 81 | # plt.xticks([]) 82 | # plt.yticks([]) 83 | 84 | for j in range(n_match): 85 | plt.subplot(1, n_match + 1, j + 2) 86 | 87 | g_img_path = gallery_img_paths[sort_index[j]].strip() 88 | g_img = cv2.imread(g_img_path) 89 | g_img = cv2.cvtColor(g_img, cv2.COLOR_BGR2RGB) 90 | 91 | # plt.text(30, -8, 'cam %d' % gallery_cam_ids[sort_index[j]], size=45) 92 | 93 | plt.imshow(cv2.resize(g_img, (128, 256))) 94 | plt.xticks([]) 95 | plt.yticks([]) 96 | 97 | if query_ids[i] != gallery_ids[sort_index[j]]: 98 | color = "red" 99 | else: 100 | color = "green" 101 | 102 | ax = plt.gca() 103 | ax.spines['right'].set_linewidth(8) 104 | ax.spines['top'].set_linewidth(8) 105 | ax.spines['left'].set_linewidth(8) 106 | ax.spines['bottom'].set_linewidth(8) 107 | ax.spines['right'].set_color(color) 108 | ax.spines['top'].set_color(color) 109 | ax.spines['left'].set_color(color) 110 | ax.spines['bottom'].set_color(color) 111 | 112 | # # --------------------------------- # 113 | # plt.subplot(3, n_match + 1, j + 2 + n_match + 1) 114 | 115 | # g_img_path = gallery_img_paths[intra_sort_index[j]].strip() 116 | # g_img = cv2.imread(g_img_path) 117 | # g_img = cv2.cvtColor(g_img, cv2.COLOR_BGR2RGB) 118 | 119 | # plt.text(30, -8, 'cam %d' % gallery_cam_ids[intra_sort_index[j]], size=45) 120 | # # plt.text(30, -8, '%.3f' % dist[i][intra_sort_index[j]], size=45) 121 | 122 | # plt.imshow(cv2.resize(g_img, (128, 256))) 123 | # plt.xticks([]) 124 | # plt.yticks([]) 125 | 126 | # if query_ids[i] != gallery_ids[intra_sort_index[j]]: 127 | # color = "red" 128 | # else: 129 | # color = "green" 130 | 131 | # ax = plt.gca() 132 | # ax.spines['right'].set_linewidth(8) 133 | # ax.spines['top'].set_linewidth(8) 134 | # ax.spines['left'].set_linewidth(8) 135 | # ax.spines['bottom'].set_linewidth(8) 136 | # ax.spines['right'].set_color(color) 137 | # ax.spines['top'].set_color(color) 138 | # ax.spines['left'].set_color(color) 139 | # ax.spines['bottom'].set_color(color) 140 | 141 | # # --------------------------------- # 142 | # plt.subplot(3, n_match + 1, j + 2 + (n_match + 1) * 2) 143 | 144 | # g_img_path = gallery_img_paths[inter_sort_index[j]].strip() 145 | # g_img = cv2.imread(g_img_path) 146 | # g_img = cv2.cvtColor(g_img, cv2.COLOR_BGR2RGB) 147 | 148 | # plt.text(30, -8, 'cam %d' % gallery_cam_ids[inter_sort_index[j]], size=45) 149 | # # plt.text(30, -8, '%.3f' % dist[i][inter_sort_index[j]], size=45) 150 | 151 | # plt.imshow(cv2.resize(g_img, (128, 256))) 152 | # plt.xticks([]) 153 | # plt.yticks([]) 154 | 155 | # if query_ids[i] != gallery_ids[inter_sort_index[j]]: 156 | # color = "red" 157 | # else: 158 | # color = "green" 159 | 160 | # ax = plt.gca() 161 | # ax.spines['right'].set_linewidth(8) 162 | # ax.spines['top'].set_linewidth(8) 163 | # ax.spines['left'].set_linewidth(8) 164 | # ax.spines['bottom'].set_linewidth(8) 165 | # ax.spines['right'].set_color(color) 166 | # ax.spines['top'].set_color(color) 167 | # ax.spines['left'].set_color(color) 168 | # ax.spines['bottom'].set_color(color) 169 | 170 | plt.tight_layout() 171 | plt.savefig("{}/{}.jpg".format(save_dir, os.path.basename(q_img_path))) 172 | plt.close() 173 | 174 | 175 | if __name__ == '__main__': 176 | import os 177 | import argparse 178 | 179 | import scipy.io as sio 180 | 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument("dataset", type=str, choices=[ 183 | "market", "duke", "msmt", "cuhk"]) 184 | parser.add_argument("prefix", type=str) 185 | args = parser.parse_args() 186 | 187 | q_mat_path = os.path.join("features", args.dataset, 188 | "query-%s.mat" % args.prefix) 189 | g_mat_path = os.path.join("features", args.dataset, 190 | "gallery-%s.mat" % args.prefix) 191 | 192 | g_mat = sio.loadmat(g_mat_path) 193 | q_mat = sio.loadmat(q_mat_path) 194 | 195 | save_dir = "./vis/baseline" 196 | os.makedirs(save_dir, exist_ok=True) 197 | 198 | rank_vis(q_mat, g_mat, n_match=10, save_dir=save_dir) 199 | -------------------------------------------------------------------------------- /utils/tsne.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import scipy.io as sio 5 | import matplotlib as mpl 6 | 7 | mpl.use('agg') 8 | 9 | import matplotlib.pyplot as plt 10 | from sklearn.manifold import TSNE 11 | from collections import Counter 12 | 13 | k = 25 14 | 15 | random.seed(0) 16 | plt.figure(figsize=(7.5, 3.5)) 17 | 18 | source_features_path = 'features/duke/gallery-duke2market-all-eps0.8-nomix_model_70.mat' 19 | target_features_path = 'features/market/gallery-duke2market-all-eps0.8-nomix_model_70.mat' 20 | 21 | print('Loading...') 22 | source_mat = sio.loadmat(source_features_path) 23 | target_mat = sio.loadmat(target_features_path) 24 | print('Done!') 25 | 26 | source_features = source_mat["feat"] 27 | source_ids = source_mat["ids"].squeeze() 28 | source_cam_ids = source_mat["cam_ids"].squeeze() 29 | source_img_paths = source_mat['img_path'] 30 | 31 | target_features = target_mat["feat"] 32 | target_ids = -target_mat["ids"].squeeze() 33 | target_cam_ids = target_mat["cam_ids"].squeeze() 34 | target_img_paths = target_mat['img_path'] 35 | 36 | s_counter = Counter(source_ids) 37 | t_counter = Counter(target_ids) 38 | 39 | s_select_ids = [] 40 | t_select_ids = [] 41 | for idx, num in s_counter.items(): 42 | if 30 < num < 60 and idx not in [0, -1]: 43 | s_select_ids.append(idx) 44 | for idx, num in t_counter.items(): 45 | if 30 < num < 60 and idx not in [0, -1]: 46 | t_select_ids.append(idx) 47 | 48 | assert len(s_select_ids) >= k 49 | assert len(t_select_ids) >= k 50 | 51 | s_select_ids = random.sample(s_select_ids, k) 52 | t_select_ids = random.sample(t_select_ids, k) 53 | 54 | s_flags = np.in1d(source_ids, s_select_ids) 55 | t_flags = np.in1d(target_ids, t_select_ids) 56 | 57 | s_ids = source_ids[s_flags] 58 | t_ids = target_ids[t_flags] 59 | 60 | ids = np.concatenate([s_ids, t_ids], axis=0).tolist() 61 | 62 | id_map = dict(zip(s_select_ids + t_select_ids, range(2 * k))) 63 | 64 | new_ids = [] 65 | for x in ids: 66 | new_ids.append(id_map[x]) 67 | 68 | s_feats = source_features[s_flags] 69 | t_feats = target_features[t_flags] 70 | feats = np.concatenate([s_feats, t_feats], axis=0) 71 | 72 | tsne = TSNE(n_components=2, random_state=0) 73 | proj = tsne.fit_transform(feats) 74 | 75 | ax = plt.subplot(121) 76 | ax.spines['top'].set_visible(False) 77 | ax.spines['right'].set_visible(False) 78 | ax.spines['bottom'].set_visible(False) 79 | ax.spines['left'].set_visible(False) 80 | ax.set_xticks([]) 81 | ax.set_yticks([]) 82 | 83 | t_size = t_feats.shape[0] 84 | s_size = s_feats.shape[0] 85 | ax.scatter(proj[-t_size:, 0], proj[-t_size:, 1], c=['b'] * t_size, marker='.') 86 | ax.scatter(proj[:s_size, 0], proj[:s_size, 1], c=['r'] * s_size, marker='.') 87 | 88 | # --------------------------------------------------------------------- # 89 | source_features_path = 'features/duke/gallery-duke2market-all-eps0.8-mix0.6_model_70.mat' 90 | target_features_path = 'features/market/gallery-duke2market-all-eps0.8-mix0.6_model_70.mat' 91 | 92 | print('Loading...') 93 | source_mat = sio.loadmat(source_features_path) 94 | target_mat = sio.loadmat(target_features_path) 95 | print('Done!') 96 | 97 | source_features = source_mat["feat"] 98 | source_ids = source_mat["ids"].squeeze() 99 | source_cam_ids = source_mat["cam_ids"].squeeze() 100 | 101 | target_features = target_mat["feat"] 102 | target_ids = -target_mat["ids"].squeeze() 103 | target_cam_ids = target_mat["cam_ids"].squeeze() 104 | 105 | s_flags = np.in1d(source_ids, s_select_ids) 106 | t_flags = np.in1d(target_ids, t_select_ids) 107 | 108 | s_ids = source_ids[s_flags] 109 | t_ids = target_ids[t_flags] 110 | 111 | ids = np.concatenate([s_ids, t_ids], axis=0).tolist() 112 | 113 | new_ids = [] 114 | for x in ids: 115 | new_ids.append(id_map[x]) 116 | 117 | s_feats = source_features[s_flags] 118 | t_feats = target_features[t_flags] 119 | feats = np.concatenate([s_feats, t_feats], axis=0) 120 | 121 | tsne = TSNE(n_components=2, random_state=0) 122 | proj = tsne.fit_transform(feats) 123 | 124 | ax = plt.subplot(122) 125 | ax.spines['top'].set_visible(False) 126 | ax.spines['right'].set_visible(False) 127 | ax.spines['bottom'].set_visible(False) 128 | ax.spines['left'].set_visible(False) 129 | ax.set_xticks([]) 130 | ax.set_yticks([]) 131 | 132 | ax.scatter(proj[-t_size:, 0], proj[-t_size:, 1], c=['b'] * t_size, marker='.') 133 | ax.scatter(proj[:s_size, 0], proj[:s_size, 1], c=['r'] * s_size, marker='.') 134 | 135 | plt.tight_layout() 136 | plt.savefig('tsne.pdf') 137 | 138 | s_paths = source_img_paths[s_flags] 139 | t_paths = target_img_paths[t_flags] 140 | for path in s_paths.tolist(): 141 | os.system('cp %s /home/chuanchen_luo/vis' % path) 142 | for path in t_paths.tolist(): 143 | os.system('cp %s /home/chuanchen_luo/vis' % path) 144 | --------------------------------------------------------------------------------