├── Experiment-all_tricks-tri_center-market.sh
├── LICENCE.md
├── README.md
├── Test-all_tricks-tri_center-feat_after_bn-cos-market.sh
├── config
├── __init__.py
└── defaults.py
├── configs
├── baseline.yml
├── softmax.yml
├── softmax_triplet.yml
└── softmax_triplet_with_center.yml
├── data
├── __init__.py
├── build.py
├── collate_batch.py
├── datasets
│ ├── __init__.py
│ ├── bases.py
│ ├── cuhk03.py
│ ├── dataset_loader.py
│ ├── dukemtmcreid.py
│ ├── eval_reid.py
│ ├── market1501.py
│ ├── msmt17.py
│ ├── nformer.py
│ └── veri.py
├── samplers
│ ├── __init__.py
│ └── triplet_sampler.py
└── transforms
│ ├── __init__.py
│ ├── build.py
│ └── transforms.py
├── engine
├── inference.py
└── trainer.py
├── layers
├── __init__.py
├── center_loss.py
└── triplet_loss.py
├── modeling
├── __init__.py
├── backbones
│ ├── __init__.py
│ ├── resnet.py
│ ├── resnet_ibn_a.py
│ └── senet.py
├── baseline.py
├── model.py
└── nformer.py
├── pipeline.jpg
├── solver
├── __init__.py
├── build.py
└── lr_scheduler.py
├── tools
├── __init__.py
├── nformer_train.py
├── test.py
└── train.py
└── utils
├── __init__.py
├── iotools.py
├── logger.py
├── re_ranking.py
└── reid_metric.py
/Experiment-all_tricks-tri_center-market.sh:
--------------------------------------------------------------------------------
1 | # Experiment all tricks with center loss : 256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005
2 | # Dataset 1: market1501
3 | # imagesize: 256x128
4 | # batchsize: 16x4
5 | # warmup_step 10
6 | # random erase prob 0.5
7 | # labelsmooth: on
8 | # last stride 1
9 | # bnneck on
10 | # with center loss
11 | python3 tools/train.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haochen/workspace/project/NFORMER/')" OUTPUT_DIR "('test')"
12 |
--------------------------------------------------------------------------------
/LICENCE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2019] [HaoLuo]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NFormer
2 |
3 | Implementation of NFormer: Robust Person Re-identification with Neighbor Transformer. CVPR2022
4 |
5 | ## Pipeline
6 |
7 |

