66 |
67 | Testing data (Occluded-REID): query (left, 100% occluded), gallery (right, ~ 100% non-occluded)
68 |
69 |

70 |

71 |
72 |
73 | ## Contact
74 |
75 | Please contact Zi Wang (email address: [ziwang1121@foxmail.com](mailto:ziwang1121@foxmail.com)). Feel free to drop me an email if you have any questions.
76 |
77 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .defaults import _C as cfg
8 | from .defaults import _C as cfg_test
9 |
--------------------------------------------------------------------------------
/config/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/config/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/config/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/config/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/config/__pycache__/defaults.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/config/__pycache__/defaults.cpython-36.pyc
--------------------------------------------------------------------------------
/config/__pycache__/defaults.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/config/__pycache__/defaults.cpython-37.pyc
--------------------------------------------------------------------------------
/configs/DukeMTMC/deit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('2')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.8 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('dukemtmc')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: '../logs/dukemtmc_deit_transreid/transformer_120.pth'
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/dukemtmc_deit_transreid'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/DukeMTMC/vit_base.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('6')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 128]
15 | SIZE_TEST: [256, 128]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('dukemtmc')
24 | ROOT_DIR: ('/data2/zi.wang/')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 120
34 | BASE_LR: 0.008
35 | IMS_PER_BATCH: 64
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 120
39 | LOG_PERIOD: 50
40 | EVAL_PERIOD: 120
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: ''
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: '../logs/duke_vit_base'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/DukeMTMC/vit_jpm.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('1')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | JPM: True
13 | RE_ARRANGE: True
14 |
15 | INPUT:
16 | SIZE_TRAIN: [256, 128]
17 | SIZE_TEST: [256, 128]
18 | PROB: 0.5 # random horizontal flip
19 | RE_PROB: 0.5 # random erasing
20 | PADDING: 10
21 | PIXEL_MEAN: [0.5, 0.5, 0.5]
22 | PIXEL_STD: [0.5, 0.5, 0.5]
23 |
24 | DATASETS:
25 | NAMES: ('dukemtmc')
26 | ROOT_DIR: ('../../data')
27 |
28 | DATALOADER:
29 | SAMPLER: 'softmax_triplet'
30 | NUM_INSTANCE: 4
31 | NUM_WORKERS: 8
32 |
33 | SOLVER:
34 | OPTIMIZER_NAME: 'SGD'
35 | MAX_EPOCHS: 120
36 | BASE_LR: 0.008
37 | IMS_PER_BATCH: 64
38 | WARMUP_METHOD: 'linear'
39 | LARGE_FC_LR: False
40 | CHECKPOINT_PERIOD: 120
41 | LOG_PERIOD: 50
42 | EVAL_PERIOD: 120
43 | WEIGHT_DECAY: 1e-4
44 | WEIGHT_DECAY_BIAS: 1e-4
45 | BIAS_LR_FACTOR: 2
46 |
47 | TEST:
48 | EVAL: True
49 | IMS_PER_BATCH: 256
50 | RE_RANKING: False
51 | WEIGHT: '../logs/duke_vit_jpm/transformer_120.pth'
52 | NECK_FEAT: 'before'
53 | FEAT_NORM: 'yes'
54 |
55 | OUTPUT_DIR: '../logs/duke_vit_jpm'
56 |
57 |
58 |
--------------------------------------------------------------------------------
/configs/DukeMTMC/vit_sie.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('2')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 |
15 | INPUT:
16 | SIZE_TRAIN: [256, 128]
17 | SIZE_TEST: [256, 128]
18 | PROB: 0.5 # random horizontal flip
19 | RE_PROB: 0.5 # random erasing
20 | PADDING: 10
21 | PIXEL_MEAN: [0.5, 0.5, 0.5]
22 | PIXEL_STD: [0.5, 0.5, 0.5]
23 |
24 | DATASETS:
25 | NAMES: ('dukemtmc')
26 | ROOT_DIR: ('../../data')
27 |
28 | DATALOADER:
29 | SAMPLER: 'softmax_triplet'
30 | NUM_INSTANCE: 4
31 | NUM_WORKERS: 8
32 |
33 | SOLVER:
34 | OPTIMIZER_NAME: 'SGD'
35 | MAX_EPOCHS: 120
36 | BASE_LR: 0.008
37 | IMS_PER_BATCH: 64
38 | WARMUP_METHOD: 'linear'
39 | LARGE_FC_LR: False
40 | CHECKPOINT_PERIOD: 120
41 | LOG_PERIOD: 50
42 | EVAL_PERIOD: 120
43 | WEIGHT_DECAY: 1e-4
44 | WEIGHT_DECAY_BIAS: 1e-4
45 | BIAS_LR_FACTOR: 2
46 |
47 | TEST:
48 | EVAL: True
49 | IMS_PER_BATCH: 256
50 | RE_RANKING: False
51 | WEIGHT: '../logs/duke_vit_sie/transformer_120.pth'
52 | NECK_FEAT: 'before'
53 | FEAT_NORM: 'yes'
54 |
55 | OUTPUT_DIR: '../logs/duke_vit_sie'
56 |
57 |
58 |
--------------------------------------------------------------------------------
/configs/DukeMTMC/vit_transreid.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('dukemtmc')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: '../logs/duke_vit_transreid/transformer_120.pth'
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/duke_vit_transreid'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/DukeMTMC/vit_transreid_384.yml:
--------------------------------------------------------------------------------
1 | tMODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [384, 128]
19 | SIZE_TEST: [384, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('dukemtmc')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: '../logs/duke_vit_transreid_384/transformer_120.pth'
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/duke_vit_transreid_384'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/DukeMTMC/vit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('4')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('dukemtmc')
28 | ROOT_DIR: ('/data2/zi.wang/')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 300
38 | BASE_LR: 0.01
39 | IMS_PER_BATCH: 32
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 10000
43 | LOG_PERIOD: 100
44 | EVAL_PERIOD: 10
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: './logs_duke/lr001_b32_Process1_Model12_loss1_AO'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/DukeMTMC/vit_transreid_stride_384.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [384, 128]
19 | SIZE_TEST: [384, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('dukemtmc')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: '../logs/duke_vit_transreid_stride_384/transformer_120.pth'
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/duke_vit_transreid_stride_384'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/MSMT17/deit_small.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_small_distilled_patch16_224-649709d9.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('1')
10 | TRANSFORMER_TYPE: 'deit_small_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 128]
15 | SIZE_TEST: [256, 128]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.8 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('msmt17')
24 | ROOT_DIR: ('../../data')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 120
34 | BASE_LR: 0.005
35 | IMS_PER_BATCH: 64
36 | LARGE_FC_LR: False
37 | CHECKPOINT_PERIOD: 120
38 | LOG_PERIOD: 50
39 | EVAL_PERIOD: 120
40 | WEIGHT_DECAY: 1e-4
41 | WEIGHT_DECAY_BIAS: 1e-4
42 | BIAS_LR_FACTOR: 2
43 |
44 | TEST:
45 | EVAL: True
46 | IMS_PER_BATCH: 256
47 | RE_RANKING: False
48 | WEIGHT: ''
49 | NECK_FEAT: 'before'
50 | FEAT_NORM: 'yes'
51 |
52 | OUTPUT_DIR: '../logs/msmt17_deit_small_try'
53 |
54 |
55 |
--------------------------------------------------------------------------------
/configs/MSMT17/deit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('5')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.8 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('msmt17')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.005
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/msmt17_deit_transreid'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/MSMT17/vit_base.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('4')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 128]
15 | SIZE_TEST: [256, 128]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('msmt17')
24 | ROOT_DIR: ('../../data')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 120
34 | BASE_LR: 0.008
35 | IMS_PER_BATCH: 64
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 120
39 | LOG_PERIOD: 50
40 | EVAL_PERIOD: 120
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: ''
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: '../logs/msmt17_vit_base'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/MSMT17/vit_jpm.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('1')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | JPM: True
13 | RE_ARRANGE: True
14 |
15 | INPUT:
16 | SIZE_TRAIN: [256, 128]
17 | SIZE_TEST: [256, 128]
18 | PROB: 0.5 # random horizontal flip
19 | RE_PROB: 0.5 # random erasing
20 | PADDING: 10
21 | PIXEL_MEAN: [0.5, 0.5, 0.5]
22 | PIXEL_STD: [0.5, 0.5, 0.5]
23 |
24 | DATASETS:
25 | NAMES: ('msmt17')
26 | ROOT_DIR: ('../../data')
27 |
28 | DATALOADER:
29 | SAMPLER: 'softmax_triplet'
30 | NUM_INSTANCE: 4
31 | NUM_WORKERS: 8
32 |
33 | SOLVER:
34 | OPTIMIZER_NAME: 'SGD'
35 | MAX_EPOCHS: 120
36 | BASE_LR: 0.008
37 | IMS_PER_BATCH: 64
38 | WARMUP_METHOD: 'linear'
39 | LARGE_FC_LR: False
40 | CHECKPOINT_PERIOD: 120
41 | LOG_PERIOD: 50
42 | EVAL_PERIOD: 120
43 | WEIGHT_DECAY: 1e-4
44 | WEIGHT_DECAY_BIAS: 1e-4
45 | BIAS_LR_FACTOR: 2
46 |
47 | TEST:
48 | EVAL: True
49 | IMS_PER_BATCH: 256
50 | RE_RANKING: False
51 | WEIGHT: ''
52 | NECK_FEAT: 'before'
53 | FEAT_NORM: 'yes'
54 |
55 | OUTPUT_DIR: '../logs/msmt17_vit_jpm'
56 |
57 |
58 |
--------------------------------------------------------------------------------
/configs/MSMT17/vit_sie.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('2')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 |
15 | INPUT:
16 | SIZE_TRAIN: [256, 128]
17 | SIZE_TEST: [256, 128]
18 | PROB: 0.5 # random horizontal flip
19 | RE_PROB: 0.5 # random erasing
20 | PADDING: 10
21 | PIXEL_MEAN: [0.5, 0.5, 0.5]
22 | PIXEL_STD: [0.5, 0.5, 0.5]
23 |
24 | DATASETS:
25 | NAMES: ('msmt17')
26 | ROOT_DIR: ('../../data')
27 |
28 | DATALOADER:
29 | SAMPLER: 'softmax_triplet'
30 | NUM_INSTANCE: 4
31 | NUM_WORKERS: 8
32 |
33 | SOLVER:
34 | OPTIMIZER_NAME: 'SGD'
35 | MAX_EPOCHS: 120
36 | BASE_LR: 0.008
37 | IMS_PER_BATCH: 64
38 | WARMUP_METHOD: 'linear'
39 | LARGE_FC_LR: False
40 | CHECKPOINT_PERIOD: 120
41 | LOG_PERIOD: 50
42 | EVAL_PERIOD: 120
43 | WEIGHT_DECAY: 1e-4
44 | WEIGHT_DECAY_BIAS: 1e-4
45 | BIAS_LR_FACTOR: 2
46 |
47 | TEST:
48 | EVAL: True
49 | IMS_PER_BATCH: 256
50 | RE_RANKING: False
51 | WEIGHT: ''
52 | NECK_FEAT: 'before'
53 | FEAT_NORM: 'yes'
54 |
55 | OUTPUT_DIR: '../logs/msmt17_vit_sie'
56 |
57 |
58 |
--------------------------------------------------------------------------------
/configs/MSMT17/vit_small.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/vit_small_p16_224-15ec54c9.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('0')
10 | TRANSFORMER_TYPE: 'vit_small_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 128]
15 | SIZE_TEST: [256, 128]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.8 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('msmt17')
24 | ROOT_DIR: ('../../data')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 120
34 | BASE_LR: 0.005
35 | IMS_PER_BATCH: 64
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 120
39 | LOG_PERIOD: 50
40 | EVAL_PERIOD: 120
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: ''
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: '../logs/msmt17_vit_small'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/MSMT17/vit_transreid.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('msmt17')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/msmt17_vit_transreid'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/MSMT17/vit_transreid_384.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [384, 128]
19 | SIZE_TEST: [384, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('msmt17')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/msmt17_vit_transreid_384'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/MSMT17/vit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('6')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('msmt17')
28 | ROOT_DIR: ('/data/zi.wang/')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 32
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 1000
43 | LOG_PERIOD: 1000
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: './logs_msmt17/lr0008_b32_Process1_Model1_loss1'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/MSMT17/vit_transreid_stride_384.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [384, 128]
19 | SIZE_TEST: [384, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('msmt17')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/msmt17_vit_transreid_stride_384'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/Market/deit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('4')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.8 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('market1501')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: '../logs/0321_market_deit_transreie/transformer_120.pth'
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/0321_market_deit_transreie'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/Market/vit_base.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('7')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 128]
15 | SIZE_TEST: [256, 128]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('market1501')
24 | ROOT_DIR: ('../../data')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 120
34 | BASE_LR: 0.008
35 | IMS_PER_BATCH: 64
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 120
39 | LOG_PERIOD: 50
40 | EVAL_PERIOD: 120
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: '../logs/0321_market_vit_base/transformer_120.pth'
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: '../logs/0321_market_vit_base'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/Market/vit_jpm.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('1')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | JPM: True
13 | RE_ARRANGE: True
14 |
15 | INPUT:
16 | SIZE_TRAIN: [256, 128]
17 | SIZE_TEST: [256, 128]
18 | PROB: 0.5 # random horizontal flip
19 | RE_PROB: 0.5 # random erasing
20 | PADDING: 10
21 | PIXEL_MEAN: [0.5, 0.5, 0.5]
22 | PIXEL_STD: [0.5, 0.5, 0.5]
23 |
24 | DATASETS:
25 | NAMES: ('market1501')
26 | ROOT_DIR: ('../../data')
27 |
28 | DATALOADER:
29 | SAMPLER: 'softmax_triplet'
30 | NUM_INSTANCE: 4
31 | NUM_WORKERS: 8
32 |
33 | SOLVER:
34 | OPTIMIZER_NAME: 'SGD'
35 | MAX_EPOCHS: 120
36 | BASE_LR: 0.008
37 | IMS_PER_BATCH: 64
38 | WARMUP_METHOD: 'linear'
39 | LARGE_FC_LR: False
40 | CHECKPOINT_PERIOD: 120
41 | LOG_PERIOD: 50
42 | EVAL_PERIOD: 120
43 | WEIGHT_DECAY: 1e-4
44 | WEIGHT_DECAY_BIAS: 1e-4
45 | BIAS_LR_FACTOR: 2
46 |
47 | TEST:
48 | EVAL: True
49 | IMS_PER_BATCH: 256
50 | RE_RANKING: False
51 | WEIGHT: '../logs/0321_market_vit_jpm/transformer_120.pth'
52 | NECK_FEAT: 'before'
53 | FEAT_NORM: 'yes'
54 |
55 | OUTPUT_DIR: '../logs/0321_market_vit_jpm'
56 |
57 |
58 |
--------------------------------------------------------------------------------
/configs/Market/vit_sie.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('7')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 |
15 | INPUT:
16 | SIZE_TRAIN: [256, 128]
17 | SIZE_TEST: [256, 128]
18 | PROB: 0.5 # random horizontal flip
19 | RE_PROB: 0.5 # random erasing
20 | PADDING: 10
21 | PIXEL_MEAN: [0.5, 0.5, 0.5]
22 | PIXEL_STD: [0.5, 0.5, 0.5]
23 |
24 | DATASETS:
25 | NAMES: ('market1501')
26 | ROOT_DIR: ('../../data')
27 |
28 | DATALOADER:
29 | SAMPLER: 'softmax_triplet'
30 | NUM_INSTANCE: 4
31 | NUM_WORKERS: 8
32 |
33 | SOLVER:
34 | OPTIMIZER_NAME: 'SGD'
35 | MAX_EPOCHS: 120
36 | BASE_LR: 0.008
37 | IMS_PER_BATCH: 64
38 | WARMUP_METHOD: 'linear'
39 | LARGE_FC_LR: False
40 | CHECKPOINT_PERIOD: 120
41 | LOG_PERIOD: 50
42 | EVAL_PERIOD: 120
43 | WEIGHT_DECAY: 1e-4
44 | WEIGHT_DECAY_BIAS: 1e-4
45 | BIAS_LR_FACTOR: 2
46 |
47 | TEST:
48 | EVAL: True
49 | IMS_PER_BATCH: 256
50 | RE_RANKING: False
51 | WEIGHT: ''
52 | NECK_FEAT: 'before'
53 | FEAT_NORM: 'yes'
54 |
55 | OUTPUT_DIR: '../logs/market_vit_sie'
56 |
57 |
58 |
--------------------------------------------------------------------------------
/configs/Market/vit_transreid.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('5')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('market1501')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: '../logs/market_vit_transreid/transformer_120.pth'
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/market_vit_transreid'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/Market/vit_transreid_384.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('5')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [384, 128]
19 | SIZE_TEST: [384, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('market1501')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: '../logs/market_vit_transreid_384/transformer_120.pth'
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/market_vit_transreid_384'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/Market/vit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('5')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('market1501')
28 | ROOT_DIR: ('/data2/zi.wang/')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 300
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 32
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 10000
43 | LOG_PERIOD: 100
44 | EVAL_PERIOD: 5
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: './logs_market/lr0008_b32_Process1_Model12_loss1_AO'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/Market/vit_transreid_stride_384.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [384, 128]
19 | SIZE_TEST: [384, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('market1501')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/0321_market_vit_transreid_stride_384'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/OCC_Duke/deit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('2')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [11, 11]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.8 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('occ_duke')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/occ_duke_deit_transreid_stride11'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/OCC_Duke/osnet.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/osnet_x0_5_imagenet.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'osnet'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('5')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 128]
15 | SIZE_TEST: [256, 128]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('occ_duke')
24 | ROOT_DIR: ('/data2/zi.wang/')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 170
34 | BASE_LR: 0.008
35 | IMS_PER_BATCH: 32
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 1000
39 | LOG_PERIOD: 200
40 | EVAL_PERIOD: 5
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: ''
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: './logs_occ_duke/osnet/lr0008_b32_Process1_Model1_loss1'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/OCC_Duke/resnet.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'resnet50'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('5')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 128]
15 | SIZE_TEST: [256, 128]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('occ_duke')
24 | ROOT_DIR: ('/data2/zi.wang/')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 170
34 | BASE_LR: 0.008
35 | IMS_PER_BATCH: 32
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 1000
39 | LOG_PERIOD: 200
40 | EVAL_PERIOD: 5
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: ''
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: './logs_occ_duke/resnet/lr0008_b32_Process1_Model1_loss1'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/OCC_Duke/vit_base.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('5')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 128]
15 | SIZE_TEST: [256, 128]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('occ_duke')
24 | ROOT_DIR: ('/data2/zi.wang/')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 120
34 | BASE_LR: 0.008
35 | IMS_PER_BATCH: 32
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 1000
39 | LOG_PERIOD: 50
40 | EVAL_PERIOD: 5
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: ''
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: './logs_occ_duke/vit_base/lr0008_b32_Process1_Model1_loss4'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/OCC_Duke/vit_jpm.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('1')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | JPM: True
13 | RE_ARRANGE: True
14 |
15 | INPUT:
16 | SIZE_TRAIN: [256, 128]
17 | SIZE_TEST: [256, 128]
18 | PROB: 0.5 # random horizontal flip
19 | RE_PROB: 0.5 # random erasing
20 | PADDING: 10
21 | PIXEL_MEAN: [0.5, 0.5, 0.5]
22 | PIXEL_STD: [0.5, 0.5, 0.5]
23 |
24 | DATASETS:
25 | NAMES: ('occ_duke')
26 | ROOT_DIR: ('../../data')
27 |
28 | DATALOADER:
29 | SAMPLER: 'softmax_triplet'
30 | NUM_INSTANCE: 4
31 | NUM_WORKERS: 8
32 |
33 | SOLVER:
34 | OPTIMIZER_NAME: 'SGD'
35 | MAX_EPOCHS: 120
36 | BASE_LR: 0.008
37 | IMS_PER_BATCH: 64
38 | WARMUP_METHOD: 'linear'
39 | LARGE_FC_LR: False
40 | CHECKPOINT_PERIOD: 120
41 | LOG_PERIOD: 50
42 | EVAL_PERIOD: 120
43 | WEIGHT_DECAY: 1e-4
44 | WEIGHT_DECAY_BIAS: 1e-4
45 | BIAS_LR_FACTOR: 2
46 |
47 | TEST:
48 | EVAL: True
49 | IMS_PER_BATCH: 256
50 | RE_RANKING: False
51 | WEIGHT: ''
52 | NECK_FEAT: 'before'
53 | FEAT_NORM: 'yes'
54 |
55 | OUTPUT_DIR: '../logs/occ_duke_vit_jpm'
56 |
57 |
58 |
--------------------------------------------------------------------------------
/configs/OCC_Duke/vit_sie.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('2')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 |
15 | INPUT:
16 | SIZE_TRAIN: [256, 128]
17 | SIZE_TEST: [256, 128]
18 | PROB: 0.5 # random horizontal flip
19 | RE_PROB: 0.5 # random erasing
20 | PADDING: 10
21 | PIXEL_MEAN: [0.5, 0.5, 0.5]
22 | PIXEL_STD: [0.5, 0.5, 0.5]
23 |
24 | DATASETS:
25 | NAMES: ('occ_duke')
26 | ROOT_DIR: ('../../data')
27 |
28 | DATALOADER:
29 | SAMPLER: 'softmax_triplet'
30 | NUM_INSTANCE: 4
31 | NUM_WORKERS: 8
32 |
33 | SOLVER:
34 | OPTIMIZER_NAME: 'SGD'
35 | MAX_EPOCHS: 120
36 | BASE_LR: 0.008
37 | IMS_PER_BATCH: 64
38 | WARMUP_METHOD: 'linear'
39 | LARGE_FC_LR: False
40 | CHECKPOINT_PERIOD: 120
41 | LOG_PERIOD: 50
42 | EVAL_PERIOD: 120
43 | WEIGHT_DECAY: 1e-4
44 | WEIGHT_DECAY_BIAS: 1e-4
45 | BIAS_LR_FACTOR: 2
46 |
47 | TEST:
48 | EVAL: True
49 | IMS_PER_BATCH: 256
50 | RE_RANKING: False
51 | WEIGHT: ''
52 | NECK_FEAT: 'before'
53 | FEAT_NORM: 'yes'
54 |
55 | OUTPUT_DIR: '../logs/occ_duke_vit_sie'
56 |
57 |
58 |
--------------------------------------------------------------------------------
/configs/OCC_Duke/vit_transreid.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('occ_duke')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/occ_duke_vit_transreid'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/OCC_Duke/vit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [11, 11]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('occ_duke')
28 | ROOT_DIR: ('/data2/zi.wang/')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 170
38 | BASE_LR: 0.008
39 | IMS_PER_BATCH: 32
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 1000
43 | LOG_PERIOD: 200
44 | EVAL_PERIOD: 5
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 | STEPS: (40, 70)
49 |
50 | TEST:
51 | EVAL: True
52 | IMS_PER_BATCH: 256
53 | RE_RANKING: False
54 | WEIGHT: ''
55 | NECK_FEAT: 'before'
56 | FEAT_NORM: 'yes'
57 |
58 | OUTPUT_DIR: './logs_occ_duke/lr0008_b32'
59 |
60 |
61 |
--------------------------------------------------------------------------------
/configs/OCC_ReID/vit_base.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('5')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 128]
15 | SIZE_TEST: [256, 128]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('occ_reid')
24 | ROOT_DIR: ('/data/zi.wang/')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 120
34 | BASE_LR: 0.004
35 | IMS_PER_BATCH: 32
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 1000
39 | LOG_PERIOD: 100
40 | EVAL_PERIOD: 1
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: ''
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: './logs_occ_reid/vit_base_load_lr0004_b32_trainAll'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/OCC_ReID/vit_local.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [11, 11]
12 | JPM: True
13 |
14 | INPUT:
15 | SIZE_TRAIN: [256, 128]
16 | SIZE_TEST: [256, 128]
17 | PROB: 0.5 # random horizontal flip
18 | RE_PROB: 0.5 # random erasing
19 | PADDING: 10
20 | PIXEL_MEAN: [0.5, 0.5, 0.5]
21 | PIXEL_STD: [0.5, 0.5, 0.5]
22 |
23 | DATASETS:
24 | NAMES: ('occ_reid')
25 | ROOT_DIR: ('/data2/zi.wang/')
26 |
27 | DATALOADER:
28 | SAMPLER: 'softmax_triplet'
29 | NUM_INSTANCE: 4
30 | NUM_WORKERS: 8
31 |
32 | SOLVER:
33 | OPTIMIZER_NAME: 'SGD'
34 | MAX_EPOCHS: 120
35 | BASE_LR: 0.008
36 | IMS_PER_BATCH: 32
37 | WARMUP_METHOD: 'linear'
38 | LARGE_FC_LR: False
39 | CHECKPOINT_PERIOD: 120
40 | LOG_PERIOD: 100
41 | EVAL_PERIOD: 1
42 | WEIGHT_DECAY: 1e-4
43 | WEIGHT_DECAY_BIAS: 1e-4
44 | BIAS_LR_FACTOR: 2
45 |
46 | TEST:
47 | EVAL: True
48 | IMS_PER_BATCH: 256
49 | RE_RANKING: False
50 | WEIGHT: ''
51 | NECK_FEAT: 'before'
52 | FEAT_NORM: 'yes'
53 |
54 | OUTPUT_DIR: './logs_occ_reid/vit_local_lr0008_b32_trainAll'
55 |
56 |
57 |
--------------------------------------------------------------------------------
/configs/OCC_ReID/vit_small.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/vit_small_p16_224-15ec54c9.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('5')
10 | TRANSFORMER_TYPE: 'vit_small_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 128]
15 | SIZE_TEST: [256, 128]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('occ_reid')
24 | ROOT_DIR: ('/data/zi.wang/')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 80
34 | BASE_LR: 0.004
35 | IMS_PER_BATCH: 32
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 120
39 | LOG_PERIOD: 100
40 | EVAL_PERIOD: 1
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: ''
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: './logs_occ_reid/vit_small_load_lr0004_b32_KL1_ori_and_eraser'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/OCC_ReID/vit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [11, 11]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('occ_reid')
28 | ROOT_DIR: ('/data2/zi.wang/')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 8
38 | BASE_LR: 0.0001
39 | IMS_PER_BATCH: 32
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 1000
43 | LOG_PERIOD: 100
44 | EVAL_PERIOD: 1
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: './logs_occ_reid/lr00001_b32_Process1_Model1_loss1'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/Partial_ReID/vit_base.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('5')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 128]
15 | SIZE_TEST: [256, 128]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('partial_reid')
24 | ROOT_DIR: ('/data/zi.wang/')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 80
34 | BASE_LR: 0.008
35 | IMS_PER_BATCH: 32
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 1000
39 | LOG_PERIOD: 50
40 | EVAL_PERIOD: 1
41 | WEIGHT_DECAY: 0.0001
42 | WEIGHT_DECAY_BIAS: 0.0001
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: ''
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: './logs_partial_reid/vit_base_load_lr0008_b32_trainAll'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/Partial_ReID/vit_local.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [11, 11]
12 | JPM: True
13 |
14 | INPUT:
15 | SIZE_TRAIN: [256, 128]
16 | SIZE_TEST: [256, 128]
17 | PROB: 0.5 # random horizontal flip
18 | RE_PROB: 0.5 # random erasing
19 | PADDING: 10
20 | PIXEL_MEAN: [0.5, 0.5, 0.5]
21 | PIXEL_STD: [0.5, 0.5, 0.5]
22 |
23 | DATASETS:
24 | NAMES: ('partial_reid')
25 | ROOT_DIR: ('/data2/zi.wang/')
26 |
27 | DATALOADER:
28 | SAMPLER: 'softmax_triplet'
29 | NUM_INSTANCE: 4
30 | NUM_WORKERS: 8
31 |
32 | SOLVER:
33 | OPTIMIZER_NAME: 'SGD'
34 | MAX_EPOCHS: 80
35 | BASE_LR: 0.008
36 | IMS_PER_BATCH: 32
37 | WARMUP_METHOD: 'linear'
38 | LARGE_FC_LR: False
39 | CHECKPOINT_PERIOD: 120
40 | LOG_PERIOD: 100
41 | EVAL_PERIOD: 1
42 | WEIGHT_DECAY: 1e-4
43 | WEIGHT_DECAY_BIAS: 1e-4
44 | BIAS_LR_FACTOR: 2
45 |
46 | TEST:
47 | EVAL: True
48 | IMS_PER_BATCH: 256
49 | RE_RANKING: False
50 | WEIGHT: ''
51 | NECK_FEAT: 'before'
52 | FEAT_NORM: 'yes'
53 |
54 | OUTPUT_DIR: './logs_partial_reid/vit_local_stride'
55 |
56 |
57 |
--------------------------------------------------------------------------------
/configs/Partial_ReID/vit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('3')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [11, 11]
12 | SIE_CAMERA: True
13 | SIE_COE: 3.0
14 | JPM: True
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 128]
19 | SIZE_TEST: [256, 128]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('partial_reid')
28 | ROOT_DIR: ('/data2/zi.wang/')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 8
38 | BASE_LR: 0.0001
39 | IMS_PER_BATCH: 32
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 1000
43 | LOG_PERIOD: 100
44 | EVAL_PERIOD: 1
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: './logs_partial_reid/lr00001_b32_Process1_Model12_loss1'
58 |
59 |
--------------------------------------------------------------------------------
/configs/VeRi/deit_transreid.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('4')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_VIEW: True
14 | SIE_COE: 3.0
15 | JPM: True
16 | SHIFT_NUM: 8
17 | RE_ARRANGE: True
18 |
19 | INPUT:
20 | SIZE_TRAIN: [256, 256]
21 | SIZE_TEST: [256, 256]
22 | PROB: 0.5 # random horizontal flip
23 | RE_PROB: 0.8 # random erasing
24 | PADDING: 10
25 |
26 | DATASETS:
27 | NAMES: ('veri')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.01
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/veri_deit_transreid'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/VeRi/deit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('0')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | SIE_CAMERA: True
13 | SIE_VIEW: True
14 | SIE_COE: 3.0
15 | JPM: True
16 | SHIFT_NUM: 8
17 | RE_ARRANGE: True
18 |
19 | INPUT:
20 | SIZE_TRAIN: [256, 256]
21 | SIZE_TEST: [256, 256]
22 | PROB: 0.5 # random horizontal flip
23 | RE_PROB: 0.8 # random erasing
24 | PADDING: 10
25 |
26 | DATASETS:
27 | NAMES: ('veri')
28 | ROOT_DIR: ('../../datasets')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.01
39 | IMS_PER_BATCH: 64
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: ''
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/veri_deit_transreid_stride'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/VeRi/vit_base.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('4')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 256]
15 | SIZE_TEST: [256, 256]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('veri')
24 | ROOT_DIR: ('../../data')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 120
34 | BASE_LR: 0.008
35 | IMS_PER_BATCH: 64
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 120
39 | LOG_PERIOD: 50
40 | EVAL_PERIOD: 120
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: ''
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: '../logs/veri_vit_base'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/VeRi/vit_transreid.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('4')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | SIE_CAMERA: True
13 | SIE_VIEW: True
14 | SIE_COE: 3.0
15 | JPM: True
16 | SHIFT_NUM: 8
17 | RE_ARRANGE: True
18 |
19 | INPUT:
20 | SIZE_TRAIN: [256, 256]
21 | SIZE_TEST: [256, 256]
22 | PROB: 0.5 # random horizontal flip
23 | RE_PROB: 0.5 # random erasing
24 | PADDING: 10
25 | PIXEL_MEAN: [0.5, 0.5, 0.5]
26 | PIXEL_STD: [0.5, 0.5, 0.5]
27 |
28 | DATASETS:
29 | NAMES: ('veri')
30 | ROOT_DIR: ('../../data')
31 |
32 | DATALOADER:
33 | SAMPLER: 'softmax_triplet'
34 | NUM_INSTANCE: 4
35 | NUM_WORKERS: 8
36 |
37 | SOLVER:
38 | OPTIMIZER_NAME: 'SGD'
39 | MAX_EPOCHS: 120
40 | BASE_LR: 0.01
41 | IMS_PER_BATCH: 64
42 | WARMUP_METHOD: 'linear'
43 | LARGE_FC_LR: False
44 | CHECKPOINT_PERIOD: 120
45 | LOG_PERIOD: 50
46 | EVAL_PERIOD: 120
47 | WEIGHT_DECAY: 1e-4
48 | WEIGHT_DECAY_BIAS: 1e-4
49 | BIAS_LR_FACTOR: 2
50 |
51 | TEST:
52 | EVAL: True
53 | IMS_PER_BATCH: 256
54 | RE_RANKING: False
55 | WEIGHT: ''
56 | NECK_FEAT: 'before'
57 | FEAT_NORM: 'yes'
58 |
59 | OUTPUT_DIR: '../logs/veri_vit_transreid'
60 |
61 |
62 |
--------------------------------------------------------------------------------
/configs/VeRi/vit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('2')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | SIE_CAMERA: True
13 | SIE_VIEW: True
14 | SIE_COE: 3.0
15 | JPM: True
16 | SHIFT_NUM: 8
17 | RE_ARRANGE: True
18 |
19 | INPUT:
20 | SIZE_TRAIN: [256, 128]
21 | SIZE_TEST: [256, 128]
22 | PROB: 0.5 # random horizontal flip
23 | RE_PROB: 0.5 # random erasing
24 | PADDING: 10
25 | PIXEL_MEAN: [0.5, 0.5, 0.5]
26 | PIXEL_STD: [0.5, 0.5, 0.5]
27 |
28 | DATASETS:
29 | NAMES: ('veri')
30 | ROOT_DIR: ('/data/zi.wang/')
31 |
32 | DATALOADER:
33 | SAMPLER: 'softmax_triplet'
34 | NUM_INSTANCE: 4
35 | NUM_WORKERS: 8
36 |
37 | SOLVER:
38 | OPTIMIZER_NAME: 'SGD'
39 | MAX_EPOCHS: 150
40 | BASE_LR: 0.01
41 | IMS_PER_BATCH: 32
42 | WARMUP_METHOD: 'linear'
43 | LARGE_FC_LR: False
44 | CHECKPOINT_PERIOD: 1000
45 | LOG_PERIOD: 100
46 | EVAL_PERIOD: 10
47 | WEIGHT_DECAY: 1e-4
48 | WEIGHT_DECAY_BIAS: 1e-4
49 | BIAS_LR_FACTOR: 2
50 |
51 | TEST:
52 | EVAL: True
53 | IMS_PER_BATCH: 256
54 | RE_RANKING: False
55 | WEIGHT: ''
56 | NECK_FEAT: 'before'
57 | FEAT_NORM: 'yes'
58 |
59 | OUTPUT_DIR: './logs_veri/lr001_b32_Process1_Model1_loss1'
60 |
61 |
62 |
--------------------------------------------------------------------------------
/configs/VehicleID/deit_transreid.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | # DEVICE_ID: ('0')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | DIST_TRAIN: True
13 | JPM: True
14 | SHIFT_NUM: 8
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 256]
19 | SIZE_TEST: [256, 256]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.8 # random erasing
22 | PADDING: 10
23 |
24 | DATASETS:
25 | NAMES: ('VehicleID')
26 | ROOT_DIR: ('../../data')
27 |
28 | DATALOADER:
29 | SAMPLER: 'softmax_triplet'
30 | NUM_INSTANCE: 4
31 | NUM_WORKERS: 8
32 |
33 | SOLVER:
34 | OPTIMIZER_NAME: 'SGD'
35 | MAX_EPOCHS: 120
36 | BASE_LR: 0.03
37 | IMS_PER_BATCH: 256
38 | WARMUP_METHOD: 'linear'
39 | LARGE_FC_LR: False
40 | CHECKPOINT_PERIOD: 120
41 | LOG_PERIOD: 50
42 | EVAL_PERIOD: 120
43 | WEIGHT_DECAY: 1e-4
44 | WEIGHT_DECAY_BIAS: 1e-4
45 | BIAS_LR_FACTOR: 2
46 |
47 | TEST:
48 | EVAL: True
49 | IMS_PER_BATCH: 256
50 | RE_RANKING: False
51 | WEIGHT: ''
52 | NECK_FEAT: 'before'
53 | FEAT_NORM: 'yes'
54 |
55 | OUTPUT_DIR: '../logs/vehicleID_deit_transreid'
56 |
57 |
58 |
--------------------------------------------------------------------------------
/configs/VehicleID/deit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | # DEVICE_ID: ('0')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | DIST_TRAIN: True
13 | JPM: True
14 | SHIFT_NUM: 8
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 256]
19 | SIZE_TEST: [256, 256]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.8 # random erasing
22 | PADDING: 10
23 |
24 | DATASETS:
25 | NAMES: ('VehicleID')
26 | ROOT_DIR: ('../../data')
27 |
28 | DATALOADER:
29 | SAMPLER: 'softmax_triplet'
30 | NUM_INSTANCE: 4
31 | NUM_WORKERS: 8
32 |
33 | SOLVER:
34 | OPTIMIZER_NAME: 'SGD'
35 | MAX_EPOCHS: 120
36 | BASE_LR: 0.03
37 | IMS_PER_BATCH: 256
38 | WARMUP_METHOD: 'linear'
39 | LARGE_FC_LR: False
40 | CHECKPOINT_PERIOD: 120
41 | LOG_PERIOD: 50
42 | EVAL_PERIOD: 120
43 | WEIGHT_DECAY: 1e-4
44 | WEIGHT_DECAY_BIAS: 1e-4
45 | BIAS_LR_FACTOR: 2
46 |
47 | TEST:
48 | EVAL: True
49 | IMS_PER_BATCH: 256
50 | RE_RANKING: False
51 | WEIGHT: ''
52 | NECK_FEAT: 'before'
53 | FEAT_NORM: 'yes'
54 |
55 | OUTPUT_DIR: '../logs/vehicleID_deit_transreid_stride'
56 |
57 |
58 |
--------------------------------------------------------------------------------
/configs/VehicleID/vit_base.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | # DEVICE_ID: ('0')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 256]
15 | SIZE_TEST: [256, 256]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('VehicleID')
24 | ROOT_DIR: ('../../data')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 120
34 | BASE_LR: 0.04
35 | IMS_PER_BATCH: 224
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 120
39 | LOG_PERIOD: 50
40 | EVAL_PERIOD: 120
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: '../logs/vehicleID_vit_base/transformer_120.pth'
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: '../logs/vehicleID_vit_base'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/VehicleID/vit_transreid.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | # DEVICE_ID: ('0')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 | # DIST_TRAIN: True
13 | JPM: True
14 | SHIFT_NUM: 8
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 256]
19 | SIZE_TEST: [256, 256]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('VehicleID')
28 | ROOT_DIR: ('../../data')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.045
39 | IMS_PER_BATCH: 224
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: '../logs/vehicleID_vit_transreid/transformer_120.pth'
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: '../logs/vehicleID_vit_transreid'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/VehicleID/vit_transreid_stride.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/ziwang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | # DEVICE_ID: ('0')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [12, 12]
12 | # DIST_TRAIN: True
13 | JPM: True
14 | SHIFT_NUM: 8
15 | RE_ARRANGE: True
16 |
17 | INPUT:
18 | SIZE_TRAIN: [256, 256]
19 | SIZE_TEST: [256, 256]
20 | PROB: 0.5 # random horizontal flip
21 | RE_PROB: 0.5 # random erasing
22 | PADDING: 10
23 | PIXEL_MEAN: [0.5, 0.5, 0.5]
24 | PIXEL_STD: [0.5, 0.5, 0.5]
25 |
26 | DATASETS:
27 | NAMES: ('VehicleID')
28 | ROOT_DIR: ('/data/zi.wang/')
29 |
30 | DATALOADER:
31 | SAMPLER: 'softmax_triplet'
32 | NUM_INSTANCE: 4
33 | NUM_WORKERS: 8
34 |
35 | SOLVER:
36 | OPTIMIZER_NAME: 'SGD'
37 | MAX_EPOCHS: 120
38 | BASE_LR: 0.045
39 | IMS_PER_BATCH: 256
40 | WARMUP_METHOD: 'linear'
41 | LARGE_FC_LR: False
42 | CHECKPOINT_PERIOD: 120
43 | LOG_PERIOD: 50
44 | EVAL_PERIOD: 120
45 | WEIGHT_DECAY: 1e-4
46 | WEIGHT_DECAY_BIAS: 1e-4
47 | BIAS_LR_FACTOR: 2
48 |
49 | TEST:
50 | EVAL: True
51 | IMS_PER_BATCH: 256
52 | RE_RANKING: False
53 | WEIGHT: '../logs/vehicleID_vit_transreid_stride/transformer_120.pth'
54 | NECK_FEAT: 'before'
55 | FEAT_NORM: 'yes'
56 |
57 | OUTPUT_DIR: './logs_vehicleID/vit_transreid_stride'
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/transformer_base.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
4 | METRIC_LOSS_TYPE: 'triplet'
5 | IF_LABELSMOOTH: 'off'
6 | IF_WITH_CENTER: 'no'
7 | NAME: 'transformer'
8 | NO_MARGIN: True
9 | DEVICE_ID: ('7')
10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID'
11 | STRIDE_SIZE: [16, 16]
12 |
13 | INPUT:
14 | SIZE_TRAIN: [256, 256]
15 | SIZE_TEST: [256, 256]
16 | PROB: 0.5 # random horizontal flip
17 | RE_PROB: 0.5 # random erasing
18 | PADDING: 10
19 | PIXEL_MEAN: [0.5, 0.5, 0.5]
20 | PIXEL_STD: [0.5, 0.5, 0.5]
21 |
22 | DATASETS:
23 | NAMES: ('dukemtmc')
24 | ROOT_DIR: ('../../data')
25 |
26 | DATALOADER:
27 | SAMPLER: 'softmax_triplet'
28 | NUM_INSTANCE: 4
29 | NUM_WORKERS: 8
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'SGD'
33 | MAX_EPOCHS: 120
34 | BASE_LR: 0.008
35 | IMS_PER_BATCH: 64
36 | WARMUP_METHOD: 'linear'
37 | LARGE_FC_LR: False
38 | CHECKPOINT_PERIOD: 120
39 | LOG_PERIOD: 50
40 | EVAL_PERIOD: 120
41 | WEIGHT_DECAY: 1e-4
42 | WEIGHT_DECAY_BIAS: 1e-4
43 | BIAS_LR_FACTOR: 2
44 |
45 | TEST:
46 | EVAL: True
47 | IMS_PER_BATCH: 256
48 | RE_RANKING: False
49 | WEIGHT: ''
50 | NECK_FEAT: 'before'
51 | FEAT_NORM: 'yes'
52 |
53 | OUTPUT_DIR: '../logs/'
54 |
55 |
56 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .make_dataloader import make_dataloader
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/bases.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/bases.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/dukemtmcreid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/dukemtmcreid.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/make_dataloader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/make_dataloader.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/make_dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/make_dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/market1501.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/market1501.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/msmt17.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/msmt17.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/occ_duke.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/occ_duke.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/occ_reid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/occ_reid.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/partial_reid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/partial_reid.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/sampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/sampler.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/sampler_ddp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/sampler_ddp.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/vehicleid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/vehicleid.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/veri.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/veri.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/bases.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageFile
2 |
3 | from torch.utils.data import Dataset
4 | import os.path as osp
5 | import random
6 | import torch
7 | ImageFile.LOAD_TRUNCATED_IMAGES = True
8 |
9 |
10 | def read_image(img_path):
11 | """Keep reading image until succeed.
12 | This can avoid IOError incurred by heavy IO process."""
13 | got_img = False
14 | if not osp.exists(img_path):
15 | raise IOError("{} does not exist".format(img_path))
16 | while not got_img:
17 | try:
18 | img = Image.open(img_path).convert('RGB')
19 | got_img = True
20 | except IOError:
21 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
22 | pass
23 | return img
24 |
25 |
26 | class BaseDataset(object):
27 | """
28 | Base class of reid dataset
29 | """
30 |
31 | def get_imagedata_info(self, data):
32 | pids, cams, tracks = [], [], []
33 |
34 | for _, pid, camid, trackid in data:
35 | pids += [pid]
36 | cams += [camid]
37 | tracks += [trackid]
38 | pids = set(pids)
39 | cams = set(cams)
40 | tracks = set(tracks)
41 | num_pids = len(pids)
42 | num_cams = len(cams)
43 | num_imgs = len(data)
44 | num_views = len(tracks)
45 | return num_pids, num_imgs, num_cams, num_views
46 |
47 | def print_dataset_statistics(self):
48 | raise NotImplementedError
49 |
50 |
51 | class BaseImageDataset(BaseDataset):
52 | """
53 | Base class of image reid dataset
54 | """
55 |
56 | def print_dataset_statistics(self, train, query, gallery):
57 | num_train_pids, num_train_imgs, num_train_cams, num_train_views = self.get_imagedata_info(train)
58 | num_query_pids, num_query_imgs, num_query_cams, num_train_views = self.get_imagedata_info(query)
59 | num_gallery_pids, num_gallery_imgs, num_gallery_cams, num_train_views = self.get_imagedata_info(gallery)
60 |
61 | print("Dataset statistics:")
62 | print(" ----------------------------------------")
63 | print(" subset | # ids | # images | # cameras")
64 | print(" ----------------------------------------")
65 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
66 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
67 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
68 | print(" ----------------------------------------")
69 |
70 |
71 | class ImageDataset(Dataset):
72 | def __init__(self, dataset, transform=None, crop_transform=None, eraser_transform=None):
73 | self.dataset = dataset
74 | self.transform = transform
75 | self.crop_transform = crop_transform
76 | self.eraser_transform = eraser_transform
77 |
78 | def __len__(self):
79 | return len(self.dataset)
80 |
81 | def __getitem__(self, index):
82 | img_path, pid, camid, trackid = self.dataset[index]
83 | img = read_image(img_path)
84 |
85 | if self.transform is not None:
86 | img1 = self.transform(img)
87 | if self.crop_transform is not None and self.eraser_transform is not None:
88 | img2 = self.crop_transform(img)
89 | img3 = self.eraser_transform(img)
90 | return img1, img2, img3, pid, camid, trackid,img_path.split('/')[-1]
91 | else:
92 | return img1, img1, img1, pid, camid, trackid,img_path.split('/')[-1]
--------------------------------------------------------------------------------
/datasets/dukemtmcreid.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import glob
8 | import re
9 | import urllib
10 | import zipfile
11 |
12 | import os.path as osp
13 |
14 | from utils.iotools import mkdir_if_missing
15 | from .bases import BaseImageDataset
16 |
17 |
18 | class DukeMTMCreID(BaseImageDataset):
19 | """
20 | DukeMTMC-reID
21 | Reference:
22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation
25 |
26 | Dataset statistics:
27 | # identities: 1404 (train + query)
28 | # images:16522 (train) + 2228 (query) + 17661 (gallery)
29 | # cameras: 8
30 | """
31 | dataset_dir = 'DukeMTMC-reID'
32 |
33 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs):
34 | super(DukeMTMCreID, self).__init__()
35 | self.dataset_dir = osp.join(root, self.dataset_dir)
36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
37 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
38 | self.query_dir = osp.join(self.dataset_dir, 'query')
39 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
40 | self.pid_begin = pid_begin
41 | self._download_data()
42 | self._check_before_run()
43 |
44 | train = self._process_dir(self.train_dir, relabel=True)
45 | query = self._process_dir(self.query_dir, relabel=False)
46 | gallery = self._process_dir(self.gallery_dir, relabel=False)
47 |
48 | if verbose:
49 | print("=> DukeMTMC-reID loaded")
50 | self.print_dataset_statistics(train, query, gallery)
51 |
52 | self.train = train
53 | self.query = query
54 | self.gallery = gallery
55 |
56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train)
57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query)
58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery)
59 |
60 | def _download_data(self):
61 | if osp.exists(self.dataset_dir):
62 | print("This dataset has been downloaded.")
63 | return
64 |
65 | print("Creating directory {}".format(self.dataset_dir))
66 | mkdir_if_missing(self.dataset_dir)
67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url))
68 |
69 | print("Downloading DukeMTMC-reID dataset")
70 | urllib.request.urlretrieve(self.dataset_url, fpath)
71 |
72 | print("Extracting files")
73 | zip_ref = zipfile.ZipFile(fpath, 'r')
74 | zip_ref.extractall(self.dataset_dir)
75 | zip_ref.close()
76 |
77 | def _check_before_run(self):
78 | """Check if all files are available before going deeper"""
79 | if not osp.exists(self.dataset_dir):
80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
81 | if not osp.exists(self.train_dir):
82 | raise RuntimeError("'{}' is not available".format(self.train_dir))
83 | if not osp.exists(self.query_dir):
84 | raise RuntimeError("'{}' is not available".format(self.query_dir))
85 | if not osp.exists(self.gallery_dir):
86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
87 |
88 | def _process_dir(self, dir_path, relabel=False):
89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
90 | pattern = re.compile(r'([-\d]+)_c(\d)')
91 |
92 | pid_container = set()
93 | for img_path in img_paths:
94 | pid, _ = map(int, pattern.search(img_path).groups())
95 | pid_container.add(pid)
96 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
97 |
98 | dataset = []
99 | cam_container = set()
100 | for img_path in img_paths:
101 | pid, camid = map(int, pattern.search(img_path).groups())
102 | assert 1 <= camid <= 8
103 | camid -= 1 # index starts from 0
104 | if relabel: pid = pid2label[pid]
105 | dataset.append((img_path, self.pid_begin + pid, camid, 1))
106 | cam_container.add(camid)
107 | print(cam_container, 'cam_container')
108 | return dataset
109 |
--------------------------------------------------------------------------------
/datasets/make_dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as T
3 | from torch.utils.data import DataLoader
4 |
5 | from .bases import ImageDataset
6 | from timm.data.random_erasing import RandomErasing
7 | from .sampler import RandomIdentitySampler
8 | from .dukemtmcreid import DukeMTMCreID
9 | from .market1501 import Market1501
10 | from .msmt17 import MSMT17
11 | from .sampler_ddp import RandomIdentitySampler_DDP
12 | import torch.distributed as dist
13 | from .occ_duke import OCC_DukeMTMCreID
14 | from .vehicleid import VehicleID
15 | from .veri import VeRi
16 | # from .occ_reid import Occ_ReID
17 | # from .partial_reid import Partial_REID
18 |
19 | __factory = {
20 | 'market1501': Market1501,
21 | 'dukemtmc': DukeMTMCreID,
22 | 'msmt17': MSMT17,
23 | 'occ_duke': OCC_DukeMTMCreID,
24 | 'veri': VeRi,
25 | 'VehicleID': VehicleID,
26 | # 'partial_reid': Partial_REID,
27 | # 'occ_reid': Occ_ReID
28 | }
29 |
30 | def train_collate_fn(batch):
31 | """
32 | # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果
33 | """
34 | imgs1, imgs2, imgs3, pids, camids, viewids , _ = zip(*batch)
35 | pids = torch.tensor(pids, dtype=torch.int64)
36 | viewids = torch.tensor(viewids, dtype=torch.int64)
37 | camids = torch.tensor(camids, dtype=torch.int64)
38 | return torch.stack(imgs1, dim=0), torch.stack(imgs2, dim=0), torch.stack(imgs3, dim=0), pids, camids, viewids,
39 |
40 | def val_collate_fn(batch):
41 | imgs1, imgs2, imgs3, pids, camids, viewids, img_paths = zip(*batch)
42 | viewids = torch.tensor(viewids, dtype=torch.int64)
43 | camids_batch = torch.tensor(camids, dtype=torch.int64)
44 | return torch.stack(imgs1, dim=0), torch.stack(imgs2, dim=0), torch.stack(imgs3, dim=0), pids, camids, camids_batch, viewids, img_paths
45 |
46 | def make_dataloader(cfg):
47 | train_transforms = T.Compose([
48 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3),
49 | # T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
50 | # T.Pad(cfg.INPUT.PADDING),
51 | # T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
52 | T.ToTensor(),
53 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD),
54 | # RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'),
55 | # RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN)
56 | ])
57 | crop_transforms = T.Compose([
58 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3),
59 | # T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
60 | T.Pad(30),
61 | T.ToTensor(),
62 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD),
63 | # T.RandomResizedCrop(size=(256, 128), scale=(0.3, 0.6)),
64 | T.RandomResizedCrop(size=(256, 128)),
65 | ])
66 | eraser_transforms = T.Compose([
67 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3),
68 | # T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
69 | T.ToTensor(),
70 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD),
71 | RandomErasing(probability=1, mode='pixel', max_count=1, device='cpu'),
72 | ])
73 |
74 |
75 |
76 | val_transforms = T.Compose([
77 | T.Resize(cfg.INPUT.SIZE_TEST),
78 | T.ToTensor(),
79 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
80 | ])
81 |
82 | num_workers = cfg.DATALOADER.NUM_WORKERS
83 |
84 | dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR)
85 |
86 | train_set = ImageDataset(dataset.train, train_transforms, crop_transform=crop_transforms, eraser_transform=eraser_transforms)
87 | train_set_normal = ImageDataset(dataset.train, val_transforms)
88 | num_classes = dataset.num_train_pids
89 | cam_num = dataset.num_train_cams
90 | view_num = dataset.num_train_vids
91 |
92 | if 'triplet' in cfg.DATALOADER.SAMPLER:
93 | if cfg.MODEL.DIST_TRAIN:
94 | print('DIST_TRAIN START')
95 | mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size()
96 | data_sampler = RandomIdentitySampler_DDP(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
97 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
98 | train_loader = torch.utils.data.DataLoader(
99 | train_set,
100 | num_workers=num_workers,
101 | batch_sampler=batch_sampler,
102 | collate_fn=train_collate_fn,
103 | pin_memory=True,
104 | )
105 | else:
106 | train_loader = DataLoader(
107 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
108 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
109 | num_workers=num_workers, collate_fn=train_collate_fn
110 | )
111 | elif cfg.DATALOADER.SAMPLER == 'softmax':
112 | print('using softmax sampler')
113 | train_loader = DataLoader(
114 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
115 | collate_fn=train_collate_fn
116 | )
117 | else:
118 | print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER))
119 |
120 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
121 |
122 | val_loader = DataLoader(
123 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
124 | collate_fn=val_collate_fn
125 | )
126 | train_loader_normal = DataLoader(
127 | train_set_normal, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
128 | collate_fn=val_collate_fn
129 | )
130 | return train_loader, train_loader_normal, val_loader, len(dataset.query), num_classes, cam_num, view_num
131 |
--------------------------------------------------------------------------------
/datasets/market1501.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import glob
8 | import re
9 |
10 | import os.path as osp
11 |
12 | from .bases import BaseImageDataset
13 | from collections import defaultdict
14 | import pickle
15 | class Market1501(BaseImageDataset):
16 | """
17 | Market1501
18 | Reference:
19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
20 | URL: http://www.liangzheng.org/Project/project_reid.html
21 |
22 | Dataset statistics:
23 | # identities: 1501 (+1 for background)
24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery)
25 | """
26 | dataset_dir = 'market1501'
27 |
28 | def __init__(self, root='', verbose=True, pid_begin = 0, **kwargs):
29 | super(Market1501, self).__init__()
30 | self.dataset_dir = osp.join(root, self.dataset_dir)
31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
32 | self.query_dir = osp.join(self.dataset_dir, 'query')
33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
34 |
35 | self._check_before_run()
36 | self.pid_begin = pid_begin
37 | train = self._process_dir(self.train_dir, relabel=True)
38 | query = self._process_dir(self.query_dir, relabel=False)
39 | gallery = self._process_dir(self.gallery_dir, relabel=False)
40 |
41 | if verbose:
42 | print("=> Market1501 loaded")
43 | self.print_dataset_statistics(train, query, gallery)
44 |
45 | self.train = train
46 | self.query = query
47 | self.gallery = gallery
48 |
49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train)
50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query)
51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery)
52 |
53 | def _check_before_run(self):
54 | """Check if all files are available before going deeper"""
55 | if not osp.exists(self.dataset_dir):
56 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
57 | if not osp.exists(self.train_dir):
58 | raise RuntimeError("'{}' is not available".format(self.train_dir))
59 | if not osp.exists(self.query_dir):
60 | raise RuntimeError("'{}' is not available".format(self.query_dir))
61 | if not osp.exists(self.gallery_dir):
62 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
63 |
64 | def _process_dir(self, dir_path, relabel=False):
65 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
66 | pattern = re.compile(r'([-\d]+)_c(\d)')
67 |
68 | pid_container = set()
69 | for img_path in sorted(img_paths):
70 | pid, _ = map(int, pattern.search(img_path).groups())
71 | if pid == -1: continue # junk images are just ignored
72 | pid_container.add(pid)
73 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
74 | dataset = []
75 | for img_path in sorted(img_paths):
76 | pid, camid = map(int, pattern.search(img_path).groups())
77 | if pid == -1: continue # junk images are just ignored
78 | assert 0 <= pid <= 1501 # pid == 0 means background
79 | assert 1 <= camid <= 6
80 | camid -= 1 # index starts from 0
81 | if relabel: pid = pid2label[pid]
82 |
83 | dataset.append((img_path, self.pid_begin + pid, camid, 1))
84 | return dataset
85 |
--------------------------------------------------------------------------------
/datasets/msmt17.py:
--------------------------------------------------------------------------------
1 |
2 | import glob
3 | import re
4 |
5 | import os.path as osp
6 |
7 | from .bases import BaseImageDataset
8 |
9 |
10 | class MSMT17(BaseImageDataset):
11 | """
12 | MSMT17
13 |
14 | Reference:
15 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018.
16 |
17 | URL: http://www.pkuvmc.com/publications/msmt17.html
18 |
19 | Dataset statistics:
20 | # identities: 4101
21 | # images: 32621 (train) + 11659 (query) + 82161 (gallery)
22 | # cameras: 15
23 | """
24 | dataset_dir = 'MSMT17'
25 |
26 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs):
27 | super(MSMT17, self).__init__()
28 | self.pid_begin = pid_begin
29 | self.dataset_dir = osp.join(root, self.dataset_dir)
30 | self.train_dir = osp.join(self.dataset_dir, 'train')
31 | self.test_dir = osp.join(self.dataset_dir, 'test')
32 | self.list_train_path = osp.join(self.dataset_dir, 'list_train.txt')
33 | self.list_val_path = osp.join(self.dataset_dir, 'list_val.txt')
34 | self.list_query_path = osp.join(self.dataset_dir, 'list_query.txt')
35 | self.list_gallery_path = osp.join(self.dataset_dir, 'list_gallery.txt')
36 |
37 | self._check_before_run()
38 | train = self._process_dir(self.train_dir, self.list_train_path)
39 | val = self._process_dir(self.train_dir, self.list_val_path)
40 | train += val
41 | query = self._process_dir(self.test_dir, self.list_query_path)
42 | gallery = self._process_dir(self.test_dir, self.list_gallery_path)
43 | if verbose:
44 | print("=> MSMT17 loaded")
45 | self.print_dataset_statistics(train, query, gallery)
46 |
47 | self.train = train
48 | self.query = query
49 | self.gallery = gallery
50 |
51 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train)
52 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query)
53 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery)
54 | def _check_before_run(self):
55 | """Check if all files are available before going deeper"""
56 | if not osp.exists(self.dataset_dir):
57 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
58 | if not osp.exists(self.train_dir):
59 | raise RuntimeError("'{}' is not available".format(self.train_dir))
60 | if not osp.exists(self.test_dir):
61 | raise RuntimeError("'{}' is not available".format(self.test_dir))
62 |
63 | def _process_dir(self, dir_path, list_path):
64 | with open(list_path, 'r') as txt:
65 | lines = txt.readlines()
66 | dataset = []
67 | pid_container = set()
68 | cam_container = set()
69 | for img_idx, img_info in enumerate(lines):
70 | img_path, pid = img_info.split(' ')
71 | pid = int(pid) # no need to relabel
72 | camid = int(img_path.split('_')[2])
73 | img_path = osp.join(dir_path, img_path)
74 | dataset.append((img_path, self.pid_begin +pid, camid-1, 1))
75 | pid_container.add(pid)
76 | cam_container.add(camid)
77 | print(cam_container, 'cam_container')
78 | # check if pid starts from 0 and increments with 1
79 | for idx, pid in enumerate(pid_container):
80 | assert idx == pid, "See code comment for explanation"
81 | return dataset
--------------------------------------------------------------------------------
/datasets/occ_duke.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import glob
8 | import re
9 | import urllib
10 | import zipfile
11 |
12 | import os.path as osp
13 |
14 | from utils.iotools import mkdir_if_missing
15 | from .bases import BaseImageDataset
16 |
17 |
18 | class OCC_DukeMTMCreID(BaseImageDataset):
19 | """
20 | DukeMTMC-reID
21 | Reference:
22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation
25 |
26 | Dataset statistics:
27 | # identities: 1404 (train + query)
28 | # images:16522 (train) + 2228 (query) + 17661 (gallery)
29 | # cameras: 8
30 | """
31 | dataset_dir = 'Occluded_Duke'
32 |
33 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs):
34 | super(OCC_DukeMTMCreID, self).__init__()
35 | self.dataset_dir = osp.join(root, self.dataset_dir)
36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
37 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
38 | self.query_dir = osp.join(self.dataset_dir, 'query')
39 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
40 | self.pid_begin = pid_begin
41 | self._download_data()
42 | self._check_before_run()
43 |
44 | train = self._process_dir(self.train_dir, relabel=True)
45 | query = self._process_dir(self.query_dir, relabel=False)
46 | gallery = self._process_dir(self.gallery_dir, relabel=False)
47 |
48 | if verbose:
49 | print("=> DukeMTMC-reID loaded")
50 | self.print_dataset_statistics(train, query, gallery)
51 |
52 | self.train = train
53 | self.query = query
54 | self.gallery = gallery
55 |
56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train)
57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query)
58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery)
59 |
60 | def _download_data(self):
61 | if osp.exists(self.dataset_dir):
62 | print("This dataset has been downloaded.")
63 | return
64 |
65 | print("Creating directory {}".format(self.dataset_dir))
66 | mkdir_if_missing(self.dataset_dir)
67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url))
68 |
69 | print("Downloading DukeMTMC-reID dataset")
70 | urllib.request.urlretrieve(self.dataset_url, fpath)
71 |
72 | print("Extracting files")
73 | zip_ref = zipfile.ZipFile(fpath, 'r')
74 | zip_ref.extractall(self.dataset_dir)
75 | zip_ref.close()
76 |
77 | def _check_before_run(self):
78 | """Check if all files are available before going deeper"""
79 | if not osp.exists(self.dataset_dir):
80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
81 | if not osp.exists(self.train_dir):
82 | raise RuntimeError("'{}' is not available".format(self.train_dir))
83 | if not osp.exists(self.query_dir):
84 | raise RuntimeError("'{}' is not available".format(self.query_dir))
85 | if not osp.exists(self.gallery_dir):
86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
87 |
88 | def _process_dir(self, dir_path, relabel=False):
89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
90 | pattern = re.compile(r'([-\d]+)_c(\d)')
91 |
92 | pid_container = set()
93 | for img_path in img_paths:
94 | pid, _ = map(int, pattern.search(img_path).groups())
95 | pid_container.add(pid)
96 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
97 |
98 | dataset = []
99 | cam_container = set()
100 | for img_path in img_paths:
101 | pid, camid = map(int, pattern.search(img_path).groups())
102 | assert 1 <= camid <= 8
103 | camid -= 1 # index starts from 0
104 | if relabel: pid = pid2label[pid]
105 | dataset.append((img_path, self.pid_begin + pid, camid, 1))
106 | cam_container.add(camid)
107 | print(cam_container, 'cam_container')
108 | return dataset
109 |
--------------------------------------------------------------------------------
/datasets/occ_reid.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import glob
8 | import re
9 | import urllib
10 | import zipfile
11 |
12 | import os.path as osp
13 |
14 | from utils.iotools import mkdir_if_missing
15 | from .bases import BaseImageDataset
16 |
17 |
18 | class Occ_ReID(BaseImageDataset):
19 |
20 | dataset_dir_train = 'market1501'
21 | dataset_dir_test = 'OccludedREID'
22 |
23 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs):
24 | super(Occ_ReID, self).__init__()
25 | self.dataset_dir_train = osp.join(root, self.dataset_dir_train)
26 | self.dataset_dir_test = osp.join(root, self.dataset_dir_test)
27 |
28 | self.train_dir = osp.join(self.dataset_dir_train, 'bounding_box_train')
29 | self.query_dir = osp.join(self.dataset_dir_test, 'query')
30 | self.gallery_dir = osp.join(self.dataset_dir_test, 'gallery')
31 | self.pid_begin = pid_begin
32 | self._check_before_run()
33 |
34 | train = self._process_dir_train(self.train_dir, relabel=True)
35 | query = self._process_dir_test(self.query_dir, camera_id=1, relabel=False)
36 | gallery = self._process_dir_test(self.gallery_dir, camera_id=2, relabel=False)
37 |
38 | if verbose:
39 | print("=> Occ_ReID loaded")
40 | self.print_dataset_statistics(train, query, gallery)
41 |
42 | self.train = train
43 | self.query = query
44 | self.gallery = gallery
45 |
46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train)
47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query)
48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery)
49 |
50 | def _check_before_run(self):
51 | """Check if all files are available before going deeper"""
52 | if not osp.exists(self.dataset_dir_train):
53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir_train))
54 | if not osp.exists(self.train_dir):
55 | raise RuntimeError("'{}' is not available".format(self.train_dir))
56 | if not osp.exists(self.query_dir):
57 | raise RuntimeError("'{}' is not available".format(self.query_dir))
58 | if not osp.exists(self.gallery_dir):
59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
60 |
61 | def _process_dir_train(self, dir_path, relabel=False):
62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
63 | pattern = re.compile(r'([-\d]+)_c(\d)')
64 |
65 | pid_container = set()
66 | for img_path in sorted(img_paths):
67 | pid, _ = map(int, pattern.search(img_path).groups())
68 | if pid == -1: continue # junk images are just ignored
69 | pid_container.add(pid)
70 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
71 | dataset = []
72 | for img_path in sorted(img_paths):
73 | pid, camid = map(int, pattern.search(img_path).groups())
74 | if pid == -1: continue # junk images are just ignored
75 | assert 0 <= pid <= 1501 # pid == 0 means background
76 | assert 1 <= camid <= 6
77 | camid -= 1 # index starts from 0
78 | if relabel: pid = pid2label[pid]
79 |
80 | dataset.append((img_path, self.pid_begin + pid, camid, 1))
81 | return dataset
82 |
83 | def _process_dir_test(self, dir_path, camera_id=1, relabel=False):
84 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
85 | pid_container = set()
86 | for img_path in img_paths:
87 | jpg_name = img_path.split('/')[-1]
88 | pid = int(jpg_name.split('_')[0])
89 | pid_container.add(pid)
90 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
91 |
92 | data = []
93 | for img_path in img_paths:
94 | jpg_name = img_path.split('/')[-1]
95 | pid = int(jpg_name.split('_')[0])
96 | camid = camera_id
97 | camid -= 1 # index starts from 0
98 | if relabel:
99 | pid = pid2label[pid]
100 | data.append((img_path, pid, camid, 1))
101 | return data
102 |
--------------------------------------------------------------------------------
/datasets/partial_reid.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import glob
8 | import re
9 | import urllib
10 | import zipfile
11 |
12 | import os.path as osp
13 |
14 | from utils.iotools import mkdir_if_missing
15 | from .bases import BaseImageDataset
16 |
17 |
18 | class Partial_REID(BaseImageDataset):
19 |
20 | dataset_dir_train = 'market1501'
21 | dataset_dir_test = 'Partial_REID'
22 |
23 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs):
24 | super(Partial_REID, self).__init__()
25 | self.dataset_dir_train = osp.join(root, self.dataset_dir_train)
26 | self.dataset_dir_test = osp.join(root, self.dataset_dir_test)
27 |
28 | self.train_dir = osp.join(self.dataset_dir_train, 'bounding_box_train')
29 | self.query_dir = osp.join(self.dataset_dir_test, 'partial_body_images')
30 | self.gallery_dir = osp.join(self.dataset_dir_test, 'whole_body_images')
31 | self.pid_begin = pid_begin
32 | self._check_before_run()
33 |
34 | train = self._process_dir_train(self.train_dir, relabel=True)
35 | query = self._process_dir_test(self.query_dir, camera_id=1, relabel=False)
36 | gallery = self._process_dir_test(self.gallery_dir, camera_id=2, relabel=False)
37 |
38 | if verbose:
39 | print("=> Partial REID loaded")
40 | self.print_dataset_statistics(train, query, gallery)
41 |
42 | self.train = train
43 | self.query = query
44 | self.gallery = gallery
45 |
46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train)
47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query)
48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery)
49 |
50 | def _check_before_run(self):
51 | """Check if all files are available before going deeper"""
52 | if not osp.exists(self.dataset_dir_train):
53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir_train))
54 | if not osp.exists(self.train_dir):
55 | raise RuntimeError("'{}' is not available".format(self.train_dir))
56 | if not osp.exists(self.query_dir):
57 | raise RuntimeError("'{}' is not available".format(self.query_dir))
58 | if not osp.exists(self.gallery_dir):
59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
60 |
61 | def _process_dir_train(self, dir_path, relabel=False):
62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
63 | pattern = re.compile(r'([-\d]+)_c(\d)')
64 |
65 | pid_container = set()
66 | for img_path in sorted(img_paths):
67 | pid, _ = map(int, pattern.search(img_path).groups())
68 | if pid == -1: continue # junk images are just ignored
69 | pid_container.add(pid)
70 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
71 | dataset = []
72 | for img_path in sorted(img_paths):
73 | pid, camid = map(int, pattern.search(img_path).groups())
74 | if pid == -1: continue # junk images are just ignored
75 | assert 0 <= pid <= 1501 # pid == 0 means background
76 | assert 1 <= camid <= 6
77 | camid -= 1 # index starts from 0
78 | if relabel: pid = pid2label[pid]
79 |
80 | dataset.append((img_path, self.pid_begin + pid, camid, 1))
81 | return dataset
82 |
83 | def _process_dir_test(self, dir_path, camera_id=1, relabel=False):
84 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
85 | pid_container = set()
86 | for img_path in img_paths:
87 | jpg_name = img_path.split('/')[-1]
88 | pid = int(jpg_name.split('_')[0])
89 | pid_container.add(pid)
90 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
91 |
92 | data = []
93 | for img_path in img_paths:
94 | jpg_name = img_path.split('/')[-1]
95 | pid = int(jpg_name.split('_')[0])
96 | camid = camera_id
97 | camid -= 1 # index starts from 0
98 | if relabel:
99 | pid = pid2label[pid]
100 | data.append((img_path, pid, camid, 1))
101 | return data
102 |
--------------------------------------------------------------------------------
/datasets/preprocessing.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 |
4 |
5 | class RandomErasing(object):
6 | """ Randomly selects a rectangle region in an image and erases its pixels.
7 | 'Random Erasing Data Augmentation' by Zhong et al.
8 | See https://arxiv.org/pdf/1708.04896.pdf
9 | Args:
10 | probability: The probability that the Random Erasing operation will be performed.
11 | sl: Minimum proportion of erased area against input image.
12 | sh: Maximum proportion of erased area against input image.
13 | r1: Minimum aspect ratio of erased area.
14 | mean: Erasing value.
15 | """
16 |
17 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
18 | self.probability = probability
19 | self.mean = mean
20 | self.sl = sl
21 | self.sh = sh
22 | self.r1 = r1
23 |
24 | def __call__(self, img):
25 |
26 | if random.uniform(0, 1) >= self.probability:
27 | return img
28 |
29 | for attempt in range(100):
30 | area = img.size()[1] * img.size()[2]
31 |
32 | target_area = random.uniform(self.sl, self.sh) * area
33 | aspect_ratio = random.uniform(self.r1, 1 / self.r1)
34 |
35 | h = int(round(math.sqrt(target_area * aspect_ratio)))
36 | w = int(round(math.sqrt(target_area / aspect_ratio)))
37 |
38 | if w < img.size()[2] and h < img.size()[1]:
39 | x1 = random.randint(0, img.size()[1] - h)
40 | y1 = random.randint(0, img.size()[2] - w)
41 | if img.size()[0] == 3:
42 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
43 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
44 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
45 | else:
46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
47 | return img
48 |
49 | return img
50 |
51 |
--------------------------------------------------------------------------------
/datasets/sampler.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.sampler import Sampler
2 | from collections import defaultdict
3 | import copy
4 | import random
5 | import numpy as np
6 |
7 | class RandomIdentitySampler(Sampler):
8 | """
9 | Randomly sample N identities, then for each identity,
10 | randomly sample K instances, therefore batch size is N*K.
11 | Args:
12 | - data_source (list): list of (img_path, pid, camid).
13 | - num_instances (int): number of instances per identity in a batch.
14 | - batch_size (int): number of examples in a batch.
15 | """
16 |
17 | def __init__(self, data_source, batch_size, num_instances):
18 | self.data_source = data_source
19 | self.batch_size = batch_size
20 | self.num_instances = num_instances
21 | self.num_pids_per_batch = self.batch_size // self.num_instances
22 | self.index_dic = defaultdict(list) #dict with list value
23 | #{783: [0, 5, 116, 876, 1554, 2041],...,}
24 | for index, (_, pid, _, _) in enumerate(self.data_source):
25 | self.index_dic[pid].append(index)
26 | self.pids = list(self.index_dic.keys())
27 |
28 | # estimate number of examples in an epoch
29 | self.length = 0
30 | for pid in self.pids:
31 | idxs = self.index_dic[pid]
32 | num = len(idxs)
33 | if num < self.num_instances:
34 | num = self.num_instances
35 | self.length += num - num % self.num_instances
36 |
37 | def __iter__(self):
38 | batch_idxs_dict = defaultdict(list)
39 |
40 | for pid in self.pids:
41 | idxs = copy.deepcopy(self.index_dic[pid])
42 | if len(idxs) < self.num_instances:
43 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
44 | random.shuffle(idxs)
45 | batch_idxs = []
46 | for idx in idxs:
47 | batch_idxs.append(idx)
48 | if len(batch_idxs) == self.num_instances:
49 | batch_idxs_dict[pid].append(batch_idxs)
50 | batch_idxs = []
51 |
52 | avai_pids = copy.deepcopy(self.pids)
53 | final_idxs = []
54 |
55 | while len(avai_pids) >= self.num_pids_per_batch:
56 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
57 | for pid in selected_pids:
58 | batch_idxs = batch_idxs_dict[pid].pop(0)
59 | final_idxs.extend(batch_idxs)
60 | if len(batch_idxs_dict[pid]) == 0:
61 | avai_pids.remove(pid)
62 |
63 | return iter(final_idxs)
64 |
65 | def __len__(self):
66 | return self.length
67 |
68 |
--------------------------------------------------------------------------------
/datasets/veri.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import re
3 | import os.path as osp
4 |
5 | from .bases import BaseImageDataset
6 |
7 |
8 | class VeRi(BaseImageDataset):
9 | """
10 | VeRi-776
11 | Reference:
12 | Liu, Xinchen, et al. "Large-scale vehicle re-identification in urban surveillance videos." ICME 2016.
13 |
14 | URL:https://vehiclereid.github.io/VeRi/
15 |
16 | Dataset statistics:
17 | # identities: 776
18 | # images: 37778 (train) + 1678 (query) + 11579 (gallery)
19 | # cameras: 20
20 | """
21 |
22 | dataset_dir = 'VeRi'
23 |
24 | def __init__(self, root='', verbose=True, **kwargs):
25 | super(VeRi, self).__init__()
26 | self.dataset_dir = osp.join(root, self.dataset_dir)
27 | self.train_dir = osp.join(self.dataset_dir, 'image_train')
28 | self.query_dir = osp.join(self.dataset_dir, 'image_query')
29 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
30 |
31 | self._check_before_run()
32 |
33 | path_train = 'datasets/keypoint_train.txt'
34 | with open(path_train, 'r') as txt:
35 | lines = txt.readlines()
36 | self.image_map_view_train = {}
37 | for img_idx, img_info in enumerate(lines):
38 | content = img_info.split(' ')
39 | viewid = int(content[-1])
40 | self.image_map_view_train[osp.basename(content[0])] = viewid
41 |
42 | path_test = 'datasets/keypoint_test.txt'
43 | with open(path_test, 'r') as txt:
44 | lines = txt.readlines()
45 | self.image_map_view_test = {}
46 | for img_idx, img_info in enumerate(lines):
47 | content = img_info.split(' ')
48 | viewid = int(content[-1])
49 | self.image_map_view_test[osp.basename(content[0])] = viewid
50 |
51 | train = self._process_dir(self.train_dir, relabel=True)
52 | query = self._process_dir(self.query_dir, relabel=False)
53 | gallery = self._process_dir(self.gallery_dir, relabel=False)
54 |
55 | if verbose:
56 | print("=> VeRi-776 loaded")
57 | self.print_dataset_statistics(train, query, gallery)
58 |
59 | self.train = train
60 | self.query = query
61 | self.gallery = gallery
62 |
63 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(
64 | self.train)
65 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(
66 | self.query)
67 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(
68 | self.gallery)
69 |
70 | def _check_before_run(self):
71 | """Check if all files are available before going deeper"""
72 | if not osp.exists(self.dataset_dir):
73 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
74 | if not osp.exists(self.train_dir):
75 | raise RuntimeError("'{}' is not available".format(self.train_dir))
76 | if not osp.exists(self.query_dir):
77 | raise RuntimeError("'{}' is not available".format(self.query_dir))
78 | if not osp.exists(self.gallery_dir):
79 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
80 |
81 | def _process_dir(self, dir_path, relabel=False):
82 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
83 | pattern = re.compile(r'([-\d]+)_c(\d+)')
84 |
85 | pid_container = set()
86 | for img_path in img_paths:
87 | pid, _ = map(int, pattern.search(img_path).groups())
88 | if pid == -1: continue # junk images are just ignored
89 | pid_container.add(pid)
90 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
91 |
92 | view_container = set()
93 | dataset = []
94 | count = 0
95 | for img_path in img_paths:
96 | pid, camid = map(int, pattern.search(img_path).groups())
97 | if pid == -1: continue # junk images are just ignored
98 | assert 0 <= pid <= 776 # pid == 0 means background
99 | assert 1 <= camid <= 20
100 | camid -= 1 # index starts from 0
101 | if relabel: pid = pid2label[pid]
102 |
103 | if osp.basename(img_path) not in self.image_map_view_train.keys():
104 | try:
105 | viewid = self.image_map_view_test[osp.basename(img_path)]
106 | except:
107 | count += 1
108 | # print(img_path, 'img_path')
109 | continue
110 | else:
111 | viewid = self.image_map_view_train[osp.basename(img_path)]
112 | view_container.add(viewid)
113 | dataset.append((img_path, pid, camid, viewid))
114 | print(view_container, 'view_container')
115 | print(count, 'samples without viewpoint annotations')
116 | return dataset
117 |
118 |
--------------------------------------------------------------------------------
/dist_test.sh:
--------------------------------------------------------------------------------
1 |
2 | python test.py --config_file configs/DukeMTMC/vit_transreid_stride.yml MODEL.DEVICE_ID "('5')" TEST.WEIGHT '/data2/zi.wang/code/PartialReID-final/logs_duke/lr0008_b32_Process1_Model12_loss1/transformer_best.pth' OUTPUT_DIR './logs_duke/test_AO_0.2'
3 |
4 | python test.py --config_file configs/Market/vit_transreid_stride.yml MODEL.DEVICE_ID "('7')" TEST.WEIGHT '/data2/zi.wang/code/PartialReID-final/logs_market/lr0008_b32_Process1_Model12_loss1/transformer_best.pth' OUTPUT_DIR './logs_market/test_AO_0.2'
5 |
6 |
7 | python test.py --config_file configs/OCC_Duke/vit_transreid_stride.yml MODEL.DEVICE_ID "('5')" TEST.WEIGHT '../logs/occ_duke_vit_transreid_stride/transformer_120.pth'
--------------------------------------------------------------------------------
/dist_train.sh:
--------------------------------------------------------------------------------
1 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.001 OUTPUT_DIR './logs_partial_reid/lr0001_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 20
2 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.002 OUTPUT_DIR './logs_partial_reid/lr0002_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 20
3 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.004 OUTPUT_DIR './logs_partial_reid/lr0004_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 15
4 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.008 OUTPUT_DIR './logs_partial_reid/lr0008_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 15
5 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.0001 OUTPUT_DIR './logs_partial_reid/lr00001_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25
6 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.0002 OUTPUT_DIR './logs_partial_reid/lr00002_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25
7 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.0004 OUTPUT_DIR './logs_partial_reid/lr00004_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25
8 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.0008 OUTPUT_DIR './logs_partial_reid/lr00008_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25
9 |
--------------------------------------------------------------------------------
/dist_train_occReID.sh:
--------------------------------------------------------------------------------
1 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.001 OUTPUT_DIR './logs_occ_reid/lr0001_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 20
2 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.002 OUTPUT_DIR './logs_occ_reid/lr0002_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 20
3 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.004 OUTPUT_DIR './logs_occ_reid/lr0004_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 15
4 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.008 OUTPUT_DIR './logs_occ_reid/lr0008_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 15
5 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.0001 OUTPUT_DIR './logs_occ_reid/lr00001_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25
6 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.0002 OUTPUT_DIR './logs_occ_reid/lr00002_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25
7 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.0004 OUTPUT_DIR './logs_occ_reid/lr00004_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25
8 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.0008 OUTPUT_DIR './logs_occ_reid/lr00008_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25
9 |
--------------------------------------------------------------------------------
/fig/1:
--------------------------------------------------------------------------------
1 | 11
2 |
--------------------------------------------------------------------------------
/fig/OccludedREID_gallery.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/OccludedREID_gallery.jpg
--------------------------------------------------------------------------------
/fig/OccludedREID_query.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/OccludedREID_query.jpg
--------------------------------------------------------------------------------
/fig/RankingList-partial.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/RankingList-partial.png
--------------------------------------------------------------------------------
/fig/image-20221018171750395.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/image-20221018171750395.png
--------------------------------------------------------------------------------
/fig/image-20221018171831853.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/image-20221018171831853.png
--------------------------------------------------------------------------------
/fig/image-20221018171840117.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/image-20221018171840117.png
--------------------------------------------------------------------------------
/fig/market_train.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/market_train.jpg
--------------------------------------------------------------------------------
/fig/partial_gallery.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/partial_gallery.jpg
--------------------------------------------------------------------------------
/fig/partial_query.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/partial_query.jpg
--------------------------------------------------------------------------------
/loss/HCloss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch import nn, tensor
3 | import torch
4 | from torch.autograd import Variable
5 | import pdb
6 |
7 | class hetero_loss(nn.Module):
8 | def __init__(self, margin=0.1, dist_type = 'l2'):
9 | super(hetero_loss, self).__init__()
10 | self.margin = margin
11 | self.dist_type = dist_type
12 | if dist_type == 'l2':
13 | self.dist = nn.MSELoss(reduction='sum')
14 | if dist_type == 'cos':
15 | self.dist = nn.CosineSimilarity(dim=0)
16 | if dist_type == 'l1':
17 | self.dist = nn.L1Loss()
18 |
19 | def forward(self, feat1, feat2, label1):
20 | feat_size = feat1.size()[1]
21 | feat_num = feat1.size()[0]
22 | label_num = len(label1.unique())
23 | feat1 = feat1.chunk(label_num, 0)
24 | feat2 = feat2.chunk(label_num, 0)
25 | #loss = Variable(.cuda())
26 | # pdb.set_trace()
27 | for i in range(label_num):
28 | center1 = torch.mean(feat1[i], dim=0)
29 | center2 = torch.mean(feat2[i], dim=0)
30 | if self.dist_type == 'l2' or self.dist_type == 'l1':
31 | if i == 0:
32 | dist = max(0, self.dist(center1, center2) - self.margin)
33 | else:
34 | dist += max(0, self.dist(center1, center2) - self.margin)
35 | elif self.dist_type == 'cos':
36 | if i == 0:
37 | dist = max(0, 1-self.dist(center1, center2) - self.margin)
38 | else:
39 | dist += max(0, 1-self.dist(center1, center2) - self.margin)
40 |
41 | return dist
--------------------------------------------------------------------------------
/loss/KLloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class KDLoss(nn.Module):
7 |
8 | def __init__(self, temp: float, reduction: str):
9 | super(KDLoss, self).__init__()
10 |
11 | self.temp = temp
12 | self.reduction = reduction
13 | self.kl_loss = nn.KLDivLoss(reduction=reduction)
14 |
15 | def forward(self, teacher_logits: torch.Tensor, student_logits: torch.Tensor):
16 |
17 | student_softmax = F.log_softmax(student_logits / self.temp, dim=-1)
18 | teacher_softmax = F.softmax(teacher_logits / self.temp, dim=-1)
19 |
20 | kl = nn.KLDivLoss(reduction='none')(student_softmax, teacher_softmax)
21 | kl = kl.sum() if self.reduction == 'sum' else kl.sum(1).mean()
22 | kl = kl * (self.temp ** 2)
23 |
24 | return kl
25 |
26 | def __call__(self, *args, **kwargs):
27 | return super(KDLoss, self).__call__(*args, **kwargs)
28 |
29 |
30 | class LogitsMatching(nn.Module):
31 |
32 | def __init__(self, reduction: str):
33 | super(LogitsMatching, self).__init__()
34 | self.mse_loss = nn.MSELoss(reduction=reduction)
35 |
36 | def forward(self, teacher_logits: torch.Tensor, student_logits: torch.Tensor):
37 | return self.mse_loss(student_logits, teacher_logits)
38 |
39 | def __call__(self, *args, **kwargs):
40 | return super(LogitsMatching, self).__call__(*args, **kwargs)
--------------------------------------------------------------------------------
/loss/MSEloss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class HintLoss(nn.Module):
4 | """Fitnets: hints for thin deep nets, ICLR 2015"""
5 | def __init__(self):
6 | super(HintLoss, self).__init__()
7 | self.crit = nn.MSELoss()
8 |
9 | def forward(self, f_s, f_t):
10 | loss = self.crit(f_s, f_t)
11 | return loss
12 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .make_loss import make_loss
2 | from .arcface import ArcFace
--------------------------------------------------------------------------------
/loss/__pycache__/HCloss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/HCloss.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/KLloss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/KLloss.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/MSEloss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/MSEloss.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/arcface.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/arcface.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/center_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/center_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/make_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/make_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/metric_learning.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/metric_learning.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/softmax_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/softmax_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/triplet_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/triplet_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/arcface.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import Parameter
5 | import math
6 |
7 |
8 | class ArcFace(nn.Module):
9 | def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False):
10 | super(ArcFace, self).__init__()
11 | self.in_features = in_features
12 | self.out_features = out_features
13 | self.s = s
14 | self.m = m
15 | self.cos_m = math.cos(m)
16 | self.sin_m = math.sin(m)
17 |
18 | self.th = math.cos(math.pi - m)
19 | self.mm = math.sin(math.pi - m) * m
20 |
21 | self.weight = Parameter(torch.Tensor(out_features, in_features))
22 | if bias:
23 | self.bias = Parameter(torch.Tensor(out_features))
24 | else:
25 | self.register_parameter('bias', None)
26 | self.reset_parameters()
27 |
28 | def reset_parameters(self):
29 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
30 | if self.bias is not None:
31 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
32 | bound = 1 / math.sqrt(fan_in)
33 | nn.init.uniform_(self.bias, -bound, bound)
34 |
35 | def forward(self, input, label):
36 | cosine = F.linear(F.normalize(input), F.normalize(self.weight))
37 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
38 | phi = cosine * self.cos_m - sine * self.sin_m
39 | phi = torch.where(cosine > self.th, phi, cosine - self.mm)
40 | # --------------------------- convert label to one-hot ---------------------------
41 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
42 | one_hot = torch.zeros(cosine.size(), device='cuda')
43 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
44 | # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
45 | output = (one_hot * phi) + (
46 | (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
47 | output *= self.s
48 | # print(output)
49 |
50 | return output
51 |
52 | class CircleLoss(nn.Module):
53 | def __init__(self, in_features, num_classes, s=256, m=0.25):
54 | super(CircleLoss, self).__init__()
55 | self.weight = Parameter(torch.Tensor(num_classes, in_features))
56 | self.s = s
57 | self.m = m
58 | self._num_classes = num_classes
59 | self.reset_parameters()
60 |
61 |
62 | def reset_parameters(self):
63 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
64 |
65 | def __call__(self, bn_feat, targets):
66 |
67 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
68 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.)
69 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.)
70 | delta_p = 1 - self.m
71 | delta_n = self.m
72 |
73 | s_p = self.s * alpha_p * (sim_mat - delta_p)
74 | s_n = self.s * alpha_n * (sim_mat - delta_n)
75 |
76 | targets = F.one_hot(targets, num_classes=self._num_classes)
77 |
78 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n
79 |
80 | return pred_class_logits
--------------------------------------------------------------------------------
/loss/center_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class CenterLoss(nn.Module):
8 | """Center loss.
9 |
10 | Reference:
11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
12 |
13 | Args:
14 | num_classes (int): number of classes.
15 | feat_dim (int): feature dimension.
16 | """
17 |
18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
19 | super(CenterLoss, self).__init__()
20 | self.num_classes = num_classes
21 | self.feat_dim = feat_dim
22 | self.use_gpu = use_gpu
23 |
24 | if self.use_gpu:
25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
26 | else:
27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
28 |
29 | def forward(self, x, labels):
30 | """
31 | Args:
32 | x: feature matrix with shape (batch_size, feat_dim).
33 | labels: ground truth labels with shape (num_classes).
34 | """
35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
36 |
37 | batch_size = x.size(0)
38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
40 | distmat.addmm_(1, -2, x, self.centers.t())
41 |
42 | classes = torch.arange(self.num_classes).long()
43 | if self.use_gpu: classes = classes.cuda()
44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
45 | mask = labels.eq(classes.expand(batch_size, self.num_classes))
46 |
47 | dist = []
48 | for i in range(batch_size):
49 | value = distmat[i][mask[i]]
50 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability
51 | dist.append(value)
52 | dist = torch.cat(dist)
53 | loss = dist.mean()
54 | return loss
55 |
56 |
57 | if __name__ == '__main__':
58 | use_gpu = False
59 | center_loss = CenterLoss(use_gpu=use_gpu)
60 | features = torch.rand(16, 2048)
61 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long()
62 | if use_gpu:
63 | features = torch.rand(16, 2048).cuda()
64 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda()
65 |
66 | loss = center_loss(features, targets)
67 | print(loss)
68 |
--------------------------------------------------------------------------------
/loss/make_loss.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch.nn.functional as F
8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy
9 | from .triplet_loss import TripletLoss
10 | from .center_loss import CenterLoss
11 |
12 |
13 | def make_loss(cfg, num_classes): # modified by gu
14 | sampler = cfg.DATALOADER.SAMPLER
15 | feat_dim = 2048
16 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
17 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE:
18 | if cfg.MODEL.NO_MARGIN:
19 | triplet = TripletLoss()
20 | print("using soft triplet loss for training")
21 | else:
22 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
23 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN))
24 | else:
25 | print('expected METRIC_LOSS_TYPE should be triplet'
26 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
27 |
28 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
29 | xent = CrossEntropyLabelSmooth(num_classes=num_classes)
30 | print("label smooth on, numclasses:", num_classes)
31 |
32 | if sampler == 'softmax':
33 | def loss_func(score, feat, target):
34 | return F.cross_entropy(score, target)
35 |
36 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
37 | def loss_func(score, feat, target, target_cam):
38 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
39 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
40 | if isinstance(score, list):
41 | ID_LOSS = [xent(scor, target) for scor in score[1:]]
42 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS)
43 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target)
44 | else:
45 | ID_LOSS = xent(score, target)
46 |
47 | if isinstance(feat, list):
48 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]]
49 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS)
50 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0]
51 | else:
52 | TRI_LOSS = triplet(feat, target)[0]
53 |
54 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \
55 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
56 | else:
57 | if isinstance(score, list):
58 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score[3:]]
59 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS)
60 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * ((F.cross_entropy(score[0], target) + F.cross_entropy(score[1], target) + F.cross_entropy(score[2], target))/3)
61 | else:
62 | ID_LOSS = F.cross_entropy(score, target)
63 |
64 | if isinstance(feat, list):
65 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[3:]]
66 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS)
67 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * (triplet(feat[0], target)[0] + triplet(feat[1], target)[0] + triplet(feat[2], target)[0])
68 | else:
69 | TRI_LOSS = triplet(feat, target)[0]
70 |
71 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \
72 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
73 | else:
74 | print('expected METRIC_LOSS_TYPE should be triplet'
75 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
76 |
77 | else:
78 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center'
79 | 'but got {}'.format(cfg.DATALOADER.SAMPLER))
80 | return loss_func, center_criterion
81 |
82 |
83 |
--------------------------------------------------------------------------------
/loss/make_loss1_l2norm.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch.nn.functional as F
8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy
9 | from .triplet_loss import TripletLoss
10 | from .center_loss import CenterLoss
11 | import torch.nn as nn
12 | class normalize(nn.Module):
13 | def __init__(self, power=2):
14 | super(normalize, self).__init__()
15 | self.power = power
16 |
17 | def forward(self, x):
18 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
19 | out = x.div(norm)
20 | return out
21 |
22 | def make_loss(cfg, num_classes): # modified by gu
23 | sampler = cfg.DATALOADER.SAMPLER
24 | feat_dim = 2048
25 | l2_norm = normalize()
26 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
27 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE:
28 | if cfg.MODEL.NO_MARGIN:
29 | triplet = TripletLoss()
30 | print("using soft triplet loss for training")
31 | else:
32 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
33 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN))
34 | else:
35 | print('expected METRIC_LOSS_TYPE should be triplet'
36 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
37 |
38 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
39 | xent = CrossEntropyLabelSmooth(num_classes=num_classes)
40 | print("label smooth on, numclasses:", num_classes)
41 |
42 | if sampler == 'softmax':
43 | def loss_func(score, feat, target):
44 | return F.cross_entropy(score, target)
45 |
46 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
47 | def loss_func(score, feat, target, target_cam):
48 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
49 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
50 | if isinstance(score, list):
51 | ID_LOSS = [xent(scor, target) for scor in score[1:]]
52 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS)
53 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target)
54 | else:
55 | ID_LOSS = xent(score, target)
56 |
57 | if isinstance(feat, list):
58 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]]
59 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS)
60 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0]
61 | else:
62 | TRI_LOSS = triplet(feat, target)[0]
63 |
64 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \
65 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
66 | else:
67 | if isinstance(score, list):
68 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score[3:]]
69 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS)
70 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * ((F.cross_entropy(score[0], target) + F.cross_entropy(score[1], target) + F.cross_entropy(score[2], target))/3)
71 | else:
72 | ID_LOSS = F.cross_entropy(score, target)
73 |
74 | if isinstance(feat, list):
75 | TRI_LOSS = [triplet(l2_norm(feats), target)[0] for feats in feat[3:]]
76 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS)
77 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * (triplet(l2_norm(feat[0]), target)[0] + triplet(l2_norm(feat[1]), target)[0] + triplet(l2_norm(feat[2]), target)[0])
78 | else:
79 | TRI_LOSS = triplet(feat, target)[0]
80 |
81 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \
82 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
83 | else:
84 | print('expected METRIC_LOSS_TYPE should be triplet'
85 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
86 |
87 | else:
88 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center'
89 | 'but got {}'.format(cfg.DATALOADER.SAMPLER))
90 | return loss_func, center_criterion
91 |
92 |
93 |
--------------------------------------------------------------------------------
/loss/make_loss1_vitbase_resnet.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch.nn.functional as F
8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy
9 | from .triplet_loss import TripletLoss
10 | from .center_loss import CenterLoss
11 |
12 |
13 | def make_loss(cfg, num_classes): # modified by gu
14 | sampler = cfg.DATALOADER.SAMPLER
15 | feat_dim = 2048
16 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
17 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE:
18 | if cfg.MODEL.NO_MARGIN:
19 | triplet = TripletLoss()
20 | print("using soft triplet loss for training")
21 | else:
22 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
23 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN))
24 | else:
25 | print('expected METRIC_LOSS_TYPE should be triplet'
26 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
27 |
28 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
29 | xent = CrossEntropyLabelSmooth(num_classes=num_classes)
30 | print("label smooth on, numclasses:", num_classes)
31 |
32 | if sampler == 'softmax':
33 | def loss_func(score, feat, target):
34 | return F.cross_entropy(score, target)
35 |
36 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
37 | def loss_func(score, feat, target, target_cam):
38 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
39 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
40 | if isinstance(score, list):
41 | ID_LOSS = [xent(scor, target) for scor in score[1:]]
42 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS)
43 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target)
44 | else:
45 | ID_LOSS = xent(score, target)
46 |
47 | if isinstance(feat, list):
48 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]]
49 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS)
50 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0]
51 | else:
52 | TRI_LOSS = triplet(feat, target)[0]
53 |
54 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \
55 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
56 | else:
57 | if isinstance(score, list):
58 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score]
59 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS)
60 | # ID_LOSS = 0.5 * ID_LOSS + 0.5 * ((F.cross_entropy(score[0], target) + F.cross_entropy(score[1], target) + F.cross_entropy(score[2], target))/3)
61 | else:
62 | ID_LOSS = F.cross_entropy(score, target)
63 |
64 | if isinstance(feat, list):
65 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat]
66 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS)
67 | # TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * (triplet(feat[0], target)[0] + triplet(feat[1], target)[0] + triplet(feat[2], target)[0])
68 | else:
69 | TRI_LOSS = triplet(feat, target)[0]
70 |
71 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \
72 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
73 | else:
74 | print('expected METRIC_LOSS_TYPE should be triplet'
75 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
76 |
77 | else:
78 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center'
79 | 'but got {}'.format(cfg.DATALOADER.SAMPLER))
80 | return loss_func, center_criterion
81 |
82 |
83 |
--------------------------------------------------------------------------------
/loss/make_loss_onlyOneAugmentation.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import torch.nn.functional as F
8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy
9 | from .triplet_loss import TripletLoss
10 | from .center_loss import CenterLoss
11 |
12 |
13 | def make_loss(cfg, num_classes): # modified by gu
14 | sampler = cfg.DATALOADER.SAMPLER
15 | feat_dim = 2048
16 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
17 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE:
18 | if cfg.MODEL.NO_MARGIN:
19 | triplet = TripletLoss()
20 | print("using soft triplet loss for training")
21 | else:
22 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
23 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN))
24 | else:
25 | print('expected METRIC_LOSS_TYPE should be triplet'
26 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
27 |
28 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
29 | xent = CrossEntropyLabelSmooth(num_classes=num_classes)
30 | print("label smooth on, numclasses:", num_classes)
31 |
32 | if sampler == 'softmax':
33 | def loss_func(score, feat, target):
34 | return F.cross_entropy(score, target)
35 |
36 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
37 | def loss_func(score, feat, target, target_cam):
38 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
39 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
40 | if isinstance(score, list):
41 | ID_LOSS = [xent(scor, target) for scor in score[1:]]
42 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS)
43 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target)
44 | else:
45 | ID_LOSS = xent(score, target)
46 |
47 | if isinstance(feat, list):
48 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]]
49 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS)
50 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0]
51 | else:
52 | TRI_LOSS = triplet(feat, target)[0]
53 |
54 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \
55 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
56 | else:
57 | if isinstance(score, list):
58 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score[2:]]
59 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS)
60 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * ((F.cross_entropy(score[0], target) + F.cross_entropy(score[1], target))/2)
61 | else:
62 | ID_LOSS = F.cross_entropy(score, target)
63 |
64 | if isinstance(feat, list):
65 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[2:]]
66 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS)
67 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * (triplet(feat[0], target)[0] + triplet(feat[1], target)[0])
68 | else:
69 | TRI_LOSS = triplet(feat, target)[0]
70 |
71 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \
72 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
73 | else:
74 | print('expected METRIC_LOSS_TYPE should be triplet'
75 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
76 |
77 | else:
78 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center'
79 | 'but got {}'.format(cfg.DATALOADER.SAMPLER))
80 | return loss_func, center_criterion
81 |
82 |
83 |
--------------------------------------------------------------------------------
/loss/softmax_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | class CrossEntropyLabelSmooth(nn.Module):
5 | """Cross entropy loss with label smoothing regularizer.
6 |
7 | Reference:
8 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
9 | Equation: y = (1 - epsilon) * y + epsilon / K.
10 |
11 | Args:
12 | num_classes (int): number of classes.
13 | epsilon (float): weight.
14 | """
15 |
16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
17 | super(CrossEntropyLabelSmooth, self).__init__()
18 | self.num_classes = num_classes
19 | self.epsilon = epsilon
20 | self.use_gpu = use_gpu
21 | self.logsoftmax = nn.LogSoftmax(dim=1)
22 |
23 | def forward(self, inputs, targets):
24 | """
25 | Args:
26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
27 | targets: ground truth labels with shape (num_classes)
28 | """
29 | log_probs = self.logsoftmax(inputs)
30 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
31 | if self.use_gpu: targets = targets.cuda()
32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
33 | loss = (- targets * log_probs).mean(0).sum()
34 | return loss
35 |
36 | class LabelSmoothingCrossEntropy(nn.Module):
37 | """
38 | NLL loss with label smoothing.
39 | """
40 | def __init__(self, smoothing=0.1):
41 | """
42 | Constructor for the LabelSmoothing module.
43 | :param smoothing: label smoothing factor
44 | """
45 | super(LabelSmoothingCrossEntropy, self).__init__()
46 | assert smoothing < 1.0
47 | self.smoothing = smoothing
48 | self.confidence = 1. - smoothing
49 |
50 | def forward(self, x, target):
51 | logprobs = F.log_softmax(x, dim=-1)
52 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
53 | nll_loss = nll_loss.squeeze(1)
54 | smooth_loss = -logprobs.mean(dim=-1)
55 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
56 | return loss.mean()
--------------------------------------------------------------------------------
/loss/triplet_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def normalize(x, axis=-1):
6 | """Normalizing to unit length along the specified dimension.
7 | Args:
8 | x: pytorch Variable
9 | Returns:
10 | x: pytorch Variable, same shape as input
11 | """
12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
13 | return x
14 |
15 |
16 | def euclidean_dist(x, y):
17 | """
18 | Args:
19 | x: pytorch Variable, with shape [m, d]
20 | y: pytorch Variable, with shape [n, d]
21 | Returns:
22 | dist: pytorch Variable, with shape [m, n]
23 | """
24 | m, n = x.size(0), y.size(0)
25 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
26 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
27 | dist = xx + yy
28 | dist = dist - 2 * torch.matmul(x, y.t())
29 | # dist.addmm_(1, -2, x, y.t())
30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
31 | return dist
32 |
33 |
34 | def cosine_dist(x, y):
35 | """
36 | Args:
37 | x: pytorch Variable, with shape [m, d]
38 | y: pytorch Variable, with shape [n, d]
39 | Returns:
40 | dist: pytorch Variable, with shape [m, n]
41 | """
42 | m, n = x.size(0), y.size(0)
43 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n)
44 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t()
45 | xy_intersection = torch.mm(x, y.t())
46 | dist = xy_intersection/(x_norm * y_norm)
47 | dist = (1. - dist) / 2
48 | return dist
49 |
50 |
51 | def hard_example_mining(dist_mat, labels, return_inds=False):
52 | """For each anchor, find the hardest positive and negative sample.
53 | Args:
54 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
55 | labels: pytorch LongTensor, with shape [N]
56 | return_inds: whether to return the indices. Save time if `False`(?)
57 | Returns:
58 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
59 | dist_an: pytorch Variable, distance(anchor, negative); shape [N]
60 | p_inds: pytorch LongTensor, with shape [N];
61 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
62 | n_inds: pytorch LongTensor, with shape [N];
63 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
64 | NOTE: Only consider the case in which all labels have same num of samples,
65 | thus we can cope with all anchors in parallel.
66 | """
67 |
68 | assert len(dist_mat.size()) == 2
69 | assert dist_mat.size(0) == dist_mat.size(1)
70 | N = dist_mat.size(0)
71 |
72 | # shape [N, N]
73 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
74 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
75 |
76 | # `dist_ap` means distance(anchor, positive)
77 | # both `dist_ap` and `relative_p_inds` with shape [N, 1]
78 | dist_ap, relative_p_inds = torch.max(
79 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
80 | # print(dist_mat[is_pos].shape)
81 | # `dist_an` means distance(anchor, negative)
82 | # both `dist_an` and `relative_n_inds` with shape [N, 1]
83 | dist_an, relative_n_inds = torch.min(
84 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
85 | # shape [N]
86 | dist_ap = dist_ap.squeeze(1)
87 | dist_an = dist_an.squeeze(1)
88 |
89 | if return_inds:
90 | # shape [N, N]
91 | ind = (labels.new().resize_as_(labels)
92 | .copy_(torch.arange(0, N).long())
93 | .unsqueeze(0).expand(N, N))
94 | # shape [N, 1]
95 | p_inds = torch.gather(
96 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
97 | n_inds = torch.gather(
98 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
99 | # shape [N]
100 | p_inds = p_inds.squeeze(1)
101 | n_inds = n_inds.squeeze(1)
102 | return dist_ap, dist_an, p_inds, n_inds
103 |
104 | return dist_ap, dist_an
105 |
106 |
107 | class TripletLoss(object):
108 | """
109 | Triplet loss using HARDER example mining,
110 | modified based on original triplet loss using hard example mining
111 | """
112 |
113 | def __init__(self, margin=None, hard_factor=0.0):
114 | self.margin = margin
115 | self.hard_factor = hard_factor
116 | if margin is not None:
117 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
118 | else:
119 | self.ranking_loss = nn.SoftMarginLoss()
120 |
121 | def __call__(self, global_feat, labels, normalize_feature=False):
122 | if normalize_feature:
123 | global_feat = normalize(global_feat, axis=-1)
124 | dist_mat = euclidean_dist(global_feat, global_feat)
125 | dist_ap, dist_an = hard_example_mining(dist_mat, labels)
126 |
127 | dist_ap *= (1.0 + self.hard_factor)
128 | dist_an *= (1.0 - self.hard_factor)
129 |
130 | y = dist_an.new().resize_as_(dist_an).fill_(1)
131 | if self.margin is not None:
132 | loss = self.ranking_loss(dist_an, dist_ap, y)
133 | else:
134 | loss = self.ranking_loss(dist_an - dist_ap, y)
135 | return loss, dist_ap, dist_an
136 |
137 |
138 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .make_model import make_model
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/make_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/__pycache__/make_model.cpython-36.pyc
--------------------------------------------------------------------------------
/model/backbones/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/backbones/__init__.py
--------------------------------------------------------------------------------
/model/backbones/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/backbones/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/backbones/__pycache__/osnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/backbones/__pycache__/osnet.cpython-36.pyc
--------------------------------------------------------------------------------
/model/backbones/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/backbones/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/model/backbones/__pycache__/vit_pytorch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/backbones/__pycache__/vit_pytorch.cpython-36.pyc
--------------------------------------------------------------------------------
/model/backbones/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | """3x3 convolution with padding"""
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 |
12 |
13 | class BasicBlock(nn.Module):
14 | expansion = 1
15 |
16 | def __init__(self, inplanes, planes, stride=1, downsample=None):
17 | super(BasicBlock, self).__init__()
18 | self.conv1 = conv3x3(inplanes, planes, stride)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 | self.relu = nn.ReLU(inplace=True)
21 | self.conv2 = conv3x3(planes, planes)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.downsample = downsample
24 | self.stride = stride
25 |
26 | def forward(self, x):
27 | residual = x
28 |
29 | out = self.conv1(x)
30 | out = self.bn1(out)
31 | out = self.relu(out)
32 |
33 | out = self.conv2(out)
34 | out = self.bn2(out)
35 |
36 | if self.downsample is not None:
37 | residual = self.downsample(x)
38 |
39 | out += residual
40 | out = self.relu(out)
41 |
42 | return out
43 |
44 |
45 | class Bottleneck(nn.Module):
46 | expansion = 4
47 |
48 | def __init__(self, inplanes, planes, stride=1, downsample=None):
49 | super(Bottleneck, self).__init__()
50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
51 | self.bn1 = nn.BatchNorm2d(planes)
52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
53 | padding=1, bias=False)
54 | self.bn2 = nn.BatchNorm2d(planes)
55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
56 | self.bn3 = nn.BatchNorm2d(planes * 4)
57 | self.relu = nn.ReLU(inplace=True)
58 | self.downsample = downsample
59 | self.stride = stride
60 |
61 | def forward(self, x):
62 | residual = x
63 |
64 | out = self.conv1(x)
65 | out = self.bn1(out)
66 | out = self.relu(out)
67 |
68 | out = self.conv2(out)
69 | out = self.bn2(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv3(out)
73 | out = self.bn3(out)
74 |
75 | if self.downsample is not None:
76 | residual = self.downsample(x)
77 |
78 | out += residual
79 | out = self.relu(out)
80 |
81 | return out
82 |
83 |
84 | class ResNet(nn.Module):
85 | def __init__(self, last_stride=2, block=Bottleneck,layers=[3, 4, 6, 3]):
86 | self.inplanes = 64
87 | super().__init__()
88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
89 | bias=False)
90 | self.bn1 = nn.BatchNorm2d(64)
91 | # self.relu = nn.ReLU(inplace=True) # add missed relu
92 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=None, padding=0)
93 | self.layer1 = self._make_layer(block, 64, layers[0])
94 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
96 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride)
97 |
98 | def _make_layer(self, block, planes, blocks, stride=1):
99 | downsample = None
100 | if stride != 1 or self.inplanes != planes * block.expansion:
101 | downsample = nn.Sequential(
102 | nn.Conv2d(self.inplanes, planes * block.expansion,
103 | kernel_size=1, stride=stride, bias=False),
104 | nn.BatchNorm2d(planes * block.expansion),
105 | )
106 |
107 | layers = []
108 | layers.append(block(self.inplanes, planes, stride, downsample))
109 | self.inplanes = planes * block.expansion
110 | for i in range(1, blocks):
111 | layers.append(block(self.inplanes, planes))
112 |
113 | return nn.Sequential(*layers)
114 |
115 | def forward(self, x, cam_label=None):
116 | x = self.conv1(x)
117 | x = self.bn1(x)
118 | # x = self.relu(x) # add missed relu
119 | x = self.maxpool(x)
120 | x = self.layer1(x)
121 | x = self.layer2(x)
122 | x = self.layer3(x)
123 | x = self.layer4(x)
124 |
125 | return x
126 |
127 | def load_param(self, model_path):
128 | param_dict = torch.load(model_path)
129 | for i in param_dict:
130 | if 'fc' in i:
131 | continue
132 | self.state_dict()[i].copy_(param_dict[i])
133 |
134 | def random_init(self):
135 | for m in self.modules():
136 | if isinstance(m, nn.Conv2d):
137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
138 | m.weight.data.normal_(0, math.sqrt(2. / n))
139 | elif isinstance(m, nn.BatchNorm2d):
140 | m.weight.data.fill_(1)
141 | m.bias.data.zero_()
--------------------------------------------------------------------------------
/processor/__init__.py:
--------------------------------------------------------------------------------
1 | from .processor import do_train, do_inference
--------------------------------------------------------------------------------
/processor/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/processor/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/processor/__pycache__/processor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/processor/__pycache__/processor.cpython-36.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | timm
4 | yacs
5 | opencv-python
--------------------------------------------------------------------------------
/solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .lr_scheduler import WarmupMultiStepLR
2 | from .make_optimizer import make_optimizer
--------------------------------------------------------------------------------
/solver/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/__pycache__/cosine_lr.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/cosine_lr.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/__pycache__/lr_scheduler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/lr_scheduler.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/__pycache__/make_optimizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/make_optimizer.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/__pycache__/scheduler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/scheduler.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/__pycache__/scheduler_factory.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/scheduler_factory.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/cosine_lr.py:
--------------------------------------------------------------------------------
1 | """ Cosine Scheduler
2 |
3 | Cosine LR schedule with warmup, cycle/restarts, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import logging
8 | import math
9 | import torch
10 |
11 | from .scheduler import Scheduler
12 |
13 |
14 | _logger = logging.getLogger(__name__)
15 |
16 |
17 | class CosineLRScheduler(Scheduler):
18 | """
19 | Cosine decay with restarts.
20 | This is described in the paper https://arxiv.org/abs/1608.03983.
21 |
22 | Inspiration from
23 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
24 | """
25 |
26 | def __init__(self,
27 | optimizer: torch.optim.Optimizer,
28 | t_initial: int,
29 | t_mul: float = 1.,
30 | lr_min: float = 0.,
31 | decay_rate: float = 1.,
32 | warmup_t=0,
33 | warmup_lr_init=0,
34 | warmup_prefix=False,
35 | cycle_limit=0,
36 | t_in_epochs=True,
37 | noise_range_t=None,
38 | noise_pct=0.67,
39 | noise_std=1.0,
40 | noise_seed=42,
41 | initialize=True) -> None:
42 | super().__init__(
43 | optimizer, param_group_field="lr",
44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
45 | initialize=initialize)
46 |
47 | assert t_initial > 0
48 | assert lr_min >= 0
49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1:
50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning "
51 | "rate since t_initial = t_mul = eta_mul = 1.")
52 | self.t_initial = t_initial
53 | self.t_mul = t_mul
54 | self.lr_min = lr_min
55 | self.decay_rate = decay_rate
56 | self.cycle_limit = cycle_limit
57 | self.warmup_t = warmup_t
58 | self.warmup_lr_init = warmup_lr_init
59 | self.warmup_prefix = warmup_prefix
60 | self.t_in_epochs = t_in_epochs
61 | if self.warmup_t:
62 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
63 | super().update_groups(self.warmup_lr_init)
64 | else:
65 | self.warmup_steps = [1 for _ in self.base_values]
66 |
67 | def _get_lr(self, t):
68 | if t < self.warmup_t:
69 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
70 | else:
71 | if self.warmup_prefix:
72 | t = t - self.warmup_t
73 |
74 | if self.t_mul != 1:
75 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
76 | t_i = self.t_mul ** i * self.t_initial
77 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
78 | else:
79 | i = t // self.t_initial
80 | t_i = self.t_initial
81 | t_curr = t - (self.t_initial * i)
82 |
83 | gamma = self.decay_rate ** i
84 | lr_min = self.lr_min * gamma
85 | lr_max_values = [v * gamma for v in self.base_values]
86 |
87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
88 | lrs = [
89 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
90 | ]
91 | else:
92 | lrs = [self.lr_min for _ in self.base_values]
93 |
94 | return lrs
95 |
96 | def get_epoch_values(self, epoch: int):
97 | if self.t_in_epochs:
98 | return self._get_lr(epoch)
99 | else:
100 | return None
101 |
102 | def get_update_values(self, num_updates: int):
103 | if not self.t_in_epochs:
104 | return self._get_lr(num_updates)
105 | else:
106 | return None
107 |
108 | def get_cycle_length(self, cycles=0):
109 | if not cycles:
110 | cycles = self.cycle_limit
111 | cycles = max(1, cycles)
112 | if self.t_mul == 1.0:
113 | return self.t_initial * cycles
114 | else:
115 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
116 |
--------------------------------------------------------------------------------
/solver/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | from bisect import bisect_right
7 | import torch
8 |
9 |
10 | # FIXME ideally this would be achieved with a CombinedLRScheduler,
11 | # separating MultiStepLR with WarmupLR
12 | # but the current LRScheduler design doesn't allow it
13 |
14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
15 | def __init__(
16 | self,
17 | optimizer,
18 | milestones, # steps
19 | gamma=0.1,
20 | warmup_factor=1.0 / 3,
21 | warmup_iters=500,
22 | warmup_method="linear",
23 | last_epoch=-1,
24 | ):
25 | if not list(milestones) == sorted(milestones):
26 | raise ValueError(
27 | "Milestones should be a list of" " increasing integers. Got {}",
28 | milestones,
29 | )
30 |
31 | if warmup_method not in ("constant", "linear"):
32 | raise ValueError(
33 | "Only 'constant' or 'linear' warmup_method accepted"
34 | "got {}".format(warmup_method)
35 | )
36 | self.milestones = milestones
37 | self.gamma = gamma
38 | self.warmup_factor = warmup_factor
39 | self.warmup_iters = warmup_iters
40 | self.warmup_method = warmup_method
41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
42 |
43 | def _get_lr(self):
44 | warmup_factor = 1
45 | if self.last_epoch < self.warmup_iters:
46 | if self.warmup_method == "constant":
47 | warmup_factor = self.warmup_factor
48 | elif self.warmup_method == "linear":
49 | alpha = self.last_epoch / self.warmup_iters
50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
51 | return [
52 | base_lr
53 | * warmup_factor
54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
55 | for base_lr in self.base_lrs
56 | ]
57 |
--------------------------------------------------------------------------------
/solver/make_optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def make_optimizer(cfg, model, center_criterion):
5 | params = []
6 | for key, value in model.named_parameters():
7 | if not value.requires_grad:
8 | continue
9 | lr = cfg.SOLVER.BASE_LR
10 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
11 | if "bias" in key:
12 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
14 | if cfg.SOLVER.LARGE_FC_LR:
15 | if "classifier" in key or "arcface" in key:
16 | lr = cfg.SOLVER.BASE_LR * 2
17 | print('Using two times learning rate for fc ')
18 |
19 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
20 |
21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
23 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW':
24 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
25 | else:
26 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
27 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR)
28 |
29 | return optimizer, optimizer_center
30 |
--------------------------------------------------------------------------------
/solver/scheduler.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 |
3 | import torch
4 |
5 |
6 | class Scheduler:
7 | """ Parameter Scheduler Base Class
8 | A scheduler base class that can be used to schedule any optimizer parameter groups.
9 |
10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called
11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
13 |
14 | The schedulers built on this should try to remain as stateless as possible (for simplicity).
15 |
16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training
18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call.
19 |
20 | Based on ideas from:
21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
23 | """
24 |
25 | def __init__(self,
26 | optimizer: torch.optim.Optimizer,
27 | param_group_field: str,
28 | noise_range_t=None,
29 | noise_type='normal',
30 | noise_pct=0.67,
31 | noise_std=1.0,
32 | noise_seed=None,
33 | initialize: bool = True) -> None:
34 | self.optimizer = optimizer
35 | self.param_group_field = param_group_field
36 | self._initial_param_group_field = f"initial_{param_group_field}"
37 | if initialize:
38 | for i, group in enumerate(self.optimizer.param_groups):
39 | if param_group_field not in group:
40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
41 | group.setdefault(self._initial_param_group_field, group[param_group_field])
42 | else:
43 | for i, group in enumerate(self.optimizer.param_groups):
44 | if self._initial_param_group_field not in group:
45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
47 | self.metric = None # any point to having this for all?
48 | self.noise_range_t = noise_range_t
49 | self.noise_pct = noise_pct
50 | self.noise_type = noise_type
51 | self.noise_std = noise_std
52 | self.noise_seed = noise_seed if noise_seed is not None else 42
53 | self.update_groups(self.base_values)
54 |
55 | def state_dict(self) -> Dict[str, Any]:
56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
57 |
58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
59 | self.__dict__.update(state_dict)
60 |
61 | def get_epoch_values(self, epoch: int):
62 | return None
63 |
64 | def get_update_values(self, num_updates: int):
65 | return None
66 |
67 | def step(self, epoch: int, metric: float = None) -> None:
68 | self.metric = metric
69 | values = self.get_epoch_values(epoch)
70 | if values is not None:
71 | values = self._add_noise(values, epoch)
72 | self.update_groups(values)
73 |
74 | def step_update(self, num_updates: int, metric: float = None):
75 | self.metric = metric
76 | values = self.get_update_values(num_updates)
77 | if values is not None:
78 | values = self._add_noise(values, num_updates)
79 | self.update_groups(values)
80 |
81 | def update_groups(self, values):
82 | if not isinstance(values, (list, tuple)):
83 | values = [values] * len(self.optimizer.param_groups)
84 | for param_group, value in zip(self.optimizer.param_groups, values):
85 | param_group[self.param_group_field] = value
86 |
87 | def _add_noise(self, lrs, t):
88 | if self.noise_range_t is not None:
89 | if isinstance(self.noise_range_t, (list, tuple)):
90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
91 | else:
92 | apply_noise = t >= self.noise_range_t
93 | if apply_noise:
94 | g = torch.Generator()
95 | g.manual_seed(self.noise_seed + t)
96 | if self.noise_type == 'normal':
97 | while True:
98 | # resample if noise out of percent limit, brute force but shouldn't spin much
99 | noise = torch.randn(1, generator=g).item()
100 | if abs(noise) < self.noise_pct:
101 | break
102 | else:
103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
104 | lrs = [v + v * noise for v in lrs]
105 | return lrs
106 |
--------------------------------------------------------------------------------
/solver/scheduler_factory.py:
--------------------------------------------------------------------------------
1 | """ Scheduler Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from .cosine_lr import CosineLRScheduler
5 |
6 |
7 | def create_scheduler(cfg, optimizer):
8 | num_epochs = cfg.SOLVER.MAX_EPOCHS
9 | # type 1
10 | # lr_min = 0.01 * cfg.SOLVER.BASE_LR
11 | # warmup_lr_init = 0.001 * cfg.SOLVER.BASE_LR
12 | # type 2
13 | lr_min = 0.002 * cfg.SOLVER.BASE_LR
14 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR
15 | # type 3
16 | # lr_min = 0.001 * cfg.SOLVER.BASE_LR
17 | # warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR
18 |
19 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS
20 | noise_range = None
21 |
22 | lr_scheduler = CosineLRScheduler(
23 | optimizer,
24 | t_initial=num_epochs,
25 | lr_min=lr_min,
26 | t_mul= 1.,
27 | decay_rate=0.1,
28 | warmup_lr_init=warmup_lr_init,
29 | warmup_t=warmup_t,
30 | cycle_limit=1,
31 | t_in_epochs=True,
32 | noise_range_t=noise_range,
33 | noise_pct= 0.67,
34 | noise_std= 1.,
35 | noise_seed=42,
36 | )
37 |
38 | return lr_scheduler
39 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from config import cfg
3 | import argparse
4 | from datasets import make_dataloader
5 | from model import make_model
6 | from processor import do_inference
7 | from utils.logger import setup_logger
8 |
9 |
10 | if __name__ == "__main__":
11 | parser = argparse.ArgumentParser(description="ReID Baseline Training")
12 | parser.add_argument(
13 | "--config_file", default="", help="path to config file", type=str
14 | )
15 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
16 | nargs=argparse.REMAINDER)
17 |
18 | args = parser.parse_args()
19 |
20 |
21 |
22 | if args.config_file != "":
23 | cfg.merge_from_file(args.config_file)
24 | cfg.merge_from_list(args.opts)
25 | cfg.freeze()
26 |
27 | output_dir = cfg.OUTPUT_DIR
28 | if output_dir and not os.path.exists(output_dir):
29 | os.makedirs(output_dir)
30 |
31 | logger = setup_logger("transreid", output_dir, if_train=False)
32 | logger.info(args)
33 |
34 | if args.config_file != "":
35 | logger.info("Loaded configuration file {}".format(args.config_file))
36 | with open(args.config_file, 'r') as cf:
37 | config_str = "\n" + cf.read()
38 | logger.info(config_str)
39 | logger.info("Running with config:\n{}".format(cfg))
40 |
41 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID
42 |
43 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
44 |
45 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)
46 | model.load_param(cfg.TEST.WEIGHT)
47 |
48 | if cfg.DATASETS.NAMES == 'VehicleID':
49 | for trial in range(10):
50 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
51 | rank_1, rank5 = do_inference(cfg,
52 | model,
53 | val_loader,
54 | num_query)
55 | if trial == 0:
56 | all_rank_1 = rank_1
57 | all_rank_5 = rank5
58 | else:
59 | all_rank_1 = all_rank_1 + rank_1
60 | all_rank_5 = all_rank_5 + rank5
61 |
62 | logger.info("rank_1:{}, rank_5 {} : trial : {}".format(rank_1, rank5, trial))
63 | logger.info("sum_rank_1:{:.1%}, sum_rank_5 {:.1%}".format(all_rank_1.sum()/10.0, all_rank_5.sum()/10.0))
64 | else:
65 | do_inference(cfg,
66 | model,
67 | val_loader,
68 | num_query)
69 |
70 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from utils.logger import setup_logger
2 | from datasets import make_dataloader
3 | from model import make_model
4 | from solver import make_optimizer
5 | from solver.scheduler_factory import create_scheduler
6 | from loss import make_loss
7 | from processor import do_train
8 | import random
9 | import torch
10 | import numpy as np
11 | import os
12 | import argparse
13 | # from timm.scheduler import create_scheduler
14 | from config import cfg
15 |
16 | def set_seed(seed):
17 | torch.manual_seed(seed)
18 | torch.cuda.manual_seed(seed)
19 | torch.cuda.manual_seed_all(seed)
20 | np.random.seed(seed)
21 | random.seed(seed)
22 | torch.backends.cudnn.deterministic = True
23 | torch.backends.cudnn.benchmark = True
24 |
25 | if __name__ == '__main__':
26 |
27 | parser = argparse.ArgumentParser(description="ReID Baseline Training")
28 | parser.add_argument(
29 | "--config_file", default="", help="path to config file", type=str
30 | )
31 |
32 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
33 | nargs=argparse.REMAINDER)
34 | parser.add_argument("--local_rank", default=0, type=int)
35 | args = parser.parse_args()
36 |
37 | if args.config_file != "":
38 | cfg.merge_from_file(args.config_file)
39 | cfg.merge_from_list(args.opts)
40 | cfg.freeze()
41 |
42 | set_seed(cfg.SOLVER.SEED)
43 |
44 | if cfg.MODEL.DIST_TRAIN:
45 | torch.cuda.set_device(args.local_rank)
46 |
47 | output_dir = cfg.OUTPUT_DIR
48 | if output_dir and not os.path.exists(output_dir):
49 | os.makedirs(output_dir)
50 |
51 | logger = setup_logger("transreid", output_dir, if_train=True)
52 | logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR))
53 | logger.info(args)
54 |
55 | if args.config_file != "":
56 | logger.info("Loaded configuration file {}".format(args.config_file))
57 | with open(args.config_file, 'r') as cf:
58 | config_str = "\n" + cf.read()
59 | logger.info(config_str)
60 | logger.info("Running with config:\n{}".format(cfg))
61 |
62 | if cfg.MODEL.DIST_TRAIN:
63 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
64 |
65 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID
66 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
67 |
68 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)
69 |
70 | loss_func, center_criterion = make_loss(cfg, num_classes=num_classes)
71 |
72 | optimizer, optimizer_center = make_optimizer(cfg, model, center_criterion)
73 |
74 | scheduler = create_scheduler(cfg, optimizer)
75 |
76 | do_train(
77 | cfg,
78 | model,
79 | center_criterion,
80 | train_loader,
81 | val_loader,
82 | optimizer,
83 | optimizer_center,
84 | scheduler,
85 | loss_func,
86 | num_query, args.local_rank
87 | )
88 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/iotools.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/iotools.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/logger.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/meter.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/meter.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/metrics.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/reranking.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/reranking.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/iotools.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import errno
8 | import json
9 | import os
10 |
11 | import os.path as osp
12 |
13 |
14 | def mkdir_if_missing(directory):
15 | if not osp.exists(directory):
16 | try:
17 | os.makedirs(directory)
18 | except OSError as e:
19 | if e.errno != errno.EEXIST:
20 | raise
21 |
22 |
23 | def check_isfile(path):
24 | isfile = osp.isfile(path)
25 | if not isfile:
26 | print("=> Warning: no file found at '{}' (ignored)".format(path))
27 | return isfile
28 |
29 |
30 | def read_json(fpath):
31 | with open(fpath, 'r') as f:
32 | obj = json.load(f)
33 | return obj
34 |
35 |
36 | def write_json(obj, fpath):
37 | mkdir_if_missing(osp.dirname(fpath))
38 | with open(fpath, 'w') as f:
39 | json.dump(obj, f, indent=4, separators=(',', ': '))
40 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import os.path as osp
5 | def setup_logger(name, save_dir, if_train):
6 | logger = logging.getLogger(name)
7 | logger.setLevel(logging.DEBUG)
8 |
9 | ch = logging.StreamHandler(stream=sys.stdout)
10 | ch.setLevel(logging.DEBUG)
11 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
12 | ch.setFormatter(formatter)
13 | logger.addHandler(ch)
14 |
15 | if save_dir:
16 | if not osp.exists(save_dir):
17 | os.makedirs(save_dir)
18 | if if_train:
19 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w')
20 | else:
21 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='w')
22 | fh.setLevel(logging.DEBUG)
23 | fh.setFormatter(formatter)
24 | logger.addHandler(fh)
25 |
26 | return logger
--------------------------------------------------------------------------------
/utils/meter.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | """Computes and stores the average and current value"""
3 |
4 | def __init__(self):
5 | self.val = 0
6 | self.avg = 0
7 | self.sum = 0
8 | self.count = 0
9 |
10 | def reset(self):
11 | self.val = 0
12 | self.avg = 0
13 | self.sum = 0
14 | self.count = 0
15 |
16 | def update(self, val, n=1):
17 | self.val = val
18 | self.sum += val * n
19 | self.count += n
20 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | from utils.reranking import re_ranking
5 |
6 |
7 | def euclidean_distance(qf, gf):
8 | m = qf.shape[0]
9 | n = gf.shape[0]
10 | dist_mat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
11 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
12 | dist_mat.addmm_(1, -2, qf, gf.t())
13 | return dist_mat.cpu().numpy()
14 |
15 | def cosine_similarity(qf, gf):
16 | epsilon = 0.00001
17 | dist_mat = qf.mm(gf.t())
18 | qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True) # mx1
19 | gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True) # nx1
20 | qg_normdot = qf_norm.mm(gf_norm.t())
21 |
22 | dist_mat = dist_mat.mul(1 / qg_normdot).cpu().numpy()
23 | dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon)
24 | dist_mat = np.arccos(dist_mat)
25 | return dist_mat
26 |
27 |
28 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
29 | """Evaluation with market1501 metric
30 | Key: for each query identity, its gallery images from the same camera view are discarded.
31 | """
32 | num_q, num_g = distmat.shape
33 | # distmat g
34 | # q 1 3 2 4
35 | # 4 1 2 3
36 | if num_g < max_rank:
37 | max_rank = num_g
38 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
39 | indices = np.argsort(distmat, axis=1)
40 | # 0 2 1 3
41 | # 1 2 3 0
42 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
43 | # compute cmc curve for each query
44 | all_cmc = []
45 | all_AP = []
46 | num_valid_q = 0. # number of valid query
47 | for q_idx in range(num_q):
48 | # get query pid and camid
49 | q_pid = q_pids[q_idx]
50 | q_camid = q_camids[q_idx]
51 |
52 | # remove gallery samples that have the same pid and camid with query
53 | order = indices[q_idx] # select one row
54 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
55 | keep = np.invert(remove)
56 |
57 | # compute cmc curve
58 | # binary vector, positions with value 1 are correct matches
59 | orig_cmc = matches[q_idx][keep]
60 | if not np.any(orig_cmc):
61 | # this condition is true when query identity does not appear in gallery
62 | continue
63 |
64 | cmc = orig_cmc.cumsum()
65 | cmc[cmc > 1] = 1
66 |
67 | all_cmc.append(cmc[:max_rank])
68 | num_valid_q += 1.
69 |
70 | # compute average precision
71 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
72 | num_rel = orig_cmc.sum()
73 | tmp_cmc = orig_cmc.cumsum()
74 | #tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
75 | y = np.arange(1, tmp_cmc.shape[0] + 1) * 1.0
76 | tmp_cmc = tmp_cmc / y
77 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
78 | AP = tmp_cmc.sum() / num_rel
79 | all_AP.append(AP)
80 |
81 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
82 |
83 | all_cmc = np.asarray(all_cmc).astype(np.float32)
84 | all_cmc = all_cmc.sum(0) / num_valid_q
85 | mAP = np.mean(all_AP)
86 |
87 | return all_cmc, mAP
88 |
89 |
90 | class R1_mAP_eval():
91 | def __init__(self, num_query, max_rank=50, feat_norm=True, reranking=False):
92 | super(R1_mAP_eval, self).__init__()
93 | self.num_query = num_query
94 | self.max_rank = max_rank
95 | self.feat_norm = feat_norm
96 | self.reranking = reranking
97 |
98 | def reset(self):
99 | self.feats = []
100 | self.pids = []
101 | self.camids = []
102 |
103 | def update(self, output): # called once for each batch
104 | feat, pid, camid = output
105 | self.feats.append(feat.cpu())
106 | self.pids.extend(np.asarray(pid))
107 | self.camids.extend(np.asarray(camid))
108 |
109 | def compute(self): # called after each epoch
110 | feats = torch.cat(self.feats, dim=0)
111 | if self.feat_norm:
112 | print("The test feature is normalized")
113 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) # along channel
114 | # query
115 | qf = feats[:self.num_query]
116 | q_pids = np.asarray(self.pids[:self.num_query])
117 | q_camids = np.asarray(self.camids[:self.num_query])
118 | # gallery
119 | gf = feats[self.num_query:]
120 | g_pids = np.asarray(self.pids[self.num_query:])
121 |
122 | g_camids = np.asarray(self.camids[self.num_query:])
123 | if self.reranking:
124 | print('=> Enter reranking')
125 | # distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3)
126 | distmat = re_ranking(qf, gf, k1=50, k2=15, lambda_value=0.3)
127 |
128 | else:
129 | print('=> Computing DistMat with euclidean_distance')
130 | distmat = euclidean_distance(qf, gf)
131 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
132 |
133 | return cmc, mAP, distmat, self.pids, self.camids, qf, gf
134 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/utils/reranking.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Fri, 25 May 2018 20:29:09
5 |
6 |
7 | """
8 |
9 | """
10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking
13 | """
14 |
15 | """
16 | API
17 |
18 | probFea: all feature vectors of the query set (torch tensor)
19 | probFea: all feature vectors of the gallery set (torch tensor)
20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3)
21 | MemorySave: set to 'True' when using MemorySave mode
22 | Minibatch: avaliable when 'MemorySave' is 'True'
23 | """
24 |
25 | import numpy as np
26 | import torch
27 |
28 |
29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False):
30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor
31 | query_num = probFea.size(0)
32 | all_num = query_num + galFea.size(0)
33 | if only_local:
34 | original_dist = local_distmat
35 | else:
36 | feat = torch.cat([probFea, galFea])
37 | # print('using GPU to compute original distance')
38 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \
39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t()
40 | distmat.addmm_(1, -2, feat, feat.t())
41 | original_dist = distmat.cpu().numpy()
42 | del feat
43 | if not local_distmat is None:
44 | original_dist = original_dist + local_distmat
45 | gallery_num = original_dist.shape[0]
46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
47 | V = np.zeros_like(original_dist).astype(np.float16)
48 | initial_rank = np.argsort(original_dist).astype(np.int32)
49 |
50 | # print('starting re_ranking')
51 | for i in range(all_num):
52 | # k-reciprocal neighbors
53 | forward_k_neigh_index = initial_rank[i, :k1 + 1]
54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
55 | fi = np.where(backward_k_neigh_index == i)[0]
56 | k_reciprocal_index = forward_k_neigh_index[fi]
57 | k_reciprocal_expansion_index = k_reciprocal_index
58 | for j in range(len(k_reciprocal_index)):
59 | candidate = k_reciprocal_index[j]
60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
62 | :int(np.around(k1 / 2)) + 1]
63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
66 | candidate_k_reciprocal_index):
67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
68 |
69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
72 | original_dist = original_dist[:query_num, ]
73 | if k2 != 1:
74 | V_qe = np.zeros_like(V, dtype=np.float16)
75 | for i in range(all_num):
76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
77 | V = V_qe
78 | del V_qe
79 | del initial_rank
80 | invIndex = []
81 | for i in range(gallery_num):
82 | invIndex.append(np.where(V[:, i] != 0)[0])
83 |
84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
85 |
86 | for i in range(query_num):
87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
88 | indNonZero = np.where(V[i, :] != 0)[0]
89 | indImages = [invIndex[ind] for ind in indNonZero]
90 | for j in range(len(indNonZero)):
91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
92 | V[indImages[j], indNonZero[j]])
93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
94 |
95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
96 | del original_dist
97 | del V
98 | del jaccard_dist
99 | final_dist = final_dist[:query_num, query_num:]
100 | return final_dist
101 |
102 |
--------------------------------------------------------------------------------