├── requirements.txt ├── Datasets ├── __pycache__ │ ├── iLDSVID.cpython-36.pyc │ ├── MARS_dataset.cpython-36.pyc │ └── PRID_dataset.cpython-36.pyc ├── PRID_dataset.py ├── iLDSVID.py └── MARS_dataset.py ├── loss ├── __pycache__ │ ├── center_loss.cpython-36.pyc │ ├── softmax_loss.cpython-36.pyc │ └── triplet_loss.cpython-36.pyc ├── softmax_loss.py ├── center_loss.py └── triplet_loss.py ├── Loss_fun.py ├── README.md ├── VID_Test.py ├── VID_Trans_ReID.py ├── VID_Trans_model.py ├── Dataloader.py ├── utility.py └── vit_ID.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | timm 4 | yacs 5 | opencv-python 6 | -------------------------------------------------------------------------------- /Datasets/__pycache__/iLDSVID.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/Datasets/__pycache__/iLDSVID.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/center_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/loss/__pycache__/center_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/softmax_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/loss/__pycache__/softmax_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/triplet_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/loss/__pycache__/triplet_loss.cpython-36.pyc -------------------------------------------------------------------------------- /Datasets/__pycache__/MARS_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/Datasets/__pycache__/MARS_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /Datasets/__pycache__/PRID_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishahAADU/VID-Trans-ReID/HEAD/Datasets/__pycache__/PRID_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /Loss_fun.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from loss.softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 3 | from loss.triplet_loss import TripletLoss 4 | from loss.center_loss import CenterLoss 5 | 6 | 7 | def make_loss(num_classes): 8 | 9 | feat_dim =768 10 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 11 | center_criterion2 = CenterLoss(num_classes=num_classes, feat_dim=3072, use_gpu=True) 12 | 13 | triplet = TripletLoss() 14 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 15 | 16 | 17 | def loss_func(score, feat, target, target_cam): 18 | if isinstance(score, list): 19 | ID_LOSS = [xent(scor, target) for scor in score[1:]] 20 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 21 | ID_LOSS = 0.25 * ID_LOSS + 0.75 * xent(score[0], target) 22 | else: 23 | ID_LOSS = xent(score, target) 24 | 25 | if isinstance(feat, list): 26 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 27 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 28 | TRI_LOSS = 0.25 * TRI_LOSS + 0.75 * triplet(feat[0], target)[0] 29 | 30 | center=center_criterion(feat[0], target) 31 | centr2 = [center_criterion2(feats, target) for feats in feat[1:]] 32 | centr2 = sum(centr2) / len(centr2) 33 | center=0.25 *centr2 + 0.75 * center 34 | else: 35 | TRI_LOSS = triplet(feat, target)[0] 36 | 37 | return ID_LOSS+ TRI_LOSS, center 38 | 39 | return loss_func,center_criterion 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /loss/softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | class CrossEntropyLabelSmooth(nn.Module): 5 | """Cross entropy loss with label smoothing regularizer. 6 | 7 | Reference: 8 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 9 | Equation: y = (1 - epsilon) * y + epsilon / K. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | epsilon (float): weight. 14 | """ 15 | 16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 17 | super(CrossEntropyLabelSmooth, self).__init__() 18 | self.num_classes = num_classes 19 | self.epsilon = epsilon 20 | self.use_gpu = use_gpu 21 | self.logsoftmax = nn.LogSoftmax(dim=1) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (num_classes) 28 | """ 29 | log_probs = self.logsoftmax(inputs) 30 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 31 | if self.use_gpu: targets = targets.cuda() 32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 33 | loss = (- targets * log_probs).mean(0).sum() 34 | return loss 35 | 36 | class LabelSmoothingCrossEntropy(nn.Module): 37 | """ 38 | NLL loss with label smoothing. 39 | """ 40 | def __init__(self, smoothing=0.1): 41 | """ 42 | Constructor for the LabelSmoothing module. 43 | :param smoothing: label smoothing factor 44 | """ 45 | super(LabelSmoothingCrossEntropy, self).__init__() 46 | assert smoothing < 1.0 47 | self.smoothing = smoothing 48 | self.confidence = 1. - smoothing 49 | 50 | def forward(self, x, target): 51 | logprobs = F.log_softmax(x, dim=-1) 52 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 53 | nll_loss = nll_loss.squeeze(1) 54 | smooth_loss = -logprobs.mean(dim=-1) 55 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 56 | return loss.mean() -------------------------------------------------------------------------------- /loss/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = [] 48 | for i in range(batch_size): 49 | value = distmat[i][mask[i]] 50 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 51 | dist.append(value) 52 | dist = torch.cat(dist) 53 | loss = dist.mean() 54 | return loss 55 | 56 | 57 | if __name__ == '__main__': 58 | use_gpu = False 59 | center_loss = CenterLoss(use_gpu=use_gpu) 60 | features = torch.rand(16, 2048) 61 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 62 | if use_gpu: 63 | features = torch.rand(16, 2048).cuda() 64 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 65 | 66 | loss = center_loss(features, targets) 67 | print(loss) 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VID-Trans-ReID 2 | This is an Official Pytorch Implementation of our paper VID-Trans-ReID: Enhanced Video Transformers for Person Re-identification 3 | 4 | [](https://www.python.org/downloads/release/python-370/) Tested using Python 3.7.x and Torch: 1.8.0. 5 | 6 | ## Architecture: 7 |
8 |
9 |
19 |
20 | ##
21 |
22 |
23 |
24 |
25 |
26 |
27 | ## Requirements
28 | ```
29 | pip install -r requirements.txt
30 | ```
31 | ## Getting Started
32 |
33 | 1. Download the ImageNet pretrained transformer model : [ViT_base](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth).
34 | 2. Download the video person Re-ID datasets [MARS](http://zheng-lab.cecs.anu.edu.au/Project/project_mars.html), [PRID](https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/prid11/) and [iLIDS-VID](https://xiatian-zhu.github.io/downloads_qmul_iLIDS-VID_ReID_dataset.html)
35 |
36 | ## Train and Evaluate
37 |
38 | Use the pre-trained model [ViT_base](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth) to initialize ViT transformer then train the whole model.
39 |
40 | MARS Datasete
41 | ```
42 | python -u VID_Trans_ReID.py --Dataset_name 'Mars' --ViT_path 'jx_vit_base_p16_224-80ecf9dd.pth'
43 | ```
44 |
45 | PRID Dataset
46 | ```
47 | python -u VID_Trans_ReID.py --Dataset_name 'PRID' --ViT_path 'jx_vit_base_p16_224-80ecf9dd.pth'
48 | ```
49 |
50 | iLIDS-VID Dataset
51 | ```
52 | python -u VID_Trans_ReID.py --Dataset_name 'iLIDSVID' --ViT_path 'jx_vit_base_p16_224-80ecf9dd.pth'
53 | ```
54 |
55 | ## Test
56 | To test the model you can use our pretrained model on MARS dataset [download](https://durhamuniversity-my.sharepoint.com/:u:/g/personal/zwjx97_durham_ac_uk/Ec09LVNFG_JKotjNPkVgTaIB7k0eUAwmPq9gawciw2ggBQ?e=swd9DK)
57 |
58 | ```
59 | python -u VID_Test.py --Dataset_name 'Mars' --model_path 'MarsMain_Model.pth'
60 | ```
61 |
62 | ## Acknowledgement
63 | Thanks to Hao Luo, using some implementation from his [repository](https://github.com/michuanhaohao)
64 |
65 | ## Citation
66 |
67 | If you are making use of this work in any way, you must please reference the following paper in any report, publication, presentation, software release or any other associated materials:
68 |
69 | [VID-Trans-ReID: Enhanced Video Transformers for Person Re-identification](https://breckon.org/toby/publications/papers/alsehaim22vidtransreid.pdf) (A. Alsehaim, T.P. Breckon), In Proc. British Machine Vision Conference, BMVA, 2022.
70 |
71 | ```
72 | @inproceedings{alsehaim22vidtransreid,
73 | author = {Alsehaim, A. and Breckon, T.P.},
74 | title = {VID-Trans-ReID: Enhanced Video Transformers for Person Re-identification},
75 | booktitle = {Proc. British Machine Vision Conference},
76 | year = {2022},
77 | month = {November},
78 | publisher = {BMVA},
79 | url = {https://breckon.org/toby/publications/papers/alsehaim22vidtransreid.pdf}
80 | }
81 | ```
82 |
--------------------------------------------------------------------------------
/loss/triplet_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def normalize(x, axis=-1):
6 | """Normalizing to unit length along the specified dimension.
7 | Args:
8 | x: pytorch Variable
9 | Returns:
10 | x: pytorch Variable, same shape as input
11 | """
12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
13 | return x
14 |
15 |
16 | def euclidean_dist(x, y):
17 | """
18 | Args:
19 | x: pytorch Variable, with shape [m, d]
20 | y: pytorch Variable, with shape [n, d]
21 | Returns:
22 | dist: pytorch Variable, with shape [m, n]
23 | """
24 | m, n = x.size(0), y.size(0)
25 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
26 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
27 | dist = xx + yy
28 | dist = dist - 2 * torch.matmul(x, y.t())
29 | # dist.addmm_(1, -2, x, y.t())
30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
31 | return dist
32 |
33 |
34 | def cosine_dist(x, y):
35 | """
36 | Args:
37 | x: pytorch Variable, with shape [m, d]
38 | y: pytorch Variable, with shape [n, d]
39 | Returns:
40 | dist: pytorch Variable, with shape [m, n]
41 | """
42 | m, n = x.size(0), y.size(0)
43 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n)
44 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t()
45 | xy_intersection = torch.mm(x, y.t())
46 | dist = xy_intersection/(x_norm * y_norm)
47 | dist = (1. - dist) / 2
48 | return dist
49 |
50 |
51 | def hard_example_mining(dist_mat, labels, return_inds=False):
52 | """For each anchor, find the hardest positive and negative sample.
53 | Args:
54 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
55 | labels: pytorch LongTensor, with shape [N]
56 | return_inds: whether to return the indices. Save time if `False`(?)
57 | Returns:
58 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
59 | dist_an: pytorch Variable, distance(anchor, negative); shape [N]
60 | p_inds: pytorch LongTensor, with shape [N];
61 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
62 | n_inds: pytorch LongTensor, with shape [N];
63 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
64 | NOTE: Only consider the case in which all labels have same num of samples,
65 | thus we can cope with all anchors in parallel.
66 | """
67 |
68 | assert len(dist_mat.size()) == 2
69 | assert dist_mat.size(0) == dist_mat.size(1)
70 | N = dist_mat.size(0)
71 |
72 | # shape [N, N]
73 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
74 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
75 |
76 | # `dist_ap` means distance(anchor, positive)
77 | # both `dist_ap` and `relative_p_inds` with shape [N, 1]
78 | dist_ap, relative_p_inds = torch.max(
79 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
80 | # print(dist_mat[is_pos].shape)
81 | # `dist_an` means distance(anchor, negative)
82 | # both `dist_an` and `relative_n_inds` with shape [N, 1]
83 | dist_an, relative_n_inds = torch.min(
84 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
85 | # shape [N]
86 | dist_ap = dist_ap.squeeze(1)
87 | dist_an = dist_an.squeeze(1)
88 |
89 | if return_inds:
90 | # shape [N, N]
91 | ind = (labels.new().resize_as_(labels)
92 | .copy_(torch.arange(0, N).long())
93 | .unsqueeze(0).expand(N, N))
94 | # shape [N, 1]
95 | p_inds = torch.gather(
96 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
97 | n_inds = torch.gather(
98 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
99 | # shape [N]
100 | p_inds = p_inds.squeeze(1)
101 | n_inds = n_inds.squeeze(1)
102 | return dist_ap, dist_an, p_inds, n_inds
103 |
104 | return dist_ap, dist_an
105 |
106 |
107 | class TripletLoss(object):
108 | """
109 | Triplet loss using HARDER example mining,
110 | modified based on original triplet loss using hard example mining
111 | """
112 |
113 | def __init__(self, margin=None, hard_factor=0.0):
114 | self.margin = margin
115 | self.hard_factor = hard_factor
116 | if margin is not None:
117 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
118 | else:
119 | self.ranking_loss = nn.SoftMarginLoss()
120 |
121 | def __call__(self, global_feat, labels, normalize_feature=False):
122 |
123 | if normalize_feature:
124 | global_feat = normalize(global_feat, axis=-1)
125 | dist_mat = euclidean_dist(global_feat, global_feat)
126 | dist_ap, dist_an = hard_example_mining(dist_mat, labels)
127 |
128 | dist_ap *= (1.0 + self.hard_factor)
129 | dist_an *= (1.0 - self.hard_factor)
130 |
131 | y = dist_an.new().resize_as_(dist_an).fill_(1)
132 | if self.margin is not None:
133 | loss = self.ranking_loss(dist_an, dist_ap, y)
134 | else:
135 | loss = self.ranking_loss(dist_an - dist_ap, y)
136 | return loss, dist_ap, dist_an
137 |
138 |
139 |
--------------------------------------------------------------------------------
/Datasets/PRID_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import print_function, absolute_import
3 |
4 | from collections import defaultdict
5 | from scipy.io import loadmat
6 | import os.path as osp
7 | import numpy as np
8 | import json
9 | import glob
10 | def read_json(fpath):
11 | with open(fpath, 'r') as f:
12 | obj = json.load(f)
13 | return obj
14 |
15 | class PRID(object):
16 | """
17 | PRID
18 | Reference:
19 | Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011.
20 |
21 | Dataset statistics:
22 | # identities: 200
23 | # tracklets: 400
24 | # cameras: 2
25 | Args:
26 | split_id (int): indicates which split to use. There are totally 10 splits.
27 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0).
28 | """
29 | root = "prid_2011"
30 |
31 | # root = './data/prid2011'
32 | dataset_url = 'https://files.icg.tugraz.at/f/6ab7e8ce8f/?raw=1'
33 | split_path = osp.join(root, 'splits_prid2011.json')
34 | cam_a_path = osp.join(root, 'multi_shot', 'cam_a')
35 | cam_b_path = osp.join(root, 'multi_shot', 'cam_b')
36 |
37 | def __init__(self, split_id=0, min_seq_len=0):
38 | self._check_before_run()
39 | splits = read_json(self.split_path)
40 | if split_id >= len(splits):
41 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1))
42 | split = splits[split_id]
43 | train_dirs, test_dirs = split['train'], split['test']
44 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs)))
45 |
46 | train, num_train_tracklets, num_train_pids, num_imgs_train = \
47 | self._process_data(train_dirs, cam1=True, cam2=True)
48 | query, num_query_tracklets, num_query_pids, num_imgs_query = \
49 | self._process_data(test_dirs, cam1=True, cam2=False)
50 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \
51 | self._process_data(test_dirs, cam1=False, cam2=True)
52 |
53 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery
54 | min_num = np.min(num_imgs_per_tracklet)
55 | max_num = np.max(num_imgs_per_tracklet)
56 | avg_num = np.mean(num_imgs_per_tracklet)
57 |
58 | num_total_pids = num_train_pids + num_query_pids
59 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets
60 |
61 | print("=> PRID-2011 loaded")
62 | print("Dataset statistics:")
63 | print(" ------------------------------")
64 | print(" subset | # ids | # tracklets")
65 | print(" ------------------------------")
66 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets))
67 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets))
68 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets))
69 | print(" ------------------------------")
70 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets))
71 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num))
72 | print(" ------------------------------")
73 |
74 | self.train = train
75 | self.query = query
76 | self.gallery = gallery
77 |
78 | self.num_train_pids = num_train_pids
79 | self.num_query_pids = num_query_pids
80 | self.num_gallery_pids = num_gallery_pids
81 | self.num_train_cams=2
82 | self.num_query_cams=2
83 | self.num_gallery_cams=2
84 | self.num_train_vids=num_train_tracklets
85 | self.num_query_vids=num_query_tracklets
86 | self.num_gallery_vids=num_gallery_tracklets
87 |
88 | def _check_before_run(self):
89 | """Check if all files are available before going deeper"""
90 | if not osp.exists(self.root):
91 | raise RuntimeError("'{}' is not available".format(self.root))
92 |
93 | def _process_data(self, dirnames, cam1=True, cam2=True):
94 | tracklets = []
95 | num_imgs_per_tracklet = []
96 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)}
97 |
98 | for dirname in dirnames:
99 | if cam1:
100 | person_dir = osp.join(self.cam_a_path, dirname)
101 | img_names = glob.glob(osp.join(person_dir, '*.png'))
102 | assert len(img_names) > 0
103 | img_names = tuple(img_names)
104 | pid = dirname2pid[dirname]
105 | tracklets.append((img_names, pid, 0))
106 | num_imgs_per_tracklet.append(len(img_names))
107 |
108 | if cam2:
109 | person_dir = osp.join(self.cam_b_path, dirname)
110 | img_names = glob.glob(osp.join(person_dir, '*.png'))
111 | assert len(img_names) > 0
112 | img_names = tuple(img_names)
113 | pid = dirname2pid[dirname]
114 | tracklets.append((img_names, pid, 1))
115 | num_imgs_per_tracklet.append(len(img_names))
116 |
117 | num_tracklets = len(tracklets)
118 | num_pids = len(dirnames)
119 |
120 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet
121 |
--------------------------------------------------------------------------------
/VID_Test.py:
--------------------------------------------------------------------------------
1 | from Dataloader import dataloader
2 | from VID_Trans_model import VID_Trans
3 |
4 |
5 | from Loss_fun import make_loss
6 | import random
7 | import torch
8 | import numpy as np
9 | import os
10 | import argparse
11 |
12 | import logging
13 | import os
14 | import time
15 | import torch
16 | import torch.nn as nn
17 |
18 | from torch.cuda import amp
19 | from utility import AverageMeter, optimizer,scheduler
20 |
21 | from torch.autograd import Variable
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=21):
31 | num_q, num_g = distmat.shape
32 | if num_g < max_rank:
33 | max_rank = num_g
34 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
35 | indices = np.argsort(distmat, axis=1)
36 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
37 |
38 | # compute cmc curve for each query
39 | all_cmc = []
40 | all_AP = []
41 | num_valid_q = 0.
42 | for q_idx in range(num_q):
43 | # get query pid and camid
44 | q_pid = q_pids[q_idx]
45 | q_camid = q_camids[q_idx]
46 |
47 | # remove gallery samples that have the same pid and camid with query
48 | order = indices[q_idx]
49 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
50 | keep = np.invert(remove)
51 |
52 | # compute cmc curve
53 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
54 | if not np.any(orig_cmc):
55 | # this condition is true when query identity does not appear in gallery
56 | continue
57 |
58 | cmc = orig_cmc.cumsum()
59 | cmc[cmc > 1] = 1
60 |
61 | all_cmc.append(cmc[:max_rank])
62 | num_valid_q += 1.
63 |
64 | # compute average precision
65 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
66 | num_rel = orig_cmc.sum()
67 | tmp_cmc = orig_cmc.cumsum()
68 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
69 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
70 | AP = tmp_cmc.sum() / num_rel
71 | all_AP.append(AP)
72 |
73 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
74 |
75 | all_cmc = np.asarray(all_cmc).astype(np.float32)
76 | all_cmc = all_cmc.sum(0) / num_valid_q
77 | mAP = np.mean(all_AP)
78 |
79 | return all_cmc, mAP
80 |
81 | def test(model, queryloader, galleryloader, pool='avg', use_gpu=True, ranks=[1, 5, 10, 20]):
82 | model.eval()
83 | qf, q_pids, q_camids = [], [], []
84 | with torch.no_grad():
85 | for batch_idx, (imgs, pids, camids,_) in enumerate(queryloader):
86 |
87 | if use_gpu:
88 | imgs = imgs.cuda()
89 | imgs = Variable(imgs, volatile=True)
90 |
91 | b, s, c, h, w = imgs.size()
92 |
93 |
94 | features = model(imgs,pids,cam_label=camids )
95 |
96 | features = features.view(b, -1)
97 | features = torch.mean(features, 0)
98 | features = features.data.cpu()
99 | qf.append(features)
100 |
101 | q_pids.append(pids)
102 | q_camids.extend(camids)
103 | qf = torch.stack(qf)
104 | q_pids = np.asarray(q_pids)
105 | q_camids = np.asarray(q_camids)
106 | print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
107 | gf, g_pids, g_camids = [], [], []
108 | for batch_idx, (imgs, pids, camids,_) in enumerate(galleryloader):
109 | if use_gpu:
110 | imgs = imgs.cuda()
111 | imgs = Variable(imgs, volatile=True)
112 | b, s,c, h, w = imgs.size()
113 | features = model(imgs,pids,cam_label=camids)
114 | features = features.view(b, -1)
115 | if pool == 'avg':
116 | features = torch.mean(features, 0)
117 | else:
118 | features, _ = torch.max(features, 0)
119 | features = features.data.cpu()
120 | gf.append(features)
121 | g_pids.append(pids)
122 | g_camids.extend(camids)
123 | gf = torch.stack(gf)
124 | g_pids = np.asarray(g_pids)
125 | g_camids = np.asarray(g_camids)
126 | print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
127 | print("Computing distance matrix")
128 | m, n = qf.size(0), gf.size(0)
129 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
130 | distmat.addmm_(1, -2, qf, gf.t())
131 | distmat = distmat.numpy()
132 | gf = gf.numpy()
133 | qf = qf.numpy()
134 |
135 | print("Original Computing CMC and mAP")
136 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
137 |
138 | # print("Results ---------- {:.1%} ".format(distmat_rerank))
139 | print("Results ---------- ")
140 |
141 | print("mAP: {:.1%} ".format(mAP))
142 | print("CMC curve r1:",cmc[0])
143 |
144 | return cmc[0], mAP
145 |
146 |
147 |
148 | if __name__ == '__main__':
149 | parser = argparse.ArgumentParser(description="VID-Trans-ReID")
150 | parser.add_argument(
151 | "--Dataset_name", default="", help="The name of the DataSet", type=str)
152 | parser.add_argument(
153 | "--model_path", default="", help="pretrained model", type=str)
154 | args = parser.parse_args()
155 | Dataset_name=args.Dataset_name
156 | pretrainpath=args.model_path
157 |
158 |
159 |
160 | train_loader, num_query, num_classes, camera_num, view_num,q_val_set,g_val_set = dataloader(Dataset_name)
161 | model = VID_Trans( num_classes=num_classes, camera_num=camera_num,pretrainpath=None)
162 |
163 | device = "cuda"
164 | model=model.to(device)
165 |
166 | checkpoint = torch.load(pretrainpath)
167 | model.load_state_dict(checkpoint)
168 |
169 |
170 | model.eval()
171 | cmc,map = test(model, q_val_set,g_val_set)
172 | print('CMC: %.4f, mAP : %.4f'%(cmc,map))
173 |
174 |
--------------------------------------------------------------------------------
/Datasets/iLDSVID.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import print_function, absolute_import
3 |
4 | from collections import defaultdict
5 | from scipy.io import loadmat
6 | import os.path as osp
7 | import numpy as np
8 | import json
9 | import errno
10 | import os
11 | import tarfile
12 | import glob
13 | def read_json(fpath):
14 | with open(fpath, 'r') as f:
15 | obj = json.load(f)
16 | return obj
17 | import urllib.request
18 |
19 |
20 | def write_json(obj, fpath):
21 | mkdir_if_missing(osp.dirname(fpath))
22 | with open(fpath, 'w') as f:
23 | json.dump(obj, f, indent=4, separators=(',', ': '))
24 |
25 | def mkdir_if_missing(directory):
26 | if not osp.exists(directory):
27 | try:
28 | os.makedirs(directory)
29 | except OSError as e:
30 | if e.errno != errno.EEXIST:
31 | raise
32 |
33 |
34 | class iLIDSVID(object):
35 | """
36 | iLIDS-VID
37 | Reference:
38 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014.
39 |
40 | Dataset statistics:
41 | # identities: 300
42 | # tracklets: 600
43 | # cameras: 2
44 | Args:
45 | split_id (int): indicates which split to use. There are totally 10 splits.
46 | """
47 | root ="iLIDS-VID"
48 | # root = '/mnt/scratch/1/pathak/data/iLIDS'
49 | # root = './data/ilids-vid'
50 | dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar'
51 | data_dir = osp.join(root, 'i-LIDS-VID')
52 | split_dir = osp.join(root, 'train-test people splits')
53 | split_mat_path = osp.join(split_dir, 'train_test_splits_ilidsvid.mat')
54 | split_path = osp.join(root, 'splits.json')
55 | cam_1_path = osp.join(root, 'i-LIDS-VID/sequences/cam1')
56 | cam_2_path = osp.join(root, 'i-LIDS-VID/sequences/cam2')
57 |
58 | def __init__(self, split_id=0):
59 | self._download_data()
60 | self._check_before_run()
61 |
62 | self._prepare_split()
63 | splits = read_json(self.split_path)
64 | if split_id >= len(splits):
65 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1))
66 | split = splits[split_id]
67 | train_dirs, test_dirs = split['train'], split['test']
68 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs)))
69 |
70 | train, num_train_tracklets, num_train_pids, num_imgs_train = \
71 | self._process_data(train_dirs, cam1=True, cam2=True)
72 | query, num_query_tracklets, num_query_pids, num_imgs_query = \
73 | self._process_data(test_dirs, cam1=True, cam2=False)
74 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \
75 | self._process_data(test_dirs, cam1=False, cam2=True)
76 |
77 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery
78 | min_num = np.min(num_imgs_per_tracklet)
79 | max_num = np.max(num_imgs_per_tracklet)
80 | avg_num = np.mean(num_imgs_per_tracklet)
81 |
82 | num_total_pids = num_train_pids + num_query_pids
83 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets
84 |
85 | print("=> iLIDS-VID loaded")
86 | print("Dataset statistics:")
87 | print(" ------------------------------")
88 | print(" subset | # ids | # tracklets")
89 | print(" ------------------------------")
90 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets))
91 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets))
92 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets))
93 | print(" ------------------------------")
94 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets))
95 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num))
96 | print(" ------------------------------")
97 |
98 |
99 | self.train = train
100 | self.query = query
101 | self.gallery = gallery
102 |
103 | self.num_train_pids = num_train_pids
104 | self.num_query_pids = num_query_pids
105 | self.num_gallery_pids = num_gallery_pids
106 | self.num_train_cams=2
107 | self.num_query_cams=2
108 | self.num_gallery_cams=2
109 | self.num_train_vids=num_train_tracklets
110 | self.num_query_vids=num_query_tracklets
111 | self.num_gallery_vids=num_gallery_tracklets
112 |
113 | def _download_data(self):
114 | if osp.exists(self.root):
115 | print("This dataset has been downloaded.")
116 | return
117 |
118 | mkdir_if_missing(self.root)
119 | fpath = osp.join(self.root, osp.basename(self.dataset_url))
120 |
121 | print("Downloading iLIDS-VID dataset")
122 | url_opener = urllib.request
123 | url_opener.urlretrieve(self.dataset_url, fpath)
124 |
125 | print("Extracting files")
126 | tar = tarfile.open(fpath)
127 | tar.extractall(path=self.root)
128 | tar.close()
129 |
130 | def _check_before_run(self):
131 | """Check if all files are available before going deeper"""
132 | if not osp.exists(self.root):
133 | raise RuntimeError("'{}' is not available".format(self.root))
134 | if not osp.exists(self.data_dir):
135 | raise RuntimeError("'{}' is not available".format(self.data_dir))
136 | if not osp.exists(self.split_dir):
137 | raise RuntimeError("'{}' is not available".format(self.split_dir))
138 |
139 | def _prepare_split(self):
140 | if not osp.exists(self.split_path):
141 | print("Creating splits")
142 | mat_split_data = loadmat(self.split_mat_path)['ls_set']
143 |
144 | num_splits = mat_split_data.shape[0]
145 | num_total_ids = mat_split_data.shape[1]
146 | assert num_splits == 10
147 | assert num_total_ids == 300
148 | num_ids_each = int(num_total_ids/2)
149 |
150 | # pids in mat_split_data are indices, so we need to transform them
151 | # to real pids
152 | person_cam1_dirs = os.listdir(self.cam_1_path)
153 | person_cam2_dirs = os.listdir(self.cam_2_path)
154 |
155 | # make sure persons in one camera view can be found in the other camera view
156 | assert set(person_cam1_dirs) == set(person_cam2_dirs)
157 |
158 | splits = []
159 | for i_split in range(num_splits):
160 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14.
161 | train_idxs = sorted(list(mat_split_data[i_split,num_ids_each:]))
162 | test_idxs = sorted(list(mat_split_data[i_split,:num_ids_each]))
163 |
164 | train_idxs = [int(i)-1 for i in train_idxs]
165 | test_idxs = [int(i)-1 for i in test_idxs]
166 |
167 | # transform pids to person dir names
168 | train_dirs = [person_cam1_dirs[i] for i in train_idxs]
169 | test_dirs = [person_cam1_dirs[i] for i in test_idxs]
170 |
171 | split = {'train': train_dirs, 'test': test_dirs}
172 | splits.append(split)
173 |
174 | print("Totally {} splits are created, following Wang et al. ECCV'14".format(len(splits)))
175 | print("Split file is saved to {}".format(self.split_path))
176 | write_json(splits, self.split_path)
177 |
178 | print("Splits created")
179 |
180 | def _process_data(self, dirnames, cam1=True, cam2=True):
181 | tracklets = []
182 | num_imgs_per_tracklet = []
183 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)}
184 |
185 | for dirname in dirnames:
186 | if cam1:
187 | person_dir = osp.join(self.cam_1_path, dirname)
188 | img_names = glob.glob(osp.join(person_dir, '*.png'))
189 | assert len(img_names) > 0
190 | img_names = tuple(img_names)
191 | pid = dirname2pid[dirname]
192 | tracklets.append((img_names, pid, 0))
193 | num_imgs_per_tracklet.append(len(img_names))
194 |
195 | if cam2:
196 | person_dir = osp.join(self.cam_2_path, dirname)
197 | img_names = glob.glob(osp.join(person_dir, '*.png'))
198 | assert len(img_names) > 0
199 | img_names = tuple(img_names)
200 | pid = dirname2pid[dirname]
201 | tracklets.append((img_names, pid, 1))
202 | num_imgs_per_tracklet.append(len(img_names))
203 |
204 | num_tracklets = len(tracklets)
205 | num_pids = len(dirnames)
206 |
207 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet
208 |
--------------------------------------------------------------------------------
/VID_Trans_ReID.py:
--------------------------------------------------------------------------------
1 | from Dataloader import dataloader
2 | from VID_Trans_model import VID_Trans
3 |
4 |
5 | from Loss_fun import make_loss
6 |
7 | import random
8 | import torch
9 | import numpy as np
10 | import os
11 | import argparse
12 |
13 | import logging
14 | import os
15 | import time
16 | import torch
17 | import torch.nn as nn
18 | from torch_ema import ExponentialMovingAverage
19 | from torch.cuda import amp
20 | import torch.distributed as dist
21 |
22 | from utility import AverageMeter, optimizer,scheduler
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | from torch.autograd import Variable
31 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=21):
32 | num_q, num_g = distmat.shape
33 | if num_g < max_rank:
34 | max_rank = num_g
35 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
36 | indices = np.argsort(distmat, axis=1)
37 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
38 |
39 | # compute cmc curve for each query
40 | all_cmc = []
41 | all_AP = []
42 | num_valid_q = 0.
43 | for q_idx in range(num_q):
44 | # get query pid and camid
45 | q_pid = q_pids[q_idx]
46 | q_camid = q_camids[q_idx]
47 |
48 | # remove gallery samples that have the same pid and camid with query
49 | order = indices[q_idx]
50 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
51 | keep = np.invert(remove)
52 |
53 | # compute cmc curve
54 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
55 | if not np.any(orig_cmc):
56 | # this condition is true when query identity does not appear in gallery
57 | continue
58 |
59 | cmc = orig_cmc.cumsum()
60 | cmc[cmc > 1] = 1
61 |
62 | all_cmc.append(cmc[:max_rank])
63 | num_valid_q += 1.
64 |
65 | # compute average precision
66 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
67 | num_rel = orig_cmc.sum()
68 | tmp_cmc = orig_cmc.cumsum()
69 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
70 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
71 | AP = tmp_cmc.sum() / num_rel
72 | all_AP.append(AP)
73 |
74 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
75 |
76 | all_cmc = np.asarray(all_cmc).astype(np.float32)
77 | all_cmc = all_cmc.sum(0) / num_valid_q
78 | mAP = np.mean(all_AP)
79 |
80 | return all_cmc, mAP
81 |
82 | def test(model, queryloader, galleryloader, pool='avg', use_gpu=True, ranks=[1, 5, 10, 20]):
83 | model.eval()
84 | qf, q_pids, q_camids = [], [], []
85 | with torch.no_grad():
86 | for batch_idx, (imgs, pids, camids,_) in enumerate(queryloader):
87 |
88 | if use_gpu:
89 | imgs = imgs.cuda()
90 | imgs = Variable(imgs, volatile=True)
91 |
92 | b, s, c, h, w = imgs.size()
93 |
94 |
95 | features = model(imgs,pids,cam_label=camids )
96 |
97 | features = features.view(b, -1)
98 | features = torch.mean(features, 0)
99 | features = features.data.cpu()
100 | qf.append(features)
101 |
102 | q_pids.append(pids)
103 | q_camids.extend(camids)
104 | qf = torch.stack(qf)
105 | q_pids = np.asarray(q_pids)
106 | q_camids = np.asarray(q_camids)
107 | print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
108 | gf, g_pids, g_camids = [], [], []
109 | for batch_idx, (imgs, pids, camids,_) in enumerate(galleryloader):
110 | if use_gpu:
111 | imgs = imgs.cuda()
112 | imgs = Variable(imgs, volatile=True)
113 | b, s,c, h, w = imgs.size()
114 | features = model(imgs,pids,cam_label=camids)
115 | features = features.view(b, -1)
116 | if pool == 'avg':
117 | features = torch.mean(features, 0)
118 | else:
119 | features, _ = torch.max(features, 0)
120 | features = features.data.cpu()
121 | gf.append(features)
122 | g_pids.append(pids)
123 | g_camids.extend(camids)
124 | gf = torch.stack(gf)
125 | g_pids = np.asarray(g_pids)
126 | g_camids = np.asarray(g_camids)
127 | print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
128 | print("Computing distance matrix")
129 | m, n = qf.size(0), gf.size(0)
130 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
131 | distmat.addmm_(1, -2, qf, gf.t())
132 | distmat = distmat.numpy()
133 | gf = gf.numpy()
134 | qf = qf.numpy()
135 |
136 | print("Original Computing CMC and mAP")
137 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
138 |
139 | # print("Results ---------- {:.1%} ".format(distmat_rerank))
140 | print("Results ---------- ")
141 |
142 | print("mAP: {:.1%} ".format(mAP))
143 | print("CMC curve r1:",cmc[0])
144 |
145 | return cmc[0], mAP
146 |
147 |
148 |
149 | if __name__ == '__main__':
150 | parser = argparse.ArgumentParser(description="VID-Trans-ReID")
151 | parser.add_argument(
152 | "--Dataset_name", default="", help="The name of the DataSet", type=str)
153 | args = parser.parse_args()
154 | Dataset_name=args.Dataset_name
155 | torch.manual_seed(1234)
156 | torch.cuda.manual_seed(1234)
157 | torch.cuda.manual_seed_all(1234)
158 | np.random.seed(1234)
159 | random.seed(1234)
160 | torch.backends.cudnn.deterministic = True
161 | torch.backends.cudnn.benchmark = True
162 | train_loader, num_query, num_classes, camera_num, view_num,q_val_set,g_val_set = dataloader(Dataset_name)
163 | model = VID_Trans( num_classes=num_classes, camera_num=camera_num,pretrainpath=pretrainpath)
164 |
165 | loss_fun,center_criterion= make_loss( num_classes=num_classes)
166 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr= 0.5)
167 |
168 | optimizer= optimizer( model)
169 | scheduler = scheduler(optimizer)
170 | scaler = amp.GradScaler()
171 |
172 | #Train
173 | device = "cuda"
174 | epochs = 120
175 | model=model.to(device)
176 | ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
177 | loss_meter = AverageMeter()
178 | acc_meter = AverageMeter()
179 |
180 | cmc_rank1=0
181 | for epoch in range(1, epochs + 1):
182 | start_time = time.time()
183 | loss_meter.reset()
184 | acc_meter.reset()
185 |
186 | scheduler.step(epoch)
187 | model.train()
188 |
189 | for Epoch_n, (img, pid, target_cam,labels2) in enumerate(train_loader):
190 |
191 | optimizer.zero_grad()
192 | optimizer_center.zero_grad()
193 |
194 | img = img.to(device)
195 | pid = pid.to(device)
196 | target_cam = target_cam.to(device)
197 |
198 | labels2=labels2.to(device)
199 | with amp.autocast(enabled=True):
200 | target_cam=target_cam.view(-1)
201 | score, feat ,a_vals= model(img, pid, cam_label=target_cam)
202 |
203 | labels2=labels2.to(device)
204 | attn_noise = a_vals * labels2
205 | attn_loss = attn_noise.sum(1).mean()
206 |
207 | loss_id ,center= loss_fun(score, feat, pid, target_cam)
208 | loss=loss_id+ 0.0005*center +attn_loss
209 | scaler.scale(loss).backward()
210 |
211 | scaler.step(optimizer)
212 | scaler.update()
213 | ema.update()
214 |
215 | for param in center_criterion.parameters():
216 | param.grad.data *= (1. / 0.0005)
217 | scaler.step(optimizer_center)
218 | scaler.update()
219 | if isinstance(score, list):
220 | acc = (score[0].max(1)[1] == pid).float().mean()
221 | else:
222 | acc = (score.max(1)[1] == pid).float().mean()
223 |
224 | loss_meter.update(loss.item(), img.shape[0])
225 | acc_meter.update(acc, 1)
226 |
227 | torch.cuda.synchronize()
228 | if (Epoch_n + 1) % 50 == 0:
229 | print("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
230 | .format(epoch, (Epoch_n + 1), len(train_loader),
231 | loss_meter.avg, acc_meter.avg, scheduler._get_lr(epoch)[0]))
232 |
233 | if (epoch+1)%10 == 0 :
234 |
235 | model.eval()
236 | cmc,map = test(model, q_val_set,g_val_set)
237 | print('CMC: %.4f, mAP : %.4f'%(cmc,map))
238 | if cmc_rank1 < cmc:
239 | cmc_rank1=cmc
240 | torch.save(model.state_dict(),os.path.join('/VID-Trans-ReID', Dataset_name+'Main_Model.pth'))
241 |
242 |
243 |
--------------------------------------------------------------------------------
/Datasets/MARS_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import print_function, absolute_import
3 |
4 | from collections import defaultdict
5 | from scipy.io import loadmat
6 | import os.path as osp
7 | import numpy as np
8 |
9 |
10 | class Mars(object):
11 | """
12 | MARS
13 | Reference:
14 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016.
15 |
16 | Dataset statistics:
17 | # identities: 1261
18 | # tracklets: 8298 (train) + 1980 (query) + 9330 (gallery)
19 | # cameras: 6
20 | Args:
21 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0).
22 | """
23 |
24 | root ='MARS' #'/home2/zwjx97/STE-NVAN-master/MARS' #"/home/aishahalsehaim/Desktop/STE-NVAN-master/MARS"
25 |
26 | train_name_path = osp.join(root, 'info/train_name.txt')
27 | test_name_path = osp.join(root, 'info/test_name.txt')
28 | track_train_info_path = osp.join(root, 'info/tracks_train_info.mat')
29 | track_test_info_path = osp.join(root, 'info/tracks_test_info.mat')
30 | query_IDX_path = osp.join(root, 'info/query_IDX.mat')
31 |
32 | def __init__(self, min_seq_len=0, ):
33 | self._check_before_run()
34 |
35 | # prepare meta data
36 | train_names = self._get_names(self.train_name_path)
37 | test_names = self._get_names(self.test_name_path)
38 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4)
39 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4)
40 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,)
41 | query_IDX -= 1 # index from 0
42 | track_query = track_test[query_IDX,:]
43 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX]
44 | track_gallery = track_test[gallery_IDX,:]
45 |
46 | train, num_train_tracklets, num_train_pids, num_train_imgs = self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len)
47 |
48 | video = self._process_train_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len)
49 |
50 |
51 | query, num_query_tracklets, num_query_pids, num_query_imgs = self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
52 |
53 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
54 |
55 | num_imgs_per_tracklet = num_train_imgs + num_query_imgs + num_gallery_imgs
56 | min_num = np.min(num_imgs_per_tracklet)
57 | max_num = np.max(num_imgs_per_tracklet)
58 | avg_num = np.mean(num_imgs_per_tracklet)
59 |
60 | num_total_pids = num_train_pids + num_query_pids
61 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets
62 |
63 | print("=> MARS loaded")
64 | print("Dataset statistics:")
65 | print(" ------------------------------")
66 | print(" subset | # ids | # tracklets")
67 | print(" ------------------------------")
68 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets))
69 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets))
70 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets))
71 | print(" ------------------------------")
72 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets))
73 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num))
74 | print(" ------------------------------")
75 |
76 | # self.train_videos = video
77 | self.train = train
78 | self.query = query
79 | self.gallery = gallery
80 |
81 | self.num_train_pids = num_train_pids
82 | self.num_query_pids = num_query_pids
83 | self.num_gallery_pids = num_gallery_pids
84 | self.num_train_cams=6
85 | self.num_query_cams=6
86 | self.num_gallery_cams=6
87 | self.num_train_vids=num_train_tracklets
88 | self.num_query_vids=num_query_tracklets
89 | self.num_gallery_vids=num_gallery_tracklets
90 | def _check_before_run(self):
91 | """Check if all files are available before going deeper"""
92 | if not osp.exists(self.root):
93 | raise RuntimeError("'{}' is not available".format(self.root))
94 | if not osp.exists(self.train_name_path):
95 | raise RuntimeError("'{}' is not available".format(self.train_name_path))
96 | if not osp.exists(self.test_name_path):
97 | raise RuntimeError("'{}' is not available".format(self.test_name_path))
98 | if not osp.exists(self.track_train_info_path):
99 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path))
100 | if not osp.exists(self.track_test_info_path):
101 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path))
102 | if not osp.exists(self.query_IDX_path):
103 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path))
104 |
105 | def _get_names(self, fpath):
106 | names = []
107 | with open(fpath, 'r') as f:
108 | for line in f:
109 | new_line = line.rstrip()
110 | names.append(new_line)
111 | return names
112 |
113 | def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0):
114 | assert home_dir in ['bbox_train', 'bbox_test']
115 | num_tracklets = meta_data.shape[0]
116 | pid_list = list(set(meta_data[:,2].tolist()))
117 | num_pids = len(pid_list)
118 | if not relabel: pid2label = {pid:int(pid) for label, pid in enumerate(pid_list)}
119 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)}
120 | tracklets = []
121 | num_imgs_per_tracklet = []
122 |
123 | for tracklet_idx in range(num_tracklets):
124 | data = meta_data[tracklet_idx,...]
125 | start_index, end_index, pid, camid = data
126 | if pid == -1: continue # junk images are just ignored
127 | assert 1 <= camid <= 6
128 | #if relabel: pid = pid2label[pid]
129 | pid = pid2label[pid]
130 | camid -= 1 # index starts from 0
131 | img_names = names[start_index-1:end_index]
132 |
133 | # make sure image names correspond to the same person
134 | pnames = [img_name[:4] for img_name in img_names]
135 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images"
136 |
137 | # make sure all images are captured under the same camera
138 | camnames = [img_name[5] for img_name in img_names]
139 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!"
140 |
141 | # append image names with directory information
142 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names]
143 | if len(img_paths) >= min_seq_len:
144 | img_paths = tuple(img_paths)
145 | tracklets.append((img_paths, pid, camid))
146 | num_imgs_per_tracklet.append(len(img_paths))
147 | # if camid in video[pid] :
148 | # video[pid][camid].append(img_paths)
149 | # else:
150 | # video[pid][camid] = img_paths
151 |
152 | num_tracklets = len(tracklets)
153 |
154 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet
155 |
156 | def _process_train_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0):
157 | video = defaultdict(dict)
158 |
159 | assert home_dir in ['bbox_train', 'bbox_test']
160 | num_tracklets = meta_data.shape[0]
161 | pid_list = list(set(meta_data[:,2].tolist()))
162 | num_pids = len(pid_list)
163 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)}
164 | for tracklet_idx in range(num_tracklets):
165 | data = meta_data[tracklet_idx,...]
166 | start_index, end_index, pid, camid = data
167 | if pid == -1: continue # junk images are just ignored
168 | assert 1 <= camid <= 6
169 | if relabel: pid = pid2label[pid]
170 | camid -= 1 # index starts from 0
171 | img_names = names[start_index-1:end_index]
172 | # make sure image names correspond to the same person
173 | pnames = [img_name[:4] for img_name in img_names]
174 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images"
175 | # make sure all images are captured under the same camera
176 | camnames = [img_name[5] for img_name in img_names]
177 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!"
178 |
179 | # append image names with directory information
180 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names]
181 | if len(img_paths) >= min_seq_len:
182 | if camid in video[pid] :
183 | video[pid][camid].extend(img_paths)
184 | else:
185 | video[pid][camid] = img_paths
186 | return video
187 |
188 |
189 |
--------------------------------------------------------------------------------
/VID_Trans_model.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import copy
5 | from vit_ID import TransReID,Block
6 | from functools import partial
7 | from torch.nn import functional as F
8 |
9 |
10 | def TCSS(features, shift, b,t):
11 | #aggregate features at patch level
12 | features=features.view(b,features.size(1),t*features.size(2))
13 | token = features[:, 0:1]
14 |
15 | batchsize = features.size(0)
16 | dim = features.size(-1)
17 |
18 |
19 | #shift the patches with amount=shift
20 | features= torch.cat([features[:, shift:], features[:, 1:shift]], dim=1)
21 |
22 | # Patch Shuffling by 2 part
23 | try:
24 | features = features.view(batchsize, 2, -1, dim)
25 | except:
26 | features = torch.cat([features, features[:, -2:-1, :]], dim=1)
27 | features = features.view(batchsize, 2, -1, dim)
28 |
29 | features = torch.transpose(features, 1, 2).contiguous()
30 | features = features.view(batchsize, -1, dim)
31 |
32 | return features,token
33 |
34 | def weights_init_kaiming(m):
35 | classname = m.__class__.__name__
36 | if classname.find('Linear') != -1:
37 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
38 | nn.init.constant_(m.bias, 0.0)
39 |
40 | elif classname.find('Conv') != -1:
41 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
42 | if m.bias is not None:
43 | nn.init.constant_(m.bias, 0.0)
44 | elif classname.find('BatchNorm') != -1:
45 | if m.affine:
46 | nn.init.constant_(m.weight, 1.0)
47 | nn.init.constant_(m.bias, 0.0)
48 |
49 | def weights_init_classifier(m):
50 | classname = m.__class__.__name__
51 | if classname.find('Linear') != -1:
52 | nn.init.normal_(m.weight, std=0.001)
53 | if m.bias:
54 | nn.init.constant_(m.bias, 0.0)
55 |
56 |
57 |
58 |
59 | class VID_Trans(nn.Module):
60 | def __init__(self, num_classes, camera_num,pretrainpath):
61 | super(VID_Trans, self).__init__()
62 | self.in_planes = 768
63 | self.num_classes = num_classes
64 |
65 |
66 | self.base =TransReID(
67 | img_size=[256, 128], patch_size=16, stride_size=[16, 16], embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\
68 | camera=camera_num, drop_path_rate=0.1, drop_rate=0.0, attn_drop_rate=0.0,norm_layer=partial(nn.LayerNorm, eps=1e-6), cam_lambda=3.0)
69 |
70 |
71 | state_dict = torch.load(pretrainpath, map_location='cpu')
72 | self.base.load_param(state_dict,load=True)
73 |
74 |
75 | #global stream
76 | block= self.base.blocks[-1]
77 | layer_norm = self.base.norm
78 | self.b1 = nn.Sequential(
79 | copy.deepcopy(block),
80 | copy.deepcopy(layer_norm)
81 | )
82 |
83 | self.bottleneck = nn.BatchNorm1d(self.in_planes)
84 | self.bottleneck.bias.requires_grad_(False)
85 | self.bottleneck.apply(weights_init_kaiming)
86 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
87 | self.classifier.apply(weights_init_classifier)
88 |
89 | #-----------------------------------------------
90 | #-----------------------------------------------
91 |
92 |
93 | # building local video stream
94 | dpr = [x.item() for x in torch.linspace(0, 0, 12)] # stochastic depth decay rule
95 |
96 | self.block1 = Block(
97 | dim=3072, num_heads=12, mlp_ratio=4, qkv_bias=True, qk_scale=None,
98 | drop=0, attn_drop=0, drop_path=dpr[11], norm_layer=partial(nn.LayerNorm, eps=1e-6))
99 |
100 | self.b2 = nn.Sequential(
101 | self.block1,
102 | nn.LayerNorm(3072)#copy.deepcopy(layer_norm)
103 | )
104 |
105 |
106 | self.bottleneck_1 = nn.BatchNorm1d(3072)
107 | self.bottleneck_1.bias.requires_grad_(False)
108 | self.bottleneck_1.apply(weights_init_kaiming)
109 | self.bottleneck_2 = nn.BatchNorm1d(3072)
110 | self.bottleneck_2.bias.requires_grad_(False)
111 | self.bottleneck_2.apply(weights_init_kaiming)
112 | self.bottleneck_3 = nn.BatchNorm1d(3072)
113 | self.bottleneck_3.bias.requires_grad_(False)
114 | self.bottleneck_3.apply(weights_init_kaiming)
115 | self.bottleneck_4 = nn.BatchNorm1d(3072)
116 | self.bottleneck_4.bias.requires_grad_(False)
117 | self.bottleneck_4.apply(weights_init_kaiming)
118 |
119 |
120 | self.classifier_1 = nn.Linear(3072, self.num_classes, bias=False)
121 | self.classifier_1.apply(weights_init_classifier)
122 | self.classifier_2 = nn.Linear(3072, self.num_classes, bias=False)
123 | self.classifier_2.apply(weights_init_classifier)
124 | self.classifier_3 = nn.Linear(3072, self.num_classes, bias=False)
125 | self.classifier_3.apply(weights_init_classifier)
126 | self.classifier_4 = nn.Linear(3072, self.num_classes, bias=False)
127 | self.classifier_4.apply(weights_init_classifier)
128 |
129 |
130 | #-------------------video attention-------------
131 | self.middle_dim = 256 # middle layer dimension
132 | self.attention_conv = nn.Conv2d(self.in_planes, self.middle_dim, [1,1]) # 7,4 cooresponds to 224, 112 input image size
133 | self.attention_tconv = nn.Conv1d(self.middle_dim, 1, 3, padding=1)
134 | self.attention_conv.apply(weights_init_kaiming)
135 | self.attention_tconv.apply(weights_init_kaiming)
136 | #------------------------------------------
137 |
138 | self.shift_num = 5
139 | self.part = 4
140 | self.rearrange=True
141 |
142 |
143 |
144 |
145 | def forward(self, x, label=None, cam_label= None, view_label=None): # label is unused if self.cos_layer == 'no'
146 | b=x.size(0)
147 | t=x.size(1)
148 |
149 | x=x.view(x.size(0)*x.size(1), x.size(2), x.size(3), x.size(4))
150 | features = self.base(x, cam_label=cam_label)
151 |
152 |
153 | # global branch
154 | b1_feat = self.b1(features) # [64, 129, 768]
155 | global_feat = b1_feat[:, 0]
156 |
157 | global_feat=global_feat.unsqueeze(dim=2)
158 | global_feat=global_feat.unsqueeze(dim=3)
159 | a = F.relu(self.attention_conv(global_feat))
160 | a = a.view(b, t, self.middle_dim)
161 | a = a.permute(0,2,1)
162 | a = F.relu(self.attention_tconv(a))
163 | a = a.view(b, t)
164 | a_vals = a
165 |
166 | a = F.softmax(a, dim=1)
167 | x = global_feat.view(b, t, -1)
168 | a = torch.unsqueeze(a, -1)
169 | a = a.expand(b, t, self.in_planes)
170 | att_x = torch.mul(x,a)
171 | att_x = torch.sum(att_x,1)
172 |
173 | global_feat = att_x.view(b,self.in_planes)
174 | feat = self.bottleneck(global_feat)
175 |
176 |
177 |
178 |
179 | #-------------------------------------------------
180 | #-------------------------------------------------
181 |
182 |
183 | # video patch patr features
184 |
185 | feature_length = features.size(1) - 1
186 | patch_length = feature_length // 4
187 |
188 | #Temporal clip shift and shuffled
189 | x ,token=TCSS(features, self.shift_num, b,t)
190 |
191 |
192 | # part1
193 | part1 = x[:, :patch_length]
194 | part1 = self.b2(torch.cat((token, part1), dim=1))
195 | part1_f = part1[:, 0]
196 |
197 | # part2
198 | part2 = x[:, patch_length:patch_length*2]
199 | part2 = self.b2(torch.cat((token, part2), dim=1))
200 | part2_f = part2[:, 0]
201 |
202 | # part3
203 | part3 = x[:, patch_length*2:patch_length*3]
204 | part3 = self.b2(torch.cat((token, part3), dim=1))
205 | part3_f = part3[:, 0]
206 |
207 | # part4
208 | part4 = x[:, patch_length*3:patch_length*4]
209 | part4 = self.b2(torch.cat((token, part4), dim=1))
210 | part4_f = part4[:, 0]
211 |
212 |
213 |
214 | part1_bn = self.bottleneck_1(part1_f)
215 | part2_bn = self.bottleneck_2(part2_f)
216 | part3_bn = self.bottleneck_3(part3_f)
217 | part4_bn = self.bottleneck_4(part4_f)
218 |
219 | if self.training:
220 |
221 | Global_ID = self.classifier(feat)
222 | Local_ID1 = self.classifier_1(part1_bn)
223 | Local_ID2 = self.classifier_2(part2_bn)
224 | Local_ID3 = self.classifier_3(part3_bn)
225 | Local_ID4 = self.classifier_4(part4_bn)
226 |
227 | return [Global_ID, Local_ID1, Local_ID2, Local_ID3, Local_ID4 ], [global_feat, part1_f, part2_f, part3_f,part4_f], a_vals #[global_feat, part1_f, part2_f, part3_f,part4_f], a_vals
228 |
229 | else:
230 | return torch.cat([feat, part1_bn/4 , part2_bn/4 , part3_bn /4, part4_bn/4 ], dim=1)
231 |
232 |
233 |
234 | def load_param(self, trained_path,load=False):
235 | if not load:
236 | param_dict = torch.load(trained_path)
237 | for i in param_dict:
238 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
239 | print('Loading pretrained model from {}'.format(trained_path))
240 | else:
241 | param_dict=trained_path
242 | for i in param_dict:
243 | #print(i)
244 | if i not in self.state_dict() or 'classifier' in i or 'sie_embed' in i:
245 | continue
246 | self.state_dict()[i].copy_(param_dict[i])
247 |
248 |
249 |
250 | def load_param_finetune(self, model_path):
251 | param_dict = torch.load(model_path)
252 | for i in param_dict:
253 | self.state_dict()[i].copy_(param_dict[i])
254 | print('Loading pretrained model for finetuning from {}'.format(model_path))
255 |
256 |
257 |
258 |
259 |
260 |
--------------------------------------------------------------------------------
/Dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as T
3 | from torch.utils.data import DataLoader
4 | from PIL import Image, ImageFile
5 |
6 | from torch.utils.data import Dataset
7 | import os.path as osp
8 | import random
9 | import torch
10 | import numpy as np
11 | import math
12 |
13 | from timm.data.random_erasing import RandomErasing
14 | from utility import RandomIdentitySampler,RandomErasing3
15 | from Datasets.MARS_dataset import Mars
16 | from Datasets.iLDSVID import iLIDSVID
17 | from Datasets.PRID_dataset import PRID
18 |
19 | __factory = {
20 | 'Mars':Mars,
21 | 'iLIDSVID':iLIDSVID,
22 | 'PRID':PRID
23 | }
24 |
25 | def train_collate_fn(batch):
26 |
27 | imgs, pids, camids,a= zip(*batch)
28 | pids = torch.tensor(pids, dtype=torch.int64)
29 |
30 | camids = torch.tensor(camids, dtype=torch.int64)
31 | return torch.stack(imgs, dim=0), pids, camids, torch.stack(a, dim=0)
32 |
33 | def val_collate_fn(batch):
34 |
35 | imgs, pids, camids, img_paths = zip(*batch)
36 | viewids = torch.tensor(viewids, dtype=torch.int64)
37 | camids_batch = torch.tensor(camids, dtype=torch.int64)
38 | return torch.stack(imgs, dim=0), pids, camids_batch, img_paths
39 |
40 | def dataloader(Dataset_name):
41 | train_transforms = T.Compose([
42 | T.Resize([256, 128], interpolation=3),
43 | T.RandomHorizontalFlip(p=0.5),
44 | T.Pad(10),
45 | T.RandomCrop([256, 128]),
46 | T.ToTensor(),
47 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
48 |
49 |
50 | ])
51 |
52 | val_transforms = T.Compose([
53 | T.Resize([256, 128]),
54 | T.ToTensor(),
55 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
56 | ])
57 |
58 |
59 |
60 | dataset = __factory[Dataset_name]()
61 | train_set = VideoDataset_inderase(dataset.train, seq_len=4, sample='intelligent',transform=train_transforms)
62 | num_classes = dataset.num_train_pids
63 | cam_num = dataset.num_train_cams
64 | view_num = dataset.num_train_vids
65 |
66 |
67 | train_loader = DataLoader(train_set, batch_size=64,sampler=RandomIdentitySampler(dataset.train, 64,4),num_workers=4, collate_fn=train_collate_fn)
68 |
69 | q_val_set = VideoDataset(dataset.query, seq_len=4, sample='dense', transform=val_transforms)
70 | g_val_set = VideoDataset(dataset.gallery, seq_len=4, sample='dense', transform=val_transforms)
71 |
72 |
73 | return train_loader, len(dataset.query), num_classes, cam_num, view_num,q_val_set,g_val_set
74 |
75 |
76 |
77 | def read_image(img_path):
78 | """Keep reading image until succeed.
79 | This can avoid IOError incurred by heavy IO process."""
80 | got_img = False
81 | while not got_img:
82 | try:
83 | img = Image.open(img_path).convert('RGB')
84 | got_img = True
85 | except IOError:
86 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
87 | pass
88 | return img
89 |
90 | class VideoDataset(Dataset):
91 | """Video Person ReID Dataset.
92 | Note batch data has shape (batch, seq_len, channel, height, width).
93 | """
94 | sample_methods = ['evenly', 'random', 'all']
95 |
96 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None , max_length=40):
97 | self.dataset = dataset
98 | self.seq_len = seq_len
99 | self.sample = sample
100 | self.transform = transform
101 | self.max_length = max_length
102 |
103 | def __len__(self):
104 | return len(self.dataset)
105 |
106 | def __getitem__(self, index):
107 | img_paths, pid, camid = self.dataset[index]
108 | num = len(img_paths)
109 |
110 | # if self.sample == 'restricted_random':
111 | # frame_indices = range(num)
112 | # chunks =
113 | # rand_end = max(0, len(frame_indices) - self.seq_len - 1)
114 | # begin_index = random.randint(0, rand_end)
115 |
116 |
117 | if self.sample == 'random':
118 | """
119 | Randomly sample seq_len consecutive frames from num frames,
120 | if num is smaller than seq_len, then replicate items.
121 | This sampling strategy is used in training phase.
122 | """
123 | frame_indices = range(num)
124 | rand_end = max(0, len(frame_indices) - self.seq_len - 1)
125 | begin_index = random.randint(0, rand_end)
126 | end_index = min(begin_index + self.seq_len, len(frame_indices))
127 |
128 | indices = frame_indices[begin_index:end_index]
129 | # print(begin_index, end_index, indices)
130 | if len(indices) < self.seq_len:
131 | indices=np.array(indices)
132 | indices = np.append(indices , [indices[-1] for i in range(self.seq_len - len(indices))])
133 | else:
134 | indices=np.array(indices)
135 | imgs = []
136 | targt_cam=[]
137 | for index in indices:
138 | index=int(index)
139 | img_path = img_paths[index]
140 | img = read_image(img_path)
141 | if self.transform is not None:
142 | img = self.transform(img)
143 | img = img.unsqueeze(0)
144 | targt_cam.append(camid)
145 | imgs.append(img)
146 | imgs = torch.cat(imgs, dim=0)
147 | #imgs=imgs.permute(1,0,2,3)
148 | return imgs, pid, targt_cam
149 |
150 | elif self.sample == 'dense':
151 | """
152 | Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1.
153 | This sampling strategy is used in test phase.
154 | """
155 | # import pdb
156 | # pdb.set_trace()
157 |
158 | cur_index=0
159 | frame_indices = [i for i in range(num)]
160 | indices_list=[]
161 | while num-cur_index > self.seq_len:
162 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len])
163 | cur_index+=self.seq_len
164 | last_seq=frame_indices[cur_index:]
165 | # print(last_seq)
166 | for index in last_seq:
167 | if len(last_seq) >= self.seq_len:
168 | break
169 | last_seq.append(index)
170 |
171 |
172 | indices_list.append(last_seq)
173 | imgs_list=[]
174 | targt_cam=[]
175 | # print(indices_list , num , img_paths )
176 | for indices in indices_list:
177 | if len(imgs_list) > self.max_length:
178 | break
179 | imgs = []
180 | for index in indices:
181 | index=int(index)
182 | img_path = img_paths[index]
183 | img = read_image(img_path)
184 | if self.transform is not None:
185 | img = self.transform(img)
186 | img = img.unsqueeze(0)
187 | imgs.append(img)
188 | targt_cam.append(camid)
189 | imgs = torch.cat(imgs, dim=0)
190 | #imgs=imgs.permute(1,0,2,3)
191 | imgs_list.append(imgs)
192 | imgs_array = torch.stack(imgs_list)
193 | return imgs_array, pid, targt_cam,img_paths
194 | #return imgs_array, pid, int(camid),trackid
195 |
196 |
197 | elif self.sample == 'dense_subset':
198 | """
199 | Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1.
200 | This sampling strategy is used in test phase.
201 | """
202 | frame_indices = range(num)
203 | rand_end = max(0, len(frame_indices) - self.max_length - 1)
204 | begin_index = random.randint(0, rand_end)
205 |
206 |
207 | cur_index=begin_index
208 | frame_indices = [i for i in range(num)]
209 | indices_list=[]
210 | while num-cur_index > self.seq_len:
211 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len])
212 | cur_index+=self.seq_len
213 | last_seq=frame_indices[cur_index:]
214 | # print(last_seq)
215 | for index in last_seq:
216 | if len(last_seq) >= self.seq_len:
217 | break
218 | last_seq.append(index)
219 |
220 |
221 | indices_list.append(last_seq)
222 | imgs_list=[]
223 | # print(indices_list , num , img_paths )
224 | for indices in indices_list:
225 | if len(imgs_list) > self.max_length:
226 | break
227 | imgs = []
228 | for index in indices:
229 | index=int(index)
230 | img_path = img_paths[index]
231 | img = read_image(img_path)
232 | if self.transform is not None:
233 | img = self.transform(img)
234 | img = img.unsqueeze(0)
235 | imgs.append(img)
236 | imgs = torch.cat(imgs, dim=0)
237 | #imgs=imgs.permute(1,0,2,3)
238 | imgs_list.append(imgs)
239 | imgs_array = torch.stack(imgs_list)
240 | return imgs_array, pid, camid
241 |
242 | elif self.sample == 'intelligent_random':
243 | # frame_indices = range(num)
244 | indices = []
245 | each = max(num//seq_len,1)
246 | for i in range(seq_len):
247 | if i != seq_len -1:
248 | indices.append(random.randint(min(i*each , num-1), min( (i+1)*each-1, num-1)) )
249 | else:
250 | indices.append(random.randint(min(i*each , num-1), num-1) )
251 | print(len(indices))
252 | imgs = []
253 | for index in indices:
254 | index=int(index)
255 | img_path = img_paths[index]
256 | img = read_image(img_path)
257 | if self.transform is not None:
258 | img = self.transform(img)
259 | img = img.unsqueeze(0)
260 | imgs.append(img)
261 | imgs = torch.cat(imgs, dim=0)
262 | #imgs=imgs.permute(1,0,2,3)
263 | return imgs, pid, camid
264 | else:
265 | raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods))
266 |
267 |
268 |
269 |
270 | class VideoDataset_inderase(Dataset):
271 | """Video Person ReID Dataset.
272 | Note batch data has shape (batch, seq_len, channel, height, width).
273 | """
274 | sample_methods = ['evenly', 'random', 'all']
275 |
276 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None , max_length=40):
277 | self.dataset = dataset
278 | self.seq_len = seq_len
279 | self.sample = sample
280 | self.transform = transform
281 | self.max_length = max_length
282 | self.erase = RandomErasing3(probability=0.5, mean=[0.485, 0.456, 0.406])
283 |
284 | def __len__(self):
285 | return len(self.dataset)
286 |
287 | def __getitem__(self, index):
288 | img_paths, pid, camid = self.dataset[index]
289 | num = len(img_paths)
290 | if self.sample != "intelligent":
291 | frame_indices = range(num)
292 | rand_end = max(0, len(frame_indices) - self.seq_len - 1)
293 | begin_index = random.randint(0, rand_end)
294 | end_index = min(begin_index + self.seq_len, len(frame_indices))
295 |
296 | indices1 = frame_indices[begin_index:end_index]
297 | indices = []
298 | for index in indices1:
299 | if len(indices1) >= self.seq_len:
300 | break
301 | indices.append(index)
302 | indices=np.array(indices)
303 | else:
304 | # frame_indices = range(num)
305 | indices = []
306 | each = max(num//self.seq_len,1)
307 | for i in range(self.seq_len):
308 | if i != self.seq_len -1:
309 | indices.append(random.randint(min(i*each , num-1), min( (i+1)*each-1, num-1)) )
310 | else:
311 | indices.append(random.randint(min(i*each , num-1), num-1) )
312 | # print(len(indices), indices, num )
313 | imgs = []
314 | labels = []
315 | targt_cam=[]
316 |
317 | for index in indices:
318 | index=int(index)
319 | img_path = img_paths[index]
320 |
321 | img = read_image(img_path)
322 | if self.transform is not None:
323 | img = self.transform(img)
324 | img , temp = self.erase(img)
325 | labels.append(temp)
326 | img = img.unsqueeze(0)
327 | imgs.append(img)
328 | targt_cam.append(camid)
329 | labels = torch.tensor(labels)
330 | imgs = torch.cat(imgs, dim=0)
331 |
332 | return imgs, pid, targt_cam ,labels
333 |
334 |
335 |
--------------------------------------------------------------------------------
/utility.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.sampler import Sampler
2 | from collections import defaultdict
3 | import copy
4 | import random
5 | import numpy as np
6 | import torch
7 | import logging
8 | import math
9 | import torch
10 | from typing import Dict, Any
11 | class RandomIdentitySampler(Sampler):
12 | """
13 | Randomly sample N identities, then for each identity,
14 | randomly sample K instances, therefore batch size is N*K.
15 | Args:
16 | - data_source (list): list of (img_path, pid, camid).
17 | - num_instances (int): number of instances per identity in a batch.
18 | - batch_size (int): number of examples in a batch.
19 | """
20 |
21 | def __init__(self, data_source, batch_size, num_instances):
22 | self.data_source = data_source
23 | self.batch_size = batch_size
24 | self.num_instances = num_instances
25 | self.num_pids_per_batch = self.batch_size // self.num_instances
26 | self.index_dic = defaultdict(list) #dict with list value
27 | #{783: [0, 5, 116, 876, 1554, 2041],...,}
28 | for index, (_, pid, _) in enumerate(self.data_source):
29 | self.index_dic[pid].append(index)
30 | self.pids = list(self.index_dic.keys())
31 |
32 | # estimate number of examples in an epoch
33 | self.length = 0
34 | for pid in self.pids:
35 | idxs = self.index_dic[pid]
36 | num = len(idxs)
37 | if num < self.num_instances:
38 | num = self.num_instances
39 | self.length += num - num % self.num_instances
40 |
41 | def __iter__(self):
42 | batch_idxs_dict = defaultdict(list)
43 |
44 | for pid in self.pids:
45 | idxs = copy.deepcopy(self.index_dic[pid])
46 | if len(idxs) < self.num_instances:
47 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
48 | random.shuffle(idxs)
49 | batch_idxs = []
50 | for idx in idxs:
51 | batch_idxs.append(idx)
52 | if len(batch_idxs) == self.num_instances:
53 | batch_idxs_dict[pid].append(batch_idxs)
54 | batch_idxs = []
55 |
56 | avai_pids = copy.deepcopy(self.pids)
57 | final_idxs = []
58 |
59 | while len(avai_pids) >= self.num_pids_per_batch:
60 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
61 | for pid in selected_pids:
62 | batch_idxs = batch_idxs_dict[pid].pop(0)
63 | final_idxs.extend(batch_idxs)
64 | if len(batch_idxs_dict[pid]) == 0:
65 | avai_pids.remove(pid)
66 |
67 | return iter(final_idxs)
68 |
69 | def __len__(self):
70 | return self.length
71 |
72 | class AverageMeter(object):
73 | """Computes and stores the average and current value"""
74 |
75 | def __init__(self):
76 | self.val = 0
77 | self.avg = 0
78 | self.sum = 0
79 | self.count = 0
80 |
81 | def reset(self):
82 | self.val = 0
83 | self.avg = 0
84 | self.sum = 0
85 | self.count = 0
86 |
87 | def update(self, val, n=1):
88 | self.val = val
89 | self.sum += val * n
90 | self.count += n
91 | self.avg = self.sum / self.count
92 |
93 | class RandomErasing3(object):
94 | """ Randomly selects a rectangle region in an image and erases its pixels.
95 | 'Random Erasing Data Augmentation' by Zhong et al.
96 | See https://arxiv.org/pdf/1708.04896.pdf
97 | Args:
98 | probability: The probability that the Random Erasing operation will be performed.
99 | sl: Minimum proportion of erased area against input image.
100 | sh: Maximum proportion of erased area against input image.
101 | r1: Minimum aspect ratio of erased area.
102 | mean: Erasing value.
103 | """
104 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
105 | self.probability = probability
106 | self.mean = mean
107 | self.sl = sl
108 | self.sh = sh
109 | self.r1 = r1
110 | def __call__(self, img):
111 | if random.uniform(0, 1) >= self.probability:
112 | return img , 0
113 | for attempt in range(100):
114 | area = img.size()[1] * img.size()[2]
115 | target_area = random.uniform(self.sl, self.sh) * area
116 | aspect_ratio = random.uniform(self.r1, 1 / self.r1)
117 | h = int(round(math.sqrt(target_area * aspect_ratio)))
118 | w = int(round(math.sqrt(target_area / aspect_ratio)))
119 | if w < img.size()[2] and h < img.size()[1]:
120 | x1 = random.randint(0, img.size()[1] - h)
121 | y1 = random.randint(0, img.size()[2] - w)
122 | if img.size()[0] == 3:
123 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
124 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
125 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
126 | else:
127 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
128 | return img , 1
129 | return img , 0
130 |
131 |
132 | def scheduler(optimizer):
133 | num_epochs = 120
134 |
135 | lr_min = 0.002 * 0.008
136 | warmup_lr_init = 0.01 * 0.008
137 |
138 | warmup_t = 5
139 | noise_range = None
140 |
141 | lr_scheduler = CosineLRScheduler(
142 | optimizer,
143 | t_initial=num_epochs,
144 | lr_min=lr_min,
145 | t_mul= 1.,
146 | decay_rate=0.1,
147 | warmup_lr_init=warmup_lr_init,
148 | warmup_t=warmup_t,
149 | cycle_limit=1,
150 | t_in_epochs=True,
151 | noise_range_t=noise_range,
152 | noise_pct= 0.67,
153 | noise_std= 1.,
154 | noise_seed=42,
155 | )
156 |
157 | return lr_scheduler
158 |
159 |
160 |
161 |
162 | def optimizer(model):
163 | params = []
164 | for key, value in model.named_parameters():
165 | if not value.requires_grad:
166 | continue
167 | lr = 0.008
168 | weight_decay = 1e-4
169 | if "bias" in key:
170 | lr = 0.008 * 2
171 | weight_decay = 1e-4
172 |
173 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
174 |
175 |
176 | optimizer = getattr(torch.optim, 'SGD')(params, momentum=0.9)
177 |
178 |
179 |
180 | return optimizer
181 |
182 |
183 |
184 |
185 |
186 | class Scheduler:
187 | """ Parameter Scheduler Base Class
188 | A scheduler base class that can be used to schedule any optimizer parameter groups.
189 |
190 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called
191 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
192 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
193 |
194 | The schedulers built on this should try to remain as stateless as possible (for simplicity).
195 |
196 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
197 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training
198 | code and explicitly passed in to the schedulers on the corresponding step or step_update call.
199 |
200 | Based on ideas from:
201 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
202 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
203 | """
204 |
205 | def __init__(self,
206 | optimizer: torch.optim.Optimizer,
207 | param_group_field: str,
208 | noise_range_t=None,
209 | noise_type='normal',
210 | noise_pct=0.67,
211 | noise_std=1.0,
212 | noise_seed=None,
213 | initialize: bool = True) -> None:
214 | self.optimizer = optimizer
215 | self.param_group_field = param_group_field
216 | self._initial_param_group_field = f"initial_{param_group_field}"
217 | if initialize:
218 | for i, group in enumerate(self.optimizer.param_groups):
219 | if param_group_field not in group:
220 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
221 | group.setdefault(self._initial_param_group_field, group[param_group_field])
222 | else:
223 | for i, group in enumerate(self.optimizer.param_groups):
224 | if self._initial_param_group_field not in group:
225 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
226 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
227 | self.metric = None # any point to having this for all?
228 | self.noise_range_t = noise_range_t
229 | self.noise_pct = noise_pct
230 | self.noise_type = noise_type
231 | self.noise_std = noise_std
232 | self.noise_seed = noise_seed if noise_seed is not None else 42
233 | self.update_groups(self.base_values)
234 |
235 | def state_dict(self) -> Dict[str, Any]:
236 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
237 |
238 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
239 | self.__dict__.update(state_dict)
240 |
241 | def get_epoch_values(self, epoch: int):
242 | return None
243 |
244 | def get_update_values(self, num_updates: int):
245 | return None
246 |
247 | def step(self, epoch: int, metric: float = None) -> None:
248 | self.metric = metric
249 | values = self.get_epoch_values(epoch)
250 | if values is not None:
251 | values = self._add_noise(values, epoch)
252 | self.update_groups(values)
253 |
254 | def step_update(self, num_updates: int, metric: float = None):
255 | self.metric = metric
256 | values = self.get_update_values(num_updates)
257 | if values is not None:
258 | values = self._add_noise(values, num_updates)
259 | self.update_groups(values)
260 |
261 | def update_groups(self, values):
262 | if not isinstance(values, (list, tuple)):
263 | values = [values] * len(self.optimizer.param_groups)
264 | for param_group, value in zip(self.optimizer.param_groups, values):
265 | param_group[self.param_group_field] = value
266 |
267 | def _add_noise(self, lrs, t):
268 | if self.noise_range_t is not None:
269 | if isinstance(self.noise_range_t, (list, tuple)):
270 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
271 | else:
272 | apply_noise = t >= self.noise_range_t
273 | if apply_noise:
274 | g = torch.Generator()
275 | g.manual_seed(self.noise_seed + t)
276 | if self.noise_type == 'normal':
277 | while True:
278 | # resample if noise out of percent limit, brute force but shouldn't spin much
279 | noise = torch.randn(1, generator=g).item()
280 | if abs(noise) < self.noise_pct:
281 | break
282 | else:
283 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
284 | lrs = [v + v * noise for v in lrs]
285 | return lrs
286 |
287 | class CosineLRScheduler(Scheduler):
288 | """
289 | Cosine decay with restarts.
290 | This is described in the paper https://arxiv.org/abs/1608.03983.
291 |
292 | Inspiration from
293 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
294 | """
295 |
296 | def __init__(self,
297 | optimizer: torch.optim.Optimizer,
298 | t_initial: int,
299 | t_mul: float = 1.,
300 | lr_min: float = 0.,
301 | decay_rate: float = 1.,
302 | warmup_t=0,
303 | warmup_lr_init=0,
304 | warmup_prefix=False,
305 | cycle_limit=0,
306 | t_in_epochs=True,
307 | noise_range_t=None,
308 | noise_pct=0.67,
309 | noise_std=1.0,
310 | noise_seed=42,
311 | initialize=True) -> None:
312 | super().__init__(
313 | optimizer, param_group_field="lr",
314 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
315 | initialize=initialize)
316 |
317 | assert t_initial > 0
318 | assert lr_min >= 0
319 | if t_initial == 1 and t_mul == 1 and decay_rate == 1:
320 | _logger.warning("Cosine annealing scheduler will have no effect on the learning "
321 | "rate since t_initial = t_mul = eta_mul = 1.")
322 | self.t_initial = t_initial
323 | self.t_mul = t_mul
324 | self.lr_min = lr_min
325 | self.decay_rate = decay_rate
326 | self.cycle_limit = cycle_limit
327 | self.warmup_t = warmup_t
328 | self.warmup_lr_init = warmup_lr_init
329 | self.warmup_prefix = warmup_prefix
330 | self.t_in_epochs = t_in_epochs
331 | if self.warmup_t:
332 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
333 | super().update_groups(self.warmup_lr_init)
334 | else:
335 | self.warmup_steps = [1 for _ in self.base_values]
336 |
337 | def _get_lr(self, t):
338 | if t < self.warmup_t:
339 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
340 | else:
341 | if self.warmup_prefix:
342 | t = t - self.warmup_t
343 |
344 | if self.t_mul != 1:
345 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
346 | t_i = self.t_mul ** i * self.t_initial
347 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
348 | else:
349 | i = t // self.t_initial
350 | t_i = self.t_initial
351 | t_curr = t - (self.t_initial * i)
352 |
353 | gamma = self.decay_rate ** i
354 | lr_min = self.lr_min * gamma
355 | lr_max_values = [v * gamma for v in self.base_values]
356 |
357 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
358 | lrs = [
359 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
360 | ]
361 | else:
362 | lrs = [self.lr_min for _ in self.base_values]
363 |
364 | return lrs
365 |
366 | def get_epoch_values(self, epoch: int):
367 | if self.t_in_epochs:
368 | return self._get_lr(epoch)
369 | else:
370 | return None
371 |
372 | def get_update_values(self, num_updates: int):
373 | if not self.t_in_epochs:
374 | return self._get_lr(num_updates)
375 | else:
376 | return None
377 |
378 | def get_cycle_length(self, cycles=0):
379 | if not cycles:
380 | cycles = self.cycle_limit
381 | cycles = max(1, cycles)
382 | if self.t_mul == 1.0:
383 | return self.t_initial * cycles
384 | else:
385 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
386 |
387 |
388 |
389 |
--------------------------------------------------------------------------------
/vit_ID.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | from functools import partial
4 | from itertools import repeat
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch._six import container_abcs
10 |
11 |
12 | # From PyTorch internals
13 | def _ntuple(n):
14 | def parse(x):
15 | if isinstance(x, container_abcs.Iterable):
16 | return x
17 | return tuple(repeat(x, n))
18 | return parse
19 |
20 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
21 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
22 | to_2tuple = _ntuple(2)
23 |
24 | def drop_path(x, drop_prob: float = 0., training: bool = False):
25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
26 |
27 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
28 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
29 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
30 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
31 | 'survival rate' as the argument.
32 |
33 | """
34 | if drop_prob == 0. or not training:
35 | return x
36 | keep_prob = 1 - drop_prob
37 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
38 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
39 | random_tensor.floor_() # binarize
40 | output = x.div(keep_prob) * random_tensor
41 | return output
42 |
43 | class DropPath(nn.Module):
44 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
45 | """
46 | def __init__(self, drop_prob=None):
47 | super(DropPath, self).__init__()
48 | self.drop_prob = drop_prob
49 |
50 | def forward(self, x):
51 | return drop_path(x, self.drop_prob, self.training)
52 |
53 |
54 |
55 |
56 | class Mlp(nn.Module):
57 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
58 | super().__init__()
59 | out_features = out_features or in_features
60 | hidden_features = hidden_features or in_features
61 | self.fc1 = nn.Linear(in_features, hidden_features)
62 | self.act = act_layer()
63 | self.fc2 = nn.Linear(hidden_features, out_features)
64 | self.drop = nn.Dropout(drop)
65 |
66 | def forward(self, x):
67 | x = self.fc1(x)
68 | x = self.act(x)
69 | x = self.drop(x)
70 | x = self.fc2(x)
71 | x = self.drop(x)
72 | return x
73 |
74 |
75 | class Attention(nn.Module):
76 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
77 | super().__init__()
78 | self.num_heads = num_heads
79 | head_dim = dim // num_heads
80 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
81 | self.scale = qk_scale or head_dim ** -0.5
82 |
83 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
84 | self.attn_drop = nn.Dropout(attn_drop)
85 | self.proj = nn.Linear(dim, dim)
86 | self.proj_drop = nn.Dropout(proj_drop)
87 |
88 | def forward(self, x):
89 |
90 | B, N, C = x.shape
91 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
92 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
93 |
94 | attn = (q @ k.transpose(-2, -1)) * self.scale
95 | attn = attn.softmax(dim=-1)
96 | attn = self.attn_drop(attn)
97 |
98 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
99 | x = self.proj(x)
100 | x = self.proj_drop(x)
101 | return x
102 |
103 |
104 | class Block(nn.Module):
105 |
106 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
107 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
108 | super().__init__()
109 | self.norm1 = norm_layer(dim)
110 | self.attn = Attention(
111 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
112 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
113 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
114 | self.norm2 = norm_layer(dim)
115 | mlp_hidden_dim = int(dim * mlp_ratio)
116 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
117 |
118 | def forward(self, x):
119 | x = x + self.drop_path(self.attn(self.norm1(x)))
120 | x = x + self.drop_path(self.mlp(self.norm2(x)))
121 | return x
122 |
123 |
124 | class PatchEmbed(nn.Module):
125 | """ Image to Patch Embedding
126 | """
127 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
128 | super().__init__()
129 | img_size = to_2tuple(img_size)
130 | patch_size = to_2tuple(patch_size)
131 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
132 | self.img_size = img_size
133 | self.patch_size = patch_size
134 | self.num_patches = num_patches
135 |
136 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
137 |
138 | def forward(self, x):
139 | B, C, H, W = x.shape
140 | # FIXME look at relaxing size constraints
141 | assert H == self.img_size[0] and W == self.img_size[1], \
142 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
143 | x = self.proj(x).flatten(2).transpose(1, 2)
144 | return x
145 |
146 |
147 |
148 |
149 |
150 | class PatchEmbed_overlap(nn.Module):
151 | """ Image to Patch Embedding with overlapping patches
152 | """
153 | def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768):
154 | super().__init__()
155 | img_size = to_2tuple(img_size)
156 | patch_size = to_2tuple(patch_size)
157 | stride_size_tuple = to_2tuple(stride_size)
158 | self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1
159 | self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1
160 | print('using stride: {}, and patch number is num_y{} * num_x{}'.format(stride_size, self.num_y, self.num_x))
161 | num_patches = self.num_x * self.num_y
162 | self.img_size = img_size
163 | self.patch_size = patch_size
164 | self.num_patches = num_patches
165 |
166 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size)
167 | for m in self.modules():
168 | if isinstance(m, nn.Conv2d):
169 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
170 | m.weight.data.normal_(0, math.sqrt(2. / n))
171 | elif isinstance(m, nn.BatchNorm2d):
172 | m.weight.data.fill_(1)
173 | m.bias.data.zero_()
174 | elif isinstance(m, nn.InstanceNorm2d):
175 | m.weight.data.fill_(1)
176 | m.bias.data.zero_()
177 |
178 | def forward(self, x):
179 | B, C, H, W = x.shape
180 |
181 | # FIXME look at relaxing size constraints
182 | assert H == self.img_size[0] and W == self.img_size[1], \
183 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
184 | x = self.proj(x)
185 |
186 | x = x.flatten(2).transpose(1, 2) # [64, 8, 768]
187 | return x
188 |
189 |
190 | class TransReID(nn.Module):
191 | """ Transformer-based Object Re-Identification
192 | """
193 | def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
194 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., camera=0,
195 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, cam_lambda =3.0):
196 | super().__init__()
197 | self.num_classes = num_classes
198 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
199 |
200 | self.cam_num = camera
201 | self.cam_lambda = cam_lambda
202 |
203 |
204 | self.patch_embed = PatchEmbed_overlap(img_size=img_size, patch_size=patch_size, stride_size=stride_size, in_chans=in_chans,embed_dim=embed_dim)
205 | num_patches = self.patch_embed.num_patches
206 |
207 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
208 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
209 | self.Cam = nn.Parameter(torch.zeros(camera, 1, embed_dim))
210 |
211 | trunc_normal_(self.Cam, std=.02)
212 | self.pos_drop = nn.Dropout(p=drop_rate)
213 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
214 |
215 | self.blocks = nn.ModuleList([
216 | Block(
217 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
218 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
219 | for i in range(depth)])
220 |
221 | self.norm = norm_layer(embed_dim)
222 |
223 | # Classifier head
224 | self.fc = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
225 | trunc_normal_(self.cls_token, std=.02)
226 | trunc_normal_(self.pos_embed, std=.02)
227 |
228 | self.apply(self._init_weights)
229 |
230 | def _init_weights(self, m):
231 | if isinstance(m, nn.Linear):
232 | trunc_normal_(m.weight, std=.02)
233 | if isinstance(m, nn.Linear) and m.bias is not None:
234 | nn.init.constant_(m.bias, 0)
235 | elif isinstance(m, nn.LayerNorm):
236 | nn.init.constant_(m.bias, 0)
237 | nn.init.constant_(m.weight, 1.0)
238 |
239 | @torch.jit.ignore
240 | def no_weight_decay(self):
241 | return {'pos_embed', 'cls_token'}
242 |
243 | def get_classifier(self):
244 | return self.head
245 |
246 | def reset_classifier(self, num_classes, global_pool=''):
247 | self.num_classes = num_classes
248 | self.fc = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
249 |
250 | def forward_features(self, x, camera_id):
251 | B = x.shape[0]
252 |
253 | x = self.patch_embed(x)
254 |
255 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
256 | x = torch.cat((cls_tokens, x), dim=1)
257 | x = x + self.pos_embed + self.cam_lambda * self.Cam[camera_id]
258 | x = self.pos_drop(x)
259 |
260 | for blk in self.blocks[:-1]:
261 | x = blk(x)
262 | return x
263 |
264 |
265 | def forward(self, x, cam_label=None):
266 | x = self.forward_features(x, cam_label)
267 | return x
268 |
269 | def load_param(self, model_path,load=False):
270 | if not load:
271 | param_dict = torch.load(model_path, map_location='cpu')
272 | else:
273 | param_dict= model_path
274 | if 'model' in param_dict:
275 | param_dict = param_dict['model']
276 | if 'state_dict' in param_dict:
277 | param_dict = param_dict['state_dict']
278 | for k, v in param_dict.items():
279 | if 'head' in k or 'dist' in k:
280 | continue
281 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
282 | # For old models that I trained prior to conv based patchification
283 | O, I, H, W = self.patch_embed.proj.weight.shape
284 | v = v.reshape(O, -1, H, W)
285 | elif k == 'pos_embed' and v.shape != self.pos_embed.shape:
286 | # To resize pos embedding when using model at different size from pretrained weights
287 | if 'distilled' in model_path:
288 | print('distill need to choose right cls token in the pth')
289 | v = torch.cat([v[:, 0:1], v[:, 2:]], dim=1)
290 | v = resize_pos_embed(v, self.pos_embed, self.patch_embed.num_y, self.patch_embed.num_x)
291 | try:
292 | self.state_dict()[k].copy_(v)
293 | except:
294 | print('===========================ERROR=========================')
295 | print('shape do not match in k :{}: param_dict{} vs self.state_dict(){}'.format(k, v.shape, self.state_dict()[k].shape))
296 |
297 |
298 | def resize_pos_embed(posemb, posemb_new, hight, width):
299 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
300 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
301 | ntok_new = posemb_new.shape[1]
302 |
303 | posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
304 | ntok_new -= 1
305 |
306 | gs_old = int(math.sqrt(len(posemb_grid)))
307 | print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width))
308 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
309 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
310 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
311 | posemb = torch.cat([posemb_token, posemb_grid], dim=1)
312 | return posemb
313 |
314 |
315 |
316 |
317 |
318 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
319 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
320 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
321 | def norm_cdf(x):
322 | # Computes standard normal cumulative distribution function
323 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
324 |
325 | if (mean < a - 2 * std) or (mean > b + 2 * std):
326 | print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
327 | "The distribution of values may be incorrect.",)
328 |
329 | with torch.no_grad():
330 | # Values are generated by using a truncated uniform distribution and
331 | # then using the inverse CDF for the normal distribution.
332 | # Get upper and lower cdf values
333 | l = norm_cdf((a - mean) / std)
334 | u = norm_cdf((b - mean) / std)
335 |
336 | # Uniformly fill tensor with values from [l, u], then translate to
337 | # [2l-1, 2u-1].
338 | tensor.uniform_(2 * l - 1, 2 * u - 1)
339 |
340 | # Use inverse cdf transform for normal distribution to get truncated
341 | # standard normal
342 | tensor.erfinv_()
343 |
344 | # Transform to proper mean, std
345 | tensor.mul_(std * math.sqrt(2.))
346 | tensor.add_(mean)
347 |
348 | # Clamp to ensure it's in the proper range
349 | tensor.clamp_(min=a, max=b)
350 | return tensor
351 |
352 |
353 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
354 | # type: (Tensor, float, float, float, float) -> Tensor
355 | r"""Fills the input Tensor with values drawn from a truncated
356 | normal distribution. The values are effectively drawn from the
357 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
358 | with values outside :math:`[a, b]` redrawn until they are within
359 | the bounds. The method used for generating the random values works
360 | best when :math:`a \leq \text{mean} \leq b`.
361 | Args:
362 | tensor: an n-dimensional `torch.Tensor`
363 | mean: the mean of the normal distribution
364 | std: the standard deviation of the normal distribution
365 | a: the minimum cutoff value
366 | b: the maximum cutoff value
367 | Examples:
368 | >>> w = torch.empty(3, 5)
369 | >>> nn.init.trunc_normal_(w)
370 | """
371 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
372 |
--------------------------------------------------------------------------------