8 |
9 |
10 | ## Requirements
11 | - Python3
12 | - pytorch>=0.4
13 | - torchvision
14 | - pytorch-ignite=0.1.2 (Note: V0.2.0 may result in an error)
15 | - yacs
16 | ## Hardware
17 | - 1 NVIDIA 3090 Ti
18 |
19 | ## Dataset
20 | Create a directory to store reid datasets under this repo or outside this repo. Set your path to the root of the dataset in `config/defaults.py` or set in scripts `Experiment-all_tricks-tri_center-market.sh` and `Test-all_tricks-tri_center-feat_after_bn-cos-market.sh`.
21 | #### Market1501
22 | * Download dataset to `data/` from https://zheng-lab.cecs.anu.edu.au/Project/project_reid.html
23 | * Extract dataset and rename to `market1501`. The data structure would like:
24 |
25 | ```bash
26 | |- data
27 | |- market1501 # this folder contains 6 files.
28 | |- bounding_box_test/
29 | |- bounding_box_train/
30 | ......
31 | ```
32 |
33 |
34 |
35 | ## Training
36 | download the pretrained [resnet50](https://download.pytorch.org/models/resnet50-19c8e357.pth) model and set the path at [line3](configs/softmax_triplet_with_center.yml)
37 |
38 | run `Experiment-all_tricks-tri_center-market.sh` to train NFormer on Market-1501 dataset
39 | ```
40 | sh Experiment-all_tricks-tri_center-market.sh
41 | ```
42 | or
43 | ```
44 | python3 tools/train.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haochen/workspace/project/NFORMER/')" OUTPUT_DIR "('work_dirs')"
45 | ```
46 |
47 | ## Evaluation
48 | run `Test-all_tricks-tri_center-feat_after_bn-cos-market.sh` to evaluate NFormer on Market-1501 dataset. Change `TEST.TEST_NFORMER` to determine test for NFormer (`'yes'`) or CNNEncoder (`'no'`).
49 |
50 | ```
51 | sh Test-all_tricks-tri_center-feat_after_bn-cos-market.sh
52 | ```
53 | or
54 | ```
55 | python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haochen/workspace/project/NFORMER')" MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('test/nformer_model.pth')" TEST.TEST_NFORMER "('no')"
56 | ```
57 |
58 |
59 |
60 | ## Acknowledgement
61 | This repo is highly based on [reid-strong-baseline](https://github.com/michuanhaohao/reid-strong-baseline), thanks for their excellent work.
62 |
63 | ## Citation
64 | ```
65 | @article{wang2022nformer,
66 | title={NFormer: Robust Person Re-identification with Neighbor Transformer},
67 | author={Wang, Haochen and Shen, Jiayi and Liu, Yongtuo and Gao, Yan and Gavves, Efstratios},
68 | journal={arXiv preprint arXiv:2204.09331},
69 | year={2022}
70 | }
71 |
72 | @InProceedings{Luo_2019_CVPR_Workshops,
73 | author = {Luo, Hao and Gu, Youzhi and Liao, Xingyu and Lai, Shenqi and Jiang, Wei},
74 | title = {Bag of Tricks and a Strong Baseline for Deep Person Re-Identification},
75 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
76 | month = {June},
77 | year = {2019}
78 | }
79 | ```
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/Test-all_tricks-tri_center-feat_after_bn-cos-market.sh:
--------------------------------------------------------------------------------
1 | # Dataset 1: market1501
2 | # imagesize: 256x128
3 | # batchsize: 16x4
4 | # warmup_step 10
5 | # random erase prob 0.5
6 | # labelsmooth: on
7 | # last stride 1
8 | # bnneck on
9 | # with center loss
10 | # without re-ranking
11 | python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haochen/workspace/project/NFORMER')" MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('test/nformer_model.pth')" TEST.TEST_NFORMER "('no')"
12 |
13 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .defaults import _C as cfg
8 |
--------------------------------------------------------------------------------
/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # -----------------------------------------------------------------------------
4 | # Convention about Training / Test specific parameters
5 | # -----------------------------------------------------------------------------
6 | # Whenever an argument can be either used for training or for testing, the
7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter,
8 | # or _TEST for a test-specific parameter.
9 | # For example, the number of images during training will be
10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
11 | # IMAGES_PER_BATCH_TEST
12 |
13 | # -----------------------------------------------------------------------------
14 | # Config definition
15 | # -----------------------------------------------------------------------------
16 |
17 | _C = CN()
18 |
19 | _C.MODEL = CN()
20 | # Using cuda or cpu for training
21 | _C.MODEL.DEVICE = "cuda"
22 | # ID number of GPU
23 | _C.MODEL.DEVICE_ID = '0'
24 | # Name of backbone
25 | _C.MODEL.NAME = 'resnet50'
26 | # Last stride of backbone
27 | _C.MODEL.LAST_STRIDE = 1
28 | # Path to pretrained model of backbone
29 | _C.MODEL.PRETRAIN_PATH = ''
30 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model
31 | # Options: 'imagenet' or 'self'
32 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet'
33 | # If train with BNNeck, options: 'bnneck' or 'no'
34 | _C.MODEL.NECK = 'bnneck'
35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration
36 | _C.MODEL.IF_WITH_CENTER = 'no'
37 | # The loss type of metric loss
38 | # options:['triplet'](without center loss) or ['center','triplet_center'](with center loss)
39 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet'
40 | # For example, if loss type is cross entropy loss + triplet loss + center loss
41 | # the setting should be: _C.MODEL.METRIC_LOSS_TYPE = 'triplet_center' and _C.MODEL.IF_WITH_CENTER = 'yes'
42 |
43 | # If train with label smooth, options: 'on', 'off'
44 | _C.MODEL.IF_LABELSMOOTH = 'on'
45 | _C.MODEL.N_EMBD = 256
46 | _C.MODEL.N_HEAD = 2
47 | _C.MODEL.N_LAYER = 4
48 | _C.MODEL.EMBD_PDROP = 0.1
49 | _C.MODEL.ATTN_PDROP = 0.1
50 | _C.MODEL.RESID_PDROP = 0.1
51 | _C.MODEL.AFN = 'gelu'
52 | _C.MODEL.CLF_PDROP = 0.1
53 | _C.MODEL.TOPK = 20
54 | _C.MODEL.LANDMARK = 10
55 |
56 |
57 | # -----------------------------------------------------------------------------
58 | # INPUT
59 | # -----------------------------------------------------------------------------
60 | _C.INPUT = CN()
61 | # Size of the image during training
62 | _C.INPUT.SIZE_TRAIN = [384, 128]
63 | # Size of the image during test
64 | _C.INPUT.SIZE_TEST = [384, 128]
65 | # Random probability for image horizontal flip
66 | _C.INPUT.PROB = 0.5
67 | # Random probability for random erasing
68 | _C.INPUT.RE_PROB = 0.5
69 | # Values to be used for image normalization
70 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
71 | # Values to be used for image normalization
72 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
73 | # Value of padding size
74 | _C.INPUT.PADDING = 10
75 |
76 | # -----------------------------------------------------------------------------
77 | # Dataset
78 | # -----------------------------------------------------------------------------
79 | _C.DATASETS = CN()
80 | # List of the dataset names for training, as present in paths_catalog.py
81 | _C.DATASETS.NAMES = ('market1501')
82 | # Root directory where datasets should be used (and downloaded if not found)
83 | _C.DATASETS.ROOT_DIR = ('./data')
84 |
85 | # -----------------------------------------------------------------------------
86 | # DataLoader
87 | # -----------------------------------------------------------------------------
88 | _C.DATALOADER = CN()
89 | # Number of data loading threads
90 | _C.DATALOADER.NUM_WORKERS = 8
91 | # Sampler for data loading
92 | _C.DATALOADER.SAMPLER = 'softmax'
93 | # Number of instance for one batch
94 | _C.DATALOADER.NUM_INSTANCE = 16
95 |
96 | # ---------------------------------------------------------------------------- #
97 | # Solver
98 | # ---------------------------------------------------------------------------- #
99 | _C.SOLVER = CN()
100 | # Name of optimizer
101 | _C.SOLVER.OPTIMIZER_NAME = "Adam"
102 | # Number of max epoches
103 | _C.SOLVER.MAX_EPOCHS = 50
104 | # Number of nformer max epoches
105 | _C.SOLVER.NFORMER_MAX_EPOCHS = 20
106 | # Base learning rate
107 | _C.SOLVER.BASE_LR = 3e-4
108 | # Factor of learning bias
109 | _C.SOLVER.BIAS_LR_FACTOR = 2
110 | # Momentum
111 | _C.SOLVER.MOMENTUM = 0.9
112 | # Margin of triplet loss
113 | _C.SOLVER.MARGIN = 0.3
114 | # Margin of cluster ;pss
115 | _C.SOLVER.CLUSTER_MARGIN = 0.3
116 | # Learning rate of SGD to learn the centers of center loss
117 | _C.SOLVER.CENTER_LR = 0.5
118 | # Balanced weight of center loss
119 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005
120 | # Settings of range loss
121 | _C.SOLVER.RANGE_K = 2
122 | _C.SOLVER.RANGE_MARGIN = 0.3
123 | _C.SOLVER.RANGE_ALPHA = 0
124 | _C.SOLVER.RANGE_BETA = 1
125 | _C.SOLVER.RANGE_LOSS_WEIGHT = 1
126 |
127 | # Settings of weight decay
128 | _C.SOLVER.WEIGHT_DECAY = 0.0005
129 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.
130 |
131 | # decay rate of learning rate
132 | _C.SOLVER.GAMMA = 0.1
133 | # decay step of learning rate
134 | _C.SOLVER.STEPS = (30, 55)
135 |
136 | # warm up factor
137 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3
138 | # iterations of warm up
139 | _C.SOLVER.WARMUP_ITERS = 500
140 | # method of warm up, option: 'constant','linear'
141 | _C.SOLVER.WARMUP_METHOD = "linear"
142 |
143 | # epoch number of saving checkpoints
144 | _C.SOLVER.CHECKPOINT_PERIOD = 50
145 | # iteration of display training log
146 | _C.SOLVER.LOG_PERIOD = 100
147 | # epoch number of validation
148 | _C.SOLVER.EVAL_PERIOD = 50
149 |
150 | # Number of images per batch
151 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
152 | # see 2 images per batch
153 | _C.SOLVER.IMS_PER_BATCH = 64
154 |
155 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
156 | # see 2 images per batch
157 | _C.TEST = CN()
158 | # Number of images per batch during test
159 | _C.TEST.IMS_PER_BATCH = 128
160 | # If test with re-ranking, options: 'yes','no'
161 | _C.TEST.RE_RANKING = 'no'
162 | # Path to trained model
163 | _C.TEST.WEIGHT = ""
164 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after'
165 | _C.TEST.NECK_FEAT = 'after'
166 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance
167 | _C.TEST.FEAT_NORM = 'yes'
168 | # Whether test nformer or encoder only
169 | _C.TEST.TEST_NFORMER = 'yes'
170 |
171 | # Data type durining test
172 |
173 | # ---------------------------------------------------------------------------- #
174 | # Misc options
175 | # ---------------------------------------------------------------------------- #
176 | # Path to checkpoint and saved log of trained model
177 | _C.OUTPUT_DIR = ""
178 |
--------------------------------------------------------------------------------
/configs/baseline.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
4 | LAST_STRIDE: 2
5 | NECK: 'no'
6 | METRIC_LOSS_TYPE: 'triplet'
7 | IF_LABELSMOOTH: 'off'
8 | IF_WITH_CENTER: 'no'
9 |
10 |
11 | INPUT:
12 | SIZE_TRAIN: [256, 128]
13 | SIZE_TEST: [256, 128]
14 | PROB: 0.5 # random horizontal flip
15 | RE_PROB: 0.0 # random erasing
16 | PADDING: 10
17 |
18 | DATASETS:
19 | NAMES: ('market1501')
20 |
21 | DATALOADER:
22 | SAMPLER: 'softmax_triplet'
23 | NUM_INSTANCE: 4
24 | NUM_WORKERS: 8
25 |
26 | SOLVER:
27 | OPTIMIZER_NAME: 'Adam'
28 | MAX_EPOCHS: 120
29 | BASE_LR: 0.00035
30 |
31 | CLUSTER_MARGIN: 0.3
32 |
33 | CENTER_LR: 0.5
34 | CENTER_LOSS_WEIGHT: 0.0005
35 |
36 | RANGE_K: 2
37 | RANGE_MARGIN: 0.3
38 | RANGE_ALPHA: 0
39 | RANGE_BETA: 1
40 | RANGE_LOSS_WEIGHT: 1
41 |
42 | BIAS_LR_FACTOR: 1
43 | WEIGHT_DECAY: 0.0005
44 | WEIGHT_DECAY_BIAS: 0.0005
45 | IMS_PER_BATCH: 64
46 |
47 | STEPS: [40, 70]
48 | GAMMA: 0.1
49 |
50 | WARMUP_FACTOR: 0.01
51 | WARMUP_ITERS: 0
52 | WARMUP_METHOD: 'linear'
53 |
54 | CHECKPOINT_PERIOD: 40
55 | LOG_PERIOD: 20
56 | EVAL_PERIOD: 40
57 |
58 | TEST:
59 | IMS_PER_BATCH: 128
60 | RE_RANKING: 'no'
61 | WEIGHT: "path"
62 | NECK_FEAT: 'after'
63 | FEAT_NORM: 'yes'
64 |
65 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on"
66 |
67 |
68 |
--------------------------------------------------------------------------------
/configs/softmax.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
3 |
4 |
5 | INPUT:
6 | SIZE_TRAIN: [256, 128]
7 | SIZE_TEST: [256, 128]
8 | PROB: 0.5 # random horizontal flip
9 | RE_PROB: 0.5 # random erasing
10 | PADDING: 10
11 |
12 | DATASETS:
13 | NAMES: ('market1501')
14 |
15 | DATALOADER:
16 | SAMPLER: 'softmax'
17 | NUM_WORKERS: 8
18 |
19 | SOLVER:
20 | OPTIMIZER_NAME: 'Adam'
21 | MAX_EPOCHS: 120
22 | BASE_LR: 0.00035
23 | BIAS_LR_FACTOR: 1
24 | WEIGHT_DECAY: 0.0005
25 | WEIGHT_DECAY_BIAS: 0.0005
26 | IMS_PER_BATCH: 64
27 |
28 | STEPS: [30, 55]
29 | GAMMA: 0.1
30 |
31 | WARMUP_FACTOR: 0.01
32 | WARMUP_ITERS: 5
33 | WARMUP_METHOD: 'linear'
34 |
35 | CHECKPOINT_PERIOD: 20
36 | LOG_PERIOD: 20
37 | EVAL_PERIOD: 20
38 |
39 | TEST:
40 | IMS_PER_BATCH: 128
41 |
42 | OUTPUT_DIR: "/home/haoluo/log/reid/market1501/softmax_bs64_256x128"
43 |
44 |
45 |
--------------------------------------------------------------------------------
/configs/softmax_triplet.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'on'
6 | IF_WITH_CENTER: 'no'
7 |
8 |
9 |
10 |
11 | INPUT:
12 | SIZE_TRAIN: [256, 128]
13 | SIZE_TEST: [256, 128]
14 | PROB: 0.5 # random horizontal flip
15 | RE_PROB: 0.5 # random erasing
16 | PADDING: 10
17 |
18 | DATASETS:
19 | NAMES: ('market1501')
20 |
21 | DATALOADER:
22 | SAMPLER: 'softmax_triplet'
23 | NUM_INSTANCE: 4
24 | NUM_WORKERS: 8
25 |
26 | SOLVER:
27 | OPTIMIZER_NAME: 'Adam'
28 | MAX_EPOCHS: 120
29 | BASE_LR: 0.00035
30 |
31 | CLUSTER_MARGIN: 0.3
32 |
33 | CENTER_LR: 0.5
34 | CENTER_LOSS_WEIGHT: 0.0005
35 |
36 | RANGE_K: 2
37 | RANGE_MARGIN: 0.3
38 | RANGE_ALPHA: 0
39 | RANGE_BETA: 1
40 | RANGE_LOSS_WEIGHT: 1
41 |
42 | BIAS_LR_FACTOR: 1
43 | WEIGHT_DECAY: 0.0005
44 | WEIGHT_DECAY_BIAS: 0.0005
45 | IMS_PER_BATCH: 64
46 |
47 | STEPS: [40, 70]
48 | GAMMA: 0.1
49 |
50 | WARMUP_FACTOR: 0.01
51 | WARMUP_ITERS: 10
52 | WARMUP_METHOD: 'linear'
53 |
54 | CHECKPOINT_PERIOD: 40
55 | LOG_PERIOD: 20
56 | EVAL_PERIOD: 40
57 |
58 | TEST:
59 | IMS_PER_BATCH: 128
60 | RE_RANKING: 'no'
61 | WEIGHT: "path"
62 | NECK_FEAT: 'after'
63 | FEAT_NORM: 'yes'
64 |
65 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on"
66 |
67 |
68 |
--------------------------------------------------------------------------------
/configs/softmax_triplet_with_center.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/haochen/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth'
4 | METRIC_LOSS_TYPE: 'triplet_center'
5 | IF_LABELSMOOTH: 'on'
6 | IF_WITH_CENTER: 'yes'
7 | N_EMBD: 256
8 | N_HEAD: 2
9 | N_LAYER: 4
10 | EMBD_PDROP: 0.1
11 | ATTN_PDROP: 0.1
12 | RESID_PDROP: 0.1
13 | AFN: 'gelu'
14 | CLF_PDROP: 0.1
15 | TOPK: 20
16 | LANDMARK: 5
17 |
18 |
19 |
20 |
21 | INPUT:
22 | SIZE_TRAIN: [256, 128]
23 | SIZE_TEST: [256, 128]
24 | PROB: 0.5 # random horizontal flip
25 | RE_PROB: 0.5 # random erasing
26 | PADDING: 10
27 |
28 | DATASETS:
29 | NAMES: ('market1501')
30 |
31 | DATALOADER:
32 | SAMPLER: 'softmax_triplet'
33 | NUM_INSTANCE: 4
34 | NUM_WORKERS: 8
35 |
36 | SOLVER:
37 | OPTIMIZER_NAME: 'Adam'
38 | MAX_EPOCHS: 120
39 | NFORMER_MAX_EPOCHS: 20
40 | BASE_LR: 0.00035
41 |
42 | CLUSTER_MARGIN: 0.3
43 |
44 | CENTER_LR: 0.5
45 | CENTER_LOSS_WEIGHT: 0.0005
46 |
47 | RANGE_K: 2
48 | RANGE_MARGIN: 0.3
49 | RANGE_ALPHA: 0
50 | RANGE_BETA: 1
51 | RANGE_LOSS_WEIGHT: 1
52 |
53 | BIAS_LR_FACTOR: 1
54 | WEIGHT_DECAY: 0.0005
55 | WEIGHT_DECAY_BIAS: 0.0005
56 | IMS_PER_BATCH: 64
57 |
58 | STEPS: [40, 70]
59 | GAMMA: 0.1
60 |
61 | WARMUP_FACTOR: 0.01
62 | WARMUP_ITERS: 10
63 | WARMUP_METHOD: 'linear'
64 |
65 | CHECKPOINT_PERIOD: 40
66 | LOG_PERIOD: 20
67 | EVAL_PERIOD: 40
68 |
69 | TEST:
70 | IMS_PER_BATCH: 128
71 | RE_RANKING: 'no'
72 | WEIGHT: "path"
73 | NECK_FEAT: 'after'
74 | FEAT_NORM: 'yes'
75 | TEST_NFORMER: 'NO'
76 |
77 | OUTPUT_DIR: "work_dirs"
78 |
79 |
80 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .build import make_data_loader
8 |
--------------------------------------------------------------------------------
/data/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from torch.utils.data import DataLoader
8 |
9 | from .collate_batch import train_collate_fn, val_collate_fn
10 | from .datasets import init_dataset, ImageDataset
11 | from .samplers import RandomIdentitySampler, RandomIdentitySampler_alignedreid # New add by gu
12 | from .transforms import build_transforms
13 |
14 |
15 | def make_data_loader(cfg):
16 | train_transforms = build_transforms(cfg, is_train=True)
17 | val_transforms = build_transforms(cfg, is_train=False)
18 | num_workers = cfg.DATALOADER.NUM_WORKERS
19 | if len(cfg.DATASETS.NAMES) == 1:
20 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
21 | else:
22 | # TODO: add multi dataset to train
23 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
24 |
25 | num_classes = dataset.num_train_pids
26 | train_set = ImageDataset(dataset.train, train_transforms)
27 | if cfg.DATALOADER.SAMPLER == 'softmax':
28 | train_loader = DataLoader(
29 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
30 | collate_fn=train_collate_fn
31 | )
32 | else:
33 | train_loader = DataLoader(
34 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
35 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
36 | # sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE), # new add by gu
37 | num_workers=num_workers, collate_fn=train_collate_fn
38 | )
39 |
40 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
41 | val_loader = DataLoader(
42 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
43 | collate_fn=val_collate_fn
44 | )
45 | return train_loader, val_loader, len(dataset.query), num_classes
46 |
--------------------------------------------------------------------------------
/data/collate_batch.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 |
9 |
10 | def train_collate_fn(batch):
11 | imgs, pids, _, _, = zip(*batch)
12 | pids = torch.tensor(pids, dtype=torch.int64)
13 | return torch.stack(imgs, dim=0), pids
14 |
15 |
16 | def val_collate_fn(batch):
17 | imgs, pids, camids, _ = zip(*batch)
18 | return torch.stack(imgs, dim=0), pids, camids
19 |
--------------------------------------------------------------------------------
/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | # from .cuhk03 import CUHK03
7 | from .dukemtmcreid import DukeMTMCreID
8 | from .market1501 import Market1501
9 | from .msmt17 import MSMT17
10 | from .veri import VeRi
11 | from .dataset_loader import ImageDataset
12 |
13 | __factory = {
14 | 'market1501': Market1501,
15 | # 'cuhk03': CUHK03,
16 | 'dukemtmc': DukeMTMCreID,
17 | 'msmt17': MSMT17,
18 | 'veri': VeRi,
19 | }
20 |
21 |
22 | def get_names():
23 | return __factory.keys()
24 |
25 |
26 | def init_dataset(name, *args, **kwargs):
27 | if name not in __factory.keys():
28 | raise KeyError("Unknown datasets: {}".format(name))
29 | return __factory[name](*args, **kwargs)
30 |
--------------------------------------------------------------------------------
/data/datasets/bases.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import numpy as np
8 |
9 |
10 | class BaseDataset(object):
11 | """
12 | Base class of reid dataset
13 | """
14 |
15 | def get_imagedata_info(self, data):
16 | pids, cams = [], []
17 | for _, pid, camid in data:
18 | pids += [pid]
19 | cams += [camid]
20 | pids = set(pids)
21 | cams = set(cams)
22 | num_pids = len(pids)
23 | num_cams = len(cams)
24 | num_imgs = len(data)
25 | return num_pids, num_imgs, num_cams
26 |
27 | def get_videodata_info(self, data, return_tracklet_stats=False):
28 | pids, cams, tracklet_stats = [], [], []
29 | for img_paths, pid, camid in data:
30 | pids += [pid]
31 | cams += [camid]
32 | tracklet_stats += [len(img_paths)]
33 | pids = set(pids)
34 | cams = set(cams)
35 | num_pids = len(pids)
36 | num_cams = len(cams)
37 | num_tracklets = len(data)
38 | if return_tracklet_stats:
39 | return num_pids, num_tracklets, num_cams, tracklet_stats
40 | return num_pids, num_tracklets, num_cams
41 |
42 | def print_dataset_statistics(self):
43 | raise NotImplementedError
44 |
45 |
46 | class BaseImageDataset(BaseDataset):
47 | """
48 | Base class of image reid dataset
49 | """
50 |
51 | def print_dataset_statistics(self, train, query, gallery):
52 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
53 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
54 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
55 |
56 | print("Dataset statistics:")
57 | print(" ----------------------------------------")
58 | print(" subset | # ids | # images | # cameras")
59 | print(" ----------------------------------------")
60 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
61 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
62 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
63 | print(" ----------------------------------------")
64 |
65 |
66 | class BaseVideoDataset(BaseDataset):
67 | """
68 | Base class of video reid dataset
69 | """
70 |
71 | def print_dataset_statistics(self, train, query, gallery):
72 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \
73 | self.get_videodata_info(train, return_tracklet_stats=True)
74 |
75 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \
76 | self.get_videodata_info(query, return_tracklet_stats=True)
77 |
78 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \
79 | self.get_videodata_info(gallery, return_tracklet_stats=True)
80 |
81 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats
82 | min_num = np.min(tracklet_stats)
83 | max_num = np.max(tracklet_stats)
84 | avg_num = np.mean(tracklet_stats)
85 |
86 | print("Dataset statistics:")
87 | print(" -------------------------------------------")
88 | print(" subset | # ids | # tracklets | # cameras")
89 | print(" -------------------------------------------")
90 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams))
91 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams))
92 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams))
93 | print(" -------------------------------------------")
94 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num))
95 | print(" -------------------------------------------")
96 |
--------------------------------------------------------------------------------
/data/datasets/cuhk03.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import h5py
8 | import os.path as osp
9 | from scipy.io import loadmat
10 | from scipy.misc import imsave
11 |
12 | from utils.iotools import mkdir_if_missing, write_json, read_json
13 | from .bases import BaseImageDataset
14 |
15 |
16 | class CUHK03(BaseImageDataset):
17 | """
18 | CUHK03
19 | Reference:
20 | Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014.
21 | URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#!
22 |
23 | Dataset statistics:
24 | # identities: 1360
25 | # images: 13164
26 | # cameras: 6
27 | # splits: 20 (classic)
28 | Args:
29 | split_id (int): split index (default: 0)
30 | cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False)
31 | """
32 | dataset_dir = 'cuhk03'
33 |
34 | def __init__(self, root='/home/haoluo/data', split_id=0, cuhk03_labeled=False,
35 | cuhk03_classic_split=False, verbose=True,
36 | **kwargs):
37 | super(CUHK03, self).__init__()
38 | self.dataset_dir = osp.join(root, self.dataset_dir)
39 | self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release')
40 | self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat')
41 |
42 | self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected')
43 | self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled')
44 |
45 | self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json')
46 | self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json')
47 |
48 | self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json')
49 | self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json')
50 |
51 | self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat')
52 | self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat')
53 |
54 | self._check_before_run()
55 | self._preprocess()
56 |
57 | if cuhk03_labeled:
58 | image_type = 'labeled'
59 | split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path
60 | else:
61 | image_type = 'detected'
62 | split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path
63 |
64 | splits = read_json(split_path)
65 | assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id,
66 | len(splits))
67 | split = splits[split_id]
68 | print("Split index = {}".format(split_id))
69 |
70 | train = split['train']
71 | query = split['query']
72 | gallery = split['gallery']
73 |
74 | if verbose:
75 | print("=> CUHK03 ({}) loaded".format(image_type))
76 | self.print_dataset_statistics(train, query, gallery)
77 |
78 | self.train = train
79 | self.query = query
80 | self.gallery = gallery
81 |
82 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
83 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
84 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
85 |
86 | def _check_before_run(self):
87 | """Check if all files are available before going deeper"""
88 | if not osp.exists(self.dataset_dir):
89 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
90 | if not osp.exists(self.data_dir):
91 | raise RuntimeError("'{}' is not available".format(self.data_dir))
92 | if not osp.exists(self.raw_mat_path):
93 | raise RuntimeError("'{}' is not available".format(self.raw_mat_path))
94 | if not osp.exists(self.split_new_det_mat_path):
95 | raise RuntimeError("'{}' is not available".format(self.split_new_det_mat_path))
96 | if not osp.exists(self.split_new_lab_mat_path):
97 | raise RuntimeError("'{}' is not available".format(self.split_new_lab_mat_path))
98 |
99 | def _preprocess(self):
100 | """
101 | This function is a bit complex and ugly, what it does is
102 | 1. Extract data from cuhk-03.mat and save as png images.
103 | 2. Create 20 classic splits. (Li et al. CVPR'14)
104 | 3. Create new split. (Zhong et al. CVPR'17)
105 | """
106 | print(
107 | "Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)")
108 | if osp.exists(self.imgs_labeled_dir) and \
109 | osp.exists(self.imgs_detected_dir) and \
110 | osp.exists(self.split_classic_det_json_path) and \
111 | osp.exists(self.split_classic_lab_json_path) and \
112 | osp.exists(self.split_new_det_json_path) and \
113 | osp.exists(self.split_new_lab_json_path):
114 | return
115 |
116 | mkdir_if_missing(self.imgs_detected_dir)
117 | mkdir_if_missing(self.imgs_labeled_dir)
118 |
119 | print("Extract image data from {} and save as png".format(self.raw_mat_path))
120 | mat = h5py.File(self.raw_mat_path, 'r')
121 |
122 | def _deref(ref):
123 | return mat[ref][:].T
124 |
125 | def _process_images(img_refs, campid, pid, save_dir):
126 | img_paths = [] # Note: some persons only have images for one view
127 | for imgid, img_ref in enumerate(img_refs):
128 | img = _deref(img_ref)
129 | # skip empty cell
130 | if img.size == 0 or img.ndim < 3: continue
131 | # images are saved with the following format, index-1 (ensure uniqueness)
132 | # campid: index of camera pair (1-5)
133 | # pid: index of person in 'campid'-th camera pair
134 | # viewid: index of view, {1, 2}
135 | # imgid: index of image, (1-10)
136 | viewid = 1 if imgid < 5 else 2
137 | img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid + 1, pid + 1, viewid, imgid + 1)
138 | img_path = osp.join(save_dir, img_name)
139 | if not osp.isfile(img_path):
140 | imsave(img_path, img)
141 | img_paths.append(img_path)
142 | return img_paths
143 |
144 | def _extract_img(name):
145 | print("Processing {} images (extract and save) ...".format(name))
146 | meta_data = []
147 | imgs_dir = self.imgs_detected_dir if name == 'detected' else self.imgs_labeled_dir
148 | for campid, camp_ref in enumerate(mat[name][0]):
149 | camp = _deref(camp_ref)
150 | num_pids = camp.shape[0]
151 | for pid in range(num_pids):
152 | img_paths = _process_images(camp[pid, :], campid, pid, imgs_dir)
153 | assert len(img_paths) > 0, "campid{}-pid{} has no images".format(campid, pid)
154 | meta_data.append((campid + 1, pid + 1, img_paths))
155 | print("- done camera pair {} with {} identities".format(campid + 1, num_pids))
156 | return meta_data
157 |
158 | meta_detected = _extract_img('detected')
159 | meta_labeled = _extract_img('labeled')
160 |
161 | def _extract_classic_split(meta_data, test_split):
162 | train, test = [], []
163 | num_train_pids, num_test_pids = 0, 0
164 | num_train_imgs, num_test_imgs = 0, 0
165 | for i, (campid, pid, img_paths) in enumerate(meta_data):
166 |
167 | if [campid, pid] in test_split:
168 | for img_path in img_paths:
169 | camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
170 | test.append((img_path, num_test_pids, camid))
171 | num_test_pids += 1
172 | num_test_imgs += len(img_paths)
173 | else:
174 | for img_path in img_paths:
175 | camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
176 | train.append((img_path, num_train_pids, camid))
177 | num_train_pids += 1
178 | num_train_imgs += len(img_paths)
179 | return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs
180 |
181 | print("Creating classic splits (# = 20) ...")
182 | splits_classic_det, splits_classic_lab = [], []
183 | for split_ref in mat['testsets'][0]:
184 | test_split = _deref(split_ref).tolist()
185 |
186 | # create split for detected images
187 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
188 | _extract_classic_split(meta_detected, test_split)
189 | splits_classic_det.append({
190 | 'train': train, 'query': test, 'gallery': test,
191 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs,
192 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs,
193 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs,
194 | })
195 |
196 | # create split for labeled images
197 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
198 | _extract_classic_split(meta_labeled, test_split)
199 | splits_classic_lab.append({
200 | 'train': train, 'query': test, 'gallery': test,
201 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs,
202 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs,
203 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs,
204 | })
205 |
206 | write_json(splits_classic_det, self.split_classic_det_json_path)
207 | write_json(splits_classic_lab, self.split_classic_lab_json_path)
208 |
209 | def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel):
210 | tmp_set = []
211 | unique_pids = set()
212 | for idx in idxs:
213 | img_name = filelist[idx][0]
214 | camid = int(img_name.split('_')[2]) - 1 # make it 0-based
215 | pid = pids[idx]
216 | if relabel: pid = pid2label[pid]
217 | img_path = osp.join(img_dir, img_name)
218 | tmp_set.append((img_path, int(pid), camid))
219 | unique_pids.add(pid)
220 | return tmp_set, len(unique_pids), len(idxs)
221 |
222 | def _extract_new_split(split_dict, img_dir):
223 | train_idxs = split_dict['train_idx'].flatten() - 1 # index-0
224 | pids = split_dict['labels'].flatten()
225 | train_pids = set(pids[train_idxs])
226 | pid2label = {pid: label for label, pid in enumerate(train_pids)}
227 | query_idxs = split_dict['query_idx'].flatten() - 1
228 | gallery_idxs = split_dict['gallery_idx'].flatten() - 1
229 | filelist = split_dict['filelist'].flatten()
230 | train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True)
231 | query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False)
232 | gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False)
233 | return train_info, query_info, gallery_info
234 |
235 | print("Creating new splits for detected images (767/700) ...")
236 | train_info, query_info, gallery_info = _extract_new_split(
237 | loadmat(self.split_new_det_mat_path),
238 | self.imgs_detected_dir,
239 | )
240 | splits = [{
241 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0],
242 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2],
243 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2],
244 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2],
245 | }]
246 | write_json(splits, self.split_new_det_json_path)
247 |
248 | print("Creating new splits for labeled images (767/700) ...")
249 | train_info, query_info, gallery_info = _extract_new_split(
250 | loadmat(self.split_new_lab_mat_path),
251 | self.imgs_labeled_dir,
252 | )
253 | splits = [{
254 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0],
255 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2],
256 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2],
257 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2],
258 | }]
259 | write_json(splits, self.split_new_lab_json_path)
260 |
--------------------------------------------------------------------------------
/data/datasets/dataset_loader.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import os.path as osp
8 | from PIL import Image
9 | from torch.utils.data import Dataset
10 |
11 |
12 | def read_image(img_path):
13 | """Keep reading image until succeed.
14 | This can avoid IOError incurred by heavy IO process."""
15 | got_img = False
16 | if not osp.exists(img_path):
17 | raise IOError("{} does not exist".format(img_path))
18 | while not got_img:
19 | try:
20 | img = Image.open(img_path).convert('RGB')
21 | got_img = True
22 | except IOError:
23 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
24 | pass
25 | return img
26 |
27 |
28 | class ImageDataset(Dataset):
29 | """Image Person ReID Dataset"""
30 |
31 | def __init__(self, dataset, transform=None):
32 | self.dataset = dataset
33 | self.transform = transform
34 |
35 | def __len__(self):
36 | return len(self.dataset)
37 |
38 | def __getitem__(self, index):
39 | img_path, pid, camid = self.dataset[index]
40 | img = read_image(img_path)
41 |
42 | if self.transform is not None:
43 | img = self.transform(img)
44 |
45 | return img, pid, camid, img_path
46 |
--------------------------------------------------------------------------------
/data/datasets/dukemtmcreid.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import glob
8 | import re
9 | import urllib
10 | import zipfile
11 |
12 | import os.path as osp
13 |
14 | from utils.iotools import mkdir_if_missing
15 | from .bases import BaseImageDataset
16 |
17 |
18 | class DukeMTMCreID(BaseImageDataset):
19 | """
20 | DukeMTMC-reID
21 | Reference:
22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation
25 |
26 | Dataset statistics:
27 | # identities: 1404 (train + query)
28 | # images:16522 (train) + 2228 (query) + 17661 (gallery)
29 | # cameras: 8
30 | """
31 | dataset_dir = 'dukemtmc-reid'
32 |
33 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs):
34 | super(DukeMTMCreID, self).__init__()
35 | self.dataset_dir = osp.join(root, self.dataset_dir)
36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
37 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train')
38 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query')
39 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test')
40 |
41 | self._download_data()
42 | self._check_before_run()
43 |
44 | train = self._process_dir(self.train_dir, relabel=True)
45 | query = self._process_dir(self.query_dir, relabel=False)
46 | gallery = self._process_dir(self.gallery_dir, relabel=False)
47 |
48 | if verbose:
49 | print("=> DukeMTMC-reID loaded")
50 | self.print_dataset_statistics(train, query, gallery)
51 |
52 | self.train = train
53 | self.query = query
54 | self.gallery = gallery
55 |
56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
59 |
60 | def _download_data(self):
61 | if osp.exists(self.dataset_dir):
62 | print("This dataset has been downloaded.")
63 | return
64 |
65 | print("Creating directory {}".format(self.dataset_dir))
66 | mkdir_if_missing(self.dataset_dir)
67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url))
68 |
69 | print("Downloading DukeMTMC-reID dataset")
70 | urllib.request.urlretrieve(self.dataset_url, fpath)
71 |
72 | print("Extracting files")
73 | zip_ref = zipfile.ZipFile(fpath, 'r')
74 | zip_ref.extractall(self.dataset_dir)
75 | zip_ref.close()
76 |
77 | def _check_before_run(self):
78 | """Check if all files are available before going deeper"""
79 | if not osp.exists(self.dataset_dir):
80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
81 | if not osp.exists(self.train_dir):
82 | raise RuntimeError("'{}' is not available".format(self.train_dir))
83 | if not osp.exists(self.query_dir):
84 | raise RuntimeError("'{}' is not available".format(self.query_dir))
85 | if not osp.exists(self.gallery_dir):
86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
87 |
88 | def _process_dir(self, dir_path, relabel=False):
89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
90 | pattern = re.compile(r'([-\d]+)_c(\d)')
91 |
92 | pid_container = set()
93 | for img_path in img_paths:
94 | pid, _ = map(int, pattern.search(img_path).groups())
95 | pid_container.add(pid)
96 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
97 |
98 | dataset = []
99 | for img_path in img_paths:
100 | pid, camid = map(int, pattern.search(img_path).groups())
101 | assert 1 <= camid <= 8
102 | camid -= 1 # index starts from 0
103 | if relabel: pid = pid2label[pid]
104 | dataset.append((img_path, pid, camid))
105 |
106 | return dataset
107 |
--------------------------------------------------------------------------------
/data/datasets/eval_reid.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import numpy as np
8 |
9 |
10 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
11 | """Evaluation with market1501 metric
12 | Key: for each query identity, its gallery images from the same camera view are discarded.
13 | """
14 | num_q, num_g = distmat.shape
15 | if num_g < max_rank:
16 | max_rank = num_g
17 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
18 | indices = np.argsort(distmat, axis=1)
19 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
20 |
21 | # compute cmc curve for each query
22 | all_cmc = []
23 | all_AP = []
24 | num_valid_q = 0. # number of valid query
25 | for q_idx in range(num_q):
26 | # get query pid and camid
27 | q_pid = q_pids[q_idx]
28 | q_camid = q_camids[q_idx]
29 |
30 | # remove gallery samples that have the same pid and camid with query
31 | order = indices[q_idx]
32 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
33 | keep = np.invert(remove)
34 |
35 | # compute cmc curve
36 | # binary vector, positions with value 1 are correct matches
37 | orig_cmc = matches[q_idx][keep]
38 | if not np.any(orig_cmc):
39 | # this condition is true when query identity does not appear in gallery
40 | continue
41 |
42 | cmc = orig_cmc.cumsum()
43 | cmc[cmc > 1] = 1
44 |
45 | all_cmc.append(cmc[:max_rank])
46 | num_valid_q += 1.
47 |
48 | # compute average precision
49 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
50 | num_rel = orig_cmc.sum()
51 | tmp_cmc = orig_cmc.cumsum()
52 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
53 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
54 | AP = tmp_cmc.sum() / num_rel
55 | all_AP.append(AP)
56 |
57 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
58 |
59 | all_cmc = np.asarray(all_cmc).astype(np.float32)
60 | all_cmc = all_cmc.sum(0) / num_valid_q
61 | mAP = np.mean(all_AP)
62 |
63 | return all_cmc, mAP
64 |
--------------------------------------------------------------------------------
/data/datasets/market1501.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import glob
8 | import re
9 |
10 | import os.path as osp
11 |
12 | from .bases import BaseImageDataset
13 |
14 |
15 | class Market1501(BaseImageDataset):
16 | """
17 | Market1501
18 | Reference:
19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
20 | URL: http://www.liangzheng.org/Project/project_reid.html
21 |
22 | Dataset statistics:
23 | # identities: 1501 (+1 for background)
24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery)
25 | """
26 | dataset_dir = 'market1501'
27 |
28 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs):
29 | super(Market1501, self).__init__()
30 | self.dataset_dir = osp.join(root, self.dataset_dir)
31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
32 | self.query_dir = osp.join(self.dataset_dir, 'query')
33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
34 |
35 | self._check_before_run()
36 |
37 | train = self._process_dir(self.train_dir, relabel=True)
38 | query = self._process_dir(self.query_dir, relabel=False)
39 | gallery = self._process_dir(self.gallery_dir, relabel=False)
40 |
41 | if verbose:
42 | print("=> Market1501 loaded")
43 | self.print_dataset_statistics(train, query, gallery)
44 |
45 | self.train = train
46 | self.query = query
47 | self.gallery = gallery
48 |
49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
52 |
53 | def _check_before_run(self):
54 | """Check if all files are available before going deeper"""
55 | if not osp.exists(self.dataset_dir):
56 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
57 | if not osp.exists(self.train_dir):
58 | raise RuntimeError("'{}' is not available".format(self.train_dir))
59 | if not osp.exists(self.query_dir):
60 | raise RuntimeError("'{}' is not available".format(self.query_dir))
61 | if not osp.exists(self.gallery_dir):
62 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
63 |
64 | def _process_dir(self, dir_path, relabel=False):
65 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
66 | pattern = re.compile(r'([-\d]+)_c(\d)')
67 |
68 | pid_container = set()
69 | for img_path in img_paths:
70 | pid, _ = map(int, pattern.search(img_path).groups())
71 | if pid == -1: continue # junk images are just ignored
72 | pid_container.add(pid)
73 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
74 |
75 | dataset = []
76 | for img_path in img_paths:
77 | pid, camid = map(int, pattern.search(img_path).groups())
78 | if pid == -1: continue # junk images are just ignored
79 | assert 0 <= pid <= 1501 # pid == 0 means background
80 | assert 1 <= camid <= 6
81 | camid -= 1 # index starts from 0
82 | if relabel: pid = pid2label[pid]
83 | dataset.append((img_path, pid, camid))
84 |
85 | return dataset
86 |
--------------------------------------------------------------------------------
/data/datasets/msmt17.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/1/17 15:00
4 | # @Author : Hao Luo
5 | # @File : msmt17.py
6 |
7 | import glob
8 | import re
9 |
10 | import os.path as osp
11 |
12 | from .bases import BaseImageDataset
13 |
14 |
15 | class MSMT17(BaseImageDataset):
16 | """
17 | MSMT17
18 |
19 | Reference:
20 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018.
21 |
22 | URL: http://www.pkuvmc.com/publications/msmt17.html
23 |
24 | Dataset statistics:
25 | # identities: 4101
26 | # images: 32621 (train) + 11659 (query) + 82161 (gallery)
27 | # cameras: 15
28 | """
29 | dataset_dir = 'msmt17'
30 |
31 | def __init__(self,root='/home/haoluo/data', verbose=True, **kwargs):
32 | super(MSMT17, self).__init__()
33 | self.dataset_dir = osp.join(root, self.dataset_dir)
34 | self.train_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_train_v2')
35 | self.test_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_test_v2')
36 | self.list_train_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_train.txt')
37 | self.list_val_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_val.txt')
38 | self.list_query_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_query.txt')
39 | self.list_gallery_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_gallery.txt')
40 |
41 | self._check_before_run()
42 | train = self._process_dir(self.train_dir, self.list_train_path)
43 | #val, num_val_pids, num_val_imgs = self._process_dir(self.train_dir, self.list_val_path)
44 | query = self._process_dir(self.test_dir, self.list_query_path)
45 | gallery = self._process_dir(self.test_dir, self.list_gallery_path)
46 | if verbose:
47 | print("=> MSMT17 loaded")
48 | self.print_dataset_statistics(train, query, gallery)
49 |
50 | self.train = train
51 | self.query = query
52 | self.gallery = gallery
53 |
54 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
55 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
56 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
57 |
58 | def _check_before_run(self):
59 | """Check if all files are available before going deeper"""
60 | if not osp.exists(self.dataset_dir):
61 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
62 | if not osp.exists(self.train_dir):
63 | raise RuntimeError("'{}' is not available".format(self.train_dir))
64 | if not osp.exists(self.test_dir):
65 | raise RuntimeError("'{}' is not available".format(self.test_dir))
66 |
67 | def _process_dir(self, dir_path, list_path):
68 | with open(list_path, 'r') as txt:
69 | lines = txt.readlines()
70 | dataset = []
71 | pid_container = set()
72 | for img_idx, img_info in enumerate(lines):
73 | img_path, pid = img_info.split(' ')
74 | pid = int(pid) # no need to relabel
75 | camid = int(img_path.split('_')[2])
76 | img_path = osp.join(dir_path, img_path)
77 | dataset.append((img_path, pid, camid))
78 | pid_container.add(pid)
79 |
80 | # check if pid starts from 0 and increments with 1
81 | for idx, pid in enumerate(pid_container):
82 | assert idx == pid, "See code comment for explanation"
83 | return dataset
--------------------------------------------------------------------------------
/data/datasets/nformer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 |
5 | import torch
6 | import torchvision
7 | from torch.utils import data
8 |
9 | import glob
10 | from sklearn.preprocessing import normalize
11 | import random
12 |
13 | class NFormerDataset(data.Dataset):
14 | def __init__(self, data, data_length = 7000):
15 | self.data_length = data_length
16 | self.feats = data[0]
17 | self.ids = data[1]
18 | self.data_num = self.feats.shape[0]
19 |
20 | def __len__(self):
21 | return self.feats.shape[0]//30
22 |
23 | def __getitem__(self, index):
24 | center_index = random.randint(0, self.data_num - 1)
25 | center_feat = self.feats[center_index].unsqueeze(0)
26 | center_pid = self.ids[center_index]
27 |
28 | selected_flags = torch.zeros(self.data_num)
29 | selected_flags[center_index] = 1
30 | distmat = 1 - torch.mm(center_feat, self.feats.transpose(0,1))
31 | indices = torch.argsort(distmat, dim=1).numpy()
32 | indices = indices[0,:int(self.data_length * (1 + random.random()))].tolist()
33 | indices = random.sample(indices,self.data_length)
34 |
35 | random.shuffle(indices)
36 | feat_ = self.feats[indices]
37 | id_ = self.ids[indices]
38 |
39 |
40 | return feat_, id_
41 |
42 |
43 |
--------------------------------------------------------------------------------
/data/datasets/veri.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import re
3 |
4 | import os.path as osp
5 |
6 | from .bases import BaseImageDataset
7 |
8 |
9 | class VeRi(BaseImageDataset):
10 | """
11 | VeRi-776
12 | Reference:
13 | Liu, Xinchen, et al. "Large-scale vehicle re-identification in urban surveillance videos." ICME 2016.
14 |
15 | URL:https://vehiclereid.github.io/VeRi/
16 |
17 | Dataset statistics:
18 | # identities: 776
19 | # images: 37778 (train) + 1678 (query) + 11579 (gallery)
20 | # cameras: 20
21 | """
22 |
23 | dataset_dir = 'veri'
24 |
25 | def __init__(self, root='../', verbose=True, **kwargs):
26 | super(VeRi, self).__init__()
27 | self.dataset_dir = osp.join(root, self.dataset_dir)
28 | self.train_dir = osp.join(self.dataset_dir, 'image_train')
29 | self.query_dir = osp.join(self.dataset_dir, 'image_query')
30 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
31 |
32 | self._check_before_run()
33 |
34 | train = self._process_dir(self.train_dir, relabel=True)
35 | query = self._process_dir(self.query_dir, relabel=False)
36 | gallery = self._process_dir(self.gallery_dir, relabel=False)
37 |
38 | if verbose:
39 | print("=> VeRi-776 loaded")
40 | self.print_dataset_statistics(train, query, gallery)
41 |
42 | self.train = train
43 | self.query = query
44 | self.gallery = gallery
45 |
46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
49 |
50 | def _check_before_run(self):
51 | """Check if all files are available before going deeper"""
52 | if not osp.exists(self.dataset_dir):
53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
54 | if not osp.exists(self.train_dir):
55 | raise RuntimeError("'{}' is not available".format(self.train_dir))
56 | if not osp.exists(self.query_dir):
57 | raise RuntimeError("'{}' is not available".format(self.query_dir))
58 | if not osp.exists(self.gallery_dir):
59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
60 |
61 | def _process_dir(self, dir_path, relabel=False):
62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
63 | pattern = re.compile(r'([-\d]+)_c(\d+)')
64 |
65 | pid_container = set()
66 | for img_path in img_paths:
67 | pid, _ = map(int, pattern.search(img_path).groups())
68 | if pid == -1: continue # junk images are just ignored
69 | pid_container.add(pid)
70 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
71 |
72 | dataset = []
73 | for img_path in img_paths:
74 | pid, camid = map(int, pattern.search(img_path).groups())
75 | if pid == -1: continue # junk images are just ignored
76 | assert 0 <= pid <= 776 # pid == 0 means background
77 | assert 1 <= camid <= 20
78 | camid -= 1 # index starts from 0
79 | if relabel: pid = pid2label[pid]
80 | dataset.append((img_path, pid, camid))
81 |
82 | return dataset
83 |
84 |
--------------------------------------------------------------------------------
/data/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .triplet_sampler import RandomIdentitySampler, RandomIdentitySampler_alignedreid # new add by gu
8 |
--------------------------------------------------------------------------------
/data/samplers/triplet_sampler.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import copy
8 | import random
9 | import torch
10 | from collections import defaultdict
11 |
12 | import numpy as np
13 | from torch.utils.data.sampler import Sampler
14 |
15 |
16 | class RandomIdentitySampler(Sampler):
17 | """
18 | Randomly sample N identities, then for each identity,
19 | randomly sample K instances, therefore batch size is N*K.
20 | Args:
21 | - data_source (list): list of (img_path, pid, camid).
22 | - num_instances (int): number of instances per identity in a batch.
23 | - batch_size (int): number of examples in a batch.
24 | """
25 |
26 | def __init__(self, data_source, batch_size, num_instances):
27 | self.data_source = data_source
28 | self.batch_size = batch_size
29 | self.num_instances = num_instances
30 | self.num_pids_per_batch = self.batch_size // self.num_instances
31 | self.index_dic = defaultdict(list)
32 | for index, (_, pid, _) in enumerate(self.data_source):
33 | self.index_dic[pid].append(index)
34 | self.pids = list(self.index_dic.keys())
35 |
36 | # estimate number of examples in an epoch
37 | self.length = 0
38 | for pid in self.pids:
39 | idxs = self.index_dic[pid]
40 | num = len(idxs)
41 | if num < self.num_instances:
42 | num = self.num_instances
43 | self.length += num - num % self.num_instances
44 |
45 | def __iter__(self):
46 | batch_idxs_dict = defaultdict(list)
47 |
48 | for pid in self.pids:
49 | idxs = copy.deepcopy(self.index_dic[pid])
50 | if len(idxs) < self.num_instances:
51 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
52 | random.shuffle(idxs)
53 | batch_idxs = []
54 | for idx in idxs:
55 | batch_idxs.append(idx)
56 | if len(batch_idxs) == self.num_instances:
57 | batch_idxs_dict[pid].append(batch_idxs)
58 | batch_idxs = []
59 |
60 | avai_pids = copy.deepcopy(self.pids)
61 | final_idxs = []
62 |
63 | while len(avai_pids) >= self.num_pids_per_batch:
64 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
65 | for pid in selected_pids:
66 | batch_idxs = batch_idxs_dict[pid].pop(0)
67 | final_idxs.extend(batch_idxs)
68 | if len(batch_idxs_dict[pid]) == 0:
69 | avai_pids.remove(pid)
70 |
71 | self.length = len(final_idxs)
72 | return iter(final_idxs)
73 |
74 | def __len__(self):
75 | return self.length
76 |
77 |
78 | # New add by gu
79 | class RandomIdentitySampler_alignedreid(Sampler):
80 | """
81 | Randomly sample N identities, then for each identity,
82 | randomly sample K instances, therefore batch size is N*K.
83 |
84 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
85 |
86 | Args:
87 | data_source (Dataset): dataset to sample from.
88 | num_instances (int): number of instances per identity.
89 | """
90 | def __init__(self, data_source, num_instances):
91 | self.data_source = data_source
92 | self.num_instances = num_instances
93 | self.index_dic = defaultdict(list)
94 | for index, (_, pid, _) in enumerate(data_source):
95 | self.index_dic[pid].append(index)
96 | self.pids = list(self.index_dic.keys())
97 | self.num_identities = len(self.pids)
98 |
99 | def __iter__(self):
100 | indices = torch.randperm(self.num_identities)
101 | ret = []
102 | for i in indices:
103 | pid = self.pids[i]
104 | t = self.index_dic[pid]
105 | replace = False if len(t) >= self.num_instances else True
106 | t = np.random.choice(t, size=self.num_instances, replace=replace)
107 | ret.extend(t)
108 | return iter(ret)
109 |
110 | def __len__(self):
111 | return self.num_identities * self.num_instances
112 |
--------------------------------------------------------------------------------
/data/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .build import build_transforms
8 |
--------------------------------------------------------------------------------
/data/transforms/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import torchvision.transforms as T
8 |
9 | from .transforms import RandomErasing
10 |
11 |
12 | def build_transforms(cfg, is_train=True):
13 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
14 | if is_train:
15 | transform = T.Compose([
16 | T.Resize(cfg.INPUT.SIZE_TRAIN),
17 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
18 | T.Pad(cfg.INPUT.PADDING),
19 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
20 | T.ToTensor(),
21 | normalize_transform,
22 | RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN)
23 | ])
24 | else:
25 | transform = T.Compose([
26 | T.Resize(cfg.INPUT.SIZE_TEST),
27 | T.ToTensor(),
28 | normalize_transform
29 | ])
30 |
31 | return transform
32 |
--------------------------------------------------------------------------------
/data/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import math
8 | import random
9 |
10 |
11 | class RandomErasing(object):
12 | """ Randomly selects a rectangle region in an image and erases its pixels.
13 | 'Random Erasing Data Augmentation' by Zhong et al.
14 | See https://arxiv.org/pdf/1708.04896.pdf
15 | Args:
16 | probability: The probability that the Random Erasing operation will be performed.
17 | sl: Minimum proportion of erased area against input image.
18 | sh: Maximum proportion of erased area against input image.
19 | r1: Minimum aspect ratio of erased area.
20 | mean: Erasing value.
21 | """
22 |
23 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
24 | self.probability = probability
25 | self.mean = mean
26 | self.sl = sl
27 | self.sh = sh
28 | self.r1 = r1
29 |
30 | def __call__(self, img):
31 |
32 | if random.uniform(0, 1) >= self.probability:
33 | return img
34 |
35 | for attempt in range(100):
36 | area = img.size()[1] * img.size()[2]
37 |
38 | target_area = random.uniform(self.sl, self.sh) * area
39 | aspect_ratio = random.uniform(self.r1, 1 / self.r1)
40 |
41 | h = int(round(math.sqrt(target_area * aspect_ratio)))
42 | w = int(round(math.sqrt(target_area / aspect_ratio)))
43 |
44 | if w < img.size()[2] and h < img.size()[1]:
45 | x1 = random.randint(0, img.size()[1] - h)
46 | y1 = random.randint(0, img.size()[2] - w)
47 | if img.size()[0] == 3:
48 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
49 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
50 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
51 | else:
52 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
53 | return img
54 |
55 | return img
56 |
--------------------------------------------------------------------------------
/engine/inference.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | import logging
7 |
8 | import torch
9 | import torch.nn as nn
10 | from ignite.engine import Engine
11 |
12 | from utils.reid_metric import R1_mAP, NFormer_R1_mAP
13 |
14 |
15 | def create_supervised_evaluator(model, metrics,
16 | device=None):
17 | """
18 | Factory function for creating an evaluator for supervised models
19 |
20 | Args:
21 | model (`torch.nn.Module`): the model to train
22 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
23 | device (str, optional): device type specification (default: None).
24 | Applies to both model and batches.
25 | Returns:
26 | Engine: an evaluator engine with supervised inference function
27 | """
28 | if device:
29 | if torch.cuda.device_count() > 1:
30 | model = nn.DataParallel(model)
31 | model.to(device)
32 |
33 | def _inference(engine, batch):
34 | model.eval()
35 | with torch.no_grad():
36 | data, pids, camids = batch
37 | data = data.to(device) if torch.cuda.device_count() >= 1 else data
38 | feat = model(data)
39 | return feat, pids, camids
40 |
41 | engine = Engine(_inference)
42 |
43 | for name, metric in metrics.items():
44 | metric.attach(engine, name)
45 |
46 | return engine
47 |
48 |
49 | def inference(
50 | cfg,
51 | model,
52 | val_loader,
53 | num_query
54 | ):
55 | device = cfg.MODEL.DEVICE
56 |
57 | logger = logging.getLogger("reid_baseline.inference")
58 | logger.info("Enter inferencing")
59 | if cfg.TEST.TEST_NFORMER != 'yes':
60 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},device=device)
61 | else:
62 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': NFormer_R1_mAP(model, num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},device=device)
63 |
64 | evaluator.run(val_loader)
65 | cmc, mAP = evaluator.state.metrics['r1_mAP']
66 | logger.info('Validation Results')
67 | logger.info("mAP: {:.1%}".format(mAP))
68 | for r in [1, 5, 10]:
69 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
70 |
--------------------------------------------------------------------------------
/engine/trainer.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | import os
7 | import logging
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.utils import data
12 | from ignite.engine import Engine, Events
13 | from ignite.handlers import ModelCheckpoint, Timer
14 | from ignite.metrics import RunningAverage, Metric
15 |
16 | from utils.reid_metric import R1_mAP, NFormer_R1_mAP
17 | from data.datasets.nformer import NFormerDataset
18 |
19 | global ITER
20 | ITER = 0
21 |
22 | def create_supervised_trainer(model, optimizer, loss_fn,
23 | device=None):
24 | """
25 | Factory function for creating a trainer for supervised models
26 |
27 | Args:
28 | model (`torch.nn.Module`): the model to train
29 | optimizer (`torch.optim.Optimizer`): the optimizer to use
30 | loss_fn (torch.nn loss function): the loss function to use
31 | device (str, optional): device type specification (default: None).
32 | Applies to both model and batches.
33 |
34 | Returns:
35 | Engine: a trainer engine with supervised update function
36 | """
37 | if device:
38 | if torch.cuda.device_count() > 1:
39 | model = nn.DataParallel(model)
40 | model.to(device)
41 |
42 | def _update(engine, batch):
43 | model.train()
44 | optimizer.zero_grad()
45 | img, target = batch
46 | img = img.to(device) if torch.cuda.device_count() >= 1 else img
47 | target = target.to(device) if torch.cuda.device_count() >= 1 else target
48 | score, feat = model(img, stage='encoder')
49 | loss = loss_fn(score, feat, target)
50 | loss.backward()
51 | optimizer.step()
52 | # compute acc
53 | acc = (score.max(1)[1] == target).float().mean()
54 | return loss.item(), acc.item()
55 |
56 | return Engine(_update)
57 |
58 |
59 | def create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cetner_loss_weight,
60 | device=None):
61 | """
62 | Factory function for creating a trainer for supervised models
63 |
64 | Args:
65 | model (`torch.nn.Module`): the model to train
66 | optimizer (`torch.optim.Optimizer`): the optimizer to use
67 | loss_fn (torch.nn loss function): the loss function to use
68 | device (str, optional): device type specification (default: None).
69 | Applies to both model and batches.
70 |
71 | Returns:
72 | Engine: a trainer engine with supervised update function
73 | """
74 | if device:
75 | if torch.cuda.device_count() > 1:
76 | model = nn.DataParallel(model)
77 | model.to(device)
78 |
79 | def _update(engine, batch):
80 | model.train()
81 | optimizer.zero_grad()
82 | optimizer_center.zero_grad()
83 | img, target = batch
84 | img = img.to(device) if torch.cuda.device_count() >= 1 else img
85 | target = target.to(device) if torch.cuda.device_count() >= 1 else target
86 | score, feat = model(img, stage='encoder')
87 | loss = loss_fn(score, feat, target)
88 | # print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target)))
89 | loss.backward()
90 | optimizer.step()
91 | for param in center_criterion.parameters():
92 | param.grad.data *= (1. / cetner_loss_weight)
93 | optimizer_center.step()
94 |
95 | # compute acc
96 | acc = (score.max(1)[1] == target).float().mean()
97 | return loss.item(), acc.item()
98 |
99 | return Engine(_update)
100 |
101 | def create_supervised_nformer_trainer(model, nformer_center_criterion, optimizer, optimizer_nformer_center, nformer_loss_fn, cetner_loss_weight, device=None):
102 | if device:
103 | if torch.cuda.device_count() > 1:
104 | model = nn.DataParallel(model)
105 | model.to(device)
106 |
107 | def _update(engine, batch):
108 | model.train()
109 | optimizer.zero_grad()
110 | optimizer_nformer_center.zero_grad()
111 | feat, target = batch
112 | feat = feat.to(device) if torch.cuda.device_count() >= 1 else feat
113 | target = target.to(device) if torch.cuda.device_count() >= 1 else target
114 | score, feat = model(feat, stage='nformer')
115 | bs,dl,d = feat.shape
116 | score = score.reshape(bs * dl, -1)
117 | feat = feat.reshape(bs * dl, d)
118 | target = target.reshape(bs * dl)
119 | loss = nformer_loss_fn(score, feat, target)
120 | # print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target)))
121 | loss.backward()
122 | optimizer.step()
123 | for param in nformer_center_criterion.parameters():
124 | param.grad.data *= (1. / cetner_loss_weight)
125 | optimizer_nformer_center.step()
126 |
127 | # compute acc
128 | acc = (score.max(1)[1] == target).float().mean()
129 | return loss.item(), acc.item()
130 |
131 |
132 | return Engine(_update)
133 |
134 | class data_collector(Metric):
135 | def reset(self):
136 | self.feat = []
137 | self.target = []
138 | def update(self,output):
139 | feat, target = output
140 | self.feat.append(feat)
141 | self.target.append(target)
142 | def compute(self):
143 | feat = torch.cat(self.feat, dim=0)
144 | #feat = torch.nn.functional.normalize(feat, dim=1, p=2)
145 | target = torch.cat(self.target, dim=0)
146 | return feat, target
147 |
148 | def create_nformer_data_generator(model, metrics, device=None):
149 | if device:
150 | if torch.cuda.device_count() > 1:
151 | model = nn.DataParallel(model)
152 | model.to(device)
153 | def _inference(engine, batch):
154 | model.eval()
155 | img, target = batch
156 | img = img.to(device) if torch.cuda.device_count() >= 1 else img
157 | with torch.no_grad():
158 | feat = model(img, stage='encoder')
159 | return feat.cpu(), target
160 |
161 | engine = Engine(_inference)
162 |
163 | for name, metric in metrics.items():
164 | metric.attach(engine, name)
165 |
166 | return engine
167 |
168 | def create_supervised_evaluator(model, metrics,
169 | device=None):
170 | """
171 | Factory function for creating an evaluator for supervised models
172 |
173 | Args:
174 | model (`torch.nn.Module`): the model to train
175 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
176 | device (str, optional): device type specification (default: None).
177 | Applies to both model and batches.
178 | Returns:
179 | Engine: an evaluator engine with supervised inference function
180 | """
181 | if device:
182 | if torch.cuda.device_count() > 1:
183 | model = nn.DataParallel(model)
184 | model.to(device)
185 |
186 | def _inference(engine, batch):
187 | model.eval()
188 | with torch.no_grad():
189 | data, pids, camids = batch
190 | data = data.to(device) if torch.cuda.device_count() >= 1 else data
191 | feat = model(data)
192 | return feat, pids, camids
193 |
194 | engine = Engine(_inference)
195 |
196 | for name, metric in metrics.items():
197 | metric.attach(engine, name)
198 |
199 | return engine
200 |
201 |
202 | def do_train(
203 | cfg,
204 | model,
205 | train_loader,
206 | val_loader,
207 | optimizer,
208 | scheduler,
209 | loss_fn,
210 | num_query,
211 | start_epoch
212 | ):
213 | log_period = cfg.SOLVER.LOG_PERIOD
214 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
215 | eval_period = cfg.SOLVER.EVAL_PERIOD
216 | output_dir = cfg.OUTPUT_DIR
217 | device = cfg.MODEL.DEVICE
218 | epochs = cfg.SOLVER.MAX_EPOCHS
219 |
220 | logger = logging.getLogger("reid_baseline.train")
221 | logger.info("Start training")
222 | trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
223 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
224 | data_generator = create_nformer_data_generator(model, metrics={'data':data_collector()}, device=device)
225 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
226 | timer = Timer(average=True)
227 |
228 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
229 | 'optimizer': optimizer})
230 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
231 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
232 |
233 | # average metric to attach on trainer
234 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
235 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
236 |
237 | @trainer.on(Events.STARTED)
238 | def start_training(engine):
239 | engine.state.epoch = start_epoch
240 |
241 | @trainer.on(Events.EPOCH_STARTED)
242 | def adjust_learning_rate(engine):
243 | scheduler.step()
244 |
245 | @trainer.on(Events.ITERATION_COMPLETED)
246 | def log_training_loss(engine):
247 | global ITER
248 | ITER += 1
249 |
250 | if ITER % log_period == 0:
251 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
252 | .format(engine.state.epoch, ITER, len(train_loader),
253 | engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
254 | scheduler.get_lr()[0]))
255 | if len(train_loader) == ITER:
256 | ITER = 0
257 |
258 | @trainer.on(Events.EPOCH_COMPLETED)
259 | def tarin_nfofmer(engine):
260 | data_generator.run(train_loader)
261 |
262 |
263 | # adding handlers using `trainer.on` decorator API
264 | @trainer.on(Events.EPOCH_COMPLETED)
265 | def print_times(engine):
266 | logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
267 | .format(engine.state.epoch, timer.value() * timer.step_count,
268 | train_loader.batch_size / timer.value()))
269 | logger.info('-' * 10)
270 | timer.reset()
271 |
272 | @trainer.on(Events.EPOCH_COMPLETED)
273 | def log_validation_results(engine):
274 | if engine.state.epoch % eval_period == 0:
275 | evaluator.run(val_loader)
276 | cmc, mAP = evaluator.state.metrics['r1_mAP']
277 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
278 | logger.info("mAP: {:.1%}".format(mAP))
279 | for r in [1, 5, 10]:
280 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
281 |
282 | trainer.run(train_loader, max_epochs=epochs)
283 |
284 |
285 | def do_train_with_center(
286 | cfg,
287 | model,
288 | center_criterion,
289 | nformer_center_criterion,
290 | train_loader,
291 | val_loader,
292 | optimizer,
293 | optimizer_center,
294 | optimizer_nformer,
295 | optimizer_nformer_center,
296 | scheduler,
297 | loss_fn,
298 | nformer_loss_fn,
299 | num_query,
300 | start_epoch
301 | ):
302 | log_period = cfg.SOLVER.LOG_PERIOD
303 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
304 | eval_period = cfg.SOLVER.EVAL_PERIOD
305 | output_dir = cfg.OUTPUT_DIR
306 | device = cfg.MODEL.DEVICE
307 | epochs = cfg.SOLVER.MAX_EPOCHS
308 | nformer_epochs = cfg.SOLVER.NFORMER_MAX_EPOCHS
309 |
310 | logger = logging.getLogger("reid_baseline.train")
311 | logger.info("Start training")
312 | trainer = create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device)
313 | nformer_trainer = create_supervised_nformer_trainer(model, nformer_center_criterion, optimizer_nformer, optimizer_nformer_center, nformer_loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device)
314 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
315 | nformer_evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': NFormer_R1_mAP(model, num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
316 | data_generator = create_nformer_data_generator(model, metrics={'data':data_collector()}, device=device)
317 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
318 | timer = Timer(average=True)
319 |
320 |
321 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
322 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
323 |
324 | # average metric to attach on trainer
325 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
326 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
327 |
328 | @trainer.on(Events.STARTED)
329 | def start_training(engine):
330 | engine.state.epoch = start_epoch
331 |
332 | @trainer.on(Events.EPOCH_STARTED)
333 | def adjust_learning_rate(engine):
334 | scheduler.step()
335 |
336 | @trainer.on(Events.ITERATION_COMPLETED)
337 | def log_training_loss(engine):
338 | global ITER
339 | ITER += 1
340 |
341 | if ITER % log_period == 0:
342 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
343 | .format(engine.state.epoch, ITER, len(train_loader),
344 | engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
345 | scheduler.get_lr()[0]))
346 | if len(train_loader) == ITER:
347 | ITER = 0
348 |
349 |
350 | # adding handlers using `trainer.on` decorator API
351 | @trainer.on(Events.EPOCH_COMPLETED)
352 | def print_times(engine):
353 | logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
354 | .format(engine.state.epoch, timer.value() * timer.step_count,
355 | train_loader.batch_size / timer.value()))
356 | logger.info('-' * 10)
357 | timer.reset()
358 |
359 | @trainer.on(Events.EPOCH_COMPLETED)
360 | def log_validation_results(engine):
361 | if engine.state.epoch % eval_period == 0:
362 | evaluator.run(val_loader)
363 | cmc, mAP = evaluator.state.metrics['r1_mAP']
364 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
365 | logger.info("mAP: {:.1%}".format(mAP))
366 | for r in [1, 5, 10]:
367 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
368 |
369 | @trainer.on(Events.EPOCH_COMPLETED)
370 | def tarin_nformer(engine):
371 | if engine.state.epoch < epochs:
372 | return
373 | for n_epoch in range(nformer_epochs):
374 | data_generator.run(train_loader)
375 | nformer_dataset = NFormerDataset(data_generator.state.metrics['data'])
376 | nformer_trainloader = data.DataLoader(nformer_dataset, batch_size=2, num_workers=1,shuffle = True, pin_memory=True)
377 | nformer_trainer.run(nformer_trainloader, max_epochs=1)
378 | if (n_epoch+1)%5 == 0:
379 | print('evaluate nformer at epoch {}'.format(n_epoch))
380 | nformer_evaluator.run(val_loader)
381 | cmc, mAP = nformer_evaluator.state.metrics['r1_mAP']
382 | logger.info("mAP: {:.1%}".format(mAP))
383 | for r in [1, 5, 10]:
384 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
385 |
386 |
387 |
388 | trainer.run(train_loader, max_epochs=epochs)
389 | if not os.path.exists(output_dir):
390 | os.makedirs(output_dir)
391 | torch.save(model.state_dict(), os.path.join(output_dir, 'nformer_model.pth'))
392 |
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch.nn.functional as F
8 |
9 | from .triplet_loss import TripletLoss, CrossEntropyLabelSmooth
10 | from .center_loss import CenterLoss
11 |
12 |
13 | def make_loss(cfg, num_classes): # modified by gu
14 | sampler = cfg.DATALOADER.SAMPLER
15 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
16 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
17 | else:
18 | print('expected METRIC_LOSS_TYPE should be triplet'
19 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
20 |
21 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
22 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
23 | print("label smooth on, numclasses:", num_classes)
24 |
25 | if sampler == 'softmax':
26 | def loss_func(score, feat, target):
27 | return F.cross_entropy(score, target)
28 | elif cfg.DATALOADER.SAMPLER == 'triplet':
29 | def loss_func(score, feat, target):
30 | return triplet(feat, target)[0]
31 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
32 | def loss_func(score, feat, target):
33 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
34 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
35 | return xent(score, target) + triplet(feat, target)[0]
36 | else:
37 | return F.cross_entropy(score, target) + triplet(feat, target)[0]
38 | else:
39 | print('expected METRIC_LOSS_TYPE should be triplet'
40 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
41 | else:
42 | print('expected sampler should be softmax, triplet or softmax_triplet, '
43 | 'but got {}'.format(cfg.DATALOADER.SAMPLER))
44 | return loss_func
45 |
46 |
47 | def make_loss_with_center(cfg, num_classes): # modified by gu
48 | feat_dim = 256
49 |
50 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
51 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
52 |
53 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
54 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
55 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
56 |
57 | else:
58 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
59 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
60 |
61 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
62 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
63 | print("label smooth on, numclasses:", num_classes)
64 |
65 | def loss_func(score, feat, target):
66 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
67 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
68 | return xent(score, target) + \
69 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
70 | else:
71 | return F.cross_entropy(score, target) + \
72 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
73 |
74 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
75 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
76 | return xent(score, target) + \
77 | triplet(feat, target)[0] + \
78 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
79 | else:
80 | return F.cross_entropy(score, target) + \
81 | triplet(feat, target)[0] + \
82 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
83 |
84 | else:
85 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
86 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
87 | return loss_func, center_criterion
88 |
89 | def make_nformer_loss_with_center(cfg, num_classes):
90 | feat_dim = 256
91 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
92 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True)
93 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
94 | def loss_func(score, feat, target):
95 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
96 | return xent(score, target) + \
97 | triplet(feat, target)[0] + \
98 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
99 | else:
100 | return F.cross_entropy(score, target) + \
101 | triplet(feat, target)[0] + \
102 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
103 | return loss_func, center_criterion
104 |
105 |
--------------------------------------------------------------------------------
/layers/center_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class CenterLoss(nn.Module):
8 | """Center loss.
9 |
10 | Reference:
11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
12 |
13 | Args:
14 | num_classes (int): number of classes.
15 | feat_dim (int): feature dimension.
16 | """
17 |
18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
19 | super(CenterLoss, self).__init__()
20 | self.num_classes = num_classes
21 | self.feat_dim = feat_dim
22 | self.use_gpu = use_gpu
23 |
24 | if self.use_gpu:
25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
26 | else:
27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
28 |
29 | def forward(self, x, labels):
30 | """
31 | Args:
32 | x: feature matrix with shape (batch_size, feat_dim).
33 | labels: ground truth labels with shape (num_classes).
34 | """
35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
36 |
37 | batch_size = x.size(0)
38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
40 | distmat.addmm_(1, -2, x, self.centers.t())
41 |
42 | classes = torch.arange(self.num_classes).long()
43 | if self.use_gpu: classes = classes.cuda()
44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
45 | mask = labels.eq(classes.expand(batch_size, self.num_classes))
46 |
47 | dist = distmat * mask.float()
48 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
49 | #dist = []
50 | #for i in range(batch_size):
51 | # value = distmat[i][mask[i]]
52 | # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability
53 | # dist.append(value)
54 | #dist = torch.cat(dist)
55 | #loss = dist.mean()
56 | return loss
57 |
58 |
59 | if __name__ == '__main__':
60 | use_gpu = False
61 | center_loss = CenterLoss(use_gpu=use_gpu)
62 | features = torch.rand(16, 2048)
63 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long()
64 | if use_gpu:
65 | features = torch.rand(16, 2048).cuda()
66 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda()
67 |
68 | loss = center_loss(features, targets)
69 | print(loss)
70 |
--------------------------------------------------------------------------------
/layers/triplet_loss.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | import torch
7 | from torch import nn
8 |
9 |
10 | def normalize(x, axis=-1):
11 | """Normalizing to unit length along the specified dimension.
12 | Args:
13 | x: pytorch Variable
14 | Returns:
15 | x: pytorch Variable, same shape as input
16 | """
17 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
18 | return x
19 |
20 |
21 | def euclidean_dist(x, y):
22 | """
23 | Args:
24 | x: pytorch Variable, with shape [m, d]
25 | y: pytorch Variable, with shape [n, d]
26 | Returns:
27 | dist: pytorch Variable, with shape [m, n]
28 | """
29 | m, n = x.size(0), y.size(0)
30 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
31 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
32 | dist = xx + yy
33 | dist.addmm_(1, -2, x, y.t())
34 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
35 | return dist
36 |
37 |
38 | def hard_example_mining(dist_mat, labels, return_inds=False):
39 | """For each anchor, find the hardest positive and negative sample.
40 | Args:
41 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
42 | labels: pytorch LongTensor, with shape [N]
43 | return_inds: whether to return the indices. Save time if `False`(?)
44 | Returns:
45 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
46 | dist_an: pytorch Variable, distance(anchor, negative); shape [N]
47 | p_inds: pytorch LongTensor, with shape [N];
48 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
49 | n_inds: pytorch LongTensor, with shape [N];
50 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
51 | NOTE: Only consider the case in which all labels have same num of samples,
52 | thus we can cope with all anchors in parallel.
53 | """
54 |
55 | assert len(dist_mat.size()) == 2
56 | assert dist_mat.size(0) == dist_mat.size(1)
57 | N = dist_mat.size(0)
58 |
59 | # shape [N, N]
60 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
61 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
62 |
63 | # `dist_ap` means distance(anchor, positive)
64 | # both `dist_ap` and `relative_p_inds` with shape [N, 1]
65 | dist_ap, relative_p_inds = torch.max(
66 | dist_mat * is_pos, 1, keepdim=True)
67 | #dist_ap, relative_p_inds = torch.max(
68 | # dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
69 | # `dist_an` means distance(anchor, negative)
70 | # both `dist_an` and `relative_n_inds` with shape [N, 1]
71 | dist_an, relative_n_inds = torch.min(
72 | dist_mat * is_neg + is_pos * 1e8, 1, keepdim=True)
73 | #dist_an, relative_n_inds = torch.min(
74 | # dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
75 | # shape [N]
76 | dist_ap = dist_ap.squeeze(1)
77 | dist_an = dist_an.squeeze(1)
78 |
79 | if return_inds:
80 | # shape [N, N]
81 | ind = (labels.new().resize_as_(labels)
82 | .copy_(torch.arange(0, N).long())
83 | .unsqueeze(0).expand(N, N))
84 | # shape [N, 1]
85 | p_inds = torch.gather(
86 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
87 | n_inds = torch.gather(
88 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
89 | # shape [N]
90 | p_inds = p_inds.squeeze(1)
91 | n_inds = n_inds.squeeze(1)
92 | return dist_ap, dist_an, p_inds, n_inds
93 |
94 | return dist_ap, dist_an
95 |
96 |
97 | class TripletLoss(object):
98 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
99 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
100 | Loss for Person Re-Identification'."""
101 |
102 | def __init__(self, margin=None):
103 | self.margin = margin
104 | if margin is not None:
105 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
106 | else:
107 | self.ranking_loss = nn.SoftMarginLoss()
108 |
109 | def __call__(self, global_feat, labels, normalize_feature=False):
110 | if normalize_feature:
111 | global_feat = normalize(global_feat, axis=-1)
112 | dist_mat = euclidean_dist(global_feat, global_feat)
113 | dist_ap, dist_an = hard_example_mining(
114 | dist_mat, labels)
115 | y = dist_an.new().resize_as_(dist_an).fill_(1)
116 | if self.margin is not None:
117 | loss = self.ranking_loss(dist_an, dist_ap, y)
118 | else:
119 | loss = self.ranking_loss(dist_an - dist_ap, y)
120 | return loss, dist_ap, dist_an
121 |
122 | class CrossEntropyLabelSmooth(nn.Module):
123 | """Cross entropy loss with label smoothing regularizer.
124 |
125 | Reference:
126 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
127 | Equation: y = (1 - epsilon) * y + epsilon / K.
128 |
129 | Args:
130 | num_classes (int): number of classes.
131 | epsilon (float): weight.
132 | """
133 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
134 | super(CrossEntropyLabelSmooth, self).__init__()
135 | self.num_classes = num_classes
136 | self.epsilon = epsilon
137 | self.use_gpu = use_gpu
138 | self.logsoftmax = nn.LogSoftmax(dim=1)
139 |
140 | def forward(self, inputs, targets):
141 | """
142 | Args:
143 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
144 | targets: ground truth labels with shape (num_classes)
145 | """
146 | log_probs = self.logsoftmax(inputs)
147 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
148 | if self.use_gpu: targets = targets.cuda()
149 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
150 | loss = (- targets * log_probs).mean(0).sum()
151 | return loss
152 |
--------------------------------------------------------------------------------
/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .baseline import Baseline
8 | from .model import nformer_model
9 |
10 |
11 | def build_model(cfg, num_classes):
12 | # if cfg.MODEL.NAME == 'resnet50':
13 | # model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT)
14 | model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE)
15 | return model
16 |
17 | def build_nformer_model(cfg, num_classes):
18 | model = nformer_model(cfg, num_classes)
19 | return model
20 |
--------------------------------------------------------------------------------
/modeling/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 |
--------------------------------------------------------------------------------
/modeling/backbones/resnet.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import math
8 |
9 | import torch
10 | from torch import nn
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1):
14 | """3x3 convolution with padding"""
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=1, bias=False)
17 |
18 |
19 | class BasicBlock(nn.Module):
20 | expansion = 1
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None):
23 | super(BasicBlock, self).__init__()
24 | self.conv1 = conv3x3(inplanes, planes, stride)
25 | self.bn1 = nn.BatchNorm2d(planes)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv2 = conv3x3(planes, planes)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 | self.downsample = downsample
30 | self.stride = stride
31 |
32 | def forward(self, x):
33 | residual = x
34 |
35 | out = self.conv1(x)
36 | out = self.bn1(out)
37 | out = self.relu(out)
38 |
39 | out = self.conv2(out)
40 | out = self.bn2(out)
41 |
42 | if self.downsample is not None:
43 | residual = self.downsample(x)
44 |
45 | out += residual
46 | out = self.relu(out)
47 |
48 | return out
49 |
50 |
51 | class Bottleneck(nn.Module):
52 | expansion = 4
53 |
54 | def __init__(self, inplanes, planes, stride=1, downsample=None):
55 | super(Bottleneck, self).__init__()
56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
57 | self.bn1 = nn.BatchNorm2d(planes)
58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
59 | padding=1, bias=False)
60 | self.bn2 = nn.BatchNorm2d(planes)
61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
62 | self.bn3 = nn.BatchNorm2d(planes * 4)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.downsample = downsample
65 | self.stride = stride
66 |
67 | def forward(self, x):
68 | residual = x
69 |
70 | out = self.conv1(x)
71 | out = self.bn1(out)
72 | out = self.relu(out)
73 |
74 | out = self.conv2(out)
75 | out = self.bn2(out)
76 | out = self.relu(out)
77 |
78 | out = self.conv3(out)
79 | out = self.bn3(out)
80 |
81 | if self.downsample is not None:
82 | residual = self.downsample(x)
83 |
84 | out += residual
85 | out = self.relu(out)
86 |
87 | return out
88 |
89 |
90 | class ResNet(nn.Module):
91 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]):
92 | self.inplanes = 64
93 | super().__init__()
94 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
95 | bias=False)
96 | self.bn1 = nn.BatchNorm2d(64)
97 | # self.relu = nn.ReLU(inplace=True) # add missed relu
98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
99 | self.layer1 = self._make_layer(block, 64, layers[0])
100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
102 | self.layer4 = self._make_layer(
103 | block, 512, layers[3], stride=last_stride)
104 |
105 | def _make_layer(self, block, planes, blocks, stride=1):
106 | downsample = None
107 | if stride != 1 or self.inplanes != planes * block.expansion:
108 | downsample = nn.Sequential(
109 | nn.Conv2d(self.inplanes, planes * block.expansion,
110 | kernel_size=1, stride=stride, bias=False),
111 | nn.BatchNorm2d(planes * block.expansion),
112 | )
113 |
114 | layers = []
115 | layers.append(block(self.inplanes, planes, stride, downsample))
116 | self.inplanes = planes * block.expansion
117 | for i in range(1, blocks):
118 | layers.append(block(self.inplanes, planes))
119 |
120 | return nn.Sequential(*layers)
121 |
122 | def forward(self, x):
123 | x = self.conv1(x)
124 | x = self.bn1(x)
125 | # x = self.relu(x) # add missed relu
126 | x = self.maxpool(x)
127 |
128 | x = self.layer1(x)
129 | x = self.layer2(x)
130 | x = self.layer3(x)
131 | x = self.layer4(x)
132 |
133 | return x
134 |
135 | def load_param(self, model_path):
136 | param_dict = torch.load(model_path)
137 | for i in param_dict:
138 | if 'fc' in i:
139 | continue
140 | self.state_dict()[i].copy_(param_dict[i])
141 |
142 | def random_init(self):
143 | for m in self.modules():
144 | if isinstance(m, nn.Conv2d):
145 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
146 | m.weight.data.normal_(0, math.sqrt(2. / n))
147 | elif isinstance(m, nn.BatchNorm2d):
148 | m.weight.data.fill_(1)
149 | m.bias.data.zero_()
150 |
151 |
--------------------------------------------------------------------------------
/modeling/backbones/resnet_ibn_a.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import torch.utils.model_zoo as model_zoo
5 |
6 |
7 | __all__ = ['ResNet_IBN', 'resnet50_ibn_a', 'resnet101_ibn_a',
8 | 'resnet152_ibn_a']
9 |
10 |
11 | model_urls = {
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | }
16 |
17 |
18 | class IBN(nn.Module):
19 | def __init__(self, planes):
20 | super(IBN, self).__init__()
21 | half1 = int(planes/2)
22 | self.half = half1
23 | half2 = planes - half1
24 | self.IN = nn.InstanceNorm2d(half1, affine=True)
25 | self.BN = nn.BatchNorm2d(half2)
26 |
27 | def forward(self, x):
28 | split = torch.split(x, self.half, 1)
29 | out1 = self.IN(split[0].contiguous())
30 | out2 = self.BN(split[1].contiguous())
31 | out = torch.cat((out1, out2), 1)
32 | return out
33 |
34 |
35 | class Bottleneck_IBN(nn.Module):
36 | expansion = 4
37 |
38 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None):
39 | super(Bottleneck_IBN, self).__init__()
40 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
41 | if ibn:
42 | self.bn1 = IBN(planes)
43 | else:
44 | self.bn1 = nn.BatchNorm2d(planes)
45 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
46 | padding=1, bias=False)
47 | self.bn2 = nn.BatchNorm2d(planes)
48 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
49 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
50 | self.relu = nn.ReLU(inplace=True)
51 | self.downsample = downsample
52 | self.stride = stride
53 |
54 | def forward(self, x):
55 | residual = x
56 |
57 | out = self.conv1(x)
58 | out = self.bn1(out)
59 | out = self.relu(out)
60 |
61 | out = self.conv2(out)
62 | out = self.bn2(out)
63 | out = self.relu(out)
64 |
65 | out = self.conv3(out)
66 | out = self.bn3(out)
67 |
68 | if self.downsample is not None:
69 | residual = self.downsample(x)
70 |
71 | out += residual
72 | out = self.relu(out)
73 |
74 | return out
75 |
76 |
77 | class ResNet_IBN(nn.Module):
78 |
79 | def __init__(self, last_stride, block, layers, num_classes=1000):
80 | scale = 64
81 | self.inplanes = scale
82 | super(ResNet_IBN, self).__init__()
83 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3,
84 | bias=False)
85 | self.bn1 = nn.BatchNorm2d(scale)
86 | self.relu = nn.ReLU(inplace=True)
87 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
88 | self.layer1 = self._make_layer(block, scale, layers[0])
89 | self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2)
90 | self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2)
91 | self.layer4 = self._make_layer(block, scale*8, layers[3], stride=last_stride)
92 | self.avgpool = nn.AvgPool2d(7)
93 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes)
94 |
95 | for m in self.modules():
96 | if isinstance(m, nn.Conv2d):
97 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
98 | m.weight.data.normal_(0, math.sqrt(2. / n))
99 | elif isinstance(m, nn.BatchNorm2d):
100 | m.weight.data.fill_(1)
101 | m.bias.data.zero_()
102 | elif isinstance(m, nn.InstanceNorm2d):
103 | m.weight.data.fill_(1)
104 | m.bias.data.zero_()
105 |
106 | def _make_layer(self, block, planes, blocks, stride=1):
107 | downsample = None
108 | if stride != 1 or self.inplanes != planes * block.expansion:
109 | downsample = nn.Sequential(
110 | nn.Conv2d(self.inplanes, planes * block.expansion,
111 | kernel_size=1, stride=stride, bias=False),
112 | nn.BatchNorm2d(planes * block.expansion),
113 | )
114 |
115 | layers = []
116 | ibn = True
117 | if planes == 512:
118 | ibn = False
119 | layers.append(block(self.inplanes, planes, ibn, stride, downsample))
120 | self.inplanes = planes * block.expansion
121 | for i in range(1, blocks):
122 | layers.append(block(self.inplanes, planes, ibn))
123 |
124 | return nn.Sequential(*layers)
125 |
126 | def forward(self, x):
127 | x = self.conv1(x)
128 | x = self.bn1(x)
129 | x = self.relu(x)
130 | x = self.maxpool(x)
131 |
132 | x = self.layer1(x)
133 | x = self.layer2(x)
134 | x = self.layer3(x)
135 | x = self.layer4(x)
136 |
137 | # x = self.avgpool(x)
138 | # x = x.view(x.size(0), -1)
139 | # x = self.fc(x)
140 |
141 | return x
142 |
143 | def load_param(self, model_path):
144 | param_dict = torch.load(model_path)
145 | for i in param_dict:
146 | if 'fc' in i:
147 | continue
148 | self.state_dict()[i].copy_(param_dict[i])
149 |
150 |
151 | def resnet50_ibn_a(last_stride, pretrained=False, **kwargs):
152 | """Constructs a ResNet-50 model.
153 | Args:
154 | pretrained (bool): If True, returns a model pre-trained on ImageNet
155 | """
156 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 6, 3], **kwargs)
157 | if pretrained:
158 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
159 | return model
160 |
161 |
162 | def resnet101_ibn_a(last_stride, pretrained=False, **kwargs):
163 | """Constructs a ResNet-101 model.
164 | Args:
165 | pretrained (bool): If True, returns a model pre-trained on ImageNet
166 | """
167 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 23, 3], **kwargs)
168 | if pretrained:
169 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
170 | return model
171 |
172 |
173 | def resnet152_ibn_a(last_stride, pretrained=False, **kwargs):
174 | """Constructs a ResNet-152 model.
175 | Args:
176 | pretrained (bool): If True, returns a model pre-trained on ImageNet
177 | """
178 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 8, 36, 3], **kwargs)
179 | if pretrained:
180 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
181 | return model
--------------------------------------------------------------------------------
/modeling/backbones/senet.py:
--------------------------------------------------------------------------------
1 | """
2 | ResNet code gently borrowed from
3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
4 | """
5 | from __future__ import print_function, division, absolute_import
6 | from collections import OrderedDict
7 | import math
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils import model_zoo
11 |
12 | __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
13 | 'se_resnext50_32x4d', 'se_resnext101_32x4d']
14 |
15 | pretrained_settings = {
16 | 'senet154': {
17 | 'imagenet': {
18 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
19 | 'input_space': 'RGB',
20 | 'input_size': [3, 224, 224],
21 | 'input_range': [0, 1],
22 | 'mean': [0.485, 0.456, 0.406],
23 | 'std': [0.229, 0.224, 0.225],
24 | 'num_classes': 1000
25 | }
26 | },
27 | 'se_resnet50': {
28 | 'imagenet': {
29 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
30 | 'input_space': 'RGB',
31 | 'input_size': [3, 224, 224],
32 | 'input_range': [0, 1],
33 | 'mean': [0.485, 0.456, 0.406],
34 | 'std': [0.229, 0.224, 0.225],
35 | 'num_classes': 1000
36 | }
37 | },
38 | 'se_resnet101': {
39 | 'imagenet': {
40 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
41 | 'input_space': 'RGB',
42 | 'input_size': [3, 224, 224],
43 | 'input_range': [0, 1],
44 | 'mean': [0.485, 0.456, 0.406],
45 | 'std': [0.229, 0.224, 0.225],
46 | 'num_classes': 1000
47 | }
48 | },
49 | 'se_resnet152': {
50 | 'imagenet': {
51 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
52 | 'input_space': 'RGB',
53 | 'input_size': [3, 224, 224],
54 | 'input_range': [0, 1],
55 | 'mean': [0.485, 0.456, 0.406],
56 | 'std': [0.229, 0.224, 0.225],
57 | 'num_classes': 1000
58 | }
59 | },
60 | 'se_resnext50_32x4d': {
61 | 'imagenet': {
62 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
63 | 'input_space': 'RGB',
64 | 'input_size': [3, 224, 224],
65 | 'input_range': [0, 1],
66 | 'mean': [0.485, 0.456, 0.406],
67 | 'std': [0.229, 0.224, 0.225],
68 | 'num_classes': 1000
69 | }
70 | },
71 | 'se_resnext101_32x4d': {
72 | 'imagenet': {
73 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
74 | 'input_space': 'RGB',
75 | 'input_size': [3, 224, 224],
76 | 'input_range': [0, 1],
77 | 'mean': [0.485, 0.456, 0.406],
78 | 'std': [0.229, 0.224, 0.225],
79 | 'num_classes': 1000
80 | }
81 | },
82 | }
83 |
84 |
85 | class SEModule(nn.Module):
86 |
87 | def __init__(self, channels, reduction):
88 | super(SEModule, self).__init__()
89 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
90 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
91 | padding=0)
92 | self.relu = nn.ReLU(inplace=True)
93 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
94 | padding=0)
95 | self.sigmoid = nn.Sigmoid()
96 |
97 | def forward(self, x):
98 | module_input = x
99 | x = self.avg_pool(x)
100 | x = self.fc1(x)
101 | x = self.relu(x)
102 | x = self.fc2(x)
103 | x = self.sigmoid(x)
104 | return module_input * x
105 |
106 |
107 | class Bottleneck(nn.Module):
108 | """
109 | Base class for bottlenecks that implements `forward()` method.
110 | """
111 | def forward(self, x):
112 | residual = x
113 |
114 | out = self.conv1(x)
115 | out = self.bn1(out)
116 | out = self.relu(out)
117 |
118 | out = self.conv2(out)
119 | out = self.bn2(out)
120 | out = self.relu(out)
121 |
122 | out = self.conv3(out)
123 | out = self.bn3(out)
124 |
125 | if self.downsample is not None:
126 | residual = self.downsample(x)
127 |
128 | out = self.se_module(out) + residual
129 | out = self.relu(out)
130 |
131 | return out
132 |
133 |
134 | class SEBottleneck(Bottleneck):
135 | """
136 | Bottleneck for SENet154.
137 | """
138 | expansion = 4
139 |
140 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
141 | downsample=None):
142 | super(SEBottleneck, self).__init__()
143 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
144 | self.bn1 = nn.BatchNorm2d(planes * 2)
145 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
146 | stride=stride, padding=1, groups=groups,
147 | bias=False)
148 | self.bn2 = nn.BatchNorm2d(planes * 4)
149 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1,
150 | bias=False)
151 | self.bn3 = nn.BatchNorm2d(planes * 4)
152 | self.relu = nn.ReLU(inplace=True)
153 | self.se_module = SEModule(planes * 4, reduction=reduction)
154 | self.downsample = downsample
155 | self.stride = stride
156 |
157 |
158 | class SEResNetBottleneck(Bottleneck):
159 | """
160 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
161 | implementation and uses `stride=stride` in `conv1` and not in `conv2`
162 | (the latter is used in the torchvision implementation of ResNet).
163 | """
164 | expansion = 4
165 |
166 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
167 | downsample=None):
168 | super(SEResNetBottleneck, self).__init__()
169 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False,
170 | stride=stride)
171 | self.bn1 = nn.BatchNorm2d(planes)
172 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
173 | groups=groups, bias=False)
174 | self.bn2 = nn.BatchNorm2d(planes)
175 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
176 | self.bn3 = nn.BatchNorm2d(planes * 4)
177 | self.relu = nn.ReLU(inplace=True)
178 | self.se_module = SEModule(planes * 4, reduction=reduction)
179 | self.downsample = downsample
180 | self.stride = stride
181 |
182 |
183 | class SEResNeXtBottleneck(Bottleneck):
184 | """
185 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
186 | """
187 | expansion = 4
188 |
189 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
190 | downsample=None, base_width=4):
191 | super(SEResNeXtBottleneck, self).__init__()
192 | width = math.floor(planes * (base_width / 64)) * groups
193 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
194 | stride=1)
195 | self.bn1 = nn.BatchNorm2d(width)
196 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
197 | padding=1, groups=groups, bias=False)
198 | self.bn2 = nn.BatchNorm2d(width)
199 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
200 | self.bn3 = nn.BatchNorm2d(planes * 4)
201 | self.relu = nn.ReLU(inplace=True)
202 | self.se_module = SEModule(planes * 4, reduction=reduction)
203 | self.downsample = downsample
204 | self.stride = stride
205 |
206 |
207 | class SENet(nn.Module):
208 |
209 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
210 | inplanes=128, input_3x3=True, downsample_kernel_size=3,
211 | downsample_padding=1, last_stride=2):
212 | """
213 | Parameters
214 | ----------
215 | block (nn.Module): Bottleneck class.
216 | - For SENet154: SEBottleneck
217 | - For SE-ResNet models: SEResNetBottleneck
218 | - For SE-ResNeXt models: SEResNeXtBottleneck
219 | layers (list of ints): Number of residual blocks for 4 layers of the
220 | network (layer1...layer4).
221 | groups (int): Number of groups for the 3x3 convolution in each
222 | bottleneck block.
223 | - For SENet154: 64
224 | - For SE-ResNet models: 1
225 | - For SE-ResNeXt models: 32
226 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
227 | - For all models: 16
228 | dropout_p (float or None): Drop probability for the Dropout layer.
229 | If `None` the Dropout layer is not used.
230 | - For SENet154: 0.2
231 | - For SE-ResNet models: None
232 | - For SE-ResNeXt models: None
233 | inplanes (int): Number of input channels for layer1.
234 | - For SENet154: 128
235 | - For SE-ResNet models: 64
236 | - For SE-ResNeXt models: 64
237 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
238 | a single 7x7 convolution in layer0.
239 | - For SENet154: True
240 | - For SE-ResNet models: False
241 | - For SE-ResNeXt models: False
242 | downsample_kernel_size (int): Kernel size for downsampling convolutions
243 | in layer2, layer3 and layer4.
244 | - For SENet154: 3
245 | - For SE-ResNet models: 1
246 | - For SE-ResNeXt models: 1
247 | downsample_padding (int): Padding for downsampling convolutions in
248 | layer2, layer3 and layer4.
249 | - For SENet154: 1
250 | - For SE-ResNet models: 0
251 | - For SE-ResNeXt models: 0
252 | num_classes (int): Number of outputs in `last_linear` layer.
253 | - For all models: 1000
254 | """
255 | super(SENet, self).__init__()
256 | self.inplanes = inplanes
257 | if input_3x3:
258 | layer0_modules = [
259 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
260 | bias=False)),
261 | ('bn1', nn.BatchNorm2d(64)),
262 | ('relu1', nn.ReLU(inplace=True)),
263 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
264 | bias=False)),
265 | ('bn2', nn.BatchNorm2d(64)),
266 | ('relu2', nn.ReLU(inplace=True)),
267 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
268 | bias=False)),
269 | ('bn3', nn.BatchNorm2d(inplanes)),
270 | ('relu3', nn.ReLU(inplace=True)),
271 | ]
272 | else:
273 | layer0_modules = [
274 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
275 | padding=3, bias=False)),
276 | ('bn1', nn.BatchNorm2d(inplanes)),
277 | ('relu1', nn.ReLU(inplace=True)),
278 | ]
279 | # To preserve compatibility with Caffe weights `ceil_mode=True`
280 | # is used instead of `padding=1`.
281 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
282 | ceil_mode=True)))
283 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
284 | self.layer1 = self._make_layer(
285 | block,
286 | planes=64,
287 | blocks=layers[0],
288 | groups=groups,
289 | reduction=reduction,
290 | downsample_kernel_size=1,
291 | downsample_padding=0
292 | )
293 | self.layer2 = self._make_layer(
294 | block,
295 | planes=128,
296 | blocks=layers[1],
297 | stride=2,
298 | groups=groups,
299 | reduction=reduction,
300 | downsample_kernel_size=downsample_kernel_size,
301 | downsample_padding=downsample_padding
302 | )
303 | self.layer3 = self._make_layer(
304 | block,
305 | planes=256,
306 | blocks=layers[2],
307 | stride=2,
308 | groups=groups,
309 | reduction=reduction,
310 | downsample_kernel_size=downsample_kernel_size,
311 | downsample_padding=downsample_padding
312 | )
313 | self.layer4 = self._make_layer(
314 | block,
315 | planes=512,
316 | blocks=layers[3],
317 | stride=last_stride,
318 | groups=groups,
319 | reduction=reduction,
320 | downsample_kernel_size=downsample_kernel_size,
321 | downsample_padding=downsample_padding
322 | )
323 | self.avg_pool = nn.AvgPool2d(7, stride=1)
324 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
325 |
326 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
327 | downsample_kernel_size=1, downsample_padding=0):
328 | downsample = None
329 | if stride != 1 or self.inplanes != planes * block.expansion:
330 | downsample = nn.Sequential(
331 | nn.Conv2d(self.inplanes, planes * block.expansion,
332 | kernel_size=downsample_kernel_size, stride=stride,
333 | padding=downsample_padding, bias=False),
334 | nn.BatchNorm2d(planes * block.expansion),
335 | )
336 |
337 | layers = []
338 | layers.append(block(self.inplanes, planes, groups, reduction, stride,
339 | downsample))
340 | self.inplanes = planes * block.expansion
341 | for i in range(1, blocks):
342 | layers.append(block(self.inplanes, planes, groups, reduction))
343 |
344 | return nn.Sequential(*layers)
345 |
346 | def load_param(self, model_path):
347 | param_dict = torch.load(model_path)
348 | for i in param_dict:
349 | if 'last_linear' in i:
350 | continue
351 | self.state_dict()[i].copy_(param_dict[i])
352 |
353 | def forward(self, x):
354 | x = self.layer0(x)
355 | x = self.layer1(x)
356 | x = self.layer2(x)
357 | x = self.layer3(x)
358 | x = self.layer4(x)
359 | return x
--------------------------------------------------------------------------------
/modeling/baseline.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 | from torch import nn
9 |
10 | from .backbones.resnet import ResNet, BasicBlock, Bottleneck
11 | from .backbones.senet import SENet, SEResNetBottleneck, SEBottleneck, SEResNeXtBottleneck
12 | from .backbones.resnet_ibn_a import resnet50_ibn_a
13 |
14 |
15 | def weights_init_kaiming(m):
16 | classname = m.__class__.__name__
17 | if classname.find('Linear') != -1:
18 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
19 | nn.init.constant_(m.bias, 0.0)
20 | elif classname.find('Conv') != -1:
21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
22 | if m.bias is not None:
23 | nn.init.constant_(m.bias, 0.0)
24 | elif classname.find('BatchNorm') != -1:
25 | if m.affine:
26 | nn.init.constant_(m.weight, 1.0)
27 | nn.init.constant_(m.bias, 0.0)
28 |
29 |
30 | def weights_init_classifier(m):
31 | classname = m.__class__.__name__
32 | if classname.find('Linear') != -1:
33 | nn.init.normal_(m.weight, std=0.001)
34 | if m.bias:
35 | nn.init.constant_(m.bias, 0.0)
36 |
37 |
38 | class Baseline(nn.Module):
39 | in_planes = 2048
40 |
41 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice):
42 | super(Baseline, self).__init__()
43 | self.feature_dim = 256
44 | if model_name == 'resnet18':
45 | self.in_planes = 512
46 | self.base = ResNet(last_stride=last_stride,
47 | block=BasicBlock,
48 | layers=[2, 2, 2, 2])
49 | elif model_name == 'resnet34':
50 | self.in_planes = 512
51 | self.base = ResNet(last_stride=last_stride,
52 | block=BasicBlock,
53 | layers=[3, 4, 6, 3])
54 | elif model_name == 'resnet50':
55 | self.base = ResNet(last_stride=last_stride,
56 | block=Bottleneck,
57 | layers=[3, 4, 6, 3])
58 | elif model_name == 'resnet101':
59 | self.base = ResNet(last_stride=last_stride,
60 | block=Bottleneck,
61 | layers=[3, 4, 23, 3])
62 | elif model_name == 'resnet152':
63 | self.base = ResNet(last_stride=last_stride,
64 | block=Bottleneck,
65 | layers=[3, 8, 36, 3])
66 |
67 | elif model_name == 'se_resnet50':
68 | self.base = SENet(block=SEResNetBottleneck,
69 | layers=[3, 4, 6, 3],
70 | groups=1,
71 | reduction=16,
72 | dropout_p=None,
73 | inplanes=64,
74 | input_3x3=False,
75 | downsample_kernel_size=1,
76 | downsample_padding=0,
77 | last_stride=last_stride)
78 | elif model_name == 'se_resnet101':
79 | self.base = SENet(block=SEResNetBottleneck,
80 | layers=[3, 4, 23, 3],
81 | groups=1,
82 | reduction=16,
83 | dropout_p=None,
84 | inplanes=64,
85 | input_3x3=False,
86 | downsample_kernel_size=1,
87 | downsample_padding=0,
88 | last_stride=last_stride)
89 | elif model_name == 'se_resnet152':
90 | self.base = SENet(block=SEResNetBottleneck,
91 | layers=[3, 8, 36, 3],
92 | groups=1,
93 | reduction=16,
94 | dropout_p=None,
95 | inplanes=64,
96 | input_3x3=False,
97 | downsample_kernel_size=1,
98 | downsample_padding=0,
99 | last_stride=last_stride)
100 | elif model_name == 'se_resnext50':
101 | self.base = SENet(block=SEResNeXtBottleneck,
102 | layers=[3, 4, 6, 3],
103 | groups=32,
104 | reduction=16,
105 | dropout_p=None,
106 | inplanes=64,
107 | input_3x3=False,
108 | downsample_kernel_size=1,
109 | downsample_padding=0,
110 | last_stride=last_stride)
111 | elif model_name == 'se_resnext101':
112 | self.base = SENet(block=SEResNeXtBottleneck,
113 | layers=[3, 4, 23, 3],
114 | groups=32,
115 | reduction=16,
116 | dropout_p=None,
117 | inplanes=64,
118 | input_3x3=False,
119 | downsample_kernel_size=1,
120 | downsample_padding=0,
121 | last_stride=last_stride)
122 | elif model_name == 'senet154':
123 | self.base = SENet(block=SEBottleneck,
124 | layers=[3, 8, 36, 3],
125 | groups=64,
126 | reduction=16,
127 | dropout_p=0.2,
128 | last_stride=last_stride)
129 | elif model_name == 'resnet50_ibn_a':
130 | self.base = resnet50_ibn_a(last_stride)
131 |
132 | if pretrain_choice == 'imagenet':
133 | self.base.load_param(model_path)
134 | print('Loading pretrained ImageNet model......')
135 |
136 | self.gap = nn.AdaptiveAvgPool2d(1)
137 | # self.gap = nn.AdaptiveMaxPool2d(1)
138 | self.num_classes = num_classes
139 | self.neck = neck
140 | self.neck_feat = neck_feat
141 | self.projection = nn.Linear(self.in_planes, self.feature_dim)
142 |
143 | if self.neck == 'no':
144 | self.classifier = nn.Linear(self.feature_dim, self.num_classes)
145 | # self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) # new add by luo
146 | # self.classifier.apply(weights_init_classifier) # new add by luo
147 | elif self.neck == 'bnneck':
148 | self.bottleneck = nn.BatchNorm1d(self.feature_dim)
149 | self.bottleneck.bias.requires_grad_(False) # no shift
150 | self.classifier = nn.Linear(self.feature_dim, self.num_classes, bias=False)
151 |
152 | self.bottleneck.apply(weights_init_kaiming)
153 | self.classifier.apply(weights_init_classifier)
154 |
155 | def forward(self, x):
156 |
157 | global_feat = self.gap(self.base(x)) # (b, 2048, 1, 1)
158 | global_feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048)
159 | global_feat = self.projection(global_feat)
160 |
161 | if self.neck == 'no':
162 | feat = global_feat
163 | elif self.neck == 'bnneck':
164 | feat = self.bottleneck(global_feat) # normalize for angular softmax
165 |
166 | if self.training:
167 | cls_score = self.classifier(feat)
168 | return cls_score, global_feat # global feature for triplet loss
169 | else:
170 | if self.neck_feat == 'after':
171 | # print("Test with feature after BN")
172 | return feat
173 | else:
174 | # print("Test with feature before BN")
175 | return global_feat
176 |
177 | def load_param(self, trained_path):
178 | param_dict = torch.load(trained_path).state_dict()
179 | for i in param_dict:
180 | if 'classifier' in i:
181 | continue
182 | self.state_dict()[i].copy_(param_dict[i])
183 |
--------------------------------------------------------------------------------
/modeling/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .baseline import Baseline
3 | from .nformer import NFormer
4 | import torch.nn as nn
5 | class nformer_model(nn.Module):
6 | def __init__(self, cfg, num_classes):
7 | super(nformer_model, self).__init__()
8 | self.backbone = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE)
9 | self.nformer = NFormer(cfg, num_classes)
10 |
11 | def forward(self,x,stage = 'encoder'):
12 | if stage == 'encoder':
13 | if self.training:
14 | score, feat = self.backbone(x)
15 | return score, feat
16 | else:
17 | feat = self.backbone(x)
18 | return feat
19 |
20 | elif stage == 'nformer':
21 | feat = self.nformer(x)
22 | return feat
23 |
--------------------------------------------------------------------------------
/modeling/nformer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import math
4 | import re
5 | import collections
6 |
7 | import random
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn.parameter import Parameter
13 | import random
14 |
15 | def weights_init_kaiming(m):
16 | classname = m.__class__.__name__
17 | if classname.find('Linear') != -1:
18 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
19 | nn.init.constant_(m.bias, 0.0)
20 | elif classname.find('Conv') != -1:
21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
22 | if m.bias is not None:
23 | nn.init.constant_(m.bias, 0.0)
24 | elif classname.find('BatchNorm') != -1:
25 | if m.affine:
26 | nn.init.constant_(m.weight, 1.0)
27 | nn.init.constant_(m.bias, 0.0)
28 |
29 | def weights_init_classifier(m):
30 | classname = m.__class__.__name__
31 | if classname.find('Linear') != -1:
32 | nn.init.normal_(m.weight, std=0.001)
33 | if m.bias:
34 | nn.init.constant_(m.bias, 0.0)
35 |
36 | def gelu(x):
37 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
38 |
39 | def swish(x):
40 | return x * torch.sigmoid(x)
41 |
42 | ACT_FNS = {
43 | 'relu': nn.ReLU,
44 | 'swish': swish,
45 | 'gelu': gelu
46 | }
47 |
48 |
49 | class LayerNorm(nn.Module):
50 | "Construct a layernorm module in the OpenAI style (epsilon inside the square root)."
51 |
52 | def __init__(self, n_state, e=1e-5):
53 | super(LayerNorm, self).__init__()
54 | self.g = nn.Parameter(torch.ones(n_state))
55 | self.b = nn.Parameter(torch.zeros(n_state))
56 | self.e = e
57 |
58 | def forward(self, x):
59 | u = x.mean(-1, keepdim=True)
60 | s = (x - u).pow(2).mean(-1, keepdim=True)
61 | x = (x - u) / torch.sqrt(s + self.e)
62 | return self.g * x + self.b
63 |
64 |
65 | class Conv1D(nn.Module):
66 | def __init__(self, nf, rf, nx):
67 | super(Conv1D, self).__init__()
68 | self.rf = rf
69 | self.nf = nf
70 | if rf == 1: # faster 1x1 conv
71 | w = torch.empty(nx, nf)
72 | nn.init.normal_(w, std=0.02)
73 | self.w = Parameter(w)
74 | self.b = Parameter(torch.zeros(nf))
75 | else: # was used to train LM
76 | raise NotImplementedError
77 |
78 | def forward(self, x):
79 | if self.rf == 1:
80 | size_out = x.size()[:-1] + (self.nf,)
81 | x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
82 | x = x.view(*size_out)
83 | else:
84 | raise NotImplementedError
85 | return x
86 |
87 |
88 | class Attention(nn.Module):
89 | def __init__(self, nx, n_ctx, cfg, scale=False):
90 | super(Attention, self).__init__()
91 | n_state = nx
92 | assert n_state % cfg.MODEL.N_HEAD == 0
93 | self.n_head = cfg.MODEL.N_HEAD
94 | self.split_size = n_state
95 | self.scale = scale
96 | self.c_attn = Conv1D(n_state * 3, 1, nx)
97 | self.c_proj = Conv1D(n_state, 1, nx)
98 |
99 | self.resid_dropout = nn.Dropout(cfg.MODEL.RESID_PDROP)
100 |
101 | def _attn(self, q, k, v, num_landmark, rns_indices):
102 | data_length = q.shape[2]
103 | landmark = torch.Tensor(random.sample(range(data_length),num_landmark)).long()
104 |
105 | sq = q[:,:,landmark,:].contiguous()
106 | sk = k[:,:,:,landmark].contiguous()
107 |
108 | w1 = torch.matmul(q, sk)
109 | w2 = torch.matmul(sq, k)
110 | w = torch.matmul(w1, w2)
111 |
112 | if self.scale:
113 | w = w / math.sqrt(v.size(-1))
114 | return self.rns(w, v, rns_indices)
115 |
116 | def rns(self, w, v, rns_indices):
117 | bs,hn,dl,_ = w.shape
118 | rns_indices = rns_indices.unsqueeze(1).repeat(1,hn,1,1)
119 | mask = torch.zeros_like(w).scatter_(3, rns_indices,torch.ones_like(rns_indices, dtype=w.dtype))
120 | mask = mask * mask.transpose(2,3)
121 | if 'cuda' in str(w.device):
122 | mask = mask.cuda()
123 | else:
124 | mask = mask.cpu()
125 | if self.training:
126 | w = w * mask + -1e9 * (1 - mask)
127 | w = F.softmax(w,dim=3)
128 | a_v = torch.matmul(w, v)
129 | else:
130 | w = (w * mask).reshape(bs*hn,dl,dl).to_sparse()
131 | w = torch.sparse.softmax(w,2)
132 | v = v.reshape(bs*hn,dl,-1)
133 | a_v = torch.bmm(w,v).reshape(bs,hn,dl,-1)
134 | return a_v
135 |
136 |
137 | def merge_heads(self, x):
138 | x = x.permute(0, 2, 1, 3).contiguous()
139 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
140 | return x.view(*new_x_shape)
141 |
142 | def split_heads(self, x, k=False):
143 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
144 | x = x.view(*new_x_shape)
145 | if k:
146 | return x.permute(0, 2, 3, 1)
147 | else:
148 | return x.permute(0, 2, 1, 3)
149 |
150 |
151 | def forward(self, x, num_landmark, rns_indices):
152 | x = self.c_attn(x)
153 | query, key, value = x.split(self.split_size, dim=2)
154 | query = self.split_heads(query)
155 | key = self.split_heads(key, k=True)
156 | value = self.split_heads(value)
157 | mask = None
158 | a = self._attn(query, key, value, num_landmark, rns_indices)
159 | a = self.merge_heads(a)
160 | a = self.c_proj(a)
161 | a = self.resid_dropout(a)
162 | return a
163 |
164 |
165 | class MLP(nn.Module):
166 | def __init__(self, n_state, cfg):
167 | super(MLP, self).__init__()
168 | nx = cfg.MODEL.N_EMBD
169 | self.c_fc = Conv1D(n_state, 1, nx)
170 | self.c_proj = Conv1D(nx, 1, n_state)
171 | self.act = ACT_FNS[cfg.MODEL.AFN]
172 | self.dropout = nn.Dropout(cfg.MODEL.RESID_PDROP)
173 |
174 | def forward(self, x):
175 | h = self.act(self.c_fc(x))
176 | h2 = self.c_proj(h)
177 | return self.dropout(h2)
178 |
179 |
180 | class Block(nn.Module):
181 | def __init__(self, n_ctx, cfg, scale=False):
182 | super(Block, self).__init__()
183 | nx = cfg.MODEL.N_EMBD
184 | self.attn = Attention(nx, n_ctx, cfg, scale)
185 | self.ln_1 = LayerNorm(nx)
186 | self.mlp = MLP(4 * nx, cfg)
187 | self.ln_2 = LayerNorm(nx)
188 |
189 | def forward(self, x, num_landmark, rns_indices):
190 | a = self.attn(x, num_landmark, rns_indices)
191 | n = self.ln_1(x + a)
192 | m = self.mlp(n)
193 | h = self.ln_2(n + m)
194 | return h
195 |
196 |
197 | class NFormer(nn.Module):
198 | """ NFormer model """
199 |
200 | def __init__(self, cfg, vocab=40990, n_ctx=1024, num_classes = 751):
201 | super(NFormer, self).__init__()
202 | self.num_classes = num_classes
203 |
204 | block = Block(n_ctx, cfg, scale=True)
205 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.MODEL.N_LAYER)])
206 |
207 | self.bottleneck = nn.BatchNorm1d(cfg.MODEL.N_EMBD)
208 | self.bottleneck.bias.requires_grad_(False) # no shift
209 | self.bottleneck.apply(weights_init_kaiming)
210 |
211 | self.classifier = nn.Linear(cfg.MODEL.N_EMBD, self.num_classes, bias=False)
212 | self.classifier.apply(weights_init_classifier)
213 | self.topk = cfg.MODEL.TOPK
214 | self.num_landmark = cfg.MODEL.LANDMARK
215 |
216 | def forward(self, x):
217 | _, rns_indices = torch.topk(torch.bmm(x/torch.norm(x,p=2,dim=2,keepdim=True),(x/torch.norm(x,p=2,dim=2,keepdim=True)).transpose(1,2)), self.topk, dim=2)
218 | for block in self.h:
219 | x = block(x, self.num_landmark, rns_indices)
220 |
221 | bs,dl,d = x.shape
222 | x = x.reshape(bs*dl,d)
223 | feat = self.bottleneck(x)
224 | cls_score = self.classifier(feat)
225 | x = x.reshape(bs,dl,d)
226 | feat = feat.reshape(bs,dl,d)
227 | cls_score = cls_score.reshape(bs,dl,-1)
228 |
229 | if self.training:
230 | return cls_score, x
231 | else:
232 | return feat
233 |
234 |
235 |
236 |
--------------------------------------------------------------------------------
/pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haochenheheda/NFormer/c78cb848c6b8cf64e973a1ee0ce14488d4904f8f/pipeline.jpg
--------------------------------------------------------------------------------
/solver/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .build import make_optimizer, make_optimizer_with_center, make_nformer_optimizer_with_center
8 | from .lr_scheduler import WarmupMultiStepLR
9 |
--------------------------------------------------------------------------------
/solver/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 |
9 |
10 | def make_optimizer(cfg, model):
11 | params = []
12 | for key, value in model.named_parameters():
13 | if not value.requires_grad:
14 | continue
15 | lr = cfg.SOLVER.BASE_LR
16 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
17 | if "bias" in key:
18 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
19 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
20 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
23 | else:
24 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
25 | return optimizer
26 |
27 |
28 | def make_optimizer_with_center(cfg, model, center_criterion):
29 | params = []
30 | for key, value in model.named_parameters():
31 | if not value.requires_grad:
32 | continue
33 | lr = cfg.SOLVER.BASE_LR
34 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
35 | if "bias" in key:
36 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
37 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
38 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
39 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
40 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
41 | else:
42 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
43 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR)
44 | return optimizer, optimizer_center
45 |
46 |
47 | def make_nformer_optimizer_with_center(cfg, model, center_criterion, nformer_center_criterion):
48 | params = []
49 | for key, value in model.named_parameters():
50 | if not value.requires_grad:
51 | continue
52 | lr = cfg.SOLVER.BASE_LR
53 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
54 | if "bias" in key:
55 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
56 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
57 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
58 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
59 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
60 | else:
61 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
62 | nformer_optimizer = torch.optim.Adam(model.nformer.parameters(),lr = 1e-5,eps=1e-8, betas=[0.9,0.999])
63 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR)
64 | nformer_optimizer_center = torch.optim.SGD(nformer_center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR)
65 | return optimizer, optimizer_center, nformer_optimizer, nformer_optimizer_center
66 |
--------------------------------------------------------------------------------
/solver/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | from bisect import bisect_right
7 | import torch
8 |
9 |
10 | # FIXME ideally this would be achieved with a CombinedLRScheduler,
11 | # separating MultiStepLR with WarmupLR
12 | # but the current LRScheduler design doesn't allow it
13 |
14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
15 | def __init__(
16 | self,
17 | optimizer,
18 | milestones,
19 | gamma=0.1,
20 | warmup_factor=1.0 / 3,
21 | warmup_iters=500,
22 | warmup_method="linear",
23 | last_epoch=-1,
24 | ):
25 | if not list(milestones) == sorted(milestones):
26 | raise ValueError(
27 | "Milestones should be a list of" " increasing integers. Got {}",
28 | milestones,
29 | )
30 |
31 | if warmup_method not in ("constant", "linear"):
32 | raise ValueError(
33 | "Only 'constant' or 'linear' warmup_method accepted"
34 | "got {}".format(warmup_method)
35 | )
36 | self.milestones = milestones
37 | self.gamma = gamma
38 | self.warmup_factor = warmup_factor
39 | self.warmup_iters = warmup_iters
40 | self.warmup_method = warmup_method
41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
42 |
43 | def get_lr(self):
44 | warmup_factor = 1
45 | if self.last_epoch < self.warmup_iters:
46 | if self.warmup_method == "constant":
47 | warmup_factor = self.warmup_factor
48 | elif self.warmup_method == "linear":
49 | alpha = self.last_epoch / self.warmup_iters
50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
51 | return [
52 | base_lr
53 | * warmup_factor
54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
55 | for base_lr in self.base_lrs
56 | ]
57 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
--------------------------------------------------------------------------------
/tools/nformer_train.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import argparse
8 | import os
9 | import sys
10 | import torch
11 |
12 | from torch.backends import cudnn
13 |
14 | sys.path.append('.')
15 | from config import cfg
16 | from data import make_data_loader
17 | from engine.trainer import do_train, do_train_with_center
18 | from modeling import build_model
19 | from layers import make_loss, make_loss_with_center
20 | from solver import make_optimizer, make_optimizer_with_center, WarmupMultiStepLR
21 |
22 | from utils.logger import setup_logger
23 |
24 |
25 | def train(cfg):
26 | # prepare dataset
27 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
28 |
29 | # prepare model
30 | model = build_model(cfg, num_classes)
31 |
32 | if cfg.MODEL.IF_WITH_CENTER == 'no':
33 | print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
34 | optimizer = make_optimizer(cfg, model)
35 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
36 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
37 |
38 | loss_func = make_loss(cfg, num_classes) # modified by gu
39 |
40 | # Add for using self trained model
41 | if cfg.MODEL.PRETRAIN_CHOICE == 'self':
42 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1])
43 | print('Start epoch:', start_epoch)
44 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
45 | print('Path to the checkpoint of optimizer:', path_to_optimizer)
46 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
47 | optimizer.load_state_dict(torch.load(path_to_optimizer))
48 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
49 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)
50 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
51 | start_epoch = 0
52 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
53 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
54 | else:
55 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE))
56 |
57 | arguments = {}
58 |
59 | do_train(
60 | cfg,
61 | model,
62 | train_loader,
63 | val_loader,
64 | optimizer,
65 | scheduler, # modify for using self trained model
66 | loss_func,
67 | num_query,
68 | start_epoch # add for using self trained model
69 | )
70 | elif cfg.MODEL.IF_WITH_CENTER == 'yes':
71 | print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
72 | loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu
73 | optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion)
74 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
75 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
76 |
77 | arguments = {}
78 |
79 | # Add for using self trained model
80 | if cfg.MODEL.PRETRAIN_CHOICE == 'self':
81 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1])
82 | print('Start epoch:', start_epoch)
83 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
84 | print('Path to the checkpoint of optimizer:', path_to_optimizer)
85 | path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param')
86 | print('Path to the checkpoint of center_param:', path_to_center_param)
87 | path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center')
88 | print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center)
89 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
90 | optimizer.load_state_dict(torch.load(path_to_optimizer))
91 | center_criterion.load_state_dict(torch.load(path_to_center_param))
92 | optimizer_center.load_state_dict(torch.load(path_to_optimizer_center))
93 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
94 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)
95 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
96 | start_epoch = 0
97 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
98 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
99 | else:
100 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE))
101 |
102 | do_train_with_center(
103 | cfg,
104 | model,
105 | center_criterion,
106 | train_loader,
107 | val_loader,
108 | optimizer,
109 | optimizer_center,
110 | scheduler, # modify for using self trained model
111 | loss_func,
112 | num_query,
113 | start_epoch # add for using self trained model
114 | )
115 | else:
116 | print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER))
117 |
118 |
119 | def main():
120 | parser = argparse.ArgumentParser(description="ReID Baseline Training")
121 | parser.add_argument(
122 | "--config_file", default="", help="path to config file", type=str
123 | )
124 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
125 | nargs=argparse.REMAINDER)
126 |
127 | args = parser.parse_args()
128 |
129 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
130 |
131 | if args.config_file != "":
132 | cfg.merge_from_file(args.config_file)
133 | cfg.merge_from_list(args.opts)
134 | cfg.freeze()
135 |
136 | output_dir = cfg.OUTPUT_DIR
137 | if output_dir and not os.path.exists(output_dir):
138 | os.makedirs(output_dir)
139 |
140 | logger = setup_logger("reid_baseline", output_dir, 0)
141 | logger.info("Using {} GPUS".format(num_gpus))
142 | logger.info(args)
143 |
144 | if args.config_file != "":
145 | logger.info("Loaded configuration file {}".format(args.config_file))
146 | with open(args.config_file, 'r') as cf:
147 | config_str = "\n" + cf.read()
148 | logger.info(config_str)
149 | logger.info("Running with config:\n{}".format(cfg))
150 |
151 | if cfg.MODEL.DEVICE == "cuda":
152 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu
153 | cudnn.benchmark = True
154 | train(cfg)
155 |
156 |
157 | if __name__ == '__main__':
158 | main()
159 |
--------------------------------------------------------------------------------
/tools/test.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import argparse
8 | import os
9 | import sys
10 | from os import mkdir
11 |
12 | import torch
13 | from torch.backends import cudnn
14 |
15 | sys.path.append('.')
16 | from config import cfg
17 | from data import make_data_loader
18 | from engine.inference import inference
19 | from modeling import build_nformer_model
20 | from utils.logger import setup_logger
21 |
22 |
23 | def main():
24 | parser = argparse.ArgumentParser(description="ReID Baseline Inference")
25 | parser.add_argument(
26 | "--config_file", default="", help="path to config file", type=str
27 | )
28 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
29 | nargs=argparse.REMAINDER)
30 |
31 | args = parser.parse_args()
32 |
33 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
34 |
35 | if args.config_file != "":
36 | cfg.merge_from_file(args.config_file)
37 | cfg.merge_from_list(args.opts)
38 | cfg.freeze()
39 |
40 | output_dir = cfg.OUTPUT_DIR
41 | if output_dir and not os.path.exists(output_dir):
42 | mkdir(output_dir)
43 |
44 | logger = setup_logger("reid_baseline", output_dir, 0)
45 | logger.info("Using {} GPUS".format(num_gpus))
46 | logger.info(args)
47 |
48 | if args.config_file != "":
49 | logger.info("Loaded configuration file {}".format(args.config_file))
50 | with open(args.config_file, 'r') as cf:
51 | config_str = "\n" + cf.read()
52 | logger.info(config_str)
53 | logger.info("Running with config:\n{}".format(cfg))
54 |
55 | if cfg.MODEL.DEVICE == "cuda":
56 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID
57 | cudnn.benchmark = True
58 |
59 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
60 | model = build_nformer_model(cfg, num_classes)
61 | model.load_state_dict(torch.load(cfg.TEST.WEIGHT))
62 |
63 | inference(cfg, model, val_loader, num_query)
64 |
65 |
66 | if __name__ == '__main__':
67 | main()
68 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import argparse
8 | import os
9 | import sys
10 | import torch
11 |
12 | from torch.backends import cudnn
13 |
14 | sys.path.append('.')
15 | from config import cfg
16 | from data import make_data_loader
17 | from engine.trainer import do_train, do_train_with_center
18 | from modeling import build_model, build_nformer_model
19 | from layers import make_loss, make_loss_with_center, make_nformer_loss_with_center
20 | from solver import make_optimizer, make_optimizer_with_center, make_nformer_optimizer_with_center, WarmupMultiStepLR
21 |
22 | from utils.logger import setup_logger
23 |
24 |
25 | def train(cfg):
26 | # prepare dataset
27 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
28 |
29 | # prepare model
30 | model = build_nformer_model(cfg, num_classes)
31 |
32 | if cfg.MODEL.IF_WITH_CENTER == 'no':
33 | print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
34 | optimizer = make_optimizer(cfg, model)
35 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
36 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
37 |
38 | loss_func = make_loss(cfg, num_classes) # modified by gu
39 |
40 | # Add for using self trained model
41 | if cfg.MODEL.PRETRAIN_CHOICE == 'self':
42 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1])
43 | print('Start epoch:', start_epoch)
44 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
45 | print('Path to the checkpoint of optimizer:', path_to_optimizer)
46 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
47 | optimizer.load_state_dict(torch.load(path_to_optimizer))
48 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
49 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)
50 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
51 | start_epoch = 0
52 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
53 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
54 | else:
55 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE))
56 |
57 | arguments = {}
58 |
59 | do_train(
60 | cfg,
61 | model,
62 | train_loader,
63 | val_loader,
64 | optimizer,
65 | scheduler, # modify for using self trained model
66 | loss_func,
67 | num_query,
68 | start_epoch # add for using self trained model
69 | )
70 | elif cfg.MODEL.IF_WITH_CENTER == 'yes':
71 | print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
72 | loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu
73 | nformer_loss_func, nformer_center_criterion = make_nformer_loss_with_center(cfg, num_classes) # modified by gu
74 | optimizer, optimizer_center, optimizer_nformer, optimizer_nformer_center = make_nformer_optimizer_with_center(cfg, model, center_criterion, nformer_center_criterion)
75 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
76 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
77 |
78 | arguments = {}
79 |
80 | # Add for using self trained model
81 | if cfg.MODEL.PRETRAIN_CHOICE == 'self':
82 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1])
83 | print('Start epoch:', start_epoch)
84 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
85 | print('Path to the checkpoint of optimizer:', path_to_optimizer)
86 | path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param')
87 | print('Path to the checkpoint of center_param:', path_to_center_param)
88 | path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center')
89 | print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center)
90 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
91 | optimizer.load_state_dict(torch.load(path_to_optimizer))
92 | center_criterion.load_state_dict(torch.load(path_to_center_param))
93 | optimizer_center.load_state_dict(torch.load(path_to_optimizer_center))
94 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
95 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)
96 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
97 | start_epoch = 0
98 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
99 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
100 | else:
101 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE))
102 |
103 | do_train_with_center(
104 | cfg,
105 | model,
106 | center_criterion,
107 | nformer_center_criterion,
108 | train_loader,
109 | val_loader,
110 | optimizer,
111 | optimizer_center,
112 | optimizer_nformer,
113 | optimizer_nformer_center,
114 | scheduler, # modify for using self trained model
115 | loss_func,
116 | nformer_loss_func,
117 | num_query,
118 | start_epoch # add for using self trained model
119 | )
120 | else:
121 | print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER))
122 |
123 |
124 | def main():
125 | parser = argparse.ArgumentParser(description="ReID Baseline Training")
126 | parser.add_argument(
127 | "--config_file", default="", help="path to config file", type=str
128 | )
129 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
130 | nargs=argparse.REMAINDER)
131 |
132 | args = parser.parse_args()
133 |
134 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
135 |
136 | if args.config_file != "":
137 | cfg.merge_from_file(args.config_file)
138 | cfg.merge_from_list(args.opts)
139 | cfg.freeze()
140 |
141 | output_dir = cfg.OUTPUT_DIR
142 | if output_dir and not os.path.exists(output_dir):
143 | os.makedirs(output_dir)
144 |
145 | logger = setup_logger("reid_baseline", output_dir, 0)
146 | logger.info("Using {} GPUS".format(num_gpus))
147 | logger.info(args)
148 |
149 | if args.config_file != "":
150 | logger.info("Loaded configuration file {}".format(args.config_file))
151 | with open(args.config_file, 'r') as cf:
152 | config_str = "\n" + cf.read()
153 | logger.info(config_str)
154 | logger.info("Running with config:\n{}".format(cfg))
155 |
156 | if cfg.MODEL.DEVICE == "cuda":
157 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu
158 | cudnn.benchmark = True
159 | train(cfg)
160 |
161 |
162 | if __name__ == '__main__':
163 | main()
164 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 |
--------------------------------------------------------------------------------
/utils/iotools.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import errno
8 | import json
9 | import os
10 |
11 | import os.path as osp
12 |
13 |
14 | def mkdir_if_missing(directory):
15 | if not osp.exists(directory):
16 | try:
17 | os.makedirs(directory)
18 | except OSError as e:
19 | if e.errno != errno.EEXIST:
20 | raise
21 |
22 |
23 | def check_isfile(path):
24 | isfile = osp.isfile(path)
25 | if not isfile:
26 | print("=> Warning: no file found at '{}' (ignored)".format(path))
27 | return isfile
28 |
29 |
30 | def read_json(fpath):
31 | with open(fpath, 'r') as f:
32 | obj = json.load(f)
33 | return obj
34 |
35 |
36 | def write_json(obj, fpath):
37 | mkdir_if_missing(osp.dirname(fpath))
38 | with open(fpath, 'w') as f:
39 | json.dump(obj, f, indent=4, separators=(',', ': '))
40 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import logging
8 | import os
9 | import sys
10 |
11 |
12 | def setup_logger(name, save_dir, distributed_rank):
13 | logger = logging.getLogger(name)
14 | logger.setLevel(logging.DEBUG)
15 | # don't log results for the non-master process
16 | if distributed_rank > 0:
17 | return logger
18 | ch = logging.StreamHandler(stream=sys.stdout)
19 | ch.setLevel(logging.DEBUG)
20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
21 | ch.setFormatter(formatter)
22 | logger.addHandler(ch)
23 |
24 | if save_dir:
25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w')
26 | fh.setLevel(logging.DEBUG)
27 | fh.setFormatter(formatter)
28 | logger.addHandler(fh)
29 |
30 | return logger
31 |
--------------------------------------------------------------------------------
/utils/re_ranking.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Fri, 25 May 2018 20:29:09
5 |
6 | @author: luohao
7 | """
8 |
9 | """
10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking
13 | """
14 |
15 | """
16 | API
17 |
18 | probFea: all feature vectors of the query set (torch tensor)
19 | probFea: all feature vectors of the gallery set (torch tensor)
20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3)
21 | MemorySave: set to 'True' when using MemorySave mode
22 | Minibatch: avaliable when 'MemorySave' is 'True'
23 | """
24 |
25 | import numpy as np
26 | import torch
27 |
28 |
29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False):
30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor
31 | query_num = probFea.size(0)
32 | all_num = query_num + galFea.size(0)
33 | if only_local:
34 | original_dist = local_distmat
35 | else:
36 | feat = torch.cat([probFea,galFea])
37 | print('using GPU to compute original distance')
38 | distmat = torch.pow(feat,2).sum(dim=1, keepdim=True).expand(all_num,all_num) + \
39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t()
40 | distmat.addmm_(1,-2,feat,feat.t())
41 | original_dist = distmat.cpu().numpy()
42 | del feat
43 | if not local_distmat is None:
44 | original_dist = original_dist + local_distmat
45 | gallery_num = original_dist.shape[0]
46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
47 | V = np.zeros_like(original_dist).astype(np.float16)
48 | initial_rank = np.argsort(original_dist).astype(np.int32)
49 |
50 | print('starting re_ranking')
51 | for i in range(all_num):
52 | # k-reciprocal neighbors
53 | forward_k_neigh_index = initial_rank[i, :k1 + 1]
54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
55 | fi = np.where(backward_k_neigh_index == i)[0]
56 | k_reciprocal_index = forward_k_neigh_index[fi]
57 | k_reciprocal_expansion_index = k_reciprocal_index
58 | for j in range(len(k_reciprocal_index)):
59 | candidate = k_reciprocal_index[j]
60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
62 | :int(np.around(k1 / 2)) + 1]
63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
66 | candidate_k_reciprocal_index):
67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
68 |
69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
72 | original_dist = original_dist[:query_num, ]
73 | if k2 != 1:
74 | V_qe = np.zeros_like(V, dtype=np.float16)
75 | for i in range(all_num):
76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
77 | V = V_qe
78 | del V_qe
79 | del initial_rank
80 | invIndex = []
81 | for i in range(gallery_num):
82 | invIndex.append(np.where(V[:, i] != 0)[0])
83 |
84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
85 |
86 | for i in range(query_num):
87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
88 | indNonZero = np.where(V[i, :] != 0)[0]
89 | indImages = [invIndex[ind] for ind in indNonZero]
90 | for j in range(len(indNonZero)):
91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
92 | V[indImages[j], indNonZero[j]])
93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
94 |
95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
96 | del original_dist
97 | del V
98 | del jaccard_dist
99 | final_dist = final_dist[:query_num, query_num:]
100 | return final_dist
101 |
102 |
--------------------------------------------------------------------------------
/utils/reid_metric.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import numpy as np
8 | import torch
9 | from ignite.metrics import Metric
10 |
11 | from data.datasets.eval_reid import eval_func
12 | from .re_ranking import re_ranking
13 |
14 |
15 | class R1_mAP(Metric):
16 | def __init__(self, num_query, max_rank=50, feat_norm='yes'):
17 | super(R1_mAP, self).__init__()
18 | self.num_query = num_query
19 | self.max_rank = max_rank
20 | self.feat_norm = feat_norm
21 |
22 | def reset(self):
23 | self.feats = []
24 | self.pids = []
25 | self.camids = []
26 |
27 | def update(self, output):
28 | feat, pid, camid = output
29 | self.feats.append(feat)
30 | self.pids.extend(np.asarray(pid))
31 | self.camids.extend(np.asarray(camid))
32 |
33 | def compute(self):
34 | feats = torch.cat(self.feats, dim=0)
35 | if self.feat_norm == 'yes':
36 | print("The test feature is normalized")
37 | feats = torch.nn.functional.normalize(feats, dim=1, p=2)
38 | # query
39 | qf = feats[:self.num_query]
40 | q_pids = np.asarray(self.pids[:self.num_query])
41 | q_camids = np.asarray(self.camids[:self.num_query])
42 | # gallery
43 | gf = feats[self.num_query:]
44 | g_pids = np.asarray(self.pids[self.num_query:])
45 | g_camids = np.asarray(self.camids[self.num_query:])
46 | m, n = qf.shape[0], gf.shape[0]
47 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
48 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
49 | distmat.addmm_(1, -2, qf, gf.t())
50 | distmat = distmat.cpu().numpy()
51 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
52 |
53 | return cmc, mAP
54 |
55 | class NFormer_R1_mAP(Metric):
56 | def __init__(self, model, num_query, max_rank=50, feat_norm='yes'):
57 | super(NFormer_R1_mAP, self).__init__()
58 | self.model = model
59 | self.num_query = num_query
60 | self.max_rank = max_rank
61 | self.feat_norm = feat_norm
62 |
63 | def reset(self):
64 | self.feats = []
65 | self.pids = []
66 | self.camids = []
67 |
68 | def update(self, output):
69 | feat, pid, camid = output
70 | self.feats.append(feat)
71 | self.pids.extend(np.asarray(pid))
72 | self.camids.extend(np.asarray(camid))
73 |
74 | def compute(self):
75 | feats = torch.cat(self.feats, dim=0)
76 | feats = torch.nn.functional.normalize(feats, dim=1, p=2)
77 | self.model.eval()
78 | with torch.no_grad():
79 | feats = self.model(feats.unsqueeze(0), stage='nformer')[0]
80 | self.model.train()
81 | if self.feat_norm == 'yes':
82 | print("The test feature is normalized")
83 | feats = torch.nn.functional.normalize(feats, dim=1, p=2)
84 | # query
85 | qf = feats[:self.num_query]
86 | q_pids = np.asarray(self.pids[:self.num_query])
87 | q_camids = np.asarray(self.camids[:self.num_query])
88 | # gallery
89 | gf = feats[self.num_query:]
90 | g_pids = np.asarray(self.pids[self.num_query:])
91 | g_camids = np.asarray(self.camids[self.num_query:])
92 | m, n = qf.shape[0], gf.shape[0]
93 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
94 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
95 | distmat.addmm_(1, -2, qf, gf.t())
96 | distmat = distmat.cpu().numpy()
97 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
98 |
99 | return cmc, mAP
100 |
101 |
102 | class R1_mAP_reranking(Metric):
103 | def __init__(self, num_query, max_rank=50, feat_norm='yes'):
104 | super(R1_mAP_reranking, self).__init__()
105 | self.num_query = num_query
106 | self.max_rank = max_rank
107 | self.feat_norm = feat_norm
108 |
109 | def reset(self):
110 | self.feats = []
111 | self.pids = []
112 | self.camids = []
113 |
114 | def update(self, output):
115 | feat, pid, camid = output
116 | self.feats.append(feat)
117 | self.pids.extend(np.asarray(pid))
118 | self.camids.extend(np.asarray(camid))
119 |
120 | def compute(self):
121 | feats = torch.cat(self.feats, dim=0)
122 | if self.feat_norm == 'yes':
123 | print("The test feature is normalized")
124 | feats = torch.nn.functional.normalize(feats, dim=1, p=2)
125 |
126 | # query
127 | qf = feats[:self.num_query]
128 | q_pids = np.asarray(self.pids[:self.num_query])
129 | q_camids = np.asarray(self.camids[:self.num_query])
130 | # gallery
131 | gf = feats[self.num_query:]
132 | g_pids = np.asarray(self.pids[self.num_query:])
133 | g_camids = np.asarray(self.camids[self.num_query:])
134 | # m, n = qf.shape[0], gf.shape[0]
135 | # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
136 | # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
137 | # distmat.addmm_(1, -2, qf, gf.t())
138 | # distmat = distmat.cpu().numpy()
139 | print("Enter reranking")
140 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3)
141 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
142 |
143 | return cmc, mAP
144 |
--------------------------------------------------------------------------------