├── LICENSE ├── README.md ├── config ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── defaults.cpython-35.pyc │ └── defaults.cpython-36.pyc └── defaults.py ├── configs └── video_baseline.yml ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── build.cpython-35.pyc │ ├── build.cpython-36.pyc │ ├── collate_batch.cpython-35.pyc │ └── collate_batch.cpython-36.pyc ├── build.py ├── collate_batch.py ├── datasets │ ├── DukeV.py │ ├── MARS.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── DukeV.cpython-36.pyc │ │ ├── MARS.cpython-35.pyc │ │ ├── MARS.cpython-36.pyc │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── bases.cpython-35.pyc │ │ ├── bases.cpython-36.pyc │ │ ├── dataset_loader.cpython-35.pyc │ │ ├── dataset_loader.cpython-36.pyc │ │ ├── dukemtmcreid.cpython-35.pyc │ │ ├── dukemtmcreid.cpython-36.pyc │ │ ├── eval_reid.cpython-36.pyc │ │ ├── market1501.cpython-35.pyc │ │ ├── market1501.cpython-36.pyc │ │ ├── msmt17.cpython-35.pyc │ │ ├── msmt17.cpython-36.pyc │ │ ├── veri.cpython-35.pyc │ │ └── veri.cpython-36.pyc │ ├── bases.py │ ├── dataset_loader.py │ └── eval_reid.py ├── samplers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── triplet_sampler.cpython-35.pyc │ │ └── triplet_sampler.cpython-36.pyc │ └── triplet_sampler.py └── transforms │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── build.cpython-35.pyc │ ├── build.cpython-36.pyc │ ├── spatial_transforms.cpython-36.pyc │ ├── temporal_transforms.cpython-36.pyc │ ├── transforms.cpython-35.pyc │ └── transforms.cpython-36.pyc │ ├── build.py │ ├── temporal_transforms.py │ └── transforms.py ├── engine ├── __pycache__ │ ├── data_parallel.cpython-36.pyc │ ├── inference.cpython-36.pyc │ ├── scatter_gather.cpython-36.pyc │ ├── trainer.cpython-35.pyc │ ├── trainer.cpython-36.pyc │ └── vis.cpython-36.pyc ├── data_parallel.py ├── inference.py ├── scatter_gather.py └── trainer.py ├── imgs ├── DL.png └── DL_2.png ├── layers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── center_loss.cpython-36.pyc │ ├── old_triplet_loss.cpython-36.pyc │ └── triplet_loss.cpython-36.pyc ├── center_loss.py └── triplet_loss.py ├── modeling ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── baseline.cpython-36.pyc │ └── network.cpython-36.pyc ├── backbones │ ├── ResNet.py │ ├── SA │ │ ├── AP3D.py │ │ ├── NonLocal.py │ │ ├── SelfAttn.py │ │ ├── __pycache__ │ │ │ ├── AP3D.cpython-36.pyc │ │ │ ├── NonLocal.cpython-36.pyc │ │ │ ├── SelfAttn.cpython-36.pyc │ │ │ └── inflate.cpython-36.pyc │ │ └── inflate.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── ResNet.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── non_local.cpython-36.pyc │ │ ├── resnet.cpython-36.pyc │ │ └── resnet_NL.cpython-36.pyc │ ├── non_local.py │ ├── resnet.py │ └── resnet_NL.py └── network.py ├── requirements.txt ├── scripts ├── AA_D.sh ├── AA_M.sh ├── NL_D.sh ├── NL_M.sh ├── baseline_D.sh ├── baseline_M.sh └── test_M.sh ├── solver ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── build.cpython-36.pyc │ └── lr_scheduler.cpython-36.pyc ├── build.py └── lr_scheduler.py ├── tests ├── __init__.py └── lr_scheduler_test.py ├── tools ├── __init__.py ├── test.py └── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-35.pyc ├── __init__.cpython-36.pyc ├── iotools.cpython-35.pyc ├── iotools.cpython-36.pyc ├── logger.cpython-36.pyc ├── re_ranking.cpython-36.pyc └── reid_metric.cpython-36.pyc ├── iotools.py ├── logger.py ├── re_ranking.py └── reid_metric.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Chih-Ting Liu (劉致廷) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Video-based Person Re-identification without Bells and Whistles 2 | 3 | [[Paper]](http://media.ee.ntu.edu.tw/research/CFAAN/paper/CVPRw21_VideoReID.pdf) [[arXiv]](https://arxiv.org/pdf/2105.10678.pdf) [[video]](https://youtu.be/RNssJNmq504) 4 | 5 | [Chih-Ting Liu](https://jackie840129.github.io/), [Jun-Cheng Chen](https://www.citi.sinica.edu.tw/pages/pullpull/contact_en.html), [Chu-Song Chen](https://imp.iis.sinica.edu.tw/) and [Shao-Yi Chien](http://www.ee.ntu.edu.tw/profile?id=101),
Analysis & Modeling of Faces & Gestures Workshop jointly with IEEE Conference on Computer Vision and Pattern Recognition (**CVPRw**), 2021 6 | 7 | This is the pytorch implementatin of Coarse-to-Fine Axial Attention Network **(CF-AAN)** for video-based person Re-ID. 8 |
It achieves **91.3%** in rank-1 accuracy and **86.5%** in mAP on our aligned MARS dataset. 9 | 10 | ## News 11 | 12 | **`2021-06-13`**: 13 | - We release the code and aligned dataset for our work. 14 | - We update the Readme related to our new dataset, and the others will be updated gradually. 15 | 16 | **`2021-06-18`**: 17 | - We update the description for training and testing CF-AAN. 18 | 19 | ## Aligned dataset with our re-Detect and Link module 20 | 21 | ### Download Link : 22 | 23 | - MARS (DL) : [[Google Drive]](https://drive.google.com/file/d/1adP39y7xoKYX8Z4lyBtZiDTg9kZyK1Cx/view?usp=sharing) 24 | - For DukeV, we didn't perform DL on DukeMTMC-VideoReID because the bounding boxes are greound truth annotations. 25 | 26 | ### Results 27 | The video tracklet will be re-Detected, linked (tracking) and padded to the original image size, as follow. 28 |

29 | 30 | ### Folder Structure 31 | MARS dataset: 32 | ``` 33 | MARS-DL/ 34 | |-- bbox_train/ 35 | |-- bbox_test/ 36 | |-- info/ 37 | |-- |-- mask_info.csv (for DL mask) 38 | |-- |-- mask_info_test.csv (for DL mask) 39 | |-- |-- clean_tracks_test_info.mat (for new evaluation protocol) 40 | |-- |-- .... (other original info files) 41 | ``` 42 | DukeV dataset: 43 | ``` 44 | DukeMTMC-VideoReID/ 45 | |-- train/ 46 | |-- gallery/ 47 | |-- query/ 48 | ``` 49 | You can put this two folders under your root dataset directory. 50 | ``` 51 | path to your root dir/ 52 | |-- MARS-DL/ 53 | |-- DukeMTMC-VideoReID/ 54 | ``` 55 | ## Coarse-to-Fine Axial Attention Network (CF-AAN) 56 | 57 | ### Requirement 58 | We use Python 3.6, Pytorch 1.5 and Pytorch-ignite in this project. To install required modules, run: 59 | ``` 60 | pip3 install -r requirements.txt 61 | ``` 62 | ### Training 63 | #### Train CF-AAN on MARS-DL 64 | You can alter the argument in `scripts/AA_M.sh` and run it with: 65 | ``` 66 | sh scripts/AA_M.sh 67 | ``` 68 | Or, you can directly type: 69 | ``` 70 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('mars',)" INPUT.SEQ_LEN 6 \ 71 | OUTPUT_DIR "./ckpt_DL_M/MARS_DL_s6_resnet_axial_gap_rqkv_gran4" SOLVER.SOFT_MARGIN True \ 72 | MODEL.NAME 'resnet50_axial' MODEL.TEMP 'Done' INPUT.IF_RE True \ 73 | DATASETS.ROOT_DIR '' 74 | ``` 75 | \* `` is the directory containing both MARS and DukeV dataset. 76 | #### Train Non-local or baseline on MARS 77 | You can alter the argument in `scripts/NL_M.sh` & `scripts/baseline_M.sh` and run it with: 78 | 79 | `sh scripts/AA_M.sh` & `sh scripts/baseline_M.sh` 80 | #### Train models on DukeMTMC-VideoReID 81 | You can use the scripts `scripts/AA_D.sh`, `scripts/NL_D.sh`, & `scripts/baseline_D.sh` 82 | 83 | #### Notes 84 | If you want to train on original MARS dataset, you just need to change the comment in `data/datasets/MARS.py` : 85 | ``` 86 | class MARS(BaseVideoDataset): 87 | dataset_dir = 'MARS' 88 | # dataset_dir = 'MARS-DL' 89 | info_dir = 'info 90 | ``` 91 | 92 | ### Testing 93 | You can alter the argument in `scripts/test_M.sh` and run it with: 94 | ``` 95 | sh scripts/test_M.sh 96 | ``` 97 | \* `TEST.WEIGHT` is the path for the saved pytorch (.pth) model. 98 | 99 | \* There are four modes for `TEST.TEST_MODE`. 100 | 1. `TEST.TEST_MODE 'test'` 101 | * Use RRS[3] testing mode, which samples the first image of T snippets split from tracklet. 102 | 2. `TEST.TEST_MODE 'test_0'` 103 | * Sample first T images in tracklet. 104 | 3. `TEST.TEST_MODE 'test_all_sampled'` 105 | * Create N/T tracklets (all 1st image from T RRS snippets, all 2nd from T RRS snippets...), and average the N/T features. 106 | 4. `TEST.TEST_MODE 'test_all_continuous'` 107 | * Continuous smaple T frames, create N/T tracklets, and average the N/T features. 108 | 109 | If you want to test on DukeV, you can just alter the corresponding arguments in `scripts/test_M.sh`. 110 | 111 | ## New Evaluatoin Protocol 112 | 113 | Change the `TEST.NEW_EVAL False` to `TEST.NEW_EVAL True`. 114 | 115 | The details will be introduced soon. 116 | 117 | ## Citation 118 | ``` 119 | @InProceedings{Liu_2021_CVPR, 120 | author = {Liu, Chih-Ting and Chen, Jun-Cheng and Chen, Chu-Song and Chien, Shao-Yi}, 121 | title = {Video-Based Person Re-Identification Without Bells and Whistles}, 122 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 123 | month = {June}, 124 | year = {2021}, 125 | pages = {1491-1500} 126 | } 127 | ``` 128 | ## Reference 129 | 130 | 1. The structure of our code are based on [reid-strong-baseline](https://github.com/michuanhaohao/reid-strong-baseline). 131 | 2. Some codes of our CF-AAN are based on [axial-deeplab](https://github.com/csrhddlam/axial-deeplab) 132 | 3. Li, Shuang, et al. "Diversity regularized spatiotemporal attention for video-based person re-identification." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018. 133 | ## Contact 134 | 135 | [Chih-Ting Liu](https://jackie840129.github.io/), [Media IC & System Lab](https://github.com/mediaic), National Taiwan University 136 | 137 | E-mail : jackieliu@media.ee.ntu.edu.tw 138 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/config/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/config/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/config/__pycache__/defaults.cpython-35.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/config/__pycache__/defaults.cpython-36.pyc -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | # or _TEST for a test-specific parameter. 9 | # For example, the number of images during training will be 10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be 11 | # IMAGES_PER_BATCH_TEST 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Config definition 15 | # ----------------------------------------------------------------------------- 16 | 17 | _C = CN() 18 | 19 | _C.MODEL = CN() 20 | # Using cuda or cpu for training 21 | _C.MODEL.DEVICE = "cuda" 22 | # ID number of GPU 23 | _C.MODEL.DEVICE_ID = '0' 24 | # Name of backbone 25 | _C.MODEL.NAME = 'resnet50' 26 | # Last stride of backbone 27 | _C.MODEL.LAST_STRIDE = 1 28 | # Path to pretrained model of backbone 29 | _C.MODEL.PRETRAIN_PATH = '' 30 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model 31 | # Options: 'imagenet' or 'self' 32 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet' 33 | # If train with BNNeck, options: 'bnneck' or 'no' 34 | _C.MODEL.NECK = 'bnneck' 35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration 36 | _C.MODEL.IF_WITH_CENTER = 'no' 37 | # The loss type of metric loss 38 | # options:['triplet'](without center loss) or ['center','triplet_center'](with center loss) 39 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet' 40 | # For example, if loss type is cross entropy loss + triplet loss + center loss 41 | # the setting should be: _C.MODEL.METRIC_LOSS_TYPE = 'triplet_center' and _C.MODEL.IF_WITH_CENTER = 'yes' 42 | 43 | # If train with label smooth, options: 'on', 'off' 44 | _C.MODEL.IF_LABELSMOOTH = 'on' 45 | 46 | # Video or Image based 47 | _C.MODEL.SETTING = 'image' 48 | 49 | _C.MODEL.TEMP = 'avg' 50 | _C.MODEL.NON_LAYERS = [0,0,0,0] 51 | 52 | 53 | # ----------------------------------------------------------------------------- 54 | # INPUT 55 | # ----------------------------------------------------------------------------- 56 | _C.INPUT = CN() 57 | # Size of the image during training 58 | _C.INPUT.SIZE_TRAIN = [384, 128] 59 | # Size of the image during test 60 | _C.INPUT.SIZE_TEST = [384, 128] 61 | # Random probability for image horizontal flip 62 | _C.INPUT.PROB = 0.5 63 | # Random probability for random erasing 64 | _C.INPUT.RE_PROB = 0.5 65 | # Values to be used for image normalization 66 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 67 | # Values to be used for image normalization 68 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 69 | # Value of padding size 70 | _C.INPUT.PADDING = 10 71 | # augmentation on/off 72 | _C.INPUT.IF_CROP = True 73 | _C.INPUT.IF_RE = True 74 | _C.INPUT.IF_FLIP = True 75 | 76 | # for video re-id 77 | _C.INPUT.MIN_SEQ_LEN = 0 78 | _C.INPUT.SAMPLE = 'RRS' 79 | _C.INPUT.SEQ_LEN = 8 80 | 81 | # ----------------------------------------------------------------------------- 82 | # Dataset 83 | # ----------------------------------------------------------------------------- 84 | _C.DATASETS = CN() 85 | # List of the dataset names for training, as present in paths_catalog.py 86 | _C.DATASETS.NAMES = ('market1501',) 87 | # Root directory where datasets should be used (and downloaded if not found) 88 | _C.DATASETS.ROOT_DIR = ('/home/mediax/Dataset/') 89 | 90 | # ----------------------------------------------------------------------------- 91 | # DataLoader 92 | # ----------------------------------------------------------------------------- 93 | _C.DATALOADER = CN() 94 | # Number of data loading threads 95 | _C.DATALOADER.NUM_WORKERS = 8 96 | # Sampler for data loading 97 | _C.DATALOADER.SAMPLER = 'softmax' 98 | # Number of instance for one batch 99 | _C.DATALOADER.NUM_INSTANCE = 16 100 | 101 | # ---------------------------------------------------------------------------- # 102 | # Solver 103 | # ---------------------------------------------------------------------------- # 104 | _C.SOLVER = CN() 105 | # Name of optimizer 106 | _C.SOLVER.OPTIMIZER_NAME = "Adam" 107 | # Number of max epoches 108 | _C.SOLVER.MAX_EPOCHS = 50 109 | # Base learning rate 110 | _C.SOLVER.BASE_LR = 3e-4 111 | # Factor of learning bias 112 | _C.SOLVER.BIAS_LR_FACTOR = 2 113 | # Momentum 114 | _C.SOLVER.MOMENTUM = 0.9 115 | # Margin of triplet loss 116 | _C.SOLVER.MARGIN = 0.3 117 | _C.SOLVER.SOFT_MARGIN = False 118 | # Margin of cluster ;pss 119 | _C.SOLVER.CLUSTER_MARGIN = 0.3 120 | # Learning rate of SGD to learn the centers of center loss 121 | _C.SOLVER.CENTER_LR = 0.5 122 | # Balanced weight of center loss 123 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005 124 | # Settings of range loss 125 | _C.SOLVER.RANGE_K = 2 126 | _C.SOLVER.RANGE_MARGIN = 0.3 127 | _C.SOLVER.RANGE_ALPHA = 0 128 | _C.SOLVER.RANGE_BETA = 1 129 | _C.SOLVER.RANGE_LOSS_WEIGHT = 1 130 | 131 | # Settings of weight decay 132 | _C.SOLVER.WEIGHT_DECAY = 0.0005 133 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0. 134 | 135 | # decay rate of learning rate 136 | _C.SOLVER.GAMMA = 0.1 137 | # decay step of learning rate 138 | _C.SOLVER.STEPS = (30, 55) 139 | 140 | # warm up factor 141 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3 142 | # iterations of warm up 143 | _C.SOLVER.WARMUP_ITERS = 500 144 | # method of warm up, option: 'constant','linear' 145 | _C.SOLVER.WARMUP_METHOD = "linear" 146 | 147 | # epoch number of saving checkpoints 148 | _C.SOLVER.CHECKPOINT_PERIOD = 50 149 | # iteration of display training log 150 | _C.SOLVER.LOG_PERIOD = 100 151 | # epoch number of validation 152 | _C.SOLVER.EVAL_PERIOD = 50 153 | 154 | # Number of images per batch 155 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 156 | # see 2 images per batch 157 | _C.SOLVER.IMS_PER_BATCH = 64 158 | 159 | # + 160 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 161 | # see 2 images per batch 162 | _C.TEST = CN() 163 | # Number of images per batch during test 164 | _C.TEST.IMS_PER_BATCH = 128 165 | # If test with re-ranking, options: 'yes','no' 166 | _C.TEST.RE_RANKING = 'no' 167 | # Path to trained model 168 | _C.TEST.WEIGHT = "" 169 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after' 170 | _C.TEST.NECK_FEAT = 'after' 171 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance 172 | _C.TEST.FEAT_NORM = 'yes' 173 | 174 | _C.TEST.TEST_MODE = 'test' 175 | 176 | _C.TEST.NEW_EVAL = False 177 | # - 178 | 179 | # ---------------------------------------------------------------------------- # 180 | # Misc options 181 | # ---------------------------------------------------------------------------- # 182 | # Path to checkpoint and saved log of trained model 183 | _C.OUTPUT_DIR = "" 184 | -------------------------------------------------------------------------------- /configs/video_baseline.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '' 4 | NAME : 'resnet50' 5 | METRIC_LOSS_TYPE: 'triplet' 6 | IF_LABELSMOOTH: 'no' 7 | IF_WITH_CENTER: 'no' 8 | SETTING : 'video' 9 | TEMP : 'avg' 10 | 11 | INPUT: 12 | SIZE_TRAIN: [256, 128] 13 | SIZE_TEST: [256, 128] 14 | PROB: 0.5 # random horizontal flip 15 | RE_PROB: 0.5 # random erasing 16 | PADDING: 10 17 | #video reid 18 | IF_FLIP : False 19 | IF_CROP : False 20 | IF_RE : False 21 | MIN_SEQ_LEN : 0 22 | SEQ_LEN : 6 23 | SAMPLE : 'RRS' 24 | 25 | DATASETS: 26 | NAMES: ('mars',) 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'Adam' 35 | MAX_EPOCHS: 220 36 | BASE_LR: 0.0001 37 | 38 | MARGIN : 0.3 39 | SOFT_MARGIN : True 40 | 41 | CENTER_LR: 0.5 42 | CENTER_LOSS_WEIGHT: 0.0005 43 | 44 | BIAS_LR_FACTOR: 1 45 | WEIGHT_DECAY: 5e-5 46 | WEIGHT_DECAY_BIAS: 5e-5 47 | IMS_PER_BATCH: 32 48 | 49 | STEPS: [50, 100, 150, 200] 50 | GAMMA: 0.1 51 | 52 | WARMUP_FACTOR: 0.01 53 | WARMUP_ITERS: 10 54 | WARMUP_METHOD: 'linear' 55 | 56 | EVAL_PERIOD: 10 57 | 58 | TEST: 59 | IMS_PER_BATCH: 32 60 | RE_RANKING: 'no' 61 | WEIGHT: "path" 62 | NECK_FEAT: 'after' 63 | FEAT_NORM: 'yes' 64 | TEST_MODE: 'test' 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_data_loader 8 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/build.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/build.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/collate_batch.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/collate_batch.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/collate_batch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/__pycache__/collate_batch.cpython-36.pyc -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | 5 | from .collate_batch import train_collate_fn, val_collate_fn 6 | from .datasets import init_dataset, ImageDataset,VideoDataset 7 | from .samplers import RandomIdentitySampler, RandomIdentitySampler_alignedreid 8 | from .transforms import build_transforms_ST 9 | 10 | 11 | def make_data_loader(cfg): 12 | ##### build transform ##### 13 | train_spatial_transforms , _ = build_transforms_ST(cfg, is_train=True) 14 | val_spatial_transforms, val_temporal_transforms = build_transforms_ST(cfg, is_train=False) 15 | num_workers = cfg.DATALOADER.NUM_WORKERS 16 | 17 | ##### init dataset-specific object ##### 18 | if cfg.MODEL.SETTING == 'video': 19 | dataset = init_dataset(cfg.DATASETS.NAMES[0], root=cfg.DATASETS.ROOT_DIR,min_seq_len=cfg.INPUT.MIN_SEQ_LEN,new_eval=cfg.TEST.NEW_EVAL) 20 | else: 21 | raise NotImplementedError() 22 | 23 | num_classes = dataset.num_train_pids 24 | ##### create real pytorch Dataset ##### 25 | if cfg.MODEL.SETTING == 'video': 26 | train_set = VideoDataset(dataset.train,cfg.INPUT.SEQ_LEN, cfg.INPUT.SAMPLE, train_spatial_transforms, None, mode='train') 27 | else: 28 | raise NotImplementedError() 29 | 30 | ##### create dataloader ##### 31 | if cfg.DATALOADER.SAMPLER == 'softmax': 32 | train_loader = DataLoader( 33 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 34 | collate_fn=train_collate_fn 35 | ) 36 | else: 37 | train_loader = DataLoader( 38 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 39 | worker_init_fn= lambda _:np.random.seed(), 40 | sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE), 41 | num_workers=num_workers, collate_fn=train_collate_fn,drop_last=True 42 | ) 43 | if cfg.MODEL.SETTING == 'video': 44 | val_set = VideoDataset(dataset.query + dataset.gallery, cfg.INPUT.SEQ_LEN,cfg.INPUT.SAMPLE, val_spatial_transforms,val_temporal_transforms, mode=cfg.TEST.TEST_MODE) 45 | else: 46 | raise NotImplementedError() 47 | 48 | val_loader = DataLoader( 49 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 50 | collate_fn=val_collate_fn 51 | ) 52 | return train_loader, val_loader, len(dataset.query), num_classes 53 | -------------------------------------------------------------------------------- /data/collate_batch.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import torch 3 | 4 | 5 | def train_collate_fn(batch): 6 | if len(batch[0]) == 4: 7 | imgs, pids, _,masks = zip(*batch) 8 | pids = torch.tensor(pids, dtype=torch.int64) 9 | return torch.stack(imgs, dim=0), pids ,torch.stack(masks,dim=0) 10 | imgs, pids, _, = zip(*batch) 11 | pids = torch.tensor(pids, dtype=torch.int64) 12 | return torch.stack(imgs, dim=0), pids 13 | 14 | 15 | def val_collate_fn(batch): 16 | if len(batch[0]) == 4: 17 | imgs, pids, camids ,masks = zip(*batch) 18 | return torch.stack(imgs, dim=0), pids , camids, torch.stack(masks,dim=0) 19 | elif len(batch[0]) == 5 : 20 | imgs, pids, ambi, camids ,masks = zip(*batch) 21 | return torch.stack(imgs, dim=0), pids , ambi, camids, torch.stack(masks,dim=0) 22 | imgs, pids, camids = zip(*batch) 23 | return torch.stack(imgs, dim=0), pids, camids 24 | -------------------------------------------------------------------------------- /data/datasets/DukeV.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import glob 3 | import re 4 | import json 5 | import pickle 6 | import os 7 | import os.path as osp 8 | from scipy.io import loadmat 9 | from .bases import BaseVideoDataset 10 | import pandas as pd 11 | import numpy as np 12 | 13 | 14 | class DukeV(BaseVideoDataset): 15 | dataset_dir = 'DukeMTMC-VideoReID' 16 | 17 | def __init__(self, root='/home/mediax/Dataset', verbose=True, min_seq_len =0,info_dir='./DukeV_info',new_eval=False): 18 | super(DukeV, self).__init__() 19 | self.dataset_dir = osp.join(root, self.dataset_dir) 20 | 21 | self.train_dir = osp.join(self.dataset_dir,'train') 22 | self.gallery_dir = osp.join(self.dataset_dir,'gallery') 23 | self.query_dir = osp.join(self.dataset_dir,'query') 24 | self.min_seq_len = min_seq_len 25 | #for self-created duke info 26 | if 'DL' in self.dataset_dir: 27 | info_dir = './DukeV_DL_info' 28 | self.train_pkl = osp.join(info_dir,'train.pkl') 29 | self.gallery_pkl = osp.join(info_dir,'gallery.pkl') 30 | self.query_pkl = osp.join(info_dir,'query.pkl') 31 | self.info_dir = info_dir 32 | self._check_before_run() 33 | 34 | if 'DL' in self.dataset_dir: 35 | train_mask_csv = pd.read_csv(osp.join(self.dataset_dir,'duke_mask_info.csv'),sep=',',header=None).values 36 | query_mask_csv = pd.read_csv(osp.join(self.dataset_dir,'duke_mask_info_query.csv'),sep=',',header=None).values 37 | gallery_mask_csv = pd.read_csv(osp.join(self.dataset_dir,'duke_mask_info_gallery.csv'),sep=',',header=None).values 38 | else: 39 | train_mask_csv,query_mask_csv, gallery_mask_csv = None,None,None 40 | 41 | train = self._process_dir(self.train_dir,self.train_pkl,relabel=True,mask_info=train_mask_csv) 42 | gallery = self._process_dir(self.gallery_dir,self.gallery_pkl,relabel=False,mask_info=gallery_mask_csv) 43 | query = self._process_dir(self.query_dir,self.query_pkl,relabel=False,mask_info=query_mask_csv) 44 | if verbose: 45 | print("=> DukeV loaded") 46 | self.print_dataset_statistics(train, query, gallery) 47 | self.train = train # list of tuple--(paths,id,cams) 48 | self.query = query 49 | self.gallery = gallery 50 | 51 | self.num_train_pids, self.num_train_tracklets, self.num_train_cams = self.get_videodata_info(self.train) 52 | self.num_query_pids, self.num_query_tracklets, self.num_query_cams = self.get_videodata_info(self.query) 53 | self.num_gallery_pids, self.num_gallery_tracklets, self.num_gallery_cams = self.get_videodata_info(self.gallery) 54 | 55 | def _process_dir(self,dir_path,pkl_path,relabel,mask_info=None): 56 | 57 | if osp.exists(pkl_path): 58 | print('==> %s exisit. Load...'%(pkl_path)) 59 | with open(pkl_path,'rb') as f: 60 | pkl_file = pickle.load(f) 61 | 62 | if mask_info is None: 63 | return pkl_file 64 | 65 | tracklets = [] 66 | start = 0 67 | for info in pkl_file: 68 | end = start + len(info[0]) 69 | tracklets.append((info[0],info[1],info[2],mask_info[start:end,1:].astype('int16')//16)) 70 | start = end 71 | return tracklets 72 | 73 | pdirs = sorted(glob.glob(osp.join(dir_path, '*'))) 74 | print("Processing {} with {} person identities".format(dir_path, len(pdirs))) 75 | pids = sorted(list(set([int(osp.basename(pdir)) for pdir in pdirs]))) 76 | pid2label = {pid : label for label,pid in enumerate(pids)} 77 | 78 | tracklets = [] 79 | for pdir in pdirs: 80 | pid = int(osp.basename(pdir)) 81 | if relabel : pid = pid2label[pid] 82 | track_dirs = sorted(glob.glob(osp.join(pdir,'*'))) 83 | for track_dir in track_dirs: 84 | img_paths = sorted(glob.glob(osp.join(track_dir,'*.jpg'))) 85 | num_imgs = len(img_paths) 86 | if num_imgs < self.min_seq_len : 87 | continue 88 | img_name = osp.basename(img_paths[0]) 89 | if img_name.find('_') == -1 : 90 | camid = int(img_name[5])-1 91 | else: 92 | camid = int(img_name[6])-1 93 | img_paths = tuple(img_paths) 94 | tracklets.append((img_paths,pid,camid)) 95 | # save to pickle 96 | if not osp.isdir(self.info_dir): 97 | os.mkdir(self.info_dir) 98 | with open(pkl_path,'wb') as f: 99 | pickle.dump(tracklets,f) 100 | return tracklets 101 | 102 | def _check_before_run(self): 103 | """Check if all files are available before going deeper""" 104 | if not osp.exists(self.dataset_dir): 105 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 106 | if not osp.exists(self.train_dir): 107 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 108 | if not osp.exists(self.gallery_dir): 109 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 110 | if not osp.exists(self.query_dir): 111 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 112 | -------------------------------------------------------------------------------- /data/datasets/MARS.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import glob 3 | import re 4 | 5 | import os.path as osp 6 | from scipy.io import loadmat 7 | from .bases import BaseVideoDataset 8 | import pandas as pd 9 | 10 | 11 | class MARS(BaseVideoDataset): 12 | # dataset_dir = 'MARS' 13 | dataset_dir = 'MARS-DL' 14 | info_dir = 'info' 15 | 16 | def __init__(self, root='/home/mediax/Dataset', verbose=True, min_seq_len =0,new_eval=False): 17 | super(MARS, self).__init__() 18 | self.dataset_dir = osp.join(root, self.dataset_dir) 19 | self.info_dir = osp.join(self.dataset_dir,self.info_dir) 20 | self.train_name_path = osp.join(self.info_dir,'train_name.txt') 21 | self.test_name_path = osp.join(self.info_dir,'test_name.txt') 22 | self.track_train_info_path = osp.join(self.info_dir,'tracks_train_info.mat') 23 | self.track_test_info_path = osp.join(self.info_dir,'tracks_test_info.mat') 24 | self.query_IDX_path = osp.join(self.info_dir,'query_IDX.mat') 25 | self.new_eval = new_eval 26 | if self.new_eval: 27 | self.track_test_info_path = osp.join(self.info_dir,'clean_tracks_test_info.mat') 28 | 29 | if 'DL' in self.dataset_dir: 30 | train_mask_csv = pd.read_csv(osp.join(self.info_dir,'mask_info.csv'),sep=',',header=None).values 31 | test_mask_csv = pd.read_csv(osp.join(self.info_dir,'mask_info_test.csv'),sep=',',header=None).values 32 | else: 33 | train_mask_csv,test_mask_csv = None,None 34 | self._check_before_run() 35 | # prepare meta data 36 | train_names = self._get_names(self.train_name_path) 37 | test_names = self._get_names(self.test_name_path) 38 | track_train = loadmat(self.track_train_info_path)['track_train_info'] #(8298,4) 39 | track_test = loadmat(self.track_test_info_path)['track_test_info'] #(12180,4) 40 | 41 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze()-1 #(1980,) start from 0 42 | track_query = track_test[query_IDX,:] 43 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 44 | track_gallery = track_test[gallery_IDX,:] 45 | # track_gallery = track_test 46 | 47 | train = self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True,min_seq_len=min_seq_len,mask_info=train_mask_csv) 48 | query = self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False,mask_info=test_mask_csv,new_eval=self.new_eval) 49 | gallery = self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False,mask_info=test_mask_csv,new_eval=self.new_eval) 50 | 51 | if verbose: 52 | print("=> MARS loaded") 53 | self.print_dataset_statistics(train, query, gallery) 54 | 55 | self.train = train # list of tuple--(paths,id,cams) 56 | self.query = query 57 | self.gallery = gallery 58 | 59 | self.num_train_pids, self.num_train_tracklets, self.num_train_cams = self.get_videodata_info(self.train) 60 | self.num_query_pids, self.num_query_tracklets, self.num_query_cams = self.get_videodata_info(self.query) 61 | self.num_gallery_pids, self.num_gallery_tracklets, self.num_gallery_cams = self.get_videodata_info(self.gallery) 62 | 63 | def _check_before_run(self): 64 | """Check if all files are available before going deeper""" 65 | if not osp.exists(self.dataset_dir): 66 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 67 | if not osp.exists(self.train_name_path): 68 | raise RuntimeError("'{}' is not available".format(self.train_name_path)) 69 | if not osp.exists(self.test_name_path): 70 | raise RuntimeError("'{}' is not available".format(self.test_name_path)) 71 | if not osp.exists(self.track_train_info_path): 72 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path)) 73 | if not osp.exists(self.track_test_info_path): 74 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path)) 75 | if not osp.exists(self.query_IDX_path): 76 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path)) 77 | 78 | def _get_names(self, fpath): 79 | names = [] 80 | with open(fpath, 'r') as f: 81 | for line in f: 82 | new_line = line.rstrip() 83 | names.append(new_line) 84 | return names 85 | 86 | def _process_data(self,names, meta_data, home_dir=None,relabel=False,min_seq_len=0,mask_info=None,new_eval=False): 87 | assert home_dir in ['bbox_train','bbox_test'] 88 | 89 | n_tracklets = len(meta_data) 90 | pid_list = list(set(meta_data[:,2].tolist())) 91 | num_pids = len(pid_list) 92 | 93 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 94 | 95 | tracklets = [] 96 | num_imgs_per_tracklet = [] 97 | for tracklet_idx in range(n_tracklets): 98 | data = meta_data[tracklet_idx,...] 99 | if new_eval == True: 100 | start_idx,end_idx,pid,cam, new_pid, new_ambi = data 101 | else: 102 | start_idx,end_idx,pid,cam = data 103 | if pid == -1 or pid == 0 : continue # junk index 104 | assert 1<= cam <=6 105 | 106 | if relabel : pid = pid2label[pid] 107 | cam -= 1 108 | img_names = names[start_idx-1:end_idx] 109 | 110 | if mask_info is not None: 111 | masks = mask_info[start_idx-1:end_idx,1:].astype('int16')//16 112 | 113 | # make sure image names correspond to the same person 114 | pnames = [img_name[:4] for img_name in img_names] 115 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 116 | camnames = [img_name[5] for img_name in img_names] 117 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 118 | 119 | # append image names with directory information 120 | img_paths = [osp.join(self.dataset_dir,home_dir,img_name[:4],img_name) for img_name in img_names] 121 | if len(img_paths) >= min_seq_len: 122 | img_paths = tuple(img_paths) 123 | if mask_info is not None: 124 | masks = mask_info[start_idx-1:end_idx,1:].astype('int16')//16 125 | if new_eval == True: 126 | tracklets.append((img_paths,pid,new_pid,new_ambi,cam, masks)) 127 | else: 128 | tracklets.append((img_paths,pid,cam,masks)) 129 | else: 130 | if new_eval == True: 131 | tracklets.append((img_paths,pid,new_pid,new_ambi,cam)) 132 | else: 133 | tracklets.append((img_paths,pid,cam)) 134 | # num_imgs_per_tracklet.append(len(img_paths)) 135 | # n_tracklets = len(tracklets) 136 | 137 | return tracklets 138 | -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from .MARS import MARS 3 | from .DukeV import DukeV 4 | from .dataset_loader import ImageDataset,VideoDataset 5 | 6 | __factory = { 7 | 'mars' : MARS, 8 | 'dukev':DukeV 9 | } 10 | 11 | 12 | def get_names(): 13 | return __factory.keys() 14 | 15 | 16 | def init_dataset(name, *args, **kwargs): 17 | if name not in __factory.keys(): 18 | raise KeyError("Unknown datasets: {}".format(name)) 19 | return __factory[name](*args, **kwargs) 20 | -------------------------------------------------------------------------------- /data/datasets/__pycache__/DukeV.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/DukeV.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/MARS.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/MARS.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/MARS.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/MARS.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/bases.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/bases.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/bases.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/bases.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dataset_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/dataset_loader.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dataset_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/dataset_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dukemtmcreid.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/dukemtmcreid.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/eval_reid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/eval_reid.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market1501.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/market1501.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market1501.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/market1501.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/msmt17.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/msmt17.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/msmt17.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/msmt17.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/veri.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/veri.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/veri.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/datasets/__pycache__/veri.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/bases.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def get_videodata_info(self, data, return_tracklet_stats=False): 23 | pids, cams, tracklet_stats = [], [], [] 24 | is_mask = (len(data[0]) == 4) or (len(data[0])) 25 | if len(data[0]) == 3 : 26 | for img_paths, pid, camid in data: 27 | pids += [pid] 28 | cams += [camid] 29 | tracklet_stats += [len(img_paths)] 30 | elif len(data[0]) == 4 : 31 | for img_paths, pid, camid ,_ in data: 32 | pids += [pid] 33 | cams += [camid] 34 | tracklet_stats += [len(img_paths)] 35 | elif len(data[0]) == 5 : 36 | for img_paths, pid, new_pid, new_ambi, camid in data: 37 | pids += [new_pid] 38 | cams += [camid] 39 | tracklet_stats += [len(img_paths)] 40 | elif len(data[0]) == 6 : 41 | for img_paths, pid,new_pid,new_ambi, camid ,_ in data: 42 | pids += [new_pid] 43 | cams += [camid] 44 | tracklet_stats += [len(img_paths)] 45 | 46 | pids = set(pids) 47 | cams = set(cams) 48 | num_pids = len(pids) 49 | num_cams = len(cams) 50 | num_tracklets = len(data) 51 | if return_tracklet_stats: 52 | return num_pids, num_tracklets, num_cams, tracklet_stats 53 | return num_pids, num_tracklets, num_cams 54 | 55 | def print_dataset_statistics(self): 56 | raise NotImplementedError 57 | 58 | 59 | class BaseImageDataset(BaseDataset): 60 | """ 61 | Base class of image reid dataset 62 | """ 63 | 64 | def print_dataset_statistics(self, train, query, gallery): 65 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 66 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 67 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 68 | 69 | print("Dataset statistics:") 70 | print(" ----------------------------------------") 71 | print(" subset | # ids | # images | # cameras") 72 | print(" ----------------------------------------") 73 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 74 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 75 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 76 | print(" ----------------------------------------") 77 | 78 | 79 | class BaseVideoDataset(BaseDataset): 80 | """ 81 | Base class of video reid dataset 82 | """ 83 | 84 | def print_dataset_statistics(self, train, query, gallery): 85 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ 86 | self.get_videodata_info(train, return_tracklet_stats=True) 87 | 88 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ 89 | self.get_videodata_info(query, return_tracklet_stats=True) 90 | 91 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ 92 | self.get_videodata_info(gallery, return_tracklet_stats=True) 93 | 94 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats 95 | min_num = np.min(tracklet_stats) 96 | max_num = np.max(tracklet_stats) 97 | avg_num = np.mean(tracklet_stats) 98 | 99 | print("Dataset statistics:") 100 | print(" -------------------------------------------") 101 | print(" subset | # ids | # tracklets | # cameras") 102 | print(" -------------------------------------------") 103 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) 104 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) 105 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) 106 | print(" -------------------------------------------") 107 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) 108 | print(" -------------------------------------------") 109 | -------------------------------------------------------------------------------- /data/datasets/dataset_loader.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import os.path as osp 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | import torch 7 | import random 8 | 9 | 10 | def read_image(img_path): 11 | """Keep reading image until succeed. 12 | This can avoid IOError incurred by heavy IO process.""" 13 | got_img = False 14 | if not osp.exists(img_path): 15 | raise IOError("{} does not exist".format(img_path)) 16 | while not got_img: 17 | try: 18 | img = Image.open(img_path).convert('RGB') 19 | got_img = True 20 | except IOError: 21 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 22 | pass 23 | return img 24 | 25 | 26 | class ImageDataset(Dataset): 27 | """Image Person ReID Dataset""" 28 | 29 | def __init__(self, dataset, transform=None): 30 | self.dataset = dataset 31 | self.transform = transform 32 | 33 | def __len__(self): 34 | return len(self.dataset) 35 | 36 | def __getitem__(self, index): 37 | img_path, pid, camid = self.dataset[index] 38 | img = read_image(img_path) 39 | 40 | if self.transform is not None: 41 | img = self.transform(img) 42 | 43 | return img, pid, camid, img_path 44 | 45 | class VideoDataset(Dataset): 46 | """Video Person ReID Dataset""" 47 | 48 | def __init__(self,dataset,seq_len=8,sample='RRS',spatial_transform=None, temporal_transform=None,mode='test'): 49 | self.dataset = dataset 50 | self.mask = len(dataset[0]) == 4 or len(dataset[0])==6 51 | self.new_eval = len(dataset[0]) == 5 or len(dataset[0])== 6 52 | self.seq_len = seq_len 53 | self.sample = sample 54 | self.spatial_transform = spatial_transform 55 | self.temporal_transform = temporal_transform 56 | self.mode = mode 57 | def __len__(self): 58 | return len(self.dataset) 59 | 60 | def __getitem__(self,idx): 61 | if self.mask and not self.new_eval: 62 | img_paths, pid, cam ,mask = self.dataset[idx] 63 | elif self.mask and self.new_eval: 64 | img_paths, _, pid, ambi, cam ,mask = self.dataset[idx] 65 | elif not self.mask and self.new_eval: 66 | raise NotImplementedError 67 | else: 68 | img_paths, pid, cam = self.dataset[idx] 69 | 70 | num = len(img_paths) 71 | indices = np.arange(0,num).astype(np.int32) 72 | 73 | # Temporal Sample Methods # 74 | if self.sample == 'RRS' and self.mode != 'test_0': 75 | 76 | num_pads = 0 if num%self.seq_len==0 else self.seq_len - num%self.seq_len 77 | indices = np.concatenate([indices,np.ones(num_pads).astype(np.int32)*(num-1)]) 78 | assert len(indices) %self.seq_len == 0 79 | 80 | indices_pool = np.split(indices,self.seq_len) 81 | sampled_indices = [] 82 | 83 | if self.mode == 'train': 84 | for part in indices_pool: 85 | sampled_indices.append(np.random.choice(part,1)[0]) 86 | elif self.mode == 'test_all_sampled': 87 | sampled_indices = np.vstack(indices_pool).T.flatten() 88 | elif self.mode == 'test_all_continuous': 89 | sampled_indices = np.vstack(indices_pool).flatten() 90 | else : 91 | for part in indices_pool: 92 | sampled_indices.append(part[0]) 93 | 94 | elif self.mode == 'test_0': 95 | sampled_indices = self.temporal_transform(indices) 96 | ################################ 97 | 98 | imgs = [] 99 | for index in sampled_indices: 100 | img_path = img_paths[index] 101 | img = read_image(img_path) 102 | if self.spatial_transform is not None: 103 | img = self.spatial_transform(img) 104 | imgs.append(img) 105 | imgs = torch.stack(imgs,dim=0) 106 | 107 | if self.mode == 'train': 108 | flip_prob = random.random() 109 | if flip_prob > 0.5: 110 | imgs = torch.flip(imgs,dims=[3]) 111 | 112 | if self.mask: 113 | sampled_mask = mask[sampled_indices,:] 114 | if self.mode == 'train' and flip_prob > 0.5: 115 | new_start = 128//16 - sampled_mask[:,3] 116 | new_end = 128//16 - sampled_mask[:,2] 117 | sampled_mask[:,2] = new_start 118 | sampled_mask[:,3] = new_end 119 | 120 | if self.new_eval: 121 | return imgs,pid,ambi,cam,torch.tensor(sampled_mask,dtype=torch.int16) 122 | else: 123 | return imgs,pid,cam,torch.tensor(sampled_mask,dtype=torch.int16) 124 | else: 125 | if self.new_eval: 126 | raise NotImplementedError 127 | else: 128 | return imgs,pid,cam 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /data/datasets/eval_reid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50,q_ambis=None,g_ambis=None): 6 | num_q, num_g = distmat.shape 7 | if num_g < max_rank: 8 | max_rank = num_g 9 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 10 | indices = np.argsort(distmat, axis=1) 11 | matches = (g_pids[indices] == q_pids[:, np.newaxis])#.astype(np.int32) 12 | 13 | new_eval = (q_ambis is not None) and (g_ambis is not None) 14 | if new_eval: 15 | matches_am2id = (g_ambis[indices] == q_pids[:, np.newaxis]) 16 | matches_id2am = (g_pids[indices] == q_ambis[:, np.newaxis]) 17 | matches_am2am = (g_ambis[indices] == q_ambis[:, np.newaxis]) 18 | matches = matches | matches_am2am | matches_am2id | matches_id2am 19 | matches = matches.astype(np.int32) 20 | # compute cmc curve for each query 21 | all_cmc = [] 22 | all_AP = [] 23 | num_valid_q = 0. # number of valid query 24 | for q_idx in range(num_q): 25 | # get query pid and camid 26 | q_pid = q_pids[q_idx] 27 | q_camid = q_camids[q_idx] 28 | 29 | # remove gallery samples that have the same pid and camid with query 30 | order = indices[q_idx] 31 | remove = (g_camids[order] == q_camid) 32 | if not new_eval: 33 | remove = remove & (g_pids[order] == q_pid) 34 | else: 35 | q_amb = q_ambis[q_idx] 36 | remove_dis = remove & (g_pids[order] == 0) # distractor with same cam 37 | remove_id2id = remove & (g_pids[order] == q_pid) 38 | remove_am2id = remove & (g_ambis[order] == q_pid) 39 | remove_am2am = remove & (g_ambis[order] == q_amb) 40 | remove_id2am = remove & (g_pids[order] == q_amb) 41 | remove = remove_dis | remove_id2id | remove_am2id | remove_am2am | remove_id2am 42 | 43 | # remove = remove | (g_pids[order] == -1) 44 | keep = np.invert(remove) 45 | 46 | # compute cmc curve 47 | # binary vector, positions with value 1 are correct matches 48 | orig_cmc = matches[q_idx][keep] 49 | if not np.any(orig_cmc): 50 | # this condition is true when query identity does not appear in gallery 51 | continue 52 | cmc = orig_cmc.cumsum() 53 | cmc[cmc > 1] = 1 54 | all_cmc.append(cmc[:max_rank]) 55 | num_valid_q += 1. 56 | 57 | # compute average precision 58 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 59 | num_rel = orig_cmc.sum() 60 | tmp_cmc = orig_cmc.cumsum() 61 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 62 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 63 | AP = tmp_cmc.sum() / num_rel 64 | all_AP.append(AP) 65 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 66 | 67 | all_cmc = np.asarray(all_cmc).astype(np.float32) 68 | all_cmc = all_cmc.sum(0) / num_valid_q 69 | mAP = np.mean(all_AP) 70 | return all_cmc, mAP 71 | -------------------------------------------------------------------------------- /data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .triplet_sampler import RandomIdentitySampler, RandomIdentitySampler_alignedreid # new add by gu 8 | -------------------------------------------------------------------------------- /data/samplers/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/samplers/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/samplers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/triplet_sampler.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/samplers/__pycache__/triplet_sampler.cpython-35.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/triplet_sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/samplers/__pycache__/triplet_sampler.cpython-36.pyc -------------------------------------------------------------------------------- /data/samplers/triplet_sampler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import copy 8 | import random 9 | import torch 10 | from collections import defaultdict 11 | 12 | import numpy as np 13 | from torch.utils.data.sampler import Sampler 14 | 15 | 16 | class RandomIdentitySampler(Sampler): 17 | """ 18 | Randomly sample P identities, then for each identity, 19 | randomly sample K instances, therefore batch size is P*K. 20 | Args: 21 | - data_source (list): list of (img_path, pid, camid). 22 | - num_instances (int): number of instances per identity in a batch. 23 | - batch_size (int): number of examples in a batch. 24 | """ 25 | 26 | def __init__(self, data_source, batch_size, num_instances): 27 | self.data_source = data_source 28 | self.batch_size = batch_size 29 | self.num_instances = num_instances 30 | self.num_pids_per_batch = self.batch_size // self.num_instances 31 | self.index_dic = defaultdict(list) 32 | # create a pid --> [img idx] mapping 33 | for index, (_, pid, _) in enumerate(self.data_source): 34 | self.index_dic[pid].append(index) 35 | self.pids = list(self.index_dic.keys()) 36 | 37 | # estimate number of examples in an epoch 38 | self.length = 0 39 | for pid in self.pids: 40 | idxs = self.index_dic[pid] 41 | num = len(idxs) 42 | if num < self.num_instances: 43 | num = self.num_instances 44 | self.length += num - num % self.num_instances 45 | # for market : from 12936 to 11876 46 | 47 | def __iter__(self): 48 | batch_idxs_dict = defaultdict(list) 49 | 50 | for pid in self.pids: 51 | idxs = copy.deepcopy(self.index_dic[pid]) 52 | if len(idxs) < self.num_instances: 53 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 54 | random.shuffle(idxs) 55 | batch_idxs = [] 56 | for idx in idxs: 57 | batch_idxs.append(idx) 58 | if len(batch_idxs) == self.num_instances: 59 | batch_idxs_dict[pid].append(batch_idxs) 60 | batch_idxs = [] 61 | 62 | avai_pids = copy.deepcopy(self.pids) 63 | final_idxs = [] 64 | 65 | while len(avai_pids) >= self.num_pids_per_batch: 66 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 67 | for pid in selected_pids: 68 | batch_idxs = batch_idxs_dict[pid].pop(0) 69 | final_idxs.extend(batch_idxs) 70 | if len(batch_idxs_dict[pid]) == 0: 71 | avai_pids.remove(pid) 72 | self.length = len(final_idxs) 73 | return iter(final_idxs) 74 | 75 | def __len__(self): 76 | return self.length 77 | 78 | 79 | # New add by gu 80 | class RandomIdentitySampler_alignedreid(Sampler): 81 | """ 82 | Randomly sample N identities, then for each identity, 83 | randomly sample K instances, therefore batch size is N*K. 84 | 85 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 86 | 87 | Args: 88 | data_source (Dataset): dataset to sample from. 89 | num_instances (int): number of instances per identity. 90 | """ 91 | def __init__(self, data_source, num_instances): 92 | self.data_source = data_source 93 | self.num_instances = num_instances 94 | self.index_dic = defaultdict(list) 95 | for index, data in enumerate(data_source): 96 | pid = data[1] 97 | self.index_dic[pid].append(index) 98 | self.pids = list(self.index_dic.keys()) 99 | self.num_identities = len(self.pids) 100 | 101 | def __iter__(self): 102 | indices = torch.randperm(self.num_identities) 103 | ret = [] 104 | for i in indices: 105 | pid = self.pids[i] 106 | t = self.index_dic[pid] 107 | replace = False if len(t) >= self.num_instances else True 108 | t = np.random.choice(t, size=self.num_instances, replace=replace) 109 | ret.extend(t) 110 | return iter(ret) 111 | 112 | def __len__(self): 113 | return self.num_identities * self.num_instances 114 | -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from .build import build_transforms_ST 3 | -------------------------------------------------------------------------------- /data/transforms/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/build.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/build.cpython-35.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/spatial_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/spatial_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/temporal_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/temporal_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/transforms.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/transforms.cpython-35.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/data/transforms/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import torchvision.transforms as T 3 | 4 | from .transforms import RandomErasing 5 | from .temporal_transforms import TemporalBeginCrop 6 | 7 | 8 | def build_transforms_ST(cfg,is_train=True): 9 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 10 | if is_train: 11 | transform_list = [T.Resize(cfg.INPUT.SIZE_TRAIN)] 12 | if cfg.INPUT.IF_FLIP == True: 13 | transform_list.append(T.RandomHorizontalFlip(p=cfg.INPUT.PROB)) 14 | if cfg.INPUT.IF_CROP == True: 15 | transform_list.append(T.Pad(cfg.INPUT.PADDING)) 16 | transform_list.append(T.RandomCrop(cfg.INPUT.SIZE_TRAIN)) 17 | transform_list += [T.ToTensor(),normalize_transform] 18 | if cfg.INPUT.IF_RE == True: 19 | transform_list.append(RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN)) 20 | spatial_transform = T.Compose(transform_list) 21 | temporal_transforms = None 22 | else: 23 | spatial_transform = T.Compose([ 24 | T.Resize(cfg.INPUT.SIZE_TEST), 25 | T.ToTensor(), 26 | normalize_transform 27 | ]) 28 | temporal_transforms = TemporalBeginCrop(size=cfg.INPUT.SEQ_LEN) 29 | return spatial_transform,temporal_transforms 30 | -------------------------------------------------------------------------------- /data/transforms/temporal_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import random 4 | import math 5 | import numpy as np 6 | 7 | 8 | class LoopPadding(object): 9 | 10 | def __init__(self, size): 11 | self.size = size 12 | 13 | def __call__(self, frame_indices): 14 | out = list(frame_indices) 15 | 16 | while len(out) < self.size: 17 | for index in out: 18 | if len(out) >= self.size: 19 | break 20 | out.append(index) 21 | 22 | return out 23 | 24 | 25 | class TemporalCenterCrop(object): 26 | """Temporally crop the given frame indices at a center. 27 | 28 | If the number of frames is less than the size, 29 | loop the indices as many times as necessary to satisfy the size. 30 | 31 | Args: 32 | size (int): Desired output size of the crop. 33 | """ 34 | 35 | def __init__(self, size, padding=True, pad_method='loop'): 36 | self.size = size 37 | self.padding = padding 38 | self.pad_method = pad_method 39 | 40 | def __call__(self, frame_indices): 41 | """ 42 | Args: 43 | frame_indices (list): frame indices to be cropped. 44 | Returns: 45 | list: Cropped frame indices. 46 | """ 47 | 48 | center_index = len(frame_indices) // 2 49 | begin_index = max(0, center_index - (self.size // 2)) 50 | end_index = min(begin_index + self.size, len(frame_indices)) 51 | 52 | out = list(frame_indices[begin_index:end_index]) 53 | 54 | if self.padding == True: 55 | if self.pad_method == 'loop': 56 | while len(out) < self.size: 57 | for index in out: 58 | if len(out) >= self.size: 59 | break 60 | out.append(index) 61 | else: 62 | while len(out) < self.size: 63 | for index in out: 64 | if len(out) >= self.size: 65 | break 66 | out.append(index) 67 | out.sort() 68 | 69 | return out 70 | 71 | 72 | class TemporalRandomCrop(object): 73 | """Temporally crop the given frame indices at a random location. 74 | 75 | If the number of frames is less than the size, 76 | loop the indices as many times as necessary to satisfy the size. 77 | 78 | Args: 79 | size (int): Desired output size of the crop. 80 | """ 81 | 82 | def __init__(self, size=4, stride=8): 83 | self.size = size 84 | self.stride = stride 85 | 86 | def __call__(self, frame_indices): 87 | """ 88 | Args: 89 | frame_indices (list): frame indices to be cropped. 90 | Returns: 91 | list: Cropped frame indices. 92 | """ 93 | frame_indices = list(frame_indices) 94 | 95 | if len(frame_indices) >= self.size * self.stride: 96 | rand_end = len(frame_indices) - (self.size - 1) * self.stride - 1 97 | begin_index = random.randint(0, rand_end) 98 | end_index = begin_index + (self.size - 1) * self.stride + 1 99 | out = frame_indices[begin_index:end_index:self.stride] 100 | elif len(frame_indices) >= self.size: 101 | index = np.random.choice(len(frame_indices), size=self.size, replace=False) 102 | index.sort() 103 | out = [frame_indices[index[i]] for i in range(self.size)] 104 | else: 105 | index = np.random.choice(len(frame_indices), size=self.size, replace=True) 106 | index.sort() 107 | out = [frame_indices[index[i]] for i in range(self.size)] 108 | 109 | return out 110 | 111 | 112 | class TemporalBeginCrop(object): 113 | """Temporally crop the given frame indices at a beginning. 114 | 115 | If the number of frames is less than the size, 116 | loop the indices as many times as necessary to satisfy the size. 117 | 118 | Args: 119 | size (int): Desired output size of the crop. 120 | """ 121 | def __init__(self, size=4): 122 | self.size = size 123 | 124 | def __call__(self, frame_indices): 125 | frame_indices = list(frame_indices) 126 | size = self.size 127 | 128 | if len(frame_indices) >= (size - 1) * 8 + 1: 129 | out = frame_indices[0: (size - 1) * 8 + 1: 8] 130 | elif len(frame_indices) >= (size - 1) * 4 + 1: 131 | out = frame_indices[0: (size - 1) * 4 + 1: 4] 132 | elif len(frame_indices) >= (size - 1) * 2 + 1: 133 | out = frame_indices[0: (size - 1) * 2 + 1: 2] 134 | elif len(frame_indices) >= size: 135 | out = frame_indices[0:size:1] 136 | else: 137 | out = frame_indices[0:size] 138 | while len(out) < size: 139 | for index in out: 140 | if len(out) >= size: 141 | break 142 | out.append(index) 143 | 144 | return out 145 | ''' 146 | def __init__(self, size=4): 147 | self.size = size 148 | 149 | def __call__(self, frame_indices): 150 | frame_indices = list(frame_indices) 151 | 152 | if len(frame_indices) >= 25: 153 | out = frame_indices[0:25:8] 154 | elif len(frame_indices) >= 13: 155 | out = frame_indices[0:13:4] 156 | elif len(frame_indices) >= 7: 157 | out = frame_indices[0:7:2] 158 | elif len(frame_indices) >= 4: 159 | out = frame_indices[0:4:1] 160 | else: 161 | out = frame_indices[0:4] 162 | while len(out) < 4: 163 | for index in out: 164 | if len(out) >= 4: 165 | break 166 | out.append(index) 167 | 168 | return out 169 | ''' 170 | # class TemporalBeginCrop(object): 171 | # """Temporally crop the given frame indices at a beginning. 172 | 173 | # If the number of frames is less than the size, 174 | # loop the indices as many times as necessary to satisfy the size. 175 | 176 | # Args: 177 | # size (int): Desired output size of the crop. 178 | # """ 179 | 180 | # def __init__(self, size=4): 181 | # self.size = size 182 | 183 | # def __call__(self, frame_indices): 184 | # frame_indices = list(frame_indices) 185 | 186 | # if len(frame_indices) >= 4: 187 | # out = frame_indices[0:4:1] 188 | # else: 189 | # out = frame_indices[0:4] 190 | # while len(out) < 4: 191 | # for index in out: 192 | # if len(out) >= 4: 193 | # break 194 | # out.append(index) 195 | 196 | # return out -------------------------------------------------------------------------------- /data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import math 8 | import random 9 | 10 | 11 | class RandomErasing(object): 12 | """ Randomly selects a rectangle region in an image and erases its pixels. 13 | 'Random Erasing Data Augmentation' by Zhong et al. 14 | See https://arxiv.org/pdf/1708.04896.pdf 15 | Args: 16 | probability: The probability that the Random Erasing operation will be performed. 17 | sl: Minimum proportion of erased area against input image. 18 | sh: Maximum proportion of erased area against input image. 19 | r1: Minimum aspect ratio of erased area. 20 | mean: Erasing value. 21 | """ 22 | 23 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 24 | self.probability = probability 25 | self.mean = mean 26 | self.sl = sl 27 | self.sh = sh 28 | self.r1 = r1 29 | 30 | def __call__(self, img): 31 | 32 | if random.uniform(0, 1) >= self.probability: 33 | return img 34 | 35 | for attempt in range(100): 36 | area = img.size()[1] * img.size()[2] 37 | 38 | target_area = random.uniform(self.sl, self.sh) * area 39 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 40 | 41 | h = int(round(math.sqrt(target_area * aspect_ratio))) 42 | w = int(round(math.sqrt(target_area / aspect_ratio))) 43 | 44 | if w < img.size()[2] and h < img.size()[1]: 45 | x1 = random.randint(0, img.size()[1] - h) 46 | y1 = random.randint(0, img.size()[2] - w) 47 | if img.size()[0] == 3: 48 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 49 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 50 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 51 | else: 52 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 53 | return img 54 | 55 | return img 56 | -------------------------------------------------------------------------------- /engine/__pycache__/data_parallel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/data_parallel.cpython-36.pyc -------------------------------------------------------------------------------- /engine/__pycache__/inference.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/inference.cpython-36.pyc -------------------------------------------------------------------------------- /engine/__pycache__/scatter_gather.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/scatter_gather.cpython-36.pyc -------------------------------------------------------------------------------- /engine/__pycache__/trainer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/trainer.cpython-35.pyc -------------------------------------------------------------------------------- /engine/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /engine/__pycache__/vis.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/engine/__pycache__/vis.cpython-36.pyc -------------------------------------------------------------------------------- /engine/data_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules import Module 3 | from torch.nn.parallel.scatter_gather import gather 4 | from torch.nn.parallel.replicate import replicate 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | from .scatter_gather import scatter_kwargs 8 | 9 | 10 | class _DataParallel(Module): 11 | r"""Implements data parallelism at the module level. 12 | 13 | This container parallelizes the application of the given module by 14 | splitting the input across the specified devices by chunking in the batch 15 | dimension. In the forward pass, the module is replicated on each device, 16 | and each replica handles a portion of the input. During the backwards 17 | pass, gradients from each replica are summed into the original module. 18 | 19 | The batch size should be larger than the number of GPUs used. It should 20 | also be an integer multiple of the number of GPUs so that each chunk is the 21 | same size (so that each GPU processes the same number of samples). 22 | 23 | See also: :ref:`cuda-nn-dataparallel-instead` 24 | 25 | Arbitrary positional and keyword inputs are allowed to be passed into 26 | DataParallel EXCEPT Tensors. All variables will be scattered on dim 27 | specified (default 0). Primitive types will be broadcasted, but all 28 | other types will be a shallow copy and can be corrupted if written to in 29 | the model's forward pass. 30 | 31 | Args: 32 | module: module to be parallelized 33 | device_ids: CUDA devices (default: all devices) 34 | output_device: device location of output (default: device_ids[0]) 35 | 36 | Example:: 37 | 38 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) 39 | >>> output = net(input_var) 40 | """ 41 | 42 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well 43 | 44 | def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_sizes=None): 45 | super(_DataParallel, self).__init__() 46 | 47 | if not torch.cuda.is_available(): 48 | self.module = module 49 | self.device_ids = [] 50 | return 51 | 52 | if device_ids is None: 53 | device_ids = list(range(torch.cuda.device_count())) 54 | if output_device is None: 55 | output_device = device_ids[0] 56 | self.dim = dim 57 | self.module = module 58 | self.device_ids = device_ids 59 | self.chunk_sizes = chunk_sizes 60 | self.output_device = output_device 61 | if len(self.device_ids) == 1: 62 | self.module.cuda(device_ids[0]) 63 | 64 | def forward(self, *inputs, **kwargs): 65 | if not self.device_ids: 66 | return self.module(*inputs, **kwargs) 67 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids, self.chunk_sizes) 68 | if len(self.device_ids) == 1: 69 | return self.module(*inputs[0], **kwargs[0]) 70 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 71 | outputs = self.parallel_apply(replicas, inputs, kwargs) 72 | return self.gather(outputs, self.output_device) 73 | 74 | def replicate(self, module, device_ids): 75 | return replicate(module, device_ids) 76 | 77 | def scatter(self, inputs, kwargs, device_ids, chunk_sizes): 78 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim, chunk_sizes=self.chunk_sizes) 79 | 80 | def parallel_apply(self, replicas, inputs, kwargs): 81 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 82 | 83 | def gather(self, outputs, output_device): 84 | return gather(outputs, output_device, dim=self.dim) 85 | 86 | 87 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None): 88 | r"""Evaluates module(input) in parallel across the GPUs given in device_ids. 89 | 90 | This is the functional version of the DataParallel module. 91 | 92 | Args: 93 | module: the module to evaluate in parallel 94 | inputs: inputs to the module 95 | device_ids: GPU ids on which to replicate module 96 | output_device: GPU location of the output Use -1 to indicate the CPU. 97 | (default: device_ids[0]) 98 | Returns: 99 | a Variable containing the result of module(input) located on 100 | output_device 101 | """ 102 | if not isinstance(inputs, tuple): 103 | inputs = (inputs,) 104 | 105 | if device_ids is None: 106 | device_ids = list(range(torch.cuda.device_count())) 107 | 108 | if output_device is None: 109 | output_device = device_ids[0] 110 | 111 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) 112 | if len(device_ids) == 1: 113 | return module(*inputs[0], **module_kwargs[0]) 114 | used_device_ids = device_ids[:len(inputs)] 115 | replicas = replicate(module, used_device_ids) 116 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) 117 | return gather(outputs, output_device, dim) 118 | 119 | def DataParallel(module, device_ids=None, output_device=None, dim=0, chunk_sizes=None): 120 | if chunk_sizes is None: 121 | return torch.nn.DataParallel(module, device_ids, output_device, dim) 122 | # standard_size = True 123 | # for i in range(1, len(chunk_sizes)): 124 | # if chunk_sizes[i] != chunk_sizes[0]: 125 | # standard_size = False 126 | # if standard_size: 127 | # return torch.nn.DataParallel(module, device_ids, output_device, dim) 128 | return _DataParallel(module, device_ids, output_device, dim, chunk_sizes) -------------------------------------------------------------------------------- /engine/inference.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | from ignite.engine import Engine 7 | 8 | from utils.reid_metric import R1_mAP, R1_mAP_reranking 9 | from ignite.contrib.handlers.tqdm_logger import ProgressBar 10 | 11 | 12 | def create_supervised_evaluator(model, metrics, 13 | device=None): 14 | if device: 15 | if torch.cuda.device_count() > 1: 16 | model = nn.DataParallel(model) 17 | model.to(device) 18 | 19 | def _inference(engine, batch): 20 | model.eval() 21 | with torch.no_grad(): 22 | data, pids, camids = batch 23 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 24 | feat = model(data) 25 | return feat, pids, camids 26 | 27 | engine = Engine(_inference) 28 | 29 | for name, metric in metrics.items(): 30 | metric.attach(engine, name) 31 | 32 | return engine 33 | 34 | def create_supervised_evaluator_with_mask(model, metrics, 35 | device=None): 36 | if device: 37 | # if torch.cuda.device_count() > 1: 38 | # model = nn.DataParallel(model) 39 | model.to(device) 40 | 41 | def _inference(engine, batch): 42 | model.eval() 43 | with torch.no_grad(): 44 | data, pids, camids ,masks = batch 45 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 46 | feat = model(data,masks) 47 | return feat, pids, camids 48 | 49 | engine = Engine(_inference) 50 | 51 | for name, metric in metrics.items(): 52 | metric.attach(engine, name) 53 | 54 | return engine 55 | 56 | def create_supervised_evaluator_with_mask_new_eval(model, metrics, 57 | device=None): 58 | 59 | if device: 60 | # if torch.cuda.device_count() > 1: 61 | # model = nn.DataParallel(model) 62 | model.to(device) 63 | 64 | def _inference(engine, batch): 65 | model.eval() 66 | with torch.no_grad(): 67 | data, pids, ambi, camids ,masks = batch 68 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 69 | feat = model(data,masks) 70 | return feat, pids, ambi, camids 71 | 72 | engine = Engine(_inference) 73 | 74 | for name, metric in metrics.items(): 75 | metric.attach(engine, name) 76 | 77 | return engine 78 | 79 | def create_supervised_all_evaluator(model, metrics,seq_len, 80 | device=None): 81 | if device: 82 | if torch.cuda.device_count() > 1: 83 | model = nn.DataParallel(model) 84 | model.to(device) 85 | 86 | def _inference(engine, batch): 87 | model.eval() 88 | feats = [] 89 | with torch.no_grad(): 90 | data, pids, camids = batch 91 | iteration = data.shape[1]//seq_len 92 | for i in range(iteration): 93 | x = data[:,i*seq_len:(i+1)*seq_len,...] 94 | x = x.to(device) if torch.cuda.device_count() >= 1 else x 95 | feat = model(x) 96 | feats.append(feat) 97 | feats = torch.mean(torch.cat(feats,dim=0),dim=0,keepdim=True) 98 | return feats, pids, camids 99 | 100 | engine = Engine(_inference) 101 | 102 | for name, metric in metrics.items(): 103 | metric.attach(engine, name) 104 | 105 | return engine 106 | 107 | 108 | def create_supervised_all_evaluator_with_mask(model, metrics,seq_len, 109 | device=None): 110 | if device: 111 | if torch.cuda.device_count() > 1: 112 | model = nn.DataParallel(model) 113 | model.to(device) 114 | 115 | def _inference(engine, batch): 116 | model.eval() 117 | feats = [] 118 | with torch.no_grad(): 119 | data, pids, camids, masks = batch 120 | iteration = data.shape[1]//seq_len 121 | for i in range(iteration): 122 | x = data[:,i*seq_len:(i+1)*seq_len,...] 123 | mask = masks[:,i*seq_len:(i+1)*seq_len,...] 124 | x = x.to(device) if torch.cuda.device_count() >= 1 else x 125 | feat = model(x,mask) 126 | feats.append(feat) 127 | feats = torch.mean(torch.cat(feats,dim=0),dim=0,keepdim=True) 128 | return feats, pids, camids 129 | 130 | engine = Engine(_inference) 131 | 132 | for name, metric in metrics.items(): 133 | metric.attach(engine, name) 134 | 135 | return engine 136 | 137 | def create_supervised_all_evaluator_with_mask_new_eval(model, metrics,seq_len, 138 | device=None): 139 | if device: 140 | if torch.cuda.device_count() > 1: 141 | model = nn.DataParallel(model) 142 | model.to(device) 143 | 144 | def _inference(engine, batch): 145 | model.eval() 146 | feats = [] 147 | with torch.no_grad(): 148 | data, pids, ambi, camids, masks = batch 149 | iteration = data.shape[1]//seq_len 150 | for i in range(iteration): 151 | x = data[:,i*seq_len:(i+1)*seq_len,...] 152 | mask = masks[:,i*seq_len:(i+1)*seq_len,...] 153 | x = x.to(device) if torch.cuda.device_count() >= 1 else x 154 | feat = model(x,mask) 155 | feats.append(feat) 156 | feats = torch.mean(torch.cat(feats,dim=0),dim=0,keepdim=True) 157 | return feats, pids, ambi, camids 158 | 159 | engine = Engine(_inference) 160 | 161 | for name, metric in metrics.items(): 162 | metric.attach(engine, name) 163 | 164 | return engine 165 | 166 | 167 | def inference( 168 | cfg, 169 | model, 170 | val_loader, 171 | num_query 172 | ): 173 | device = cfg.MODEL.DEVICE 174 | 175 | logger = logging.getLogger("reid_baseline.inference") 176 | logger.info("Enter inferencing") 177 | if cfg.TEST.RE_RANKING == 'no': 178 | print("Create evaluator") 179 | if 'test_all' in cfg.TEST.TEST_MODE: 180 | if len(val_loader.dataset.dataset[0]) == 4: # mask no new eval 181 | evaluator = create_supervised_all_evaluator_with_mask(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, 182 | seq_len=cfg.INPUT.SEQ_LEN,device=device) 183 | elif len(val_loader.dataset.dataset[0]) == 6: # mask , new eval 184 | evaluator = create_supervised_all_evaluator_with_mask_new_eval(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM,new_eval=True)}, 185 | seq_len=cfg.INPUT.SEQ_LEN,device=device) 186 | else: 187 | evaluator = create_supervised_all_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, 188 | seq_len=cfg.INPUT.SEQ_LEN,device=device) 189 | else: 190 | if len(val_loader.dataset.dataset[0]) == 6: # mask , new eval 191 | evaluator = create_supervised_evaluator_with_mask_new_eval(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM,new_eval=True)}, 192 | device=device) 193 | elif len(val_loader.dataset.dataset[0]) == 4 : # mask, no new eval 194 | evaluator = create_supervised_evaluator_with_mask(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, 195 | device=device) 196 | else: 197 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, 198 | device=device) 199 | elif cfg.TEST.RE_RANKING == 'yes': # haven't implement with mask 200 | print("Create evaluator for reranking") 201 | if 'test_all' in cfg.TEST.TEST_MODE: 202 | evaluator = create_supervised_all_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, 203 | seq_len=cfg.INPUT.SEQ_LEN,device=device) 204 | else: 205 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, 206 | device=device) 207 | else: 208 | print("Unsupported re_ranking config. Only support for no or yes, but got {}.".format(cfg.TEST.RE_RANKING)) 209 | 210 | pbar = ProgressBar(persist=True,ncols=120) 211 | pbar.attach(evaluator) 212 | 213 | evaluator.run(val_loader) 214 | cmc, mAP = evaluator.state.metrics['r1_mAP'] 215 | logger.info('Validation Results') 216 | logger.info("mAP: {:.1%}".format(mAP)) 217 | for r in [1, 5, 10]: 218 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 219 | -------------------------------------------------------------------------------- /engine/scatter_gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.nn.parallel._functions import Scatter 4 | 5 | 6 | def scatter(inputs, target_gpus, dim=0, chunk_sizes=None): 7 | r""" 8 | Slices variables into approximately equal chunks and 9 | distributes them across given GPUs. Duplicates 10 | references to objects that are not variables. Does not 11 | support Tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, Variable): 15 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 16 | assert not torch.is_tensor(obj), "Tensors not supported in scatter." 17 | if isinstance(obj, tuple): 18 | return list(zip(*map(scatter_map, obj))) 19 | if isinstance(obj, list): 20 | return list(map(list, zip(*map(scatter_map, obj)))) 21 | if isinstance(obj, dict): 22 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 23 | return [obj for targets in target_gpus] 24 | 25 | return scatter_map(inputs) 26 | 27 | 28 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_sizes=None): 29 | r"""Scatter with support for kwargs dictionary""" 30 | inputs = scatter(inputs, target_gpus, dim, chunk_sizes) if inputs else [] 31 | kwargs = scatter(kwargs, target_gpus, dim, chunk_sizes) if kwargs else [] 32 | if len(inputs) < len(kwargs): 33 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 34 | elif len(kwargs) < len(inputs): 35 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 36 | inputs = tuple(inputs) 37 | kwargs = tuple(kwargs) 38 | return inputs, kwargs 39 | -------------------------------------------------------------------------------- /engine/trainer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import DataParallel 7 | # from engine.data_parallel import DataParallel 8 | # #self create dataparallel for unbalance GPU memory size 9 | from ignite.engine import Engine, Events 10 | from ignite.handlers import ModelCheckpoint, Timer,global_step_from_engine 11 | from ignite.metrics import RunningAverage 12 | from ignite.contrib.handlers.tqdm_logger import ProgressBar 13 | from utils.reid_metric import R1_mAP 14 | 15 | 16 | def create_supervised_trainer(model, optimizer, loss_fn, 17 | device=None): 18 | if device: 19 | if torch.cuda.device_count() > 1: 20 | model = DataParallel(model) 21 | model.to(device) 22 | def _update(engine, batch): 23 | model.train() 24 | optimizer.zero_grad() 25 | img, target = batch 26 | img = img.to(device) if torch.cuda.device_count() >= 1 else img 27 | target = target.to(device) if torch.cuda.device_count() >= 1 else target 28 | score, feat = model(img) 29 | loss,loss_dict = loss_fn(score, feat, target) 30 | loss.backward() 31 | optimizer.step() 32 | # compute acc 33 | acc = (score.max(1)[1] == target).float().mean() 34 | loss_dict['loss'] = loss.item() 35 | return acc.item(),loss_dict 36 | 37 | return Engine(_update) 38 | 39 | def create_supervised_trainer_with_mask(model, optimizer, loss_fn, 40 | device=None): 41 | if device: 42 | if torch.cuda.device_count() > 1: 43 | model = DataParallel(model) 44 | model.to(device) 45 | def _update(engine, batch): 46 | model.train() 47 | optimizer.zero_grad() 48 | img, target ,masks = batch 49 | img = img.to(device) if torch.cuda.device_count() >= 1 else img 50 | target = target.to(device) if torch.cuda.device_count() >= 1 else target 51 | score, feat = model(img,masks) 52 | loss,loss_dict = loss_fn(score, feat, target) 53 | loss.backward() 54 | optimizer.step() 55 | # compute acc 56 | acc = (score.max(1)[1] == target).float().mean() 57 | loss_dict['loss'] = loss.item() 58 | return acc.item(),loss_dict 59 | 60 | return Engine(_update) 61 | 62 | def create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cetner_loss_weight, 63 | device=None): 64 | if device: 65 | if torch.cuda.device_count() > 1: 66 | model = nn.DataParallel(model) 67 | model.to(device) 68 | 69 | def _update(engine, batch): 70 | model.train() 71 | optimizer.zero_grad() 72 | optimizer_center.zero_grad() 73 | img, target = batch 74 | img = img.to(device) if torch.cuda.device_count() >= 1 else img 75 | target = target.to(device) if torch.cuda.device_count() >= 1 else target 76 | score, feat = model(img) 77 | loss,loss_dict = loss_fn(score, feat, target) 78 | # print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target))) 79 | loss.backward() 80 | optimizer.step() 81 | for param in center_criterion.parameters(): 82 | param.grad.data *= (1. / cetner_loss_weight) 83 | optimizer_center.step() 84 | 85 | # compute acc 86 | acc = (score.max(1)[1] == target).float().mean() 87 | loss_dict['loss'] = loss.item() 88 | return acc.item(),loss_dict 89 | 90 | return Engine(_update) 91 | 92 | # + 93 | def create_supervised_evaluator(model, metrics, 94 | device=None): 95 | if device: 96 | # if torch.cuda.device_count() > 1: 97 | # model = nn.DataParallel(model) 98 | model.to(device) 99 | 100 | def _inference(engine, batch): 101 | model.eval() 102 | with torch.no_grad(): 103 | data, pids, camids = batch 104 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 105 | feat = model(data) 106 | return feat, pids, camids 107 | 108 | engine = Engine(_inference) 109 | 110 | for name, metric in metrics.items(): 111 | metric.attach(engine, name) 112 | 113 | return engine 114 | 115 | def create_supervised_evaluator_with_mask(model, metrics, 116 | device=None): 117 | if device: 118 | # if torch.cuda.device_count() > 1: 119 | # model = nn.DataParallel(model) 120 | model.to(device) 121 | 122 | def _inference(engine, batch): 123 | model.eval() 124 | with torch.no_grad(): 125 | data, pids, camids ,masks = batch 126 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 127 | feat = model(data,masks) 128 | return feat, pids, camids 129 | 130 | engine = Engine(_inference) 131 | 132 | for name, metric in metrics.items(): 133 | metric.attach(engine, name) 134 | 135 | return engine 136 | 137 | def create_supervised_evaluator_with_mask_new_eval(model, metrics, 138 | device=None): 139 | if device: 140 | # if torch.cuda.device_count() > 1: 141 | # model = nn.DataParallel(model) 142 | model.to(device) 143 | 144 | def _inference(engine, batch): 145 | model.eval() 146 | with torch.no_grad(): 147 | data, pids, ambi, camids ,masks = batch 148 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 149 | feat = model(data,masks) 150 | return feat, pids, ambi, camids 151 | 152 | engine = Engine(_inference) 153 | 154 | for name, metric in metrics.items(): 155 | metric.attach(engine, name) 156 | 157 | return engine 158 | 159 | # - 160 | 161 | def do_train( 162 | cfg, 163 | model, 164 | train_loader, 165 | val_loader, 166 | optimizer, 167 | scheduler, 168 | loss_fn, 169 | num_query, 170 | start_epoch 171 | ): 172 | # checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 173 | eval_period = cfg.SOLVER.EVAL_PERIOD 174 | output_dir = cfg.OUTPUT_DIR 175 | device = cfg.MODEL.DEVICE 176 | epochs = cfg.SOLVER.MAX_EPOCHS 177 | 178 | logger = logging.getLogger("reid_baseline.train") 179 | logger.info("Start training") 180 | # Create 1. trainer 2. evaluator 3. checkpointer 4. timer 5. pbar 181 | if len(train_loader.dataset.dataset[0]) == 4 : #train with mask 182 | trainer = create_supervised_trainer_with_mask(model, optimizer, loss_fn, device=device) 183 | if cfg.TEST.NEW_EVAL == False: 184 | evaluator = create_supervised_evaluator_with_mask(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device) 185 | else: 186 | evaluator = create_supervised_evaluator_with_mask_new_eval(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM,new_eval=True)}, device=device) 187 | else: # no mask 188 | trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) 189 | if cfg.TEST.NEW_EVAL == False: 190 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device) 191 | else: 192 | raise NotImplementedError 193 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, n_saved=1, require_empty=False,\ 194 | score_function=lambda x : x.state.metrics['r1_mAP'][1],\ 195 | global_step_transform=global_step_from_engine(trainer)) 196 | timer = Timer(average=True) 197 | tpbar = ProgressBar(persist=True,ncols=120) 198 | epbar = ProgressBar(persist=True,ncols=120) 199 | ############################################################# 200 | evaluator.add_event_handler(Events.EPOCH_COMPLETED(every=1), checkpointer, \ 201 | {'model': model,'optimizer': optimizer}) 202 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 203 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 204 | tpbar.attach(trainer) 205 | epbar.attach(evaluator) 206 | 207 | # average metric to attach on trainer 208 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_acc') 209 | RunningAverage(output_transform=lambda x: x[1]['loss']).attach(trainer, 'avg_loss') 210 | RunningAverage(output_transform=lambda x: x[1]['triplet']).attach(trainer, 'avg_trip') 211 | 212 | 213 | @trainer.on(Events.STARTED) 214 | def start_training(engine): 215 | engine.state.epoch = start_epoch 216 | 217 | @trainer.on(Events.EPOCH_COMPLETED) 218 | def adjust_learning_rate(engine): 219 | # if engine.state.epoch == 1: 220 | # scheduler.step() 221 | scheduler.step() 222 | 223 | 224 | # adding handlers using `trainer.on` decorator API 225 | @trainer.on(Events.EPOCH_COMPLETED) 226 | def print_times(engine): 227 | logger.info('Epoch {} done. Total Loss : {:.3f}, Triplet Loss : {:.3f}, Acc : {:.3f}, Base Lr : {:.2e}' 228 | .format(engine.state.epoch, engine.state.metrics['avg_loss'],engine.state.metrics['avg_trip'], 229 | engine.state.metrics['avg_acc'],scheduler.get_last_lr()[0])) 230 | timer.reset() 231 | 232 | @trainer.on(Events.EPOCH_COMPLETED) 233 | def log_validation_results(engine): 234 | if engine.state.epoch % eval_period == 0: 235 | # evaluator.state.epoch = trainer.state.epoch 236 | evaluator.run(val_loader) 237 | cmc, mAP = evaluator.state.metrics['r1_mAP'] 238 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch)) 239 | logger.info("mAP: {:.1%}".format(mAP)) 240 | for r in [1, 5, 10]: 241 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 242 | 243 | trainer.run(train_loader, max_epochs=epochs) 244 | 245 | 246 | def do_train_with_center( 247 | cfg, 248 | model, 249 | center_criterion, 250 | train_loader, 251 | val_loader, 252 | optimizer, 253 | optimizer_center, 254 | scheduler, 255 | loss_fn, 256 | num_query, 257 | start_epoch 258 | ): 259 | # checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 260 | eval_period = cfg.SOLVER.EVAL_PERIOD 261 | output_dir = cfg.OUTPUT_DIR 262 | device = cfg.MODEL.DEVICE 263 | epochs = cfg.SOLVER.MAX_EPOCHS 264 | 265 | logger = logging.getLogger("reid_baseline.train") 266 | logger.info("Start training") 267 | trainer = create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device) 268 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device) 269 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, None, n_saved=10, require_empty=False) 270 | timer = Timer(average=True) 271 | pbar = ProgressBar(persist=True,ncols=120) 272 | trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_period), checkpointer, {'model': model, 273 | 'optimizer': optimizer, 274 | 'center_param': center_criterion, 275 | 'optimizer_center': optimizer_center}) 276 | 277 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 278 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 279 | pbar.attach(trainer) 280 | 281 | # average metric to attach on trainer 282 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_acc') 283 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_loss') 284 | RunningAverage(output_transform=lambda x: x[1]['triplet']).attach(trainer, 'avg_trip') 285 | RunningAverage(output_transform=lambda x: x[1]['center']).attach(trainer, 'avg_center') 286 | 287 | @trainer.on(Events.STARTED) 288 | def start_training(engine): 289 | engine.state.epoch = start_epoch 290 | 291 | @trainer.on(Events.EPOCH_COMPLETED) 292 | def adjust_learning_rate(engine): 293 | scheduler.step() 294 | # adding handlers using `trainer.on` decorator API 295 | @trainer.on(Events.EPOCH_COMPLETED) 296 | def print_times(engine): 297 | logger.info('Epoch {} done. Total Loss : {:.3f}, Triplet Loss : {:.3f}, Center Loss , Acc : {:.3f}, Base Lr : {:.2e}' 298 | .format(engine.state.epoch, engine.state.metrics['avg_loss'],engine.state.metrics['avg_trip'], 299 | engine.state.metrics['avg_center'],engine.state.metrics['avg_acc'],scheduler.get_lr()[0])) 300 | timer.reset() 301 | 302 | @trainer.on(Events.EPOCH_COMPLETED) 303 | def log_validation_results(engine): 304 | if engine.state.epoch % eval_period == 0: 305 | evaluator.run(val_loader) 306 | cmc, mAP = evaluator.state.metrics['r1_mAP'] 307 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch)) 308 | logger.info("mAP: {:.1%}".format(mAP)) 309 | for r in [1, 5, 10]: 310 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 311 | 312 | trainer.run(train_loader, max_epochs=epochs) 313 | -------------------------------------------------------------------------------- /imgs/DL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/imgs/DL.png -------------------------------------------------------------------------------- /imgs/DL_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/imgs/DL_2.png -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import torch.nn.functional as F 3 | 4 | from .triplet_loss import TripletLoss,CrossEntropyLabelSmooth 5 | from .center_loss import CenterLoss 6 | 7 | 8 | def make_loss(cfg, num_classes): # modified by gu 9 | sampler = cfg.DATALOADER.SAMPLER 10 | # Creating Triplet 11 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 12 | if cfg.SOLVER.SOFT_MARGIN: margin = None 13 | else: margin = cfg.SOLVER.MARGIN 14 | triplet = TripletLoss(margin) # triplet loss 15 | else: 16 | print('expected METRIC_LOSS_TYPE should be triplet' 17 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 18 | # Whether Label Smoothing 19 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 20 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo 21 | print("label smooth on, numclasses:", num_classes) 22 | 23 | # Return loss_func 24 | loss_dict = {'triplet':0,'id_loss':0,'center':0} # for logging 25 | if sampler == 'softmax': 26 | def loss_func(score, feat, target): 27 | id_loss = F.cross_entropy(score,target) 28 | loss_dict['id_loss'] = id_loss.item() 29 | return id_loss,loss_dict 30 | elif cfg.DATALOADER.SAMPLER == 'triplet': 31 | def loss_func(score, feat, target): 32 | metric = triplet(feat,target)[0] 33 | loss_dict['triplet'] = metric.item() 34 | return metric,loss_dict 35 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 36 | def loss_func(score, feat, target): 37 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 38 | metric = triplet(feat,target)[0] 39 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 40 | id_loss = xent(score,target) 41 | else: 42 | id_loss = F.cross_entropy(score,target) 43 | loss_dict['triplet'] = metric.item() 44 | loss_dict['id_loss'] = id_loss.item() 45 | return metric+id_loss,loss_dict 46 | else: 47 | print('expected METRIC_LOSS_TYPE should be triplet' 48 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 49 | else: 50 | print('expected sampler should be softmax, triplet or softmax_triplet, ' 51 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 52 | return loss_func 53 | 54 | 55 | def make_loss_with_center(cfg, num_classes): # modified by gu 56 | if cfg.MODEL.NAME == 'resnet18' or cfg.MODEL.NAME == 'resnet34': 57 | feat_dim = 512 58 | else: 59 | feat_dim = 2048 60 | 61 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center': 62 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 63 | 64 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center': 65 | if cfg.SOLVER.SOFT_MARGIN: margin = None 66 | else: margin = cfg.SOLVER.MARGIN 67 | triplet = TripletLoss(margin) # triplet loss 68 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 69 | else: 70 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center' 71 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 72 | 73 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 74 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo 75 | print("label smooth on, numclasses:", num_classes) 76 | 77 | def loss_func(score, feat, target): 78 | loss_dict = {'triplet':0,'id_loss':0,'center':0} 79 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center': 80 | center = center_criterion(feat,target) 81 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 82 | id_loss = xent(score, target) 83 | else: 84 | id_loss = F.cross_entropy(score, target) 85 | loss = cfg.SOLVER.CENTER_LOSS_WEIGHT * center + id_loss 86 | loss_dict['id_loss'] = id_loss.item() 87 | loss_dict['center'] = center.item() 88 | return loss,loss_dict 89 | 90 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center': 91 | metric = triplet(feat,target)[0] 92 | center = center_criterion(feat,target) 93 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 94 | id_loss = xent(score, target) 95 | else: 96 | id_loss = F.cross_entropy(score, target) 97 | loss = cfg.SOLVER.CENTER_LOSS_WEIGHT * center + id_loss + metric 98 | loss_dict['id_loss'] = id_loss.item() 99 | loss_dict['center'] = center.item() 100 | loss_dict['triplet'] = metric.item() 101 | return loss,loss_dict 102 | 103 | else: 104 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center' 105 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 106 | return loss_func, center_criterion 107 | -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/layers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/center_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/layers/__pycache__/center_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/old_triplet_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/layers/__pycache__/old_triplet_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/triplet_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/layers/__pycache__/triplet_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(x, self.centers.t(),1,-2) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = distmat * mask.float() 48 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 49 | #dist = [] 50 | #for i in range(batch_size): 51 | # value = distmat[i][mask[i]] 52 | # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 53 | # dist.append(value) 54 | #dist = torch.cat(dist) 55 | #loss = dist.mean() 56 | return loss 57 | 58 | 59 | if __name__ == '__main__': 60 | use_gpu = False 61 | center_loss = CenterLoss(use_gpu=use_gpu) 62 | features = torch.rand(16, 2048) 63 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 64 | if use_gpu: 65 | features = torch.rand(16, 2048).cuda() 66 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 67 | 68 | loss = center_loss(features, targets) 69 | print(loss) 70 | -------------------------------------------------------------------------------- /layers/triplet_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | from torch import nn 8 | 9 | 10 | def normalize(x, axis=-1): 11 | """Normalizing to unit length along the specified dimension. 12 | Args: 13 | x: pytorch Variable 14 | Returns: 15 | x: pytorch Variable, same shape as input 16 | """ 17 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 18 | return x 19 | 20 | 21 | def euclidean_dist(x, y): 22 | """ 23 | Args: 24 | x: pytorch Variable, with shape [m, d] 25 | y: pytorch Variable, with shape [n, d] 26 | Returns: 27 | dist: pytorch Variable, with shape [m, n] 28 | """ 29 | m, n = x.size(0), y.size(0) 30 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 31 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 32 | dist = xx + yy 33 | dist.addmm_(x,y.t(),beta=1, alpha=-2) 34 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 35 | return dist 36 | 37 | 38 | def hard_example_mining(dist_mat, labels, return_inds=False): 39 | """For each anchor, find the hardest positive and negative sample. 40 | Args: 41 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 42 | labels: pytorch LongTensor, with shape [N] 43 | return_inds: whether to return the indices. Save time if `False`(?) 44 | Returns: 45 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 46 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 47 | p_inds: pytorch LongTensor, with shape [N]; 48 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 49 | n_inds: pytorch LongTensor, with shape [N]; 50 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 51 | NOTE: Only consider the case in which all labels have same num of samples, 52 | thus we can cope with all anchors in parallel. 53 | """ 54 | 55 | assert len(dist_mat.size()) == 2 56 | assert dist_mat.size(0) == dist_mat.size(1) 57 | N = dist_mat.size(0) 58 | 59 | # shape [N, N] 60 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 61 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 62 | 63 | # `dist_ap` means distance(anchor, positive) 64 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 65 | dist_ap, relative_p_inds = torch.max( 66 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 67 | # `dist_an` means distance(anchor, negative) 68 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 69 | dist_an, relative_n_inds = torch.min( 70 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 71 | # shape [N] 72 | dist_ap = dist_ap.squeeze(1) 73 | dist_an = dist_an.squeeze(1) 74 | 75 | if return_inds: 76 | # shape [N, N] 77 | ind = (labels.new().resize_as_(labels) 78 | .copy_(torch.arange(0, N).long()) 79 | .unsqueeze(0).expand(N, N)) 80 | # shape [N, 1] 81 | p_inds = torch.gather( 82 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 83 | n_inds = torch.gather( 84 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 85 | # shape [N] 86 | p_inds = p_inds.squeeze(1) 87 | n_inds = n_inds.squeeze(1) 88 | return dist_ap, dist_an, p_inds, n_inds 89 | 90 | return dist_ap, dist_an 91 | 92 | 93 | class TripletLoss(object): 94 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 95 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 96 | Loss for Person Re-Identification'.""" 97 | 98 | def __init__(self, margin=-0.1): 99 | self.margin = margin 100 | if margin is not None: 101 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 102 | else: 103 | self.ranking_loss = nn.SoftMarginLoss() 104 | 105 | def __call__(self, global_feat, labels, normalize_feature=False): 106 | if normalize_feature: 107 | global_feat = normalize(global_feat, axis=-1) 108 | dist_mat = euclidean_dist(global_feat, global_feat) 109 | dist_ap, dist_an = hard_example_mining( 110 | dist_mat, labels) 111 | y = dist_an.new().resize_as_(dist_an).fill_(1) 112 | if self.margin is not None: 113 | loss = self.ranking_loss(dist_an, dist_ap, y) 114 | else: 115 | loss = self.ranking_loss(dist_an - dist_ap, y) 116 | return loss, dist_ap, dist_an 117 | 118 | class CrossEntropyLabelSmooth(nn.Module): 119 | """Cross entropy loss with label smoothing regularizer. 120 | 121 | Reference: 122 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 123 | Equation: y = (1 - epsilon) * y + epsilon / K. 124 | 125 | Args: 126 | num_classes (int): number of classes. 127 | epsilon (float): weight. 128 | """ 129 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 130 | super(CrossEntropyLabelSmooth, self).__init__() 131 | self.num_classes = num_classes 132 | self.epsilon = epsilon 133 | self.use_gpu = use_gpu 134 | self.logsoftmax = nn.LogSoftmax(dim=1) 135 | 136 | def forward(self, inputs, targets): 137 | """ 138 | Args: 139 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 140 | targets: ground truth labels with shape (num_classes) 141 | """ 142 | log_probs = self.logsoftmax(inputs) 143 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 144 | if self.use_gpu: targets = targets.cuda() 145 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 146 | loss = (- targets * log_probs).mean(0).sum() 147 | return loss 148 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from .network import VNetwork 3 | 4 | 5 | def build_model(cfg, num_classes): 6 | if cfg.MODEL.SETTING == 'video': 7 | model = VNetwork(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, \ 8 | cfg.MODEL.PRETRAIN_CHOICE, cfg.MODEL.TEMP,cfg.MODEL.NON_LAYERS,cfg.INPUT.SEQ_LEN) 9 | return model 10 | else: 11 | raise NotImplementedError() -------------------------------------------------------------------------------- /modeling/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/baseline.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/__pycache__/baseline.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/__pycache__/network.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import copy 4 | import torchvision 5 | import torch.nn as nn 6 | from torch.nn import init 7 | from torch.autograd import Variable 8 | from torch.nn import functional as F 9 | 10 | from .SA import inflate 11 | from .SA import AP3D 12 | from .SA import NonLocal 13 | from .SA import SelfAttn 14 | 15 | 16 | def weights_init_kaiming(m): 17 | classname = m.__class__.__name__ 18 | if classname.find('Conv') != -1: 19 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 20 | init.constant_(m.bias.data, 0.0) 21 | elif classname.find('Linear') != -1: 22 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 23 | init.constant_(m.bias.data, 0.0) 24 | elif classname.find('BatchNorm') != -1: 25 | init.normal_(m.weight.data, 1.0, 0.02) 26 | init.constant_(m.bias.data, 0.0) 27 | 28 | 29 | def weights_init_classifier(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('Linear') != -1: 32 | init.normal_(m.weight.data, std=0.001) 33 | init.constant_(m.bias.data, 0.0) 34 | 35 | 36 | class Bottleneck3D(nn.Module): 37 | def __init__(self, bottleneck2d, block, inflate_time=False, temperature=4, contrastive_att=True): 38 | super(Bottleneck3D, self).__init__() 39 | 40 | self.conv1 = inflate.inflate_conv(bottleneck2d.conv1, time_dim=1) 41 | self.bn1 = inflate.inflate_batch_norm(bottleneck2d.bn1) 42 | if inflate_time == True: 43 | self.conv2 = block(bottleneck2d.conv2, temperature=temperature, contrastive_att=contrastive_att) 44 | else: 45 | self.conv2 = inflate.inflate_conv(bottleneck2d.conv2, time_dim=1) 46 | self.bn2 = inflate.inflate_batch_norm(bottleneck2d.bn2) 47 | self.conv3 = inflate.inflate_conv(bottleneck2d.conv3, time_dim=1) 48 | self.bn3 = inflate.inflate_batch_norm(bottleneck2d.bn3) 49 | self.relu = nn.ReLU(inplace=True) 50 | 51 | if bottleneck2d.downsample is not None: 52 | self.downsample = self._inflate_downsample(bottleneck2d.downsample) 53 | else: 54 | self.downsample = None 55 | 56 | def _inflate_downsample(self, downsample2d, time_stride=1): 57 | downsample3d = nn.Sequential( 58 | inflate.inflate_conv(downsample2d[0], time_dim=1, 59 | time_stride=time_stride), 60 | inflate.inflate_batch_norm(downsample2d[1])) 61 | return downsample3d 62 | 63 | def forward(self, x): 64 | residual = x 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv3(out) 74 | out = self.bn3(out) 75 | 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class ResNet503D(nn.Module): 86 | def __init__(self, block, c3d_idx, nl_idx, sa_idx, temperature=4, contrastive_att=True, seq_len=6,**kwargs): 87 | super(ResNet503D, self).__init__() 88 | 89 | self.block = block 90 | self.temperature = temperature 91 | self.contrastive_att = contrastive_att 92 | self.inplanes = 64 93 | self.seq_len = seq_len 94 | 95 | resnet2d = torchvision.models.resnet50(pretrained=True) 96 | resnet2d.layer4[0].conv2.stride=(1, 1) 97 | resnet2d.layer4[0].downsample[0].stride=(1, 1) 98 | 99 | ############ STEM ################### 100 | self.conv1 = inflate.inflate_conv(resnet2d.conv1, time_dim=1) 101 | self.bn1 = inflate.inflate_batch_norm(resnet2d.bn1) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.maxpool = inflate.inflate_pool(resnet2d.maxpool, time_dim=1) 104 | ##################################### 105 | 106 | self.layer1 = self._inflate_reslayer(resnet2d.layer1, c3d_idx=c3d_idx[0], \ 107 | nl_idx=nl_idx[0], sa_idx= sa_idx[0],in_channels=256,ks=[64,32,seq_len]) 108 | self.layer2 = self._inflate_reslayer(resnet2d.layer2, c3d_idx=c3d_idx[1], \ 109 | nl_idx=nl_idx[1], sa_idx=sa_idx[1],in_channels=512,ks=[32,16,seq_len]) 110 | self.layer3 = self._inflate_reslayer(resnet2d.layer3, c3d_idx=c3d_idx[2], \ 111 | nl_idx=nl_idx[2], sa_idx=sa_idx[2],in_channels=1024,ks=[16,8,seq_len]) 112 | self.layer4 = self._inflate_reslayer(resnet2d.layer4, c3d_idx=c3d_idx[3], \ 113 | nl_idx=nl_idx[3], sa_idx=sa_idx[3],in_channels=2048,ks=[16,8,seq_len]) 114 | 115 | def _inflate_reslayer(self, reslayer2d, c3d_idx, nl_idx=[], sa_idx=[],in_channels=0,ks=[64,32,1]): 116 | reslayers3d = [] 117 | for i,layer2d in enumerate(reslayer2d): 118 | if i not in c3d_idx: # normal 2D convolution 119 | layer3d = Bottleneck3D(layer2d, AP3D.C2D, inflate_time=False) 120 | else: # (AP)I3D, (AP)P3D-A,B,C 121 | layer3d = Bottleneck3D(layer2d, self.block, inflate_time=True, \ 122 | temperature=self.temperature, contrastive_att=self.contrastive_att) 123 | reslayers3d.append(layer3d) 124 | 125 | if (i in nl_idx) and (i not in sa_idx): 126 | non_local_block = NonLocal.NonLocalBlock3D(in_channels, sub_sample=True) 127 | reslayers3d.append(non_local_block) 128 | elif (i in sa_idx) and (i not in nl_idx): 129 | if ks[0] == 32: 130 | sa_block = SelfAttn.AxialBlock(in_channels,inter_channel=None,kernel_size=ks,granularity=4,groups=8,positional='r_qkv',order='hwt') 131 | else: 132 | sa_block = SelfAttn.AxialBlock(in_channels,inter_channel=None,kernel_size=ks,granularity=4,groups=8,positional='r_qkv',order='hwt') 133 | reslayers3d.append(sa_block) 134 | elif (i in sa_idx) and (i in nl_idx): 135 | raise ValueError("can not use nl and sa at the same time!") 136 | return nn.Sequential(*reslayers3d) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | 149 | return x 150 | 151 | 152 | def AP3DResNet50(num_classes, **kwargs): 153 | c3d_idx = [[],[0, 2],[0, 2, 4],[]] 154 | nl_idx = [[],[],[],[]] 155 | 156 | return ResNet503D(num_classes, AP3D.APP3DC, c3d_idx, nl_idx, **kwargs) 157 | 158 | def P3D_ResNet50(**kwargs): 159 | c3d_idx = [[],[0,2],[0,2,4],[]] 160 | nl_idx = [[],[],[],[]] 161 | sa_idx = [[],[],[],[]] 162 | return ResNet503D(AP3D.P3DC, c3d_idx, nl_idx, sa_idx, **kwargs) 163 | 164 | def P3D_Axial_ResNet50(**kwargs): 165 | c3d_idx = [[],[0,1],[0,1,2],[]] 166 | nl_idx = [[],[],[],[]] 167 | sa_idx = [[],[2,3],[3,4,5],[]] 168 | return ResNet503D(AP3D.P3DC, c3d_idx, nl_idx, sa_idx, **kwargs) 169 | 170 | def C2D_Axial_ResNet50(**kwargs): 171 | c3d_idx = [[],[],[],[]] 172 | nl_idx = [[],[],[],[]] 173 | sa_idx = [[],[2,3],[3,4,5],[]] 174 | return ResNet503D(AP3D.APP3DC, c3d_idx, nl_idx, sa_idx, **kwargs) 175 | 176 | 177 | def C2DResNet50(num_classes, **kwargs): 178 | c3d_idx = [[],[],[],[]] 179 | nl_idx = [[],[],[],[]] 180 | return ResNet503D(num_classes, AP3D.APP3DC, c3d_idx, nl_idx, sa_idx, **kwargs) 181 | 182 | def C2DNLResNet50(num_classes, **kwargs): 183 | c3d_idx = [[],[],[],[]] 184 | nl_idx = [[],[2, 3],[3, 4, 5],[]] 185 | 186 | return ResNet503D(num_classes, AP3D.APP3DC, c3d_idx, nl_idx, **kwargs) 187 | 188 | def AP3DNLResNet50(num_classes, **kwargs): 189 | c3d_idx = [[],[0, 2],[0, 2, 4],[]] 190 | nl_idx = [[],[1, 3],[1, 3, 5],[]] 191 | 192 | return ResNet503D(num_classes, AP3D.APP3DC, c3d_idx, nl_idx, **kwargs) 193 | -------------------------------------------------------------------------------- /modeling/backbones/SA/NonLocal.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import math 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class NonLocalBlockND(nn.Module): 10 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 11 | super(NonLocalBlockND, self).__init__() 12 | 13 | assert dimension in [1, 2, 3] 14 | 15 | self.dimension = dimension 16 | self.sub_sample = sub_sample 17 | self.in_channels = in_channels 18 | self.inter_channels = inter_channels 19 | 20 | if self.inter_channels is None: 21 | self.inter_channels = in_channels // 2 22 | if self.inter_channels == 0: 23 | self.inter_channels = 1 24 | 25 | if dimension == 3: 26 | conv_nd = nn.Conv3d 27 | max_pool = nn.MaxPool3d 28 | bn = nn.BatchNorm3d 29 | elif dimension == 2: 30 | conv_nd = nn.Conv2d 31 | max_pool = nn.MaxPool2d 32 | bn = nn.BatchNorm2d 33 | else: 34 | conv_nd = nn.Conv1d 35 | max_pool = nn.MaxPool1d 36 | bn = nn.BatchNorm1d 37 | 38 | self.g = conv_nd(self.in_channels, self.inter_channels, 39 | kernel_size=1, stride=1, padding=0, bias=True) 40 | self.theta = conv_nd(self.in_channels, self.inter_channels, 41 | kernel_size=1, stride=1, padding=0, bias=True) 42 | self.phi = conv_nd(self.in_channels, self.inter_channels, 43 | kernel_size=1, stride=1, padding=0, bias=True) 44 | if sub_sample: 45 | if dimension == 3: 46 | self.g = nn.Sequential(self.g, max_pool((1, 2, 2))) 47 | self.phi = nn.Sequential(self.phi, max_pool((1, 2, 2))) 48 | else: 49 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 50 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 51 | 52 | if bn_layer: 53 | self.W = nn.Sequential( 54 | conv_nd(self.inter_channels, self.in_channels, 55 | kernel_size=1, stride=1, padding=0, bias=True), 56 | bn(self.in_channels) 57 | ) 58 | else: 59 | self.W = conv_nd(self.inter_channels, self.in_channels, 60 | kernel_size=1, stride=1, padding=0, bias=True) 61 | 62 | # init 63 | for m in self.modules(): 64 | if isinstance(m, conv_nd): 65 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 66 | m.weight.data.normal_(0, math.sqrt(2. / n)) 67 | elif isinstance(m, bn): 68 | m.weight.data.fill_(1) 69 | m.bias.data.zero_() 70 | 71 | if bn_layer: 72 | nn.init.constant_(self.W[1].weight.data, 0.0) 73 | nn.init.constant_(self.W[1].bias.data, 0.0) 74 | else: 75 | nn.init.constant_(self.W.weight.data, 0.0) 76 | nn.init.constant_(self.W.bias.data, 0.0) 77 | 78 | 79 | def forward(self, x): 80 | ''' 81 | :param x: (b, c, t, h, w) 82 | :return: 83 | ''' 84 | 85 | batch_size = x.size(0) 86 | 87 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 88 | g_x = g_x.permute(0, 2, 1) 89 | 90 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 91 | theta_x = theta_x.permute(0, 2, 1) 92 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 93 | f = torch.matmul(theta_x, phi_x) 94 | f = F.softmax(f, dim=-1) 95 | 96 | y = torch.matmul(f, g_x) 97 | y = y.permute(0, 2, 1).contiguous() 98 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 99 | y = self.W(y) 100 | z = y + x 101 | 102 | return z 103 | 104 | 105 | class NonLocalBlock1D(NonLocalBlockND): 106 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 107 | super(NonLocalBlock1D, self).__init__(in_channels, 108 | inter_channels=inter_channels, 109 | dimension=1, sub_sample=sub_sample, 110 | bn_layer=bn_layer) 111 | 112 | 113 | class NonLocalBlock2D(NonLocalBlockND): 114 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 115 | super(NonLocalBlock2D, self).__init__(in_channels, 116 | inter_channels=inter_channels, 117 | dimension=2, sub_sample=sub_sample, 118 | bn_layer=bn_layer) 119 | 120 | 121 | class NonLocalBlock3D(NonLocalBlockND): 122 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 123 | super(NonLocalBlock3D, self).__init__(in_channels, 124 | inter_channels=inter_channels, 125 | dimension=3, sub_sample=sub_sample, 126 | bn_layer=bn_layer) 127 | -------------------------------------------------------------------------------- /modeling/backbones/SA/SelfAttn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def conv1x1(in_planes,out_planes,nd=3,stride=1,bias=False): 7 | if nd == 3: 8 | return nn.Conv3d(in_planes,out_planes,kernel_size=1,stride=stride,bias=bias) 9 | elif nd == 2: 10 | return nn.Conv2d(in_planes,out_planes,kernel_size=1,stride=stride,bias=bias) 11 | else: 12 | raise NotImplementedError 13 | 14 | class AxialAttention(nn.Module): 15 | def __init__(self,in_channel,out_channels,groups=8, kernel_size=56,axial='height', 16 | bias=False,positional='no'): 17 | super(AxialAttention,self).__init__() 18 | self.in_channel = in_channel 19 | self.out_channels = out_channels 20 | self.groups = groups 21 | self.group_planes = out_channels // groups 22 | self.kernel_size = kernel_size 23 | self.axial = axial 24 | self.positional = positional 25 | 26 | self.qkv_transform = nn.Conv1d(in_channel,out_channels*2,kernel_size=1,stride=1,padding=0,bias=False) 27 | self.bn_qkv = nn.BatchNorm1d(out_channels*2) 28 | if self.positional == 'r_qkv': 29 | self.bn_similarity = nn.BatchNorm2d(groups*3) 30 | self.bn_output = nn.BatchNorm1d(out_channels*2) 31 | # positional embedding 32 | self.relative = nn.Parameter(torch.randn(self.group_planes*2,kernel_size*2-1),requires_grad=True) 33 | query_index = torch.arange(kernel_size).unsqueeze(0) 34 | key_index = torch.arange(kernel_size).unsqueeze(1) 35 | relative_index = key_index - query_index + kernel_size - 1 36 | 37 | self.register_buffer('flatten_index', relative_index.view(-1)) 38 | elif self.positional == 'r_q': 39 | self.bn_similarity = nn.BatchNorm2d(groups*2) 40 | # positional embedding 41 | self.relative = nn.Parameter(torch.randn(self.group_planes//2,kernel_size*2-1),requires_grad=True) 42 | query_index = torch.arange(kernel_size).unsqueeze(0) 43 | key_index = torch.arange(kernel_size).unsqueeze(1) 44 | relative_index = key_index - query_index + kernel_size - 1 45 | self.register_buffer('flatten_index', relative_index.view(-1)) 46 | 47 | self.bn_output = nn.BatchNorm1d(out_channels) 48 | else: 49 | self.bn_similarity = nn.BatchNorm2d(groups) 50 | self.bn_output = nn.BatchNorm1d(out_channels) 51 | 52 | self.reset_parameters() 53 | 54 | def reset_parameters(self): 55 | self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_channel)) 56 | #nn.init.uniform_(self.relative, -0.1, 0.1) 57 | if 'r_' in self.positional: 58 | nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes)) 59 | 60 | def forward(self,x,vis=False): 61 | # x'shape : b,c,t,h,w 62 | if self.axial == 'width': 63 | x = x.permute(0,2,3,1,4) # b,t,h,c,w 64 | elif self.axial == 'temporal': 65 | x = x.permute(0,3,4,1,2) # b,h,w,c,t 66 | else: 67 | x = x.permute(0,2,4,1,3) # b,t,w,c,h 68 | B,D1,D2,C,H = x.shape 69 | x = x.contiguous().view(B*D1*D2,C,H) 70 | 71 | # input positnioal embedding 72 | if self.positional == 'input_sine': 73 | dim = torch.arange(C,dtype=torch.float32,device=x.device) 74 | dim = 1000 ** (2 * (dim//2) / C).view(1,C,1) 75 | code = torch.arange(H,dtype=torch.float32,device=x.device).view(1,1,H).repeat(B*D1*D2,C,1) / dim 76 | code = torch.stack([code[:,0::2,:].sin(),code[:,1::2,:].cos()],dim=2).reshape(B*D1*D2,C,H) 77 | x = x + code 78 | 79 | # Transformations 80 | qkv = self.bn_qkv(self.qkv_transform(x)) 81 | q,k,v = torch.split(qkv.reshape(B*D1*D2,self.groups,self.group_planes*2,H),\ 82 | [self.group_planes//2,self.group_planes//2,self.group_planes],dim=2) 83 | 84 | qk = torch.einsum('bgci, bgcj->bgij', q, k) 85 | # Calculate Positinal Embedding 86 | if self.positional == 'r_qkv': 87 | all_embeddings = torch.index_select(self.relative,1,self.flatten_index).view(\ 88 | self.group_planes*2,self.kernel_size,self.kernel_size) 89 | q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, \ 90 | [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0) 91 | 92 | qr = torch.einsum('bgci,cij->bgij', q, q_embedding) 93 | kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3) 94 | stacked_similarity = torch.cat([qk, qr, kr], dim=1) 95 | stacked_similarity = self.bn_similarity(stacked_similarity).view(B*D1*D2, 3, self.groups, H, H).sum(dim=1) 96 | 97 | elif self.positional == 'r_q': 98 | q_embedding = torch.index_select(self.relative,1,self.flatten_index).view(\ 99 | self.group_planes//2,self.kernel_size,self.kernel_size) 100 | qr = torch.einsum('bgci,cij->bgij',q,q_embedding) 101 | stacked_similarity = torch.cat([qk,qr],dim=1) 102 | stacked_similarity = self.bn_similarity(stacked_similarity).view(B*D1*D2,2,self.groups,H,H).sum(dim=1) 103 | else: 104 | stacked_similarity = self.bn_similarity(qk) 105 | 106 | similarity = F.softmax(stacked_similarity, dim=3) 107 | sv = torch.einsum('bgij,bgcj->bgci', similarity, v) 108 | if self.positional == 'r_qkv': 109 | sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding) 110 | stacked_output = torch.cat([sv, sve], dim=-1).view(B*D1*D2, self.out_channels * 2, H) 111 | output = self.bn_output(stacked_output).view(B, D1, D2 , self.out_channels, 2, H).sum(dim=-2) 112 | else: 113 | stacked_output = sv.reshape(B*D1*D2,self.out_channels,H) 114 | output = self.bn_output(stacked_output).view(B,D1,D2,self.out_channels,H) 115 | 116 | 117 | if self.axial == 'width': 118 | output = output.permute(0,3,1,2,4) 119 | elif self.axial == 'temporal': 120 | output = output.permute(0,3,4,1,2) 121 | else: 122 | output = output.permute(0,3,1,4,2) 123 | 124 | if vis == True: 125 | return output,similarity 126 | return output 127 | 128 | 129 | class AxialBlock(nn.Module): 130 | def __init__(self,in_channel,inter_channel=None,groups=8,granularity=1,kernel_size=[],positional='r_qkv',order='hwt'): 131 | super(AxialBlock,self).__init__() 132 | self.inter_channel = inter_channel 133 | self.relu = nn.ReLU(inplace=True) 134 | self.bn2 = nn.BatchNorm3d(in_channel) 135 | self.order = order 136 | self.granularity = granularity 137 | if inter_channel is not None: 138 | self.conv_down = conv1x1(in_channel,inter_channel) 139 | self.bn1 = nn.BatchNorm3d(inter_channel) 140 | self.conv_up = conv1x1(inter_channel,in_channel) 141 | self.axial_channel = inter_channel 142 | else: 143 | self.conv_up = conv1x1(in_channel,in_channel) 144 | self.axial_channel = in_channel 145 | self.in_gran_channel = self.axial_channel//self.granularity 146 | self.axial_gran = [] 147 | for i in range(self.granularity): 148 | gran_group = groups // self.granularity 149 | spatial_ratio = 2**i 150 | height_block = AxialAttention(self.in_gran_channel,self.in_gran_channel,groups=gran_group,kernel_size=kernel_size[0]//spatial_ratio,positional=positional) 151 | width_block = AxialAttention(self.in_gran_channel,self.in_gran_channel,groups=gran_group,axial='width',kernel_size=kernel_size[1]//spatial_ratio,positional=positional) 152 | temporal_block = AxialAttention(self.in_gran_channel,self.in_gran_channel,groups=gran_group,axial='temporal',kernel_size=kernel_size[2],positional=positional) 153 | self.axial_gran.append(height_block) 154 | self.axial_gran.append(width_block) 155 | self.axial_gran.append(temporal_block) 156 | self.axial_gran = nn.ModuleList(self.axial_gran) 157 | 158 | nn.init.constant_(self.bn2.weight,0) 159 | nn.init.constant_(self.bn2.bias,0) 160 | 161 | def forward(self,x): 162 | identity = x 163 | if self.inter_channel is not None: 164 | x = self.relu(self.bn1(self.conv_down(x))) 165 | gran_tensor_list = [] 166 | for i in range(self.granularity): 167 | gran_tensor = x[:, i*(self.in_gran_channel):(i+1)*(self.in_gran_channel),...] 168 | B,C,T,H,W = gran_tensor.shape 169 | gran_tensor = F.adaptive_max_pool3d(gran_tensor,(T,H//(2**i),W//(2**i))) 170 | if self.order == 'hwt': 171 | gran_tensor,h_vis = self.axial_gran[i*3+0](gran_tensor,True) 172 | gran_tensor,w_vis = self.axial_gran[i*3+1](gran_tensor,True) 173 | gran_tensor,t_vis = self.axial_gran[i*3+2](gran_tensor,True) 174 | elif self.order == 'wht': 175 | gran_tensor = self.axial_gran[i*3+1](gran_tensor) 176 | gran_tensor = self.axial_gran[i*3+0](gran_tensor) 177 | gran_tensor = self.axial_gran[i*3+2](gran_tensor) 178 | elif self.order == 'wth': 179 | gran_tensor = self.axial_gran[i*3+1](gran_tensor) 180 | gran_tensor = self.axial_gran[i*3+2](gran_tensor) 181 | gran_tensor = self.axial_gran[i*3+0](gran_tensor) 182 | elif self.order == 'twh': 183 | gran_tensor = self.axial_gran[i*3+2](gran_tensor) 184 | gran_tensor = self.axial_gran[i*3+1](gran_tensor) 185 | gran_tensor = self.axial_gran[i*3+0](gran_tensor) 186 | else: 187 | raise NotImplementedError 188 | gran_tensor = F.interpolate(gran_tensor,size=(T,H,W)) 189 | gran_tensor_list.append(gran_tensor) 190 | x = torch.cat(gran_tensor_list,dim=1) 191 | x = self.bn2(self.conv_up(x)) 192 | 193 | out = identity+x 194 | return out -------------------------------------------------------------------------------- /modeling/backbones/SA/__pycache__/AP3D.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/SA/__pycache__/AP3D.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/SA/__pycache__/NonLocal.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/SA/__pycache__/NonLocal.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/SA/__pycache__/SelfAttn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/SA/__pycache__/SelfAttn.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/SA/__pycache__/inflate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/SA/__pycache__/inflate.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/SA/inflate.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | def inflate_conv(conv2d, 7 | time_dim=1, 8 | time_padding=0, 9 | time_stride=1, 10 | time_dilation=1, 11 | center=False): 12 | # To preserve activations, padding should be by continuity and not zero 13 | # or no padding in time dimension 14 | kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1]) 15 | padding = (time_padding, conv2d.padding[0], conv2d.padding[1]) 16 | stride = (time_stride, conv2d.stride[0], conv2d.stride[0]) 17 | dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1]) 18 | conv3d = nn.Conv3d( 19 | conv2d.in_channels, 20 | conv2d.out_channels, 21 | kernel_dim, 22 | padding=padding, 23 | dilation=dilation, 24 | stride=stride) 25 | # Repeat filter time_dim times along time dimension 26 | weight_2d = conv2d.weight.data 27 | if center: 28 | weight_3d = torch.zeros(*weight_2d.shape) 29 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 30 | middle_idx = time_dim // 2 31 | weight_3d[:, :, middle_idx, :, :] = weight_2d 32 | else: 33 | weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 34 | weight_3d = weight_3d / time_dim 35 | 36 | # Assign new params 37 | conv3d.weight = nn.Parameter(weight_3d) 38 | conv3d.bias = conv2d.bias 39 | return conv3d 40 | 41 | 42 | def inflate_linear(linear2d, time_dim): 43 | """ 44 | Args: 45 | time_dim: final time dimension of the features 46 | """ 47 | linear3d = nn.Linear(linear2d.in_features * time_dim, 48 | linear2d.out_features) 49 | weight3d = linear2d.weight.data.repeat(1, time_dim) 50 | weight3d = weight3d / time_dim 51 | 52 | linear3d.weight = nn.Parameter(weight3d) 53 | linear3d.bias = linear2d.bias 54 | return linear3d 55 | 56 | 57 | def inflate_batch_norm(batch2d): 58 | # In pytorch 0.2.0 the 2d and 3d versions of batch norm 59 | # work identically except for the check that verifies the 60 | # input dimensions 61 | 62 | batch3d = nn.BatchNorm3d(batch2d.num_features) 63 | # retrieve 3d _check_input_dim function 64 | batch2d._check_input_dim = batch3d._check_input_dim 65 | return batch2d 66 | 67 | 68 | def inflate_pool(pool2d, 69 | time_dim=1, 70 | time_padding=0, 71 | time_stride=None, 72 | time_dilation=1): 73 | kernel_dim = (time_dim, pool2d.kernel_size, pool2d.kernel_size) 74 | padding = (time_padding, pool2d.padding, pool2d.padding) 75 | if time_stride is None: 76 | time_stride = time_dim 77 | stride = (time_stride, pool2d.stride, pool2d.stride) 78 | if isinstance(pool2d, nn.MaxPool2d): 79 | dilation = (time_dilation, pool2d.dilation, pool2d.dilation) 80 | pool3d = nn.MaxPool3d( 81 | kernel_dim, 82 | padding=padding, 83 | dilation=dilation, 84 | stride=stride, 85 | ceil_mode=pool2d.ceil_mode) 86 | elif isinstance(pool2d, nn.AvgPool2d): 87 | pool3d = nn.AvgPool3d(kernel_dim, stride=stride) 88 | else: 89 | raise ValueError( 90 | '{} is not among known pooling classes'.format(type(pool2d))) 91 | return pool3d 92 | -------------------------------------------------------------------------------- /modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/ResNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/__pycache__/ResNet.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/non_local.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/__pycache__/non_local.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnet_NL.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/modeling/backbones/__pycache__/resnet_NL.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/non_local.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.nn import functional as F 3 | import numpy as np 4 | import os 5 | import torch 6 | from torch import nn 7 | 8 | class NonLocalBlock(nn.Module): 9 | def __init__(self, in_channels, inter_channels=None,sub_sample=False, bn_layer=True,instance='soft',groups=1): 10 | super(NonLocalBlock, self).__init__() 11 | self.sub_sample = sub_sample 12 | self.instance = instance 13 | self.in_channels = in_channels 14 | self.inter_channels = inter_channels 15 | 16 | if self.inter_channels is None: 17 | self.inter_channels = in_channels // 2 18 | if self.inter_channels == 0: 19 | self.inter_channels = 1 20 | self.groups = groups 21 | self.group_plane = self.inter_channels//self.groups 22 | ##### temporal operation in video re-id ##### 23 | conv_nd = nn.Conv3d 24 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 25 | bn = nn.BatchNorm3d 26 | ############################################## 27 | 28 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 29 | kernel_size=1, stride=1, padding=0) 30 | if bn_layer: 31 | self.W = nn.Sequential( 32 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 33 | kernel_size=1, stride=1, padding=0), 34 | bn(self.in_channels) 35 | ) 36 | nn.init.constant_(self.W[1].weight, 0) 37 | nn.init.constant_(self.W[1].bias, 0) 38 | else: 39 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 40 | kernel_size=1, stride=1, padding=0) 41 | nn.init.constant_(self.W.weight, 0) 42 | nn.init.constant_(self.W.bias, 0) 43 | 44 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 45 | kernel_size=1, stride=1, padding=0) 46 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 47 | kernel_size=1, stride=1, padding=0) 48 | if sub_sample: 49 | self.g = nn.Sequential(self.g, max_pool_layer) 50 | self.phi = nn.Sequential(self.phi, max_pool_layer) 51 | 52 | def forward(self, x): 53 | ''' 54 | :param x: (b, c, t, h, w) 55 | :return: 56 | ''' 57 | batch_size = x.size(0) 58 | 59 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 60 | g_x = g_x.permute(0, 2, 1) # shape : (b , THW, c') 61 | 62 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 63 | theta_x = theta_x.permute(0, 2, 1) # shape : (b, THW , c') 64 | 65 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) # shape : (b , c', THW) 66 | 67 | f = torch.matmul(theta_x, phi_x) 68 | 69 | if self.instance == 'soft': 70 | f_div_C = F.softmax(f, dim=-1) 71 | elif self.instance == 'dot': 72 | f_div_C = f / f.shape[1] 73 | 74 | y = torch.matmul(f_div_C, g_x) 75 | y = y.permute(0, 2, 1).contiguous() # shape : (b, c', THW) 76 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) # shape : (b, c', T, H, W) 77 | 78 | W_y = self.W(y) # shape : (b, c, t, h, w) 79 | z = W_y + x 80 | 81 | return z 82 | -------------------------------------------------------------------------------- /modeling/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | from torchvision import models 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = conv3x3(inplanes, planes, stride) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.conv2 = conv3x3(planes, planes) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.conv1(x) 32 | out = self.bn1(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | out = self.relu(out) 43 | 44 | return out 45 | 46 | 47 | class Bottleneck(nn.Module): 48 | expansion = 4 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(Bottleneck, self).__init__() 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 55 | padding=1, bias=False) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 58 | self.bn3 = nn.BatchNorm2d(planes * 4) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | 77 | if self.downsample is not None: 78 | residual = self.downsample(x) 79 | 80 | out += residual 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class ResNet(nn.Module): 87 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): 88 | self.inplanes = 64 89 | super().__init__() 90 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 91 | bias=False) 92 | self.bn1 = nn.BatchNorm2d(64) 93 | self.relu = nn.ReLU(inplace=True) # add missed relu 94 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 95 | self.layer1 = self._make_layer(block, 64, layers[0]) 96 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 97 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 98 | self.layer4 = self._make_layer( 99 | block, 512, layers[3], stride=last_stride) 100 | 101 | def _make_layer(self, block, planes, blocks, stride=1): 102 | downsample = None 103 | if stride != 1 or self.inplanes != planes * block.expansion: 104 | downsample = nn.Sequential( 105 | nn.Conv2d(self.inplanes, planes * block.expansion, 106 | kernel_size=1, stride=stride, bias=False), 107 | nn.BatchNorm2d(planes * block.expansion), 108 | ) 109 | 110 | layers = [] 111 | layers.append(block(self.inplanes, planes, stride, downsample)) 112 | self.inplanes = planes * block.expansion 113 | for i in range(1, blocks): 114 | layers.append(block(self.inplanes, planes)) 115 | 116 | return nn.Sequential(*layers) 117 | 118 | def forward(self, x): 119 | x = self.conv1(x) 120 | x = self.bn1(x) 121 | # x = self.relu(x) # add missed relu 122 | x = self.maxpool(x) 123 | 124 | x = self.layer1(x) 125 | x = self.layer2(x) 126 | x = self.layer3(x) 127 | x = self.layer4(x) 128 | 129 | return x 130 | 131 | def load_param(self, model_path,autoload=None): 132 | if autoload == 'r50': 133 | param_dict = models.resnet50(pretrained=True).state_dict() 134 | else: 135 | param_dict = torch.load(model_path) 136 | for i in param_dict: 137 | if 'fc' in i: 138 | continue 139 | self.state_dict()[i].copy_(param_dict[i]) 140 | def random_init(self): 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 144 | m.weight.data.normal_(0, math.sqrt(2. / n)) 145 | elif isinstance(m, nn.BatchNorm2d): 146 | m.weight.data.fill_(1) 147 | m.bias.data.zero_() 148 | 149 | -------------------------------------------------------------------------------- /modeling/backbones/resnet_NL.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | from torchvision import models 7 | from .non_local import NonLocalBlock 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = conv3x3(inplanes, planes, stride) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.conv2 = conv3x3(planes, planes) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(planes * 4) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv3(out) 76 | out = self.bn3(out) 77 | 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class ResNet_NL(nn.Module): 88 | def __init__(self, last_stride=1, block=Bottleneck, layers=[3, 4, 6, 3],non_layers=[0,2,3,0]): 89 | self.inplanes = 64 90 | super().__init__() 91 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 92 | self.bn1 = nn.BatchNorm2d(64) 93 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 94 | #### layer 1 #### 95 | self.layer1 = self._make_layer(block, 64, layers[0]) 96 | NL_1 = [NonLocalBlock(self.inplanes,sub_sample=True) for i in range(non_layers[0])] 97 | self.NL_1 = nn.ModuleList(NL_1) 98 | self.NL_1_idx = sorted([layers[0]-(i+1) for i in range(non_layers[0])]) 99 | #### layer 2 #### 100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 101 | NL_2 = [NonLocalBlock(self.inplanes) for i in range(non_layers[1])] 102 | self.NL_2 = nn.ModuleList(NL_2) 103 | self.NL_2_idx = sorted([layers[1]-(i+1) for i in range(non_layers[1])]) 104 | #### layer 3 #### 105 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 106 | NL_3 = [NonLocalBlock(self.inplanes) for i in range(non_layers[2])] 107 | self.NL_3 = nn.ModuleList(NL_3) 108 | self.NL_3_idx = sorted([layers[2]-(i+1) for i in range(non_layers[2])]) 109 | #### layer 4 #### 110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 111 | NL_4 = [NonLocalBlock(self.inplanes) for i in range(non_layers[3])] 112 | self.NL_4 = nn.ModuleList(NL_4) 113 | self.NL_4_idx = sorted([layers[3]-(i+1) for i in range(non_layers[3])]) 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | nn.Conv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=False), 121 | nn.BatchNorm2d(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks): 128 | layers.append(block(self.inplanes, planes)) 129 | 130 | return nn.ModuleList(layers) 131 | 132 | def forward(self, x): 133 | b,t,c,h,w = x.shape 134 | x = self.conv1(x.view(b*t,c,h,w)) 135 | x = self.bn1(x) 136 | x = self.maxpool(x) 137 | 138 | NL1_counter = 0 139 | if len(self.NL_1_idx)== 0 : self.NL_1_idx=[-1] 140 | for i in range(len(self.layer1)): 141 | x = self.layer1[i](x) 142 | if i == self.NL_1_idx[NL1_counter]: 143 | _,c,h,w = x.shape 144 | x = self.NL_1[NL1_counter](x.view(b,t,c,h,w).permute(0,2,1,3,4)) 145 | x = x.permute(0,2,1,3,4).reshape(b*t,c,h,w) 146 | NL1_counter += 1 147 | 148 | NL2_counter = 0 149 | if len(self.NL_2_idx)== 0 : self.NL_2_idx=[-1] 150 | for i in range(len(self.layer2)): 151 | x = self.layer2[i](x) 152 | if i == self.NL_2_idx[NL2_counter]: 153 | _,c,h,w = x.shape 154 | x = self.NL_2[NL2_counter](x.view(b,t,c,h,w).permute(0,2,1,3,4)) 155 | x = x.permute(0,2,1,3,4).reshape(b*t,c,h,w) 156 | NL2_counter += 1 157 | NL3_counter = 0 158 | if len(self.NL_3_idx)== 0 : self.NL_3_idx=[-1] 159 | for i in range(len(self.layer3)): 160 | x = self.layer3[i](x) 161 | if i == self.NL_3_idx[NL3_counter]: 162 | _,c,h,w = x.shape 163 | x = self.NL_3[NL3_counter](x.view(b,t,c,h,w).permute(0,2,1,3,4)) 164 | x = x.permute(0,2,1,3,4).reshape(b*t,c,h,w) 165 | NL3_counter += 1 166 | NL4_counter = 0 167 | if len(self.NL_4_idx)== 0 : self.NL_4_idx=[-1] 168 | for i in range(len(self.layer4)): 169 | x = self.layer4[i](x) 170 | if i == self.NL_4_idx[NL4_counter]: 171 | _,c,h,w = x.shape 172 | x = self.NL_4[NL4_counter](x.view(b,t,c,h,w).permute(0,2,1,3,4)) 173 | x = x.permute(0,2,1,3,4).reshape(b*t,c,h,w) 174 | NL4_counter += 1 175 | return x 176 | 177 | def load_param(self, model_path,autoload=None): 178 | if autoload == 'r50': 179 | param_dict = models.resnet50(pretrained=True).state_dict() 180 | else: 181 | param_dict = torch.load(model_path) 182 | for i in param_dict: 183 | if 'fc' in i: 184 | continue 185 | self.state_dict()[i].copy_(param_dict[i]) 186 | 187 | def random_init(self): 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 191 | m.weight.data.normal_(0, math.sqrt(2. / n)) 192 | elif isinstance(m, nn.BatchNorm2d): 193 | m.weight.data.fill_(1) 194 | m.bias.data.zero_() 195 | 196 | -------------------------------------------------------------------------------- /modeling/network.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from .backbones.resnet import ResNet, BasicBlock, Bottleneck 7 | from .backbones.resnet_NL import ResNet_NL 8 | from .backbones.ResNet import C2D_Axial_ResNet50 9 | 10 | 11 | def weights_init_kaiming(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Linear') != -1: 14 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 15 | nn.init.constant_(m.bias, 0.0) 16 | elif classname.find('Conv') != -1: 17 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 18 | if m.bias is not None: 19 | nn.init.constant_(m.bias, 0.0) 20 | elif classname.find('BatchNorm') != -1: 21 | if m.affine: 22 | nn.init.constant_(m.weight, 1.0) 23 | nn.init.constant_(m.bias, 0.0) 24 | 25 | 26 | def weights_init_classifier(m): 27 | classname = m.__class__.__name__ 28 | if classname.find('Linear') != -1: 29 | nn.init.normal_(m.weight, std=0.001) 30 | if m.bias: 31 | nn.init.constant_(m.bias, 0.0) 32 | 33 | 34 | class VNetwork(nn.Module): 35 | in_planes = 2048 36 | 37 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice,temp,\ 38 | non_layers=[0,0,0,0], seq_len=6): 39 | super(VNetwork, self).__init__() 40 | self.seq_len = seq_len 41 | if model_name == 'resnet50': 42 | self.base = ResNet(last_stride=last_stride, 43 | block=Bottleneck, 44 | layers=[3, 4, 6, 3]) 45 | elif model_name == 'resnet50_NL': 46 | self.base = ResNet_NL(last_stride=last_stride,block=Bottleneck, 47 | layers=[3,4,6,3],non_layers=non_layers) 48 | elif model_name == 'resnet50_axial': 49 | self.base = C2D_Axial_ResNet50(seq_len=seq_len) 50 | 51 | if pretrain_choice == 'imagenet': 52 | if 'axial' not in model_name: 53 | self.base.load_param('',autoload='r50') 54 | print('Loading pretrained ImageNet model......') 55 | 56 | self.gap = nn.AdaptiveAvgPool2d(1) 57 | self.gmp = nn.AdaptiveMaxPool2d(1) 58 | self.num_classes = num_classes 59 | self.neck = neck 60 | self.neck_feat = neck_feat 61 | self.temp = temp 62 | self.model_name = model_name 63 | 64 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 65 | self.bottleneck.bias.requires_grad_(False) # no shift 66 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) 67 | 68 | self.bottleneck.apply(weights_init_kaiming) 69 | self.classifier.apply(weights_init_classifier) 70 | 71 | def forward(self, x,masks=None): 72 | b,t,c,h,w = x.shape 73 | 74 | if 'NL' in self.model_name: 75 | if self.temp == 'Done': 76 | x = self.base(x) 77 | _,c,h,w = x.shape 78 | if masks is not None: 79 | global_feat = [] 80 | masks = masks.reshape(b*t,-1) 81 | for i in range(x.shape[0]): 82 | global_feat.append(self.gap(x[i,:,masks[i][0]:masks[i][1],masks[i][2]:masks[i][3]].unsqueeze(0))) 83 | global_feat = torch.cat(global_feat,dim=0) 84 | global_feat = torch.mean(global_feat.view(b,t,-1),dim=1) 85 | else: 86 | global_feat = F.adaptive_avg_pool3d(x.view(b,t,c,h,w).permute(0,2,1,3,4),1) 87 | global_feat = global_feat.view(b,-1) 88 | else: 89 | global_feat = self.gap(self.base(x)) 90 | global_feat = global_feat.view(b*t,-1) # flatten to (b*t, 2048) 91 | 92 | elif 'axial' in self.model_name: 93 | if masks is not None: 94 | global_feat = [] 95 | masks = masks.reshape(b*t,-1) 96 | output = self.base(x.permute(0,2,1,3,4).contiguous()).permute(0,2,1,3,4).contiguous() 97 | b,t,c,h,w = output.shape 98 | output = output.view(b*t,c,h,w) 99 | for i in range(len(output)): 100 | global_feat.append(self.gap(output[i,:,masks[i][0]:masks[i][1],masks[i][2]:masks[i][3]].unsqueeze(0))) 101 | global_feat = torch.cat(global_feat,dim=0) 102 | global_feat = torch.mean(global_feat.view(b,t,-1),dim=1) 103 | else: 104 | global_feat = self.base(x.permute(0,2,1,3,4).contiguous()).permute(0,2,1,3,4).contiguous() 105 | b,t,c,h,w = global_feat.shape 106 | global_feat = self.gap(global_feat.view(b*t,c,h,w)) 107 | global_feat = torch.mean(global_feat.view(b,t,-1),dim=1) 108 | else: 109 | if masks is not None: 110 | global_feat = [] 111 | masks = masks.reshape(b*t,-1) 112 | output = self.base(x.view(b*t,c,h,w)) 113 | for i in range(len(output)): 114 | global_feat.append(self.gap(output[i,:,masks[i][0]:masks[i][1],masks[i][2]:masks[i][3]].unsqueeze(0))) 115 | global_feat = torch.cat(global_feat,dim=0) 116 | else: 117 | global_feat = self.gap(self.base(x.view(b*t,c,h,w))) # (b*t, 2048, 1, 1) 118 | global_feat = global_feat.view(b*t,-1) # flatten to (b*t, 2048) 119 | 120 | #### whether neck #### 121 | feat = self.bottleneck(global_feat) # normalize for angular softmax 122 | 123 | if self.training: 124 | cls_score = self.classifier(feat) 125 | if self.temp == 'avg': 126 | global_feat = torch.mean(global_feat.view(b,t,-1),dim=1) 127 | cls_score = torch.mean(cls_score.view(b,t,-1),dim=1) 128 | 129 | return cls_score, global_feat # global feature for triplet loss 130 | else: 131 | if self.temp == 'avg': 132 | global_feat = torch.mean(global_feat.view(b,t,-1),dim=1) 133 | feat = torch.mean(feat.view(b,t,-1),dim=1) 134 | if self.neck_feat == 'after': 135 | return feat 136 | else: 137 | return global_feat 138 | 139 | def load_param(self, trained_path,con=False): 140 | param_dict = torch.load(trained_path)['model'] 141 | for i in param_dict: 142 | if 'classifier' in i and con == False: 143 | continue 144 | if 'bn_similarity' in i: 145 | if 'num' in i: 146 | self.state_dict()[i].copy_(param_dict[i]) 147 | else: 148 | self.state_dict()[i][:param_dict[i].shape[0]].copy_(param_dict[i]) 149 | elif 'bn_output' in i : 150 | if 'num' in i: 151 | self.state_dict()[i].copy_(param_dict[i]) 152 | else: 153 | self.state_dict()[i][:param_dict[i].shape[0]].copy_(param_dict[i]) 154 | else: 155 | self.state_dict()[i].copy_(param_dict[i]) 156 | 157 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.1 2 | torchvision==0.6.1 3 | scipy==1.5.2 4 | pandas 5 | numpy 6 | Pillow==8.0.0 7 | pytorch-ignite==0.4.2 8 | yacs 9 | tqdm -------------------------------------------------------------------------------- /scripts/AA_D.sh: -------------------------------------------------------------------------------- 1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('dukev',)" INPUT.SEQ_LEN 6 \ 2 | OUTPUT_DIR "./ckpt_DL_duke/Duke_DL_s6_resnet_axial_gap_sine_gran2" SOLVER.SOFT_MARGIN True \ 3 | MODEL.NAME 'resnet50_axial' MODEL.TEMP 'Done' MODEL.IF_LABELSMOOTH 'no' INPUT.IF_RE True \ 4 | DATASETS.ROOT_DIR '/home/mediax/Dataset/' TEST.NEW_EVAL False 5 | -------------------------------------------------------------------------------- /scripts/AA_M.sh: -------------------------------------------------------------------------------- 1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('mars',)" INPUT.SEQ_LEN 6 \ 2 | OUTPUT_DIR "./ckpt_DL_M/MARS_DL_s6_resnet_axial_gap_rqkv_gran4" SOLVER.SOFT_MARGIN True \ 3 | MODEL.NAME 'resnet50_axial' MODEL.TEMP 'Done' MODEL.IF_LABELSMOOTH 'no' INPUT.IF_RE True \ 4 | DATASETS.ROOT_DIR '/work/sychien421/Dataset/' 5 | -------------------------------------------------------------------------------- /scripts/NL_D.sh: -------------------------------------------------------------------------------- 1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('dukev',)" \ 2 | OUTPUT_DIR "./ckpt_DL_duke/Duke_DL_s6_NL0230_2gpu" SOLVER.SOFT_MARGIN True \ 3 | MODEL.NON_LAYERS [0,2,3,0] INPUT.IF_RE True INPUT.IF_CROP False MODEL.IF_LABELSMOOTH 'no' \ 4 | MODEL.NAME 'resnet50_NL' INPUT.SEQ_LEN 6 MODEL.TEMP 'Done' \ 5 | DATASETS.ROOT_DIR '/work/sychien421/Dataset/' 6 | -------------------------------------------------------------------------------- /scripts/NL_M.sh: -------------------------------------------------------------------------------- 1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('1,2')" DATASETS.NAMES "('mars',)" \ 2 | OUTPUT_DIR "./ckpt_DL_M/MARS_DL_s6_NL0230" SOLVER.SOFT_MARGIN True \ 3 | MODEL.NON_LAYERS [0,2,3,0] INPUT.IF_RE True INPUT.IF_CROP False MODEL.IF_LABELSMOOTH 'no' \ 4 | MODEL.NAME 'resnet50_NL' INPUT.SEQ_LEN 6 MODEL.TEMP 'Done' \ 5 | DATASETS.ROOT_DIR '/home/mediax/Dataset/' 6 | -------------------------------------------------------------------------------- /scripts/baseline_D.sh: -------------------------------------------------------------------------------- 1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('dukev',)" \ 2 | OUTPUT_DIR "./ckpt_DL_duke/Duke_avgpool" SOLVER.SOFT_MARGIN True \ 3 | MODEL.IF_LABELSMOOTH 'no' INPUT.IF_RE True INPUT.SEQ_LEN 6 \ 4 | DATASETS.ROOT_DIR '/work/sychien421/Dataset/' 5 | -------------------------------------------------------------------------------- /scripts/baseline_M.sh: -------------------------------------------------------------------------------- 1 | python3 tools/train.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0,1')" DATASETS.NAMES "('mars',)" \ 2 | OUTPUT_DIR "./ckpt_DL_M/MARS_DL_avgpool_s6" SOLVER.SOFT_MARGIN True \ 3 | MODEL.IF_LABELSMOOTH 'no' INPUT.IF_RE True INPUT.SEQ_LEN 6 TEST.NEW_EVAL False \ 4 | DATASETS.ROOT_DIR '/home/mediax/Dataset/' 5 | -------------------------------------------------------------------------------- /scripts/test_M.sh: -------------------------------------------------------------------------------- 1 | # for testing mode. 2 | # (1) TEST.TEST_MODE 'test' (using RRS to sample) 3 | # (2) TEST.TEST_MODE 'test_0' (first T images) 4 | # (3) TEST.TEST_MODE 'test_all_sampled' (using RRS to sample T,average the N/T tracklets) 5 | # (4) TEST.TEST_MODE 'test_all_continuous' (continuous smaple T frames, average the N/T tracklets) 6 | 7 | python3 tools/test.py --config_file='configs/video_baseline.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('mars',)" MODEL.NON_LAYERS [0,2,3,0] \ 8 | MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('/home/xxxx/xxxx.pth')" \ 9 | MODEL.NAME 'resnet50_axial' INPUT.SEQ_LEN 6 MODEL.TEMP 'Done' TEST.TEST_MODE 'test_all_sampled' TEST.IMS_PER_BATCH 1 \ 10 | DATASETS.ROOT_DIR '/work/sychien421/Dataset/' TEST.NEW_EVAL False 11 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from .build import make_optimizer, make_optimizer_with_center 3 | from .lr_scheduler import WarmupMultiStepLR 4 | from torch.optim.lr_scheduler import StepLR,MultiStepLR -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/solver/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/solver/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/solver/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /solver/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import torch 3 | 4 | 5 | def make_optimizer(cfg, model): 6 | params = [] 7 | for key, value in model.named_parameters(): 8 | if not value.requires_grad: 9 | continue 10 | lr = cfg.SOLVER.BASE_LR 11 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 12 | if "bias" in key: 13 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 14 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 15 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 16 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 17 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 18 | else: 19 | # We use Adam 20 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(model.parameters(),lr = cfg.SOLVER.BASE_LR,weight_decay=cfg.SOLVER.WEIGHT_DECAY) 21 | # optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 22 | return optimizer 23 | 24 | 25 | def make_optimizer_with_center(cfg, model, center_criterion): 26 | params = [] 27 | for key, value in model.named_parameters(): 28 | if not value.requires_grad: 29 | continue 30 | lr = cfg.SOLVER.BASE_LR 31 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 32 | if "bias" in key: 33 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 34 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 35 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 36 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 37 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 38 | else: 39 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 40 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 41 | return optimizer, optimizer_center 42 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | -------------------------------------------------------------------------------- /tests/lr_scheduler_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import unittest 3 | 4 | import torch 5 | from torch import nn 6 | 7 | sys.path.append('.') 8 | from solver.lr_scheduler import WarmupMultiStepLR 9 | from solver.build import make_optimizer 10 | from config import cfg 11 | 12 | 13 | class MyTestCase(unittest.TestCase): 14 | def test_something(self): 15 | net = nn.Linear(10, 10) 16 | optimizer = make_optimizer(cfg, net) 17 | lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10) 18 | for i in range(50): 19 | lr_scheduler.step() 20 | for j in range(3): 21 | print(i, lr_scheduler.get_lr()[0]) 22 | optimizer.step() 23 | 24 | 25 | if __name__ == '__main__': 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import argparse 3 | import os 4 | import sys 5 | from os import mkdir 6 | 7 | import torch 8 | from torch.backends import cudnn 9 | 10 | sys.path.append('.') 11 | from config import cfg 12 | from data import make_data_loader 13 | from engine.inference import inference 14 | from modeling import build_model 15 | from utils.logger import setup_logger 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser(description="Video-based ReID Baseline Inference") 20 | parser.add_argument( 21 | "--config_file", default="", help="path to config file", type=str 22 | ) 23 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 24 | nargs=argparse.REMAINDER) 25 | 26 | args = parser.parse_args() 27 | 28 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 29 | 30 | if args.config_file != "": 31 | cfg.merge_from_file(args.config_file) 32 | cfg.merge_from_list(args.opts) 33 | cfg.freeze() 34 | 35 | output_dir = cfg.OUTPUT_DIR 36 | if output_dir and not os.path.exists(output_dir): 37 | mkdir(output_dir) 38 | 39 | logger = setup_logger("reid_baseline", output_dir, 0) 40 | logger.info("Using {} GPUS".format(num_gpus)) 41 | logger.info(args) 42 | 43 | if args.config_file != "": 44 | logger.info("Loaded configuration file {}".format(args.config_file)) 45 | with open(args.config_file, 'r') as cf: 46 | config_str = "\n" + cf.read() 47 | logger.debug(config_str) 48 | logger.info("Running with config:\n{}".format(cfg)) 49 | 50 | if cfg.MODEL.DEVICE == "cuda": 51 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 52 | cudnn.benchmark = True 53 | 54 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) 55 | model = build_model(cfg, num_classes) 56 | model.load_param(cfg.TEST.WEIGHT) 57 | 58 | inference(cfg, model, val_loader, num_query) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import argparse 3 | import os 4 | import sys 5 | import torch 6 | 7 | from torch.backends import cudnn 8 | 9 | sys.path.append('.') 10 | from config import cfg 11 | from data import make_data_loader 12 | from engine.trainer import do_train, do_train_with_center 13 | from modeling import build_model 14 | from layers import make_loss, make_loss_with_center 15 | from solver import make_optimizer, make_optimizer_with_center, WarmupMultiStepLR, StepLR, MultiStepLR 16 | 17 | from utils.logger import setup_logger 18 | 19 | 20 | def train(cfg): 21 | # prepare dataset 22 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) 23 | # prepare model 24 | model = build_model(cfg, num_classes) 25 | 26 | if cfg.MODEL.IF_WITH_CENTER == 'no': 27 | print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE) 28 | optimizer = make_optimizer(cfg, model) 29 | 30 | loss_func = make_loss(cfg, num_classes) 31 | 32 | # Add for using self trained model 33 | if cfg.MODEL.PRETRAIN_CHOICE == 'self': 34 | raise NotImplementedError() 35 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1]) 36 | print('Start epoch:', start_epoch) 37 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') 38 | print('Path to the checkpoint of optimizer:', path_to_optimizer) 39 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH)) 40 | optimizer.load_state_dict(torch.load(path_to_optimizer)) 41 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 42 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch) 43 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 44 | start_epoch = 0 45 | #scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 46 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 47 | #scheduler = StepLR(optimizer, 50, cfg.SOLVER.GAMMA) 48 | scheduler = MultiStepLR(optimizer, cfg.SOLVER.STEPS , cfg.SOLVER.GAMMA) 49 | else: 50 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 51 | 52 | arguments = {} 53 | do_train( 54 | cfg, 55 | model, 56 | train_loader, 57 | val_loader, 58 | optimizer, 59 | scheduler, # modify for using self trained model 60 | loss_func, 61 | num_query, 62 | start_epoch # add for using self trained model 63 | ) 64 | elif cfg.MODEL.IF_WITH_CENTER == 'yes': 65 | print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE) 66 | loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu 67 | optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion) 68 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 69 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 70 | 71 | arguments = {} 72 | 73 | # Add for using self trained model 74 | if cfg.MODEL.PRETRAIN_CHOICE == 'self': 75 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1]) 76 | print('Start epoch:', start_epoch) 77 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') 78 | print('Path to the checkpoint of optimizer:', path_to_optimizer) 79 | path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param') 80 | print('Path to the checkpoint of center_param:', path_to_center_param) 81 | path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center') 82 | print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center) 83 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH)) 84 | optimizer.load_state_dict(torch.load(path_to_optimizer)) 85 | center_criterion.load_state_dict(torch.load(path_to_center_param)) 86 | optimizer_center.load_state_dict(torch.load(path_to_optimizer_center)) 87 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 88 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch) 89 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 90 | start_epoch = 0 91 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 92 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 93 | else: 94 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 95 | 96 | do_train_with_center( 97 | cfg, 98 | model, 99 | center_criterion, 100 | train_loader, 101 | val_loader, 102 | optimizer, 103 | optimizer_center, 104 | scheduler, # modify for using self trained model 105 | loss_func, 106 | num_query, 107 | start_epoch # add for using self trained model 108 | ) 109 | else: 110 | print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER)) 111 | 112 | 113 | def main(): 114 | parser = argparse.ArgumentParser(description="Video-based ReID Training") 115 | parser.add_argument( 116 | "--config_file", default="", help="path to config file", type=str 117 | ) 118 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 119 | nargs=argparse.REMAINDER) 120 | 121 | args = parser.parse_args() 122 | 123 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 124 | 125 | if args.config_file != "": 126 | cfg.merge_from_file(args.config_file) 127 | cfg.merge_from_list(args.opts) 128 | cfg.freeze() 129 | 130 | output_dir = cfg.OUTPUT_DIR 131 | if output_dir and not os.path.exists(output_dir): 132 | os.makedirs(output_dir) 133 | 134 | logger = setup_logger("reid_baseline", output_dir, 0) 135 | logger.info("Using {} GPUS".format(num_gpus)) 136 | logger.info(args) 137 | 138 | if args.config_file != "": 139 | logger.info("Loaded configuration file {}".format(args.config_file)) 140 | with open(args.config_file, 'r') as cf: 141 | config_str = "\n" + cf.read() 142 | logger.debug(config_str) 143 | logger.info("Running with config:\n{}".format(cfg)) 144 | 145 | if cfg.MODEL.DEVICE == "cuda": 146 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu 147 | cudnn.benchmark = True 148 | train(cfg) 149 | 150 | 151 | if __name__ == '__main__': 152 | main() 153 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/iotools.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/iotools.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/re_ranking.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/re_ranking.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/reid_metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackie840129/CF-AAN/f5be5eaf0946fc135ad55e394f759fd311003903/utils/__pycache__/reid_metric.cpython-36.pyc -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import errno 8 | import json 9 | import os 10 | 11 | import os.path as osp 12 | 13 | 14 | def mkdir_if_missing(directory): 15 | if not osp.exists(directory): 16 | try: 17 | os.makedirs(directory) 18 | except OSError as e: 19 | if e.errno != errno.EEXIST: 20 | raise 21 | 22 | 23 | def check_isfile(path): 24 | isfile = osp.isfile(path) 25 | if not isfile: 26 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 27 | return isfile 28 | 29 | 30 | def read_json(fpath): 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | mkdir_if_missing(osp.dirname(fpath)) 38 | with open(fpath, 'w') as f: 39 | json.dump(obj, f, indent=4, separators=(',', ': ')) 40 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import logging 4 | import os 5 | import sys 6 | 7 | 8 | def setup_logger(name, save_dir, distributed_rank): 9 | logger = logging.getLogger(name) 10 | logger.setLevel(logging.DEBUG) 11 | # don't log results for the non-master process 12 | if distributed_rank > 0: 13 | return logger 14 | ch = logging.StreamHandler(stream=sys.stdout) 15 | ch.setLevel(logging.INFO) 16 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 17 | ch.setFormatter(formatter) 18 | logger.addHandler(ch) 19 | 20 | if save_dir: 21 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger 27 | -------------------------------------------------------------------------------- /utils/re_ranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | @author: luohao 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea,galFea]) 37 | print('using GPU to compute original distance') 38 | distmat = torch.pow(feat,2).sum(dim=1, keepdim=True).expand(all_num,all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1,-2,feat,feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | -------------------------------------------------------------------------------- /utils/reid_metric.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | import torch 4 | from ignite.metrics import Metric 5 | 6 | from data.datasets.eval_reid import eval_func 7 | from .re_ranking import re_ranking 8 | 9 | 10 | class R1_mAP(Metric): 11 | def __init__(self, num_query, max_rank=50, feat_norm='yes',new_eval=False): 12 | super(R1_mAP, self).__init__() 13 | self.num_query = num_query 14 | self.max_rank = max_rank 15 | self.feat_norm = feat_norm 16 | self.new_eval = new_eval 17 | 18 | def reset(self): 19 | self.feats = [] 20 | self.pids = [] 21 | self.camids = [] 22 | if self.new_eval: 23 | self.ambis = [] 24 | 25 | def update(self, output): 26 | if self.new_eval: 27 | feat, pid, ambi, camid = output 28 | self.ambis.extend(np.asarray(ambi)) 29 | else: 30 | feat, pid, camid = output 31 | self.feats.append(feat) 32 | self.pids.extend(np.asarray(pid)) 33 | self.camids.extend(np.asarray(camid)) 34 | 35 | def compute(self): 36 | feats = torch.cat(self.feats, dim=0) 37 | if self.feat_norm == 'yes': 38 | print("The test feature is normalized") 39 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 40 | # query 41 | qf = feats[:self.num_query] 42 | q_pids = np.asarray(self.pids[:self.num_query]) 43 | q_camids = np.asarray(self.camids[:self.num_query]) 44 | if self.new_eval: 45 | q_ambis = np.asarray(self.ambis[:self.num_query]) 46 | else: 47 | q_ambis = None 48 | 49 | # gallery 50 | gf = feats[self.num_query:] 51 | g_pids = np.asarray(self.pids[self.num_query:]) 52 | g_camids = np.asarray(self.camids[self.num_query:]) 53 | if self.new_eval: 54 | g_ambis = np.asarray(self.ambis[self.num_query:]) 55 | else: 56 | g_ambis = None 57 | 58 | m, n = qf.shape[0], gf.shape[0] 59 | 60 | # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 61 | # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 62 | # distmat.addmm_(qf, gf.t(),beta=1,alpha=-2) 63 | distmat = -1*torch.mm(qf,gf.t()) 64 | distmat = distmat.cpu().numpy() 65 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids,q_ambis=q_ambis,g_ambis=g_ambis) 66 | 67 | return cmc, mAP 68 | 69 | # Didn't implement new eval 70 | class R1_mAP_reranking(Metric): 71 | def __init__(self, num_query, max_rank=50, feat_norm='yes'): 72 | super(R1_mAP_reranking, self).__init__() 73 | self.num_query = num_query 74 | self.max_rank = max_rank 75 | self.feat_norm = feat_norm 76 | 77 | def reset(self): 78 | self.feats = [] 79 | self.pids = [] 80 | self.camids = [] 81 | 82 | def update(self, output): 83 | feat, pid, camid = output 84 | self.feats.append(feat) 85 | self.pids.extend(np.asarray(pid)) 86 | self.camids.extend(np.asarray(camid)) 87 | 88 | def compute(self): 89 | feats = torch.cat(self.feats, dim=0) 90 | if self.feat_norm == 'yes': 91 | print("The test feature is normalized") 92 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 93 | 94 | # query 95 | qf = feats[:self.num_query] 96 | q_pids = np.asarray(self.pids[:self.num_query]) 97 | q_camids = np.asarray(self.camids[:self.num_query]) 98 | # gallery 99 | gf = feats[self.num_query:] 100 | g_pids = np.asarray(self.pids[self.num_query:]) 101 | g_camids = np.asarray(self.camids[self.num_query:]) 102 | # m, n = qf.shape[0], gf.shape[0] 103 | # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 104 | # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 105 | # distmat.addmm_(1, -2, qf, gf.t()) 106 | # distmat = distmat.cpu().numpy() 107 | print("Enter reranking") 108 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 109 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 110 | 111 | return cmc, mAP --------------------------------------------------------------------------------