├── layers ├── loss │ ├── __init__.py │ ├── __pycache__ │ │ ├── JSD.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── am_softmax.cpython-37.pyc │ │ ├── center_loss.cpython-37.pyc │ │ ├── triplet_loss.cpython-37.pyc │ │ ├── crossquad_loss.cpython-37.pyc │ │ ├── mixtriplet_loss.cpython-37.pyc │ │ ├── trapezoid_loss.cpython-37.pyc │ │ ├── crosstriplet_loss.cpython-37.pyc │ │ └── local_center_loss.cpython-37.pyc │ ├── JSD.py │ ├── am_softmax.py │ ├── triplet_loss.py │ ├── local_center_loss.py │ └── center_loss.py ├── module │ ├── __init__.py │ ├── __pycache__ │ │ ├── CBAM.cpython-37.pyc │ │ ├── NonLocal.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── norm_linear.cpython-37.pyc │ │ └── reverse_grad.cpython-37.pyc │ ├── norm_linear.py │ ├── reverse_grad.py │ ├── CBAM.py │ └── NonLocal.py ├── __pycache__ │ └── __init__.cpython-37.pyc └── __init__.py ├── figs └── backbone.png ├── data ├── __pycache__ │ ├── clone.cpython-37.pyc │ ├── dataset.cpython-36.pyc │ ├── dataset.cpython-37.pyc │ ├── sampler.cpython-36.pyc │ ├── sampler.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc ├── __init__.py ├── dataset.py └── sampler.py ├── utils ├── __pycache__ │ ├── rerank.cpython-37.pyc │ ├── calc_acc.cpython-37.pyc │ ├── neighbor.cpython-37.pyc │ ├── eval_regdb.cpython-37.pyc │ └── eval_sysu.cpython-37.pyc ├── calc_acc.py ├── tsne.py ├── rerank.py ├── eval_regdb.py └── eval_sysu.py ├── engine ├── __pycache__ │ ├── engine.cpython-37.pyc │ ├── metric.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc ├── metric.py ├── engine.py └── __init__.py ├── models ├── __pycache__ │ ├── resnet.cpython-36.pyc │ ├── resnet.cpython-37.pyc │ ├── baseline.cpython-36.pyc │ └── baseline.cpython-37.pyc ├── baseline.py └── resnet.py ├── configs ├── default │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── dataset.cpython-37.pyc │ │ └── strategy.cpython-37.pyc │ ├── __init__.py │ ├── dataset.py │ └── strategy.py ├── RegDB.yml └── SYSU.yml ├── LICENSE ├── README.md └── train.py /layers/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /layers/module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figs/backbone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/figs/backbone.png -------------------------------------------------------------------------------- /data/__pycache__/clone.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/data/__pycache__/clone.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/data/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/data/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/data/__pycache__/sampler.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/rerank.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/utils/__pycache__/rerank.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /engine/__pycache__/engine.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/engine/__pycache__/engine.cpython-37.pyc -------------------------------------------------------------------------------- /engine/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/engine/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/calc_acc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/utils/__pycache__/calc_acc.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/neighbor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/utils/__pycache__/neighbor.cpython-37.pyc -------------------------------------------------------------------------------- /engine/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/engine/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /engine/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/engine/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /layers/loss/__pycache__/JSD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/loss/__pycache__/JSD.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/baseline.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/models/__pycache__/baseline.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/baseline.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/models/__pycache__/baseline.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval_regdb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/utils/__pycache__/eval_regdb.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval_sysu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/utils/__pycache__/eval_sysu.cpython-37.pyc -------------------------------------------------------------------------------- /layers/module/__pycache__/CBAM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/module/__pycache__/CBAM.cpython-37.pyc -------------------------------------------------------------------------------- /layers/loss/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/loss/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /layers/loss/__pycache__/am_softmax.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/loss/__pycache__/am_softmax.cpython-37.pyc -------------------------------------------------------------------------------- /layers/module/__pycache__/NonLocal.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/module/__pycache__/NonLocal.cpython-37.pyc -------------------------------------------------------------------------------- /layers/module/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/module/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /configs/default/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/configs/default/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /configs/default/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/configs/default/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /configs/default/__pycache__/strategy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/configs/default/__pycache__/strategy.cpython-37.pyc -------------------------------------------------------------------------------- /layers/loss/__pycache__/center_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/loss/__pycache__/center_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/loss/__pycache__/triplet_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/loss/__pycache__/triplet_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/loss/__pycache__/crossquad_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/loss/__pycache__/crossquad_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/loss/__pycache__/mixtriplet_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/loss/__pycache__/mixtriplet_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/loss/__pycache__/trapezoid_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/loss/__pycache__/trapezoid_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/module/__pycache__/norm_linear.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/module/__pycache__/norm_linear.cpython-37.pyc -------------------------------------------------------------------------------- /layers/module/__pycache__/reverse_grad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/module/__pycache__/reverse_grad.cpython-37.pyc -------------------------------------------------------------------------------- /layers/loss/__pycache__/crosstriplet_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/loss/__pycache__/crosstriplet_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/loss/__pycache__/local_center_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoubtedSteam/MPANet/HEAD/layers/loss/__pycache__/local_center_loss.cpython-37.pyc -------------------------------------------------------------------------------- /configs/default/__init__.py: -------------------------------------------------------------------------------- 1 | from configs.default.dataset import dataset_cfg 2 | from configs.default.strategy import strategy_cfg 3 | 4 | __all__ = ["dataset_cfg", "strategy_cfg"] 5 | -------------------------------------------------------------------------------- /configs/default/dataset.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | dataset_cfg = CfgNode() 4 | 5 | # config for dataset 6 | dataset_cfg.sysu = CfgNode() 7 | dataset_cfg.sysu.num_id = 395 8 | dataset_cfg.sysu.num_cam = 6 9 | dataset_cfg.sysu.data_root = "../dataset/SYSU-MM01" 10 | 11 | dataset_cfg.regdb = CfgNode() 12 | dataset_cfg.regdb.num_id = 206 13 | dataset_cfg.regdb.num_cam = 2 14 | dataset_cfg.regdb.data_root = "../dataset/RegDB" 15 | 16 | dataset_cfg.market = CfgNode() 17 | dataset_cfg.market.num_id = 751 18 | dataset_cfg.market.num_cam = 6 19 | dataset_cfg.market.data_root = "../dataset/market" 20 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from layers.loss.am_softmax import AMSoftmaxLoss 2 | from layers.loss.center_loss import CenterLoss 3 | from layers.loss.triplet_loss import TripletLoss 4 | from layers.loss.local_center_loss import CenterTripletLoss 5 | from layers.module.norm_linear import NormalizeLinear 6 | from layers.module.reverse_grad import ReverseGrad 7 | from layers.loss.JSD import js_div 8 | from layers.module.CBAM import cbam 9 | from layers.module.NonLocal import NonLocalBlockND 10 | 11 | 12 | __all__ = ['CenterLoss', 'CenterTripletLoss', 'AMSoftmaxLoss', 'TripletLoss', 'NormalizeLinear', 'js_div', 'cbam', 'NonLocalBlockND'] -------------------------------------------------------------------------------- /layers/module/norm_linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | 8 | class NormalizeLinear(nn.Module): 9 | def __init__(self, in_features, num_class): 10 | super(NormalizeLinear, self).__init__() 11 | self.weight = nn.Parameter(torch.Tensor(num_class, in_features)) 12 | self.reset_parameters() 13 | 14 | def reset_parameters(self): 15 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 16 | 17 | def forward(self, x): 18 | w = F.normalize(self.weight.float(), p=2, dim=1) 19 | return F.linear(x.float(), w) 20 | -------------------------------------------------------------------------------- /layers/loss/JSD.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # import torch.softmax as softmax 3 | from torch.nn import functional as F 4 | 5 | class js_div: 6 | def __init__(self): 7 | self.KLDivLoss = nn.KLDivLoss(reduction='batchmean') 8 | 9 | def __call__(self, p_output, q_output, get_softmax=True): 10 | """ 11 | Function that measures JS divergence between target and output logits: 12 | """ 13 | if get_softmax: 14 | p_output = F.softmax(p_output, 1) 15 | q_output = F.softmax(q_output, 1) 16 | log_mean_output = ((p_output + q_output) / 2).log() 17 | return (self.KLDivLoss(log_mean_output, p_output) + self.KLDivLoss(log_mean_output, q_output))/2 -------------------------------------------------------------------------------- /layers/module/reverse_grad.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | 4 | 5 | class ReverseGradFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, data, alpha=1.0): 9 | ctx.alpha = alpha 10 | return data 11 | 12 | @staticmethod 13 | def backward(ctx, grad_outputs): 14 | grad = None 15 | 16 | if ctx.needs_input_grad[0]: 17 | grad = -ctx.alpha * grad_outputs 18 | 19 | return grad, None 20 | 21 | 22 | class ReverseGrad(nn.Module): 23 | def __init__(self): 24 | super(ReverseGrad, self).__init__() 25 | 26 | def forward(self, x, alpha=1.0): 27 | return ReverseGradFunction.apply(x, alpha) 28 | -------------------------------------------------------------------------------- /utils/calc_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calc_acc(logits, label, ignore_index=-100, mode="multiclass"): 5 | if mode == "binary": 6 | indices = torch.round(logits).type(label.type()) 7 | elif mode == "multiclass": 8 | indices = torch.max(logits, dim=1)[1] 9 | 10 | if label.size() == logits.size(): 11 | ignore = 1 - torch.round(label.sum(dim=1)) 12 | label = torch.max(label, dim=1)[1] 13 | else: 14 | ignore = torch.eq(label, ignore_index).view(-1) 15 | 16 | correct = torch.eq(indices, label).view(-1) 17 | num_correct = torch.sum(correct) 18 | num_examples = logits.shape[0] - ignore.sum() 19 | 20 | return num_correct.float() / num_examples.float() 21 | -------------------------------------------------------------------------------- /layers/loss/am_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AMSoftmaxLoss(nn.Module): 8 | def __init__(self, scale, margin, weight=None, ignore_index=-100, reduction='mean'): 9 | super(AMSoftmaxLoss, self).__init__() 10 | self.weight = weight 11 | self.ignore_index = ignore_index 12 | self.reduction = reduction 13 | self.scale = scale 14 | self.margin = margin 15 | 16 | def forward(self, x, y): 17 | y_onehot = torch.zeros_like(x, device=x.device) 18 | y_onehot.scatter_(1, y.data.view(-1, 1), self.margin) 19 | 20 | out = self.scale * (x - y_onehot) 21 | loss = F.cross_entropy(out, y, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) 22 | 23 | return loss 24 | -------------------------------------------------------------------------------- /configs/RegDB.yml: -------------------------------------------------------------------------------- 1 | prefix: RegDB 2 | 3 | fp16: true 4 | rerank: false 5 | 6 | # dataset 7 | sample_method: identity_random 8 | image_size: (256, 128) 9 | p_size: 12 10 | k_size: 10 11 | 12 | dataset: regdb 13 | 14 | # loss 15 | classification: true 16 | center_cluster: true 17 | triplet: false 18 | center: false 19 | 20 | # parameters 21 | margin: 0.7 22 | # pattern attention 23 | num_parts: 6 24 | weight_sep: 0.5 25 | # mutual learning 26 | update_rate: 0.2 27 | weight_sid: 0.5 28 | weight_KL: 2.5 29 | 30 | # architecture 31 | drop_last_stride: true 32 | pattern_attention: true 33 | mutual_learning: true 34 | modality_attention: 2 35 | 36 | # optimizer 37 | lr: 0.00035 38 | optimizer: adam 39 | num_epoch: 140 40 | lr_step: [110] 41 | 42 | # augmentation 43 | random_flip: true 44 | random_crop: true 45 | random_erase: true 46 | color_jitter: false 47 | padding: 10 48 | 49 | # log 50 | log_period: 20 51 | start_eval: 115 52 | eval_interval: 5 53 | -------------------------------------------------------------------------------- /configs/SYSU.yml: -------------------------------------------------------------------------------- 1 | prefix: SYSU 2 | 3 | fp16: true 4 | rerank: false 5 | 6 | # dataset 7 | sample_method: identity_random 8 | image_size: (384, 128) 9 | p_size: 16 10 | k_size: 8 11 | 12 | dataset: sysu 13 | 14 | # loss 15 | classification: true 16 | center_cluster: true 17 | triplet: false 18 | center: false 19 | 20 | # parameters 21 | margin: 0.7 22 | # pattern attention 23 | num_parts: 6 24 | weight_sep: 0.5 25 | # mutual learning 26 | update_rate: 0.2 27 | weight_sid: 0.5 28 | weight_KL: 2.5 29 | 30 | # architecture 31 | drop_last_stride: true 32 | pattern_attention: true 33 | mutual_learning: true 34 | modality_attention: 2 35 | 36 | # optimizer 37 | lr: 0.00035 38 | optimizer: adam 39 | num_epoch: 140 40 | lr_step: [80, 120] 41 | 42 | # augmentation 43 | random_flip: true 44 | random_crop: true 45 | random_erase: true 46 | color_jitter: false 47 | padding: 10 48 | 49 | # log 50 | log_period: 150 51 | start_eval: 115 52 | eval_interval: 5 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 DoubtedSteam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Discover Cross-Modality Nuances for Visible-Infrared Person Re-Identification 2 | [\[paper\]](https://openaccess.thecvf.com/content/CVPR2021/papers/Wu_Discover_Cross-Modality_Nuances_for_Visible-Infrared_Person_Re-Identification_CVPR_2021_paper.pdf) 3 | 4 | This repository is Pytorch code for our proposed joint Modality and Pattern Alignment Network (MPANet) 5 | 6 | ![](figs/backbone.png) 7 | 8 | ## Environmental requirements: 9 | 10 | PyTorch == 1.2.0 11 | 12 | ignite == 0.2.1 13 | 14 | torchvision == 0.4.0 15 | 16 | apex == 0.1 17 | 18 | ## Quick start 19 | 20 | 1. Clone this repository: 21 | 22 | ```shell 23 | git clone https://github.com/MPANet/MPANet.git 24 | ``` 25 | 26 | 2. Modify the path to datasets: 27 | 28 | The path to datasets can be modified in the following file: 29 | 30 | ```shell 31 | ./configs/default/dataset.py 32 | ``` 33 | 34 | 3. Training: 35 | 36 | To train the model, you can use following command: 37 | 38 | SYSU-MM01: 39 | ```Shell 40 | python train.py --cfg ./configs/SYSU.yml 41 | ``` 42 | 43 | RegDB: 44 | ```Shell 45 | python train.py --cfg ./configs/RegDB.yml 46 | ``` 47 | 48 | ## trained model: 49 | The checkpoint can be found here: 50 | https://pan.baidu.com/s/1TnjtfMFPnm5TEprgAhqz9A 51 | 52 | Code: rfti 53 | 54 | ## Reference: 55 | [LuckyDC/RGB-IR-ReID-Baseline](https://github.com/LuckyDC/RGB-IR-ReID-Baseline) 56 | 57 | -------------------------------------------------------------------------------- /layers/loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class TripletLoss(nn.Module): 5 | def __init__(self, margin=0): 6 | super(TripletLoss, self).__init__() 7 | self.margin = margin 8 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 9 | 10 | def forward(self, inputs, targets): 11 | n = inputs.size(0) 12 | # Compute pairwise distance, replace by the official when merged 13 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 14 | dist = dist + dist.t() 15 | dist.addmm_(1, -2, inputs, inputs.t()) 16 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 17 | 18 | # For each anchor, find the hardest positive and negative 19 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 20 | dist_ap, dist_an = [], [] 21 | for i in range(n): 22 | dist_ap.append(dist[i][mask[i]].max()) 23 | dist_an.append(dist[i][mask[i] == 0].min()) 24 | dist_ap = torch.stack(dist_ap) 25 | dist_an = torch.stack(dist_an) 26 | 27 | # Compute ranking hinge loss 28 | y = dist_an.data.new() 29 | y.resize_as_(dist_an.data) 30 | y.fill_(1) 31 | loss = self.ranking_loss(dist_an, dist_ap, y) 32 | prec = dist_an.data > dist_ap.data 33 | length = torch.sqrt((inputs * inputs).sum(1)).mean() 34 | return loss, dist_ap, dist_an 35 | -------------------------------------------------------------------------------- /configs/default/strategy.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | strategy_cfg = CfgNode() 4 | 5 | strategy_cfg.prefix = "baseline" 6 | 7 | # setting for loader 8 | strategy_cfg.sample_method = "random" 9 | strategy_cfg.batch_size = 128 10 | strategy_cfg.p_size = 16 11 | strategy_cfg.k_size = 8 12 | 13 | # setting for loss 14 | strategy_cfg.classification = True 15 | strategy_cfg.triplet = False 16 | strategy_cfg.center_cluster = False 17 | strategy_cfg.center = False 18 | 19 | # setting for metric learning 20 | strategy_cfg.margin = 0.3 21 | strategy_cfg.weight_KL = 3.0 22 | strategy_cfg.weight_sid = 1.0 23 | strategy_cfg.weight_sep = 1.0 24 | strategy_cfg.update_rate = 1.0 25 | 26 | # settings for optimizer 27 | strategy_cfg.optimizer = "sgd" 28 | strategy_cfg.lr = 0.1 29 | strategy_cfg.wd = 5e-4 30 | strategy_cfg.lr_step = [40] 31 | 32 | strategy_cfg.fp16 = False 33 | 34 | strategy_cfg.num_epoch = 60 35 | 36 | # settings for dataset 37 | strategy_cfg.dataset = "sysu" 38 | strategy_cfg.image_size = (384, 128) 39 | 40 | # settings for augmentation 41 | strategy_cfg.random_flip = True 42 | strategy_cfg.random_crop = True 43 | strategy_cfg.random_erase = True 44 | strategy_cfg.color_jitter = False 45 | strategy_cfg.padding = 10 46 | 47 | # settings for base architecture 48 | strategy_cfg.drop_last_stride = False 49 | strategy_cfg.pattern_attention = False 50 | strategy_cfg.modality_attention = 0 51 | strategy_cfg.mutual_learning = False 52 | strategy_cfg.rerank = False 53 | strategy_cfg.num_parts = 6 54 | 55 | # logging 56 | strategy_cfg.eval_interval = -1 57 | strategy_cfg.start_eval = 60 58 | strategy_cfg.log_period = 10 59 | 60 | # testing 61 | strategy_cfg.resume = '' -------------------------------------------------------------------------------- /layers/loss/local_center_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class CenterTripletLoss(nn.Module): 5 | def __init__(self, k_size, margin=0): 6 | super(CenterTripletLoss, self).__init__() 7 | self.margin = margin 8 | self.k_size = k_size 9 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 10 | 11 | def forward(self, inputs, targets): 12 | n = inputs.size(0) 13 | 14 | # Come to centers 15 | centers = [] 16 | for i in range(n): 17 | centers.append(inputs[targets == targets[i]].mean(0)) 18 | centers = torch.stack(centers) 19 | 20 | dist_pc = (inputs - centers)**2 21 | dist_pc = dist_pc.sum(1) 22 | dist_pc = dist_pc.sqrt() 23 | 24 | # Compute pairwise distance, replace by the official when merged 25 | dist = torch.pow(centers, 2).sum(dim=1, keepdim=True).expand(n, n) 26 | dist = dist + dist.t() 27 | dist.addmm_(1, -2, centers, centers.t()) 28 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 29 | 30 | # For each anchor, find the hardest positive and negative 31 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 32 | dist_an, dist_ap = [], [] 33 | for i in range(0, n, self.k_size): 34 | dist_an.append( (self.margin - dist[i][mask[i] == 0]).clamp(min=0.0).mean() ) 35 | dist_an = torch.stack(dist_an) 36 | 37 | # Compute ranking hinge loss 38 | y = dist_an.data.new() 39 | y.resize_as_(dist_an.data) 40 | y.fill_(1) 41 | loss = dist_pc.mean() + dist_an.mean() 42 | return loss, dist_pc.mean(), dist_an.mean() 43 | -------------------------------------------------------------------------------- /layers/loss/center_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CenterLoss(nn.Module): 6 | """Center loss. 7 | 8 | Reference: 9 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | feat_dim (int): feature dimension. 14 | """ 15 | 16 | def __init__(self, num_classes, feat_dim, reduction='mean'): 17 | super(CenterLoss, self).__init__() 18 | self.num_classes = num_classes 19 | self.feat_dim = feat_dim 20 | self.reduction = reduction 21 | 22 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 23 | 24 | def forward(self, x, labels): 25 | """ 26 | Args: 27 | x: feature matrix with shape (batch_size, feat_dim). 28 | labels: ground truth labels with shape (batch_size). 29 | """ 30 | batch_size = x.size(0) 31 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 32 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 33 | distmat.addmm_(1, -2, x, self.centers.t()) 34 | 35 | classes = torch.arange(self.num_classes).to(device=x.device, dtype=torch.long) 36 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 37 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 38 | 39 | loss = distmat * mask.float() 40 | 41 | if self.reduction == 'mean': 42 | loss = loss.mean() 43 | elif self.reduction == 'sum': 44 | loss = loss.sum() 45 | 46 | return loss 47 | -------------------------------------------------------------------------------- /layers/module/CBAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ChannelAttention(nn.Module): 6 | def __init__(self, in_planes, ratio=16): 7 | super(ChannelAttention, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.max_pool = nn.AdaptiveMaxPool2d(1) 10 | 11 | self.sharedMLP = nn.Sequential( 12 | nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), 13 | nn.ReLU(), 14 | nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) 15 | self.sigmoid = nn.Sigmoid() 16 | 17 | def forward(self, x): 18 | avgout = self.sharedMLP(self.avg_pool(x)) 19 | maxout = self.sharedMLP(self.max_pool(x)) 20 | return self.sigmoid(avgout + maxout) 21 | 22 | 23 | class SpatialAttention(nn.Module): 24 | def __init__(self, kernel_size=7): 25 | super(SpatialAttention, self).__init__() 26 | assert kernel_size in (3,7), "kernel size must be 3 or 7" 27 | padding = 3 if kernel_size == 7 else 1 28 | 29 | self.conv = nn.Conv2d(2,1,kernel_size, padding=padding, bias=False) 30 | self.sigmoid = nn.Sigmoid() 31 | 32 | def forward(self, x): 33 | avgout = torch.mean(x, dim=1, keepdim=True) 34 | maxout, _ = torch.max(x, dim=1, keepdim=True) 35 | x = torch.cat([avgout, maxout], dim=1) 36 | x = self.conv(x) 37 | return self.sigmoid(x) 38 | 39 | 40 | class cbam(nn.Module): 41 | def __init__(self, planes): 42 | super(cbam, self).__init__() 43 | self.ca = ChannelAttention(planes) 44 | self.sa = SpatialAttention() 45 | 46 | def forward(self, x): 47 | x = self.ca(x) * x 48 | x = self.sa(x) * x 49 | return x -------------------------------------------------------------------------------- /utils/tsne.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import scipy.io as sio 5 | import matplotlib as mpl 6 | 7 | mpl.use('AGG') 8 | import matplotlib.pyplot as plt 9 | from sklearn.manifold import TSNE 10 | 11 | if __name__ == '__main__': 12 | test_ids = [ 13 | 6, 10, 17, 21, 24, 25, 27, 28, 31, 34, 36, 37, 40, 41, 42, 43, 44, 45, 49, 50, 51, 54, 63, 69, 75, 80, 81, 82, 14 | 83, 84, 85, 86, 87, 88, 89, 90, 93, 102, 104, 105, 106, 108, 112, 116, 117, 122, 125, 129, 130, 134, 138, 139, 15 | 150, 152, 162, 166, 167, 170, 172, 176, 185, 190, 192, 202, 204, 207, 210, 215, 223, 229, 232, 237, 252, 253, 16 | 257, 259, 263, 266, 269, 272, 273, 274, 275, 282, 285, 291, 300, 301, 302, 303, 307, 312, 315, 318, 331, 333 17 | ] 18 | random.seed(0) 19 | tsne = TSNE(n_components=2, init='pca') 20 | selected_ids = random.sample(test_ids, 20) 21 | plt.figure(figsize=(5, 5)) 22 | 23 | # features without dual path 24 | q_mat_path = 'features/sysu/query-sysu-test-nodual-nore-adam-16x8-grey_model_150.mat' 25 | g_mat_path = 'features/sysu/gallery-sysu-test-nodual-nore-adam-16x8-grey_model_150.mat' 26 | 27 | mat = sio.loadmat(q_mat_path) 28 | q_feats = mat["feat"] 29 | q_ids = mat["ids"].squeeze() 30 | flag = np.in1d(q_ids, selected_ids) 31 | q_feats = q_feats[flag] 32 | 33 | mat = sio.loadmat(g_mat_path) 34 | g_feats = mat["feat"] 35 | g_ids = mat["ids"].squeeze() 36 | flag = np.in1d(g_ids, selected_ids) 37 | g_feats = g_feats[flag] 38 | 39 | embed = tsne.fit_transform(np.concatenate([q_feats, g_feats], axis=0)) 40 | c = ['r'] * q_feats.shape[0] + ['b'] * g_feats.shape[0] 41 | # plt.subplot(1, 2, 1) 42 | plt.scatter(embed[:, 0], embed[:, 1], c=c) 43 | 44 | # # features with dual path 45 | # q_mat_path = 'features/sysu/query-sysu-test-dual-nore-separatelayer12-0.05_model_30.mat' 46 | # g_mat_path = 'features/sysu/gallery-sysu-test-dual-nore-separatelayer12-0.05_model_30.mat' 47 | # 48 | # mat = sio.loadmat(q_mat_path) 49 | # q_feats = mat["feat"] 50 | # q_ids = mat["ids"].squeeze() 51 | # flag = np.in1d(q_ids, selected_ids) 52 | # q_feats = q_feats[flag] 53 | # 54 | # mat = sio.loadmat(g_mat_path) 55 | # g_feats = mat["feat"] 56 | # g_ids = mat["ids"].squeeze() 57 | # flag = np.in1d(g_ids, selected_ids) 58 | # g_feats = g_feats[flag] 59 | # 60 | # embed = tsne.fit_transform(np.concatenate([q_feats, g_feats], axis=0)) 61 | # c = ['r'] * q_feats.shape[0] + ['b'] * g_feats.shape[0] 62 | # plt.subplot(1, 2, 2) 63 | # plt.scatter(embed[:, 0], embed[:, 1], c=c) 64 | 65 | plt.tight_layout() 66 | plt.savefig('tsne-adv-layer2-separate-l2.jpg') 67 | -------------------------------------------------------------------------------- /engine/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | from ignite.exceptions import NotComputableError 5 | from ignite.metrics import Metric, Accuracy 6 | 7 | 8 | class ScalarMetric(Metric): 9 | 10 | def update(self, value): 11 | self.sum_metric += value 12 | self.sum_inst += 1 13 | 14 | def reset(self): 15 | self.sum_inst = 0 16 | self.sum_metric = 0 17 | 18 | def compute(self): 19 | if self.sum_inst == 0: 20 | raise NotComputableError('Accuracy must have at least one example before it can be computed') 21 | return self.sum_metric / self.sum_inst 22 | 23 | 24 | class IgnoreAccuracy(Accuracy): 25 | def __init__(self, ignore_index=-1): 26 | super(IgnoreAccuracy, self).__init__() 27 | 28 | self.ignore_index = ignore_index 29 | 30 | def reset(self): 31 | self._num_correct = 0 32 | self._num_examples = 0 33 | 34 | def update(self, output): 35 | 36 | y_pred, y = self._check_shape(output) 37 | self._check_type((y_pred, y)) 38 | 39 | if self._type == "binary": 40 | indices = torch.round(y_pred).type(y.type()) 41 | elif self._type == "multiclass": 42 | indices = torch.max(y_pred, dim=1)[1] 43 | 44 | correct = torch.eq(indices, y).view(-1) 45 | ignore = torch.eq(y, self.ignore_index).view(-1) 46 | self._num_correct += torch.sum(correct).item() 47 | self._num_examples += correct.shape[0] - ignore.sum().item() 48 | 49 | def compute(self): 50 | if self._num_examples == 0: 51 | raise NotComputableError('Accuracy must have at least one example before it can be computed') 52 | return self._num_correct / self._num_examples 53 | 54 | 55 | class AutoKVMetric(Metric): 56 | def __init__(self): 57 | self.kv_sum_metric = defaultdict(lambda: torch.tensor(0., device="cuda")) 58 | self.kv_sum_inst = defaultdict(lambda: torch.tensor(0., device="cuda")) 59 | 60 | self.kv_metric = defaultdict(lambda: 0) 61 | 62 | super(AutoKVMetric, self).__init__() 63 | 64 | def update(self, output): 65 | if not isinstance(output, dict): 66 | raise TypeError('The output must be a key-value dict.') 67 | 68 | for k in output.keys(): 69 | self.kv_sum_metric[k].add_(output[k]) 70 | self.kv_sum_inst[k].add_(1) 71 | 72 | def reset(self): 73 | for k in self.kv_sum_metric.keys(): 74 | self.kv_sum_metric[k].zero_() 75 | self.kv_sum_inst[k].zero_() 76 | self.kv_metric[k] = 0 77 | 78 | def compute(self): 79 | for k in self.kv_sum_metric.keys(): 80 | if self.kv_sum_inst[k] == 0: 81 | continue 82 | # raise NotComputableError('Accuracy must have at least one example before it can be computed') 83 | 84 | metric_value = self.kv_sum_metric[k] / self.kv_sum_inst[k] 85 | self.kv_metric[k] = metric_value.item() 86 | 87 | return self.kv_metric 88 | -------------------------------------------------------------------------------- /engine/engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from apex import amp 5 | from ignite.engine import Engine 6 | from ignite.engine import Events 7 | from torch.autograd import no_grad 8 | from utils.calc_acc import calc_acc 9 | from torch.nn import functional as F 10 | 11 | 12 | def create_train_engine(model, optimizer, non_blocking=False): 13 | device = torch.device("cuda", torch.cuda.current_device()) 14 | 15 | def _process_func(engine, batch): 16 | model.train() 17 | 18 | data, labels, cam_ids, img_paths, img_ids = batch 19 | epoch = engine.state.epoch 20 | 21 | data = data.to(device, non_blocking=non_blocking) 22 | labels = labels.to(device, non_blocking=non_blocking) 23 | cam_ids = cam_ids.to(device, non_blocking=non_blocking) 24 | 25 | optimizer.zero_grad() 26 | 27 | loss, metric = model(data, labels, 28 | cam_ids=cam_ids, 29 | epoch=epoch) 30 | 31 | with amp.scale_loss(loss, optimizer) as scaled_loss: 32 | scaled_loss.backward() 33 | 34 | optimizer.step() 35 | 36 | return metric 37 | 38 | return Engine(_process_func) 39 | 40 | 41 | def create_eval_engine(model, non_blocking=False): 42 | device = torch.device("cuda", torch.cuda.current_device()) 43 | 44 | def _process_func(engine, batch): 45 | model.eval() 46 | 47 | data, labels, cam_ids, img_paths = batch[:4] 48 | 49 | data = data.to(device, non_blocking=non_blocking) 50 | 51 | with no_grad(): 52 | feat = model(data, cam_ids=cam_ids.to(device, non_blocking=non_blocking)) 53 | 54 | return feat.data.float().cpu(), labels, cam_ids, np.array(img_paths) 55 | 56 | engine = Engine(_process_func) 57 | 58 | @engine.on(Events.EPOCH_STARTED) 59 | def clear_data(engine): 60 | # feat list 61 | if not hasattr(engine.state, "feat_list"): 62 | setattr(engine.state, "feat_list", []) 63 | else: 64 | engine.state.feat_list.clear() 65 | 66 | # id_list 67 | if not hasattr(engine.state, "id_list"): 68 | setattr(engine.state, "id_list", []) 69 | else: 70 | engine.state.id_list.clear() 71 | 72 | # cam list 73 | if not hasattr(engine.state, "cam_list"): 74 | setattr(engine.state, "cam_list", []) 75 | else: 76 | engine.state.cam_list.clear() 77 | 78 | # img path list 79 | if not hasattr(engine.state, "img_path_list"): 80 | setattr(engine.state, "img_path_list", []) 81 | else: 82 | engine.state.img_path_list.clear() 83 | 84 | @engine.on(Events.ITERATION_COMPLETED) 85 | def store_data(engine): 86 | engine.state.feat_list.append(engine.state.output[0]) 87 | engine.state.id_list.append(engine.state.output[1]) 88 | engine.state.cam_list.append(engine.state.output[2]) 89 | engine.state.img_path_list.append(engine.state.output[3]) 90 | 91 | return engine 92 | -------------------------------------------------------------------------------- /layers/module/NonLocal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class NonLocalBlockND(nn.Module): 7 | """ 8 | 调用过程 9 | NONLocalBlock2D(in_channels=32), 10 | super(NONLocalBlock2D, self).__init__(in_channels, 11 | inter_channels=inter_channels, 12 | dimension=2, sub_sample=sub_sample, 13 | bn_layer=bn_layer) 14 | """ 15 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 16 | super(NonLocalBlockND, self).__init__() 17 | 18 | assert dimension in [1, 2, 3] 19 | 20 | self.dimension = dimension 21 | self.sub_sample = sub_sample 22 | 23 | self.in_channels = in_channels 24 | self.inter_channels = inter_channels 25 | 26 | if self.inter_channels is None: 27 | self.inter_channels = in_channels // 2 28 | if self.inter_channels == 0: 29 | self.inter_channels = 1 30 | 31 | if dimension == 3: 32 | conv_nd = nn.Conv3d 33 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 34 | bn = nn.BatchNorm3d 35 | elif dimension == 2: 36 | conv_nd = nn.Conv2d 37 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 38 | bn = nn.BatchNorm2d 39 | else: 40 | conv_nd = nn.Conv1d 41 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 42 | bn = nn.BatchNorm1d 43 | 44 | self.g = conv_nd(in_channels=self.in_channels, 45 | out_channels=self.inter_channels, 46 | kernel_size=1, 47 | stride=1, 48 | padding=0) 49 | 50 | if bn_layer: 51 | self.W = nn.Sequential( 52 | conv_nd(in_channels=self.inter_channels, 53 | out_channels=self.in_channels, 54 | kernel_size=1, 55 | stride=1, 56 | padding=0), bn(self.in_channels)) 57 | nn.init.constant_(self.W[1].weight, 0) 58 | nn.init.constant_(self.W[1].bias, 0) 59 | else: 60 | self.W = conv_nd(in_channels=self.inter_channels, 61 | out_channels=self.in_channels, 62 | kernel_size=1, 63 | stride=1, 64 | padding=0) 65 | nn.init.constant_(self.W.weight, 0) 66 | nn.init.constant_(self.W.bias, 0) 67 | 68 | self.theta = conv_nd(in_channels=self.in_channels, 69 | out_channels=self.inter_channels, 70 | kernel_size=1, 71 | stride=1, 72 | padding=0) 73 | self.phi = conv_nd(in_channels=self.in_channels, 74 | out_channels=self.inter_channels, 75 | kernel_size=1, 76 | stride=1, 77 | padding=0) 78 | 79 | if sub_sample: 80 | self.g = nn.Sequential(self.g, max_pool_layer) 81 | self.phi = nn.Sequential(self.phi, max_pool_layer) 82 | 83 | def forward(self, x): 84 | ''' 85 | :param x: (b, c, h, w) 86 | :return: 87 | ''' 88 | 89 | batch_size = x.size(0) 90 | 91 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)#[bs, c, w*h] 92 | g_x = g_x.permute(0, 2, 1) 93 | 94 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 95 | theta_x = theta_x.permute(0, 2, 1) 96 | 97 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 98 | 99 | f = torch.matmul(theta_x, phi_x) 100 | 101 | # print(f.shape) 102 | 103 | f_div_C = F.softmax(f, dim=-1) 104 | 105 | y = torch.matmul(f_div_C, g_x) 106 | y = y.permute(0, 2, 1).contiguous() 107 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 108 | W_y = self.W(y) 109 | z = W_y + x 110 | return z -------------------------------------------------------------------------------- /utils/rerank.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def k_reciprocal_neigh( initial_rank, i, k1): 5 | forward_k_neigh_index = initial_rank[i,:k1+1] 6 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 7 | fi = np.where(backward_k_neigh_index==i)[0] 8 | return forward_k_neigh_index[fi] 9 | 10 | def pairwise_distance(query_features, gallery_features): 11 | x = query_features 12 | y = gallery_features 13 | m, n = x.size(0), y.size(0) 14 | x = x.view(m, -1) 15 | y = y.view(n, -1) 16 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 17 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 18 | dist.addmm_(1, -2, x, y.t()) 19 | return dist 20 | 21 | def re_ranking(q_feat, g_feat, k1=20, k2=6, lambda_value=0.3, eval_type=True): 22 | # The following naming, e.g. gallery_num, is different from outer scope. 23 | # Don't care about it. 24 | feats = torch.cat([q_feat, g_feat], 0) 25 | dist = pairwise_distance(feats, feats) 26 | original_dist = dist.clone().numpy() 27 | all_num = original_dist.shape[0] 28 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 29 | V = np.zeros_like(original_dist).astype(np.float16) 30 | 31 | query_num = q_feat.size(0) 32 | all_num = original_dist.shape[0] 33 | if eval_type: 34 | dist[:, query_num:] = dist.max() 35 | dist = dist.numpy() 36 | initial_rank = np.argsort(dist).astype(np.int32) 37 | 38 | # print("start re-ranking") 39 | for i in range(all_num): 40 | # k-reciprocal neighbors 41 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 42 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 43 | fi = np.where(backward_k_neigh_index == i)[0] 44 | k_reciprocal_index = forward_k_neigh_index[fi] 45 | k_reciprocal_expansion_index = k_reciprocal_index 46 | for j in range(len(k_reciprocal_index)): 47 | candidate = k_reciprocal_index[j] 48 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 49 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 50 | :int(np.around(k1 / 2)) + 1] 51 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 52 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 53 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 54 | candidate_k_reciprocal_index): 55 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 56 | 57 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 58 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 59 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 60 | original_dist = original_dist[:query_num, ] 61 | if k2 != 1: 62 | V_qe = np.zeros_like(V, dtype=np.float16) 63 | for i in range(all_num): 64 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 65 | V = V_qe 66 | del V_qe 67 | del initial_rank 68 | invIndex = [] 69 | for i in range(all_num): 70 | invIndex.append(np.where(V[:, i] != 0)[0]) 71 | 72 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 73 | 74 | for i in range(query_num): 75 | temp_min = np.zeros(shape=[1, all_num], dtype=np.float16) 76 | indNonZero = np.where(V[i, :] != 0)[0] 77 | indImages = [] 78 | indImages = [invIndex[ind] for ind in indNonZero] 79 | for j in range(len(indNonZero)): 80 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 81 | V[indImages[j], indNonZero[j]]) 82 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 83 | 84 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 85 | del original_dist 86 | del V 87 | del jaccard_dist 88 | final_dist = final_dist[:query_num, query_num:] 89 | return final_dist -------------------------------------------------------------------------------- /utils/eval_regdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch 5 | from sklearn.preprocessing import normalize 6 | from torch.nn import functional as F 7 | from .rerank import re_ranking, pairwise_distance 8 | 9 | 10 | def get_gallery_names(perm, cams, ids, trial_id, num_shots=1): 11 | names = [] 12 | for cam in cams: 13 | cam_perm = perm[cam - 1][0].squeeze() 14 | for i in ids: 15 | instance_id = cam_perm[i - 1][trial_id][:num_shots] 16 | names.extend(['cam{}/{:0>4d}/{:0>4d}'.format(cam, i, ins) for ins in instance_id.tolist()]) 17 | 18 | return names 19 | 20 | 21 | def get_unique(array): 22 | _, idx = np.unique(array, return_index=True) 23 | return array[np.sort(idx)] 24 | 25 | 26 | def get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 27 | gallery_unique_count = get_unique(gallery_ids).shape[0] 28 | match_counter = np.zeros((gallery_unique_count,)) 29 | 30 | result = gallery_ids[sorted_indices] 31 | cam_locations_result = gallery_cam_ids[sorted_indices] 32 | 33 | valid_probe_sample_count = 0 34 | 35 | for probe_index in range(sorted_indices.shape[0]): 36 | # remove gallery samples from the same camera of the probe 37 | result_i = result[probe_index, :] 38 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1 39 | 40 | # remove the -1 entries from the label result 41 | result_i = np.array([i for i in result_i if i != -1]) 42 | 43 | # remove duplicated id in "stable" manner 44 | result_i_unique = get_unique(result_i) 45 | 46 | # match for probe i 47 | match_i = np.equal(result_i_unique, query_ids[probe_index]) 48 | 49 | if np.sum(match_i) != 0: # if there is true matching in gallery 50 | valid_probe_sample_count += 1 51 | match_counter += match_i 52 | 53 | rank = match_counter / valid_probe_sample_count 54 | cmc = np.cumsum(rank) 55 | return cmc 56 | 57 | 58 | def get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 59 | result = gallery_ids[sorted_indices] 60 | cam_locations_result = gallery_cam_ids[sorted_indices] 61 | 62 | valid_probe_sample_count = 0 63 | avg_precision_sum = 0 64 | 65 | for probe_index in range(sorted_indices.shape[0]): 66 | # remove gallery samples from the same camera of the probe 67 | result_i = result[probe_index, :] 68 | result_i[cam_locations_result[probe_index, :] == query_cam_ids[probe_index]] = -1 69 | 70 | # remove the -1 entries from the label result 71 | result_i = np.array([i for i in result_i if i != -1]) 72 | 73 | # match for probe i 74 | match_i = result_i == query_ids[probe_index] 75 | true_match_count = np.sum(match_i) 76 | 77 | if true_match_count != 0: # if there is true matching in gallery 78 | valid_probe_sample_count += 1 79 | true_match_rank = np.where(match_i)[0] 80 | 81 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1)) 82 | avg_precision_sum += ap 83 | 84 | mAP = avg_precision_sum / valid_probe_sample_count 85 | return mAP 86 | 87 | def eval_regdb(query_feats, query_ids, query_cam_ids, gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, rerank=False): 88 | # gallery_feats = F.normalize(gallery_feats, dim=1) 89 | # query_feats = F.normalize(query_feats, dim=1) 90 | 91 | if rerank: 92 | dist_mat = re_ranking(query_feats, gallery_feats, eval_type=False) 93 | else: 94 | dist_mat = pairwise_distance(query_feats, gallery_feats) 95 | # dist_mat = -torch.mm(query_feats, gallery_feats.t()) 96 | 97 | sorted_indices = np.argsort(dist_mat, axis=1) 98 | 99 | mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 100 | cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 101 | 102 | r1 = cmc[0] 103 | r5 = cmc[4] 104 | r10 = cmc[9] 105 | r20 = cmc[19] 106 | 107 | r1 = r1 * 100 108 | r5 = r5 * 100 109 | r10 = r10 * 100 110 | r20 = r20 * 100 111 | mAP = mAP * 100 112 | 113 | perf = 'r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f}, mAP = {:.2f}' 114 | logging.info(perf.format(r1, r10, r20, mAP)) 115 | 116 | return mAP, r1, r5, r10, r20 117 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision.transforms as T 5 | 6 | from torch.utils.data import DataLoader 7 | from data.dataset import SYSUDataset 8 | from data.dataset import RegDBDataset 9 | from data.dataset import MarketDataset 10 | 11 | from data.sampler import CrossModalityIdentitySampler 12 | from data.sampler import CrossModalityRandomSampler 13 | from data.sampler import RandomIdentitySampler 14 | from data.sampler import NormTripletSampler 15 | 16 | 17 | def collate_fn(batch): # img, label, cam_id, img_path, img_id 18 | samples = list(zip(*batch)) 19 | 20 | data = [torch.stack(x, 0) for i, x in enumerate(samples) if i != 3] 21 | data.insert(3, samples[3]) 22 | return data 23 | 24 | 25 | def get_train_loader(dataset, root, sample_method, batch_size, p_size, k_size, image_size, random_flip=False, random_crop=False, 26 | random_erase=False, color_jitter=False, padding=0, num_workers=4): 27 | # data pre-processing 28 | t = [T.Resize(image_size)] 29 | 30 | if random_flip: 31 | t.append(T.RandomHorizontalFlip()) 32 | 33 | if color_jitter: 34 | t.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)) 35 | 36 | if random_crop: 37 | t.extend([T.Pad(padding, fill=127), T.RandomCrop(image_size)]) 38 | 39 | t.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 40 | 41 | if random_erase: 42 | t.append(T.RandomErasing()) 43 | # t.append(Jigsaw()) 44 | 45 | transform = T.Compose(t) 46 | 47 | # dataset 48 | if dataset == 'sysu': 49 | train_dataset = SYSUDataset(root, mode='train', transform=transform) 50 | elif dataset == 'regdb': 51 | train_dataset = RegDBDataset(root, mode='train', transform=transform) 52 | elif dataset == 'market': 53 | train_dataset = MarketDataset(root, mode='train', transform=transform) 54 | 55 | # sampler 56 | assert sample_method in ['random', 'identity_uniform', 'identity_random', 'norm_triplet'] 57 | if sample_method == 'identity_uniform': 58 | batch_size = p_size * k_size 59 | sampler = CrossModalityIdentitySampler(train_dataset, p_size, k_size) 60 | elif sample_method == 'identity_random': 61 | batch_size = p_size * k_size 62 | sampler = RandomIdentitySampler(train_dataset, p_size * k_size, k_size) 63 | elif sample_method == 'norm_triplet': 64 | batch_size = p_size * k_size 65 | sampler = NormTripletSampler(train_dataset, p_size * k_size, k_size) 66 | else: 67 | sampler = CrossModalityRandomSampler(train_dataset, batch_size) 68 | 69 | # loader 70 | train_loader = DataLoader(train_dataset, batch_size, sampler=sampler, drop_last=True, pin_memory=True, 71 | collate_fn=collate_fn, num_workers=num_workers) 72 | 73 | return train_loader 74 | 75 | 76 | def get_test_loader(dataset, root, batch_size, image_size, num_workers=4): 77 | # transform 78 | transform = T.Compose([ 79 | T.Resize(image_size), 80 | T.ToTensor(), 81 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 82 | ]) 83 | 84 | # dataset 85 | if dataset == 'sysu': 86 | gallery_dataset = SYSUDataset(root, mode='gallery', transform=transform) 87 | query_dataset = SYSUDataset(root, mode='query', transform=transform) 88 | elif dataset == 'regdb': 89 | gallery_dataset = RegDBDataset(root, mode='gallery', transform=transform) 90 | query_dataset = RegDBDataset(root, mode='query', transform=transform) 91 | elif dataset == 'market': 92 | gallery_dataset = MarketDataset(root, mode='gallery', transform=transform) 93 | query_dataset = MarketDataset(root, mode='query', transform=transform) 94 | 95 | # dataloader 96 | query_loader = DataLoader(dataset=query_dataset, 97 | batch_size=batch_size, 98 | shuffle=False, 99 | pin_memory=True, 100 | drop_last=False, 101 | collate_fn=collate_fn, 102 | num_workers=num_workers) 103 | 104 | gallery_loader = DataLoader(dataset=gallery_dataset, 105 | batch_size=batch_size, 106 | shuffle=False, 107 | pin_memory=True, 108 | drop_last=False, 109 | collate_fn=collate_fn, 110 | num_workers=num_workers) 111 | 112 | return gallery_loader, query_loader 113 | -------------------------------------------------------------------------------- /utils/eval_sysu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import numpy as np 5 | from sklearn.preprocessing import normalize 6 | from .rerank import re_ranking, pairwise_distance 7 | from torch.nn import functional as F 8 | 9 | 10 | def get_gallery_names(perm, cams, ids, trial_id, num_shots=1): 11 | names = [] 12 | for cam in cams: 13 | cam_perm = perm[cam - 1][0].squeeze() 14 | for i in ids: 15 | instance_id = cam_perm[i - 1][trial_id][:num_shots] 16 | names.extend(['cam{}/{:0>4d}/{:0>4d}'.format(cam, i, ins) for ins in instance_id.tolist()]) 17 | 18 | return names 19 | 20 | 21 | def get_unique(array): 22 | _, idx = np.unique(array, return_index=True) 23 | return array[np.sort(idx)] 24 | 25 | 26 | def get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 27 | gallery_unique_count = get_unique(gallery_ids).shape[0] 28 | match_counter = np.zeros((gallery_unique_count,)) 29 | 30 | result = gallery_ids[sorted_indices] 31 | cam_locations_result = gallery_cam_ids[sorted_indices] 32 | 33 | valid_probe_sample_count = 0 34 | 35 | for probe_index in range(sorted_indices.shape[0]): 36 | # remove gallery samples from the same camera of the probe 37 | result_i = result[probe_index, :] 38 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1 39 | 40 | # remove the -1 entries from the label result 41 | result_i = np.array([i for i in result_i if i != -1]) 42 | 43 | # remove duplicated id in "stable" manner 44 | result_i_unique = get_unique(result_i) 45 | 46 | # match for probe i 47 | match_i = np.equal(result_i_unique, query_ids[probe_index]) 48 | 49 | if np.sum(match_i) != 0: # if there is true matching in gallery 50 | valid_probe_sample_count += 1 51 | match_counter += match_i 52 | 53 | rank = match_counter / valid_probe_sample_count 54 | cmc = np.cumsum(rank) 55 | return cmc 56 | 57 | 58 | def get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 59 | result = gallery_ids[sorted_indices] 60 | cam_locations_result = gallery_cam_ids[sorted_indices] 61 | 62 | valid_probe_sample_count = 0 63 | avg_precision_sum = 0 64 | 65 | for probe_index in range(sorted_indices.shape[0]): 66 | # remove gallery samples from the same camera of the probe 67 | result_i = result[probe_index, :] 68 | result_i[cam_locations_result[probe_index, :] == query_cam_ids[probe_index]] = -1 69 | 70 | # remove the -1 entries from the label result 71 | result_i = np.array([i for i in result_i if i != -1]) 72 | 73 | # match for probe i 74 | match_i = result_i == query_ids[probe_index] 75 | true_match_count = np.sum(match_i) 76 | 77 | if true_match_count != 0: # if there is true matching in gallery 78 | valid_probe_sample_count += 1 79 | true_match_rank = np.where(match_i)[0] 80 | 81 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1)) 82 | avg_precision_sum += ap 83 | 84 | mAP = avg_precision_sum / valid_probe_sample_count 85 | return mAP 86 | 87 | 88 | def eval_sysu(query_feats, query_ids, query_cam_ids, gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, 89 | perm, mode='all', num_shots=1, num_trials=10, rerank=False): 90 | assert mode in ['indoor', 'all'] 91 | 92 | gallery_cams = [1, 2] if mode == 'indoor' else [1, 2, 4, 5] 93 | 94 | # cam2 and cam3 are in the same location 95 | query_cam_ids[np.equal(query_cam_ids, 3)] = 2 96 | query_feats = F.normalize(query_feats, dim=1) 97 | 98 | gallery_indices = np.in1d(gallery_cam_ids, gallery_cams) 99 | 100 | gallery_feats = gallery_feats[gallery_indices] 101 | gallery_feats = F.normalize(gallery_feats, dim=1) 102 | gallery_cam_ids = gallery_cam_ids[gallery_indices] 103 | gallery_ids = gallery_ids[gallery_indices] 104 | gallery_img_paths = gallery_img_paths[gallery_indices] 105 | gallery_names = np.array(['/'.join(os.path.splitext(path)[0].split('/')[-3:]) for path in gallery_img_paths]) 106 | 107 | gallery_id_set = np.unique(gallery_ids) 108 | 109 | mAP, r1, r5, r10, r20 = 0, 0, 0, 0, 0 110 | for t in range(num_trials): 111 | names = get_gallery_names(perm, gallery_cams, gallery_id_set, t, num_shots) 112 | flag = np.in1d(gallery_names, names) 113 | 114 | g_feat = gallery_feats[flag] 115 | g_ids = gallery_ids[flag] 116 | g_cam_ids = gallery_cam_ids[flag] 117 | 118 | if rerank: 119 | dist_mat = re_ranking(query_feats, g_feat) 120 | else: 121 | dist_mat = pairwise_distance(query_feats, g_feat) 122 | # dist_mat = -torch.mm(query_feats, g_feat.permute(1,0)) 123 | 124 | sorted_indices = np.argsort(dist_mat, axis=1) 125 | 126 | mAP += get_mAP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 127 | cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 128 | 129 | r1 += cmc[0] 130 | r5 += cmc[4] 131 | r10 += cmc[9] 132 | r20 += cmc[19] 133 | 134 | r1 = r1 / num_trials * 100 135 | r5 = r5 / num_trials * 100 136 | r10 = r10 / num_trials * 100 137 | r20 = r20 / num_trials * 100 138 | mAP = mAP / num_trials * 100 139 | 140 | perf = '{} num-shot:{} r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f}, mAP = {:.2f}' 141 | logging.info(perf.format(mode, num_shots, r1, r10, r20, mAP)) 142 | 143 | return mAP, r1, r5, r10, r20 144 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pprint 4 | 5 | import torch 6 | import yaml 7 | from apex import amp 8 | from torch import optim 9 | 10 | from data import get_test_loader 11 | from data import get_train_loader 12 | from engine import get_trainer 13 | from models.baseline import Baseline 14 | 15 | # from WarmUpLR import WarmUpStepLR 16 | 17 | 18 | def train(cfg): 19 | # set logger 20 | log_dir = os.path.join("logs/", cfg.dataset, cfg.prefix) 21 | if not os.path.isdir(log_dir): 22 | os.makedirs(log_dir, exist_ok=True) 23 | 24 | logging.basicConfig(format="%(asctime)s %(message)s", 25 | filename=log_dir + "/" + "log.txt", 26 | filemode="w") 27 | 28 | logger = logging.getLogger() 29 | logger.setLevel(logging.INFO) 30 | stream_handler = logging.StreamHandler() 31 | stream_handler.setLevel(logging.INFO) 32 | logger.addHandler(stream_handler) 33 | 34 | logger.info(pprint.pformat(cfg)) 35 | 36 | # training data loader 37 | train_loader = get_train_loader(dataset=cfg.dataset, 38 | root=cfg.data_root, 39 | sample_method=cfg.sample_method, 40 | batch_size=cfg.batch_size, 41 | p_size=cfg.p_size, 42 | k_size=cfg.k_size, 43 | random_flip=cfg.random_flip, 44 | random_crop=cfg.random_crop, 45 | random_erase=cfg.random_erase, 46 | color_jitter=cfg.color_jitter, 47 | padding=cfg.padding, 48 | image_size=cfg.image_size, 49 | num_workers=8) 50 | 51 | # evaluation data loader 52 | gallery_loader, query_loader = None, None 53 | if cfg.eval_interval > 0: 54 | gallery_loader, query_loader = get_test_loader(dataset=cfg.dataset, 55 | root=cfg.data_root, 56 | batch_size=64, 57 | image_size=cfg.image_size, 58 | num_workers=4) 59 | 60 | # model 61 | model = Baseline(num_classes=cfg.num_id, 62 | pattern_attention=cfg.pattern_attention, 63 | modality_attention=cfg.modality_attention, 64 | mutual_learning=cfg.mutual_learning, 65 | drop_last_stride=cfg.drop_last_stride, 66 | triplet=cfg.triplet, 67 | k_size=cfg.k_size, 68 | center_cluster=cfg.center_cluster, 69 | center=cfg.center, 70 | margin=cfg.margin, 71 | num_parts=cfg.num_parts, 72 | weight_KL=cfg.weight_KL, 73 | weight_sid=cfg.weight_sid, 74 | weight_sep=cfg.weight_sep, 75 | update_rate=cfg.update_rate, 76 | classification=cfg.classification) 77 | 78 | def get_parameter_number(net): 79 | total_num = sum(p.numel() for p in net.parameters()) 80 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 81 | return {'Total': total_num, 'Trainable': trainable_num} 82 | 83 | print(get_parameter_number(model)) 84 | 85 | model.cuda() 86 | 87 | # optimizer 88 | assert cfg.optimizer in ['adam', 'sgd'] 89 | if cfg.optimizer == 'adam': 90 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) 91 | else: 92 | optimizer = optim.SGD(model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=cfg.wd) 93 | 94 | # convert model for mixed precision training 95 | model, optimizer = amp.initialize(model, optimizer, enabled=cfg.fp16, opt_level="O1") 96 | if cfg.center: 97 | model.center_loss.centers = model.center_loss.centers.float() 98 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, 99 | milestones=cfg.lr_step, 100 | gamma=0.1) 101 | 102 | if cfg.resume: 103 | checkpoint = torch.load(cfg.resume) 104 | model.load_state_dict(checkpoint) 105 | 106 | # engine 107 | checkpoint_dir = os.path.join("checkpoints", cfg.dataset, cfg.prefix) 108 | engine = get_trainer(dataset=cfg.dataset, 109 | model=model, 110 | optimizer=optimizer, 111 | lr_scheduler=lr_scheduler, 112 | logger=logger, 113 | non_blocking=True, 114 | log_period=cfg.log_period, 115 | save_dir=checkpoint_dir, 116 | prefix=cfg.prefix, 117 | eval_interval=cfg.eval_interval, 118 | start_eval=cfg.start_eval, 119 | gallery_loader=gallery_loader, 120 | query_loader=query_loader, 121 | rerank=cfg.rerank) 122 | 123 | # training 124 | engine.run(train_loader, max_epochs=cfg.num_epoch) 125 | 126 | 127 | if __name__ == '__main__': 128 | import argparse 129 | import random 130 | import numpy as np 131 | from configs.default import strategy_cfg 132 | from configs.default import dataset_cfg 133 | 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument("--cfg", type=str, default="configs/softmax.yml") 136 | args = parser.parse_args() 137 | 138 | # set random seed 139 | seed = 1 140 | random.seed(seed) 141 | np.random.RandomState(seed) 142 | np.random.seed(seed) 143 | torch.manual_seed(seed) 144 | torch.cuda.manual_seed(seed) 145 | 146 | # enable cudnn backend 147 | torch.backends.cudnn.benchmark = True 148 | # torch.backends.cudnn.benchmark = False 149 | # torch.backends.cudnn.deterministic = True 150 | 151 | # load configuration 152 | customized_cfg = yaml.load(open(args.cfg, "r"), Loader=yaml.SafeLoader) 153 | 154 | cfg = strategy_cfg 155 | cfg.merge_from_file(args.cfg) 156 | 157 | dataset_cfg = dataset_cfg.get(cfg.dataset) 158 | 159 | for k, v in dataset_cfg.items(): 160 | cfg[k] = v 161 | 162 | if cfg.sample_method == 'identity_uniform': 163 | cfg.batch_size = cfg.p_size * cfg.k_size 164 | 165 | cfg.freeze() 166 | 167 | train(cfg) 168 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import os.path as osp 4 | from glob import glob 5 | 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | 10 | ''' 11 | Specific dataset classes for person re-identification dataset. 12 | ''' 13 | 14 | 15 | class SYSUDataset(Dataset): 16 | def __init__(self, root, mode='train', transform=None): 17 | assert os.path.isdir(root) 18 | assert mode in ['train', 'gallery', 'query'] 19 | 20 | if mode == 'train': 21 | train_ids = open(os.path.join(root, 'exp', 'train_id.txt')).readline() 22 | val_ids = open(os.path.join(root, 'exp', 'val_id.txt')).readline() 23 | 24 | train_ids = train_ids.strip('\n').split(',') 25 | val_ids = val_ids.strip('\n').split(',') 26 | selected_ids = train_ids + val_ids 27 | else: 28 | test_ids = open(os.path.join(root, 'exp', 'test_id.txt')).readline() 29 | selected_ids = test_ids.strip('\n').split(',') 30 | 31 | selected_ids = [int(i) for i in selected_ids] 32 | num_ids = len(selected_ids) 33 | 34 | img_paths = glob(os.path.join(root, '**/*.jpg'), recursive=True) 35 | img_paths = [path for path in img_paths if int(path.split('/')[-2]) in selected_ids] 36 | 37 | if mode == 'gallery': 38 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (1, 2, 4, 5)] 39 | elif mode == 'query': 40 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (3, 6)] 41 | 42 | img_paths = sorted(img_paths) 43 | self.img_paths = img_paths 44 | self.cam_ids = [int(path.split('/')[-3][-1]) for path in img_paths] 45 | self.num_ids = num_ids 46 | self.transform = transform 47 | 48 | if mode == 'train': 49 | id_map = dict(zip(selected_ids, range(num_ids))) 50 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths] 51 | else: 52 | self.ids = [int(path.split('/')[-2]) for path in img_paths] 53 | 54 | def __len__(self): 55 | return len(self.img_paths) 56 | 57 | def __getitem__(self, item): 58 | path = self.img_paths[item] 59 | img = Image.open(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | 63 | label = torch.tensor(self.ids[item], dtype=torch.long) 64 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long) 65 | item = torch.tensor(item, dtype=torch.long) 66 | 67 | return img, label, cam, path, item 68 | 69 | class RegDBDataset(Dataset): 70 | def __init__(self, root, mode='train', transform=None): 71 | assert os.path.isdir(root) 72 | assert mode in ['train', 'gallery', 'query'] 73 | 74 | def loadIdx(index): 75 | Lines = index.readlines() 76 | idx = [] 77 | for line in Lines: 78 | tmp = line.strip('\n') 79 | tmp = tmp.split(' ') 80 | idx.append(tmp) 81 | return idx 82 | 83 | num = '1' 84 | if mode == 'train': 85 | index_RGB = loadIdx(open(root + '/idx/train_visible_'+num+'.txt','r')) 86 | index_IR = loadIdx(open(root + '/idx/train_thermal_'+num+'.txt','r')) 87 | else: 88 | index_RGB = loadIdx(open(root + '/idx/test_visible_'+num+'.txt','r')) 89 | index_IR = loadIdx(open(root + '/idx/test_thermal_'+num+'.txt','r')) 90 | 91 | if mode == 'gallery': 92 | img_paths = [root + '/' + path for path, _ in index_RGB] 93 | elif mode == 'query': 94 | img_paths = [root + '/' + path for path, _ in index_IR] 95 | else: 96 | img_paths = [root + '/' + path for path, _ in index_RGB] + [root + '/' + path for path, _ in index_IR] 97 | 98 | selected_ids = [int(path.split('/')[-2]) for path in img_paths] 99 | selected_ids = list(set(selected_ids)) 100 | num_ids = len(selected_ids) 101 | 102 | img_paths = sorted(img_paths) 103 | self.img_paths = img_paths 104 | self.cam_ids = [int(path.split('/')[-3] == 'Thermal') + 2 for path in img_paths] 105 | # the visible cams are 1 2 4 5 and thermal cams are 3 6 in sysu 106 | # to simplify the code, visible cam is 2 and thermal cam is 3 in regdb 107 | self.num_ids = num_ids 108 | self.transform = transform 109 | 110 | if mode == 'train': 111 | id_map = dict(zip(selected_ids, range(num_ids))) 112 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths] 113 | else: 114 | self.ids = [int(path.split('/')[-2]) for path in img_paths] 115 | 116 | def __len__(self): 117 | return len(self.img_paths) 118 | 119 | def __getitem__(self, item): 120 | path = self.img_paths[item] 121 | img = Image.open(path) 122 | if self.transform is not None: 123 | img = self.transform(img) 124 | 125 | label = torch.tensor(self.ids[item], dtype=torch.long) 126 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long) 127 | item = torch.tensor(item, dtype=torch.long) 128 | 129 | return img, label, cam, path, item 130 | 131 | 132 | class MarketDataset(Dataset): 133 | def __init__(self, root, mode='train', transform=None): 134 | assert os.path.isdir(root) 135 | assert mode in ['train', 'gallery', 'query'] 136 | 137 | self.transform = transform 138 | 139 | if mode == 'train': 140 | img_paths = glob(os.path.join(root, 'bounding_box_train/*.jpg'), recursive=True) 141 | elif mode == 'gallery': 142 | img_paths = glob(os.path.join(root, 'bounding_box_test/*.jpg'), recursive=True) 143 | elif mode == 'query': 144 | img_paths = glob(os.path.join(root, 'query/*.jpg'), recursive=True) 145 | 146 | pattern = re.compile(r'([-\d]+)_c(\d)') 147 | all_pids = {} 148 | relabel = mode == 'train' 149 | self.img_paths = [] 150 | self.cam_ids = [] 151 | self.ids = [] 152 | for fpath in img_paths: 153 | fname = osp.basename(fpath) 154 | pid, cam = map(int, pattern.search(fname).groups()) 155 | if pid == -1: continue 156 | if relabel: 157 | if pid not in all_pids: 158 | all_pids[pid] = len(all_pids) 159 | else: 160 | if pid not in all_pids: 161 | all_pids[pid] = pid 162 | self.img_paths.append(fpath) 163 | self.ids.append(all_pids[pid]) 164 | self.cam_ids.append(cam - 1) 165 | 166 | def __len__(self): 167 | return len(self.img_paths) 168 | 169 | def __getitem__(self, item): 170 | path = self.img_paths[item] 171 | img = Image.open(path) 172 | if self.transform is not None: 173 | img = self.transform(img) 174 | 175 | label = torch.tensor(self.ids[item], dtype=torch.long) 176 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long) 177 | item = torch.tensor(item, dtype=torch.long) 178 | 179 | return img, label, cam, path, item 180 | -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import numpy as np 4 | import torch 5 | import scipy.io as sio 6 | 7 | from ignite.engine import Events 8 | from ignite.handlers import ModelCheckpoint 9 | from ignite.handlers import Timer 10 | 11 | from engine.engine import create_eval_engine 12 | from engine.engine import create_train_engine 13 | from engine.metric import AutoKVMetric 14 | from utils.eval_sysu import eval_sysu 15 | from utils.eval_regdb import eval_regdb 16 | from configs.default.dataset import dataset_cfg 17 | from configs.default.strategy import strategy_cfg 18 | 19 | def get_trainer(dataset, model, optimizer, lr_scheduler=None, logger=None, writer=None, non_blocking=False, log_period=10, 20 | save_dir="checkpoints", prefix="model", gallery_loader=None, query_loader=None, 21 | eval_interval=None, start_eval=None, rerank=False): 22 | if logger is None: 23 | logger = logging.getLogger() 24 | logger.setLevel(logging.WARN) 25 | 26 | # trainer 27 | trainer = create_train_engine(model, optimizer, non_blocking) 28 | 29 | setattr(trainer, "rerank", rerank) 30 | 31 | # checkpoint handler 32 | handler = ModelCheckpoint(save_dir, prefix, save_interval=eval_interval, n_saved=3, create_dir=True, 33 | save_as_state_dict=True, require_empty=False) 34 | trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {"model": model}) 35 | 36 | # metric 37 | timer = Timer(average=True) 38 | 39 | kv_metric = AutoKVMetric() 40 | 41 | # evaluator 42 | evaluator = None 43 | if not type(eval_interval) == int: 44 | raise TypeError("The parameter 'validate_interval' must be type INT.") 45 | if not type(start_eval) == int: 46 | raise TypeError("The parameter 'start_eval' must be type INT.") 47 | if eval_interval > 0 and gallery_loader is not None and query_loader is not None: 48 | evaluator = create_eval_engine(model, non_blocking) 49 | 50 | @trainer.on(Events.STARTED) 51 | def train_start(engine): 52 | setattr(engine.state, "best_rank1", 0.0) 53 | 54 | @trainer.on(Events.COMPLETED) 55 | def train_completed(engine): 56 | torch.cuda.empty_cache() 57 | 58 | # extract query feature 59 | evaluator.run(query_loader) 60 | 61 | q_feats = torch.cat(evaluator.state.feat_list, dim=0) 62 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 63 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 64 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 65 | 66 | # extract gallery feature 67 | evaluator.run(gallery_loader) 68 | 69 | g_feats = torch.cat(evaluator.state.feat_list, dim=0) 70 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 71 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 72 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 73 | 74 | # print("best rank1={:.2f}%".format(engine.state.best_rank1)) 75 | 76 | if dataset == 'sysu': 77 | perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))[ 78 | 'rand_perm_cam'] 79 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=1, rerank=engine.rerank) 80 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=10, rerank=engine.rerank) 81 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='indoor', num_shots=1, rerank=engine.rerank) 82 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='indoor', num_shots=10, rerank=engine.rerank) 83 | elif dataset == 'regdb': 84 | print('infrared to visible') 85 | eval_regdb(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=engine.rerank) 86 | print('visible to infrared') 87 | eval_regdb(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths, rerank=engine.rerank) 88 | elif dataset == 'market': 89 | eval_regdb(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=engine.rerank) 90 | 91 | 92 | evaluator.state.feat_list.clear() 93 | evaluator.state.id_list.clear() 94 | evaluator.state.cam_list.clear() 95 | evaluator.state.img_path_list.clear() 96 | del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams 97 | 98 | torch.cuda.empty_cache() 99 | 100 | @trainer.on(Events.EPOCH_STARTED) 101 | def epoch_started_callback(engine): 102 | 103 | epoch = engine.state.epoch 104 | if model.mutual_learning: 105 | model.update_rate = min(100 / (epoch + 1), 1.0) * model.update_rate_ 106 | 107 | kv_metric.reset() 108 | timer.reset() 109 | 110 | @trainer.on(Events.EPOCH_COMPLETED) 111 | def epoch_completed_callback(engine): 112 | epoch = engine.state.epoch 113 | 114 | if lr_scheduler is not None: 115 | lr_scheduler.step() 116 | 117 | if epoch % eval_interval == 0: 118 | logger.info("Model saved at {}/{}_model_{}.pth".format(save_dir, prefix, epoch)) 119 | 120 | if evaluator and epoch % eval_interval == 0 and epoch > start_eval: 121 | torch.cuda.empty_cache() 122 | 123 | # extract query feature 124 | evaluator.run(query_loader) 125 | 126 | q_feats = torch.cat(evaluator.state.feat_list, dim=0) 127 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 128 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 129 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 130 | 131 | # extract gallery feature 132 | evaluator.run(gallery_loader) 133 | 134 | g_feats = torch.cat(evaluator.state.feat_list, dim=0) 135 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 136 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 137 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 138 | 139 | if dataset == 'sysu': 140 | perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))[ 141 | 'rand_perm_cam'] 142 | mAP, r1, r5, _, _ = eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=1, rerank=engine.rerank) 143 | elif dataset == 'regdb': 144 | print('infrared to visible') 145 | mAP, r1, r5, _, _ = eval_regdb(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=engine.rerank) 146 | print('visible to infrared') 147 | mAP, r1_, r5, _, _ = eval_regdb(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths, rerank=engine.rerank) 148 | r1 = (r1 + r1_) / 2 149 | elif dataset == 'market': 150 | mAP, r1, r5, _, _ = eval_regdb(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=engine.rerank) 151 | 152 | if r1 > engine.state.best_rank1: 153 | engine.state.best_rank1 = r1 154 | torch.save(model.state_dict(), "{}/model_best.pth".format(save_dir)) 155 | 156 | if writer is not None: 157 | writer.add_scalar('eval/mAP', mAP, epoch) 158 | writer.add_scalar('eval/r1', r1, epoch) 159 | writer.add_scalar('eval/r5', r5, epoch) 160 | 161 | evaluator.state.feat_list.clear() 162 | evaluator.state.id_list.clear() 163 | evaluator.state.cam_list.clear() 164 | evaluator.state.img_path_list.clear() 165 | del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams 166 | 167 | torch.cuda.empty_cache() 168 | 169 | @trainer.on(Events.ITERATION_COMPLETED) 170 | def iteration_complete_callback(engine): 171 | timer.step() 172 | 173 | # print(engine.state.output) 174 | kv_metric.update(engine.state.output) 175 | 176 | epoch = engine.state.epoch 177 | iteration = engine.state.iteration 178 | iter_in_epoch = iteration - (epoch - 1) * len(engine.state.dataloader) 179 | 180 | if iter_in_epoch % log_period == 0 and iter_in_epoch > 0: 181 | batch_size = engine.state.batch[0].size(0) 182 | speed = batch_size / timer.value() 183 | 184 | msg = "Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec" % (epoch, iter_in_epoch, speed) 185 | 186 | metric_dict = kv_metric.compute() 187 | 188 | # log output information 189 | if logger is not None: 190 | for k in sorted(metric_dict.keys()): 191 | msg += "\t%s: %.4f" % (k, metric_dict[k]) 192 | if writer is not None: 193 | writer.add_scalar('metric/{}'.format(k), metric_dict[k], iteration) 194 | 195 | logger.info(msg) 196 | 197 | kv_metric.reset() 198 | timer.reset() 199 | 200 | return trainer 201 | -------------------------------------------------------------------------------- /models/baseline.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | from torch.nn import functional as F 6 | from torch.nn import Parameter 7 | import numpy as np 8 | 9 | import cv2 10 | 11 | from models.resnet import resnet50 12 | from utils.calc_acc import calc_acc 13 | 14 | from layers import TripletLoss 15 | from layers import CenterTripletLoss 16 | from layers import CenterLoss 17 | from layers import cbam 18 | from layers import NonLocalBlockND 19 | 20 | class Baseline(nn.Module): 21 | def __init__(self, num_classes=None, drop_last_stride=False, pattern_attention=False, modality_attention=0, mutual_learning=False, **kwargs): 22 | super(Baseline, self).__init__() 23 | 24 | self.drop_last_stride = drop_last_stride 25 | self.pattern_attention = pattern_attention 26 | self.modality_attention = modality_attention 27 | self.mutual_learning = mutual_learning 28 | 29 | self.backbone = resnet50(pretrained=True, drop_last_stride=drop_last_stride, modality_attention=modality_attention) 30 | 31 | self.base_dim = 2048 32 | self.dim = 0 33 | self.part_num = kwargs.get('num_parts', 0) 34 | 35 | if pattern_attention: 36 | self.base_dim = 2048 37 | self.dim = 2048 38 | self.part_num = kwargs.get('num_parts', 6) 39 | self.spatial_attention = nn.Conv2d(self.base_dim, self.part_num, kernel_size=1, stride=1, padding=0, bias=True) 40 | torch.nn.init.constant_(self.spatial_attention.bias, 0.0) 41 | self.activation = nn.Sigmoid() 42 | self.weight_sep = kwargs.get('weight_sep', 0.1) 43 | 44 | if mutual_learning: 45 | self.visible_classifier = nn.Linear(self.base_dim + self.dim * self.part_num, num_classes, bias=False) 46 | self.infrared_classifier = nn.Linear(self.base_dim + self.dim * self.part_num, num_classes, bias=False) 47 | 48 | self.visible_classifier_ = nn.Linear(self.base_dim + self.dim * self.part_num, num_classes, bias=False) 49 | self.visible_classifier_.weight.requires_grad_(False) 50 | self.visible_classifier_.weight.data = self.visible_classifier.weight.data 51 | 52 | self.infrared_classifier_ = nn.Linear(self.base_dim + self.dim * self.part_num, num_classes, bias=False) 53 | self.infrared_classifier_.weight.requires_grad_(False) 54 | self.infrared_classifier_.weight.data = self.infrared_classifier.weight.data 55 | 56 | self.KLDivLoss = nn.KLDivLoss(reduction='batchmean') 57 | self.weight_sid = kwargs.get('weight_sid', 0.5) 58 | self.weight_KL = kwargs.get('weight_KL', 2.0) 59 | self.update_rate = kwargs.get('update_rate', 0.2) 60 | self.update_rate_ = self.update_rate 61 | 62 | print("output feat length:{}".format(self.base_dim + self.dim * self.part_num)) 63 | self.bn_neck = nn.BatchNorm1d(self.base_dim + self.dim * self.part_num) 64 | nn.init.constant_(self.bn_neck.bias, 0) 65 | self.bn_neck.bias.requires_grad_(False) 66 | 67 | if kwargs.get('eval', False): 68 | return 69 | 70 | self.classification = kwargs.get('classification', False) 71 | self.triplet = kwargs.get('triplet', False) 72 | self.center_cluster = kwargs.get('center_cluster', False) 73 | self.center_loss = kwargs.get('center', False) 74 | self.margin = kwargs.get('margin', 0.3) 75 | 76 | if self.classification: 77 | self.classifier = nn.Linear(self.base_dim + self.dim * self.part_num , num_classes, bias=False) 78 | if self.mutual_learning or self.classification: 79 | self.id_loss = nn.CrossEntropyLoss(ignore_index=-1) 80 | if self.triplet: 81 | self.triplet_loss = TripletLoss(margin=self.margin) 82 | if self.center_cluster: 83 | k_size = kwargs.get('k_size', 8) 84 | self.center_cluster_loss = CenterTripletLoss(k_size=k_size, margin=self.margin) 85 | if self.center_loss: 86 | self.center_loss = CenterLoss(num_classes, self.base_dim + self.dim * self.part_num) 87 | 88 | def forward(self, inputs, labels=None, **kwargs): 89 | loss_reg = 0 90 | loss_center = 0 91 | modality_logits = None 92 | modality_feat = None 93 | 94 | cam_ids = kwargs.get('cam_ids') 95 | sub = (cam_ids == 3) + (cam_ids == 6) 96 | # CNN 97 | global_feat = self.backbone(inputs) 98 | 99 | b, c, w, h = global_feat.shape 100 | 101 | if self.pattern_attention: 102 | masks = global_feat 103 | masks = self.spatial_attention(masks) 104 | masks = self.activation(masks) 105 | 106 | feats = [] 107 | for i in range(self.part_num): 108 | mask = masks[:, i:i+1, :, :] 109 | feat = mask * global_feat 110 | 111 | feat = F.avg_pool2d(feat, feat.size()[2:]) 112 | feat = feat.view(feat.size(0), -1) 113 | 114 | feats.append(feat) 115 | 116 | global_feat = F.avg_pool2d(global_feat, global_feat.size()[2:]) 117 | global_feat = global_feat.view(global_feat.size(0), -1) 118 | 119 | feats.append(global_feat) 120 | feats = torch.cat(feats, 1) 121 | 122 | if self.training: 123 | masks = masks.view(b, self.part_num, w*h) 124 | loss_reg = torch.bmm(masks, masks.permute(0, 2, 1)) 125 | loss_reg = torch.triu(loss_reg, diagonal = 1).sum() / (b * self.part_num * (self.part_num - 1) / 2) 126 | 127 | else: 128 | feats = F.avg_pool2d(global_feat, global_feat.size()[2:]) 129 | feats = feats.view(feats.size(0), -1) 130 | 131 | if not self.training: 132 | feats = self.bn_neck(feats) 133 | return feats 134 | else: 135 | return self.train_forward(feats, labels, loss_reg, sub, **kwargs) 136 | 137 | def train_forward(self, feat, labels, loss_reg, sub, **kwargs): 138 | epoch = kwargs.get('epoch') 139 | metric = {} 140 | if self.pattern_attention and loss_reg != 0 : 141 | loss = loss_reg.float() * self.weight_sep 142 | metric.update({'p-reg': loss_reg.data}) 143 | else: 144 | loss = 0 145 | 146 | if self.triplet: 147 | triplet_loss, _, _ = self.triplet_loss(feat.float(), labels) 148 | loss += triplet_loss 149 | metric.update({'tri': triplet_loss.data}) 150 | 151 | if self.center_loss: 152 | center_loss = self.center_loss(feat.float(), labels) 153 | loss += center_loss 154 | metric.update({'cen': center_loss.data}) 155 | 156 | if self.center_cluster: 157 | center_cluster_loss, _, _ = self.center_cluster_loss(feat.float(), labels) 158 | loss += center_cluster_loss 159 | metric.update({'cc': center_cluster_loss.data}) 160 | 161 | feat = self.bn_neck(feat) 162 | 163 | if self.classification: 164 | logits = self.classifier(feat) 165 | cls_loss = self.id_loss(logits.float(), labels) 166 | loss += cls_loss 167 | metric.update({'acc': calc_acc(logits.data, labels), 'ce': cls_loss.data}) 168 | 169 | if self.mutual_learning: 170 | # cam_ids = kwargs.get('cam_ids') 171 | # sub = (cam_ids == 3) + (cam_ids == 6) 172 | 173 | logits_v = self.visible_classifier(feat[sub == 0]) 174 | v_cls_loss = self.id_loss(logits_v.float(), labels[sub == 0]) 175 | loss += v_cls_loss * self.weight_sid 176 | logits_i = self.infrared_classifier(feat[sub == 1]) 177 | i_cls_loss = self.id_loss(logits_i.float(), labels[sub == 1]) 178 | loss += i_cls_loss * self.weight_sid 179 | 180 | logits_m = torch.cat([logits_v, logits_i], 0).float() 181 | with torch.no_grad(): 182 | self.infrared_classifier_.weight.data = self.infrared_classifier_.weight.data * (1 - self.update_rate) \ 183 | + self.infrared_classifier.weight.data * self.update_rate 184 | self.visible_classifier_.weight.data = self.visible_classifier_.weight.data * (1 - self.update_rate) \ 185 | + self.visible_classifier.weight.data * self.update_rate 186 | 187 | logits_v_ = self.infrared_classifier_(feat[sub == 0]) 188 | logits_i_ = self.visible_classifier_(feat[sub == 1]) 189 | 190 | logits_m_ = torch.cat([logits_v_, logits_i_], 0).float() 191 | logits_m = F.softmax(logits_m, 1) 192 | logits_m_ = F.log_softmax(logits_m_, 1) 193 | mod_loss = self.KLDivLoss(logits_m_, logits_m) 194 | 195 | loss += mod_loss * self.weight_KL + (v_cls_loss + i_cls_loss) * self.weight_sid 196 | metric.update({'ce-v': v_cls_loss.data}) 197 | metric.update({'ce-i': i_cls_loss.data}) 198 | metric.update({'KL': mod_loss.data}) 199 | 200 | return loss, metric 201 | -------------------------------------------------------------------------------- /data/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import copy 4 | from torch.utils.data import Sampler 5 | from collections import defaultdict 6 | 7 | 8 | class CrossModalityRandomSampler(Sampler): 9 | def __init__(self, dataset, batch_size): 10 | self.dataset = dataset 11 | self.batch_size = batch_size 12 | 13 | self.rgb_list = [] 14 | self.ir_list = [] 15 | for i, cam in enumerate(dataset.cam_ids): 16 | if cam in [3, 6]: 17 | self.ir_list.append(i) 18 | else: 19 | self.rgb_list.append(i) 20 | 21 | def __len__(self): 22 | return max(len(self.rgb_list), len(self.ir_list)) * 2 23 | 24 | def __iter__(self): 25 | sample_list = [] 26 | rgb_list = np.random.permutation(self.rgb_list).tolist() 27 | ir_list = np.random.permutation(self.ir_list).tolist() 28 | 29 | rgb_size = len(self.rgb_list) 30 | ir_size = len(self.ir_list) 31 | if rgb_size >= ir_size: 32 | diff = rgb_size - ir_size 33 | reps = diff // ir_size 34 | pad_size = diff % ir_size 35 | for _ in range(reps): 36 | ir_list.extend(np.random.permutation(self.ir_list).tolist()) 37 | ir_list.extend(np.random.choice(self.ir_list, pad_size, replace=False).tolist()) 38 | else: 39 | diff = ir_size - rgb_size 40 | reps = diff // ir_size 41 | pad_size = diff % ir_size 42 | for _ in range(reps): 43 | rgb_list.extend(np.random.permutation(self.rgb_list).tolist()) 44 | rgb_list.extend(np.random.choice(self.rgb_list, pad_size, replace=False).tolist()) 45 | 46 | assert len(rgb_list) == len(ir_list) 47 | 48 | half_bs = self.batch_size // 2 49 | for start in range(0, len(rgb_list), half_bs): 50 | sample_list.extend(rgb_list[start:start + half_bs]) 51 | sample_list.extend(ir_list[start:start + half_bs]) 52 | 53 | return iter(sample_list) 54 | 55 | 56 | class CrossModalityIdentitySampler(Sampler): 57 | def __init__(self, dataset, p_size, k_size): 58 | self.dataset = dataset 59 | self.p_size = p_size 60 | self.k_size = k_size // 2 61 | self.batch_size = p_size * k_size * 2 62 | 63 | self.id2idx_rgb = defaultdict(list) 64 | self.id2idx_ir = defaultdict(list) 65 | for i, identity in enumerate(dataset.ids): 66 | if dataset.cam_ids[i] in [3, 6]: 67 | self.id2idx_ir[identity].append(i) 68 | else: 69 | self.id2idx_rgb[identity].append(i) 70 | 71 | def __len__(self): 72 | return self.dataset.num_ids * self.k_size * 2 73 | 74 | def __iter__(self): 75 | sample_list = [] 76 | 77 | id_perm = np.random.permutation(self.dataset.num_ids) 78 | for start in range(0, self.dataset.num_ids, self.p_size): 79 | selected_ids = id_perm[start:start + self.p_size] 80 | 81 | sample = [] 82 | for identity in selected_ids: 83 | replace = len(self.id2idx_rgb[identity]) < self.k_size 84 | s = np.random.choice(self.id2idx_rgb[identity], size=self.k_size, replace=replace) 85 | sample.extend(s) 86 | 87 | sample_list.extend(sample) 88 | 89 | sample.clear() 90 | for identity in selected_ids: 91 | replace = len(self.id2idx_ir[identity]) < self.k_size 92 | s = np.random.choice(self.id2idx_ir[identity], size=self.k_size, replace=replace) 93 | sample.extend(s) 94 | 95 | sample_list.extend(sample) 96 | 97 | return iter(sample_list) 98 | 99 | 100 | class RandomIdentitySampler(Sampler): 101 | def __init__(self, data_source, batch_size, num_instances): 102 | self.data_source = data_source 103 | self.batch_size = batch_size 104 | self.num_instances = num_instances 105 | self.num_pids_per_batch = self.batch_size // self.num_instances 106 | self.index_dic_R = defaultdict(list) 107 | self.index_dic_I = defaultdict(list) 108 | for i, identity in enumerate(data_source.ids): 109 | if data_source.cam_ids[i] in [3, 6]: 110 | self.index_dic_I[identity].append(i) 111 | else: 112 | self.index_dic_R[identity].append(i) 113 | self.pids = list(self.index_dic_I.keys()) 114 | 115 | # estimate number of examples in an epoch 116 | self.length = 0 117 | for pid in self.pids: 118 | idxs = self.index_dic_I[pid] 119 | num = len(idxs) 120 | if num < self.num_instances: 121 | num = self.num_instances 122 | self.length += num - num % self.num_instances 123 | 124 | def __iter__(self): 125 | batch_idxs_dict = defaultdict(list) 126 | 127 | for pid in self.pids: 128 | idxs_I = copy.deepcopy(self.index_dic_I[pid]) 129 | idxs_R = copy.deepcopy(self.index_dic_R[pid]) 130 | if len(idxs_I) < self.num_instances // 2 and len(idxs_R) < self.num_instances // 2: 131 | idxs_I = np.random.choice(idxs_I, size=self.num_instances // 2, replace=True) 132 | idxs_R = np.random.choice(idxs_R, size=self.num_instances // 2, replace=True) 133 | if len(idxs_I) > len(idxs_R): 134 | idxs_I = np.random.choice(idxs_I, size=len(idxs_R), replace=False) 135 | if len(idxs_R) > len(idxs_I): 136 | idxs_R = np.random.choice(idxs_R, size=len(idxs_I), replace=False) 137 | np.random.shuffle(idxs_I) 138 | np.random.shuffle(idxs_R) 139 | batch_idxs = [] 140 | for idx_I, idx_R in zip(idxs_I, idxs_R): 141 | batch_idxs.append(idx_I) 142 | batch_idxs.append(idx_R) 143 | if len(batch_idxs) == self.num_instances: 144 | batch_idxs_dict[pid].append(batch_idxs) 145 | batch_idxs = [] 146 | 147 | avai_pids = copy.deepcopy(self.pids) 148 | final_idxs = [] 149 | 150 | while len(avai_pids) >= self.num_pids_per_batch: 151 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False) 152 | for pid in selected_pids: 153 | batch_idxs = batch_idxs_dict[pid].pop(0) 154 | final_idxs.extend(batch_idxs) 155 | if len(batch_idxs_dict[pid]) == 0: 156 | avai_pids.remove(pid) 157 | 158 | self.length = len(final_idxs) 159 | return iter(final_idxs) 160 | 161 | def __len__(self): 162 | return self.length 163 | 164 | 165 | class NormTripletSampler(Sampler): 166 | """ 167 | Randomly sample N identities, then for each identity, 168 | randomly sample K instances, therefore batch size is N*K. 169 | Args: 170 | - data_source (list): list of (img_path, pid, camid). 171 | - num_instances (int): number of instances per identity in a batch. 172 | - batch_size (int): number of examples in a batch. 173 | """ 174 | 175 | def __init__(self, data_source, batch_size, num_instances): 176 | self.data_source = data_source 177 | self.batch_size = batch_size 178 | self.num_instances = num_instances 179 | self.num_pids_per_batch = self.batch_size // self.num_instances 180 | self.index_dic = defaultdict(list) 181 | for index, pid in enumerate(self.data_source.ids): 182 | self.index_dic[pid].append(index) 183 | self.pids = list(self.index_dic.keys()) 184 | 185 | # estimate number of examples in an epoch 186 | self.length = 0 187 | for pid in self.pids: 188 | idxs = self.index_dic[pid] 189 | num = len(idxs) 190 | if num < self.num_instances: 191 | num = self.num_instances 192 | self.length += num - num % self.num_instances 193 | 194 | def __iter__(self): 195 | batch_idxs_dict = defaultdict(list) 196 | 197 | for pid in self.pids: 198 | idxs = copy.deepcopy(self.index_dic[pid]) 199 | if len(idxs) < self.num_instances: 200 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 201 | np.random.shuffle(idxs) 202 | batch_idxs = [] 203 | for idx in idxs: 204 | batch_idxs.append(idx) 205 | if len(batch_idxs) == self.num_instances: 206 | batch_idxs_dict[pid].append(batch_idxs) 207 | batch_idxs = [] 208 | 209 | avai_pids = copy.deepcopy(self.pids) 210 | final_idxs = [] 211 | 212 | while len(avai_pids) >= self.num_pids_per_batch: 213 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False) 214 | for pid in selected_pids: 215 | batch_idxs = batch_idxs_dict[pid].pop(0) 216 | final_idxs.extend(batch_idxs) 217 | if len(batch_idxs_dict[pid]) == 0: 218 | avai_pids.remove(pid) 219 | 220 | self.length = len(final_idxs) 221 | return iter(final_idxs) 222 | 223 | def __len__(self): 224 | return self.length -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | from torchvision.models.utils import load_state_dict_from_url 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 15 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=dilation, groups=groups, bias=False, dilation=dilation) 23 | 24 | 25 | def conv1x1(in_planes, out_planes, stride=1): 26 | """1x1 convolution""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 28 | 29 | 30 | class MAM(nn.Module): 31 | def __init__(self, dim, r=16): 32 | super(MAM, self).__init__() 33 | 34 | self.channel_attention = nn.Sequential( 35 | nn.Conv2d(dim, dim // r, kernel_size=1, bias=False), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(dim // r, dim, kernel_size=1, bias=False), 38 | nn.Sigmoid() 39 | ) 40 | self.IN = nn.InstanceNorm2d(dim, track_running_stats=False) 41 | 42 | def forward(self, x): 43 | pooled = F.avg_pool2d(x, x.size()[2:]) 44 | mask = self.channel_attention(pooled) 45 | x = x * mask + self.IN(x) * (1 - mask) 46 | 47 | return x 48 | 49 | class BasicBlock(nn.Module): 50 | expansion = 1 51 | 52 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 53 | base_width=64, dilation=1, norm_layer=None): 54 | super(BasicBlock, self).__init__() 55 | if norm_layer is None: 56 | norm_layer = nn.BatchNorm2d 57 | if groups != 1 or base_width != 64: 58 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 59 | if dilation > 1: 60 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 61 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 62 | self.conv1 = conv3x3(inplanes, planes, stride) 63 | self.bn1 = norm_layer(planes) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.conv2 = conv3x3(planes, planes) 66 | self.bn2 = norm_layer(planes) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | identity = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | 80 | if self.downsample is not None: 81 | identity = self.downsample(x) 82 | 83 | out += identity 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class Bottleneck(nn.Module): 90 | expansion = 4 91 | 92 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 93 | base_width=64, dilation=1, norm_layer=None): 94 | super(Bottleneck, self).__init__() 95 | if norm_layer is None: 96 | norm_layer = nn.BatchNorm2d 97 | width = int(planes * (base_width / 64.)) * groups 98 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 99 | self.conv1 = conv1x1(inplanes, width) 100 | self.bn1 = norm_layer(width) 101 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 102 | self.bn2 = norm_layer(width) 103 | self.conv3 = conv1x1(width, planes * self.expansion) 104 | self.bn3 = norm_layer(planes * self.expansion) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.downsample = downsample 107 | self.stride = stride 108 | 109 | def forward(self, x): 110 | identity = x 111 | 112 | out = self.conv1(x) 113 | out = self.bn1(out) 114 | out = self.relu(out) 115 | 116 | out = self.conv2(out) 117 | out = self.bn2(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv3(out) 121 | out = self.bn3(out) 122 | 123 | if self.downsample is not None: 124 | identity = self.downsample(x) 125 | 126 | out += identity 127 | out = self.relu(out) 128 | 129 | return out 130 | 131 | 132 | class ResNet(nn.Module): 133 | 134 | def __init__(self, block, layers, zero_init_residual=False, modality_attention=0, 135 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 136 | norm_layer=None, drop_last_stride=False): 137 | super(ResNet, self).__init__() 138 | if norm_layer is None: 139 | norm_layer = nn.BatchNorm2d 140 | self._norm_layer = norm_layer 141 | 142 | self.inplanes = 64 143 | self.dilation = 1 144 | if replace_stride_with_dilation is None: 145 | # each element in the tuple indicates if we should replace 146 | # the 2x2 stride with a dilated convolution instead 147 | replace_stride_with_dilation = [False, False, False] 148 | if len(replace_stride_with_dilation) != 3: 149 | raise ValueError("replace_stride_with_dilation should be None " 150 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 151 | self.groups = groups 152 | self.base_width = width_per_group 153 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 154 | bias=False) 155 | self.bn1 = norm_layer(self.inplanes) 156 | self.relu = nn.ReLU(inplace=True) 157 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 158 | self.layer1 = self._make_layer(block, 64, layers[0]) 159 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 160 | dilate=replace_stride_with_dilation[0]) 161 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 162 | dilate=replace_stride_with_dilation[1]) 163 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1 if drop_last_stride else 2, 164 | dilate=replace_stride_with_dilation[2]) 165 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 166 | 167 | self.modality_attention = modality_attention 168 | if self.modality_attention > 1: 169 | self.MAM3 = MAM(1024) 170 | if self.modality_attention > 0: 171 | self.MAM4 = MAM(2048) 172 | 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 176 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 177 | nn.init.constant_(m.weight, 1) 178 | nn.init.constant_(m.bias, 0) 179 | 180 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 181 | norm_layer = self._norm_layer 182 | downsample = None 183 | previous_dilation = self.dilation 184 | if dilate: 185 | self.dilation *= stride 186 | stride = 1 187 | if stride != 1 or self.inplanes != planes * block.expansion: 188 | downsample = nn.Sequential( 189 | conv1x1(self.inplanes, planes * block.expansion, stride), 190 | norm_layer(planes * block.expansion), 191 | ) 192 | 193 | layers = [] 194 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 195 | self.base_width, previous_dilation, norm_layer)) 196 | self.inplanes = planes * block.expansion 197 | for _ in range(1, blocks): 198 | layers.append(block(self.inplanes, planes, groups=self.groups, 199 | base_width=self.base_width, dilation=self.dilation, 200 | norm_layer=norm_layer)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def forward(self, x): 205 | x = self.conv1(x) 206 | x = self.bn1(x) 207 | x = self.relu(x) 208 | x = self.maxpool(x) 209 | 210 | x = self.layer1(x) 211 | x = self.layer2(x) 212 | x = self.layer3(x) 213 | 214 | if self.modality_attention > 1: 215 | x = self.MAM3(x) 216 | 217 | x = self.layer4(x) 218 | 219 | if self.modality_attention > 0: 220 | x = self.MAM4(x) 221 | 222 | return x 223 | 224 | 225 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 226 | model = ResNet(block, layers, **kwargs) 227 | if pretrained: 228 | state_dict = load_state_dict_from_url(model_urls[arch], 229 | progress=progress) 230 | model.load_state_dict(state_dict, strict=False) 231 | return model 232 | 233 | 234 | def resnet18(pretrained=False, progress=True, **kwargs): 235 | """Constructs a ResNet-18 model. 236 | 237 | Args: 238 | pretrained (bool): If True, returns a model pre-trained on ImageNet 239 | progress (bool): If True, displays a progress bar of the download to stderr 240 | """ 241 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 242 | **kwargs) 243 | 244 | 245 | def resnet34(pretrained=False, progress=True, **kwargs): 246 | """Constructs a ResNet-34 model. 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | progress (bool): If True, displays a progress bar of the download to stderr 251 | """ 252 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 253 | **kwargs) 254 | 255 | 256 | def resnet50(pretrained=False, progress=True, **kwargs): 257 | """Constructs a ResNet-50 model. 258 | 259 | Args: 260 | pretrained (bool): If True, returns a model pre-trained on ImageNet 261 | progress (bool): If True, displays a progress bar of the download to stderr 262 | """ 263 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 264 | **kwargs) 265 | 266 | 267 | def resnet101(pretrained=False, progress=True, **kwargs): 268 | """Constructs a ResNet-101 model. 269 | 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | """ 274 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 275 | **kwargs) 276 | 277 | 278 | def resnet152(pretrained=False, progress=True, **kwargs): 279 | """Constructs a ResNet-152 model. 280 | 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | progress (bool): If True, displays a progress bar of the download to stderr 284 | """ 285 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 286 | **kwargs) 287 | 288 | 289 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 290 | """Constructs a ResNeXt-50 32x4d model. 291 | 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained on ImageNet 294 | progress (bool): If True, displays a progress bar of the download to stderr 295 | """ 296 | kwargs['groups'] = 32 297 | kwargs['width_per_group'] = 4 298 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 299 | pretrained, progress, **kwargs) 300 | 301 | 302 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 303 | """Constructs a ResNeXt-101 32x8d model. 304 | 305 | Args: 306 | pretrained (bool): If True, returns a model pre-trained on ImageNet 307 | progress (bool): If True, displays a progress bar of the download to stderr 308 | """ 309 | kwargs['groups'] = 32 310 | kwargs['width_per_group'] = 8 311 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 312 | pretrained, progress, **kwargs) 313 | --------------------------------------------------------------------------------