├── .gitignore ├── DATASETS.md ├── LICENSE ├── README.md ├── requirements.txt ├── scripts ├── test_vidreid_xent_htri_vmgn_dukev.sh ├── test_vidreid_xent_htri_vmgn_ilidsvid.sh ├── test_vidreid_xent_htri_vmgn_mars.sh ├── test_vidreid_xent_htri_vmgn_prid2011.sh ├── train_vidreid_xent_htri_vmgn_dukev.sh ├── train_vidreid_xent_htri_vmgn_ilidsvid.sh ├── train_vidreid_xent_htri_vmgn_mars.sh └── train_vidreid_xent_htri_vmgn_prid2011.sh ├── torchreid ├── __init__.py ├── data_manager │ ├── __init__.py │ ├── dukemtmcvidreid.py │ ├── ilidsvid.py │ ├── mars.py │ └── prid2011.py ├── dataset_loader.py ├── losses │ ├── __init__.py │ ├── cross_entropy_loss.py │ └── hard_mine_triplet_loss.py ├── lr_scheduler.py ├── metrics │ ├── __init__.py │ ├── accuracy.py │ ├── distance.py │ ├── rank.py │ └── rank_cylib │ │ ├── Makefile │ │ ├── __init__.py │ │ ├── rank_cy.c │ │ ├── rank_cy.pyx │ │ ├── setup.py │ │ └── test_cython.py ├── models │ ├── __init__.py │ ├── ganet.py │ ├── graphnet.py │ ├── gsta.py │ ├── res50tp.py │ ├── resnet.py │ ├── resnet3d.py │ ├── resnet3dt.py │ ├── resnet50_s1.py │ ├── resnet_temporal.py │ ├── simple_sta.py │ ├── sta.py │ └── vmgn.py ├── optimizers.py ├── samplers.py ├── transforms.py └── utils │ ├── __init__.py │ ├── avgmeter.py │ ├── iotools.py │ ├── logger.py │ ├── model_complexity.py │ ├── re_ranking.py │ ├── reidtools.py │ └── torchtools.py └── train_vidreid_xent_htri.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | data/ 3 | log/ 4 | saved-models/ 5 | torchreid/eval_lib/eval.c 6 | .idea 7 | pretrained 8 | 9 | # OS X 10 | .DS_Store 11 | .Spotlight-V100 12 | .Trashes 13 | ._* 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | .static_storage/ 71 | .media/ 72 | local_settings.py 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | -------------------------------------------------------------------------------- /DATASETS.md: -------------------------------------------------------------------------------- 1 | ## How to prepare data 2 | 3 | Create a directory to store reid datasets under this repo via 4 | ```bash 5 | cd AGRL.pytorch/ 6 | mkdir data/ 7 | ``` 8 | 9 | If you wanna store datasets in another directory, you need to specify `--root path_to_your/data` when running the training code. Please follow the instructions below to prepare each dataset. After that, you can simply do `-d the_dataset` when running the training code. 10 | 11 | Please do not call image dataset when running video reid scripts, otherwise error would occur, and vice versa. 12 | 13 | ### Video ReID 14 | 15 | **MARS** [8]: 16 | 1. Create a directory named `mars/` under `data/`. 17 | 2. Download dataset to `data/mars/` from http://www.liangzheng.com.cn/Project/project_mars.html. 18 | 3. Extract `bbox_train.zip` and `bbox_test.zip`. 19 | 4. Download split information from https://github.com/liangzheng06/MARS-evaluation/tree/master/info and put `info/` in `data/mars` (we want to follow the standard split in [8]). The data structure would look like: 20 | ``` 21 | mars/ 22 | bbox_test/ 23 | bbox_train/ 24 | info/ 25 | ``` 26 | 5. Use `-d mars` when running the training code. 27 | 28 | **iLIDS-VID** [11]: 29 | 1. The code supports automatic download and formatting. Simple use `-d ilidsvid` when running the training code. The data structure would look like: 30 | ``` 31 | ilids-vid/ 32 | i-LIDS-VID/ 33 | train-test people splits/ 34 | splits.json 35 | ``` 36 | 37 | **PRID** [12]: 38 | 1. Under `data/`, do `mkdir prid2011` to create a directory. 39 | 2. Download dataset from https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/PRID11/ and extract it under `data/prid2011`. 40 | 3. Download the split created by [iLIDS-VID](http://www.eecs.qmul.ac.uk/~xiatian/downloads_qmul_iLIDS-VID_ReID_dataset.html) from [here](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/datasets/prid2011/splits_prid2011.json), and put it in `data/prid2011/`. We follow [11] and use 178 persons whose sequences are more than a threshold so that results on this dataset can be fairly compared with other approaches. The data structure would look like: 41 | ``` 42 | prid2011/ 43 | splits_prid2011.json 44 | prid_2011/ 45 | multi_shot/ 46 | single_shot/ 47 | readme.txt 48 | ``` 49 | 4. Use `-d prid2011` when running the training code. 50 | 51 | **DukeMTMC-VideoReID** [16, 23]: 52 | 1. Use `-d dukemtmcvidreid` directly. 53 | 2. If you wanna download the dataset manually, get `DukeMTMC-VideoReID.zip` from https://github.com/Yu-Wu/DukeMTMC-VideoReID. Unzip the file to `data/dukemtmc-vidreid`. Ultimately, you need to have 54 | ``` 55 | dukemtmc-vidreid/ 56 | DukeMTMC-VideoReID/ 57 | train/ # essential 58 | query/ # essential 59 | gallery/ # essential 60 | ... (and license files) 61 | ``` 62 | 63 | 64 | ## Dataset loaders 65 | These are implemented in `dataset_loader.py` where we have two main classes that subclass [torch.utils.data.Dataset](http://pytorch.org/docs/master/_modules/torch/utils/data/dataset.html#Dataset): 66 | * [VideoDataset](https://github.com/KaiyangZhou/deep-person-reid/blob/master/dataset_loader.py#L38): processes video-based person reid datasets. 67 | 68 | These two classes are used for [torch.utils.data.DataLoader](http://pytorch.org/docs/master/_modules/torch/utils/data/dataloader.html#DataLoader) that can provide batched data. Data loader with `VideoDataset` outputs batch data of `(batch, sequence, channel, height, width)`. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yiming Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | This is the PyTorch implementation for **Adaptive Graph Representation Learning for Video Person Re-identification**. 3 | 4 | ## Get started 5 | ```shell script 6 | git clone https://github.com/weleen/AGRL.pytorch /path/to/save 7 | pip install -r requirements.txt 8 | cd torchreid/metrics/rank_cylib && make 9 | ``` 10 | 11 | ## Dataset 12 | create dataset directory 13 | ```shell script 14 | mkdir data 15 | ``` 16 | Prepare datasets: 17 | ```shell script 18 | ├── dukemtmc-vidreid 19 | │ ├── DukeMTMC-VideoReID 20 | │ ├── pose.json 21 | │ ├── split_gallery.json 22 | │ ├── split_query.json 23 | │ └── split_train.json 24 | │ 25 | ├── ilids-vid 26 | │ ├── i-LIDS-VID 27 | │ ├── pose.json 28 | │ ├── splits.json 29 | │ └── train-test people splits 30 | │ 31 | ├── mars 32 | │ ├── bbox_test 33 | │ ├── bbox_train 34 | │ ├── info 35 | │ ├── pose.json 36 | │ └── train-test people splits 37 | │ 38 | ├── prid2011 39 | ├── pose.json 40 | ├── prid_2011 41 | ├── prid_2011.zip 42 | ├── splits_prid2011.json 43 | └── train_test_splits_prid.mat 44 | ``` 45 | `pose.json` is obtained by running [AlphaPose](https://github.com/MVIG-SJTU/AlphaPose), we put the files on [Baidu Netdisk](https://pan.baidu.com/s/1RduGEbq-tmfLAHM0k3xa4A) (code: luxr) and 46 | [Google Driver](https://drive.google.com/drive/folders/1BVEjMava3UQh4jC2bp-tcFo1rOZDB8MS?usp=sharing)(only pose.json, please download the model from Baidu Netdisk). 47 | 48 | More details could be found in [DATASETS.md](DATASETS.md). 49 | 50 | 51 | ## Train 52 | ```shell script 53 | bash scripts/train_vidreid_xent_htri_vmgn_mars.sh 54 | ``` 55 | 56 | To use multiple GPUs, you can set `--gpu-devices 0,1,2,3`. 57 | 58 | **Note:** To resume training, you can use `--resume path/to/model` to load a checkpoint from which saved model weights and `start_epoch` will be used. Learning rate needs to be initialized carefully. If you just wanna load a pretrained model by discarding layers that do not match in size (e.g. classification layer), use `--load-weights path/to/model` instead. 59 | 60 | Please refer to the code for more details. 61 | 62 | 63 | ## Test 64 | create a directory to store model weights `mkdir saved-models/` beforehand. Then, run the following command to test 65 | ```shell script 66 | bash scripts/test_vidreid_xent_htri_vmgn_mars.sh 67 | ``` 68 | All the model weights are available. 69 | 70 | ## Model 71 | 72 | All the results tested with 4 TITAN X GPU and 64GB memory. 73 | 74 | | Dataset | Rank-1 | mAP | 75 | | :---: | :---: | :---: | 76 | | iLIDS-VID | 83.7% | - | 77 | | PRID2011 | 93.1% | - | 78 | | MARS | 89.8% | 81.1% | 79 | | DukeMTMC-vidreid | 96.7% | 94.2% | 80 | 81 | 82 | ## Citation 83 | Please kindly cite this project in your paper if it is helpful😊: 84 | ``` 85 | @article{wu2020adaptive, 86 | title={Adaptive graph representation learning for video person re-identification}, 87 | author={Wu, Yiming and Bourahla, Omar El Farouk and Li, Xi* and Wu, Fei and Tian, Qi and Zhou, Xue}, 88 | journal={IEEE Transactions on Image Processing}, 89 | year={2020}, 90 | publisher={IEEE} 91 | } 92 | ``` 93 | 94 | This project is developed based on [deep-person-reid](https://github.com/KaiyangZhou/deep-person-reid) and [STE-NVAN](https://github.com/jackie840129/STE-NVAN/). 95 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.28.3 2 | h5py==2.8.0 3 | numpy==1.14.5 4 | Pillow>=7.1.0 5 | scipy>=1.0.0 6 | torch==1.2.0 7 | torchvision==0.2.1 8 | tensorboardX -------------------------------------------------------------------------------- /scripts/test_vidreid_xent_htri_vmgn_dukev.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_vidreid_xent_htri.py -d dukemtmcvidreid \ 3 | -a vmgn \ 4 | --evaluate \ 5 | --seq-len 8 \ 6 | --test-sample evenly \ 7 | --num-split 4 \ 8 | --pyramid-part \ 9 | --num-gb 2 \ 10 | --use-pose \ 11 | --learn-graph \ 12 | --gpu-devices 0 \ 13 | --dist-metric cosine \ 14 | --load-weights saved-models/dukemtmc-vidreid/model_dukev.pth.tar \ 15 | --save-dir log/dukev 16 | -------------------------------------------------------------------------------- /scripts/test_vidreid_xent_htri_vmgn_ilidsvid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | i=0 3 | while((i<10)) 4 | do 5 | python train_vidreid_xent_htri.py -d ilidsvid \ 6 | --evaluate \ 7 | --seq-len 8 \ 8 | --test-batch 16 \ 9 | --test-sample evenly \ 10 | -a vmgn \ 11 | --num-split 4 \ 12 | --pyramid-part \ 13 | --num-gb 2 \ 14 | --use-pose \ 15 | --learn-graph \ 16 | --gpu-devices 0 \ 17 | --dist-metric cosine \ 18 | --split-id $i \ 19 | --load-weights saved-models/ilidsvid/split"$i"/model_ilidsvid.pth.tar \ 20 | --save-dir log/ilidsvid/split"$i" 21 | let i=$i+1 22 | done 23 | -------------------------------------------------------------------------------- /scripts/test_vidreid_xent_htri_vmgn_mars.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_vidreid_xent_htri.py -d mars \ 3 | -a vmgn \ 4 | --evaluate \ 5 | --seq-len 8 \ 6 | --test-sample evenly \ 7 | --num-split 4 \ 8 | --pyramid-part \ 9 | --num-gb 2 \ 10 | --use-pose \ 11 | --learn-graph \ 12 | --gpu-devices 0 \ 13 | --dist-metric cosine \ 14 | --load-weights saved-models/mars/model_mars.pth.tar \ 15 | --save-dir log/mars 16 | -------------------------------------------------------------------------------- /scripts/test_vidreid_xent_htri_vmgn_prid2011.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | i=0 3 | while((i<10)) 4 | do 5 | python train_vidreid_xent_htri.py -d prid2011 \ 6 | --evaluate \ 7 | --seq-len 8 \ 8 | --test-batch 16 \ 9 | --test-sample evenly \ 10 | -a vmgn \ 11 | --num-split 4 \ 12 | --pyramid-part \ 13 | --num-gb 2 \ 14 | --use-pose \ 15 | --learn-graph \ 16 | --gpu-devices 0 \ 17 | --dist-metric cosine \ 18 | --split-id $i \ 19 | --load-weights saved-models/prid2011/split"$i"/model_prid2011.pth.tar \ 20 | --save-dir log/prid2011/split"$i" 21 | let i=$i+1 22 | done 23 | -------------------------------------------------------------------------------- /scripts/train_vidreid_xent_htri_vmgn_dukev.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_vidreid_xent_htri.py -d dukemtmcvidreid \ 3 | -a vmgn \ 4 | --seq-len 8 \ 5 | --train-batch 16 \ 6 | --test-batch 16 \ 7 | --num-instances 4 \ 8 | --train-sample restricted \ 9 | --train-sampler RandomIdentitySamplerV1 \ 10 | --test-sample evenly \ 11 | --optim adam \ 12 | --soft-margin \ 13 | --lr 1e-4 \ 14 | --max-epoch 200 \ 15 | --stepsize 50 100 150 \ 16 | --num-split 4 \ 17 | --pyramid-part \ 18 | --num-gb 2 \ 19 | --use-pose \ 20 | --learn-graph \ 21 | --flip-aug \ 22 | --gpu-devices 0 \ 23 | --eval-step 5 \ 24 | --print-last \ 25 | --dist-metric cosine \ 26 | --consistent-loss \ 27 | --save-dir log/video/vmgn/duke-ngb2-consistent -------------------------------------------------------------------------------- /scripts/train_vidreid_xent_htri_vmgn_ilidsvid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | i=0 3 | while((i<10)) 4 | do 5 | python train_vidreid_xent_htri.py -d ilidsvid \ 6 | --seq-len 8 \ 7 | --train-batch 16 \ 8 | --test-batch 16 \ 9 | --num-instances 4 \ 10 | --train-sample restricted \ 11 | --test-sample evenly \ 12 | --train-sampler RandomIdentitySamplerV1 \ 13 | --optim adam \ 14 | --soft-margin \ 15 | --max-epoch 400 \ 16 | --lr 1e-4 \ 17 | --stepsize 100 200 300 \ 18 | -a vmgn \ 19 | --num-split 4 \ 20 | --pyramid-part \ 21 | --num-gb 2 \ 22 | --use-pose \ 23 | --learn-graph \ 24 | --flip-aug \ 25 | --print-last \ 26 | --gpu-devices 0 \ 27 | --eval-step 1 \ 28 | --dist-metric cosine \ 29 | --consistent-loss \ 30 | --split-id $i \ 31 | --save-dir log/video/vmgn/ilidsvid-ngb2-consistent/split"$i" 32 | let i=$i+1 33 | done 34 | -------------------------------------------------------------------------------- /scripts/train_vidreid_xent_htri_vmgn_mars.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_vidreid_xent_htri.py -d mars \ 3 | -a vmgn \ 4 | --seq-len 8 \ 5 | --train-batch 16 \ 6 | --test-batch 16 \ 7 | --num-instances 4 \ 8 | --train-sample restricted \ 9 | --train-sampler RandomIdentitySamplerV1 \ 10 | --test-sample evenly \ 11 | --optim adam \ 12 | --soft-margin \ 13 | --lr 1e-4 \ 14 | --max-epoch 200 \ 15 | --stepsize 50 100 150 \ 16 | --num-split 4 \ 17 | --pyramid-part \ 18 | --num-gb 2 \ 19 | --use-pose \ 20 | --learn-graph \ 21 | --flip-aug \ 22 | --gpu-devices 0 \ 23 | --eval-step 5 \ 24 | --print-last \ 25 | --dist-metric cosine \ 26 | --consistent-loss \ 27 | --save-dir log/video/vmgn/mars-ngb2-consistent 28 | -------------------------------------------------------------------------------- /scripts/train_vidreid_xent_htri_vmgn_prid2011.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | i=0 3 | while((i<10)) 4 | do 5 | python train_vidreid_xent_htri.py -d prid2011 \ 6 | --seq-len 8 \ 7 | --train-batch 16 \ 8 | --test-batch 16 \ 9 | --num-instances 4 \ 10 | --train-sample restricted \ 11 | --test-sample evenly \ 12 | --train-sampler RandomIdentitySamplerV1 \ 13 | --optim adam \ 14 | --soft-margin \ 15 | --max-epoch 400 \ 16 | --lr 1e-4 \ 17 | --stepsize 100 200 300 \ 18 | -a vmgn \ 19 | --num-split 4 \ 20 | --pyramid-part \ 21 | --num-gb 2 \ 22 | --use-pose \ 23 | --learn-graph \ 24 | --flip-aug \ 25 | --print-last \ 26 | --gpu-devices 0 \ 27 | --eval-step 1 \ 28 | --dist-metric cosine \ 29 | --split-id $i \ 30 | --consistent-loss \ 31 | --save-dir log/video/vmgn/prid2011-ngb2-consistent/split"$i" 32 | let i=$i+1 33 | done 34 | -------------------------------------------------------------------------------- /torchreid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weleen/AGRL.pytorch/83a37eea19365bc8e74c3b3c5fe5ad5d00d04f2c/torchreid/__init__.py -------------------------------------------------------------------------------- /torchreid/data_manager/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .mars import Mars 6 | from .ilidsvid import iLIDSVID 7 | from .prid2011 import PRID2011 8 | from .dukemtmcvidreid import DukeMTMCVidReID 9 | 10 | __vidreid_factory = { 11 | 'mars': Mars, 12 | 'ilidsvid': iLIDSVID, 13 | 'prid2011': PRID2011, 14 | 'dukemtmcvidreid': DukeMTMCVidReID, 15 | } 16 | 17 | 18 | def get_names(): 19 | return list(__vidreid_factory.keys()) 20 | 21 | 22 | def init_vidreid_dataset(name, **kwargs): 23 | if name not in list(__vidreid_factory.keys()): 24 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, list(__vidreid_factory.keys()))) 25 | return __vidreid_factory[name](**kwargs) -------------------------------------------------------------------------------- /torchreid/data_manager/dukemtmcvidreid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import glob 7 | import re 8 | import sys 9 | import urllib 10 | import tarfile 11 | import zipfile 12 | import os.path as osp 13 | from scipy.io import loadmat 14 | import numpy as np 15 | import json 16 | import h5py 17 | from scipy.misc import imsave 18 | 19 | from torchreid.utils.iotools import mkdir_if_missing, write_json, read_json 20 | 21 | 22 | class DukeMTMCVidReID(object): 23 | """ 24 | DukeMTMCVidReID 25 | 26 | Reference: 27 | Wu et al. Exploit the Unknown Gradually: One-Shot Video-Based Person 28 | Re-Identification by Stepwise Learning. CVPR 2018. 29 | 30 | URL: https://github.com/Yu-Wu/DukeMTMC-VideoReID 31 | 32 | Dataset statistics: 33 | # identities: 702 (train) + 702 (test) 34 | # tracklets: 2196 (train) + 2636 (test) 35 | """ 36 | dataset_dir = 'dukemtmc-vidreid' 37 | 38 | def __init__(self, root='data', min_seq_len=0, verbose=True, **kwargs): 39 | self.dataset_dir = osp.join(root, self.dataset_dir) 40 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-VideoReID.zip' 41 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/train') 42 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/query') 43 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/gallery') 44 | self.split_train_json_path = osp.join(self.dataset_dir, 'split_train.json') 45 | self.split_query_json_path = osp.join(self.dataset_dir, 'split_query.json') 46 | self.split_gallery_json_path = osp.join(self.dataset_dir, 'split_gallery.json') 47 | self.pose_file = osp.join(self.dataset_dir, 'pose.json') 48 | 49 | self.min_seq_len = min_seq_len 50 | self._download_data() 51 | self._check_before_run() 52 | print("Note: if root path is changed, the previously generated json files need to be re-generated (so delete them first)") 53 | 54 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 55 | self._process_dir(self.train_dir, self.split_train_json_path, relabel=True) 56 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 57 | self._process_dir(self.query_dir, self.split_query_json_path, relabel=False) 58 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 59 | self._process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) 60 | 61 | if 'use_pose' in kwargs and kwargs['use_pose']: 62 | # process the pose information 63 | with open(self.pose_file, 'r') as f: 64 | self.poses = json.load(f) 65 | self.process_poses = dict() 66 | for key in self.poses: 67 | # save only one body 68 | maxidx = -1 69 | maxarea = -1 70 | maxscore = -1 71 | assert len(self.poses[key]['bodies']) >= 1, 'pose of {} is empty'.format(key) 72 | if len(self.poses[key]['bodies']) == 1: 73 | self.process_poses[key] = np.array(self.poses[key]['bodies'][0]['joints']).reshape((-1, 3)) 74 | else: 75 | for idx, body in enumerate(self.poses[key]['bodies']): 76 | tmp_kps = np.array(body['joints']).reshape((-1, 3)) 77 | tmp_area = (max(tmp_kps[:, 0]) - min(tmp_kps[:, 0])) * (max(tmp_kps[:, 1]) - min(tmp_kps[:, 1])) 78 | tmp_score = body['score'] 79 | if tmp_score > maxscore: 80 | if tmp_area > maxarea and tmp_score > 1.1 * maxscore: 81 | maxscore = tmp_score 82 | maxidx = idx 83 | self.process_poses[key] = np.array(self.poses[key]['bodies'][maxidx]['joints']).reshape((-1, 3)) 84 | else: 85 | self.process_poses = dict() 86 | 87 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 88 | min_num = np.min(num_imgs_per_tracklet) 89 | max_num = np.max(num_imgs_per_tracklet) 90 | avg_num = np.mean(num_imgs_per_tracklet) 91 | 92 | num_total_pids = num_train_pids + num_query_pids 93 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 94 | 95 | if verbose: 96 | print("=> DukeMTMC-VideoReID loaded") 97 | print("Dataset statistics:") 98 | print(" ------------------------------") 99 | print(" subset | # ids | # tracklets") 100 | print(" ------------------------------") 101 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 102 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 103 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 104 | print(" ------------------------------") 105 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 106 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 107 | print(" ------------------------------") 108 | 109 | self.train = train 110 | self.query = query 111 | self.gallery = gallery 112 | 113 | self.num_train_pids = num_train_pids 114 | self.num_query_pids = num_query_pids 115 | self.num_gallery_pids = num_gallery_pids 116 | 117 | def _download_data(self): 118 | if osp.exists(self.dataset_dir): 119 | print("This dataset has been downloaded.") 120 | return 121 | 122 | print("Creating directory {}".format(self.dataset_dir)) 123 | mkdir_if_missing(self.dataset_dir) 124 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 125 | 126 | print("Downloading DukeMTMC-VideoReID dataset") 127 | urllib.urlretrieve(self.dataset_url, fpath) 128 | 129 | print("Extracting files") 130 | zip_ref = zipfile.ZipFile(fpath, 'r') 131 | zip_ref.extractall(self.dataset_dir) 132 | zip_ref.close() 133 | 134 | def _check_before_run(self): 135 | """Check if all files are available before going deeper""" 136 | if not osp.exists(self.dataset_dir): 137 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 138 | if not osp.exists(self.train_dir): 139 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 140 | if not osp.exists(self.query_dir): 141 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 142 | if not osp.exists(self.gallery_dir): 143 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 144 | 145 | def _process_dir(self, dir_path, json_path, relabel): 146 | if osp.exists(json_path): 147 | print("=> {} generated before, awesome!".format(json_path)) 148 | split = read_json(json_path) 149 | return split['tracklets'], split['num_tracklets'], split['num_pids'], split['num_imgs_per_tracklet'] 150 | 151 | print("=> Automatically generating split (might take a while for the first time, have a coffe)") 152 | pdirs = glob.glob(osp.join(dir_path, '*')) # avoid .DS_Store 153 | print("Processing {} with {} person identities".format(dir_path, len(pdirs))) 154 | 155 | pid_container = set() 156 | for pdir in pdirs: 157 | pid = int(osp.basename(pdir)) 158 | pid_container.add(pid) 159 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 160 | 161 | tracklets = [] 162 | num_imgs_per_tracklet = [] 163 | for pdir in pdirs: 164 | pid = int(osp.basename(pdir)) 165 | if relabel: pid = pid2label[pid] 166 | tdirs = glob.glob(osp.join(pdir, '*')) 167 | for tdir in tdirs: 168 | raw_img_paths = glob.glob(osp.join(tdir, '*.jpg')) 169 | num_imgs = len(raw_img_paths) 170 | 171 | if num_imgs < self.min_seq_len: 172 | continue 173 | 174 | num_imgs_per_tracklet.append(num_imgs) 175 | img_paths = [] 176 | for img_idx in range(num_imgs): 177 | # some tracklet starts from 0002 instead of 0001 178 | img_idx_name = 'F' + str(img_idx+1).zfill(4) 179 | res = glob.glob(osp.join(tdir, '*' + img_idx_name + '*.jpg')) 180 | if len(res) == 0: 181 | print("Warn: index name {} in {} is missing, jump to next".format(img_idx_name, tdir)) 182 | continue 183 | img_paths.append(res[0]) 184 | img_name = osp.basename(img_paths[0]) 185 | if img_name.find('_') == -1: 186 | # old naming format: 0001C6F0099X30823.jpg 187 | camid = int(img_name[5]) - 1 188 | else: 189 | # new naming format: 0001_C6_F0099_X30823.jpg 190 | camid = int(img_name[6]) - 1 191 | img_paths = tuple(img_paths) 192 | tracklets.append((img_paths, pid, camid)) 193 | 194 | num_pids = len(pid_container) 195 | num_tracklets = len(tracklets) 196 | 197 | print("Saving split to {}".format(json_path)) 198 | split_dict = { 199 | 'tracklets': tracklets, 200 | 'num_tracklets': num_tracklets, 201 | 'num_pids': num_pids, 202 | 'num_imgs_per_tracklet': num_imgs_per_tracklet, 203 | } 204 | write_json(split_dict, json_path) 205 | 206 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 207 | -------------------------------------------------------------------------------- /torchreid/data_manager/ilidsvid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import glob 7 | import re 8 | import sys 9 | import urllib 10 | import tarfile 11 | import zipfile 12 | import os.path as osp 13 | from scipy.io import loadmat 14 | import numpy as np 15 | import h5py 16 | from scipy.misc import imsave 17 | import json 18 | 19 | from torchreid.utils.iotools import mkdir_if_missing, write_json, read_json 20 | 21 | 22 | class iLIDSVID(object): 23 | """ 24 | iLIDS-VID 25 | 26 | Reference: 27 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014. 28 | 29 | URL: http://www.eecs.qmul.ac.uk/~xiatian/downloads_qmul_iLIDS-VID_ReID_dataset.html 30 | 31 | Dataset statistics: 32 | # identities: 300 33 | # tracklets: 600 34 | # cameras: 2 35 | """ 36 | dataset_dir = 'ilids-vid' 37 | 38 | def __init__(self, root='data', split_id=0, verbose=True, **kwargs): 39 | self.dataset_dir = osp.join(root, self.dataset_dir) 40 | self.dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar' 41 | self.data_dir = osp.join(self.dataset_dir, 'i-LIDS-VID') 42 | self.split_dir = osp.join(self.dataset_dir, 'train-test people splits') 43 | self.split_mat_path = osp.join(self.split_dir, 'train_test_splits_ilidsvid.mat') 44 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 45 | self.cam_1_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam1') 46 | self.cam_2_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam2') 47 | self.pose_file = osp.join(self.dataset_dir, 'pose.json') 48 | 49 | self._download_data() 50 | self._check_before_run() 51 | with open(self.pose_file, 'r') as f: 52 | self.poses = json.load(f) 53 | # process the pose information 54 | self.process_poses = dict() 55 | for key in self.poses: 56 | # save only one body 57 | maxidx = -1 58 | maxarea = -1 59 | maxscore = -1 60 | assert len(self.poses[key]['bodies']) >= 1, 'pose of {} is empty'.format(key) 61 | if len(self.poses[key]['bodies']) == 1: 62 | self.process_poses[key] = np.array(self.poses[key]['bodies'][0]['joints']).reshape((-1, 3)) 63 | else: 64 | for idx, body in enumerate(self.poses[key]['bodies']): 65 | tmp_kps = np.array(body['joints']).reshape((-1, 3)) 66 | tmp_area = (max(tmp_kps[:, 0]) - min(tmp_kps[:, 0])) * (max(tmp_kps[:, 1]) - min(tmp_kps[:, 1])) 67 | tmp_score = body['score'] 68 | if tmp_score > maxscore: 69 | if tmp_area > maxarea and tmp_score > 1.1 * maxscore: 70 | maxscore = tmp_score 71 | maxidx = idx 72 | self.process_poses[key] = np.array(self.poses[key]['bodies'][maxidx]['joints']).reshape((-1, 3)) 73 | 74 | self._prepare_split() 75 | splits = read_json(self.split_path) 76 | if split_id >= len(splits): 77 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 78 | split = splits[split_id] 79 | train_dirs, test_dirs = split['train'], split['test'] 80 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 81 | 82 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 83 | self._process_data(train_dirs, cam1=True, cam2=True) 84 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 85 | self._process_data(test_dirs, cam1=True, cam2=False) 86 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 87 | self._process_data(test_dirs, cam1=False, cam2=True) 88 | 89 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 90 | min_num = np.min(num_imgs_per_tracklet) 91 | max_num = np.max(num_imgs_per_tracklet) 92 | avg_num = np.mean(num_imgs_per_tracklet) 93 | 94 | num_total_pids = num_train_pids + num_query_pids 95 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 96 | 97 | if verbose: 98 | print("=> iLIDS-VID loaded") 99 | print("Dataset statistics:") 100 | print(" ------------------------------") 101 | print(" subset | # ids | # tracklets") 102 | print(" ------------------------------") 103 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 104 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 105 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 106 | print(" ------------------------------") 107 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 108 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 109 | print(" ------------------------------") 110 | 111 | self.train = train 112 | self.query = query 113 | self.gallery = gallery 114 | 115 | self.num_train_pids = num_train_pids 116 | self.num_query_pids = num_query_pids 117 | self.num_gallery_pids = num_gallery_pids 118 | 119 | def _download_data(self): 120 | if osp.exists(self.dataset_dir): 121 | print("This dataset has been downloaded.") 122 | return 123 | 124 | mkdir_if_missing(self.dataset_dir) 125 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 126 | 127 | print("Downloading iLIDS-VID dataset") 128 | urllib.urlretrieve(self.dataset_url, fpath) 129 | 130 | print("Extracting files") 131 | tar = tarfile.open(fpath) 132 | tar.extractall(path=self.dataset_dir) 133 | tar.close() 134 | 135 | def _check_before_run(self): 136 | """Check if all files are available before going deeper""" 137 | if not osp.exists(self.dataset_dir): 138 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 139 | if not osp.exists(self.data_dir): 140 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 141 | if not osp.exists(self.split_dir): 142 | raise RuntimeError("'{}' is not available".format(self.split_dir)) 143 | 144 | def _prepare_split(self): 145 | if not osp.exists(self.split_path): 146 | print("Creating splits ...") 147 | mat_split_data = loadmat(self.split_mat_path)['ls_set'] 148 | 149 | num_splits = mat_split_data.shape[0] 150 | num_total_ids = mat_split_data.shape[1] 151 | assert num_splits == 10 152 | assert num_total_ids == 300 153 | num_ids_each = num_total_ids // 2 154 | 155 | # pids in mat_split_data are indices, so we need to transform them 156 | # to real pids 157 | person_cam1_dirs = sorted(glob.glob(osp.join(self.cam_1_path, '*'))) 158 | person_cam2_dirs = sorted(glob.glob(osp.join(self.cam_2_path, '*'))) 159 | 160 | person_cam1_dirs = [osp.basename(item) for item in person_cam1_dirs] 161 | person_cam2_dirs = [osp.basename(item) for item in person_cam2_dirs] 162 | 163 | # make sure persons in one camera view can be found in the other camera view 164 | assert set(person_cam1_dirs) == set(person_cam2_dirs) 165 | 166 | splits = [] 167 | for i_split in range(num_splits): 168 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14. 169 | train_idxs = sorted(list(mat_split_data[i_split, num_ids_each:])) 170 | test_idxs = sorted(list(mat_split_data[i_split, :num_ids_each])) 171 | 172 | train_idxs = [int(i)-1 for i in train_idxs] 173 | test_idxs = [int(i)-1 for i in test_idxs] 174 | 175 | # transform pids to person dir names 176 | train_dirs = [person_cam1_dirs[i] for i in train_idxs] 177 | test_dirs = [person_cam1_dirs[i] for i in test_idxs] 178 | 179 | split = {'train': train_dirs, 'test': test_dirs} 180 | splits.append(split) 181 | 182 | print("Totally {} splits are created, following Wang et al. ECCV'14".format(len(splits))) 183 | print("Split file is saved to {}".format(self.split_path)) 184 | write_json(splits, self.split_path) 185 | 186 | print("Splits created") 187 | 188 | def _process_data(self, dirnames, cam1=True, cam2=True): 189 | tracklets = [] 190 | num_imgs_per_tracklet = [] 191 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 192 | 193 | for dirname in dirnames: 194 | if cam1: 195 | person_dir = osp.join(self.cam_1_path, dirname) 196 | img_names = sorted(glob.glob(osp.join(person_dir, '*.png'))) 197 | assert len(img_names) > 0 198 | img_names = tuple(img_names) 199 | pid = dirname2pid[dirname] 200 | tracklets.append((img_names, pid, 0)) 201 | num_imgs_per_tracklet.append(len(img_names)) 202 | 203 | if cam2: 204 | person_dir = osp.join(self.cam_2_path, dirname) 205 | img_names = sorted(glob.glob(osp.join(person_dir, '*.png'))) 206 | assert len(img_names) > 0 207 | img_names = tuple(img_names) 208 | pid = dirname2pid[dirname] 209 | tracklets.append((img_names, pid, 1)) 210 | num_imgs_per_tracklet.append(len(img_names)) 211 | 212 | num_tracklets = len(tracklets) 213 | num_pids = len(dirnames) 214 | 215 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 216 | -------------------------------------------------------------------------------- /torchreid/data_manager/mars.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import glob 7 | import re 8 | import sys 9 | import urllib 10 | import tarfile 11 | import zipfile 12 | import os.path as osp 13 | from scipy.io import loadmat 14 | import numpy as np 15 | import h5py 16 | from scipy.misc import imsave 17 | import json 18 | 19 | 20 | class Mars(object): 21 | """ 22 | MARS 23 | 24 | Reference: 25 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 26 | 27 | URL: http://www.liangzheng.com.cn/Project/project_mars.html 28 | 29 | Dataset statistics: 30 | # identities: 1261 31 | # tracklets: 8298 (train) + 1980 (query) + 9330 (gallery) 32 | # cameras: 6 33 | """ 34 | dataset_dir = 'mars' 35 | 36 | def __init__(self, root='data', min_seq_len=0, verbose=True, **kwargs): 37 | self.dataset_dir = osp.join(root, self.dataset_dir) 38 | self.train_name_path = osp.join(self.dataset_dir, 'info/train_name.txt') 39 | self.test_name_path = osp.join(self.dataset_dir, 'info/test_name.txt') 40 | self.track_train_info_path = osp.join(self.dataset_dir, 'info/tracks_train_info.mat') 41 | self.track_test_info_path = osp.join(self.dataset_dir, 'info/tracks_test_info.mat') 42 | self.query_IDX_path = osp.join(self.dataset_dir, 'info/query_IDX.mat') 43 | self.pose_file = osp.join(self.dataset_dir, 'pose.json') 44 | 45 | self._check_before_run() 46 | if 'use_pose' in kwargs and kwargs['use_pose']: 47 | with open(self.pose_file, 'r') as f: 48 | self.poses = json.load(f) 49 | # process the pose information 50 | self.process_poses = dict() 51 | for key in self.poses: 52 | # save only one body 53 | maxidx = -1 54 | maxarea = -1 55 | maxscore = -1 56 | assert len(self.poses[key]['bodies']) >= 1, 'pose of {} is empty'.format(key) 57 | if len(self.poses[key]['bodies']) == 1: 58 | self.process_poses[key] = np.array(self.poses[key]['bodies'][0]['joints']).reshape((-1, 3)) 59 | else: 60 | for idx, body in enumerate(self.poses[key]['bodies']): 61 | tmp_kps = np.array(body['joints']).reshape((-1, 3)) 62 | tmp_area = (max(tmp_kps[:, 0]) - min(tmp_kps[:, 0])) * (max(tmp_kps[:, 1]) - min(tmp_kps[:, 1])) 63 | tmp_score = body['score'] 64 | if tmp_score > maxscore: 65 | if tmp_area > maxarea and tmp_score > 1.1 * maxscore: 66 | maxscore = tmp_score 67 | maxidx = idx 68 | self.process_poses[key] = np.array(self.poses[key]['bodies'][maxidx]['joints']).reshape((-1, 3)) 69 | else: 70 | self.process_poses = dict() 71 | 72 | # prepare meta data 73 | train_names = self._get_names(self.train_name_path) 74 | test_names = self._get_names(self.test_name_path) 75 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) 76 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) 77 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) 78 | query_IDX -= 1 # index from 0 79 | track_query = track_test[query_IDX,:] 80 | track_gallery = track_test 81 | 82 | train, num_train_tracklets, num_train_pids, num_train_imgs = \ 83 | self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) 84 | 85 | query, num_query_tracklets, num_query_pids, num_query_imgs = \ 86 | self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 87 | 88 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = \ 89 | self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 90 | 91 | num_imgs_per_tracklet = num_train_imgs + num_gallery_imgs 92 | min_num = np.min(num_imgs_per_tracklet) 93 | max_num = np.max(num_imgs_per_tracklet) 94 | avg_num = np.mean(num_imgs_per_tracklet) 95 | 96 | num_total_pids = num_train_pids + num_gallery_pids 97 | num_total_tracklets = num_train_tracklets + num_gallery_tracklets 98 | 99 | if verbose: 100 | print("=> MARS loaded") 101 | print("Dataset statistics:") 102 | print(" ------------------------------") 103 | print(" subset | # ids | # tracklets") 104 | print(" ------------------------------") 105 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 106 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 107 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 108 | print(" ------------------------------") 109 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 110 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 111 | print(" ------------------------------") 112 | 113 | self.train = train 114 | self.query = query 115 | self.gallery = gallery 116 | 117 | self.num_train_pids = num_train_pids 118 | self.num_query_pids = num_query_pids 119 | self.num_gallery_pids = num_gallery_pids 120 | 121 | def _check_before_run(self): 122 | """Check if all files are available before going deeper""" 123 | if not osp.exists(self.dataset_dir): 124 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 125 | if not osp.exists(self.train_name_path): 126 | raise RuntimeError("'{}' is not available".format(self.train_name_path)) 127 | if not osp.exists(self.test_name_path): 128 | raise RuntimeError("'{}' is not available".format(self.test_name_path)) 129 | if not osp.exists(self.track_train_info_path): 130 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path)) 131 | if not osp.exists(self.track_test_info_path): 132 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path)) 133 | if not osp.exists(self.query_IDX_path): 134 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path)) 135 | 136 | def _get_names(self, fpath): 137 | names = [] 138 | with open(fpath, 'r') as f: 139 | for line in f: 140 | new_line = line.rstrip() 141 | names.append(new_line) 142 | return names 143 | 144 | def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 145 | assert home_dir in ['bbox_train', 'bbox_test'] 146 | num_tracklets = meta_data.shape[0] 147 | pid_list = list(set(meta_data[:,2].tolist())) 148 | num_pids = len(pid_list) 149 | 150 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 151 | tracklets = [] 152 | num_imgs_per_tracklet = [] 153 | 154 | for tracklet_idx in range(num_tracklets): 155 | data = meta_data[tracklet_idx,...] 156 | start_index, end_index, pid, camid = data 157 | # if pid == -1: continue # junk images are just ignored 158 | assert 1 <= camid <= 6 159 | if relabel: pid = pid2label[pid] 160 | camid -= 1 # index starts from 0 161 | img_names = names[start_index-1:end_index] 162 | 163 | # make sure image names correspond to the same person 164 | pnames = [img_name[:4] for img_name in img_names] 165 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 166 | 167 | # make sure all images are captured under the same camera 168 | camnames = [img_name[5] for img_name in img_names] 169 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 170 | 171 | # append image names with directory information 172 | img_paths = [osp.join(self.dataset_dir, home_dir, img_name[:4], img_name) for img_name in img_names] 173 | if len(img_paths) >= min_seq_len: 174 | img_paths = tuple(img_paths) 175 | tracklets.append((img_paths, pid, camid)) 176 | num_imgs_per_tracklet.append(len(img_paths)) 177 | 178 | num_tracklets = len(tracklets) 179 | 180 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet -------------------------------------------------------------------------------- /torchreid/data_manager/prid2011.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import glob 7 | import re 8 | import sys 9 | import urllib 10 | import tarfile 11 | import zipfile 12 | import os.path as osp 13 | from scipy.io import loadmat 14 | import numpy as np 15 | import h5py 16 | from scipy.misc import imsave 17 | import json 18 | 19 | from torchreid.utils.iotools import mkdir_if_missing, write_json, read_json 20 | 21 | 22 | class PRID2011(object): 23 | """ 24 | PRID2011 25 | 26 | Reference: 27 | Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011. 28 | 29 | URL: https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/PRID11/ 30 | 31 | Dataset statistics: 32 | # identities: 200 33 | # tracklets: 400 34 | # cameras: 2 35 | """ 36 | dataset_dir = 'prid2011' 37 | 38 | def __init__(self, root='data', split_id=0, min_seq_len=0, verbose=True, **kwargs): 39 | self.dataset_dir = osp.join(root, self.dataset_dir) 40 | self.split_path = osp.join(self.dataset_dir, 'splits_prid2011.json') 41 | self.cam_a_path = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_a') 42 | self.cam_b_path = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_b') 43 | self.pose_file = osp.join(self.dataset_dir, 'pose.json') 44 | 45 | self._check_before_run() 46 | with open(self.pose_file, 'r') as f: 47 | self.poses = json.load(f) 48 | # process the pose information 49 | self.process_poses = dict() 50 | for key in self.poses: 51 | # save only one body 52 | maxidx = -1 53 | maxarea = -1 54 | maxscore = -1 55 | assert len(self.poses[key]['bodies']) >= 1, 'pose of {} is empty'.format(key) 56 | if len(self.poses[key]['bodies']) == 1: 57 | self.process_poses[key] = np.array(self.poses[key]['bodies'][0]['joints']).reshape((-1, 3)) 58 | else: 59 | for idx, body in enumerate(self.poses[key]['bodies']): 60 | tmp_kps = np.array(body['joints']).reshape((-1, 3)) 61 | tmp_area = (max(tmp_kps[:, 0]) - min(tmp_kps[:, 0])) * (max(tmp_kps[:, 1]) - min(tmp_kps[:, 1])) 62 | tmp_score = body['score'] 63 | if tmp_score > maxscore: 64 | if tmp_area > maxarea and tmp_score > 1.1 * maxscore: 65 | maxscore = tmp_score 66 | maxidx = idx 67 | self.process_poses[key] = np.array(self.poses[key]['bodies'][maxidx]['joints']).reshape((-1, 3)) 68 | splits = read_json(self.split_path) 69 | if split_id >= len(splits): 70 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 71 | split = splits[split_id] 72 | train_dirs, test_dirs = split['train'], split['test'] 73 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 74 | 75 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 76 | self._process_data(train_dirs, cam1=True, cam2=True) 77 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 78 | self._process_data(test_dirs, cam1=True, cam2=False) 79 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 80 | self._process_data(test_dirs, cam1=False, cam2=True) 81 | 82 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 83 | min_num = np.min(num_imgs_per_tracklet) 84 | max_num = np.max(num_imgs_per_tracklet) 85 | avg_num = np.mean(num_imgs_per_tracklet) 86 | 87 | num_total_pids = num_train_pids + num_query_pids 88 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 89 | 90 | if verbose: 91 | print("=> PRID2011 loaded") 92 | print("Dataset statistics:") 93 | print(" ------------------------------") 94 | print(" subset | # ids | # tracklets") 95 | print(" ------------------------------") 96 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 97 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 98 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 99 | print(" ------------------------------") 100 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 101 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 102 | print(" ------------------------------") 103 | 104 | self.train = train 105 | self.query = query 106 | self.gallery = gallery 107 | 108 | self.num_train_pids = num_train_pids 109 | self.num_query_pids = num_query_pids 110 | self.num_gallery_pids = num_gallery_pids 111 | 112 | def _check_before_run(self): 113 | """Check if all files are available before going deeper""" 114 | if not osp.exists(self.dataset_dir): 115 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 116 | 117 | def _process_data(self, dirnames, cam1=True, cam2=True): 118 | tracklets = [] 119 | num_imgs_per_tracklet = [] 120 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 121 | 122 | for dirname in dirnames: 123 | if cam1: 124 | person_dir = osp.join(self.cam_a_path, dirname) 125 | img_names = sorted(glob.glob(osp.join(person_dir, '*.png'))) 126 | assert len(img_names) > 0 127 | img_names = tuple(img_names) 128 | pid = dirname2pid[dirname] 129 | tracklets.append((img_names, pid, 0)) 130 | num_imgs_per_tracklet.append(len(img_names)) 131 | 132 | if cam2: 133 | person_dir = osp.join(self.cam_b_path, dirname) 134 | img_names = sorted(glob.glob(osp.join(person_dir, '*.png'))) 135 | assert len(img_names) > 0 136 | img_names = tuple(img_names) 137 | pid = dirname2pid[dirname] 138 | tracklets.append((img_names, pid, 1)) 139 | num_imgs_per_tracklet.append(len(img_names)) 140 | 141 | num_tracklets = len(tracklets) 142 | num_pids = len(dirnames) 143 | 144 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 145 | -------------------------------------------------------------------------------- /torchreid/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .cross_entropy_loss import CrossEntropyLabelSmooth 6 | from .hard_mine_triplet_loss import TripletLoss 7 | 8 | 9 | def DeepSupervision(criterion, xs, y): 10 | """ 11 | Args: 12 | - criterion: loss function 13 | - xs: tuple of inputs 14 | - y: ground truth 15 | """ 16 | loss = 0. 17 | for x in xs: 18 | loss += criterion(x, y) 19 | loss /= len(xs) 20 | return loss -------------------------------------------------------------------------------- /torchreid/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class CrossEntropyLabelSmooth(nn.Module): 9 | """Cross entropy loss with label smoothing regularizer. 10 | 11 | Reference: 12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 13 | Equation: y = (1 - epsilon) * y + epsilon / K. 14 | 15 | Args: 16 | - num_classes (int): number of classes. 17 | - epsilon (float): weight. 18 | """ 19 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 20 | super(CrossEntropyLabelSmooth, self).__init__() 21 | self.num_classes = num_classes 22 | self.epsilon = epsilon 23 | self.use_gpu = use_gpu 24 | self.logsoftmax = nn.LogSoftmax(dim=1) 25 | 26 | def forward(self, inputs, targets): 27 | """ 28 | Args: 29 | - inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 30 | - targets: ground truth labels with shape (num_classes) 31 | """ 32 | log_probs = self.logsoftmax(inputs) 33 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 34 | if self.use_gpu: targets = targets.cuda() 35 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 36 | loss = (- targets * log_probs).mean(0).sum() 37 | return loss -------------------------------------------------------------------------------- /torchreid/losses/hard_mine_triplet_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class TripletLoss(nn.Module): 9 | """Triplet loss with hard positive/negative mining. 10 | 11 | Reference: 12 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 14 | 15 | Args: 16 | margin (float, optional): margin for triplet. Default is 0.3. 17 | """ 18 | def __init__(self, margin=0.3, soft=True): 19 | super(TripletLoss, self).__init__() 20 | self.margin = margin 21 | self.soft = soft 22 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 23 | 24 | def forward(self, inputs, targets): 25 | """ 26 | Args: 27 | inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim). 28 | targets (torch.LongTensor): ground truth labels with shape (num_classes). 29 | """ 30 | n = inputs.size(0) 31 | 32 | # Compute pairwise distance, replace by the official when merged 33 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 34 | dist = dist + dist.t() 35 | dist.addmm_(1, -2, inputs, inputs.t()) 36 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 37 | 38 | # For each anchor, find the hardest positive and negative 39 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 40 | dist_ap, dist_an = [], [] 41 | for i in range(n): 42 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 43 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 44 | dist_ap = torch.cat(dist_ap) 45 | dist_an = torch.cat(dist_an) 46 | 47 | if self.soft: 48 | return torch.log(1 + torch.exp(dist_ap - dist_an)).mean() 49 | else: 50 | return self.ranking_loss(dist_an, dist_ap, torch.ones_like(dist_an)) 51 | -------------------------------------------------------------------------------- /torchreid/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import torch 5 | from bisect import bisect_right 6 | from torch.optim.lr_scheduler import * 7 | 8 | 9 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 10 | def __init__( 11 | self, 12 | optimizer, 13 | milestones, 14 | gamma=0.1, 15 | warmup_factor=1.0 / 3, 16 | warmup_iters=500, 17 | warmup_method="linear", 18 | last_epoch=-1, 19 | ): 20 | if not list(milestones) == sorted(milestones): 21 | raise ValueError( 22 | "Milestones should be a list of" " increasing integers. Got {}", 23 | milestones, 24 | ) 25 | 26 | if warmup_method not in ("constant", "linear"): 27 | raise ValueError( 28 | "Only 'constant' or 'linear' warmup_method accepted" 29 | "got {}".format(warmup_method) 30 | ) 31 | self.milestones = milestones 32 | self.gamma = gamma 33 | self.warmup_factor = warmup_factor 34 | self.warmup_iters = warmup_iters 35 | self.warmup_method = warmup_method 36 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 37 | 38 | def get_lr(self): 39 | warmup_factor = 1 40 | if self.last_epoch < self.warmup_iters: 41 | if self.warmup_method == "constant": 42 | warmup_factor = self.warmup_factor 43 | elif self.warmup_method == "linear": 44 | alpha = self.last_epoch / self.warmup_iters 45 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 46 | return [ 47 | base_lr 48 | * warmup_factor 49 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 50 | for base_lr in self.base_lrs 51 | ] -------------------------------------------------------------------------------- /torchreid/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .accuracy import accuracy 4 | from .rank import evaluate_rank 5 | from .distance import compute_distance_matrix -------------------------------------------------------------------------------- /torchreid/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def accuracy(output, target, topk=(1,)): 10 | """Computes the accuracy over the k top predictions for the specified values of k""" 11 | with torch.no_grad(): 12 | maxk = max(topk) 13 | batch_size = target.size(0) 14 | 15 | def calc_acc(output, target): 16 | _, pred = output.topk(maxk, 1, True, True) 17 | pred = pred.t() 18 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 19 | 20 | res = [] 21 | for k in topk: 22 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 23 | acc = correct_k.mul_(1.0 / batch_size) 24 | res.append(acc.item()) 25 | return res 26 | 27 | all_res = [] 28 | if isinstance(output, (tuple, list)): 29 | for out in output: 30 | all_res.append(calc_acc(out, target)) 31 | else: 32 | all_res.append(calc_acc(output, target)) 33 | return np.array(all_res) -------------------------------------------------------------------------------- /torchreid/metrics/distance.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | import torch 8 | from torch.nn import functional as F 9 | 10 | 11 | def compute_distance_matrix(input1, input2, metric='euclidean'): 12 | """A wrapper function for computing distance matrix. 13 | 14 | Args: 15 | input1 (torch.Tensor): 2-D feature matrix. 16 | input2 (torch.Tensor): 2-D feature matrix. 17 | metric (str, optional): "euclidean" or "cosine". 18 | Default is "euclidean". 19 | 20 | Returns: 21 | CurTime: 2019-08-03 00:12:21 Epoch: [160][100/117] Time 0.663 (0.713) Speed 89.744 samples/s Data 0.0002 (0.0366) Xent 1.0156 (1.0432) Htri 0.0000 (0.0000) Top1 1.0000 (0.9998) Eta 10:12:04 22 | Sat Aug 3 00:12:27 2019 ==> Test 23 | Sat Aug 3 00:12:36 2019 Extracted features for query set, obtained 1980-by-1024 matrix 24 | Sat Aug 3 00:13:08 2019 Extracted features for gallery set, obtained 9330-by-1024 matrix 25 | Sat Aug 3 00:13:08 2019 ==> BatchTime(s)/BatchSize(img): 0.045/256 26 | Sat Aug 3 00:13:08 2019 Computing distance matrix with metric=euclidean ... 27 | Sat Aug 3 00:13:08 2019 Computing CMC and mAP 28 | Sat Aug 3 00:13:12 2019 Results ---------- 29 | Sat Aug 3 00:13:12 2019 mAP: 78.00% 30 | torch.Tensor: distance matrix. 31 | 32 | Examples:: 33 | >>> from torchreid import metrics 34 | >>> input1 = torch.rand(10, 2048) 35 | >>> input2 = torch.rand(100, 2048) 36 | >>> distmat = metrics.compute_distance_matrix(input1, input2) 37 | >>> distmat.size() # (10, 100) 38 | """ 39 | # check input 40 | assert isinstance(input1, torch.Tensor) 41 | assert isinstance(input2, torch.Tensor) 42 | assert input1.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(input1.dim()) 43 | assert input2.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(input2.dim()) 44 | assert input1.size(1) == input2.size(1) 45 | 46 | if metric == 'euclidean': 47 | distmat = euclidean_squared_distance(input1, input2) 48 | elif metric == 'cosine': 49 | distmat = cosine_distance(input1, input2) 50 | else: 51 | raise ValueError( 52 | 'Unknown distance metric: {}. ' 53 | 'Please choose either "euclidean" or "cosine"'.format(metric) 54 | ) 55 | 56 | return distmat 57 | 58 | 59 | def euclidean_squared_distance(input1, input2): 60 | """Computes euclidean squared distance. 61 | 62 | Args: 63 | input1 (torch.Tensor): 2-D feature matrix. 64 | input2 (torch.Tensor): 2-D feature matrix. 65 | 66 | Returns: 67 | torch.Tensor: distance matrix. 68 | """ 69 | m, n = input1.size(0), input2.size(0) 70 | distmat = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 71 | torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t() 72 | distmat.addmm_(1, -2, input1, input2.t()) 73 | return distmat 74 | 75 | 76 | def cosine_distance(input1, input2): 77 | """Computes cosine distance. 78 | 79 | Args: 80 | input1 (torch.Tensor): 2-D feature matrix. 81 | input2 (torch.Tensor): 2-D feature matrix. 82 | 83 | Returns: 84 | torch.Tensor: distance matrix. 85 | """ 86 | input1_normed = F.normalize(input1, p=2, dim=1) 87 | input2_normed = F.normalize(input2, p=2, dim=1) 88 | distmat = 1 - torch.mm(input1_normed, input2_normed.t()) 89 | return distmat -------------------------------------------------------------------------------- /torchreid/metrics/rank.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import copy 7 | from collections import defaultdict 8 | import sys 9 | import warnings 10 | 11 | try: 12 | from torchreid.metrics.rank_cylib.rank_cy import evaluate_cy 13 | IS_CYTHON_AVAI = True 14 | except ImportError: 15 | IS_CYTHON_AVAI = False 16 | warnings.warn( 17 | 'Cython evaluation (very fast so highly recommended) is ' 18 | 'unavailable, now use python evaluation.' 19 | ) 20 | 21 | 22 | def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 23 | """Evaluation with cuhk03 metric 24 | Key: one image for each gallery identity is randomly sampled for each query identity. 25 | Random sampling is performed num_repeats times. 26 | """ 27 | num_repeats = 10 28 | num_q, num_g = distmat.shape 29 | 30 | if num_g < max_rank: 31 | max_rank = num_g 32 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 33 | 34 | indices = np.argsort(distmat, axis=1) 35 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 36 | 37 | # compute cmc curve for each query 38 | all_cmc = [] 39 | all_AP = [] 40 | num_valid_q = 0. # number of valid query 41 | 42 | for q_idx in range(num_q): 43 | # get query pid and camid 44 | q_pid = q_pids[q_idx] 45 | q_camid = q_camids[q_idx] 46 | 47 | # remove gallery samples that have the same pid and camid with query 48 | order = indices[q_idx] 49 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 50 | keep = np.invert(remove) 51 | 52 | # compute cmc curve 53 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 54 | if not np.any(raw_cmc): 55 | # this condition is true when query identity does not appear in gallery 56 | continue 57 | 58 | kept_g_pids = g_pids[order][keep] 59 | g_pids_dict = defaultdict(list) 60 | for idx, pid in enumerate(kept_g_pids): 61 | g_pids_dict[pid].append(idx) 62 | 63 | cmc = 0. 64 | for repeat_idx in range(num_repeats): 65 | mask = np.zeros(len(raw_cmc), dtype=np.bool) 66 | for _, idxs in g_pids_dict.items(): 67 | # randomly sample one image for each gallery person 68 | rnd_idx = np.random.choice(idxs) 69 | mask[rnd_idx] = True 70 | masked_raw_cmc = raw_cmc[mask] 71 | _cmc = masked_raw_cmc.cumsum() 72 | _cmc[_cmc > 1] = 1 73 | cmc += _cmc[:max_rank].astype(np.float32) 74 | 75 | cmc /= num_repeats 76 | all_cmc.append(cmc) 77 | # compute AP 78 | num_rel = raw_cmc.sum() 79 | tmp_cmc = raw_cmc.cumsum() 80 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 81 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 82 | AP = tmp_cmc.sum() / num_rel 83 | all_AP.append(AP) 84 | num_valid_q += 1. 85 | 86 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 87 | 88 | all_cmc = np.asarray(all_cmc).astype(np.float32) 89 | all_cmc = all_cmc.sum(0) / num_valid_q 90 | mAP = np.mean(all_AP) 91 | 92 | return all_cmc, mAP 93 | 94 | 95 | def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 96 | """Evaluation with market1501 metric 97 | Key: for each query identity, its gallery images from the same camera view are discarded. 98 | """ 99 | num_q, num_g = distmat.shape 100 | 101 | if num_g < max_rank: 102 | max_rank = num_g 103 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 104 | 105 | indices = np.argsort(distmat, axis=1) 106 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 107 | 108 | # compute cmc curve for each query 109 | all_cmc = [] 110 | all_AP = [] 111 | num_valid_q = 0. # number of valid query 112 | 113 | for q_idx in range(num_q): 114 | # get query pid and camid 115 | q_pid = q_pids[q_idx] 116 | q_camid = q_camids[q_idx] 117 | 118 | # remove gallery samples that have the same pid and camid with query 119 | order = indices[q_idx] 120 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 121 | keep = np.invert(remove) 122 | 123 | # compute cmc curve 124 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 125 | if not np.any(raw_cmc): 126 | # this condition is true when query identity does not appear in gallery 127 | continue 128 | 129 | cmc = raw_cmc.cumsum() 130 | cmc[cmc > 1] = 1 131 | 132 | all_cmc.append(cmc[:max_rank]) 133 | num_valid_q += 1. 134 | 135 | # compute average precision 136 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 137 | num_rel = raw_cmc.sum() 138 | tmp_cmc = raw_cmc.cumsum() 139 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 140 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 141 | AP = tmp_cmc.sum() / num_rel 142 | all_AP.append(AP) 143 | 144 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 145 | 146 | all_cmc = np.asarray(all_cmc).astype(np.float32) 147 | all_cmc = all_cmc.sum(0) / num_valid_q 148 | mAP = np.mean(all_AP) 149 | 150 | return all_cmc, mAP 151 | 152 | 153 | def evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03): 154 | if use_metric_cuhk03: 155 | return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 156 | else: 157 | return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 158 | 159 | 160 | def evaluate_mars(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 161 | num_q, num_g = distmat.shape 162 | cmc = np.zeros((num_q, max_rank)) 163 | ap = np.zeros(num_q) 164 | 165 | for k in range(num_q): 166 | good_idx = np.where((q_pids[k] == g_pids) & (q_camids[k] != g_camids))[0] 167 | junk_mask1 = (g_pids == -1) 168 | junk_mask2 = (q_pids[k] == g_pids) & (q_camids[k] == g_camids) 169 | junk_idx = np.where(junk_mask1 | junk_mask2)[0] 170 | score = distmat[k, :] 171 | sort_idx = np.argsort(score) 172 | sort_idx = sort_idx[:max_rank] 173 | 174 | ap[k], cmc[k, :] = Compute_AP(good_idx, junk_idx, sort_idx) 175 | CMC = np.mean(cmc, axis=0) 176 | mAP = np.mean(ap) 177 | return CMC, mAP 178 | 179 | 180 | def Compute_AP(good_image, junk_image, index): 181 | cmc = np.zeros((len(index),)) 182 | ngood = len(good_image) 183 | 184 | old_recall = 0 185 | old_precision = 1. 186 | ap = 0 187 | intersect_size = 0 188 | j = 0 189 | good_now = 0 190 | njunk = 0 191 | for n in range(len(index)): 192 | flag = 0 193 | if np.any(good_image == index[n]): 194 | cmc[n - njunk:] = 1 195 | flag = 1 # good image 196 | good_now += 1 197 | if np.any(junk_image == index[n]): 198 | njunk += 1 199 | continue # junk image 200 | 201 | if flag == 1: 202 | intersect_size += 1 203 | recall = intersect_size / ngood 204 | precision = intersect_size / (j + 1) 205 | ap += (recall - old_recall) * (old_precision + precision) / 2 206 | old_recall = recall 207 | old_precision = precision 208 | j += 1 209 | 210 | if good_now == ngood: 211 | return ap, cmc 212 | return ap, cmc 213 | 214 | 215 | def evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False, 216 | use_metric_market1501=False, use_metric_mars=False, use_cython=True): 217 | """ 218 | Evaluate CMC and mAP. 219 | :param distmat: (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 220 | :param q_pids: (numpy.ndarray): 1-D array containing person identities of each query instance. 221 | :param g_pids: (numpy.ndarray): 1-D array containing person identities of each gallery instance. 222 | :param q_camids: 1-D array containing camera views under which each query instance is captured. 223 | :param g_camids: 1-D array containing camera views under which each gallery instance is captured. 224 | :param max_rank: maximum CMC rank to be computed. Default is 50. 225 | :param use_metric_cuhk03: 226 | :param use_metric_market1501: 227 | :param use_metric_mars: 228 | :param use_metric_dukev: same as use_metric_mars 229 | :param use_cython: 230 | :return: 231 | """ 232 | if use_metric_market1501 or use_metric_cuhk03: 233 | if use_cython and IS_CYTHON_AVAI: 234 | return evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03) 235 | else: 236 | return evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03) 237 | elif use_metric_mars: 238 | return evaluate_mars(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 239 | 240 | 241 | # add from Dukemtmc video reid 242 | from sklearn.metrics.base import _average_binary_score 243 | from sklearn.metrics import precision_recall_curve, auc 244 | 245 | def _unique_sample(ids_dict, num): 246 | mask = np.zeros(num, dtype=np.bool) 247 | for _, indices in ids_dict.items(): 248 | i = np.random.choice(indices) 249 | mask[i] = True 250 | return mask 251 | 252 | 253 | def average_precision_score(y_true, y_score, average="macro", 254 | sample_weight=None): 255 | def _binary_average_precision(y_true, y_score, sample_weight=None): 256 | precision, recall, thresholds = precision_recall_curve( 257 | y_true, y_score, sample_weight=sample_weight) 258 | return auc(recall, precision) 259 | 260 | return _average_binary_score(_binary_average_precision, y_true, y_score, 261 | average, sample_weight=sample_weight) 262 | 263 | 264 | def cmc(distmat, query_ids, gallery_ids, query_cams, gallery_cams, topk=100, 265 | separate_camera_set=False, single_gallery_shot=False, first_match_break=False): 266 | m, n = distmat.shape 267 | # Sort and find correct matches 268 | indices = np.argsort(distmat, axis=1) 269 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 270 | # Compute CMC for each query 271 | ret = np.zeros(topk) 272 | num_valid_queries = 0 273 | for i in range(m): 274 | # Filter out the same id and same camera 275 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 276 | (gallery_cams[indices[i]] != query_cams[i])) 277 | if separate_camera_set: 278 | # Filter out samples from same camera 279 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 280 | if not np.any(matches[i, valid]): continue 281 | if single_gallery_shot: 282 | repeat = 10 283 | gids = gallery_ids[indices[i][valid]] 284 | inds = np.where(valid)[0] 285 | ids_dict = defaultdict(list) 286 | for j, x in zip(inds, gids): 287 | ids_dict[x].append(j) 288 | else: 289 | repeat = 1 290 | for _ in range(repeat): 291 | if single_gallery_shot: 292 | # Randomly choose one instance for each id 293 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 294 | index = np.nonzero(matches[i, sampled])[0] 295 | else: 296 | index = np.nonzero(matches[i, valid])[0] 297 | delta = 1. / (len(index) * repeat) 298 | for j, k in enumerate(index): 299 | if k - j >= topk: break 300 | if first_match_break: 301 | ret[k - j] += 1 302 | break 303 | ret[k - j] += delta 304 | num_valid_queries += 1 305 | if num_valid_queries == 0: 306 | raise RuntimeError("No valid query") 307 | return ret.cumsum() / num_valid_queries 308 | 309 | 310 | def mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams): 311 | m, n = distmat.shape 312 | # Sort and find correct matches 313 | indices = np.argsort(distmat, axis=1) 314 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 315 | # Compute AP for each query 316 | aps = [] 317 | for i in range(m): 318 | # Filter out the same id and same camera 319 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 320 | (gallery_cams[indices[i]] != query_cams[i])) 321 | y_true = matches[i, valid] 322 | y_score = -distmat[i][indices[i]][valid] 323 | if not np.any(y_true): continue 324 | aps.append(average_precision_score(y_true, y_score)) 325 | if len(aps) == 0: 326 | raise RuntimeError("No valid query") 327 | return np.mean(aps) 328 | 329 | 330 | def evaluate_dukev(distmat, query_ids, gallery_ids, query_cams, gallery_cams, max_rank=50): 331 | # Compute mean AP 332 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 333 | 334 | # Compute all kinds of CMC scores 335 | cmc_configs = { 336 | 'market1501': dict(separate_camera_set=False, 337 | single_gallery_shot=False, 338 | first_match_break=True)} 339 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 340 | query_cams, gallery_cams, **params) 341 | for name, params in cmc_configs.items()} 342 | 343 | return cmc_scores['market1501'], mAP -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python setup.py build_ext --inplace 3 | rm -rf build 4 | clean: 5 | rm -rf build 6 | rm -f rank_cy.c *.so -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weleen/AGRL.pytorch/83a37eea19365bc8e74c3b3c5fe5ad5d00d04f2c/torchreid/metrics/rank_cylib/__init__.py -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/rank_cy.pyx: -------------------------------------------------------------------------------- 1 | # cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True 2 | 3 | from __future__ import print_function 4 | 5 | import cython 6 | import numpy as np 7 | cimport numpy as np 8 | from collections import defaultdict 9 | import random 10 | 11 | 12 | """ 13 | Compiler directives: 14 | https://github.com/cython/cython/wiki/enhancements-compilerdirectives 15 | 16 | Cython tutorial: 17 | https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html 18 | 19 | Credit to https://github.com/luzai 20 | """ 21 | 22 | 23 | # Main interface 24 | cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False): 25 | distmat = np.asarray(distmat, dtype=np.float32) 26 | q_pids = np.asarray(q_pids, dtype=np.int64) 27 | g_pids = np.asarray(g_pids, dtype=np.int64) 28 | q_camids = np.asarray(q_camids, dtype=np.int64) 29 | g_camids = np.asarray(g_camids, dtype=np.int64) 30 | if use_metric_cuhk03: 31 | return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 32 | return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 33 | 34 | 35 | cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 36 | long[:]q_camids, long[:]g_camids, long max_rank): 37 | 38 | cdef long num_q = distmat.shape[0] 39 | cdef long num_g = distmat.shape[1] 40 | 41 | if num_g < max_rank: 42 | max_rank = num_g 43 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 44 | 45 | cdef: 46 | long num_repeats = 10 47 | long[:,:] indices = np.argsort(distmat, axis=1) 48 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 49 | 50 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 51 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 52 | float num_valid_q = 0. # number of valid query 53 | 54 | long q_idx, q_pid, q_camid, g_idx 55 | long[:] order = np.zeros(num_g, dtype=np.int64) 56 | long keep 57 | 58 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 59 | float[:] masked_raw_cmc = np.zeros(num_g, dtype=np.float32) 60 | float[:] cmc, masked_cmc 61 | long num_g_real, num_g_real_masked, rank_idx, rnd_idx 62 | unsigned long meet_condition 63 | float AP 64 | long[:] kept_g_pids, mask 65 | 66 | float num_rel 67 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 68 | float tmp_cmc_sum 69 | 70 | for q_idx in range(num_q): 71 | # get query pid and camid 72 | q_pid = q_pids[q_idx] 73 | q_camid = q_camids[q_idx] 74 | 75 | # remove gallery samples that have the same pid and camid with query 76 | for g_idx in range(num_g): 77 | order[g_idx] = indices[q_idx, g_idx] 78 | num_g_real = 0 79 | meet_condition = 0 80 | kept_g_pids = np.zeros(num_g, dtype=np.int64) 81 | 82 | for g_idx in range(num_g): 83 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 84 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 85 | kept_g_pids[num_g_real] = g_pids[order[g_idx]] 86 | num_g_real += 1 87 | if matches[q_idx][g_idx] > 1e-31: 88 | meet_condition = 1 89 | 90 | if not meet_condition: 91 | # this condition is true when query identity does not appear in gallery 92 | continue 93 | 94 | # cuhk03-specific setting 95 | g_pids_dict = defaultdict(list) # overhead! 96 | for g_idx in range(num_g_real): 97 | g_pids_dict[kept_g_pids[g_idx]].append(g_idx) 98 | 99 | cmc = np.zeros(max_rank, dtype=np.float32) 100 | for _ in range(num_repeats): 101 | mask = np.zeros(num_g_real, dtype=np.int64) 102 | 103 | for _, idxs in g_pids_dict.items(): 104 | # randomly sample one image for each gallery person 105 | rnd_idx = np.random.choice(idxs) 106 | #rnd_idx = idxs[0] # use deterministic for debugging 107 | mask[rnd_idx] = 1 108 | 109 | num_g_real_masked = 0 110 | for g_idx in range(num_g_real): 111 | if mask[g_idx] == 1: 112 | masked_raw_cmc[num_g_real_masked] = raw_cmc[g_idx] 113 | num_g_real_masked += 1 114 | 115 | masked_cmc = np.zeros(num_g, dtype=np.float32) 116 | function_cumsum(masked_raw_cmc, masked_cmc, num_g_real_masked) 117 | for g_idx in range(num_g_real_masked): 118 | if masked_cmc[g_idx] > 1: 119 | masked_cmc[g_idx] = 1 120 | 121 | for rank_idx in range(max_rank): 122 | cmc[rank_idx] += masked_cmc[rank_idx] / num_repeats 123 | 124 | for rank_idx in range(max_rank): 125 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 126 | # compute average precision 127 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 128 | function_cumsum(raw_cmc, tmp_cmc, num_g_real) 129 | num_rel = 0 130 | tmp_cmc_sum = 0 131 | for g_idx in range(num_g_real): 132 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx] 133 | num_rel += raw_cmc[g_idx] 134 | all_AP[q_idx] = tmp_cmc_sum / num_rel 135 | num_valid_q += 1. 136 | 137 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 138 | 139 | # compute averaged cmc 140 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 141 | for rank_idx in range(max_rank): 142 | for q_idx in range(num_q): 143 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 144 | avg_cmc[rank_idx] /= num_valid_q 145 | 146 | cdef float mAP = 0 147 | for q_idx in range(num_q): 148 | mAP += all_AP[q_idx] 149 | mAP /= num_valid_q 150 | 151 | return np.asarray(avg_cmc).astype(np.float32), mAP 152 | 153 | 154 | cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 155 | long[:]q_camids, long[:]g_camids, long max_rank): 156 | 157 | cdef long num_q = distmat.shape[0] 158 | cdef long num_g = distmat.shape[1] 159 | 160 | if num_g < max_rank: 161 | max_rank = num_g 162 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 163 | 164 | cdef: 165 | long[:,:] indices = np.argsort(distmat, axis=1) 166 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 167 | 168 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 169 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 170 | float num_valid_q = 0. # number of valid query 171 | 172 | long q_idx, q_pid, q_camid, g_idx 173 | long[:] order = np.zeros(num_g, dtype=np.int64) 174 | long keep 175 | 176 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 177 | float[:] cmc = np.zeros(num_g, dtype=np.float32) 178 | long num_g_real, rank_idx 179 | unsigned long meet_condition 180 | 181 | float num_rel 182 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 183 | float tmp_cmc_sum 184 | 185 | for q_idx in range(num_q): 186 | # get query pid and camid 187 | q_pid = q_pids[q_idx] 188 | q_camid = q_camids[q_idx] 189 | 190 | # remove gallery samples that have the same pid and camid with query 191 | for g_idx in range(num_g): 192 | order[g_idx] = indices[q_idx, g_idx] 193 | num_g_real = 0 194 | meet_condition = 0 195 | 196 | for g_idx in range(num_g): 197 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 198 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 199 | num_g_real += 1 200 | if matches[q_idx][g_idx] > 1e-31: 201 | meet_condition = 1 202 | 203 | if not meet_condition: 204 | # this condition is true when query identity does not appear in gallery 205 | continue 206 | 207 | # compute cmc 208 | function_cumsum(raw_cmc, cmc, num_g_real) 209 | for g_idx in range(num_g_real): 210 | if cmc[g_idx] > 1: 211 | cmc[g_idx] = 1 212 | 213 | for rank_idx in range(max_rank): 214 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 215 | num_valid_q += 1. 216 | 217 | # compute average precision 218 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 219 | function_cumsum(raw_cmc, tmp_cmc, num_g_real) 220 | num_rel = 0 221 | tmp_cmc_sum = 0 222 | for g_idx in range(num_g_real): 223 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx] 224 | num_rel += raw_cmc[g_idx] 225 | all_AP[q_idx] = tmp_cmc_sum / num_rel 226 | 227 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 228 | 229 | # compute averaged cmc 230 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 231 | for rank_idx in range(max_rank): 232 | for q_idx in range(num_q): 233 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 234 | avg_cmc[rank_idx] /= num_valid_q 235 | 236 | cdef float mAP = 0 237 | for q_idx in range(num_q): 238 | mAP += all_AP[q_idx] 239 | mAP /= num_valid_q 240 | 241 | return np.asarray(avg_cmc).astype(np.float32), mAP 242 | 243 | 244 | # Compute the cumulative sum 245 | cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, long n): 246 | cdef long i 247 | dst[0] = src[0] 248 | for i in range(1, n): 249 | dst[i] = src[i] + dst[i - 1] -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | from Cython.Build import cythonize 4 | import numpy as np 5 | 6 | 7 | def numpy_include(): 8 | try: 9 | numpy_include = np.get_include() 10 | except AttributeError: 11 | numpy_include = np.get_numpy_include() 12 | return numpy_include 13 | 14 | 15 | ext_modules = [ 16 | Extension( 17 | 'rank_cy', 18 | ['rank_cy.pyx'], 19 | include_dirs=[numpy_include()], 20 | ) 21 | ] 22 | 23 | setup( 24 | name='Cython-based reid evaluation code', 25 | ext_modules=cythonize(ext_modules) 26 | ) -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/test_cython.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import os.path as osp 5 | import timeit 6 | import numpy as np 7 | 8 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 9 | from torchreid import metrics 10 | 11 | """ 12 | Test the speed of cython-based evaluation code. The speed improvements 13 | can be much bigger when using the real reid data, which contains a larger 14 | amount of query and gallery images. 15 | 16 | Note: you might encounter the following error: 17 | 'AssertionError: Error: all query identities do not appear in gallery'. 18 | This is normal because the inputs are random numbers. Just try again. 19 | """ 20 | 21 | print('*** Compare running time ***') 22 | 23 | setup = ''' 24 | import sys 25 | import os.path as osp 26 | import numpy as np 27 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 28 | from torchreid import metrics 29 | num_q = 30 30 | num_g = 300 31 | max_rank = 5 32 | distmat = np.random.rand(num_q, num_g) * 20 33 | q_pids = np.random.randint(0, num_q, size=num_q) 34 | g_pids = np.random.randint(0, num_g, size=num_g) 35 | q_camids = np.random.randint(0, 5, size=num_q) 36 | g_camids = np.random.randint(0, 5, size=num_g) 37 | ''' 38 | 39 | print('=> Using market1501\'s metric') 40 | pytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)', setup=setup, number=20) 41 | cytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)', setup=setup, number=20) 42 | print('Python time: {} s'.format(pytime)) 43 | print('Cython time: {} s'.format(cytime)) 44 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 45 | 46 | print('=> Using cuhk03\'s metric') 47 | pytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)', setup=setup, number=20) 48 | cytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)', setup=setup, number=20) 49 | print('Python time: {} s'.format(pytime)) 50 | print('Cython time: {} s'.format(cytime)) 51 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 52 | 53 | """ 54 | print("=> Check precision") 55 | 56 | num_q = 30 57 | num_g = 300 58 | max_rank = 5 59 | distmat = np.random.rand(num_q, num_g) * 20 60 | q_pids = np.random.randint(0, num_q, size=num_q) 61 | g_pids = np.random.randint(0, num_g, size=num_g) 62 | q_camids = np.random.randint(0, 5, size=num_q) 63 | g_camids = np.random.randint(0, 5, size=num_g) 64 | 65 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False) 66 | print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 67 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True) 68 | print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 69 | """ -------------------------------------------------------------------------------- /torchreid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import shutil 5 | import inspect 6 | 7 | # video 8 | from .res50tp import * 9 | from .sta import * 10 | from .simple_sta import * 11 | from .gsta import * 12 | from .resnet50_s1 import * 13 | from .graphnet import * 14 | from .vmgn import * 15 | from .ganet import * 16 | 17 | __model_factory = { 18 | 'res50tp': res50tp, 19 | 'resnet50_s1': resnet50_s1, 20 | 'sta': sta_p4, 21 | 'simple_sta': simple_sta_p4, 22 | 'gsta': gsta, 23 | 'msppn': MSPyraPartNet, 24 | 'msppgn': MSPyraPartGraphNet, 25 | 'vmgn': vmgn, 26 | 'ganet': ganet 27 | } 28 | 29 | 30 | def get_names(): 31 | return list(__model_factory.keys()) 32 | 33 | 34 | def init_model(name, *args, **kwargs): 35 | if name not in list(__model_factory.keys()): 36 | raise KeyError("Unknown model: {}".format(name)) 37 | if 'save_dir' in kwargs: 38 | # XXX: shutil.copy and shutil.copy2 raise PermissionError, so use copyfile here. 39 | model_file = inspect.getfile(__model_factory[name]) 40 | shutil.copyfile(model_file, os.path.join(os.path.abspath(kwargs['save_dir']), os.path.basename(model_file))) 41 | return __model_factory[name](*args, **kwargs) 42 | -------------------------------------------------------------------------------- /torchreid/models/graphnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from torchreid.utils.torchtools import weights_init_xavier 7 | from torchreid.utils.reidtools import calc_splits 8 | 9 | 10 | def pose_aggregation(features, locs, num_scale=1, total_parts=[3], seq_len=4): 11 | """ 12 | :param features: (b, num_scale * total_split, seq_len, c) 13 | :param locs: (b, seq_len, num_parts) 14 | :param num_scale: int 15 | :param total_parts: list(int) 16 | :param seq_len: int 17 | :return: 18 | """ 19 | locs_t = locs.transpose(1, 2) 20 | num_parts = total_parts[-1] 21 | total_split = sum(total_parts) 22 | batch, _, _, channel = features.size() 23 | if total_split in [3, 4]: 24 | locs_tmp = locs_t * seq_len + (torch.arange(seq_len).unsqueeze(0).repeat(num_parts, 1)).float().unsqueeze(0).to( 25 | locs_t.device) 26 | features = features.view(batch, num_scale, num_parts * seq_len, channel) 27 | fused_features = torch.zeros((batch, num_scale, num_parts, channel)).to(features.device) 28 | for b in range(batch): 29 | loc = locs_tmp[b].data.cpu().numpy().reshape(num_parts * seq_len).astype(int) 30 | for idx_p in range(num_parts): 31 | index = loc[idx_p * seq_len: (idx_p + 1) * seq_len] 32 | fused = [] 33 | for idx in index: 34 | fused.append(features[b, :, idx]) 35 | fused_features[b, :, idx_p] = torch.stack(fused, dim=1).mean(dim=1) 36 | elif total_split in [6, 7]: 37 | head_num_parts = 3 38 | features = features.view(batch, num_scale, total_split, seq_len, channel) 39 | fused_features_tail = torch.zeros((batch, num_scale, num_parts, channel)).to(features.device) 40 | fused_features_head = features[:, :, :head_num_parts].mean(dim=3) 41 | features_tail = features[:, :, head_num_parts:].view(batch, num_scale, num_parts * seq_len, channel) 42 | locs_tmp = locs_t * seq_len + (torch.arange(seq_len).unsqueeze(0).repeat(num_parts, 1)).float().unsqueeze(0).to( 43 | locs_t.device) 44 | for b in range(batch): 45 | loc = locs_tmp[b].data.cpu().numpy().reshape(num_parts * seq_len).astype(int) 46 | for idx_p in range(num_parts): 47 | index = loc[idx_p * seq_len: (idx_p + 1) * seq_len] 48 | fused = [] 49 | for idx in index: 50 | fused.append(features_tail[b, :, idx]) 51 | fused_features_tail[b, :, idx_p] = torch.stack(fused, dim=1).mean(dim=1) 52 | fused_features = torch.cat([fused_features_head, fused_features_tail], dim=2) 53 | else: 54 | raise NotImplementedError 55 | return fused_features.view(batch, num_scale * total_split, channel) 56 | 57 | 58 | class GraphBlock(nn.Module): 59 | def __init__(self, in_features, out_features, dropout=0., alpha=1, gamma=1, 60 | learn_graph=True, use_pose=True, self_loop=False, **kwargs): 61 | super(GraphBlock, self).__init__() 62 | self.in_features = in_features 63 | self.out_features = out_features 64 | self.dropout = dropout 65 | self.alpha = alpha 66 | self.gamma = gamma 67 | self.learn_graph = learn_graph 68 | self.use_pose = use_pose 69 | self.self_loop = self_loop 70 | 71 | self.linear = nn.Linear(in_features, out_features, bias=False) 72 | nn.init.normal_(self.linear.weight, mean=0, std=0.001) 73 | 74 | if self.learn_graph: 75 | num_hid = 128 76 | self.emb_q = nn.Linear(out_features, num_hid) 77 | self.emb_k = nn.Linear(out_features, num_hid) 78 | nn.init.normal_(self.emb_q.weight, std=0.001) 79 | nn.init.constant_(self.emb_q.bias, 0) 80 | nn.init.normal_(self.emb_k.weight, std=0.001) 81 | nn.init.constant_(self.emb_k.bias, 0) 82 | 83 | self.bn = nn.BatchNorm1d(out_features) 84 | 85 | def forward(self, input, adj): 86 | h = self.linear(input) 87 | N, V, C = h.size() 88 | 89 | if self.use_pose: 90 | adj = F.normalize(adj, p=1, dim=2) 91 | 92 | if self.learn_graph: 93 | emb_q = self.emb_q(h) 94 | emb_k = self.emb_k(h) 95 | graph = torch.bmm(emb_q, emb_k.transpose(1, 2)) 96 | graph = F.softmax(graph, dim=2) 97 | if self.self_loop: 98 | I = torch.eye(V, device=input.device).view(1, V, V).repeat(N, 1, 1) 99 | graph = F.softmax(graph + I, dim=2) 100 | if self.use_pose: 101 | graph = (adj + self.gamma * graph) / (1 + self.gamma) 102 | else: 103 | graph = adj 104 | 105 | h_prime = torch.bmm(graph, h) 106 | h_prime = F.dropout(h_prime, p=self.dropout, training=self.training) 107 | h_prime = F.relu(h_prime) 108 | 109 | h_prime = h_prime.view(N * V, self.out_features) 110 | h_prime = self.bn(h_prime) 111 | h_prime = h_prime.view(N, V, self.out_features) 112 | 113 | assert input.size() == h_prime.size(), 'when use skip connection, input size must equal to output size.' 114 | return input + self.alpha * h_prime 115 | 116 | def __repr__(self): 117 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 118 | 119 | 120 | class MSPyraPartNet(nn.Module): 121 | """use layer2 layer3 layer4.""" 122 | def __init__(self, num_classes=100, loss={'xent', 'htri'}, num_split=4, 123 | **kwargs): 124 | super(MSPyraPartNet, self).__init__() 125 | self.num_classes = num_classes 126 | self.loss = loss 127 | self.num_parts = num_split 128 | 129 | resnet50 = torchvision.models.resnet50(pretrained=True) 130 | self.conv1 = resnet50.conv1 131 | self.bn1 = resnet50.bn1 132 | self.relu = resnet50.relu 133 | self.maxpool = resnet50.maxpool 134 | self.layer1 = resnet50.layer1 135 | self.layer2 = resnet50.layer2 136 | self.layer3 = resnet50.layer3 137 | self.layer4 = resnet50.layer4 138 | 139 | self.num_scale = 3 140 | self.total_parts = calc_splits(self.num_parts) 141 | self.total_split = sum(self.total_parts) 142 | self.num_hid = 512 143 | 144 | self.avg_pool = nn.ModuleList() 145 | self.max_pool = nn.ModuleList() 146 | for n in self.total_parts: 147 | self.avg_pool.append(nn.AdaptiveAvgPool2d((n, 1))) 148 | self.max_pool.append(nn.AdaptiveMaxPool2d((n, 1))) 149 | 150 | self.reduce_f1 = nn.Linear(512, self.num_hid) 151 | self.bn_f1 = nn.BatchNorm1d(self.num_hid) 152 | self.reduce_f2 = nn.Linear(1024, self.num_hid) 153 | self.bn_f2 = nn.BatchNorm1d(self.num_hid) 154 | self.reduce_f3 = nn.Linear(2048, self.num_hid) 155 | self.bn_f3 = nn.BatchNorm1d(self.num_hid) 156 | weights_init_xavier(self.reduce_f1) 157 | weights_init_xavier(self.reduce_f2) 158 | weights_init_xavier(self.reduce_f3) 159 | weights_init_xavier(self.bn_f1) 160 | weights_init_xavier(self.bn_f2) 161 | weights_init_xavier(self.bn_f3) 162 | 163 | self.fusion_conv = nn.Conv1d(self.num_scale * self.total_split, 1, 1, bias=False) 164 | self.classifier = nn.ModuleList() 165 | for i in range(self.num_scale * self.total_split + 1): 166 | self.classifier.append(nn.Linear(self.num_hid, num_classes)) 167 | weights_init_xavier(self.fusion_conv) 168 | weights_init_xavier(self.classifier) 169 | 170 | def forward(self, x, adj=None): 171 | b, s, c, h, w = x.size() 172 | x = x.view(b * s, c, h, w) 173 | 174 | x = self.conv1(x) 175 | x = self.bn1(x) 176 | x = self.relu(x) 177 | x = self.maxpool(x) 178 | x = self.layer1(x) 179 | f1 = self.layer2(x) 180 | f2 = self.layer3(f1) 181 | f3 = self.layer4(f2) 182 | 183 | # global and local feature 184 | l_f1 = [] 185 | l_f2 = [] 186 | l_f3 = [] 187 | for idx, n in enumerate(self.total_parts): 188 | l_f1.append((self.avg_pool[idx](f1) + self.max_pool[idx](f1)).view(b, s, 512, -1)) 189 | l_f2.append((self.avg_pool[idx](f2) + self.max_pool[idx](f2)).view(b, s, 1024, -1)) 190 | l_f3.append((self.avg_pool[idx](f3) + self.max_pool[idx](f3)).view(b, s, 2048, -1)) 191 | l_f1 = torch.cat(l_f1, dim=3).permute(0, 3, 1, 2).contiguous() 192 | l_f1 = self.bn_f1(self.reduce_f1(l_f1).view(b * self.total_split * s, -1)).view(b, self.total_split * s, -1) 193 | l_f2 = torch.cat(l_f2, dim=3).permute(0, 3, 1, 2).contiguous() 194 | l_f2 = self.bn_f2(self.reduce_f2(l_f2).view(b * self.total_split * s, -1)).view(b, self.total_split * s, -1) 195 | l_f3 = torch.cat(l_f3, dim=3).permute(0, 3, 1, 2).contiguous() 196 | l_f3 = self.bn_f3(self.reduce_f3(l_f3).view(b * self.total_split * s, -1)).view(b, self.total_split * s, -1) 197 | f = torch.cat([l_f1, l_f2, l_f3], dim=1).view(b, self.num_scale * self.total_split, s, -1) 198 | 199 | vf = f.mean(dim=2) 200 | allf = [vf[:, i] for i in range(self.num_scale * self.total_split)] 201 | 202 | fused_f = self.fusion_conv(vf).view(b, self.num_hid) 203 | if not self.training: 204 | return fused_f 205 | 206 | allf.append(fused_f) 207 | y = [self.classifier[i](vf[:, i]) for i in range(self.num_scale * self.total_split)] 208 | y.append(self.classifier[-1](fused_f)) 209 | if self.loss == {'xent'}: 210 | return y 211 | elif self.loss == {'xent', 'htri'}: 212 | return y, allf 213 | else: 214 | raise KeyError("Unsupported loss: {}".format(self.loss)) 215 | 216 | 217 | class MSPyraPartGraphNet(nn.Module): 218 | def __init__(self, num_classes=100, loss={'xent', 'htri'}, num_split=3, use_pose=True, 219 | learn_graph=True, num_gb=3, **kwargs): 220 | super(MSPyraPartGraphNet, self).__init__() 221 | self.num_classes = num_classes 222 | self.loss = loss 223 | resnet50 = torchvision.models.resnet50(pretrained=True) 224 | # 0: conv2d 1: bn 2: relu 3: maxpool 4:layer1 5: layer2 6:layer3 7: layer4 225 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 226 | self.num_scale = 3 # number of layers for feature extraction 227 | self.num_split = num_split # number of feature split in spatial pooling 228 | self.total_split = sum(calc_splits(num_split)) 229 | self.num_hid = 512 230 | self.num_gb = num_gb 231 | self.use_pose = use_pose 232 | 233 | self.avg_pool = nn.ModuleList() 234 | self.max_pool = nn.ModuleList() 235 | for n in calc_splits(self.num_split): 236 | self.avg_pool.append(nn.AdaptiveAvgPool2d((n, 1))) 237 | self.max_pool.append(nn.AdaptiveMaxPool2d((n, 1))) 238 | 239 | self.reduce1 = nn.Linear(512, self.num_hid) 240 | self.bn1 = nn.BatchNorm1d(self.num_hid) 241 | self.reduce2 = nn.Linear(1024, self.num_hid) 242 | self.bn2 = nn.BatchNorm1d(self.num_hid) 243 | self.reduce3 = nn.Linear(2048, self.num_hid) 244 | self.bn3 = nn.BatchNorm1d(self.num_hid) 245 | weights_init_xavier(self.reduce1) 246 | weights_init_xavier(self.reduce2) 247 | weights_init_xavier(self.reduce3) 248 | weights_init_xavier(self.bn1) 249 | weights_init_xavier(self.bn2) 250 | weights_init_xavier(self.bn3) 251 | 252 | self.gbs = nn.ModuleList() 253 | for j in range(self.num_gb): 254 | self.gbs.append( 255 | GraphBlock(in_features=self.num_hid, out_features=self.num_hid, 256 | learn_graph=learn_graph, use_pose=use_pose)) 257 | 258 | self.fusion_conv = nn.Conv1d(self.num_scale * self.total_split, 1, 1, bias=False) 259 | self.classifiers = nn.ModuleList() 260 | for i in range(self.num_scale * self.total_split + 1): 261 | self.classifiers.append(nn.Linear(self.num_hid * (self.num_gb + 1), num_classes)) 262 | weights_init_xavier(self.fusion_conv) 263 | weights_init_xavier(self.classifiers) 264 | 265 | def forward(self, x, adj): 266 | b, s, c, h, w = x.size() 267 | x = x.view(b * s, c, h, w) 268 | for name, module in self.base._modules.items(): 269 | if name == '3': # maxpool 270 | x = module(x) 271 | elif name == '5': # layer2 272 | layer2 = module(x) 273 | elif name == '6': # layer3 274 | layer3 = module(layer2) 275 | elif name == '7': # layer4 276 | layer4 = module(layer3) 277 | else: 278 | x = module(x) 279 | # global and local feature 280 | l2_f = [] 281 | l3_f = [] 282 | l4_f = [] 283 | for idx, n in enumerate(calc_splits(self.num_split)): 284 | l2_f.append((self.avg_pool[idx](layer2) + self.max_pool[idx](layer2)).view(b, s, 512, -1)) 285 | l3_f.append((self.avg_pool[idx](layer3) + self.max_pool[idx](layer3)).view(b, s, 1024, -1)) 286 | l4_f.append((self.avg_pool[idx](layer4) + self.max_pool[idx](layer4)).view(b, s, 2048, -1)) 287 | l2_f = torch.cat(l2_f, dim=3).permute(0, 3, 1, 2).contiguous() 288 | l3_f = torch.cat(l3_f, dim=3).permute(0, 3, 1, 2).contiguous() 289 | l4_f = torch.cat(l4_f, dim=3).permute(0, 3, 1, 2).contiguous() 290 | 291 | l2_f = self.reduce1(l2_f).view(b * self.total_split * s, self.num_hid) 292 | l2_f = self.bn1(l2_f).view(b, self.total_split * s, self.num_hid) 293 | l3_f = self.reduce2(l3_f).view(b * self.total_split * s, self.num_hid) 294 | l3_f = self.bn2(l3_f).view(b, self.total_split * s, self.num_hid) 295 | l4_f = self.reduce3(l4_f).view(b * self.total_split * s, self.num_hid) 296 | l4_f = self.bn3(l4_f).view(b, self.total_split * s, self.num_hid) 297 | f = torch.cat([l2_f, l3_f, l4_f], dim=1) 298 | 299 | gb_out = [f] 300 | for i in range(self.num_gb): 301 | gb_out.append(self.gbs[i](gb_out[-1], adj)) 302 | f = torch.stack(gb_out, dim=2).view(b, self.num_scale * self.total_split, s, (self.num_gb + 1) * self.num_hid) 303 | 304 | vf = f.mean(dim=2) 305 | allf = [vf[:, i] for i in range(self.num_scale * self.total_split)] 306 | 307 | fused_f = self.fusion_conv(vf).view(b, (self.num_gb + 1) * self.num_hid) 308 | allf.append(fused_f) 309 | 310 | if not self.training: 311 | return fused_f 312 | y = [self.classifiers[i](vf[:, i]) for i in range(self.num_scale * self.total_split)] 313 | y.append(self.classifiers[-1](fused_f)) 314 | 315 | if self.loss == {'xent'}: 316 | return y 317 | elif self.loss == {'xent', 'htri'}: 318 | return y, allf 319 | else: 320 | raise KeyError("Unsupported loss: {}".format(self.loss)) 321 | -------------------------------------------------------------------------------- /torchreid/models/gsta.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | __all__ = ['gsta'] 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | import torch.utils.model_zoo as model_zoo 11 | 12 | from torchreid.utils.reidtools import calc_splits 13 | 14 | model_urls = { 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | 26 | 27 | class Bottleneck(nn.Module): 28 | expansion = 4 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(Bottleneck, self).__init__() 32 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 35 | padding=1, bias=False) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 38 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv3(out) 55 | out = self.bn3(out) 56 | 57 | if self.downsample is not None: 58 | residual = self.downsample(x) 59 | 60 | out += residual 61 | out = self.relu(out) 62 | 63 | return out 64 | 65 | 66 | class GraphLayer(nn.Module): 67 | """ 68 | graph block with residual learning. 69 | """ 70 | 71 | def __init__(self, in_features, out_features, learn_graph=True, use_pose=True, 72 | dist_method='l2', gamma=0.1, k=4, **kwargs): 73 | """ 74 | :param in_features: input feature size. 75 | :param out_features: output feature size. 76 | :param learn_graph: learn a affinity graph or not. 77 | :param use_pose: use graph from pose estimation or not. 78 | :param dist_method: calculate the similarity between the vertex. 79 | :param k: nearest neighbor size. 80 | :param kwargs: 81 | """ 82 | super(GraphLayer, self).__init__() 83 | self.in_features = in_features 84 | self.out_features = out_features 85 | self.learn_graph = learn_graph 86 | self.use_pose = use_pose 87 | self.dist_method = dist_method 88 | self.gamma = gamma 89 | 90 | assert use_pose or learn_graph 91 | self.linear = nn.Linear(in_features, out_features, bias=False) 92 | self.bn = nn.BatchNorm1d(out_features) 93 | self.relu = nn.LeakyReLU(0.1) 94 | 95 | if self.learn_graph and dist_method == 'dot': 96 | num_hid = self.in_features // 8 97 | self.emb_q = nn.Linear(out_features, num_hid) 98 | self.emb_k = nn.Linear(out_features, num_hid) 99 | 100 | self._init_params() 101 | 102 | def get_sim_matrix(self, v_feats): 103 | """ 104 | generate similarity matrix 105 | :param v_feats: (batch, num_vertex, num_hid) 106 | :return: sim_matrix: (batch, num_vertex, num_vertex) 107 | """ 108 | if self.dist_method == 'dot': 109 | emb_q = self.emb_q(v_feats) 110 | emb_k = self.emb_k(v_feats) 111 | sim_matrix = torch.bmm(emb_q, emb_k.transpose(1, 2)) 112 | elif self.dist_method == 'l2': 113 | # calculate the pairwise distance with exp(x) - 1 / exp(x) + 1 114 | distmat = torch.pow(v_feats, 2).sum(dim=2).unsqueeze(1) + \ 115 | torch.pow(v_feats, 2).sum(dim=2).unsqueeze(2) 116 | distmat -= 2 * torch.bmm(v_feats, v_feats.transpose(1, 2)) 117 | distmat = distmat.clamp(1e-12).sqrt() # numerical stability 118 | sim_matrix = 2 / (distmat.exp() + 1) 119 | else: 120 | raise NotImplementedError 121 | return sim_matrix 122 | 123 | def _init_params(self): 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 127 | if m.bias is not None: 128 | nn.init.constant_(m.bias, 0) 129 | elif isinstance(m, nn.BatchNorm2d): 130 | nn.init.constant_(m.weight, 1) 131 | nn.init.constant_(m.bias, 0) 132 | elif isinstance(m, nn.BatchNorm1d): 133 | nn.init.constant_(m.weight, 1) 134 | nn.init.constant_(m.bias, 0) 135 | elif isinstance(m, nn.Linear): 136 | nn.init.normal_(m.weight, 0, 0.01) 137 | if m.bias is not None: 138 | nn.init.constant_(m.bias, 0) 139 | 140 | def forward(self, input, adj): 141 | """ 142 | :param input: (b, num_vertex, num_hid), where num_vertex = num_scale * seq_len * num_splits 143 | :param adj: (b, num_vertex, num_vertex), the pose-driven graph 144 | :return: 145 | """ 146 | h = self.linear(input) 147 | N, V, C = h.size() 148 | 149 | # mask = torch.ones((N, V, V)).to(h.device) 150 | # for i in range(mask.size(1)): 151 | # mask[:, i, i] = 0 152 | 153 | if self.use_pose: 154 | # adj = mask * adj 155 | adj = F.normalize(adj, p=1, dim=2) 156 | 157 | if self.learn_graph: 158 | graph = self.get_sim_matrix(input) 159 | # graph = mask * graph 160 | graph = F.normalize(graph, p=1, dim=2) 161 | if self.use_pose: 162 | graph = (adj + graph) / 2 163 | else: 164 | graph = adj 165 | 166 | h_prime = torch.bmm(graph, h) 167 | h_prime = self.bn(h_prime.view(N * V, -1)).view(N, V, -1) 168 | h_prime = self.relu(h_prime) 169 | 170 | return (1 - self.gamma) * input + self.gamma * h_prime 171 | 172 | 173 | class GSTA(nn.Module): 174 | def __init__(self, num_classes, loss, block, layers, 175 | num_split, pyramid_part, num_gb, use_pose, learn_graph, 176 | consistent_loss, nonlinear='relu', **kwargs): 177 | self.inplanes = 64 178 | super(GSTA, self).__init__() 179 | self.loss = loss 180 | self.feature_dim = 512 * block.expansion 181 | 182 | # backbone network 183 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 184 | self.bn1 = nn.BatchNorm2d(64) 185 | self.relu = nn.ReLU(inplace=True) 186 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 187 | self.layer1 = self._make_layer(block, 64, layers[0]) 188 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 189 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 190 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1) 191 | 192 | # sta layers 193 | self.num_split = num_split 194 | self.total_split_list = calc_splits(num_split) if pyramid_part else [num_split] 195 | self.total_split = sum(self.total_split_list) 196 | 197 | self.parts_avgpool = nn.ModuleList() 198 | for n in self.total_split_list: 199 | self.parts_avgpool.append(nn.AdaptiveAvgPool2d((n, 1))) 200 | 201 | # graph layers 202 | self.num_gb = num_gb 203 | self.graph_layers = nn.ModuleList() 204 | for i in range(num_gb): 205 | self.graph_layers.append(GraphLayer(in_features=self.feature_dim, 206 | out_features=self.feature_dim, 207 | use_pose=use_pose, 208 | learn_graph=learn_graph)) 209 | 210 | self.consistent_loss = consistent_loss 211 | 212 | self.bottleneck = nn.BatchNorm1d(self.feature_dim) 213 | self.bottleneck.bias.requires_grad_(False) 214 | self.classifier = nn.Linear(self.feature_dim, num_classes, bias=False) 215 | 216 | self._init_params() 217 | 218 | def _make_layer(self, block, planes, blocks, stride=1): 219 | downsample = None 220 | if stride != 1 or self.inplanes != planes * block.expansion: 221 | downsample = nn.Sequential( 222 | nn.Conv2d(self.inplanes, planes * block.expansion, 223 | kernel_size=1, stride=stride, bias=False), 224 | nn.BatchNorm2d(planes * block.expansion), 225 | ) 226 | 227 | layers = [] 228 | layers.append(block(self.inplanes, planes, stride, downsample)) 229 | self.inplanes = planes * block.expansion 230 | for i in range(1, blocks): 231 | layers.append(block(self.inplanes, planes)) 232 | 233 | return nn.Sequential(*layers) 234 | 235 | def _init_params(self): 236 | for m in self.modules(): 237 | if isinstance(m, nn.Conv2d): 238 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 239 | if m.bias is not None: 240 | nn.init.constant_(m.bias, 0) 241 | elif isinstance(m, nn.BatchNorm2d): 242 | nn.init.constant_(m.weight, 1) 243 | nn.init.constant_(m.bias, 0) 244 | elif isinstance(m, nn.BatchNorm1d): 245 | nn.init.constant_(m.weight, 1) 246 | nn.init.constant_(m.bias, 0) 247 | elif isinstance(m, nn.Linear): 248 | nn.init.normal_(m.weight, 0, 0.01) 249 | if m.bias is not None: 250 | nn.init.constant_(m.bias, 0) 251 | 252 | def featuremaps(self, x): 253 | x = self.conv1(x) 254 | x = self.bn1(x) 255 | x = self.relu(x) 256 | x = self.maxpool(x) 257 | x = self.layer1(x) 258 | x = self.layer2(x) 259 | x = self.layer3(x) 260 | x = self.layer4(x) 261 | return x 262 | 263 | def _attention_op(self, feat): 264 | """ 265 | do attention fusion 266 | :param feat: (batch, seq_len, num_split, c) 267 | :return: feat: (batch, num_split, c) 268 | """ 269 | att = F.normalize(feat.norm(p=2, dim=3, keepdim=True), p=1, dim=1) 270 | f = feat.mul(att).sum(dim=1) 271 | return f 272 | 273 | def forward(self, x, adj, *args): 274 | B, S, C, H, W = x.size() 275 | x = x.view(B * S, C, H, W) 276 | f = self.featuremaps(x) 277 | _, c, h, w = f.shape 278 | 279 | v_g = list() 280 | for idx, n in enumerate(self.total_split_list): 281 | v_g.append(self.parts_avgpool[idx](f).view(B, S, c, n)) 282 | v_g = torch.cat(v_g, dim=3) 283 | f = v_g.transpose(2, 3).contiguous().view(B, S * self.total_split, c) 284 | # graph propagation 285 | for i in range(self.num_gb): 286 | f = self.graph_layers[i](f, adj) 287 | f = f.view(B, S, self.total_split, c) 288 | 289 | f_fuse = self._attention_op(f) 290 | 291 | f_g = f_fuse.mean(dim=1).view(B, -1) 292 | bn = self.bottleneck(f_g) 293 | 294 | # consistent 295 | if self.consistent_loss and self.training: 296 | # random select sub frames 297 | sub_index = list() 298 | for i in range(B): 299 | tmp_ind = list(range(0, S)) 300 | tmp_ind.remove(np.random.randint(S)) 301 | sub_index.append(tmp_ind) 302 | sub_index = torch.LongTensor(sub_index).to(f_fuse.device) 303 | sf = torch.gather(f, dim=1, index=sub_index.view(B, S - 1, 1, 1).repeat(1, 1, f.size(2), f.size(3))) 304 | sf_fuse = self._attention_op(sf) 305 | sf_g = sf_fuse.mean(dim=1).view(B, -1) 306 | sbn = self.bottleneck(sf_g) 307 | sy = self.classifier(sbn) 308 | 309 | if not self.training: 310 | return bn 311 | 312 | y = self.classifier(bn) 313 | 314 | if self.loss == {'xent'}: 315 | if self.consistent_loss: 316 | return [y, sy] 317 | else: 318 | return y 319 | elif self.loss == {'xent', 'htri'}: 320 | if self.consistent_loss: 321 | return [y, sy], [f_g, sf_g] 322 | else: 323 | return y, f_g 324 | else: 325 | raise KeyError('Unsupported loss: {}'.format(self.loss)) 326 | 327 | 328 | def init_pretrained_weights(model, model_url): 329 | """Initializes model with pretrained weights. 330 | 331 | Layers that don't match with pretrained layers in name or size are kept unchanged. 332 | """ 333 | pretrain_dict = model_zoo.load_url(model_url) 334 | model_dict = model.state_dict() 335 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 336 | model_dict.update(pretrain_dict) 337 | model.load_state_dict(model_dict) 338 | 339 | 340 | def gsta(num_classes, loss, last_stride, num_split, num_gb, num_scale, pyramid_part, use_pose, learn_graph, 341 | pretrained=True, consistent_loss=False, **kwargs): 342 | model = GSTA( 343 | num_classes=num_classes, 344 | loss=loss, 345 | block=Bottleneck, 346 | layers=[3, 4, 6, 3], 347 | last_stride=last_stride, 348 | num_split=num_split, 349 | pyramid_part=pyramid_part, 350 | num_gb=num_gb, 351 | use_pose=use_pose, 352 | learn_graph=learn_graph, 353 | consistent_loss=consistent_loss, 354 | nonlinear='relu', 355 | **kwargs 356 | ) 357 | if pretrained: 358 | print('init pretrained weights from {}'.format(model_urls['resnet50'])) 359 | init_pretrained_weights(model, model_urls['resnet50']) 360 | return model -------------------------------------------------------------------------------- /torchreid/models/res50tp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | 6 | model_urls = { 7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 8 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 11 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 12 | } 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, inplanes, planes, stride=1, downsample=None): 25 | super(BasicBlock, self).__init__() 26 | self.conv1 = conv3x3(inplanes, planes, stride) 27 | self.bn1 = nn.BatchNorm2d(planes) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv2 = conv3x3(planes, planes) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | out += residual 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | class Bottleneck(nn.Module): 54 | expansion = 4 55 | 56 | def __init__(self, inplanes, planes, stride=1, downsample=None): 57 | super(Bottleneck, self).__init__() 58 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(planes) 60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 61 | padding=1, bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 64 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.downsample = downsample 67 | self.stride = stride 68 | 69 | def forward(self, x): 70 | residual = x 71 | 72 | out = self.conv1(x) 73 | out = self.bn1(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv3(out) 81 | out = self.bn3(out) 82 | 83 | if self.downsample is not None: 84 | residual = self.downsample(x) 85 | 86 | out += residual 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class DimReduceLayer(nn.Module): 93 | 94 | def __init__(self, in_channels, out_channels, nonlinear): 95 | super(DimReduceLayer, self).__init__() 96 | layers = [] 97 | layers.append(nn.Linear(in_channels, out_channels, bias=False)) 98 | layers.append(nn.BatchNorm1d(out_channels)) 99 | 100 | if nonlinear == 'relu': 101 | layers.append(nn.ReLU(inplace=True)) 102 | elif nonlinear == 'leakyrelu': 103 | layers.append(nn.LeakyReLU(0.1)) 104 | 105 | self.layers = nn.Sequential(*layers) 106 | 107 | def forward(self, x): 108 | return self.layers(x) 109 | 110 | 111 | class ResNet50TP(nn.Module): 112 | def __init__(self, num_classes, loss, block, layers, last_stride=1, 113 | bnneck=True, **kwargs): 114 | self.inplanes = 64 115 | super(ResNet50TP, self).__init__() 116 | self.num_classes = num_classes 117 | self.loss = loss 118 | self.feature_dim = 2048 119 | self.num_scale = 3 # number of layers for feature extraction 120 | 121 | # backbone network 122 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 123 | self.bn1 = nn.BatchNorm2d(64) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 126 | self.layer1 = self._make_layer(block, 64, layers[0]) 127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 128 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 129 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 130 | 131 | self.part = 4 132 | self.avg_pool = nn.AdaptiveAvgPool2d((self.part, 1)) 133 | 134 | self.bottleneck = nn.BatchNorm1d(self.feature_dim) 135 | self.bottleneck.bias.requires_grad_(False) 136 | self.classifier = nn.Linear(self.feature_dim, num_classes, bias=False) 137 | 138 | self._init_params() 139 | 140 | def _make_layer(self, block, planes, blocks, stride=1): 141 | downsample = None 142 | if stride != 1 or self.inplanes != planes * block.expansion: 143 | downsample = nn.Sequential( 144 | nn.Conv2d(self.inplanes, planes * block.expansion, 145 | kernel_size=1, stride=stride, bias=False), 146 | nn.BatchNorm2d(planes * block.expansion), 147 | ) 148 | 149 | layers = [] 150 | layers.append(block(self.inplanes, planes, stride, downsample)) 151 | self.inplanes = planes * block.expansion 152 | for i in range(1, blocks): 153 | layers.append(block(self.inplanes, planes)) 154 | 155 | return nn.Sequential(*layers) 156 | 157 | def _init_params(self): 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 161 | if m.bias is not None: 162 | nn.init.constant_(m.bias, 0) 163 | elif isinstance(m, nn.BatchNorm2d): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | elif isinstance(m, nn.BatchNorm1d): 167 | nn.init.constant_(m.weight, 1) 168 | nn.init.constant_(m.bias, 0) 169 | elif isinstance(m, nn.Linear): 170 | nn.init.normal_(m.weight, 0, 0.01) 171 | if m.bias is not None: 172 | nn.init.constant_(m.bias, 0) 173 | 174 | def _extract_feat(self, x): 175 | x = self.conv1(x) 176 | x = self.bn1(x) 177 | x = self.relu(x) 178 | x = self.maxpool(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | x = self.layer4(x) 184 | return x 185 | 186 | def forward(self, x, adj): 187 | b, s, c, h, w = x.size() 188 | 189 | x = x.view(b * s, c, h, w) 190 | x = self._extract_feat(x) 191 | # global feature 192 | v_g = self.avg_pool(x).view(b, s, self.feature_dim, self.part) 193 | t_a = F.normalize(v_g.norm(p=2, dim=2, keepdim=True), p=1, dim=1) 194 | f = v_g.mul(t_a).sum(dim=1) 195 | f = F.adaptive_avg_pool1d(f, 1).view(b, -1) 196 | 197 | bn = self.bottleneck(f) 198 | 199 | if not self.training: 200 | return bn 201 | 202 | y = self.classifier(bn) 203 | 204 | if self.loss == {'xent'}: 205 | return y 206 | elif self.loss == {'xent', 'htri'}: 207 | return y, f 208 | else: 209 | raise KeyError("Unsupported loss: {}".format(self.loss)) 210 | 211 | 212 | def init_pretrained_weights(model, model_url): 213 | """Initializes model with pretrained weights. 214 | 215 | Layers that don't match with pretrained layers in name or size are kept unchanged. 216 | """ 217 | pretrain_dict = model_zoo.load_url(model_url) 218 | model_dict = model.state_dict() 219 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 220 | model_dict.update(pretrain_dict) 221 | model.load_state_dict(model_dict) 222 | 223 | 224 | def res50tp(num_classes=100, loss={'xent', 'htri'}, pretrain=True, bnneck=True, last_stride=1, **kwargs): 225 | model = ResNet50TP( 226 | num_classes=num_classes, 227 | loss=loss, 228 | block=Bottleneck, 229 | layers=[3, 4, 6, 3], 230 | bnneck=bnneck, 231 | last_stride=last_stride, 232 | **kwargs 233 | ) 234 | if pretrain: 235 | init_pretrained_weights(model, model_urls['resnet50']) 236 | return model -------------------------------------------------------------------------------- /torchreid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | import torchvision 8 | 9 | from ..utils.torchtools import weights_init_kaiming 10 | 11 | __all__ = ['ResNet50', 'ResNet101', 'ResNet50M', 'ResNet50B'] 12 | 13 | 14 | class ResNet50(nn.Module): 15 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 16 | super(ResNet50, self).__init__() 17 | self.loss = loss 18 | resnet50 = torchvision.models.resnet50(pretrained=True) 19 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 20 | self.classifier = nn.Linear(2048, num_classes) 21 | self.feat_dim = 2048 22 | 23 | def forward(self, x): 24 | x = self.base(x) 25 | x = F.avg_pool2d(x, x.size()[2:]) 26 | f = x.view(x.size(0), -1) 27 | if not self.training: 28 | return f 29 | y = self.classifier(f) 30 | 31 | if self.loss == {'xent'}: 32 | return y 33 | elif self.loss == {'xent', 'htri'}: 34 | return y, f 35 | else: 36 | raise KeyError("Unsupported loss: {}".format(self.loss)) 37 | 38 | 39 | class ResNet101(nn.Module): 40 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 41 | super(ResNet101, self).__init__() 42 | self.loss = loss 43 | resnet101 = torchvision.models.resnet101(pretrained=True) 44 | self.base = nn.Sequential(*list(resnet101.children())[:-2]) 45 | self.classifier = nn.Linear(2048, num_classes) 46 | self.feat_dim = 2048 # feature dimension 47 | 48 | def forward(self, x): 49 | x = self.base(x) 50 | x = F.avg_pool2d(x, x.size()[2:]) 51 | f = x.view(x.size(0), -1) 52 | if not self.training: 53 | return f 54 | y = self.classifier(f) 55 | 56 | if self.loss == {'xent'}: 57 | return y 58 | elif self.loss == {'xent', 'htri'}: 59 | return y, f 60 | else: 61 | raise KeyError("Unsupported loss: {}".format(self.loss)) 62 | 63 | 64 | class ResNet50M(nn.Module): 65 | """ResNet50 + mid-level features. 66 | 67 | Reference: 68 | Yu et al. The Devil is in the Middle: Exploiting Mid-level Representations for 69 | Cross-Domain Instance Matching. arXiv:1711.08106. 70 | """ 71 | def __init__(self, num_classes=0, loss={'xent'}, **kwargs): 72 | super(ResNet50M, self).__init__() 73 | self.loss = loss 74 | resnet50 = torchvision.models.resnet50(pretrained=True) 75 | base = nn.Sequential(*list(resnet50.children())[:-2]) 76 | self.layers1 = nn.Sequential(base[0], base[1], base[2]) 77 | self.layers2 = nn.Sequential(base[3], base[4]) 78 | self.layers3 = base[5] 79 | self.layers4 = base[6] 80 | self.layers5a = base[7][0] 81 | self.layers5b = base[7][1] 82 | self.layers5c = base[7][2] 83 | self.fc_fuse = nn.Sequential(nn.Linear(4096, 1024), nn.BatchNorm1d(1024), nn.ReLU()) 84 | self.classifier = nn.Linear(3072, num_classes) 85 | self.feat_dim = 3072 # feature dimension 86 | 87 | def forward(self, x): 88 | x1 = self.layers1(x) 89 | x2 = self.layers2(x1) 90 | x3 = self.layers3(x2) 91 | x4 = self.layers4(x3) 92 | x5a = self.layers5a(x4) 93 | x5b = self.layers5b(x5a) 94 | x5c = self.layers5c(x5b) 95 | 96 | x5a_feat = F.avg_pool2d(x5a, x5a.size()[2:]).view(x5a.size(0), x5a.size(1)) 97 | x5b_feat = F.avg_pool2d(x5b, x5b.size()[2:]).view(x5b.size(0), x5b.size(1)) 98 | x5c_feat = F.avg_pool2d(x5c, x5c.size()[2:]).view(x5c.size(0), x5c.size(1)) 99 | 100 | midfeat = torch.cat((x5a_feat, x5b_feat), dim=1) 101 | midfeat = self.fc_fuse(midfeat) 102 | 103 | combofeat = torch.cat((x5c_feat, midfeat), dim=1) 104 | 105 | if not self.training: 106 | return combofeat 107 | 108 | prelogits = self.classifier(combofeat) 109 | 110 | if self.loss == {'xent'}: 111 | return prelogits 112 | elif self.loss == {'xent', 'htri'}: 113 | return prelogits, combofeat 114 | else: 115 | raise KeyError("Unsupported loss: {}".format(self.loss)) 116 | 117 | 118 | class ResNet50B(nn.Module): 119 | """Resnet50+bottleneck 120 | 121 | Reference: 122 | https://github.com/L1aoXingyu/reid_baseline 123 | """ 124 | def __init__(self, num_classes=0, loss={'xent'}, **kwargs): 125 | super(ResNet50B, self).__init__() 126 | self.loss = loss 127 | resnet50 = torchvision.models.resnet50(pretrained=True) 128 | resnet50.layer4[0].conv2.stride = (1, 1) 129 | resnet50.layer4[0].downsample[0].stride = (1, 1) 130 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 131 | 132 | self.in_planes = 2048 133 | self.bottleneck = nn.Sequential( 134 | nn.Linear(self.in_planes, 512), 135 | nn.BatchNorm1d(512), 136 | nn.LeakyReLU(0.1), 137 | nn.Dropout(p=0.5)) 138 | self.bottleneck.apply(weights_init_kaiming) 139 | 140 | self.classifier = nn.Linear(512, num_classes) 141 | self.classifier.apply(weights_init_kaiming) 142 | 143 | def forward(self, x): 144 | global_feat = self.base(x) 145 | global_feat = F.avg_pool2d(global_feat, global_feat.size()[-2:]) 146 | global_feat = global_feat.view(global_feat.size(0), -1) 147 | if not self.training: 148 | return global_feat 149 | else: 150 | feat = self.bottleneck(global_feat) 151 | y = self.classifier(feat) 152 | 153 | if self.loss == {'xent'}: 154 | return y 155 | elif self.loss == {'xent', 'htri'}: 156 | return y, global_feat 157 | else: 158 | raise KeyError("Unsupported loss: {}".format(self.loss)) -------------------------------------------------------------------------------- /torchreid/models/resnet3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import math 7 | from functools import partial 8 | 9 | __all__ = [ 10 | 'ResNet3d', 'resnet3d10', 'resnet3d18', 'resnet3d34', 'resnet3d50', 'resnet3d101', 11 | 'resnet3d152', 'resnet3d200' 12 | ] 13 | 14 | 15 | def conv3x3x3(in_planes, out_planes, stride=1): 16 | # 3x3x3 convolution with padding 17 | return nn.Conv3d( 18 | in_planes, 19 | out_planes, 20 | kernel_size=3, 21 | stride=stride, 22 | padding=1, 23 | bias=False) 24 | 25 | 26 | def downsample_basic_block(x, planes, stride): 27 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 28 | zero_pads = torch.Tensor( 29 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 30 | out.size(4)).zero_() 31 | if isinstance(out.data, torch.cuda.FloatTensor): 32 | zero_pads = zero_pads.cuda() 33 | 34 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 35 | 36 | return out 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | expansion = 1 41 | 42 | def __init__(self, inplanes, planes, stride=1, downsample=None): 43 | super(BasicBlock, self).__init__() 44 | self.conv1 = conv3x3x3(inplanes, planes, stride) 45 | self.bn1 = nn.BatchNorm3d(planes) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.conv2 = conv3x3x3(planes, planes) 48 | self.bn2 = nn.BatchNorm3d(planes) 49 | self.downsample = downsample 50 | self.stride = stride 51 | 52 | def forward(self, x): 53 | residual = x 54 | 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | 62 | if self.downsample is not None: 63 | residual = self.downsample(x) 64 | 65 | out += residual 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, inplanes, planes, stride=1, downsample=None): 75 | super(Bottleneck, self).__init__() 76 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 77 | self.bn1 = nn.BatchNorm3d(planes) 78 | self.conv2 = nn.Conv3d( 79 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 80 | self.bn2 = nn.BatchNorm3d(planes) 81 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 82 | self.bn3 = nn.BatchNorm3d(planes * 4) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | residual = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | residual = self.downsample(x) 103 | 104 | out += residual 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet3d(nn.Module): 111 | def __init__(self, block, layers, shortcut_type='B', num_classes=400): 112 | self.inplanes = 64 113 | super(ResNet3d, self).__init__() 114 | self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False) 115 | self.bn1 = nn.BatchNorm3d(64) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 118 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 119 | self.layer2 = self._make_layer(block, 128, layers[1], shortcut_type, stride=2) 120 | self.layer3 = self._make_layer(block, 256, layers[2], shortcut_type, stride=2) 121 | self.layer4 = self._make_layer(block, 512, layers[3], shortcut_type, stride=2) 122 | self.fc = nn.Linear(512 * block.expansion, num_classes) 123 | 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv3d): 126 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 127 | elif isinstance(m, nn.BatchNorm3d): 128 | nn.init.constant_(m.weight, 1) 129 | nn.init.constant_(m.bias, 0) 130 | 131 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 132 | downsample = None 133 | if stride != 1 or self.inplanes != planes * block.expansion: 134 | if shortcut_type == 'A': 135 | downsample = partial( 136 | downsample_basic_block, 137 | planes=planes * block.expansion, 138 | stride=stride) 139 | else: 140 | downsample = nn.Sequential( 141 | nn.Conv3d( 142 | self.inplanes, 143 | planes * block.expansion, 144 | kernel_size=1, 145 | stride=stride, 146 | bias=False), nn.BatchNorm3d(planes * block.expansion)) 147 | 148 | layers = [] 149 | layers.append(block(self.inplanes, planes, stride, downsample)) 150 | self.inplanes = planes * block.expansion 151 | for i in range(1, blocks): 152 | layers.append(block(self.inplanes, planes)) 153 | 154 | return nn.Sequential(*layers) 155 | 156 | def load_matched_state_dict(self, state_dict): 157 | 158 | own_state = self.state_dict() 159 | for name, param in state_dict.items(): 160 | if name not in own_state: 161 | continue 162 | # if isinstance(param, Parameter): 163 | # backwards compatibility for serialized parameters 164 | param = param.data 165 | print("loading " + name) 166 | own_state[name].copy_(param) 167 | 168 | def forward(self, x): 169 | # default size is (b, s, c, w, h), s for seq_len, c for channel 170 | # convert for 3d cnn, (b, c, s, w, h) 171 | x = x.permute(0, 2, 1, 3, 4) 172 | x = self.conv1(x) 173 | x = self.bn1(x) 174 | x = self.relu(x) 175 | x = self.maxpool(x) 176 | 177 | x = self.layer1(x) 178 | x = self.layer2(x) 179 | x = self.layer3(x) 180 | x = self.layer4(x) 181 | x = F.adaptive_avg_pool3d(x, (1, 1, 1)) 182 | x = x.view(x.size(0), -1) 183 | y = self.fc(x) 184 | 185 | return y 186 | 187 | 188 | def get_fine_tuning_parameters(model, ft_begin_index): 189 | if ft_begin_index == 0: 190 | return model.parameters() 191 | 192 | ft_module_names = [] 193 | for i in range(ft_begin_index, 5): 194 | ft_module_names.append('layer{}'.format(i)) 195 | ft_module_names.append('fc') 196 | 197 | parameters = [] 198 | for k, v in model.named_parameters(): 199 | for ft_module in ft_module_names: 200 | if ft_module in k: 201 | parameters.append({'params': v}) 202 | break 203 | else: 204 | parameters.append({'params': v, 'lr': 0.0}) 205 | 206 | return parameters 207 | 208 | 209 | def resnet3d10(pretrained='', **kwargs): 210 | """Constructs a ResNet3d-10 model. 211 | """ 212 | model = ResNet3d(BasicBlock, [1, 1, 1, 1], **kwargs) 213 | if pretrained: 214 | model = load_state_dict(model, pretrained) 215 | return model 216 | 217 | 218 | def resnet3d18(pretrained='', **kwargs): 219 | """Constructs a ResNet3d-18 model. 220 | """ 221 | model = ResNet3d(BasicBlock, [2, 2, 2, 2], **kwargs) 222 | if pretrained: 223 | model = load_state_dict(model, pretrained) 224 | return model 225 | 226 | 227 | def resnet3d34(pretrained='', **kwargs): 228 | """Constructs a ResNet3d-34 model. 229 | """ 230 | model = ResNet3d(BasicBlock, [3, 4, 6, 3], **kwargs) 231 | if pretrained: 232 | model = load_state_dict(model, pretrained) 233 | return model 234 | 235 | 236 | def resnet3d50(pretrained='', **kwargs): 237 | """Constructs a ResNet3d-50 model. 238 | """ 239 | model = ResNet3d(Bottleneck, [3, 4, 6, 3], **kwargs) 240 | if pretrained: 241 | model = load_state_dict(model, pretrained) 242 | return model 243 | 244 | 245 | def resnet3d101(pretrained='', **kwargs): 246 | """Constructs a ResNet3d-101 model. 247 | """ 248 | model = ResNet3d(Bottleneck, [3, 4, 23, 3], **kwargs) 249 | if pretrained: 250 | model = load_state_dict(model, pretrained) 251 | return model 252 | 253 | 254 | def resnet3d152(pretrained='', **kwargs): 255 | """Constructs a ResNet3d-101 model. 256 | """ 257 | model = ResNet3d(Bottleneck, [3, 8, 36, 3], **kwargs) 258 | if pretrained: 259 | model = load_state_dict(model, pretrained) 260 | return model 261 | 262 | 263 | def resnet3d200(pretrained='', **kwargs): 264 | """Constructs a ResNet3d-101 model. 265 | """ 266 | model = ResNet3d(Bottleneck, [3, 24, 36, 3], **kwargs) 267 | if pretrained: 268 | model = load_state_dict(model, pretrained) 269 | return model 270 | 271 | def load_state_dict(model, pretrained): 272 | assert os.path.exists(pretrained), '{} is not exists'.format(os.path.abspath(pretrained)) 273 | checkpoint = torch.load(pretrained, map_location='cpu') 274 | pretrain_dict = checkpoint['state_dict'] 275 | state_dict = {} 276 | for key in pretrain_dict: 277 | state_dict[key.partition("module.")[2]] = pretrain_dict[key] 278 | pretrain_dict = state_dict 279 | 280 | model_dict = model.state_dict() 281 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 282 | assert len(pretrain_dict) == len(model_dict) 283 | model_dict.update(pretrain_dict) 284 | model.load_state_dict(model_dict) 285 | return model 286 | -------------------------------------------------------------------------------- /torchreid/models/resnet3dt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | if __package__: 6 | from .resnet3d import * 7 | else: 8 | from resnet3d import * 9 | 10 | __all__ = ['ResNet3dT', 'resnet3dt50'] 11 | 12 | class ResNet3dT(nn.Module): 13 | networks = {'resnet3d10': None, 14 | 'resnet3d18': None, 15 | 'resnet3d34': None, 16 | 'resnet3d50': './pretrained/resnet-50-kinetics.pth'} 17 | def __init__(self, network, num_classes, loss={'xent', 'htri'}, pretrained='', **kwargs): 18 | super(ResNet3dT, self).__init__(**kwargs) 19 | assert network in self.networks, '{} is not supported'.format(network) 20 | if not pretrained: 21 | pretrained = self.networks[network] 22 | 23 | resnet3d = eval(network)(pretrained=pretrained) 24 | self.base = nn.Sequential(*list(resnet3d.children())[:-1]) 25 | self.num_classes = num_classes 26 | self.loss = loss 27 | self.fc = nn.Linear(resnet3d.fc.in_features, num_classes) 28 | 29 | def forward(self, x): 30 | # default size is (b, s, c, w, h), s for seq_len, c for channel 31 | # convert for 3d cnn, (b, c, s, w, h) 32 | x = x.permute(0, 2, 1, 3, 4) 33 | x = self.base(x) 34 | x = F.adaptive_avg_pool3d(x, (1, 1, 1)) 35 | x = x.view(x.size(0), -1) 36 | if not self.training: 37 | return x 38 | 39 | y = self.fc(x) 40 | if self.loss == {'xent'}: 41 | return y 42 | elif self.loss == {'xent', 'htri'}: 43 | return y, x 44 | else: 45 | raise KeyError("Unsupported loss: {}".format(self.loss)) 46 | 47 | def resnet3dt50(**kwargs): 48 | return ResNet3dT(network='resnet3d50', **kwargs) 49 | 50 | if __name__ == '__main__': 51 | model = resnet3dt50(num_classes=625, pretrained='../../pretrained/resnet-50-kinetics.pth') 52 | model(torch.randn((1, 4, 3, 224, 112))) -------------------------------------------------------------------------------- /torchreid/models/resnet50_s1.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | model_urls = { 5 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 6 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 7 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 8 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 9 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 10 | } 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class DimReduceLayer(nn.Module): 91 | 92 | def __init__(self, in_channels, out_channels, nonlinear): 93 | super(DimReduceLayer, self).__init__() 94 | layers = [] 95 | layers.append(nn.Linear(in_channels, out_channels, bias=False)) 96 | layers.append(nn.BatchNorm1d(out_channels)) 97 | 98 | if nonlinear == 'relu': 99 | layers.append(nn.ReLU(inplace=True)) 100 | elif nonlinear == 'leakyrelu': 101 | layers.append(nn.LeakyReLU(0.1)) 102 | 103 | self.layers = nn.Sequential(*layers) 104 | 105 | def forward(self, x): 106 | return self.layers(x) 107 | 108 | 109 | class ResNet50TP(nn.Module): 110 | def __init__(self, num_classes, loss, block, layers, last_stride=2, 111 | bnneck=True, **kwargs): 112 | self.inplanes = 64 113 | super(ResNet50TP, self).__init__() 114 | self.num_classes = num_classes 115 | self.loss = loss 116 | self.feature_dim = 2048 117 | self.num_scale = 3 # number of layers for feature extraction 118 | 119 | # backbone network 120 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 121 | self.bn1 = nn.BatchNorm2d(64) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 124 | self.layer1 = self._make_layer(block, 64, layers[0]) 125 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 126 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 127 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 128 | 129 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 130 | 131 | self.bottleneck = nn.BatchNorm1d(self.feature_dim) 132 | self.bottleneck.bias.requires_grad_(False) 133 | self.classifier = nn.Linear(self.feature_dim, num_classes, bias=False) 134 | 135 | self._init_params() 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | nn.Conv2d(self.inplanes, planes * block.expansion, 142 | kernel_size=1, stride=stride, bias=False), 143 | nn.BatchNorm2d(planes * block.expansion), 144 | ) 145 | 146 | layers = [] 147 | layers.append(block(self.inplanes, planes, stride, downsample)) 148 | self.inplanes = planes * block.expansion 149 | for i in range(1, blocks): 150 | layers.append(block(self.inplanes, planes)) 151 | 152 | return nn.Sequential(*layers) 153 | 154 | def _init_params(self): 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | if m.bias is not None: 159 | nn.init.constant_(m.bias, 0) 160 | elif isinstance(m, nn.BatchNorm2d): 161 | nn.init.constant_(m.weight, 1) 162 | nn.init.constant_(m.bias, 0) 163 | elif isinstance(m, nn.BatchNorm1d): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | elif isinstance(m, nn.Linear): 167 | nn.init.normal_(m.weight, 0, 0.01) 168 | if m.bias is not None: 169 | nn.init.constant_(m.bias, 0) 170 | 171 | def _extract_feat(self, x): 172 | x = self.conv1(x) 173 | x = self.bn1(x) 174 | x = self.relu(x) 175 | x = self.maxpool(x) 176 | 177 | x = self.layer1(x) 178 | x = self.layer2(x) 179 | x = self.layer3(x) 180 | x = self.layer4(x) 181 | return x 182 | 183 | def forward(self, x, adj): 184 | b, s, c, h, w = x.size() 185 | 186 | x = x.view(b * s, c, h, w) 187 | x = self._extract_feat(x) 188 | # global feature 189 | f = self.avg_pool(x).view(b * s, self.feature_dim) 190 | bn = self.bottleneck(f) 191 | 192 | if not self.training: 193 | return bn.view(b, s, -1).mean(dim=1) 194 | f = f.view(b, s, -1).mean(dim=1) 195 | y = self.classifier(bn).view(b, s, -1).mean(dim=1) 196 | 197 | if self.loss == {'xent'}: 198 | return y 199 | elif self.loss == {'xent', 'htri'}: 200 | return y, f 201 | else: 202 | raise KeyError("Unsupported loss: {}".format(self.loss)) 203 | 204 | 205 | def init_pretrained_weights(model, model_url): 206 | """Initializes model with pretrained weights. 207 | 208 | Layers that don't match with pretrained layers in name or size are kept unchanged. 209 | """ 210 | pretrain_dict = model_zoo.load_url(model_url) 211 | model_dict = model.state_dict() 212 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 213 | model_dict.update(pretrain_dict) 214 | model.load_state_dict(model_dict) 215 | 216 | 217 | def resnet50_s1(num_classes=100, loss={'xent', 'htri'}, pretrain=True, bnneck=True, last_stride=1, **kwargs): 218 | model = ResNet50TP( 219 | num_classes=num_classes, 220 | loss=loss, 221 | block=Bottleneck, 222 | layers=[3, 4, 6, 3], 223 | bnneck=bnneck, 224 | last_stride=last_stride, 225 | **kwargs 226 | ) 227 | if pretrain: 228 | init_pretrained_weights(model, model_urls['resnet50']) 229 | return model -------------------------------------------------------------------------------- /torchreid/models/resnet_temporal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import torchvision 5 | 6 | if __package__: 7 | pass 8 | else: 9 | import sys 10 | sys.path.insert(0, '..') 11 | 12 | __all__ = ['ResNet50TP', 'ResNet50TA', 'ResNet50RNN'] 13 | 14 | class ResNet50TP(nn.Module): 15 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 16 | super(ResNet50TP, self).__init__() 17 | self.loss = loss 18 | resnet50 = torchvision.models.resnet50(pretrained=True) 19 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 20 | self.feat_dim = 2048 21 | self.classifier = nn.Linear(self.feat_dim, num_classes) 22 | 23 | def forward(self, x): 24 | b = x.size(0) 25 | t = x.size(1) 26 | x = x.view(b * t, x.size(2), x.size(3), x.size(4)) 27 | x = self.base(x) 28 | x = F.avg_pool2d(x, x.size()[2:]) 29 | x = x.view(b, t, -1) 30 | x = x.permute(0, 2, 1) 31 | f = F.avg_pool1d(x, t) 32 | f = f.view(b, self.feat_dim) 33 | if not self.training: 34 | return f 35 | y = self.classifier(f) 36 | 37 | if self.loss == {'xent'}: 38 | return y 39 | elif self.loss == {'xent', 'htri'}: 40 | return y, f 41 | elif self.loss == {'cent'}: 42 | return y, f 43 | else: 44 | raise KeyError("Unsupported loss: {}".format(self.loss)) 45 | 46 | 47 | class ResNet50TA(nn.Module): 48 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 49 | super(ResNet50TA, self).__init__() 50 | self.loss = loss 51 | resnet50 = torchvision.models.resnet50(pretrained=True) 52 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 53 | self.att_gen = 'softmax' # method for attention generation: softmax or sigmoid 54 | self.feat_dim = 2048 # feature dimension 55 | self.middle_dim = 256 # middle layer dimension 56 | self.classifier = nn.Linear(self.feat_dim, num_classes) 57 | # 7,4 cooresponds to 224, 112 input image size 58 | self.attention_conv = nn.Conv2d(self.feat_dim, self.middle_dim, [7, 4]) 59 | self.attention_tconv = nn.Conv1d(self.middle_dim, 1, 3, padding=1) 60 | 61 | def forward(self, x): 62 | b = x.size(0) 63 | t = x.size(1) 64 | x = x.view(b * t, x.size(2), x.size(3), x.size(4)) 65 | x = self.base(x) 66 | a = F.relu(self.attention_conv(x)) 67 | a = a.view(b, t, self.middle_dim) 68 | a = a.permute(0, 2, 1) 69 | a = F.relu(self.attention_tconv(a)) 70 | a = a.view(b, t) 71 | x = F.avg_pool2d(x, x.size()[2:]) 72 | if self.att_gen == 'softmax': 73 | a = F.softmax(a, dim=1) 74 | elif self.att_gen == 'sigmoid': 75 | a = F.sigmoid(a) 76 | a = F.normalize(a, p=1, dim=1) 77 | else: 78 | raise KeyError("Unsupported attention generation function: {}".format(self.att_gen)) 79 | x = x.view(b, t, -1) 80 | a = torch.unsqueeze(a, -1) 81 | a = a.expand(b, t, self.feat_dim) 82 | att_x = torch.mul(x, a) 83 | att_x = torch.sum(att_x, 1) 84 | 85 | f = att_x.view(b, self.feat_dim) 86 | if not self.training: 87 | return f 88 | y = self.classifier(f) 89 | 90 | if self.loss == {'xent'}: 91 | return y 92 | elif self.loss == {'xent', 'htri'}: 93 | return y, f 94 | elif self.loss == {'cent'}: 95 | return y, f 96 | else: 97 | raise KeyError("Unsupported loss: {}".format(self.loss)) 98 | 99 | 100 | class ResNet50RNN(nn.Module): 101 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 102 | super(ResNet50RNN, self).__init__() 103 | self.loss = loss 104 | resnet50 = torchvision.models.resnet50(pretrained=True) 105 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 106 | self.hidden_dim = 512 107 | self.feat_dim = 2048 108 | self.classifier = nn.Linear(self.hidden_dim, num_classes) 109 | self.lstm = nn.LSTM(input_size=self.feat_dim, hidden_size=self.hidden_dim, num_layers=1, batch_first=True) 110 | 111 | def forward(self, x): 112 | b = x.size(0) 113 | t = x.size(1) 114 | x = x.view(b * t, x.size(2), x.size(3), x.size(4)) 115 | x = self.base(x) 116 | x = F.avg_pool2d(x, x.size()[2:]) 117 | x = x.view(b, t, -1) 118 | output, (h_n, c_n) = self.lstm(x) 119 | output = output.permute(0, 2, 1) 120 | f = F.avg_pool1d(output, t) 121 | f = f.view(b, self.hidden_dim) 122 | if not self.training: 123 | return f 124 | y = self.classifier(f) 125 | 126 | if self.loss == {'xent'}: 127 | return y 128 | elif self.loss == {'xent', 'htri'}: 129 | return y, f 130 | elif self.loss == {'cent'}: 131 | return y, f 132 | else: 133 | raise KeyError("Unsupported loss: {}".format(self.loss)) 134 | -------------------------------------------------------------------------------- /torchreid/models/simple_sta.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | __all__ = ['simple_sta_p4'] 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | import torch.utils.model_zoo as model_zoo 10 | 11 | model_urls = { 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class DimReduceLayer(nn.Module): 96 | 97 | def __init__(self, in_channels, out_channels, nonlinear): 98 | super(DimReduceLayer, self).__init__() 99 | layers = [] 100 | layers.append(nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)) 101 | layers.append(nn.BatchNorm2d(out_channels)) 102 | 103 | if nonlinear == 'relu': 104 | layers.append(nn.ReLU(inplace=True)) 105 | elif nonlinear == 'leakyrelu': 106 | layers.append(nn.LeakyReLU(0.1)) 107 | 108 | self.layers = nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | return self.layers(x) 112 | 113 | 114 | class STA(nn.Module): 115 | """Part-based Convolutional Baseline. 116 | 117 | Reference: 118 | STA: Spatial-Temporal Attention for Large-Scale Video-based Person Re-Identification 119 | 120 | Public keys: 121 | - ``sta``. 122 | """ 123 | 124 | def __init__(self, num_classes, loss, block, layers, 125 | reduced_dim=512, 126 | nonlinear='relu', 127 | **kwargs): 128 | self.inplanes = 64 129 | super(STA, self).__init__() 130 | self.loss = loss 131 | self.feature_dim = 512 * block.expansion 132 | self.parts = 4 133 | 134 | # backbone network 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 136 | self.bn1 = nn.BatchNorm2d(64) 137 | self.relu = nn.ReLU(inplace=True) 138 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 139 | self.layer1 = self._make_layer(block, 64, layers[0]) 140 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 141 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 142 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1) 143 | 144 | # sta layers 145 | self.parts_avgpool = nn.AdaptiveAvgPool2d((self.parts, 1)) 146 | self.dropout = nn.Dropout(p=0.5) 147 | self.fc1 = nn.Sequential( 148 | nn.Linear(4096, reduced_dim, bias=False), 149 | nn.BatchNorm1d(reduced_dim), 150 | nn.ReLU() if nonlinear == 'relu' else nn.LeakyReLU(0.1) 151 | ) 152 | 153 | self.classifier = nn.Linear(reduced_dim, num_classes) 154 | 155 | self._init_params() 156 | 157 | def _make_layer(self, block, planes, blocks, stride=1): 158 | downsample = None 159 | if stride != 1 or self.inplanes != planes * block.expansion: 160 | downsample = nn.Sequential( 161 | nn.Conv2d(self.inplanes, planes * block.expansion, 162 | kernel_size=1, stride=stride, bias=False), 163 | nn.BatchNorm2d(planes * block.expansion), 164 | ) 165 | 166 | layers = [] 167 | layers.append(block(self.inplanes, planes, stride, downsample)) 168 | self.inplanes = planes * block.expansion 169 | for i in range(1, blocks): 170 | layers.append(block(self.inplanes, planes)) 171 | 172 | return nn.Sequential(*layers) 173 | 174 | def _init_params(self): 175 | for m in self.modules(): 176 | if isinstance(m, nn.Conv2d): 177 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 178 | if m.bias is not None: 179 | nn.init.constant_(m.bias, 0) 180 | elif isinstance(m, nn.BatchNorm2d): 181 | nn.init.constant_(m.weight, 1) 182 | nn.init.constant_(m.bias, 0) 183 | elif isinstance(m, nn.BatchNorm1d): 184 | nn.init.constant_(m.weight, 1) 185 | nn.init.constant_(m.bias, 0) 186 | elif isinstance(m, nn.Linear): 187 | nn.init.normal_(m.weight, 0, 0.01) 188 | if m.bias is not None: 189 | nn.init.constant_(m.bias, 0) 190 | 191 | def featuremaps(self, x): 192 | x = self.conv1(x) 193 | x = self.bn1(x) 194 | x = self.relu(x) 195 | x = self.maxpool(x) 196 | x = self.layer1(x) 197 | x = self.layer2(x) 198 | x = self.layer3(x) 199 | x = self.layer4(x) 200 | return x 201 | 202 | def forward(self, x, *args): 203 | B, S, C, H, W = x.size() 204 | x = x.view(B * S, C, H, W) 205 | f = self.featuremaps(x) 206 | _, c, h, w = f.shape 207 | 208 | v_g = self.parts_avgpool(f).view(B, S, c, self.parts) 209 | t_a = F.normalize(v_g.norm(p=2, dim=2, keepdim=True), p=1, dim=1) 210 | h_index = t_a.argmax(dim=1, keepdim=True) 211 | f_1 = v_g.gather(dim=1, index=h_index.expand((B, 1, c, self.parts))).view(B, c, self.parts) 212 | f_2 = v_g.mul(t_a).sum(dim=1) 213 | f_fuse = torch.cat([f_1, f_2], dim=1) 214 | 215 | f_g = F.adaptive_avg_pool1d(f_fuse, 1).view(B, -1) 216 | f_t = self.fc1(f_g) 217 | 218 | if not self.training: 219 | return f_t 220 | 221 | y = self.classifier(f_t) 222 | 223 | if self.loss == {'xent'}: 224 | return y 225 | elif self.loss == {'xent', 'htri'}: 226 | return y, f_t 227 | else: 228 | raise KeyError('Unsupported loss: {}'.format(self.loss)) 229 | 230 | 231 | def init_pretrained_weights(model, model_url): 232 | """Initializes model with pretrained weights. 233 | 234 | Layers that don't match with pretrained layers in name or size are kept unchanged. 235 | """ 236 | pretrain_dict = model_zoo.load_url(model_url) 237 | model_dict = model.state_dict() 238 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 239 | model_dict.update(pretrain_dict) 240 | model.load_state_dict(model_dict) 241 | 242 | 243 | def simple_sta_p4(num_classes, loss={'xent', 'htri'}, last_stride=1, pretrained=True, **kwargs): 244 | model = STA( 245 | num_classes=num_classes, 246 | loss=loss, 247 | block=Bottleneck, 248 | layers=[3, 4, 6, 3], 249 | last_stride=last_stride, 250 | reduced_dim=1024, 251 | nonlinear='relu', 252 | **kwargs 253 | ) 254 | if pretrained: 255 | print('init pretrained weights from {}'.format(model_urls['resnet50'])) 256 | init_pretrained_weights(model, model_urls['resnet50']) 257 | return model -------------------------------------------------------------------------------- /torchreid/models/sta.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | __all__ = ['sta_p4'] 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | import torchvision 10 | import numpy as np 11 | import torch.utils.model_zoo as model_zoo 12 | 13 | model_urls = { 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class DimReduceLayer(nn.Module): 98 | 99 | def __init__(self, in_channels, out_channels, nonlinear): 100 | super(DimReduceLayer, self).__init__() 101 | layers = [] 102 | layers.append(nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)) 103 | layers.append(nn.BatchNorm2d(out_channels)) 104 | 105 | if nonlinear == 'relu': 106 | layers.append(nn.ReLU(inplace=True)) 107 | elif nonlinear == 'leakyrelu': 108 | layers.append(nn.LeakyReLU(0.1)) 109 | 110 | self.layers = nn.Sequential(*layers) 111 | 112 | def forward(self, x): 113 | return self.layers(x) 114 | 115 | 116 | class STA(nn.Module): 117 | """Part-based Convolutional Baseline. 118 | 119 | Reference: 120 | STA: Spatial-Temporal Attention for Large-Scale Video-based Person Re-Identification 121 | 122 | Public keys: 123 | - ``sta``. 124 | """ 125 | 126 | def __init__(self, num_classes, loss, block, layers, 127 | reduced_dim=512, 128 | nonlinear='relu', 129 | enable_reg=False, 130 | **kwargs): 131 | self.inplanes = 64 132 | super(STA, self).__init__() 133 | self.loss = loss 134 | self.feature_dim = 512 * block.expansion 135 | self.parts = 4 136 | 137 | # backbone network 138 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 139 | self.bn1 = nn.BatchNorm2d(64) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 142 | self.layer1 = self._make_layer(block, 64, layers[0]) 143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 144 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1) 146 | 147 | # sta layers 148 | self.parts_avgpool = nn.AdaptiveAvgPool2d((self.parts, 1)) 149 | self.dropout = nn.Dropout(p=0.5) 150 | self.fc1 = nn.Sequential( 151 | nn.Linear(4096, reduced_dim, bias=False), 152 | nn.BatchNorm1d(reduced_dim), 153 | nn.ReLU() 154 | ) 155 | 156 | self.feature_dim = reduced_dim 157 | self.classifier = nn.Linear(self.feature_dim, num_classes) 158 | 159 | self._init_params() 160 | 161 | def _make_layer(self, block, planes, blocks, stride=1): 162 | downsample = None 163 | if stride != 1 or self.inplanes != planes * block.expansion: 164 | downsample = nn.Sequential( 165 | nn.Conv2d(self.inplanes, planes * block.expansion, 166 | kernel_size=1, stride=stride, bias=False), 167 | nn.BatchNorm2d(planes * block.expansion), 168 | ) 169 | 170 | layers = [] 171 | layers.append(block(self.inplanes, planes, stride, downsample)) 172 | self.inplanes = planes * block.expansion 173 | for i in range(1, blocks): 174 | layers.append(block(self.inplanes, planes)) 175 | 176 | return nn.Sequential(*layers) 177 | 178 | def _init_params(self): 179 | for m in self.modules(): 180 | if isinstance(m, nn.Conv2d): 181 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 182 | if m.bias is not None: 183 | nn.init.constant_(m.bias, 0) 184 | elif isinstance(m, nn.BatchNorm2d): 185 | nn.init.constant_(m.weight, 1) 186 | nn.init.constant_(m.bias, 0) 187 | elif isinstance(m, nn.BatchNorm1d): 188 | nn.init.constant_(m.weight, 1) 189 | nn.init.constant_(m.bias, 0) 190 | elif isinstance(m, nn.Linear): 191 | nn.init.normal_(m.weight, 0, 0.01) 192 | if m.bias is not None: 193 | nn.init.constant_(m.bias, 0) 194 | 195 | def featuremaps(self, x): 196 | x = self.conv1(x) 197 | x = self.bn1(x) 198 | x = self.relu(x) 199 | x = self.maxpool(x) 200 | x = self.layer1(x) 201 | x = self.layer2(x) 202 | x = self.layer3(x) 203 | x = self.layer4(x) 204 | return x 205 | 206 | def forward(self, x, *args): 207 | B, S, C, H, W = x.size() 208 | x = x.view(B * S, C, H, W) 209 | f = self.featuremaps(x) 210 | _, c, h, w = f.shape 211 | 212 | # attention map, first l2 normalization 213 | g_a = f.norm(p=2, dim=1, keepdim=True).view(B * S, 1, h * w) 214 | g_a = F.normalize(g_a, p=2, dim=2).view(B * S, 1, h, w) 215 | 216 | # spatial attention map, second l1 norm 217 | s_a = self.parts_avgpool(g_a).view(B, S, self.parts) 218 | 219 | # temporal attention map, third l1 norm 220 | t_a = F.normalize(s_a, p=1, dim=1) 221 | 222 | v_g = self.parts_avgpool(f).view(B, S, c, self.parts) 223 | 224 | # highest score index 225 | h_index = t_a.argmax(dim=1, keepdim=True).unsqueeze(2) 226 | 227 | # f_1 228 | f_1 = v_g.gather(dim=1, index=h_index.expand((B, 1, c, self.parts))).view(B, c, self.parts) 229 | 230 | # f_2 231 | f_2 = v_g.mul(t_a.unsqueeze(2)).sum(dim=1) 232 | 233 | # fusion 234 | f_fuse = torch.cat([f_1, f_2], dim=1) 235 | # f_fuse = f_2 236 | 237 | # GAP 238 | f_g = F.adaptive_avg_pool1d(f_fuse, 1).view(B, -1) 239 | 240 | f_t = self.fc1(f_g) 241 | 242 | if not self.training: 243 | return f_t 244 | 245 | y = self.classifier(f_t) 246 | 247 | if self.loss == {'xent'}: 248 | return y 249 | elif self.loss == {'xent', 'htri'}: 250 | # v_g = F.normalize(v_g, p=2, dim=1) 251 | return y, f_t 252 | else: 253 | raise KeyError('Unsupported loss: {}'.format(self.loss)) 254 | 255 | 256 | def init_pretrained_weights(model, model_url): 257 | """Initializes model with pretrained weights. 258 | 259 | Layers that don't match with pretrained layers in name or size are kept unchanged. 260 | """ 261 | pretrain_dict = model_zoo.load_url(model_url) 262 | model_dict = model.state_dict() 263 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 264 | model_dict.update(pretrain_dict) 265 | model.load_state_dict(model_dict) 266 | 267 | 268 | def sta_p4(num_classes, loss={'xent', 'htri'}, last_stride=1, pretrained=True, **kwargs): 269 | model = STA( 270 | num_classes=num_classes, 271 | loss=loss, 272 | block=Bottleneck, 273 | layers=[3, 4, 6, 3], 274 | last_stride=last_stride, 275 | reduced_dim=1024, 276 | nonlinear='relu', 277 | **kwargs 278 | ) 279 | if pretrained: 280 | print('init pretrained weights from {}'.format(model_urls['resnet50'])) 281 | init_pretrained_weights(model, model_urls['resnet50']) 282 | return model -------------------------------------------------------------------------------- /torchreid/optimizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import math 5 | 6 | 7 | def init_optim(optim, params, lr, weight_decay): 8 | if optim == 'adam': 9 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay) 10 | elif optim == 'amsgrad': 11 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, amsgrad=True) 12 | elif optim == 'sgd': 13 | return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay) 14 | elif optim == 'nesterov': 15 | return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=True) 16 | elif optim == 'rmsprop': 17 | return torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay) 18 | elif optim == 'adabound': 19 | return AdaBound(params, lr=lr, final_lr=100 * lr, weight_decay=weight_decay) 20 | elif optim == 'radam': 21 | return RAdam(params, lr=lr, weight_decay=weight_decay) 22 | else: 23 | raise KeyError("Unsupported optimizer: {}".format(optim)) 24 | 25 | 26 | class AdaBound(torch.optim.Optimizer): 27 | """Implements AdaBound algorithm. 28 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. 29 | Arguments: 30 | params (iterable): iterable of parameters to optimize or dicts defining 31 | parameter groups 32 | lr (float, optional): Adam learning rate (default: 1e-3) 33 | betas (Tuple[float, float], optional): coefficients used for computing 34 | running averages of gradient and its square (default: (0.9, 0.999)) 35 | final_lr (float, optional): final (SGD) learning rate (default: 0.1) 36 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3) 37 | eps (float, optional): term added to the denominator to improve 38 | numerical stability (default: 1e-8) 39 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 40 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm 41 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: 42 | https://openreview.net/forum?id=Bkg3g2R9FX 43 | """ 44 | 45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3, 46 | eps=1e-8, weight_decay=0, amsbound=False): 47 | if not 0.0 <= lr: 48 | raise ValueError("Invalid learning rate: {}".format(lr)) 49 | if not 0.0 <= eps: 50 | raise ValueError("Invalid epsilon value: {}".format(eps)) 51 | if not 0.0 <= betas[0] < 1.0: 52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 55 | if not 0.0 <= final_lr: 56 | raise ValueError("Invalid final learning rate: {}".format(final_lr)) 57 | if not 0.0 <= gamma < 1.0: 58 | raise ValueError("Invalid gamma parameter: {}".format(gamma)) 59 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps, 60 | weight_decay=weight_decay, amsbound=amsbound) 61 | super(AdaBound, self).__init__(params, defaults) 62 | 63 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) 64 | 65 | def __setstate__(self, state): 66 | super(AdaBound, self).__setstate__(state) 67 | for group in self.param_groups: 68 | group.setdefault('amsbound', False) 69 | 70 | def step(self, closure=None): 71 | """Performs a single optimization step. 72 | Arguments: 73 | closure (callable, optional): A closure that reevaluates the model 74 | and returns the loss. 75 | """ 76 | loss = None 77 | if closure is not None: 78 | loss = closure() 79 | 80 | for group, base_lr in zip(self.param_groups, self.base_lrs): 81 | for p in group['params']: 82 | if p.grad is None: 83 | continue 84 | grad = p.grad.data 85 | if grad.is_sparse: 86 | raise RuntimeError( 87 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 88 | amsbound = group['amsbound'] 89 | 90 | state = self.state[p] 91 | 92 | # State initialization 93 | if len(state) == 0: 94 | state['step'] = 0 95 | # Exponential moving average of gradient values 96 | state['exp_avg'] = torch.zeros_like(p.data) 97 | # Exponential moving average of squared gradient values 98 | state['exp_avg_sq'] = torch.zeros_like(p.data) 99 | if amsbound: 100 | # Maintains max of all exp. moving avg. of sq. grad. values 101 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 102 | 103 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 104 | if amsbound: 105 | max_exp_avg_sq = state['max_exp_avg_sq'] 106 | beta1, beta2 = group['betas'] 107 | 108 | state['step'] += 1 109 | 110 | if group['weight_decay'] != 0: 111 | grad = grad.add(group['weight_decay'], p.data) 112 | 113 | # Decay the first and second moment running average coefficient 114 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 115 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 116 | if amsbound: 117 | # Maintains the maximum of all 2nd moment running avg. till now 118 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 119 | # Use the max. for normalizing running avg. of gradient 120 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 121 | else: 122 | denom = exp_avg_sq.sqrt().add_(group['eps']) 123 | 124 | bias_correction1 = 1 - beta1 ** state['step'] 125 | bias_correction2 = 1 - beta2 ** state['step'] 126 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 127 | 128 | # Applies bounds on actual learning rate 129 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay 130 | final_lr = group['final_lr'] * group['lr'] / base_lr 131 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) 132 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) 133 | step_size = torch.full_like(denom, step_size) 134 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) 135 | 136 | p.data.add_(-step_size) 137 | 138 | return loss 139 | 140 | 141 | class RAdam(torch.optim.Optimizer): 142 | 143 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 144 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 145 | self.buffer = [[None, None, None] for ind in range(10)] 146 | super(RAdam, self).__init__(params, defaults) 147 | 148 | def __setstate__(self, state): 149 | super(RAdam, self).__setstate__(state) 150 | 151 | def step(self, closure=None): 152 | 153 | loss = None 154 | if closure is not None: 155 | loss = closure() 156 | 157 | for group in self.param_groups: 158 | 159 | for p in group['params']: 160 | if p.grad is None: 161 | continue 162 | grad = p.grad.data.float() 163 | if grad.is_sparse: 164 | raise RuntimeError('RAdam does not support sparse gradients') 165 | 166 | p_data_fp32 = p.data.float() 167 | 168 | state = self.state[p] 169 | 170 | if len(state) == 0: 171 | state['step'] = 0 172 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 173 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 174 | else: 175 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 176 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 177 | 178 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 179 | beta1, beta2 = group['betas'] 180 | 181 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 182 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 183 | 184 | state['step'] += 1 185 | buffered = self.buffer[int(state['step'] % 10)] 186 | if state['step'] == buffered[0]: 187 | N_sma, step_size = buffered[1], buffered[2] 188 | else: 189 | buffered[0] = state['step'] 190 | beta2_t = beta2 ** state['step'] 191 | N_sma_max = 2 / (1 - beta2) - 1 192 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 193 | buffered[1] = N_sma 194 | if N_sma > 5: 195 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 196 | else: 197 | step_size = group['lr'] / (1 - beta1 ** state['step']) 198 | buffered[2] = step_size 199 | 200 | if group['weight_decay'] != 0: 201 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 202 | 203 | if N_sma > 4: 204 | denom = exp_avg_sq.sqrt().add_(group['eps']) 205 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 206 | else: 207 | p_data_fp32.add_(-step_size, exp_avg) 208 | 209 | p.data.copy_(p_data_fp32) 210 | 211 | return loss 212 | -------------------------------------------------------------------------------- /torchreid/samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from collections import defaultdict 5 | import numpy as np 6 | import copy 7 | import random 8 | 9 | import torch 10 | from torch.utils.data.sampler import * 11 | 12 | 13 | class RandomSampler(RandomSampler): 14 | def __init__(self, data_source, batch_size, num_instances): 15 | super(RandomSampler, self).__init__(data_source) 16 | 17 | 18 | class RandomIdentitySampler(Sampler): 19 | """ 20 | Randomly sample N identities, then for each identity, 21 | randomly sample K instances, therefore batch size is N*K. 22 | 23 | Args: 24 | - data_source (Dataset): dataset to sample from. 25 | - num_instances (int): number of instances per identity in a batch. 26 | - batch_size (int): number of examples in a batch. 27 | """ 28 | def __init__(self, data_source, batch_size, num_instances): 29 | self.data_source = data_source 30 | self.batch_size = batch_size 31 | self.num_instances = num_instances 32 | self.num_pids_per_batch = self.batch_size // self.num_instances 33 | self.index_dic = defaultdict(list) 34 | for index, (_, pid, _) in enumerate(self.data_source): 35 | self.index_dic[pid].append(index) 36 | self.pids = list(self.index_dic.keys()) 37 | 38 | # estimate number of examples in an epoch 39 | self.length = 0 40 | for pid in self.pids: 41 | idxs = self.index_dic[pid] 42 | num = len(idxs) 43 | if num < self.num_instances: 44 | num = self.num_instances 45 | self.length += num - num % self.num_instances 46 | 47 | def __iter__(self): 48 | batch_idxs_dict = defaultdict(list) 49 | 50 | for pid in self.pids: 51 | idxs = copy.deepcopy(self.index_dic[pid]) 52 | if len(idxs) < self.num_instances: 53 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 54 | random.shuffle(idxs) 55 | batch_idxs = [] 56 | for idx in idxs: 57 | batch_idxs.append(idx) 58 | if len(batch_idxs) == self.num_instances: 59 | batch_idxs_dict[pid].append(batch_idxs) 60 | batch_idxs = [] 61 | 62 | avai_pids = copy.deepcopy(self.pids) 63 | final_idxs = [] 64 | 65 | while len(avai_pids) >= self.num_pids_per_batch: 66 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 67 | for pid in selected_pids: 68 | batch_idxs = batch_idxs_dict[pid].pop(0) 69 | final_idxs.extend(batch_idxs) 70 | if len(batch_idxs_dict[pid]) == 0: 71 | avai_pids.remove(pid) 72 | 73 | return iter(final_idxs) 74 | 75 | def __len__(self): 76 | return self.length 77 | 78 | 79 | class RandomIdentitySamplerV1(Sampler): 80 | """ 81 | Randomly sample N identities, then for each identity, 82 | randomly sample K instances, therefore batch size is N*K. 83 | 84 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 85 | 86 | Args: 87 | data_source (Dataset): dataset to sample from. 88 | num_instances (int): number of instances per identity. 89 | """ 90 | def __init__(self, data_source, num_instances=4, **kwargs): 91 | self.data_source = data_source 92 | self.num_instances = num_instances 93 | self.index_dic = defaultdict(list) 94 | for index, (_, pid, _) in enumerate(data_source): 95 | self.index_dic[pid].append(index) 96 | self.pids = list(self.index_dic.keys()) 97 | self.num_identities = len(self.pids) 98 | 99 | def __iter__(self): 100 | indices = torch.randperm(self.num_identities) 101 | ret = [] 102 | for i in indices: 103 | pid = self.pids[i] 104 | t = self.index_dic[pid] 105 | replace = False if len(t) >= self.num_instances else True 106 | t = np.random.choice(t, size=self.num_instances, replace=replace) 107 | ret.extend(t) 108 | return iter(ret) 109 | 110 | def __len__(self): 111 | return self.num_identities * self.num_instances -------------------------------------------------------------------------------- /torchreid/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weleen/AGRL.pytorch/83a37eea19365bc8e74c3b3c5fe5ad5d00d04f2c/torchreid/utils/__init__.py -------------------------------------------------------------------------------- /torchreid/utils/avgmeter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value. 7 | 8 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 9 | """ 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /torchreid/utils/iotools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import os.path as osp 5 | import errno 6 | import json 7 | import shutil 8 | from collections import OrderedDict 9 | 10 | import torch 11 | 12 | 13 | def mkdir_if_missing(directory): 14 | if not osp.exists(directory): 15 | try: 16 | os.makedirs(directory) 17 | except OSError as e: 18 | if e.errno != errno.EEXIST: 19 | raise 20 | 21 | 22 | def check_isfile(path): 23 | isfile = osp.isfile(path) 24 | if not isfile: 25 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 26 | return isfile 27 | 28 | 29 | def read_json(fpath): 30 | with open(fpath, 'r') as f: 31 | obj = json.load(f) 32 | return obj 33 | 34 | 35 | def write_json(obj, fpath): 36 | mkdir_if_missing(osp.dirname(fpath)) 37 | with open(fpath, 'w') as f: 38 | json.dump(obj, f, indent=4, separators=(',', ': ')) 39 | 40 | 41 | def save_checkpoint(state, is_best=False, fpath='checkpoint.pth.tar', remove_module_from_keys=False): 42 | if len(osp.dirname(fpath)) != 0: 43 | mkdir_if_missing(osp.dirname(fpath)) 44 | if remove_module_from_keys: 45 | # remove 'module.' in state_dict's keys 46 | state_dict = state['state_dict'] 47 | new_state_dict = OrderedDict() 48 | for k, v in state_dict.items(): 49 | if k.startswith('module.'): 50 | k = k[7:] 51 | new_state_dict[k] = v 52 | state['state_dict'] = new_state_dict 53 | torch.save(state, fpath) 54 | if is_best: 55 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) -------------------------------------------------------------------------------- /torchreid/utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import sys 4 | import os 5 | import os.path as osp 6 | import time 7 | 8 | from .iotools import mkdir_if_missing 9 | 10 | 11 | class Logger(object): 12 | """ 13 | Write console output to external text file. 14 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 15 | """ 16 | def __init__(self, fpath=None): 17 | self.console = sys.stdout 18 | self.file = None 19 | if fpath is not None: 20 | mkdir_if_missing(osp.dirname(fpath)) 21 | self.file = open(fpath, 'w') 22 | 23 | def __del__(self): 24 | self.close() 25 | 26 | def __enter__(self): 27 | pass 28 | 29 | def __exit__(self, *args): 30 | self.close() 31 | 32 | def write(self, msg): 33 | if msg.isprintable(): 34 | msg = time.asctime() + ' '+ msg 35 | self.console.write(msg) 36 | if self.file is not None: 37 | self.file.write(msg) 38 | 39 | def flush(self): 40 | self.console.flush() 41 | if self.file is not None: 42 | self.file.flush() 43 | os.fsync(self.file.fileno()) 44 | 45 | def close(self): 46 | self.console.close() 47 | if self.file is not None: 48 | self.file.close() -------------------------------------------------------------------------------- /torchreid/utils/model_complexity.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | __all__ = ['compute_model_complexity'] 3 | 4 | from collections import namedtuple, defaultdict 5 | import numpy as np 6 | import math 7 | from itertools import repeat 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | """ 14 | Utility 15 | """ 16 | def _ntuple(n): 17 | def parse(x): 18 | if isinstance(x, int): 19 | return tuple(repeat(x, n)) 20 | return x 21 | return parse 22 | 23 | _single = _ntuple(1) 24 | _pair = _ntuple(2) 25 | _triple = _ntuple(3) 26 | 27 | 28 | """ 29 | Convolution 30 | """ 31 | def hook_convNd(m, x, y): 32 | k = torch.prod(torch.Tensor(m.kernel_size)).item() 33 | cin = m.in_channels 34 | flops_per_ele = k*cin #+ (k*cin-1) 35 | if m.bias is not None: 36 | flops_per_ele += 1 37 | flops = flops_per_ele * y.numel() / m.groups 38 | return int(flops) 39 | 40 | 41 | """ 42 | Pooling 43 | """ 44 | def hook_maxpool1d(m, x, y): 45 | flops_per_ele = m.kernel_size - 1 46 | flops = flops_per_ele * y.numel() 47 | return int(flops) 48 | 49 | 50 | def hook_maxpool2d(m, x, y): 51 | k = _pair(m.kernel_size) 52 | k = torch.prod(torch.Tensor(k)).item() 53 | # ops: compare 54 | flops_per_ele = k - 1 55 | flops = flops_per_ele * y.numel() 56 | return int(flops) 57 | 58 | 59 | def hook_maxpool3d(m, x, y): 60 | k = _triple(m.kernel_size) 61 | k = torch.prod(torch.Tensor(k)).item() 62 | flops_per_ele = k - 1 63 | flops = flops_per_ele * y.numel() 64 | return int(flops) 65 | 66 | 67 | def hook_avgpool1d(m, x, y): 68 | flops_per_ele = m.kernel_size 69 | flops = flops_per_ele * y.numel() 70 | return int(flops) 71 | 72 | 73 | def hook_avgpool2d(m, x, y): 74 | k = _pair(m.kernel_size) 75 | k = torch.prod(torch.Tensor(k)).item() 76 | flops_per_ele = k 77 | flops = flops_per_ele * y.numel() 78 | return int(flops) 79 | 80 | 81 | def hook_avgpool3d(m, x, y): 82 | k = _triple(m.kernel_size) 83 | k = torch.prod(torch.Tensor(k)).item() 84 | flops_per_ele = k 85 | flops = flops_per_ele * y.numel() 86 | return int(flops) 87 | 88 | 89 | def hook_adapmaxpool1d(m, x, y): 90 | x = x[0] 91 | out_size = m.output_size 92 | k = math.ceil(x.size(2) / out_size) 93 | flops_per_ele = k - 1 94 | flops = flops_per_ele * y.numel() 95 | return int(flops) 96 | 97 | 98 | def hook_adapmaxpool2d(m, x, y): 99 | x = x[0] 100 | out_size = _pair(m.output_size) 101 | k = torch.Tensor(list(x.size()[2:])) / torch.Tensor(out_size) 102 | k = torch.prod(torch.ceil(k)).item() 103 | flops_per_ele = k - 1 104 | flops = flops_per_ele * y.numel() 105 | return int(flops) 106 | 107 | 108 | def hook_adapmaxpool3d(m, x, y): 109 | x = x[0] 110 | out_size = _triple(m.output_size) 111 | k = torch.Tensor(list(x.size()[2:])) / torch.Tensor(out_size) 112 | k = torch.prod(torch.ceil(k)).item() 113 | flops_per_ele = k - 1 114 | flops = flops_per_ele * y.numel() 115 | return int(flops) 116 | 117 | 118 | def hook_adapavgpool1d(m, x, y): 119 | x = x[0] 120 | out_size = m.output_size 121 | k = math.ceil(x.size(2) / out_size) 122 | flops_per_ele = k 123 | flops = flops_per_ele * y.numel() 124 | return int(flops) 125 | 126 | 127 | def hook_adapavgpool2d(m, x, y): 128 | x = x[0] 129 | out_size = _pair(m.output_size) 130 | k = torch.Tensor(list(x.size()[2:])) / torch.Tensor(out_size) 131 | k = torch.prod(torch.ceil(k)).item() 132 | flops_per_ele = k 133 | flops = flops_per_ele * y.numel() 134 | return int(flops) 135 | 136 | 137 | def hook_adapavgpool3d(m, x, y): 138 | x = x[0] 139 | out_size = _triple(m.output_size) 140 | k = torch.Tensor(list(x.size()[2:])) / torch.Tensor(out_size) 141 | k = torch.prod(torch.ceil(k)).item() 142 | flops_per_ele = k 143 | flops = flops_per_ele * y.numel() 144 | return int(flops) 145 | 146 | 147 | """ 148 | Non-linear activations 149 | """ 150 | def hook_relu(m, x, y): 151 | # eq: max(0, x) 152 | num_ele = y.numel() 153 | return int(num_ele) 154 | 155 | 156 | def hook_leakyrelu(m, x, y): 157 | # eq: max(0, x) + negative_slope*min(0, x) 158 | num_ele = y.numel() 159 | flops = 3 * num_ele 160 | return int(flops) 161 | 162 | 163 | """ 164 | Normalization 165 | """ 166 | def hook_batchnormNd(m, x, y): 167 | num_ele = y.numel() 168 | flops = 2 * num_ele # mean and std 169 | if m.affine: 170 | flops += 2 * num_ele # gamma and beta 171 | return int(flops) 172 | 173 | 174 | def hook_instancenormNd(m, x, y): 175 | return hook_batchnormNd(m, x, y) 176 | 177 | 178 | def hook_groupnorm(m, x, y): 179 | return hook_batchnormNd(m, x, y) 180 | 181 | 182 | def hook_layernorm(m, x, y): 183 | num_ele = y.numel() 184 | flops = 2 * num_ele # mean and std 185 | if m.elementwise_affine: 186 | flops += 2 * num_ele # gamma and beta 187 | return int(flops) 188 | 189 | 190 | """ 191 | Linear 192 | """ 193 | def hook_linear(m, x, y): 194 | flops_per_ele = m.in_features #+ (m.in_features-1) 195 | if m.bias is not None: 196 | flops_per_ele += 1 197 | flops = flops_per_ele * y.numel() 198 | return int(flops) 199 | 200 | __generic_flops_counter = { 201 | # Convolution 202 | 'Conv1d': hook_convNd, 203 | 'Conv2d': hook_convNd, 204 | 'Conv3d': hook_convNd, 205 | # Pooling 206 | 'MaxPool1d': hook_maxpool1d, 207 | 'MaxPool2d': hook_maxpool2d, 208 | 'MaxPool3d': hook_maxpool3d, 209 | 'AvgPool1d': hook_avgpool1d, 210 | 'AvgPool2d': hook_avgpool2d, 211 | 'AvgPool3d': hook_avgpool3d, 212 | 'AdaptiveMaxPool1d': hook_adapmaxpool1d, 213 | 'AdaptiveMaxPool2d': hook_adapmaxpool2d, 214 | 'AdaptiveMaxPool3d': hook_adapmaxpool3d, 215 | 'AdaptiveAvgPool1d': hook_adapavgpool1d, 216 | 'AdaptiveAvgPool2d': hook_adapavgpool2d, 217 | 'AdaptiveAvgPool3d': hook_adapavgpool3d, 218 | # Non-linear activations 219 | 'ReLU': hook_relu, 220 | 'ReLU6': hook_relu, 221 | 'LeakyReLU': hook_leakyrelu, 222 | # Normalization 223 | 'BatchNorm1d': hook_batchnormNd, 224 | 'BatchNorm2d': hook_batchnormNd, 225 | 'BatchNorm3d': hook_batchnormNd, 226 | 'InstanceNorm1d': hook_instancenormNd, 227 | 'InstanceNorm2d': hook_instancenormNd, 228 | 'InstanceNorm3d': hook_instancenormNd, 229 | 'GroupNorm': hook_groupnorm, 230 | 'LayerNorm': hook_layernorm, 231 | # Linear 232 | 'Linear': hook_linear, 233 | } 234 | 235 | 236 | __conv_linear_flops_counter = { 237 | # Convolution 238 | 'Conv1d': hook_convNd, 239 | 'Conv2d': hook_convNd, 240 | 'Conv3d': hook_convNd, 241 | # Linear 242 | 'Linear': hook_linear, 243 | } 244 | 245 | 246 | def _get_flops_counter(only_conv_linear): 247 | if only_conv_linear: 248 | return __conv_linear_flops_counter 249 | return __generic_flops_counter 250 | 251 | 252 | def compute_model_complexity(model, input, verbose=False, only_conv_linear=True): 253 | """Returns number of parameters and FLOPs. 254 | 255 | .. note:: 256 | (1) this function only provides an estimate of the theoretical time complexity 257 | rather than the actual running time which depends on implementations and hardware, 258 | and (2) the FLOPs is only counted for layers that are used at test time. This means 259 | that redundant layers such as person ID classification layer will be ignored as it 260 | is discarded when doing feature extraction. Note that the inference graph depends on 261 | how you construct the computations in ``forward()``. 262 | 263 | Args: 264 | model (nn.Module): network model. 265 | input (torch.Tensor or list of torch.Tensor): input 266 | verbose (bool, optional): shows detailed complexity of 267 | each module. Default is False. 268 | only_conv_linear (bool, optional): only considers convolution 269 | and linear layers when counting flops. Default is True. 270 | If set to False, flops of all layers will be counted. 271 | 272 | Examples:: 273 | >>> from torchreid import models, utils 274 | >>> model = models.build_model(name='resnet50', num_classes=1000) 275 | >>> num_params, flops = utils.compute_model_complexity(model, (1, 3, 256, 128), verbose=True) 276 | """ 277 | registered_handles = [] 278 | layer_list = [] 279 | layer = namedtuple('layer', ['class_name', 'params', 'flops']) 280 | 281 | def _add_hooks(m): 282 | def _has_submodule(m): 283 | return len(list(m.children()))>0 284 | 285 | def _hook(m, x, y): 286 | params = sum(p.numel() for p in m.parameters()) 287 | class_name = str(m.__class__.__name__) 288 | flops_counter = _get_flops_counter(only_conv_linear) 289 | if class_name in flops_counter: 290 | flops = flops_counter[class_name](m, x, y) 291 | else: 292 | flops = 0 293 | layer_list.append( 294 | layer( 295 | class_name=class_name, 296 | params=params, 297 | flops=flops 298 | ) 299 | ) 300 | 301 | # only consider the very basic nn layer 302 | if _has_submodule(m): 303 | return 304 | 305 | handle = m.register_forward_hook(_hook) 306 | registered_handles.append(handle) 307 | 308 | default_train_mode = model.training 309 | 310 | model.eval().apply(_add_hooks) 311 | # input = torch.rand(input_size) 312 | if next(model.parameters()).is_cuda: 313 | if isinstance(input, list): 314 | input = [inp.cuda() for inp in input] 315 | else: 316 | input = [input.cuda()] 317 | model(*input) # forward 318 | 319 | for handle in registered_handles: 320 | handle.remove() 321 | 322 | model.train(default_train_mode) 323 | 324 | if verbose: 325 | per_module_params = defaultdict(list) 326 | per_module_flops = defaultdict(list) 327 | 328 | total_params, total_flops = 0, 0 329 | 330 | for layer in layer_list: 331 | total_params += layer.params 332 | total_flops += layer.flops 333 | if verbose: 334 | per_module_params[layer.class_name].append(layer.params) 335 | per_module_flops[layer.class_name].append(layer.flops) 336 | 337 | if verbose: 338 | num_udscore = 55 339 | print(' {}'.format('-'*num_udscore)) 340 | print(' Model complexity with input size {}'.format(input[0].size())) 341 | print(' {}'.format('-'*num_udscore)) 342 | for class_name in per_module_params: 343 | params = int(np.sum(per_module_params[class_name])) 344 | flops = int(np.sum(per_module_flops[class_name])) 345 | print(' {} (params={:,}, flops={:,})'.format(class_name, params, flops)) 346 | print(' {}'.format('-'*num_udscore)) 347 | print(' Total (params={:,}, flops={:,})'.format(total_params, total_flops)) 348 | print(' {}'.format('-'*num_udscore)) 349 | 350 | return total_params, total_flops 351 | -------------------------------------------------------------------------------- /torchreid/utils/re_ranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | """ 5 | Created on Mon Jun 26 14:46:56 2017 6 | @author: luohao 7 | Modified by Houjing Huang, 2017-12-22. 8 | - This version accepts distance matrix instead of raw features. 9 | - The difference of `/` division between python 2 and 3 is handled. 10 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 11 | """ 12 | 13 | """ 14 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 15 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 16 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 17 | """ 18 | 19 | """ 20 | API 21 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 22 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 23 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 24 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 25 | Returns: 26 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 27 | """ 28 | 29 | 30 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 31 | # The following naming, e.g. gallery_num, is different from outer scope. 32 | # Don't care about it. 33 | 34 | original_dist = np.concatenate( 35 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 36 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 37 | axis=0) 38 | original_dist = np.power(original_dist, 2).astype(np.float32) 39 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 40 | V = np.zeros_like(original_dist).astype(np.float32) 41 | initial_rank = np.argsort(original_dist).astype(np.int32) 42 | 43 | query_num = q_g_dist.shape[0] 44 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 45 | all_num = gallery_num 46 | 47 | for i in range(all_num): 48 | # k-reciprocal neighbors 49 | forward_k_neigh_index = initial_rank[i,:k1+1] 50 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 51 | fi = np.where(backward_k_neigh_index==i)[0] 52 | k_reciprocal_index = forward_k_neigh_index[fi] 53 | k_reciprocal_expansion_index = k_reciprocal_index 54 | for j in range(len(k_reciprocal_index)): 55 | candidate = k_reciprocal_index[j] 56 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 57 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 58 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 59 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 60 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 61 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 62 | 63 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 64 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 65 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 66 | original_dist = original_dist[:query_num,] 67 | if k2 != 1: 68 | V_qe = np.zeros_like(V,dtype=np.float32) 69 | for i in range(all_num): 70 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 71 | V = V_qe 72 | del V_qe 73 | del initial_rank 74 | invIndex = [] 75 | for i in range(gallery_num): 76 | invIndex.append(np.where(V[:,i] != 0)[0]) 77 | 78 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 79 | 80 | 81 | for i in range(query_num): 82 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 83 | indNonZero = np.where(V[i,:] != 0)[0] 84 | indImages = [] 85 | indImages = [invIndex[ind] for ind in indNonZero] 86 | for j in range(len(indNonZero)): 87 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 88 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 89 | 90 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 91 | del original_dist 92 | del V 93 | del jaccard_dist 94 | final_dist = final_dist[:query_num,query_num:] 95 | return final_dist -------------------------------------------------------------------------------- /torchreid/utils/reidtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | import os 6 | import torch 7 | import os.path as osp 8 | import shutil 9 | 10 | from .iotools import mkdir_if_missing 11 | 12 | 13 | def calc_splits(num_split): 14 | assert (num_split & (num_split - 1)) == 0, 'num_split must be the power of 2, {} is not supported'.format(num_split) 15 | return [i for i in range(num_split, 0, -1) if num_split % i == 0] 16 | 17 | 18 | def visualize_ranked_results(distmat, dataset, save_dir='log/ranked_results', topk=20): 19 | """ 20 | Visualize ranked results 21 | 22 | Support both imgreid and vidreid 23 | 24 | Args: 25 | - distmat: distance matrix of shape (num_query, num_gallery). 26 | - dataset: has dataset.query and dataset.gallery, both are lists of (img_path, pid, camid); 27 | for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing 28 | a sequence of strings. 29 | - save_dir: directory to save output images. 30 | - topk: int, denoting top-k images in the rank list to be visualized. 31 | """ 32 | num_q, num_g = distmat.shape 33 | 34 | print("Visualizing top-{} ranks".format(topk)) 35 | print("# query: {}\n# gallery {}".format(num_q, num_g)) 36 | print("Saving images to '{}'".format(save_dir)) 37 | 38 | assert num_q == len(dataset.query) 39 | assert num_g == len(dataset.gallery) 40 | 41 | indices = np.argsort(distmat, axis=1) 42 | mkdir_if_missing(save_dir) 43 | 44 | def _cp_img_to(src, dst, rank, prefix): 45 | """ 46 | - src: image path or tuple (for vidreid) 47 | - dst: target directory 48 | - rank: int, denoting ranked position, starting from 1 49 | - prefix: string 50 | """ 51 | if isinstance(src, tuple) or isinstance(src, list): 52 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 53 | mkdir_if_missing(dst) 54 | for img_path in src: 55 | shutil.copy(img_path, dst) 56 | else: 57 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 58 | shutil.copy(src, dst) 59 | 60 | for q_idx in range(num_q): 61 | qimg_path, qpid, qcamid = dataset.query[q_idx] 62 | if isinstance(qimg_path, list) or isinstance(qimg_path, tuple): 63 | qimg_index = qimg_path[0].split('/')[-2] 64 | else: 65 | qimg_index = osp.basename(qimg_path) 66 | qdir = osp.join(save_dir, 'id' + qimg_index + '_cam' + str(qcamid)) 67 | mkdir_if_missing(qdir) 68 | _cp_img_to(qimg_path, qdir, rank=0, prefix='query') 69 | 70 | rank_idx = 1 71 | for g_idx in indices[q_idx,:]: 72 | gimg_path, gpid, gcamid = dataset.gallery[g_idx] 73 | invalid = (qpid == gpid) & (qcamid == gcamid) 74 | if not invalid: 75 | _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery') 76 | rank_idx += 1 77 | if rank_idx > topk: 78 | break 79 | 80 | print("Done") 81 | -------------------------------------------------------------------------------- /torchreid/utils/torchtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | import gc 7 | import time 8 | 9 | 10 | def cur_time(): 11 | return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) 12 | 13 | 14 | def adjust_learning_rate(optimizer, base_lr, epoch, stepsize, gamma=0.1): 15 | # decay learning rate by 'gamma' for every 'stepsize' 16 | lr = base_lr * (gamma ** (epoch // stepsize)) 17 | for param_group in optimizer.param_groups: 18 | param_group['lr'] = lr 19 | 20 | 21 | def set_bn_to_eval(m): 22 | # 1. no update for running mean and var 23 | # 2. scale and shift parameters are still trainable 24 | classname = m.__class__.__name__ 25 | if classname.find('BatchNorm') != -1: 26 | m.eval() 27 | 28 | 29 | def set_wd(optim, num): 30 | assert isinstance(num, (int, float)), '{} is not int or float'.format(num) 31 | for group in optim.param_groups: 32 | if group['weight_decay'] != num: 33 | group['weight_decay'] = num 34 | 35 | 36 | def count_num_param(model): 37 | num_param = sum(p.numel() for p in model.parameters()) / 1e+06 38 | if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module): 39 | # we ignore the classifier because it is unused at test time 40 | num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06 41 | return num_param 42 | 43 | 44 | def flip_tensor(x, dim): 45 | indices = [slice(None)] * x.dim() 46 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, 47 | dtype=torch.long, device=x.device) 48 | return x[tuple(indices)] 49 | 50 | 51 | def weights_init_kaiming(m): 52 | classname = m.__class__.__name__ 53 | if classname.find('Linear') != -1: 54 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 55 | if m.bias is not None: 56 | nn.init.constant_(m.bias, 0) 57 | elif classname.find('Conv') != -1: 58 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 59 | if m.bias is not None: 60 | nn.init.constant_(m.bias, 0) 61 | elif classname.find('BatchNorm') != -1: 62 | if m.affine: 63 | nn.init.normal_(m.weight, 1.0, 0.001) 64 | nn.init.constant_(m.bias, 0.0) 65 | 66 | 67 | def weights_init_xavier(m): 68 | classname = m.__class__.__name__ 69 | if classname.find('Linear') != -1: 70 | nn.init.xavier_normal_(m.weight) 71 | if m.bias is not None: 72 | nn.init.constant_(m.bias, 0) 73 | elif classname.find('Conv') != -1: 74 | nn.init.xavier_normal_(m.weight) 75 | if m.bias is not None: 76 | nn.init.constant_(m.bias, 0) 77 | elif classname.find('BatchNorm') != -1: 78 | if m.affine: 79 | nn.init.normal_(m.weight, 1.0, 0.001) 80 | nn.init.constant_(m.bias, 0.0) 81 | 82 | 83 | def weights_init_classifier(m): 84 | classname = m.__class__.__name__ 85 | if classname.find('Linear') != -1: 86 | nn.init.normal_(m.weight.data, std=0.001) 87 | if m.bias is not None: 88 | nn.init.constant_(m.bias.data, 0.0) 89 | 90 | 91 | def mem_report(): 92 | """Report the memory usage of the tensor.storage in pytorch 93 | Both on CPUs and GPUs are reported""" 94 | 95 | def _mem_report(tensors, mem_type): 96 | '''Print the selected tensors of type 97 | There are two major storage types in our major concern: 98 | - GPU: tensors transferred to CUDA devices 99 | - CPU: tensors remaining on the system memory (usually unimportant) 100 | Args: 101 | - tensors: the tensors of specified type 102 | - mem_type: 'CPU' or 'GPU' in current implementation ''' 103 | print('Storage on %s' %(mem_type)) 104 | print('-'*LEN) 105 | total_numel = 0 106 | total_mem = 0 107 | visited_data = [] 108 | for tensor in tensors: 109 | if tensor.is_sparse: 110 | continue 111 | # a data_ptr indicates a memory block allocated 112 | data_ptr = tensor.storage().data_ptr() 113 | if data_ptr in visited_data: 114 | continue 115 | visited_data.append(data_ptr) 116 | 117 | numel = tensor.storage().size() 118 | total_numel += numel 119 | element_size = tensor.storage().element_size() 120 | mem = numel*element_size /1024/1024 # 32bit=4Byte, MByte 121 | total_mem += mem 122 | element_type = type(tensor).__name__ 123 | size = tuple(tensor.size()) 124 | 125 | print('%s\t\t%s\t\t%.2f' % ( 126 | element_type, 127 | size, 128 | mem) ) 129 | print('-'*LEN) 130 | print('Total Tensors: %d \tUsed Memory Space: %.2f MBytes' % (total_numel, total_mem) ) 131 | print('-'*LEN) 132 | 133 | LEN = 65 134 | print('='*LEN) 135 | objects = gc.get_objects() 136 | print('%s\t%s\t\t\t%s' %('Element type', 'Size', 'Used MEM(MBytes)') ) 137 | tensors = [obj for obj in objects if torch.is_tensor(obj)] 138 | cuda_tensors = [t for t in tensors if t.is_cuda] 139 | host_tensors = [t for t in tensors if not t.is_cuda] 140 | _mem_report(cuda_tensors, 'GPU') 141 | _mem_report(host_tensors, 'CPU') 142 | print('='*LEN) --------------------------------------------------------------------------------