├── requirements.txt ├── Datasets ├── __pycache__ │ ├── iLDSVID.cpython-36.pyc │ ├── MARS_dataset.cpython-36.pyc │ └── PRID_dataset.cpython-36.pyc ├── PRID_dataset.py ├── iLDSVID.py └── MARS_dataset.py ├── loss ├── __pycache__ │ ├── center_loss.cpython-36.pyc │ ├── softmax_loss.cpython-36.pyc │ └── triplet_loss.cpython-36.pyc ├── softmax_loss.py ├── center_loss.py └── triplet_loss.py ├── Loss_fun.py ├── README.md ├── VID_Test.py ├── VID_Trans_ReID.py ├── VID_Trans_model.py ├── Dataloader.py ├── utility.py └── vit_ID.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | timm 4 | yacs 5 | opencv-python 6 | -------------------------------------------------------------------------------- /Datasets/__pycache__/iLDSVID.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/Datasets/__pycache__/iLDSVID.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/center_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/loss/__pycache__/center_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/softmax_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/loss/__pycache__/softmax_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/triplet_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/loss/__pycache__/triplet_loss.cpython-36.pyc -------------------------------------------------------------------------------- /Datasets/__pycache__/MARS_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/Datasets/__pycache__/MARS_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /Datasets/__pycache__/PRID_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/Datasets/__pycache__/PRID_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /Loss_fun.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from loss.softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 3 | from loss.triplet_loss import TripletLoss 4 | from loss.center_loss import CenterLoss 5 | 6 | 7 | def make_loss(num_classes): 8 | 9 | feat_dim =768 10 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 11 | center_criterion2 = CenterLoss(num_classes=num_classes, feat_dim=3072, use_gpu=True) 12 | 13 | triplet = TripletLoss() 14 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 15 | 16 | 17 | def loss_func(score, feat, target, target_cam): 18 | if isinstance(score, list): 19 | ID_LOSS = [xent(scor, target) for scor in score[1:]] 20 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 21 | ID_LOSS = 0.25 * ID_LOSS + 0.75 * xent(score[0], target) 22 | else: 23 | ID_LOSS = xent(score, target) 24 | 25 | if isinstance(feat, list): 26 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 27 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 28 | TRI_LOSS = 0.25 * TRI_LOSS + 0.75 * triplet(feat[0], target)[0] 29 | 30 | center=center_criterion(feat[0], target) 31 | centr2 = [center_criterion2(feats, target) for feats in feat[1:]] 32 | centr2 = sum(centr2) / len(centr2) 33 | center=0.25 *centr2 + 0.75 * center 34 | else: 35 | TRI_LOSS = triplet(feat, target)[0] 36 | 37 | return ID_LOSS+ TRI_LOSS, center 38 | 39 | return loss_func,center_criterion 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /loss/softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | class CrossEntropyLabelSmooth(nn.Module): 5 | """Cross entropy loss with label smoothing regularizer. 6 | 7 | Reference: 8 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 9 | Equation: y = (1 - epsilon) * y + epsilon / K. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | epsilon (float): weight. 14 | """ 15 | 16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 17 | super(CrossEntropyLabelSmooth, self).__init__() 18 | self.num_classes = num_classes 19 | self.epsilon = epsilon 20 | self.use_gpu = use_gpu 21 | self.logsoftmax = nn.LogSoftmax(dim=1) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (num_classes) 28 | """ 29 | log_probs = self.logsoftmax(inputs) 30 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 31 | if self.use_gpu: targets = targets.cuda() 32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 33 | loss = (- targets * log_probs).mean(0).sum() 34 | return loss 35 | 36 | class LabelSmoothingCrossEntropy(nn.Module): 37 | """ 38 | NLL loss with label smoothing. 39 | """ 40 | def __init__(self, smoothing=0.1): 41 | """ 42 | Constructor for the LabelSmoothing module. 43 | :param smoothing: label smoothing factor 44 | """ 45 | super(LabelSmoothingCrossEntropy, self).__init__() 46 | assert smoothing < 1.0 47 | self.smoothing = smoothing 48 | self.confidence = 1. - smoothing 49 | 50 | def forward(self, x, target): 51 | logprobs = F.log_softmax(x, dim=-1) 52 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 53 | nll_loss = nll_loss.squeeze(1) 54 | smooth_loss = -logprobs.mean(dim=-1) 55 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 56 | return loss.mean() -------------------------------------------------------------------------------- /loss/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = [] 48 | for i in range(batch_size): 49 | value = distmat[i][mask[i]] 50 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 51 | dist.append(value) 52 | dist = torch.cat(dist) 53 | loss = dist.mean() 54 | return loss 55 | 56 | 57 | if __name__ == '__main__': 58 | use_gpu = False 59 | center_loss = CenterLoss(use_gpu=use_gpu) 60 | features = torch.rand(16, 2048) 61 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 62 | if use_gpu: 63 | features = torch.rand(16, 2048).cuda() 64 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 65 | 66 | loss = center_loss(features, targets) 67 | print(loss) 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VID-Trans-ReID 2 | This is an Official Pytorch Implementation of our paper VID-Trans-ReID: Enhanced Video Transformers for Person Re-identification 3 | 4 | [![Python 3.6](https://img.shields.io/badge/python-3.7-blue.svg)](https://www.python.org/downloads/release/python-370/) Tested using Python 3.7.x and Torch: 1.8.0. 5 | 6 | ## Architecture: 7 |

8 | modelupdated2 9 |

10 | 11 | ## Abstract 12 | _"Video-based person Re-identification (Re-ID) has received increasing attention recently due to its important role within surveillance video analysis. Video-based Re-ID expands upon earlier image-based methods by extracting person features temporally across multiple video image frames. The key challenge within person Re-ID is extracting a robust feature representation that is invariant to the challenges of pose and illumination variation across multiple camera viewpoints. Whilst most contemporary methods use a CNN based methodology, recent advances in vision transformer (ViT) architectures boos fine-grained feature discrimination via the use of both multi-head attention without any loss of feature robustness. To specifically enable ViT architectures to effectively address the challenges of video person Re-ID, we propose two novel modules constructs, Tem- poral Clip Shift and Shuffled (TCSS) and Video Patch Part Feature (VPPF), that boost the robustness of the resultant Re-ID feature representation. Furthermore, we combine our proposed approach with current best practices spanning both image and video based Re-ID including camera view embedding. Our proposed approach outperforms existing state-of-the-art work on the MARS, PRID2011, and iLIDS-VID Re-ID benchmark datasets achieving 96.36%, 96.63%, 94.67% rank-1 accuracy respectively and achieving 90.25% mAP on MARS."_ 13 | 14 | [[A. Alsehaim, T.P. Breckon, In Proc. British Machine Vision Conference, BMVA, 2022](https://breckon.org/toby/publications/papers/alsehaim22vidtransreid.pdf)] [[Talk](https://www.youtube.com/embed/NARrZroYD-U)] [[Poster](https://breckon.org/toby/publications/posters/alsehaim22vidtransreid_poster.pdf)] 15 | 16 | ## 17 | 18 | non-id2 19 | 20 | ## 21 | 22 | paper2Dig 23 | 24 | 25 | 26 | 27 | ## Requirements 28 | ``` 29 | pip install -r requirements.txt 30 | ``` 31 | ## Getting Started 32 | 33 | 1. Download the ImageNet pretrained transformer model : [ViT_base](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth). 34 | 2. Download the video person Re-ID datasets [MARS](http://zheng-lab.cecs.anu.edu.au/Project/project_mars.html), [PRID](https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/prid11/) and [iLIDS-VID](https://xiatian-zhu.github.io/downloads_qmul_iLIDS-VID_ReID_dataset.html) 35 | 36 | ## Train and Evaluate 37 | 38 | Use the pre-trained model [ViT_base](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth) to initialize ViT transformer then train the whole model. 39 | 40 | MARS Datasete 41 | ``` 42 | python -u VID_Trans_ReID.py --Dataset_name 'Mars' --ViT_path 'jx_vit_base_p16_224-80ecf9dd.pth' 43 | ``` 44 | 45 | PRID Dataset 46 | ``` 47 | python -u VID_Trans_ReID.py --Dataset_name 'PRID' --ViT_path 'jx_vit_base_p16_224-80ecf9dd.pth' 48 | ``` 49 | 50 | iLIDS-VID Dataset 51 | ``` 52 | python -u VID_Trans_ReID.py --Dataset_name 'iLIDSVID' --ViT_path 'jx_vit_base_p16_224-80ecf9dd.pth' 53 | ``` 54 | 55 | ## Test 56 | To test the model you can use our pretrained model on MARS dataset [download](https://durhamuniversity-my.sharepoint.com/:u:/g/personal/zwjx97_durham_ac_uk/Ec09LVNFG_JKotjNPkVgTaIB7k0eUAwmPq9gawciw2ggBQ?e=swd9DK) 57 | 58 | ``` 59 | python -u VID_Test.py --Dataset_name 'Mars' --model_path 'MarsMain_Model.pth' 60 | ``` 61 | 62 | ## Acknowledgement 63 | Thanks to Hao Luo, using some implementation from his [repository](https://github.com/michuanhaohao) 64 | 65 | ## Citation 66 | 67 | If you are making use of this work in any way, you must please reference the following paper in any report, publication, presentation, software release or any other associated materials: 68 | 69 | [VID-Trans-ReID: Enhanced Video Transformers for Person Re-identification](https://breckon.org/toby/publications/papers/alsehaim22vidtransreid.pdf) (A. Alsehaim, T.P. Breckon), In Proc. British Machine Vision Conference, BMVA, 2022. 70 | 71 | ``` 72 | @inproceedings{alsehaim22vidtransreid, 73 | author = {Alsehaim, A. and Breckon, T.P.}, 74 | title = {VID-Trans-ReID: Enhanced Video Transformers for Person Re-identification}, 75 | booktitle = {Proc. British Machine Vision Conference}, 76 | year = {2022}, 77 | month = {November}, 78 | publisher = {BMVA}, 79 | url = {https://breckon.org/toby/publications/papers/alsehaim22vidtransreid.pdf} 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def normalize(x, axis=-1): 6 | """Normalizing to unit length along the specified dimension. 7 | Args: 8 | x: pytorch Variable 9 | Returns: 10 | x: pytorch Variable, same shape as input 11 | """ 12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 13 | return x 14 | 15 | 16 | def euclidean_dist(x, y): 17 | """ 18 | Args: 19 | x: pytorch Variable, with shape [m, d] 20 | y: pytorch Variable, with shape [n, d] 21 | Returns: 22 | dist: pytorch Variable, with shape [m, n] 23 | """ 24 | m, n = x.size(0), y.size(0) 25 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 26 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 27 | dist = xx + yy 28 | dist = dist - 2 * torch.matmul(x, y.t()) 29 | # dist.addmm_(1, -2, x, y.t()) 30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 31 | return dist 32 | 33 | 34 | def cosine_dist(x, y): 35 | """ 36 | Args: 37 | x: pytorch Variable, with shape [m, d] 38 | y: pytorch Variable, with shape [n, d] 39 | Returns: 40 | dist: pytorch Variable, with shape [m, n] 41 | """ 42 | m, n = x.size(0), y.size(0) 43 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n) 44 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t() 45 | xy_intersection = torch.mm(x, y.t()) 46 | dist = xy_intersection/(x_norm * y_norm) 47 | dist = (1. - dist) / 2 48 | return dist 49 | 50 | 51 | def hard_example_mining(dist_mat, labels, return_inds=False): 52 | """For each anchor, find the hardest positive and negative sample. 53 | Args: 54 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 55 | labels: pytorch LongTensor, with shape [N] 56 | return_inds: whether to return the indices. Save time if `False`(?) 57 | Returns: 58 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 59 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 60 | p_inds: pytorch LongTensor, with shape [N]; 61 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 62 | n_inds: pytorch LongTensor, with shape [N]; 63 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 64 | NOTE: Only consider the case in which all labels have same num of samples, 65 | thus we can cope with all anchors in parallel. 66 | """ 67 | 68 | assert len(dist_mat.size()) == 2 69 | assert dist_mat.size(0) == dist_mat.size(1) 70 | N = dist_mat.size(0) 71 | 72 | # shape [N, N] 73 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 74 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 75 | 76 | # `dist_ap` means distance(anchor, positive) 77 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 78 | dist_ap, relative_p_inds = torch.max( 79 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 80 | # print(dist_mat[is_pos].shape) 81 | # `dist_an` means distance(anchor, negative) 82 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 83 | dist_an, relative_n_inds = torch.min( 84 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 85 | # shape [N] 86 | dist_ap = dist_ap.squeeze(1) 87 | dist_an = dist_an.squeeze(1) 88 | 89 | if return_inds: 90 | # shape [N, N] 91 | ind = (labels.new().resize_as_(labels) 92 | .copy_(torch.arange(0, N).long()) 93 | .unsqueeze(0).expand(N, N)) 94 | # shape [N, 1] 95 | p_inds = torch.gather( 96 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 97 | n_inds = torch.gather( 98 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 99 | # shape [N] 100 | p_inds = p_inds.squeeze(1) 101 | n_inds = n_inds.squeeze(1) 102 | return dist_ap, dist_an, p_inds, n_inds 103 | 104 | return dist_ap, dist_an 105 | 106 | 107 | class TripletLoss(object): 108 | """ 109 | Triplet loss using HARDER example mining, 110 | modified based on original triplet loss using hard example mining 111 | """ 112 | 113 | def __init__(self, margin=None, hard_factor=0.0): 114 | self.margin = margin 115 | self.hard_factor = hard_factor 116 | if margin is not None: 117 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 118 | else: 119 | self.ranking_loss = nn.SoftMarginLoss() 120 | 121 | def __call__(self, global_feat, labels, normalize_feature=False): 122 | 123 | if normalize_feature: 124 | global_feat = normalize(global_feat, axis=-1) 125 | dist_mat = euclidean_dist(global_feat, global_feat) 126 | dist_ap, dist_an = hard_example_mining(dist_mat, labels) 127 | 128 | dist_ap *= (1.0 + self.hard_factor) 129 | dist_an *= (1.0 - self.hard_factor) 130 | 131 | y = dist_an.new().resize_as_(dist_an).fill_(1) 132 | if self.margin is not None: 133 | loss = self.ranking_loss(dist_an, dist_ap, y) 134 | else: 135 | loss = self.ranking_loss(dist_an - dist_ap, y) 136 | return loss, dist_ap, dist_an 137 | 138 | 139 | -------------------------------------------------------------------------------- /Datasets/PRID_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function, absolute_import 3 | 4 | from collections import defaultdict 5 | from scipy.io import loadmat 6 | import os.path as osp 7 | import numpy as np 8 | import json 9 | import glob 10 | def read_json(fpath): 11 | with open(fpath, 'r') as f: 12 | obj = json.load(f) 13 | return obj 14 | 15 | class PRID(object): 16 | """ 17 | PRID 18 | Reference: 19 | Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011. 20 | 21 | Dataset statistics: 22 | # identities: 200 23 | # tracklets: 400 24 | # cameras: 2 25 | Args: 26 | split_id (int): indicates which split to use. There are totally 10 splits. 27 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0). 28 | """ 29 | root = "prid_2011" 30 | 31 | # root = './data/prid2011' 32 | dataset_url = 'https://files.icg.tugraz.at/f/6ab7e8ce8f/?raw=1' 33 | split_path = osp.join(root, 'splits_prid2011.json') 34 | cam_a_path = osp.join(root, 'multi_shot', 'cam_a') 35 | cam_b_path = osp.join(root, 'multi_shot', 'cam_b') 36 | 37 | def __init__(self, split_id=0, min_seq_len=0): 38 | self._check_before_run() 39 | splits = read_json(self.split_path) 40 | if split_id >= len(splits): 41 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 42 | split = splits[split_id] 43 | train_dirs, test_dirs = split['train'], split['test'] 44 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 45 | 46 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 47 | self._process_data(train_dirs, cam1=True, cam2=True) 48 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 49 | self._process_data(test_dirs, cam1=True, cam2=False) 50 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 51 | self._process_data(test_dirs, cam1=False, cam2=True) 52 | 53 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 54 | min_num = np.min(num_imgs_per_tracklet) 55 | max_num = np.max(num_imgs_per_tracklet) 56 | avg_num = np.mean(num_imgs_per_tracklet) 57 | 58 | num_total_pids = num_train_pids + num_query_pids 59 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 60 | 61 | print("=> PRID-2011 loaded") 62 | print("Dataset statistics:") 63 | print(" ------------------------------") 64 | print(" subset | # ids | # tracklets") 65 | print(" ------------------------------") 66 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 67 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 68 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 69 | print(" ------------------------------") 70 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 71 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 72 | print(" ------------------------------") 73 | 74 | self.train = train 75 | self.query = query 76 | self.gallery = gallery 77 | 78 | self.num_train_pids = num_train_pids 79 | self.num_query_pids = num_query_pids 80 | self.num_gallery_pids = num_gallery_pids 81 | self.num_train_cams=2 82 | self.num_query_cams=2 83 | self.num_gallery_cams=2 84 | self.num_train_vids=num_train_tracklets 85 | self.num_query_vids=num_query_tracklets 86 | self.num_gallery_vids=num_gallery_tracklets 87 | 88 | def _check_before_run(self): 89 | """Check if all files are available before going deeper""" 90 | if not osp.exists(self.root): 91 | raise RuntimeError("'{}' is not available".format(self.root)) 92 | 93 | def _process_data(self, dirnames, cam1=True, cam2=True): 94 | tracklets = [] 95 | num_imgs_per_tracklet = [] 96 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 97 | 98 | for dirname in dirnames: 99 | if cam1: 100 | person_dir = osp.join(self.cam_a_path, dirname) 101 | img_names = glob.glob(osp.join(person_dir, '*.png')) 102 | assert len(img_names) > 0 103 | img_names = tuple(img_names) 104 | pid = dirname2pid[dirname] 105 | tracklets.append((img_names, pid, 0)) 106 | num_imgs_per_tracklet.append(len(img_names)) 107 | 108 | if cam2: 109 | person_dir = osp.join(self.cam_b_path, dirname) 110 | img_names = glob.glob(osp.join(person_dir, '*.png')) 111 | assert len(img_names) > 0 112 | img_names = tuple(img_names) 113 | pid = dirname2pid[dirname] 114 | tracklets.append((img_names, pid, 1)) 115 | num_imgs_per_tracklet.append(len(img_names)) 116 | 117 | num_tracklets = len(tracklets) 118 | num_pids = len(dirnames) 119 | 120 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 121 | -------------------------------------------------------------------------------- /VID_Test.py: -------------------------------------------------------------------------------- 1 | from Dataloader import dataloader 2 | from VID_Trans_model import VID_Trans 3 | 4 | 5 | from Loss_fun import make_loss 6 | import random 7 | import torch 8 | import numpy as np 9 | import os 10 | import argparse 11 | 12 | import logging 13 | import os 14 | import time 15 | import torch 16 | import torch.nn as nn 17 | 18 | from torch.cuda import amp 19 | from utility import AverageMeter, optimizer,scheduler 20 | 21 | from torch.autograd import Variable 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=21): 31 | num_q, num_g = distmat.shape 32 | if num_g < max_rank: 33 | max_rank = num_g 34 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 35 | indices = np.argsort(distmat, axis=1) 36 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 37 | 38 | # compute cmc curve for each query 39 | all_cmc = [] 40 | all_AP = [] 41 | num_valid_q = 0. 42 | for q_idx in range(num_q): 43 | # get query pid and camid 44 | q_pid = q_pids[q_idx] 45 | q_camid = q_camids[q_idx] 46 | 47 | # remove gallery samples that have the same pid and camid with query 48 | order = indices[q_idx] 49 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 50 | keep = np.invert(remove) 51 | 52 | # compute cmc curve 53 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 54 | if not np.any(orig_cmc): 55 | # this condition is true when query identity does not appear in gallery 56 | continue 57 | 58 | cmc = orig_cmc.cumsum() 59 | cmc[cmc > 1] = 1 60 | 61 | all_cmc.append(cmc[:max_rank]) 62 | num_valid_q += 1. 63 | 64 | # compute average precision 65 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 66 | num_rel = orig_cmc.sum() 67 | tmp_cmc = orig_cmc.cumsum() 68 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 69 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 70 | AP = tmp_cmc.sum() / num_rel 71 | all_AP.append(AP) 72 | 73 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 74 | 75 | all_cmc = np.asarray(all_cmc).astype(np.float32) 76 | all_cmc = all_cmc.sum(0) / num_valid_q 77 | mAP = np.mean(all_AP) 78 | 79 | return all_cmc, mAP 80 | 81 | def test(model, queryloader, galleryloader, pool='avg', use_gpu=True, ranks=[1, 5, 10, 20]): 82 | model.eval() 83 | qf, q_pids, q_camids = [], [], [] 84 | with torch.no_grad(): 85 | for batch_idx, (imgs, pids, camids,_) in enumerate(queryloader): 86 | 87 | if use_gpu: 88 | imgs = imgs.cuda() 89 | imgs = Variable(imgs, volatile=True) 90 | 91 | b, s, c, h, w = imgs.size() 92 | 93 | 94 | features = model(imgs,pids,cam_label=camids ) 95 | 96 | features = features.view(b, -1) 97 | features = torch.mean(features, 0) 98 | features = features.data.cpu() 99 | qf.append(features) 100 | 101 | q_pids.append(pids) 102 | q_camids.extend(camids) 103 | qf = torch.stack(qf) 104 | q_pids = np.asarray(q_pids) 105 | q_camids = np.asarray(q_camids) 106 | print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1))) 107 | gf, g_pids, g_camids = [], [], [] 108 | for batch_idx, (imgs, pids, camids,_) in enumerate(galleryloader): 109 | if use_gpu: 110 | imgs = imgs.cuda() 111 | imgs = Variable(imgs, volatile=True) 112 | b, s,c, h, w = imgs.size() 113 | features = model(imgs,pids,cam_label=camids) 114 | features = features.view(b, -1) 115 | if pool == 'avg': 116 | features = torch.mean(features, 0) 117 | else: 118 | features, _ = torch.max(features, 0) 119 | features = features.data.cpu() 120 | gf.append(features) 121 | g_pids.append(pids) 122 | g_camids.extend(camids) 123 | gf = torch.stack(gf) 124 | g_pids = np.asarray(g_pids) 125 | g_camids = np.asarray(g_camids) 126 | print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1))) 127 | print("Computing distance matrix") 128 | m, n = qf.size(0), gf.size(0) 129 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 130 | distmat.addmm_(1, -2, qf, gf.t()) 131 | distmat = distmat.numpy() 132 | gf = gf.numpy() 133 | qf = qf.numpy() 134 | 135 | print("Original Computing CMC and mAP") 136 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) 137 | 138 | # print("Results ---------- {:.1%} ".format(distmat_rerank)) 139 | print("Results ---------- ") 140 | 141 | print("mAP: {:.1%} ".format(mAP)) 142 | print("CMC curve r1:",cmc[0]) 143 | 144 | return cmc[0], mAP 145 | 146 | 147 | 148 | if __name__ == '__main__': 149 | parser = argparse.ArgumentParser(description="VID-Trans-ReID") 150 | parser.add_argument( 151 | "--Dataset_name", default="", help="The name of the DataSet", type=str) 152 | parser.add_argument( 153 | "--model_path", default="", help="pretrained model", type=str) 154 | args = parser.parse_args() 155 | Dataset_name=args.Dataset_name 156 | pretrainpath=args.model_path 157 | 158 | 159 | 160 | train_loader, num_query, num_classes, camera_num, view_num,q_val_set,g_val_set = dataloader(Dataset_name) 161 | model = VID_Trans( num_classes=num_classes, camera_num=camera_num,pretrainpath=None) 162 | 163 | device = "cuda" 164 | model=model.to(device) 165 | 166 | checkpoint = torch.load(pretrainpath) 167 | model.load_state_dict(checkpoint) 168 | 169 | 170 | model.eval() 171 | cmc,map = test(model, q_val_set,g_val_set) 172 | print('CMC: %.4f, mAP : %.4f'%(cmc,map)) 173 | 174 | -------------------------------------------------------------------------------- /Datasets/iLDSVID.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function, absolute_import 3 | 4 | from collections import defaultdict 5 | from scipy.io import loadmat 6 | import os.path as osp 7 | import numpy as np 8 | import json 9 | import errno 10 | import os 11 | import tarfile 12 | import glob 13 | def read_json(fpath): 14 | with open(fpath, 'r') as f: 15 | obj = json.load(f) 16 | return obj 17 | import urllib.request 18 | 19 | 20 | def write_json(obj, fpath): 21 | mkdir_if_missing(osp.dirname(fpath)) 22 | with open(fpath, 'w') as f: 23 | json.dump(obj, f, indent=4, separators=(',', ': ')) 24 | 25 | def mkdir_if_missing(directory): 26 | if not osp.exists(directory): 27 | try: 28 | os.makedirs(directory) 29 | except OSError as e: 30 | if e.errno != errno.EEXIST: 31 | raise 32 | 33 | 34 | class iLIDSVID(object): 35 | """ 36 | iLIDS-VID 37 | Reference: 38 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014. 39 | 40 | Dataset statistics: 41 | # identities: 300 42 | # tracklets: 600 43 | # cameras: 2 44 | Args: 45 | split_id (int): indicates which split to use. There are totally 10 splits. 46 | """ 47 | root ="iLIDS-VID" 48 | # root = '/mnt/scratch/1/pathak/data/iLIDS' 49 | # root = './data/ilids-vid' 50 | dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar' 51 | data_dir = osp.join(root, 'i-LIDS-VID') 52 | split_dir = osp.join(root, 'train-test people splits') 53 | split_mat_path = osp.join(split_dir, 'train_test_splits_ilidsvid.mat') 54 | split_path = osp.join(root, 'splits.json') 55 | cam_1_path = osp.join(root, 'i-LIDS-VID/sequences/cam1') 56 | cam_2_path = osp.join(root, 'i-LIDS-VID/sequences/cam2') 57 | 58 | def __init__(self, split_id=0): 59 | self._download_data() 60 | self._check_before_run() 61 | 62 | self._prepare_split() 63 | splits = read_json(self.split_path) 64 | if split_id >= len(splits): 65 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 66 | split = splits[split_id] 67 | train_dirs, test_dirs = split['train'], split['test'] 68 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 69 | 70 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 71 | self._process_data(train_dirs, cam1=True, cam2=True) 72 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 73 | self._process_data(test_dirs, cam1=True, cam2=False) 74 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 75 | self._process_data(test_dirs, cam1=False, cam2=True) 76 | 77 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 78 | min_num = np.min(num_imgs_per_tracklet) 79 | max_num = np.max(num_imgs_per_tracklet) 80 | avg_num = np.mean(num_imgs_per_tracklet) 81 | 82 | num_total_pids = num_train_pids + num_query_pids 83 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 84 | 85 | print("=> iLIDS-VID loaded") 86 | print("Dataset statistics:") 87 | print(" ------------------------------") 88 | print(" subset | # ids | # tracklets") 89 | print(" ------------------------------") 90 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 91 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 92 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 93 | print(" ------------------------------") 94 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 95 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 96 | print(" ------------------------------") 97 | 98 | 99 | self.train = train 100 | self.query = query 101 | self.gallery = gallery 102 | 103 | self.num_train_pids = num_train_pids 104 | self.num_query_pids = num_query_pids 105 | self.num_gallery_pids = num_gallery_pids 106 | self.num_train_cams=2 107 | self.num_query_cams=2 108 | self.num_gallery_cams=2 109 | self.num_train_vids=num_train_tracklets 110 | self.num_query_vids=num_query_tracklets 111 | self.num_gallery_vids=num_gallery_tracklets 112 | 113 | def _download_data(self): 114 | if osp.exists(self.root): 115 | print("This dataset has been downloaded.") 116 | return 117 | 118 | mkdir_if_missing(self.root) 119 | fpath = osp.join(self.root, osp.basename(self.dataset_url)) 120 | 121 | print("Downloading iLIDS-VID dataset") 122 | url_opener = urllib.request 123 | url_opener.urlretrieve(self.dataset_url, fpath) 124 | 125 | print("Extracting files") 126 | tar = tarfile.open(fpath) 127 | tar.extractall(path=self.root) 128 | tar.close() 129 | 130 | def _check_before_run(self): 131 | """Check if all files are available before going deeper""" 132 | if not osp.exists(self.root): 133 | raise RuntimeError("'{}' is not available".format(self.root)) 134 | if not osp.exists(self.data_dir): 135 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 136 | if not osp.exists(self.split_dir): 137 | raise RuntimeError("'{}' is not available".format(self.split_dir)) 138 | 139 | def _prepare_split(self): 140 | if not osp.exists(self.split_path): 141 | print("Creating splits") 142 | mat_split_data = loadmat(self.split_mat_path)['ls_set'] 143 | 144 | num_splits = mat_split_data.shape[0] 145 | num_total_ids = mat_split_data.shape[1] 146 | assert num_splits == 10 147 | assert num_total_ids == 300 148 | num_ids_each = int(num_total_ids/2) 149 | 150 | # pids in mat_split_data are indices, so we need to transform them 151 | # to real pids 152 | person_cam1_dirs = os.listdir(self.cam_1_path) 153 | person_cam2_dirs = os.listdir(self.cam_2_path) 154 | 155 | # make sure persons in one camera view can be found in the other camera view 156 | assert set(person_cam1_dirs) == set(person_cam2_dirs) 157 | 158 | splits = [] 159 | for i_split in range(num_splits): 160 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14. 161 | train_idxs = sorted(list(mat_split_data[i_split,num_ids_each:])) 162 | test_idxs = sorted(list(mat_split_data[i_split,:num_ids_each])) 163 | 164 | train_idxs = [int(i)-1 for i in train_idxs] 165 | test_idxs = [int(i)-1 for i in test_idxs] 166 | 167 | # transform pids to person dir names 168 | train_dirs = [person_cam1_dirs[i] for i in train_idxs] 169 | test_dirs = [person_cam1_dirs[i] for i in test_idxs] 170 | 171 | split = {'train': train_dirs, 'test': test_dirs} 172 | splits.append(split) 173 | 174 | print("Totally {} splits are created, following Wang et al. ECCV'14".format(len(splits))) 175 | print("Split file is saved to {}".format(self.split_path)) 176 | write_json(splits, self.split_path) 177 | 178 | print("Splits created") 179 | 180 | def _process_data(self, dirnames, cam1=True, cam2=True): 181 | tracklets = [] 182 | num_imgs_per_tracklet = [] 183 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 184 | 185 | for dirname in dirnames: 186 | if cam1: 187 | person_dir = osp.join(self.cam_1_path, dirname) 188 | img_names = glob.glob(osp.join(person_dir, '*.png')) 189 | assert len(img_names) > 0 190 | img_names = tuple(img_names) 191 | pid = dirname2pid[dirname] 192 | tracklets.append((img_names, pid, 0)) 193 | num_imgs_per_tracklet.append(len(img_names)) 194 | 195 | if cam2: 196 | person_dir = osp.join(self.cam_2_path, dirname) 197 | img_names = glob.glob(osp.join(person_dir, '*.png')) 198 | assert len(img_names) > 0 199 | img_names = tuple(img_names) 200 | pid = dirname2pid[dirname] 201 | tracklets.append((img_names, pid, 1)) 202 | num_imgs_per_tracklet.append(len(img_names)) 203 | 204 | num_tracklets = len(tracklets) 205 | num_pids = len(dirnames) 206 | 207 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 208 | -------------------------------------------------------------------------------- /VID_Trans_ReID.py: -------------------------------------------------------------------------------- 1 | from Dataloader import dataloader 2 | from VID_Trans_model import VID_Trans 3 | 4 | 5 | from Loss_fun import make_loss 6 | 7 | import random 8 | import torch 9 | import numpy as np 10 | import os 11 | import argparse 12 | 13 | import logging 14 | import os 15 | import time 16 | import torch 17 | import torch.nn as nn 18 | from torch_ema import ExponentialMovingAverage 19 | from torch.cuda import amp 20 | import torch.distributed as dist 21 | 22 | from utility import AverageMeter, optimizer,scheduler 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | from torch.autograd import Variable 31 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=21): 32 | num_q, num_g = distmat.shape 33 | if num_g < max_rank: 34 | max_rank = num_g 35 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 36 | indices = np.argsort(distmat, axis=1) 37 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 38 | 39 | # compute cmc curve for each query 40 | all_cmc = [] 41 | all_AP = [] 42 | num_valid_q = 0. 43 | for q_idx in range(num_q): 44 | # get query pid and camid 45 | q_pid = q_pids[q_idx] 46 | q_camid = q_camids[q_idx] 47 | 48 | # remove gallery samples that have the same pid and camid with query 49 | order = indices[q_idx] 50 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 51 | keep = np.invert(remove) 52 | 53 | # compute cmc curve 54 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 55 | if not np.any(orig_cmc): 56 | # this condition is true when query identity does not appear in gallery 57 | continue 58 | 59 | cmc = orig_cmc.cumsum() 60 | cmc[cmc > 1] = 1 61 | 62 | all_cmc.append(cmc[:max_rank]) 63 | num_valid_q += 1. 64 | 65 | # compute average precision 66 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 67 | num_rel = orig_cmc.sum() 68 | tmp_cmc = orig_cmc.cumsum() 69 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 70 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 71 | AP = tmp_cmc.sum() / num_rel 72 | all_AP.append(AP) 73 | 74 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 75 | 76 | all_cmc = np.asarray(all_cmc).astype(np.float32) 77 | all_cmc = all_cmc.sum(0) / num_valid_q 78 | mAP = np.mean(all_AP) 79 | 80 | return all_cmc, mAP 81 | 82 | def test(model, queryloader, galleryloader, pool='avg', use_gpu=True, ranks=[1, 5, 10, 20]): 83 | model.eval() 84 | qf, q_pids, q_camids = [], [], [] 85 | with torch.no_grad(): 86 | for batch_idx, (imgs, pids, camids,_) in enumerate(queryloader): 87 | 88 | if use_gpu: 89 | imgs = imgs.cuda() 90 | imgs = Variable(imgs, volatile=True) 91 | 92 | b, s, c, h, w = imgs.size() 93 | 94 | 95 | features = model(imgs,pids,cam_label=camids ) 96 | 97 | features = features.view(b, -1) 98 | features = torch.mean(features, 0) 99 | features = features.data.cpu() 100 | qf.append(features) 101 | 102 | q_pids.append(pids) 103 | q_camids.extend(camids) 104 | qf = torch.stack(qf) 105 | q_pids = np.asarray(q_pids) 106 | q_camids = np.asarray(q_camids) 107 | print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1))) 108 | gf, g_pids, g_camids = [], [], [] 109 | for batch_idx, (imgs, pids, camids,_) in enumerate(galleryloader): 110 | if use_gpu: 111 | imgs = imgs.cuda() 112 | imgs = Variable(imgs, volatile=True) 113 | b, s,c, h, w = imgs.size() 114 | features = model(imgs,pids,cam_label=camids) 115 | features = features.view(b, -1) 116 | if pool == 'avg': 117 | features = torch.mean(features, 0) 118 | else: 119 | features, _ = torch.max(features, 0) 120 | features = features.data.cpu() 121 | gf.append(features) 122 | g_pids.append(pids) 123 | g_camids.extend(camids) 124 | gf = torch.stack(gf) 125 | g_pids = np.asarray(g_pids) 126 | g_camids = np.asarray(g_camids) 127 | print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1))) 128 | print("Computing distance matrix") 129 | m, n = qf.size(0), gf.size(0) 130 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 131 | distmat.addmm_(1, -2, qf, gf.t()) 132 | distmat = distmat.numpy() 133 | gf = gf.numpy() 134 | qf = qf.numpy() 135 | 136 | print("Original Computing CMC and mAP") 137 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) 138 | 139 | # print("Results ---------- {:.1%} ".format(distmat_rerank)) 140 | print("Results ---------- ") 141 | 142 | print("mAP: {:.1%} ".format(mAP)) 143 | print("CMC curve r1:",cmc[0]) 144 | 145 | return cmc[0], mAP 146 | 147 | 148 | 149 | if __name__ == '__main__': 150 | parser = argparse.ArgumentParser(description="VID-Trans-ReID") 151 | parser.add_argument( 152 | "--Dataset_name", default="", help="The name of the DataSet", type=str) 153 | args = parser.parse_args() 154 | Dataset_name=args.Dataset_name 155 | torch.manual_seed(1234) 156 | torch.cuda.manual_seed(1234) 157 | torch.cuda.manual_seed_all(1234) 158 | np.random.seed(1234) 159 | random.seed(1234) 160 | torch.backends.cudnn.deterministic = True 161 | torch.backends.cudnn.benchmark = True 162 | train_loader, num_query, num_classes, camera_num, view_num,q_val_set,g_val_set = dataloader(Dataset_name) 163 | model = VID_Trans( num_classes=num_classes, camera_num=camera_num,pretrainpath=pretrainpath) 164 | 165 | loss_fun,center_criterion= make_loss( num_classes=num_classes) 166 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr= 0.5) 167 | 168 | optimizer= optimizer( model) 169 | scheduler = scheduler(optimizer) 170 | scaler = amp.GradScaler() 171 | 172 | #Train 173 | device = "cuda" 174 | epochs = 120 175 | model=model.to(device) 176 | ema = ExponentialMovingAverage(model.parameters(), decay=0.995) 177 | loss_meter = AverageMeter() 178 | acc_meter = AverageMeter() 179 | 180 | cmc_rank1=0 181 | for epoch in range(1, epochs + 1): 182 | start_time = time.time() 183 | loss_meter.reset() 184 | acc_meter.reset() 185 | 186 | scheduler.step(epoch) 187 | model.train() 188 | 189 | for Epoch_n, (img, pid, target_cam,labels2) in enumerate(train_loader): 190 | 191 | optimizer.zero_grad() 192 | optimizer_center.zero_grad() 193 | 194 | img = img.to(device) 195 | pid = pid.to(device) 196 | target_cam = target_cam.to(device) 197 | 198 | labels2=labels2.to(device) 199 | with amp.autocast(enabled=True): 200 | target_cam=target_cam.view(-1) 201 | score, feat ,a_vals= model(img, pid, cam_label=target_cam) 202 | 203 | labels2=labels2.to(device) 204 | attn_noise = a_vals * labels2 205 | attn_loss = attn_noise.sum(1).mean() 206 | 207 | loss_id ,center= loss_fun(score, feat, pid, target_cam) 208 | loss=loss_id+ 0.0005*center +attn_loss 209 | scaler.scale(loss).backward() 210 | 211 | scaler.step(optimizer) 212 | scaler.update() 213 | ema.update() 214 | 215 | for param in center_criterion.parameters(): 216 | param.grad.data *= (1. / 0.0005) 217 | scaler.step(optimizer_center) 218 | scaler.update() 219 | if isinstance(score, list): 220 | acc = (score[0].max(1)[1] == pid).float().mean() 221 | else: 222 | acc = (score.max(1)[1] == pid).float().mean() 223 | 224 | loss_meter.update(loss.item(), img.shape[0]) 225 | acc_meter.update(acc, 1) 226 | 227 | torch.cuda.synchronize() 228 | if (Epoch_n + 1) % 50 == 0: 229 | print("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" 230 | .format(epoch, (Epoch_n + 1), len(train_loader), 231 | loss_meter.avg, acc_meter.avg, scheduler._get_lr(epoch)[0])) 232 | 233 | if (epoch+1)%10 == 0 : 234 | 235 | model.eval() 236 | cmc,map = test(model, q_val_set,g_val_set) 237 | print('CMC: %.4f, mAP : %.4f'%(cmc,map)) 238 | if cmc_rank1 < cmc: 239 | cmc_rank1=cmc 240 | torch.save(model.state_dict(),os.path.join('/VID-Trans-ReID', Dataset_name+'Main_Model.pth')) 241 | 242 | 243 | -------------------------------------------------------------------------------- /Datasets/MARS_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function, absolute_import 3 | 4 | from collections import defaultdict 5 | from scipy.io import loadmat 6 | import os.path as osp 7 | import numpy as np 8 | 9 | 10 | class Mars(object): 11 | """ 12 | MARS 13 | Reference: 14 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 15 | 16 | Dataset statistics: 17 | # identities: 1261 18 | # tracklets: 8298 (train) + 1980 (query) + 9330 (gallery) 19 | # cameras: 6 20 | Args: 21 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0). 22 | """ 23 | 24 | root ='MARS' #'/home2/zwjx97/STE-NVAN-master/MARS' #"/home/aishahalsehaim/Desktop/STE-NVAN-master/MARS" 25 | 26 | train_name_path = osp.join(root, 'info/train_name.txt') 27 | test_name_path = osp.join(root, 'info/test_name.txt') 28 | track_train_info_path = osp.join(root, 'info/tracks_train_info.mat') 29 | track_test_info_path = osp.join(root, 'info/tracks_test_info.mat') 30 | query_IDX_path = osp.join(root, 'info/query_IDX.mat') 31 | 32 | def __init__(self, min_seq_len=0, ): 33 | self._check_before_run() 34 | 35 | # prepare meta data 36 | train_names = self._get_names(self.train_name_path) 37 | test_names = self._get_names(self.test_name_path) 38 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) 39 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) 40 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) 41 | query_IDX -= 1 # index from 0 42 | track_query = track_test[query_IDX,:] 43 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 44 | track_gallery = track_test[gallery_IDX,:] 45 | 46 | train, num_train_tracklets, num_train_pids, num_train_imgs = self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) 47 | 48 | video = self._process_train_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) 49 | 50 | 51 | query, num_query_tracklets, num_query_pids, num_query_imgs = self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 52 | 53 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 54 | 55 | num_imgs_per_tracklet = num_train_imgs + num_query_imgs + num_gallery_imgs 56 | min_num = np.min(num_imgs_per_tracklet) 57 | max_num = np.max(num_imgs_per_tracklet) 58 | avg_num = np.mean(num_imgs_per_tracklet) 59 | 60 | num_total_pids = num_train_pids + num_query_pids 61 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 62 | 63 | print("=> MARS loaded") 64 | print("Dataset statistics:") 65 | print(" ------------------------------") 66 | print(" subset | # ids | # tracklets") 67 | print(" ------------------------------") 68 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 69 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 70 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 71 | print(" ------------------------------") 72 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 73 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 74 | print(" ------------------------------") 75 | 76 | # self.train_videos = video 77 | self.train = train 78 | self.query = query 79 | self.gallery = gallery 80 | 81 | self.num_train_pids = num_train_pids 82 | self.num_query_pids = num_query_pids 83 | self.num_gallery_pids = num_gallery_pids 84 | self.num_train_cams=6 85 | self.num_query_cams=6 86 | self.num_gallery_cams=6 87 | self.num_train_vids=num_train_tracklets 88 | self.num_query_vids=num_query_tracklets 89 | self.num_gallery_vids=num_gallery_tracklets 90 | def _check_before_run(self): 91 | """Check if all files are available before going deeper""" 92 | if not osp.exists(self.root): 93 | raise RuntimeError("'{}' is not available".format(self.root)) 94 | if not osp.exists(self.train_name_path): 95 | raise RuntimeError("'{}' is not available".format(self.train_name_path)) 96 | if not osp.exists(self.test_name_path): 97 | raise RuntimeError("'{}' is not available".format(self.test_name_path)) 98 | if not osp.exists(self.track_train_info_path): 99 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path)) 100 | if not osp.exists(self.track_test_info_path): 101 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path)) 102 | if not osp.exists(self.query_IDX_path): 103 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path)) 104 | 105 | def _get_names(self, fpath): 106 | names = [] 107 | with open(fpath, 'r') as f: 108 | for line in f: 109 | new_line = line.rstrip() 110 | names.append(new_line) 111 | return names 112 | 113 | def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 114 | assert home_dir in ['bbox_train', 'bbox_test'] 115 | num_tracklets = meta_data.shape[0] 116 | pid_list = list(set(meta_data[:,2].tolist())) 117 | num_pids = len(pid_list) 118 | if not relabel: pid2label = {pid:int(pid) for label, pid in enumerate(pid_list)} 119 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 120 | tracklets = [] 121 | num_imgs_per_tracklet = [] 122 | 123 | for tracklet_idx in range(num_tracklets): 124 | data = meta_data[tracklet_idx,...] 125 | start_index, end_index, pid, camid = data 126 | if pid == -1: continue # junk images are just ignored 127 | assert 1 <= camid <= 6 128 | #if relabel: pid = pid2label[pid] 129 | pid = pid2label[pid] 130 | camid -= 1 # index starts from 0 131 | img_names = names[start_index-1:end_index] 132 | 133 | # make sure image names correspond to the same person 134 | pnames = [img_name[:4] for img_name in img_names] 135 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 136 | 137 | # make sure all images are captured under the same camera 138 | camnames = [img_name[5] for img_name in img_names] 139 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 140 | 141 | # append image names with directory information 142 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] 143 | if len(img_paths) >= min_seq_len: 144 | img_paths = tuple(img_paths) 145 | tracklets.append((img_paths, pid, camid)) 146 | num_imgs_per_tracklet.append(len(img_paths)) 147 | # if camid in video[pid] : 148 | # video[pid][camid].append(img_paths) 149 | # else: 150 | # video[pid][camid] = img_paths 151 | 152 | num_tracklets = len(tracklets) 153 | 154 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 155 | 156 | def _process_train_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 157 | video = defaultdict(dict) 158 | 159 | assert home_dir in ['bbox_train', 'bbox_test'] 160 | num_tracklets = meta_data.shape[0] 161 | pid_list = list(set(meta_data[:,2].tolist())) 162 | num_pids = len(pid_list) 163 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 164 | for tracklet_idx in range(num_tracklets): 165 | data = meta_data[tracklet_idx,...] 166 | start_index, end_index, pid, camid = data 167 | if pid == -1: continue # junk images are just ignored 168 | assert 1 <= camid <= 6 169 | if relabel: pid = pid2label[pid] 170 | camid -= 1 # index starts from 0 171 | img_names = names[start_index-1:end_index] 172 | # make sure image names correspond to the same person 173 | pnames = [img_name[:4] for img_name in img_names] 174 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 175 | # make sure all images are captured under the same camera 176 | camnames = [img_name[5] for img_name in img_names] 177 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 178 | 179 | # append image names with directory information 180 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] 181 | if len(img_paths) >= min_seq_len: 182 | if camid in video[pid] : 183 | video[pid][camid].extend(img_paths) 184 | else: 185 | video[pid][camid] = img_paths 186 | return video 187 | 188 | 189 | -------------------------------------------------------------------------------- /VID_Trans_model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import copy 5 | from vit_ID import TransReID,Block 6 | from functools import partial 7 | from torch.nn import functional as F 8 | 9 | 10 | def TCSS(features, shift, b,t): 11 | #aggregate features at patch level 12 | features=features.view(b,features.size(1),t*features.size(2)) 13 | token = features[:, 0:1] 14 | 15 | batchsize = features.size(0) 16 | dim = features.size(-1) 17 | 18 | 19 | #shift the patches with amount=shift 20 | features= torch.cat([features[:, shift:], features[:, 1:shift]], dim=1) 21 | 22 | # Patch Shuffling by 2 part 23 | try: 24 | features = features.view(batchsize, 2, -1, dim) 25 | except: 26 | features = torch.cat([features, features[:, -2:-1, :]], dim=1) 27 | features = features.view(batchsize, 2, -1, dim) 28 | 29 | features = torch.transpose(features, 1, 2).contiguous() 30 | features = features.view(batchsize, -1, dim) 31 | 32 | return features,token 33 | 34 | def weights_init_kaiming(m): 35 | classname = m.__class__.__name__ 36 | if classname.find('Linear') != -1: 37 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 38 | nn.init.constant_(m.bias, 0.0) 39 | 40 | elif classname.find('Conv') != -1: 41 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 42 | if m.bias is not None: 43 | nn.init.constant_(m.bias, 0.0) 44 | elif classname.find('BatchNorm') != -1: 45 | if m.affine: 46 | nn.init.constant_(m.weight, 1.0) 47 | nn.init.constant_(m.bias, 0.0) 48 | 49 | def weights_init_classifier(m): 50 | classname = m.__class__.__name__ 51 | if classname.find('Linear') != -1: 52 | nn.init.normal_(m.weight, std=0.001) 53 | if m.bias: 54 | nn.init.constant_(m.bias, 0.0) 55 | 56 | 57 | 58 | 59 | class VID_Trans(nn.Module): 60 | def __init__(self, num_classes, camera_num,pretrainpath): 61 | super(VID_Trans, self).__init__() 62 | self.in_planes = 768 63 | self.num_classes = num_classes 64 | 65 | 66 | self.base =TransReID( 67 | img_size=[256, 128], patch_size=16, stride_size=[16, 16], embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\ 68 | camera=camera_num, drop_path_rate=0.1, drop_rate=0.0, attn_drop_rate=0.0,norm_layer=partial(nn.LayerNorm, eps=1e-6), cam_lambda=3.0) 69 | 70 | 71 | state_dict = torch.load(pretrainpath, map_location='cpu') 72 | self.base.load_param(state_dict,load=True) 73 | 74 | 75 | #global stream 76 | block= self.base.blocks[-1] 77 | layer_norm = self.base.norm 78 | self.b1 = nn.Sequential( 79 | copy.deepcopy(block), 80 | copy.deepcopy(layer_norm) 81 | ) 82 | 83 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 84 | self.bottleneck.bias.requires_grad_(False) 85 | self.bottleneck.apply(weights_init_kaiming) 86 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) 87 | self.classifier.apply(weights_init_classifier) 88 | 89 | #----------------------------------------------- 90 | #----------------------------------------------- 91 | 92 | 93 | # building local video stream 94 | dpr = [x.item() for x in torch.linspace(0, 0, 12)] # stochastic depth decay rule 95 | 96 | self.block1 = Block( 97 | dim=3072, num_heads=12, mlp_ratio=4, qkv_bias=True, qk_scale=None, 98 | drop=0, attn_drop=0, drop_path=dpr[11], norm_layer=partial(nn.LayerNorm, eps=1e-6)) 99 | 100 | self.b2 = nn.Sequential( 101 | self.block1, 102 | nn.LayerNorm(3072)#copy.deepcopy(layer_norm) 103 | ) 104 | 105 | 106 | self.bottleneck_1 = nn.BatchNorm1d(3072) 107 | self.bottleneck_1.bias.requires_grad_(False) 108 | self.bottleneck_1.apply(weights_init_kaiming) 109 | self.bottleneck_2 = nn.BatchNorm1d(3072) 110 | self.bottleneck_2.bias.requires_grad_(False) 111 | self.bottleneck_2.apply(weights_init_kaiming) 112 | self.bottleneck_3 = nn.BatchNorm1d(3072) 113 | self.bottleneck_3.bias.requires_grad_(False) 114 | self.bottleneck_3.apply(weights_init_kaiming) 115 | self.bottleneck_4 = nn.BatchNorm1d(3072) 116 | self.bottleneck_4.bias.requires_grad_(False) 117 | self.bottleneck_4.apply(weights_init_kaiming) 118 | 119 | 120 | self.classifier_1 = nn.Linear(3072, self.num_classes, bias=False) 121 | self.classifier_1.apply(weights_init_classifier) 122 | self.classifier_2 = nn.Linear(3072, self.num_classes, bias=False) 123 | self.classifier_2.apply(weights_init_classifier) 124 | self.classifier_3 = nn.Linear(3072, self.num_classes, bias=False) 125 | self.classifier_3.apply(weights_init_classifier) 126 | self.classifier_4 = nn.Linear(3072, self.num_classes, bias=False) 127 | self.classifier_4.apply(weights_init_classifier) 128 | 129 | 130 | #-------------------video attention------------- 131 | self.middle_dim = 256 # middle layer dimension 132 | self.attention_conv = nn.Conv2d(self.in_planes, self.middle_dim, [1,1]) # 7,4 cooresponds to 224, 112 input image size 133 | self.attention_tconv = nn.Conv1d(self.middle_dim, 1, 3, padding=1) 134 | self.attention_conv.apply(weights_init_kaiming) 135 | self.attention_tconv.apply(weights_init_kaiming) 136 | #------------------------------------------ 137 | 138 | self.shift_num = 5 139 | self.part = 4 140 | self.rearrange=True 141 | 142 | 143 | 144 | 145 | def forward(self, x, label=None, cam_label= None, view_label=None): # label is unused if self.cos_layer == 'no' 146 | b=x.size(0) 147 | t=x.size(1) 148 | 149 | x=x.view(x.size(0)*x.size(1), x.size(2), x.size(3), x.size(4)) 150 | features = self.base(x, cam_label=cam_label) 151 | 152 | 153 | # global branch 154 | b1_feat = self.b1(features) # [64, 129, 768] 155 | global_feat = b1_feat[:, 0] 156 | 157 | global_feat=global_feat.unsqueeze(dim=2) 158 | global_feat=global_feat.unsqueeze(dim=3) 159 | a = F.relu(self.attention_conv(global_feat)) 160 | a = a.view(b, t, self.middle_dim) 161 | a = a.permute(0,2,1) 162 | a = F.relu(self.attention_tconv(a)) 163 | a = a.view(b, t) 164 | a_vals = a 165 | 166 | a = F.softmax(a, dim=1) 167 | x = global_feat.view(b, t, -1) 168 | a = torch.unsqueeze(a, -1) 169 | a = a.expand(b, t, self.in_planes) 170 | att_x = torch.mul(x,a) 171 | att_x = torch.sum(att_x,1) 172 | 173 | global_feat = att_x.view(b,self.in_planes) 174 | feat = self.bottleneck(global_feat) 175 | 176 | 177 | 178 | 179 | #------------------------------------------------- 180 | #------------------------------------------------- 181 | 182 | 183 | # video patch patr features 184 | 185 | feature_length = features.size(1) - 1 186 | patch_length = feature_length // 4 187 | 188 | #Temporal clip shift and shuffled 189 | x ,token=TCSS(features, self.shift_num, b,t) 190 | 191 | 192 | # part1 193 | part1 = x[:, :patch_length] 194 | part1 = self.b2(torch.cat((token, part1), dim=1)) 195 | part1_f = part1[:, 0] 196 | 197 | # part2 198 | part2 = x[:, patch_length:patch_length*2] 199 | part2 = self.b2(torch.cat((token, part2), dim=1)) 200 | part2_f = part2[:, 0] 201 | 202 | # part3 203 | part3 = x[:, patch_length*2:patch_length*3] 204 | part3 = self.b2(torch.cat((token, part3), dim=1)) 205 | part3_f = part3[:, 0] 206 | 207 | # part4 208 | part4 = x[:, patch_length*3:patch_length*4] 209 | part4 = self.b2(torch.cat((token, part4), dim=1)) 210 | part4_f = part4[:, 0] 211 | 212 | 213 | 214 | part1_bn = self.bottleneck_1(part1_f) 215 | part2_bn = self.bottleneck_2(part2_f) 216 | part3_bn = self.bottleneck_3(part3_f) 217 | part4_bn = self.bottleneck_4(part4_f) 218 | 219 | if self.training: 220 | 221 | Global_ID = self.classifier(feat) 222 | Local_ID1 = self.classifier_1(part1_bn) 223 | Local_ID2 = self.classifier_2(part2_bn) 224 | Local_ID3 = self.classifier_3(part3_bn) 225 | Local_ID4 = self.classifier_4(part4_bn) 226 | 227 | return [Global_ID, Local_ID1, Local_ID2, Local_ID3, Local_ID4 ], [global_feat, part1_f, part2_f, part3_f,part4_f], a_vals #[global_feat, part1_f, part2_f, part3_f,part4_f], a_vals 228 | 229 | else: 230 | return torch.cat([feat, part1_bn/4 , part2_bn/4 , part3_bn /4, part4_bn/4 ], dim=1) 231 | 232 | 233 | 234 | def load_param(self, trained_path,load=False): 235 | if not load: 236 | param_dict = torch.load(trained_path) 237 | for i in param_dict: 238 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i]) 239 | print('Loading pretrained model from {}'.format(trained_path)) 240 | else: 241 | param_dict=trained_path 242 | for i in param_dict: 243 | #print(i) 244 | if i not in self.state_dict() or 'classifier' in i or 'sie_embed' in i: 245 | continue 246 | self.state_dict()[i].copy_(param_dict[i]) 247 | 248 | 249 | 250 | def load_param_finetune(self, model_path): 251 | param_dict = torch.load(model_path) 252 | for i in param_dict: 253 | self.state_dict()[i].copy_(param_dict[i]) 254 | print('Loading pretrained model for finetuning from {}'.format(model_path)) 255 | 256 | 257 | 258 | 259 | 260 | -------------------------------------------------------------------------------- /Dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | from torch.utils.data import DataLoader 4 | from PIL import Image, ImageFile 5 | 6 | from torch.utils.data import Dataset 7 | import os.path as osp 8 | import random 9 | import torch 10 | import numpy as np 11 | import math 12 | 13 | from timm.data.random_erasing import RandomErasing 14 | from utility import RandomIdentitySampler,RandomErasing3 15 | from Datasets.MARS_dataset import Mars 16 | from Datasets.iLDSVID import iLIDSVID 17 | from Datasets.PRID_dataset import PRID 18 | 19 | __factory = { 20 | 'Mars':Mars, 21 | 'iLIDSVID':iLIDSVID, 22 | 'PRID':PRID 23 | } 24 | 25 | def train_collate_fn(batch): 26 | 27 | imgs, pids, camids,a= zip(*batch) 28 | pids = torch.tensor(pids, dtype=torch.int64) 29 | 30 | camids = torch.tensor(camids, dtype=torch.int64) 31 | return torch.stack(imgs, dim=0), pids, camids, torch.stack(a, dim=0) 32 | 33 | def val_collate_fn(batch): 34 | 35 | imgs, pids, camids, img_paths = zip(*batch) 36 | viewids = torch.tensor(viewids, dtype=torch.int64) 37 | camids_batch = torch.tensor(camids, dtype=torch.int64) 38 | return torch.stack(imgs, dim=0), pids, camids_batch, img_paths 39 | 40 | def dataloader(Dataset_name): 41 | train_transforms = T.Compose([ 42 | T.Resize([256, 128], interpolation=3), 43 | T.RandomHorizontalFlip(p=0.5), 44 | T.Pad(10), 45 | T.RandomCrop([256, 128]), 46 | T.ToTensor(), 47 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 48 | 49 | 50 | ]) 51 | 52 | val_transforms = T.Compose([ 53 | T.Resize([256, 128]), 54 | T.ToTensor(), 55 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 56 | ]) 57 | 58 | 59 | 60 | dataset = __factory[Dataset_name]() 61 | train_set = VideoDataset_inderase(dataset.train, seq_len=4, sample='intelligent',transform=train_transforms) 62 | num_classes = dataset.num_train_pids 63 | cam_num = dataset.num_train_cams 64 | view_num = dataset.num_train_vids 65 | 66 | 67 | train_loader = DataLoader(train_set, batch_size=64,sampler=RandomIdentitySampler(dataset.train, 64,4),num_workers=4, collate_fn=train_collate_fn) 68 | 69 | q_val_set = VideoDataset(dataset.query, seq_len=4, sample='dense', transform=val_transforms) 70 | g_val_set = VideoDataset(dataset.gallery, seq_len=4, sample='dense', transform=val_transforms) 71 | 72 | 73 | return train_loader, len(dataset.query), num_classes, cam_num, view_num,q_val_set,g_val_set 74 | 75 | 76 | 77 | def read_image(img_path): 78 | """Keep reading image until succeed. 79 | This can avoid IOError incurred by heavy IO process.""" 80 | got_img = False 81 | while not got_img: 82 | try: 83 | img = Image.open(img_path).convert('RGB') 84 | got_img = True 85 | except IOError: 86 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 87 | pass 88 | return img 89 | 90 | class VideoDataset(Dataset): 91 | """Video Person ReID Dataset. 92 | Note batch data has shape (batch, seq_len, channel, height, width). 93 | """ 94 | sample_methods = ['evenly', 'random', 'all'] 95 | 96 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None , max_length=40): 97 | self.dataset = dataset 98 | self.seq_len = seq_len 99 | self.sample = sample 100 | self.transform = transform 101 | self.max_length = max_length 102 | 103 | def __len__(self): 104 | return len(self.dataset) 105 | 106 | def __getitem__(self, index): 107 | img_paths, pid, camid = self.dataset[index] 108 | num = len(img_paths) 109 | 110 | # if self.sample == 'restricted_random': 111 | # frame_indices = range(num) 112 | # chunks = 113 | # rand_end = max(0, len(frame_indices) - self.seq_len - 1) 114 | # begin_index = random.randint(0, rand_end) 115 | 116 | 117 | if self.sample == 'random': 118 | """ 119 | Randomly sample seq_len consecutive frames from num frames, 120 | if num is smaller than seq_len, then replicate items. 121 | This sampling strategy is used in training phase. 122 | """ 123 | frame_indices = range(num) 124 | rand_end = max(0, len(frame_indices) - self.seq_len - 1) 125 | begin_index = random.randint(0, rand_end) 126 | end_index = min(begin_index + self.seq_len, len(frame_indices)) 127 | 128 | indices = frame_indices[begin_index:end_index] 129 | # print(begin_index, end_index, indices) 130 | if len(indices) < self.seq_len: 131 | indices=np.array(indices) 132 | indices = np.append(indices , [indices[-1] for i in range(self.seq_len - len(indices))]) 133 | else: 134 | indices=np.array(indices) 135 | imgs = [] 136 | targt_cam=[] 137 | for index in indices: 138 | index=int(index) 139 | img_path = img_paths[index] 140 | img = read_image(img_path) 141 | if self.transform is not None: 142 | img = self.transform(img) 143 | img = img.unsqueeze(0) 144 | targt_cam.append(camid) 145 | imgs.append(img) 146 | imgs = torch.cat(imgs, dim=0) 147 | #imgs=imgs.permute(1,0,2,3) 148 | return imgs, pid, targt_cam 149 | 150 | elif self.sample == 'dense': 151 | """ 152 | Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1. 153 | This sampling strategy is used in test phase. 154 | """ 155 | # import pdb 156 | # pdb.set_trace() 157 | 158 | cur_index=0 159 | frame_indices = [i for i in range(num)] 160 | indices_list=[] 161 | while num-cur_index > self.seq_len: 162 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len]) 163 | cur_index+=self.seq_len 164 | last_seq=frame_indices[cur_index:] 165 | # print(last_seq) 166 | for index in last_seq: 167 | if len(last_seq) >= self.seq_len: 168 | break 169 | last_seq.append(index) 170 | 171 | 172 | indices_list.append(last_seq) 173 | imgs_list=[] 174 | targt_cam=[] 175 | # print(indices_list , num , img_paths ) 176 | for indices in indices_list: 177 | if len(imgs_list) > self.max_length: 178 | break 179 | imgs = [] 180 | for index in indices: 181 | index=int(index) 182 | img_path = img_paths[index] 183 | img = read_image(img_path) 184 | if self.transform is not None: 185 | img = self.transform(img) 186 | img = img.unsqueeze(0) 187 | imgs.append(img) 188 | targt_cam.append(camid) 189 | imgs = torch.cat(imgs, dim=0) 190 | #imgs=imgs.permute(1,0,2,3) 191 | imgs_list.append(imgs) 192 | imgs_array = torch.stack(imgs_list) 193 | return imgs_array, pid, targt_cam,img_paths 194 | #return imgs_array, pid, int(camid),trackid 195 | 196 | 197 | elif self.sample == 'dense_subset': 198 | """ 199 | Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1. 200 | This sampling strategy is used in test phase. 201 | """ 202 | frame_indices = range(num) 203 | rand_end = max(0, len(frame_indices) - self.max_length - 1) 204 | begin_index = random.randint(0, rand_end) 205 | 206 | 207 | cur_index=begin_index 208 | frame_indices = [i for i in range(num)] 209 | indices_list=[] 210 | while num-cur_index > self.seq_len: 211 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len]) 212 | cur_index+=self.seq_len 213 | last_seq=frame_indices[cur_index:] 214 | # print(last_seq) 215 | for index in last_seq: 216 | if len(last_seq) >= self.seq_len: 217 | break 218 | last_seq.append(index) 219 | 220 | 221 | indices_list.append(last_seq) 222 | imgs_list=[] 223 | # print(indices_list , num , img_paths ) 224 | for indices in indices_list: 225 | if len(imgs_list) > self.max_length: 226 | break 227 | imgs = [] 228 | for index in indices: 229 | index=int(index) 230 | img_path = img_paths[index] 231 | img = read_image(img_path) 232 | if self.transform is not None: 233 | img = self.transform(img) 234 | img = img.unsqueeze(0) 235 | imgs.append(img) 236 | imgs = torch.cat(imgs, dim=0) 237 | #imgs=imgs.permute(1,0,2,3) 238 | imgs_list.append(imgs) 239 | imgs_array = torch.stack(imgs_list) 240 | return imgs_array, pid, camid 241 | 242 | elif self.sample == 'intelligent_random': 243 | # frame_indices = range(num) 244 | indices = [] 245 | each = max(num//seq_len,1) 246 | for i in range(seq_len): 247 | if i != seq_len -1: 248 | indices.append(random.randint(min(i*each , num-1), min( (i+1)*each-1, num-1)) ) 249 | else: 250 | indices.append(random.randint(min(i*each , num-1), num-1) ) 251 | print(len(indices)) 252 | imgs = [] 253 | for index in indices: 254 | index=int(index) 255 | img_path = img_paths[index] 256 | img = read_image(img_path) 257 | if self.transform is not None: 258 | img = self.transform(img) 259 | img = img.unsqueeze(0) 260 | imgs.append(img) 261 | imgs = torch.cat(imgs, dim=0) 262 | #imgs=imgs.permute(1,0,2,3) 263 | return imgs, pid, camid 264 | else: 265 | raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods)) 266 | 267 | 268 | 269 | 270 | class VideoDataset_inderase(Dataset): 271 | """Video Person ReID Dataset. 272 | Note batch data has shape (batch, seq_len, channel, height, width). 273 | """ 274 | sample_methods = ['evenly', 'random', 'all'] 275 | 276 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None , max_length=40): 277 | self.dataset = dataset 278 | self.seq_len = seq_len 279 | self.sample = sample 280 | self.transform = transform 281 | self.max_length = max_length 282 | self.erase = RandomErasing3(probability=0.5, mean=[0.485, 0.456, 0.406]) 283 | 284 | def __len__(self): 285 | return len(self.dataset) 286 | 287 | def __getitem__(self, index): 288 | img_paths, pid, camid = self.dataset[index] 289 | num = len(img_paths) 290 | if self.sample != "intelligent": 291 | frame_indices = range(num) 292 | rand_end = max(0, len(frame_indices) - self.seq_len - 1) 293 | begin_index = random.randint(0, rand_end) 294 | end_index = min(begin_index + self.seq_len, len(frame_indices)) 295 | 296 | indices1 = frame_indices[begin_index:end_index] 297 | indices = [] 298 | for index in indices1: 299 | if len(indices1) >= self.seq_len: 300 | break 301 | indices.append(index) 302 | indices=np.array(indices) 303 | else: 304 | # frame_indices = range(num) 305 | indices = [] 306 | each = max(num//self.seq_len,1) 307 | for i in range(self.seq_len): 308 | if i != self.seq_len -1: 309 | indices.append(random.randint(min(i*each , num-1), min( (i+1)*each-1, num-1)) ) 310 | else: 311 | indices.append(random.randint(min(i*each , num-1), num-1) ) 312 | # print(len(indices), indices, num ) 313 | imgs = [] 314 | labels = [] 315 | targt_cam=[] 316 | 317 | for index in indices: 318 | index=int(index) 319 | img_path = img_paths[index] 320 | 321 | img = read_image(img_path) 322 | if self.transform is not None: 323 | img = self.transform(img) 324 | img , temp = self.erase(img) 325 | labels.append(temp) 326 | img = img.unsqueeze(0) 327 | imgs.append(img) 328 | targt_cam.append(camid) 329 | labels = torch.tensor(labels) 330 | imgs = torch.cat(imgs, dim=0) 331 | 332 | return imgs, pid, targt_cam ,labels 333 | 334 | 335 | -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | import torch 7 | import logging 8 | import math 9 | import torch 10 | from typing import Dict, Any 11 | class RandomIdentitySampler(Sampler): 12 | """ 13 | Randomly sample N identities, then for each identity, 14 | randomly sample K instances, therefore batch size is N*K. 15 | Args: 16 | - data_source (list): list of (img_path, pid, camid). 17 | - num_instances (int): number of instances per identity in a batch. 18 | - batch_size (int): number of examples in a batch. 19 | """ 20 | 21 | def __init__(self, data_source, batch_size, num_instances): 22 | self.data_source = data_source 23 | self.batch_size = batch_size 24 | self.num_instances = num_instances 25 | self.num_pids_per_batch = self.batch_size // self.num_instances 26 | self.index_dic = defaultdict(list) #dict with list value 27 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 28 | for index, (_, pid, _) in enumerate(self.data_source): 29 | self.index_dic[pid].append(index) 30 | self.pids = list(self.index_dic.keys()) 31 | 32 | # estimate number of examples in an epoch 33 | self.length = 0 34 | for pid in self.pids: 35 | idxs = self.index_dic[pid] 36 | num = len(idxs) 37 | if num < self.num_instances: 38 | num = self.num_instances 39 | self.length += num - num % self.num_instances 40 | 41 | def __iter__(self): 42 | batch_idxs_dict = defaultdict(list) 43 | 44 | for pid in self.pids: 45 | idxs = copy.deepcopy(self.index_dic[pid]) 46 | if len(idxs) < self.num_instances: 47 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 48 | random.shuffle(idxs) 49 | batch_idxs = [] 50 | for idx in idxs: 51 | batch_idxs.append(idx) 52 | if len(batch_idxs) == self.num_instances: 53 | batch_idxs_dict[pid].append(batch_idxs) 54 | batch_idxs = [] 55 | 56 | avai_pids = copy.deepcopy(self.pids) 57 | final_idxs = [] 58 | 59 | while len(avai_pids) >= self.num_pids_per_batch: 60 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 61 | for pid in selected_pids: 62 | batch_idxs = batch_idxs_dict[pid].pop(0) 63 | final_idxs.extend(batch_idxs) 64 | if len(batch_idxs_dict[pid]) == 0: 65 | avai_pids.remove(pid) 66 | 67 | return iter(final_idxs) 68 | 69 | def __len__(self): 70 | return self.length 71 | 72 | class AverageMeter(object): 73 | """Computes and stores the average and current value""" 74 | 75 | def __init__(self): 76 | self.val = 0 77 | self.avg = 0 78 | self.sum = 0 79 | self.count = 0 80 | 81 | def reset(self): 82 | self.val = 0 83 | self.avg = 0 84 | self.sum = 0 85 | self.count = 0 86 | 87 | def update(self, val, n=1): 88 | self.val = val 89 | self.sum += val * n 90 | self.count += n 91 | self.avg = self.sum / self.count 92 | 93 | class RandomErasing3(object): 94 | """ Randomly selects a rectangle region in an image and erases its pixels. 95 | 'Random Erasing Data Augmentation' by Zhong et al. 96 | See https://arxiv.org/pdf/1708.04896.pdf 97 | Args: 98 | probability: The probability that the Random Erasing operation will be performed. 99 | sl: Minimum proportion of erased area against input image. 100 | sh: Maximum proportion of erased area against input image. 101 | r1: Minimum aspect ratio of erased area. 102 | mean: Erasing value. 103 | """ 104 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 105 | self.probability = probability 106 | self.mean = mean 107 | self.sl = sl 108 | self.sh = sh 109 | self.r1 = r1 110 | def __call__(self, img): 111 | if random.uniform(0, 1) >= self.probability: 112 | return img , 0 113 | for attempt in range(100): 114 | area = img.size()[1] * img.size()[2] 115 | target_area = random.uniform(self.sl, self.sh) * area 116 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 117 | h = int(round(math.sqrt(target_area * aspect_ratio))) 118 | w = int(round(math.sqrt(target_area / aspect_ratio))) 119 | if w < img.size()[2] and h < img.size()[1]: 120 | x1 = random.randint(0, img.size()[1] - h) 121 | y1 = random.randint(0, img.size()[2] - w) 122 | if img.size()[0] == 3: 123 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 124 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 125 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 126 | else: 127 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 128 | return img , 1 129 | return img , 0 130 | 131 | 132 | def scheduler(optimizer): 133 | num_epochs = 120 134 | 135 | lr_min = 0.002 * 0.008 136 | warmup_lr_init = 0.01 * 0.008 137 | 138 | warmup_t = 5 139 | noise_range = None 140 | 141 | lr_scheduler = CosineLRScheduler( 142 | optimizer, 143 | t_initial=num_epochs, 144 | lr_min=lr_min, 145 | t_mul= 1., 146 | decay_rate=0.1, 147 | warmup_lr_init=warmup_lr_init, 148 | warmup_t=warmup_t, 149 | cycle_limit=1, 150 | t_in_epochs=True, 151 | noise_range_t=noise_range, 152 | noise_pct= 0.67, 153 | noise_std= 1., 154 | noise_seed=42, 155 | ) 156 | 157 | return lr_scheduler 158 | 159 | 160 | 161 | 162 | def optimizer(model): 163 | params = [] 164 | for key, value in model.named_parameters(): 165 | if not value.requires_grad: 166 | continue 167 | lr = 0.008 168 | weight_decay = 1e-4 169 | if "bias" in key: 170 | lr = 0.008 * 2 171 | weight_decay = 1e-4 172 | 173 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 174 | 175 | 176 | optimizer = getattr(torch.optim, 'SGD')(params, momentum=0.9) 177 | 178 | 179 | 180 | return optimizer 181 | 182 | 183 | 184 | 185 | 186 | class Scheduler: 187 | """ Parameter Scheduler Base Class 188 | A scheduler base class that can be used to schedule any optimizer parameter groups. 189 | 190 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 191 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 192 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 193 | 194 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 195 | 196 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 197 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 198 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 199 | 200 | Based on ideas from: 201 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 202 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 203 | """ 204 | 205 | def __init__(self, 206 | optimizer: torch.optim.Optimizer, 207 | param_group_field: str, 208 | noise_range_t=None, 209 | noise_type='normal', 210 | noise_pct=0.67, 211 | noise_std=1.0, 212 | noise_seed=None, 213 | initialize: bool = True) -> None: 214 | self.optimizer = optimizer 215 | self.param_group_field = param_group_field 216 | self._initial_param_group_field = f"initial_{param_group_field}" 217 | if initialize: 218 | for i, group in enumerate(self.optimizer.param_groups): 219 | if param_group_field not in group: 220 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 221 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 222 | else: 223 | for i, group in enumerate(self.optimizer.param_groups): 224 | if self._initial_param_group_field not in group: 225 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 226 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 227 | self.metric = None # any point to having this for all? 228 | self.noise_range_t = noise_range_t 229 | self.noise_pct = noise_pct 230 | self.noise_type = noise_type 231 | self.noise_std = noise_std 232 | self.noise_seed = noise_seed if noise_seed is not None else 42 233 | self.update_groups(self.base_values) 234 | 235 | def state_dict(self) -> Dict[str, Any]: 236 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 237 | 238 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 239 | self.__dict__.update(state_dict) 240 | 241 | def get_epoch_values(self, epoch: int): 242 | return None 243 | 244 | def get_update_values(self, num_updates: int): 245 | return None 246 | 247 | def step(self, epoch: int, metric: float = None) -> None: 248 | self.metric = metric 249 | values = self.get_epoch_values(epoch) 250 | if values is not None: 251 | values = self._add_noise(values, epoch) 252 | self.update_groups(values) 253 | 254 | def step_update(self, num_updates: int, metric: float = None): 255 | self.metric = metric 256 | values = self.get_update_values(num_updates) 257 | if values is not None: 258 | values = self._add_noise(values, num_updates) 259 | self.update_groups(values) 260 | 261 | def update_groups(self, values): 262 | if not isinstance(values, (list, tuple)): 263 | values = [values] * len(self.optimizer.param_groups) 264 | for param_group, value in zip(self.optimizer.param_groups, values): 265 | param_group[self.param_group_field] = value 266 | 267 | def _add_noise(self, lrs, t): 268 | if self.noise_range_t is not None: 269 | if isinstance(self.noise_range_t, (list, tuple)): 270 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 271 | else: 272 | apply_noise = t >= self.noise_range_t 273 | if apply_noise: 274 | g = torch.Generator() 275 | g.manual_seed(self.noise_seed + t) 276 | if self.noise_type == 'normal': 277 | while True: 278 | # resample if noise out of percent limit, brute force but shouldn't spin much 279 | noise = torch.randn(1, generator=g).item() 280 | if abs(noise) < self.noise_pct: 281 | break 282 | else: 283 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 284 | lrs = [v + v * noise for v in lrs] 285 | return lrs 286 | 287 | class CosineLRScheduler(Scheduler): 288 | """ 289 | Cosine decay with restarts. 290 | This is described in the paper https://arxiv.org/abs/1608.03983. 291 | 292 | Inspiration from 293 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 294 | """ 295 | 296 | def __init__(self, 297 | optimizer: torch.optim.Optimizer, 298 | t_initial: int, 299 | t_mul: float = 1., 300 | lr_min: float = 0., 301 | decay_rate: float = 1., 302 | warmup_t=0, 303 | warmup_lr_init=0, 304 | warmup_prefix=False, 305 | cycle_limit=0, 306 | t_in_epochs=True, 307 | noise_range_t=None, 308 | noise_pct=0.67, 309 | noise_std=1.0, 310 | noise_seed=42, 311 | initialize=True) -> None: 312 | super().__init__( 313 | optimizer, param_group_field="lr", 314 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 315 | initialize=initialize) 316 | 317 | assert t_initial > 0 318 | assert lr_min >= 0 319 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 320 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 321 | "rate since t_initial = t_mul = eta_mul = 1.") 322 | self.t_initial = t_initial 323 | self.t_mul = t_mul 324 | self.lr_min = lr_min 325 | self.decay_rate = decay_rate 326 | self.cycle_limit = cycle_limit 327 | self.warmup_t = warmup_t 328 | self.warmup_lr_init = warmup_lr_init 329 | self.warmup_prefix = warmup_prefix 330 | self.t_in_epochs = t_in_epochs 331 | if self.warmup_t: 332 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 333 | super().update_groups(self.warmup_lr_init) 334 | else: 335 | self.warmup_steps = [1 for _ in self.base_values] 336 | 337 | def _get_lr(self, t): 338 | if t < self.warmup_t: 339 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 340 | else: 341 | if self.warmup_prefix: 342 | t = t - self.warmup_t 343 | 344 | if self.t_mul != 1: 345 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 346 | t_i = self.t_mul ** i * self.t_initial 347 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 348 | else: 349 | i = t // self.t_initial 350 | t_i = self.t_initial 351 | t_curr = t - (self.t_initial * i) 352 | 353 | gamma = self.decay_rate ** i 354 | lr_min = self.lr_min * gamma 355 | lr_max_values = [v * gamma for v in self.base_values] 356 | 357 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 358 | lrs = [ 359 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 360 | ] 361 | else: 362 | lrs = [self.lr_min for _ in self.base_values] 363 | 364 | return lrs 365 | 366 | def get_epoch_values(self, epoch: int): 367 | if self.t_in_epochs: 368 | return self._get_lr(epoch) 369 | else: 370 | return None 371 | 372 | def get_update_values(self, num_updates: int): 373 | if not self.t_in_epochs: 374 | return self._get_lr(num_updates) 375 | else: 376 | return None 377 | 378 | def get_cycle_length(self, cycles=0): 379 | if not cycles: 380 | cycles = self.cycle_limit 381 | cycles = max(1, cycles) 382 | if self.t_mul == 1.0: 383 | return self.t_initial * cycles 384 | else: 385 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 386 | 387 | 388 | 389 | -------------------------------------------------------------------------------- /vit_ID.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from functools import partial 4 | from itertools import repeat 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch._six import container_abcs 10 | 11 | 12 | # From PyTorch internals 13 | def _ntuple(n): 14 | def parse(x): 15 | if isinstance(x, container_abcs.Iterable): 16 | return x 17 | return tuple(repeat(x, n)) 18 | return parse 19 | 20 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 21 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 22 | to_2tuple = _ntuple(2) 23 | 24 | def drop_path(x, drop_prob: float = 0., training: bool = False): 25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 26 | 27 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 28 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 29 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 30 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 31 | 'survival rate' as the argument. 32 | 33 | """ 34 | if drop_prob == 0. or not training: 35 | return x 36 | keep_prob = 1 - drop_prob 37 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 38 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 39 | random_tensor.floor_() # binarize 40 | output = x.div(keep_prob) * random_tensor 41 | return output 42 | 43 | class DropPath(nn.Module): 44 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 45 | """ 46 | def __init__(self, drop_prob=None): 47 | super(DropPath, self).__init__() 48 | self.drop_prob = drop_prob 49 | 50 | def forward(self, x): 51 | return drop_path(x, self.drop_prob, self.training) 52 | 53 | 54 | 55 | 56 | class Mlp(nn.Module): 57 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 58 | super().__init__() 59 | out_features = out_features or in_features 60 | hidden_features = hidden_features or in_features 61 | self.fc1 = nn.Linear(in_features, hidden_features) 62 | self.act = act_layer() 63 | self.fc2 = nn.Linear(hidden_features, out_features) 64 | self.drop = nn.Dropout(drop) 65 | 66 | def forward(self, x): 67 | x = self.fc1(x) 68 | x = self.act(x) 69 | x = self.drop(x) 70 | x = self.fc2(x) 71 | x = self.drop(x) 72 | return x 73 | 74 | 75 | class Attention(nn.Module): 76 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 77 | super().__init__() 78 | self.num_heads = num_heads 79 | head_dim = dim // num_heads 80 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 81 | self.scale = qk_scale or head_dim ** -0.5 82 | 83 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 84 | self.attn_drop = nn.Dropout(attn_drop) 85 | self.proj = nn.Linear(dim, dim) 86 | self.proj_drop = nn.Dropout(proj_drop) 87 | 88 | def forward(self, x): 89 | 90 | B, N, C = x.shape 91 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 92 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 93 | 94 | attn = (q @ k.transpose(-2, -1)) * self.scale 95 | attn = attn.softmax(dim=-1) 96 | attn = self.attn_drop(attn) 97 | 98 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 99 | x = self.proj(x) 100 | x = self.proj_drop(x) 101 | return x 102 | 103 | 104 | class Block(nn.Module): 105 | 106 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 107 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 108 | super().__init__() 109 | self.norm1 = norm_layer(dim) 110 | self.attn = Attention( 111 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 112 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 113 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 114 | self.norm2 = norm_layer(dim) 115 | mlp_hidden_dim = int(dim * mlp_ratio) 116 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 117 | 118 | def forward(self, x): 119 | x = x + self.drop_path(self.attn(self.norm1(x))) 120 | x = x + self.drop_path(self.mlp(self.norm2(x))) 121 | return x 122 | 123 | 124 | class PatchEmbed(nn.Module): 125 | """ Image to Patch Embedding 126 | """ 127 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 128 | super().__init__() 129 | img_size = to_2tuple(img_size) 130 | patch_size = to_2tuple(patch_size) 131 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 132 | self.img_size = img_size 133 | self.patch_size = patch_size 134 | self.num_patches = num_patches 135 | 136 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 137 | 138 | def forward(self, x): 139 | B, C, H, W = x.shape 140 | # FIXME look at relaxing size constraints 141 | assert H == self.img_size[0] and W == self.img_size[1], \ 142 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 143 | x = self.proj(x).flatten(2).transpose(1, 2) 144 | return x 145 | 146 | 147 | 148 | 149 | 150 | class PatchEmbed_overlap(nn.Module): 151 | """ Image to Patch Embedding with overlapping patches 152 | """ 153 | def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768): 154 | super().__init__() 155 | img_size = to_2tuple(img_size) 156 | patch_size = to_2tuple(patch_size) 157 | stride_size_tuple = to_2tuple(stride_size) 158 | self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1 159 | self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1 160 | print('using stride: {}, and patch number is num_y{} * num_x{}'.format(stride_size, self.num_y, self.num_x)) 161 | num_patches = self.num_x * self.num_y 162 | self.img_size = img_size 163 | self.patch_size = patch_size 164 | self.num_patches = num_patches 165 | 166 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size) 167 | for m in self.modules(): 168 | if isinstance(m, nn.Conv2d): 169 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 170 | m.weight.data.normal_(0, math.sqrt(2. / n)) 171 | elif isinstance(m, nn.BatchNorm2d): 172 | m.weight.data.fill_(1) 173 | m.bias.data.zero_() 174 | elif isinstance(m, nn.InstanceNorm2d): 175 | m.weight.data.fill_(1) 176 | m.bias.data.zero_() 177 | 178 | def forward(self, x): 179 | B, C, H, W = x.shape 180 | 181 | # FIXME look at relaxing size constraints 182 | assert H == self.img_size[0] and W == self.img_size[1], \ 183 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 184 | x = self.proj(x) 185 | 186 | x = x.flatten(2).transpose(1, 2) # [64, 8, 768] 187 | return x 188 | 189 | 190 | class TransReID(nn.Module): 191 | """ Transformer-based Object Re-Identification 192 | """ 193 | def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 194 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., camera=0, 195 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, cam_lambda =3.0): 196 | super().__init__() 197 | self.num_classes = num_classes 198 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 199 | 200 | self.cam_num = camera 201 | self.cam_lambda = cam_lambda 202 | 203 | 204 | self.patch_embed = PatchEmbed_overlap(img_size=img_size, patch_size=patch_size, stride_size=stride_size, in_chans=in_chans,embed_dim=embed_dim) 205 | num_patches = self.patch_embed.num_patches 206 | 207 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 208 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 209 | self.Cam = nn.Parameter(torch.zeros(camera, 1, embed_dim)) 210 | 211 | trunc_normal_(self.Cam, std=.02) 212 | self.pos_drop = nn.Dropout(p=drop_rate) 213 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 214 | 215 | self.blocks = nn.ModuleList([ 216 | Block( 217 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 218 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 219 | for i in range(depth)]) 220 | 221 | self.norm = norm_layer(embed_dim) 222 | 223 | # Classifier head 224 | self.fc = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 225 | trunc_normal_(self.cls_token, std=.02) 226 | trunc_normal_(self.pos_embed, std=.02) 227 | 228 | self.apply(self._init_weights) 229 | 230 | def _init_weights(self, m): 231 | if isinstance(m, nn.Linear): 232 | trunc_normal_(m.weight, std=.02) 233 | if isinstance(m, nn.Linear) and m.bias is not None: 234 | nn.init.constant_(m.bias, 0) 235 | elif isinstance(m, nn.LayerNorm): 236 | nn.init.constant_(m.bias, 0) 237 | nn.init.constant_(m.weight, 1.0) 238 | 239 | @torch.jit.ignore 240 | def no_weight_decay(self): 241 | return {'pos_embed', 'cls_token'} 242 | 243 | def get_classifier(self): 244 | return self.head 245 | 246 | def reset_classifier(self, num_classes, global_pool=''): 247 | self.num_classes = num_classes 248 | self.fc = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 249 | 250 | def forward_features(self, x, camera_id): 251 | B = x.shape[0] 252 | 253 | x = self.patch_embed(x) 254 | 255 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 256 | x = torch.cat((cls_tokens, x), dim=1) 257 | x = x + self.pos_embed + self.cam_lambda * self.Cam[camera_id] 258 | x = self.pos_drop(x) 259 | 260 | for blk in self.blocks[:-1]: 261 | x = blk(x) 262 | return x 263 | 264 | 265 | def forward(self, x, cam_label=None): 266 | x = self.forward_features(x, cam_label) 267 | return x 268 | 269 | def load_param(self, model_path,load=False): 270 | if not load: 271 | param_dict = torch.load(model_path, map_location='cpu') 272 | else: 273 | param_dict= model_path 274 | if 'model' in param_dict: 275 | param_dict = param_dict['model'] 276 | if 'state_dict' in param_dict: 277 | param_dict = param_dict['state_dict'] 278 | for k, v in param_dict.items(): 279 | if 'head' in k or 'dist' in k: 280 | continue 281 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 282 | # For old models that I trained prior to conv based patchification 283 | O, I, H, W = self.patch_embed.proj.weight.shape 284 | v = v.reshape(O, -1, H, W) 285 | elif k == 'pos_embed' and v.shape != self.pos_embed.shape: 286 | # To resize pos embedding when using model at different size from pretrained weights 287 | if 'distilled' in model_path: 288 | print('distill need to choose right cls token in the pth') 289 | v = torch.cat([v[:, 0:1], v[:, 2:]], dim=1) 290 | v = resize_pos_embed(v, self.pos_embed, self.patch_embed.num_y, self.patch_embed.num_x) 291 | try: 292 | self.state_dict()[k].copy_(v) 293 | except: 294 | print('===========================ERROR=========================') 295 | print('shape do not match in k :{}: param_dict{} vs self.state_dict(){}'.format(k, v.shape, self.state_dict()[k].shape)) 296 | 297 | 298 | def resize_pos_embed(posemb, posemb_new, hight, width): 299 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 300 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 301 | ntok_new = posemb_new.shape[1] 302 | 303 | posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] 304 | ntok_new -= 1 305 | 306 | gs_old = int(math.sqrt(len(posemb_grid))) 307 | print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width)) 308 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 309 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') 310 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) 311 | posemb = torch.cat([posemb_token, posemb_grid], dim=1) 312 | return posemb 313 | 314 | 315 | 316 | 317 | 318 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 319 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 320 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 321 | def norm_cdf(x): 322 | # Computes standard normal cumulative distribution function 323 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 324 | 325 | if (mean < a - 2 * std) or (mean > b + 2 * std): 326 | print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 327 | "The distribution of values may be incorrect.",) 328 | 329 | with torch.no_grad(): 330 | # Values are generated by using a truncated uniform distribution and 331 | # then using the inverse CDF for the normal distribution. 332 | # Get upper and lower cdf values 333 | l = norm_cdf((a - mean) / std) 334 | u = norm_cdf((b - mean) / std) 335 | 336 | # Uniformly fill tensor with values from [l, u], then translate to 337 | # [2l-1, 2u-1]. 338 | tensor.uniform_(2 * l - 1, 2 * u - 1) 339 | 340 | # Use inverse cdf transform for normal distribution to get truncated 341 | # standard normal 342 | tensor.erfinv_() 343 | 344 | # Transform to proper mean, std 345 | tensor.mul_(std * math.sqrt(2.)) 346 | tensor.add_(mean) 347 | 348 | # Clamp to ensure it's in the proper range 349 | tensor.clamp_(min=a, max=b) 350 | return tensor 351 | 352 | 353 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 354 | # type: (Tensor, float, float, float, float) -> Tensor 355 | r"""Fills the input Tensor with values drawn from a truncated 356 | normal distribution. The values are effectively drawn from the 357 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 358 | with values outside :math:`[a, b]` redrawn until they are within 359 | the bounds. The method used for generating the random values works 360 | best when :math:`a \leq \text{mean} \leq b`. 361 | Args: 362 | tensor: an n-dimensional `torch.Tensor` 363 | mean: the mean of the normal distribution 364 | std: the standard deviation of the normal distribution 365 | a: the minimum cutoff value 366 | b: the maximum cutoff value 367 | Examples: 368 | >>> w = torch.empty(3, 5) 369 | >>> nn.init.trunc_normal_(w) 370 | """ 371 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 372 | --------------------------------------------------------------------------------