├── .gitignore ├── README.md ├── bibtex ├── docs ├── AWESOME_REID.md ├── MODEL_ZOO.md ├── Makefile ├── conf.py ├── datasets.rst ├── evaluation.rst ├── figures │ ├── deep-person-reid-logo.png │ └── ranked_results.jpg ├── index.rst ├── pkg │ ├── data.rst │ ├── engine.rst │ ├── losses.rst │ ├── metrics.rst │ ├── models.rst │ ├── optim.rst │ └── utils.rst └── user_guide.rst ├── requirements.txt ├── scripts ├── default_parser.py ├── main.py ├── openpose_PETHZ.sh ├── openpose_market.sh ├── openpose_occluded_duke.sh └── openpose_occluded_reid.sh ├── setup.py └── torchreid ├── __init__.py ├── data ├── __init__.py ├── datamanager.py ├── datasets │ ├── __init__.py │ ├── dataset.py │ ├── image │ │ ├── __init__.py │ │ ├── cuhk01.py │ │ ├── cuhk03.py │ │ ├── dukemtmcreid.py │ │ ├── grid.py │ │ ├── ilids.py │ │ ├── market1501.py │ │ ├── msmt17.py │ │ ├── occluded_duke.py │ │ ├── occlusion_reid.py │ │ ├── p_ETHZ.py │ │ ├── partial_reid.py │ │ ├── pduke_reid.py │ │ ├── prid.py │ │ ├── sensereid.py │ │ └── viper.py │ └── video │ │ ├── __init__.py │ │ ├── dukemtmcvidreid.py │ │ ├── ilidsvid.py │ │ ├── mars.py │ │ └── prid2011.py ├── sampler.py └── transforms.py ├── engine ├── __init__.py ├── engine.py ├── image │ ├── __init__.py │ ├── softmax.py │ └── triplet.py └── video │ ├── __init__.py │ ├── softmax.py │ └── triplet.py ├── losses ├── __init__.py ├── cross_entropy_loss.py └── hard_mine_triplet_loss.py ├── metrics ├── __init__.py ├── accuracy.py ├── distance.py ├── rank.py └── rank_cylib │ ├── Makefile │ ├── __init__.py │ ├── rank_cy.pyx │ ├── setup.py │ └── test_cython.py ├── models ├── __init__.py ├── densenet.py ├── hacnn.py ├── inceptionresnetv2.py ├── inceptionv4.py ├── mlfn.py ├── mobilenetv2.py ├── mudeep.py ├── nasnet.py ├── non_local_block.py ├── osnet.py ├── pcb.py ├── resnet.py ├── resnet_.py ├── resnetmid.py ├── senet.py ├── shufflenet.py ├── shufflenetv2.py ├── squeezenet.py └── xception.py ├── optim ├── __init__.py ├── lr_scheduler.py └── optimizer.py └── utils ├── __init__.py ├── avgmeter.py ├── loggers.py ├── model_complexity.py ├── reidtools.py ├── rerank.py ├── tools.py ├── torchtools.py └── vis_featmat_cluster.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | torchreid.egg-info/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | 121 | .idea/ 122 | hhh/ 123 | imgs/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | 131 | # ReID 132 | reid-data/ 133 | log/ 134 | saved-models/ 135 | model-zoo/ 136 | debug.py 137 | 138 | # Cython eval code 139 | *.c 140 | *.html 141 | 142 | # OS X 143 | .DS_Store 144 | .Spotlight-V100 145 | .Trashes 146 | ._* 147 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CVPR2020 Pose-guided Visible Part Matching for Occluded Person ReID 2 | This is the pytorch implementation of the CVPR2020 paper *"Pose-guided Visible Part Matching for Occluded Person ReID"* 3 | 4 | ## Dependencies 5 | -Python2.7\ 6 | -Pytorch 1.0\ 7 | -Numpy 8 | 9 | ## Related Project 10 | Our code is based on [deep-person-reid](https://github.com/KaiyangZhou/deep-person-reid). We adopt [openpose](https://github.com/CMU-Perceptual-Computing-Lab/openpose) to extract pose landmarks and part affinity fields. 11 | 12 | ## Dataset Preparation 13 | Download the raw datasets [Occluded-REID, P-DukeMTMC-reID](https://github.com/tinajia2012/ICME2018_Occluded-Person-Reidentification_datasets), and [Partial-Reid](https://pan.baidu.com/s/1VhPUVJOLvkhgbJiUoEnJWg) (code:zdl8) which is released by [Partial Person Re-identification](https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Zheng_Partial_Person_Re-Identification_ICCV_2015_paper.html). Instructions regarding how to prepare [Market1501](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Zheng_Scalable_Person_Re-Identification_ICCV_2015_paper.pdf) datasets can be found [here](https://kaiyangzhou.github.io/deep-person-reid/datasets.html). And then place them under the directory like: 14 | 15 | ``` 16 | PVPM_experiments/data/ 17 | ├── ICME2018_Occluded-Person-Reidentification_datasets 18 | │   ├── Occluded_Duke 19 | │ └── Occluded_REID 20 | ├── Market-1501-v15.09.15 21 | └── Partial-REID_Dataset 22 | ``` 23 | 24 | ## Pose extraction 25 | Install openopse as described [here](https://github.com/CMU-Perceptual-Computing-Lab/openpose).\ 26 | Change path to your own dataset root and run sh files in /scripts: 27 | ``` 28 | sh openpose_occluded_reid.sh 29 | sh openpose_market.sh 30 | ``` 31 | Extracted Pose information can be found [here](https://pan.baidu.com/s/1Majze1iFo7FytREijmQO5A)(code:iwlz) 32 | 33 | ## To Train PCB baseline 34 | 35 | ``` 36 | python scripts/main.py --root PATH_TO_DATAROOT \ 37 | -s market1501 -t market1501\ 38 | --save-dir PATH_TO_EXPERIMENT_FOLDER/market_PCB\ 39 | -a pcb_p6 --gpu-devices 0 --fixbase-epoch 0\ 40 | --open-layers classifier fc\ 41 | --new-layers classifier em\ 42 | --transforms random_flip\ 43 | --optim sgd --lr 0.02\ 44 | --stepsize 25 50\ 45 | --staged-lr --height 384 --width 128\ 46 | --batch-size 32 --base-lr-mult 0.5 47 | ``` 48 | ## To train PVPM 49 | ``` 50 | python scripts/main.py --load-pose --root PATH_TO_DATAROOT 51 | -s market1501\ 52 | -t occlusion_reid p_duke partial_reid\ 53 | --save-dir PATH_TO_EXPERIMENT_FOLDER/PVPM\ 54 | -a pose_p6s --gpu-devices 0\ 55 | --fixbase-epoch 30\ 56 | --open-layers pose_subnet\ 57 | --new-layers pose_subnet\ 58 | --transforms random_flip\ 59 | --optim sgd --lr 0.02\ 60 | --stepsize 15 25 --staged-lr\ 61 | --height 384 --width 128\ 62 | --batch-size 32\ 63 | --start-eval 20\ 64 | --eval-freq 10\ 65 | --load-weights PATH_TO_EXPERIMENT_FOLDER/market_PCB/model.pth.tar-60\ 66 | --train-sampler RandomIdentitySampler\ 67 | --reg-matching-score-epoch 0\ 68 | --graph-matching 69 | --max-epoch 30 70 | --part-score 71 | ``` 72 | Trained PCB model and PVPM model can be found [here](https://pan.baidu.com/s/16lr8m-wv-XOXACqIthC8lw)(code:64zy) 73 | 74 | # Citation 75 | If you find this code useful to your research, please cite the following paper: 76 | >@inproceedings{gao2020pose, 77 | title={Pose-guided Visible Part Matching for Occluded Person ReID}, 78 | author={Gao, Shang and Wang, Jingya and Lu, Huchuan and Liu, Zimo}, 79 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 80 | pages={11744--11752}, 81 | year={2020} 82 | } 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /bibtex: -------------------------------------------------------------------------------- 1 | @inproceedings{PVPM-CVPR2020, 2 | author = {Shang Gao and Jingya Wang and Huchuan Lu and Zimo Liu}, 3 | title = {Pose-guided Visible Part Matching for Occluded Person ReID}, 4 | booktitle = {CVPR}, 5 | year = {2020} 6 | } -------------------------------------------------------------------------------- /docs/AWESOME_REID.md: -------------------------------------------------------------------------------- 1 | # Awesome-ReID 2 | Here is a collection of ReID-related research with links to papers and code. You are welcome to submit [PR](https://help.github.com/articles/creating-a-pull-request/)s if you find something missing. 3 | 4 | ## Conferences 5 | - **[CVPR 2019](#cvpr-2019)** 6 | - **[AAAI 2019](#aaai-2019)** 7 | - **[NeurIPS 2018](#neurips-2018)** 8 | - **[ECCV 2018](#eccv-2018)** 9 | - **[CVPR 2018](#cvpr-2018)** 10 | - **[ArXiv](#arxiv)** 11 | 12 | ### CVPR 2019 13 | - Joint Discriminative and Generative Learning for Person Re-identification. [[paper](https://arxiv.org/abs/1904.07223)][[code](https://github.com/NVlabs/DG-Net)] 14 | - Invariance Matters: Exemplar Memory for Domain Adaptive Person Re-identification. [[paper](https://arxiv.org/abs/1904.01990)][[code](https://github.com/zhunzhong07/ECN)] 15 | - Dissecting Person Re-identification from the Viewpoint of Viewpoint. [[paper](https://arxiv.org/abs/1812.02162)][[code](https://github.com/sxzrt/Dissecting-Person-Re-ID-from-the-Viewpoint-of-Viewpoint)] 16 | - Unsupervised Person Re-identification by Soft Multilabel Learning. [[paper](https://arxiv.org/abs/1903.06325)][[code](https://github.com/KovenYu/MAR)] 17 | - Patch-based Discriminative Feature Learning for Unsupervised Person Re-identification. [[paper](https://kovenyu.com/publication/2019-cvpr-pedal/)][[code](https://github.com/QizeYang/PAUL)] 18 | 19 | 20 | ### AAAI 2019 21 | - Spatial and Temporal Mutual Promotion for Video-based Person Re-identification. [[paper](https://arxiv.org/abs/1812.10305)][[code](https://github.com/yolomax/person-reid-lib)] 22 | - Spatial-Temporal Person Re-identification. [[paper](https://arxiv.org/abs/1812.03282)][[code](https://github.com/Wanggcong/Spatial-Temporal-Re-identification)] 23 | - Horizontal Pyramid Matching for Person Re-identification. [[paper](https://arxiv.org/abs/1804.05275)][[code](https://github.com/OasisYang/HPM)] 24 | - Backbone Can Not be Trained at Once: Rolling Back to Pre-trained Network for Person Re-identification. [[paper](https://arxiv.org/abs/1901.06140)][[code](https://github.com/youngminPIL/rollback)] 25 | - A Bottom-Up Clustering Approach to Unsupervised Person Re-identification. [[paper](https://vana77.github.io/vana77.github.io/images/AAAI19.pdf)][[code](https://github.com/vana77/Bottom-up-Clustering-Person-Re-identification)] 26 | 27 | ### NeurIPS 2018 28 | - FD-GAN: Pose-guided Feature Distilling GAN for Robust Person Re-identification. [[paper](https://arxiv.org/abs/1810.02936)][[code](https://github.com/yxgeee/FD-GAN)] 29 | 30 | ### ECCV 2018 31 | - Generalizing A Person Retrieval Model Hetero- and Homogeneously. [[paper](http://openaccess.thecvf.com/content_ECCV_2018/papers/Zhun_Zhong_Generalizing_A_Person_ECCV_2018_paper.pdf)][[code](https://github.com/zhunzhong07/HHL)] 32 | - Pose-Normalized Image Generation for Person Re-identification. [[paper](https://arxiv.org/abs/1712.02225)][[code](https://github.com/naiq/PN_GAN)] 33 | 34 | ### CVPR 2018 35 | - Camera Style Adaptation for Person Re-Identification. [[paper](https://arxiv.org/abs/1711.10295)][[code](https://github.com/zhunzhong07/CamStyle)] 36 | - Deep Group-Shuffling Random Walk for Person Re-Identification. [[paper](https://arxiv.org/abs/1807.11178)][[code](https://github.com/YantaoShen/kpm_rw_person_reid)] 37 | - End-to-End Deep Kronecker-Product Matching for Person Re-identification. [[paper](https://arxiv.org/abs/1807.11182)][[code](https://github.com/YantaoShen/kpm_rw_person_reid)] 38 | - Features for Multi-Target Multi-Camera Tracking and Re-Identification. [[paper](https://arxiv.org/abs/1803.10859)][[code](https://github.com/ergysr/DeepCC)] 39 | - Group Consistent Similarity Learning via Deep CRF for Person Re-Identification. [[paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Chen_Group_Consistent_Similarity_CVPR_2018_paper.pdf)][[code](https://github.com/dapengchen123/crf_affinity)] 40 | - Harmonious Attention Network for Person Re-Identification. [[paper](https://arxiv.org/abs/1802.08122)][[code](https://github.com/KaiyangZhou/deep-person-reid)] 41 | - Human Semantic Parsing for Person Re-Identification. [[paper](https://arxiv.org/abs/1804.00216)][[code](https://github.com/emrahbasaran/SPReID)] 42 | - Multi-Level Factorisation Net for Person Re-Identification. [[paper](https://arxiv.org/abs/1803.09132)][[code](https://github.com/KaiyangZhou/deep-person-reid)] 43 | - Resource Aware Person Re-identification across Multiple Resolutions. [[paper](https://arxiv.org/abs/1805.08805)][[code](https://github.com/mileyan/DARENet)] 44 | - Exploit the Unknown Gradually: One-Shot Video-Based Person Re-Identification by Stepwise Learning. [[paper](https://yu-wu.net/pdf/CVPR2018_Exploit-Unknown-Gradually.pdf)][[code](https://github.com/Yu-Wu/Exploit-Unknown-Gradually)] 45 | 46 | ### ArXiv 47 | - Revisiting Temporal Modeling for Video-based Person ReID. [[paper](https://arxiv.org/abs/1805.02104)][[code](https://github.com/jiyanggao/Video-Person-ReID)] 48 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('..')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = u'torchreid' 23 | copyright = u'2019, Kaiyang Zhou' 24 | author = u'Kaiyang Zhou' 25 | 26 | version_file = '../torchreid/__init__.py' 27 | with open(version_file, 'r') as f: 28 | exec(compile(f.read(), version_file, 'exec')) 29 | __version__ = locals()['__version__'] 30 | 31 | # The short X.Y version 32 | version = __version__ 33 | # The full version, including alpha/beta/rc tags 34 | release = __version__ 35 | 36 | 37 | # -- General configuration --------------------------------------------------- 38 | 39 | # If your documentation needs a minimal Sphinx version, state it here. 40 | # 41 | # needs_sphinx = '1.0' 42 | 43 | # Add any Sphinx extension module names here, as strings. They can be 44 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 45 | # ones. 46 | extensions = [ 47 | 'sphinx.ext.autodoc', 48 | 'sphinxcontrib.napoleon', 49 | 'sphinx.ext.viewcode', 50 | 'sphinx.ext.githubpages', 51 | 'sphinx_markdown_tables', 52 | ] 53 | 54 | # Add any paths that contain templates here, relative to this directory. 55 | templates_path = ['_templates'] 56 | 57 | # The suffix(es) of source filenames. 58 | # You can specify multiple suffix as a list of string: 59 | # 60 | source_suffix = ['.rst', '.md'] 61 | #source_suffix = '.rst' 62 | source_parsers = {'.md': 'recommonmark.parser.CommonMarkParser'} 63 | 64 | # The master toctree document. 65 | master_doc = 'index' 66 | 67 | # The language for content autogenerated by Sphinx. Refer to documentation 68 | # for a list of supported languages. 69 | # 70 | # This is also used if you do content translation via gettext catalogs. 71 | # Usually you set "language" from the command line for these cases. 72 | language = None 73 | 74 | # List of patterns, relative to source directory, that match files and 75 | # directories to ignore when looking for source files. 76 | # This pattern also affects html_static_path and html_extra_path. 77 | exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store'] 78 | 79 | # The name of the Pygments (syntax highlighting) style to use. 80 | pygments_style = None 81 | 82 | 83 | # -- Options for HTML output ------------------------------------------------- 84 | 85 | # The theme to use for HTML and HTML Help pages. See the documentation for 86 | # a list of builtin themes. 87 | # 88 | html_theme = 'sphinx_rtd_theme' 89 | 90 | # Theme options are theme-specific and customize the look and feel of a theme 91 | # further. For a list of options available for each theme, see the 92 | # documentation. 93 | # 94 | # html_theme_options = {} 95 | 96 | # Add any paths that contain custom static files (such as style sheets) here, 97 | # relative to this directory. They are copied after the builtin static files, 98 | # so a file named "default.css" will overwrite the builtin "default.css". 99 | html_static_path = ['_static'] 100 | 101 | # Custom sidebar templates, must be a dictionary that maps document names 102 | # to template names. 103 | # 104 | # The default sidebars (for documents that don't match any pattern) are 105 | # defined by theme itself. Builtin themes are using these templates by 106 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 107 | # 'searchbox.html']``. 108 | # 109 | # html_sidebars = {} 110 | 111 | 112 | # -- Options for HTMLHelp output --------------------------------------------- 113 | 114 | # Output file base name for HTML help builder. 115 | htmlhelp_basename = 'torchreiddoc' 116 | 117 | 118 | # -- Options for LaTeX output ------------------------------------------------ 119 | 120 | latex_elements = { 121 | # The paper size ('letterpaper' or 'a4paper'). 122 | # 123 | # 'papersize': 'letterpaper', 124 | 125 | # The font size ('10pt', '11pt' or '12pt'). 126 | # 127 | # 'pointsize': '10pt', 128 | 129 | # Additional stuff for the LaTeX preamble. 130 | # 131 | # 'preamble': '', 132 | 133 | # Latex figure (float) alignment 134 | # 135 | # 'figure_align': 'htbp', 136 | } 137 | 138 | # Grouping the document tree into LaTeX files. List of tuples 139 | # (source start file, target name, title, 140 | # author, documentclass [howto, manual, or own class]). 141 | latex_documents = [ 142 | (master_doc, 'torchreid.tex', u'torchreid Documentation', 143 | u'Kaiyang Zhou', 'manual'), 144 | ] 145 | 146 | 147 | # -- Options for manual page output ------------------------------------------ 148 | 149 | # One entry per manual page. List of tuples 150 | # (source start file, name, description, authors, manual section). 151 | man_pages = [ 152 | (master_doc, 'torchreid', u'torchreid Documentation', 153 | [author], 1) 154 | ] 155 | 156 | 157 | # -- Options for Texinfo output ---------------------------------------------- 158 | 159 | # Grouping the document tree into Texinfo files. List of tuples 160 | # (source start file, target name, title, author, 161 | # dir menu entry, description, category) 162 | texinfo_documents = [ 163 | (master_doc, 'torchreid', u'torchreid Documentation', 164 | author, 'torchreid', 'One line description of project.', 165 | 'Miscellaneous'), 166 | ] 167 | 168 | 169 | # -- Options for Epub output ------------------------------------------------- 170 | 171 | # Bibliographic Dublin Core info. 172 | epub_title = project 173 | 174 | # The unique identifier of the text. This can be a ISBN number 175 | # or the project homepage. 176 | # 177 | # epub_identifier = '' 178 | 179 | # A unique identification for the text. 180 | # 181 | # epub_uid = '' 182 | 183 | # A list of files that should not be packed into the epub file. 184 | epub_exclude_files = ['search.html'] 185 | 186 | 187 | # -- Extension configuration ------------------------------------------------- 188 | -------------------------------------------------------------------------------- /docs/evaluation.rst: -------------------------------------------------------------------------------- 1 | Evaluation 2 | ========== 3 | 4 | Image ReID 5 | ----------- 6 | - **Market1501**, **DukeMTMC-reID**, **CUHK03 (767/700 split)** and **MSMT17** have fixed split so keeping ``split_id=0`` is fine. 7 | - **CUHK03 (classic split)** has 20 fixed splits, so do ``split_id=0~19``. 8 | - **VIPeR** contains 632 identities each with 2 images under two camera views. Evaluation should be done for 10 random splits. Each split randomly divides 632 identities to 316 train ids (632 images) and the other 316 test ids (632 images). Note that, in each random split, there are two sub-splits, one using camera-A as query and camera-B as gallery while the other one using camera-B as query and camera-A as gallery. Thus, there are totally 20 splits generated with ``split_id`` starting from 0 to 19. Models can be trained on ``split_id=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]`` (because ``split_id=0`` and ``split_id=1`` share the same train set, and so on and so forth.). At test time, models trained on ``split_id=0`` can be directly evaluated on ``split_id=1``, models trained on ``split_id=2`` can be directly evaluated on ``split_id=3``, and so on and so forth. 9 | - **CUHK01** is similar to VIPeR in the split generation. 10 | - **GRID** , **iLIDS** and **PRID** have 10 random splits, so evaluation should be done by varying ``split_id`` from 0 to 9. 11 | - **SenseReID** has no training images and is used for evaluation only. 12 | 13 | 14 | .. note:: 15 | The ``split_id`` argument is defined in ``ImageDataManager`` and ``VideoDataManager``. Please refer to :ref:`torchreid_data`. 16 | 17 | 18 | Video ReID 19 | ----------- 20 | - **MARS** and **DukeMTMC-VideoReID** have fixed single split so using ``split_id=0`` is ok. 21 | - **iLIDS-VID** and **PRID2011** have 10 predefined splits so evaluation should be done by varying ``split_id`` from 0 to 9. -------------------------------------------------------------------------------- /docs/figures/deep-person-reid-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hh23333/PVPM/9587276e13f497553bf6f801c3297ee0ce050c77/docs/figures/deep-person-reid-logo.png -------------------------------------------------------------------------------- /docs/figures/ranked_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hh23333/PVPM/9587276e13f497553bf6f801c3297ee0ce050c77/docs/figures/ranked_results.jpg -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | user_guide 8 | datasets 9 | evaluation 10 | 11 | .. toctree:: 12 | :caption: Package Reference 13 | :hidden: 14 | 15 | pkg/data 16 | pkg/engine 17 | pkg/losses 18 | pkg/metrics 19 | pkg/models 20 | pkg/optim 21 | pkg/utils 22 | 23 | .. toctree:: 24 | :caption: Resources 25 | :hidden: 26 | 27 | AWESOME_REID.md 28 | MODEL_ZOO.md 29 | 30 | 31 | Indices and tables 32 | ================== 33 | 34 | * :ref:`genindex` 35 | * :ref:`modindex` -------------------------------------------------------------------------------- /docs/pkg/data.rst: -------------------------------------------------------------------------------- 1 | .. _torchreid_data: 2 | 3 | torchreid.data 4 | ============== 5 | 6 | 7 | Data Manager 8 | --------------------------- 9 | 10 | .. automodule:: torchreid.data.datamanager 11 | :members: 12 | 13 | 14 | Sampler 15 | ----------------------- 16 | 17 | .. automodule:: torchreid.data.sampler 18 | :members: 19 | 20 | 21 | Transforms 22 | --------------------------- 23 | 24 | .. automodule:: torchreid.data.transforms 25 | :members: 26 | 27 | 28 | Dataset 29 | --------------------------- 30 | 31 | .. automodule:: torchreid.data.datasets.dataset 32 | :members: 33 | 34 | 35 | .. automodule:: torchreid.data.datasets.__init__ 36 | :members: 37 | 38 | 39 | Image Datasets 40 | ------------------------------ 41 | 42 | .. automodule:: torchreid.data.datasets.image.market1501 43 | :members: 44 | 45 | .. automodule:: torchreid.data.datasets.image.cuhk03 46 | :members: 47 | 48 | .. automodule:: torchreid.data.datasets.image.dukemtmcreid 49 | :members: 50 | 51 | .. automodule:: torchreid.data.datasets.image.msmt17 52 | :members: 53 | 54 | .. automodule:: torchreid.data.datasets.image.viper 55 | :members: 56 | 57 | .. automodule:: torchreid.data.datasets.image.grid 58 | :members: 59 | 60 | .. automodule:: torchreid.data.datasets.image.cuhk01 61 | :members: 62 | 63 | .. automodule:: torchreid.data.datasets.image.ilids 64 | :members: 65 | 66 | .. automodule:: torchreid.data.datasets.image.sensereid 67 | :members: 68 | 69 | .. automodule:: torchreid.data.datasets.image.prid 70 | :members: 71 | 72 | 73 | Video Datasets 74 | ------------------------------ 75 | 76 | .. automodule:: torchreid.data.datasets.video.mars 77 | :members: 78 | 79 | .. automodule:: torchreid.data.datasets.video.ilidsvid 80 | :members: 81 | 82 | .. automodule:: torchreid.data.datasets.video.prid2011 83 | :members: 84 | 85 | .. automodule:: torchreid.data.datasets.video.dukemtmcvidreid 86 | :members: -------------------------------------------------------------------------------- /docs/pkg/engine.rst: -------------------------------------------------------------------------------- 1 | .. _torchreid_engine: 2 | 3 | torchreid.engine 4 | ================== 5 | 6 | 7 | Base Engine 8 | ------------ 9 | 10 | .. autoclass:: torchreid.engine.engine.Engine 11 | :members: 12 | 13 | 14 | Image Engines 15 | ------------- 16 | 17 | .. autoclass:: torchreid.engine.image.softmax.ImageSoftmaxEngine 18 | :members: 19 | 20 | 21 | .. autoclass:: torchreid.engine.image.triplet.ImageTripletEngine 22 | :members: 23 | 24 | 25 | Video Engines 26 | ------------- 27 | 28 | .. autoclass:: torchreid.engine.video.softmax.VideoSoftmaxEngine 29 | 30 | 31 | .. autoclass:: torchreid.engine.video.triplet.VideoTripletEngine -------------------------------------------------------------------------------- /docs/pkg/losses.rst: -------------------------------------------------------------------------------- 1 | .. _torchreid_losses: 2 | 3 | torchreid.losses 4 | ================= 5 | 6 | 7 | Softmax 8 | -------- 9 | 10 | .. automodule:: torchreid.losses.cross_entropy_loss 11 | :members: 12 | 13 | 14 | Triplet 15 | ------- 16 | 17 | .. automodule:: torchreid.losses.hard_mine_triplet_loss 18 | :members: -------------------------------------------------------------------------------- /docs/pkg/metrics.rst: -------------------------------------------------------------------------------- 1 | .. _torchreid_metrics: 2 | 3 | torchreid.metrics 4 | ================= 5 | 6 | 7 | Distance 8 | --------- 9 | 10 | .. automodule:: torchreid.metrics.distance 11 | :members: 12 | 13 | 14 | Accuracy 15 | -------- 16 | 17 | .. automodule:: torchreid.metrics.accuracy 18 | :members: 19 | 20 | 21 | Rank 22 | ----- 23 | 24 | .. automodule:: torchreid.metrics.rank 25 | :members: evaluate_rank -------------------------------------------------------------------------------- /docs/pkg/models.rst: -------------------------------------------------------------------------------- 1 | .. _torchreid_models: 2 | 3 | torchreid.models 4 | ================= 5 | 6 | Interface 7 | --------- 8 | 9 | .. automodule:: torchreid.models.__init__ 10 | :members: 11 | 12 | 13 | ImageNet Classification Models 14 | ------------------------------- 15 | 16 | .. autoclass:: torchreid.models.resnet.ResNet 17 | .. autoclass:: torchreid.models.senet.SENet 18 | .. autoclass:: torchreid.models.densenet.DenseNet 19 | .. autoclass:: torchreid.models.inceptionresnetv2.InceptionResNetV2 20 | .. autoclass:: torchreid.models.inceptionv4.InceptionV4 21 | .. autoclass:: torchreid.models.xception.Xception 22 | 23 | 24 | Lightweight Models 25 | ------------------ 26 | 27 | .. autoclass:: torchreid.models.nasnet.NASNetAMobile 28 | .. autoclass:: torchreid.models.mobilenetv2.MobileNetV2 29 | .. autoclass:: torchreid.models.shufflenet.ShuffleNet 30 | .. autoclass:: torchreid.models.squeezenet.SqueezeNet 31 | .. autoclass:: torchreid.models.shufflenetv2.ShuffleNetV2 32 | 33 | 34 | ReID-specific Models 35 | -------------------- 36 | 37 | .. autoclass:: torchreid.models.mudeep.MuDeep 38 | .. autoclass:: torchreid.models.resnetmid.ResNetMid 39 | .. autoclass:: torchreid.models.hacnn.HACNN 40 | .. autoclass:: torchreid.models.pcb.PCB 41 | .. autoclass:: torchreid.models.mlfn.MLFN 42 | .. autoclass:: torchreid.models.osnet.OSNet -------------------------------------------------------------------------------- /docs/pkg/optim.rst: -------------------------------------------------------------------------------- 1 | .. _torchreid_optim: 2 | 3 | torchreid.optim 4 | ================= 5 | 6 | 7 | Optimizer 8 | ---------- 9 | 10 | .. automodule:: torchreid.optim.optimizer 11 | :members: build_optimizer 12 | 13 | 14 | LR Scheduler 15 | ------------- 16 | 17 | .. automodule:: torchreid.optim.lr_scheduler 18 | :members: build_lr_scheduler -------------------------------------------------------------------------------- /docs/pkg/utils.rst: -------------------------------------------------------------------------------- 1 | .. _torchreid_utils: 2 | 3 | torchreid.utils 4 | ================= 5 | 6 | Average Meter 7 | -------------- 8 | 9 | .. automodule:: torchreid.utils.avgmeter 10 | :members: 11 | 12 | 13 | Loggers 14 | ------- 15 | 16 | .. automodule:: torchreid.utils.loggers 17 | :members: 18 | 19 | 20 | Generic Tools 21 | --------------- 22 | .. automodule:: torchreid.utils.tools 23 | :members: 24 | 25 | 26 | ReID Tools 27 | ---------- 28 | 29 | .. automodule:: torchreid.utils.reidtools 30 | :members: 31 | 32 | 33 | Torch Tools 34 | ------------ 35 | 36 | .. automodule:: torchreid.utils.torchtools 37 | :members: 38 | 39 | 40 | .. automodule:: torchreid.utils.model_complexity 41 | :members: 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | Cython 3 | h5py 4 | Pillow 5 | six 6 | scipy 7 | -------------------------------------------------------------------------------- /scripts/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path as osp 4 | import warnings 5 | import time 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from default_parser import ( 11 | init_parser, imagedata_kwargs, videodata_kwargs, 12 | optimizer_kwargs, lr_scheduler_kwargs, engine_run_kwargs 13 | ) 14 | import torchreid 15 | from torchreid.utils import ( 16 | Logger, set_random_seed, check_isfile, resume_from_checkpoint, 17 | load_pretrained_weights, compute_model_complexity, collect_env_info 18 | ) 19 | 20 | 21 | parser = init_parser() 22 | args = parser.parse_args() 23 | 24 | 25 | def build_datamanager(args): 26 | if args.app == 'image': 27 | return torchreid.data.ImageDataManager(**imagedata_kwargs(args)) 28 | else: 29 | return torchreid.data.VideoDataManager(**videodata_kwargs(args)) 30 | 31 | 32 | def build_engine(args, datamanager, model, optimizer, scheduler): 33 | if args.app == 'image': 34 | if args.loss == 'softmax': 35 | if args.load_pose: 36 | if args.graph_matching: 37 | engine = torchreid.engine.PoseSoftmaxEngine_wscorereg( 38 | # engine = torchreid.engine.PoseSoftmaxEngine( 39 | datamanager, 40 | model, 41 | optimizer, 42 | scheduler=scheduler, 43 | use_cpu=args.use_cpu, 44 | label_smooth=args.label_smooth, 45 | use_att_loss=args.use_att_loss, 46 | reg_matching_score_epoch = args.reg_matching_score_epoch, 47 | num_att=args.num_att 48 | ) 49 | else: 50 | engine = torchreid.engine.PoseSoftmaxEngine( 51 | datamanager, 52 | model, 53 | optimizer, 54 | scheduler=scheduler, 55 | use_cpu=args.use_cpu, 56 | label_smooth=args.label_smooth, 57 | ) 58 | else: 59 | engine = torchreid.engine.ImageSoftmaxEngine( 60 | datamanager, 61 | model, 62 | optimizer, 63 | scheduler=scheduler, 64 | use_cpu=args.use_cpu, 65 | label_smooth=args.label_smooth 66 | ) 67 | else: 68 | engine = torchreid.engine.ImageTripletEngine( 69 | datamanager, 70 | model, 71 | optimizer, 72 | margin=args.margin, 73 | weight_t=args.weight_t, 74 | weight_x=args.weight_x, 75 | scheduler=scheduler, 76 | use_cpu=args.use_cpu, 77 | label_smooth=args.label_smooth 78 | ) 79 | 80 | else: 81 | if args.loss == 'softmax': 82 | engine = torchreid.engine.VideoSoftmaxEngine( 83 | datamanager, 84 | model, 85 | optimizer, 86 | scheduler=scheduler, 87 | use_cpu=args.use_cpu, 88 | label_smooth=args.label_smooth, 89 | pooling_method=args.pooling_method 90 | ) 91 | else: 92 | engine = torchreid.engine.VideoTripletEngine( 93 | datamanager, 94 | model, 95 | optimizer, 96 | margin=args.margin, 97 | weight_t=args.weight_t, 98 | weight_x=args.weight_x, 99 | scheduler=scheduler, 100 | use_cpu=args.use_cpu, 101 | label_smooth=args.label_smooth 102 | ) 103 | 104 | return engine 105 | 106 | 107 | def main(): 108 | global args 109 | 110 | set_random_seed(args.seed) 111 | if not args.use_avai_gpus: 112 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 113 | use_gpu = torch.cuda.is_available() and not args.use_cpu 114 | log_name = 'test.log' if args.evaluate else 'train.log' 115 | log_name += time.strftime('-%Y-%m-%d-%H-%M-%S') 116 | sys.stdout = Logger(osp.join(args.save_dir, log_name)) 117 | print('** Arguments **') 118 | arg_keys = list(args.__dict__.keys()) 119 | arg_keys.sort() 120 | for key in arg_keys: 121 | print('{}: {}'.format(key, args.__dict__[key])) 122 | print('\n') 123 | print('Collecting env info ...') 124 | print('** System info **\n{}\n'.format(collect_env_info())) 125 | if use_gpu: 126 | torch.backends.cudnn.benchmark = True 127 | else: 128 | warnings.warn('Currently using CPU, however, GPU is highly recommended') 129 | 130 | datamanager = build_datamanager(args) 131 | 132 | print('Building model: {}'.format(args.arch)) 133 | model = torchreid.models.build_model( 134 | name=args.arch, 135 | num_classes=datamanager.num_train_pids, 136 | loss=args.loss.lower(), 137 | pretrained=(not args.no_pretrained), 138 | use_gpu=use_gpu 139 | ) 140 | # num_params, flops = compute_model_complexity(model, (1, 3, args.height, args.width)) 141 | # print('Model complexity: params={:,} flops={:,}'.format(num_params, flops)) 142 | 143 | if args.load_weights and check_isfile(args.load_weights): 144 | load_pretrained_weights(model, args.load_weights) 145 | 146 | if use_gpu: 147 | model = nn.DataParallel(model).cuda() 148 | 149 | optimizer = torchreid.optim.build_optimizer(model, **optimizer_kwargs(args)) 150 | 151 | scheduler = torchreid.optim.build_lr_scheduler(optimizer, **lr_scheduler_kwargs(args)) 152 | 153 | if args.resume and check_isfile(args.resume): 154 | args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer) 155 | 156 | print('Building {}-engine for {}-reid'.format(args.loss, args.app)) 157 | engine = build_engine(args, datamanager, model, optimizer, scheduler) 158 | 159 | engine.run(**engine_run_kwargs(args)) 160 | 161 | 162 | if __name__ == '__main__': 163 | main() -------------------------------------------------------------------------------- /scripts/openpose_PETHZ.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | cd /media/hh/disc_d/hh/code/openpose-master 3 | path=/media/hh/disc_d/datasets/ICME2018_Occluded-Person-Reidentification_datasets/P_ETHZ/ 4 | image_dir=whole_body_images/ 5 | output_dir=whole_body_pose 6 | # image_dir=occluded_body_images/ 7 | # output_dir=occluded_body_pose 8 | files=$(ls ${path}${image_dir}) 9 | chmod +xw ./build/examples/openpose/openpose.bin 10 | for dir in ${files} 11 | do 12 | # mogrify -path ${path}${image_dir}${dir} -format jpg ${path}${image_dir}${dir}/*.tif 13 | ./build/examples/openpose/openpose.bin \ 14 | --image_dir ${path}${image_dir}${dir} \ 15 | --write_images ${path}${output_dir} \ 16 | --model_pose COCO \ 17 | --write_json ${path}${output_dir} \ 18 | --heatmaps_add_parts true \ 19 | --heatmaps_add_PAFs true \ 20 | --write_heatmaps ${path}${output_dir} \ 21 | --net_resolution -1x384 22 | --display 0 23 | done -------------------------------------------------------------------------------- /scripts/openpose_market.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | cd PATH_TO_openpose-master/ 3 | path=PATH_TO_YOUR_DATASET_ROOT/Market-1501-v15.09.15/ 4 | 5 | image_dir=bounding_box_train/ 6 | output_dir=bounding_box_pose_train 7 | 8 | ./build/examples/openpose/openpose.bin \ 9 | --image_dir ${path}${image_dir} \ 10 | --write_images ${path}${output_dir} \ 11 | --model_pose COCO \ 12 | --write_json ${path}${output_dir} \ 13 | --heatmaps_add_parts true \ 14 | --heatmaps_add_PAFs true \ 15 | --write_heatmaps ${path}${output_dir} \ 16 | --net_resolution -1x384 17 | 18 | image_dir=bounding_box_test/ 19 | output_dir=bounding_box_pose_test 20 | 21 | ./build/examples/openpose/openpose.bin \ 22 | --image_dir ${path}${image_dir} \ 23 | --write_images ${path}${output_dir} \ 24 | --model_pose COCO \ 25 | --write_json ${path}${output_dir} \ 26 | --heatmaps_add_parts true \ 27 | --heatmaps_add_PAFs true \ 28 | --write_heatmaps ${path}${output_dir} \ 29 | --net_resolution -1x384 30 | 31 | image_dir=query/ 32 | output_dir=query_pose 33 | 34 | ./build/examples/openpose/openpose.bin \ 35 | --image_dir ${path}${image_dir} \ 36 | --write_images ${path}${output_dir} \ 37 | --model_pose COCO \ 38 | --write_json ${path}${output_dir} \ 39 | --heatmaps_add_parts true \ 40 | --heatmaps_add_PAFs true \ 41 | --write_heatmaps ${path}${output_dir} \ 42 | --net_resolution -1x384 43 | -------------------------------------------------------------------------------- /scripts/openpose_occluded_duke.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | cd /media/hh/disc_d/hh/code/openpose-master 3 | path=/media/hh/disc_d/datasets/ICME2018_Occluded-Person-Reidentification_datasets/Occluded_Duke/ 4 | for split in 'bounding_box_train' 'bounding_box_test' 'query' 5 | do 6 | # image_dir = occluded_body_images 7 | output_dir=${split}'_pose' 8 | echo ${output_dir} 9 | chmod +xw ./build/examples/openpose/openpose.bin 10 | ./build/examples/openpose/openpose.bin \ 11 | --image_dir ${path}${split} \ 12 | --write_images ${path}${output_dir} \ 13 | --model_pose COCO \ 14 | --write_json ${path}${output_dir} \ 15 | --heatmaps_add_parts true \ 16 | --heatmaps_add_PAFs true \ 17 | --write_heatmaps ${path}${output_dir} \ 18 | --net_resolution -1x384 19 | --display 0 20 | done 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /scripts/openpose_occluded_reid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | cd PATH_TO_openpose-master/ 3 | path=PATH_TO_YOUR_DATASET_ROOT/ICME2018_Occluded-Person-Reidentification_datasets/Occluded_REID/ 4 | image_dir=whole_body_images/ 5 | output_dir=whole_body_pose 6 | # image_dir = occluded_body_images 7 | files=$(ls ${path}${image_dir}) 8 | chmod +xw ./build/examples/openpose/openpose.bin 9 | for dir in ${files} 10 | do 11 | mogrify -path ${path}${image_dir}${dir} -format jpg ${path}${image_dir}${dir}/*.tif 12 | ./build/examples/openpose/openpose.bin \ 13 | --image_dir ${path}${image_dir}${dir} \ 14 | --write_images ${path}${output_dir} \ 15 | --model_pose COCO \ 16 | --write_json ${path}${output_dir} \ 17 | --heatmaps_add_parts true \ 18 | --heatmaps_add_PAFs true \ 19 | --write_heatmaps ${path}${output_dir} \ 20 | --net_resolution -1x384 21 | done 22 | 23 | image_dir=occluded_body_images/ 24 | output_dir=occluded_body_pose 25 | files=$(ls ${path}${image_dir}) 26 | chmod +xw ./build/examples/openpose/openpose.bin 27 | for dir in ${files} 28 | do 29 | mogrify -path ${path}${image_dir}${dir} -format jpg ${path}${image_dir}${dir}/*.tif 30 | ./build/examples/openpose/openpose.bin \ 31 | --image_dir ${path}${image_dir}${dir} \ 32 | --write_images ${path}${output_dir} \ 33 | --model_pose COCO \ 34 | --write_json ${path}${output_dir} \ 35 | --heatmaps_add_parts true \ 36 | --heatmaps_add_PAFs true \ 37 | --write_heatmaps ${path}${output_dir} \ 38 | --net_resolution -1x384 39 | done 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | from distutils.extension import Extension 3 | from Cython.Build import cythonize 4 | import numpy as np 5 | 6 | 7 | def readme(): 8 | with open('README.rst') as f: 9 | content = f.read() 10 | return content 11 | 12 | 13 | def find_version(): 14 | version_file = 'torchreid/__init__.py' 15 | with open(version_file, 'r') as f: 16 | exec(compile(f.read(), version_file, 'exec')) 17 | return locals()['__version__'] 18 | 19 | 20 | def numpy_include(): 21 | try: 22 | numpy_include = np.get_include() 23 | except AttributeError: 24 | numpy_include = np.get_numpy_include() 25 | return numpy_include 26 | 27 | 28 | ext_modules = [ 29 | Extension( 30 | 'torchreid.metrics.rank_cylib.rank_cy', 31 | ['torchreid/metrics/rank_cylib/rank_cy.pyx'], 32 | include_dirs=[numpy_include()], 33 | ) 34 | ] 35 | 36 | 37 | setup( 38 | name='torchreid', 39 | version=find_version(), 40 | description='Pytorch framework for deep-learning person re-identification', 41 | author='Kaiyang Zhou', 42 | author_email='k.zhou.vision@gmail.com', 43 | license='MIT', 44 | long_description=readme(), 45 | url='https://github.com/KaiyangZhou/deep-person-reid', 46 | packages=find_packages(), 47 | install_requires=[ 48 | 'numpy', 49 | 'Cython', 50 | 'h5py', 51 | 'Pillow', 52 | 'six', 53 | 'scipy>=1.0.0', 54 | 'torch>=0.4.1', 55 | 'torchvision>=0.2.1' 56 | ], 57 | keywords=[ 58 | 'Person Re-Identification', 59 | 'Deep Learning', 60 | 'Computer Vision' 61 | ], 62 | ext_modules=cythonize(ext_modules) 63 | ) -------------------------------------------------------------------------------- /torchreid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | __version__ = '0.8.1' 5 | __author__ = 'Kaiyang Zhou' 6 | __description__ = 'Deep learning person re-identification in PyTorch' 7 | 8 | from torchreid import ( 9 | engine, 10 | models, 11 | losses, 12 | metrics, 13 | data, 14 | optim, 15 | utils 16 | ) 17 | -------------------------------------------------------------------------------- /torchreid/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .datasets import Dataset, ImageDataset, VideoDataset 5 | from .datasets import register_image_dataset 6 | from .datasets import register_video_dataset 7 | from .datamanager import ImageDataManager, VideoDataManager -------------------------------------------------------------------------------- /torchreid/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .dataset import Dataset, ImageDataset, VideoDataset 5 | from .image import * 6 | from .video import * 7 | 8 | 9 | __image_datasets = { 10 | 'market1501': Market1501, 11 | 'market1501_oc': Market1501_simu_occluded, 12 | 'cuhk03': CUHK03, 13 | 'dukemtmcreid': DukeMTMCreID, 14 | 'msmt17': MSMT17, 15 | 'viper': VIPeR, 16 | 'grid': GRID, 17 | 'cuhk01': CUHK01, 18 | 'ilids': iLIDS, 19 | 'sensereid': SenseReID, 20 | 'prid': PRID, 21 | 'occlusion_reid': Occluded_REID, 22 | 'partial_reid':Paritial_REID, 23 | 'p_ethz':P_ETHZ, 24 | 'p_duke':P_Dukereid, 25 | 'o_duke':Occluded_duke 26 | } 27 | 28 | 29 | __video_datasets = { 30 | 'mars': Mars, 31 | 'ilidsvid': iLIDSVID, 32 | 'prid2011': PRID2011, 33 | 'dukemtmcvidreid': DukeMTMCVidReID 34 | } 35 | 36 | 37 | def init_image_dataset(name, **kwargs): 38 | """Initializes an image dataset.""" 39 | avai_datasets = list(__image_datasets.keys()) 40 | if name not in avai_datasets: 41 | raise ValueError('Invalid dataset name. Received "{}", ' 42 | 'but expected to be one of {}'.format(name, avai_datasets)) 43 | return __image_datasets[name](**kwargs) 44 | 45 | 46 | def init_video_dataset(name, **kwargs): 47 | """Initializes a video dataset.""" 48 | avai_datasets = list(__video_datasets.keys()) 49 | if name not in avai_datasets: 50 | raise ValueError('Invalid dataset name. Received "{}", ' 51 | 'but expected to be one of {}'.format(name, avai_datasets)) 52 | return __video_datasets[name](**kwargs) 53 | 54 | 55 | def register_image_dataset(name, dataset): 56 | """Registers a new image dataset. 57 | 58 | Args: 59 | name (str): key corresponding to the new dataset. 60 | dataset (Dataset): the new dataset class. 61 | 62 | Examples:: 63 | 64 | import torchreid 65 | import NewDataset 66 | torchreid.data.register_image_dataset('new_dataset', NewDataset) 67 | # single dataset case 68 | datamanager = torchreid.data.ImageDataManager( 69 | root='reid-data', 70 | sources='new_dataset' 71 | ) 72 | # multiple dataset case 73 | datamanager = torchreid.data.ImageDataManager( 74 | root='reid-data', 75 | sources=['new_dataset', 'dukemtmcreid'] 76 | ) 77 | """ 78 | global __image_datasets 79 | curr_datasets = list(__image_datasets.keys()) 80 | if name in curr_datasets: 81 | raise ValueError('The given name already exists, please choose ' 82 | 'another name excluding {}'.format(curr_datasets)) 83 | __image_datasets[name] = dataset 84 | 85 | 86 | def register_video_dataset(name, dataset): 87 | """Registers a new video dataset. 88 | 89 | Args: 90 | name (str): key corresponding to the new dataset. 91 | dataset (Dataset): the new dataset class. 92 | 93 | Examples:: 94 | 95 | import torchreid 96 | import NewDataset 97 | torchreid.data.register_video_dataset('new_dataset', NewDataset) 98 | # single dataset case 99 | datamanager = torchreid.data.VideoDataManager( 100 | root='reid-data', 101 | sources='new_dataset' 102 | ) 103 | # multiple dataset case 104 | datamanager = torchreid.data.VideoDataManager( 105 | root='reid-data', 106 | sources=['new_dataset', 'ilidsvid'] 107 | ) 108 | """ 109 | global __video_datasets 110 | curr_datasets = list(__video_datasets.keys()) 111 | if name in curr_datasets: 112 | raise ValueError('The given name already exists, please choose ' 113 | 'another name excluding {}'.format(curr_datasets)) 114 | __video_datasets[name] = dataset 115 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .market1501 import Market1501, Market1501_simu_occluded 5 | from .dukemtmcreid import DukeMTMCreID 6 | from .cuhk03 import CUHK03 7 | from .msmt17 import MSMT17 8 | from .viper import VIPeR 9 | from .grid import GRID 10 | from .cuhk01 import CUHK01 11 | from .ilids import iLIDS 12 | from .sensereid import SenseReID 13 | from .prid import PRID 14 | from .occlusion_reid import Occluded_REID 15 | from .p_ETHZ import P_ETHZ 16 | from .partial_reid import Paritial_REID 17 | from .pduke_reid import P_Dukereid 18 | from .occluded_duke import Occluded_duke -------------------------------------------------------------------------------- /torchreid/data/datasets/image/cuhk01.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import zipfile 10 | import numpy as np 11 | 12 | from torchreid.data.datasets import ImageDataset 13 | from torchreid.utils import read_json, write_json 14 | 15 | 16 | class CUHK01(ImageDataset): 17 | """CUHK01. 18 | 19 | Reference: 20 | Li et al. Human Reidentification with Transferred Metric Learning. ACCV 2012. 21 | 22 | URL: ``_ 23 | 24 | Dataset statistics: 25 | - identities: 971. 26 | - images: 3884. 27 | - cameras: 4. 28 | """ 29 | dataset_dir = 'cuhk01' 30 | dataset_url = None 31 | 32 | def __init__(self, root='', split_id=0, **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.zip_path = osp.join(self.dataset_dir, 'CUHK01.zip') 38 | self.campus_dir = osp.join(self.dataset_dir, 'campus') 39 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 40 | 41 | self.extract_file() 42 | 43 | required_files = [ 44 | self.dataset_dir, 45 | self.campus_dir 46 | ] 47 | self.check_before_run(required_files) 48 | 49 | self.prepare_split() 50 | splits = read_json(self.split_path) 51 | if split_id >= len(splits): 52 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 53 | split = splits[split_id] 54 | 55 | train = split['train'] 56 | query = split['query'] 57 | gallery = split['gallery'] 58 | 59 | train = [tuple(item) for item in train] 60 | query = [tuple(item) for item in query] 61 | gallery = [tuple(item) for item in gallery] 62 | 63 | super(CUHK01, self).__init__(train, query, gallery, **kwargs) 64 | 65 | def extract_file(self): 66 | if not osp.exists(self.campus_dir): 67 | print('Extracting files') 68 | zip_ref = zipfile.ZipFile(self.zip_path, 'r') 69 | zip_ref.extractall(self.dataset_dir) 70 | zip_ref.close() 71 | 72 | def prepare_split(self): 73 | """ 74 | Image name format: 0001001.png, where first four digits represent identity 75 | and last four digits represent cameras. Camera 1&2 are considered the same 76 | view and camera 3&4 are considered the same view. 77 | """ 78 | if not osp.exists(self.split_path): 79 | print('Creating 10 random splits of train ids and test ids') 80 | img_paths = sorted(glob.glob(osp.join(self.campus_dir, '*.png'))) 81 | img_list = [] 82 | pid_container = set() 83 | for img_path in img_paths: 84 | img_name = osp.basename(img_path) 85 | pid = int(img_name[:4]) - 1 86 | camid = (int(img_name[4:7]) - 1) // 2 # result is either 0 or 1 87 | img_list.append((img_path, pid, camid)) 88 | pid_container.add(pid) 89 | 90 | num_pids = len(pid_container) 91 | num_train_pids = num_pids // 2 92 | 93 | splits = [] 94 | for _ in range(10): 95 | order = np.arange(num_pids) 96 | np.random.shuffle(order) 97 | train_idxs = order[:num_train_pids] 98 | train_idxs = np.sort(train_idxs) 99 | idx2label = {idx: label for label, idx in enumerate(train_idxs)} 100 | 101 | train, test_a, test_b = [], [], [] 102 | for img_path, pid, camid in img_list: 103 | if pid in train_idxs: 104 | train.append((img_path, idx2label[pid], camid)) 105 | else: 106 | if camid == 0: 107 | test_a.append((img_path, pid, camid)) 108 | else: 109 | test_b.append((img_path, pid, camid)) 110 | 111 | # use cameraA as query and cameraB as gallery 112 | split = { 113 | 'train': train, 114 | 'query': test_a, 115 | 'gallery': test_b, 116 | 'num_train_pids': num_train_pids, 117 | 'num_query_pids': num_pids - num_train_pids, 118 | 'num_gallery_pids': num_pids - num_train_pids 119 | } 120 | splits.append(split) 121 | 122 | # use cameraB as query and cameraA as gallery 123 | split = { 124 | 'train': train, 125 | 'query': test_b, 126 | 'gallery': test_a, 127 | 'num_train_pids': num_train_pids, 128 | 'num_query_pids': num_pids - num_train_pids, 129 | 'num_gallery_pids': num_pids - num_train_pids 130 | } 131 | splits.append(split) 132 | 133 | print('Totally {} splits are created'.format(len(splits))) 134 | write_json(splits, self.split_path) 135 | print('Split file saved to {}'.format(self.split_path)) 136 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | 11 | from torchreid.data.datasets import ImageDataset 12 | 13 | 14 | class DukeMTMCreID(ImageDataset): 15 | """DukeMTMC-reID. 16 | 17 | Reference: 18 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 1404 (train + query). 25 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 26 | - cameras: 8. 27 | """ 28 | dataset_dir = 'DukeMTMC-reID' 29 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 30 | 31 | def __init__(self, root='', **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 36 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 37 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 38 | 39 | required_files = [ 40 | self.dataset_dir, 41 | self.train_dir, 42 | self.query_dir, 43 | self.gallery_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | train = self.process_dir(self.train_dir, relabel=True) 48 | query = self.process_dir(self.query_dir, relabel=False) 49 | gallery = self.process_dir(self.gallery_dir, relabel=False) 50 | 51 | super(DukeMTMCreID, self).__init__(train, query, gallery, **kwargs) 52 | 53 | def process_dir(self, dir_path, relabel=False): 54 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 55 | pattern = re.compile(r'([-\d]+)_c(\d)') 56 | 57 | pid_container = set() 58 | for img_path in img_paths: 59 | pid, _ = map(int, pattern.search(img_path).groups()) 60 | pid_container.add(pid) 61 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 62 | 63 | data = [] 64 | for img_path in img_paths: 65 | pid, camid = map(int, pattern.search(img_path).groups()) 66 | assert 1 <= camid <= 8 67 | camid -= 1 # index starts from 0 68 | if relabel: pid = pid2label[pid] 69 | data.append((img_path, pid, camid)) 70 | 71 | return data -------------------------------------------------------------------------------- /torchreid/data/datasets/image/grid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | from scipy.io import loadmat 10 | 11 | from torchreid.data.datasets import ImageDataset 12 | from torchreid.utils import read_json, write_json 13 | 14 | 15 | class GRID(ImageDataset): 16 | """GRID. 17 | 18 | Reference: 19 | Loy et al. Multi-camera activity correlation analysis. CVPR 2009. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 250. 25 | - images: 1275. 26 | - cameras: 8. 27 | """ 28 | dataset_dir = 'grid' 29 | dataset_url = 'http://personal.ie.cuhk.edu.hk/~ccloy/files/datasets/underground_reid.zip' 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.probe_path = osp.join(self.dataset_dir, 'underground_reid', 'probe') 37 | self.gallery_path = osp.join(self.dataset_dir, 'underground_reid', 'gallery') 38 | self.split_mat_path = osp.join(self.dataset_dir, 'underground_reid', 'features_and_partitions.mat') 39 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 40 | 41 | required_files = [ 42 | self.dataset_dir, 43 | self.probe_path, 44 | self.gallery_path, 45 | self.split_mat_path 46 | ] 47 | self.check_before_run(required_files) 48 | 49 | self.prepare_split() 50 | splits = read_json(self.split_path) 51 | if split_id >= len(splits): 52 | raise ValueError('split_id exceeds range, received {}, ' 53 | 'but expected between 0 and {}'.format(split_id, len(splits)-1)) 54 | split = splits[split_id] 55 | 56 | train = split['train'] 57 | query = split['query'] 58 | gallery = split['gallery'] 59 | 60 | train = [tuple(item) for item in train] 61 | query = [tuple(item) for item in query] 62 | gallery = [tuple(item) for item in gallery] 63 | 64 | super(GRID, self).__init__(train, query, gallery, **kwargs) 65 | 66 | def prepare_split(self): 67 | if not osp.exists(self.split_path): 68 | print('Creating 10 random splits') 69 | split_mat = loadmat(self.split_mat_path) 70 | trainIdxAll = split_mat['trainIdxAll'][0] # length = 10 71 | probe_img_paths = sorted(glob.glob(osp.join(self.probe_path, '*.jpeg'))) 72 | gallery_img_paths = sorted(glob.glob(osp.join(self.gallery_path, '*.jpeg'))) 73 | 74 | splits = [] 75 | for split_idx in range(10): 76 | train_idxs = trainIdxAll[split_idx][0][0][2][0].tolist() 77 | assert len(train_idxs) == 125 78 | idx2label = {idx: label for label, idx in enumerate(train_idxs)} 79 | 80 | train, query, gallery = [], [], [] 81 | 82 | # processing probe folder 83 | for img_path in probe_img_paths: 84 | img_name = osp.basename(img_path) 85 | img_idx = int(img_name.split('_')[0]) 86 | camid = int(img_name.split('_')[1]) - 1 # index starts from 0 87 | if img_idx in train_idxs: 88 | train.append((img_path, idx2label[img_idx], camid)) 89 | else: 90 | query.append((img_path, img_idx, camid)) 91 | 92 | # process gallery folder 93 | for img_path in gallery_img_paths: 94 | img_name = osp.basename(img_path) 95 | img_idx = int(img_name.split('_')[0]) 96 | camid = int(img_name.split('_')[1]) - 1 # index starts from 0 97 | if img_idx in train_idxs: 98 | train.append((img_path, idx2label[img_idx], camid)) 99 | else: 100 | gallery.append((img_path, img_idx, camid)) 101 | 102 | split = { 103 | 'train': train, 104 | 'query': query, 105 | 'gallery': gallery, 106 | 'num_train_pids': 125, 107 | 'num_query_pids': 125, 108 | 'num_gallery_pids': 900 109 | } 110 | splits.append(split) 111 | 112 | print('Totally {} splits are created'.format(len(splits))) 113 | write_json(splits, self.split_path) 114 | print('Split file saved to {}'.format(self.split_path)) -------------------------------------------------------------------------------- /torchreid/data/datasets/image/ilids.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import numpy as np 10 | import copy 11 | import random 12 | from collections import defaultdict 13 | 14 | from torchreid.data.datasets import ImageDataset 15 | from torchreid.utils import read_json, write_json 16 | 17 | 18 | class iLIDS(ImageDataset): 19 | """QMUL-iLIDS. 20 | 21 | Reference: 22 | Zheng et al. Associating Groups of People. BMVC 2009. 23 | 24 | Dataset statistics: 25 | - identities: 119. 26 | - images: 476. 27 | - cameras: 8 (not explicitly provided). 28 | """ 29 | dataset_dir = 'ilids' 30 | dataset_url = 'http://www.eecs.qmul.ac.uk/~jason/data/i-LIDS_Pedestrian.tgz' 31 | 32 | def __init__(self, root='', split_id=0, **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.data_dir = osp.join(self.dataset_dir, 'i-LIDS_Pedestrian/Persons') 38 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.data_dir 43 | ] 44 | self.check_before_run(required_files) 45 | 46 | self.prepare_split() 47 | splits = read_json(self.split_path) 48 | if split_id >= len(splits): 49 | raise ValueError('split_id exceeds range, received {}, but ' 50 | 'expected between 0 and {}'.format(split_id, len(splits)-1)) 51 | split = splits[split_id] 52 | 53 | train, query, gallery = self.process_split(split) 54 | 55 | super(iLIDS, self).__init__(train, query, gallery, **kwargs) 56 | 57 | def prepare_split(self): 58 | if not osp.exists(self.split_path): 59 | print('Creating splits ...') 60 | 61 | paths = glob.glob(osp.join(self.data_dir, '*.jpg')) 62 | img_names = [osp.basename(path) for path in paths] 63 | num_imgs = len(img_names) 64 | assert num_imgs == 476, 'There should be 476 images, but ' \ 65 | 'got {}, please check the data'.format(num_imgs) 66 | 67 | # store image names 68 | # image naming format: 69 | # the first four digits denote the person ID 70 | # the last four digits denote the sequence index 71 | pid_dict = defaultdict(list) 72 | for img_name in img_names: 73 | pid = int(img_name[:4]) 74 | pid_dict[pid].append(img_name) 75 | pids = list(pid_dict.keys()) 76 | num_pids = len(pids) 77 | assert num_pids == 119, 'There should be 119 identities, ' \ 78 | 'but got {}, please check the data'.format(num_pids) 79 | 80 | num_train_pids = int(num_pids * 0.5) 81 | num_test_pids = num_pids - num_train_pids # supposed to be 60 82 | 83 | splits = [] 84 | for _ in range(10): 85 | # randomly choose num_train_pids train IDs and num_test_pids test IDs 86 | pids_copy = copy.deepcopy(pids) 87 | random.shuffle(pids_copy) 88 | train_pids = pids_copy[:num_train_pids] 89 | test_pids = pids_copy[num_train_pids:] 90 | 91 | train = [] 92 | query = [] 93 | gallery = [] 94 | 95 | # for train IDs, all images are used in the train set. 96 | for pid in train_pids: 97 | img_names = pid_dict[pid] 98 | train.extend(img_names) 99 | 100 | # for each test ID, randomly choose two images, one for 101 | # query and the other one for gallery. 102 | for pid in test_pids: 103 | img_names = pid_dict[pid] 104 | samples = random.sample(img_names, 2) 105 | query.append(samples[0]) 106 | gallery.append(samples[1]) 107 | 108 | split = {'train': train, 'query': query, 'gallery': gallery} 109 | splits.append(split) 110 | 111 | print('Totally {} splits are created'.format(len(splits))) 112 | write_json(splits, self.split_path) 113 | print('Split file is saved to {}'.format(self.split_path)) 114 | 115 | def get_pid2label(self, img_names): 116 | pid_container = set() 117 | for img_name in img_names: 118 | pid = int(img_name[:4]) 119 | pid_container.add(pid) 120 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 121 | return pid2label 122 | 123 | def parse_img_names(self, img_names, pid2label=None): 124 | data = [] 125 | 126 | for img_name in img_names: 127 | pid = int(img_name[:4]) 128 | if pid2label is not None: 129 | pid = pid2label[pid] 130 | camid = int(img_name[4:7]) - 1 # 0-based 131 | img_path = osp.join(self.data_dir, img_name) 132 | data.append((img_path, pid, camid)) 133 | 134 | return data 135 | 136 | def process_split(self, split): 137 | train, query, gallery = [], [], [] 138 | train_pid2label = self.get_pid2label(split['train']) 139 | train = self.parse_img_names(split['train'], train_pid2label) 140 | query = self.parse_img_names(split['query']) 141 | gallery = self.parse_img_names(split['gallery']) 142 | return train, query, gallery -------------------------------------------------------------------------------- /torchreid/data/datasets/image/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | 9 | from torchreid.data.datasets import ImageDataset 10 | 11 | 12 | ##### Log ##### 13 | # 22.01.2019 14 | # - add v2 15 | # - v1 and v2 differ in dir names 16 | # - note that faces in v2 are blurred 17 | TRAIN_DIR_KEY = 'train_dir' 18 | TEST_DIR_KEY = 'test_dir' 19 | VERSION_DICT = { 20 | 'MSMT17_V1': { 21 | TRAIN_DIR_KEY: 'train', 22 | TEST_DIR_KEY: 'test', 23 | }, 24 | 'MSMT17_V2': { 25 | TRAIN_DIR_KEY: 'mask_train_v2', 26 | TEST_DIR_KEY: 'mask_test_v2', 27 | } 28 | } 29 | 30 | 31 | class MSMT17(ImageDataset): 32 | """MSMT17. 33 | 34 | Reference: 35 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 36 | 37 | URL: ``_ 38 | 39 | Dataset statistics: 40 | - identities: 4101. 41 | - images: 32621 (train) + 11659 (query) + 82161 (gallery). 42 | - cameras: 15. 43 | """ 44 | dataset_dir = 'msmt17' 45 | dataset_url = None 46 | 47 | def __init__(self, root='', **kwargs): 48 | self.root = osp.abspath(osp.expanduser(root)) 49 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 50 | self.download_dataset(self.dataset_dir, self.dataset_url) 51 | 52 | has_main_dir = False 53 | for main_dir in VERSION_DICT: 54 | if osp.exists(osp.join(self.dataset_dir, main_dir)): 55 | train_dir = VERSION_DICT[main_dir][TRAIN_DIR_KEY] 56 | test_dir = VERSION_DICT[main_dir][TEST_DIR_KEY] 57 | has_main_dir = True 58 | break 59 | assert has_main_dir, 'Dataset folder not found' 60 | 61 | self.train_dir = osp.join(self.dataset_dir, main_dir, train_dir) 62 | self.test_dir = osp.join(self.dataset_dir, main_dir, test_dir) 63 | self.list_train_path = osp.join(self.dataset_dir, main_dir, 'list_train.txt') 64 | self.list_val_path = osp.join(self.dataset_dir, main_dir, 'list_val.txt') 65 | self.list_query_path = osp.join(self.dataset_dir, main_dir, 'list_query.txt') 66 | self.list_gallery_path = osp.join(self.dataset_dir, main_dir, 'list_gallery.txt') 67 | 68 | required_files = [ 69 | self.dataset_dir, 70 | self.train_dir, 71 | self.test_dir 72 | ] 73 | self.check_before_run(required_files) 74 | 75 | train = self.process_dir(self.train_dir, self.list_train_path) 76 | val = self.process_dir(self.train_dir, self.list_val_path) 77 | query = self.process_dir(self.test_dir, self.list_query_path) 78 | gallery = self.process_dir(self.test_dir, self.list_gallery_path) 79 | 80 | # Note: to fairly compare with published methods on the conventional ReID setting, 81 | # do not add val images to the training set. 82 | if 'combineall' in kwargs and kwargs['combineall']: 83 | train += val 84 | 85 | super(MSMT17, self).__init__(train, query, gallery, **kwargs) 86 | 87 | def process_dir(self, dir_path, list_path): 88 | with open(list_path, 'r') as txt: 89 | lines = txt.readlines() 90 | 91 | data = [] 92 | 93 | for img_idx, img_info in enumerate(lines): 94 | img_path, pid = img_info.split(' ') 95 | pid = int(pid) # no need to relabel 96 | camid = int(img_path.split('_')[2]) - 1 # index starts from 0 97 | img_path = osp.join(dir_path, img_path) 98 | data.append((img_path, pid, camid)) 99 | 100 | return data -------------------------------------------------------------------------------- /torchreid/data/datasets/image/occluded_duke.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | import warnings 11 | 12 | from torchreid.data.datasets import ImageDataset 13 | from torchreid.utils import read_image 14 | import cv2 15 | import numpy as np 16 | 17 | class Occluded_duke(ImageDataset): 18 | 19 | def __init__(self, root='', **kwargs): 20 | dataset_dir = 'ICME2018_Occluded-Person-Reidentification_datasets/Occluded_Duke' 21 | self.root=osp.abspath(osp.expanduser(root)) 22 | # self.dataset_dir = self.root 23 | data_dir = osp.join(self.root, dataset_dir) 24 | if osp.isdir(data_dir): 25 | self.data_dir = data_dir 26 | else: 27 | warnings.warn('The current data structure is deprecated.') 28 | self.train_dir = osp.join(self.data_dir, 'bounding_box_train') 29 | self.query_dir = osp.join(self.data_dir, 'query') 30 | self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test') 31 | 32 | train = self.process_dir(self.train_dir, relabel=True) 33 | query = self.process_dir(self.query_dir, relabel=False) 34 | gallery = self.process_dir(self.gallery_dir, relabel=False) 35 | super(Occluded_duke, self).__init__(train, query, gallery, **kwargs) 36 | self.load_pose = isinstance(self.transform, tuple) 37 | if self.load_pose: 38 | self.train_pose_dir = osp.join(self.data_dir, 'bounding_box_train_pose') 39 | self.gallery_pose_dir = osp.join(self.data_dir, 'bounding_box_test_pose') 40 | self.query_pose_dir = osp.join(self.data_dir, 'query_pose') 41 | if self.mode == 'train': 42 | self.pose_dir = self.train_pose_dir 43 | elif self.mode == 'query': 44 | self.pose_dir = self.query_pose_dir 45 | elif self.mode == 'gallery': 46 | self.pose_dir = self.gallery_pose_dir 47 | else: 48 | raise ValueError('Invalid mode. Got {}, but expected to be ' 49 | 'one of [train | query | gallery]'.format(self.mode)) 50 | 51 | def process_dir(self, dir_path, relabel=False): 52 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 53 | pattern = re.compile(r'([-\d]+)_c(\d)') 54 | 55 | pid_container = set() 56 | for img_path in img_paths: 57 | pid, _ = map(int, pattern.search(img_path).groups()) 58 | pid_container.add(pid) 59 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 60 | 61 | data = [] 62 | for img_path in img_paths: 63 | pid, camid = map(int, pattern.search(img_path).groups()) 64 | assert 1 <= camid <= 8 65 | camid -= 1 # index starts from 0 66 | if relabel: pid = pid2label[pid] 67 | data.append((img_path, pid, camid)) 68 | 69 | return data 70 | 71 | def __getitem__(self, index): 72 | img_path, pid, camid = self.data[index] 73 | img = read_image(img_path) 74 | 75 | if self.load_pose: 76 | img_name = '.'.join(img_path.split('/')[-1].split('.')[:-1]) 77 | pose_pic_name = img_name + '_pose_heatmaps.png' 78 | pose_pic_path = os.path.join(self.pose_dir, pose_pic_name) 79 | pose = cv2.imread(pose_pic_path, cv2.IMREAD_GRAYSCALE) 80 | pose = pose.reshape((pose.shape[0], 56, -1)).transpose((0,2,1)).astype('float32') 81 | pose[:,:,18:] = np.abs(pose[:,:,18:]-128) 82 | img, pose = self.transform[1](img, pose) 83 | img = self.transform[0](img) 84 | return img, pid, camid, img_path, pose 85 | else: 86 | if self.transform is not None: 87 | img = self.transform(img) 88 | return img, pid, camid, img_path -------------------------------------------------------------------------------- /torchreid/data/datasets/image/occlusion_reid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | import warnings 11 | 12 | from torchreid.data.datasets import ImageDataset 13 | from torchreid.utils import read_image 14 | import cv2 15 | import numpy as np 16 | 17 | class Occluded_REID(ImageDataset): 18 | 19 | def __init__(self, root='', **kwargs): 20 | dataset_dir = 'ICME2018_Occluded-Person-Reidentification_datasets/Occluded_REID' 21 | self.root=osp.abspath(osp.expanduser(root)) 22 | # self.dataset_dir = self.root 23 | data_dir = osp.join(self.root, dataset_dir) 24 | if osp.isdir(data_dir): 25 | self.data_dir = data_dir 26 | else: 27 | warnings.warn('The current data structure is deprecated.') 28 | self.query_dir=osp.join(self.data_dir, 'occluded_body_images') 29 | self.gallery_dir=osp.join(self.data_dir, 'whole_body_images') 30 | 31 | train = [] 32 | query = self.process_dir(self.query_dir, relabel=False) 33 | gallery = self.process_dir(self.gallery_dir, relabel=False, is_query=False) 34 | super(Occluded_REID, self).__init__(train, query, gallery, **kwargs) 35 | self.load_pose = isinstance(self.transform, tuple) 36 | if self.load_pose: 37 | if self.mode == 'query': 38 | self.pose_dir = osp.join(self.data_dir, 'occluded_body_pose') 39 | elif self.mode=='gallery': 40 | self.pose_dir = osp.join(self.data_dir, 'whole_body_pose') 41 | else: 42 | self.pose_dir='' 43 | 44 | 45 | def process_dir(self, dir_path, relabel=False, is_query=True): 46 | img_paths = glob.glob(osp.join(dir_path,'*','*.jpg')) 47 | if is_query: 48 | camid = 0 49 | else: 50 | camid = 1 51 | pid_container = set() 52 | for img_path in img_paths: 53 | img_name = img_path.split('/')[-1] 54 | pid = int(img_name.split('_')[0]) 55 | pid_container.add(pid) 56 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 57 | 58 | data = [] 59 | for img_path in img_paths: 60 | img_name = img_path.split('/')[-1] 61 | pid = int(img_name.split('_')[0]) 62 | if relabel: 63 | pid = pid2label[pid] 64 | data.append((img_path, pid, camid)) 65 | return data 66 | 67 | def __getitem__(self, index): 68 | img_path, pid, camid = self.data[index] 69 | img = read_image(img_path) 70 | 71 | if self.load_pose: 72 | img_name = '.'.join(img_path.split('/')[-1].split('.')[:-1]) 73 | pose_pic_name = img_name + '_pose_heatmaps.png' 74 | pose_pic_path = os.path.join(self.pose_dir, pose_pic_name) 75 | pose = cv2.imread(pose_pic_path, cv2.IMREAD_GRAYSCALE) 76 | pose = pose.reshape((pose.shape[0], 56, -1)).transpose((0,2,1)).astype('float32') 77 | pose[:,:,18:] = np.abs(pose[:,:,18:]-128) 78 | img, pose = self.transform[1](img, pose) 79 | img = self.transform[0](img) 80 | return img, pid, camid, img_path, pose 81 | else: 82 | if self.transform is not None: 83 | img = self.transform(img) 84 | return img, pid, camid, img_path 85 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/p_ETHZ.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | import warnings 11 | 12 | from torchreid.data.datasets import ImageDataset 13 | from torchreid.utils import read_image 14 | import cv2 15 | import numpy as np 16 | 17 | class P_ETHZ(ImageDataset): 18 | 19 | def __init__(self, root='', **kwargs): 20 | dataset_dir = 'ICME2018_Occluded-Person-Reidentification_datasets/P_ETHZ' 21 | self.root=osp.abspath(osp.expanduser(root)) 22 | # self.dataset_dir = self.root 23 | data_dir = osp.join(self.root, dataset_dir) 24 | if osp.isdir(data_dir): 25 | self.data_dir = data_dir 26 | else: 27 | warnings.warn('The current data structure is deprecated.') 28 | self.query_dir=osp.join(self.data_dir, 'occluded_body_images') 29 | self.gallery_dir=osp.join(self.data_dir, 'whole_body_images') 30 | 31 | train = [] 32 | query = self.process_dir(self.query_dir, relabel=False) 33 | gallery = self.process_dir(self.gallery_dir, relabel=False, is_query=False) 34 | super(P_ETHZ, self).__init__(train, query, gallery, **kwargs) 35 | self.load_pose = isinstance(self.transform, tuple) 36 | if self.load_pose: 37 | if self.mode == 'query': 38 | self.pose_dir = osp.join(self.data_dir, 'occluded_body_pose') 39 | elif self.mode=='gallery': 40 | self.pose_dir = osp.join(self.data_dir, 'whole_body_pose') 41 | else: 42 | self.pose_dir='' 43 | 44 | 45 | def process_dir(self, dir_path, relabel=False, is_query=True): 46 | img_paths = glob.glob(osp.join(dir_path,'*','*.png')) 47 | if is_query: 48 | camid = 0 49 | else: 50 | camid = 1 51 | pid_container = set() 52 | for img_path in img_paths: 53 | img_name = img_path.split('/')[-1] 54 | pid = int(img_name.split('_')[0]) 55 | pid_container.add(pid) 56 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 57 | 58 | data = [] 59 | for img_path in img_paths: 60 | img_name = img_path.split('/')[-1] 61 | pid = int(img_name.split('_')[0]) 62 | if relabel: 63 | pid = pid2label[pid] 64 | data.append((img_path, pid, camid)) 65 | return data 66 | 67 | def __getitem__(self, index): 68 | img_path, pid, camid = self.data[index] 69 | img = read_image(img_path) 70 | 71 | if self.load_pose: 72 | img_name = '.'.join(img_path.split('/')[-1].split('.')[:-1]) 73 | pose_pic_name = img_name + '_pose_heatmaps.png' 74 | pose_pic_path = os.path.join(self.pose_dir, pose_pic_name) 75 | pose = cv2.imread(pose_pic_path, cv2.IMREAD_GRAYSCALE) 76 | try: 77 | pose = pose.reshape((pose.shape[0], 56, -1)).transpose((0,2,1)).astype('float32') 78 | except: 79 | print(pose) 80 | pose[:,:,18:] = np.abs(pose[:,:,18:]-128) 81 | img, pose = self.transform[1](img, pose) 82 | img = self.transform[0](img) 83 | return img, pid, camid, img_path, pose 84 | else: 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | return img, pid, camid, img_path 88 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/partial_reid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | import warnings 11 | 12 | from torchreid.data.datasets import ImageDataset 13 | from torchreid.utils import read_image 14 | import cv2 15 | import numpy as np 16 | 17 | class Paritial_REID(ImageDataset): 18 | 19 | def __init__(self, root='', **kwargs): 20 | dataset_dir = 'Partial-REID_Dataset' 21 | self.root=osp.abspath(osp.expanduser(root)) 22 | # self.dataset_dir = self.root 23 | data_dir = osp.join(self.root, dataset_dir) 24 | if osp.isdir(data_dir): 25 | self.data_dir = data_dir 26 | else: 27 | warnings.warn('The current data structure is deprecated.') 28 | self.query_dir=osp.join(self.data_dir, 'occluded_body_images') 29 | self.gallery_dir=osp.join(self.data_dir, 'whole_body_images') 30 | 31 | train = [] 32 | query = self.process_dir(self.query_dir, relabel=False) 33 | gallery = self.process_dir(self.gallery_dir, relabel=False, is_query=False) 34 | super(Paritial_REID, self).__init__(train, query, gallery, **kwargs) 35 | self.load_pose = isinstance(self.transform, tuple) 36 | if self.load_pose: 37 | if self.mode == 'query': 38 | self.pose_dir = osp.join(self.data_dir, 'occluded_body_pose') 39 | elif self.mode=='gallery': 40 | self.pose_dir = osp.join(self.data_dir, 'whole_body_pose') 41 | else: 42 | self.pose_dir='' 43 | 44 | 45 | def process_dir(self, dir_path, relabel=False, is_query=True): 46 | img_paths = glob.glob(osp.join(dir_path,'*.jpg')) 47 | if is_query: 48 | camid = 0 49 | else: 50 | camid = 1 51 | pid_container = set() 52 | for img_path in img_paths: 53 | img_name = img_path.split('/')[-1] 54 | pid = int(img_name.split('_')[0]) 55 | pid_container.add(pid) 56 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 57 | 58 | data = [] 59 | for img_path in img_paths: 60 | img_name = img_path.split('/')[-1] 61 | pid = int(img_name.split('_')[0]) 62 | if relabel: 63 | pid = pid2label[pid] 64 | data.append((img_path, pid, camid)) 65 | return data 66 | 67 | def __getitem__(self, index): 68 | img_path, pid, camid = self.data[index] 69 | img = read_image(img_path) 70 | 71 | if self.load_pose: 72 | img_name = '.'.join(img_path.split('/')[-1].split('.')[:-1]) 73 | pose_pic_name = img_name + '_pose_heatmaps.png' 74 | pose_pic_path = os.path.join(self.pose_dir, pose_pic_name) 75 | pose = cv2.imread(pose_pic_path, cv2.IMREAD_GRAYSCALE) 76 | pose = pose.reshape((pose.shape[0], 56, -1)).transpose((0,2,1)).astype('float32') 77 | pose[:,:,18:] = np.abs(pose[:,:,18:]-128) 78 | img, pose = self.transform[1](img, pose) 79 | img = self.transform[0](img) 80 | return img, pid, camid, img_path, pose 81 | else: 82 | if self.transform is not None: 83 | img = self.transform(img) 84 | return img, pid, camid, img_path 85 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/pduke_reid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | import warnings 11 | 12 | from torchreid.data.datasets import ImageDataset 13 | from torchreid.utils import read_image 14 | import cv2 15 | import numpy as np 16 | 17 | class P_Dukereid(ImageDataset): 18 | 19 | def __init__(self, root='', **kwargs): 20 | dataset_dir = 'ICME2018_Occluded-Person-Reidentification_datasets/P-DukeMTMC-reid' 21 | self.root=osp.abspath(osp.expanduser(root)) 22 | # self.dataset_dir = self.root 23 | data_dir = osp.join(self.root, dataset_dir) 24 | if osp.isdir(data_dir): 25 | self.data_dir = data_dir 26 | else: 27 | warnings.warn('The current data structure is deprecated.') 28 | self.train_dir=osp.join(self.data_dir, 'train') 29 | self.query_dir=osp.join(self.data_dir, 'test', 'occluded_body_images') 30 | self.gallery_dir=osp.join(self.data_dir, 'test', 'whole_body_images') 31 | 32 | train = self.process_train_dir(self.train_dir, relabel=True) 33 | query = self.process_dir(self.query_dir, relabel=False) 34 | gallery = self.process_dir(self.gallery_dir, relabel=False, is_query=False) 35 | super(P_Dukereid, self).__init__(train, query, gallery, **kwargs) 36 | self.load_pose = isinstance(self.transform, tuple) 37 | if self.load_pose: 38 | if self.mode == 'query': 39 | self.pose_dir = osp.join(self.data_dir, 'test') 40 | elif self.mode=='gallery': 41 | self.pose_dir = osp.join(self.data_dir, 'test') 42 | else: 43 | self.pose_dir= osp.join(self.data_dir,'train') 44 | 45 | def process_train_dir(self, dir_path, relabel=True): 46 | img_paths = glob.glob(osp.join(dir_path,'whole_body_images','*','*.jpg')) 47 | camid=1 48 | pid_container = set() 49 | for img_path in img_paths: 50 | img_name = img_path.split('/')[-1] 51 | pid = int(img_name.split('_')[0]) 52 | pid_container.add(pid) 53 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 54 | data = [] 55 | for img_path in img_paths: 56 | img_name = img_path.split('/')[-1] 57 | pid = int(img_name.split('_')[0]) 58 | if relabel: 59 | pid = pid2label[pid] 60 | data.append((img_path, pid, camid)) 61 | img_paths = glob.glob(osp.join(dir_path,'occluded_body_images','*','*.jpg')) 62 | camid=0 63 | for img_path in img_paths: 64 | img_name = img_path.split('/')[-1] 65 | pid = int(img_name.split('_')[0]) 66 | if relabel: 67 | pid = pid2label[pid] 68 | data.append((img_path, pid, camid)) 69 | return data 70 | 71 | def process_dir(self, dir_path, relabel=False, is_query=True): 72 | img_paths = glob.glob(osp.join(dir_path,'*','*.jpg')) 73 | if is_query: 74 | camid = 0 75 | else: 76 | camid = 1 77 | pid_container = set() 78 | for img_path in img_paths: 79 | img_name = img_path.split('/')[-1] 80 | pid = int(img_name.split('_')[0]) 81 | pid_container.add(pid) 82 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 83 | 84 | data = [] 85 | for img_path in img_paths: 86 | img_name = img_path.split('/')[-1] 87 | pid = int(img_name.split('_')[0]) 88 | if relabel: 89 | pid = pid2label[pid] 90 | data.append((img_path, pid, camid)) 91 | return data 92 | 93 | def __getitem__(self, index): 94 | img_path, pid, camid = self.data[index] 95 | img = read_image(img_path) 96 | 97 | if self.load_pose: 98 | img_name = '.'.join(img_path.split('/')[-1].split('.')[:-1]) 99 | pose_pic_name = img_name + '_pose_heatmaps.png' 100 | if 'whole_body' in img_path: 101 | pose_pic_path = os.path.join(self.pose_dir,'whole_body_pose', pose_pic_name) 102 | else: 103 | pose_pic_path = os.path.join(self.pose_dir,'occluded_body_pose', pose_pic_name) 104 | pose = cv2.imread(pose_pic_path, cv2.IMREAD_GRAYSCALE) 105 | pose = pose.reshape((pose.shape[0], 56, -1)).transpose((0,2,1)).astype('float32') 106 | pose[:,:,18:] = np.abs(pose[:,:,18:]-128) 107 | img, pose = self.transform[1](img, pose) 108 | img = self.transform[0](img) 109 | return img, pid, camid, img_path, pose 110 | else: 111 | if self.transform is not None: 112 | img = self.transform(img) 113 | return img, pid, camid, img_path 114 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/prid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import random 9 | 10 | from torchreid.data.datasets import ImageDataset 11 | from torchreid.utils import read_json, write_json 12 | 13 | 14 | class PRID(ImageDataset): 15 | """PRID (single-shot version of prid-2011) 16 | 17 | Reference: 18 | Hirzer et al. Person Re-Identification by Descriptive and Discriminative 19 | Classification. SCIA 2011. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - Two views. 25 | - View A captures 385 identities. 26 | - View B captures 749 identities. 27 | - 200 identities appear in both views. 28 | """ 29 | dataset_dir = 'prid2011' 30 | dataset_url = None 31 | 32 | def __init__(self, root='', split_id=0, **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.cam_a_dir = osp.join(self.dataset_dir, 'prid_2011', 'single_shot', 'cam_a') 38 | self.cam_b_dir = osp.join(self.dataset_dir, 'prid_2011', 'single_shot', 'cam_b') 39 | self.split_path = osp.join(self.dataset_dir, 'splits_single_shot.json') 40 | 41 | required_files = [ 42 | self.dataset_dir, 43 | self.cam_a_dir, 44 | self.cam_b_dir 45 | ] 46 | self.check_before_run(required_files) 47 | 48 | self.prepare_split() 49 | splits = read_json(self.split_path) 50 | if split_id >= len(splits): 51 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 52 | split = splits[split_id] 53 | 54 | train, query, gallery = self.process_split(split) 55 | 56 | super(PRID, self).__init__(train, query, gallery, **kwargs) 57 | 58 | def prepare_split(self): 59 | if not osp.exists(self.split_path): 60 | print('Creating splits ...') 61 | 62 | splits = [] 63 | for _ in range(10): 64 | # randomly sample 100 IDs for train and use the rest 100 IDs for test 65 | # (note: there are only 200 IDs appearing in both views) 66 | pids = [i for i in range(1, 201)] 67 | train_pids = random.sample(pids, 100) 68 | train_pids.sort() 69 | test_pids = [i for i in pids if i not in train_pids] 70 | split = {'train': train_pids, 'test': test_pids} 71 | splits.append(split) 72 | 73 | print('Totally {} splits are created'.format(len(splits))) 74 | write_json(splits, self.split_path) 75 | print('Split file is saved to {}'.format(self.split_path)) 76 | 77 | def process_split(self, split): 78 | train, query, gallery = [], [], [] 79 | train_pids = split['train'] 80 | test_pids = split['test'] 81 | 82 | train_pid2label = {pid: label for label, pid in enumerate(train_pids)} 83 | 84 | # train 85 | train = [] 86 | for pid in train_pids: 87 | img_name = 'person_' + str(pid).zfill(4) + '.png' 88 | pid = train_pid2label[pid] 89 | img_a_path = osp.join(self.cam_a_dir, img_name) 90 | train.append((img_a_path, pid, 0)) 91 | img_b_path = osp.join(self.cam_b_dir, img_name) 92 | train.append((img_b_path, pid, 1)) 93 | 94 | # query and gallery 95 | query, gallery = [], [] 96 | for pid in test_pids: 97 | img_name = 'person_' + str(pid).zfill(4) + '.png' 98 | img_a_path = osp.join(self.cam_a_dir, img_name) 99 | query.append((img_a_path, pid, 0)) 100 | img_b_path = osp.join(self.cam_b_dir, img_name) 101 | gallery.append((img_b_path, pid, 1)) 102 | for pid in range(201, 750): 103 | img_name = 'person_' + str(pid).zfill(4) + '.png' 104 | img_b_path = osp.join(self.cam_b_dir, img_name) 105 | gallery.append((img_b_path, pid, 1)) 106 | 107 | return train, query, gallery -------------------------------------------------------------------------------- /torchreid/data/datasets/image/sensereid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import copy 10 | 11 | from torchreid.data.datasets import ImageDataset 12 | 13 | 14 | class SenseReID(ImageDataset): 15 | """SenseReID. 16 | 17 | This dataset is used for test purpose only. 18 | 19 | Reference: 20 | Zhao et al. Spindle Net: Person Re-identification with Human Body 21 | Region Guided Feature Decomposition and Fusion. CVPR 2017. 22 | 23 | URL: ``_ 24 | 25 | Dataset statistics: 26 | - query: 522 ids, 1040 images. 27 | - gallery: 1717 ids, 3388 images. 28 | """ 29 | dataset_dir = 'sensereid' 30 | dataset_url = None 31 | 32 | def __init__(self, root='', **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.query_dir = osp.join(self.dataset_dir, 'SenseReID', 'test_probe') 38 | self.gallery_dir = osp.join(self.dataset_dir, 'SenseReID', 'test_gallery') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.query_dir, 43 | self.gallery_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | query = self.process_dir(self.query_dir) 48 | gallery = self.process_dir(self.gallery_dir) 49 | 50 | # relabel 51 | g_pids = set() 52 | for _, pid, _ in gallery: 53 | g_pids.add(pid) 54 | pid2label = {pid: i for i, pid in enumerate(g_pids)} 55 | 56 | query = [(img_path, pid2label[pid], camid) for img_path, pid, camid in query] 57 | gallery = [(img_path, pid2label[pid], camid) for img_path, pid, camid in gallery] 58 | train = copy.deepcopy(query) + copy.deepcopy(gallery) # dummy variable 59 | 60 | super(SenseReID, self).__init__(train, query, gallery, **kwargs) 61 | 62 | def process_dir(self, dir_path): 63 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 64 | data = [] 65 | 66 | for img_path in img_paths: 67 | img_name = osp.splitext(osp.basename(img_path))[0] 68 | pid, camid = img_name.split('_') 69 | pid, camid = int(pid), int(camid) 70 | data.append((img_path, pid, camid)) 71 | 72 | return data -------------------------------------------------------------------------------- /torchreid/data/datasets/image/viper.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import numpy as np 10 | 11 | from torchreid.data.datasets import ImageDataset 12 | from torchreid.utils import read_json, write_json 13 | 14 | 15 | class VIPeR(ImageDataset): 16 | """VIPeR. 17 | 18 | Reference: 19 | Gray et al. Evaluating appearance models for recognition, reacquisition, and tracking. PETS 2007. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 632. 25 | - images: 632 x 2 = 1264. 26 | - cameras: 2. 27 | """ 28 | dataset_dir = 'viper' 29 | dataset_url = 'http://users.soe.ucsc.edu/~manduchi/VIPeR.v1.0.zip' 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.cam_a_dir = osp.join(self.dataset_dir, 'VIPeR', 'cam_a') 37 | self.cam_b_dir = osp.join(self.dataset_dir, 'VIPeR', 'cam_b') 38 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.cam_a_dir, 43 | self.cam_b_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | self.prepare_split() 48 | splits = read_json(self.split_path) 49 | if split_id >= len(splits): 50 | raise ValueError('split_id exceeds range, received {}, ' 51 | 'but expected between 0 and {}'.format(split_id, len(splits)-1)) 52 | split = splits[split_id] 53 | 54 | train = split['train'] 55 | query = split['query'] # query and gallery share the same images 56 | gallery = split['gallery'] 57 | 58 | train = [tuple(item) for item in train] 59 | query = [tuple(item) for item in query] 60 | gallery = [tuple(item) for item in gallery] 61 | 62 | super(VIPeR, self).__init__(train, query, gallery, **kwargs) 63 | 64 | def prepare_split(self): 65 | if not osp.exists(self.split_path): 66 | print('Creating 10 random splits of train ids and test ids') 67 | 68 | cam_a_imgs = sorted(glob.glob(osp.join(self.cam_a_dir, '*.bmp'))) 69 | cam_b_imgs = sorted(glob.glob(osp.join(self.cam_b_dir, '*.bmp'))) 70 | assert len(cam_a_imgs) == len(cam_b_imgs) 71 | num_pids = len(cam_a_imgs) 72 | print('Number of identities: {}'.format(num_pids)) 73 | num_train_pids = num_pids // 2 74 | 75 | """ 76 | In total, there will be 20 splits because each random split creates two 77 | sub-splits, one using cameraA as query and cameraB as gallery 78 | while the other using cameraB as query and cameraA as gallery. 79 | Therefore, results should be averaged over 20 splits (split_id=0~19). 80 | 81 | In practice, a model trained on split_id=0 can be applied to split_id=0&1 82 | as split_id=0&1 share the same training data (so on and so forth). 83 | """ 84 | splits = [] 85 | for _ in range(10): 86 | order = np.arange(num_pids) 87 | np.random.shuffle(order) 88 | train_idxs = order[:num_train_pids] 89 | test_idxs = order[num_train_pids:] 90 | assert not bool(set(train_idxs) & set(test_idxs)), 'Error: train and test overlap' 91 | 92 | train = [] 93 | for pid, idx in enumerate(train_idxs): 94 | cam_a_img = cam_a_imgs[idx] 95 | cam_b_img = cam_b_imgs[idx] 96 | train.append((cam_a_img, pid, 0)) 97 | train.append((cam_b_img, pid, 1)) 98 | 99 | test_a = [] 100 | test_b = [] 101 | for pid, idx in enumerate(test_idxs): 102 | cam_a_img = cam_a_imgs[idx] 103 | cam_b_img = cam_b_imgs[idx] 104 | test_a.append((cam_a_img, pid, 0)) 105 | test_b.append((cam_b_img, pid, 1)) 106 | 107 | # use cameraA as query and cameraB as gallery 108 | split = { 109 | 'train': train, 110 | 'query': test_a, 111 | 'gallery': test_b, 112 | 'num_train_pids': num_train_pids, 113 | 'num_query_pids': num_pids - num_train_pids, 114 | 'num_gallery_pids': num_pids - num_train_pids 115 | } 116 | splits.append(split) 117 | 118 | # use cameraB as query and cameraA as gallery 119 | split = { 120 | 'train': train, 121 | 'query': test_b, 122 | 'gallery': test_a, 123 | 'num_train_pids': num_train_pids, 124 | 'num_query_pids': num_pids - num_train_pids, 125 | 'num_gallery_pids': num_pids - num_train_pids 126 | } 127 | splits.append(split) 128 | 129 | print('Totally {} splits are created'.format(len(splits))) 130 | write_json(splits, self.split_path) 131 | print('Split file saved to {}'.format(self.split_path)) -------------------------------------------------------------------------------- /torchreid/data/datasets/video/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .mars import Mars 5 | from .ilidsvid import iLIDSVID 6 | from .prid2011 import PRID2011 7 | from .dukemtmcvidreid import DukeMTMCVidReID -------------------------------------------------------------------------------- /torchreid/data/datasets/video/dukemtmcvidreid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import warnings 10 | 11 | from torchreid.data.datasets import VideoDataset 12 | from torchreid.utils import read_json, write_json 13 | 14 | 15 | class DukeMTMCVidReID(VideoDataset): 16 | """DukeMTMCVidReID. 17 | 18 | Reference: 19 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, 20 | Multi-Camera Tracking. ECCVW 2016. 21 | - Wu et al. Exploit the Unknown Gradually: One-Shot Video-Based Person 22 | Re-Identification by Stepwise Learning. CVPR 2018. 23 | 24 | URL: ``_ 25 | 26 | Dataset statistics: 27 | - identities: 702 (train) + 702 (test). 28 | - tracklets: 2196 (train) + 2636 (test). 29 | """ 30 | dataset_dir = 'dukemtmc-vidreid' 31 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-VideoReID.zip' 32 | 33 | def __init__(self, root='', min_seq_len=0, **kwargs): 34 | self.root = osp.abspath(osp.expanduser(root)) 35 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 36 | self.download_dataset(self.dataset_dir, self.dataset_url) 37 | 38 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/train') 39 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/query') 40 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/gallery') 41 | self.split_train_json_path = osp.join(self.dataset_dir, 'split_train.json') 42 | self.split_query_json_path = osp.join(self.dataset_dir, 'split_query.json') 43 | self.split_gallery_json_path = osp.join(self.dataset_dir, 'split_gallery.json') 44 | self.min_seq_len = min_seq_len 45 | 46 | required_files = [ 47 | self.dataset_dir, 48 | self.train_dir, 49 | self.query_dir, 50 | self.gallery_dir 51 | ] 52 | self.check_before_run(required_files) 53 | 54 | train = self.process_dir(self.train_dir, self.split_train_json_path, relabel=True) 55 | query = self.process_dir(self.query_dir, self.split_query_json_path, relabel=False) 56 | gallery = self.process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) 57 | 58 | super(DukeMTMCVidReID, self).__init__(train, query, gallery, **kwargs) 59 | 60 | def process_dir(self, dir_path, json_path, relabel): 61 | if osp.exists(json_path): 62 | split = read_json(json_path) 63 | return split['tracklets'] 64 | 65 | print('=> Generating split json file (** this might take a while **)') 66 | pdirs = glob.glob(osp.join(dir_path, '*')) # avoid .DS_Store 67 | print('Processing "{}" with {} person identities'.format(dir_path, len(pdirs))) 68 | 69 | pid_container = set() 70 | for pdir in pdirs: 71 | pid = int(osp.basename(pdir)) 72 | pid_container.add(pid) 73 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 74 | 75 | tracklets = [] 76 | for pdir in pdirs: 77 | pid = int(osp.basename(pdir)) 78 | if relabel: 79 | pid = pid2label[pid] 80 | tdirs = glob.glob(osp.join(pdir, '*')) 81 | for tdir in tdirs: 82 | raw_img_paths = glob.glob(osp.join(tdir, '*.jpg')) 83 | num_imgs = len(raw_img_paths) 84 | 85 | if num_imgs < self.min_seq_len: 86 | continue 87 | 88 | img_paths = [] 89 | for img_idx in range(num_imgs): 90 | # some tracklet starts from 0002 instead of 0001 91 | img_idx_name = 'F' + str(img_idx+1).zfill(4) 92 | res = glob.glob(osp.join(tdir, '*' + img_idx_name + '*.jpg')) 93 | if len(res) == 0: 94 | warnings.warn('Index name {} in {} is missing, skip'.format(img_idx_name, tdir)) 95 | continue 96 | img_paths.append(res[0]) 97 | img_name = osp.basename(img_paths[0]) 98 | if img_name.find('_') == -1: 99 | # old naming format: 0001C6F0099X30823.jpg 100 | camid = int(img_name[5]) - 1 101 | else: 102 | # new naming format: 0001_C6_F0099_X30823.jpg 103 | camid = int(img_name[6]) - 1 104 | img_paths = tuple(img_paths) 105 | tracklets.append((img_paths, pid, camid)) 106 | 107 | print('Saving split to {}'.format(json_path)) 108 | split_dict = {'tracklets': tracklets} 109 | write_json(split_dict, json_path) 110 | 111 | return tracklets -------------------------------------------------------------------------------- /torchreid/data/datasets/video/ilidsvid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | from scipy.io import loadmat 10 | 11 | from torchreid.data.datasets import VideoDataset 12 | from torchreid.utils import read_json, write_json 13 | 14 | 15 | class iLIDSVID(VideoDataset): 16 | """iLIDS-VID. 17 | 18 | Reference: 19 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 300. 25 | - tracklets: 600. 26 | - cameras: 2. 27 | """ 28 | dataset_dir = 'ilids-vid' 29 | dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar' 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.data_dir = osp.join(self.dataset_dir, 'i-LIDS-VID') 37 | self.split_dir = osp.join(self.dataset_dir, 'train-test people splits') 38 | self.split_mat_path = osp.join(self.split_dir, 'train_test_splits_ilidsvid.mat') 39 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 40 | self.cam_1_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam1') 41 | self.cam_2_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam2') 42 | 43 | required_files = [ 44 | self.dataset_dir, 45 | self.data_dir, 46 | self.split_dir 47 | ] 48 | self.check_before_run(required_files) 49 | 50 | self.prepare_split() 51 | splits = read_json(self.split_path) 52 | if split_id >= len(splits): 53 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 54 | split = splits[split_id] 55 | train_dirs, test_dirs = split['train'], split['test'] 56 | 57 | train = self.process_data(train_dirs, cam1=True, cam2=True) 58 | query = self.process_data(test_dirs, cam1=True, cam2=False) 59 | gallery = self.process_data(test_dirs, cam1=False, cam2=True) 60 | 61 | super(iLIDSVID, self).__init__(train, query, gallery, **kwargs) 62 | 63 | def prepare_split(self): 64 | if not osp.exists(self.split_path): 65 | print('Creating splits ...') 66 | mat_split_data = loadmat(self.split_mat_path)['ls_set'] 67 | 68 | num_splits = mat_split_data.shape[0] 69 | num_total_ids = mat_split_data.shape[1] 70 | assert num_splits == 10 71 | assert num_total_ids == 300 72 | num_ids_each = num_total_ids // 2 73 | 74 | # pids in mat_split_data are indices, so we need to transform them 75 | # to real pids 76 | person_cam1_dirs = sorted(glob.glob(osp.join(self.cam_1_path, '*'))) 77 | person_cam2_dirs = sorted(glob.glob(osp.join(self.cam_2_path, '*'))) 78 | 79 | person_cam1_dirs = [osp.basename(item) for item in person_cam1_dirs] 80 | person_cam2_dirs = [osp.basename(item) for item in person_cam2_dirs] 81 | 82 | # make sure persons in one camera view can be found in the other camera view 83 | assert set(person_cam1_dirs) == set(person_cam2_dirs) 84 | 85 | splits = [] 86 | for i_split in range(num_splits): 87 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14. 88 | train_idxs = sorted(list(mat_split_data[i_split, num_ids_each:])) 89 | test_idxs = sorted(list(mat_split_data[i_split, :num_ids_each])) 90 | 91 | train_idxs = [int(i)-1 for i in train_idxs] 92 | test_idxs = [int(i)-1 for i in test_idxs] 93 | 94 | # transform pids to person dir names 95 | train_dirs = [person_cam1_dirs[i] for i in train_idxs] 96 | test_dirs = [person_cam1_dirs[i] for i in test_idxs] 97 | 98 | split = {'train': train_dirs, 'test': test_dirs} 99 | splits.append(split) 100 | 101 | print('Totally {} splits are created, following Wang et al. ECCV\'14'.format(len(splits))) 102 | print('Split file is saved to {}'.format(self.split_path)) 103 | write_json(splits, self.split_path) 104 | 105 | def process_data(self, dirnames, cam1=True, cam2=True): 106 | tracklets = [] 107 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 108 | 109 | for dirname in dirnames: 110 | if cam1: 111 | person_dir = osp.join(self.cam_1_path, dirname) 112 | img_names = glob.glob(osp.join(person_dir, '*.png')) 113 | assert len(img_names) > 0 114 | img_names = tuple(img_names) 115 | pid = dirname2pid[dirname] 116 | tracklets.append((img_names, pid, 0)) 117 | 118 | if cam2: 119 | person_dir = osp.join(self.cam_2_path, dirname) 120 | img_names = glob.glob(osp.join(person_dir, '*.png')) 121 | assert len(img_names) > 0 122 | img_names = tuple(img_names) 123 | pid = dirname2pid[dirname] 124 | tracklets.append((img_names, pid, 1)) 125 | 126 | return tracklets -------------------------------------------------------------------------------- /torchreid/data/datasets/video/mars.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | from scipy.io import loadmat 9 | import warnings 10 | 11 | from torchreid.data.datasets import VideoDataset 12 | 13 | 14 | class Mars(VideoDataset): 15 | """MARS. 16 | 17 | Reference: 18 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 19 | 20 | URL: ``_ 21 | 22 | Dataset statistics: 23 | - identities: 1261. 24 | - tracklets: 8298 (train) + 1980 (query) + 9330 (gallery). 25 | - cameras: 6. 26 | """ 27 | dataset_dir = 'mars' 28 | dataset_url = None 29 | 30 | def __init__(self, root='', **kwargs): 31 | self.root = osp.abspath(osp.expanduser(root)) 32 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 33 | self.download_dataset(self.dataset_dir, self.dataset_url) 34 | 35 | self.train_name_path = osp.join(self.dataset_dir, 'info/train_name.txt') 36 | self.test_name_path = osp.join(self.dataset_dir, 'info/test_name.txt') 37 | self.track_train_info_path = osp.join(self.dataset_dir, 'info/tracks_train_info.mat') 38 | self.track_test_info_path = osp.join(self.dataset_dir, 'info/tracks_test_info.mat') 39 | self.query_IDX_path = osp.join(self.dataset_dir, 'info/query_IDX.mat') 40 | 41 | required_files = [ 42 | self.dataset_dir, 43 | self.train_name_path, 44 | self.test_name_path, 45 | self.track_train_info_path, 46 | self.track_test_info_path, 47 | self.query_IDX_path 48 | ] 49 | self.check_before_run(required_files) 50 | 51 | train_names = self.get_names(self.train_name_path) 52 | test_names = self.get_names(self.test_name_path) 53 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) 54 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) 55 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) 56 | query_IDX -= 1 # index from 0 57 | track_query = track_test[query_IDX,:] 58 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 59 | track_gallery = track_test[gallery_IDX,:] 60 | 61 | train = self.process_data(train_names, track_train, home_dir='bbox_train', relabel=True) 62 | query = self.process_data(test_names, track_query, home_dir='bbox_test', relabel=False) 63 | gallery = self.process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False) 64 | 65 | super(Mars, self).__init__(train, query, gallery, **kwargs) 66 | 67 | def get_names(self, fpath): 68 | names = [] 69 | with open(fpath, 'r') as f: 70 | for line in f: 71 | new_line = line.rstrip() 72 | names.append(new_line) 73 | return names 74 | 75 | def process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 76 | assert home_dir in ['bbox_train', 'bbox_test'] 77 | num_tracklets = meta_data.shape[0] 78 | pid_list = list(set(meta_data[:,2].tolist())) 79 | num_pids = len(pid_list) 80 | 81 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 82 | tracklets = [] 83 | 84 | for tracklet_idx in range(num_tracklets): 85 | data = meta_data[tracklet_idx,...] 86 | start_index, end_index, pid, camid = data 87 | if pid == -1: 88 | continue # junk images are just ignored 89 | assert 1 <= camid <= 6 90 | if relabel: pid = pid2label[pid] 91 | camid -= 1 # index starts from 0 92 | img_names = names[start_index - 1:end_index] 93 | 94 | # make sure image names correspond to the same person 95 | pnames = [img_name[:4] for img_name in img_names] 96 | assert len(set(pnames)) == 1, 'Error: a single tracklet contains different person images' 97 | 98 | # make sure all images are captured under the same camera 99 | camnames = [img_name[5] for img_name in img_names] 100 | assert len(set(camnames)) == 1, 'Error: images are captured under different cameras!' 101 | 102 | # append image names with directory information 103 | img_paths = [osp.join(self.dataset_dir, home_dir, img_name[:4], img_name) for img_name in img_names] 104 | if len(img_paths) >= min_seq_len: 105 | img_paths = tuple(img_paths) 106 | tracklets.append((img_paths, pid, camid)) 107 | 108 | return tracklets 109 | 110 | def combine_all(self): 111 | warnings.warn('Some query IDs do not appear in gallery. Therefore, combineall ' 112 | 'does not make any difference to Mars') -------------------------------------------------------------------------------- /torchreid/data/datasets/video/prid2011.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | 10 | from torchreid.data.datasets import VideoDataset 11 | from torchreid.utils import read_json, write_json 12 | 13 | 14 | class PRID2011(VideoDataset): 15 | """PRID2011. 16 | 17 | Reference: 18 | Hirzer et al. Person Re-Identification by Descriptive and 19 | Discriminative Classification. SCIA 2011. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 200. 25 | - tracklets: 400. 26 | - cameras: 2. 27 | """ 28 | dataset_dir = 'prid2011' 29 | dataset_url = None 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.split_path = osp.join(self.dataset_dir, 'splits_prid2011.json') 37 | self.cam_a_dir = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_a') 38 | self.cam_b_dir = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_b') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.cam_a_dir, 43 | self.cam_b_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | splits = read_json(self.split_path) 48 | if split_id >= len(splits): 49 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 50 | split = splits[split_id] 51 | train_dirs, test_dirs = split['train'], split['test'] 52 | 53 | train = self.process_dir(train_dirs, cam1=True, cam2=True) 54 | query = self.process_dir(test_dirs, cam1=True, cam2=False) 55 | gallery = self.process_dir(test_dirs, cam1=False, cam2=True) 56 | 57 | super(PRID2011, self).__init__(train, query, gallery, **kwargs) 58 | 59 | def process_dir(self, dirnames, cam1=True, cam2=True): 60 | tracklets = [] 61 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 62 | 63 | for dirname in dirnames: 64 | if cam1: 65 | person_dir = osp.join(self.cam_a_dir, dirname) 66 | img_names = glob.glob(osp.join(person_dir, '*.png')) 67 | assert len(img_names) > 0 68 | img_names = tuple(img_names) 69 | pid = dirname2pid[dirname] 70 | tracklets.append((img_names, pid, 0)) 71 | 72 | if cam2: 73 | person_dir = osp.join(self.cam_b_dir, dirname) 74 | img_names = glob.glob(osp.join(person_dir, '*.png')) 75 | assert len(img_names) > 0 76 | img_names = tuple(img_names) 77 | pid = dirname2pid[dirname] 78 | tracklets.append((img_names, pid, 1)) 79 | 80 | return tracklets -------------------------------------------------------------------------------- /torchreid/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from collections import defaultdict 5 | import numpy as np 6 | import copy 7 | import random 8 | 9 | import torch 10 | from torch.utils.data.sampler import Sampler, RandomSampler 11 | 12 | 13 | class RandomIdentitySampler(Sampler): 14 | """Randomly samples N identities each with K instances. 15 | 16 | Args: 17 | data_source (list): contains tuples of (img_path(s), pid, camid). 18 | batch_size (int): batch size. 19 | num_instances (int): number of instances per identity in a batch. 20 | """ 21 | def __init__(self, data_source, batch_size, num_instances): 22 | if batch_size < num_instances: 23 | raise ValueError('batch_size={} must be no less ' 24 | 'than num_instances={}'.format(batch_size, num_instances)) 25 | 26 | self.data_source = data_source 27 | self.batch_size = batch_size 28 | self.num_instances = num_instances 29 | self.num_pids_per_batch = self.batch_size // self.num_instances 30 | self.index_dic = defaultdict(list) 31 | for index, (_, pid, _) in enumerate(self.data_source): 32 | self.index_dic[pid].append(index) 33 | self.pids = list(self.index_dic.keys()) 34 | 35 | # estimate number of examples in an epoch 36 | # TODO: improve precision 37 | self.length = 0 38 | for pid in self.pids: 39 | idxs = self.index_dic[pid] 40 | num = len(idxs) 41 | if num < self.num_instances: 42 | num = self.num_instances 43 | self.length += num - num % self.num_instances 44 | 45 | def __iter__(self): 46 | batch_idxs_dict = defaultdict(list) 47 | 48 | for pid in self.pids: 49 | idxs = copy.deepcopy(self.index_dic[pid]) 50 | if len(idxs) < self.num_instances: 51 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 52 | random.shuffle(idxs) 53 | batch_idxs = [] 54 | for idx in idxs: 55 | batch_idxs.append(idx) 56 | if len(batch_idxs) == self.num_instances: 57 | batch_idxs_dict[pid].append(batch_idxs) 58 | batch_idxs = [] 59 | 60 | avai_pids = copy.deepcopy(self.pids) 61 | final_idxs = [] 62 | 63 | while len(avai_pids) >= self.num_pids_per_batch: 64 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 65 | for pid in selected_pids: 66 | batch_idxs = batch_idxs_dict[pid].pop(0) 67 | final_idxs.extend(batch_idxs) 68 | if len(batch_idxs_dict[pid]) == 0: 69 | avai_pids.remove(pid) 70 | 71 | return iter(final_idxs) 72 | 73 | def __len__(self): 74 | return self.length 75 | 76 | 77 | def build_train_sampler(data_source, train_sampler, batch_size=32, num_instances=4, **kwargs): 78 | """Builds a training sampler. 79 | 80 | Args: 81 | data_source (list): contains tuples of (img_path(s), pid, camid). 82 | train_sampler (str): sampler name (default: ``RandomSampler``). 83 | batch_size (int, optional): batch size. Default is 32. 84 | num_instances (int, optional): number of instances per identity in a 85 | batch (for ``RandomIdentitySampler``). Default is 4. 86 | """ 87 | if train_sampler == 'RandomIdentitySampler': 88 | sampler = RandomIdentitySampler(data_source, batch_size, num_instances) 89 | 90 | else: 91 | sampler = RandomSampler(data_source) 92 | 93 | return sampler -------------------------------------------------------------------------------- /torchreid/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .engine import Engine 5 | 6 | from .image import ImageSoftmaxEngine 7 | from .image import ImageTripletEngine 8 | from .image import PoseSoftmaxEngine, PoseSoftmaxEngine_wscorereg 9 | 10 | from .video import VideoSoftmaxEngine 11 | from .video import VideoTripletEngine -------------------------------------------------------------------------------- /torchreid/engine/image/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .softmax import ImageSoftmaxEngine, PoseSoftmaxEngine, PoseSoftmaxEngine_wscorereg 4 | from .triplet import ImageTripletEngine -------------------------------------------------------------------------------- /torchreid/engine/video/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .softmax import VideoSoftmaxEngine 4 | from .triplet import VideoTripletEngine -------------------------------------------------------------------------------- /torchreid/engine/video/softmax.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import time 6 | import datetime 7 | 8 | import torch 9 | 10 | import torchreid 11 | from torchreid.engine.image import ImageSoftmaxEngine 12 | 13 | 14 | class VideoSoftmaxEngine(ImageSoftmaxEngine): 15 | """Softmax-loss engine for video-reid. 16 | 17 | Args: 18 | datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager`` 19 | or ``torchreid.data.VideoDataManager``. 20 | model (nn.Module): model instance. 21 | optimizer (Optimizer): an Optimizer. 22 | scheduler (LRScheduler, optional): if None, no learning rate decay will be performed. 23 | use_cpu (bool, optional): use cpu. Default is False. 24 | label_smooth (bool, optional): use label smoothing regularizer. Default is True. 25 | pooling_method (str, optional): how to pool features for a tracklet. 26 | Default is "avg" (average). Choices are ["avg", "max"]. 27 | 28 | Examples:: 29 | 30 | import torch 31 | import torchreid 32 | # Each batch contains batch_size*seq_len images 33 | datamanager = torchreid.data.VideoDataManager( 34 | root='path/to/reid-data', 35 | sources='mars', 36 | height=256, 37 | width=128, 38 | combineall=False, 39 | batch_size=8, # number of tracklets 40 | seq_len=15 # number of images in each tracklet 41 | ) 42 | model = torchreid.models.build_model( 43 | name='resnet50', 44 | num_classes=datamanager.num_train_pids, 45 | loss='softmax' 46 | ) 47 | model = model.cuda() 48 | optimizer = torchreid.optim.build_optimizer( 49 | model, optim='adam', lr=0.0003 50 | ) 51 | scheduler = torchreid.optim.build_lr_scheduler( 52 | optimizer, 53 | lr_scheduler='single_step', 54 | stepsize=20 55 | ) 56 | engine = torchreid.engine.VideoSoftmaxEngine( 57 | datamanager, model, optimizer, scheduler=scheduler, 58 | pooling_method='avg' 59 | ) 60 | engine.run( 61 | max_epoch=60, 62 | save_dir='log/resnet50-softmax-mars', 63 | print_freq=10 64 | ) 65 | """ 66 | 67 | def __init__(self, datamanager, model, optimizer, scheduler=None, 68 | use_cpu=False, label_smooth=True, pooling_method='avg'): 69 | super(VideoSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler=scheduler, 70 | use_cpu=use_cpu, label_smooth=label_smooth) 71 | self.pooling_method = pooling_method 72 | 73 | def _parse_data_for_train(self, data): 74 | imgs = data[0] 75 | pids = data[1] 76 | if imgs.dim() == 5: 77 | # b: batch size 78 | # s: sqeuence length 79 | # c: channel depth 80 | # h: height 81 | # w: width 82 | b, s, c, h, w = imgs.size() 83 | imgs = imgs.view(b*s, c, h, w) 84 | pids = pids.view(b, 1).expand(b, s) 85 | pids = pids.contiguous().view(b*s) 86 | return imgs, pids 87 | 88 | def _extract_features(self, input): 89 | self.model.eval() 90 | # b: batch size 91 | # s: sqeuence length 92 | # c: channel depth 93 | # h: height 94 | # w: width 95 | b, s, c, h, w = input.size() 96 | input = input.view(b*s, c, h, w) 97 | features = self.model(input) 98 | features = features.view(b, s, -1) 99 | if self.pooling_method == 'avg': 100 | features = torch.mean(features, 1) 101 | else: 102 | features = torch.max(features, 1)[0] 103 | return features -------------------------------------------------------------------------------- /torchreid/engine/video/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import time 6 | import datetime 7 | 8 | import torch 9 | 10 | import torchreid 11 | from torchreid.engine.image import ImageTripletEngine 12 | from torchreid.engine.video import VideoSoftmaxEngine 13 | 14 | 15 | class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine): 16 | """Triplet-loss engine for video-reid. 17 | 18 | Args: 19 | datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager`` 20 | or ``torchreid.data.VideoDataManager``. 21 | model (nn.Module): model instance. 22 | optimizer (Optimizer): an Optimizer. 23 | margin (float, optional): margin for triplet loss. Default is 0.3. 24 | weight_t (float, optional): weight for triplet loss. Default is 1. 25 | weight_x (float, optional): weight for softmax loss. Default is 1. 26 | scheduler (LRScheduler, optional): if None, no learning rate decay will be performed. 27 | use_cpu (bool, optional): use cpu. Default is False. 28 | label_smooth (bool, optional): use label smoothing regularizer. Default is True. 29 | pooling_method (str, optional): how to pool features for a tracklet. 30 | Default is "avg" (average). Choices are ["avg", "max"]. 31 | 32 | Examples:: 33 | 34 | import torch 35 | import torchreid 36 | # Each batch contains batch_size*seq_len images 37 | # Each identity is sampled with num_instances tracklets 38 | datamanager = torchreid.data.VideoDataManager( 39 | root='path/to/reid-data', 40 | sources='mars', 41 | height=256, 42 | width=128, 43 | combineall=False, 44 | num_instances=4, 45 | train_sampler='RandomIdentitySampler' 46 | batch_size=8, # number of tracklets 47 | seq_len=15 # number of images in each tracklet 48 | ) 49 | model = torchreid.models.build_model( 50 | name='resnet50', 51 | num_classes=datamanager.num_train_pids, 52 | loss='triplet' 53 | ) 54 | model = model.cuda() 55 | optimizer = torchreid.optim.build_optimizer( 56 | model, optim='adam', lr=0.0003 57 | ) 58 | scheduler = torchreid.optim.build_lr_scheduler( 59 | optimizer, 60 | lr_scheduler='single_step', 61 | stepsize=20 62 | ) 63 | engine = torchreid.engine.VideoTripletEngine( 64 | datamanager, model, optimizer, margin=0.3, 65 | weight_t=0.7, weight_x=1, scheduler=scheduler, 66 | pooling_method='avg' 67 | ) 68 | engine.run( 69 | max_epoch=60, 70 | save_dir='log/resnet50-triplet-mars', 71 | print_freq=10 72 | ) 73 | """ 74 | 75 | def __init__(self, datamanager, model, optimizer, margin=0.3, 76 | weight_t=1, weight_x=1, scheduler=None, use_cpu=False, 77 | label_smooth=True, pooling_method='avg'): 78 | super(VideoTripletEngine, self).__init__(datamanager, model, optimizer, margin=margin, 79 | weight_t=weight_t, weight_x=weight_x, 80 | scheduler=scheduler, use_cpu=use_cpu, 81 | label_smooth=label_smooth) 82 | self.pooling_method = pooling_method -------------------------------------------------------------------------------- /torchreid/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .cross_entropy_loss import CrossEntropyLoss, Isolate_loss 6 | from .hard_mine_triplet_loss import TripletLoss, Part_similarity_constrain 7 | 8 | 9 | def DeepSupervision(criterion, xs, y, part_weights): 10 | """DeepSupervision 11 | 12 | Applies criterion to each element in a list. 13 | 14 | Args: 15 | criterion: loss function 16 | xs: tuple of inputs 17 | y: ground truth 18 | """ 19 | loss = 0. 20 | for i, x in enumerate(xs): 21 | if part_weights is not None: 22 | loss += criterion(x, y, part_weights[:,i]) 23 | else: 24 | loss += criterion(x, y) 25 | # loss /= len(xs) 26 | return loss -------------------------------------------------------------------------------- /torchreid/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class CrossEntropyLoss(nn.Module): 9 | r"""Cross entropy loss with label smoothing regularizer. 10 | 11 | Reference: 12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 13 | 14 | With label smoothing, the label :math:`y` for a class is computed by 15 | 16 | .. math:: 17 | \begin{equation} 18 | (1 - \epsilon) \times y + \frac{\epsilon}{K}, 19 | \end{equation} 20 | 21 | where :math:`K` denotes the number of classes and :math:`\epsilon` is a weight. When 22 | :math:`\epsilon = 0`, the loss function reduces to the normal cross entropy. 23 | 24 | Args: 25 | num_classes (int): number of classes. 26 | epsilon (float, optional): weight. Default is 0.1. 27 | use_gpu (bool, optional): whether to use gpu devices. Default is True. 28 | label_smooth (bool, optional): whether to apply label smoothing. Default is True. 29 | """ 30 | 31 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True): 32 | super(CrossEntropyLoss, self).__init__() 33 | self.num_classes = num_classes 34 | self.epsilon = epsilon if label_smooth else 0 35 | self.use_gpu = use_gpu 36 | self.logsoftmax = nn.LogSoftmax(dim=1) 37 | 38 | def forward(self, inputs, targets, part_weight=None): 39 | """ 40 | Args: 41 | inputs (torch.Tensor): prediction matrix (before softmax) with 42 | shape (batch_size, num_classes). 43 | targets (torch.LongTensor): ground truth labels with shape (batch_size). 44 | Each position contains the label index. 45 | """ 46 | log_probs = self.logsoftmax(inputs) 47 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 48 | if self.use_gpu: targets = targets.cuda() 49 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 50 | if part_weight is None: 51 | return (- targets * log_probs).mean(0).sum() 52 | else: 53 | return ((- targets * log_probs).sum(1) * part_weight / (part_weight.sum()+1e-6)).sum() 54 | 55 | class Isolate_loss(nn.Module): 56 | def forward(self, inputs): 57 | # att_flatten=nn.functional.softmax(inputs.view(inputs.size(0), inputs.size(1), -1), dim=2) 58 | att_flatten=nn.functional.normalize(inputs.view(inputs.size(0), inputs.size(1), -1), dim=2) 59 | att_sim_matrix = att_flatten.matmul(att_flatten.transpose(1,2)) 60 | diag_element_mean = (att_flatten*att_flatten).sum(2).sum(1).mean() 61 | # att_sim_matrix = torch.triu(att_flatten.matmul(att_flatten.transpose(1,2)), diagonal=1) 62 | loss = att_sim_matrix.sum(1).sum(1).mean()-diag_element_mean 63 | return loss 64 | -------------------------------------------------------------------------------- /torchreid/losses/hard_mine_triplet_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | # import miosqp 8 | 9 | class TripletLoss(nn.Module): 10 | """Triplet loss with hard positive/negative mining. 11 | 12 | Reference: 13 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 14 | 15 | Imported from ``_. 16 | 17 | Args: 18 | margin (float, optional): margin for triplet. Default is 0.3. 19 | """ 20 | 21 | def __init__(self, margin=0.3): 22 | super(TripletLoss, self).__init__() 23 | self.margin = margin 24 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 25 | 26 | def forward(self, inputs, targets): 27 | """ 28 | Args: 29 | inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim). 30 | targets (torch.LongTensor): ground truth labels with shape (num_classes). 31 | """ 32 | n = inputs.size(0) 33 | 34 | # Compute pairwise distance, replace by the official when merged 35 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 36 | dist = dist + dist.t() 37 | dist.addmm_(1, -2, inputs, inputs.t()) 38 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 39 | 40 | # For each anchor, find the hardest positive and negative 41 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 42 | dist_ap, dist_an = [], [] 43 | for i in range(n): 44 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 45 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 46 | dist_ap = torch.cat(dist_ap) 47 | dist_an = torch.cat(dist_an) 48 | 49 | # Compute ranking hinge loss 50 | y = torch.ones_like(dist_an) 51 | return self.ranking_loss(dist_an, dist_ap, y) 52 | 53 | class Part_similarity_constrain(nn.Module): 54 | #output a binary vector:b which indicate the matching label of each part 55 | #output the cost function:l for updating the pose subnet 56 | def __init__(self, momentum=0.5, ss_init=0.9, lambd_rate=0.9, part_num=6): 57 | super(Part_similarity_constrain, self).__init__() 58 | self.lambd_rate=lambd_rate 59 | self.momentum = momentum #sue for update lambd and ss 60 | self.matching_criterion = nn.BCELoss() # visibility verification loss in the paper 61 | #part-part similarity between gallery and query images 62 | self.register_buffer('lambd', lambd_rate*torch.ones(part_num)) 63 | ##self_similarity between each part 64 | self.register_buffer('ss_mean', torch.zeros((part_num, part_num))) 65 | 66 | def forward(self, inputs, targets, matching_inputs, use_matching_loss=False): 67 | #input_size = [N,c,p,1] 68 | #matching_inputs [1] N*P which is the output of PVP module 69 | loss = 0 70 | inputs = torch.cat(inputs, dim=2) 71 | batchsize = inputs.size(0) 72 | num_part = matching_inputs.size(1) 73 | ss_diff_=inputs.expand((batchsize,inputs.size(1),num_part, num_part)) 74 | ss_diff = ss_diff_-ss_diff_.permute(0,1,3,2) 75 | # normalized difference vectors between different part: N*c*p*p 76 | ss_diff_norm=torch.nn.functional.normalize(ss_diff,p=2,dim=1) 77 | device_id = inputs.get_device() 78 | mask_s = (torch.ones((num_part,num_part))-torch.eye(num_part)).cuda(device_id) 79 | matching_targets = [] 80 | matching_logit = [] 81 | cs_batch = 0 82 | ss_batch = 0 83 | inputs = torch.nn.functional.normalize(inputs, p=2, dim=1).squeeze() 84 | ss = torch.matmul(inputs.transpose(1, 2), inputs) # similarity matrix between different part 85 | lambd = (ss*mask_s).mean(1)*num_part/(num_part-1) # lambda used for matching loss 86 | for i in range(batchsize): 87 | l_q = targets[i] 88 | for j in range(i+1,batchsize): 89 | l_g = targets[j] 90 | if i!=j and l_g==l_q: 91 | cs_ij_=(inputs[i]*inputs[j]).sum(0) #cross_similarity P*P 92 | cs_batch = cs_batch+cs_ij_.detach() 93 | cs_ij=torch.diag(cs_ij_) #cross_similarity P*P 94 | s_constr_= (ss_diff_norm[i]*ss_diff_norm[j]).sum(0) 95 | ss_batch=ss_batch+s_constr_.detach() 96 | W = cs_ij+(s_constr_-self.ss_mean) 97 | if use_matching_loss: 98 | x_optim, _= self.IQP_solver(W,self.lambd) 99 | x_optim = torch.from_numpy(x_optim).cuda(device_id) 100 | else: 101 | x_optim = torch.from_numpy(np.ones(W.size(0))).cuda(device_id) 102 | matching_targets.append(x_optim) 103 | matching_logit.append(matching_inputs[i]*matching_inputs[j]) 104 | loss += -torch.matmul(torch.matmul(x_optim,W),x_optim) + ((lambd[i]+lambd[j])*x_optim).sum()/2 105 | 106 | loss = loss/len(matching_targets) 107 | cs_batch=cs_batch/len(matching_targets) 108 | ss_batch=ss_batch/len(matching_targets) 109 | self.ss_mean=self.momentum*self.ss_mean+(1-self.momentum)*ss_batch 110 | self.lambd=self.momentum*self.lambd+(1-self.momentum)*self.lambd_rate*cs_batch 111 | if use_matching_loss: 112 | matching_targets = torch.stack(matching_targets) 113 | matching_logit = torch.stack(matching_logit) 114 | matching_loss = self.matching_criterion(matching_logit, matching_targets) 115 | else: 116 | matching_loss = 0 117 | return matching_loss, loss 118 | 119 | def IQP_solver(self, W, lambd): 120 | W = W.data.cpu().numpy() 121 | lambd = lambd.data.cpu().numpy() 122 | value = 0 123 | x_result = np.zeros(W.shape[0], dtype='float32') 124 | for i in range(2**(W.shape[0])): 125 | x = np.asarray(list('{:b}'.format(i).zfill(W.shape[0])), dtype='float32') 126 | value_i = np.matmul(np.matmul(x,W),x)-(lambd*x).sum() 127 | if value_i>value: 128 | value = value_i 129 | x_result = x 130 | return x_result, value 131 | -------------------------------------------------------------------------------- /torchreid/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .accuracy import accuracy 4 | from .rank import evaluate_rank 5 | from .distance import compute_distance_matrix, compute_weight_distance_matrix -------------------------------------------------------------------------------- /torchreid/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | """Computes the accuracy over the k top predictions for 8 | the specified values of k. 9 | 10 | Args: 11 | output (torch.Tensor): prediction matrix with shape (batch_size, num_classes). 12 | target (torch.LongTensor): ground truth labels with shape (batch_size). 13 | topk (tuple, optional): accuracy at top-k will be computed. For example, 14 | topk=(1, 5) means accuracy at top-1 and top-5 will be computed. 15 | 16 | Returns: 17 | list: accuracy at top-k. 18 | 19 | Examples:: 20 | >>> from torchreid import metrics 21 | >>> metrics.accuracy(output, target) 22 | """ 23 | maxk = max(topk) 24 | batch_size = target.size(0) 25 | 26 | if isinstance(output, (tuple, list)): 27 | output = output[0] 28 | 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 32 | 33 | res = [] 34 | for k in topk: 35 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 36 | acc = correct_k.mul_(100.0 / batch_size) 37 | res.append(acc) 38 | 39 | return res -------------------------------------------------------------------------------- /torchreid/metrics/distance.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | import torch 8 | from torch.nn import functional as F 9 | 10 | 11 | def compute_distance_matrix(input1, input2, metric='euclidean'): 12 | """A wrapper function for computing distance matrix. 13 | 14 | Args: 15 | input1 (torch.Tensor): 2-D feature matrix. 16 | input2 (torch.Tensor): 2-D feature matrix. 17 | metric (str, optional): "euclidean" or "cosine". 18 | Default is "euclidean". 19 | 20 | Returns: 21 | torch.Tensor: distance matrix. 22 | 23 | Examples:: 24 | >>> from torchreid import metrics 25 | >>> input1 = torch.rand(10, 2048) 26 | >>> input2 = torch.rand(100, 2048) 27 | >>> distmat = metrics.compute_distance_matrix(input1, input2) 28 | >>> distmat.size() # (10, 100) 29 | """ 30 | # check input 31 | assert isinstance(input1, torch.Tensor) 32 | assert isinstance(input2, torch.Tensor) 33 | assert input1.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(input1.dim()) 34 | assert input2.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(input2.dim()) 35 | assert input1.size(1) == input2.size(1) 36 | 37 | if metric == 'euclidean': 38 | distmat = euclidean_squared_distance(input1, input2) 39 | elif metric == 'cosine': 40 | distmat = cosine_distance(input1, input2) 41 | else: 42 | raise ValueError( 43 | 'Unknown distance metric: {}. ' 44 | 'Please choose either "euclidean" or "cosine"'.format(metric) 45 | ) 46 | 47 | return distmat 48 | 49 | 50 | def compute_weight_distance_matrix(input1, input2, input1_score, input2_score, metric='euclidean'): 51 | """A wrapper function for computing distance matrix. 52 | 53 | Args: 54 | input1 (torch.Tensor): 3-D feature matrix. N1*C*P 55 | input2 (torch.Tensor): 3-D feature matrix. N2*C*P 56 | input1_score N1*P 57 | input2_score N2*P 58 | metric (str, optional): "euclidean" or "cosine". 59 | Default is "euclidean". 60 | 61 | Returns: 62 | torch.Tensor: distance matrix. 63 | 64 | Examples:: 65 | >>> from torchreid import metrics 66 | >>> input1 = torch.rand(10, 2048) 67 | >>> input2 = torch.rand(100, 2048) 68 | >>> distmat = metrics.compute_distance_matrix(input1, input2) 69 | >>> distmat.size() # (10, 100) 70 | """ 71 | # check input 72 | 73 | input1_score = input1_score.unsqueeze(1).expand(input1.size(0),input2.size(0),input1.size(2)) 74 | input2_score = input2_score.unsqueeze(0).expand(input1.size(0),input2.size(0),input1.size(2)) 75 | score_map = input1_score*input2_score 76 | distmat = 0 77 | for i in range(input1.size(2)): 78 | if metric == 'euclidean': 79 | distmat += score_map[:,:,i]*euclidean_squared_distance(input1[:,:,i], input2[:,:,i]) 80 | elif metric == 'cosine': 81 | distmat += score_map[:,:,i]*cosine_distance(input1[:,:,i], input2[:,:,i]) 82 | else: 83 | raise ValueError( 84 | 'Unknown distance metric: {}. ' 85 | 'Please choose either "euclidean" or "cosine"'.format(metric) 86 | ) 87 | distmat = distmat/(score_map.sum(2)) 88 | 89 | return distmat 90 | 91 | def euclidean_squared_distance(input1, input2): 92 | """Computes euclidean squared distance. 93 | 94 | Args: 95 | input1 (torch.Tensor): 2-D feature matrix. 96 | input2 (torch.Tensor): 2-D feature matrix. 97 | 98 | Returns: 99 | torch.Tensor: distance matrix. 100 | """ 101 | m, n = input1.size(0), input2.size(0) 102 | distmat = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 103 | torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t() 104 | distmat.addmm_(1, -2, input1, input2.t()) 105 | return distmat 106 | 107 | 108 | def cosine_distance(input1, input2): 109 | """Computes cosine distance. 110 | 111 | Args: 112 | input1 (torch.Tensor): 2-D feature matrix. 113 | input2 (torch.Tensor): 2-D feature matrix. 114 | 115 | Returns: 116 | torch.Tensor: distance matrix. 117 | """ 118 | input1_normed = F.normalize(input1, p=2, dim=1) 119 | input2_normed = F.normalize(input2, p=2, dim=1) 120 | distmat = 1 - torch.mm(input1_normed, input2_normed.t()) 121 | return distmat -------------------------------------------------------------------------------- /torchreid/metrics/rank.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import copy 7 | from collections import defaultdict 8 | import sys 9 | import warnings 10 | 11 | try: 12 | from torchreid.metrics.rank_cylib.rank_cy import evaluate_cy 13 | IS_CYTHON_AVAI = True 14 | except ImportError: 15 | IS_CYTHON_AVAI = False 16 | warnings.warn( 17 | 'Cython evaluation (very fast so highly recommended) is ' 18 | 'unavailable, now use python evaluation.' 19 | ) 20 | 21 | 22 | def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 23 | """Evaluation with cuhk03 metric 24 | Key: one image for each gallery identity is randomly sampled for each query identity. 25 | Random sampling is performed num_repeats times. 26 | """ 27 | num_repeats = 10 28 | num_q, num_g = distmat.shape 29 | 30 | if num_g < max_rank: 31 | max_rank = num_g 32 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 33 | 34 | indices = np.argsort(distmat, axis=1) 35 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 36 | 37 | # compute cmc curve for each query 38 | all_cmc = [] 39 | all_AP = [] 40 | num_valid_q = 0. # number of valid query 41 | 42 | for q_idx in range(num_q): 43 | # get query pid and camid 44 | q_pid = q_pids[q_idx] 45 | q_camid = q_camids[q_idx] 46 | 47 | # remove gallery samples that have the same pid and camid with query 48 | order = indices[q_idx] 49 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 50 | keep = np.invert(remove) 51 | 52 | # compute cmc curve 53 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 54 | if not np.any(raw_cmc): 55 | # this condition is true when query identity does not appear in gallery 56 | continue 57 | 58 | kept_g_pids = g_pids[order][keep] 59 | g_pids_dict = defaultdict(list) 60 | for idx, pid in enumerate(kept_g_pids): 61 | g_pids_dict[pid].append(idx) 62 | 63 | cmc = 0. 64 | for repeat_idx in range(num_repeats): 65 | mask = np.zeros(len(raw_cmc), dtype=np.bool) 66 | for _, idxs in g_pids_dict.items(): 67 | # randomly sample one image for each gallery person 68 | rnd_idx = np.random.choice(idxs) 69 | mask[rnd_idx] = True 70 | masked_raw_cmc = raw_cmc[mask] 71 | _cmc = masked_raw_cmc.cumsum() 72 | _cmc[_cmc > 1] = 1 73 | cmc += _cmc[:max_rank].astype(np.float32) 74 | 75 | cmc /= num_repeats 76 | all_cmc.append(cmc) 77 | # compute AP 78 | num_rel = raw_cmc.sum() 79 | tmp_cmc = raw_cmc.cumsum() 80 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 81 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 82 | AP = tmp_cmc.sum() / num_rel 83 | all_AP.append(AP) 84 | num_valid_q += 1. 85 | 86 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 87 | 88 | all_cmc = np.asarray(all_cmc).astype(np.float32) 89 | all_cmc = all_cmc.sum(0) / num_valid_q 90 | mAP = np.mean(all_AP) 91 | 92 | return all_cmc, mAP 93 | 94 | 95 | def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 96 | """Evaluation with market1501 metric 97 | Key: for each query identity, its gallery images from the same camera view are discarded. 98 | """ 99 | num_q, num_g = distmat.shape 100 | 101 | if num_g < max_rank: 102 | max_rank = num_g 103 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 104 | 105 | indices = np.argsort(distmat, axis=1) 106 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 107 | 108 | # compute cmc curve for each query 109 | all_cmc = [] 110 | all_AP = [] 111 | num_valid_q = 0. # number of valid query 112 | 113 | for q_idx in range(num_q): 114 | # get query pid and camid 115 | q_pid = q_pids[q_idx] 116 | q_camid = q_camids[q_idx] 117 | 118 | # remove gallery samples that have the same pid and camid with query 119 | order = indices[q_idx] 120 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 121 | keep = np.invert(remove) 122 | 123 | # compute cmc curve 124 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 125 | if not np.any(raw_cmc): 126 | # this condition is true when query identity does not appear in gallery 127 | continue 128 | 129 | cmc = raw_cmc.cumsum() 130 | cmc[cmc > 1] = 1 131 | 132 | all_cmc.append(cmc[:max_rank]) 133 | num_valid_q += 1. 134 | 135 | # compute average precision 136 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 137 | num_rel = raw_cmc.sum() 138 | tmp_cmc = raw_cmc.cumsum() 139 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 140 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 141 | AP = tmp_cmc.sum() / num_rel 142 | all_AP.append(AP) 143 | 144 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 145 | 146 | all_cmc = np.asarray(all_cmc).astype(np.float32) 147 | all_cmc = all_cmc.sum(0) / num_valid_q 148 | mAP = np.mean(all_AP) 149 | 150 | return all_cmc, mAP 151 | 152 | 153 | def evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03): 154 | if use_metric_cuhk03: 155 | return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 156 | else: 157 | return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 158 | 159 | 160 | def evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, 161 | use_metric_cuhk03=False, use_cython=True): 162 | """Evaluates CMC rank. 163 | 164 | Args: 165 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 166 | q_pids (numpy.ndarray): 1-D array containing person identities 167 | of each query instance. 168 | g_pids (numpy.ndarray): 1-D array containing person identities 169 | of each gallery instance. 170 | q_camids (numpy.ndarray): 1-D array containing camera views under 171 | which each query instance is captured. 172 | g_camids (numpy.ndarray): 1-D array containing camera views under 173 | which each gallery instance is captured. 174 | max_rank (int, optional): maximum CMC rank to be computed. Default is 50. 175 | use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03. 176 | Default is False. This should be enabled when using cuhk03 classic split. 177 | use_cython (bool, optional): use cython code for evaluation. Default is True. 178 | This is highly recommended as the cython code can speed up the cmc computation 179 | by more than 10x. This requires Cython to be installed. 180 | """ 181 | if use_cython and IS_CYTHON_AVAI: 182 | return evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03) 183 | else: 184 | return evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03) -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | $(PYTHON) setup.py build_ext --inplace 3 | rm -rf build 4 | clean: 5 | rm -rf build 6 | rm -f rank_cy.c *.so -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hh23333/PVPM/9587276e13f497553bf6f801c3297ee0ce050c77/torchreid/metrics/rank_cylib/__init__.py -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | from Cython.Build import cythonize 4 | import numpy as np 5 | 6 | 7 | def numpy_include(): 8 | try: 9 | numpy_include = np.get_include() 10 | except AttributeError: 11 | numpy_include = np.get_numpy_include() 12 | return numpy_include 13 | 14 | 15 | ext_modules = [ 16 | Extension( 17 | 'rank_cy', 18 | ['rank_cy.pyx'], 19 | include_dirs=[numpy_include()], 20 | ) 21 | ] 22 | 23 | setup( 24 | name='Cython-based reid evaluation code', 25 | ext_modules=cythonize(ext_modules) 26 | ) -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/test_cython.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import os.path as osp 5 | import timeit 6 | import numpy as np 7 | 8 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 9 | from torchreid import metrics 10 | 11 | """ 12 | Test the speed of cython-based evaluation code. The speed improvements 13 | can be much bigger when using the real reid data, which contains a larger 14 | amount of query and gallery images. 15 | 16 | Note: you might encounter the following error: 17 | 'AssertionError: Error: all query identities do not appear in gallery'. 18 | This is normal because the inputs are random numbers. Just try again. 19 | """ 20 | 21 | print('*** Compare running time ***') 22 | 23 | setup = ''' 24 | import sys 25 | import os.path as osp 26 | import numpy as np 27 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 28 | from torchreid import metrics 29 | num_q = 30 30 | num_g = 300 31 | max_rank = 5 32 | distmat = np.random.rand(num_q, num_g) * 20 33 | q_pids = np.random.randint(0, num_q, size=num_q) 34 | g_pids = np.random.randint(0, num_g, size=num_g) 35 | q_camids = np.random.randint(0, 5, size=num_q) 36 | g_camids = np.random.randint(0, 5, size=num_g) 37 | ''' 38 | 39 | print('=> Using market1501\'s metric') 40 | pytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)', setup=setup, number=20) 41 | cytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)', setup=setup, number=20) 42 | print('Python time: {} s'.format(pytime)) 43 | print('Cython time: {} s'.format(cytime)) 44 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 45 | 46 | print('=> Using cuhk03\'s metric') 47 | pytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)', setup=setup, number=20) 48 | cytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)', setup=setup, number=20) 49 | print('Python time: {} s'.format(pytime)) 50 | print('Cython time: {} s'.format(cytime)) 51 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 52 | 53 | """ 54 | print("=> Check precision") 55 | 56 | num_q = 30 57 | num_g = 300 58 | max_rank = 5 59 | distmat = np.random.rand(num_q, num_g) * 20 60 | q_pids = np.random.randint(0, num_q, size=num_q) 61 | g_pids = np.random.randint(0, num_g, size=num_g) 62 | q_camids = np.random.randint(0, 5, size=num_q) 63 | g_camids = np.random.randint(0, 5, size=num_g) 64 | 65 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False) 66 | print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 67 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True) 68 | print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 69 | """ -------------------------------------------------------------------------------- /torchreid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | from .resnet import * 6 | from .resnet_ import resnet50_, resnet50_fc512_ 7 | from .resnetmid import * 8 | from .senet import * 9 | from .densenet import * 10 | from .inceptionresnetv2 import * 11 | from .inceptionv4 import * 12 | from .xception import * 13 | 14 | from .nasnet import * 15 | from .mobilenetv2 import * 16 | from .shufflenet import * 17 | from .squeezenet import * 18 | from .shufflenetv2 import * 19 | 20 | from .mudeep import * 21 | from .hacnn import * 22 | from .pcb import * 23 | from .mlfn import * 24 | from .osnet import * 25 | 26 | 27 | __model_factory = { 28 | # image classification models 29 | 'resnet18': resnet18, 30 | 'resnet34': resnet34, 31 | 'resnet50': resnet50, 32 | 'resnet50_': resnet50_, 33 | 'resnet50_fc512_': resnet50_fc512_, 34 | # 'pose_resnet50_fc512': pose_resnet50_fc512, 35 | 'resnet101': resnet101, 36 | 'resnet152': resnet152, 37 | 'resnext50_32x4d': resnext50_32x4d, 38 | 'resnext101_32x8d': resnext101_32x8d, 39 | 'resnet50_fc512': resnet50_fc512, 40 | 'se_resnet50': se_resnet50, 41 | 'se_resnet50_fc512': se_resnet50_fc512, 42 | 'se_resnet101': se_resnet101, 43 | 'se_resnext50_32x4d': se_resnext50_32x4d, 44 | 'se_resnext101_32x4d': se_resnext101_32x4d, 45 | 'densenet121': densenet121, 46 | 'densenet169': densenet169, 47 | 'densenet201': densenet201, 48 | 'densenet161': densenet161, 49 | 'densenet121_fc512': densenet121_fc512, 50 | 'inceptionresnetv2': inceptionresnetv2, 51 | 'inceptionv4': inceptionv4, 52 | 'xception': xception, 53 | # lightweight models 54 | 'nasnsetmobile': nasnetamobile, 55 | 'mobilenetv2_x1_0': mobilenetv2_x1_0, 56 | 'mobilenetv2_x1_4': mobilenetv2_x1_4, 57 | 'shufflenet': shufflenet, 58 | 'squeezenet1_0': squeezenet1_0, 59 | 'squeezenet1_0_fc512': squeezenet1_0_fc512, 60 | 'squeezenet1_1': squeezenet1_1, 61 | 'shufflenet_v2_x0_5': shufflenet_v2_x0_5, 62 | 'shufflenet_v2_x1_0': shufflenet_v2_x1_0, 63 | 'shufflenet_v2_x1_5': shufflenet_v2_x1_5, 64 | 'shufflenet_v2_x2_0': shufflenet_v2_x2_0, 65 | # reid-specific models 66 | 'mudeep': MuDeep, 67 | 'resnet50mid': resnet50mid, 68 | 'hacnn': HACNN, 69 | 'pcb_p6': pcb_p6, 70 | 'pcb_p4': pcb_p4, 71 | 'pose_p4':pose_resnet50_256_p4, 72 | 'pose_p6':pose_resnet50_256_p6, 73 | 'pose_p6s':pose_resnet50_256_p6_pscore_reg, 74 | 'pose_p4s':pose_resnet50_256_p4_pscore_reg, 75 | 'mlfn': mlfn, 76 | 'osnet_x1_0': osnet_x1_0, 77 | 'osnet_x0_75': osnet_x0_75, 78 | 'osnet_x0_5': osnet_x0_5, 79 | 'osnet_x0_25': osnet_x0_25, 80 | 'osnet_ibn_x1_0': osnet_ibn_x1_0 81 | } 82 | 83 | 84 | def show_avai_models(): 85 | """Displays available models. 86 | 87 | Examples:: 88 | >>> from torchreid import models 89 | >>> models.show_avai_models() 90 | """ 91 | print(list(__model_factory.keys())) 92 | 93 | 94 | def build_model(name, num_classes, loss='softmax', pretrained=True, use_gpu=True): 95 | """A function wrapper for building a model. 96 | 97 | Args: 98 | name (str): model name. 99 | num_classes (int): number of training identities. 100 | loss (str, optional): loss function to optimize the model. Currently 101 | supports "softmax" and "triplet". Default is "softmax". 102 | pretrained (bool, optional): whether to load ImageNet-pretrained weights. 103 | Default is True. 104 | use_gpu (bool, optional): whether to use gpu. Default is True. 105 | 106 | Returns: 107 | nn.Module 108 | 109 | Examples:: 110 | >>> from torchreid import models 111 | >>> model = models.build_model('resnet50', 751, loss='softmax') 112 | """ 113 | avai_models = list(__model_factory.keys()) 114 | if name not in avai_models: 115 | raise KeyError('Unknown model: {}. Must be one of {}'.format(name, avai_models)) 116 | return __model_factory[name]( 117 | num_classes=num_classes, 118 | loss=loss, 119 | pretrained=pretrained, 120 | use_gpu=use_gpu 121 | ) -------------------------------------------------------------------------------- /torchreid/models/mlfn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | __all__ = ['mlfn'] 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | import torchvision 10 | import torch.utils.model_zoo as model_zoo 11 | 12 | 13 | model_urls = { 14 | # training epoch = 5, top1 = 51.6 15 | 'imagenet': 'https://mega.nz/#!YHxAhaxC!yu9E6zWl0x5zscSouTdbZu8gdFFytDdl-RAdD2DEfpk', 16 | } 17 | 18 | 19 | class MLFNBlock(nn.Module): 20 | 21 | def __init__(self, in_channels, out_channels, stride, fsm_channels, groups=32): 22 | super(MLFNBlock, self).__init__() 23 | self.groups = groups 24 | mid_channels = out_channels // 2 25 | 26 | # Factor Modules 27 | self.fm_conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False) 28 | self.fm_bn1 = nn.BatchNorm2d(mid_channels) 29 | self.fm_conv2 = nn.Conv2d(mid_channels, mid_channels, 3, stride=stride, padding=1, bias=False, groups=self.groups) 30 | self.fm_bn2 = nn.BatchNorm2d(mid_channels) 31 | self.fm_conv3 = nn.Conv2d(mid_channels, out_channels, 1, bias=False) 32 | self.fm_bn3 = nn.BatchNorm2d(out_channels) 33 | 34 | # Factor Selection Module 35 | self.fsm = nn.Sequential( 36 | nn.AdaptiveAvgPool2d(1), 37 | nn.Conv2d(in_channels, fsm_channels[0], 1), 38 | nn.BatchNorm2d(fsm_channels[0]), 39 | nn.ReLU(inplace=True), 40 | nn.Conv2d(fsm_channels[0], fsm_channels[1], 1), 41 | nn.BatchNorm2d(fsm_channels[1]), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(fsm_channels[1], self.groups, 1), 44 | nn.BatchNorm2d(self.groups), 45 | nn.Sigmoid(), 46 | ) 47 | 48 | self.downsample = None 49 | if in_channels != out_channels or stride > 1: 50 | self.downsample = nn.Sequential( 51 | nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False), 52 | nn.BatchNorm2d(out_channels), 53 | ) 54 | 55 | def forward(self, x): 56 | residual = x 57 | s = self.fsm(x) 58 | 59 | # reduce dimension 60 | x = self.fm_conv1(x) 61 | x = self.fm_bn1(x) 62 | x = F.relu(x, inplace=True) 63 | 64 | # group convolution 65 | x = self.fm_conv2(x) 66 | x = self.fm_bn2(x) 67 | x = F.relu(x, inplace=True) 68 | 69 | # factor selection 70 | b, c = x.size(0), x.size(1) 71 | n = c // self.groups 72 | ss = s.repeat(1, n, 1, 1) # from (b, g, 1, 1) to (b, g*n=c, 1, 1) 73 | ss = ss.view(b, n, self.groups, 1, 1) 74 | ss = ss.permute(0, 2, 1, 3, 4).contiguous() 75 | ss = ss.view(b, c, 1, 1) 76 | x = ss * x 77 | 78 | # recover dimension 79 | x = self.fm_conv3(x) 80 | x = self.fm_bn3(x) 81 | x = F.relu(x, inplace=True) 82 | 83 | if self.downsample is not None: 84 | residual = self.downsample(residual) 85 | 86 | return F.relu(residual + x, inplace=True), s 87 | 88 | 89 | class MLFN(nn.Module): 90 | """Multi-Level Factorisation Net. 91 | 92 | Reference: 93 | Chang et al. Multi-Level Factorisation Net for 94 | Person Re-Identification. CVPR 2018. 95 | 96 | Public keys: 97 | - ``mlfn``: MLFN (Multi-Level Factorisation Net). 98 | """ 99 | 100 | def __init__(self, num_classes, loss='softmax', groups=32, channels=[64, 256, 512, 1024, 2048], embed_dim=1024, **kwargs): 101 | super(MLFN, self).__init__() 102 | self.loss = loss 103 | self.groups = groups 104 | 105 | # first convolutional layer 106 | self.conv1 = nn.Conv2d(3, channels[0], 7, stride=2, padding=3) 107 | self.bn1 = nn.BatchNorm2d(channels[0]) 108 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 109 | 110 | # main body 111 | self.feature = nn.ModuleList([ 112 | # layer 1-3 113 | MLFNBlock(channels[0], channels[1], 1, [128, 64], self.groups), 114 | MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups), 115 | MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups), 116 | # layer 4-7 117 | MLFNBlock(channels[1], channels[2], 2, [256, 128], self.groups), 118 | MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups), 119 | MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups), 120 | MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups), 121 | # layer 8-13 122 | MLFNBlock(channels[2], channels[3], 2, [512, 128], self.groups), 123 | MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups), 124 | MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups), 125 | MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups), 126 | MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups), 127 | MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups), 128 | # layer 14-16 129 | MLFNBlock(channels[3], channels[4], 2, [512, 128], self.groups), 130 | MLFNBlock(channels[4], channels[4], 1, [512, 128], self.groups), 131 | MLFNBlock(channels[4], channels[4], 1, [512, 128], self.groups), 132 | ]) 133 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 134 | 135 | # projection functions 136 | self.fc_x = nn.Sequential( 137 | nn.Conv2d(channels[4], embed_dim, 1, bias=False), 138 | nn.BatchNorm2d(embed_dim), 139 | nn.ReLU(inplace=True), 140 | ) 141 | self.fc_s = nn.Sequential( 142 | nn.Conv2d(self.groups * 16, embed_dim, 1, bias=False), 143 | nn.BatchNorm2d(embed_dim), 144 | nn.ReLU(inplace=True), 145 | ) 146 | 147 | self.classifier = nn.Linear(embed_dim, num_classes) 148 | 149 | self.init_params() 150 | 151 | def init_params(self): 152 | for m in self.modules(): 153 | if isinstance(m, nn.Conv2d): 154 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 155 | if m.bias is not None: 156 | nn.init.constant_(m.bias, 0) 157 | elif isinstance(m, nn.BatchNorm2d): 158 | nn.init.constant_(m.weight, 1) 159 | nn.init.constant_(m.bias, 0) 160 | elif isinstance(m, nn.Linear): 161 | nn.init.normal_(m.weight, 0, 0.01) 162 | if m.bias is not None: 163 | nn.init.constant_(m.bias, 0) 164 | 165 | def forward(self, x): 166 | x = self.conv1(x) 167 | x = self.bn1(x) 168 | x = F.relu(x, inplace=True) 169 | x = self.maxpool(x) 170 | 171 | s_hat = [] 172 | for block in self.feature: 173 | x, s = block(x) 174 | s_hat.append(s) 175 | s_hat = torch.cat(s_hat, 1) 176 | 177 | x = self.global_avgpool(x) 178 | x = self.fc_x(x) 179 | s_hat = self.fc_s(s_hat) 180 | 181 | v = (x + s_hat) * 0.5 182 | v = v.view(v.size(0), -1) 183 | 184 | if not self.training: 185 | return v 186 | 187 | y = self.classifier(v) 188 | 189 | if self.loss == 'softmax': 190 | return y 191 | elif self.loss == 'triplet': 192 | return y, v 193 | else: 194 | raise KeyError('Unsupported loss: {}'.format(self.loss)) 195 | 196 | 197 | def init_pretrained_weights(model, model_url): 198 | """Initializes model with pretrained weights. 199 | 200 | Layers that don't match with pretrained layers in name or size are kept unchanged. 201 | """ 202 | pretrain_dict = model_zoo.load_url(model_url) 203 | model_dict = model.state_dict() 204 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 205 | model_dict.update(pretrain_dict) 206 | model.load_state_dict(model_dict) 207 | 208 | 209 | def mlfn(num_classes, loss='softmax', pretrained=True, **kwargs): 210 | model = MLFN(num_classes, loss, **kwargs) 211 | if pretrained: 212 | #init_pretrained_weights(model, model_urls['imagenet']) 213 | import warnings 214 | warnings.warn('The imagenet pretrained weights need to be manually downloaded from {}'.format(model_urls['imagenet'])) 215 | return model -------------------------------------------------------------------------------- /torchreid/models/mudeep.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | __all__ = ['MuDeep'] 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | import torchvision 10 | 11 | 12 | class ConvBlock(nn.Module): 13 | """Basic convolutional block. 14 | 15 | convolution + batch normalization + relu. 16 | 17 | Args: 18 | in_c (int): number of input channels. 19 | out_c (int): number of output channels. 20 | k (int or tuple): kernel size. 21 | s (int or tuple): stride. 22 | p (int or tuple): padding. 23 | """ 24 | 25 | def __init__(self, in_c, out_c, k, s, p): 26 | super(ConvBlock, self).__init__() 27 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) 28 | self.bn = nn.BatchNorm2d(out_c) 29 | 30 | def forward(self, x): 31 | return F.relu(self.bn(self.conv(x))) 32 | 33 | 34 | class ConvLayers(nn.Module): 35 | """Preprocessing layers.""" 36 | 37 | def __init__(self): 38 | super(ConvLayers, self).__init__() 39 | self.conv1 = ConvBlock(3, 48, k=3, s=1, p=1) 40 | self.conv2 = ConvBlock(48, 96, k=3, s=1, p=1) 41 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 42 | 43 | def forward(self, x): 44 | x = self.conv1(x) 45 | x = self.conv2(x) 46 | x = self.maxpool(x) 47 | return x 48 | 49 | 50 | class MultiScaleA(nn.Module): 51 | """Multi-scale stream layer A (Sec.3.1)""" 52 | 53 | def __init__(self): 54 | super(MultiScaleA, self).__init__() 55 | self.stream1 = nn.Sequential( 56 | ConvBlock(96, 96, k=1, s=1, p=0), 57 | ConvBlock(96, 24, k=3, s=1, p=1), 58 | ) 59 | self.stream2 = nn.Sequential( 60 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 61 | ConvBlock(96, 24, k=1, s=1, p=0), 62 | ) 63 | self.stream3 = ConvBlock(96, 24, k=1, s=1, p=0) 64 | self.stream4 = nn.Sequential( 65 | ConvBlock(96, 16, k=1, s=1, p=0), 66 | ConvBlock(16, 24, k=3, s=1, p=1), 67 | ConvBlock(24, 24, k=3, s=1, p=1), 68 | ) 69 | 70 | def forward(self, x): 71 | s1 = self.stream1(x) 72 | s2 = self.stream2(x) 73 | s3 = self.stream3(x) 74 | s4 = self.stream4(x) 75 | y = torch.cat([s1, s2, s3, s4], dim=1) 76 | return y 77 | 78 | 79 | class Reduction(nn.Module): 80 | """Reduction layer (Sec.3.1)""" 81 | 82 | def __init__(self): 83 | super(Reduction, self).__init__() 84 | self.stream1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 85 | self.stream2 = ConvBlock(96, 96, k=3, s=2, p=1) 86 | self.stream3 = nn.Sequential( 87 | ConvBlock(96, 48, k=1, s=1, p=0), 88 | ConvBlock(48, 56, k=3, s=1, p=1), 89 | ConvBlock(56, 64, k=3, s=2, p=1), 90 | ) 91 | 92 | def forward(self, x): 93 | s1 = self.stream1(x) 94 | s2 = self.stream2(x) 95 | s3 = self.stream3(x) 96 | y = torch.cat([s1, s2, s3], dim=1) 97 | return y 98 | 99 | 100 | class MultiScaleB(nn.Module): 101 | """Multi-scale stream layer B (Sec.3.1)""" 102 | 103 | def __init__(self): 104 | super(MultiScaleB, self).__init__() 105 | self.stream1 = nn.Sequential( 106 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 107 | ConvBlock(256, 256, k=1, s=1, p=0), 108 | ) 109 | self.stream2 = nn.Sequential( 110 | ConvBlock(256, 64, k=1, s=1, p=0), 111 | ConvBlock(64, 128, k=(1, 3), s=1, p=(0, 1)), 112 | ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)), 113 | ) 114 | self.stream3 = ConvBlock(256, 256, k=1, s=1, p=0) 115 | self.stream4 = nn.Sequential( 116 | ConvBlock(256, 64, k=1, s=1, p=0), 117 | ConvBlock(64, 64, k=(1, 3), s=1, p=(0, 1)), 118 | ConvBlock(64, 128, k=(3, 1), s=1, p=(1, 0)), 119 | ConvBlock(128, 128, k=(1, 3), s=1, p=(0, 1)), 120 | ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)), 121 | ) 122 | 123 | def forward(self, x): 124 | s1 = self.stream1(x) 125 | s2 = self.stream2(x) 126 | s3 = self.stream3(x) 127 | s4 = self.stream4(x) 128 | return s1, s2, s3, s4 129 | 130 | 131 | class Fusion(nn.Module): 132 | """Saliency-based learning fusion layer (Sec.3.2)""" 133 | 134 | def __init__(self): 135 | super(Fusion, self).__init__() 136 | self.a1 = nn.Parameter(torch.rand(1, 256, 1, 1)) 137 | self.a2 = nn.Parameter(torch.rand(1, 256, 1, 1)) 138 | self.a3 = nn.Parameter(torch.rand(1, 256, 1, 1)) 139 | self.a4 = nn.Parameter(torch.rand(1, 256, 1, 1)) 140 | 141 | # We add an average pooling layer to reduce the spatial dimension 142 | # of feature maps, which differs from the original paper. 143 | self.avgpool = nn.AvgPool2d(kernel_size=4, stride=4, padding=0) 144 | 145 | def forward(self, x1, x2, x3, x4): 146 | s1 = self.a1.expand_as(x1) * x1 147 | s2 = self.a2.expand_as(x2) * x2 148 | s3 = self.a3.expand_as(x3) * x3 149 | s4 = self.a4.expand_as(x4) * x4 150 | y = self.avgpool(s1 + s2 + s3 + s4) 151 | return y 152 | 153 | 154 | class MuDeep(nn.Module): 155 | """Multiscale deep neural network. 156 | 157 | Reference: 158 | Qian et al. Multi-scale Deep Learning Architectures 159 | for Person Re-identification. ICCV 2017. 160 | 161 | Public keys: 162 | - ``mudeep``: Multiscale deep neural network. 163 | """ 164 | 165 | def __init__(self, num_classes, loss='softmax', **kwargs): 166 | super(MuDeep, self).__init__() 167 | self.loss = loss 168 | 169 | self.block1 = ConvLayers() 170 | self.block2 = MultiScaleA() 171 | self.block3 = Reduction() 172 | self.block4 = MultiScaleB() 173 | self.block5 = Fusion() 174 | 175 | # Due to this fully connected layer, input image has to be fixed 176 | # in shape, i.e. (3, 256, 128), such that the last convolutional feature 177 | # maps are of shape (256, 16, 8). If input shape is changed, 178 | # the input dimension of this layer has to be changed accordingly. 179 | self.fc = nn.Sequential( 180 | nn.Linear(256*16*8, 4096), 181 | nn.BatchNorm1d(4096), 182 | nn.ReLU(), 183 | ) 184 | self.classifier = nn.Linear(4096, num_classes) 185 | self.feat_dim = 4096 186 | 187 | def featuremaps(self, x): 188 | x = self.block1(x) 189 | x = self.block2(x) 190 | x = self.block3(x) 191 | x = self.block4(x) 192 | x = self.block5(*x) 193 | return x 194 | 195 | def forward(self, x): 196 | x = self.featuremaps(x) 197 | x = x.view(x.size(0), -1) 198 | x = self.fc(x) 199 | y = self.classifier(x) 200 | 201 | if self.loss == 'softmax': 202 | return y 203 | elif self.loss == 'triplet': 204 | return y, x 205 | else: 206 | raise KeyError('Unsupported loss: {}'.format(self.loss)) -------------------------------------------------------------------------------- /torchreid/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | __all__ = ['shufflenet'] 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | import torchvision 10 | import torch.utils.model_zoo as model_zoo 11 | 12 | 13 | model_urls = { 14 | # training epoch = 90, top1 = 61.8 15 | 'imagenet': 'https://mega.nz/#!RDpUlQCY!tr_5xBEkelzDjveIYBBcGcovNCOrgfiJO9kiidz9fZM', 16 | } 17 | 18 | 19 | class ChannelShuffle(nn.Module): 20 | 21 | def __init__(self, num_groups): 22 | super(ChannelShuffle, self).__init__() 23 | self.g = num_groups 24 | 25 | def forward(self, x): 26 | b, c, h, w = x.size() 27 | n = c // self.g 28 | # reshape 29 | x = x.view(b, self.g, n, h, w) 30 | # transpose 31 | x = x.permute(0, 2, 1, 3, 4).contiguous() 32 | # flatten 33 | x = x.view(b, c, h, w) 34 | return x 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | 39 | def __init__(self, in_channels, out_channels, stride, num_groups, group_conv1x1=True): 40 | super(Bottleneck, self).__init__() 41 | assert stride in [1, 2], 'Warning: stride must be either 1 or 2' 42 | self.stride = stride 43 | mid_channels = out_channels // 4 44 | if stride == 2: out_channels -= in_channels 45 | # group conv is not applied to first conv1x1 at stage 2 46 | num_groups_conv1x1 = num_groups if group_conv1x1 else 1 47 | self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, groups=num_groups_conv1x1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(mid_channels) 49 | self.shuffle1 = ChannelShuffle(num_groups) 50 | self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, stride=stride, padding=1, groups=mid_channels, bias=False) 51 | self.bn2 = nn.BatchNorm2d(mid_channels) 52 | self.conv3 = nn.Conv2d(mid_channels, out_channels, 1, groups=num_groups, bias=False) 53 | self.bn3 = nn.BatchNorm2d(out_channels) 54 | if stride == 2: self.shortcut = nn.AvgPool2d(3, stride=2, padding=1) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = self.shuffle1(out) 59 | out = self.bn2(self.conv2(out)) 60 | out = self.bn3(self.conv3(out)) 61 | if self.stride == 2: 62 | res = self.shortcut(x) 63 | out = F.relu(torch.cat([res, out], 1)) 64 | else: 65 | out = F.relu(x + out) 66 | return out 67 | 68 | 69 | # configuration of (num_groups: #out_channels) based on Table 1 in the paper 70 | cfg = { 71 | 1: [144, 288, 576], 72 | 2: [200, 400, 800], 73 | 3: [240, 480, 960], 74 | 4: [272, 544, 1088], 75 | 8: [384, 768, 1536], 76 | } 77 | 78 | 79 | class ShuffleNet(nn.Module): 80 | """ShuffleNet. 81 | 82 | Reference: 83 | Zhang et al. ShuffleNet: An Extremely Efficient Convolutional Neural 84 | Network for Mobile Devices. CVPR 2018. 85 | 86 | Public keys: 87 | - ``shufflenet``: ShuffleNet (groups=3). 88 | """ 89 | 90 | def __init__(self, num_classes, loss='softmax', num_groups=3, **kwargs): 91 | super(ShuffleNet, self).__init__() 92 | self.loss = loss 93 | 94 | self.conv1 = nn.Sequential( 95 | nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False), 96 | nn.BatchNorm2d(24), 97 | nn.ReLU(), 98 | nn.MaxPool2d(3, stride=2, padding=1), 99 | ) 100 | 101 | self.stage2 = nn.Sequential( 102 | Bottleneck(24, cfg[num_groups][0], 2, num_groups, group_conv1x1=False), 103 | Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), 104 | Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), 105 | Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), 106 | ) 107 | 108 | self.stage3 = nn.Sequential( 109 | Bottleneck(cfg[num_groups][0], cfg[num_groups][1], 2, num_groups), 110 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 111 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 112 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 113 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 114 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 115 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 116 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 117 | ) 118 | 119 | self.stage4 = nn.Sequential( 120 | Bottleneck(cfg[num_groups][1], cfg[num_groups][2], 2, num_groups), 121 | Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), 122 | Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), 123 | Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), 124 | ) 125 | 126 | self.classifier = nn.Linear(cfg[num_groups][2], num_classes) 127 | self.feat_dim = cfg[num_groups][2] 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | x = self.stage2(x) 132 | x = self.stage3(x) 133 | x = self.stage4(x) 134 | x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1) 135 | 136 | if not self.training: 137 | return x 138 | 139 | y = self.classifier(x) 140 | 141 | if self.loss == 'softmax': 142 | return y 143 | elif self.loss == 'triplet': 144 | return y, x 145 | else: 146 | raise KeyError('Unsupported loss: {}'.format(self.loss)) 147 | 148 | 149 | def init_pretrained_weights(model, model_url): 150 | """Initializes model with pretrained weights. 151 | 152 | Layers that don't match with pretrained layers in name or size are kept unchanged. 153 | """ 154 | pretrain_dict = model_zoo.load_url(model_url) 155 | model_dict = model.state_dict() 156 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 157 | model_dict.update(pretrain_dict) 158 | model.load_state_dict(model_dict) 159 | 160 | 161 | def shufflenet(num_classes, loss='softmax', pretrained=True, **kwargs): 162 | model = ShuffleNet(num_classes, loss, **kwargs) 163 | if pretrained: 164 | #init_pretrained_weights(model, model_urls['imagenet']) 165 | import warnings 166 | warnings.warn('The imagenet pretrained weights need to be manually downloaded from {}'.format(model_urls['imagenet'])) 167 | return model -------------------------------------------------------------------------------- /torchreid/models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code source: https://github.com/pytorch/vision 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | 7 | __all__ = ['shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'] 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | import torchvision 13 | import torch.utils.model_zoo as model_zoo 14 | 15 | 16 | model_urls = { 17 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 18 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 19 | 'shufflenetv2_x1.5': None, 20 | 'shufflenetv2_x2.0': None, 21 | } 22 | 23 | 24 | def channel_shuffle(x, groups): 25 | batchsize, num_channels, height, width = x.data.size() 26 | channels_per_group = num_channels // groups 27 | 28 | # reshape 29 | x = x.view(batchsize, groups, 30 | channels_per_group, height, width) 31 | 32 | x = torch.transpose(x, 1, 2).contiguous() 33 | 34 | # flatten 35 | x = x.view(batchsize, -1, height, width) 36 | 37 | return x 38 | 39 | 40 | class InvertedResidual(nn.Module): 41 | 42 | def __init__(self, inp, oup, stride): 43 | super(InvertedResidual, self).__init__() 44 | 45 | if not (1 <= stride <= 3): 46 | raise ValueError('illegal stride value') 47 | self.stride = stride 48 | 49 | branch_features = oup // 2 50 | assert (self.stride != 1) or (inp == branch_features << 1) 51 | 52 | if self.stride > 1: 53 | self.branch1 = nn.Sequential( 54 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 55 | nn.BatchNorm2d(inp), 56 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 57 | nn.BatchNorm2d(branch_features), 58 | nn.ReLU(inplace=True), 59 | ) 60 | 61 | self.branch2 = nn.Sequential( 62 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 63 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 64 | nn.BatchNorm2d(branch_features), 65 | nn.ReLU(inplace=True), 66 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 67 | nn.BatchNorm2d(branch_features), 68 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 69 | nn.BatchNorm2d(branch_features), 70 | nn.ReLU(inplace=True), 71 | ) 72 | 73 | @staticmethod 74 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 75 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 76 | 77 | def forward(self, x): 78 | if self.stride == 1: 79 | x1, x2 = x.chunk(2, dim=1) 80 | out = torch.cat((x1, self.branch2(x2)), dim=1) 81 | else: 82 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 83 | 84 | out = channel_shuffle(out, 2) 85 | 86 | return out 87 | 88 | 89 | class ShuffleNetV2(nn.Module): 90 | """ShuffleNetV2. 91 | 92 | Reference: 93 | Ma et al. ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design. ECCV 2018. 94 | 95 | Public keys: 96 | - ``shufflenet_v2_x0_5``: ShuffleNetV2 x0.5. 97 | - ``shufflenet_v2_x1_0``: ShuffleNetV2 x1.0. 98 | - ``shufflenet_v2_x1_5``: ShuffleNetV2 x1.5. 99 | - ``shufflenet_v2_x2_0``: ShuffleNetV2 x2.0. 100 | """ 101 | 102 | def __init__(self, num_classes, loss, stages_repeats, stages_out_channels, **kwargs): 103 | super(ShuffleNetV2, self).__init__() 104 | self.loss = loss 105 | 106 | if len(stages_repeats) != 3: 107 | raise ValueError('expected stages_repeats as list of 3 positive ints') 108 | if len(stages_out_channels) != 5: 109 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 110 | self._stage_out_channels = stages_out_channels 111 | 112 | input_channels = 3 113 | output_channels = self._stage_out_channels[0] 114 | self.conv1 = nn.Sequential( 115 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 116 | nn.BatchNorm2d(output_channels), 117 | nn.ReLU(inplace=True), 118 | ) 119 | input_channels = output_channels 120 | 121 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 122 | 123 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 124 | for name, repeats, output_channels in zip( 125 | stage_names, stages_repeats, self._stage_out_channels[1:]): 126 | seq = [InvertedResidual(input_channels, output_channels, 2)] 127 | for i in range(repeats - 1): 128 | seq.append(InvertedResidual(output_channels, output_channels, 1)) 129 | setattr(self, name, nn.Sequential(*seq)) 130 | input_channels = output_channels 131 | 132 | output_channels = self._stage_out_channels[-1] 133 | self.conv5 = nn.Sequential( 134 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 135 | nn.BatchNorm2d(output_channels), 136 | nn.ReLU(inplace=True), 137 | ) 138 | self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 139 | 140 | self.classifier = nn.Linear(output_channels, num_classes) 141 | 142 | def featuremaps(self, x): 143 | x = self.conv1(x) 144 | x = self.maxpool(x) 145 | x = self.stage2(x) 146 | x = self.stage3(x) 147 | x = self.stage4(x) 148 | x = self.conv5(x) 149 | return x 150 | 151 | def forward(self, x): 152 | f = self.featuremaps(x) 153 | v = self.global_avgpool(f) 154 | v = v.view(v.size(0), -1) 155 | 156 | if not self.training: 157 | return v 158 | 159 | y = self.classifier(v) 160 | 161 | if self.loss == 'softmax': 162 | return y 163 | elif self.loss == 'triplet': 164 | return y, v 165 | else: 166 | raise KeyError("Unsupported loss: {}".format(self.loss)) 167 | 168 | 169 | def init_pretrained_weights(model, model_url): 170 | """Initializes model with pretrained weights. 171 | 172 | Layers that don't match with pretrained layers in name or size are kept unchanged. 173 | """ 174 | if model_url is None: 175 | import warnings 176 | warnings.warn('ImageNet pretrained weights are unavailable for this model') 177 | return 178 | pretrain_dict = model_zoo.load_url(model_url) 179 | model_dict = model.state_dict() 180 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 181 | model_dict.update(pretrain_dict) 182 | model.load_state_dict(model_dict) 183 | 184 | 185 | def shufflenet_v2_x0_5(num_classes, loss='softmax', pretrained=True, **kwargs): 186 | model = ShuffleNetV2(num_classes, loss, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 187 | if pretrained: 188 | init_pretrained_weights(model, model_urls['shufflenetv2_x0.5']) 189 | return model 190 | 191 | 192 | def shufflenet_v2_x1_0(num_classes, loss='softmax', pretrained=True, **kwargs): 193 | model = ShuffleNetV2(num_classes, loss, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 194 | if pretrained: 195 | init_pretrained_weights(model, model_urls['shufflenetv2_x1.0']) 196 | return model 197 | 198 | 199 | def shufflenet_v2_x1_5(num_classes, loss='softmax', pretrained=True, **kwargs): 200 | model = ShuffleNetV2(num_classes, loss, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) 201 | if pretrained: 202 | init_pretrained_weights(model, model_urls['shufflenetv2_x1.5']) 203 | return model 204 | 205 | 206 | def shufflenet_v2_x2_0(num_classes, loss='softmax', pretrained=True, **kwargs): 207 | model = ShuffleNetV2(num_classes, loss, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 208 | if pretrained: 209 | init_pretrained_weights(model, model_urls['shufflenetv2_x2.0']) 210 | return model 211 | -------------------------------------------------------------------------------- /torchreid/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .optimizer import build_optimizer 4 | from .lr_scheduler import build_lr_scheduler -------------------------------------------------------------------------------- /torchreid/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import torch 5 | 6 | 7 | AVAI_SCH = ['single_step', 'multi_step'] 8 | 9 | 10 | def build_lr_scheduler(optimizer, lr_scheduler, stepsize, gamma=0.1): 11 | """A function wrapper for building a learning rate scheduler. 12 | 13 | Args: 14 | optimizer (Optimizer): an Optimizer. 15 | lr_scheduler (str): learning rate scheduler method. Currently supports 16 | "single_step" and "multi_step". 17 | stepsize (int or list): step size to decay learning rate. When ``lr_scheduler`` is 18 | "single_step", ``stepsize`` should be an integer. When ``lr_scheduler`` is 19 | "multi_step", ``stepsize`` is a list. 20 | gamma (float, optional): decay rate. Default is 0.1. 21 | 22 | Examples:: 23 | >>> # Decay learning rate by every 20 epochs. 24 | >>> scheduler = torchreid.optim.build_lr_scheduler( 25 | >>> optimizer, lr_scheduler='single_step', stepsize=20 26 | >>> ) 27 | >>> # Decay learning rate at 30, 50 and 55 epochs. 28 | >>> scheduler = torchreid.optim.build_lr_scheduler( 29 | >>> optimizer, lr_scheduler='multi_step', stepsize=[30, 50, 55] 30 | >>> ) 31 | """ 32 | if lr_scheduler not in AVAI_SCH: 33 | raise ValueError('Unsupported scheduler: {}. Must be one of {}'.format(lr_scheduler, AVAI_SCH)) 34 | 35 | if lr_scheduler == 'single_step': 36 | if isinstance(stepsize, list): 37 | stepsize = stepsize[-1] 38 | 39 | if not isinstance(stepsize, int): 40 | raise TypeError( 41 | 'For single_step lr_scheduler, stepsize must ' 42 | 'be an integer, but got {}'.format(type(stepsize)) 43 | ) 44 | 45 | scheduler = torch.optim.lr_scheduler.StepLR( 46 | optimizer, step_size=stepsize, gamma=gamma 47 | ) 48 | 49 | elif lr_scheduler == 'multi_step': 50 | if not isinstance(stepsize, list): 51 | raise TypeError( 52 | 'For multi_step lr_scheduler, stepsize must ' 53 | 'be a list, but got {}'.format(type(stepsize)) 54 | ) 55 | 56 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 57 | optimizer, milestones=stepsize, gamma=gamma 58 | ) 59 | 60 | return scheduler -------------------------------------------------------------------------------- /torchreid/optim/optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import warnings 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | AVAI_OPTIMS = ['adam', 'amsgrad', 'sgd', 'rmsprop'] 11 | 12 | 13 | def build_optimizer( 14 | model, 15 | optim='adam', 16 | lr=0.0003, 17 | weight_decay=5e-04, 18 | momentum=0.9, 19 | sgd_dampening=0, 20 | sgd_nesterov=False, 21 | rmsprop_alpha=0.99, 22 | adam_beta1=0.9, 23 | adam_beta2=0.99, 24 | staged_lr=False, 25 | new_layers='', 26 | base_lr_mult=0.1 27 | ): 28 | """A function wrapper for building an optimizer. 29 | 30 | Args: 31 | model (nn.Module): model. 32 | optim (str, optional): optimizer. Default is "adam". 33 | lr (float, optional): learning rate. Default is 0.0003. 34 | weight_decay (float, optional): weight decay (L2 penalty). Default is 5e-04. 35 | momentum (float, optional): momentum factor in sgd. Default is 0.9, 36 | sgd_dampening (float, optional): dampening for momentum. Default is 0. 37 | sgd_nesterov (bool, optional): enables Nesterov momentum. Default is False. 38 | rmsprop_alpha (float, optional): smoothing constant for rmsprop. Default is 0.99. 39 | adam_beta1 (float, optional): beta-1 value in adam. Default is 0.9. 40 | adam_beta2 (float, optional): beta-2 value in adam. Default is 0.99, 41 | staged_lr (bool, optional): uses different learning rates for base and new layers. Base 42 | layers are pretrained layers while new layers are randomly initialized, e.g. the 43 | identity classification layer. Enabling ``staged_lr`` can allow the base layers to 44 | be trained with a smaller learning rate determined by ``base_lr_mult``, while the new 45 | layers will take the ``lr``. Default is False. 46 | new_layers (str or list): attribute names in ``model``. Default is empty. 47 | base_lr_mult (float, optional): learning rate multiplier for base layers. Default is 0.1. 48 | 49 | Examples:: 50 | >>> # A normal optimizer can be built by 51 | >>> optimizer = torchreid.optim.build_optimizer(model, optim='sgd', lr=0.01) 52 | >>> # If you want to use a smaller learning rate for pretrained layers 53 | >>> # and the attribute name for the randomly initialized layer is 'classifier', 54 | >>> # you can do 55 | >>> optimizer = torchreid.optim.build_optimizer( 56 | >>> model, optim='sgd', lr=0.01, staged_lr=True, 57 | >>> new_layers='classifier', base_lr_mult=0.1 58 | >>> ) 59 | >>> # Now the `classifier` has learning rate 0.01 but the base layers 60 | >>> # have learning rate 0.01 * 0.1. 61 | >>> # new_layers can also take multiple attribute names. Say the new layers 62 | >>> # are 'fc' and 'classifier', you can do 63 | >>> optimizer = torchreid.optim.build_optimizer( 64 | >>> model, optim='sgd', lr=0.01, staged_lr=True, 65 | >>> new_layers=['fc', 'classifier'], base_lr_mult=0.1 66 | >>> ) 67 | """ 68 | if optim not in AVAI_OPTIMS: 69 | raise ValueError('Unsupported optim: {}. Must be one of {}'.format(optim, AVAI_OPTIMS)) 70 | 71 | if not isinstance(model, nn.Module): 72 | raise TypeError('model given to build_optimizer must be an instance of nn.Module') 73 | 74 | if staged_lr: 75 | if isinstance(new_layers, str): 76 | if new_layers is None: 77 | warnings.warn('new_layers is empty, therefore, staged_lr is useless') 78 | new_layers = [new_layers] 79 | 80 | if isinstance(model, nn.DataParallel): 81 | model = model.module 82 | 83 | base_params = [] 84 | base_layers = [] 85 | new_params = [] 86 | 87 | for name, module in model.named_children(): 88 | if name in new_layers: 89 | new_params += [p for p in module.parameters()] 90 | else: 91 | base_params += [p for p in module.parameters()] 92 | base_layers.append(name) 93 | 94 | param_groups = [ 95 | {'params': base_params, 'lr': lr * base_lr_mult}, 96 | {'params': new_params}, 97 | ] 98 | 99 | else: 100 | param_groups = model.parameters() 101 | 102 | if optim == 'adam': 103 | optimizer = torch.optim.Adam( 104 | param_groups, 105 | lr=lr, 106 | weight_decay=weight_decay, 107 | betas=(adam_beta1, adam_beta2), 108 | ) 109 | 110 | elif optim == 'amsgrad': 111 | optimizer = torch.optim.Adam( 112 | param_groups, 113 | lr=lr, 114 | weight_decay=weight_decay, 115 | betas=(adam_beta1, adam_beta2), 116 | amsgrad=True, 117 | ) 118 | 119 | elif optim == 'sgd': 120 | optimizer = torch.optim.SGD( 121 | param_groups, 122 | lr=lr, 123 | momentum=momentum, 124 | weight_decay=weight_decay, 125 | dampening=sgd_dampening, 126 | nesterov=sgd_nesterov, 127 | ) 128 | 129 | elif optim == 'rmsprop': 130 | optimizer = torch.optim.RMSprop( 131 | param_groups, 132 | lr=lr, 133 | momentum=momentum, 134 | weight_decay=weight_decay, 135 | alpha=rmsprop_alpha, 136 | ) 137 | 138 | return optimizer -------------------------------------------------------------------------------- /torchreid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .avgmeter import * 4 | from .loggers import * 5 | from .tools import * 6 | from .reidtools import * 7 | from .torchtools import * 8 | from .rerank import re_ranking 9 | from .model_complexity import compute_model_complexity 10 | -------------------------------------------------------------------------------- /torchreid/utils/avgmeter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | __all__ = ['AverageMeter'] 5 | 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value. 9 | 10 | Examples:: 11 | >>> # Initialize a meter to record loss 12 | >>> losses = AverageMeter() 13 | >>> # Update meter after every minibatch update 14 | >>> losses.update(loss_value, batch_size) 15 | """ 16 | def __init__(self): 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /torchreid/utils/loggers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | __all__ = ['Logger', 'RankLogger'] 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | 9 | from .tools import mkdir_if_missing 10 | 11 | 12 | class Logger(object): 13 | """Writes console output to external text file. 14 | 15 | Imported from ``_ 16 | 17 | Args: 18 | fpath (str): directory to save logging file. 19 | 20 | Examples:: 21 | >>> import sys 22 | >>> import os 23 | >>> import os.path as osp 24 | >>> from torchreid.utils import Logger 25 | >>> save_dir = 'log/resnet50-softmax-market1501' 26 | >>> log_name = 'train.log' 27 | >>> sys.stdout = Logger(osp.join(args.save_dir, log_name)) 28 | """ 29 | def __init__(self, fpath=None): 30 | self.console = sys.stdout 31 | self.file = None 32 | if fpath is not None: 33 | mkdir_if_missing(osp.dirname(fpath)) 34 | self.file = open(fpath, 'w') 35 | 36 | def __del__(self): 37 | self.close() 38 | 39 | def __enter__(self): 40 | pass 41 | 42 | def __exit__(self, *args): 43 | self.close() 44 | 45 | def write(self, msg): 46 | self.console.write(msg) 47 | if self.file is not None: 48 | self.file.write(msg) 49 | 50 | def flush(self): 51 | self.console.flush() 52 | if self.file is not None: 53 | self.file.flush() 54 | os.fsync(self.file.fileno()) 55 | 56 | def close(self): 57 | self.console.close() 58 | if self.file is not None: 59 | self.file.close() 60 | 61 | 62 | class RankLogger(object): 63 | """Records the rank1 matching accuracy obtained for each 64 | test dataset at specified evaluation steps and provides a function 65 | to show the summarized results, which are convenient for analysis. 66 | 67 | Args: 68 | sources (str or list): source dataset name(s). 69 | targets (str or list): target dataset name(s). 70 | 71 | Examples:: 72 | >>> from torchreid.utils import RankLogger 73 | >>> s = 'market1501' 74 | >>> t = 'market1501' 75 | >>> ranklogger = RankLogger(s, t) 76 | >>> ranklogger.write(t, 10, 0.5) 77 | >>> ranklogger.write(t, 20, 0.7) 78 | >>> ranklogger.write(t, 30, 0.9) 79 | >>> ranklogger.show_summary() 80 | >>> # You will see: 81 | >>> # => Show performance summary 82 | >>> # market1501 (source) 83 | >>> # - epoch 10 rank1 50.0% 84 | >>> # - epoch 20 rank1 70.0% 85 | >>> # - epoch 30 rank1 90.0% 86 | >>> # If there are multiple test datasets 87 | >>> t = ['market1501', 'dukemtmcreid'] 88 | >>> ranklogger = RankLogger(s, t) 89 | >>> ranklogger.write(t[0], 10, 0.5) 90 | >>> ranklogger.write(t[0], 20, 0.7) 91 | >>> ranklogger.write(t[0], 30, 0.9) 92 | >>> ranklogger.write(t[1], 10, 0.1) 93 | >>> ranklogger.write(t[1], 20, 0.2) 94 | >>> ranklogger.write(t[1], 30, 0.3) 95 | >>> ranklogger.show_summary() 96 | >>> # You can see: 97 | >>> # => Show performance summary 98 | >>> # market1501 (source) 99 | >>> # - epoch 10 rank1 50.0% 100 | >>> # - epoch 20 rank1 70.0% 101 | >>> # - epoch 30 rank1 90.0% 102 | >>> # dukemtmcreid (target) 103 | >>> # - epoch 10 rank1 10.0% 104 | >>> # - epoch 20 rank1 20.0% 105 | >>> # - epoch 30 rank1 30.0% 106 | """ 107 | def __init__(self, sources, targets): 108 | self.sources = sources 109 | self.targets = targets 110 | 111 | if isinstance(self.sources, str): 112 | self.sources = [self.sources] 113 | 114 | if isinstance(self.targets, str): 115 | self.targets = [self.targets] 116 | 117 | self.logger = {name: {'epoch': [], 'rank1': []} for name in self.targets} 118 | 119 | def write(self, name, epoch, rank1): 120 | """Writes result. 121 | 122 | Args: 123 | name (str): dataset name. 124 | epoch (int): current epoch. 125 | rank1 (float): rank1 result. 126 | """ 127 | self.logger[name]['epoch'].append(epoch) 128 | self.logger[name]['rank1'].append(rank1) 129 | 130 | def show_summary(self): 131 | """Shows saved results.""" 132 | print('=> Show performance summary') 133 | for name in self.targets: 134 | from_where = 'source' if name in self.sources else 'target' 135 | print('{} ({})'.format(name, from_where)) 136 | for epoch, rank1 in zip(self.logger[name]['epoch'], self.logger[name]['rank1']): 137 | print('- epoch {}\t rank1 {:.1%}'.format(epoch, rank1)) -------------------------------------------------------------------------------- /torchreid/utils/reidtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | __all__ = ['visualize_ranked_results'] 5 | 6 | import numpy as np 7 | import os 8 | import os.path as osp 9 | import shutil 10 | # from os import listdir 11 | from PIL import Image 12 | from PIL import Image,ImageDraw,ImageFont 13 | 14 | from .tools import mkdir_if_missing 15 | 16 | 17 | def visualize_ranked_results(distmat, dataset, save_dir='', topk=20): 18 | """Visualizes ranked results. 19 | 20 | Supports both image-reid and video-reid. 21 | 22 | Args: 23 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 24 | dataset (tuple): a 2-tuple containing (query, gallery), each of which contains 25 | tuples of (img_path(s), pid, camid). 26 | save_dir (str): directory to save output images. 27 | topk (int, optional): denoting top-k images in the rank list to be visualized. 28 | """ 29 | num_q, num_g = distmat.shape 30 | 31 | print('Visualizing top-{} ranks'.format(topk)) 32 | print('# query: {}\n# gallery {}'.format(num_q, num_g)) 33 | print('Saving images to "{}"'.format(save_dir)) 34 | 35 | query, gallery = dataset 36 | assert num_q == len(query) 37 | assert num_g == len(gallery) 38 | 39 | indices = np.argsort(distmat, axis=1) 40 | mkdir_if_missing(save_dir) 41 | 42 | def _cp_img_to(src, dst, rank, prefix): 43 | """ 44 | Args: 45 | src: image path or tuple (for vidreid) 46 | dst: target directory 47 | rank: int, denoting ranked position, starting from 1 48 | prefix: string 49 | """ 50 | if isinstance(src, tuple) or isinstance(src, list): 51 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 52 | mkdir_if_missing(dst) 53 | for img_path in src: 54 | shutil.copy(img_path, dst) 55 | else: 56 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 57 | shutil.copy(src, dst) 58 | HEIGHT = 256 59 | WIDTH = 128 60 | for q_idx in range(num_q): 61 | ims = [] 62 | qimg_path, qpid, qcamid = query[q_idx] 63 | 64 | save_img_path = os.path.join(save_dir, qimg_path.split('/')[-1]) 65 | q_im = Image.open(qimg_path).resize((WIDTH, HEIGHT), Image.BILINEAR) 66 | ims.append(q_im) 67 | 68 | # if isinstance(qimg_path, tuple) or isinstance(qimg_path, list): 69 | # qdir = osp.join(save_dir, osp.basename(qimg_path[0])) 70 | # else: 71 | # qdir = osp.join(save_dir, osp.basename(qimg_path)) 72 | # mkdir_if_missing(qdir) 73 | # _cp_img_to(qimg_path, qdir, rank=0, prefix='query') 74 | 75 | rank_idx = 1 76 | for g_idx in indices[q_idx,:]: 77 | gimg_path, gpid, gcamid = gallery[g_idx] 78 | invalid = (qpid == gpid) & (qcamid == gcamid) 79 | if not invalid: 80 | g_im = Image.open(gimg_path).resize((WIDTH, HEIGHT), Image.BILINEAR) 81 | draw = ImageDraw.Draw(g_im) 82 | if gpid==qpid: 83 | color = (0,255,0) 84 | else: 85 | color = (255,0,0) 86 | draw.text((8, 8), str(gpid), fill=color) 87 | ims.append(g_im) 88 | rank_idx += 1 89 | if rank_idx > topk: 90 | break 91 | img_ = Image.new(ims[0].mode, (WIDTH*len(ims), HEIGHT)) 92 | for i, im in enumerate(ims): 93 | img_.paste(im, box=(i*WIDTH,0)) 94 | img_.save(save_img_path) 95 | 96 | print("Done") -------------------------------------------------------------------------------- /torchreid/utils/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Source: https://github.com/zhunzhong07/person-re-ranking 5 | 6 | Created on Mon Jun 26 14:46:56 2017 7 | @author: luohao 8 | Modified by Houjing Huang, 2017-12-22. 9 | - This version accepts distance matrix instead of raw features. 10 | - The difference of `/` division between python 2 and 3 is handled. 11 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 12 | 13 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 14 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 15 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 16 | 17 | API 18 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 19 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 20 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 21 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 22 | Returns: 23 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 24 | """ 25 | from __future__ import absolute_import 26 | from __future__ import print_function 27 | from __future__ import division 28 | 29 | __all__ = ['re_ranking'] 30 | 31 | import numpy as np 32 | 33 | 34 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 35 | 36 | # The following naming, e.g. gallery_num, is different from outer scope. 37 | # Don't care about it. 38 | 39 | original_dist = np.concatenate( 40 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 41 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 42 | axis=0) 43 | original_dist = np.power(original_dist, 2).astype(np.float32) 44 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 45 | V = np.zeros_like(original_dist).astype(np.float32) 46 | initial_rank = np.argsort(original_dist).astype(np.int32) 47 | 48 | query_num = q_g_dist.shape[0] 49 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 50 | all_num = gallery_num 51 | 52 | for i in range(all_num): 53 | # k-reciprocal neighbors 54 | forward_k_neigh_index = initial_rank[i,:k1+1] 55 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 56 | fi = np.where(backward_k_neigh_index==i)[0] 57 | k_reciprocal_index = forward_k_neigh_index[fi] 58 | k_reciprocal_expansion_index = k_reciprocal_index 59 | for j in range(len(k_reciprocal_index)): 60 | candidate = k_reciprocal_index[j] 61 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 62 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 66 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 67 | 68 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 69 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 70 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 71 | original_dist = original_dist[:query_num,] 72 | if k2 != 1: 73 | V_qe = np.zeros_like(V,dtype=np.float32) 74 | for i in range(all_num): 75 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 76 | V = V_qe 77 | del V_qe 78 | del initial_rank 79 | invIndex = [] 80 | for i in range(gallery_num): 81 | invIndex.append(np.where(V[:,i] != 0)[0]) 82 | 83 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 84 | 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 88 | indNonZero = np.where(V[i,:] != 0)[0] 89 | indImages = [] 90 | indImages = [invIndex[ind] for ind in indNonZero] 91 | for j in range(len(indNonZero)): 92 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 93 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 94 | 95 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num,query_num:] 100 | return final_dist 101 | -------------------------------------------------------------------------------- /torchreid/utils/tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | __all__ = ['mkdir_if_missing', 'check_isfile', 'read_json', 'write_json', 6 | 'set_random_seed', 'download_url', 'read_image', 'collect_env_info'] 7 | 8 | import sys 9 | import os 10 | import os.path as osp 11 | import time 12 | import errno 13 | import json 14 | from collections import OrderedDict 15 | import warnings 16 | import random 17 | import numpy as np 18 | import PIL 19 | from PIL import Image 20 | 21 | import torch 22 | 23 | 24 | def mkdir_if_missing(dirname): 25 | """Creates dirname if it is missing.""" 26 | if not osp.exists(dirname): 27 | try: 28 | os.makedirs(dirname) 29 | except OSError as e: 30 | if e.errno != errno.EEXIST: 31 | raise 32 | 33 | 34 | def check_isfile(fpath): 35 | """Checks if the given path is a file. 36 | 37 | Args: 38 | fpath (str): file path. 39 | 40 | Returns: 41 | bool 42 | """ 43 | isfile = osp.isfile(fpath) 44 | if not isfile: 45 | warnings.warn('No file found at "{}"'.format(fpath)) 46 | return isfile 47 | 48 | 49 | def read_json(fpath): 50 | """Reads json file from a path.""" 51 | with open(fpath, 'r') as f: 52 | obj = json.load(f) 53 | return obj 54 | 55 | 56 | def write_json(obj, fpath): 57 | """Writes to a json file.""" 58 | mkdir_if_missing(osp.dirname(fpath)) 59 | with open(fpath, 'w') as f: 60 | json.dump(obj, f, indent=4, separators=(',', ': ')) 61 | 62 | 63 | def set_random_seed(seed): 64 | random.seed(seed) 65 | np.random.seed(seed) 66 | torch.manual_seed(seed) 67 | torch.cuda.manual_seed_all(seed) 68 | 69 | 70 | def download_url(url, dst): 71 | """Downloads file from a url to a destination. 72 | 73 | Args: 74 | url (str): url to download file. 75 | dst (str): destination path. 76 | """ 77 | from six.moves import urllib 78 | print('* url="{}"'.format(url)) 79 | print('* destination="{}"'.format(dst)) 80 | 81 | def _reporthook(count, block_size, total_size): 82 | global start_time 83 | if count == 0: 84 | start_time = time.time() 85 | return 86 | duration = time.time() - start_time 87 | progress_size = int(count * block_size) 88 | speed = int(progress_size / (1024 * duration)) 89 | percent = int(count * block_size * 100 / total_size) 90 | sys.stdout.write('\r...%d%%, %d MB, %d KB/s, %d seconds passed' % 91 | (percent, progress_size / (1024 * 1024), speed, duration)) 92 | sys.stdout.flush() 93 | 94 | urllib.request.urlretrieve(url, dst, _reporthook) 95 | sys.stdout.write('\n') 96 | 97 | 98 | def read_image(path): 99 | """Reads image from path using ``PIL.Image``. 100 | 101 | Args: 102 | path (str): path to an image. 103 | 104 | Returns: 105 | PIL image 106 | """ 107 | got_img = False 108 | if not osp.exists(path): 109 | raise IOError('"{}" does not exist'.format(path)) 110 | while not got_img: 111 | try: 112 | img = Image.open(path).convert('RGB') 113 | got_img = True 114 | except IOError: 115 | print('IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'.format(img_path)) 116 | pass 117 | return img 118 | 119 | 120 | def collect_env_info(): 121 | """Returns env info as a string. 122 | 123 | Code source: github.com/facebookresearch/maskrcnn-benchmark 124 | """ 125 | from torch.utils.collect_env import get_pretty_env_info 126 | env_str = get_pretty_env_info() 127 | env_str += '\n Pillow ({})'.format(PIL.__version__) 128 | return env_str 129 | -------------------------------------------------------------------------------- /torchreid/utils/vis_featmat_cluster.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.cluster import DBSCAN, KMeans 3 | import torch 4 | # from PIL import Image 5 | import os 6 | import errno 7 | from matplotlib import pyplot as plt 8 | 9 | def mkdir_if_missing(dir_path): 10 | try: 11 | os.makedirs(dir_path) 12 | except OSError as e: 13 | if e.errno != errno.EEXIST: 14 | raise 15 | 16 | def euclidean_distance(inputs): 17 | n = inputs.size(0) 18 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 19 | dist = dist + dist.t() 20 | dist.addmm_(1, -2, inputs, inputs.t()) 21 | dist = dist.clamp(min=0).sqrt() 22 | return dist 23 | 24 | def vis_featmat_DBSCAN(feat_map, top_per=0.05, save_path='/home/hh/tmp/featmap_DBSCAN'): 25 | H,W = feat_map.size(2), feat_map.size(3) 26 | mkdir_if_missing(save_path) 27 | for i in range(feat_map.size(0)): 28 | feat = feat_map[i].view(feat_map.size(1), -1).transpose(0,1) 29 | dist_mat = euclidean_distance(feat).data.cpu().numpy() 30 | tri_mat = np.triu(dist_mat, 1) # tri_mat.dim=2 31 | tri_mat = tri_mat[np.nonzero(tri_mat)] # tri_mat.dim=1 32 | tri_mat = np.sort(tri_mat, axis=None) 33 | top_num = np.round(top_per*tri_mat.size).astype(int) 34 | eps = tri_mat[:top_num].mean() 35 | cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=8) 36 | labels = cluster.fit_predict(dist_mat) 37 | labels = labels.reshape(H,W) 38 | # labels = labels.reshape(H,W).astype(float) 39 | # labels = (labels/labels.max())*255 40 | # im = Image.fromarray(labels).convert('L') 41 | # im.save(os.path.join(save_path,'Dbscan_{}.jpg'.format(i))) 42 | plt.imshow(labels, cmap=plt.cm.hot_r) 43 | fig = plt.gcf() 44 | fig.savefig(os.path.join(save_path,'Dbscan_{}.jpg'.format(i))) 45 | 46 | def vis_featmat_Kmeans(feat_map, num_cluster=4, save_path='/home/hh/tmp/featmap_kmeans'): 47 | H,W = feat_map.size(2), feat_map.size(3) 48 | mkdir_if_missing(save_path) 49 | for i in range(feat_map.size(0)): 50 | feat = feat_map[i].view(feat_map.size(1), -1).transpose(0,1) 51 | dist_mat = euclidean_distance(feat).data.cpu().numpy() 52 | cluster = KMeans(n_clusters=num_cluster) 53 | labels = cluster.fit_predict(dist_mat) 54 | labels = labels.reshape(H,W) 55 | # labels = labels.reshape(H,W).astype(float) 56 | # labels = (labels/labels.max())*255 57 | # im = Image.fromarray(labels).convert('L') 58 | # im.save(os.path.join(save_path,'kmeans_{}.jpg'.format(i))) 59 | plt.imshow(labels, cmap=plt.cm.hot_r) 60 | fig = plt.gcf() 61 | fig.savefig(os.path.join(save_path,'kmeans_{}_{}.jpg'.format(i,num_cluster))) --------------------------------------------------------------------------------