├── README.md ├── figs └── pipeline.png ├── reid ├── criterion │ ├── __init__.py │ ├── build_criterion.py │ ├── softmax_loss.py │ ├── spatial_align_loss.py │ └── triplet_loss.py ├── data │ ├── __init__.py │ ├── bases.py │ ├── build_data.py │ ├── market1501.py │ ├── msmt17.py │ ├── preprocessing.py │ ├── sampler.py │ ├── vehicleid.py │ └── veri.py ├── engine │ ├── __init__.py │ ├── evaluator.py │ ├── searcher.py │ └── trainer.py ├── models │ ├── __init__.py │ ├── cm.py │ ├── msinet.py │ ├── operations.py │ ├── sam.py │ └── search_cnn.py ├── solver │ ├── __init__.py │ ├── build_optimizer.py │ └── lr_scheduler.py └── utils │ ├── __init__.py │ ├── logging.py │ ├── meters.py │ ├── metrics.py │ ├── osutils.py │ ├── rerank.py │ └── serialization.py ├── search.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # MSINet 2 | Official implementation of "[MSINet: Twins Contrastive Search of Multi-Scale Interaction for Object ReID](https://arxiv.org/abs/2303.07065)". 3 | 4 | ## Highlight 5 | 6 | * MSINet is accepted by **CVPR2023**. 7 | * MSINet is a light-weighted network architecture for object Re-ID tasks. 8 | * MSINet trained from scratch achieves higher retrieval performance compared with pre-trained ResNet-50. 9 | 10 | ## Abstract 11 | 12 | Neural Architecture Search (NAS) has been increasingly appealing to the society of object Re-Identification (ReID), for that task-specific architectures significantly improve the retrieval performance. Previous works explore new optimizing targets and search spaces for NAS ReID, yet they neglect the difference of training schemes between image classification and ReID. In this work, we propose a novel Twins Contrastive Mechanism (TCM) to provide more appropriate supervision for ReID architecture search. TCM reduces the category overlaps between the training and validation data, and assists NAS in simulating real-world ReID training schemes. We then design a Multi-Scale Interaction (MSI) search space to search for rational interaction operations between multi-scale features. In addition, we introduce a Spatial Alignment Module (SAM) to further enhance the attention consistency confronted with images from different sources. Under the proposed NAS scheme, a specific architecture is automatically searched, named as MSINet. Extensive experiments demonstrate that our method surpasses state-of-the-art ReID methods on both in-domain and cross-domain scenarios. 13 | 14 | ![pipeline](figs/pipeline.png) 15 | 16 | ## Datasets 17 | 18 | Put the datasets into `./data` 19 | 20 | * Market-1501 21 | * MSMT17 22 | * VeRi-776 23 | * VehicleID 24 | 25 | ## Experiment Commands 26 | 27 | ### Evaluate the Performance of MSINet 28 | 29 | **Train from Scratch** 30 | 31 | ```bash 32 | python train.py 33 | ``` 34 | 35 | **Train after Pre-training** 36 | 37 | Download the pre-trained model from [Google Drive](https://drive.google.com/file/d/1ZNLDbtpsraiF149htbyhbh3UjCwREi9p/view?usp=sharing) and put it into `./pretrained`. 38 | 39 | ```bash 40 | python train.py --pretrained 41 | ``` 42 | 43 | **Train Cross-domain Experiments** 44 | 45 | ```bash 46 | python train.py -ds market1501 -dt msmt17 --pretrained --epochs 250 47 | ``` 48 | 49 | **Add SAM for Cross-domain Experiments** 50 | 51 | ```bash 52 | python train.py -ds market1501 -dt msmt17 --pretrained --epochs 250 --sam-mode pos_neg 53 | ``` 54 | 55 | To train vehicle datasets, please add `--width 256`. 56 | 57 | ### Conduct Search for Other Re-ID Datasets 58 | 59 | ```bash 60 | python search.py -ds market1501 61 | ``` 62 | 63 | ## Citation 64 | 65 | ``` 66 | @inproceedings{gu2023msinet, 67 | title={MSINet: Twins Contrastive Search of Multi-Scale Interaction for Object ReID}, 68 | author={Gu, Jianyang and Wang, Kai and Luo, Hao and Chen, Chen and Jiang, Wei and Fang, Yuqiang and Zhang, Shanghang and You, Yang and Zhao, Jian}, 69 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 70 | pages={19243--19253}, 71 | year={2023} 72 | } 73 | ``` 74 | 75 | -------------------------------------------------------------------------------- /figs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vimar-gu/MSINet/2a8845b6b3d1a3b8baeb864b92f9423c2dc711ee/figs/pipeline.png -------------------------------------------------------------------------------- /reid/criterion/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_criterion import build_criterion 2 | -------------------------------------------------------------------------------- /reid/criterion/build_criterion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .softmax_loss import CrossEntropyLabelSmooth 5 | from .triplet_loss import TripletLoss, cosine_dist 6 | from .spatial_align_loss import SpatialAlignLoss 7 | 8 | 9 | class ReIDLoss(nn.Module): 10 | """Build the loss function for ReID tasks. 11 | """ 12 | def __init__(self, args, num_classes): 13 | super(ReIDLoss, self).__init__() 14 | 15 | self.triplet = TripletLoss(args.margin, normalize_feature=True) 16 | self.softmax = CrossEntropyLabelSmooth(num_classes=num_classes) 17 | if args.sam_mode != 'none': 18 | self.spatial = SpatialAlignLoss(args.sam_mode) 19 | self.sam_mode = args.sam_mode 20 | self.sam_ratio = args.sam_ratio 21 | else: 22 | self.sam_mode = None 23 | 24 | def forward(self, feats, logits, sam_logits, target, sam=False): 25 | triplet_loss = self.triplet(feats, target)[0] 26 | softmax_loss = self.softmax(logits, target) 27 | if self.sam_mode != None and sam: 28 | spatial_loss = self.spatial(sam_logits, target) 29 | return softmax_loss + triplet_loss + spatial_loss * self.sam_ratio 30 | else: 31 | return softmax_loss + triplet_loss 32 | 33 | 34 | def build_criterion(args, num_classes): 35 | return ReIDLoss(args, num_classes) 36 | -------------------------------------------------------------------------------- /reid/criterion/softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CrossEntropyLabelSmooth(nn.Module): 6 | """Cross entropy loss with label smoothing regularizer. 7 | 8 | Reference: 9 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 10 | Equation: y = (1 - epsilon) * y + epsilon / K. 11 | 12 | Args: 13 | num_classes (int): number of classes. 14 | epsilon (float): weight. 15 | """ 16 | 17 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 18 | super(CrossEntropyLabelSmooth, self).__init__() 19 | self.num_classes = num_classes 20 | self.epsilon = epsilon 21 | self.use_gpu = use_gpu 22 | self.logsoftmax = nn.LogSoftmax(dim=1) 23 | 24 | def forward(self, inputs, targets): 25 | """ 26 | Args: 27 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 28 | targets: ground truth labels with shape (num_classes) 29 | """ 30 | log_probs = self.logsoftmax(inputs) 31 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 32 | if self.use_gpu: targets = targets.cuda() 33 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 34 | loss = (- targets * log_probs).mean(0).sum() 35 | return loss -------------------------------------------------------------------------------- /reid/criterion/spatial_align_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpatialAlignLoss(nn.Module): 6 | """The spatial alignment loss for cross-domain Re-ID. 7 | """ 8 | def __init__(self, mode='pos'): 9 | super(SpatialAlignLoss, self).__init__() 10 | self.mode = mode 11 | 12 | def forward(self, sam_logits, labels): 13 | unsup_corrs = sam_logits[:-1] 14 | pos_corrs = sam_logits[-1].unsqueeze(1) 15 | 16 | N = unsup_corrs.shape[0] 17 | M = unsup_corrs.shape[-1] 18 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 19 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 20 | # Calculate un-parametric self-attention activation. 21 | unsup_pos = unsup_corrs[is_pos].contiguous().view(N, -1, M) 22 | unsup_neg = unsup_corrs[is_neg].contiguous().view(N, -1, M) 23 | 24 | def batch_cosine_dist(x, y): 25 | bs = x.shape[0] 26 | bs1, bs2 = x.shape[1], y.shape[1] 27 | frac_up = torch.bmm(x, y.transpose(1, 2)) 28 | frac_down1 = torch.sqrt(torch.sum(torch.pow(x, 2), dim=2)).view(bs, bs1, 1).repeat(1, 1, bs2) 29 | frac_down2 = torch.sqrt(torch.sum(torch.pow(y, 2), dim=2)).view(bs, 1, bs2).repeat(1, bs1, 1) 30 | 31 | return 1 - frac_up / (frac_down1 * frac_down2) 32 | 33 | all_losses = 0 34 | # Align the positive attention. 35 | if 'pos' in self.mode: 36 | pos_dist = batch_cosine_dist(pos_corrs, unsup_pos) 37 | all_losses += pos_dist.mean() 38 | # Align the negative attention. 39 | if 'neg' in self.mode: 40 | neg_dist = batch_cosine_dist(unsup_neg, unsup_neg) 41 | all_losses += neg_dist.mean() 42 | 43 | return all_losses 44 | -------------------------------------------------------------------------------- /reid/criterion/triplet_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def euclidean_dist(x, y): 9 | m, n = x.size(0), y.size(0) 10 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 11 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 12 | dist = xx + yy 13 | dist.addmm_(x, y.t(), beta=1, alpha=-2) 14 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 15 | return dist 16 | 17 | def cosine_dist(x, y): 18 | bs1, bs2 = x.size(0), y.size(0) 19 | frac_up = torch.matmul(x, y.transpose(0, 1)) 20 | frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \ 21 | (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1) 22 | cosine = frac_up / frac_down 23 | return 1-cosine 24 | 25 | def _batch_hard(mat_distance, mat_similarity, indice=False): 26 | sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1, descending=True) 27 | hard_p = sorted_mat_distance[:, 0] 28 | hard_p_indice = positive_indices[:, 0] 29 | sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1, descending=False) 30 | hard_n = sorted_mat_distance[:, 0] 31 | hard_n_indice = negative_indices[:, 0] 32 | if(indice): 33 | return hard_p, hard_n, hard_p_indice, hard_n_indice 34 | return hard_p, hard_n 35 | 36 | class TripletLoss(nn.Module): 37 | ''' 38 | Compute Triplet loss augmented with Batch Hard 39 | Details can be seen in 'In defense of the Triplet Loss for Person Re-Identification' 40 | ''' 41 | 42 | def __init__(self, margin, normalize_feature=False): 43 | super(TripletLoss, self).__init__() 44 | self.margin = margin 45 | self.normalize_feature = normalize_feature 46 | self.margin_loss = nn.MarginRankingLoss(margin=margin).cuda() 47 | 48 | def forward(self, emb, label): 49 | if self.normalize_feature: 50 | # equal to cosine similarity 51 | emb = F.normalize(emb) 52 | mat_dist = euclidean_dist(emb, emb) 53 | # mat_dist = cosine_dist(emb, emb) 54 | assert mat_dist.size(0) == mat_dist.size(1) 55 | N = mat_dist.size(0) 56 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 57 | 58 | dist_ap, dist_an = _batch_hard(mat_dist, mat_sim) 59 | assert dist_an.size(0)==dist_ap.size(0) 60 | y = torch.ones_like(dist_ap) 61 | loss = self.margin_loss(dist_an, dist_ap, y) 62 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 63 | return loss, prec 64 | 65 | class SoftTripletLoss(nn.Module): 66 | 67 | def __init__(self, margin=None, normalize_feature=False): 68 | super(SoftTripletLoss, self).__init__() 69 | self.margin = margin 70 | self.normalize_feature = normalize_feature 71 | 72 | def forward(self, emb1, emb2, label): 73 | if self.normalize_feature: 74 | # equal to cosine similarity 75 | emb1 = F.normalize(emb1) 76 | emb2 = F.normalize(emb2) 77 | 78 | mat_dist = euclidean_dist(emb1, emb1) 79 | assert mat_dist.size(0) == mat_dist.size(1) 80 | N = mat_dist.size(0) 81 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 82 | 83 | dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True) 84 | assert dist_an.size(0)==dist_ap.size(0) 85 | triple_dist = torch.stack((dist_ap, dist_an), dim=1) 86 | triple_dist = F.log_softmax(triple_dist, dim=1) 87 | if (self.margin is not None): 88 | loss = (- self.margin * triple_dist[:,0] - (1 - self.margin) * triple_dist[:,1]).mean() 89 | return loss 90 | 91 | mat_dist_ref = euclidean_dist(emb2, emb2) 92 | dist_ap_ref = torch.gather(mat_dist_ref, 1, ap_idx.view(N,1).expand(N,N))[:,0] 93 | dist_an_ref = torch.gather(mat_dist_ref, 1, an_idx.view(N,1).expand(N,N))[:,0] 94 | triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1) 95 | triple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach() 96 | 97 | loss = (- triple_dist_ref * triple_dist).mean(0).sum() 98 | return loss 99 | -------------------------------------------------------------------------------- /reid/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_data import build_data 2 | -------------------------------------------------------------------------------- /reid/data/bases.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFile 2 | 3 | from torch.utils.data import Dataset 4 | import os.path as osp 5 | 6 | ImageFile.LOAD_TRUNCATED_IMAGES = True 7 | 8 | 9 | def read_image(img_path): 10 | """Keep reading image until succeed. 11 | This can avoid IOError incurred by heavy IO process.""" 12 | got_img = False 13 | if not osp.exists(img_path): 14 | raise IOError("{} does not exist".format(img_path)) 15 | while not got_img: 16 | try: 17 | img = Image.open(img_path).convert('RGB') 18 | got_img = True 19 | except IOError: 20 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 21 | pass 22 | return img 23 | 24 | 25 | class BaseDataset(object): 26 | """ 27 | Base class of reid dataset 28 | """ 29 | 30 | def get_imagedata_info(self, data): 31 | pids, cams = [], [] 32 | for _, pid, camid in data: 33 | pids += [pid] 34 | cams += [camid] 35 | pids = set(pids) 36 | cams = set(cams) 37 | num_pids = len(pids) 38 | num_cams = len(cams) 39 | num_imgs = len(data) 40 | return num_pids, num_imgs, num_cams 41 | 42 | def print_dataset_statistics(self): 43 | raise NotImplementedError 44 | 45 | 46 | class BaseImageDataset(BaseDataset): 47 | """ 48 | Base class of image reid dataset 49 | """ 50 | 51 | def print_dataset_statistics(self, train, query, gallery): 52 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 53 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 54 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 55 | 56 | print("Dataset statistics:") 57 | print(" ----------------------------------------") 58 | print(" subset | # ids | # images | # cameras") 59 | print(" ----------------------------------------") 60 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 61 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 62 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 63 | print(" ----------------------------------------") 64 | 65 | 66 | class ImageDataset(Dataset): 67 | def __init__(self, dataset, transform=None): 68 | self.dataset = dataset 69 | self.transform = transform 70 | 71 | def __len__(self): 72 | return len(self.dataset) 73 | 74 | def __getitem__(self, index): 75 | img_path, pid, camid = self.dataset[index] 76 | img = read_image(img_path) 77 | 78 | if self.transform is not None: 79 | img = self.transform(img) 80 | 81 | return img, pid, camid 82 | -------------------------------------------------------------------------------- /reid/data/build_data.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | from torch.utils.data import DataLoader 3 | from collections import defaultdict 4 | import numpy as np 5 | 6 | from .market1501 import Market1501 7 | from .msmt17 import MSMT17 8 | from .vehicleid import VehicleID 9 | from .veri import VeRi 10 | from .bases import ImageDataset 11 | from .preprocessing import RandomErasing 12 | from .sampler import RandomIdentitySampler 13 | 14 | 15 | __factory = { 16 | 'market1501': Market1501, 17 | 'msmt17': MSMT17, 18 | 'vehicleid': VehicleID, 19 | 'veri': VeRi, 20 | } 21 | 22 | 23 | class IterLoader: 24 | def __init__(self, loader, length=None): 25 | self.loader = loader 26 | self.length = length 27 | self.iter = None 28 | 29 | def __len__(self): 30 | if self.length is not None: 31 | return self.length 32 | else: 33 | return len(self.loader) 34 | 35 | def new_epoch(self): 36 | self.iter = iter(self.loader) 37 | 38 | def next(self): 39 | try: 40 | return next(self.iter) 41 | except: 42 | self.iter = iter(self.loader) 43 | return next(self.iter) 44 | 45 | 46 | def separate_trainval(train): 47 | pid2data_dict = defaultdict(list) 48 | for data in train: 49 | pid2data_dict[data[1]].append(data) 50 | num_pids = len(pid2data_dict.keys()) 51 | val_pids = list(np.arange(num_pids // 5 * 1, num_pids)) 52 | val_pid2label = {pid: idx for idx, pid in enumerate(val_pids)} 53 | new_train = [] 54 | new_valid = [] 55 | for pid in pid2data_dict.keys(): 56 | if pid < num_pids // 5 * 1: 57 | new_train += pid2data_dict[pid] 58 | elif pid >= num_pids // 5 * 3: 59 | for data in pid2data_dict[pid]: 60 | new_valid.append((data[0], val_pid2label[pid], data[2])) 61 | else: 62 | data = pid2data_dict[pid] 63 | data_len = len(data) 64 | for idx in range(data_len // 2): 65 | new_train.append((data[idx][0], pid, data[idx][2])) 66 | for idx in range(data_len // 2, data_len): 67 | new_valid.append((data[idx][0], val_pid2label[pid], data[idx][2])) 68 | 69 | return new_train, new_valid, num_pids // 5 * 3, num_pids - num_pids // 5 * 1 70 | 71 | 72 | def build_data(args, target=False, search=False): 73 | if target: 74 | data_name = args.target_dataset 75 | else: 76 | data_name = args.source_dataset 77 | 78 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 79 | std=[0.229, 0.224, 0.225]) 80 | if args.target_dataset != 'none': 81 | train_transforms = T.Compose([ 82 | T.Resize((args.height, args.width)), 83 | T.RandomHorizontalFlip(p=0.5), 84 | T.Pad(10), 85 | T.RandomCrop((args.height, args.width)), 86 | T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, 87 | hue=0), 88 | T.ToTensor(), 89 | normalizer 90 | ]) 91 | else: 92 | train_transforms = T.Compose([ 93 | T.Resize((args.height, args.width)), 94 | T.RandomHorizontalFlip(p=0.5), 95 | T.Pad(10), 96 | T.RandomCrop((args.height, args.width)), 97 | T.ToTensor(), 98 | normalizer, 99 | RandomErasing(probability=0.5, sh=0.4, 100 | mean=(0.4914, 0.4822, 0.4465)) 101 | ]) 102 | test_transforms = T.Compose([ 103 | T.Resize((args.height, args.width)), 104 | T.ToTensor(), 105 | normalizer 106 | ]) 107 | 108 | dataset = __factory[data_name](args.data_dir) 109 | 110 | num_workers = args.workers 111 | num_classes = dataset.num_train_pids 112 | 113 | testset = ImageDataset(dataset.query + dataset.gallery, test_transforms) 114 | test_loader = DataLoader( 115 | testset, batch_size=args.test_batch_size, shuffle=False, 116 | num_workers=num_workers 117 | ) 118 | 119 | if not search: 120 | trainset = ImageDataset(dataset.train, train_transforms) 121 | train_loader = DataLoader( 122 | trainset, batch_size=args.batch_size, num_workers=num_workers, 123 | sampler=RandomIdentitySampler(dataset.train, args.batch_size, args.num_instance), 124 | drop_last=True 125 | ) 126 | 127 | return train_loader, test_loader, len(dataset.query), num_classes 128 | else: 129 | all_train_data = dataset.train 130 | new_train, new_valid, num_train_classes, num_valid_classes =\ 131 | separate_trainval(all_train_data) 132 | 133 | trainset = ImageDataset(new_train, train_transforms) 134 | train_loader = DataLoader( 135 | trainset, batch_size=args.batch_size, num_workers=num_workers, 136 | sampler=RandomIdentitySampler( 137 | new_train, args.batch_size, args.num_instance 138 | ), drop_last=True 139 | ) 140 | validset = ImageDataset(new_valid, train_transforms) 141 | valid_loader = DataLoader( 142 | validset, batch_size=args.batch_size, num_workers=num_workers, 143 | sampler=RandomIdentitySampler( 144 | new_valid, args.batch_size, args.num_instance 145 | ), drop_last=True 146 | ) 147 | valid_loader = IterLoader(valid_loader) 148 | return train_loader, valid_loader, test_loader, len(dataset.query),\ 149 | num_train_classes, num_valid_classes 150 | 151 | -------------------------------------------------------------------------------- /reid/data/market1501.py: -------------------------------------------------------------------------------- 1 | from .bases import BaseImageDataset 2 | import os.path as osp 3 | import glob 4 | import re 5 | import os 6 | 7 | 8 | class Market1501(BaseImageDataset): 9 | def __init__(self, data_dir='data_dir', verbose=True): 10 | super(Market1501, self).__init__() 11 | self.dataset_dir = osp.join(data_dir, 'Market-1501-v15.09.15') 12 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 13 | self.query_dir = osp.join(self.dataset_dir, 'query') 14 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 15 | 16 | train = self._process_dir(self.train_dir, relabel=True) 17 | query = self._process_dir(self.query_dir, relabel=False) 18 | gallery = self._process_dir(self.gallery_dir, relabel=False) 19 | 20 | if verbose: 21 | print("=> Market1501 loaded") 22 | self.print_dataset_statistics(train, query, gallery) 23 | 24 | self.train = train 25 | self.query = query 26 | self.gallery = gallery 27 | 28 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 29 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 30 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 31 | 32 | def _process_dir(self, data_dir, relabel=True): 33 | img_paths = glob.glob(osp.join(data_dir, '*.jpg')) 34 | pattern = re.compile(r'([-\d]+)_c(\d)') 35 | 36 | pid_container = set() 37 | for img_path in img_paths: 38 | pid, _ = map(int, pattern.search(img_path).groups()) 39 | if pid == -1: continue # junk images are just ignored 40 | pid_container.add(pid) 41 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 42 | 43 | dataset = [] 44 | for img_path in img_paths: 45 | pid, camid = map(int, pattern.search(img_path).groups()) 46 | if pid == -1: continue # junk images are just ignored 47 | #assert 0 <= pid <= 2501 # pid == 0 means background 48 | assert 1 <= camid <= 6 49 | camid -= 1 # index starts from 0 50 | if relabel: pid = pid2label[pid] 51 | dataset.append((img_path, pid, camid)) 52 | 53 | return dataset 54 | -------------------------------------------------------------------------------- /reid/data/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import tarfile 4 | 5 | import glob 6 | import re 7 | import urllib 8 | import zipfile 9 | 10 | 11 | style='MSMT17_V1' 12 | def _pluck_msmt(list_file, subdir, pattern=re.compile(r'([-\d]+)_([-\d]+)_([-\d]+)')): 13 | with open(list_file, 'r') as f: 14 | lines = f.readlines() 15 | ret = [] 16 | pids_ = [] 17 | cams_ = [] 18 | for line in lines: 19 | line = line.strip() 20 | fname = line.split(' ')[0] 21 | pid, _, cam = map(int, pattern.search(osp.basename(fname)).groups()) 22 | if pid not in pids_: 23 | pids_.append(pid) 24 | if cam not in cams_: 25 | cams_.append(cam) 26 | 27 | img_path=osp.join(subdir,fname) 28 | ret.append((osp.join(subdir,fname), pid, cam)) 29 | 30 | return ret, pids_, cams_ 31 | 32 | 33 | class Dataset_MSMT(object): 34 | def __init__(self, root): 35 | self.root = root 36 | self.train, self.val, self.trainval = [], [], [] 37 | self.query, self.gallery = [], [] 38 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 39 | 40 | @property 41 | def images_dir(self): 42 | return osp.join(self.root, style) 43 | 44 | def load(self, verbose=True): 45 | exdir = osp.join(self.root, style) 46 | nametrain= osp.join(exdir, 'train') 47 | nametest = osp.join(exdir, 'test') 48 | self.train, train_pids, train_cams = _pluck_msmt(osp.join(exdir, 'list_train.txt'), nametrain) 49 | self.val, val_pids, val_cams = _pluck_msmt(osp.join(exdir, 'list_val.txt'), nametrain) 50 | self.train = self.train + self.val 51 | self.query, query_pids, query_cams = _pluck_msmt(osp.join(exdir, 'list_query.txt'), nametest) 52 | self.gallery, gallery_pids, gallery_cams = _pluck_msmt(osp.join(exdir, 'list_gallery.txt'), nametest) 53 | self.num_train_pids = len(list(set(train_pids).union(set(val_pids)))) 54 | self.num_train_cams = len(list(set(train_cams).union(set(val_cams)))) 55 | 56 | if verbose: 57 | print(self.__class__.__name__, "v1~~~ dataset loaded") 58 | print(" ---------------------------------------") 59 | print(" subset | # ids | # images | # cams") 60 | print(" ---------------------------------------") 61 | print(" train | {:5d} | {:8d} | {:5d}" 62 | .format(self.num_train_pids, len(self.train), self.num_train_cams)) 63 | print(" query | {:5d} | {:8d} | {:5d}" 64 | .format(len(query_pids), len(self.query), len(query_cams))) 65 | print(" gallery | {:5d} | {:8d} | {:5d}" 66 | .format(len(gallery_pids), len(self.gallery), len(gallery_cams))) 67 | print(" ---------------------------------------") 68 | 69 | 70 | class MSMT17(Dataset_MSMT): 71 | 72 | def __init__(self, data_dir, split_id=0, download=False): 73 | super(MSMT17, self).__init__(data_dir) 74 | 75 | if download: 76 | self.download() 77 | 78 | self.load() 79 | 80 | def download(self): 81 | 82 | import re 83 | import hashlib 84 | import shutil 85 | from glob import glob 86 | from zipfile import ZipFile 87 | 88 | raw_dir = osp.join(self.root) 89 | mkdir_if_missing(raw_dir) 90 | 91 | # Download the raw zip file 92 | fpath = osp.join(raw_dir, style) 93 | if osp.isdir(fpath): 94 | print("Using downloaded file: " + fpath) 95 | else: 96 | raise RuntimeError("Please download the dataset manually to {}".format(fpath)) 97 | -------------------------------------------------------------------------------- /reid/data/preprocessing.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import random 4 | import math 5 | 6 | 7 | class GaussianMask(object): 8 | def __init__(self, probability=0.5): 9 | self.probability = probability 10 | 11 | def __call__(self, img): 12 | if random.uniform(0, 1) >= self.probability: 13 | return img 14 | width = img.size[0] 15 | height = img.size[1] 16 | mask = np.zeros((height, width)) 17 | mask_h = np.zeros((height, width)) 18 | mask_h += np.arange(0, width) - width / 2 19 | mask_v = np.zeros((width, height)) 20 | mask_v += np.arange(0, height) - height / 2 21 | mask_v = mask_v.T 22 | 23 | numerator = np.power(mask_h, 2) + np.power(mask_v, 2) 24 | denominator = 2 * (height * height + width * width) 25 | mask = np.exp(-(numerator / denominator)) 26 | 27 | img = np.asarray(img) 28 | new_img = np.zeros_like(img) 29 | new_img[:, :, 0] = np.multiply(mask, img[:, :, 0]) 30 | new_img[:, :, 1] = np.multiply(mask, img[:, :, 1]) 31 | new_img[:, :, 2] = np.multiply(mask, img[:, :, 2]) 32 | 33 | return Image.fromarray(new_img) 34 | 35 | 36 | class RandomErasing(object): 37 | """ Randomly selects a rectangle region in an image and erases its pixels. 38 | 'Random Erasing Data Augmentation' by Zhong et al. 39 | See https://arxiv.org/pdf/1708.04896.pdf 40 | Args: 41 | probability: The probability that the Random Erasing operation will be performed. 42 | sl: Minimum proportion of erased area against input image. 43 | sh: Maximum proportion of erased area against input image. 44 | r1: Minimum aspect ratio of erased area. 45 | mean: Erasing value. 46 | """ 47 | 48 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 49 | self.probability = probability 50 | self.mean = mean 51 | self.sl = sl 52 | self.sh = sh 53 | self.r1 = r1 54 | 55 | def __call__(self, img): 56 | 57 | if random.uniform(0, 1) >= self.probability: 58 | return img 59 | 60 | for attempt in range(100): 61 | area = img.size()[1] * img.size()[2] 62 | 63 | target_area = random.uniform(self.sl, self.sh) * area 64 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 65 | 66 | h = int(round(math.sqrt(target_area * aspect_ratio))) 67 | w = int(round(math.sqrt(target_area / aspect_ratio))) 68 | 69 | if w < img.size()[2] and h < img.size()[1]: 70 | x1 = random.randint(0, img.size()[1] - h) 71 | y1 = random.randint(0, img.size()[2] - w) 72 | if img.size()[0] == 3: 73 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 74 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 75 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 76 | else: 77 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 78 | return img 79 | 80 | return img 81 | -------------------------------------------------------------------------------- /reid/data/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | 7 | 8 | class RandomIdentitySampler(Sampler): 9 | """ 10 | Randomly sample N identities, then for each identity, 11 | randomly sample K instances, therefore batch size is N*K. 12 | Args: 13 | - data_source (list): list of (img_path, pid, camid). 14 | - num_instances (int): number of instances per identity in a batch. 15 | - batch_size (int): number of examples in a batch. 16 | """ 17 | 18 | def __init__(self, data_source, batch_size, num_instances): 19 | self.data_source = data_source 20 | self.batch_size = batch_size 21 | self.num_instances = num_instances 22 | self.num_pids_per_batch = self.batch_size // self.num_instances 23 | self.index_dic = defaultdict(list) #dict with list value 24 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 25 | for index, (_, pid, _) in enumerate(self.data_source): 26 | self.index_dic[pid].append(index) 27 | self.pids = list(self.index_dic.keys()) 28 | 29 | # estimate number of examples in an epoch 30 | self.length = 0 31 | for pid in self.pids: 32 | idxs = self.index_dic[pid] 33 | num = len(idxs) 34 | if num < self.num_instances: 35 | num = self.num_instances 36 | self.length += num - num % self.num_instances 37 | 38 | def __iter__(self): 39 | batch_idxs_dict = defaultdict(list) 40 | 41 | for pid in self.pids: 42 | idxs = copy.deepcopy(self.index_dic[pid]) 43 | if len(idxs) < self.num_instances: 44 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 45 | random.shuffle(idxs) 46 | batch_idxs = [] 47 | for idx in idxs: 48 | batch_idxs.append(idx) 49 | if len(batch_idxs) == self.num_instances: 50 | batch_idxs_dict[pid].append(batch_idxs) 51 | batch_idxs = [] 52 | 53 | avai_pids = copy.deepcopy(self.pids) 54 | final_idxs = [] 55 | 56 | while len(avai_pids) >= self.num_pids_per_batch: 57 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 58 | for pid in selected_pids: 59 | batch_idxs = batch_idxs_dict[pid].pop(0) 60 | final_idxs.extend(batch_idxs) 61 | if len(batch_idxs_dict[pid]) == 0: 62 | avai_pids.remove(pid) 63 | 64 | return iter(final_idxs) 65 | 66 | def __len__(self): 67 | return self.length 68 | -------------------------------------------------------------------------------- /reid/data/vehicleid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | import os.path as osp 7 | 8 | from .bases import BaseImageDataset 9 | from collections import defaultdict 10 | 11 | 12 | class VehicleID(BaseImageDataset): 13 | """ 14 | VehicleID 15 | Reference: 16 | Deep Relative Distance Learning: Tell the Difference Between Similar Vehicles 17 | 18 | Dataset statistics: 19 | # train_list: 13164 vehicles for model training 20 | # test_list_800: 800 vehicles for model testing(small test set in paper 21 | # test_list_1600: 1600 vehicles for model testing(medium test set in paper 22 | # test_list_2400: 2400 vehicles for model testing(large test set in paper 23 | # test_list_3200: 3200 vehicles for model testing 24 | # test_list_6000: 6000 vehicles for model testing 25 | # test_list_13164: 13164 vehicles for model testing 26 | """ 27 | dataset_dir = 'VehicleID' 28 | 29 | def __init__(self, root, verbose=True, test_size=2400, **kwargs): 30 | super(VehicleID, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.img_dir = osp.join(self.dataset_dir, 'image') 33 | self.split_dir = osp.join(self.dataset_dir, 'train_test_split') 34 | self.train_list = osp.join(self.split_dir, 'train_list.txt') 35 | self.test_size = test_size 36 | 37 | if self.test_size == 800: 38 | self.test_list = osp.join(self.split_dir, 'test_list_800.txt') 39 | elif self.test_size == 1600: 40 | self.test_list = osp.join(self.split_dir, 'test_list_1600.txt') 41 | elif self.test_size == 2400: 42 | self.test_list = osp.join(self.split_dir, 'test_list_2400.txt') 43 | 44 | print(self.test_list) 45 | 46 | self.check_before_run() 47 | 48 | train, query, gallery = self.process_split(relabel=True) 49 | self.train = train 50 | self.query = query 51 | self.gallery = gallery 52 | 53 | if verbose: 54 | print('=> VehicleID loaded') 55 | self.print_dataset_statistics(train, query, gallery) 56 | 57 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 58 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 59 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 60 | 61 | def check_before_run(self): 62 | """Check if all files are available before going deeper""" 63 | if not osp.exists(self.dataset_dir): 64 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 65 | if not osp.exists(self.split_dir): 66 | raise RuntimeError('"{}" is not available'.format(self.split_dir)) 67 | if not osp.exists(self.train_list): 68 | raise RuntimeError('"{}" is not available'.format(self.train_list)) 69 | if self.test_size not in [800, 1600, 2400]: 70 | raise RuntimeError('"{}" is not available'.format(self.test_size)) 71 | if not osp.exists(self.test_list): 72 | raise RuntimeError('"{}" is not available'.format(self.test_list)) 73 | 74 | def get_pid2label(self, pids): 75 | pid_container = set(pids) 76 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 77 | return pid2label 78 | 79 | def parse_img_pids(self, nl_pairs, pid2label=None): 80 | # il_pair is the pairs of img name and label 81 | output = [] 82 | for info in nl_pairs: 83 | name = info[0] 84 | pid = info[1] 85 | if pid2label is not None: 86 | pid = pid2label[pid] 87 | camid = 0 # don't have camid information use 0 for all 88 | img_path = osp.join(self.img_dir, name+'.jpg') 89 | output.append((img_path, pid, camid)) 90 | return output 91 | 92 | def process_split(self, relabel=False): 93 | # read train paths 94 | train_pid_dict = defaultdict(list) 95 | 96 | # 'train_list.txt' format: 97 | # the first number is the number of image 98 | # the second number is the id of vehicle 99 | with open(self.train_list) as f_train: 100 | train_data = f_train.readlines() 101 | for data in train_data: 102 | name, pid = data.strip().split(' ') 103 | pid = int(pid) 104 | train_pid_dict[pid].append([name, pid]) 105 | train_pids = list(train_pid_dict.keys()) 106 | num_train_pids = len(train_pids) 107 | assert num_train_pids == 13164, 'There should be 13164 vehicles for training,' \ 108 | ' but but got {}, please check the data'\ 109 | .format(num_train_pids) 110 | # print('num of train ids: {}'.format(num_train_pids)) 111 | test_pid_dict = defaultdict(list) 112 | with open(self.test_list) as f_test: 113 | test_data = f_test.readlines() 114 | for data in test_data: 115 | name, pid = data.split(' ') 116 | pid = int(pid) 117 | test_pid_dict[pid].append([name, pid]) 118 | test_pids = list(test_pid_dict.keys()) 119 | num_test_pids = len(test_pids) 120 | assert num_test_pids == self.test_size, 'There should be {} vehicles for testing,' \ 121 | ' but but got {}, please check the data'\ 122 | .format(self.test_size, num_test_pids) 123 | 124 | train_data = [] 125 | query_data = [] 126 | gallery_data = [] 127 | 128 | # for train ids, all images are used in the train set. 129 | for pid in train_pids: 130 | imginfo = train_pid_dict[pid] # imginfo include image name and id 131 | train_data.extend(imginfo) 132 | 133 | # for each test id, random choose one image for gallery 134 | # and the other ones for query. 135 | for pid in test_pids: 136 | imginfo = test_pid_dict[pid] 137 | sample = random.choice(imginfo) 138 | imginfo.remove(sample) 139 | query_data.extend(imginfo) 140 | gallery_data.append(sample) 141 | 142 | if relabel: 143 | train_pid2label = self.get_pid2label(train_pids) 144 | else: 145 | train_pid2label = None 146 | # for key, value in train_pid2label.items(): 147 | # print('{key}:{value}'.format(key=key, value=value)) 148 | 149 | train = self.parse_img_pids(train_data, train_pid2label) 150 | query = self.parse_img_pids(query_data) 151 | gallery = self.parse_img_pids(gallery_data) 152 | return train, query, gallery 153 | 154 | -------------------------------------------------------------------------------- /reid/data/veri.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import re 7 | import os.path as osp 8 | 9 | from .bases import BaseImageDataset 10 | 11 | 12 | class VeRi(BaseImageDataset): 13 | """ 14 | VeRi 15 | Reference: 16 | Liu, X., Liu, W., Ma, H., Fu, H.: Large-scale vehicle re-identification in urban surveillance videos. In: IEEE % 17 | International Conference on Multimedia and Expo. (2016) accepted. 18 | Dataset statistics: 19 | # identities: 776 vehicles(576 for training and 200 for testing) 20 | # images: 37778 (train) + 11579 (query) 21 | """ 22 | dataset_dir = 'VeRi' 23 | 24 | def __init__(self, root, verbose=True, **kwargs): 25 | super(VeRi, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 30 | 31 | self.check_before_run() 32 | 33 | train = self.process_dir(self.train_dir, relabel=True) 34 | query = self.process_dir(self.query_dir, relabel=False) 35 | gallery = self.process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | print('=> VeRi loaded') 39 | self.print_dataset_statistics(train, query, gallery) 40 | 41 | self.train = train 42 | self.query = query 43 | self.gallery = gallery 44 | 45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 48 | 49 | def check_before_run(self): 50 | """Check if all files are available before going deeper""" 51 | if not osp.exists(self.dataset_dir): 52 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 53 | if not osp.exists(self.train_dir): 54 | raise RuntimeError('"{}" is not available'.format(self.train_dir)) 55 | if not osp.exists(self.query_dir): 56 | raise RuntimeError('"{}" is not available'.format(self.query_dir)) 57 | if not osp.exists(self.gallery_dir): 58 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir)) 59 | 60 | def process_dir(self, dir_path, relabel=False): 61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 62 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 63 | 64 | pid_container = set() 65 | for img_path in img_paths: 66 | pid, _ = map(int, pattern.search(img_path).groups()) 67 | if pid == -1: 68 | continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | dataset = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1: 76 | continue # junk images are just ignored 77 | assert 0 <= pid <= 776 # pid == 0 means background 78 | assert 1 <= camid <= 20 79 | camid -= 1 # index starts from 0 80 | if relabel: 81 | pid = pid2label[pid] 82 | dataset.append((img_path, pid, camid)) 83 | 84 | return dataset 85 | 86 | -------------------------------------------------------------------------------- /reid/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import do_train 2 | from .evaluator import evaluate 3 | -------------------------------------------------------------------------------- /reid/engine/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import os.path as osp 4 | 5 | from ..utils.meters import AverageMeter 6 | from ..utils.metrics import R1_mAP 7 | 8 | 9 | def evaluate(args, model, test_loader, num_query, remove_cam): 10 | """Standard Re-ID evaluating engine.""" 11 | print_freq = args.print_freq 12 | evaluator = R1_mAP(num_query, max_rank=50, feat_norm=True, remove_cam=remove_cam) 13 | evaluator.reset() 14 | 15 | model.eval() 16 | for n_iter, (img, pid, camid) in enumerate(test_loader): 17 | with torch.no_grad(): 18 | img = img.cuda() 19 | feat = model(img) 20 | evaluator.update((feat, pid, camid)) 21 | 22 | if n_iter % print_freq == 0: 23 | print('Evaluating: [{}/{}]'.format(n_iter, len(test_loader))) 24 | 25 | cmc, mAP, _, _, _, _, _ = evaluator.compute() 26 | print('Validation Results') 27 | print('mAP: {:.1%}'.format(mAP)) 28 | for r in [1, 5, 10]: 29 | print('CMC curve, Rank-{:<3}:{:.1%}'.format(r, cmc[r - 1])) 30 | 31 | -------------------------------------------------------------------------------- /reid/engine/searcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import os.path as osp 5 | from torch.nn import functional as F 6 | 7 | from ..utils.meters import AverageMeter 8 | from ..utils.metrics import R1_mAP 9 | 10 | 11 | class Architect(object): 12 | """Architecture parameter maintenance and update""" 13 | def __init__(self, args, model): 14 | self.model = model 15 | self.optimizer = torch.optim.Adam( 16 | self.model.arch_parameters(), 17 | lr=args.arch_lr, betas=(0.5, 0.999), 18 | weight_decay=args.arch_weight_decay 19 | ) 20 | self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 21 | self.optimizer, milestones=[150, 225, 350] 22 | ) 23 | 24 | def step(self, input_valid, target_valid): 25 | self.optimizer.zero_grad() 26 | self._backward_step(input_valid, target_valid) 27 | self.optimizer.step() 28 | 29 | def _backward_step(self, input_valid, target_valid): 30 | loss = self.model._loss(input_valid, target_valid) 31 | loss.backward() 32 | 33 | def step_scheduler(self): 34 | self.lr_scheduler.step() 35 | 36 | 37 | def do_search(args, model, criterion, train_loader, valid_loader, test_loader, 38 | optimizer, lr_scheduler, num_query, remove_cam): 39 | """The engine for searching Re-ID architectures.""" 40 | loss_meter = AverageMeter() 41 | acc_meter = AverageMeter() 42 | 43 | print_freq = args.print_freq 44 | geno_interval = args.geno_interval 45 | eval_interval = args.eval_interval 46 | 47 | architect = Architect(args, model) 48 | 49 | for epoch in range(args.epochs): 50 | loss_meter.reset() 51 | acc_meter.reset() 52 | 53 | if (epoch + 1) % geno_interval == 0 or (epoch + 1) == args.epochs: 54 | genotype = model.genotype() 55 | print('genotype = {}'.format(genotype)) 56 | print(F.softmax(model.arch_parameters()[0], dim=-1)) 57 | 58 | model.train() 59 | for n_iter, (img, target, _) in enumerate(train_loader): 60 | optimizer.zero_grad() 61 | img = img.cuda() 62 | target = target.cuda() 63 | (img_s, target_s, _) = valid_loader.next() 64 | img_s = img_s.cuda() 65 | target_s = target_s.cuda() 66 | 67 | if n_iter > 0: 68 | # Update the architecture parameter. 69 | architect.step(img_s, target_s) 70 | 71 | feats, logits = model(img) 72 | loss = criterion(feats, target) 73 | 74 | # Update the model parameter. 75 | loss.backward() 76 | optimizer.step() 77 | 78 | acc = (logits.max(1)[1] == target).float().mean() 79 | loss_meter.update(loss.item(), img.shape[0]) 80 | acc_meter.update(acc, 1) 81 | 82 | if n_iter % print_freq == 0: 83 | print('Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Lr: {:.2e}' 84 | .format(epoch, n_iter, len(train_loader), loss_meter.avg, 85 | acc_meter.avg, lr_scheduler.get_last_lr()[0])) 86 | 87 | lr_scheduler.step() 88 | architect.step_scheduler() 89 | 90 | if (epoch + 1) % eval_interval == 0 or (epoch + 1) == args.epochs: 91 | torch.save(model.state_dict(), osp.join(args.logs_dir, 'model_{}.pth'.format(epoch))) 92 | evaluator = R1_mAP(num_query, max_rank=50, feat_norm=True, remove_cam=remove_cam) 93 | evaluator.reset() 94 | model.eval() 95 | for n_iter, (img, pid, camid) in enumerate(test_loader): 96 | with torch.no_grad(): 97 | img = img.cuda() 98 | feat = model(img) 99 | evaluator.update((feat, pid, camid)) 100 | 101 | if n_iter % print_freq == 0: 102 | print('Evaluating: [{}/{}]'.format(n_iter, len(test_loader))) 103 | 104 | cmc, mAP, _, _, _, _, _ = evaluator.compute() 105 | print('Validation Results - Epoch[{}]'.format(epoch)) 106 | print('mAP: {:.1%}'.format(mAP)) 107 | for r in [1, 5, 10]: 108 | print('CMC curve, Rank-{:<3}:{:.1%}'.format(r, cmc[r - 1])) 109 | del evaluator 110 | -------------------------------------------------------------------------------- /reid/engine/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import os.path as osp 4 | 5 | from ..utils.meters import AverageMeter 6 | from ..utils.metrics import R1_mAP 7 | 8 | 9 | def do_train(args, model, criterion, train_loader, test_loader, 10 | optimizer, lr_scheduler, num_query, remove_cam): 11 | """Standard Re-ID training engine.""" 12 | loss_meter = AverageMeter() 13 | acc_meter = AverageMeter() 14 | 15 | print_freq = args.print_freq 16 | eval_interval = args.eval_interval 17 | 18 | for epoch in range(args.epochs): 19 | loss_meter.reset() 20 | acc_meter.reset() 21 | 22 | model.train() 23 | for n_iter, (img, pid, _) in enumerate(train_loader): 24 | optimizer.zero_grad() 25 | img = img.cuda() 26 | target = pid.cuda() 27 | 28 | feats, logits, f_feats, f_logits, sam_logits = model(img) 29 | loss = criterion(feats, logits, sam_logits, target, sam=True) + criterion(f_feats, f_logits, sam_logits, target) 30 | 31 | loss.backward() 32 | optimizer.step() 33 | 34 | acc = (logits.max(1)[1] == target).float().mean() 35 | loss_meter.update(loss.item(), img.shape[0]) 36 | acc_meter.update(acc, 1) 37 | 38 | if n_iter % print_freq == 0: 39 | print('Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Lr: {:.2e}' 40 | .format(epoch, n_iter, len(train_loader), loss_meter.avg, 41 | acc_meter.avg, lr_scheduler.get_last_lr()[0])) 42 | 43 | lr_scheduler.step() 44 | 45 | if (epoch + 1) % eval_interval == 0 or (epoch + 1) == args.epochs: 46 | torch.save(model.state_dict(), osp.join(args.logs_dir, 'model_{}.pth'.format(epoch))) 47 | evaluator = R1_mAP(num_query, max_rank=50, feat_norm=True, remove_cam=remove_cam) 48 | evaluator.reset() 49 | model.eval() 50 | for n_iter, (img, pid, camid) in enumerate(test_loader): 51 | with torch.no_grad(): 52 | img = img.cuda() 53 | feat = model(img) 54 | evaluator.update((feat, pid, camid)) 55 | 56 | if n_iter % print_freq == 0: 57 | print('Evaluating: [{}/{}]'.format(n_iter, len(test_loader))) 58 | 59 | cmc, mAP, _, _, _, _, _ = evaluator.compute() 60 | print('Validation Results - Epoch[{}]'.format(epoch)) 61 | print('mAP: {:.1%}'.format(mAP)) 62 | for r in [1, 5, 10]: 63 | print('CMC curve, Rank-{:<3}:{:.1%}'.format(r, cmc[r - 1])) 64 | del evaluator 65 | 66 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .msinet import msinet_x1_0 2 | -------------------------------------------------------------------------------- /reid/models/cm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.autograd as autograd 5 | 6 | 7 | class CM(autograd.Function): 8 | 9 | @staticmethod 10 | def forward(ctx, inputs, targets, features, momentum): 11 | ctx.features = features 12 | ctx.momentum = momentum 13 | ctx.save_for_backward(inputs, targets) 14 | outputs = inputs.mm(ctx.features.t()) 15 | 16 | return outputs 17 | 18 | @staticmethod 19 | def backward(ctx, grad_outputs): 20 | inputs, targets = ctx.saved_tensors 21 | grad_inputs = None 22 | if ctx.needs_input_grad[0]: 23 | grad_inputs = grad_outputs.mm(ctx.features) 24 | 25 | # momentum update 26 | for x, y in zip(inputs, targets): 27 | ctx.features[y] = ctx.momentum * ctx.features[y] + (1. - ctx.momentum) * x 28 | ctx.features[y] /= ctx.features[y].norm() 29 | 30 | return grad_inputs, None, None, None 31 | 32 | 33 | def cm(inputs, indexes, features, momentum=0.5): 34 | return CM.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 35 | 36 | 37 | class ClusterMemory(nn.Module): 38 | def __init__(self, num_features, num_samples, temp=0.05, momentum=0.2, use_hard=False): 39 | super(ClusterMemory, self).__init__() 40 | self.num_features = num_features 41 | self.num_samples = num_samples 42 | 43 | self.momentum = momentum 44 | self.temp = temp 45 | self.use_hard = use_hard 46 | 47 | self.register_buffer('features', torch.zeros(num_samples, num_features)) 48 | 49 | def forward(self, inputs, targets): 50 | 51 | inputs = F.normalize(inputs, dim=1).cuda() 52 | if self.use_hard: 53 | outputs = cm_hard(inputs, targets, self.features, self.momentum) 54 | else: 55 | outputs = cm(inputs, targets, self.features, self.momentum) 56 | 57 | outputs /= self.temp 58 | loss = F.cross_entropy(outputs, targets) 59 | return loss 60 | 61 | -------------------------------------------------------------------------------- /reid/models/msinet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .operations import * 8 | from .sam import AlignModule 9 | from reid.utils.serialization import copy_state_dict 10 | 11 | 12 | class Cell(nn.Module): 13 | """Basic form of a cell. 14 | Consisted of two branches, with 2 and 6 light conv 3x3, respectively. 15 | There are two interaction modules in the middle and tail of the cell.""" 16 | 17 | def __init__(self, in_channels, out_channels, genotypes): 18 | super(Cell, self).__init__() 19 | mid_channels = out_channels // 4 20 | self.conv1a = Conv1x1(in_channels, mid_channels) 21 | self.conv1b = Conv1x1(in_channels, mid_channels) 22 | 23 | self.conv2a = LightConv3x3(mid_channels, mid_channels) 24 | self.conv2b = nn.Sequential( 25 | LightConv3x3(mid_channels, mid_channels), 26 | LightConv3x3(mid_channels, mid_channels), 27 | LightConv3x3(mid_channels, mid_channels), 28 | ) 29 | # The first interaction module. 30 | self._op2 = OPS[genotypes[0]](mid_channels, mid_channels) 31 | 32 | self.conv3a = LightConv3x3(mid_channels, mid_channels) 33 | self.conv3b = nn.Sequential( 34 | LightConv3x3(mid_channels, mid_channels), 35 | LightConv3x3(mid_channels, mid_channels), 36 | LightConv3x3(mid_channels, mid_channels), 37 | ) 38 | # The second interaction module. 39 | self._op3 = OPS[genotypes[1]](mid_channels, mid_channels) 40 | 41 | # Fusing operation. 42 | self.conv4a = Conv1x1Linear(mid_channels, out_channels) 43 | self.conv4b = Conv1x1Linear(mid_channels, out_channels) 44 | self.downsample = None 45 | if in_channels != out_channels: 46 | self.downsample = Conv1x1Linear(in_channels, out_channels) 47 | 48 | def forward(self, x): 49 | identity = x 50 | x1a = self.conv1a(x) 51 | x1b = self.conv1b(x) 52 | 53 | x2a = self.conv2a(x1a) 54 | x2b = self.conv2b(x1b) 55 | x2a, x2b = self._op2((x2a, x2b)) 56 | 57 | x3a = self.conv3a(x2a) 58 | x3b = self.conv3b(x2b) 59 | x3a, x3b = self._op3((x3a, x3b)) 60 | 61 | x4 = self.conv4a(x3a) + self.conv4b(x3b) 62 | if self.downsample is not None: 63 | identity = self.downsample(identity) 64 | out = x4 + identity 65 | return F.relu(out) 66 | 67 | 68 | class MSINet(nn.Module): 69 | """The basic structure of the proposed MSINet.""" 70 | 71 | def __init__(self, args, num_classes, channels, genotypes): 72 | super(MSINet, self).__init__() 73 | self.num_classes = num_classes 74 | self.channels = channels 75 | 76 | self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3) 77 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 78 | self.cells = nn.ModuleList() 79 | # Consisted of 6 cells in total. 80 | for i in range(3): 81 | in_channels = self.channels[i] 82 | out_channels = self.channels[i + 1] 83 | print(genotypes[i * 4 : i * 4 + 4]) 84 | self.cells += [ 85 | Cell(in_channels, out_channels, genotypes[i * 4 : i * 4 + 2]), 86 | Cell(out_channels, out_channels, genotypes[i * 4 + 2 : i * 4 + 4]) 87 | ] 88 | if i != 2: 89 | # Downsample 90 | self.cells += [ 91 | nn.Sequential( 92 | Conv1x1(out_channels, out_channels), 93 | nn.AvgPool2d(2, stride=2) 94 | ) 95 | ] 96 | 97 | self.gap = nn.AdaptiveAvgPool2d(1) 98 | self.head = nn.BatchNorm1d(channels[-1]) 99 | self.head.bias.requires_grad_(False) 100 | self.fc = nn.Linear(channels[-1], channels[-1], bias=False) 101 | self.classifier = nn.Linear(channels[-1], num_classes) 102 | 103 | self.f_conv = nn.Conv2d(channels[-2], 128, 1) 104 | self.f_head = nn.BatchNorm1d(256) 105 | self.f_head.bias.requires_grad_(False) 106 | self.f_fc = nn.Linear(256, 256, bias=False) 107 | self.f_classifier = nn.Linear(256, num_classes) 108 | 109 | self.sam_mode = args.sam_mode 110 | if args.sam_mode != 'none': 111 | if args.source_dataset in ['veri', 'vehicleid']: 112 | self.align_module = AlignModule(16, 16, channels[-1]) 113 | else: 114 | self.align_module = AlignModule(16, 8, channels[-1]) 115 | 116 | self._init_params() 117 | 118 | def featuremaps(self, x): 119 | x = self.conv1(x) 120 | x = self.maxpool(x) 121 | for cell_idx, cell in enumerate(self.cells): 122 | x = cell(x) 123 | if cell_idx == 5: 124 | f_x = x 125 | 126 | return x, f_x 127 | 128 | def forward(self, x, train_transfer=False, test_transfer=False): 129 | x, f_x = self.featuremaps(x) 130 | 131 | if self.sam_mode != 'none': 132 | sam_scores = self.align_module(x) 133 | else: 134 | sam_scores = None 135 | 136 | v = self.gap(x).view(x.shape[0], -1) 137 | f_x = self.f_conv(f_x) 138 | height = f_x.shape[2] 139 | f_v_up = self.gap(f_x[:, :, :height // 2, :]).view(f_x.shape[0], -1) 140 | f_v_down = self.gap(f_x[:, :, height // 2:, :]).view(f_x.shape[0], -1) 141 | f_v = torch.cat((f_v_up, f_v_down), dim=1) 142 | 143 | n_v = self.head(self.fc(v)) 144 | n_f_v = self.f_head(self.f_fc(f_v)) 145 | 146 | if not self.training: 147 | if test_transfer: 148 | return torch.cat((n_v, n_f_v), dim=1) 149 | else: 150 | return torch.cat((v, f_v), dim=1) 151 | 152 | y = self.classifier(n_v) 153 | f_y = self.f_classifier(n_f_v) 154 | 155 | if train_transfer: 156 | return n_v, y, n_f_v, f_y, sam_scores 157 | else: 158 | return v, y, f_v, f_y, sam_scores 159 | 160 | def _init_params(self): 161 | for m in self.modules(): 162 | if isinstance(m, nn.Conv2d): 163 | nn.init.kaiming_normal_( 164 | m.weight, mode='fan_out', nonlinearity='relu' 165 | ) 166 | if m.bias is not None: 167 | nn.init.constant_(m.bias, 0) 168 | 169 | elif isinstance(m, nn.BatchNorm2d): 170 | nn.init.constant_(m.weight, 1) 171 | nn.init.constant_(m.bias, 0) 172 | 173 | elif isinstance(m, nn.BatchNorm1d): 174 | nn.init.constant_(m.weight, 1) 175 | nn.init.constant_(m.bias, 0) 176 | 177 | elif isinstance(m, nn.Linear): 178 | nn.init.normal_(m.weight, 0, 0.01) 179 | if m.bias is not None: 180 | nn.init.constant_(m.bias, 0) 181 | 182 | 183 | def msinet_x1_0(args, num_classes=1000): 184 | genotypes, pretrained_weight = genotype_factory[args.genotypes] 185 | model = MSINet( 186 | args, 187 | num_classes, 188 | channels=[64, 256, 384, 512], 189 | genotypes=genotypes 190 | ) 191 | if args.pretrained: 192 | copy_state_dict( 193 | torch.load( 194 | osp.join(args.pretrain_dir, pretrained_weight) 195 | )['state_dict'], model 196 | ) 197 | return model 198 | -------------------------------------------------------------------------------- /reid/models/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | OPS = { 7 | 'none': lambda C_in, C_out: Identity(), 8 | 'exc': lambda C_in, C_out: Exchange(), 9 | 'ag': lambda C_in, C_out: ChannelGate(C_in, C_out), 10 | 'cross_att': lambda C_in, C_out: CrossAttention(C_in, C_out) 11 | } 12 | 13 | 14 | genotype_factory = { 15 | 'msmt': (['ag', 'ag', 'exc', 'ag', 16 | 'cross_att', 'ag', 'ag', 'none', 17 | 'ag', 'cross_att', 'exc', 'cross_att'], 18 | 'msinet_msmt.pth.tar'), 19 | } 20 | 21 | 22 | ########## 23 | # Basic layers 24 | ########## 25 | class ConvLayer(nn.Module): 26 | """Convolution layer (conv + bn + relu).""" 27 | 28 | def __init__( 29 | self, in_channels, out_channels, kernel_size, 30 | stride=1, padding=0, groups=1, IN=False 31 | ): 32 | super(ConvLayer, self).__init__() 33 | self.conv = nn.Conv2d( 34 | in_channels, out_channels, kernel_size, 35 | stride=stride, padding=padding, bias=False, 36 | groups=groups 37 | ) 38 | if IN: 39 | self.bn = nn.InstanceNorm2d(out_channels, affine=True) 40 | else: 41 | self.bn = nn.BatchNorm2d(out_channels) 42 | self.relu = nn.ReLU(inplace=True) 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | x = self.bn(x) 47 | x = self.relu(x) 48 | return x 49 | 50 | 51 | class Conv1x1(nn.Module): 52 | """1x1 convolution + bn + relu.""" 53 | 54 | def __init__(self, in_channels, out_channels, stride=1, groups=1): 55 | super(Conv1x1, self).__init__() 56 | self.conv = nn.Conv2d( 57 | in_channels, out_channels, 1, stride=stride, 58 | padding=0, bias=False, groups=groups 59 | ) 60 | self.bn = nn.BatchNorm2d(out_channels) 61 | self.relu = nn.ReLU(inplace=True) 62 | 63 | def forward(self, x): 64 | x = self.conv(x) 65 | x = self.bn(x) 66 | x = self.relu(x) 67 | return x 68 | 69 | 70 | class Conv1x1Linear(nn.Module): 71 | """1x1 convolution + bn (w/o non-linearity).""" 72 | 73 | def __init__(self, in_channels, out_channels, stride=1): 74 | super(Conv1x1Linear, self).__init__() 75 | self.conv = nn.Conv2d( 76 | in_channels, out_channels, 1, stride=stride, padding=0, bias=False 77 | ) 78 | self.bn = nn.BatchNorm2d(out_channels) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | x = self.bn(x) 83 | return x 84 | 85 | 86 | class Conv3x3(nn.Module): 87 | """3x3 convolution + bn + relu.""" 88 | 89 | def __init__(self, in_channels, out_channels, stride=1, groups=1): 90 | super(Conv3x3, self).__init__() 91 | self.conv = nn.Conv2d( 92 | in_channels, out_channels, 3, stride=stride, 93 | padding=1, bias=False, groups=groups 94 | ) 95 | self.bn = nn.BatchNorm2d(out_channels) 96 | self.relu = nn.ReLU(inplace=True) 97 | 98 | def forward(self, x): 99 | x = self.conv(x) 100 | x = self.bn(x) 101 | x = self.relu(x) 102 | return x 103 | 104 | 105 | class LightConv3x3(nn.Module): 106 | """Lightweight 3x3 convolution. 107 | 108 | 1x1 (linear) + dw 3x3 (nonlinear). 109 | """ 110 | 111 | def __init__(self, in_channels, out_channels): 112 | super(LightConv3x3, self).__init__() 113 | self.conv1 = nn.Conv2d( 114 | in_channels, out_channels, 1, stride=1, padding=0, bias=False 115 | ) 116 | self.conv2 = nn.Conv2d( 117 | out_channels, out_channels, 3, stride=1, padding=1, 118 | bias=False, groups=out_channels 119 | ) 120 | self.bn = nn.BatchNorm2d(out_channels) 121 | self.relu = nn.ReLU(inplace=True) 122 | 123 | def forward(self, x): 124 | x = self.conv1(x) 125 | x = self.conv2(x) 126 | x = self.bn(x) 127 | x = self.relu(x) 128 | return x 129 | 130 | 131 | ########## 132 | # Blocks for Multi-scale Feature Interaction between Two Branches 133 | ########## 134 | class Identity(nn.Module): 135 | """Return the original features.""" 136 | def __init__(self): 137 | super(Identity, self).__init__() 138 | 139 | def forward(self, x): 140 | return x 141 | 142 | 143 | class Exchange(nn.Module): 144 | """Directly exchange the features of two branches.""" 145 | def __init__(self): 146 | super(Exchange, self).__init__() 147 | 148 | def forward(self, x): 149 | x1, x2 = x 150 | return x2, x1 151 | 152 | 153 | class ChannelGate(nn.Module): 154 | """A mini-network that generates channel-wise gates conditioned on input tensor.""" 155 | 156 | def __init__( 157 | self, in_channels, num_gates=None, return_gates=False, 158 | gate_activation='sigmoid', reduction=16, layer_norm=False 159 | ): 160 | super(ChannelGate, self).__init__() 161 | if num_gates is None: 162 | num_gates = in_channels 163 | self.return_gates = return_gates 164 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 165 | self.fc1 = nn.Conv2d( 166 | in_channels, in_channels // reduction, kernel_size=1, 167 | bias=True, padding=0 168 | ) 169 | self.norm1 = None 170 | if layer_norm: 171 | self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) 172 | self.relu = nn.ReLU(inplace=True) 173 | self.fc2 = nn.Conv2d( 174 | in_channels // reduction, num_gates, kernel_size=1, 175 | bias=True, padding=0 176 | ) 177 | if gate_activation == 'sigmoid': 178 | self.gate_activation = nn.Sigmoid() 179 | elif gate_activation == 'relu': 180 | self.gate_activation = nn.ReLU(inplace=True) 181 | elif gate_activation == 'linear': 182 | self.gate_activation = None 183 | else: 184 | raise RuntimeError( 185 | "Unknown gate activation: {}".format(gate_activation) 186 | ) 187 | 188 | def forward(self, xs): 189 | out = [] 190 | for x in xs: 191 | input = x 192 | x = self.global_avgpool(x) 193 | x = self.fc1(x) 194 | if self.norm1 is not None: 195 | x = self.norm1(x) 196 | x = self.relu(x) 197 | x = self.fc2(x) 198 | if self.gate_activation is not None: 199 | x = self.gate_activation(x) 200 | out.append(input * x) 201 | return out 202 | 203 | 204 | class CrossAttention(nn.Module): 205 | """Exchange the key feature to calculate the correlation for two branches.""" 206 | def __init__(self, in_channels, out_channels): 207 | super(CrossAttention, self).__init__() 208 | self.in_channels = in_channels 209 | self.gamma = nn.Parameter(torch.zeros(1)) 210 | 211 | def forward(self, x): 212 | xa, xb = x 213 | m_bs, C, height, width = xa.size() 214 | 215 | querya = xa.view(m_bs, C, -1) 216 | keya = xa.view(m_bs, C, -1).permute(0, 2, 1) 217 | 218 | queryb = xb.view(m_bs, C, -1) 219 | keyb = xb.view(m_bs, C, -1).permute(0, 2, 1) 220 | 221 | energya = torch.bmm(querya, keyb) 222 | energyb = torch.bmm(queryb, keya) 223 | 224 | def get_output(energy, xin): 225 | max_energy_0 = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) 226 | energy_new = max_energy_0 - energy 227 | attention = F.softmax(energy_new, dim=-1) 228 | proj_value = xin.view(m_bs, C, -1) 229 | 230 | out = torch.bmm(attention, proj_value) 231 | out = out.view(m_bs, C, height, width) 232 | 233 | gamma = self.gamma.to(out.device) 234 | out = gamma * out + xin 235 | return out 236 | 237 | return get_output(energya, xa), get_output(energyb, xb) 238 | -------------------------------------------------------------------------------- /reid/models/sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AlignModule(nn.Module): 6 | def __init__(self, input_height, input_width, in_planes): 7 | super(AlignModule, self).__init__() 8 | self.qaconv = QAConv(in_planes, input_height, input_width) 9 | self.pos_pam = PAM_Module(in_planes) 10 | 11 | def forward(self, x): 12 | self.qaconv.make_kernel(x) 13 | kernel_score = self.qaconv(x) 14 | pos_score = self.pos_pam(x) 15 | scores = torch.cat( 16 | (kernel_score.max(dim=3)[0], 17 | pos_score.unsqueeze(0).max(dim=3)[0]), 18 | dim=0 19 | ) 20 | 21 | return scores 22 | 23 | 24 | class QAConv(nn.Module): 25 | """Un-parametric correlation calculation""" 26 | def __init__(self, num_features, height, width): 27 | super(QAConv, self).__init__() 28 | self.num_features = num_features 29 | self.height = height 30 | self.width = width 31 | 32 | def make_kernel(self, features): 33 | self.kernel = features 34 | 35 | def forward(self, features): 36 | hw = self.height * self.width 37 | batch_size = features.shape[0] 38 | score = torch.einsum('g c h w, p c y x -> g p y x h w', features, self.kernel) 39 | score = score.view(batch_size, -1, hw, hw) 40 | 41 | return score 42 | 43 | 44 | class PAM_Module(nn.Module): 45 | """Position Attention Module.""" 46 | def __init__(self, in_planes): 47 | super(PAM_Module, self).__init__() 48 | self.in_planes = in_planes 49 | 50 | self.query_conv = nn.Conv2d( 51 | in_channels=in_planes, out_channels=in_planes // 8, kernel_size=1 52 | ) 53 | self.key_conv = nn.Conv2d( 54 | in_channels=in_planes, out_channels=in_planes // 8, kernel_size=1 55 | ) 56 | 57 | def forward(self, x): 58 | batch_size, C, height, width = x.shape 59 | proj_query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1) 60 | proj_key = self.key_conv(x).view(batch_size, -1, height * width) 61 | energy = torch.bmm(proj_query, proj_key) 62 | 63 | return energy 64 | -------------------------------------------------------------------------------- /reid/models/search_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torch.nn import functional as F 5 | from torch.autograd import Variable 6 | 7 | from .operations import * 8 | 9 | 10 | PRIMITIVES = [ 11 | 'none', 12 | 'ag', 13 | 'cross_att', 14 | 'exc' 15 | ] 16 | 17 | 18 | class Cell(nn.Module): 19 | """The searching form of cells. 20 | Adopt the standard DARTS scheme, where the output is the 21 | weighted sum of all options.""" 22 | 23 | def __init__(self, in_channels, out_channels): 24 | super(Cell, self).__init__() 25 | mid_channels = in_channels // 4 26 | self.conv1a = Conv1x1(in_channels, mid_channels) 27 | self.conv1b = Conv1x1(in_channels, mid_channels) 28 | 29 | self.conv2a = LightConv3x3(mid_channels, mid_channels) 30 | self.conv2b = nn.Sequential( 31 | LightConv3x3(mid_channels, mid_channels), 32 | LightConv3x3(mid_channels, mid_channels), 33 | LightConv3x3(mid_channels, mid_channels), 34 | ) 35 | # Create interaction options. 36 | self._op2s = nn.ModuleList() 37 | for primitive in PRIMITIVES: 38 | op = OPS[primitive](mid_channels, mid_channels) 39 | self._op2s.append(op) 40 | 41 | self.conv3a = LightConv3x3(mid_channels, mid_channels) 42 | self.conv3b = nn.Sequential( 43 | LightConv3x3(mid_channels, mid_channels), 44 | LightConv3x3(mid_channels, mid_channels), 45 | LightConv3x3(mid_channels, mid_channels), 46 | ) 47 | # Create interaction options. 48 | self._op3s = nn.ModuleList() 49 | for primitive in PRIMITIVES: 50 | op = OPS[primitive](mid_channels, mid_channels) 51 | self._op3s.append(op) 52 | 53 | self.conv4a = Conv1x1Linear(mid_channels, out_channels) 54 | self.conv4b = Conv1x1Linear(mid_channels, out_channels) 55 | self.downsample = None 56 | if in_channels != out_channels: 57 | self.downsample = Conv1x1Linear(in_channels, out_channels) 58 | 59 | def forward(self, x, weights): 60 | identity = x 61 | x1a = self.conv1a(x) 62 | x1b = self.conv1b(x) 63 | 64 | x2a = self.conv2a(x1a) 65 | x2b = self.conv2b(x1b) 66 | x2as, x2bs = [], [] 67 | for op, w in zip(self._op2s, weights[0]): 68 | x2a_op, x2b_op = op((x2a, x2b)) 69 | x2as.append(x2a_op * w) 70 | x2bs.append(x2b_op * w) 71 | x2a = sum(x2as) 72 | x2b = sum(x2bs) 73 | 74 | x3a = self.conv3a(x2a) 75 | x3b = self.conv3b(x2b) 76 | x3as, x3bs = [], [] 77 | for op, w in zip(self._op3s, weights[1]): 78 | x3a_op, x3b_op = op((x3a, x3b)) 79 | x3as.append(x3a_op * w) 80 | x3bs.append(x3b_op * w) 81 | x3a = sum(x3as) 82 | x3b = sum(x3bs) 83 | 84 | x4 = self.conv4a(x3a) + self.conv4b(x3b) 85 | if self.downsample is not None: 86 | identity = self.downsample(identity) 87 | out = x4 + identity 88 | return F.relu(out) 89 | 90 | 91 | class SearchCNN(nn.Module): 92 | """The searching form of MSINet. 93 | Simultaneously maintain the model parameters and architecture parameters.""" 94 | 95 | def __init__(self, num_classes, channels, criterion): 96 | super(SearchCNN, self).__init__() 97 | self.num_classes = num_classes 98 | self.channels = channels 99 | self._criterion = criterion 100 | 101 | self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3) 102 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 103 | self.cells = nn.ModuleList() 104 | for i in range(3): 105 | in_channels = self.channels[i] 106 | out_channels = self.channels[i + 1] 107 | self.cells += [ 108 | Cell(in_channels, out_channels), 109 | Cell(out_channels, out_channels), 110 | ] 111 | if i != 2: 112 | self.cells += [ 113 | nn.Sequential( 114 | Conv1x1(out_channels, out_channels), 115 | nn.AvgPool2d(2, stride=2) 116 | ) 117 | ] 118 | 119 | out_planes = channels[-1] 120 | self.gap = nn.AdaptiveAvgPool2d(1) 121 | 122 | self.head = nn.BatchNorm1d(out_planes) 123 | self.head.bias.requires_grad_(False) 124 | self.classifier = nn.Linear(out_planes, num_classes, bias=False) 125 | 126 | self.reset_params() 127 | self.reset_alphas() 128 | 129 | def featuremaps(self, x): 130 | x = self.conv1(x) 131 | x = self.maxpool(x) 132 | 133 | weights = F.softmax(self._arch_weights, dim=-1) 134 | for cell_idx, cell in enumerate(self.cells): 135 | if cell_idx % 3 == 2: 136 | x = cell(x) 137 | else: 138 | tmp_idx = cell_idx - cell_idx // 3 139 | x = cell(x, weights[tmp_idx * 2 : tmp_idx * 2 + 2]) 140 | 141 | return x 142 | 143 | def forward(self, x): 144 | x = self.featuremaps(x) 145 | x = self.gap(x) 146 | x = x.view(x.shape[0], -1) 147 | 148 | if not self.training: 149 | return x 150 | 151 | bn_x = self.head(x) 152 | prob = self.classifier(bn_x) 153 | 154 | return x, prob 155 | 156 | def _loss(self, img, target): 157 | x, prob = self(img) 158 | return self._criterion(x, target) 159 | 160 | def reset_alphas(self): 161 | num_ops = len(PRIMITIVES) 162 | self._arch_weights = Variable( 163 | 1e-3 * torch.randn((12, num_ops)).cuda(), requires_grad=True 164 | ) 165 | self._arch_parameters = [self._arch_weights,] 166 | 167 | def arch_parameters(self): 168 | return self._arch_parameters 169 | 170 | def genotype(self): 171 | weights = self._arch_parameters[0] 172 | gene = [] 173 | for i in range(12): 174 | w = weights[i] 175 | best = torch.argmax(w) 176 | gene.append(PRIMITIVES[best]) 177 | 178 | return gene 179 | 180 | def reset_params(self): 181 | for m in self.modules(): 182 | if isinstance(m, nn.Conv2d): 183 | init.kaiming_normal_(m.weight, mode='fan_out') 184 | if m.bias is not None: 185 | init.constant_(m.bias, 0) 186 | elif isinstance(m, nn.BatchNorm2d): 187 | init.constant_(m.weight, 1) 188 | init.constant_(m.bias, 0) 189 | elif isinstance(m, nn.BatchNorm1d): 190 | init.constant_(m.weight, 1) 191 | init.constant_(m.bias, 0) 192 | elif isinstance(m, nn.Linear): 193 | init.normal_(m.weight, std=0.001) 194 | if m.bias is not None: 195 | init.constant_(m.bias, 0) 196 | 197 | 198 | def build_search_model(criterion, num_classes=1000, pretrained=False): 199 | model = SearchCNN( 200 | num_classes, 201 | channels=[64, 256, 384, 512], 202 | criterion=criterion 203 | ) 204 | return model 205 | 206 | -------------------------------------------------------------------------------- /reid/solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_optimizer import build_optimizer 2 | -------------------------------------------------------------------------------- /reid/solver/build_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .lr_scheduler import WarmupMultiStepLR 3 | 4 | 5 | def build_optimizer(args, model): 6 | params = [] 7 | for key, value in model.named_parameters(): 8 | if not value.requires_grad: 9 | continue 10 | if 'criterion' in key: 11 | continue 12 | params += [{"params": [value], "lr": args.lr, "weight_decay": args.weight_decay, "momentum": args.momentum}] 13 | optimizer = torch.optim.SGD(params) 14 | lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=0.1, warmup_factor=0.01, warmup_iters=args.warmup_step) 15 | 16 | return optimizer, lr_scheduler 17 | -------------------------------------------------------------------------------- /reid/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | from torch.optim.lr_scheduler import * 9 | 10 | 11 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 12 | # separating MultiStepLR with WarmupLR 13 | # but the current LRScheduler design doesn't allow it 14 | 15 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 16 | def __init__( 17 | self, 18 | optimizer, 19 | milestones, 20 | gamma=0.1, 21 | warmup_factor=1.0 / 3, 22 | warmup_iters=500, 23 | warmup_method="linear", 24 | last_epoch=-1, 25 | ): 26 | if not list(milestones) == sorted(milestones): 27 | raise ValueError( 28 | "Milestones should be a list of" " increasing integers. Got {}", 29 | milestones, 30 | ) 31 | 32 | if warmup_method not in ("constant", "linear"): 33 | raise ValueError( 34 | "Only 'constant' or 'linear' warmup_method accepted" 35 | "got {}".format(warmup_method) 36 | ) 37 | self.milestones = milestones 38 | self.gamma = gamma 39 | self.warmup_factor = warmup_factor 40 | self.warmup_iters = warmup_iters 41 | self.warmup_method = warmup_method 42 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 43 | 44 | def get_lr(self): 45 | warmup_factor = 1 46 | if self.last_epoch < self.warmup_iters: 47 | if self.warmup_method == "constant": 48 | warmup_factor = self.warmup_factor 49 | elif self.warmup_method == "linear": 50 | alpha = float(self.last_epoch) / float(self.warmup_iters) 51 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 52 | return [ 53 | base_lr 54 | * warmup_factor 55 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 56 | for base_lr in self.base_lrs 57 | ] 58 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vimar-gu/MSINet/2a8845b6b3d1a3b8baeb864b92f9423c2dc711ee/reid/utils/__init__.py -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /reid/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .rerank import re_ranking 5 | 6 | 7 | def euclidean_distance(qf, gf): 8 | m = qf.shape[0] 9 | n = gf.shape[0] 10 | dist_mat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 11 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 12 | dist_mat.addmm_(qf, gf.t(), beta=1, alpha=-2) 13 | return dist_mat.cpu().numpy() 14 | 15 | 16 | def cosine_similarity(qf, gf): 17 | epsilon = 0.00001 18 | dist_mat = qf.mm(gf.t()) 19 | qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True) # mx1 20 | gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True) # nx1 21 | qg_normdot = qf_norm.mm(gf_norm.t()) 22 | 23 | dist_mat = dist_mat.mul(1 / qg_normdot).cpu().numpy() 24 | dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon) 25 | dist_mat = np.arccos(dist_mat) 26 | return dist_mat 27 | 28 | 29 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, remove_cam=True): 30 | """Evaluation with market1501 metric 31 | Key: for each query identity, its gallery images from the same camera view are discarded. 32 | """ 33 | num_q, num_g = distmat.shape 34 | # distmat g 35 | # q 1 3 2 4 36 | # 4 1 2 3 37 | if num_g < max_rank: 38 | max_rank = num_g 39 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 40 | indices = np.argsort(distmat, axis=1) 41 | # 0 2 1 3 42 | # 1 2 3 0 43 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 44 | # compute cmc curve for each query 45 | all_cmc = [] 46 | all_AP = [] 47 | num_valid_q = 0. # number of valid query 48 | for q_idx in range(num_q): 49 | # get query pid and camid 50 | q_pid = q_pids[q_idx] 51 | q_camid = q_camids[q_idx] 52 | 53 | # remove gallery samples that have the same pid and camid with query 54 | order = indices[q_idx] # select one row 55 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 56 | keep = np.invert(remove) 57 | 58 | # compute cmc curve 59 | # binary vector, positions with value 1 are correct matches 60 | if remove_cam: 61 | orig_cmc = matches[q_idx][keep] 62 | else: 63 | orig_cmc = matches[q_idx] 64 | if not np.any(orig_cmc): 65 | # this condition is true when query identity does not appear in gallery 66 | continue 67 | 68 | cmc = orig_cmc.cumsum() 69 | cmc[cmc > 1] = 1 70 | 71 | all_cmc.append(cmc[:max_rank]) 72 | num_valid_q += 1. 73 | 74 | # compute average precision 75 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 76 | num_rel = orig_cmc.sum() 77 | tmp_cmc = orig_cmc.cumsum() 78 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 79 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 80 | AP = tmp_cmc.sum() / num_rel 81 | all_AP.append(AP) 82 | 83 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 84 | 85 | all_cmc = np.asarray(all_cmc).astype(np.float32) 86 | all_cmc = all_cmc.sum(0) / num_valid_q 87 | mAP = np.mean(all_AP) 88 | 89 | return all_cmc, mAP 90 | 91 | 92 | class R1_mAP(): 93 | def __init__(self, num_query, max_rank=50, feat_norm=True, method='euclidean', reranking=False, remove_cam=True): 94 | super(R1_mAP, self).__init__() 95 | self.num_query = num_query 96 | self.max_rank = max_rank 97 | self.feat_norm = feat_norm 98 | self.method = method 99 | self.reranking = reranking 100 | self.remove_cam = remove_cam 101 | 102 | def reset(self): 103 | self.feats = [] 104 | self.pids = [] 105 | self.camids = [] 106 | 107 | def update(self, output): # called once for each batch 108 | feat, pid, camid = output 109 | self.feats.append(feat) 110 | self.pids.extend(np.asarray(pid)) 111 | self.camids.extend(np.asarray(camid)) 112 | 113 | def compute(self): # called after each epoch 114 | feats = torch.cat(self.feats, dim=0) 115 | if self.feat_norm: 116 | print("The test feature is normalized") 117 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) # along channel 118 | # query 119 | qf = feats[:self.num_query] 120 | q_pids = np.asarray(self.pids[:self.num_query]) 121 | q_camids = np.asarray(self.camids[:self.num_query]) 122 | # gallery 123 | gf = feats[self.num_query:] 124 | g_pids = np.asarray(self.pids[self.num_query:]) 125 | g_camids = np.asarray(self.camids[self.num_query:]) 126 | if self.reranking: 127 | print('=> Enter reranking') 128 | distmat = re_ranking(qf, gf, k1=30, k2=10, lambda_value=0.2) 129 | 130 | else: 131 | if self.method == 'euclidean': 132 | print('=> Computing DistMat with euclidean distance') 133 | distmat = euclidean_distance(qf, gf) 134 | elif self.method == 'cosine': 135 | print('=> Computing DistMat with cosine similarity') 136 | distmat = cosine_similarity(qf, gf) 137 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids, remove_cam=self.remove_cam) 138 | 139 | return cmc, mAP, distmat, self.pids, self.camids, qf, gf 140 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/utils/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Source: https://github.com/zhunzhong07/person-re-ranking 5 | Created on Mon Jun 26 14:46:56 2017 6 | @author: luohao 7 | Modified by Houjing Huang, 2017-12-22. 8 | - This version accepts distance matrix instead of raw features. 9 | - The difference of `/` division between python 2 and 3 is handled. 10 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 11 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 12 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 13 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 14 | API 15 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 16 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 17 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 18 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 19 | Returns: 20 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import print_function 24 | from __future__ import division 25 | 26 | __all__ = ['re_ranking'] 27 | 28 | import numpy as np 29 | 30 | 31 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 32 | 33 | # The following naming, e.g. gallery_num, is different from outer scope. 34 | # Don't care about it. 35 | 36 | original_dist = np.concatenate( 37 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 38 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 39 | axis=0) 40 | original_dist = np.power(original_dist, 2).astype(np.float32) 41 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 42 | V = np.zeros_like(original_dist).astype(np.float32) 43 | initial_rank = np.argsort(original_dist).astype(np.int32) 44 | 45 | query_num = q_g_dist.shape[0] 46 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 47 | all_num = gallery_num 48 | 49 | for i in range(all_num): 50 | # k-reciprocal neighbors 51 | forward_k_neigh_index = initial_rank[i,:k1+1] 52 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 53 | fi = np.where(backward_k_neigh_index==i)[0] 54 | k_reciprocal_index = forward_k_neigh_index[fi] 55 | k_reciprocal_expansion_index = k_reciprocal_index 56 | for j in range(len(k_reciprocal_index)): 57 | candidate = k_reciprocal_index[j] 58 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 59 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 60 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 61 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 62 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 63 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 64 | 65 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 66 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 67 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 68 | original_dist = original_dist[:query_num,] 69 | if k2 != 1: 70 | V_qe = np.zeros_like(V,dtype=np.float32) 71 | for i in range(all_num): 72 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 73 | V = V_qe 74 | del V_qe 75 | del initial_rank 76 | invIndex = [] 77 | for i in range(gallery_num): 78 | invIndex.append(np.where(V[:,i] != 0)[0]) 79 | 80 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 81 | 82 | 83 | for i in range(query_num): 84 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 85 | indNonZero = np.where(V[i,:] != 0)[0] 86 | indImages = [] 87 | indImages = [invIndex[ind] for ind in indNonZero] 88 | for j in range(len(indNonZero)): 89 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 90 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 91 | 92 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 93 | del original_dist 94 | del V 95 | del jaccard_dist 96 | final_dist = final_dist[:query_num,query_num:] 97 | return final_dist 98 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path as osp 3 | import shutil 4 | 5 | import torch 6 | from torch.nn import Parameter 7 | 8 | from .osutils import mkdir_if_missing 9 | 10 | 11 | def read_json(fpath): 12 | with open(fpath, 'r') as f: 13 | obj = json.load(f) 14 | return obj 15 | 16 | 17 | def write_json(obj, fpath): 18 | mkdir_if_missing(osp.dirname(fpath)) 19 | with open(fpath, 'w') as f: 20 | json.dump(obj, f, indent=4, separators=(',', ': ')) 21 | 22 | 23 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 24 | mkdir_if_missing(osp.dirname(fpath)) 25 | torch.save(state, fpath) 26 | if is_best: 27 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 28 | 29 | 30 | def load_checkpoint(fpath): 31 | if osp.isfile(fpath): 32 | checkpoint = torch.load(fpath) 33 | print("=> Loaded checkpoint '{}'".format(fpath)) 34 | return checkpoint 35 | else: 36 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 37 | 38 | 39 | def copy_state_dict(state_dict, model, strip=None): 40 | tgt_state = model.state_dict() 41 | copied_names = set() 42 | for name, param in state_dict.items(): 43 | if strip is not None and name.startswith(strip): 44 | name = name[len(strip):] 45 | if name not in tgt_state: 46 | continue 47 | if isinstance(param, Parameter): 48 | param = param.data 49 | if param.size() != tgt_state[name].size(): 50 | print('mismatch:', name, param.size(), tgt_state[name].size()) 51 | continue 52 | tgt_state[name].copy_(param) 53 | copied_names.add(name) 54 | 55 | missing = set(tgt_state.keys()) - copied_names 56 | if len(missing) > 0: 57 | print("missing keys in state_dict:", missing) 58 | 59 | return model 60 | -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import random 5 | import logging 6 | import argparse 7 | import numpy as np 8 | import os.path as osp 9 | from torch.backends import cudnn 10 | 11 | from reid.data import build_data 12 | from reid.models.search_cnn import build_search_model 13 | from reid.models.cm import ClusterMemory 14 | from reid.solver import build_optimizer 15 | from reid.engine.searcher import do_search 16 | from reid.utils.logging import Logger 17 | 18 | 19 | def main(args): 20 | random.seed(args.seed) 21 | np.random.seed(args.seed) 22 | torch.manual_seed(args.seed) 23 | cudnn.deterministic = True 24 | cudnn.benchmark = True 25 | 26 | if not args.evaluate: 27 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 28 | 29 | print('Running with:\n{}'.format(args)) 30 | 31 | train_loader, valid_loader, test_loader, num_query,\ 32 | num_train_classes, num_valid_classes, = build_data(args, search=True) 33 | train_memory = ClusterMemory(512, num_train_classes).cuda() 34 | valid_memory = ClusterMemory(512, num_valid_classes).cuda() 35 | model = build_search_model(valid_memory, num_train_classes) 36 | model = model.cuda() 37 | 38 | optimizer, lr_scheduler = build_optimizer(args, model) 39 | 40 | do_search(args, model, train_memory, train_loader, valid_loader, 41 | test_loader, optimizer, lr_scheduler, num_query, 42 | remove_cam=(args.source_dataset != 'vehicleid')) 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser() 47 | 48 | # data 49 | parser.add_argument('-ds', '--source-dataset', type=str, default='msmt17') 50 | parser.add_argument('-dt', '--target-dataset', type=str, default='none') 51 | parser.add_argument('-b', '--batch-size', type=int, default=64) 52 | parser.add_argument('--test-batch-size', type=int, default=128) 53 | parser.add_argument('-j', '--workers', type=int, default=4) 54 | parser.add_argument('--height', type=int, default=256) 55 | parser.add_argument('--width', type=int, default=128) 56 | parser.add_argument('--num-instance', type=int, default=4) 57 | 58 | # model 59 | parser.add_argument('-a', '--arch', type=str, default='resnet50') 60 | parser.add_argument('--pretrained', action='store_true', default=False) 61 | parser.add_argument('--reset-params', type=bool, default=False) 62 | parser.add_argument('--genotypes', type=str, default='msmt') 63 | 64 | # loss 65 | parser.add_argument('--loss', type=str, default='triplet_softmax') 66 | parser.add_argument('--triplet_margin', type=float, default=0.3) 67 | 68 | # optimizer 69 | parser.add_argument('--optim', type=str, default='SGD') 70 | parser.add_argument('--lr', type=float, default=0.025) 71 | parser.add_argument('--weight-decay', type=float, default=5e-4) 72 | parser.add_argument('--arch-lr', type=float, default=0.002) 73 | parser.add_argument('--arch-weight-decay', type=float, default=1e-3) 74 | parser.add_argument('--momentum', type=float, default=0.9) 75 | parser.add_argument('--milestones', nargs='+', type=int, 76 | default=[150, 225, 300]) 77 | parser.add_argument('--warmup-step', type=int, default=10) 78 | 79 | # training configs 80 | parser.add_argument('--resume', type=str, default='') 81 | parser.add_argument('--evaluate', action='store_true', default=False) 82 | parser.add_argument('--epochs', type=int, default=350) 83 | parser.add_argument('--seed', type=int, default=0) 84 | parser.add_argument('--print-freq', type=int, default=100) 85 | parser.add_argument('--geno-interval', type=int, default=5) 86 | parser.add_argument('--eval-interval', type=int, default=40) 87 | 88 | # misc 89 | parser.add_argument('--data-dir', type=str, default='./data') 90 | parser.add_argument('--logs-dir', type=str, default='./logs') 91 | parser.add_argument('--pretrain-dir', type=str, default='./pretrained') 92 | 93 | args = parser.parse_args() 94 | main(args) 95 | 96 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import random 5 | import argparse 6 | import numpy as np 7 | import os.path as osp 8 | from torch.backends import cudnn 9 | 10 | from reid.utils.logging import Logger 11 | from reid.data import build_data 12 | from reid.criterion import build_criterion 13 | from reid.solver import build_optimizer 14 | from reid.engine import do_train, evaluate 15 | from reid.models.msinet import msinet_x1_0 16 | from reid.utils.serialization import copy_state_dict 17 | 18 | 19 | def count_parameters(model): 20 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if ('classifier' not in name)) / 1e6 21 | 22 | 23 | def main(args): 24 | random.seed(args.seed) 25 | np.random.seed(args.seed) 26 | torch.manual_seed(args.seed) 27 | cudnn.deterministic = True 28 | cudnn.benchmark = True 29 | 30 | if not args.evaluate: 31 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 32 | 33 | print('Running with:\n{}'.format(args)) 34 | 35 | train_loader, test_loader, num_query, num_classes = build_data(args) 36 | model = msinet_x1_0(args, num_classes) 37 | print('Model Params: {}'.format(count_parameters(model))) 38 | model = model.cuda() 39 | 40 | if args.resume != '': 41 | copy_state_dict(torch.load(args.resume), model) 42 | 43 | if args.evaluate: 44 | evaluate(args, model, test_loader, num_query) 45 | if args.target_dataset != 'none': 46 | _, tar_test_loader, tar_num_query, _ = build_data(args, target=True) 47 | evaluate(args, model, tar_test_loader, tar_num_query) 48 | return 49 | 50 | criterion = build_criterion(args, num_classes) 51 | optimizer, lr_scheduler = build_optimizer(args, model) 52 | 53 | do_train(args, model, criterion, train_loader, test_loader, 54 | optimizer, lr_scheduler, num_query, 55 | remove_cam=(args.source_dataset != 'vehicleid')) 56 | 57 | if args.target_dataset != 'none': 58 | _, tar_test_loader, tar_num_query, _ = build_data(args, target=True) 59 | evaluate(args, model, tar_test_loader, tar_num_query, 60 | remove_cam=(args.target_dataset != 'vehicleid')) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | 66 | # data 67 | parser.add_argument('-ds', '--source-dataset', type=str, default='market1501') 68 | parser.add_argument('-dt', '--target-dataset', type=str, default='none') 69 | parser.add_argument('-b', '--batch-size', type=int, default=64) 70 | parser.add_argument('--test-batch-size', type=int, default=128) 71 | parser.add_argument('-j', '--workers', type=int, default=4) 72 | parser.add_argument('--height', type=int, default=256) 73 | parser.add_argument('--width', type=int, default=128) 74 | parser.add_argument('--num-instance', type=int, default=4) 75 | 76 | # model 77 | parser.add_argument('-a', '--arch', type=str, default='resnet50') 78 | parser.add_argument('--pretrained', action='store_true', default=False) 79 | parser.add_argument('--reset-params', type=bool, default=False) 80 | parser.add_argument('--genotypes', type=str, default='msmt') 81 | 82 | # loss 83 | parser.add_argument('--margin', type=float, default=0.3) 84 | parser.add_argument('--sam-mode', type=str, default='none') 85 | parser.add_argument('--sam-ratio', type=float, default=2.0) 86 | 87 | # optimizer 88 | parser.add_argument('--optim', type=str, default='sgd') 89 | parser.add_argument('--lr', type=float, default=0.065) 90 | parser.add_argument('--weight-decay', type=float, default=5e-4) 91 | parser.add_argument('--momentum', type=float, default=0.9) 92 | parser.add_argument('--milestones', nargs='+', type=int, default=[150, 225, 300]) 93 | parser.add_argument('--warmup-step', type=int, default=10) 94 | 95 | # training configs 96 | parser.add_argument('--resume', type=str, default='') 97 | parser.add_argument('--evaluate', action='store_true', default=False) 98 | parser.add_argument('--epochs', type=int, default=350) 99 | parser.add_argument('--seed', type=int, default=0) 100 | parser.add_argument('--print-freq', type=int, default=100) 101 | parser.add_argument('--eval-interval', type=int, default=40) 102 | 103 | # misc 104 | parser.add_argument('--data-dir', type=str, default='./data') 105 | parser.add_argument('--logs-dir', type=str, default='./logs') 106 | parser.add_argument('--pretrain-dir', type=str, default='./pretrained') 107 | 108 | args = parser.parse_args() 109 | main(args) 110 | --------------------------------------------------------------------------------