├── LICENCE.md
├── README.md
├── config
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ └── defaults.cpython-36.pyc
└── defaults.py
├── configs
├── baseline.yml
├── softmax.yml
├── softmax_triplet.yml
├── softmax_triplet_ft.yml
├── softmax_triplet_ftc.yml
└── softmax_triplet_with_center.yml
├── data
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── build.cpython-36.pyc
│ └── collate_batch.cpython-36.pyc
├── build.py
├── collate_batch.py
├── datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── bases.cpython-36.pyc
│ │ ├── cuhk.cpython-36.pyc
│ │ ├── dataset_loader.cpython-36.pyc
│ │ ├── dukemtmcreid.cpython-36.pyc
│ │ ├── eval_reid.cpython-36.pyc
│ │ ├── market1501.cpython-36.pyc
│ │ ├── msmt17.cpython-36.pyc
│ │ ├── prw.cpython-36.pyc
│ │ └── veri.cpython-36.pyc
│ ├── bases.py
│ ├── cuhk.py
│ ├── cuhk03.py
│ ├── dataset_loader.py
│ ├── dukemtmcreid.py
│ ├── eval_reid.py
│ ├── market1501.py
│ ├── msmt17.py
│ ├── prw.py
│ └── veri.py
├── samplers
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── triplet_sampler.cpython-36.pyc
│ └── triplet_sampler.py
└── transforms
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── build.cpython-36.pyc
│ └── transforms.cpython-36.pyc
│ ├── build.py
│ └── transforms.py
├── engine
├── __pycache__
│ ├── inference.cpython-36.pyc
│ └── trainer.cpython-36.pyc
├── inference.py
└── trainer.py
├── image
└── examples.png
├── layers
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── center_loss.cpython-36.pyc
│ └── triplet_loss.cpython-36.pyc
├── center_loss.py
└── triplet_loss.py
├── modeling
├── PISNet.py
├── Pre_Selection_Model.py
├── __init__.py
├── __pycache__
│ └── __init__.cpython-36.pyc
├── backbones
│ ├── Query_Guided_Attention.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── resnet.cpython-36.pyc
│ ├── pisnet.py
│ └── resnet.py
└── baseline.py
├── pi_cuhk.sh
├── pi_prw.sh
├── pre_select_cuhk.sh
├── solver
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── build.cpython-36.pyc
│ └── lr_scheduler.cpython-36.pyc
├── build.py
└── lr_scheduler.py
├── tests
├── __init__.py
└── lr_scheduler_test.py
├── tools
├── __init__.py
├── pre_selection.py
└── train.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── iotools.cpython-36.pyc
├── logger.cpython-36.pyc
├── re_ranking.cpython-36.pyc
└── reid_metric.cpython-36.pyc
├── iotools.py
├── logger.py
├── re_ranking.py
└── reid_metric.py
/LICENCE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2019] [HaoLuo]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Do Not Disturb Me: Person Re-identification Under the Interference of Other Pedestrians (ECCV 2020)
2 |
3 | Official code for ECCV 2020 paper [Do Not Disturb Me: Person Re-identification Under the Interference of Other Pedestrians](https://arxiv.org/abs/2008.06963).
4 |
5 |
6 |
7 |
8 |
9 | ## Introduction
10 |
11 | In the conventional person Re-ID setting, it is assumed that cropped images are the person images within the bounding box for each individual. However, in a crowded scene, off-shelf-detectors may generate bounding boxes involving multiple people, where the large proportion of background pedestrians or human occlusion exists. The representa- tion extracted from such cropped images, which contain both the target and the interference pedestrians, might include distractive information. This will lead to wrong retrieval results. To address this problem, this paper presents a novel deep network termed Pedestrian-Interference Sup- pression Network (PISNet). PISNet leverages a Query-Guided Attention Block (QGAB) to enhance the feature of the target in the gallery, under the guidance of the query. Furthermore, the involving Guidance Reversed Attention Module and the Multi-Person Separation Loss promote QGAB to suppress the interference of other pedestrians. Our method is evalu- ated on two new pedestrian-interference datasets and the results show that the proposed method performs favorably against existing Re-ID methods.
12 |
13 |
14 |
Resouces
15 |
16 | 1. Pretrained Models:
17 |
18 | [Baidu NetDisk](https://pan.baidu.com/s/1O08TssJcASsTh8veIBimzA), Password: 6x4x. The Models are trained using the gt boxes from [CUHK-SYSU](https://github.com/ShuangLI59/person_search) and [PRW](https://github.com/liangzheng06/PRW-baseline), respectively.
19 |
20 | 2. Datasets:
21 |
22 | Request the datasets from xbrainzsz@gmail.com (academic only).
23 | Due to licensing issues, please send me your request using your university email.
24 |
25 | ## Citation
26 |
27 | If you find this code useful in your research, please consider citing:
28 | ```
29 | @inproceedings{zhao2020pireid,
30 | title={Do Not Disturb Me: Person Re-identification Under the Interference of Other Pedestrians},
31 | author={Shizhen, Zhao and Changxin, Gao and Jun, Zhang and Hao, Cheng and Chuchu, Han and Xinyang, Jiang and Xiaowei, Guo and Wei-Shi, Zheng and Nong, Sang and Xing, Sun},
32 | booktitle={European Conference on Computer Vision (ECCV)},
33 | year={2020}
34 | }
35 | ```
36 |
37 | ## Contact
38 |
39 | Shizhen Zhao: xbrainzsz@gmail.com
40 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .defaults import _C as cfg
8 |
--------------------------------------------------------------------------------
/config/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/config/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/config/__pycache__/defaults.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/config/__pycache__/defaults.cpython-36.pyc
--------------------------------------------------------------------------------
/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # -----------------------------------------------------------------------------
4 | # Convention about Training / Test specific parameters
5 | # -----------------------------------------------------------------------------
6 | # Whenever an argument can be either used for training or for testing, the
7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter,
8 | # or _TEST for a test-specific parameter.
9 | # For example, the number of images during training will be
10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
11 | # IMAGES_PER_BATCH_TEST
12 |
13 | # -----------------------------------------------------------------------------
14 | # Config definition
15 | # -----------------------------------------------------------------------------
16 |
17 | _C = CN()
18 |
19 | _C.MODEL = CN()
20 | # Using cuda or cpu for training
21 | _C.MODEL.DEVICE = "cuda"
22 | # ID number of GPU
23 | _C.MODEL.DEVICE_ID = '0'
24 | # Name of backbone
25 | _C.MODEL.NAME = 'resnet50'
26 | # Last stride of backbone
27 | _C.MODEL.LAST_STRIDE = 1
28 | # Path to pretrained model of backbone
29 | _C.MODEL.PRETRAIN_PATH = ''
30 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model
31 | # Options: 'imagenet' or 'self'
32 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet'
33 | # If train with BNNeck, options: 'bnneck' or 'no'
34 | _C.MODEL.NECK = 'bnneck'
35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration
36 | _C.MODEL.IF_WITH_CENTER = 'no'
37 | # The loss type of metric loss
38 | # options:['triplet'](without center loss) or ['center','triplet_center'](with center loss)
39 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet'
40 | # For example, if loss type is cross entropy loss + triplet loss + center loss
41 | # the setting should be: _C.MODEL.METRIC_LOSS_TYPE = 'triplet_center' and _C.MODEL.IF_WITH_CENTER = 'yes'
42 |
43 | # If train with label smooth, options: 'on', 'off'
44 | _C.MODEL.IF_LABELSMOOTH = 'on'
45 |
46 | # HAS_NON_LOCAL
47 | _C.MODEL.HAS_NON_LOCAL = "no"
48 |
49 | #Whole model train
50 | _C.MODEL.WHOLE_MODEL_TRAIN = "no"
51 |
52 | #SIAMESE REGULARIZATION
53 | _C.MODEL.SIA_REG = "no"
54 |
55 | #Pyramid Attention
56 | _C.MODEL.PYRAMID = "no"
57 |
58 | #Pyramid Attention
59 | _C.MODEL.PYRAMID = "no"
60 |
61 | #GAMMA
62 | _C.MODEL.GAMMA = 1.0
63 |
64 | #BETA
65 | _C.MODEL.BETA = 1.0
66 |
67 | # -----------------------------------------------------------------------------
68 | # INPUT
69 | # -----------------------------------------------------------------------------
70 | _C.INPUT = CN()
71 | # Size of the image during training
72 | _C.INPUT.SIZE_TRAIN = [384, 128]
73 | # Size of the image during test
74 | _C.INPUT.SIZE_TEST = [384, 128]
75 | # Random probability for image horizontal flip
76 | _C.INPUT.PROB = 0.5
77 | # Random probability for random erasing
78 | _C.INPUT.RE_PROB = 0.5
79 | # Values to be used for image normalization
80 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
81 | # Values to be used for image normalization
82 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
83 | # Value of padding size
84 | _C.INPUT.PADDING = 10
85 |
86 | # -----------------------------------------------------------------------------
87 | # Dataset
88 | # -----------------------------------------------------------------------------
89 | _C.DATASETS = CN()
90 | # List of the dataset names for training, as present in paths_catalog.py
91 | _C.DATASETS.NAMES = ('market1501')
92 | # Root directory where datasets should be used (and downloaded if not found)
93 | _C.DATASETS.ROOT_DIR = '/root/person_search/dataset/multi_person'
94 | #
95 | _C.DATASETS.TRAIN_ANNO = 1
96 |
97 | # -----------------------------------------------------------------------------
98 | # DataLoader
99 | # -----------------------------------------------------------------------------
100 | _C.DATALOADER = CN()
101 | # Number of data loading threads
102 | _C.DATALOADER.NUM_WORKERS = 8
103 | # Sampler for data loading
104 | _C.DATALOADER.SAMPLER = 'softmax'
105 | # Number of instance for one batch
106 | _C.DATALOADER.NUM_INSTANCE = 16
107 |
108 | # ---------------------------------------------------------------------------- #
109 | # Solver
110 | # ---------------------------------------------------------------------------- #
111 | _C.SOLVER = CN()
112 | # Name of optimizer
113 | _C.SOLVER.OPTIMIZER_NAME = "Adam"
114 | # Number of max epoches
115 | _C.SOLVER.MAX_EPOCHS = 50
116 | # Base learning rate
117 | _C.SOLVER.BASE_LR = 3e-4
118 | # Factor of learning bias
119 | _C.SOLVER.BIAS_LR_FACTOR = 2
120 | # Momentum
121 | _C.SOLVER.MOMENTUM = 0.9
122 | # Margin of triplet loss
123 | _C.SOLVER.MARGIN = 0.3
124 | # Margin of cluster ;pss
125 | _C.SOLVER.CLUSTER_MARGIN = 0.3
126 | # Learning rate of SGD to learn the centers of center loss
127 | _C.SOLVER.CENTER_LR = 0.5
128 | # Balanced weight of center loss
129 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005
130 | # Settings of range loss
131 | _C.SOLVER.RANGE_K = 2
132 | _C.SOLVER.RANGE_MARGIN = 0.3
133 | _C.SOLVER.RANGE_ALPHA = 0
134 | _C.SOLVER.RANGE_BETA = 1
135 | _C.SOLVER.RANGE_LOSS_WEIGHT = 1
136 |
137 | # Settings of weight decay
138 | _C.SOLVER.WEIGHT_DECAY = 0.0005
139 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.
140 |
141 | # decay rate of learning rate
142 | _C.SOLVER.GAMMA = 0.1
143 | # decay step of learning rate
144 | _C.SOLVER.STEPS = (30, 55)
145 |
146 | # warm up factor
147 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3
148 | # iterations of warm up
149 | _C.SOLVER.WARMUP_ITERS = 500
150 | # method of warm up, option: 'constant','linear'
151 | _C.SOLVER.WARMUP_METHOD = "linear"
152 |
153 | # epoch number of saving checkpoints
154 | _C.SOLVER.CHECKPOINT_PERIOD = 50
155 | # iteration of display training log
156 | _C.SOLVER.LOG_PERIOD = 100
157 | # epoch number of validation
158 | _C.SOLVER.EVAL_PERIOD = 50
159 |
160 | # Number of images per batch
161 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
162 | # see 2 images per batch
163 | _C.SOLVER.IMS_PER_BATCH = 64
164 |
165 |
166 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
167 | # see 2 images per batch
168 | _C.TEST = CN()
169 | # Number of images per batch during test
170 | _C.TEST.IMS_PER_BATCH = 128
171 | # If test with re-ranking, options: 'yes','no'
172 | _C.TEST.RE_RANKING = 'no'
173 | # Path to trained model
174 | _C.TEST.WEIGHT = ""
175 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after'
176 | _C.TEST.NECK_FEAT = 'after'
177 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance
178 | _C.TEST.FEAT_NORM = 'yes'
179 | # Test pair
180 | _C.TEST.PAIR = "no"
181 |
182 | # ---------------------------------------------------------------------------- #
183 | # Misc options
184 | # ---------------------------------------------------------------------------- #
185 | # Path to checkpoint and saved log of trained model
186 | _C.OUTPUT_DIR = ""
187 | _C.Pre_Index_DIR = ""
188 |
--------------------------------------------------------------------------------
/configs/baseline.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
4 | LAST_STRIDE: 2
5 | NECK: 'no'
6 | METRIC_LOSS_TYPE: 'triplet'
7 | IF_LABELSMOOTH: 'off'
8 | IF_WITH_CENTER: 'no'
9 |
10 |
11 | INPUT:
12 | SIZE_TRAIN: [256, 128]
13 | SIZE_TEST: [256, 128]
14 | PROB: 0.5 # random horizontal flip
15 | RE_PROB: 0.0 # random erasing
16 | PADDING: 10
17 |
18 | DATASETS:
19 | NAMES: ('market1501')
20 |
21 | DATALOADER:
22 | SAMPLER: 'softmax_triplet'
23 | NUM_INSTANCE: 4
24 | NUM_WORKERS: 8
25 |
26 | SOLVER:
27 | OPTIMIZER_NAME: 'Adam'
28 | MAX_EPOCHS: 120
29 | BASE_LR: 0.00035
30 |
31 | CLUSTER_MARGIN: 0.3
32 |
33 | CENTER_LR: 0.5
34 | CENTER_LOSS_WEIGHT: 0.0005
35 |
36 | RANGE_K: 2
37 | RANGE_MARGIN: 0.3
38 | RANGE_ALPHA: 0
39 | RANGE_BETA: 1
40 | RANGE_LOSS_WEIGHT: 1
41 |
42 | BIAS_LR_FACTOR: 1
43 | WEIGHT_DECAY: 0.0005
44 | WEIGHT_DECAY_BIAS: 0.0005
45 | IMS_PER_BATCH: 64
46 |
47 | STEPS: [40, 70]
48 | GAMMA: 0.1
49 |
50 | WARMUP_FACTOR: 0.01
51 | WARMUP_ITERS: 0
52 | WARMUP_METHOD: 'linear'
53 |
54 | CHECKPOINT_PERIOD: 40
55 | LOG_PERIOD: 20
56 | EVAL_PERIOD: 40
57 |
58 | TEST:
59 | IMS_PER_BATCH: 128
60 | RE_RANKING: 'no'
61 | WEIGHT: "path"
62 | NECK_FEAT: 'after'
63 | FEAT_NORM: 'yes'
64 |
65 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on"
66 |
67 |
68 |
--------------------------------------------------------------------------------
/configs/softmax.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
3 |
4 |
5 | INPUT:
6 | SIZE_TRAIN: [256, 128]
7 | SIZE_TEST: [256, 128]
8 | PROB: 0.5 # random horizontal flip
9 | RE_PROB: 0.5 # random erasing
10 | PADDING: 10
11 |
12 | DATASETS:
13 | NAMES: ('market1501')
14 |
15 | DATALOADER:
16 | SAMPLER: 'softmax'
17 | NUM_WORKERS: 8
18 |
19 | SOLVER:
20 | OPTIMIZER_NAME: 'Adam'
21 | MAX_EPOCHS: 120
22 | BASE_LR: 0.00035
23 | BIAS_LR_FACTOR: 1
24 | WEIGHT_DECAY: 0.0005
25 | WEIGHT_DECAY_BIAS: 0.0005
26 | IMS_PER_BATCH: 64
27 |
28 | STEPS: [30, 55]
29 | GAMMA: 0.1
30 |
31 | WARMUP_FACTOR: 0.01
32 | WARMUP_ITERS: 5
33 | WARMUP_METHOD: 'linear'
34 |
35 | CHECKPOINT_PERIOD: 20
36 | LOG_PERIOD: 20
37 | EVAL_PERIOD: 20
38 |
39 | TEST:
40 | IMS_PER_BATCH: 128
41 |
42 | OUTPUT_DIR: "/home/haoluo/log/reid/market1501/softmax_bs64_256x128"
43 |
44 |
45 |
--------------------------------------------------------------------------------
/configs/softmax_triplet.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'on'
6 | IF_WITH_CENTER: 'no'
7 |
8 |
9 |
10 |
11 | INPUT:
12 | SIZE_TRAIN: [256, 128]
13 | SIZE_TEST: [256, 128]
14 | PROB: 0.5 # random horizontal flip
15 | RE_PROB: 0.5 # random erasing
16 | PADDING: 10
17 |
18 | DATASETS:
19 | NAMES: ('market1501')
20 |
21 | DATALOADER:
22 | SAMPLER: 'softmax_triplet'
23 | NUM_INSTANCE: 4
24 | NUM_WORKERS: 8
25 |
26 | SOLVER:
27 | OPTIMIZER_NAME: 'Adam'
28 | MAX_EPOCHS: 120
29 | BASE_LR: 0.00035
30 |
31 | CLUSTER_MARGIN: 0.3
32 |
33 | CENTER_LR: 0.5
34 | CENTER_LOSS_WEIGHT: 0.0005
35 |
36 | RANGE_K: 2
37 | RANGE_MARGIN: 0.3
38 | RANGE_ALPHA: 0
39 | RANGE_BETA: 1
40 | RANGE_LOSS_WEIGHT: 1
41 |
42 | BIAS_LR_FACTOR: 1
43 | WEIGHT_DECAY: 0.0005
44 | WEIGHT_DECAY_BIAS: 0.0005
45 | IMS_PER_BATCH: 64
46 |
47 | STEPS: [40, 70]
48 | GAMMA: 0.1
49 |
50 | WARMUP_FACTOR: 0.01
51 | WARMUP_ITERS: 10
52 | WARMUP_METHOD: 'linear'
53 |
54 | CHECKPOINT_PERIOD: 40
55 | LOG_PERIOD: 20
56 | EVAL_PERIOD: 40
57 |
58 | TEST:
59 | IMS_PER_BATCH: 128
60 | RE_RANKING: 'no'
61 | WEIGHT: "path"
62 | NECK_FEAT: 'after'
63 | FEAT_NORM: 'yes'
64 |
65 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on"
66 |
67 |
68 |
--------------------------------------------------------------------------------
/configs/softmax_triplet_ft.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'self'
3 | PRETRAIN_PATH: '/root/person_search/trained/strong_baseline/prw_all_trick_10/resnet50_model_120.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'no'
6 | IF_WITH_CENTER: 'no'
7 | HAS_NON_LOCAL: "yes"
8 | WHOLE_MODEL_TRAIN: "no"
9 | SIA_REG: "no"
10 | PYRAMID: "no"
11 | GAMMA: 1.0
12 | BETA: 1.0
13 |
14 |
15 | INPUT:
16 | SIZE_TRAIN: [256, 128]
17 | SIZE_TEST: [256, 128]
18 | PROB: 0.5 # random horizontal flip
19 | RE_PROB: 0.5 # random erasing
20 | PADDING: 10
21 |
22 | DATASETS:
23 | NAMES: ('market1501')
24 | TRAIN_ANNO: 1
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'Adam'
33 | MAX_EPOCHS: 50
34 | BASE_LR: 0.00035
35 |
36 | CLUSTER_MARGIN: 0.3
37 |
38 | CENTER_LR: 0.5
39 | CENTER_LOSS_WEIGHT: 0.0005
40 |
41 | RANGE_K: 2
42 | RANGE_MARGIN: 0.3
43 | RANGE_ALPHA: 0
44 | RANGE_BETA: 1
45 | RANGE_LOSS_WEIGHT: 1
46 |
47 | BIAS_LR_FACTOR: 1
48 | WEIGHT_DECAY: 0.0005
49 | WEIGHT_DECAY_BIAS: 0.0005
50 | IMS_PER_BATCH: 64
51 |
52 | STEPS: [20, 40]
53 | GAMMA: 0.1
54 |
55 | WARMUP_FACTOR: 0.01
56 | WARMUP_ITERS: 10
57 | WARMUP_METHOD: 'linear'
58 |
59 | CHECKPOINT_PERIOD: 20
60 | LOG_PERIOD: 20
61 | EVAL_PERIOD: 20
62 |
63 |
64 | TEST:
65 | IMS_PER_BATCH: 128
66 | RE_RANKING: 'no'
67 | WEIGHT: "path"
68 | NECK_FEAT: 'after'
69 | FEAT_NORM: 'yes'
70 | PAIR: "no"
71 |
72 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on"
73 | Pre_Index_DIR: "/root/person_search/multi-personReid/pre_index_dir/prw_pre_index.json"
74 |
75 |
76 |
--------------------------------------------------------------------------------
/configs/softmax_triplet_ftc.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'self'
3 | PRETRAIN_PATH: '/root/person_search/trained/strong_baseline/cuhk_all_trick_1/resnet50_model_120.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'no'
6 | IF_WITH_CENTER: 'no'
7 | HAS_NON_LOCAL: "yes"
8 | WHOLE_MODEL_TRAIN: "no"
9 | SIA_REG: "no"
10 | PYRAMID: "no"
11 | GAMMA: 1.0
12 | BETA: 1.0
13 |
14 |
15 | INPUT:
16 | SIZE_TRAIN: [256, 128]
17 | SIZE_TEST: [256, 128]
18 | PROB: 0.5 # random horizontal flip
19 | RE_PROB: 0.5 # random erasing
20 | PADDING: 10
21 |
22 | DATASETS:
23 | NAMES: ('market1501')
24 | TRAIN_ANNO: 1
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'Adam'
33 | MAX_EPOCHS: 50
34 | BASE_LR: 0.00035
35 |
36 | CLUSTER_MARGIN: 0.3
37 |
38 | CENTER_LR: 0.5
39 | CENTER_LOSS_WEIGHT: 0.0005
40 |
41 | RANGE_K: 2
42 | RANGE_MARGIN: 0.3
43 | RANGE_ALPHA: 0
44 | RANGE_BETA: 1
45 | RANGE_LOSS_WEIGHT: 1
46 |
47 | BIAS_LR_FACTOR: 1
48 | WEIGHT_DECAY: 0.0005
49 | WEIGHT_DECAY_BIAS: 0.0005
50 | IMS_PER_BATCH: 64
51 |
52 | STEPS: [20, 40]
53 | GAMMA: 0.1
54 |
55 | WARMUP_FACTOR: 0.01
56 | WARMUP_ITERS: 10
57 | WARMUP_METHOD: 'linear'
58 |
59 | CHECKPOINT_PERIOD: 20
60 | LOG_PERIOD: 20
61 | EVAL_PERIOD: 20
62 |
63 |
64 | TEST:
65 | IMS_PER_BATCH: 128
66 | RE_RANKING: 'no'
67 | WEIGHT: "path"
68 | NECK_FEAT: 'after'
69 | FEAT_NORM: 'yes'
70 | PAIR: "no"
71 |
72 |
73 |
74 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on"
75 |
76 | Pre_Index_DIR: "/root/person_search/multi-personReid/pre_index_dir/cuhk_pre_index.json"
77 |
78 |
79 |
--------------------------------------------------------------------------------
/configs/softmax_triplet_with_center.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth'
4 | METRIC_LOSS_TYPE: 'triplet_center'
5 | IF_LABELSMOOTH: 'on'
6 | IF_WITH_CENTER: 'yes'
7 |
8 |
9 |
10 |
11 | INPUT:
12 | SIZE_TRAIN: [256, 128]
13 | SIZE_TEST: [256, 128]
14 | PROB: 0.5 # random horizontal flip
15 | RE_PROB: 0.5 # random erasing
16 | PADDING: 10
17 |
18 | DATASETS:
19 | NAMES: ('market1501')
20 |
21 | DATALOADER:
22 | SAMPLER: 'softmax_triplet'
23 | NUM_INSTANCE: 4
24 | NUM_WORKERS: 8
25 |
26 | SOLVER:
27 | OPTIMIZER_NAME: 'Adam'
28 | MAX_EPOCHS: 120
29 | BASE_LR: 0.00035
30 |
31 | CLUSTER_MARGIN: 0.3
32 |
33 | CENTER_LR: 0.5
34 | CENTER_LOSS_WEIGHT: 0.0005
35 |
36 | RANGE_K: 2
37 | RANGE_MARGIN: 0.3
38 | RANGE_ALPHA: 0
39 | RANGE_BETA: 1
40 | RANGE_LOSS_WEIGHT: 1
41 |
42 | BIAS_LR_FACTOR: 1
43 | WEIGHT_DECAY: 0.0005
44 | WEIGHT_DECAY_BIAS: 0.0005
45 | IMS_PER_BATCH: 64
46 |
47 | STEPS: [40, 70]
48 | GAMMA: 0.1
49 |
50 | WARMUP_FACTOR: 0.01
51 | WARMUP_ITERS: 10
52 | WARMUP_METHOD: 'linear'
53 |
54 | CHECKPOINT_PERIOD: 40
55 | LOG_PERIOD: 20
56 | EVAL_PERIOD: 40
57 |
58 | TEST:
59 | IMS_PER_BATCH: 128
60 | RE_RANKING: 'no'
61 | WEIGHT: "path"
62 | NECK_FEAT: 'after'
63 | FEAT_NORM: 'yes'
64 |
65 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005"
66 |
67 |
68 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .build import make_data_loader, make_data_loader_train, make_data_loader_val
8 |
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/build.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/__pycache__/build.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/collate_batch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/__pycache__/collate_batch.cpython-36.pyc
--------------------------------------------------------------------------------
/data/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from torch.utils.data import DataLoader
8 |
9 | from .collate_batch import train_collate_fn, val_collate_fn, train_collate_fn_pair, val_collate_fn_pair, train_collate_fn_pair3
10 | from .datasets import init_dataset, ImageDataset, ImageDataset_pair, ImageDataset_pair_val, ImageDataset_pair3
11 | from .samplers import RandomIdentitySampler, RandomIdentitySampler_alignedreid # New add by gu
12 | from .transforms import build_transforms
13 | import json
14 | import numpy as np
15 | import random
16 | import os
17 |
18 | import time
19 |
20 |
21 | def multi_person_training_info_prw(train_anno, root):
22 | root = os.path.join(root, 'prw')
23 | path_gt = os.path.join(root, 'each_pid_info.json')
24 |
25 | with open(path_gt, 'r') as f:
26 | each_pid_info = json.load(f)
27 | # print(each_pid_info)
28 | path_hard = os.path.join(root, 'hard_gallery_train/gallery.json')
29 |
30 | with open(path_hard, 'r') as f:
31 | path_hard = json.load(f)
32 | # print(path_hard)
33 | path_hard_camera_id = os.path.join(root, 'hard_gallery_train/camera_id.json')
34 | with open(path_hard_camera_id, 'r') as f:
35 | path_hard_camera_id = json.load(f)
36 | # print(path_hard_camera_id)
37 |
38 | pairs_anno = []
39 | for img, pids in path_hard.items():
40 | camera_id = path_hard_camera_id[img]
41 | if len(pids) < 2:
42 | continue
43 | one_pair = [img]
44 | for index, pid in enumerate(pids):
45 | pid_info = each_pid_info[str(pid)]
46 | pid_info_camera_id = np.array(pid_info[0])
47 | pos_index = np.where(pid_info_camera_id != camera_id)[0]
48 | if len(pos_index) == 0:
49 | continue
50 | query_img = pid_info[1][random.choice(pos_index)]
51 | one_pair = one_pair + [query_img, pid]
52 |
53 | one_pair = one_pair + [camera_id]
54 | if len(one_pair) > 5:
55 | second_pair = [one_pair[0], one_pair[3], one_pair[4], one_pair[1], one_pair[2], one_pair[5]]
56 | pairs_anno.append(one_pair)
57 | pairs_anno.append(second_pair)
58 | # print(len(pairs_anno))
59 | anno_save_path = os.path.join(root, "pair_pos_unary" + str(train_anno) + ".json")
60 | with open(anno_save_path, 'w+') as f:
61 | json.dump(pairs_anno, f)
62 |
63 | def multi_person_training_info_cuhk(train_anno, root):
64 | root = os.path.join(root, 'cuhk')
65 | path_gt = os.path.join(root, 'each_pid_info.json')
66 | with open(path_gt, 'r') as f:
67 | each_pid_info = json.load(f)
68 | # print(each_pid_info)
69 |
70 | path_hard = os.path.join(root, 'hard_gallery_train/gallery.json')
71 | with open(path_hard, 'r') as f:
72 | path_hard = json.load(f)
73 | # print(path_hard)
74 |
75 | path_hard_camera_id = os.path.join(root, 'hard_gallery_train/camera_id.json')
76 | with open(path_hard_camera_id, 'r') as f:
77 | path_hard_camera_id = json.load(f)
78 | # print(path_hard_camera_id)
79 |
80 |
81 | pairs_anno = []
82 | count2 = 0
83 | for img, pids in path_hard.items():
84 | # camera_id = path_hard_camera_id[img]
85 | if len(pids) < 2:
86 | continue
87 | count2+=1
88 | # else:
89 | # continue
90 | one_pair = [img]
91 | camera_id = 0
92 | for index, pid in enumerate(pids):
93 | pid_info = each_pid_info[str(pid)]
94 | # pid_info_camera_id = np.array(pid_info[0])
95 | # pos_index = np.where(pid_info_camera_id != camera_id)[0]
96 | # if len(pos_index) == 0:
97 | # continue
98 | # query_img = pid_info[1][random.choice(pos_index)]
99 | query_img = random.choice(pid_info[1])
100 | one_pair = one_pair + [query_img, pid]
101 |
102 | one_pair = one_pair + [camera_id]
103 | if len(one_pair) > 5:
104 | second_pair = [one_pair[0], one_pair[3], one_pair[4], one_pair[1], one_pair[2], one_pair[5]]
105 | pairs_anno.append(one_pair)
106 | pairs_anno.append(second_pair)
107 |
108 | anno_save_path = os.path.join(root, "pair_pos_unary" + str(train_anno) + ".json")
109 | with open(anno_save_path, 'w+') as f:
110 | json.dump(pairs_anno, f)
111 |
112 | def make_data_loader(cfg):
113 | train_transforms = build_transforms(cfg, is_train=True)
114 | val_transforms = build_transforms(cfg, is_train=False)
115 | num_workers = cfg.DATALOADER.NUM_WORKERS
116 | if len(cfg.DATASETS.NAMES) == 1:
117 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
118 | else:
119 | # TODO: add multi dataset to train
120 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
121 |
122 | num_classes = dataset.num_train_pids
123 | train_set = ImageDataset(dataset.train, train_transforms)
124 | if cfg.DATALOADER.SAMPLER == 'softmax':
125 | train_loader = DataLoader(
126 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
127 | collate_fn=train_collate_fn
128 | )
129 | else:
130 | train_loader = DataLoader(
131 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
132 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
133 | # sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE), # new add by gu
134 | num_workers=num_workers, collate_fn=train_collate_fn
135 | )
136 |
137 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
138 | val_loader = DataLoader(
139 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
140 | collate_fn=val_collate_fn
141 | )
142 | return train_loader, val_loader, len(dataset.query), num_classes
143 |
144 | def make_data_loader_train(cfg):
145 | # multi_person_training_info2(cfg.DATASETS.TRAIN_ANNO)
146 |
147 | if "cuhk" in cfg.DATASETS.NAMES:
148 | multi_person_training_info_cuhk(cfg.DATASETS.TRAIN_ANNO, cfg.DATASETS.ROOT_DIR)
149 | else:
150 | multi_person_training_info_prw(cfg.DATASETS.TRAIN_ANNO, cfg.DATASETS.ROOT_DIR)
151 |
152 | train_transforms = build_transforms(cfg, is_train=True)
153 | val_transforms = build_transforms(cfg, is_train=False)
154 | num_workers = cfg.DATALOADER.NUM_WORKERS
155 | if len(cfg.DATASETS.NAMES) == 1:
156 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR, train_anno=cfg.DATASETS.TRAIN_ANNO)
157 | else:
158 | # TODO: add multi dataset to train
159 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR, train_anno=cfg.DATASETS.TRAIN_ANNO)
160 |
161 | train_set = ImageDataset_pair3(dataset.train, train_transforms)
162 | num_classes = dataset.num_train_pids
163 |
164 | if cfg.DATALOADER.SAMPLER == 'softmax':
165 | train_loader = DataLoader(
166 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
167 | collate_fn=train_collate_fn_pair3
168 | )
169 | else:
170 | train_loader = DataLoader(
171 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
172 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
173 | # sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE), # new add by gu
174 | num_workers=num_workers, collate_fn=train_collate_fn_pair3
175 | )
176 |
177 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
178 | val_loader = DataLoader(
179 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
180 | collate_fn=val_collate_fn
181 | )
182 |
183 | return train_loader, val_loader, len(dataset.query), num_classes
184 |
185 | def make_data_loader_val(cfg, index, dataset):
186 |
187 | indice_path = cfg.Pre_Index_DIR
188 | with open(indice_path, 'r') as f:
189 | indices = json.load(f)
190 | indice = indices[index][:100]
191 |
192 | val_transforms = build_transforms(cfg, is_train=False)
193 | num_workers = cfg.DATALOADER.NUM_WORKERS
194 |
195 | # if len(cfg.DATASETS.NAMES) == 1:
196 | # dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
197 | # else:
198 | # # TODO: add multi dataset to train
199 | # print(cfg.DATASETS.NAMES)
200 | # dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
201 |
202 | query = dataset.query[index]
203 | gallery = [dataset.gallery[ind] for ind in indice]
204 | gallery = [query] + gallery
205 |
206 | val_set = ImageDataset_pair_val(query, gallery, val_transforms)
207 |
208 | val_loader = DataLoader(
209 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
210 | collate_fn=val_collate_fn_pair
211 | )
212 |
213 | return val_loader
214 |
215 |
216 |
--------------------------------------------------------------------------------
/data/collate_batch.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 |
9 |
10 | def train_collate_fn(batch):
11 | imgs, pids, _, _, = zip(*batch)
12 | pids = torch.tensor(pids, dtype=torch.int64)
13 | return torch.stack(imgs, dim=0), pids
14 |
15 | def train_collate_fn_pair(batch):
16 | imgs_query, img_gallery, pids, _, _ , _, pids2, pos_neg = zip(*batch)
17 | pids = torch.tensor(pids, dtype=torch.int64)
18 | pids2 = torch.tensor(pids2, dtype=torch.int64)
19 | pos_neg = torch.FloatTensor(pos_neg)
20 | # pos_neg = torch.tensor(pos_neg)
21 |
22 | return torch.stack(imgs_query, dim=0), torch.stack(img_gallery, dim=0), pids, pids2, pos_neg
23 |
24 | def train_collate_fn_pair3(batch):
25 | img_gallery, imgs_query1, pids1, imgs_query2, pids2, _ = zip(*batch)
26 | pids1 = torch.tensor(pids1, dtype=torch.int64)
27 | pids2 = torch.tensor(pids2, dtype=torch.int64)
28 | return torch.stack(img_gallery, dim=0), torch.stack(imgs_query1, dim=0), torch.stack(imgs_query2, dim=0), pids1, pids2
29 |
30 | def val_collate_fn(batch):
31 | imgs, pids, camids, _ = zip(*batch)
32 | return torch.stack(imgs, dim=0), pids, camids
33 |
34 | def val_collate_fn_pair(batch):
35 | imgs_query, imgs_gallery, pids, camids, _ , _, is_first = zip(*batch)
36 | return torch.stack(imgs_query, dim=0), torch.stack(imgs_gallery, dim=0), pids, camids, is_first
37 |
--------------------------------------------------------------------------------
/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | # from .cuhk03 import CUHK03
7 | from .dukemtmcreid import DukeMTMCreID
8 | from .market1501 import Market1501
9 | from .msmt17 import MSMT17
10 | from .veri import VeRi
11 | from .dataset_loader import ImageDataset, ImageDataset_pair, ImageDataset_pair_val, ImageDataset_pair3
12 | from .prw import PRW
13 | from .cuhk import CUHK
14 |
15 | __factory = {
16 | 'market1501': Market1501,
17 | # 'cuhk03': CUHK03,
18 | 'dukemtmc': DukeMTMCreID,
19 | 'msmt17': MSMT17,
20 | 'veri': VeRi,
21 | 'prw': PRW,
22 | 'cuhk':CUHK
23 | }
24 |
25 |
26 | def get_names():
27 | return __factory.keys()
28 |
29 |
30 | def init_dataset(name, *args, **kwargs):
31 | if name not in __factory.keys():
32 | raise KeyError("Unknown datasets: {}".format(name))
33 | return __factory[name](*args, **kwargs)
34 |
--------------------------------------------------------------------------------
/data/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/bases.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/bases.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/cuhk.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/cuhk.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/dataset_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/dataset_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/eval_reid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/eval_reid.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/market1501.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/market1501.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/msmt17.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/msmt17.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/prw.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/prw.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/veri.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/veri.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/bases.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import numpy as np
8 |
9 |
10 | class BaseDataset(object):
11 | """
12 | Base class of reid dataset
13 | """
14 |
15 | def get_imagedata_info(self, data):
16 | pids, cams = [], []
17 | for _, pid, camid in data:
18 | pids += [pid]
19 | cams += [camid]
20 | pids = set(pids)
21 | cams = set(cams)
22 | num_pids = len(pids)
23 | num_cams = len(cams)
24 | num_imgs = len(data)
25 | return num_pids, num_imgs, num_cams
26 |
27 | def get_videodata_info(self, data, return_tracklet_stats=False):
28 | pids, cams, tracklet_stats = [], [], []
29 | for img_paths, pid, camid in data:
30 | pids += [pid]
31 | cams += [camid]
32 | tracklet_stats += [len(img_paths)]
33 | pids = set(pids)
34 | cams = set(cams)
35 | num_pids = len(pids)
36 | num_cams = len(cams)
37 | num_tracklets = len(data)
38 | if return_tracklet_stats:
39 | return num_pids, num_tracklets, num_cams, tracklet_stats
40 | return num_pids, num_tracklets, num_cams
41 |
42 | def print_dataset_statistics(self):
43 | raise NotImplementedError
44 |
45 |
46 | class BaseImageDataset(BaseDataset):
47 | """
48 | Base class of image reid dataset
49 | """
50 |
51 | def print_dataset_statistics(self, train, query, gallery):
52 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
53 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
54 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
55 |
56 | print("Dataset statistics:")
57 | print(" ----------------------------------------")
58 | print(" subset | # ids | # images | # cameras")
59 | print(" ----------------------------------------")
60 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
61 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
62 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
63 | print(" ----------------------------------------")
64 |
65 |
66 | class BaseVideoDataset(BaseDataset):
67 | """
68 | Base class of video reid dataset
69 | """
70 |
71 | def print_dataset_statistics(self, train, query, gallery):
72 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \
73 | self.get_videodata_info(train, return_tracklet_stats=True)
74 |
75 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \
76 | self.get_videodata_info(query, return_tracklet_stats=True)
77 |
78 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \
79 | self.get_videodata_info(gallery, return_tracklet_stats=True)
80 |
81 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats
82 | min_num = np.min(tracklet_stats)
83 | max_num = np.max(tracklet_stats)
84 | avg_num = np.mean(tracklet_stats)
85 |
86 | print("Dataset statistics:")
87 | print(" -------------------------------------------")
88 | print(" subset | # ids | # tracklets | # cameras")
89 | print(" -------------------------------------------")
90 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams))
91 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams))
92 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams))
93 | print(" -------------------------------------------")
94 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num))
95 | print(" -------------------------------------------")
96 |
--------------------------------------------------------------------------------
/data/datasets/cuhk.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import re
3 |
4 | import os.path as osp
5 |
6 | from .bases import BaseImageDataset
7 | import warnings
8 | import json
9 | import cv2
10 | from tqdm import tqdm
11 | import json
12 | import random
13 | import numpy as np
14 | import os
15 |
16 | import time
17 |
18 |
19 | class CUHK(BaseImageDataset):
20 | """Market1501.
21 |
22 | Reference:
23 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
24 |
25 | URL: ``_
26 |
27 | Dataset statistics:
28 | - identities: 1501 (+1 for background).
29 | - images: 12936 (train) + 3368 (query) + 15913 (gallery).
30 | """
31 | _junk_pids = [0, -1]
32 | dataset_dir = ''
33 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip'
34 |
35 | def __init__(self, root='datasets', market1501_500k=False, train_anno=1, **kwargs):
36 |
37 | # root = "/root/person_search/dataset/multi_person"
38 | self.root = os.path.join(root, 'cuhk')
39 | self.train_anno = train_anno
40 |
41 | self.pid_container = set()
42 |
43 | self.gallery_id = []
44 |
45 | # train = self.process_dir("train", relabel=True)
46 | train = self.process_dir_train(relabel=True)
47 | query = self.process_dir("query", relabel=False)
48 | gallery = self.process_dir("gallery", relabel=False)
49 |
50 | query = sorted(query)
51 | gallery = sorted(gallery)
52 |
53 | self.train = train
54 | self.query = query
55 | self.gallery = gallery
56 | #
57 |
58 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info_train(self.train)
59 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
60 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info_gallery(self.gallery)
61 |
62 | print("Dataset statistics:")
63 | print(" ----------------------------------------")
64 | print(" subset | # ids | # images | # cameras")
65 | print(" ----------------------------------------")
66 | print(" train | {:5d} | {:8d} | {:9d}".format(self.num_train_pids, self.num_train_imgs, self.num_train_cams))
67 | print(" query | {:5d} | {:8d} | {:9d}".format(self.num_query_pids, self.num_query_imgs, self.num_query_cams))
68 | print(" gallery | {:5d} | {:8d} | {:9d}".format(self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams))
69 | print(" ----------------------------------------")
70 |
71 |
72 | def get_imagedata_info_train(self, data):
73 |
74 | pids, cams = [], []
75 | for _, _, pid, camid, pid2, pos_neg in data:
76 | pids += [pid]
77 | pids += [pid2]
78 | cams += [camid]
79 | pids = set(pids)
80 | cams = set(cams)
81 | num_pids = len(pids)
82 | num_cams = len(cams)
83 | num_imgs = len(data)
84 | return num_pids, num_imgs, num_cams
85 |
86 | def get_imagedata_info_gallery(self, data):
87 | pids, cams = [], []
88 | for _, pid, camid in data:
89 | if isinstance(pid, list):
90 | for one_pid in pid:
91 | pids += [one_pid]
92 | cams += [camid]
93 | pids = set(pids)
94 | cams = set(cams)
95 | num_pids = len(pids)
96 | num_cams = len(cams)
97 | num_imgs = len(data)
98 | return num_pids, num_imgs, num_cams
99 |
100 | def process_dir_train(self, relabel=True):
101 | # root = "/root/person_search/dataset/person_search/cuhk"
102 | anno_path = osp.join(self.root, "gt_training_box.json")
103 | with open(anno_path, 'r+') as f:
104 | all_anno = json.load(f)
105 |
106 | pid_container = set()
107 | for img_name, pid in all_anno.items():
108 | pid_container.add(int(pid))
109 | # print(pid_container)
110 | # print("pid_container: " + str(len(pid_container)))
111 | pid2label = {int(pid): label for label, pid in enumerate(pid_container)}
112 | # print(pid2label)
113 | # print("pid_container: " + str(len(pid_container)))
114 |
115 |
116 | new_anno_path = osp.join(self.root, "pair_pos_unary" + str(self.train_anno) + ".json")
117 | with open(new_anno_path, 'r+') as f:
118 | all_anno = json.load(f)
119 | data = []
120 |
121 | # img_root1 = "/root/person_search/dataset/multi_person/cuhk/hard_gallery_train/image"
122 | # img_root2 = "/root/person_search/dataset/multi_person/cuhk/train_gt/image"
123 |
124 | img_root1 = os.path.join(self.root, 'hard_gallery_train/image')
125 | img_root2 = os.path.join(self.root, 'train_gt/image')
126 |
127 | file_index = 0
128 | for one_pair in all_anno:
129 | hard_imgname = one_pair[0]
130 | query_train_imgname1 = one_pair[1]
131 | pid1 = one_pair[2]
132 | query_train_imgname2 = one_pair[3]
133 | pid2 = one_pair[4]
134 | camera_id = one_pair[5]
135 | if relabel:
136 | pid1 = pid2label[pid1]
137 | pid2 = pid2label[pid2]
138 | hard_imgname_path = osp.join(img_root1, hard_imgname)
139 | query_train_path1 = osp.join(img_root2, query_train_imgname1)
140 | query_train_path2 = osp.join(img_root2, query_train_imgname2)
141 | new_anno = [hard_imgname_path, query_train_path1, pid1, query_train_path2, pid2, camera_id]
142 | # print(new_anno)
143 | data.append(new_anno)
144 |
145 | return data
146 |
147 | def process_dir(self, dataset, relabel=False):
148 |
149 | if dataset == "query":
150 | anno_path = osp.join(self.root, "query", "query.json")
151 | img_root = osp.join(self.root, "query", "query_image")
152 | elif dataset == "gallery":
153 | gallery_name = "hard_gallery_test"
154 | anno_path = osp.join(self.root, gallery_name, "gallery.json")
155 | img_root = osp.join(self.root, gallery_name, "image")
156 |
157 | with open(anno_path, 'r+') as f:
158 | all_anno = json.load(f)
159 |
160 | valid_pid_path = os.path.join(self.root, 'valid_q_pid.json')
161 | with open(valid_pid_path, 'r+') as f:
162 | valid_pid = json.load(f)
163 |
164 |
165 | data = []
166 | for img_name, pid in all_anno.items():
167 | image_path = osp.join(img_root, img_name)
168 | if dataset == "query":
169 | camid = 1
170 | elif dataset == "gallery":
171 | camid = 2
172 | if isinstance(pid, str):
173 | pid = int(pid)
174 | if dataset == "query":
175 | if pid not in valid_pid:
176 | continue
177 | data.append((image_path, pid, int(camid)))
178 | return data
179 |
180 |
--------------------------------------------------------------------------------
/data/datasets/cuhk03.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import h5py
8 | import os.path as osp
9 | from scipy.io import loadmat
10 | from scipy.misc import imsave
11 |
12 | from utils.iotools import mkdir_if_missing, write_json, read_json
13 | from .bases import BaseImageDataset
14 |
15 |
16 | class CUHK03(BaseImageDataset):
17 | """
18 | CUHK03
19 | Reference:
20 | Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014.
21 | URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#!
22 |
23 | Dataset statistics:
24 | # identities: 1360
25 | # images: 13164
26 | # cameras: 6
27 | # splits: 20 (classic)
28 | Args:
29 | split_id (int): split index (default: 0)
30 | cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False)
31 | """
32 | dataset_dir = 'cuhk03'
33 |
34 | def __init__(self, root='/home/haoluo/data', split_id=0, cuhk03_labeled=False,
35 | cuhk03_classic_split=False, verbose=True,
36 | **kwargs):
37 | super(CUHK03, self).__init__()
38 | self.dataset_dir = osp.join(root, self.dataset_dir)
39 | self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release')
40 | self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat')
41 |
42 | self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected')
43 | self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled')
44 |
45 | self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json')
46 | self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json')
47 |
48 | self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json')
49 | self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json')
50 |
51 | self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat')
52 | self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat')
53 |
54 | self._check_before_run()
55 | self._preprocess()
56 |
57 | if cuhk03_labeled:
58 | image_type = 'labeled'
59 | split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path
60 | else:
61 | image_type = 'detected'
62 | split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path
63 |
64 | splits = read_json(split_path)
65 | assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id,
66 | len(splits))
67 | split = splits[split_id]
68 | print("Split index = {}".format(split_id))
69 |
70 | train = split['train']
71 | query = split['query']
72 | gallery = split['gallery']
73 |
74 | if verbose:
75 | print("=> CUHK03 ({}) loaded".format(image_type))
76 | self.print_dataset_statistics(train, query, gallery)
77 |
78 | self.train = train
79 | self.query = query
80 | self.gallery = gallery
81 |
82 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
83 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
84 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
85 |
86 | def _check_before_run(self):
87 | """Check if all files are available before going deeper"""
88 | if not osp.exists(self.dataset_dir):
89 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
90 | if not osp.exists(self.data_dir):
91 | raise RuntimeError("'{}' is not available".format(self.data_dir))
92 | if not osp.exists(self.raw_mat_path):
93 | raise RuntimeError("'{}' is not available".format(self.raw_mat_path))
94 | if not osp.exists(self.split_new_det_mat_path):
95 | raise RuntimeError("'{}' is not available".format(self.split_new_det_mat_path))
96 | if not osp.exists(self.split_new_lab_mat_path):
97 | raise RuntimeError("'{}' is not available".format(self.split_new_lab_mat_path))
98 |
99 | def _preprocess(self):
100 | """
101 | This function is a bit complex and ugly, what it does is
102 | 1. Extract data from cuhk-03.mat and save as png images.
103 | 2. Create 20 classic splits. (Li et al. CVPR'14)
104 | 3. Create new split. (Zhong et al. CVPR'17)
105 | """
106 | print(
107 | "Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)")
108 | if osp.exists(self.imgs_labeled_dir) and \
109 | osp.exists(self.imgs_detected_dir) and \
110 | osp.exists(self.split_classic_det_json_path) and \
111 | osp.exists(self.split_classic_lab_json_path) and \
112 | osp.exists(self.split_new_det_json_path) and \
113 | osp.exists(self.split_new_lab_json_path):
114 | return
115 |
116 | mkdir_if_missing(self.imgs_detected_dir)
117 | mkdir_if_missing(self.imgs_labeled_dir)
118 |
119 | print("Extract image data from {} and save as png".format(self.raw_mat_path))
120 | mat = h5py.File(self.raw_mat_path, 'r')
121 |
122 | def _deref(ref):
123 | return mat[ref][:].T
124 |
125 | def _process_images(img_refs, campid, pid, save_dir):
126 | img_paths = [] # Note: some persons only have images for one view
127 | for imgid, img_ref in enumerate(img_refs):
128 | img = _deref(img_ref)
129 | # skip empty cell
130 | if img.size == 0 or img.ndim < 3: continue
131 | # images are saved with the following format, index-1 (ensure uniqueness)
132 | # campid: index of camera pair (1-5)
133 | # pid: index of person in 'campid'-th camera pair
134 | # viewid: index of view, {1, 2}
135 | # imgid: index of image, (1-10)
136 | viewid = 1 if imgid < 5 else 2
137 | img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid + 1, pid + 1, viewid, imgid + 1)
138 | img_path = osp.join(save_dir, img_name)
139 | if not osp.isfile(img_path):
140 | imsave(img_path, img)
141 | img_paths.append(img_path)
142 | return img_paths
143 |
144 | def _extract_img(name):
145 | print("Processing {} images (extract and save) ...".format(name))
146 | meta_data = []
147 | imgs_dir = self.imgs_detected_dir if name == 'detected' else self.imgs_labeled_dir
148 | for campid, camp_ref in enumerate(mat[name][0]):
149 | camp = _deref(camp_ref)
150 | num_pids = camp.shape[0]
151 | for pid in range(num_pids):
152 | img_paths = _process_images(camp[pid, :], campid, pid, imgs_dir)
153 | assert len(img_paths) > 0, "campid{}-pid{} has no images".format(campid, pid)
154 | meta_data.append((campid + 1, pid + 1, img_paths))
155 | print("- done camera pair {} with {} identities".format(campid + 1, num_pids))
156 | return meta_data
157 |
158 | meta_detected = _extract_img('detected')
159 | meta_labeled = _extract_img('labeled')
160 |
161 | def _extract_classic_split(meta_data, test_split):
162 | train, test = [], []
163 | num_train_pids, num_test_pids = 0, 0
164 | num_train_imgs, num_test_imgs = 0, 0
165 | for i, (campid, pid, img_paths) in enumerate(meta_data):
166 |
167 | if [campid, pid] in test_split:
168 | for img_path in img_paths:
169 | camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
170 | test.append((img_path, num_test_pids, camid))
171 | num_test_pids += 1
172 | num_test_imgs += len(img_paths)
173 | else:
174 | for img_path in img_paths:
175 | camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
176 | train.append((img_path, num_train_pids, camid))
177 | num_train_pids += 1
178 | num_train_imgs += len(img_paths)
179 | return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs
180 |
181 | print("Creating classic splits (# = 20) ...")
182 | splits_classic_det, splits_classic_lab = [], []
183 | for split_ref in mat['testsets'][0]:
184 | test_split = _deref(split_ref).tolist()
185 |
186 | # create split for detected images
187 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
188 | _extract_classic_split(meta_detected, test_split)
189 | splits_classic_det.append({
190 | 'train': train, 'query': test, 'gallery': test,
191 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs,
192 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs,
193 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs,
194 | })
195 |
196 | # create split for labeled images
197 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
198 | _extract_classic_split(meta_labeled, test_split)
199 | splits_classic_lab.append({
200 | 'train': train, 'query': test, 'gallery': test,
201 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs,
202 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs,
203 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs,
204 | })
205 |
206 | write_json(splits_classic_det, self.split_classic_det_json_path)
207 | write_json(splits_classic_lab, self.split_classic_lab_json_path)
208 |
209 | def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel):
210 | tmp_set = []
211 | unique_pids = set()
212 | for idx in idxs:
213 | img_name = filelist[idx][0]
214 | camid = int(img_name.split('_')[2]) - 1 # make it 0-based
215 | pid = pids[idx]
216 | if relabel: pid = pid2label[pid]
217 | img_path = osp.join(img_dir, img_name)
218 | tmp_set.append((img_path, int(pid), camid))
219 | unique_pids.add(pid)
220 | return tmp_set, len(unique_pids), len(idxs)
221 |
222 | def _extract_new_split(split_dict, img_dir):
223 | train_idxs = split_dict['train_idx'].flatten() - 1 # index-0
224 | pids = split_dict['labels'].flatten()
225 | train_pids = set(pids[train_idxs])
226 | pid2label = {pid: label for label, pid in enumerate(train_pids)}
227 | query_idxs = split_dict['query_idx'].flatten() - 1
228 | gallery_idxs = split_dict['gallery_idx'].flatten() - 1
229 | filelist = split_dict['filelist'].flatten()
230 | train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True)
231 | query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False)
232 | gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False)
233 | return train_info, query_info, gallery_info
234 |
235 | print("Creating new splits for detected images (767/700) ...")
236 | train_info, query_info, gallery_info = _extract_new_split(
237 | loadmat(self.split_new_det_mat_path),
238 | self.imgs_detected_dir,
239 | )
240 | splits = [{
241 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0],
242 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2],
243 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2],
244 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2],
245 | }]
246 | write_json(splits, self.split_new_det_json_path)
247 |
248 | print("Creating new splits for labeled images (767/700) ...")
249 | train_info, query_info, gallery_info = _extract_new_split(
250 | loadmat(self.split_new_lab_mat_path),
251 | self.imgs_labeled_dir,
252 | )
253 | splits = [{
254 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0],
255 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2],
256 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2],
257 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2],
258 | }]
259 | write_json(splits, self.split_new_lab_json_path)
260 |
--------------------------------------------------------------------------------
/data/datasets/dataset_loader.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import os.path as osp
8 | from PIL import Image
9 | from torch.utils.data import Dataset
10 |
11 |
12 | def read_image(img_path):
13 | """Keep reading image until succeed.
14 | This can avoid IOError incurred by heavy IO process."""
15 | got_img = False
16 | if not osp.exists(img_path):
17 | raise IOError("{} does not exist".format(img_path))
18 | while not got_img:
19 | try:
20 | img = Image.open(img_path).convert('RGB')
21 | got_img = True
22 | except IOError:
23 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
24 | pass
25 | return img
26 |
27 |
28 | class ImageDataset(Dataset):
29 | """Image Person ReID Dataset"""
30 |
31 | def __init__(self, dataset, transform=None):
32 | self.dataset = dataset
33 | self.transform = transform
34 |
35 | def __len__(self):
36 | return len(self.dataset)
37 |
38 | def __getitem__(self, index):
39 | img_path, pid, camid = self.dataset[index]
40 | img = read_image(img_path)
41 |
42 | if self.transform is not None:
43 | img = self.transform(img)
44 |
45 | return img, pid, camid, img_path
46 |
47 | class ImageDataset_pair(Dataset):
48 | """Image Person ReID Dataset"""
49 |
50 | def __init__(self, dataset, transform=None):
51 | self.dataset = dataset
52 | self.transform = transform
53 |
54 | def __len__(self):
55 | return len(self.dataset)
56 |
57 | def __getitem__(self, index):
58 | query_path, gallery_path, pid, camid, pid2, pos_neg = self.dataset[index]
59 | query_img = read_image(query_path)
60 | gallery_img = read_image(gallery_path)
61 |
62 | if self.transform is not None:
63 | query_img = self.transform(query_img)
64 | gallery_img = self.transform(gallery_img)
65 |
66 | return query_img, gallery_img, pid, camid, query_path, gallery_path, pid2, pos_neg
67 |
68 | class ImageDataset_pair3(Dataset):
69 | """Image Person ReID Dataset"""
70 |
71 | def __init__(self, dataset, transform=None):
72 | self.dataset = dataset
73 | self.transform = transform
74 |
75 | def __len__(self):
76 | return len(self.dataset)
77 |
78 | def __getitem__(self, index):
79 | gallery_path, query_path1, pid1, query_path2, pid2, camera_id = self.dataset[index]
80 | query_img1 = read_image(query_path1)
81 | query_img2 = read_image(query_path2)
82 | gallery_img = read_image(gallery_path)
83 |
84 | if self.transform is not None:
85 | query_img1 = self.transform(query_img1)
86 | query_img2 = self.transform(query_img2)
87 | gallery_img = self.transform(gallery_img)
88 |
89 | return gallery_img, query_img1, pid1, query_img2, pid2, camera_id
90 |
91 |
92 | class ImageDataset_pair_val(Dataset):
93 | """Image Person ReID Dataset"""
94 |
95 | def __init__(self, query, gallery, transform=None):
96 | self.query = query
97 | self.gallery = gallery
98 | self.transform = transform
99 |
100 | def __len__(self):
101 | return len(self.gallery)
102 |
103 | def __getitem__(self, index):
104 |
105 | query_path, pid, camid = self.query
106 | gallery_path, pid, camid = self.gallery[index]
107 |
108 | is_first = query_path == gallery_path
109 |
110 | query_img = read_image(query_path)
111 | gallery_img = read_image(gallery_path)
112 |
113 | if self.transform is not None:
114 | query_img = self.transform(query_img)
115 | gallery_img = self.transform(gallery_img)
116 |
117 | return query_img, gallery_img, pid, camid, query_path, gallery_path, is_first
118 |
--------------------------------------------------------------------------------
/data/datasets/dukemtmcreid.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import glob
8 | import re
9 | import urllib
10 | import zipfile
11 |
12 | import os.path as osp
13 |
14 | from utils.iotools import mkdir_if_missing
15 | from .bases import BaseImageDataset
16 | import json
17 |
18 |
19 | class DukeMTMCreID(BaseImageDataset):
20 | """
21 | DukeMTMC-reID
22 | Reference:
23 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
24 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
25 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation
26 |
27 | Dataset statistics:
28 | # identities: 1404 (train + query)
29 | # images:16522 (train) + 2228 (query) + 17661 (gallery)
30 | # cameras: 8
31 | """
32 | dataset_dir = 'dukemtmc-reid'
33 |
34 | def __init__(self, root='/home/haoluo/data', train_anno = 1, verbose=True, **kwargs):
35 | super(DukeMTMCreID, self).__init__()
36 | self.dataset_dir = root
37 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
38 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train')
39 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query')
40 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test')
41 |
42 | # self._download_data()
43 | # self._check_before_run()
44 |
45 | self.root = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw"
46 | # self.multi_person_training_info2()
47 | self.train_anno = train_anno
48 |
49 | train = self.process_dir_train(relabel=True)
50 | query = self._process_dir(self.query_dir, relabel=False)
51 | gallery = self._process_dir(self.gallery_dir, relabel=False)
52 |
53 | # if verbose:
54 | # print("=> DukeMTMC-reID loaded")
55 | # self.print_dataset_statistics(train, query, gallery)
56 |
57 | self.train = train
58 | self.query = query
59 | self.gallery = gallery
60 |
61 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info_train(self.train)
62 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
63 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
64 |
65 | def _download_data(self):
66 | if osp.exists(self.dataset_dir):
67 | print("This dataset has been downloaded.")
68 | return
69 |
70 | print("Creating directory {}".format(self.dataset_dir))
71 | mkdir_if_missing(self.dataset_dir)
72 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url))
73 |
74 | print("Downloading DukeMTMC-reID dataset")
75 | urllib.request.urlretrieve(self.dataset_url, fpath)
76 |
77 | print("Extracting files")
78 | zip_ref = zipfile.ZipFile(fpath, 'r')
79 | zip_ref.extractall(self.dataset_dir)
80 | zip_ref.close()
81 |
82 | def _check_before_run(self):
83 | """Check if all files are available before going deeper"""
84 | if not osp.exists(self.dataset_dir):
85 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
86 | if not osp.exists(self.train_dir):
87 | raise RuntimeError("'{}' is not available".format(self.train_dir))
88 | if not osp.exists(self.query_dir):
89 | raise RuntimeError("'{}' is not available".format(self.query_dir))
90 | if not osp.exists(self.gallery_dir):
91 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
92 |
93 | def _process_dir(self, dir_path, relabel=False):
94 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
95 | pattern = re.compile(r'([-\d]+)_c(\d)')
96 |
97 | pid_container = set()
98 | for img_path in img_paths:
99 | pid, _ = map(int, pattern.search(img_path).groups())
100 | pid_container.add(pid)
101 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
102 |
103 | dataset = []
104 | for img_path in img_paths:
105 | pid, camid = map(int, pattern.search(img_path).groups())
106 | assert 1 <= camid <= 8
107 | camid -= 1 # index starts from 0
108 | if relabel: pid = pid2label[pid]
109 | dataset.append((img_path, pid, camid))
110 | if 'query' in dir_path and len(dataset) >= 300:
111 | break
112 |
113 | return dataset
114 |
115 | def get_imagedata_info_train(self, data):
116 |
117 | pids, cams = [], []
118 | for _, _, pid, camid, pid2, pos_neg in data:
119 | pids += [pid]
120 | pids += [pid2]
121 | cams += [camid]
122 | pids = set(pids)
123 | cams = set(cams)
124 | num_pids = len(pids)
125 | num_cams = len(cams)
126 | num_imgs = len(data)
127 | return num_pids, num_imgs, num_cams
128 |
129 | def process_dir_train(self, relabel=True):
130 | root = "/raid/home/henrayzhao/person_search/dataset/person_search/prw"
131 | anno_path = osp.join(root, "training_box", "training_box.json")
132 | with open(anno_path, 'r+') as f:
133 | all_anno = json.load(f)
134 |
135 | pid_container = set()
136 | for img_name, pid in all_anno.items():
137 | pid_container.add(pid)
138 | pid2label = {int(pid): label for label, pid in enumerate(pid_container)}
139 |
140 | new_anno_path = osp.join(self.root, "pair_pos_unary" + str(self.train_anno) + ".json")
141 | with open(new_anno_path, 'r+') as f:
142 | all_anno = json.load(f)
143 | data = []
144 |
145 | img_root1 = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw/hard_gallery_train/image"
146 | img_root2 = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw/train_gt/image"
147 |
148 | for one_pair in all_anno:
149 | # print(one_pair)
150 | hard_imgname = one_pair[0]
151 | query_train_imgname1 = one_pair[1]
152 | pid1 = one_pair[2]
153 | query_train_imgname2 = one_pair[3]
154 | pid2 = one_pair[4]
155 | camera_id = one_pair[5]
156 | if relabel:
157 | pid1 = pid2label[pid1]
158 | pid2 = pid2label[pid2]
159 | hard_imgname_path = osp.join(img_root1, hard_imgname)
160 | query_train_path1 = osp.join(img_root2, query_train_imgname1)
161 | query_train_path2 = osp.join(img_root2, query_train_imgname2)
162 | new_anno = [hard_imgname_path, query_train_path1, pid1, query_train_path2, pid2, camera_id]
163 | data.append(new_anno)
164 | return data
165 |
--------------------------------------------------------------------------------
/data/datasets/eval_reid.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import numpy as np
8 | import json
9 | import os
10 |
11 |
12 | def process_g_pids(q_pid, g_pid_lists):
13 | g_pids = []
14 | for g_pid_list in g_pid_lists:
15 | if len(g_pid_list) <= 1:
16 | g_pids.append(g_pid_list[0])
17 | else:
18 | if q_pid in g_pid_list:
19 | g_pids.append(q_pid)
20 | else:
21 | g_pids.append(g_pid_list[0])
22 | return np.array(g_pids)
23 |
24 |
25 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
26 |
27 | # print(list(g_pids))
28 |
29 | """Evaluation with market1501 metric
30 | Key: for each query identity, its gallery images from the same camera view are discarded.
31 | """
32 | num_q, num_g = distmat.shape
33 | if num_g < max_rank:
34 | max_rank = num_g
35 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
36 | indices = np.argsort(distmat, axis=1)
37 |
38 | # compute cmc curve for each query
39 | all_cmc = []
40 | all_AP = []
41 |
42 | flag = 0
43 |
44 | if not isinstance(g_pids[0], (int, str)):
45 | list_g_pids = g_pids
46 | flag = 1
47 |
48 | num_valid_q = 0. # number of valid query
49 |
50 | q_pid_return = -88
51 |
52 | for q_idx in range(num_q):
53 | # get query pid and camid
54 | q_pid = q_pids[q_idx]
55 | q_camid = q_camids[q_idx]
56 |
57 | # print(flag)
58 | if flag == 1:
59 | g_pids = process_g_pids(q_pid, list_g_pids)
60 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
61 |
62 | # remove gallery samples that have the same pid and camid with query
63 | order = indices[q_idx]
64 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
65 | keep = np.invert(remove)
66 |
67 | # compute cmc curve
68 | # binary vector, positions with value 1 are correct matches
69 | orig_cmc = matches[q_idx][keep]
70 | if not np.any(orig_cmc):
71 | # this condition is true when query identity does not appear in gallery
72 | continue
73 |
74 | cmc = orig_cmc.cumsum()
75 | cmc[cmc > 1] = 1
76 | all_cmc.append(cmc[:max_rank])
77 | num_valid_q += 1.
78 |
79 | # compute average precision
80 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
81 | num_rel = orig_cmc.sum()
82 | tmp_cmc = orig_cmc.cumsum()
83 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
84 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
85 | AP = tmp_cmc.sum() / num_rel
86 | all_AP.append(AP)
87 | q_pid_return = q_pid
88 |
89 | if num_valid_q == 0:
90 | return -1, -1, q_pid_return
91 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
92 |
93 | all_cmc = np.asarray(all_cmc).astype(np.float32)
94 |
95 | all_cmc = all_cmc.sum(0) / num_valid_q
96 |
97 | mAP = np.mean(all_AP)
98 |
99 | return all_cmc, mAP, q_pid_return
100 |
--------------------------------------------------------------------------------
/data/datasets/market1501.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import glob
8 | import re
9 |
10 | import os.path as osp
11 |
12 | from .bases import BaseImageDataset
13 | import json
14 |
15 |
16 | class Market1501(BaseImageDataset):
17 | """
18 | Market1501
19 | Reference:
20 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
21 | URL: http://www.liangzheng.org/Project/project_reid.html
22 |
23 | Dataset statistics:
24 | # identities: 1501 (+1 for background)
25 | # images: 12936 (train) + 3368 (query) + 15913 (gallery)
26 | """
27 | dataset_dir = 'market1501'
28 |
29 | def __init__(self, root='/home/haoluo/data', train_anno=1, verbose=True, **kwargs):
30 | super(Market1501, self).__init__()
31 | self.dataset_dir = osp.join(root, self.dataset_dir)
32 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
33 | self.query_dir = osp.join(self.dataset_dir, 'query')
34 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
35 |
36 | # self._check_before_run()
37 |
38 | self.root = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw"
39 | # self.multi_person_training_info2()
40 | self.train_anno = train_anno
41 |
42 |
43 | train = self.process_dir_train(relabel=True)
44 | query = self._process_dir(self.query_dir, relabel=False)
45 | gallery = self._process_dir(self.gallery_dir, relabel=False)
46 |
47 | # if verbose:
48 | # print("=> Market1501 loaded")
49 | # self.print_dataset_statistics(train, query, gallery)
50 |
51 | self.train = train
52 | self.query = query
53 | self.gallery = gallery
54 |
55 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info_train(self.train)
56 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
57 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
58 | #
59 | # def _check_before_run(self):
60 | # """Check if all files are available before going deeper"""
61 | # if not osp.exists(self.dataset_dir):
62 | # raise RuntimeError("'{}' is not available".format(self.dataset_dir))
63 | # if not osp.exists(self.train_dir):
64 | # raise RuntimeError("'{}' is not available".format(self.train_dir))
65 | # if not osp.exists(self.query_dir):
66 | # raise RuntimeError("'{}' is not available".format(self.query_dir))
67 | # if not osp.exists(self.gallery_dir):
68 | # raise RuntimeError("'{}' is not available".format(self.gallery_dir))
69 |
70 |
71 | def get_imagedata_info_train(self, data):
72 |
73 | pids, cams = [], []
74 | for _, _, pid, camid, pid2, pos_neg in data:
75 | pids += [pid]
76 | pids += [pid2]
77 | cams += [camid]
78 | pids = set(pids)
79 | cams = set(cams)
80 | num_pids = len(pids)
81 | num_cams = len(cams)
82 | num_imgs = len(data)
83 | return num_pids, num_imgs, num_cams
84 |
85 | def _process_dir(self, dir_path, relabel=False):
86 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
87 | pattern = re.compile(r'([-\d]+)_c(\d)')
88 |
89 | pid_container = set()
90 | for img_path in img_paths:
91 | pid, _ = map(int, pattern.search(img_path).groups())
92 | if pid == -1: continue # junk images are just ignored
93 | pid_container.add(pid)
94 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
95 |
96 | dataset = []
97 | for img_path in img_paths:
98 | pid, camid = map(int, pattern.search(img_path).groups())
99 | if pid == -1: continue # junk images are just ignored
100 | assert 0 <= pid <= 1501 # pid == 0 means background
101 | assert 1 <= camid <= 6
102 | camid -= 1 # index starts from 0
103 | if relabel: pid = pid2label[pid]
104 | dataset.append((img_path, pid, camid))
105 | if 'query' in dir_path and len(dataset) >= 300:
106 | break
107 |
108 | return dataset
109 |
110 | def process_dir_train(self, relabel=True):
111 | root = "/raid/home/henrayzhao/person_search/dataset/person_search/prw"
112 | anno_path = osp.join(root, "training_box", "training_box.json")
113 | with open(anno_path, 'r+') as f:
114 | all_anno = json.load(f)
115 |
116 | pid_container = set()
117 | for img_name, pid in all_anno.items():
118 | pid_container.add(pid)
119 | pid2label = {int(pid): label for label, pid in enumerate(pid_container)}
120 |
121 | new_anno_path = osp.join(self.root, "pair_pos_unary" + str(self.train_anno) + ".json")
122 | with open(new_anno_path, 'r+') as f:
123 | all_anno = json.load(f)
124 | data = []
125 |
126 | img_root1 = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw/hard_gallery_train/image"
127 | img_root2 = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw/train_gt/image"
128 |
129 | for one_pair in all_anno:
130 | # print(one_pair)
131 | hard_imgname = one_pair[0]
132 | query_train_imgname1 = one_pair[1]
133 | pid1 = one_pair[2]
134 | query_train_imgname2 = one_pair[3]
135 | pid2 = one_pair[4]
136 | camera_id = one_pair[5]
137 | if relabel:
138 | pid1 = pid2label[pid1]
139 | pid2 = pid2label[pid2]
140 | hard_imgname_path = osp.join(img_root1, hard_imgname)
141 | query_train_path1 = osp.join(img_root2, query_train_imgname1)
142 | query_train_path2 = osp.join(img_root2, query_train_imgname2)
143 | new_anno = [hard_imgname_path, query_train_path1, pid1, query_train_path2, pid2, camera_id]
144 | data.append(new_anno)
145 | return data
146 |
--------------------------------------------------------------------------------
/data/datasets/msmt17.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/1/17 15:00
4 | # @Author : Hao Luo
5 | # @File : msmt17.py
6 |
7 | import glob
8 | import re
9 |
10 | import os.path as osp
11 |
12 | from .bases import BaseImageDataset
13 |
14 |
15 | class MSMT17(BaseImageDataset):
16 | """
17 | MSMT17
18 |
19 | Reference:
20 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018.
21 |
22 | URL: http://www.pkuvmc.com/publications/msmt17.html
23 |
24 | Dataset statistics:
25 | # identities: 4101
26 | # images: 32621 (train) + 11659 (query) + 82161 (gallery)
27 | # cameras: 15
28 | """
29 | dataset_dir = 'msmt17'
30 |
31 | def __init__(self,root='/home/haoluo/data', verbose=True, **kwargs):
32 | super(MSMT17, self).__init__()
33 | self.dataset_dir = osp.join(root, self.dataset_dir)
34 | self.train_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_train_v2')
35 | self.test_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_test_v2')
36 | self.list_train_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_train.txt')
37 | self.list_val_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_val.txt')
38 | self.list_query_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_query.txt')
39 | self.list_gallery_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_gallery.txt')
40 |
41 | self._check_before_run()
42 | train = self._process_dir(self.train_dir, self.list_train_path)
43 | #val, num_val_pids, num_val_imgs = self._process_dir(self.train_dir, self.list_val_path)
44 | query = self._process_dir(self.test_dir, self.list_query_path)
45 | gallery = self._process_dir(self.test_dir, self.list_gallery_path)
46 | if verbose:
47 | print("=> MSMT17 loaded")
48 | self.print_dataset_statistics(train, query, gallery)
49 |
50 | self.train = train
51 | self.query = query
52 | self.gallery = gallery
53 |
54 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
55 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
56 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
57 |
58 | def _check_before_run(self):
59 | """Check if all files are available before going deeper"""
60 | if not osp.exists(self.dataset_dir):
61 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
62 | if not osp.exists(self.train_dir):
63 | raise RuntimeError("'{}' is not available".format(self.train_dir))
64 | if not osp.exists(self.test_dir):
65 | raise RuntimeError("'{}' is not available".format(self.test_dir))
66 |
67 | def _process_dir(self, dir_path, list_path):
68 | with open(list_path, 'r') as txt:
69 | lines = txt.readlines()
70 | dataset = []
71 | pid_container = set()
72 | for img_idx, img_info in enumerate(lines):
73 | img_path, pid = img_info.split(' ')
74 | pid = int(pid) # no need to relabel
75 | camid = int(img_path.split('_')[2])
76 | img_path = osp.join(dir_path, img_path)
77 | dataset.append((img_path, pid, camid))
78 | pid_container.add(pid)
79 |
80 | # check if pid starts from 0 and increments with 1
81 | for idx, pid in enumerate(pid_container):
82 | assert idx == pid, "See code comment for explanation"
83 | return dataset
--------------------------------------------------------------------------------
/data/datasets/prw.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import re
3 |
4 | import os.path as osp
5 |
6 | from .bases import BaseImageDataset
7 | import warnings
8 | import json
9 | import cv2
10 | from tqdm import tqdm
11 | import json
12 | import random
13 | import numpy as np
14 | import os
15 |
16 | class PRW(BaseImageDataset):
17 | """Market1501.
18 |
19 | Reference:
20 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
21 |
22 | URL: ``_
23 |
24 | Dataset statistics:
25 | - identities: 1501 (+1 for background).
26 | - images: 12936 (train) + 3368 (query) + 15913 (gallery).
27 | """
28 | _junk_pids = [0, -1]
29 | dataset_dir = ''
30 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip'
31 |
32 | def __init__(self, root='datasets', market1501_500k=False, train_anno=1, **kwargs):
33 |
34 | # root = "/root/person_search/dataset/multi_person"
35 | self.root = osp.join(root, 'prw')
36 |
37 | # self.root = "/root/person_search/dataset/multi_person/prw"
38 | self.train_anno = train_anno
39 | self.pid_container = set()
40 |
41 | self.gallery_id = []
42 |
43 | # train = self.process_dir("train", relabel=True)
44 | train = self.process_dir_train(relabel=True)
45 | query = self.process_dir("query", relabel=False)
46 | gallery = self.process_dir("gallery", relabel=False)
47 |
48 | query = sorted(query)
49 | gallery = sorted(gallery)
50 |
51 | # print(query)
52 | # print(len(query))
53 |
54 |
55 | self.train = train
56 | self.query = query
57 | self.gallery = gallery
58 | #
59 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info_train(self.train)
60 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
61 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info_gallery(self.gallery)
62 |
63 | print("Dataset statistics:")
64 | print(" ----------------------------------------")
65 | print(" subset | # ids | # images | # cameras")
66 | print(" ----------------------------------------")
67 | print(
68 | " train | {:5d} | {:8d} | {:9d}".format(self.num_train_pids, self.num_train_imgs, self.num_train_cams))
69 | print(
70 | " query | {:5d} | {:8d} | {:9d}".format(self.num_query_pids, self.num_query_imgs, self.num_query_cams))
71 | print(" gallery | {:5d} | {:8d} | {:9d}".format(self.num_gallery_pids, self.num_gallery_imgs,
72 | self.num_gallery_cams))
73 | print(" ----------------------------------------")
74 |
75 |
76 | def get_imagedata_info_train(self, data):
77 |
78 | pids, cams = [], []
79 | for _, _, pid, camid, pid2, pos_neg in data:
80 | pids += [pid]
81 | pids += [pid2]
82 | cams += [camid]
83 | pids = set(pids)
84 | cams = set(cams)
85 | num_pids = len(pids)
86 | num_cams = len(cams)
87 | num_imgs = len(data)
88 | return num_pids, num_imgs, num_cams
89 |
90 | def get_imagedata_info_gallery(self, data):
91 | pids, cams = [], []
92 | for _, pid, camid in data:
93 | if isinstance(pid, list):
94 | for one_pid in pid:
95 | pids += [one_pid]
96 | cams += [camid]
97 | pids = set(pids)
98 | cams = set(cams)
99 | num_pids = len(pids)
100 | num_cams = len(cams)
101 | num_imgs = len(data)
102 | return num_pids, num_imgs, num_cams
103 |
104 | def process_dir_train(self, relabel=True):
105 | # # root = "/root/person_search/dataset/person_search/prw"
106 | anno_path = osp.join(self.root, "gt_training_box.json")
107 |
108 | with open(anno_path, 'r+') as f:
109 | all_anno = json.load(f)
110 |
111 | pid_container = set()
112 | for img_name, pid in all_anno.items():
113 | pid_container.add(pid)
114 | pid2label = {int(pid): label for label, pid in enumerate(pid_container)}
115 | # print(pid2label)
116 |
117 |
118 | new_anno_path = osp.join(self.root, "pair_pos_unary" + str(self.train_anno) + ".json")
119 | with open(new_anno_path, 'r+') as f:
120 | all_anno = json.load(f)
121 | data = []
122 |
123 | # img_root1 = "/root/person_search/dataset/multi_person/prw/hard_gallery_train/image"
124 | # img_root2 = "/root/person_search/dataset/multi_person/prw/train_gt/image"
125 |
126 | img_root1 = os.path.join(self.root, 'hard_gallery_train/image')
127 | img_root2 = os.path.join(self.root, 'train_gt/image')
128 |
129 | for one_pair in all_anno:
130 | # print(one_pair)
131 | hard_imgname = one_pair[0]
132 | query_train_imgname1 = one_pair[1]
133 | pid1 = one_pair[2]
134 | query_train_imgname2 = one_pair[3]
135 | pid2 = one_pair[4]
136 | camera_id = one_pair[5]
137 | if relabel:
138 | pid1 = pid2label[pid1]
139 | pid2 = pid2label[pid2]
140 | hard_imgname_path = osp.join(img_root1, hard_imgname)
141 | query_train_path1 = osp.join(img_root2, query_train_imgname1)
142 | query_train_path2 = osp.join(img_root2, query_train_imgname2)
143 | new_anno = [hard_imgname_path, query_train_path1, pid1, query_train_path2, pid2, camera_id]
144 | data.append(new_anno)
145 | return data
146 |
147 | def process_dir(self, dataset, relabel=False):
148 |
149 | if dataset == "query":
150 | anno_path = osp.join(self.root, "query", "query.json")
151 | img_root = osp.join(self.root, "query", "query_image")
152 | camid_path = osp.join(self.root, "query", "camera_id.json")
153 | elif dataset == "gallery":
154 | gallery_name = "hard_gallery_test"
155 | anno_path = osp.join(self.root, gallery_name, "gallery.json")
156 | img_root = osp.join(self.root, gallery_name, "image")
157 | camid_path = osp.join(self.root, gallery_name, "camera_id.json")
158 |
159 | with open(anno_path, 'r+') as f:
160 | all_anno = json.load(f)
161 |
162 |
163 |
164 | if dataset == "query" or dataset == "gallery":
165 | with open(camid_path, 'r+') as f:
166 | camid_dic = json.load(f)
167 |
168 | # valid_pid_path = "/root/person_search/dataset/multi_person/prw/valid_q_pid_3.json"
169 | valid_pid_path = os.path.join(self.root, 'valid_q_pid.json')
170 |
171 | with open(valid_pid_path, 'r+') as f:
172 | valid_pid = json.load(f)
173 |
174 |
175 | data = []
176 | pid_set = set()
177 | for img_name, pid in all_anno.items():
178 | image_path = osp.join(img_root, img_name)
179 | if dataset == "query":
180 | print({img_name: pid})
181 | if dataset == "query" or dataset == "gallery":
182 | camid = camid_dic[img_name]
183 | if isinstance(pid, str):
184 | pid = int(pid)
185 | if dataset == "query":
186 | if pid in valid_pid:
187 | if pid in pid_set:
188 | continue
189 | else:
190 | pid_set.add(pid)
191 | else:
192 | continue
193 |
194 | data.append((image_path, pid, int(camid)))
195 |
196 | return data
197 |
198 |
--------------------------------------------------------------------------------
/data/datasets/veri.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import re
3 |
4 | import os.path as osp
5 |
6 | from .bases import BaseImageDataset
7 |
8 |
9 | class VeRi(BaseImageDataset):
10 | """
11 | VeRi-776
12 | Reference:
13 | Liu, Xinchen, et al. "Large-scale vehicle re-identification in urban surveillance videos." ICME 2016.
14 |
15 | URL:https://vehiclereid.github.io/VeRi/
16 |
17 | Dataset statistics:
18 | # identities: 776
19 | # images: 37778 (train) + 1678 (query) + 11579 (gallery)
20 | # cameras: 20
21 | """
22 |
23 | dataset_dir = 'veri'
24 |
25 | def __init__(self, root='../', verbose=True, **kwargs):
26 | super(VeRi, self).__init__()
27 | self.dataset_dir = osp.join(root, self.dataset_dir)
28 | self.train_dir = osp.join(self.dataset_dir, 'image_train')
29 | self.query_dir = osp.join(self.dataset_dir, 'image_query')
30 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
31 |
32 | self._check_before_run()
33 |
34 | train = self._process_dir(self.train_dir, relabel=True)
35 | query = self._process_dir(self.query_dir, relabel=False)
36 | gallery = self._process_dir(self.gallery_dir, relabel=False)
37 |
38 | if verbose:
39 | print("=> VeRi-776 loaded")
40 | self.print_dataset_statistics(train, query, gallery)
41 |
42 | self.train = train
43 | self.query = query
44 | self.gallery = gallery
45 |
46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
49 |
50 | def _check_before_run(self):
51 | """Check if all files are available before going deeper"""
52 | if not osp.exists(self.dataset_dir):
53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
54 | if not osp.exists(self.train_dir):
55 | raise RuntimeError("'{}' is not available".format(self.train_dir))
56 | if not osp.exists(self.query_dir):
57 | raise RuntimeError("'{}' is not available".format(self.query_dir))
58 | if not osp.exists(self.gallery_dir):
59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
60 |
61 | def _process_dir(self, dir_path, relabel=False):
62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
63 | pattern = re.compile(r'([-\d]+)_c(\d+)')
64 |
65 | pid_container = set()
66 | for img_path in img_paths:
67 | pid, _ = map(int, pattern.search(img_path).groups())
68 | if pid == -1: continue # junk images are just ignored
69 | pid_container.add(pid)
70 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
71 |
72 | dataset = []
73 | for img_path in img_paths:
74 | pid, camid = map(int, pattern.search(img_path).groups())
75 | if pid == -1: continue # junk images are just ignored
76 | assert 0 <= pid <= 776 # pid == 0 means background
77 | assert 1 <= camid <= 20
78 | camid -= 1 # index starts from 0
79 | if relabel: pid = pid2label[pid]
80 | dataset.append((img_path, pid, camid))
81 |
82 | return dataset
83 |
84 |
--------------------------------------------------------------------------------
/data/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .triplet_sampler import RandomIdentitySampler, RandomIdentitySampler_alignedreid # new add by gu
8 |
--------------------------------------------------------------------------------
/data/samplers/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/samplers/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/samplers/__pycache__/triplet_sampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/samplers/__pycache__/triplet_sampler.cpython-36.pyc
--------------------------------------------------------------------------------
/data/samplers/triplet_sampler.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import copy
8 | import random
9 | import torch
10 | from collections import defaultdict
11 |
12 | import numpy as np
13 | from torch.utils.data.sampler import Sampler
14 |
15 |
16 | class RandomIdentitySampler(Sampler):
17 | """
18 | Randomly sample N identities, then for each identity,
19 | randomly sample K instances, therefore batch size is N*K.
20 | Args:
21 | - data_source (list): list of (img_path, pid, camid).
22 | - num_instances (int): number of instances per identity in a batch.
23 | - batch_size (int): number of examples in a batch.
24 | """
25 |
26 | def __init__(self, data_source, batch_size, num_instances):
27 | self.data_source = data_source
28 | self.batch_size = batch_size
29 | self.num_instances = num_instances
30 | self.num_pids_per_batch = self.batch_size // self.num_instances
31 | self.index_dic = defaultdict(list)
32 | # for index, (_, _, pid, _, _, _) in enumerate(self.data_source):
33 | for index, (_, _, pid, _, _, _) in enumerate(self.data_source):
34 | self.index_dic[pid].append(index)
35 | self.pids = list(self.index_dic.keys())
36 |
37 | # estimate number of examples in an epoch
38 | self.length = 0
39 | for pid in self.pids:
40 | idxs = self.index_dic[pid]
41 | num = len(idxs)
42 | if num < self.num_instances:
43 | num = self.num_instances
44 | self.length += num - num % self.num_instances
45 |
46 | def __iter__(self):
47 | batch_idxs_dict = defaultdict(list)
48 |
49 | for pid in self.pids:
50 | idxs = copy.deepcopy(self.index_dic[pid])
51 | if len(idxs) < self.num_instances:
52 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
53 | random.shuffle(idxs)
54 | batch_idxs = []
55 | for idx in idxs:
56 | batch_idxs.append(idx)
57 | if len(batch_idxs) == self.num_instances:
58 | batch_idxs_dict[pid].append(batch_idxs)
59 | batch_idxs = []
60 |
61 | avai_pids = copy.deepcopy(self.pids)
62 | final_idxs = []
63 |
64 | while len(avai_pids) >= self.num_pids_per_batch:
65 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
66 | for pid in selected_pids:
67 | batch_idxs = batch_idxs_dict[pid].pop(0)
68 | final_idxs.extend(batch_idxs)
69 | if len(batch_idxs_dict[pid]) == 0:
70 | avai_pids.remove(pid)
71 |
72 | self.length = len(final_idxs)
73 | return iter(final_idxs)
74 |
75 | def __len__(self):
76 | return self.length
77 |
78 |
79 | # New add by gu
80 | class RandomIdentitySampler_alignedreid(Sampler):
81 | """
82 | Randomly sample N identities, then for each identity,
83 | randomly sample K instances, therefore batch size is N*K.
84 |
85 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
86 |
87 | Args:
88 | data_source (Dataset): dataset to sample from.
89 | num_instances (int): number of instances per identity.
90 | """
91 | def __init__(self, data_source, num_instances):
92 | self.data_source = data_source
93 | self.num_instances = num_instances
94 | self.index_dic = defaultdict(list)
95 | for index, (_, pid, _) in enumerate(data_source):
96 | self.index_dic[pid].append(index)
97 | self.pids = list(self.index_dic.keys())
98 | self.num_identities = len(self.pids)
99 |
100 | def __iter__(self):
101 | indices = torch.randperm(self.num_identities)
102 | ret = []
103 | for i in indices:
104 | pid = self.pids[i]
105 | t = self.index_dic[pid]
106 | replace = False if len(t) >= self.num_instances else True
107 | t = np.random.choice(t, size=self.num_instances, replace=replace)
108 | ret.extend(t)
109 | return iter(ret)
110 |
111 | def __len__(self):
112 | return self.num_identities * self.num_instances
113 |
--------------------------------------------------------------------------------
/data/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .build import build_transforms
8 |
--------------------------------------------------------------------------------
/data/transforms/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/transforms/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/build.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/transforms/__pycache__/build.cpython-36.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/transforms/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/data/transforms/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import torchvision.transforms as T
8 |
9 | from .transforms import RandomErasing
10 |
11 |
12 | def build_transforms(cfg, is_train=True):
13 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
14 | if is_train:
15 | transform = T.Compose([
16 | T.Resize(cfg.INPUT.SIZE_TRAIN),
17 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
18 | T.Pad(cfg.INPUT.PADDING),
19 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
20 | T.ToTensor(),
21 | normalize_transform,
22 | RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN)
23 | ])
24 | else:
25 | transform = T.Compose([
26 | T.Resize(cfg.INPUT.SIZE_TEST),
27 | T.ToTensor(),
28 | normalize_transform
29 | ])
30 |
31 | return transform
32 |
--------------------------------------------------------------------------------
/data/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import math
8 | import random
9 |
10 |
11 | class RandomErasing(object):
12 | """ Randomly selects a rectangle region in an image and erases its pixels.
13 | 'Random Erasing Data Augmentation' by Zhong et al.
14 | See https://arxiv.org/pdf/1708.04896.pdf
15 | Args:
16 | probability: The probability that the Random Erasing operation will be performed.
17 | sl: Minimum proportion of erased area against input image.
18 | sh: Maximum proportion of erased area against input image.
19 | r1: Minimum aspect ratio of erased area.
20 | mean: Erasing value.
21 | """
22 |
23 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
24 | self.probability = probability
25 | self.mean = mean
26 | self.sl = sl
27 | self.sh = sh
28 | self.r1 = r1
29 |
30 | def __call__(self, img):
31 |
32 | if random.uniform(0, 1) >= self.probability:
33 | return img
34 |
35 | for attempt in range(100):
36 | area = img.size()[1] * img.size()[2]
37 |
38 | target_area = random.uniform(self.sl, self.sh) * area
39 | aspect_ratio = random.uniform(self.r1, 1 / self.r1)
40 |
41 | h = int(round(math.sqrt(target_area * aspect_ratio)))
42 | w = int(round(math.sqrt(target_area / aspect_ratio)))
43 |
44 | if w < img.size()[2] and h < img.size()[1]:
45 | x1 = random.randint(0, img.size()[1] - h)
46 | y1 = random.randint(0, img.size()[2] - w)
47 | if img.size()[0] == 3:
48 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
49 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
50 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
51 | else:
52 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
53 | return img
54 |
55 | return img
56 |
--------------------------------------------------------------------------------
/engine/__pycache__/inference.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/engine/__pycache__/inference.cpython-36.pyc
--------------------------------------------------------------------------------
/engine/__pycache__/trainer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/engine/__pycache__/trainer.cpython-36.pyc
--------------------------------------------------------------------------------
/engine/inference.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | import logging
7 |
8 | import torch
9 | import torch.nn as nn
10 | from ignite.engine import Engine
11 |
12 | from utils.reid_metric import R1_mAP, R1_mAP_reranking
13 |
14 |
15 | def create_supervised_evaluator(model, metrics,
16 | device=None):
17 | """
18 | Factory function for creating an evaluator for supervised models
19 |
20 | Args:
21 | model (`torch.nn.Module`): the model to train
22 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
23 | device (str, optional): device type specification (default: None).
24 | Applies to both model and batches.
25 | Returns:
26 | Engine: an evaluator engine with supervised inference function
27 | """
28 | if device:
29 | if torch.cuda.device_count() > 1:
30 | model = nn.DataParallel(model)
31 | model.to(device)
32 |
33 | def _inference(engine, batch):
34 | model.eval()
35 | with torch.no_grad():
36 | data, pids, camids = batch
37 | data = data.to(device) if torch.cuda.device_count() >= 1 else data
38 | feat = model(data)
39 | return feat, pids, camids
40 |
41 | engine = Engine(_inference)
42 |
43 | for name, metric in metrics.items():
44 | metric.attach(engine, name)
45 |
46 | return engine
47 |
48 |
49 | def inference(
50 | cfg,
51 | model,
52 | val_loader,
53 | num_query
54 | ):
55 | device = cfg.MODEL.DEVICE
56 |
57 | logger = logging.getLogger("reid_baseline.inference")
58 | logger.info("Enter inferencing")
59 | if cfg.TEST.RE_RANKING == 'no':
60 | print("Create evaluator")
61 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=100, feat_norm=cfg.TEST.FEAT_NORM)},
62 | device=device)
63 | elif cfg.TEST.RE_RANKING == 'yes':
64 | print("Create evaluator for reranking")
65 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=100, feat_norm=cfg.TEST.FEAT_NORM)},
66 | device=device)
67 | else:
68 | print("Unsupported re_ranking config. Only support for no or yes, but got {}.".format(cfg.TEST.RE_RANKING))
69 |
70 | evaluator.run(val_loader)
71 | cmc, mAP, _ = evaluator.state.metrics['r1_mAP']
72 | logger.info('Validation Results')
73 | logger.info("mAP: {:.1%}".format(mAP))
74 | for r in [1, 5, 10, 100]:
75 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
76 |
--------------------------------------------------------------------------------
/engine/trainer.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import logging
8 |
9 | import torch
10 | import torch.nn as nn
11 | from ignite.engine import Engine, Events
12 | from ignite.handlers import ModelCheckpoint, Timer
13 | from ignite.metrics import RunningAverage
14 | from data import make_data_loader_val
15 |
16 | from utils.reid_metric import R1_mAP, R1_mAP_pair
17 | from tqdm import tqdm
18 | import time
19 | import json
20 | import numpy as np
21 | from layers.triplet_loss import TripletLoss
22 |
23 | from data.datasets import init_dataset
24 |
25 | import copy
26 |
27 | import torch.nn.functional as F
28 |
29 | from data import make_data_loader, make_data_loader_train
30 | import random
31 | import os
32 |
33 | global ITER
34 | ITER = 0
35 |
36 |
37 | def euclidean_dist(gallery_feature1, gallery_feature2):
38 |
39 | xx = torch.pow(gallery_feature1, 2).sum(1, keepdim=True)
40 | yy = torch.pow(gallery_feature2, 2).sum(1, keepdim=True)
41 | dist1 = xx + yy
42 | dist2 = gallery_feature1 * gallery_feature2
43 | dist2 = dist2.sum(1, keepdim=True)
44 | dist = dist1 - 2 * dist2
45 | dist = dist.clamp(min=1e-12).sqrt()
46 | return dist
47 |
48 | def loss1(gallery_feature1, gallery_feature2, query_feature, margin=0.3):
49 |
50 | ranking_loss = nn.MarginRankingLoss(margin=margin)
51 | y = gallery_feature1.new((gallery_feature1.shape[0], 1)).fill_(1)
52 | dist_neg = euclidean_dist(gallery_feature1, gallery_feature2)
53 | dist_pos = euclidean_dist(gallery_feature1, query_feature)
54 | loss = ranking_loss(dist_neg, dist_pos, y)
55 |
56 | return loss
57 |
58 |
59 | def create_supervised_trainer(model, optimizer, loss_fn,
60 | device=None, gamma=1.0, margin=0.3, beta=1.0):
61 | """
62 | Factory function for creating a trainer for supervised models
63 |
64 | Args:
65 | model (`torch.nn.Module`): the model to train
66 | optimizer (`torch.optim.Optimizer`): the optimizer to use
67 | loss_fn (torch.nn loss function): the loss function to use
68 | device (str, optional): device type specification (default: None).
69 | Applies to both model and batches.
70 |
71 | Returns:
72 | Engine: a trainer engine with supervised update function
73 | """
74 | if device:
75 | if torch.cuda.device_count() > 1:
76 | model = nn.DataParallel(model)
77 | model.to(device)
78 |
79 | def _update(engine, batch):
80 | model.train()
81 | optimizer.zero_grad()
82 | # guiding, img, target, target2, pos_neg = batch
83 |
84 | img, guiding1, guiding2, target1, target2 = batch
85 |
86 | img = img.to(device) if torch.cuda.device_count() >= 1 else img
87 |
88 | guiding1 = guiding1.to(device) if torch.cuda.device_count() >= 1 else guiding1
89 | target1 = target1.to(device) if torch.cuda.device_count() >= 1 else target1
90 |
91 | guiding2 = guiding2.to(device) if torch.cuda.device_count() >= 1 else guiding2
92 | target2 = target2.to(device) if torch.cuda.device_count() >= 1 else target2
93 |
94 | # score, feat, score_guiding, feature_guiding, gallery_attention, score_pos_neg = model(guiding1, img, x_g2=guiding2)
95 | score, feat, score1, feat1, feat_query, score2, feat2 = model(guiding1, img, x_g2=guiding2)
96 |
97 | loss = loss_fn(score, feat, target1) + gamma * loss1(feat, feat1.detach(), feat_query, margin=margin) + beta * loss_fn(score2, feat, target1)
98 |
99 | loss.backward()
100 | optimizer.step()
101 |
102 | acc = (score.max(1)[1] == target1).float().mean()
103 | return loss.item(), acc.item()
104 |
105 | return Engine(_update)
106 |
107 |
108 | def create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cetner_loss_weight,
109 | device=None):
110 | """
111 | Factory function for creating a trainer for supervised models
112 |
113 | Args:
114 | model (`torch.nn.Module`): the model to train
115 | optimizer (`torch.optim.Optimizer`): the optimizer to use
116 | loss_fn (torch.nn loss function): the loss function to use
117 | device (str, optional): device type specification (default: None).
118 | Applies to both model and batches.
119 |
120 | Returns:
121 | Engine: a trainer engine with supervised update function
122 | """
123 | if device:
124 | if torch.cuda.device_count() > 1:
125 | model = nn.DataParallel(model)
126 | model.to(device)
127 |
128 | def _update(engine, batch):
129 | model.train()
130 | optimizer.zero_grad()
131 | optimizer_center.zero_grad()
132 | img, target = batch
133 | img = img.to(device) if torch.cuda.device_count() >= 1 else img
134 | target = target.to(device) if torch.cuda.device_count() >= 1 else target
135 | score, feat = model(img)
136 | loss = loss_fn(score, feat, target)
137 | # print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target)))
138 | loss.backward()
139 | optimizer.step()
140 | for param in center_criterion.parameters():
141 | param.grad.data *= (1. / cetner_loss_weight)
142 | optimizer_center.step()
143 |
144 | # compute acc
145 | acc = (score.max(1)[1] == target).float().mean()
146 | return loss.item(), acc.item()
147 |
148 | return Engine(_update)
149 |
150 |
151 | def create_supervised_evaluator(model, metrics,
152 | device=None):
153 | """
154 | Factory function for creating an evaluator for supervised models
155 |
156 | Args:
157 | model (`torch.nn.Module`): the model to train
158 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
159 | device (str, optional): device type specification (default: None).
160 | Applies to both model and batches.
161 | Returns:
162 | Engine: an evaluator engine with supervised inference function
163 | """
164 | if device:
165 | if torch.cuda.device_count() > 1:
166 | model = nn.DataParallel(model)
167 | model.to(device)
168 |
169 | def _inference(engine, batch):
170 | model.eval()
171 | with torch.no_grad():
172 | guiding, data, pids, camids, is_first = batch
173 |
174 | data = data.to(device) if torch.cuda.device_count() >= 1 else data
175 | guiding = guiding.to(device) if torch.cuda.device_count() >= 1 else guiding
176 | feat = model(guiding, data, is_first=is_first)
177 |
178 | return feat, pids, camids
179 |
180 | engine = Engine(_inference)
181 |
182 | for name, metric in metrics.items():
183 | metric.attach(engine, name)
184 |
185 | return engine
186 |
187 |
188 | def do_train(
189 | cfg,
190 | model,
191 | train_loader,
192 | val_loader,
193 | optimizer,
194 | scheduler,
195 | loss_fn,
196 | num_query,
197 | start_epoch
198 | ):
199 | log_period = cfg.SOLVER.LOG_PERIOD
200 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
201 | eval_period = cfg.SOLVER.EVAL_PERIOD
202 | output_dir = cfg.OUTPUT_DIR
203 | device = cfg.MODEL.DEVICE
204 | epochs = cfg.SOLVER.MAX_EPOCHS
205 |
206 | logger = logging.getLogger("reid_baseline.train")
207 | logger.info("Start training")
208 | trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device, gamma=cfg.MODEL.GAMMA, margin=cfg.SOLVER.MARGIN, beta=cfg.MODEL.BETA)
209 | if cfg.TEST.PAIR == "no":
210 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(1, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
211 | elif cfg.TEST.PAIR == "yes":
212 | evaluator = create_supervised_evaluator(model, metrics={
213 | 'r1_mAP': R1_mAP_pair(1, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
214 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
215 | # checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, n_saved=10, require_empty=False)
216 | timer = Timer(average=True)
217 |
218 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
219 | 'optimizer': optimizer})
220 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
221 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
222 |
223 | # average metric to attach on trainer
224 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
225 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
226 |
227 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
228 |
229 | @trainer.on(Events.STARTED)
230 | def start_training(engine):
231 | engine.state.epoch = start_epoch
232 |
233 | @trainer.on(Events.EPOCH_STARTED)
234 | def adjust_learning_rate(engine):
235 | scheduler.step()
236 |
237 | @trainer.on(Events.ITERATION_COMPLETED)
238 | def log_training_loss(engine):
239 | global ITER
240 | ITER += 1
241 | if ITER % log_period == 0:
242 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
243 | .format(engine.state.epoch, ITER, len(train_loader),
244 | engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
245 | scheduler.get_lr()[0]))
246 | if len(train_loader) == ITER:
247 | ITER = 0
248 |
249 | # adding handlers using `trainer.on` decorator API
250 | @trainer.on(Events.EPOCH_COMPLETED)
251 | def print_times(engine):
252 | # multi_person_training_info2()
253 | train_loader, val_loader, num_query, num_classes = make_data_loader_train(cfg)
254 | logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
255 | .format(engine.state.epoch, timer.value() * timer.step_count,
256 | train_loader.batch_size / timer.value()))
257 | logger.info('-' * 10)
258 | timer.reset()
259 |
260 | @trainer.on(Events.EPOCH_COMPLETED)
261 | def log_validation_results(engine):
262 | # if engine.state.epoch % eval_period == 0:
263 | if engine.state.epoch >= eval_period:
264 | all_cmc = []
265 | all_AP = []
266 | num_valid_q = 0
267 | q_pids = []
268 | for query_index in tqdm(range(num_query)):
269 |
270 | val_loader = make_data_loader_val(cfg, query_index, dataset)
271 | evaluator.run(val_loader)
272 | cmc, AP, q_pid = evaluator.state.metrics['r1_mAP']
273 |
274 | if AP >= 0:
275 | if cmc.shape[0] < 50:
276 | continue
277 | num_valid_q += 1
278 |
279 | all_cmc.append(cmc)
280 | all_AP.append(AP)
281 | q_pids.append(int(q_pid))
282 | else:
283 | continue
284 |
285 | all_cmc = np.asarray(all_cmc).astype(np.float32)
286 | cmc = all_cmc.sum(0) / num_valid_q
287 | mAP = np.mean(all_AP)
288 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
289 | logger.info("mAP: {:.1%}".format(mAP))
290 | for r in [1, 5, 10]:
291 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
292 |
293 |
294 | trainer.run(train_loader, max_epochs=epochs)
295 |
296 |
297 | def do_train_with_center(
298 | cfg,
299 | model,
300 | center_criterion,
301 | train_loader,
302 | val_loader,
303 | optimizer,
304 | optimizer_center,
305 | scheduler,
306 | loss_fn,
307 | num_query,
308 | start_epoch
309 | ):
310 | log_period = cfg.SOLVER.LOG_PERIOD
311 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
312 | eval_period = cfg.SOLVER.EVAL_PERIOD
313 | output_dir = cfg.OUTPUT_DIR
314 | device = cfg.MODEL.DEVICE
315 | epochs = cfg.SOLVER.MAX_EPOCHS
316 |
317 | logger = logging.getLogger("reid_baseline.train")
318 | logger.info("Start training")
319 | trainer = create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device)
320 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
321 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
322 | timer = Timer(average=True)
323 |
324 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
325 | 'optimizer': optimizer,
326 | 'center_param': center_criterion,
327 | 'optimizer_center': optimizer_center})
328 |
329 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
330 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
331 |
332 | # average metric to attach on trainer
333 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
334 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
335 |
336 | @trainer.on(Events.STARTED)
337 | def start_training(engine):
338 | engine.state.epoch = start_epoch
339 |
340 | @trainer.on(Events.EPOCH_STARTED)
341 | def adjust_learning_rate(engine):
342 | scheduler.step()
343 |
344 | @trainer.on(Events.ITERATION_COMPLETED)
345 | def log_training_loss(engine):
346 | global ITER
347 | ITER += 1
348 |
349 | if ITER % log_period == 0:
350 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
351 | .format(engine.state.epoch, ITER, len(train_loader),
352 | engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
353 | scheduler.get_lr()[0]))
354 | if len(train_loader) == ITER:
355 | ITER = 0
356 |
357 | # adding handlers using `trainer.on` decorator API
358 | @trainer.on(Events.EPOCH_COMPLETED)
359 | def print_times(engine):
360 | logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
361 | .format(engine.state.epoch, timer.value() * timer.step_count,
362 | train_loader.batch_size / timer.value()))
363 | logger.info('-' * 10)
364 | timer.reset()
365 |
366 | @trainer.on(Events.EPOCH_COMPLETED)
367 | def log_validation_results(engine):
368 | if engine.state.epoch % eval_period == 0:
369 | evaluator.run(val_loader)
370 | cmc, mAP = evaluator.state.metrics['r1_mAP']
371 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
372 | logger.info("mAP: {:.1%}".format(mAP))
373 | for r in [1, 5, 10]:
374 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
375 |
376 | trainer.run(train_loader, max_epochs=epochs)
377 |
--------------------------------------------------------------------------------
/image/examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/image/examples.png
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch.nn.functional as F
8 |
9 | from .triplet_loss import TripletLoss, CrossEntropyLabelSmooth
10 | from .center_loss import CenterLoss
11 |
12 |
13 | def make_loss(cfg, num_classes): # modified by gu
14 | sampler = cfg.DATALOADER.SAMPLER
15 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
16 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
17 | else:
18 | print('expected METRIC_LOSS_TYPE should be triplet'
19 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
20 |
21 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
22 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
23 | print("label smooth on, numclasses:", num_classes)
24 |
25 | if sampler == 'softmax':
26 | def loss_func(score, feat, target):
27 | return F.cross_entropy(score, target)
28 | elif cfg.DATALOADER.SAMPLER == 'triplet':
29 | def loss_func(score, feat, target):
30 | return triplet(feat, target)[0]
31 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
32 | def loss_func(score, feat, target):
33 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
34 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
35 | return xent(score, target) + triplet(feat, target)[0]
36 | else:
37 | return F.cross_entropy(score, target) + triplet(feat, target)[0]
38 | else:
39 | print('expected METRIC_LOSS_TYPE should be triplet'
40 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
41 | else:
42 | print('expected sampler should be softmax, triplet or softmax_triplet, '
43 | 'but got {}'.format(cfg.DATALOADER.SAMPLER))
44 | return loss_func
45 |
46 |
47 | def make_loss_with_center(cfg, num_classes): # modified by gu
48 | if cfg.MODEL.NAME == 'resnet18' or cfg.MODEL.NAME == 'resnet34':
49 | feat_dim = 512
50 | else:
51 | feat_dim = 2048
52 |
53 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
54 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
55 |
56 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
57 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
58 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
59 |
60 | else:
61 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
62 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
63 |
64 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
65 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
66 | print("label smooth on, numclasses:", num_classes)
67 |
68 | def loss_func(score, feat, target):
69 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
70 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
71 | return xent(score, target) + \
72 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
73 | else:
74 | return F.cross_entropy(score, target) + \
75 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
76 |
77 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
78 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
79 | return xent(score, target) + \
80 | triplet(feat, target)[0] + \
81 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
82 | else:
83 | return F.cross_entropy(score, target) + \
84 | triplet(feat, target)[0] + \
85 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target)
86 |
87 | else:
88 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
89 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
90 | return loss_func, center_criterion
--------------------------------------------------------------------------------
/layers/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/layers/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/center_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/layers/__pycache__/center_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/triplet_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/layers/__pycache__/triplet_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/center_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class CenterLoss(nn.Module):
8 | """Center loss.
9 |
10 | Reference:
11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
12 |
13 | Args:
14 | num_classes (int): number of classes.
15 | feat_dim (int): feature dimension.
16 | """
17 |
18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
19 | super(CenterLoss, self).__init__()
20 | self.num_classes = num_classes
21 | self.feat_dim = feat_dim
22 | self.use_gpu = use_gpu
23 |
24 | if self.use_gpu:
25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
26 | else:
27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
28 |
29 | def forward(self, x, labels):
30 | """
31 | Args:
32 | x: feature matrix with shape (batch_size, feat_dim).
33 | labels: ground truth labels with shape (num_classes).
34 | """
35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
36 |
37 | batch_size = x.size(0)
38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
40 | distmat.addmm_(1, -2, x, self.centers.t())
41 |
42 | classes = torch.arange(self.num_classes).long()
43 | if self.use_gpu: classes = classes.cuda()
44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
45 | mask = labels.eq(classes.expand(batch_size, self.num_classes))
46 |
47 | dist = distmat * mask.float()
48 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
49 | #dist = []
50 | #for i in range(batch_size):
51 | # value = distmat[i][mask[i]]
52 | # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability
53 | # dist.append(value)
54 | #dist = torch.cat(dist)
55 | #loss = dist.mean()
56 | return loss
57 |
58 |
59 | if __name__ == '__main__':
60 | use_gpu = False
61 | center_loss = CenterLoss(use_gpu=use_gpu)
62 | features = torch.rand(16, 2048)
63 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long()
64 | if use_gpu:
65 | features = torch.rand(16, 2048).cuda()
66 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda()
67 |
68 | loss = center_loss(features, targets)
69 | print(loss)
70 |
--------------------------------------------------------------------------------
/layers/triplet_loss.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | import torch
7 | from torch import nn
8 |
9 |
10 | def normalize(x, axis=-1):
11 | """Normalizing to unit length along the specified dimension.
12 | Args:
13 | x: pytorch Variable
14 | Returns:
15 | x: pytorch Variable, same shape as input
16 | """
17 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
18 | return x
19 |
20 |
21 | def euclidean_dist(x, y):
22 | """
23 | Args:
24 | x: pytorch Variable, with shape [m, d]
25 | y: pytorch Variable, with shape [n, d]
26 | Returns:
27 | dist: pytorch Variable, with shape [m, n]
28 | """
29 | m, n = x.size(0), y.size(0)
30 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
31 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
32 | dist = xx + yy
33 | dist.addmm_(1, -2, x, y.t())
34 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
35 | return dist
36 |
37 |
38 | def hard_example_mining(dist_mat, labels, return_inds=False):
39 | """For each anchor, find the hardest positive and negative sample.
40 | Args:
41 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
42 | labels: pytorch LongTensor, with shape [N]
43 | return_inds: whether to return the indices. Save time if `False`(?)
44 | Returns:
45 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
46 | dist_an: pytorch Variable, distance(anchor, negative); shape [N]
47 | p_inds: pytorch LongTensor, with shape [N];
48 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
49 | n_inds: pytorch LongTensor, with shape [N];
50 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
51 | NOTE: Only consider the case in which all labels have same num of samples,
52 | thus we can cope with all anchors in parallel.
53 | """
54 |
55 | assert len(dist_mat.size()) == 2
56 | assert dist_mat.size(0) == dist_mat.size(1)
57 | N = dist_mat.size(0)
58 |
59 | # shape [N, N]
60 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
61 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
62 |
63 | # `dist_ap` means distance(anchor, positive)
64 | # both `dist_ap` and `relative_p_inds` with shape [N, 1]
65 |
66 | dist_ap, relative_p_inds = torch.max(
67 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
68 |
69 |
70 | # `dist_an` means distance(anchor, negative)
71 | # both `dist_an` and `relative_n_inds` with shape [N, 1]
72 | dist_an, relative_n_inds = torch.min(
73 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
74 | # shape [N]
75 | dist_ap = dist_ap.squeeze(1)
76 | dist_an = dist_an.squeeze(1)
77 |
78 | if return_inds:
79 | # shape [N, N]
80 | ind = (labels.new().resize_as_(labels)
81 | .copy_(torch.arange(0, N).long())
82 | .unsqueeze(0).expand(N, N))
83 | # shape [N, 1]
84 | p_inds = torch.gather(
85 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
86 | n_inds = torch.gather(
87 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
88 | # shape [N]
89 | p_inds = p_inds.squeeze(1)
90 | n_inds = n_inds.squeeze(1)
91 | return dist_ap, dist_an, p_inds, n_inds
92 |
93 | return dist_ap, dist_an
94 |
95 |
96 | class TripletLoss(object):
97 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
98 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
99 | Loss for Person Re-Identification'."""
100 |
101 | def __init__(self, margin=None):
102 | self.margin = margin
103 | if margin is not None:
104 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
105 | else:
106 | self.ranking_loss = nn.SoftMarginLoss()
107 |
108 | def __call__(self, global_feat, labels, normalize_feature=False):
109 | if normalize_feature:
110 | global_feat = normalize(global_feat, axis=-1)
111 | dist_mat = euclidean_dist(global_feat, global_feat)
112 | dist_ap, dist_an = hard_example_mining(
113 | dist_mat, labels)
114 | y = dist_an.new().resize_as_(dist_an).fill_(1)
115 | if self.margin is not None:
116 | loss = self.ranking_loss(dist_an, dist_ap, y)
117 | else:
118 | loss = self.ranking_loss(dist_an - dist_ap, y)
119 | return loss, dist_ap, dist_an
120 |
121 | class CrossEntropyLabelSmooth(nn.Module):
122 | """Cross entropy loss with label smoothing regularizer.
123 |
124 | Reference:
125 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
126 | Equation: y = (1 - epsilon) * y + epsilon / K.
127 |
128 | Args:
129 | num_classes (int): number of classes.
130 | epsilon (float): weight.
131 | """
132 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
133 | super(CrossEntropyLabelSmooth, self).__init__()
134 | self.num_classes = num_classes
135 | self.epsilon = epsilon
136 | self.use_gpu = use_gpu
137 | self.logsoftmax = nn.LogSoftmax(dim=1)
138 |
139 | def forward(self, inputs, targets):
140 | """
141 | Args:
142 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
143 | targets: ground truth labels with shape (num_classes)
144 | """
145 | log_probs = self.logsoftmax(inputs)
146 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
147 | if self.use_gpu: targets = targets.cuda()
148 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
149 | loss = (- targets * log_probs).mean(0).sum()
150 | return loss
--------------------------------------------------------------------------------
/modeling/PISNet.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 | from torch import nn
9 |
10 | from .backbones.pisnet import pisnet, BasicBlock, Bottleneck
11 |
12 | def weights_init_kaiming(m):
13 | classname = m.__class__.__name__
14 | if classname.find('Linear') != -1:
15 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
16 | nn.init.constant_(m.bias, 0.0)
17 | elif classname.find('Conv') != -1:
18 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19 | if m.bias is not None:
20 | nn.init.constant_(m.bias, 0.0)
21 | elif classname.find('BatchNorm') != -1:
22 | if m.affine:
23 | nn.init.constant_(m.weight, 1.0)
24 | nn.init.constant_(m.bias, 0.0)
25 |
26 |
27 | def weights_init_classifier(m):
28 | classname = m.__class__.__name__
29 | if classname.find('Linear') != -1:
30 | nn.init.normal_(m.weight, std=0.001)
31 | if m.bias is not None:
32 | nn.init.constant_(m.bias, 0.0)
33 |
34 |
35 | class PISNet(nn.Module):
36 | in_planes = 2048
37 |
38 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice, has_non_local="no", sia_reg="no", pyramid="no", test_pair="no"):
39 | super(PISNet, self).__init__()
40 |
41 | self.base = pisnet(last_stride=last_stride,
42 | block=Bottleneck,
43 | layers=[3, 4, 6, 3], has_non_local=has_non_local, sia_reg=sia_reg, pyramid=pyramid)
44 |
45 | if pretrain_choice == 'imagenet':
46 | self.base.load_param(model_path)
47 | print('Loading pretrained ImageNet model......')
48 |
49 | self.gap = nn.AdaptiveAvgPool2d(1)
50 | # self.gap = nn.AdaptiveMaxPool2d(1)
51 | self.num_classes = num_classes
52 | self.neck = neck
53 | self.neck_feat = neck_feat
54 | self.test_pair = test_pair
55 | self.sia_reg = sia_reg
56 |
57 | if self.neck == 'no':
58 | self.classifier = nn.Linear(self.in_planes, self.num_classes)
59 | # self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) # new add by luo
60 | # self.classifier.apply(weights_init_classifier) # new add by luo
61 | elif self.neck == 'bnneck':
62 | self.bottleneck = nn.BatchNorm1d(self.in_planes)
63 | self.bottleneck.bias.requires_grad_(False) # no shift
64 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
65 |
66 | self.bottleneck.apply(weights_init_kaiming)
67 | self.classifier.apply(weights_init_classifier)
68 |
69 | def forward(self, x_g, x, x_g2=[], is_first=False):
70 |
71 | feature_gallery, gallery_attention, feature_gallery1, gallery_attention1, feature_query, reg_feature_query, reg_query_attention = self.base(x_g, x, x_g2=x_g2, is_first=is_first)
72 |
73 | global_feat = self.gap(feature_gallery)
74 | global_feat = global_feat.view(global_feat.shape[0], -1)
75 | # gallery_attention = gallery_attention.view(gallery_attention.shape[0], -1)
76 |
77 | if self.training:
78 | global_feat1 = self.gap(feature_gallery1)
79 | global_feat1 = global_feat.view(global_feat1.shape[0], -1)
80 | gallery_attention1 = gallery_attention.view(gallery_attention1.shape[0], -1)
81 |
82 | global_feature_query = self.gap(feature_query)
83 | global_feature_query = global_feat.view(global_feature_query.shape[0], -1)
84 |
85 | if self.sia_reg == "yes":
86 | global_reg_query = self.gap(reg_feature_query)
87 | global_reg_query = global_feat.view(global_reg_query.shape[0], -1)
88 | reg_query_attention = gallery_attention.view(reg_query_attention.shape[0], -1)
89 |
90 | # cls_score_pos_neg = self.classifier_attention(gallery_attention)
91 | # cls_score_pos_neg = self.sigmoid(cls_score_pos_neg)
92 |
93 | if self.neck == 'no':
94 | feat = global_feat
95 | if self.training:
96 | feat1 = global_feat1
97 | if self.sia_reg == "yes":
98 | feat2 = global_reg_query
99 | # feat_query = global_feature_query
100 |
101 | # feat_guiding = global_feat_guiding
102 | elif self.neck == 'bnneck':
103 | feat = self.bottleneck(global_feat) # normalize for angular softmax
104 | if self.training:
105 | feat1 = self.bottleneck(global_feat1) # normalize for angular softmax
106 | if self.sia_reg == "yes":
107 | feat2 = self.bottleneck(global_reg_query)
108 | # feat_query = self.bottleneck(global_feature_query)
109 |
110 | # feat_guiding = self.bottleneck(global_feat_guiding)
111 | if self.training:
112 | cls_score = self.classifier(feat)
113 | cls_score1 = self.classifier(feat1)
114 | cls_score2 = self.classifier(feat2)
115 | # cls_score_guiding = self.classifier(feat_guiding)
116 | return cls_score, global_feat, cls_score1, global_feat1, global_feature_query, cls_score2, global_reg_query # global feature for triplet loss
117 | else:
118 | if self.neck_feat == 'after':
119 | # print("Test with feature after BN")
120 | return feat
121 | else:
122 | return global_feat
123 |
124 | def load_param(self, trained_path):
125 | param_dict = torch.load(trained_path).state_dict()
126 | for i in param_dict:
127 | if 'classifier' in i:
128 | continue
129 | self.state_dict()[i].copy_(param_dict[i])
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/modeling/Pre_Selection_Model.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 | from torch import nn
9 |
10 | from .backbones.resnet import ResNet, BasicBlock, Bottleneck
11 |
12 |
13 | def weights_init_kaiming(m):
14 | classname = m.__class__.__name__
15 | if classname.find('Linear') != -1:
16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
17 | nn.init.constant_(m.bias, 0.0)
18 | elif classname.find('Conv') != -1:
19 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
20 | if m.bias is not None:
21 | nn.init.constant_(m.bias, 0.0)
22 | elif classname.find('BatchNorm') != -1:
23 | if m.affine:
24 | nn.init.constant_(m.weight, 1.0)
25 | nn.init.constant_(m.bias, 0.0)
26 |
27 |
28 | def weights_init_classifier(m):
29 | classname = m.__class__.__name__
30 | if classname.find('Linear') != -1:
31 | nn.init.normal_(m.weight, std=0.001)
32 | if m.bias:
33 | nn.init.constant_(m.bias, 0.0)
34 |
35 |
36 | class Pre_Selection_Model(nn.Module):
37 | in_planes = 2048
38 |
39 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice):
40 | super(Pre_Selection_Model, self).__init__()
41 |
42 | self.base = ResNet(last_stride=last_stride,
43 | block=Bottleneck,
44 | layers=[3, 4, 6, 3])
45 |
46 | if pretrain_choice == 'imagenet':
47 | self.base.load_param(model_path)
48 | print('Loading pretrained ImageNet model......')
49 |
50 | self.gap = nn.AdaptiveAvgPool2d(1)
51 | # self.gap = nn.AdaptiveMaxPool2d(1)
52 | self.num_classes = num_classes
53 | self.neck = neck
54 | self.neck_feat = neck_feat
55 |
56 | if self.neck == 'no':
57 | self.classifier = nn.Linear(self.in_planes, self.num_classes)
58 | # self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) # new add by luo
59 | # self.classifier.apply(weights_init_classifier) # new add by luo
60 | elif self.neck == 'bnneck':
61 | self.bottleneck = nn.BatchNorm1d(self.in_planes)
62 | self.bottleneck.bias.requires_grad_(False) # no shift
63 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
64 |
65 | self.bottleneck.apply(weights_init_kaiming)
66 | self.classifier.apply(weights_init_classifier)
67 |
68 | def forward(self, x):
69 |
70 | global_feat = self.gap(self.base(x)) # (b, 2048, 1, 1)
71 | global_feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048)
72 |
73 | if self.neck == 'no':
74 | feat = global_feat
75 | elif self.neck == 'bnneck':
76 | feat = self.bottleneck(global_feat) # normalize for angular softmax
77 |
78 | if self.training:
79 | cls_score = self.classifier(feat)
80 | return cls_score, global_feat # global feature for triplet loss
81 | else:
82 | if self.neck_feat == 'after':
83 | # print("Test with feature after BN")
84 | return feat
85 | else:
86 | # print("Test with feature before BN")
87 | return global_feat
88 |
89 | def load_param(self, trained_path):
90 | param_dict = torch.load(trained_path).state_dict()
91 | # param_dict = torch.load(trained_path)['model']
92 | for i in param_dict:
93 | if 'classifier' in i:
94 | continue
95 | self.state_dict()[i].copy_(param_dict[i])
96 |
97 |
--------------------------------------------------------------------------------
/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .PISNet import PISNet
8 | from .Pre_Selection_Model import Pre_Selection_Model
9 |
10 |
11 | def build_model(cfg, num_classes):
12 | model = PISNet(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE, has_non_local=cfg.MODEL.HAS_NON_LOCAL, sia_reg=cfg.MODEL.SIA_REG, pyramid=cfg.MODEL.PYRAMID, test_pair=cfg.TEST.PAIR)
13 | return model
14 |
15 | def build_model_pre(cfg, num_classes):
16 | model = Pre_Selection_Model(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE)
17 | return model
18 |
19 |
--------------------------------------------------------------------------------
/modeling/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/modeling/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/Query_Guided_Attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from time import time
5 |
6 |
7 | class _Query_Guided_Attention(nn.Module):
8 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
9 | super(_Query_Guided_Attention, self).__init__()
10 |
11 | assert dimension in [1, 2, 3]
12 |
13 | self.dimension = dimension
14 | self.sub_sample = sub_sample
15 |
16 | self.in_channels = in_channels
17 | self.inter_channels = inter_channels
18 |
19 | if self.inter_channels is None:
20 | self.inter_channels = in_channels // 2
21 | if self.inter_channels == 0:
22 | self.inter_channels = 1
23 |
24 | if dimension == 3:
25 | conv_nd = nn.Conv3d
26 | self.max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
27 | bn = nn.BatchNorm3d
28 | elif dimension == 2:
29 | conv_nd = nn.Conv2d
30 | self.max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
31 | self.max_pool_layer1 = nn.MaxPool2d(kernel_size=(4, 4))
32 | self.max_pool_layer2 = nn.MaxPool2d(kernel_size=(8, 8))
33 | self.gmp = nn.AdaptiveMaxPool2d(1)
34 |
35 | bn = nn.BatchNorm2d
36 | else:
37 | conv_nd = nn.Conv1d
38 | self.max_pool_layer = nn.MaxPool1d(kernel_size=(2))
39 | bn = nn.BatchNorm1d
40 |
41 | # self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
42 | # kernel_size=1, stride=1, padding=0)
43 |
44 | if bn_layer:
45 | self.W = nn.Sequential(
46 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
47 | kernel_size=1, stride=1, padding=0),
48 | bn(self.in_channels)
49 | )
50 | nn.init.constant_(self.W[1].weight, 0)
51 | nn.init.constant_(self.W[1].bias, 0)
52 | else:
53 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
54 | kernel_size=1, stride=1, padding=0)
55 | nn.init.constant_(self.W.weight, 0)
56 | nn.init.constant_(self.W.bias, 0)
57 |
58 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
59 | kernel_size=1, stride=1, padding=0)
60 |
61 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
62 | kernel_size=1, stride=1, padding=0)
63 |
64 | # self.compress_attention = conv_nd(in_channels=3, out_channels=1,
65 | # kernel_size=1, stride=1, padding=0)
66 | # if sub_sample:
67 | # # self.g = nn.Sequential(self.g, max_pool_layer)
68 | # self.phi = nn.Sequential(self.phi, max_pool_layer)
69 |
70 | self.relu = nn.ReLU()
71 |
72 | # self.gmp = nn.AdaptiveMaxPool1d(1, return_indices=True)
73 |
74 | def forward(self, x, x_g, attention="x", pyramid="no"):
75 | '''
76 | :param x: (b, c, t, h, w)
77 | :return:
78 | '''
79 |
80 | batch_size = x.size(0)
81 | theta_x = self.theta(x)
82 | phi_x = self.phi(x_g)
83 |
84 | if attention == "x":
85 | theta_x = theta_x.view(batch_size, self.inter_channels, -1)
86 | theta_x = theta_x.permute(0, 2, 1)
87 |
88 | if pyramid == "yes":
89 | phi_x1 = self.max_pool_layer(phi_x).view(batch_size, self.inter_channels, -1)
90 | f = torch.matmul(theta_x, phi_x1)
91 | N = f.size(-1)
92 | f_div_C1 = f / N
93 |
94 | phi_x2 = phi_x.view(batch_size, self.inter_channels, -1)
95 | f = torch.matmul(theta_x, phi_x2)
96 | f_div_C2 = f / N
97 |
98 | phi_x3 = self.max_pool_layer1(phi_x).view(batch_size, self.inter_channels, -1)
99 | f = torch.matmul(theta_x, phi_x3)
100 | f_div_C3 = f / N
101 |
102 | phi_x4 = self.max_pool_layer1(phi_x).view(batch_size, self.inter_channels, -1)
103 | f = torch.matmul(theta_x, phi_x4)
104 | f_div_C4 = f / N
105 |
106 | phi_x5 = self.gmp(phi_x).view(batch_size, self.inter_channels, -1)
107 | f = torch.matmul(theta_x, phi_x5)
108 | f_div_C5 = f / N
109 |
110 | f_div_C = torch.cat((f_div_C1, f_div_C2, f_div_C3, f_div_C4, f_div_C5), 2)
111 | elif pyramid == "no":
112 | phi_x1 = phi_x.view(batch_size, self.inter_channels, -1)
113 | f = torch.matmul(theta_x, phi_x1)
114 | N = f.size(-1)
115 | f_div_C = f / N
116 | elif pyramid == "s2":
117 | phi_x1 = self.max_pool_layer(phi_x).view(batch_size, self.inter_channels, -1)
118 | f = torch.matmul(theta_x, phi_x1)
119 | N = f.size(-1)
120 | f_div_C = f / N
121 |
122 | f, max_index = torch.max(f_div_C, 2)
123 | f = f.view(batch_size, *x.size()[2:]).unsqueeze(1)
124 |
125 | W_y = x * f
126 | z = W_y + x
127 |
128 | return z, f.squeeze()
129 |
130 | elif attention == "x_g":
131 | phi_x = phi_x.view(batch_size, self.inter_channels, -1).permute(0, 2, 1)
132 | theta_x = theta_x.view(batch_size, self.inter_channels, -1)
133 | f = torch.matmul(phi_x, theta_x)
134 | N = f.size(-1)
135 | f_div_C = f / N
136 | f, max_index = torch.max(f_div_C, 2)
137 | f = f.view(batch_size, *x_g.size()[2:]).unsqueeze(1)
138 |
139 | W_y = x_g * f
140 | z = W_y + x_g
141 |
142 | return z, f
143 |
144 | class Query_Guided_Attention(_Query_Guided_Attention):
145 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
146 | super(Query_Guided_Attention, self).__init__(in_channels,
147 | inter_channels=inter_channels,
148 | dimension=2, sub_sample=sub_sample,
149 | bn_layer=bn_layer)
--------------------------------------------------------------------------------
/modeling/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 |
--------------------------------------------------------------------------------
/modeling/backbones/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/modeling/backbones/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/modeling/backbones/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/pisnet.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import math
8 |
9 | import torch
10 | from torch import nn
11 | from time import time
12 | from torch.nn import functional as F
13 |
14 | from modeling.backbones.Query_Guided_Attention import Query_Guided_Attention
15 | import numpy as np
16 |
17 | def weights_init_kaiming(m):
18 | classname = m.__class__.__name__
19 | if classname.find('Linear') != -1:
20 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
21 | nn.init.constant_(m.bias, 0.0)
22 | elif classname.find('Conv') != -1:
23 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
24 | if m.bias is not None:
25 | nn.init.constant_(m.bias, 0.0)
26 | elif classname.find('BatchNorm') != -1:
27 | if m.affine:
28 | nn.init.constant_(m.weight, 1.0)
29 | nn.init.constant_(m.bias, 0.0)
30 |
31 | def conv3x3(in_planes, out_planes, stride=1):
32 | """3x3 convolution with padding"""
33 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
34 | padding=1, bias=False)
35 |
36 |
37 | class BasicBlock(nn.Module):
38 | expansion = 1
39 |
40 | def __init__(self, inplanes, planes, stride=1, downsample=None):
41 | super(BasicBlock, self).__init__()
42 | self.conv1 = conv3x3(inplanes, planes, stride)
43 | self.bn1 = nn.BatchNorm2d(planes)
44 | self.relu = nn.ReLU(inplace=True)
45 | self.conv2 = conv3x3(planes, planes)
46 | self.bn2 = nn.BatchNorm2d(planes)
47 | self.downsample = downsample
48 | self.stride = stride
49 |
50 | def forward(self, x):
51 | residual = x
52 |
53 | out = self.conv1(x)
54 | out = self.bn1(out)
55 | out = self.relu(out)
56 |
57 | out = self.conv2(out)
58 | out = self.bn2(out)
59 |
60 | if self.downsample is not None:
61 | residual = self.downsample(x)
62 |
63 | out += residual
64 | out = self.relu(out)
65 |
66 | return out
67 |
68 |
69 | def feature_corruption(x_g, x_g2):
70 | # We ABANDON the standard feature corruption in the paper.
71 | # The simple concat yields the comparable performance.
72 | corrupted_x = torch.cat((x_g, x_g2), 3)
73 | return corrupted_x
74 |
75 |
76 | class Bottleneck(nn.Module):
77 | expansion = 4
78 |
79 | def __init__(self, inplanes, planes, stride=1, downsample=None):
80 | super(Bottleneck, self).__init__()
81 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
82 | self.bn1 = nn.BatchNorm2d(planes)
83 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
84 | padding=1, bias=False)
85 | self.bn2 = nn.BatchNorm2d(planes)
86 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
87 | self.bn3 = nn.BatchNorm2d(planes * 4)
88 | self.relu = nn.ReLU(inplace=True)
89 | self.downsample = downsample
90 | self.stride = stride
91 |
92 | def forward(self, x):
93 | residual = x
94 |
95 | out = self.conv1(x)
96 | out = self.bn1(out)
97 | out = self.relu(out)
98 |
99 | out = self.conv2(out)
100 | out = self.bn2(out)
101 | out = self.relu(out)
102 |
103 | out = self.conv3(out)
104 | out = self.bn3(out)
105 |
106 | if self.downsample is not None:
107 | residual = self.downsample(x)
108 |
109 | out += residual
110 | out = self.relu(out)
111 |
112 | return out
113 |
114 |
115 | class pisnet(nn.Module):
116 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3], has_non_local="no", sia_reg="no", pyramid="no"):
117 | self.inplanes = 64
118 | super().__init__()
119 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
120 | bias=False)
121 | self.bn1 = nn.BatchNorm2d(64)
122 | # self.relu = nn.ReLU(inplace=True) # add missed relu
123 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
124 | self.layer1 = self._make_layer(block, 64, layers[0])
125 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
126 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
127 | self.layer4 = self._make_layer(
128 | block, 512, layers[3], stride=last_stride)
129 | print("has_non_local:" + has_non_local)
130 | self.has_non_local = has_non_local
131 | self.pyramid = pyramid
132 | self.Query_Guided_Attention = Query_Guided_Attention(in_channels=2048)
133 | self.Query_Guided_Attention.apply(weights_init_kaiming)
134 | self.sia_reg = sia_reg
135 |
136 |
137 | def _make_layer(self, block, planes, blocks, stride=1):
138 | downsample = None
139 | if stride != 1 or self.inplanes != planes * block.expansion:
140 | downsample = nn.Sequential(
141 | nn.Conv2d(self.inplanes, planes * block.expansion,
142 | kernel_size=1, stride=stride, bias=False),
143 | nn.BatchNorm2d(planes * block.expansion),
144 | )
145 |
146 | layers = []
147 | layers.append(block(self.inplanes, planes, stride, downsample))
148 | self.inplanes = planes * block.expansion
149 | for i in range(1, blocks):
150 | layers.append(block(self.inplanes, planes))
151 |
152 | return nn.Sequential(*layers)
153 |
154 | def forward(self, x_g, x, x_g2=[], is_first=False):
155 |
156 |
157 | x = self.conv1(x)
158 | x_g = self.conv1(x_g)
159 |
160 | x = self.bn1(x)
161 | x_g = self.bn1(x_g)
162 |
163 | x = self.maxpool(x)
164 | x_g = self.maxpool(x_g)
165 |
166 | x = self.layer1(x)
167 | x_g = self.layer1(x_g)
168 |
169 | x = self.layer2(x)
170 | x_g = self.layer2(x_g)
171 |
172 | x = self.layer3(x)
173 | x_g = self.layer3(x_g)
174 |
175 | x = self.layer4(x)
176 | x_g = self.layer4(x_g)
177 |
178 | if not isinstance(x_g2, list):
179 |
180 | x_g2 = self.conv1(x_g2)
181 | x_g2 = self.bn1(x_g2)
182 | x_g2 = self.maxpool(x_g2)
183 | x_g2 = self.layer1(x_g2)
184 | x_g2 = self.layer2(x_g2)
185 | x_g2 = self.layer3(x_g2)
186 | x_g2 = self.layer4(x_g2)
187 |
188 | x1, attention1 = self.Query_Guided_Attention(x, x_g, attention='x', pyramid=self.pyramid)
189 |
190 | if not isinstance(x_g2, list):
191 | x2, attention2 = self.Query_Guided_Attention(x, x_g2, attention='x', pyramid=self.pyramid)
192 | if self.sia_reg == "yes":
193 | rec_x_g = feature_corruption(x_g, x_g2.detach())
194 | x3, attention3 = self.Query_Guided_Attention(x1, rec_x_g, attention='x_g', pyramid=self.pyramid)
195 | else:
196 | x2 = []
197 | attention2 = []
198 | x3 = []
199 | attention3 = []
200 |
201 | if isinstance(is_first, tuple):
202 | x1[0, :, :, :] = x_g[0, :, :, :]
203 |
204 | return x1, attention1, x2, attention2, x_g, x3, attention3
205 |
206 | def load_param(self, model_path):
207 | param_dict = torch.load(model_path)
208 | for i in param_dict:
209 | if 'fc' in i:
210 | continue
211 | self.state_dict()[i].copy_(param_dict[i])
212 |
213 | def random_init(self):
214 | for m in self.modules():
215 | if isinstance(m, nn.Conv2d):
216 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
217 | m.weight.data.normal_(0, math.sqrt(2. / n))
218 | elif isinstance(m, nn.BatchNorm2d):
219 | m.weight.data.fill_(1)
220 | m.bias.data.zero_()
--------------------------------------------------------------------------------
/modeling/backbones/resnet.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import math
8 |
9 | import torch
10 | from torch import nn
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1):
14 | """3x3 convolution with padding"""
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=1, bias=False)
17 |
18 |
19 | class BasicBlock(nn.Module):
20 | expansion = 1
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None):
23 | super(BasicBlock, self).__init__()
24 | self.conv1 = conv3x3(inplanes, planes, stride)
25 | self.bn1 = nn.BatchNorm2d(planes)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv2 = conv3x3(planes, planes)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 | self.downsample = downsample
30 | self.stride = stride
31 |
32 | def forward(self, x):
33 | residual = x
34 |
35 | out = self.conv1(x)
36 | out = self.bn1(out)
37 | out = self.relu(out)
38 |
39 | out = self.conv2(out)
40 | out = self.bn2(out)
41 |
42 | if self.downsample is not None:
43 | residual = self.downsample(x)
44 |
45 | out += residual
46 | out = self.relu(out)
47 |
48 | return out
49 |
50 |
51 | class Bottleneck(nn.Module):
52 | expansion = 4
53 |
54 | def __init__(self, inplanes, planes, stride=1, downsample=None):
55 | super(Bottleneck, self).__init__()
56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
57 | self.bn1 = nn.BatchNorm2d(planes)
58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
59 | padding=1, bias=False)
60 | self.bn2 = nn.BatchNorm2d(planes)
61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
62 | self.bn3 = nn.BatchNorm2d(planes * 4)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.downsample = downsample
65 | self.stride = stride
66 |
67 | def forward(self, x):
68 | residual = x
69 |
70 | out = self.conv1(x)
71 | out = self.bn1(out)
72 | out = self.relu(out)
73 |
74 | out = self.conv2(out)
75 | out = self.bn2(out)
76 | out = self.relu(out)
77 |
78 | out = self.conv3(out)
79 | out = self.bn3(out)
80 |
81 | if self.downsample is not None:
82 | residual = self.downsample(x)
83 |
84 | out += residual
85 | out = self.relu(out)
86 |
87 | return out
88 |
89 |
90 | class ResNet(nn.Module):
91 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]):
92 | self.inplanes = 64
93 | super().__init__()
94 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
95 | bias=False)
96 | self.bn1 = nn.BatchNorm2d(64)
97 | # self.relu = nn.ReLU(inplace=True) # add missed relu
98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
99 | self.layer1 = self._make_layer(block, 64, layers[0])
100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
102 | self.layer4 = self._make_layer(
103 | block, 512, layers[3], stride=last_stride)
104 |
105 | def _make_layer(self, block, planes, blocks, stride=1):
106 | downsample = None
107 | if stride != 1 or self.inplanes != planes * block.expansion:
108 | downsample = nn.Sequential(
109 | nn.Conv2d(self.inplanes, planes * block.expansion,
110 | kernel_size=1, stride=stride, bias=False),
111 | nn.BatchNorm2d(planes * block.expansion),
112 | )
113 |
114 | layers = []
115 | layers.append(block(self.inplanes, planes, stride, downsample))
116 | self.inplanes = planes * block.expansion
117 | for i in range(1, blocks):
118 | layers.append(block(self.inplanes, planes))
119 |
120 | return nn.Sequential(*layers)
121 |
122 | def forward(self, x):
123 | x = self.conv1(x)
124 | x = self.bn1(x)
125 | # x = self.relu(x) # add missed relu
126 | x = self.maxpool(x)
127 |
128 | x = self.layer1(x)
129 | x = self.layer2(x)
130 | x = self.layer3(x)
131 | x = self.layer4(x)
132 |
133 | return x
134 |
135 | def load_param(self, model_path):
136 | param_dict = torch.load(model_path)
137 | for i in param_dict:
138 | if 'fc' in i:
139 | continue
140 | self.state_dict()[i].copy_(param_dict[i])
141 |
142 | def random_init(self):
143 | for m in self.modules():
144 | if isinstance(m, nn.Conv2d):
145 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
146 | m.weight.data.normal_(0, math.sqrt(2. / n))
147 | elif isinstance(m, nn.BatchNorm2d):
148 | m.weight.data.fill_(1)
149 | m.bias.data.zero_()
150 |
151 |
--------------------------------------------------------------------------------
/modeling/baseline.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 | from torch import nn
9 |
10 | from .backbones.resnet import ResNet, BasicBlock, Bottleneck
11 |
12 | def weights_init_kaiming(m):
13 | classname = m.__class__.__name__
14 | if classname.find('Linear') != -1:
15 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
16 | nn.init.constant_(m.bias, 0.0)
17 | elif classname.find('Conv') != -1:
18 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19 | if m.bias is not None:
20 | nn.init.constant_(m.bias, 0.0)
21 | elif classname.find('BatchNorm') != -1:
22 | if m.affine:
23 | nn.init.constant_(m.weight, 1.0)
24 | nn.init.constant_(m.bias, 0.0)
25 |
26 |
27 | def weights_init_classifier(m):
28 | classname = m.__class__.__name__
29 | if classname.find('Linear') != -1:
30 | nn.init.normal_(m.weight, std=0.001)
31 | if m.bias:
32 | nn.init.constant_(m.bias, 0.0)
33 |
34 |
35 | class Baseline(nn.Module):
36 | in_planes = 2048
37 |
38 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice):
39 | super(Baseline, self).__init__()
40 | #
41 |
--------------------------------------------------------------------------------
/pi_cuhk.sh:
--------------------------------------------------------------------------------
1 | PROJECT_ROOT_DIR=/root/PI-ReID
2 | DATASETS_ROOT_DIR=$PROJECT_ROOT_DIR/datasets
3 | PRETRAINED_PATH=$PROJECT_ROOT_DIR/pretrained/cuhk/resnet50_model_120.pth
4 | OUTPUT=$PROJECT_ROOT_DIR/output/cuhk
5 | Pre_Index_DIR=$PROJECT_ROOT_DIR/pre_index_dir/cuhk_pre_index.json
6 |
7 | python3 tools/pre_selection.py --config_file='configs/softmax_triplet_ftc.yml' MODEL.DEVICE_ID "('1')" DATASETS.NAMES "('cuhk')" \
8 | DATASETS.ROOT_DIR $DATASETS_ROOT_DIR MODEL.PRETRAIN_CHOICE "('self')" \
9 | TEST.WEIGHT $PRETRAINED_PATH \
10 | OUTPUT_DIR $OUTPUT \
11 | Pre_Index_DIR $Pre_Index_DIR
12 |
13 | python3 tools/train.py --config_file='configs/softmax_triplet_ftc.yml' MODEL.DEVICE_ID "('1')" DATASETS.NAMES "('cuhk')" DATASETS.ROOT_DIR $DATASETS_ROOT_DIR \
14 | OUTPUT_DIR $OUTPUT SOLVER.BASE_LR 0.000035 TEST.PAIR "no" SOLVER.IMS_PER_BATCH 64 \
15 | MODEL.WHOLE_MODEL_TRAIN "no" MODEL.PYRAMID "s2" MODEL.SIA_REG "yes" MODEL.GAMMA 1.0 SOLVER.MARGIN 0.1 MODEL.BETA 0.5 DATASETS.TRAIN_ANNO 1 SOLVER.EVAL_PERIOD 10 SOLVER.MAX_EPOCHS 50 \
16 | MODEL.PRETRAIN_PATH $PRETRAINED_PATH \
17 | Pre_Index_DIR $Pre_Index_DIR
18 |
19 |
--------------------------------------------------------------------------------
/pi_prw.sh:
--------------------------------------------------------------------------------
1 | PROJECT_ROOT_DIR=/root/PI-ReID
2 | DATASETS_ROOT_DIR=$PROJECT_ROOT_DIR/datasets
3 | PRETRAINED_PATH=$PROJECT_ROOT_DIR/pretrained/prw/resnet50_model_120.pth
4 | OUTPUT=$PROJECT_ROOT_DIR/output/prw
5 | Pre_Index_DIR=$PROJECT_ROOT_DIR/pre_index_dir/prw_pre_index.json
6 |
7 | python tools/pre_selection.py --config_file='configs/softmax_triplet_ft.yml' MODEL.DEVICE_ID "('3')" DATASETS.NAMES "('prw')" \
8 | DATASETS.ROOT_DIR $DATASETS_ROOT_DIR MODEL.PRETRAIN_CHOICE "('self')" \
9 | TEST.WEIGHT $PRETRAINED_PATH \
10 | OUTPUT_DIR $OUTPUT \
11 | Pre_Index_DIR $Pre_Index_DIR
12 |
13 | python3 tools/train.py --config_file='configs/softmax_triplet_ft.yml' MODEL.DEVICE_ID "('3')" DATASETS.NAMES "('prw')" DATASETS.ROOT_DIR $DATASETS_ROOT_DIR \
14 | OUTPUT_DIR $OUTPUT SOLVER.BASE_LR 0.00035 TEST.PAIR "no" SOLVER.IMS_PER_BATCH 64 \
15 | MODEL.WHOLE_MODEL_TRAIN "no" MODEL.PYRAMID "s2" MODEL.SIA_REG "yes" MODEL.GAMMA 1.0 SOLVER.MARGIN 0.1 MODEL.BETA 0.5 DATASETS.TRAIN_ANNO 1 SOLVER.EVAL_PERIOD 15 SOLVER.MAX_EPOCHS 50 \
16 | MODEL.PRETRAIN_PATH $PRETRAINED_PATH \
17 | Pre_Index_DIR $Pre_Index_DIR
18 |
--------------------------------------------------------------------------------
/pre_select_cuhk.sh:
--------------------------------------------------------------------------------
1 |
2 | python3 tools/pre_selection.py --config_file='configs/softmax_triplet_ftc.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('cuhk')" \
3 | DATASETS.ROOT_DIR "('/root/person_search/dataset')" MODEL.PRETRAIN_CHOICE "('self')" \
4 | TEST.WEIGHT "('/root/person_search/trained/strong_baseline/cuhk_all_trick_1/resnet50_model_120.pth')" \
5 | OUTPUT_DIR "('/root/person_search/trained/multi_person/cuhk_all_trick_1')"
6 |
--------------------------------------------------------------------------------
/solver/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .build import make_optimizer, make_optimizer_with_center
8 | from .lr_scheduler import WarmupMultiStepLR
--------------------------------------------------------------------------------
/solver/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/solver/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/__pycache__/build.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/solver/__pycache__/build.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/__pycache__/lr_scheduler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/solver/__pycache__/lr_scheduler.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch
8 |
9 |
10 | def make_optimizer(cfg, model):
11 | params = []
12 | for key, value in model.named_parameters():
13 | if not value.requires_grad:
14 | continue
15 | lr = cfg.SOLVER.BASE_LR
16 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
17 | if "bias" in key:
18 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
19 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
20 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
23 | else:
24 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
25 | return optimizer
26 |
27 |
28 | def make_optimizer_with_center(cfg, model, center_criterion):
29 | params = []
30 | for key, value in model.named_parameters():
31 | if not value.requires_grad:
32 | continue
33 | lr = cfg.SOLVER.BASE_LR
34 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
35 | if "bias" in key:
36 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
37 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
38 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
39 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
40 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
41 | else:
42 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
43 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR)
44 | return optimizer, optimizer_center
45 |
--------------------------------------------------------------------------------
/solver/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | from bisect import bisect_right
7 | import torch
8 |
9 |
10 | # FIXME ideally this would be achieved with a CombinedLRScheduler,
11 | # separating MultiStepLR with WarmupLR
12 | # but the current LRScheduler design doesn't allow it
13 |
14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
15 | def __init__(
16 | self,
17 | optimizer,
18 | milestones,
19 | gamma=0.1,
20 | warmup_factor=1.0 / 3,
21 | warmup_iters=500,
22 | warmup_method="linear",
23 | last_epoch=-1,
24 | ):
25 | if not list(milestones) == sorted(milestones):
26 | raise ValueError(
27 | "Milestones should be a list of" " increasing integers. Got {}",
28 | milestones,
29 | )
30 |
31 | if warmup_method not in ("constant", "linear"):
32 | raise ValueError(
33 | "Only 'constant' or 'linear' warmup_method accepted"
34 | "got {}".format(warmup_method)
35 | )
36 | self.milestones = milestones
37 | self.gamma = gamma
38 | self.warmup_factor = warmup_factor
39 | self.warmup_iters = warmup_iters
40 | self.warmup_method = warmup_method
41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
42 |
43 | def get_lr(self):
44 | warmup_factor = 1
45 | if self.last_epoch < self.warmup_iters:
46 | if self.warmup_method == "constant":
47 | warmup_factor = self.warmup_factor
48 | elif self.warmup_method == "linear":
49 | alpha = self.last_epoch / self.warmup_iters
50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
51 | return [
52 | base_lr
53 | * warmup_factor
54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
55 | for base_lr in self.base_lrs
56 | ]
57 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
--------------------------------------------------------------------------------
/tests/lr_scheduler_test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import unittest
3 |
4 | import torch
5 | from torch import nn
6 |
7 | sys.path.append('.')
8 | from solver.lr_scheduler import WarmupMultiStepLR
9 | from solver.build import make_optimizer
10 | from config import cfg
11 |
12 |
13 | class MyTestCase(unittest.TestCase):
14 | def test_something(self):
15 | net = nn.Linear(10, 10)
16 | optimizer = make_optimizer(cfg, net)
17 | lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10)
18 | for i in range(50):
19 | lr_scheduler.step()
20 | for j in range(3):
21 | print(i, lr_scheduler.get_lr()[0])
22 | optimizer.step()
23 |
24 |
25 | if __name__ == '__main__':
26 | unittest.main()
27 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
--------------------------------------------------------------------------------
/tools/pre_selection.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import argparse
8 | import os
9 | import sys
10 | from os import mkdir
11 |
12 | import torch
13 | from torch.backends import cudnn
14 | import json
15 |
16 | sys.path.append('.')
17 | from config import cfg
18 | from data import make_data_loader
19 | from engine.inference import inference
20 | from modeling import build_model_pre
21 | from utils.logger import setup_logger
22 |
23 | import torch
24 | import torch.nn as nn
25 | from ignite.engine import Engine
26 |
27 | from utils.reid_metric import R1_mAP, pre_selection_index
28 |
29 | def create_supervised_evaluator(model, metrics,
30 | device=None):
31 | """
32 | Factory function for creating an evaluator for supervised models
33 |
34 | Args:
35 | model (`torch.nn.Module`): the model to train
36 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
37 | device (str, optional): device type specification (default: None).
38 | Applies to both model and batches.
39 | Returns:
40 | Engine: an evaluator engine with supervised inference function
41 | """
42 | if device:
43 | if torch.cuda.device_count() > 1:
44 | model = nn.DataParallel(model)
45 | model.to(device)
46 |
47 | def _inference(engine, batch):
48 | model.eval()
49 | with torch.no_grad():
50 | data, pids, camids = batch
51 | data = data.to(device) if torch.cuda.device_count() >= 1 else data
52 | feat = model(data)
53 | return feat, pids, camids
54 |
55 | engine = Engine(_inference)
56 |
57 | for name, metric in metrics.items():
58 | metric.attach(engine, name)
59 |
60 | return engine
61 |
62 | # def inference(
63 | # cfg,
64 | # model,
65 | # val_loader,
66 | # num_query
67 | # ):
68 | # device = cfg.MODEL.DEVICE
69 | #
70 | # logger = logging.getLogger("reid_baseline.inference")
71 | # logger.info("Enter inferencing")
72 | # if cfg.TEST.RE_RANKING == 'no':
73 | # print("Create evaluator")
74 | # evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=100, feat_norm=cfg.TEST.FEAT_NORM)},
75 | # device=device)
76 | # elif cfg.TEST.RE_RANKING == 'yes':
77 | # print("Create evaluator for reranking")
78 | # evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=100, feat_norm=cfg.TEST.FEAT_NORM)},
79 | # device=device)
80 | # else:
81 | # print("Unsupported re_ranking config. Only support for no or yes, but got {}.".format(cfg.TEST.RE_RANKING))
82 | #
83 | # evaluator.run(val_loader)
84 | # cmc, mAP, _ = evaluator.state.metrics['r1_mAP']
85 | # logger.info('Validation Results')
86 | # logger.info("mAP: {:.1%}".format(mAP))
87 | # for r in [1, 5, 10, 100]:
88 | # logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
89 |
90 |
91 |
92 | def main():
93 | parser = argparse.ArgumentParser(description="ReID Baseline Inference")
94 | parser.add_argument(
95 | "--config_file", default="", help="path to config file", type=str
96 | )
97 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
98 | nargs=argparse.REMAINDER)
99 |
100 | args = parser.parse_args()
101 |
102 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
103 |
104 | if args.config_file != "":
105 | cfg.merge_from_file(args.config_file)
106 | cfg.merge_from_list(args.opts)
107 | cfg.freeze()
108 |
109 | output_dir = cfg.OUTPUT_DIR
110 | if output_dir and not os.path.exists(output_dir):
111 | mkdir(output_dir)
112 |
113 | if cfg.MODEL.DEVICE == "cuda":
114 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID
115 | cudnn.benchmark = True
116 |
117 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
118 | model = build_model_pre(cfg, num_classes)
119 | model.load_param(cfg.TEST.WEIGHT)
120 |
121 | # inference(cfg, model, val_loader, num_query)
122 | device = cfg.MODEL.DEVICE
123 |
124 | evaluator = create_supervised_evaluator(model, metrics={
125 | 'pre_selection_index': pre_selection_index(num_query, max_rank=100, feat_norm=cfg.TEST.FEAT_NORM)},
126 | device=device)
127 |
128 | evaluator.run(val_loader)
129 |
130 | index = evaluator.state.metrics['pre_selection_index']
131 |
132 | with open(cfg.Pre_Index_DIR, 'w+') as f:
133 | json.dump(index.tolist(), f)
134 |
135 | print("Pre_Selection_Done")
136 |
137 | if __name__ == '__main__':
138 | main()
139 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import argparse
8 | import os
9 | import sys
10 | import torch
11 |
12 | from torch.backends import cudnn
13 |
14 | sys.path.append('.')
15 | from config import cfg
16 | from data import make_data_loader, make_data_loader_train
17 | from engine.trainer import do_train, do_train_with_center
18 | from modeling import build_model
19 | from layers import make_loss, make_loss_with_center
20 | from solver import make_optimizer, make_optimizer_with_center, WarmupMultiStepLR
21 |
22 | from utils.logger import setup_logger
23 |
24 |
25 |
26 |
27 |
28 |
29 | def train(cfg):
30 | # prepare dataset
31 | # train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
32 | train_loader, val_loader, num_query, num_classes = make_data_loader_train(cfg)
33 |
34 |
35 | # prepare model
36 | if 'prw' in cfg.DATASETS.NAMES:
37 | num_classes = 483
38 | elif "market1501" in cfg.DATASETS.NAMES:
39 | num_classes = 751
40 | elif "duke" in cfg.DATASETS.NAMES:
41 | num_classes = 702
42 | elif "cuhk" in cfg.DATASETS.NAMES:
43 | num_classes = 5532
44 |
45 |
46 | model = build_model(cfg, num_classes)
47 |
48 | if cfg.MODEL.IF_WITH_CENTER == 'no':
49 | print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
50 | optimizer = make_optimizer(cfg, model)
51 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
52 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
53 |
54 | loss_func = make_loss(cfg, num_classes) # modified by gu
55 |
56 | # Add for using self trained model
57 | if cfg.MODEL.PRETRAIN_CHOICE == 'self':
58 | # start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1])
59 | start_epoch = 0
60 | print('Start epoch:', start_epoch)
61 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
62 | print('Path to the checkpoint of optimizer:', path_to_optimizer)
63 |
64 |
65 | pretrained_dic = torch.load(cfg.MODEL.PRETRAIN_PATH).state_dict()
66 | model_dict = model.state_dict()
67 |
68 | model_dict.update(pretrained_dic)
69 | model.load_state_dict(model_dict)
70 |
71 | if cfg.MODEL.WHOLE_MODEL_TRAIN == "no":
72 | for name, value in model.named_parameters():
73 | if "Query_Guided_Attention" not in name and "non_local" not in name and "classifier_attention" not in name:
74 | value.requires_grad = False
75 | optimizer = make_optimizer(cfg, model)
76 | # else:
77 | # cfg.SOLVER.BASE_LR = 0.0000035
78 |
79 | # optimizer.load_state_dict(torch.load(path_to_optimizer))
80 | # #####
81 | # for state in optimizer.state.values():
82 | # for k, v in state.items():
83 | # if isinstance(v, torch.Tensor):
84 | # state[k] = v.cuda()
85 | # #####
86 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
87 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
88 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
89 | start_epoch = 0
90 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
91 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
92 | else:
93 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE))
94 |
95 | arguments = {}
96 |
97 | do_train(
98 | cfg,
99 | model,
100 | train_loader,
101 | val_loader,
102 | optimizer,
103 | scheduler, # modify for using self trained model
104 | loss_func,
105 | num_query,
106 | start_epoch # add for using self trained model
107 | )
108 | elif cfg.MODEL.IF_WITH_CENTER == 'yes':
109 | print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
110 | loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu
111 | optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion)
112 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
113 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
114 |
115 | arguments = {}
116 |
117 | # Add for using self trained model
118 | if cfg.MODEL.PRETRAIN_CHOICE == 'self':
119 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1])
120 | print('Start epoch:', start_epoch)
121 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
122 | print('Path to the checkpoint of optimizer:', path_to_optimizer)
123 | path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param')
124 | print('Path to the checkpoint of center_param:', path_to_center_param)
125 | path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center')
126 | print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center)
127 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
128 | optimizer.load_state_dict(torch.load(path_to_optimizer))
129 | #####
130 | for state in optimizer.state.values():
131 | for k, v in state.items():
132 | if isinstance(v, torch.Tensor):
133 | state[k] = v.cuda()
134 | #####
135 | center_criterion.load_state_dict(torch.load(path_to_center_param))
136 | optimizer_center.load_state_dict(torch.load(path_to_optimizer_center))
137 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
138 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)
139 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
140 | start_epoch = 0
141 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
142 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
143 | else:
144 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE))
145 |
146 | do_train_with_center(
147 | cfg,
148 | model,
149 | center_criterion,
150 | train_loader,
151 | val_loader,
152 | optimizer,
153 | optimizer_center,
154 | scheduler, # modify for using self trained model
155 | loss_func,
156 | num_query,
157 | start_epoch # add for using self trained model
158 | )
159 | else:
160 | print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER))
161 |
162 |
163 | def main():
164 | parser = argparse.ArgumentParser(description="ReID Baseline Training")
165 | parser.add_argument(
166 | "--config_file", default="", help="path to config file", type=str
167 | )
168 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
169 | nargs=argparse.REMAINDER)
170 |
171 | args = parser.parse_args()
172 |
173 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
174 |
175 | if args.config_file != "":
176 | cfg.merge_from_file(args.config_file)
177 | cfg.merge_from_list(args.opts)
178 | cfg.freeze()
179 |
180 | output_dir = cfg.OUTPUT_DIR
181 | if output_dir and not os.path.exists(output_dir):
182 | os.makedirs(output_dir)
183 |
184 | logger = setup_logger("reid_baseline", output_dir, 0)
185 | logger.info("Using {} GPUS".format(num_gpus))
186 | logger.info(args)
187 |
188 | if args.config_file != "":
189 | logger.info("Loaded configuration file {}".format(args.config_file))
190 | with open(args.config_file, 'r') as cf:
191 | config_str = "\n" + cf.read()
192 | logger.info(config_str)
193 | logger.info("Running with config:\n{}".format(cfg))
194 |
195 | if cfg.MODEL.DEVICE == "cuda":
196 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu
197 | cudnn.benchmark = True
198 | train(cfg)
199 |
200 |
201 | if __name__ == '__main__':
202 |
203 | # model_path = "/raid/home/henrayzhao/person_search/trained/strong_baseline/prw_non_local/resnet50_model_40.pth"
204 | # model_dic = torch.load(model_path)
205 | # print(model_dic)
206 | main()
207 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 |
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/iotools.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/utils/__pycache__/iotools.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/utils/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/re_ranking.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/utils/__pycache__/re_ranking.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/reid_metric.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/utils/__pycache__/reid_metric.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/iotools.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import errno
8 | import json
9 | import os
10 |
11 | import os.path as osp
12 |
13 |
14 | def mkdir_if_missing(directory):
15 | if not osp.exists(directory):
16 | try:
17 | os.makedirs(directory)
18 | except OSError as e:
19 | if e.errno != errno.EEXIST:
20 | raise
21 |
22 |
23 | def check_isfile(path):
24 | isfile = osp.isfile(path)
25 | if not isfile:
26 | print("=> Warning: no file found at '{}' (ignored)".format(path))
27 | return isfile
28 |
29 |
30 | def read_json(fpath):
31 | with open(fpath, 'r') as f:
32 | obj = json.load(f)
33 | return obj
34 |
35 |
36 | def write_json(obj, fpath):
37 | mkdir_if_missing(osp.dirname(fpath))
38 | with open(fpath, 'w') as f:
39 | json.dump(obj, f, indent=4, separators=(',', ': '))
40 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import logging
8 | import os
9 | import sys
10 |
11 |
12 | def setup_logger(name, save_dir, distributed_rank):
13 | logger = logging.getLogger(name)
14 | logger.setLevel(logging.DEBUG)
15 | # don't log results for the non-master process
16 | if distributed_rank > 0:
17 | return logger
18 | ch = logging.StreamHandler(stream=sys.stdout)
19 | ch.setLevel(logging.DEBUG)
20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
21 | ch.setFormatter(formatter)
22 | logger.addHandler(ch)
23 |
24 | if save_dir:
25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w')
26 | fh.setLevel(logging.DEBUG)
27 | fh.setFormatter(formatter)
28 | logger.addHandler(fh)
29 |
30 | return logger
31 |
--------------------------------------------------------------------------------
/utils/re_ranking.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Fri, 25 May 2018 20:29:09
5 |
6 | @author: luohao
7 | """
8 |
9 | """
10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking
13 | """
14 |
15 | """
16 | API
17 |
18 | probFea: all feature vectors of the query set (torch tensor)
19 | probFea: all feature vectors of the gallery set (torch tensor)
20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3)
21 | MemorySave: set to 'True' when using MemorySave mode
22 | Minibatch: avaliable when 'MemorySave' is 'True'
23 | """
24 |
25 | import numpy as np
26 | import torch
27 |
28 |
29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False):
30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor
31 | query_num = probFea.size(0)
32 | all_num = query_num + galFea.size(0)
33 | if only_local:
34 | original_dist = local_distmat
35 | else:
36 | feat = torch.cat([probFea,galFea])
37 | print('using GPU to compute original distance')
38 | distmat = torch.pow(feat,2).sum(dim=1, keepdim=True).expand(all_num,all_num) + \
39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t()
40 | distmat.addmm_(1,-2,feat,feat.t())
41 | original_dist = distmat.cpu().numpy()
42 | del feat
43 | if not local_distmat is None:
44 | original_dist = original_dist + local_distmat
45 | gallery_num = original_dist.shape[0]
46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
47 | V = np.zeros_like(original_dist).astype(np.float16)
48 | initial_rank = np.argsort(original_dist).astype(np.int32)
49 |
50 | print('starting re_ranking')
51 | for i in range(all_num):
52 | # k-reciprocal neighbors
53 | forward_k_neigh_index = initial_rank[i, :k1 + 1]
54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
55 | fi = np.where(backward_k_neigh_index == i)[0]
56 | k_reciprocal_index = forward_k_neigh_index[fi]
57 | k_reciprocal_expansion_index = k_reciprocal_index
58 | for j in range(len(k_reciprocal_index)):
59 | candidate = k_reciprocal_index[j]
60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
62 | :int(np.around(k1 / 2)) + 1]
63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
66 | candidate_k_reciprocal_index):
67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
68 |
69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
72 | original_dist = original_dist[:query_num, ]
73 | if k2 != 1:
74 | V_qe = np.zeros_like(V, dtype=np.float16)
75 | for i in range(all_num):
76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
77 | V = V_qe
78 | del V_qe
79 | del initial_rank
80 | invIndex = []
81 | for i in range(gallery_num):
82 | invIndex.append(np.where(V[:, i] != 0)[0])
83 |
84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
85 |
86 | for i in range(query_num):
87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
88 | indNonZero = np.where(V[i, :] != 0)[0]
89 | indImages = [invIndex[ind] for ind in indNonZero]
90 | for j in range(len(indNonZero)):
91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
92 | V[indImages[j], indNonZero[j]])
93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
94 |
95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
96 | del original_dist
97 | del V
98 | del jaccard_dist
99 | final_dist = final_dist[:query_num, query_num:]
100 | return final_dist
101 |
102 |
--------------------------------------------------------------------------------
/utils/reid_metric.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import numpy as np
8 | import torch
9 | from ignite.metrics import Metric
10 |
11 | from data.datasets.eval_reid import eval_func
12 | from .re_ranking import re_ranking
13 |
14 | class pre_selection_index(Metric):
15 | def __init__(self, num_query, max_rank=100, feat_norm='yes'):
16 | super(pre_selection_index, self).__init__()
17 | self.num_query = num_query
18 | self.max_rank = max_rank
19 | self.feat_norm = feat_norm
20 |
21 | def reset(self):
22 | self.feats = []
23 | self.pids = []
24 | self.camids = []
25 |
26 | def update(self, output):
27 | feat, pid, camid = output
28 | self.feats.append(feat)
29 | self.pids.extend(np.asarray(pid))
30 | self.camids.extend(np.asarray(camid))
31 |
32 | def compute(self):
33 | feats = torch.cat(self.feats, dim=0)
34 | if self.feat_norm == 'yes':
35 | # print("The test feature is normalized")
36 | feats = torch.nn.functional.normalize(feats, dim=1, p=2)
37 | # query
38 | qf = feats[:self.num_query]
39 | # gallery
40 | gf = feats[self.num_query:]
41 | m, n = qf.shape[0], gf.shape[0]
42 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
43 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
44 | distmat.addmm_(1, -2, qf, gf.t())
45 | distmat = distmat.cpu().numpy()
46 | return np.argsort(distmat, axis=1)
47 |
48 |
49 | class R1_mAP(Metric):
50 | def __init__(self, num_query, max_rank=100, feat_norm='yes'):
51 | super(R1_mAP, self).__init__()
52 | self.num_query = num_query
53 | self.max_rank = max_rank
54 | self.feat_norm = feat_norm
55 |
56 | def reset(self):
57 | self.feats = []
58 | self.pids = []
59 | self.camids = []
60 |
61 | def update(self, output):
62 | feat, pid, camid = output
63 | self.feats.append(feat)
64 | self.pids.extend(np.asarray(pid))
65 | self.camids.extend(np.asarray(camid))
66 |
67 | def compute(self):
68 | feats = torch.cat(self.feats, dim=0)
69 | if self.feat_norm == 'yes':
70 | # print("The test feature is normalized")
71 | feats = torch.nn.functional.normalize(feats, dim=1, p=2)
72 | # query
73 | qf = feats[:self.num_query]
74 | q_pids = np.asarray(self.pids[:self.num_query])
75 | q_camids = np.asarray(self.camids[:self.num_query])
76 | # gallery
77 | gf = feats[self.num_query:]
78 | g_pids = np.asarray(self.pids[self.num_query:])
79 | g_camids = np.asarray(self.camids[self.num_query:])
80 | m, n = qf.shape[0], gf.shape[0]
81 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
82 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
83 | distmat.addmm_(1, -2, qf, gf.t())
84 | distmat = distmat.cpu().numpy()
85 |
86 | cmc, mAP, q_pid_return = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
87 |
88 | return cmc, mAP, q_pid_return
89 |
90 | class R1_mAP_pair(Metric):
91 | def __init__(self, num_query, max_rank=100, feat_norm='yes'):
92 | super(R1_mAP_pair, self).__init__()
93 | self.num_query = num_query
94 | self.max_rank = max_rank
95 | self.feat_norm = feat_norm
96 |
97 | def reset(self):
98 | self.scores = []
99 | self.pids = []
100 | self.camids = []
101 |
102 | def update(self, output):
103 | score, pid, camid = output
104 | self.scores.append(score)
105 | self.pids.extend(np.asarray(pid))
106 | self.camids.extend(np.asarray(camid))
107 |
108 | def compute(self):
109 | scores = torch.cat(self.scores, dim=0).view(1, -1)
110 | distmat = scores.cpu().numpy()
111 | # print(distmat.shape)
112 |
113 | if distmat.shape[1] == 101:
114 | distmat = distmat[:, 1:]
115 | # query
116 | q_pids = np.asarray(self.pids[:self.num_query])
117 | q_camids = np.asarray(self.camids[:self.num_query])
118 | # gallery
119 | g_pids = np.asarray(self.pids[self.num_query:])
120 | g_camids = np.asarray(self.camids[self.num_query:])
121 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
122 |
123 | return cmc, mAP
124 |
125 |
126 | class R1_mAP_reranking(Metric):
127 | def __init__(self, num_query, max_rank=100, feat_norm='yes'):
128 | super(R1_mAP_reranking, self).__init__()
129 | self.num_query = num_query
130 | self.max_rank = max_rank
131 | self.feat_norm = feat_norm
132 |
133 | def reset(self):
134 | self.feats = []
135 | self.pids = []
136 | self.camids = []
137 |
138 | def update(self, output):
139 | feat, pid, camid = output
140 | self.feats.append(feat)
141 | self.pids.extend(np.asarray(pid))
142 | self.camids.extend(np.asarray(camid))
143 |
144 | def compute(self):
145 | feats = torch.cat(self.feats, dim=0)
146 | if self.feat_norm == 'yes':
147 | # print("The test feature is normalized")
148 | feats = torch.nn.functional.normalize(feats, dim=1, p=2)
149 |
150 | # query
151 | qf = feats[:self.num_query]
152 | q_pids = np.asarray(self.pids[:self.num_query])
153 | q_camids = np.asarray(self.camids[:self.num_query])
154 | # gallery
155 | gf = feats[self.num_query:]
156 | g_pids = np.asarray(self.pids[self.num_query:])
157 | g_camids = np.asarray(self.camids[self.num_query:])
158 | # m, n = qf.shape[0], gf.shape[0]
159 | # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
160 | # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
161 | # distmat.addmm_(1, -2, qf, gf.t())
162 | # distmat = distmat.cpu().numpy()
163 | print("Enter reranking")
164 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3)
165 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
166 |
167 | return cmc, mAP
--------------------------------------------------------------------------------