├── LICENSE
├── README.md
├── config
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── __init__.cpython-36.pyc
│ ├── defaults.cpython-35.pyc
│ └── defaults.cpython-36.pyc
└── defaults.py
├── configs
└── video_baseline.yml
├── data
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── __init__.cpython-36.pyc
│ ├── build.cpython-35.pyc
│ ├── build.cpython-36.pyc
│ ├── collate_batch.cpython-35.pyc
│ └── collate_batch.cpython-36.pyc
├── build.py
├── collate_batch.py
├── datasets
│ ├── DukeV.py
│ ├── MARS.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── DukeV.cpython-36.pyc
│ │ ├── MARS.cpython-35.pyc
│ │ ├── MARS.cpython-36.pyc
│ │ ├── __init__.cpython-35.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── bases.cpython-35.pyc
│ │ ├── bases.cpython-36.pyc
│ │ ├── dataset_loader.cpython-35.pyc
│ │ ├── dataset_loader.cpython-36.pyc
│ │ ├── dukemtmcreid.cpython-35.pyc
│ │ ├── dukemtmcreid.cpython-36.pyc
│ │ ├── eval_reid.cpython-36.pyc
│ │ ├── market1501.cpython-35.pyc
│ │ ├── market1501.cpython-36.pyc
│ │ ├── msmt17.cpython-35.pyc
│ │ ├── msmt17.cpython-36.pyc
│ │ ├── veri.cpython-35.pyc
│ │ └── veri.cpython-36.pyc
│ ├── bases.py
│ ├── dataset_loader.py
│ └── eval_reid.py
├── samplers
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-35.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── triplet_sampler.cpython-35.pyc
│ │ └── triplet_sampler.cpython-36.pyc
│ └── triplet_sampler.py
└── transforms
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── __init__.cpython-36.pyc
│ ├── build.cpython-35.pyc
│ ├── build.cpython-36.pyc
│ ├── spatial_transforms.cpython-36.pyc
│ ├── temporal_transforms.cpython-36.pyc
│ ├── transforms.cpython-35.pyc
│ └── transforms.cpython-36.pyc
│ ├── build.py
│ ├── temporal_transforms.py
│ └── transforms.py
├── engine
├── __pycache__
│ ├── data_parallel.cpython-36.pyc
│ ├── inference.cpython-36.pyc
│ ├── scatter_gather.cpython-36.pyc
│ ├── trainer.cpython-35.pyc
│ ├── trainer.cpython-36.pyc
│ └── vis.cpython-36.pyc
├── data_parallel.py
├── inference.py
├── scatter_gather.py
└── trainer.py
├── imgs
├── DL.png
└── DL_2.png
├── layers
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── center_loss.cpython-36.pyc
│ ├── old_triplet_loss.cpython-36.pyc
│ └── triplet_loss.cpython-36.pyc
├── center_loss.py
└── triplet_loss.py
├── modeling
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── baseline.cpython-36.pyc
│ └── network.cpython-36.pyc
├── backbones
│ ├── ResNet.py
│ ├── SA
│ │ ├── AP3D.py
│ │ ├── NonLocal.py
│ │ ├── SelfAttn.py
│ │ ├── __pycache__
│ │ │ ├── AP3D.cpython-36.pyc
│ │ │ ├── NonLocal.cpython-36.pyc
│ │ │ ├── SelfAttn.cpython-36.pyc
│ │ │ └── inflate.cpython-36.pyc
│ │ └── inflate.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── ResNet.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── non_local.cpython-36.pyc
│ │ ├── resnet.cpython-36.pyc
│ │ └── resnet_NL.cpython-36.pyc
│ ├── non_local.py
│ ├── resnet.py
│ └── resnet_NL.py
└── network.py
├── requirements.txt
├── scripts
├── AA_D.sh
├── AA_M.sh
├── NL_D.sh
├── NL_M.sh
├── baseline_D.sh
├── baseline_M.sh
└── test_M.sh
├── solver
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── build.cpython-36.pyc
│ └── lr_scheduler.cpython-36.pyc
├── build.py
└── lr_scheduler.py
├── tests
├── __init__.py
└── lr_scheduler_test.py
├── tools
├── __init__.py
├── test.py
└── train.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-35.pyc
├── __init__.cpython-36.pyc
├── iotools.cpython-35.pyc
├── iotools.cpython-36.pyc
├── logger.cpython-36.pyc
├── re_ranking.cpython-36.pyc
└── reid_metric.cpython-36.pyc
├── iotools.py
├── logger.py
├── re_ranking.py
└── reid_metric.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Chih-Ting Liu (劉致廷)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Video-based Person Re-identification without Bells and Whistles
2 |
3 | [[Paper]](http://media.ee.ntu.edu.tw/research/CFAAN/paper/CVPRw21_VideoReID.pdf) [[arXiv]](https://arxiv.org/pdf/2105.10678.pdf) [[video]](https://youtu.be/RNssJNmq504)
4 |
5 | [Chih-Ting Liu](https://jackie840129.github.io/), [Jun-Cheng Chen](https://www.citi.sinica.edu.tw/pages/pullpull/contact_en.html), [Chu-Song Chen](https://imp.iis.sinica.edu.tw/) and [Shao-Yi Chien](http://www.ee.ntu.edu.tw/profile?id=101),
Analysis & Modeling of Faces & Gestures Workshop jointly with IEEE Conference on Computer Vision and Pattern Recognition (**CVPRw**), 2021
6 |
7 | This is the pytorch implementatin of Coarse-to-Fine Axial Attention Network **(CF-AAN)** for video-based person Re-ID.
8 |
It achieves **91.3%** in rank-1 accuracy and **86.5%** in mAP on our aligned MARS dataset.
9 |
10 | ## News
11 |
12 | **`2021-06-13`**:
13 | - We release the code and aligned dataset for our work.
14 | - We update the Readme related to our new dataset, and the others will be updated gradually.
15 |
16 | **`2021-06-18`**:
17 | - We update the description for training and testing CF-AAN.
18 |
19 | ## Aligned dataset with our re-Detect and Link module
20 |
21 | ### Download Link :
22 |
23 | - MARS (DL) : [[Google Drive]](https://drive.google.com/file/d/1adP39y7xoKYX8Z4lyBtZiDTg9kZyK1Cx/view?usp=sharing)
24 | - For DukeV, we didn't perform DL on DukeMTMC-VideoReID because the bounding boxes are greound truth annotations.
25 |
26 | ### Results
27 | The video tracklet will be re-Detected, linked (tracking) and padded to the original image size, as follow.
28 |

29 |
30 | ### Folder Structure
31 | MARS dataset:
32 | ```
33 | MARS-DL/
34 | |-- bbox_train/
35 | |-- bbox_test/
36 | |-- info/
37 | |-- |-- mask_info.csv (for DL mask)
38 | |-- |-- mask_info_test.csv (for DL mask)
39 | |-- |-- clean_tracks_test_info.mat (for new evaluation protocol)
40 | |-- |-- .... (other original info files)
41 | ```
42 | DukeV dataset:
43 | ```
44 | DukeMTMC-VideoReID/
45 | |-- train/
46 | |-- gallery/
47 | |-- query/
48 | ```
49 | You can put this two folders under your root dataset directory.
50 | ```
51 | path to your root dir/
52 | |-- MARS-DL/
53 | |-- DukeMTMC-VideoReID/
54 | ```
55 | ## Coarse-to-Fine Axial Attention Network (CF-AAN)
56 |
57 | ### Requirement
58 | We use Python 3.6, Pytorch 1.5 and Pytorch-ignite in this project. To install required modules, run:
59 | ```
60 | pip3 install -r requirements.txt
61 | ```
62 | ### Training
63 | #### Train CF-AAN on MARS-DL
64 | You can alter the argument in `scripts/AA_M.sh` and run it with:
65 | ```
66 | sh scripts/AA_M.sh
67 | ```
68 | Or, you can directly type:
69 | ```
70 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('mars',)" INPUT.SEQ_LEN 6 \
71 | OUTPUT_DIR "./ckpt_DL_M/MARS_DL_s6_resnet_axial_gap_rqkv_gran4" SOLVER.SOFT_MARGIN True \
72 | MODEL.NAME 'resnet50_axial' MODEL.TEMP 'Done' INPUT.IF_RE True \
73 | DATASETS.ROOT_DIR ''
74 | ```
75 | \* `` is the directory containing both MARS and DukeV dataset.
76 | #### Train Non-local or baseline on MARS
77 | You can alter the argument in `scripts/NL_M.sh` & `scripts/baseline_M.sh` and run it with:
78 |
79 | `sh scripts/AA_M.sh` & `sh scripts/baseline_M.sh`
80 | #### Train models on DukeMTMC-VideoReID
81 | You can use the scripts `scripts/AA_D.sh`, `scripts/NL_D.sh`, & `scripts/baseline_D.sh`
82 |
83 | #### Notes
84 | If you want to train on original MARS dataset, you just need to change the comment in `data/datasets/MARS.py` :
85 | ```
86 | class MARS(BaseVideoDataset):
87 | dataset_dir = 'MARS'
88 | # dataset_dir = 'MARS-DL'
89 | info_dir = 'info
90 | ```
91 |
92 | ### Testing
93 | You can alter the argument in `scripts/test_M.sh` and run it with:
94 | ```
95 | sh scripts/test_M.sh
96 | ```
97 | \* `TEST.WEIGHT` is the path for the saved pytorch (.pth) model.
98 |
99 | \* There are four modes for `TEST.TEST_MODE`.
100 | 1. `TEST.TEST_MODE 'test'`
101 | * Use RRS[3] testing mode, which samples the first image of T snippets split from tracklet.
102 | 2. `TEST.TEST_MODE 'test_0'`
103 | * Sample first T images in tracklet.
104 | 3. `TEST.TEST_MODE 'test_all_sampled'`
105 | * Create N/T tracklets (all 1st image from T RRS snippets, all 2nd from T RRS snippets...), and average the N/T features.
106 | 4. `TEST.TEST_MODE 'test_all_continuous'`
107 | * Continuous smaple T frames, create N/T tracklets, and average the N/T features.
108 |
109 | If you want to test on DukeV, you can just alter the corresponding arguments in `scripts/test_M.sh`.
110 |
111 | ## New Evaluatoin Protocol
112 |
113 | Change the `TEST.NEW_EVAL False` to `TEST.NEW_EVAL True`.
114 |
115 | The details will be introduced soon.
116 |
117 | ## Citation
118 | ```
119 | @InProceedings{Liu_2021_CVPR,
120 | author = {Liu, Chih-Ting and Chen, Jun-Cheng and Chen, Chu-Song and Chien, Shao-Yi},
121 | title = {Video-Based Person Re-Identification Without Bells and Whistles},
122 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
123 | month = {June},
124 | year = {2021},
125 | pages = {1491-1500}
126 | }
127 | ```
128 | ## Reference
129 |
130 | 1. The structure of our code are based on [reid-strong-baseline](https://github.com/michuanhaohao/reid-strong-baseline).
131 | 2. Some codes of our CF-AAN are based on [axial-deeplab](https://github.com/csrhddlam/axial-deeplab)
132 | 3. Li, Shuang, et al. "Diversity regularized spatiotemporal attention for video-based person re-identification." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.
133 | ## Contact
134 |
135 | [Chih-Ting Liu](https://jackie840129.github.io/), [Media IC & System Lab](https://github.com/mediaic), National Taiwan University
136 |
137 | E-mail : jackieliu@media.ee.ntu.edu.tw
138 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .defaults import _C as cfg
8 |
--------------------------------------------------------------------------------
/config/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/config/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/config/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/config/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/config/__pycache__/defaults.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/config/__pycache__/defaults.cpython-35.pyc
--------------------------------------------------------------------------------
/config/__pycache__/defaults.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/config/__pycache__/defaults.cpython-36.pyc
--------------------------------------------------------------------------------
/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # -----------------------------------------------------------------------------
4 | # Convention about Training / Test specific parameters
5 | # -----------------------------------------------------------------------------
6 | # Whenever an argument can be either used for training or for testing, the
7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter,
8 | # or _TEST for a test-specific parameter.
9 | # For example, the number of images during training will be
10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
11 | # IMAGES_PER_BATCH_TEST
12 |
13 | # -----------------------------------------------------------------------------
14 | # Config definition
15 | # -----------------------------------------------------------------------------
16 |
17 | _C = CN()
18 |
19 | _C.MODEL = CN()
20 | # Using cuda or cpu for training
21 | _C.MODEL.DEVICE = "cuda"
22 | # ID number of GPU
23 | _C.MODEL.DEVICE_ID = '0'
24 | # Name of backbone
25 | _C.MODEL.NAME = 'resnet50'
26 | # Last stride of backbone
27 | _C.MODEL.LAST_STRIDE = 1
28 | # Path to pretrained model of backbone
29 | _C.MODEL.PRETRAIN_PATH = ''
30 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model
31 | # Options: 'imagenet' or 'self'
32 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet'
33 | # If train with BNNeck, options: 'bnneck' or 'no'
34 | _C.MODEL.NECK = 'bnneck'
35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration
36 | _C.MODEL.IF_WITH_CENTER = 'no'
37 | # The loss type of metric loss
38 | # options:['triplet'](without center loss) or ['center','triplet_center'](with center loss)
39 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet'
40 | # For example, if loss type is cross entropy loss + triplet loss + center loss
41 | # the setting should be: _C.MODEL.METRIC_LOSS_TYPE = 'triplet_center' and _C.MODEL.IF_WITH_CENTER = 'yes'
42 |
43 | # If train with label smooth, options: 'on', 'off'
44 | _C.MODEL.IF_LABELSMOOTH = 'on'
45 |
46 | # Video or Image based
47 | _C.MODEL.SETTING = 'image'
48 |
49 | _C.MODEL.TEMP = 'avg'
50 | _C.MODEL.NON_LAYERS = [0,0,0,0]
51 |
52 |
53 | # -----------------------------------------------------------------------------
54 | # INPUT
55 | # -----------------------------------------------------------------------------
56 | _C.INPUT = CN()
57 | # Size of the image during training
58 | _C.INPUT.SIZE_TRAIN = [384, 128]
59 | # Size of the image during test
60 | _C.INPUT.SIZE_TEST = [384, 128]
61 | # Random probability for image horizontal flip
62 | _C.INPUT.PROB = 0.5
63 | # Random probability for random erasing
64 | _C.INPUT.RE_PROB = 0.5
65 | # Values to be used for image normalization
66 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
67 | # Values to be used for image normalization
68 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
69 | # Value of padding size
70 | _C.INPUT.PADDING = 10
71 | # augmentation on/off
72 | _C.INPUT.IF_CROP = True
73 | _C.INPUT.IF_RE = True
74 | _C.INPUT.IF_FLIP = True
75 |
76 | # for video re-id
77 | _C.INPUT.MIN_SEQ_LEN = 0
78 | _C.INPUT.SAMPLE = 'RRS'
79 | _C.INPUT.SEQ_LEN = 8
80 |
81 | # -----------------------------------------------------------------------------
82 | # Dataset
83 | # -----------------------------------------------------------------------------
84 | _C.DATASETS = CN()
85 | # List of the dataset names for training, as present in paths_catalog.py
86 | _C.DATASETS.NAMES = ('market1501',)
87 | # Root directory where datasets should be used (and downloaded if not found)
88 | _C.DATASETS.ROOT_DIR = ('/home/mediax/Dataset/')
89 |
90 | # -----------------------------------------------------------------------------
91 | # DataLoader
92 | # -----------------------------------------------------------------------------
93 | _C.DATALOADER = CN()
94 | # Number of data loading threads
95 | _C.DATALOADER.NUM_WORKERS = 8
96 | # Sampler for data loading
97 | _C.DATALOADER.SAMPLER = 'softmax'
98 | # Number of instance for one batch
99 | _C.DATALOADER.NUM_INSTANCE = 16
100 |
101 | # ---------------------------------------------------------------------------- #
102 | # Solver
103 | # ---------------------------------------------------------------------------- #
104 | _C.SOLVER = CN()
105 | # Name of optimizer
106 | _C.SOLVER.OPTIMIZER_NAME = "Adam"
107 | # Number of max epoches
108 | _C.SOLVER.MAX_EPOCHS = 50
109 | # Base learning rate
110 | _C.SOLVER.BASE_LR = 3e-4
111 | # Factor of learning bias
112 | _C.SOLVER.BIAS_LR_FACTOR = 2
113 | # Momentum
114 | _C.SOLVER.MOMENTUM = 0.9
115 | # Margin of triplet loss
116 | _C.SOLVER.MARGIN = 0.3
117 | _C.SOLVER.SOFT_MARGIN = False
118 | # Margin of cluster ;pss
119 | _C.SOLVER.CLUSTER_MARGIN = 0.3
120 | # Learning rate of SGD to learn the centers of center loss
121 | _C.SOLVER.CENTER_LR = 0.5
122 | # Balanced weight of center loss
123 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005
124 | # Settings of range loss
125 | _C.SOLVER.RANGE_K = 2
126 | _C.SOLVER.RANGE_MARGIN = 0.3
127 | _C.SOLVER.RANGE_ALPHA = 0
128 | _C.SOLVER.RANGE_BETA = 1
129 | _C.SOLVER.RANGE_LOSS_WEIGHT = 1
130 |
131 | # Settings of weight decay
132 | _C.SOLVER.WEIGHT_DECAY = 0.0005
133 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.
134 |
135 | # decay rate of learning rate
136 | _C.SOLVER.GAMMA = 0.1
137 | # decay step of learning rate
138 | _C.SOLVER.STEPS = (30, 55)
139 |
140 | # warm up factor
141 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3
142 | # iterations of warm up
143 | _C.SOLVER.WARMUP_ITERS = 500
144 | # method of warm up, option: 'constant','linear'
145 | _C.SOLVER.WARMUP_METHOD = "linear"
146 |
147 | # epoch number of saving checkpoints
148 | _C.SOLVER.CHECKPOINT_PERIOD = 50
149 | # iteration of display training log
150 | _C.SOLVER.LOG_PERIOD = 100
151 | # epoch number of validation
152 | _C.SOLVER.EVAL_PERIOD = 50
153 |
154 | # Number of images per batch
155 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
156 | # see 2 images per batch
157 | _C.SOLVER.IMS_PER_BATCH = 64
158 |
159 | # +
160 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
161 | # see 2 images per batch
162 | _C.TEST = CN()
163 | # Number of images per batch during test
164 | _C.TEST.IMS_PER_BATCH = 128
165 | # If test with re-ranking, options: 'yes','no'
166 | _C.TEST.RE_RANKING = 'no'
167 | # Path to trained model
168 | _C.TEST.WEIGHT = ""
169 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after'
170 | _C.TEST.NECK_FEAT = 'after'
171 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance
172 | _C.TEST.FEAT_NORM = 'yes'
173 |
174 | _C.TEST.TEST_MODE = 'test'
175 |
176 | _C.TEST.NEW_EVAL = False
177 | # -
178 |
179 | # ---------------------------------------------------------------------------- #
180 | # Misc options
181 | # ---------------------------------------------------------------------------- #
182 | # Path to checkpoint and saved log of trained model
183 | _C.OUTPUT_DIR = ""
184 |
--------------------------------------------------------------------------------
/configs/video_baseline.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | PRETRAIN_PATH: ''
4 | NAME : 'resnet50'
5 | METRIC_LOSS_TYPE: 'triplet'
6 | IF_LABELSMOOTH: 'no'
7 | IF_WITH_CENTER: 'no'
8 | SETTING : 'video'
9 | TEMP : 'avg'
10 |
11 | INPUT:
12 | SIZE_TRAIN: [256, 128]
13 | SIZE_TEST: [256, 128]
14 | PROB: 0.5 # random horizontal flip
15 | RE_PROB: 0.5 # random erasing
16 | PADDING: 10
17 | #video reid
18 | IF_FLIP : False
19 | IF_CROP : False
20 | IF_RE : False
21 | MIN_SEQ_LEN : 0
22 | SEQ_LEN : 6
23 | SAMPLE : 'RRS'
24 |
25 | DATASETS:
26 | NAMES: ('mars',)
27 |
28 | DATALOADER:
29 | SAMPLER: 'softmax_triplet'
30 | NUM_INSTANCE: 4
31 | NUM_WORKERS: 8
32 |
33 | SOLVER:
34 | OPTIMIZER_NAME: 'Adam'
35 | MAX_EPOCHS: 220
36 | BASE_LR: 0.0001
37 |
38 | MARGIN : 0.3
39 | SOFT_MARGIN : True
40 |
41 | CENTER_LR: 0.5
42 | CENTER_LOSS_WEIGHT: 0.0005
43 |
44 | BIAS_LR_FACTOR: 1
45 | WEIGHT_DECAY: 5e-5
46 | WEIGHT_DECAY_BIAS: 5e-5
47 | IMS_PER_BATCH: 32
48 |
49 | STEPS: [50, 100, 150, 200]
50 | GAMMA: 0.1
51 |
52 | WARMUP_FACTOR: 0.01
53 | WARMUP_ITERS: 10
54 | WARMUP_METHOD: 'linear'
55 |
56 | EVAL_PERIOD: 10
57 |
58 | TEST:
59 | IMS_PER_BATCH: 32
60 | RE_RANKING: 'no'
61 | WEIGHT: "path"
62 | NECK_FEAT: 'after'
63 | FEAT_NORM: 'yes'
64 | TEST_MODE: 'test'
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .build import make_data_loader
8 |
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/build.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/build.cpython-35.pyc
--------------------------------------------------------------------------------
/data/__pycache__/build.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/build.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/collate_batch.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/collate_batch.cpython-35.pyc
--------------------------------------------------------------------------------
/data/__pycache__/collate_batch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/collate_batch.cpython-36.pyc
--------------------------------------------------------------------------------
/data/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import numpy as np
3 | from torch.utils.data import DataLoader
4 |
5 | from .collate_batch import train_collate_fn, val_collate_fn
6 | from .datasets import init_dataset, ImageDataset,VideoDataset
7 | from .samplers import RandomIdentitySampler, RandomIdentitySampler_alignedreid
8 | from .transforms import build_transforms_ST
9 |
10 |
11 | def make_data_loader(cfg):
12 | ##### build transform #####
13 | train_spatial_transforms , _ = build_transforms_ST(cfg, is_train=True)
14 | val_spatial_transforms, val_temporal_transforms = build_transforms_ST(cfg, is_train=False)
15 | num_workers = cfg.DATALOADER.NUM_WORKERS
16 |
17 | ##### init dataset-specific object #####
18 | if cfg.MODEL.SETTING == 'video':
19 | dataset = init_dataset(cfg.DATASETS.NAMES[0], root=cfg.DATASETS.ROOT_DIR,min_seq_len=cfg.INPUT.MIN_SEQ_LEN,new_eval=cfg.TEST.NEW_EVAL)
20 | else:
21 | raise NotImplementedError()
22 |
23 | num_classes = dataset.num_train_pids
24 | ##### create real pytorch Dataset #####
25 | if cfg.MODEL.SETTING == 'video':
26 | train_set = VideoDataset(dataset.train,cfg.INPUT.SEQ_LEN, cfg.INPUT.SAMPLE, train_spatial_transforms, None, mode='train')
27 | else:
28 | raise NotImplementedError()
29 |
30 | ##### create dataloader #####
31 | if cfg.DATALOADER.SAMPLER == 'softmax':
32 | train_loader = DataLoader(
33 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
34 | collate_fn=train_collate_fn
35 | )
36 | else:
37 | train_loader = DataLoader(
38 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
39 | worker_init_fn= lambda _:np.random.seed(),
40 | sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE),
41 | num_workers=num_workers, collate_fn=train_collate_fn,drop_last=True
42 | )
43 | if cfg.MODEL.SETTING == 'video':
44 | val_set = VideoDataset(dataset.query + dataset.gallery, cfg.INPUT.SEQ_LEN,cfg.INPUT.SAMPLE, val_spatial_transforms,val_temporal_transforms, mode=cfg.TEST.TEST_MODE)
45 | else:
46 | raise NotImplementedError()
47 |
48 | val_loader = DataLoader(
49 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
50 | collate_fn=val_collate_fn
51 | )
52 | return train_loader, val_loader, len(dataset.query), num_classes
53 |
--------------------------------------------------------------------------------
/data/collate_batch.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import torch
3 |
4 |
5 | def train_collate_fn(batch):
6 | if len(batch[0]) == 4:
7 | imgs, pids, _,masks = zip(*batch)
8 | pids = torch.tensor(pids, dtype=torch.int64)
9 | return torch.stack(imgs, dim=0), pids ,torch.stack(masks,dim=0)
10 | imgs, pids, _, = zip(*batch)
11 | pids = torch.tensor(pids, dtype=torch.int64)
12 | return torch.stack(imgs, dim=0), pids
13 |
14 |
15 | def val_collate_fn(batch):
16 | if len(batch[0]) == 4:
17 | imgs, pids, camids ,masks = zip(*batch)
18 | return torch.stack(imgs, dim=0), pids , camids, torch.stack(masks,dim=0)
19 | elif len(batch[0]) == 5 :
20 | imgs, pids, ambi, camids ,masks = zip(*batch)
21 | return torch.stack(imgs, dim=0), pids , ambi, camids, torch.stack(masks,dim=0)
22 | imgs, pids, camids = zip(*batch)
23 | return torch.stack(imgs, dim=0), pids, camids
24 |
--------------------------------------------------------------------------------
/data/datasets/DukeV.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import glob
3 | import re
4 | import json
5 | import pickle
6 | import os
7 | import os.path as osp
8 | from scipy.io import loadmat
9 | from .bases import BaseVideoDataset
10 | import pandas as pd
11 | import numpy as np
12 |
13 |
14 | class DukeV(BaseVideoDataset):
15 | dataset_dir = 'DukeMTMC-VideoReID'
16 |
17 | def __init__(self, root='/home/mediax/Dataset', verbose=True, min_seq_len =0,info_dir='./DukeV_info',new_eval=False):
18 | super(DukeV, self).__init__()
19 | self.dataset_dir = osp.join(root, self.dataset_dir)
20 |
21 | self.train_dir = osp.join(self.dataset_dir,'train')
22 | self.gallery_dir = osp.join(self.dataset_dir,'gallery')
23 | self.query_dir = osp.join(self.dataset_dir,'query')
24 | self.min_seq_len = min_seq_len
25 | #for self-created duke info
26 | if 'DL' in self.dataset_dir:
27 | info_dir = './DukeV_DL_info'
28 | self.train_pkl = osp.join(info_dir,'train.pkl')
29 | self.gallery_pkl = osp.join(info_dir,'gallery.pkl')
30 | self.query_pkl = osp.join(info_dir,'query.pkl')
31 | self.info_dir = info_dir
32 | self._check_before_run()
33 |
34 | if 'DL' in self.dataset_dir:
35 | train_mask_csv = pd.read_csv(osp.join(self.dataset_dir,'duke_mask_info.csv'),sep=',',header=None).values
36 | query_mask_csv = pd.read_csv(osp.join(self.dataset_dir,'duke_mask_info_query.csv'),sep=',',header=None).values
37 | gallery_mask_csv = pd.read_csv(osp.join(self.dataset_dir,'duke_mask_info_gallery.csv'),sep=',',header=None).values
38 | else:
39 | train_mask_csv,query_mask_csv, gallery_mask_csv = None,None,None
40 |
41 | train = self._process_dir(self.train_dir,self.train_pkl,relabel=True,mask_info=train_mask_csv)
42 | gallery = self._process_dir(self.gallery_dir,self.gallery_pkl,relabel=False,mask_info=gallery_mask_csv)
43 | query = self._process_dir(self.query_dir,self.query_pkl,relabel=False,mask_info=query_mask_csv)
44 | if verbose:
45 | print("=> DukeV loaded")
46 | self.print_dataset_statistics(train, query, gallery)
47 | self.train = train # list of tuple--(paths,id,cams)
48 | self.query = query
49 | self.gallery = gallery
50 |
51 | self.num_train_pids, self.num_train_tracklets, self.num_train_cams = self.get_videodata_info(self.train)
52 | self.num_query_pids, self.num_query_tracklets, self.num_query_cams = self.get_videodata_info(self.query)
53 | self.num_gallery_pids, self.num_gallery_tracklets, self.num_gallery_cams = self.get_videodata_info(self.gallery)
54 |
55 | def _process_dir(self,dir_path,pkl_path,relabel,mask_info=None):
56 |
57 | if osp.exists(pkl_path):
58 | print('==> %s exisit. Load...'%(pkl_path))
59 | with open(pkl_path,'rb') as f:
60 | pkl_file = pickle.load(f)
61 |
62 | if mask_info is None:
63 | return pkl_file
64 |
65 | tracklets = []
66 | start = 0
67 | for info in pkl_file:
68 | end = start + len(info[0])
69 | tracklets.append((info[0],info[1],info[2],mask_info[start:end,1:].astype('int16')//16))
70 | start = end
71 | return tracklets
72 |
73 | pdirs = sorted(glob.glob(osp.join(dir_path, '*')))
74 | print("Processing {} with {} person identities".format(dir_path, len(pdirs)))
75 | pids = sorted(list(set([int(osp.basename(pdir)) for pdir in pdirs])))
76 | pid2label = {pid : label for label,pid in enumerate(pids)}
77 |
78 | tracklets = []
79 | for pdir in pdirs:
80 | pid = int(osp.basename(pdir))
81 | if relabel : pid = pid2label[pid]
82 | track_dirs = sorted(glob.glob(osp.join(pdir,'*')))
83 | for track_dir in track_dirs:
84 | img_paths = sorted(glob.glob(osp.join(track_dir,'*.jpg')))
85 | num_imgs = len(img_paths)
86 | if num_imgs < self.min_seq_len :
87 | continue
88 | img_name = osp.basename(img_paths[0])
89 | if img_name.find('_') == -1 :
90 | camid = int(img_name[5])-1
91 | else:
92 | camid = int(img_name[6])-1
93 | img_paths = tuple(img_paths)
94 | tracklets.append((img_paths,pid,camid))
95 | # save to pickle
96 | if not osp.isdir(self.info_dir):
97 | os.mkdir(self.info_dir)
98 | with open(pkl_path,'wb') as f:
99 | pickle.dump(tracklets,f)
100 | return tracklets
101 |
102 | def _check_before_run(self):
103 | """Check if all files are available before going deeper"""
104 | if not osp.exists(self.dataset_dir):
105 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
106 | if not osp.exists(self.train_dir):
107 | raise RuntimeError("'{}' is not available".format(self.train_dir))
108 | if not osp.exists(self.gallery_dir):
109 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
110 | if not osp.exists(self.query_dir):
111 | raise RuntimeError("'{}' is not available".format(self.query_dir))
112 |
--------------------------------------------------------------------------------
/data/datasets/MARS.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import glob
3 | import re
4 |
5 | import os.path as osp
6 | from scipy.io import loadmat
7 | from .bases import BaseVideoDataset
8 | import pandas as pd
9 |
10 |
11 | class MARS(BaseVideoDataset):
12 | # dataset_dir = 'MARS'
13 | dataset_dir = 'MARS-DL'
14 | info_dir = 'info'
15 |
16 | def __init__(self, root='/home/mediax/Dataset', verbose=True, min_seq_len =0,new_eval=False):
17 | super(MARS, self).__init__()
18 | self.dataset_dir = osp.join(root, self.dataset_dir)
19 | self.info_dir = osp.join(self.dataset_dir,self.info_dir)
20 | self.train_name_path = osp.join(self.info_dir,'train_name.txt')
21 | self.test_name_path = osp.join(self.info_dir,'test_name.txt')
22 | self.track_train_info_path = osp.join(self.info_dir,'tracks_train_info.mat')
23 | self.track_test_info_path = osp.join(self.info_dir,'tracks_test_info.mat')
24 | self.query_IDX_path = osp.join(self.info_dir,'query_IDX.mat')
25 | self.new_eval = new_eval
26 | if self.new_eval:
27 | self.track_test_info_path = osp.join(self.info_dir,'clean_tracks_test_info.mat')
28 |
29 | if 'DL' in self.dataset_dir:
30 | train_mask_csv = pd.read_csv(osp.join(self.info_dir,'mask_info.csv'),sep=',',header=None).values
31 | test_mask_csv = pd.read_csv(osp.join(self.info_dir,'mask_info_test.csv'),sep=',',header=None).values
32 | else:
33 | train_mask_csv,test_mask_csv = None,None
34 | self._check_before_run()
35 | # prepare meta data
36 | train_names = self._get_names(self.train_name_path)
37 | test_names = self._get_names(self.test_name_path)
38 | track_train = loadmat(self.track_train_info_path)['track_train_info'] #(8298,4)
39 | track_test = loadmat(self.track_test_info_path)['track_test_info'] #(12180,4)
40 |
41 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze()-1 #(1980,) start from 0
42 | track_query = track_test[query_IDX,:]
43 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX]
44 | track_gallery = track_test[gallery_IDX,:]
45 | # track_gallery = track_test
46 |
47 | train = self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True,min_seq_len=min_seq_len,mask_info=train_mask_csv)
48 | query = self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False,mask_info=test_mask_csv,new_eval=self.new_eval)
49 | gallery = self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False,mask_info=test_mask_csv,new_eval=self.new_eval)
50 |
51 | if verbose:
52 | print("=> MARS loaded")
53 | self.print_dataset_statistics(train, query, gallery)
54 |
55 | self.train = train # list of tuple--(paths,id,cams)
56 | self.query = query
57 | self.gallery = gallery
58 |
59 | self.num_train_pids, self.num_train_tracklets, self.num_train_cams = self.get_videodata_info(self.train)
60 | self.num_query_pids, self.num_query_tracklets, self.num_query_cams = self.get_videodata_info(self.query)
61 | self.num_gallery_pids, self.num_gallery_tracklets, self.num_gallery_cams = self.get_videodata_info(self.gallery)
62 |
63 | def _check_before_run(self):
64 | """Check if all files are available before going deeper"""
65 | if not osp.exists(self.dataset_dir):
66 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
67 | if not osp.exists(self.train_name_path):
68 | raise RuntimeError("'{}' is not available".format(self.train_name_path))
69 | if not osp.exists(self.test_name_path):
70 | raise RuntimeError("'{}' is not available".format(self.test_name_path))
71 | if not osp.exists(self.track_train_info_path):
72 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path))
73 | if not osp.exists(self.track_test_info_path):
74 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path))
75 | if not osp.exists(self.query_IDX_path):
76 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path))
77 |
78 | def _get_names(self, fpath):
79 | names = []
80 | with open(fpath, 'r') as f:
81 | for line in f:
82 | new_line = line.rstrip()
83 | names.append(new_line)
84 | return names
85 |
86 | def _process_data(self,names, meta_data, home_dir=None,relabel=False,min_seq_len=0,mask_info=None,new_eval=False):
87 | assert home_dir in ['bbox_train','bbox_test']
88 |
89 | n_tracklets = len(meta_data)
90 | pid_list = list(set(meta_data[:,2].tolist()))
91 | num_pids = len(pid_list)
92 |
93 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)}
94 |
95 | tracklets = []
96 | num_imgs_per_tracklet = []
97 | for tracklet_idx in range(n_tracklets):
98 | data = meta_data[tracklet_idx,...]
99 | if new_eval == True:
100 | start_idx,end_idx,pid,cam, new_pid, new_ambi = data
101 | else:
102 | start_idx,end_idx,pid,cam = data
103 | if pid == -1 or pid == 0 : continue # junk index
104 | assert 1<= cam <=6
105 |
106 | if relabel : pid = pid2label[pid]
107 | cam -= 1
108 | img_names = names[start_idx-1:end_idx]
109 |
110 | if mask_info is not None:
111 | masks = mask_info[start_idx-1:end_idx,1:].astype('int16')//16
112 |
113 | # make sure image names correspond to the same person
114 | pnames = [img_name[:4] for img_name in img_names]
115 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images"
116 | camnames = [img_name[5] for img_name in img_names]
117 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!"
118 |
119 | # append image names with directory information
120 | img_paths = [osp.join(self.dataset_dir,home_dir,img_name[:4],img_name) for img_name in img_names]
121 | if len(img_paths) >= min_seq_len:
122 | img_paths = tuple(img_paths)
123 | if mask_info is not None:
124 | masks = mask_info[start_idx-1:end_idx,1:].astype('int16')//16
125 | if new_eval == True:
126 | tracklets.append((img_paths,pid,new_pid,new_ambi,cam, masks))
127 | else:
128 | tracklets.append((img_paths,pid,cam,masks))
129 | else:
130 | if new_eval == True:
131 | tracklets.append((img_paths,pid,new_pid,new_ambi,cam))
132 | else:
133 | tracklets.append((img_paths,pid,cam))
134 | # num_imgs_per_tracklet.append(len(img_paths))
135 | # n_tracklets = len(tracklets)
136 |
137 | return tracklets
138 |
--------------------------------------------------------------------------------
/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | from .MARS import MARS
3 | from .DukeV import DukeV
4 | from .dataset_loader import ImageDataset,VideoDataset
5 |
6 | __factory = {
7 | 'mars' : MARS,
8 | 'dukev':DukeV
9 | }
10 |
11 |
12 | def get_names():
13 | return __factory.keys()
14 |
15 |
16 | def init_dataset(name, *args, **kwargs):
17 | if name not in __factory.keys():
18 | raise KeyError("Unknown datasets: {}".format(name))
19 | return __factory[name](*args, **kwargs)
20 |
--------------------------------------------------------------------------------
/data/datasets/__pycache__/DukeV.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/DukeV.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/MARS.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/MARS.cpython-35.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/MARS.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/MARS.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/bases.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/bases.cpython-35.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/bases.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/bases.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/dataset_loader.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/dataset_loader.cpython-35.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/dataset_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/dataset_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/dukemtmcreid.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/dukemtmcreid.cpython-35.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/eval_reid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/eval_reid.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/market1501.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/market1501.cpython-35.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/market1501.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/market1501.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/msmt17.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/msmt17.cpython-35.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/msmt17.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/msmt17.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/veri.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/veri.cpython-35.pyc
--------------------------------------------------------------------------------
/data/datasets/__pycache__/veri.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/veri.cpython-36.pyc
--------------------------------------------------------------------------------
/data/datasets/bases.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import numpy as np
3 |
4 |
5 | class BaseDataset(object):
6 | """
7 | Base class of reid dataset
8 | """
9 |
10 | def get_imagedata_info(self, data):
11 | pids, cams = [], []
12 | for _, pid, camid in data:
13 | pids += [pid]
14 | cams += [camid]
15 | pids = set(pids)
16 | cams = set(cams)
17 | num_pids = len(pids)
18 | num_cams = len(cams)
19 | num_imgs = len(data)
20 | return num_pids, num_imgs, num_cams
21 |
22 | def get_videodata_info(self, data, return_tracklet_stats=False):
23 | pids, cams, tracklet_stats = [], [], []
24 | is_mask = (len(data[0]) == 4) or (len(data[0]))
25 | if len(data[0]) == 3 :
26 | for img_paths, pid, camid in data:
27 | pids += [pid]
28 | cams += [camid]
29 | tracklet_stats += [len(img_paths)]
30 | elif len(data[0]) == 4 :
31 | for img_paths, pid, camid ,_ in data:
32 | pids += [pid]
33 | cams += [camid]
34 | tracklet_stats += [len(img_paths)]
35 | elif len(data[0]) == 5 :
36 | for img_paths, pid, new_pid, new_ambi, camid in data:
37 | pids += [new_pid]
38 | cams += [camid]
39 | tracklet_stats += [len(img_paths)]
40 | elif len(data[0]) == 6 :
41 | for img_paths, pid,new_pid,new_ambi, camid ,_ in data:
42 | pids += [new_pid]
43 | cams += [camid]
44 | tracklet_stats += [len(img_paths)]
45 |
46 | pids = set(pids)
47 | cams = set(cams)
48 | num_pids = len(pids)
49 | num_cams = len(cams)
50 | num_tracklets = len(data)
51 | if return_tracklet_stats:
52 | return num_pids, num_tracklets, num_cams, tracklet_stats
53 | return num_pids, num_tracklets, num_cams
54 |
55 | def print_dataset_statistics(self):
56 | raise NotImplementedError
57 |
58 |
59 | class BaseImageDataset(BaseDataset):
60 | """
61 | Base class of image reid dataset
62 | """
63 |
64 | def print_dataset_statistics(self, train, query, gallery):
65 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
66 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
67 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
68 |
69 | print("Dataset statistics:")
70 | print(" ----------------------------------------")
71 | print(" subset | # ids | # images | # cameras")
72 | print(" ----------------------------------------")
73 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
74 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
75 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
76 | print(" ----------------------------------------")
77 |
78 |
79 | class BaseVideoDataset(BaseDataset):
80 | """
81 | Base class of video reid dataset
82 | """
83 |
84 | def print_dataset_statistics(self, train, query, gallery):
85 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \
86 | self.get_videodata_info(train, return_tracklet_stats=True)
87 |
88 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \
89 | self.get_videodata_info(query, return_tracklet_stats=True)
90 |
91 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \
92 | self.get_videodata_info(gallery, return_tracklet_stats=True)
93 |
94 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats
95 | min_num = np.min(tracklet_stats)
96 | max_num = np.max(tracklet_stats)
97 | avg_num = np.mean(tracklet_stats)
98 |
99 | print("Dataset statistics:")
100 | print(" -------------------------------------------")
101 | print(" subset | # ids | # tracklets | # cameras")
102 | print(" -------------------------------------------")
103 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams))
104 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams))
105 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams))
106 | print(" -------------------------------------------")
107 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num))
108 | print(" -------------------------------------------")
109 |
--------------------------------------------------------------------------------
/data/datasets/dataset_loader.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import os.path as osp
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 | import numpy as np
6 | import torch
7 | import random
8 |
9 |
10 | def read_image(img_path):
11 | """Keep reading image until succeed.
12 | This can avoid IOError incurred by heavy IO process."""
13 | got_img = False
14 | if not osp.exists(img_path):
15 | raise IOError("{} does not exist".format(img_path))
16 | while not got_img:
17 | try:
18 | img = Image.open(img_path).convert('RGB')
19 | got_img = True
20 | except IOError:
21 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
22 | pass
23 | return img
24 |
25 |
26 | class ImageDataset(Dataset):
27 | """Image Person ReID Dataset"""
28 |
29 | def __init__(self, dataset, transform=None):
30 | self.dataset = dataset
31 | self.transform = transform
32 |
33 | def __len__(self):
34 | return len(self.dataset)
35 |
36 | def __getitem__(self, index):
37 | img_path, pid, camid = self.dataset[index]
38 | img = read_image(img_path)
39 |
40 | if self.transform is not None:
41 | img = self.transform(img)
42 |
43 | return img, pid, camid, img_path
44 |
45 | class VideoDataset(Dataset):
46 | """Video Person ReID Dataset"""
47 |
48 | def __init__(self,dataset,seq_len=8,sample='RRS',spatial_transform=None, temporal_transform=None,mode='test'):
49 | self.dataset = dataset
50 | self.mask = len(dataset[0]) == 4 or len(dataset[0])==6
51 | self.new_eval = len(dataset[0]) == 5 or len(dataset[0])== 6
52 | self.seq_len = seq_len
53 | self.sample = sample
54 | self.spatial_transform = spatial_transform
55 | self.temporal_transform = temporal_transform
56 | self.mode = mode
57 | def __len__(self):
58 | return len(self.dataset)
59 |
60 | def __getitem__(self,idx):
61 | if self.mask and not self.new_eval:
62 | img_paths, pid, cam ,mask = self.dataset[idx]
63 | elif self.mask and self.new_eval:
64 | img_paths, _, pid, ambi, cam ,mask = self.dataset[idx]
65 | elif not self.mask and self.new_eval:
66 | raise NotImplementedError
67 | else:
68 | img_paths, pid, cam = self.dataset[idx]
69 |
70 | num = len(img_paths)
71 | indices = np.arange(0,num).astype(np.int32)
72 |
73 | # Temporal Sample Methods #
74 | if self.sample == 'RRS' and self.mode != 'test_0':
75 |
76 | num_pads = 0 if num%self.seq_len==0 else self.seq_len - num%self.seq_len
77 | indices = np.concatenate([indices,np.ones(num_pads).astype(np.int32)*(num-1)])
78 | assert len(indices) %self.seq_len == 0
79 |
80 | indices_pool = np.split(indices,self.seq_len)
81 | sampled_indices = []
82 |
83 | if self.mode == 'train':
84 | for part in indices_pool:
85 | sampled_indices.append(np.random.choice(part,1)[0])
86 | elif self.mode == 'test_all_sampled':
87 | sampled_indices = np.vstack(indices_pool).T.flatten()
88 | elif self.mode == 'test_all_continuous':
89 | sampled_indices = np.vstack(indices_pool).flatten()
90 | else :
91 | for part in indices_pool:
92 | sampled_indices.append(part[0])
93 |
94 | elif self.mode == 'test_0':
95 | sampled_indices = self.temporal_transform(indices)
96 | ################################
97 |
98 | imgs = []
99 | for index in sampled_indices:
100 | img_path = img_paths[index]
101 | img = read_image(img_path)
102 | if self.spatial_transform is not None:
103 | img = self.spatial_transform(img)
104 | imgs.append(img)
105 | imgs = torch.stack(imgs,dim=0)
106 |
107 | if self.mode == 'train':
108 | flip_prob = random.random()
109 | if flip_prob > 0.5:
110 | imgs = torch.flip(imgs,dims=[3])
111 |
112 | if self.mask:
113 | sampled_mask = mask[sampled_indices,:]
114 | if self.mode == 'train' and flip_prob > 0.5:
115 | new_start = 128//16 - sampled_mask[:,3]
116 | new_end = 128//16 - sampled_mask[:,2]
117 | sampled_mask[:,2] = new_start
118 | sampled_mask[:,3] = new_end
119 |
120 | if self.new_eval:
121 | return imgs,pid,ambi,cam,torch.tensor(sampled_mask,dtype=torch.int16)
122 | else:
123 | return imgs,pid,cam,torch.tensor(sampled_mask,dtype=torch.int16)
124 | else:
125 | if self.new_eval:
126 | raise NotImplementedError
127 | else:
128 | return imgs,pid,cam
129 |
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/data/datasets/eval_reid.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import numpy as np
3 |
4 |
5 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50,q_ambis=None,g_ambis=None):
6 | num_q, num_g = distmat.shape
7 | if num_g < max_rank:
8 | max_rank = num_g
9 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
10 | indices = np.argsort(distmat, axis=1)
11 | matches = (g_pids[indices] == q_pids[:, np.newaxis])#.astype(np.int32)
12 |
13 | new_eval = (q_ambis is not None) and (g_ambis is not None)
14 | if new_eval:
15 | matches_am2id = (g_ambis[indices] == q_pids[:, np.newaxis])
16 | matches_id2am = (g_pids[indices] == q_ambis[:, np.newaxis])
17 | matches_am2am = (g_ambis[indices] == q_ambis[:, np.newaxis])
18 | matches = matches | matches_am2am | matches_am2id | matches_id2am
19 | matches = matches.astype(np.int32)
20 | # compute cmc curve for each query
21 | all_cmc = []
22 | all_AP = []
23 | num_valid_q = 0. # number of valid query
24 | for q_idx in range(num_q):
25 | # get query pid and camid
26 | q_pid = q_pids[q_idx]
27 | q_camid = q_camids[q_idx]
28 |
29 | # remove gallery samples that have the same pid and camid with query
30 | order = indices[q_idx]
31 | remove = (g_camids[order] == q_camid)
32 | if not new_eval:
33 | remove = remove & (g_pids[order] == q_pid)
34 | else:
35 | q_amb = q_ambis[q_idx]
36 | remove_dis = remove & (g_pids[order] == 0) # distractor with same cam
37 | remove_id2id = remove & (g_pids[order] == q_pid)
38 | remove_am2id = remove & (g_ambis[order] == q_pid)
39 | remove_am2am = remove & (g_ambis[order] == q_amb)
40 | remove_id2am = remove & (g_pids[order] == q_amb)
41 | remove = remove_dis | remove_id2id | remove_am2id | remove_am2am | remove_id2am
42 |
43 | # remove = remove | (g_pids[order] == -1)
44 | keep = np.invert(remove)
45 |
46 | # compute cmc curve
47 | # binary vector, positions with value 1 are correct matches
48 | orig_cmc = matches[q_idx][keep]
49 | if not np.any(orig_cmc):
50 | # this condition is true when query identity does not appear in gallery
51 | continue
52 | cmc = orig_cmc.cumsum()
53 | cmc[cmc > 1] = 1
54 | all_cmc.append(cmc[:max_rank])
55 | num_valid_q += 1.
56 |
57 | # compute average precision
58 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
59 | num_rel = orig_cmc.sum()
60 | tmp_cmc = orig_cmc.cumsum()
61 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
62 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
63 | AP = tmp_cmc.sum() / num_rel
64 | all_AP.append(AP)
65 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
66 |
67 | all_cmc = np.asarray(all_cmc).astype(np.float32)
68 | all_cmc = all_cmc.sum(0) / num_valid_q
69 | mAP = np.mean(all_AP)
70 | return all_cmc, mAP
71 |
--------------------------------------------------------------------------------
/data/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .triplet_sampler import RandomIdentitySampler, RandomIdentitySampler_alignedreid # new add by gu
8 |
--------------------------------------------------------------------------------
/data/samplers/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/samplers/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/data/samplers/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/samplers/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/samplers/__pycache__/triplet_sampler.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/samplers/__pycache__/triplet_sampler.cpython-35.pyc
--------------------------------------------------------------------------------
/data/samplers/__pycache__/triplet_sampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/samplers/__pycache__/triplet_sampler.cpython-36.pyc
--------------------------------------------------------------------------------
/data/samplers/triplet_sampler.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import copy
8 | import random
9 | import torch
10 | from collections import defaultdict
11 |
12 | import numpy as np
13 | from torch.utils.data.sampler import Sampler
14 |
15 |
16 | class RandomIdentitySampler(Sampler):
17 | """
18 | Randomly sample P identities, then for each identity,
19 | randomly sample K instances, therefore batch size is P*K.
20 | Args:
21 | - data_source (list): list of (img_path, pid, camid).
22 | - num_instances (int): number of instances per identity in a batch.
23 | - batch_size (int): number of examples in a batch.
24 | """
25 |
26 | def __init__(self, data_source, batch_size, num_instances):
27 | self.data_source = data_source
28 | self.batch_size = batch_size
29 | self.num_instances = num_instances
30 | self.num_pids_per_batch = self.batch_size // self.num_instances
31 | self.index_dic = defaultdict(list)
32 | # create a pid --> [img idx] mapping
33 | for index, (_, pid, _) in enumerate(self.data_source):
34 | self.index_dic[pid].append(index)
35 | self.pids = list(self.index_dic.keys())
36 |
37 | # estimate number of examples in an epoch
38 | self.length = 0
39 | for pid in self.pids:
40 | idxs = self.index_dic[pid]
41 | num = len(idxs)
42 | if num < self.num_instances:
43 | num = self.num_instances
44 | self.length += num - num % self.num_instances
45 | # for market : from 12936 to 11876
46 |
47 | def __iter__(self):
48 | batch_idxs_dict = defaultdict(list)
49 |
50 | for pid in self.pids:
51 | idxs = copy.deepcopy(self.index_dic[pid])
52 | if len(idxs) < self.num_instances:
53 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
54 | random.shuffle(idxs)
55 | batch_idxs = []
56 | for idx in idxs:
57 | batch_idxs.append(idx)
58 | if len(batch_idxs) == self.num_instances:
59 | batch_idxs_dict[pid].append(batch_idxs)
60 | batch_idxs = []
61 |
62 | avai_pids = copy.deepcopy(self.pids)
63 | final_idxs = []
64 |
65 | while len(avai_pids) >= self.num_pids_per_batch:
66 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
67 | for pid in selected_pids:
68 | batch_idxs = batch_idxs_dict[pid].pop(0)
69 | final_idxs.extend(batch_idxs)
70 | if len(batch_idxs_dict[pid]) == 0:
71 | avai_pids.remove(pid)
72 | self.length = len(final_idxs)
73 | return iter(final_idxs)
74 |
75 | def __len__(self):
76 | return self.length
77 |
78 |
79 | # New add by gu
80 | class RandomIdentitySampler_alignedreid(Sampler):
81 | """
82 | Randomly sample N identities, then for each identity,
83 | randomly sample K instances, therefore batch size is N*K.
84 |
85 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
86 |
87 | Args:
88 | data_source (Dataset): dataset to sample from.
89 | num_instances (int): number of instances per identity.
90 | """
91 | def __init__(self, data_source, num_instances):
92 | self.data_source = data_source
93 | self.num_instances = num_instances
94 | self.index_dic = defaultdict(list)
95 | for index, data in enumerate(data_source):
96 | pid = data[1]
97 | self.index_dic[pid].append(index)
98 | self.pids = list(self.index_dic.keys())
99 | self.num_identities = len(self.pids)
100 |
101 | def __iter__(self):
102 | indices = torch.randperm(self.num_identities)
103 | ret = []
104 | for i in indices:
105 | pid = self.pids[i]
106 | t = self.index_dic[pid]
107 | replace = False if len(t) >= self.num_instances else True
108 | t = np.random.choice(t, size=self.num_instances, replace=replace)
109 | ret.extend(t)
110 | return iter(ret)
111 |
112 | def __len__(self):
113 | return self.num_identities * self.num_instances
114 |
--------------------------------------------------------------------------------
/data/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | from .build import build_transforms_ST
3 |
--------------------------------------------------------------------------------
/data/transforms/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/build.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/build.cpython-35.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/build.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/build.cpython-36.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/spatial_transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/spatial_transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/temporal_transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/temporal_transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/transforms.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/transforms.cpython-35.pyc
--------------------------------------------------------------------------------
/data/transforms/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/data/transforms/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import torchvision.transforms as T
3 |
4 | from .transforms import RandomErasing
5 | from .temporal_transforms import TemporalBeginCrop
6 |
7 |
8 | def build_transforms_ST(cfg,is_train=True):
9 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
10 | if is_train:
11 | transform_list = [T.Resize(cfg.INPUT.SIZE_TRAIN)]
12 | if cfg.INPUT.IF_FLIP == True:
13 | transform_list.append(T.RandomHorizontalFlip(p=cfg.INPUT.PROB))
14 | if cfg.INPUT.IF_CROP == True:
15 | transform_list.append(T.Pad(cfg.INPUT.PADDING))
16 | transform_list.append(T.RandomCrop(cfg.INPUT.SIZE_TRAIN))
17 | transform_list += [T.ToTensor(),normalize_transform]
18 | if cfg.INPUT.IF_RE == True:
19 | transform_list.append(RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN))
20 | spatial_transform = T.Compose(transform_list)
21 | temporal_transforms = None
22 | else:
23 | spatial_transform = T.Compose([
24 | T.Resize(cfg.INPUT.SIZE_TEST),
25 | T.ToTensor(),
26 | normalize_transform
27 | ])
28 | temporal_transforms = TemporalBeginCrop(size=cfg.INPUT.SEQ_LEN)
29 | return spatial_transform,temporal_transforms
30 |
--------------------------------------------------------------------------------
/data/transforms/temporal_transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import random
4 | import math
5 | import numpy as np
6 |
7 |
8 | class LoopPadding(object):
9 |
10 | def __init__(self, size):
11 | self.size = size
12 |
13 | def __call__(self, frame_indices):
14 | out = list(frame_indices)
15 |
16 | while len(out) < self.size:
17 | for index in out:
18 | if len(out) >= self.size:
19 | break
20 | out.append(index)
21 |
22 | return out
23 |
24 |
25 | class TemporalCenterCrop(object):
26 | """Temporally crop the given frame indices at a center.
27 |
28 | If the number of frames is less than the size,
29 | loop the indices as many times as necessary to satisfy the size.
30 |
31 | Args:
32 | size (int): Desired output size of the crop.
33 | """
34 |
35 | def __init__(self, size, padding=True, pad_method='loop'):
36 | self.size = size
37 | self.padding = padding
38 | self.pad_method = pad_method
39 |
40 | def __call__(self, frame_indices):
41 | """
42 | Args:
43 | frame_indices (list): frame indices to be cropped.
44 | Returns:
45 | list: Cropped frame indices.
46 | """
47 |
48 | center_index = len(frame_indices) // 2
49 | begin_index = max(0, center_index - (self.size // 2))
50 | end_index = min(begin_index + self.size, len(frame_indices))
51 |
52 | out = list(frame_indices[begin_index:end_index])
53 |
54 | if self.padding == True:
55 | if self.pad_method == 'loop':
56 | while len(out) < self.size:
57 | for index in out:
58 | if len(out) >= self.size:
59 | break
60 | out.append(index)
61 | else:
62 | while len(out) < self.size:
63 | for index in out:
64 | if len(out) >= self.size:
65 | break
66 | out.append(index)
67 | out.sort()
68 |
69 | return out
70 |
71 |
72 | class TemporalRandomCrop(object):
73 | """Temporally crop the given frame indices at a random location.
74 |
75 | If the number of frames is less than the size,
76 | loop the indices as many times as necessary to satisfy the size.
77 |
78 | Args:
79 | size (int): Desired output size of the crop.
80 | """
81 |
82 | def __init__(self, size=4, stride=8):
83 | self.size = size
84 | self.stride = stride
85 |
86 | def __call__(self, frame_indices):
87 | """
88 | Args:
89 | frame_indices (list): frame indices to be cropped.
90 | Returns:
91 | list: Cropped frame indices.
92 | """
93 | frame_indices = list(frame_indices)
94 |
95 | if len(frame_indices) >= self.size * self.stride:
96 | rand_end = len(frame_indices) - (self.size - 1) * self.stride - 1
97 | begin_index = random.randint(0, rand_end)
98 | end_index = begin_index + (self.size - 1) * self.stride + 1
99 | out = frame_indices[begin_index:end_index:self.stride]
100 | elif len(frame_indices) >= self.size:
101 | index = np.random.choice(len(frame_indices), size=self.size, replace=False)
102 | index.sort()
103 | out = [frame_indices[index[i]] for i in range(self.size)]
104 | else:
105 | index = np.random.choice(len(frame_indices), size=self.size, replace=True)
106 | index.sort()
107 | out = [frame_indices[index[i]] for i in range(self.size)]
108 |
109 | return out
110 |
111 |
112 | class TemporalBeginCrop(object):
113 | """Temporally crop the given frame indices at a beginning.
114 |
115 | If the number of frames is less than the size,
116 | loop the indices as many times as necessary to satisfy the size.
117 |
118 | Args:
119 | size (int): Desired output size of the crop.
120 | """
121 | def __init__(self, size=4):
122 | self.size = size
123 |
124 | def __call__(self, frame_indices):
125 | frame_indices = list(frame_indices)
126 | size = self.size
127 |
128 | if len(frame_indices) >= (size - 1) * 8 + 1:
129 | out = frame_indices[0: (size - 1) * 8 + 1: 8]
130 | elif len(frame_indices) >= (size - 1) * 4 + 1:
131 | out = frame_indices[0: (size - 1) * 4 + 1: 4]
132 | elif len(frame_indices) >= (size - 1) * 2 + 1:
133 | out = frame_indices[0: (size - 1) * 2 + 1: 2]
134 | elif len(frame_indices) >= size:
135 | out = frame_indices[0:size:1]
136 | else:
137 | out = frame_indices[0:size]
138 | while len(out) < size:
139 | for index in out:
140 | if len(out) >= size:
141 | break
142 | out.append(index)
143 |
144 | return out
145 | '''
146 | def __init__(self, size=4):
147 | self.size = size
148 |
149 | def __call__(self, frame_indices):
150 | frame_indices = list(frame_indices)
151 |
152 | if len(frame_indices) >= 25:
153 | out = frame_indices[0:25:8]
154 | elif len(frame_indices) >= 13:
155 | out = frame_indices[0:13:4]
156 | elif len(frame_indices) >= 7:
157 | out = frame_indices[0:7:2]
158 | elif len(frame_indices) >= 4:
159 | out = frame_indices[0:4:1]
160 | else:
161 | out = frame_indices[0:4]
162 | while len(out) < 4:
163 | for index in out:
164 | if len(out) >= 4:
165 | break
166 | out.append(index)
167 |
168 | return out
169 | '''
170 | # class TemporalBeginCrop(object):
171 | # """Temporally crop the given frame indices at a beginning.
172 |
173 | # If the number of frames is less than the size,
174 | # loop the indices as many times as necessary to satisfy the size.
175 |
176 | # Args:
177 | # size (int): Desired output size of the crop.
178 | # """
179 |
180 | # def __init__(self, size=4):
181 | # self.size = size
182 |
183 | # def __call__(self, frame_indices):
184 | # frame_indices = list(frame_indices)
185 |
186 | # if len(frame_indices) >= 4:
187 | # out = frame_indices[0:4:1]
188 | # else:
189 | # out = frame_indices[0:4]
190 | # while len(out) < 4:
191 | # for index in out:
192 | # if len(out) >= 4:
193 | # break
194 | # out.append(index)
195 |
196 | # return out
--------------------------------------------------------------------------------
/data/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: liaoxingyu2@jd.com
5 | """
6 |
7 | import math
8 | import random
9 |
10 |
11 | class RandomErasing(object):
12 | """ Randomly selects a rectangle region in an image and erases its pixels.
13 | 'Random Erasing Data Augmentation' by Zhong et al.
14 | See https://arxiv.org/pdf/1708.04896.pdf
15 | Args:
16 | probability: The probability that the Random Erasing operation will be performed.
17 | sl: Minimum proportion of erased area against input image.
18 | sh: Maximum proportion of erased area against input image.
19 | r1: Minimum aspect ratio of erased area.
20 | mean: Erasing value.
21 | """
22 |
23 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
24 | self.probability = probability
25 | self.mean = mean
26 | self.sl = sl
27 | self.sh = sh
28 | self.r1 = r1
29 |
30 | def __call__(self, img):
31 |
32 | if random.uniform(0, 1) >= self.probability:
33 | return img
34 |
35 | for attempt in range(100):
36 | area = img.size()[1] * img.size()[2]
37 |
38 | target_area = random.uniform(self.sl, self.sh) * area
39 | aspect_ratio = random.uniform(self.r1, 1 / self.r1)
40 |
41 | h = int(round(math.sqrt(target_area * aspect_ratio)))
42 | w = int(round(math.sqrt(target_area / aspect_ratio)))
43 |
44 | if w < img.size()[2] and h < img.size()[1]:
45 | x1 = random.randint(0, img.size()[1] - h)
46 | y1 = random.randint(0, img.size()[2] - w)
47 | if img.size()[0] == 3:
48 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
49 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
50 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
51 | else:
52 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
53 | return img
54 |
55 | return img
56 |
--------------------------------------------------------------------------------
/engine/__pycache__/data_parallel.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/data_parallel.cpython-36.pyc
--------------------------------------------------------------------------------
/engine/__pycache__/inference.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/inference.cpython-36.pyc
--------------------------------------------------------------------------------
/engine/__pycache__/scatter_gather.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/scatter_gather.cpython-36.pyc
--------------------------------------------------------------------------------
/engine/__pycache__/trainer.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/trainer.cpython-35.pyc
--------------------------------------------------------------------------------
/engine/__pycache__/trainer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/trainer.cpython-36.pyc
--------------------------------------------------------------------------------
/engine/__pycache__/vis.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/vis.cpython-36.pyc
--------------------------------------------------------------------------------
/engine/data_parallel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.modules import Module
3 | from torch.nn.parallel.scatter_gather import gather
4 | from torch.nn.parallel.replicate import replicate
5 | from torch.nn.parallel.parallel_apply import parallel_apply
6 |
7 | from .scatter_gather import scatter_kwargs
8 |
9 |
10 | class _DataParallel(Module):
11 | r"""Implements data parallelism at the module level.
12 |
13 | This container parallelizes the application of the given module by
14 | splitting the input across the specified devices by chunking in the batch
15 | dimension. In the forward pass, the module is replicated on each device,
16 | and each replica handles a portion of the input. During the backwards
17 | pass, gradients from each replica are summed into the original module.
18 |
19 | The batch size should be larger than the number of GPUs used. It should
20 | also be an integer multiple of the number of GPUs so that each chunk is the
21 | same size (so that each GPU processes the same number of samples).
22 |
23 | See also: :ref:`cuda-nn-dataparallel-instead`
24 |
25 | Arbitrary positional and keyword inputs are allowed to be passed into
26 | DataParallel EXCEPT Tensors. All variables will be scattered on dim
27 | specified (default 0). Primitive types will be broadcasted, but all
28 | other types will be a shallow copy and can be corrupted if written to in
29 | the model's forward pass.
30 |
31 | Args:
32 | module: module to be parallelized
33 | device_ids: CUDA devices (default: all devices)
34 | output_device: device location of output (default: device_ids[0])
35 |
36 | Example::
37 |
38 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
39 | >>> output = net(input_var)
40 | """
41 |
42 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
43 |
44 | def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_sizes=None):
45 | super(_DataParallel, self).__init__()
46 |
47 | if not torch.cuda.is_available():
48 | self.module = module
49 | self.device_ids = []
50 | return
51 |
52 | if device_ids is None:
53 | device_ids = list(range(torch.cuda.device_count()))
54 | if output_device is None:
55 | output_device = device_ids[0]
56 | self.dim = dim
57 | self.module = module
58 | self.device_ids = device_ids
59 | self.chunk_sizes = chunk_sizes
60 | self.output_device = output_device
61 | if len(self.device_ids) == 1:
62 | self.module.cuda(device_ids[0])
63 |
64 | def forward(self, *inputs, **kwargs):
65 | if not self.device_ids:
66 | return self.module(*inputs, **kwargs)
67 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids, self.chunk_sizes)
68 | if len(self.device_ids) == 1:
69 | return self.module(*inputs[0], **kwargs[0])
70 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
71 | outputs = self.parallel_apply(replicas, inputs, kwargs)
72 | return self.gather(outputs, self.output_device)
73 |
74 | def replicate(self, module, device_ids):
75 | return replicate(module, device_ids)
76 |
77 | def scatter(self, inputs, kwargs, device_ids, chunk_sizes):
78 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim, chunk_sizes=self.chunk_sizes)
79 |
80 | def parallel_apply(self, replicas, inputs, kwargs):
81 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
82 |
83 | def gather(self, outputs, output_device):
84 | return gather(outputs, output_device, dim=self.dim)
85 |
86 |
87 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
88 | r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
89 |
90 | This is the functional version of the DataParallel module.
91 |
92 | Args:
93 | module: the module to evaluate in parallel
94 | inputs: inputs to the module
95 | device_ids: GPU ids on which to replicate module
96 | output_device: GPU location of the output Use -1 to indicate the CPU.
97 | (default: device_ids[0])
98 | Returns:
99 | a Variable containing the result of module(input) located on
100 | output_device
101 | """
102 | if not isinstance(inputs, tuple):
103 | inputs = (inputs,)
104 |
105 | if device_ids is None:
106 | device_ids = list(range(torch.cuda.device_count()))
107 |
108 | if output_device is None:
109 | output_device = device_ids[0]
110 |
111 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
112 | if len(device_ids) == 1:
113 | return module(*inputs[0], **module_kwargs[0])
114 | used_device_ids = device_ids[:len(inputs)]
115 | replicas = replicate(module, used_device_ids)
116 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
117 | return gather(outputs, output_device, dim)
118 |
119 | def DataParallel(module, device_ids=None, output_device=None, dim=0, chunk_sizes=None):
120 | if chunk_sizes is None:
121 | return torch.nn.DataParallel(module, device_ids, output_device, dim)
122 | # standard_size = True
123 | # for i in range(1, len(chunk_sizes)):
124 | # if chunk_sizes[i] != chunk_sizes[0]:
125 | # standard_size = False
126 | # if standard_size:
127 | # return torch.nn.DataParallel(module, device_ids, output_device, dim)
128 | return _DataParallel(module, device_ids, output_device, dim, chunk_sizes)
--------------------------------------------------------------------------------
/engine/inference.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import logging
3 |
4 | import torch
5 | import torch.nn as nn
6 | from ignite.engine import Engine
7 |
8 | from utils.reid_metric import R1_mAP, R1_mAP_reranking
9 | from ignite.contrib.handlers.tqdm_logger import ProgressBar
10 |
11 |
12 | def create_supervised_evaluator(model, metrics,
13 | device=None):
14 | if device:
15 | if torch.cuda.device_count() > 1:
16 | model = nn.DataParallel(model)
17 | model.to(device)
18 |
19 | def _inference(engine, batch):
20 | model.eval()
21 | with torch.no_grad():
22 | data, pids, camids = batch
23 | data = data.to(device) if torch.cuda.device_count() >= 1 else data
24 | feat = model(data)
25 | return feat, pids, camids
26 |
27 | engine = Engine(_inference)
28 |
29 | for name, metric in metrics.items():
30 | metric.attach(engine, name)
31 |
32 | return engine
33 |
34 | def create_supervised_evaluator_with_mask(model, metrics,
35 | device=None):
36 | if device:
37 | # if torch.cuda.device_count() > 1:
38 | # model = nn.DataParallel(model)
39 | model.to(device)
40 |
41 | def _inference(engine, batch):
42 | model.eval()
43 | with torch.no_grad():
44 | data, pids, camids ,masks = batch
45 | data = data.to(device) if torch.cuda.device_count() >= 1 else data
46 | feat = model(data,masks)
47 | return feat, pids, camids
48 |
49 | engine = Engine(_inference)
50 |
51 | for name, metric in metrics.items():
52 | metric.attach(engine, name)
53 |
54 | return engine
55 |
56 | def create_supervised_evaluator_with_mask_new_eval(model, metrics,
57 | device=None):
58 |
59 | if device:
60 | # if torch.cuda.device_count() > 1:
61 | # model = nn.DataParallel(model)
62 | model.to(device)
63 |
64 | def _inference(engine, batch):
65 | model.eval()
66 | with torch.no_grad():
67 | data, pids, ambi, camids ,masks = batch
68 | data = data.to(device) if torch.cuda.device_count() >= 1 else data
69 | feat = model(data,masks)
70 | return feat, pids, ambi, camids
71 |
72 | engine = Engine(_inference)
73 |
74 | for name, metric in metrics.items():
75 | metric.attach(engine, name)
76 |
77 | return engine
78 |
79 | def create_supervised_all_evaluator(model, metrics,seq_len,
80 | device=None):
81 | if device:
82 | if torch.cuda.device_count() > 1:
83 | model = nn.DataParallel(model)
84 | model.to(device)
85 |
86 | def _inference(engine, batch):
87 | model.eval()
88 | feats = []
89 | with torch.no_grad():
90 | data, pids, camids = batch
91 | iteration = data.shape[1]//seq_len
92 | for i in range(iteration):
93 | x = data[:,i*seq_len:(i+1)*seq_len,...]
94 | x = x.to(device) if torch.cuda.device_count() >= 1 else x
95 | feat = model(x)
96 | feats.append(feat)
97 | feats = torch.mean(torch.cat(feats,dim=0),dim=0,keepdim=True)
98 | return feats, pids, camids
99 |
100 | engine = Engine(_inference)
101 |
102 | for name, metric in metrics.items():
103 | metric.attach(engine, name)
104 |
105 | return engine
106 |
107 |
108 | def create_supervised_all_evaluator_with_mask(model, metrics,seq_len,
109 | device=None):
110 | if device:
111 | if torch.cuda.device_count() > 1:
112 | model = nn.DataParallel(model)
113 | model.to(device)
114 |
115 | def _inference(engine, batch):
116 | model.eval()
117 | feats = []
118 | with torch.no_grad():
119 | data, pids, camids, masks = batch
120 | iteration = data.shape[1]//seq_len
121 | for i in range(iteration):
122 | x = data[:,i*seq_len:(i+1)*seq_len,...]
123 | mask = masks[:,i*seq_len:(i+1)*seq_len,...]
124 | x = x.to(device) if torch.cuda.device_count() >= 1 else x
125 | feat = model(x,mask)
126 | feats.append(feat)
127 | feats = torch.mean(torch.cat(feats,dim=0),dim=0,keepdim=True)
128 | return feats, pids, camids
129 |
130 | engine = Engine(_inference)
131 |
132 | for name, metric in metrics.items():
133 | metric.attach(engine, name)
134 |
135 | return engine
136 |
137 | def create_supervised_all_evaluator_with_mask_new_eval(model, metrics,seq_len,
138 | device=None):
139 | if device:
140 | if torch.cuda.device_count() > 1:
141 | model = nn.DataParallel(model)
142 | model.to(device)
143 |
144 | def _inference(engine, batch):
145 | model.eval()
146 | feats = []
147 | with torch.no_grad():
148 | data, pids, ambi, camids, masks = batch
149 | iteration = data.shape[1]//seq_len
150 | for i in range(iteration):
151 | x = data[:,i*seq_len:(i+1)*seq_len,...]
152 | mask = masks[:,i*seq_len:(i+1)*seq_len,...]
153 | x = x.to(device) if torch.cuda.device_count() >= 1 else x
154 | feat = model(x,mask)
155 | feats.append(feat)
156 | feats = torch.mean(torch.cat(feats,dim=0),dim=0,keepdim=True)
157 | return feats, pids, ambi, camids
158 |
159 | engine = Engine(_inference)
160 |
161 | for name, metric in metrics.items():
162 | metric.attach(engine, name)
163 |
164 | return engine
165 |
166 |
167 | def inference(
168 | cfg,
169 | model,
170 | val_loader,
171 | num_query
172 | ):
173 | device = cfg.MODEL.DEVICE
174 |
175 | logger = logging.getLogger("reid_baseline.inference")
176 | logger.info("Enter inferencing")
177 | if cfg.TEST.RE_RANKING == 'no':
178 | print("Create evaluator")
179 | if 'test_all' in cfg.TEST.TEST_MODE:
180 | if len(val_loader.dataset.dataset[0]) == 4: # mask no new eval
181 | evaluator = create_supervised_all_evaluator_with_mask(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},
182 | seq_len=cfg.INPUT.SEQ_LEN,device=device)
183 | elif len(val_loader.dataset.dataset[0]) == 6: # mask , new eval
184 | evaluator = create_supervised_all_evaluator_with_mask_new_eval(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM,new_eval=True)},
185 | seq_len=cfg.INPUT.SEQ_LEN,device=device)
186 | else:
187 | evaluator = create_supervised_all_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},
188 | seq_len=cfg.INPUT.SEQ_LEN,device=device)
189 | else:
190 | if len(val_loader.dataset.dataset[0]) == 6: # mask , new eval
191 | evaluator = create_supervised_evaluator_with_mask_new_eval(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM,new_eval=True)},
192 | device=device)
193 | elif len(val_loader.dataset.dataset[0]) == 4 : # mask, no new eval
194 | evaluator = create_supervised_evaluator_with_mask(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},
195 | device=device)
196 | else:
197 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},
198 | device=device)
199 | elif cfg.TEST.RE_RANKING == 'yes': # haven't implement with mask
200 | print("Create evaluator for reranking")
201 | if 'test_all' in cfg.TEST.TEST_MODE:
202 | evaluator = create_supervised_all_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},
203 | seq_len=cfg.INPUT.SEQ_LEN,device=device)
204 | else:
205 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},
206 | device=device)
207 | else:
208 | print("Unsupported re_ranking config. Only support for no or yes, but got {}.".format(cfg.TEST.RE_RANKING))
209 |
210 | pbar = ProgressBar(persist=True,ncols=120)
211 | pbar.attach(evaluator)
212 |
213 | evaluator.run(val_loader)
214 | cmc, mAP = evaluator.state.metrics['r1_mAP']
215 | logger.info('Validation Results')
216 | logger.info("mAP: {:.1%}".format(mAP))
217 | for r in [1, 5, 10]:
218 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
219 |
--------------------------------------------------------------------------------
/engine/scatter_gather.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from torch.nn.parallel._functions import Scatter
4 |
5 |
6 | def scatter(inputs, target_gpus, dim=0, chunk_sizes=None):
7 | r"""
8 | Slices variables into approximately equal chunks and
9 | distributes them across given GPUs. Duplicates
10 | references to objects that are not variables. Does not
11 | support Tensors.
12 | """
13 | def scatter_map(obj):
14 | if isinstance(obj, Variable):
15 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
16 | assert not torch.is_tensor(obj), "Tensors not supported in scatter."
17 | if isinstance(obj, tuple):
18 | return list(zip(*map(scatter_map, obj)))
19 | if isinstance(obj, list):
20 | return list(map(list, zip(*map(scatter_map, obj))))
21 | if isinstance(obj, dict):
22 | return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
23 | return [obj for targets in target_gpus]
24 |
25 | return scatter_map(inputs)
26 |
27 |
28 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_sizes=None):
29 | r"""Scatter with support for kwargs dictionary"""
30 | inputs = scatter(inputs, target_gpus, dim, chunk_sizes) if inputs else []
31 | kwargs = scatter(kwargs, target_gpus, dim, chunk_sizes) if kwargs else []
32 | if len(inputs) < len(kwargs):
33 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
34 | elif len(kwargs) < len(inputs):
35 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
36 | inputs = tuple(inputs)
37 | kwargs = tuple(kwargs)
38 | return inputs, kwargs
39 |
--------------------------------------------------------------------------------
/engine/trainer.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import logging
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import DataParallel
7 | # from engine.data_parallel import DataParallel
8 | # #self create dataparallel for unbalance GPU memory size
9 | from ignite.engine import Engine, Events
10 | from ignite.handlers import ModelCheckpoint, Timer,global_step_from_engine
11 | from ignite.metrics import RunningAverage
12 | from ignite.contrib.handlers.tqdm_logger import ProgressBar
13 | from utils.reid_metric import R1_mAP
14 |
15 |
16 | def create_supervised_trainer(model, optimizer, loss_fn,
17 | device=None):
18 | if device:
19 | if torch.cuda.device_count() > 1:
20 | model = DataParallel(model)
21 | model.to(device)
22 | def _update(engine, batch):
23 | model.train()
24 | optimizer.zero_grad()
25 | img, target = batch
26 | img = img.to(device) if torch.cuda.device_count() >= 1 else img
27 | target = target.to(device) if torch.cuda.device_count() >= 1 else target
28 | score, feat = model(img)
29 | loss,loss_dict = loss_fn(score, feat, target)
30 | loss.backward()
31 | optimizer.step()
32 | # compute acc
33 | acc = (score.max(1)[1] == target).float().mean()
34 | loss_dict['loss'] = loss.item()
35 | return acc.item(),loss_dict
36 |
37 | return Engine(_update)
38 |
39 | def create_supervised_trainer_with_mask(model, optimizer, loss_fn,
40 | device=None):
41 | if device:
42 | if torch.cuda.device_count() > 1:
43 | model = DataParallel(model)
44 | model.to(device)
45 | def _update(engine, batch):
46 | model.train()
47 | optimizer.zero_grad()
48 | img, target ,masks = batch
49 | img = img.to(device) if torch.cuda.device_count() >= 1 else img
50 | target = target.to(device) if torch.cuda.device_count() >= 1 else target
51 | score, feat = model(img,masks)
52 | loss,loss_dict = loss_fn(score, feat, target)
53 | loss.backward()
54 | optimizer.step()
55 | # compute acc
56 | acc = (score.max(1)[1] == target).float().mean()
57 | loss_dict['loss'] = loss.item()
58 | return acc.item(),loss_dict
59 |
60 | return Engine(_update)
61 |
62 | def create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cetner_loss_weight,
63 | device=None):
64 | if device:
65 | if torch.cuda.device_count() > 1:
66 | model = nn.DataParallel(model)
67 | model.to(device)
68 |
69 | def _update(engine, batch):
70 | model.train()
71 | optimizer.zero_grad()
72 | optimizer_center.zero_grad()
73 | img, target = batch
74 | img = img.to(device) if torch.cuda.device_count() >= 1 else img
75 | target = target.to(device) if torch.cuda.device_count() >= 1 else target
76 | score, feat = model(img)
77 | loss,loss_dict = loss_fn(score, feat, target)
78 | # print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target)))
79 | loss.backward()
80 | optimizer.step()
81 | for param in center_criterion.parameters():
82 | param.grad.data *= (1. / cetner_loss_weight)
83 | optimizer_center.step()
84 |
85 | # compute acc
86 | acc = (score.max(1)[1] == target).float().mean()
87 | loss_dict['loss'] = loss.item()
88 | return acc.item(),loss_dict
89 |
90 | return Engine(_update)
91 |
92 | # +
93 | def create_supervised_evaluator(model, metrics,
94 | device=None):
95 | if device:
96 | # if torch.cuda.device_count() > 1:
97 | # model = nn.DataParallel(model)
98 | model.to(device)
99 |
100 | def _inference(engine, batch):
101 | model.eval()
102 | with torch.no_grad():
103 | data, pids, camids = batch
104 | data = data.to(device) if torch.cuda.device_count() >= 1 else data
105 | feat = model(data)
106 | return feat, pids, camids
107 |
108 | engine = Engine(_inference)
109 |
110 | for name, metric in metrics.items():
111 | metric.attach(engine, name)
112 |
113 | return engine
114 |
115 | def create_supervised_evaluator_with_mask(model, metrics,
116 | device=None):
117 | if device:
118 | # if torch.cuda.device_count() > 1:
119 | # model = nn.DataParallel(model)
120 | model.to(device)
121 |
122 | def _inference(engine, batch):
123 | model.eval()
124 | with torch.no_grad():
125 | data, pids, camids ,masks = batch
126 | data = data.to(device) if torch.cuda.device_count() >= 1 else data
127 | feat = model(data,masks)
128 | return feat, pids, camids
129 |
130 | engine = Engine(_inference)
131 |
132 | for name, metric in metrics.items():
133 | metric.attach(engine, name)
134 |
135 | return engine
136 |
137 | def create_supervised_evaluator_with_mask_new_eval(model, metrics,
138 | device=None):
139 | if device:
140 | # if torch.cuda.device_count() > 1:
141 | # model = nn.DataParallel(model)
142 | model.to(device)
143 |
144 | def _inference(engine, batch):
145 | model.eval()
146 | with torch.no_grad():
147 | data, pids, ambi, camids ,masks = batch
148 | data = data.to(device) if torch.cuda.device_count() >= 1 else data
149 | feat = model(data,masks)
150 | return feat, pids, ambi, camids
151 |
152 | engine = Engine(_inference)
153 |
154 | for name, metric in metrics.items():
155 | metric.attach(engine, name)
156 |
157 | return engine
158 |
159 | # -
160 |
161 | def do_train(
162 | cfg,
163 | model,
164 | train_loader,
165 | val_loader,
166 | optimizer,
167 | scheduler,
168 | loss_fn,
169 | num_query,
170 | start_epoch
171 | ):
172 | # checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
173 | eval_period = cfg.SOLVER.EVAL_PERIOD
174 | output_dir = cfg.OUTPUT_DIR
175 | device = cfg.MODEL.DEVICE
176 | epochs = cfg.SOLVER.MAX_EPOCHS
177 |
178 | logger = logging.getLogger("reid_baseline.train")
179 | logger.info("Start training")
180 | # Create 1. trainer 2. evaluator 3. checkpointer 4. timer 5. pbar
181 | if len(train_loader.dataset.dataset[0]) == 4 : #train with mask
182 | trainer = create_supervised_trainer_with_mask(model, optimizer, loss_fn, device=device)
183 | if cfg.TEST.NEW_EVAL == False:
184 | evaluator = create_supervised_evaluator_with_mask(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
185 | else:
186 | evaluator = create_supervised_evaluator_with_mask_new_eval(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM,new_eval=True)}, device=device)
187 | else: # no mask
188 | trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
189 | if cfg.TEST.NEW_EVAL == False:
190 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
191 | else:
192 | raise NotImplementedError
193 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, n_saved=1, require_empty=False,\
194 | score_function=lambda x : x.state.metrics['r1_mAP'][1],\
195 | global_step_transform=global_step_from_engine(trainer))
196 | timer = Timer(average=True)
197 | tpbar = ProgressBar(persist=True,ncols=120)
198 | epbar = ProgressBar(persist=True,ncols=120)
199 | #############################################################
200 | evaluator.add_event_handler(Events.EPOCH_COMPLETED(every=1), checkpointer, \
201 | {'model': model,'optimizer': optimizer})
202 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
203 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
204 | tpbar.attach(trainer)
205 | epbar.attach(evaluator)
206 |
207 | # average metric to attach on trainer
208 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_acc')
209 | RunningAverage(output_transform=lambda x: x[1]['loss']).attach(trainer, 'avg_loss')
210 | RunningAverage(output_transform=lambda x: x[1]['triplet']).attach(trainer, 'avg_trip')
211 |
212 |
213 | @trainer.on(Events.STARTED)
214 | def start_training(engine):
215 | engine.state.epoch = start_epoch
216 |
217 | @trainer.on(Events.EPOCH_COMPLETED)
218 | def adjust_learning_rate(engine):
219 | # if engine.state.epoch == 1:
220 | # scheduler.step()
221 | scheduler.step()
222 |
223 |
224 | # adding handlers using `trainer.on` decorator API
225 | @trainer.on(Events.EPOCH_COMPLETED)
226 | def print_times(engine):
227 | logger.info('Epoch {} done. Total Loss : {:.3f}, Triplet Loss : {:.3f}, Acc : {:.3f}, Base Lr : {:.2e}'
228 | .format(engine.state.epoch, engine.state.metrics['avg_loss'],engine.state.metrics['avg_trip'],
229 | engine.state.metrics['avg_acc'],scheduler.get_last_lr()[0]))
230 | timer.reset()
231 |
232 | @trainer.on(Events.EPOCH_COMPLETED)
233 | def log_validation_results(engine):
234 | if engine.state.epoch % eval_period == 0:
235 | # evaluator.state.epoch = trainer.state.epoch
236 | evaluator.run(val_loader)
237 | cmc, mAP = evaluator.state.metrics['r1_mAP']
238 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
239 | logger.info("mAP: {:.1%}".format(mAP))
240 | for r in [1, 5, 10]:
241 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
242 |
243 | trainer.run(train_loader, max_epochs=epochs)
244 |
245 |
246 | def do_train_with_center(
247 | cfg,
248 | model,
249 | center_criterion,
250 | train_loader,
251 | val_loader,
252 | optimizer,
253 | optimizer_center,
254 | scheduler,
255 | loss_fn,
256 | num_query,
257 | start_epoch
258 | ):
259 | # checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
260 | eval_period = cfg.SOLVER.EVAL_PERIOD
261 | output_dir = cfg.OUTPUT_DIR
262 | device = cfg.MODEL.DEVICE
263 | epochs = cfg.SOLVER.MAX_EPOCHS
264 |
265 | logger = logging.getLogger("reid_baseline.train")
266 | logger.info("Start training")
267 | trainer = create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device)
268 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
269 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, None, n_saved=10, require_empty=False)
270 | timer = Timer(average=True)
271 | pbar = ProgressBar(persist=True,ncols=120)
272 | trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_period), checkpointer, {'model': model,
273 | 'optimizer': optimizer,
274 | 'center_param': center_criterion,
275 | 'optimizer_center': optimizer_center})
276 |
277 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
278 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
279 | pbar.attach(trainer)
280 |
281 | # average metric to attach on trainer
282 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_acc')
283 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_loss')
284 | RunningAverage(output_transform=lambda x: x[1]['triplet']).attach(trainer, 'avg_trip')
285 | RunningAverage(output_transform=lambda x: x[1]['center']).attach(trainer, 'avg_center')
286 |
287 | @trainer.on(Events.STARTED)
288 | def start_training(engine):
289 | engine.state.epoch = start_epoch
290 |
291 | @trainer.on(Events.EPOCH_COMPLETED)
292 | def adjust_learning_rate(engine):
293 | scheduler.step()
294 | # adding handlers using `trainer.on` decorator API
295 | @trainer.on(Events.EPOCH_COMPLETED)
296 | def print_times(engine):
297 | logger.info('Epoch {} done. Total Loss : {:.3f}, Triplet Loss : {:.3f}, Center Loss , Acc : {:.3f}, Base Lr : {:.2e}'
298 | .format(engine.state.epoch, engine.state.metrics['avg_loss'],engine.state.metrics['avg_trip'],
299 | engine.state.metrics['avg_center'],engine.state.metrics['avg_acc'],scheduler.get_lr()[0]))
300 | timer.reset()
301 |
302 | @trainer.on(Events.EPOCH_COMPLETED)
303 | def log_validation_results(engine):
304 | if engine.state.epoch % eval_period == 0:
305 | evaluator.run(val_loader)
306 | cmc, mAP = evaluator.state.metrics['r1_mAP']
307 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
308 | logger.info("mAP: {:.1%}".format(mAP))
309 | for r in [1, 5, 10]:
310 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
311 |
312 | trainer.run(train_loader, max_epochs=epochs)
313 |
--------------------------------------------------------------------------------
/imgs/DL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/imgs/DL.png
--------------------------------------------------------------------------------
/imgs/DL_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/imgs/DL_2.png
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import torch.nn.functional as F
3 |
4 | from .triplet_loss import TripletLoss,CrossEntropyLabelSmooth
5 | from .center_loss import CenterLoss
6 |
7 |
8 | def make_loss(cfg, num_classes): # modified by gu
9 | sampler = cfg.DATALOADER.SAMPLER
10 | # Creating Triplet
11 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
12 | if cfg.SOLVER.SOFT_MARGIN: margin = None
13 | else: margin = cfg.SOLVER.MARGIN
14 | triplet = TripletLoss(margin) # triplet loss
15 | else:
16 | print('expected METRIC_LOSS_TYPE should be triplet'
17 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
18 | # Whether Label Smoothing
19 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
20 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
21 | print("label smooth on, numclasses:", num_classes)
22 |
23 | # Return loss_func
24 | loss_dict = {'triplet':0,'id_loss':0,'center':0} # for logging
25 | if sampler == 'softmax':
26 | def loss_func(score, feat, target):
27 | id_loss = F.cross_entropy(score,target)
28 | loss_dict['id_loss'] = id_loss.item()
29 | return id_loss,loss_dict
30 | elif cfg.DATALOADER.SAMPLER == 'triplet':
31 | def loss_func(score, feat, target):
32 | metric = triplet(feat,target)[0]
33 | loss_dict['triplet'] = metric.item()
34 | return metric,loss_dict
35 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
36 | def loss_func(score, feat, target):
37 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
38 | metric = triplet(feat,target)[0]
39 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
40 | id_loss = xent(score,target)
41 | else:
42 | id_loss = F.cross_entropy(score,target)
43 | loss_dict['triplet'] = metric.item()
44 | loss_dict['id_loss'] = id_loss.item()
45 | return metric+id_loss,loss_dict
46 | else:
47 | print('expected METRIC_LOSS_TYPE should be triplet'
48 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
49 | else:
50 | print('expected sampler should be softmax, triplet or softmax_triplet, '
51 | 'but got {}'.format(cfg.DATALOADER.SAMPLER))
52 | return loss_func
53 |
54 |
55 | def make_loss_with_center(cfg, num_classes): # modified by gu
56 | if cfg.MODEL.NAME == 'resnet18' or cfg.MODEL.NAME == 'resnet34':
57 | feat_dim = 512
58 | else:
59 | feat_dim = 2048
60 |
61 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
62 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
63 |
64 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
65 | if cfg.SOLVER.SOFT_MARGIN: margin = None
66 | else: margin = cfg.SOLVER.MARGIN
67 | triplet = TripletLoss(margin) # triplet loss
68 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
69 | else:
70 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
71 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
72 |
73 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
74 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo
75 | print("label smooth on, numclasses:", num_classes)
76 |
77 | def loss_func(score, feat, target):
78 | loss_dict = {'triplet':0,'id_loss':0,'center':0}
79 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center':
80 | center = center_criterion(feat,target)
81 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
82 | id_loss = xent(score, target)
83 | else:
84 | id_loss = F.cross_entropy(score, target)
85 | loss = cfg.SOLVER.CENTER_LOSS_WEIGHT * center + id_loss
86 | loss_dict['id_loss'] = id_loss.item()
87 | loss_dict['center'] = center.item()
88 | return loss,loss_dict
89 |
90 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center':
91 | metric = triplet(feat,target)[0]
92 | center = center_criterion(feat,target)
93 | if cfg.MODEL.IF_LABELSMOOTH == 'on':
94 | id_loss = xent(score, target)
95 | else:
96 | id_loss = F.cross_entropy(score, target)
97 | loss = cfg.SOLVER.CENTER_LOSS_WEIGHT * center + id_loss + metric
98 | loss_dict['id_loss'] = id_loss.item()
99 | loss_dict['center'] = center.item()
100 | loss_dict['triplet'] = metric.item()
101 | return loss,loss_dict
102 |
103 | else:
104 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center'
105 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
106 | return loss_func, center_criterion
107 |
--------------------------------------------------------------------------------
/layers/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/layers/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/center_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/layers/__pycache__/center_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/old_triplet_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/layers/__pycache__/old_triplet_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/triplet_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/layers/__pycache__/triplet_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/center_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class CenterLoss(nn.Module):
8 | """Center loss.
9 |
10 | Reference:
11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
12 |
13 | Args:
14 | num_classes (int): number of classes.
15 | feat_dim (int): feature dimension.
16 | """
17 |
18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
19 | super(CenterLoss, self).__init__()
20 | self.num_classes = num_classes
21 | self.feat_dim = feat_dim
22 | self.use_gpu = use_gpu
23 |
24 | if self.use_gpu:
25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
26 | else:
27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
28 |
29 | def forward(self, x, labels):
30 | """
31 | Args:
32 | x: feature matrix with shape (batch_size, feat_dim).
33 | labels: ground truth labels with shape (num_classes).
34 | """
35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
36 |
37 | batch_size = x.size(0)
38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
40 | distmat.addmm_(x, self.centers.t(),1,-2)
41 |
42 | classes = torch.arange(self.num_classes).long()
43 | if self.use_gpu: classes = classes.cuda()
44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
45 | mask = labels.eq(classes.expand(batch_size, self.num_classes))
46 |
47 | dist = distmat * mask.float()
48 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
49 | #dist = []
50 | #for i in range(batch_size):
51 | # value = distmat[i][mask[i]]
52 | # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability
53 | # dist.append(value)
54 | #dist = torch.cat(dist)
55 | #loss = dist.mean()
56 | return loss
57 |
58 |
59 | if __name__ == '__main__':
60 | use_gpu = False
61 | center_loss = CenterLoss(use_gpu=use_gpu)
62 | features = torch.rand(16, 2048)
63 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long()
64 | if use_gpu:
65 | features = torch.rand(16, 2048).cuda()
66 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda()
67 |
68 | loss = center_loss(features, targets)
69 | print(loss)
70 |
--------------------------------------------------------------------------------
/layers/triplet_loss.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | import torch
7 | from torch import nn
8 |
9 |
10 | def normalize(x, axis=-1):
11 | """Normalizing to unit length along the specified dimension.
12 | Args:
13 | x: pytorch Variable
14 | Returns:
15 | x: pytorch Variable, same shape as input
16 | """
17 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
18 | return x
19 |
20 |
21 | def euclidean_dist(x, y):
22 | """
23 | Args:
24 | x: pytorch Variable, with shape [m, d]
25 | y: pytorch Variable, with shape [n, d]
26 | Returns:
27 | dist: pytorch Variable, with shape [m, n]
28 | """
29 | m, n = x.size(0), y.size(0)
30 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
31 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
32 | dist = xx + yy
33 | dist.addmm_(x,y.t(),beta=1, alpha=-2)
34 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
35 | return dist
36 |
37 |
38 | def hard_example_mining(dist_mat, labels, return_inds=False):
39 | """For each anchor, find the hardest positive and negative sample.
40 | Args:
41 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
42 | labels: pytorch LongTensor, with shape [N]
43 | return_inds: whether to return the indices. Save time if `False`(?)
44 | Returns:
45 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
46 | dist_an: pytorch Variable, distance(anchor, negative); shape [N]
47 | p_inds: pytorch LongTensor, with shape [N];
48 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
49 | n_inds: pytorch LongTensor, with shape [N];
50 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
51 | NOTE: Only consider the case in which all labels have same num of samples,
52 | thus we can cope with all anchors in parallel.
53 | """
54 |
55 | assert len(dist_mat.size()) == 2
56 | assert dist_mat.size(0) == dist_mat.size(1)
57 | N = dist_mat.size(0)
58 |
59 | # shape [N, N]
60 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
61 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
62 |
63 | # `dist_ap` means distance(anchor, positive)
64 | # both `dist_ap` and `relative_p_inds` with shape [N, 1]
65 | dist_ap, relative_p_inds = torch.max(
66 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
67 | # `dist_an` means distance(anchor, negative)
68 | # both `dist_an` and `relative_n_inds` with shape [N, 1]
69 | dist_an, relative_n_inds = torch.min(
70 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
71 | # shape [N]
72 | dist_ap = dist_ap.squeeze(1)
73 | dist_an = dist_an.squeeze(1)
74 |
75 | if return_inds:
76 | # shape [N, N]
77 | ind = (labels.new().resize_as_(labels)
78 | .copy_(torch.arange(0, N).long())
79 | .unsqueeze(0).expand(N, N))
80 | # shape [N, 1]
81 | p_inds = torch.gather(
82 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
83 | n_inds = torch.gather(
84 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
85 | # shape [N]
86 | p_inds = p_inds.squeeze(1)
87 | n_inds = n_inds.squeeze(1)
88 | return dist_ap, dist_an, p_inds, n_inds
89 |
90 | return dist_ap, dist_an
91 |
92 |
93 | class TripletLoss(object):
94 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
95 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
96 | Loss for Person Re-Identification'."""
97 |
98 | def __init__(self, margin=-0.1):
99 | self.margin = margin
100 | if margin is not None:
101 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
102 | else:
103 | self.ranking_loss = nn.SoftMarginLoss()
104 |
105 | def __call__(self, global_feat, labels, normalize_feature=False):
106 | if normalize_feature:
107 | global_feat = normalize(global_feat, axis=-1)
108 | dist_mat = euclidean_dist(global_feat, global_feat)
109 | dist_ap, dist_an = hard_example_mining(
110 | dist_mat, labels)
111 | y = dist_an.new().resize_as_(dist_an).fill_(1)
112 | if self.margin is not None:
113 | loss = self.ranking_loss(dist_an, dist_ap, y)
114 | else:
115 | loss = self.ranking_loss(dist_an - dist_ap, y)
116 | return loss, dist_ap, dist_an
117 |
118 | class CrossEntropyLabelSmooth(nn.Module):
119 | """Cross entropy loss with label smoothing regularizer.
120 |
121 | Reference:
122 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
123 | Equation: y = (1 - epsilon) * y + epsilon / K.
124 |
125 | Args:
126 | num_classes (int): number of classes.
127 | epsilon (float): weight.
128 | """
129 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
130 | super(CrossEntropyLabelSmooth, self).__init__()
131 | self.num_classes = num_classes
132 | self.epsilon = epsilon
133 | self.use_gpu = use_gpu
134 | self.logsoftmax = nn.LogSoftmax(dim=1)
135 |
136 | def forward(self, inputs, targets):
137 | """
138 | Args:
139 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
140 | targets: ground truth labels with shape (num_classes)
141 | """
142 | log_probs = self.logsoftmax(inputs)
143 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
144 | if self.use_gpu: targets = targets.cuda()
145 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
146 | loss = (- targets * log_probs).mean(0).sum()
147 | return loss
148 |
--------------------------------------------------------------------------------
/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | from .network import VNetwork
3 |
4 |
5 | def build_model(cfg, num_classes):
6 | if cfg.MODEL.SETTING == 'video':
7 | model = VNetwork(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, \
8 | cfg.MODEL.PRETRAIN_CHOICE, cfg.MODEL.TEMP,cfg.MODEL.NON_LAYERS,cfg.INPUT.SEQ_LEN)
9 | return model
10 | else:
11 | raise NotImplementedError()
--------------------------------------------------------------------------------
/modeling/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/__pycache__/baseline.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/__pycache__/baseline.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/__pycache__/network.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/__pycache__/network.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/ResNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import copy
4 | import torchvision
5 | import torch.nn as nn
6 | from torch.nn import init
7 | from torch.autograd import Variable
8 | from torch.nn import functional as F
9 |
10 | from .SA import inflate
11 | from .SA import AP3D
12 | from .SA import NonLocal
13 | from .SA import SelfAttn
14 |
15 |
16 | def weights_init_kaiming(m):
17 | classname = m.__class__.__name__
18 | if classname.find('Conv') != -1:
19 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
20 | init.constant_(m.bias.data, 0.0)
21 | elif classname.find('Linear') != -1:
22 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
23 | init.constant_(m.bias.data, 0.0)
24 | elif classname.find('BatchNorm') != -1:
25 | init.normal_(m.weight.data, 1.0, 0.02)
26 | init.constant_(m.bias.data, 0.0)
27 |
28 |
29 | def weights_init_classifier(m):
30 | classname = m.__class__.__name__
31 | if classname.find('Linear') != -1:
32 | init.normal_(m.weight.data, std=0.001)
33 | init.constant_(m.bias.data, 0.0)
34 |
35 |
36 | class Bottleneck3D(nn.Module):
37 | def __init__(self, bottleneck2d, block, inflate_time=False, temperature=4, contrastive_att=True):
38 | super(Bottleneck3D, self).__init__()
39 |
40 | self.conv1 = inflate.inflate_conv(bottleneck2d.conv1, time_dim=1)
41 | self.bn1 = inflate.inflate_batch_norm(bottleneck2d.bn1)
42 | if inflate_time == True:
43 | self.conv2 = block(bottleneck2d.conv2, temperature=temperature, contrastive_att=contrastive_att)
44 | else:
45 | self.conv2 = inflate.inflate_conv(bottleneck2d.conv2, time_dim=1)
46 | self.bn2 = inflate.inflate_batch_norm(bottleneck2d.bn2)
47 | self.conv3 = inflate.inflate_conv(bottleneck2d.conv3, time_dim=1)
48 | self.bn3 = inflate.inflate_batch_norm(bottleneck2d.bn3)
49 | self.relu = nn.ReLU(inplace=True)
50 |
51 | if bottleneck2d.downsample is not None:
52 | self.downsample = self._inflate_downsample(bottleneck2d.downsample)
53 | else:
54 | self.downsample = None
55 |
56 | def _inflate_downsample(self, downsample2d, time_stride=1):
57 | downsample3d = nn.Sequential(
58 | inflate.inflate_conv(downsample2d[0], time_dim=1,
59 | time_stride=time_stride),
60 | inflate.inflate_batch_norm(downsample2d[1]))
61 | return downsample3d
62 |
63 | def forward(self, x):
64 | residual = x
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv3(out)
74 | out = self.bn3(out)
75 |
76 | if self.downsample is not None:
77 | residual = self.downsample(x)
78 |
79 | out += residual
80 | out = self.relu(out)
81 |
82 | return out
83 |
84 |
85 | class ResNet503D(nn.Module):
86 | def __init__(self, block, c3d_idx, nl_idx, sa_idx, temperature=4, contrastive_att=True, seq_len=6,**kwargs):
87 | super(ResNet503D, self).__init__()
88 |
89 | self.block = block
90 | self.temperature = temperature
91 | self.contrastive_att = contrastive_att
92 | self.inplanes = 64
93 | self.seq_len = seq_len
94 |
95 | resnet2d = torchvision.models.resnet50(pretrained=True)
96 | resnet2d.layer4[0].conv2.stride=(1, 1)
97 | resnet2d.layer4[0].downsample[0].stride=(1, 1)
98 |
99 | ############ STEM ###################
100 | self.conv1 = inflate.inflate_conv(resnet2d.conv1, time_dim=1)
101 | self.bn1 = inflate.inflate_batch_norm(resnet2d.bn1)
102 | self.relu = nn.ReLU(inplace=True)
103 | self.maxpool = inflate.inflate_pool(resnet2d.maxpool, time_dim=1)
104 | #####################################
105 |
106 | self.layer1 = self._inflate_reslayer(resnet2d.layer1, c3d_idx=c3d_idx[0], \
107 | nl_idx=nl_idx[0], sa_idx= sa_idx[0],in_channels=256,ks=[64,32,seq_len])
108 | self.layer2 = self._inflate_reslayer(resnet2d.layer2, c3d_idx=c3d_idx[1], \
109 | nl_idx=nl_idx[1], sa_idx=sa_idx[1],in_channels=512,ks=[32,16,seq_len])
110 | self.layer3 = self._inflate_reslayer(resnet2d.layer3, c3d_idx=c3d_idx[2], \
111 | nl_idx=nl_idx[2], sa_idx=sa_idx[2],in_channels=1024,ks=[16,8,seq_len])
112 | self.layer4 = self._inflate_reslayer(resnet2d.layer4, c3d_idx=c3d_idx[3], \
113 | nl_idx=nl_idx[3], sa_idx=sa_idx[3],in_channels=2048,ks=[16,8,seq_len])
114 |
115 | def _inflate_reslayer(self, reslayer2d, c3d_idx, nl_idx=[], sa_idx=[],in_channels=0,ks=[64,32,1]):
116 | reslayers3d = []
117 | for i,layer2d in enumerate(reslayer2d):
118 | if i not in c3d_idx: # normal 2D convolution
119 | layer3d = Bottleneck3D(layer2d, AP3D.C2D, inflate_time=False)
120 | else: # (AP)I3D, (AP)P3D-A,B,C
121 | layer3d = Bottleneck3D(layer2d, self.block, inflate_time=True, \
122 | temperature=self.temperature, contrastive_att=self.contrastive_att)
123 | reslayers3d.append(layer3d)
124 |
125 | if (i in nl_idx) and (i not in sa_idx):
126 | non_local_block = NonLocal.NonLocalBlock3D(in_channels, sub_sample=True)
127 | reslayers3d.append(non_local_block)
128 | elif (i in sa_idx) and (i not in nl_idx):
129 | if ks[0] == 32:
130 | sa_block = SelfAttn.AxialBlock(in_channels,inter_channel=None,kernel_size=ks,granularity=4,groups=8,positional='r_qkv',order='hwt')
131 | else:
132 | sa_block = SelfAttn.AxialBlock(in_channels,inter_channel=None,kernel_size=ks,granularity=4,groups=8,positional='r_qkv',order='hwt')
133 | reslayers3d.append(sa_block)
134 | elif (i in sa_idx) and (i in nl_idx):
135 | raise ValueError("can not use nl and sa at the same time!")
136 | return nn.Sequential(*reslayers3d)
137 |
138 | def forward(self, x):
139 | x = self.conv1(x)
140 | x = self.bn1(x)
141 | x = self.relu(x)
142 | x = self.maxpool(x)
143 |
144 | x = self.layer1(x)
145 | x = self.layer2(x)
146 | x = self.layer3(x)
147 | x = self.layer4(x)
148 |
149 | return x
150 |
151 |
152 | def AP3DResNet50(num_classes, **kwargs):
153 | c3d_idx = [[],[0, 2],[0, 2, 4],[]]
154 | nl_idx = [[],[],[],[]]
155 |
156 | return ResNet503D(num_classes, AP3D.APP3DC, c3d_idx, nl_idx, **kwargs)
157 |
158 | def P3D_ResNet50(**kwargs):
159 | c3d_idx = [[],[0,2],[0,2,4],[]]
160 | nl_idx = [[],[],[],[]]
161 | sa_idx = [[],[],[],[]]
162 | return ResNet503D(AP3D.P3DC, c3d_idx, nl_idx, sa_idx, **kwargs)
163 |
164 | def P3D_Axial_ResNet50(**kwargs):
165 | c3d_idx = [[],[0,1],[0,1,2],[]]
166 | nl_idx = [[],[],[],[]]
167 | sa_idx = [[],[2,3],[3,4,5],[]]
168 | return ResNet503D(AP3D.P3DC, c3d_idx, nl_idx, sa_idx, **kwargs)
169 |
170 | def C2D_Axial_ResNet50(**kwargs):
171 | c3d_idx = [[],[],[],[]]
172 | nl_idx = [[],[],[],[]]
173 | sa_idx = [[],[2,3],[3,4,5],[]]
174 | return ResNet503D(AP3D.APP3DC, c3d_idx, nl_idx, sa_idx, **kwargs)
175 |
176 |
177 | def C2DResNet50(num_classes, **kwargs):
178 | c3d_idx = [[],[],[],[]]
179 | nl_idx = [[],[],[],[]]
180 | return ResNet503D(num_classes, AP3D.APP3DC, c3d_idx, nl_idx, sa_idx, **kwargs)
181 |
182 | def C2DNLResNet50(num_classes, **kwargs):
183 | c3d_idx = [[],[],[],[]]
184 | nl_idx = [[],[2, 3],[3, 4, 5],[]]
185 |
186 | return ResNet503D(num_classes, AP3D.APP3DC, c3d_idx, nl_idx, **kwargs)
187 |
188 | def AP3DNLResNet50(num_classes, **kwargs):
189 | c3d_idx = [[],[0, 2],[0, 2, 4],[]]
190 | nl_idx = [[],[1, 3],[1, 3, 5],[]]
191 |
192 | return ResNet503D(num_classes, AP3D.APP3DC, c3d_idx, nl_idx, **kwargs)
193 |
--------------------------------------------------------------------------------
/modeling/backbones/SA/NonLocal.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import math
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 |
9 | class NonLocalBlockND(nn.Module):
10 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
11 | super(NonLocalBlockND, self).__init__()
12 |
13 | assert dimension in [1, 2, 3]
14 |
15 | self.dimension = dimension
16 | self.sub_sample = sub_sample
17 | self.in_channels = in_channels
18 | self.inter_channels = inter_channels
19 |
20 | if self.inter_channels is None:
21 | self.inter_channels = in_channels // 2
22 | if self.inter_channels == 0:
23 | self.inter_channels = 1
24 |
25 | if dimension == 3:
26 | conv_nd = nn.Conv3d
27 | max_pool = nn.MaxPool3d
28 | bn = nn.BatchNorm3d
29 | elif dimension == 2:
30 | conv_nd = nn.Conv2d
31 | max_pool = nn.MaxPool2d
32 | bn = nn.BatchNorm2d
33 | else:
34 | conv_nd = nn.Conv1d
35 | max_pool = nn.MaxPool1d
36 | bn = nn.BatchNorm1d
37 |
38 | self.g = conv_nd(self.in_channels, self.inter_channels,
39 | kernel_size=1, stride=1, padding=0, bias=True)
40 | self.theta = conv_nd(self.in_channels, self.inter_channels,
41 | kernel_size=1, stride=1, padding=0, bias=True)
42 | self.phi = conv_nd(self.in_channels, self.inter_channels,
43 | kernel_size=1, stride=1, padding=0, bias=True)
44 | if sub_sample:
45 | if dimension == 3:
46 | self.g = nn.Sequential(self.g, max_pool((1, 2, 2)))
47 | self.phi = nn.Sequential(self.phi, max_pool((1, 2, 2)))
48 | else:
49 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2))
50 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))
51 |
52 | if bn_layer:
53 | self.W = nn.Sequential(
54 | conv_nd(self.inter_channels, self.in_channels,
55 | kernel_size=1, stride=1, padding=0, bias=True),
56 | bn(self.in_channels)
57 | )
58 | else:
59 | self.W = conv_nd(self.inter_channels, self.in_channels,
60 | kernel_size=1, stride=1, padding=0, bias=True)
61 |
62 | # init
63 | for m in self.modules():
64 | if isinstance(m, conv_nd):
65 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
66 | m.weight.data.normal_(0, math.sqrt(2. / n))
67 | elif isinstance(m, bn):
68 | m.weight.data.fill_(1)
69 | m.bias.data.zero_()
70 |
71 | if bn_layer:
72 | nn.init.constant_(self.W[1].weight.data, 0.0)
73 | nn.init.constant_(self.W[1].bias.data, 0.0)
74 | else:
75 | nn.init.constant_(self.W.weight.data, 0.0)
76 | nn.init.constant_(self.W.bias.data, 0.0)
77 |
78 |
79 | def forward(self, x):
80 | '''
81 | :param x: (b, c, t, h, w)
82 | :return:
83 | '''
84 |
85 | batch_size = x.size(0)
86 |
87 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
88 | g_x = g_x.permute(0, 2, 1)
89 |
90 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
91 | theta_x = theta_x.permute(0, 2, 1)
92 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
93 | f = torch.matmul(theta_x, phi_x)
94 | f = F.softmax(f, dim=-1)
95 |
96 | y = torch.matmul(f, g_x)
97 | y = y.permute(0, 2, 1).contiguous()
98 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
99 | y = self.W(y)
100 | z = y + x
101 |
102 | return z
103 |
104 |
105 | class NonLocalBlock1D(NonLocalBlockND):
106 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
107 | super(NonLocalBlock1D, self).__init__(in_channels,
108 | inter_channels=inter_channels,
109 | dimension=1, sub_sample=sub_sample,
110 | bn_layer=bn_layer)
111 |
112 |
113 | class NonLocalBlock2D(NonLocalBlockND):
114 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
115 | super(NonLocalBlock2D, self).__init__(in_channels,
116 | inter_channels=inter_channels,
117 | dimension=2, sub_sample=sub_sample,
118 | bn_layer=bn_layer)
119 |
120 |
121 | class NonLocalBlock3D(NonLocalBlockND):
122 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
123 | super(NonLocalBlock3D, self).__init__(in_channels,
124 | inter_channels=inter_channels,
125 | dimension=3, sub_sample=sub_sample,
126 | bn_layer=bn_layer)
127 |
--------------------------------------------------------------------------------
/modeling/backbones/SA/SelfAttn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | def conv1x1(in_planes,out_planes,nd=3,stride=1,bias=False):
7 | if nd == 3:
8 | return nn.Conv3d(in_planes,out_planes,kernel_size=1,stride=stride,bias=bias)
9 | elif nd == 2:
10 | return nn.Conv2d(in_planes,out_planes,kernel_size=1,stride=stride,bias=bias)
11 | else:
12 | raise NotImplementedError
13 |
14 | class AxialAttention(nn.Module):
15 | def __init__(self,in_channel,out_channels,groups=8, kernel_size=56,axial='height',
16 | bias=False,positional='no'):
17 | super(AxialAttention,self).__init__()
18 | self.in_channel = in_channel
19 | self.out_channels = out_channels
20 | self.groups = groups
21 | self.group_planes = out_channels // groups
22 | self.kernel_size = kernel_size
23 | self.axial = axial
24 | self.positional = positional
25 |
26 | self.qkv_transform = nn.Conv1d(in_channel,out_channels*2,kernel_size=1,stride=1,padding=0,bias=False)
27 | self.bn_qkv = nn.BatchNorm1d(out_channels*2)
28 | if self.positional == 'r_qkv':
29 | self.bn_similarity = nn.BatchNorm2d(groups*3)
30 | self.bn_output = nn.BatchNorm1d(out_channels*2)
31 | # positional embedding
32 | self.relative = nn.Parameter(torch.randn(self.group_planes*2,kernel_size*2-1),requires_grad=True)
33 | query_index = torch.arange(kernel_size).unsqueeze(0)
34 | key_index = torch.arange(kernel_size).unsqueeze(1)
35 | relative_index = key_index - query_index + kernel_size - 1
36 |
37 | self.register_buffer('flatten_index', relative_index.view(-1))
38 | elif self.positional == 'r_q':
39 | self.bn_similarity = nn.BatchNorm2d(groups*2)
40 | # positional embedding
41 | self.relative = nn.Parameter(torch.randn(self.group_planes//2,kernel_size*2-1),requires_grad=True)
42 | query_index = torch.arange(kernel_size).unsqueeze(0)
43 | key_index = torch.arange(kernel_size).unsqueeze(1)
44 | relative_index = key_index - query_index + kernel_size - 1
45 | self.register_buffer('flatten_index', relative_index.view(-1))
46 |
47 | self.bn_output = nn.BatchNorm1d(out_channels)
48 | else:
49 | self.bn_similarity = nn.BatchNorm2d(groups)
50 | self.bn_output = nn.BatchNorm1d(out_channels)
51 |
52 | self.reset_parameters()
53 |
54 | def reset_parameters(self):
55 | self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_channel))
56 | #nn.init.uniform_(self.relative, -0.1, 0.1)
57 | if 'r_' in self.positional:
58 | nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))
59 |
60 | def forward(self,x,vis=False):
61 | # x'shape : b,c,t,h,w
62 | if self.axial == 'width':
63 | x = x.permute(0,2,3,1,4) # b,t,h,c,w
64 | elif self.axial == 'temporal':
65 | x = x.permute(0,3,4,1,2) # b,h,w,c,t
66 | else:
67 | x = x.permute(0,2,4,1,3) # b,t,w,c,h
68 | B,D1,D2,C,H = x.shape
69 | x = x.contiguous().view(B*D1*D2,C,H)
70 |
71 | # input positnioal embedding
72 | if self.positional == 'input_sine':
73 | dim = torch.arange(C,dtype=torch.float32,device=x.device)
74 | dim = 1000 ** (2 * (dim//2) / C).view(1,C,1)
75 | code = torch.arange(H,dtype=torch.float32,device=x.device).view(1,1,H).repeat(B*D1*D2,C,1) / dim
76 | code = torch.stack([code[:,0::2,:].sin(),code[:,1::2,:].cos()],dim=2).reshape(B*D1*D2,C,H)
77 | x = x + code
78 |
79 | # Transformations
80 | qkv = self.bn_qkv(self.qkv_transform(x))
81 | q,k,v = torch.split(qkv.reshape(B*D1*D2,self.groups,self.group_planes*2,H),\
82 | [self.group_planes//2,self.group_planes//2,self.group_planes],dim=2)
83 |
84 | qk = torch.einsum('bgci, bgcj->bgij', q, k)
85 | # Calculate Positinal Embedding
86 | if self.positional == 'r_qkv':
87 | all_embeddings = torch.index_select(self.relative,1,self.flatten_index).view(\
88 | self.group_planes*2,self.kernel_size,self.kernel_size)
89 | q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, \
90 | [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
91 |
92 | qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
93 | kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
94 | stacked_similarity = torch.cat([qk, qr, kr], dim=1)
95 | stacked_similarity = self.bn_similarity(stacked_similarity).view(B*D1*D2, 3, self.groups, H, H).sum(dim=1)
96 |
97 | elif self.positional == 'r_q':
98 | q_embedding = torch.index_select(self.relative,1,self.flatten_index).view(\
99 | self.group_planes//2,self.kernel_size,self.kernel_size)
100 | qr = torch.einsum('bgci,cij->bgij',q,q_embedding)
101 | stacked_similarity = torch.cat([qk,qr],dim=1)
102 | stacked_similarity = self.bn_similarity(stacked_similarity).view(B*D1*D2,2,self.groups,H,H).sum(dim=1)
103 | else:
104 | stacked_similarity = self.bn_similarity(qk)
105 |
106 | similarity = F.softmax(stacked_similarity, dim=3)
107 | sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
108 | if self.positional == 'r_qkv':
109 | sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
110 | stacked_output = torch.cat([sv, sve], dim=-1).view(B*D1*D2, self.out_channels * 2, H)
111 | output = self.bn_output(stacked_output).view(B, D1, D2 , self.out_channels, 2, H).sum(dim=-2)
112 | else:
113 | stacked_output = sv.reshape(B*D1*D2,self.out_channels,H)
114 | output = self.bn_output(stacked_output).view(B,D1,D2,self.out_channels,H)
115 |
116 |
117 | if self.axial == 'width':
118 | output = output.permute(0,3,1,2,4)
119 | elif self.axial == 'temporal':
120 | output = output.permute(0,3,4,1,2)
121 | else:
122 | output = output.permute(0,3,1,4,2)
123 |
124 | if vis == True:
125 | return output,similarity
126 | return output
127 |
128 |
129 | class AxialBlock(nn.Module):
130 | def __init__(self,in_channel,inter_channel=None,groups=8,granularity=1,kernel_size=[],positional='r_qkv',order='hwt'):
131 | super(AxialBlock,self).__init__()
132 | self.inter_channel = inter_channel
133 | self.relu = nn.ReLU(inplace=True)
134 | self.bn2 = nn.BatchNorm3d(in_channel)
135 | self.order = order
136 | self.granularity = granularity
137 | if inter_channel is not None:
138 | self.conv_down = conv1x1(in_channel,inter_channel)
139 | self.bn1 = nn.BatchNorm3d(inter_channel)
140 | self.conv_up = conv1x1(inter_channel,in_channel)
141 | self.axial_channel = inter_channel
142 | else:
143 | self.conv_up = conv1x1(in_channel,in_channel)
144 | self.axial_channel = in_channel
145 | self.in_gran_channel = self.axial_channel//self.granularity
146 | self.axial_gran = []
147 | for i in range(self.granularity):
148 | gran_group = groups // self.granularity
149 | spatial_ratio = 2**i
150 | height_block = AxialAttention(self.in_gran_channel,self.in_gran_channel,groups=gran_group,kernel_size=kernel_size[0]//spatial_ratio,positional=positional)
151 | width_block = AxialAttention(self.in_gran_channel,self.in_gran_channel,groups=gran_group,axial='width',kernel_size=kernel_size[1]//spatial_ratio,positional=positional)
152 | temporal_block = AxialAttention(self.in_gran_channel,self.in_gran_channel,groups=gran_group,axial='temporal',kernel_size=kernel_size[2],positional=positional)
153 | self.axial_gran.append(height_block)
154 | self.axial_gran.append(width_block)
155 | self.axial_gran.append(temporal_block)
156 | self.axial_gran = nn.ModuleList(self.axial_gran)
157 |
158 | nn.init.constant_(self.bn2.weight,0)
159 | nn.init.constant_(self.bn2.bias,0)
160 |
161 | def forward(self,x):
162 | identity = x
163 | if self.inter_channel is not None:
164 | x = self.relu(self.bn1(self.conv_down(x)))
165 | gran_tensor_list = []
166 | for i in range(self.granularity):
167 | gran_tensor = x[:, i*(self.in_gran_channel):(i+1)*(self.in_gran_channel),...]
168 | B,C,T,H,W = gran_tensor.shape
169 | gran_tensor = F.adaptive_max_pool3d(gran_tensor,(T,H//(2**i),W//(2**i)))
170 | if self.order == 'hwt':
171 | gran_tensor,h_vis = self.axial_gran[i*3+0](gran_tensor,True)
172 | gran_tensor,w_vis = self.axial_gran[i*3+1](gran_tensor,True)
173 | gran_tensor,t_vis = self.axial_gran[i*3+2](gran_tensor,True)
174 | elif self.order == 'wht':
175 | gran_tensor = self.axial_gran[i*3+1](gran_tensor)
176 | gran_tensor = self.axial_gran[i*3+0](gran_tensor)
177 | gran_tensor = self.axial_gran[i*3+2](gran_tensor)
178 | elif self.order == 'wth':
179 | gran_tensor = self.axial_gran[i*3+1](gran_tensor)
180 | gran_tensor = self.axial_gran[i*3+2](gran_tensor)
181 | gran_tensor = self.axial_gran[i*3+0](gran_tensor)
182 | elif self.order == 'twh':
183 | gran_tensor = self.axial_gran[i*3+2](gran_tensor)
184 | gran_tensor = self.axial_gran[i*3+1](gran_tensor)
185 | gran_tensor = self.axial_gran[i*3+0](gran_tensor)
186 | else:
187 | raise NotImplementedError
188 | gran_tensor = F.interpolate(gran_tensor,size=(T,H,W))
189 | gran_tensor_list.append(gran_tensor)
190 | x = torch.cat(gran_tensor_list,dim=1)
191 | x = self.bn2(self.conv_up(x))
192 |
193 | out = identity+x
194 | return out
--------------------------------------------------------------------------------
/modeling/backbones/SA/__pycache__/AP3D.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/SA/__pycache__/AP3D.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/SA/__pycache__/NonLocal.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/SA/__pycache__/NonLocal.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/SA/__pycache__/SelfAttn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/SA/__pycache__/SelfAttn.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/SA/__pycache__/inflate.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/SA/__pycache__/inflate.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/SA/inflate.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | def inflate_conv(conv2d,
7 | time_dim=1,
8 | time_padding=0,
9 | time_stride=1,
10 | time_dilation=1,
11 | center=False):
12 | # To preserve activations, padding should be by continuity and not zero
13 | # or no padding in time dimension
14 | kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1])
15 | padding = (time_padding, conv2d.padding[0], conv2d.padding[1])
16 | stride = (time_stride, conv2d.stride[0], conv2d.stride[0])
17 | dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1])
18 | conv3d = nn.Conv3d(
19 | conv2d.in_channels,
20 | conv2d.out_channels,
21 | kernel_dim,
22 | padding=padding,
23 | dilation=dilation,
24 | stride=stride)
25 | # Repeat filter time_dim times along time dimension
26 | weight_2d = conv2d.weight.data
27 | if center:
28 | weight_3d = torch.zeros(*weight_2d.shape)
29 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
30 | middle_idx = time_dim // 2
31 | weight_3d[:, :, middle_idx, :, :] = weight_2d
32 | else:
33 | weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
34 | weight_3d = weight_3d / time_dim
35 |
36 | # Assign new params
37 | conv3d.weight = nn.Parameter(weight_3d)
38 | conv3d.bias = conv2d.bias
39 | return conv3d
40 |
41 |
42 | def inflate_linear(linear2d, time_dim):
43 | """
44 | Args:
45 | time_dim: final time dimension of the features
46 | """
47 | linear3d = nn.Linear(linear2d.in_features * time_dim,
48 | linear2d.out_features)
49 | weight3d = linear2d.weight.data.repeat(1, time_dim)
50 | weight3d = weight3d / time_dim
51 |
52 | linear3d.weight = nn.Parameter(weight3d)
53 | linear3d.bias = linear2d.bias
54 | return linear3d
55 |
56 |
57 | def inflate_batch_norm(batch2d):
58 | # In pytorch 0.2.0 the 2d and 3d versions of batch norm
59 | # work identically except for the check that verifies the
60 | # input dimensions
61 |
62 | batch3d = nn.BatchNorm3d(batch2d.num_features)
63 | # retrieve 3d _check_input_dim function
64 | batch2d._check_input_dim = batch3d._check_input_dim
65 | return batch2d
66 |
67 |
68 | def inflate_pool(pool2d,
69 | time_dim=1,
70 | time_padding=0,
71 | time_stride=None,
72 | time_dilation=1):
73 | kernel_dim = (time_dim, pool2d.kernel_size, pool2d.kernel_size)
74 | padding = (time_padding, pool2d.padding, pool2d.padding)
75 | if time_stride is None:
76 | time_stride = time_dim
77 | stride = (time_stride, pool2d.stride, pool2d.stride)
78 | if isinstance(pool2d, nn.MaxPool2d):
79 | dilation = (time_dilation, pool2d.dilation, pool2d.dilation)
80 | pool3d = nn.MaxPool3d(
81 | kernel_dim,
82 | padding=padding,
83 | dilation=dilation,
84 | stride=stride,
85 | ceil_mode=pool2d.ceil_mode)
86 | elif isinstance(pool2d, nn.AvgPool2d):
87 | pool3d = nn.AvgPool3d(kernel_dim, stride=stride)
88 | else:
89 | raise ValueError(
90 | '{} is not among known pooling classes'.format(type(pool2d)))
91 | return pool3d
92 |
--------------------------------------------------------------------------------
/modeling/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 |
3 |
--------------------------------------------------------------------------------
/modeling/backbones/__pycache__/ResNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/__pycache__/ResNet.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/__pycache__/non_local.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/__pycache__/non_local.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/__pycache__/resnet_NL.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/__pycache__/resnet_NL.cpython-36.pyc
--------------------------------------------------------------------------------
/modeling/backbones/non_local.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.nn import functional as F
3 | import numpy as np
4 | import os
5 | import torch
6 | from torch import nn
7 |
8 | class NonLocalBlock(nn.Module):
9 | def __init__(self, in_channels, inter_channels=None,sub_sample=False, bn_layer=True,instance='soft',groups=1):
10 | super(NonLocalBlock, self).__init__()
11 | self.sub_sample = sub_sample
12 | self.instance = instance
13 | self.in_channels = in_channels
14 | self.inter_channels = inter_channels
15 |
16 | if self.inter_channels is None:
17 | self.inter_channels = in_channels // 2
18 | if self.inter_channels == 0:
19 | self.inter_channels = 1
20 | self.groups = groups
21 | self.group_plane = self.inter_channels//self.groups
22 | ##### temporal operation in video re-id #####
23 | conv_nd = nn.Conv3d
24 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
25 | bn = nn.BatchNorm3d
26 | ##############################################
27 |
28 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
29 | kernel_size=1, stride=1, padding=0)
30 | if bn_layer:
31 | self.W = nn.Sequential(
32 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
33 | kernel_size=1, stride=1, padding=0),
34 | bn(self.in_channels)
35 | )
36 | nn.init.constant_(self.W[1].weight, 0)
37 | nn.init.constant_(self.W[1].bias, 0)
38 | else:
39 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
40 | kernel_size=1, stride=1, padding=0)
41 | nn.init.constant_(self.W.weight, 0)
42 | nn.init.constant_(self.W.bias, 0)
43 |
44 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
45 | kernel_size=1, stride=1, padding=0)
46 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
47 | kernel_size=1, stride=1, padding=0)
48 | if sub_sample:
49 | self.g = nn.Sequential(self.g, max_pool_layer)
50 | self.phi = nn.Sequential(self.phi, max_pool_layer)
51 |
52 | def forward(self, x):
53 | '''
54 | :param x: (b, c, t, h, w)
55 | :return:
56 | '''
57 | batch_size = x.size(0)
58 |
59 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
60 | g_x = g_x.permute(0, 2, 1) # shape : (b , THW, c')
61 |
62 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
63 | theta_x = theta_x.permute(0, 2, 1) # shape : (b, THW , c')
64 |
65 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) # shape : (b , c', THW)
66 |
67 | f = torch.matmul(theta_x, phi_x)
68 |
69 | if self.instance == 'soft':
70 | f_div_C = F.softmax(f, dim=-1)
71 | elif self.instance == 'dot':
72 | f_div_C = f / f.shape[1]
73 |
74 | y = torch.matmul(f_div_C, g_x)
75 | y = y.permute(0, 2, 1).contiguous() # shape : (b, c', THW)
76 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) # shape : (b, c', T, H, W)
77 |
78 | W_y = self.W(y) # shape : (b, c, t, h, w)
79 | z = W_y + x
80 |
81 | return z
82 |
--------------------------------------------------------------------------------
/modeling/backbones/resnet.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import math
3 |
4 | import torch
5 | from torch import nn
6 | from torchvision import models
7 |
8 |
9 | def conv3x3(in_planes, out_planes, stride=1):
10 | """3x3 convolution with padding"""
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=1, bias=False)
13 |
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 |
18 | def __init__(self, inplanes, planes, stride=1, downsample=None):
19 | super(BasicBlock, self).__init__()
20 | self.conv1 = conv3x3(inplanes, planes, stride)
21 | self.bn1 = nn.BatchNorm2d(planes)
22 | self.relu = nn.ReLU(inplace=True)
23 | self.conv2 = conv3x3(planes, planes)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 | self.downsample = downsample
26 | self.stride = stride
27 |
28 | def forward(self, x):
29 | residual = x
30 |
31 | out = self.conv1(x)
32 | out = self.bn1(out)
33 | out = self.relu(out)
34 |
35 | out = self.conv2(out)
36 | out = self.bn2(out)
37 |
38 | if self.downsample is not None:
39 | residual = self.downsample(x)
40 |
41 | out += residual
42 | out = self.relu(out)
43 |
44 | return out
45 |
46 |
47 | class Bottleneck(nn.Module):
48 | expansion = 4
49 |
50 | def __init__(self, inplanes, planes, stride=1, downsample=None):
51 | super(Bottleneck, self).__init__()
52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
53 | self.bn1 = nn.BatchNorm2d(planes)
54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
55 | padding=1, bias=False)
56 | self.bn2 = nn.BatchNorm2d(planes)
57 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
58 | self.bn3 = nn.BatchNorm2d(planes * 4)
59 | self.relu = nn.ReLU(inplace=True)
60 | self.downsample = downsample
61 | self.stride = stride
62 |
63 | def forward(self, x):
64 | residual = x
65 |
66 | out = self.conv1(x)
67 | out = self.bn1(out)
68 | out = self.relu(out)
69 |
70 | out = self.conv2(out)
71 | out = self.bn2(out)
72 | out = self.relu(out)
73 |
74 | out = self.conv3(out)
75 | out = self.bn3(out)
76 |
77 | if self.downsample is not None:
78 | residual = self.downsample(x)
79 |
80 | out += residual
81 | out = self.relu(out)
82 |
83 | return out
84 |
85 |
86 | class ResNet(nn.Module):
87 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]):
88 | self.inplanes = 64
89 | super().__init__()
90 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
91 | bias=False)
92 | self.bn1 = nn.BatchNorm2d(64)
93 | self.relu = nn.ReLU(inplace=True) # add missed relu
94 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
95 | self.layer1 = self._make_layer(block, 64, layers[0])
96 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
97 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
98 | self.layer4 = self._make_layer(
99 | block, 512, layers[3], stride=last_stride)
100 |
101 | def _make_layer(self, block, planes, blocks, stride=1):
102 | downsample = None
103 | if stride != 1 or self.inplanes != planes * block.expansion:
104 | downsample = nn.Sequential(
105 | nn.Conv2d(self.inplanes, planes * block.expansion,
106 | kernel_size=1, stride=stride, bias=False),
107 | nn.BatchNorm2d(planes * block.expansion),
108 | )
109 |
110 | layers = []
111 | layers.append(block(self.inplanes, planes, stride, downsample))
112 | self.inplanes = planes * block.expansion
113 | for i in range(1, blocks):
114 | layers.append(block(self.inplanes, planes))
115 |
116 | return nn.Sequential(*layers)
117 |
118 | def forward(self, x):
119 | x = self.conv1(x)
120 | x = self.bn1(x)
121 | # x = self.relu(x) # add missed relu
122 | x = self.maxpool(x)
123 |
124 | x = self.layer1(x)
125 | x = self.layer2(x)
126 | x = self.layer3(x)
127 | x = self.layer4(x)
128 |
129 | return x
130 |
131 | def load_param(self, model_path,autoload=None):
132 | if autoload == 'r50':
133 | param_dict = models.resnet50(pretrained=True).state_dict()
134 | else:
135 | param_dict = torch.load(model_path)
136 | for i in param_dict:
137 | if 'fc' in i:
138 | continue
139 | self.state_dict()[i].copy_(param_dict[i])
140 | def random_init(self):
141 | for m in self.modules():
142 | if isinstance(m, nn.Conv2d):
143 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
144 | m.weight.data.normal_(0, math.sqrt(2. / n))
145 | elif isinstance(m, nn.BatchNorm2d):
146 | m.weight.data.fill_(1)
147 | m.bias.data.zero_()
148 |
149 |
--------------------------------------------------------------------------------
/modeling/backbones/resnet_NL.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import math
3 |
4 | import torch
5 | from torch import nn
6 | from torchvision import models
7 | from .non_local import NonLocalBlock
8 |
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | """3x3 convolution with padding"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None):
20 | super(BasicBlock, self).__init__()
21 | self.conv1 = conv3x3(inplanes, planes, stride)
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.relu = nn.ReLU(inplace=True)
24 | self.conv2 = conv3x3(planes, planes)
25 | self.bn2 = nn.BatchNorm2d(planes)
26 | self.downsample = downsample
27 | self.stride = stride
28 |
29 | def forward(self, x):
30 | residual = x
31 |
32 | out = self.conv1(x)
33 | out = self.bn1(out)
34 | out = self.relu(out)
35 |
36 | out = self.conv2(out)
37 | out = self.bn2(out)
38 |
39 | if self.downsample is not None:
40 | residual = self.downsample(x)
41 |
42 | out += residual
43 | out = self.relu(out)
44 |
45 | return out
46 |
47 |
48 | class Bottleneck(nn.Module):
49 | expansion = 4
50 |
51 | def __init__(self, inplanes, planes, stride=1, downsample=None):
52 | super(Bottleneck, self).__init__()
53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
54 | self.bn1 = nn.BatchNorm2d(planes)
55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
56 | padding=1, bias=False)
57 | self.bn2 = nn.BatchNorm2d(planes)
58 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
59 | self.bn3 = nn.BatchNorm2d(planes * 4)
60 | self.relu = nn.ReLU(inplace=True)
61 | self.downsample = downsample
62 | self.stride = stride
63 |
64 | def forward(self, x):
65 | residual = x
66 |
67 | out = self.conv1(x)
68 | out = self.bn1(out)
69 | out = self.relu(out)
70 |
71 | out = self.conv2(out)
72 | out = self.bn2(out)
73 | out = self.relu(out)
74 |
75 | out = self.conv3(out)
76 | out = self.bn3(out)
77 |
78 | if self.downsample is not None:
79 | residual = self.downsample(x)
80 |
81 | out += residual
82 | out = self.relu(out)
83 |
84 | return out
85 |
86 |
87 | class ResNet_NL(nn.Module):
88 | def __init__(self, last_stride=1, block=Bottleneck, layers=[3, 4, 6, 3],non_layers=[0,2,3,0]):
89 | self.inplanes = 64
90 | super().__init__()
91 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
92 | self.bn1 = nn.BatchNorm2d(64)
93 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
94 | #### layer 1 ####
95 | self.layer1 = self._make_layer(block, 64, layers[0])
96 | NL_1 = [NonLocalBlock(self.inplanes,sub_sample=True) for i in range(non_layers[0])]
97 | self.NL_1 = nn.ModuleList(NL_1)
98 | self.NL_1_idx = sorted([layers[0]-(i+1) for i in range(non_layers[0])])
99 | #### layer 2 ####
100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
101 | NL_2 = [NonLocalBlock(self.inplanes) for i in range(non_layers[1])]
102 | self.NL_2 = nn.ModuleList(NL_2)
103 | self.NL_2_idx = sorted([layers[1]-(i+1) for i in range(non_layers[1])])
104 | #### layer 3 ####
105 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
106 | NL_3 = [NonLocalBlock(self.inplanes) for i in range(non_layers[2])]
107 | self.NL_3 = nn.ModuleList(NL_3)
108 | self.NL_3_idx = sorted([layers[2]-(i+1) for i in range(non_layers[2])])
109 | #### layer 4 ####
110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride)
111 | NL_4 = [NonLocalBlock(self.inplanes) for i in range(non_layers[3])]
112 | self.NL_4 = nn.ModuleList(NL_4)
113 | self.NL_4_idx = sorted([layers[3]-(i+1) for i in range(non_layers[3])])
114 |
115 | def _make_layer(self, block, planes, blocks, stride=1):
116 | downsample = None
117 | if stride != 1 or self.inplanes != planes * block.expansion:
118 | downsample = nn.Sequential(
119 | nn.Conv2d(self.inplanes, planes * block.expansion,
120 | kernel_size=1, stride=stride, bias=False),
121 | nn.BatchNorm2d(planes * block.expansion),
122 | )
123 |
124 | layers = []
125 | layers.append(block(self.inplanes, planes, stride, downsample))
126 | self.inplanes = planes * block.expansion
127 | for i in range(1, blocks):
128 | layers.append(block(self.inplanes, planes))
129 |
130 | return nn.ModuleList(layers)
131 |
132 | def forward(self, x):
133 | b,t,c,h,w = x.shape
134 | x = self.conv1(x.view(b*t,c,h,w))
135 | x = self.bn1(x)
136 | x = self.maxpool(x)
137 |
138 | NL1_counter = 0
139 | if len(self.NL_1_idx)== 0 : self.NL_1_idx=[-1]
140 | for i in range(len(self.layer1)):
141 | x = self.layer1[i](x)
142 | if i == self.NL_1_idx[NL1_counter]:
143 | _,c,h,w = x.shape
144 | x = self.NL_1[NL1_counter](x.view(b,t,c,h,w).permute(0,2,1,3,4))
145 | x = x.permute(0,2,1,3,4).reshape(b*t,c,h,w)
146 | NL1_counter += 1
147 |
148 | NL2_counter = 0
149 | if len(self.NL_2_idx)== 0 : self.NL_2_idx=[-1]
150 | for i in range(len(self.layer2)):
151 | x = self.layer2[i](x)
152 | if i == self.NL_2_idx[NL2_counter]:
153 | _,c,h,w = x.shape
154 | x = self.NL_2[NL2_counter](x.view(b,t,c,h,w).permute(0,2,1,3,4))
155 | x = x.permute(0,2,1,3,4).reshape(b*t,c,h,w)
156 | NL2_counter += 1
157 | NL3_counter = 0
158 | if len(self.NL_3_idx)== 0 : self.NL_3_idx=[-1]
159 | for i in range(len(self.layer3)):
160 | x = self.layer3[i](x)
161 | if i == self.NL_3_idx[NL3_counter]:
162 | _,c,h,w = x.shape
163 | x = self.NL_3[NL3_counter](x.view(b,t,c,h,w).permute(0,2,1,3,4))
164 | x = x.permute(0,2,1,3,4).reshape(b*t,c,h,w)
165 | NL3_counter += 1
166 | NL4_counter = 0
167 | if len(self.NL_4_idx)== 0 : self.NL_4_idx=[-1]
168 | for i in range(len(self.layer4)):
169 | x = self.layer4[i](x)
170 | if i == self.NL_4_idx[NL4_counter]:
171 | _,c,h,w = x.shape
172 | x = self.NL_4[NL4_counter](x.view(b,t,c,h,w).permute(0,2,1,3,4))
173 | x = x.permute(0,2,1,3,4).reshape(b*t,c,h,w)
174 | NL4_counter += 1
175 | return x
176 |
177 | def load_param(self, model_path,autoload=None):
178 | if autoload == 'r50':
179 | param_dict = models.resnet50(pretrained=True).state_dict()
180 | else:
181 | param_dict = torch.load(model_path)
182 | for i in param_dict:
183 | if 'fc' in i:
184 | continue
185 | self.state_dict()[i].copy_(param_dict[i])
186 |
187 | def random_init(self):
188 | for m in self.modules():
189 | if isinstance(m, nn.Conv2d):
190 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
191 | m.weight.data.normal_(0, math.sqrt(2. / n))
192 | elif isinstance(m, nn.BatchNorm2d):
193 | m.weight.data.fill_(1)
194 | m.bias.data.zero_()
195 |
196 |
--------------------------------------------------------------------------------
/modeling/network.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 | from .backbones.resnet import ResNet, BasicBlock, Bottleneck
7 | from .backbones.resnet_NL import ResNet_NL
8 | from .backbones.ResNet import C2D_Axial_ResNet50
9 |
10 |
11 | def weights_init_kaiming(m):
12 | classname = m.__class__.__name__
13 | if classname.find('Linear') != -1:
14 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
15 | nn.init.constant_(m.bias, 0.0)
16 | elif classname.find('Conv') != -1:
17 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
18 | if m.bias is not None:
19 | nn.init.constant_(m.bias, 0.0)
20 | elif classname.find('BatchNorm') != -1:
21 | if m.affine:
22 | nn.init.constant_(m.weight, 1.0)
23 | nn.init.constant_(m.bias, 0.0)
24 |
25 |
26 | def weights_init_classifier(m):
27 | classname = m.__class__.__name__
28 | if classname.find('Linear') != -1:
29 | nn.init.normal_(m.weight, std=0.001)
30 | if m.bias:
31 | nn.init.constant_(m.bias, 0.0)
32 |
33 |
34 | class VNetwork(nn.Module):
35 | in_planes = 2048
36 |
37 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice,temp,\
38 | non_layers=[0,0,0,0], seq_len=6):
39 | super(VNetwork, self).__init__()
40 | self.seq_len = seq_len
41 | if model_name == 'resnet50':
42 | self.base = ResNet(last_stride=last_stride,
43 | block=Bottleneck,
44 | layers=[3, 4, 6, 3])
45 | elif model_name == 'resnet50_NL':
46 | self.base = ResNet_NL(last_stride=last_stride,block=Bottleneck,
47 | layers=[3,4,6,3],non_layers=non_layers)
48 | elif model_name == 'resnet50_axial':
49 | self.base = C2D_Axial_ResNet50(seq_len=seq_len)
50 |
51 | if pretrain_choice == 'imagenet':
52 | if 'axial' not in model_name:
53 | self.base.load_param('',autoload='r50')
54 | print('Loading pretrained ImageNet model......')
55 |
56 | self.gap = nn.AdaptiveAvgPool2d(1)
57 | self.gmp = nn.AdaptiveMaxPool2d(1)
58 | self.num_classes = num_classes
59 | self.neck = neck
60 | self.neck_feat = neck_feat
61 | self.temp = temp
62 | self.model_name = model_name
63 |
64 | self.bottleneck = nn.BatchNorm1d(self.in_planes)
65 | self.bottleneck.bias.requires_grad_(False) # no shift
66 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
67 |
68 | self.bottleneck.apply(weights_init_kaiming)
69 | self.classifier.apply(weights_init_classifier)
70 |
71 | def forward(self, x,masks=None):
72 | b,t,c,h,w = x.shape
73 |
74 | if 'NL' in self.model_name:
75 | if self.temp == 'Done':
76 | x = self.base(x)
77 | _,c,h,w = x.shape
78 | if masks is not None:
79 | global_feat = []
80 | masks = masks.reshape(b*t,-1)
81 | for i in range(x.shape[0]):
82 | global_feat.append(self.gap(x[i,:,masks[i][0]:masks[i][1],masks[i][2]:masks[i][3]].unsqueeze(0)))
83 | global_feat = torch.cat(global_feat,dim=0)
84 | global_feat = torch.mean(global_feat.view(b,t,-1),dim=1)
85 | else:
86 | global_feat = F.adaptive_avg_pool3d(x.view(b,t,c,h,w).permute(0,2,1,3,4),1)
87 | global_feat = global_feat.view(b,-1)
88 | else:
89 | global_feat = self.gap(self.base(x))
90 | global_feat = global_feat.view(b*t,-1) # flatten to (b*t, 2048)
91 |
92 | elif 'axial' in self.model_name:
93 | if masks is not None:
94 | global_feat = []
95 | masks = masks.reshape(b*t,-1)
96 | output = self.base(x.permute(0,2,1,3,4).contiguous()).permute(0,2,1,3,4).contiguous()
97 | b,t,c,h,w = output.shape
98 | output = output.view(b*t,c,h,w)
99 | for i in range(len(output)):
100 | global_feat.append(self.gap(output[i,:,masks[i][0]:masks[i][1],masks[i][2]:masks[i][3]].unsqueeze(0)))
101 | global_feat = torch.cat(global_feat,dim=0)
102 | global_feat = torch.mean(global_feat.view(b,t,-1),dim=1)
103 | else:
104 | global_feat = self.base(x.permute(0,2,1,3,4).contiguous()).permute(0,2,1,3,4).contiguous()
105 | b,t,c,h,w = global_feat.shape
106 | global_feat = self.gap(global_feat.view(b*t,c,h,w))
107 | global_feat = torch.mean(global_feat.view(b,t,-1),dim=1)
108 | else:
109 | if masks is not None:
110 | global_feat = []
111 | masks = masks.reshape(b*t,-1)
112 | output = self.base(x.view(b*t,c,h,w))
113 | for i in range(len(output)):
114 | global_feat.append(self.gap(output[i,:,masks[i][0]:masks[i][1],masks[i][2]:masks[i][3]].unsqueeze(0)))
115 | global_feat = torch.cat(global_feat,dim=0)
116 | else:
117 | global_feat = self.gap(self.base(x.view(b*t,c,h,w))) # (b*t, 2048, 1, 1)
118 | global_feat = global_feat.view(b*t,-1) # flatten to (b*t, 2048)
119 |
120 | #### whether neck ####
121 | feat = self.bottleneck(global_feat) # normalize for angular softmax
122 |
123 | if self.training:
124 | cls_score = self.classifier(feat)
125 | if self.temp == 'avg':
126 | global_feat = torch.mean(global_feat.view(b,t,-1),dim=1)
127 | cls_score = torch.mean(cls_score.view(b,t,-1),dim=1)
128 |
129 | return cls_score, global_feat # global feature for triplet loss
130 | else:
131 | if self.temp == 'avg':
132 | global_feat = torch.mean(global_feat.view(b,t,-1),dim=1)
133 | feat = torch.mean(feat.view(b,t,-1),dim=1)
134 | if self.neck_feat == 'after':
135 | return feat
136 | else:
137 | return global_feat
138 |
139 | def load_param(self, trained_path,con=False):
140 | param_dict = torch.load(trained_path)['model']
141 | for i in param_dict:
142 | if 'classifier' in i and con == False:
143 | continue
144 | if 'bn_similarity' in i:
145 | if 'num' in i:
146 | self.state_dict()[i].copy_(param_dict[i])
147 | else:
148 | self.state_dict()[i][:param_dict[i].shape[0]].copy_(param_dict[i])
149 | elif 'bn_output' in i :
150 | if 'num' in i:
151 | self.state_dict()[i].copy_(param_dict[i])
152 | else:
153 | self.state_dict()[i][:param_dict[i].shape[0]].copy_(param_dict[i])
154 | else:
155 | self.state_dict()[i].copy_(param_dict[i])
156 |
157 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.5.1
2 | torchvision==0.6.1
3 | scipy==1.5.2
4 | pandas
5 | numpy
6 | Pillow==8.0.0
7 | pytorch-ignite==0.4.2
8 | yacs
9 | tqdm
--------------------------------------------------------------------------------
/scripts/AA_D.sh:
--------------------------------------------------------------------------------
1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('dukev',)" INPUT.SEQ_LEN 6 \
2 | OUTPUT_DIR "./ckpt_DL_duke/Duke_DL_s6_resnet_axial_gap_sine_gran2" SOLVER.SOFT_MARGIN True \
3 | MODEL.NAME 'resnet50_axial' MODEL.TEMP 'Done' MODEL.IF_LABELSMOOTH 'no' INPUT.IF_RE True \
4 | DATASETS.ROOT_DIR '/home/mediax/Dataset/' TEST.NEW_EVAL False
5 |
--------------------------------------------------------------------------------
/scripts/AA_M.sh:
--------------------------------------------------------------------------------
1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('mars',)" INPUT.SEQ_LEN 6 \
2 | OUTPUT_DIR "./ckpt_DL_M/MARS_DL_s6_resnet_axial_gap_rqkv_gran4" SOLVER.SOFT_MARGIN True \
3 | MODEL.NAME 'resnet50_axial' MODEL.TEMP 'Done' MODEL.IF_LABELSMOOTH 'no' INPUT.IF_RE True \
4 | DATASETS.ROOT_DIR '/work/sychien421/Dataset/'
5 |
--------------------------------------------------------------------------------
/scripts/NL_D.sh:
--------------------------------------------------------------------------------
1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('dukev',)" \
2 | OUTPUT_DIR "./ckpt_DL_duke/Duke_DL_s6_NL0230_2gpu" SOLVER.SOFT_MARGIN True \
3 | MODEL.NON_LAYERS [0,2,3,0] INPUT.IF_RE True INPUT.IF_CROP False MODEL.IF_LABELSMOOTH 'no' \
4 | MODEL.NAME 'resnet50_NL' INPUT.SEQ_LEN 6 MODEL.TEMP 'Done' \
5 | DATASETS.ROOT_DIR '/work/sychien421/Dataset/'
6 |
--------------------------------------------------------------------------------
/scripts/NL_M.sh:
--------------------------------------------------------------------------------
1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('1,2')" DATASETS.NAMES "('mars',)" \
2 | OUTPUT_DIR "./ckpt_DL_M/MARS_DL_s6_NL0230" SOLVER.SOFT_MARGIN True \
3 | MODEL.NON_LAYERS [0,2,3,0] INPUT.IF_RE True INPUT.IF_CROP False MODEL.IF_LABELSMOOTH 'no' \
4 | MODEL.NAME 'resnet50_NL' INPUT.SEQ_LEN 6 MODEL.TEMP 'Done' \
5 | DATASETS.ROOT_DIR '/home/mediax/Dataset/'
6 |
--------------------------------------------------------------------------------
/scripts/baseline_D.sh:
--------------------------------------------------------------------------------
1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('dukev',)" \
2 | OUTPUT_DIR "./ckpt_DL_duke/Duke_avgpool" SOLVER.SOFT_MARGIN True \
3 | MODEL.IF_LABELSMOOTH 'no' INPUT.IF_RE True INPUT.SEQ_LEN 6 \
4 | DATASETS.ROOT_DIR '/work/sychien421/Dataset/'
5 |
--------------------------------------------------------------------------------
/scripts/baseline_M.sh:
--------------------------------------------------------------------------------
1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('mars',)" \
2 | OUTPUT_DIR "./ckpt_DL_M/MARS_DL_avgpool_s6" SOLVER.SOFT_MARGIN True \
3 | MODEL.IF_LABELSMOOTH 'no' INPUT.IF_RE True INPUT.SEQ_LEN 6 TEST.NEW_EVAL False \
4 | DATASETS.ROOT_DIR '/home/mediax/Dataset/'
5 |
--------------------------------------------------------------------------------
/scripts/test_M.sh:
--------------------------------------------------------------------------------
1 | # for testing mode.
2 | # (1) TEST.TEST_MODE 'test' (using RRS to sample)
3 | # (2) TEST.TEST_MODE 'test_0' (first T images)
4 | # (3) TEST.TEST_MODE 'test_all_sampled' (using RRS to sample T,average the N/T tracklets)
5 | # (4) TEST.TEST_MODE 'test_all_continuous' (continuous smaple T frames, average the N/T tracklets)
6 |
7 | python3 tools/test.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('mars',)" MODEL.NON_LAYERS [0,2,3,0] \
8 | MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('/home/xxxx/xxxx.pth')" \
9 | MODEL.NAME 'resnet50_axial' INPUT.SEQ_LEN 6 MODEL.TEMP 'Done' TEST.TEST_MODE 'test_all_sampled' TEST.IMS_PER_BATCH 1 \
10 | DATASETS.ROOT_DIR '/work/sychien421/Dataset/' TEST.NEW_EVAL False
11 |
--------------------------------------------------------------------------------
/solver/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | from .build import make_optimizer, make_optimizer_with_center
3 | from .lr_scheduler import WarmupMultiStepLR
4 | from torch.optim.lr_scheduler import StepLR,MultiStepLR
--------------------------------------------------------------------------------
/solver/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/solver/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/__pycache__/build.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/solver/__pycache__/build.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/__pycache__/lr_scheduler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/solver/__pycache__/lr_scheduler.cpython-36.pyc
--------------------------------------------------------------------------------
/solver/build.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import torch
3 |
4 |
5 | def make_optimizer(cfg, model):
6 | params = []
7 | for key, value in model.named_parameters():
8 | if not value.requires_grad:
9 | continue
10 | lr = cfg.SOLVER.BASE_LR
11 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
12 | if "bias" in key:
13 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
14 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
15 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
16 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
17 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
18 | else:
19 | # We use Adam
20 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(model.parameters(),lr = cfg.SOLVER.BASE_LR,weight_decay=cfg.SOLVER.WEIGHT_DECAY)
21 | # optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
22 | return optimizer
23 |
24 |
25 | def make_optimizer_with_center(cfg, model, center_criterion):
26 | params = []
27 | for key, value in model.named_parameters():
28 | if not value.requires_grad:
29 | continue
30 | lr = cfg.SOLVER.BASE_LR
31 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
32 | if "bias" in key:
33 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
34 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
35 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
36 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
37 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
38 | else:
39 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
40 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR)
41 | return optimizer, optimizer_center
42 |
--------------------------------------------------------------------------------
/solver/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | from bisect import bisect_right
7 | import torch
8 |
9 |
10 | # FIXME ideally this would be achieved with a CombinedLRScheduler,
11 | # separating MultiStepLR with WarmupLR
12 | # but the current LRScheduler design doesn't allow it
13 |
14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
15 | def __init__(
16 | self,
17 | optimizer,
18 | milestones,
19 | gamma=0.1,
20 | warmup_factor=1.0 / 3,
21 | warmup_iters=500,
22 | warmup_method="linear",
23 | last_epoch=-1,
24 | ):
25 | if not list(milestones) == sorted(milestones):
26 | raise ValueError(
27 | "Milestones should be a list of" " increasing integers. Got {}",
28 | milestones,
29 | )
30 |
31 | if warmup_method not in ("constant", "linear"):
32 | raise ValueError(
33 | "Only 'constant' or 'linear' warmup_method accepted"
34 | "got {}".format(warmup_method)
35 | )
36 | self.milestones = milestones
37 | self.gamma = gamma
38 | self.warmup_factor = warmup_factor
39 | self.warmup_iters = warmup_iters
40 | self.warmup_method = warmup_method
41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
42 |
43 | def get_lr(self):
44 | warmup_factor = 1
45 | if self.last_epoch < self.warmup_iters:
46 | if self.warmup_method == "constant":
47 | warmup_factor = self.warmup_factor
48 | elif self.warmup_method == "linear":
49 | alpha = self.last_epoch / self.warmup_iters
50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
51 | return [
52 | base_lr
53 | * warmup_factor
54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
55 | for base_lr in self.base_lrs
56 | ]
57 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
--------------------------------------------------------------------------------
/tests/lr_scheduler_test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import unittest
3 |
4 | import torch
5 | from torch import nn
6 |
7 | sys.path.append('.')
8 | from solver.lr_scheduler import WarmupMultiStepLR
9 | from solver.build import make_optimizer
10 | from config import cfg
11 |
12 |
13 | class MyTestCase(unittest.TestCase):
14 | def test_something(self):
15 | net = nn.Linear(10, 10)
16 | optimizer = make_optimizer(cfg, net)
17 | lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10)
18 | for i in range(50):
19 | lr_scheduler.step()
20 | for j in range(3):
21 | print(i, lr_scheduler.get_lr()[0])
22 | optimizer.step()
23 |
24 |
25 | if __name__ == '__main__':
26 | unittest.main()
27 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 |
--------------------------------------------------------------------------------
/tools/test.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import argparse
3 | import os
4 | import sys
5 | from os import mkdir
6 |
7 | import torch
8 | from torch.backends import cudnn
9 |
10 | sys.path.append('.')
11 | from config import cfg
12 | from data import make_data_loader
13 | from engine.inference import inference
14 | from modeling import build_model
15 | from utils.logger import setup_logger
16 |
17 |
18 | def main():
19 | parser = argparse.ArgumentParser(description="Video-based ReID Baseline Inference")
20 | parser.add_argument(
21 | "--config_file", default="", help="path to config file", type=str
22 | )
23 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
24 | nargs=argparse.REMAINDER)
25 |
26 | args = parser.parse_args()
27 |
28 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
29 |
30 | if args.config_file != "":
31 | cfg.merge_from_file(args.config_file)
32 | cfg.merge_from_list(args.opts)
33 | cfg.freeze()
34 |
35 | output_dir = cfg.OUTPUT_DIR
36 | if output_dir and not os.path.exists(output_dir):
37 | mkdir(output_dir)
38 |
39 | logger = setup_logger("reid_baseline", output_dir, 0)
40 | logger.info("Using {} GPUS".format(num_gpus))
41 | logger.info(args)
42 |
43 | if args.config_file != "":
44 | logger.info("Loaded configuration file {}".format(args.config_file))
45 | with open(args.config_file, 'r') as cf:
46 | config_str = "\n" + cf.read()
47 | logger.debug(config_str)
48 | logger.info("Running with config:\n{}".format(cfg))
49 |
50 | if cfg.MODEL.DEVICE == "cuda":
51 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID
52 | cudnn.benchmark = True
53 |
54 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
55 | model = build_model(cfg, num_classes)
56 | model.load_param(cfg.TEST.WEIGHT)
57 |
58 | inference(cfg, model, val_loader, num_query)
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import argparse
3 | import os
4 | import sys
5 | import torch
6 |
7 | from torch.backends import cudnn
8 |
9 | sys.path.append('.')
10 | from config import cfg
11 | from data import make_data_loader
12 | from engine.trainer import do_train, do_train_with_center
13 | from modeling import build_model
14 | from layers import make_loss, make_loss_with_center
15 | from solver import make_optimizer, make_optimizer_with_center, WarmupMultiStepLR, StepLR, MultiStepLR
16 |
17 | from utils.logger import setup_logger
18 |
19 |
20 | def train(cfg):
21 | # prepare dataset
22 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
23 | # prepare model
24 | model = build_model(cfg, num_classes)
25 |
26 | if cfg.MODEL.IF_WITH_CENTER == 'no':
27 | print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
28 | optimizer = make_optimizer(cfg, model)
29 |
30 | loss_func = make_loss(cfg, num_classes)
31 |
32 | # Add for using self trained model
33 | if cfg.MODEL.PRETRAIN_CHOICE == 'self':
34 | raise NotImplementedError()
35 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1])
36 | print('Start epoch:', start_epoch)
37 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
38 | print('Path to the checkpoint of optimizer:', path_to_optimizer)
39 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
40 | optimizer.load_state_dict(torch.load(path_to_optimizer))
41 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
42 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)
43 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
44 | start_epoch = 0
45 | #scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
46 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
47 | #scheduler = StepLR(optimizer, 50, cfg.SOLVER.GAMMA)
48 | scheduler = MultiStepLR(optimizer, cfg.SOLVER.STEPS , cfg.SOLVER.GAMMA)
49 | else:
50 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE))
51 |
52 | arguments = {}
53 | do_train(
54 | cfg,
55 | model,
56 | train_loader,
57 | val_loader,
58 | optimizer,
59 | scheduler, # modify for using self trained model
60 | loss_func,
61 | num_query,
62 | start_epoch # add for using self trained model
63 | )
64 | elif cfg.MODEL.IF_WITH_CENTER == 'yes':
65 | print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
66 | loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu
67 | optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion)
68 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
69 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
70 |
71 | arguments = {}
72 |
73 | # Add for using self trained model
74 | if cfg.MODEL.PRETRAIN_CHOICE == 'self':
75 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1])
76 | print('Start epoch:', start_epoch)
77 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
78 | print('Path to the checkpoint of optimizer:', path_to_optimizer)
79 | path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param')
80 | print('Path to the checkpoint of center_param:', path_to_center_param)
81 | path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center')
82 | print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center)
83 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
84 | optimizer.load_state_dict(torch.load(path_to_optimizer))
85 | center_criterion.load_state_dict(torch.load(path_to_center_param))
86 | optimizer_center.load_state_dict(torch.load(path_to_optimizer_center))
87 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
88 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)
89 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
90 | start_epoch = 0
91 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
92 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
93 | else:
94 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE))
95 |
96 | do_train_with_center(
97 | cfg,
98 | model,
99 | center_criterion,
100 | train_loader,
101 | val_loader,
102 | optimizer,
103 | optimizer_center,
104 | scheduler, # modify for using self trained model
105 | loss_func,
106 | num_query,
107 | start_epoch # add for using self trained model
108 | )
109 | else:
110 | print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER))
111 |
112 |
113 | def main():
114 | parser = argparse.ArgumentParser(description="Video-based ReID Training")
115 | parser.add_argument(
116 | "--config_file", default="", help="path to config file", type=str
117 | )
118 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
119 | nargs=argparse.REMAINDER)
120 |
121 | args = parser.parse_args()
122 |
123 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
124 |
125 | if args.config_file != "":
126 | cfg.merge_from_file(args.config_file)
127 | cfg.merge_from_list(args.opts)
128 | cfg.freeze()
129 |
130 | output_dir = cfg.OUTPUT_DIR
131 | if output_dir and not os.path.exists(output_dir):
132 | os.makedirs(output_dir)
133 |
134 | logger = setup_logger("reid_baseline", output_dir, 0)
135 | logger.info("Using {} GPUS".format(num_gpus))
136 | logger.info(args)
137 |
138 | if args.config_file != "":
139 | logger.info("Loaded configuration file {}".format(args.config_file))
140 | with open(args.config_file, 'r') as cf:
141 | config_str = "\n" + cf.read()
142 | logger.debug(config_str)
143 | logger.info("Running with config:\n{}".format(cfg))
144 |
145 | if cfg.MODEL.DEVICE == "cuda":
146 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu
147 | cudnn.benchmark = True
148 | train(cfg)
149 |
150 |
151 | if __name__ == '__main__':
152 | main()
153 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 |
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/iotools.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/iotools.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/iotools.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/iotools.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/re_ranking.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/re_ranking.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/reid_metric.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/reid_metric.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/iotools.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import errno
8 | import json
9 | import os
10 |
11 | import os.path as osp
12 |
13 |
14 | def mkdir_if_missing(directory):
15 | if not osp.exists(directory):
16 | try:
17 | os.makedirs(directory)
18 | except OSError as e:
19 | if e.errno != errno.EEXIST:
20 | raise
21 |
22 |
23 | def check_isfile(path):
24 | isfile = osp.isfile(path)
25 | if not isfile:
26 | print("=> Warning: no file found at '{}' (ignored)".format(path))
27 | return isfile
28 |
29 |
30 | def read_json(fpath):
31 | with open(fpath, 'r') as f:
32 | obj = json.load(f)
33 | return obj
34 |
35 |
36 | def write_json(obj, fpath):
37 | mkdir_if_missing(osp.dirname(fpath))
38 | with open(fpath, 'w') as f:
39 | json.dump(obj, f, indent=4, separators=(',', ': '))
40 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 |
3 | import logging
4 | import os
5 | import sys
6 |
7 |
8 | def setup_logger(name, save_dir, distributed_rank):
9 | logger = logging.getLogger(name)
10 | logger.setLevel(logging.DEBUG)
11 | # don't log results for the non-master process
12 | if distributed_rank > 0:
13 | return logger
14 | ch = logging.StreamHandler(stream=sys.stdout)
15 | ch.setLevel(logging.INFO)
16 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
17 | ch.setFormatter(formatter)
18 | logger.addHandler(ch)
19 |
20 | if save_dir:
21 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w')
22 | fh.setLevel(logging.DEBUG)
23 | fh.setFormatter(formatter)
24 | logger.addHandler(fh)
25 |
26 | return logger
27 |
--------------------------------------------------------------------------------
/utils/re_ranking.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Fri, 25 May 2018 20:29:09
5 |
6 | @author: luohao
7 | """
8 |
9 | """
10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking
13 | """
14 |
15 | """
16 | API
17 |
18 | probFea: all feature vectors of the query set (torch tensor)
19 | probFea: all feature vectors of the gallery set (torch tensor)
20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3)
21 | MemorySave: set to 'True' when using MemorySave mode
22 | Minibatch: avaliable when 'MemorySave' is 'True'
23 | """
24 |
25 | import numpy as np
26 | import torch
27 |
28 |
29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False):
30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor
31 | query_num = probFea.size(0)
32 | all_num = query_num + galFea.size(0)
33 | if only_local:
34 | original_dist = local_distmat
35 | else:
36 | feat = torch.cat([probFea,galFea])
37 | print('using GPU to compute original distance')
38 | distmat = torch.pow(feat,2).sum(dim=1, keepdim=True).expand(all_num,all_num) + \
39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t()
40 | distmat.addmm_(1,-2,feat,feat.t())
41 | original_dist = distmat.cpu().numpy()
42 | del feat
43 | if not local_distmat is None:
44 | original_dist = original_dist + local_distmat
45 | gallery_num = original_dist.shape[0]
46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
47 | V = np.zeros_like(original_dist).astype(np.float16)
48 | initial_rank = np.argsort(original_dist).astype(np.int32)
49 |
50 | print('starting re_ranking')
51 | for i in range(all_num):
52 | # k-reciprocal neighbors
53 | forward_k_neigh_index = initial_rank[i, :k1 + 1]
54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
55 | fi = np.where(backward_k_neigh_index == i)[0]
56 | k_reciprocal_index = forward_k_neigh_index[fi]
57 | k_reciprocal_expansion_index = k_reciprocal_index
58 | for j in range(len(k_reciprocal_index)):
59 | candidate = k_reciprocal_index[j]
60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
62 | :int(np.around(k1 / 2)) + 1]
63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
66 | candidate_k_reciprocal_index):
67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
68 |
69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
72 | original_dist = original_dist[:query_num, ]
73 | if k2 != 1:
74 | V_qe = np.zeros_like(V, dtype=np.float16)
75 | for i in range(all_num):
76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
77 | V = V_qe
78 | del V_qe
79 | del initial_rank
80 | invIndex = []
81 | for i in range(gallery_num):
82 | invIndex.append(np.where(V[:, i] != 0)[0])
83 |
84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
85 |
86 | for i in range(query_num):
87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
88 | indNonZero = np.where(V[i, :] != 0)[0]
89 | indImages = [invIndex[ind] for ind in indNonZero]
90 | for j in range(len(indNonZero)):
91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
92 | V[indImages[j], indNonZero[j]])
93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
94 |
95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
96 | del original_dist
97 | del V
98 | del jaccard_dist
99 | final_dist = final_dist[:query_num, query_num:]
100 | return final_dist
101 |
102 |
--------------------------------------------------------------------------------
/utils/reid_metric.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import numpy as np
3 | import torch
4 | from ignite.metrics import Metric
5 |
6 | from data.datasets.eval_reid import eval_func
7 | from .re_ranking import re_ranking
8 |
9 |
10 | class R1_mAP(Metric):
11 | def __init__(self, num_query, max_rank=50, feat_norm='yes',new_eval=False):
12 | super(R1_mAP, self).__init__()
13 | self.num_query = num_query
14 | self.max_rank = max_rank
15 | self.feat_norm = feat_norm
16 | self.new_eval = new_eval
17 |
18 | def reset(self):
19 | self.feats = []
20 | self.pids = []
21 | self.camids = []
22 | if self.new_eval:
23 | self.ambis = []
24 |
25 | def update(self, output):
26 | if self.new_eval:
27 | feat, pid, ambi, camid = output
28 | self.ambis.extend(np.asarray(ambi))
29 | else:
30 | feat, pid, camid = output
31 | self.feats.append(feat)
32 | self.pids.extend(np.asarray(pid))
33 | self.camids.extend(np.asarray(camid))
34 |
35 | def compute(self):
36 | feats = torch.cat(self.feats, dim=0)
37 | if self.feat_norm == 'yes':
38 | print("The test feature is normalized")
39 | feats = torch.nn.functional.normalize(feats, dim=1, p=2)
40 | # query
41 | qf = feats[:self.num_query]
42 | q_pids = np.asarray(self.pids[:self.num_query])
43 | q_camids = np.asarray(self.camids[:self.num_query])
44 | if self.new_eval:
45 | q_ambis = np.asarray(self.ambis[:self.num_query])
46 | else:
47 | q_ambis = None
48 |
49 | # gallery
50 | gf = feats[self.num_query:]
51 | g_pids = np.asarray(self.pids[self.num_query:])
52 | g_camids = np.asarray(self.camids[self.num_query:])
53 | if self.new_eval:
54 | g_ambis = np.asarray(self.ambis[self.num_query:])
55 | else:
56 | g_ambis = None
57 |
58 | m, n = qf.shape[0], gf.shape[0]
59 |
60 | # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
61 | # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
62 | # distmat.addmm_(qf, gf.t(),beta=1,alpha=-2)
63 | distmat = -1*torch.mm(qf,gf.t())
64 | distmat = distmat.cpu().numpy()
65 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids,q_ambis=q_ambis,g_ambis=g_ambis)
66 |
67 | return cmc, mAP
68 |
69 | # Didn't implement new eval
70 | class R1_mAP_reranking(Metric):
71 | def __init__(self, num_query, max_rank=50, feat_norm='yes'):
72 | super(R1_mAP_reranking, self).__init__()
73 | self.num_query = num_query
74 | self.max_rank = max_rank
75 | self.feat_norm = feat_norm
76 |
77 | def reset(self):
78 | self.feats = []
79 | self.pids = []
80 | self.camids = []
81 |
82 | def update(self, output):
83 | feat, pid, camid = output
84 | self.feats.append(feat)
85 | self.pids.extend(np.asarray(pid))
86 | self.camids.extend(np.asarray(camid))
87 |
88 | def compute(self):
89 | feats = torch.cat(self.feats, dim=0)
90 | if self.feat_norm == 'yes':
91 | print("The test feature is normalized")
92 | feats = torch.nn.functional.normalize(feats, dim=1, p=2)
93 |
94 | # query
95 | qf = feats[:self.num_query]
96 | q_pids = np.asarray(self.pids[:self.num_query])
97 | q_camids = np.asarray(self.camids[:self.num_query])
98 | # gallery
99 | gf = feats[self.num_query:]
100 | g_pids = np.asarray(self.pids[self.num_query:])
101 | g_camids = np.asarray(self.camids[self.num_query:])
102 | # m, n = qf.shape[0], gf.shape[0]
103 | # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
104 | # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
105 | # distmat.addmm_(1, -2, qf, gf.t())
106 | # distmat = distmat.cpu().numpy()
107 | print("Enter reranking")
108 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3)
109 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
110 |
111 | return cmc, mAP
--------------------------------------------------------------------------------