├── .gitignore ├── LICENSE ├── README.md ├── configs ├── default.py ├── res50_arcface_triplet.yaml ├── res50_ce_contrastive.yaml ├── res50_ce_cosface.yaml ├── res50_ce_triplet.yaml ├── res50_celabelsmooth_contrastive.yaml ├── res50_celabelsmooth_cosface.yaml ├── res50_celabelsmooth_triplet.yaml ├── res50_circle_paircircle.yaml ├── res50_cosface_cosface.yaml ├── res50_cosface_triplet.yaml └── res50_cosface_triplet_0.yaml ├── data ├── __init__.py ├── dataset_loader.py ├── datasets.py ├── samplers.py └── transforms.py ├── losses ├── __init__.py ├── arcface_loss.py ├── circle_loss.py ├── contrastive_loss.py ├── cosface_loss.py ├── cross_entropy_label_smooth.py └── triplet_loss.py ├── main.py ├── models ├── Classifier.py ├── ResNet.py └── __init__.py ├── tools ├── eval_metrics.py ├── transforms.py └── utils.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### A Simple Codebase for Image-based Person Re-identification 2 | 3 | #### Requirements: Python 3.6, Pytorch 1.6.0, yacs 4 | 5 | #### Supported losses 6 | ##### Classification Losses 7 | - [x] CrossEntropy Loss 8 | - [x] CrossEntropy Loss with Label Smooth 9 | - [x] CosFace Loss 10 | - [x] ArcFace Loss 11 | - [x] Circle Loss 12 | ##### Pairwise Losses 13 | - [x] Triplet Loss 14 | - [x] Contrastive Loss 15 | - [x] Pairwise CosFace Loss 16 | - [x] Pairwise Circle Loss 17 | 18 | #### Supported models 19 | - [x] ResNet-50 20 | - [ ] ResNet-50-IBN 21 | - [ ] IANet 22 | 23 | 24 | #### Get Started 25 | - Replace `_C.DATA.ROOT` and `_C.OUTPUT` in `configs/default.py` with your own `data path` and `output path`, respectively. 26 | - Run `train.sh` 27 | 28 | #### Some Results 29 | 30 | ##### Market-1501 31 | 32 | | classification loss | pairwise loss |backbone |top-1 | mAP| 33 | |:---|:---|:---:|:---:|:---:| 34 | | CrossEntropy | Triplet | ResNet-50 | 94.5 | 86.6 | 35 | | CrossEntropy | Contrastive | ResNet-50 | 94.3 | 86.4 | 36 | | CrossEntropy | Cosface | ResNet-50 | 94.3 | 86.2 | 37 | | CELabelSmooth | Triplet | ResNet-50 | 95.0 | 87.4 | 38 | | CELabelSmooth | Contrastive | ResNet-50 | 94.5 | 87.1 | 39 | | CELabelSmooth | Cosface | ResNet-50 | 94.1 | 86.4 | 40 | | Cosface | Triplet | ResNet-50 | 95.1 | 86.7 | 41 | | Cosface | Cosface | ResNet-50 | 94.5 | 87.1 | 42 | | Arcface | Triplet | ResNet-50 | 94.2 | 86.3 | 43 | | Circle | Circle | ResNet-50 | 94.7 | 87.3 | 44 | 45 | ##### MSMT 46 | 47 | | classification loss | pairwise loss |backbone |top-1 | mAP| 48 | |:---|:---|:---:|:---:|:---:| 49 | | CrossEntropy | Triplet | ResNet-50 | 78.9 | 57.0 | 50 | | CrossEntropy | Contrastive | ResNet-50 | 79.3 | 56.7 | 51 | | CrossEntropy | Cosface | ResNet-50 | 78.2 | 55.2 | 52 | | CELabelSmooth | Triplet | ResNet-50 | 79.9 | 58.0 | 53 | | CELabelSmooth | Contrastive | ResNet-50 | 80.3 | 58.7 | 54 | | CELabelSmooth | Cosface | ResNet-50 | 79.2 | 56.6 | 55 | | Cosface | Triplet | ResNet-50 | 78.1 | 54.1 | 56 | | Cosface | Cosface | ResNet-50 | 78.8 | 55.9 | 57 | | Arcface | Triplet | ResNet-50 | 78.2 | 54.2 | 58 | | Circle | Circle | ResNet-50 | 79.7 | 57.0 | 59 | 60 | #### Citation 61 | 62 | If you use our code in your research or wish to refer to the baseline results, please use the following BibTeX entry. 63 | 64 | @InProceedings{CVPR2019IANet 65 | author = {Hou, Ruibing and Ma, Bingpeng and Chang, Hong and Gu, Xinqian and Shan, Shiguang and Chen, Xilin}, 66 | title = {Interaction-And-Aggregation Network for Person Re-Identification}, 67 | booktitle = {CVPR}, 68 | year = {2019} 69 | } 70 | -------------------------------------------------------------------------------- /configs/default.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from yacs.config import CfgNode as CN 4 | 5 | 6 | _C = CN() 7 | # ----------------------------------------------------------------------------- 8 | # Data settings 9 | # ----------------------------------------------------------------------------- 10 | _C.DATA = CN() 11 | # Root path for dataset directory 12 | _C.DATA.ROOT = '/home/guxinqian/data' 13 | # Dataset for evaluation 14 | _C.DATA.DATASET = 'market1501' 15 | # Split index 16 | _C.DATA.SPLIT_ID = 0 17 | # Whether to use labeled images, if false, detected images are used 18 | _C.DATA.CUHK03_LABELED = False 19 | # Whether to use classic split by Li et al. CVPR'14 (default: False) 20 | _C.DATA.CUHK03_CLASSIC_SPLIT = False 21 | # Workers for dataloader 22 | _C.DATA.NUM_WORKERS = 4 23 | # Height of input image 24 | _C.DATA.HEIGHT = 256 25 | # Width of input image 26 | _C.DATA.WIDTH = 128 27 | # Batch size for training 28 | _C.DATA.TRAIN_BATCH = 64 29 | # Batch size for testing 30 | _C.DATA.TEST_BATCH = 512 31 | # The number of instances per identity for training sampler 32 | _C.DATA.NUM_INSTANCES = 4 33 | # ----------------------------------------------------------------------------- 34 | # Augmentation settings 35 | # ----------------------------------------------------------------------------- 36 | _C.AUG = CN() 37 | # Random crop prob 38 | _C.AUG.RC_PROB = 0.5 39 | # Random erase prob 40 | _C.AUG.RE_PROB = 0.5 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model name 46 | _C.MODEL.NAME = 'resnet50' 47 | # The stride for laery4 in resnet 48 | _C.MODEL.RES4_STRIDE = 1 49 | # feature dim 50 | _C.MODEL.FEATURE_DIM = 2048 51 | # Model path for resuming 52 | _C.MODEL.RESUME = '' 53 | # ----------------------------------------------------------------------------- 54 | # Losses for training 55 | # ----------------------------------------------------------------------------- 56 | _C.LOSS = CN() 57 | # Classification loss 58 | _C.LOSS.CLA_LOSS = 'crossentropy' 59 | # Scale 60 | _C.LOSS.CLA_S = 16. 61 | # Margin 62 | _C.LOSS.CLA_M = 0. 63 | # Pairwise loss 64 | _C.LOSS.PAIR_LOSS = 'triplet' 65 | # Scale 66 | _C.LOSS.PAIR_S = 16. 67 | # Margin 68 | _C.LOSS.PAIR_M = 0.3 69 | # ----------------------------------------------------------------------------- 70 | # Training settings 71 | # ----------------------------------------------------------------------------- 72 | _C.TRAIN = CN() 73 | _C.TRAIN.START_EPOCH = 0 74 | _C.TRAIN.MAX_EPOCH = 60 75 | # Optimizer 76 | _C.TRAIN.OPTIMIZER = CN() 77 | _C.TRAIN.OPTIMIZER.NAME = 'adam' 78 | # Learning rate 79 | _C.TRAIN.OPTIMIZER.LR = 0.00035 80 | _C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4 81 | # LR scheduler 82 | _C.TRAIN.LR_SCHEDULER = CN() 83 | # Stepsize to decay learning rate 84 | _C.TRAIN.LR_SCHEDULER.STEPSIZE = [20, 40] 85 | # LR decay rate, used in StepLRScheduler 86 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 87 | # ----------------------------------------------------------------------------- 88 | # Testing settings 89 | # ----------------------------------------------------------------------------- 90 | _C.TEST = CN() 91 | # Similarity for testing 92 | _C.TEST.DISTANCE = 'cosine' 93 | # Perform evaluation after every N epochs (set to -1 to test after training) 94 | _C.TEST.EVAL_STEP = 5 95 | # Start to evaluate after specific epoch 96 | _C.TEST.START_EVAL = 0 97 | # ----------------------------------------------------------------------------- 98 | # Misc 99 | # ----------------------------------------------------------------------------- 100 | # Fixed random seed 101 | _C.SEED = 0 102 | # Perform evaluation only 103 | _C.EVAL_MODE = False 104 | # GPU device ids for CUDA_VISIBLE_DEVICES 105 | _C.GPU = '0' 106 | # Path to output folder, overwritten by command line argument 107 | _C.OUTPUT = '/data/guxinqian/logs/' 108 | # Tag of experiment, overwritten by command line argument 109 | _C.TAG = 'res50-ce-tri' 110 | 111 | 112 | def update_config(config, args): 113 | config.defrost() 114 | 115 | print('=> merge config from {}'.format(args.cfg)) 116 | config.merge_from_file(args.cfg) 117 | 118 | # merge from specific arguments 119 | if args.root: 120 | config.DATA.ROOT = args.root 121 | if args.output: 122 | config.OUTPUT = args.output 123 | 124 | if args.resume: 125 | config.MODEL.RESUME = args.resume 126 | if args.eval: 127 | config.EVAL_MODE = True 128 | 129 | if args.tag: 130 | config.TAG = args.tag 131 | 132 | if args.dataset: 133 | config.DATA.DATASET = args.dataset 134 | if args.gpu: 135 | config.GPU = args.gpu 136 | 137 | # output folder 138 | config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET, config.TAG) 139 | 140 | config.freeze() 141 | 142 | 143 | def get_config(args): 144 | """Get a yacs CfgNode object with default values.""" 145 | config = _C.clone() 146 | update_config(config, args) 147 | 148 | return config 149 | -------------------------------------------------------------------------------- /configs/res50_arcface_triplet.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: arcface 5 | CLA_S: 16. 6 | CLA_M: 0.1 7 | PAIR_LOSS: triplet 8 | PAIR_M: 0.3 9 | TAG: res50-arcface-triplet -------------------------------------------------------------------------------- /configs/res50_ce_contrastive.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: crossentropy 5 | PAIR_LOSS: contrastive 6 | PAIR_S: 16. 7 | TAG: res50-crossentropy-contrastive -------------------------------------------------------------------------------- /configs/res50_ce_cosface.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: crossentropy 5 | PAIR_LOSS: cosface 6 | PAIR_S: 16. 7 | PAIR_M: 0.1 8 | TAG: res50-ce-cosface -------------------------------------------------------------------------------- /configs/res50_ce_triplet.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: crossentropy 5 | PAIR_LOSS: triplet 6 | PAIR_M: 0.3 7 | TAG: res50-ce-triplet -------------------------------------------------------------------------------- /configs/res50_celabelsmooth_contrastive.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: crossentropylabelsmooth 5 | PAIR_LOSS: contrastive 6 | PAIR_S: 16. 7 | TAG: res50-celabelsmooth-contrastive -------------------------------------------------------------------------------- /configs/res50_celabelsmooth_cosface.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: crossentropylabelsmooth 5 | PAIR_LOSS: cosface 6 | PAIR_S: 16. 7 | PAIR_M: 0.1 8 | TAG: res50-celabelsmooth-cosface -------------------------------------------------------------------------------- /configs/res50_celabelsmooth_triplet.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: crossentropylabelsmooth 5 | PAIR_LOSS: triplet 6 | PAIR_M: 0.3 7 | TAG: res50-celabelsmooth-triplet -------------------------------------------------------------------------------- /configs/res50_circle_paircircle.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: circle 5 | CLA_S: 96. 6 | CLA_M: 0.3 7 | PAIR_LOSS: circle 8 | PAIR_S: 48. 9 | PAIR_M: 0.35 10 | TAG: res50-circle-paircircle -------------------------------------------------------------------------------- /configs/res50_cosface_cosface.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: cosface 5 | CLA_S: 16. 6 | CLA_M: 0.1 7 | PAIR_LOSS: cosface 8 | PAIR_S: 16. 9 | PAIR_M: 0.1 10 | TAG: res50-cosface-cosface -------------------------------------------------------------------------------- /configs/res50_cosface_triplet.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: cosface 5 | CLA_S: 16. 6 | CLA_M: 0.1 7 | PAIR_LOSS: triplet 8 | PAIR_M: 0.3 9 | TAG: res50-cosface-triplet -------------------------------------------------------------------------------- /configs/res50_cosface_triplet_0.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: cosface 5 | CLA_S: 16. 6 | CLA_M: 0. 7 | PAIR_LOSS: triplet 8 | PAIR_M: 0.3 9 | TAG: res50-cosface-triplet-0 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import data.transforms as T 2 | from torch.utils.data import DataLoader 3 | from data.datasets import Market1501, CUHK03, DukeMTMCreID, MSMT17 4 | from data.dataset_loader import ImageDataset 5 | from data.samplers import RandomIdentitySampler 6 | 7 | 8 | __factory = { 9 | 'market1501': Market1501, 10 | 'cuhk03': CUHK03, 11 | 'dukemtmcreid': DukeMTMCreID, 12 | 'msmt17': MSMT17, 13 | } 14 | 15 | 16 | def get_names(): 17 | return list(__factory.keys()) 18 | 19 | 20 | def build_dataset(config): 21 | if config.DATA.DATASET not in __factory.keys(): 22 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, __factory.keys())) 23 | 24 | print("Initializing dataset {}".format(config.DATA.DATASET)) 25 | dataset = __factory[config.DATA.DATASET](root=config.DATA.ROOT, split_id=config.DATA.SPLIT_ID, 26 | cuhk03_labeled=config.DATA.CUHK03_LABELED, 27 | cuhk03_classic_split=config.DATA.CUHK03_CLASSIC_SPLIT) 28 | 29 | return dataset 30 | 31 | 32 | def build_transforms(config): 33 | transform_train = T.Compose([ 34 | T.RandomCroping(config.DATA.HEIGHT, config.DATA.WIDTH, p=config.AUG.RC_PROB), 35 | T.RandomHorizontalFlip(), 36 | T.ToTensor(), 37 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 38 | T.RandomErasing(probability = config.AUG.RE_PROB) 39 | ]) 40 | 41 | transform_test = T.Compose([ 42 | T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)), 43 | T.ToTensor(), 44 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 45 | ]) 46 | 47 | return transform_train, transform_test 48 | 49 | 50 | def build_dataloader(config): 51 | dataset = build_dataset(config) 52 | transform_train, transform_test = build_transforms(config) 53 | 54 | trainloader = DataLoader(ImageDataset(dataset.train, transform=transform_train), 55 | sampler=RandomIdentitySampler(dataset.train, num_instances=config.DATA.NUM_INSTANCES), 56 | batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS, 57 | pin_memory=True, drop_last=True) 58 | queryloader = DataLoader(ImageDataset(dataset.query, transform=transform_test), 59 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 60 | pin_memory=True, drop_last=False, shuffle=False) 61 | 62 | galleryloader = DataLoader(ImageDataset(dataset.gallery, transform=transform_test), 63 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 64 | pin_memory=True, drop_last=False, shuffle=False) 65 | 66 | return trainloader, queryloader, galleryloader, dataset.num_train_pids 67 | -------------------------------------------------------------------------------- /data/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | 5 | 6 | def read_image(img_path): 7 | """Keep reading image until succeed. 8 | This can avoid IOError incurred by heavy IO process.""" 9 | got_img = False 10 | if not osp.exists(img_path): 11 | raise IOError("{} does not exist".format(img_path)) 12 | while not got_img: 13 | try: 14 | img = Image.open(img_path).convert('RGB') 15 | got_img = True 16 | except IOError: 17 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 18 | pass 19 | return img 20 | 21 | 22 | class ImageDataset(Dataset): 23 | """Image Person ReID Dataset""" 24 | def __init__(self, dataset, transform=None): 25 | self.dataset = dataset 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return len(self.dataset) 30 | 31 | def __getitem__(self, index): 32 | img_path, pid, camid = self.dataset[index] 33 | img = read_image(img_path) 34 | if self.transform is not None: 35 | img = self.transform(img) 36 | return img, pid, camid -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import numpy as np 6 | import os.path as osp 7 | from scipy.io import loadmat 8 | from tools.utils import mkdir_if_missing, write_json, read_json 9 | 10 | 11 | class Market1501(object): 12 | """ 13 | Market1501 14 | 15 | Reference: 16 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 17 | 18 | URL: http://www.liangzheng.org/Project/project_reid.html 19 | 20 | Dataset statistics: 21 | # identities: 1501 (+1 for background) 22 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 23 | """ 24 | dataset_dir = 'Market' 25 | 26 | def __init__(self, root='data', **kwargs): 27 | self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 31 | 32 | self._check_before_run() 33 | 34 | train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) 35 | query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False) 36 | gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False) 37 | num_total_pids = num_train_pids + num_query_pids 38 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 39 | 40 | print("=> Market1501 loaded") 41 | print("Dataset statistics:") 42 | print(" ------------------------------") 43 | print(" subset | # ids | # images") 44 | print(" ------------------------------") 45 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 46 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 47 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 48 | print(" ------------------------------") 49 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 50 | print(" ------------------------------") 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids = num_train_pids 57 | self.num_query_pids = num_query_pids 58 | self.num_gallery_pids = num_gallery_pids 59 | 60 | def _check_before_run(self): 61 | """Check if all files are available before going deeper""" 62 | if not osp.exists(self.dataset_dir): 63 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 64 | if not osp.exists(self.train_dir): 65 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 66 | if not osp.exists(self.query_dir): 67 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 68 | if not osp.exists(self.gallery_dir): 69 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 70 | 71 | def _process_dir(self, dir_path, relabel=False, label_start=0): 72 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 73 | pattern = re.compile(r'([-\d]+)_c(\d)') 74 | 75 | pid_container = set() 76 | for img_path in img_paths: 77 | pid, _ = map(int, pattern.search(img_path).groups()) 78 | if pid == -1: continue # junk images are just ignored 79 | pid_container.add(pid) 80 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 81 | 82 | dataset = [] 83 | for img_path in img_paths: 84 | pid, camid = map(int, pattern.search(img_path).groups()) 85 | if pid == -1: continue # junk images are just ignored 86 | if label_start == 0: 87 | assert 0 <= pid <= 1501 # pid == 0 means background 88 | assert 1 <= camid <= 6 89 | camid -= 1 # index starts from 0 90 | if relabel: pid = pid2label[pid] + label_start 91 | dataset.append((img_path, pid, camid)) 92 | 93 | num_pids = len(pid_container) 94 | num_imgs = len(dataset) 95 | return dataset, num_pids, num_imgs 96 | 97 | 98 | class CUHK03(object): 99 | """ 100 | CUHK03 101 | 102 | Reference: 103 | Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014. 104 | 105 | URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#! 106 | 107 | Dataset statistics: 108 | # identities: 1360 109 | # images: 13164 110 | # cameras: 6 111 | # splits: 20 (classic) 112 | 113 | Args: 114 | split_id (int): split index (default: 0) 115 | cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False) 116 | """ 117 | dataset_dir = 'cuhk03' 118 | 119 | def __init__(self, root='data', split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, **kwargs): 120 | self.dataset_dir = osp.join(root, self.dataset_dir) 121 | self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release') 122 | self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat') 123 | 124 | self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected') 125 | self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled') 126 | 127 | self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json') 128 | self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json') 129 | 130 | self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json') 131 | self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json') 132 | 133 | self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat') 134 | self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat') 135 | 136 | self._check_before_run() 137 | self._preprocess() 138 | 139 | if cuhk03_labeled: 140 | image_type = 'labeled' 141 | split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path 142 | else: 143 | image_type = 'detected' 144 | split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path 145 | 146 | splits = read_json(split_path) 147 | assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id, len(splits)) 148 | split = splits[split_id] 149 | print("Split index = {}".format(split_id)) 150 | 151 | train = split['train'] 152 | query = split['query'] 153 | gallery = split['gallery'] 154 | 155 | num_train_pids = split['num_train_pids'] 156 | num_query_pids = split['num_query_pids'] 157 | num_gallery_pids = split['num_gallery_pids'] 158 | num_total_pids = num_train_pids + num_query_pids 159 | 160 | num_train_imgs = split['num_train_imgs'] 161 | num_query_imgs = split['num_query_imgs'] 162 | num_gallery_imgs = split['num_gallery_imgs'] 163 | num_total_imgs = num_train_imgs + num_query_imgs 164 | 165 | print("=> CUHK03 ({}) loaded".format(image_type)) 166 | print("Dataset statistics:") 167 | print(" ------------------------------") 168 | print(" subset | # ids | # images") 169 | print(" ------------------------------") 170 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 171 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 172 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 173 | print(" ------------------------------") 174 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 175 | print(" ------------------------------") 176 | 177 | self.train = train 178 | self.query = query 179 | self.gallery = gallery 180 | 181 | self.num_train_pids = num_train_pids 182 | self.num_query_pids = num_query_pids 183 | self.num_gallery_pids = num_gallery_pids 184 | 185 | def _check_before_run(self): 186 | """Check if all files are available before going deeper""" 187 | if not osp.exists(self.dataset_dir): 188 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 189 | if not osp.exists(self.data_dir): 190 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 191 | if not osp.exists(self.raw_mat_path): 192 | raise RuntimeError("'{}' is not available".format(self.raw_mat_path)) 193 | if not osp.exists(self.split_new_det_mat_path): 194 | raise RuntimeError("'{}' is not available".format(self.split_new_det_mat_path)) 195 | if not osp.exists(self.split_new_lab_mat_path): 196 | raise RuntimeError("'{}' is not available".format(self.split_new_lab_mat_path)) 197 | 198 | def _preprocess(self): 199 | """ 200 | This function is a bit complex and ugly, what it does is 201 | 1. Extract data from cuhk-03.mat and save as png images. 202 | 2. Create 20 classic splits. (Li et al. CVPR'14) 203 | 3. Create new split. (Zhong et al. CVPR'17) 204 | """ 205 | print("Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)") 206 | if osp.exists(self.imgs_labeled_dir) and \ 207 | osp.exists(self.imgs_detected_dir) and \ 208 | osp.exists(self.split_classic_det_json_path) and \ 209 | osp.exists(self.split_classic_lab_json_path) and \ 210 | osp.exists(self.split_new_det_json_path) and \ 211 | osp.exists(self.split_new_lab_json_path): 212 | return 213 | 214 | mkdir_if_missing(self.imgs_detected_dir) 215 | mkdir_if_missing(self.imgs_labeled_dir) 216 | 217 | print("Extract image data from {} and save as png".format(self.raw_mat_path)) 218 | mat = h5py.File(self.raw_mat_path, 'r') 219 | 220 | def _deref(ref): 221 | return mat[ref][:].T 222 | 223 | def _process_images(img_refs, campid, pid, save_dir): 224 | img_paths = [] # Note: some persons only have images for one view 225 | for imgid, img_ref in enumerate(img_refs): 226 | img = _deref(img_ref) 227 | # skip empty cell 228 | if img.size == 0 or img.ndim < 3: continue 229 | # images are saved with the following format, index-1 (ensure uniqueness) 230 | # campid: index of camera pair (1-5) 231 | # pid: index of person in 'campid'-th camera pair 232 | # viewid: index of view, {1, 2} 233 | # imgid: index of image, (1-10) 234 | viewid = 1 if imgid < 5 else 2 235 | img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid+1, pid+1, viewid, imgid+1) 236 | img_path = osp.join(save_dir, img_name) 237 | imsave(img_path, img) 238 | img_paths.append(img_path) 239 | return img_paths 240 | 241 | def _extract_img(name): 242 | print("Processing {} images (extract and save) ...".format(name)) 243 | meta_data = [] 244 | imgs_dir = self.imgs_detected_dir if name == 'detected' else self.imgs_labeled_dir 245 | for campid, camp_ref in enumerate(mat[name][0]): 246 | camp = _deref(camp_ref) 247 | num_pids = camp.shape[0] 248 | for pid in range(num_pids): 249 | img_paths = _process_images(camp[pid,:], campid, pid, imgs_dir) 250 | assert len(img_paths) > 0, "campid{}-pid{} has no images".format(campid, pid) 251 | meta_data.append((campid+1, pid+1, img_paths)) 252 | print("done camera pair {} with {} identities".format(campid+1, num_pids)) 253 | return meta_data 254 | 255 | meta_detected = _extract_img('detected') 256 | meta_labeled = _extract_img('labeled') 257 | 258 | def _extract_classic_split(meta_data, test_split): 259 | train, test = [], [] 260 | num_train_pids, num_test_pids = 0, 0 261 | num_train_imgs, num_test_imgs = 0, 0 262 | for i, (campid, pid, img_paths) in enumerate(meta_data): 263 | 264 | if [campid, pid] in test_split: 265 | for img_path in img_paths: 266 | camid = int(osp.basename(img_path).split('_')[2]) 267 | test.append((img_path, num_test_pids, camid)) 268 | num_test_pids += 1 269 | num_test_imgs += len(img_paths) 270 | else: 271 | for img_path in img_paths: 272 | camid = int(osp.basename(img_path).split('_')[2]) 273 | train.append((img_path, num_train_pids, camid)) 274 | num_train_pids += 1 275 | num_train_imgs += len(img_paths) 276 | return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs 277 | 278 | print("Creating classic splits (# = 20) ...") 279 | splits_classic_det, splits_classic_lab = [], [] 280 | for split_ref in mat['testsets'][0]: 281 | test_split = _deref(split_ref).tolist() 282 | 283 | # create split for detected images 284 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \ 285 | _extract_classic_split(meta_detected, test_split) 286 | splits_classic_det.append({ 287 | 'train': train, 'query': test, 'gallery': test, 288 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs, 289 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs, 290 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs, 291 | }) 292 | 293 | # create split for labeled images 294 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \ 295 | _extract_classic_split(meta_labeled, test_split) 296 | splits_classic_lab.append({ 297 | 'train': train, 'query': test, 'gallery': test, 298 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs, 299 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs, 300 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs, 301 | }) 302 | 303 | write_json(splits_classic_det, self.split_classic_det_json_path) 304 | write_json(splits_classic_lab, self.split_classic_lab_json_path) 305 | 306 | def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel): 307 | tmp_set = [] 308 | unique_pids = set() 309 | for idx in idxs: 310 | img_name = filelist[idx][0] 311 | camid = int(img_name.split('_')[2]) 312 | pid = pids[idx] 313 | if relabel: pid = pid2label[pid] 314 | img_path = osp.join(img_dir, img_name) 315 | tmp_set.append((img_path, int(pid), camid)) 316 | unique_pids.add(pid) 317 | return tmp_set, len(unique_pids), len(idxs) 318 | 319 | def _extract_new_split(split_dict, img_dir): 320 | train_idxs = split_dict['train_idx'].flatten() - 1 # index-0 321 | pids = split_dict['labels'].flatten() 322 | train_pids = set(pids[train_idxs]) 323 | pid2label = {pid: label for label, pid in enumerate(train_pids)} 324 | query_idxs = split_dict['query_idx'].flatten() - 1 325 | gallery_idxs = split_dict['gallery_idx'].flatten() - 1 326 | filelist = split_dict['filelist'].flatten() 327 | train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True) 328 | query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False) 329 | gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False) 330 | return train_info, query_info, gallery_info 331 | 332 | print("Creating new splits for detected images (767/700) ...") 333 | train_info, query_info, gallery_info = _extract_new_split( 334 | loadmat(self.split_new_det_mat_path), 335 | self.imgs_detected_dir, 336 | ) 337 | splits = [{ 338 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0], 339 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2], 340 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2], 341 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2], 342 | }] 343 | write_json(splits, self.split_new_det_json_path) 344 | 345 | print("Creating new splits for labeled images (767/700) ...") 346 | train_info, query_info, gallery_info = _extract_new_split( 347 | loadmat(self.split_new_lab_mat_path), 348 | self.imgs_labeled_dir, 349 | ) 350 | splits = [{ 351 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0], 352 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2], 353 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2], 354 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2], 355 | }] 356 | write_json(splits, self.split_new_lab_json_path) 357 | 358 | 359 | class DukeMTMCreID(object): 360 | """ 361 | DukeMTMC-reID 362 | 363 | Reference: 364 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 365 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 366 | 367 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 368 | 369 | Dataset statistics: 370 | # identities: 1404 (train + query) 371 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 372 | # cameras: 8 373 | """ 374 | dataset_dir = 'dukemtmc-reid' 375 | 376 | def __init__(self, root='data', **kwargs): 377 | self.dataset_dir = osp.join(root, self.dataset_dir) 378 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 379 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 380 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 381 | 382 | self._check_before_run() 383 | 384 | train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) 385 | query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False) 386 | gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False) 387 | num_total_pids = num_train_pids + num_query_pids 388 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 389 | 390 | print("=> DukeMTMC-reID loaded") 391 | print("Dataset statistics:") 392 | print(" ------------------------------") 393 | print(" subset | # ids | # images") 394 | print(" ------------------------------") 395 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 396 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 397 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 398 | print(" ------------------------------") 399 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 400 | print(" ------------------------------") 401 | 402 | self.train = train 403 | self.query = query 404 | self.gallery = gallery 405 | 406 | self.num_train_pids = num_train_pids 407 | self.num_query_pids = num_query_pids 408 | self.num_gallery_pids = num_gallery_pids 409 | 410 | def _check_before_run(self): 411 | """Check if all files are available before going deeper""" 412 | if not osp.exists(self.dataset_dir): 413 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 414 | if not osp.exists(self.train_dir): 415 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 416 | if not osp.exists(self.query_dir): 417 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 418 | if not osp.exists(self.gallery_dir): 419 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 420 | 421 | def _process_dir(self, dir_path, relabel=False): 422 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 423 | pattern = re.compile(r'([-\d]+)_c(\d)') 424 | 425 | pid_container = set() 426 | for img_path in img_paths: 427 | pid, _ = map(int, pattern.search(img_path).groups()) 428 | pid_container.add(pid) 429 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 430 | 431 | dataset = [] 432 | for img_path in img_paths: 433 | pid, camid = map(int, pattern.search(img_path).groups()) 434 | assert 1 <= camid <= 8 435 | camid -= 1 # index starts from 0 436 | if relabel: pid = pid2label[pid] 437 | dataset.append((img_path, pid, camid)) 438 | 439 | num_pids = len(pid_container) 440 | num_imgs = len(dataset) 441 | return dataset, num_pids, num_imgs 442 | 443 | 444 | class MSMT17(object): 445 | """ 446 | MSMT17 447 | 448 | Reference: 449 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 450 | 451 | URL: http://www.pkuvmc.com/publications/msmt17.html 452 | 453 | Dataset statistics: 454 | # identities: 4101 455 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 456 | # cameras: 15 457 | """ 458 | dataset_dir = 'MSMT17_V1' 459 | 460 | def __init__(self, root='data', **kwargs): 461 | self.dataset_dir = osp.join(root, self.dataset_dir) 462 | self.train_dir = osp.join(self.dataset_dir, 'train') 463 | self.test_dir = osp.join(self.dataset_dir, 'test') 464 | self.list_train_path = osp.join(self.dataset_dir, 'list_train.txt') 465 | self.list_val_path = osp.join(self.dataset_dir, 'list_val.txt') 466 | self.list_query_path = osp.join(self.dataset_dir, 'list_query.txt') 467 | self.list_gallery_path = osp.join(self.dataset_dir, 'list_gallery.txt') 468 | 469 | self._check_before_run() 470 | train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, self.list_train_path) 471 | #val, num_val_pids, num_val_imgs = self._process_dir(self.train_dir, self.list_val_path) 472 | query, num_query_pids, num_query_imgs = self._process_dir(self.test_dir, self.list_query_path) 473 | gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.test_dir, self.list_gallery_path) 474 | 475 | #train += val 476 | #num_train_imgs += num_val_imgs 477 | 478 | num_total_pids = num_train_pids + num_query_pids 479 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 480 | 481 | print("=> MSMT17 loaded") 482 | print("Dataset statistics:") 483 | print(" ------------------------------") 484 | print(" subset | # ids | # images") 485 | print(" ------------------------------") 486 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 487 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 488 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 489 | print(" ------------------------------") 490 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 491 | print(" ------------------------------") 492 | 493 | self.train = train 494 | self.query = query 495 | self.gallery = gallery 496 | 497 | self.num_train_pids = num_train_pids 498 | self.num_query_pids = num_query_pids 499 | self.num_gallery_pids = num_gallery_pids 500 | 501 | def _check_before_run(self): 502 | """Check if all files are available before going deeper""" 503 | if not osp.exists(self.dataset_dir): 504 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 505 | if not osp.exists(self.train_dir): 506 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 507 | if not osp.exists(self.test_dir): 508 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 509 | 510 | def _process_dir(self, dir_path, list_path): 511 | with open(list_path, 'r') as txt: 512 | lines = txt.readlines() 513 | dataset = [] 514 | pid_container = set() 515 | for img_idx, img_info in enumerate(lines): 516 | img_path, pid = img_info.split(' ') 517 | pid = int(pid) # no need to relabel 518 | camid = int(img_path.split('_')[2]) 519 | img_path = osp.join(dir_path, img_path) 520 | dataset.append((img_path, pid, camid)) 521 | pid_container.add(pid) 522 | num_imgs = len(dataset) 523 | num_pids = len(pid_container) 524 | # check if pid starts from 0 and increments with 1 525 | for idx, pid in enumerate(pid_container): 526 | assert idx == pid, "See code comment for explanation" 527 | return dataset, num_pids, num_imgs 528 | -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import numpy as np 4 | from collections import defaultdict 5 | from torch.utils.data.sampler import Sampler 6 | 7 | 8 | class RandomIdentitySampler(Sampler): 9 | """ 10 | Randomly sample N identities, then for each identity, 11 | randomly sample K instances, therefore batch size is N*K. 12 | 13 | Args: 14 | - data_source (Dataset): dataset to sample from. 15 | - num_instances (int): number of instances per identity. 16 | """ 17 | def __init__(self, data_source, num_instances=4): 18 | self.data_source = data_source 19 | self.num_instances = num_instances 20 | self.index_dic = defaultdict(list) 21 | for index, (_, pid, _) in enumerate(data_source): 22 | self.index_dic[pid].append(index) 23 | self.pids = list(self.index_dic.keys()) 24 | self.num_identities = len(self.pids) 25 | 26 | # compute number of examples in an epoch 27 | self.length = 0 28 | for pid in self.pids: 29 | idxs = self.index_dic[pid] 30 | num = len(idxs) 31 | if num < self.num_instances: 32 | num = self.num_instances 33 | self.length += num - num % self.num_instances 34 | 35 | def __iter__(self): 36 | list_container = [] 37 | 38 | for pid in self.pids: 39 | idxs = copy.deepcopy(self.index_dic[pid]) 40 | if len(idxs) < self.num_instances: 41 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 42 | random.shuffle(idxs) 43 | batch_idxs = [] 44 | for idx in idxs: 45 | batch_idxs.append(idx) 46 | if len(batch_idxs) == self.num_instances: 47 | list_container.append(batch_idxs) 48 | batch_idxs = [] 49 | 50 | random.shuffle(list_container) 51 | 52 | ret = [] 53 | for batch_idxs in list_container: 54 | ret.extend(batch_idxs) 55 | 56 | return iter(ret) 57 | 58 | def __len__(self): 59 | return self.length -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import * 2 | from PIL import Image 3 | import random 4 | import math 5 | 6 | 7 | class RandomCroping(object): 8 | """ 9 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 10 | 11 | Args: 12 | height (int): target height. 13 | width (int): target width. 14 | p (float): probability of performing this transformation. Default: 0.5. 15 | """ 16 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 17 | self.height = height 18 | self.width = width 19 | self.p = p 20 | self.interpolation = interpolation 21 | 22 | def __call__(self, img): 23 | """ 24 | Args: 25 | img (PIL Image): Image to be cropped. 26 | 27 | Returns: 28 | PIL Image: Cropped image. 29 | """ 30 | if random.uniform(0, 1) >= self.p: 31 | return img.resize((self.width, self.height), self.interpolation) 32 | 33 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 34 | resized_img = img.resize((new_width, new_height), self.interpolation) 35 | x_maxrange = new_width - self.width 36 | y_maxrange = new_height - self.height 37 | x1 = int(round(random.uniform(0, x_maxrange))) 38 | y1 = int(round(random.uniform(0, y_maxrange))) 39 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 40 | 41 | return croped_img 42 | 43 | 44 | class RandomErasing(object): 45 | """ Randomly selects a rectangle region in an image and erases its pixels. 46 | 'Random Erasing Data Augmentation' by Zhong et al. 47 | See https://arxiv.org/pdf/1708.04896.pdf 48 | Args: 49 | probability: The probability that the Random Erasing operation will be performed. 50 | sl: Minimum proportion of erased area against input image. 51 | sh: Maximum proportion of erased area against input image. 52 | r1: Minimum aspect ratio of erased area. 53 | mean: Erasing value. 54 | """ 55 | 56 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 57 | self.probability = probability 58 | self.mean = mean 59 | self.sl = sl 60 | self.sh = sh 61 | self.r1 = r1 62 | 63 | def __call__(self, img): 64 | 65 | if random.uniform(0, 1) >= self.probability: 66 | return img 67 | 68 | for attempt in range(100): 69 | area = img.size()[1] * img.size()[2] 70 | 71 | target_area = random.uniform(self.sl, self.sh) * area 72 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 73 | 74 | h = int(round(math.sqrt(target_area * aspect_ratio))) 75 | w = int(round(math.sqrt(target_area / aspect_ratio))) 76 | 77 | if w < img.size()[2] and h < img.size()[1]: 78 | x1 = random.randint(0, img.size()[1] - h) 79 | y1 = random.randint(0, img.size()[2] - w) 80 | if img.size()[0] == 3: 81 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 82 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 83 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 84 | else: 85 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 86 | return img 87 | 88 | return img -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from losses.cross_entropy_label_smooth import CrossEntropyLabelSmooth 5 | from losses.triplet_loss import TripletLoss 6 | from losses.contrastive_loss import ContrastiveLoss 7 | from losses.arcface_loss import ArcFaceLoss 8 | from losses.cosface_loss import CosFaceLoss, PairwiseCosFaceLoss 9 | from losses.circle_loss import CircleLoss, PairwiseCircleLoss 10 | 11 | 12 | def build_losses(config): 13 | # Build classification loss 14 | if config.LOSS.CLA_LOSS == 'crossentropy': 15 | criterion_cla = nn.CrossEntropyLoss() 16 | elif config.LOSS.CLA_LOSS == 'crossentropylabelsmooth': 17 | criterion_cla = CrossEntropyLabelSmooth() 18 | elif config.LOSS.CLA_LOSS == 'arcface': 19 | criterion_cla = ArcFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 20 | elif config.LOSS.CLA_LOSS == 'cosface': 21 | criterion_cla = CosFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 22 | elif config.LOSS.CLA_LOSS == 'circle': 23 | criterion_cla = CircleLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 24 | else: 25 | raise KeyError("Invalid classification loss: '{}'".format(config.LOSS.CLA_LOSS)) 26 | 27 | # Build pairwise loss 28 | if config.LOSS.PAIR_LOSS == 'triplet': 29 | criterion_pair = TripletLoss(margin=config.LOSS.PAIR_M, distance=config.TEST.DISTANCE) 30 | elif config.LOSS.PAIR_LOSS == 'contrastive': 31 | criterion_pair = ContrastiveLoss(scale=config.LOSS.PAIR_S) 32 | elif config.LOSS.PAIR_LOSS == 'cosface': 33 | criterion_pair = PairwiseCosFaceLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M) 34 | elif config.LOSS.PAIR_LOSS == 'circle': 35 | criterion_pair = PairwiseCircleLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M) 36 | else: 37 | raise KeyError("Invalid pairwise loss: '{}'".format(config.LOSS.PAIR_LOSS)) 38 | 39 | return criterion_cla, criterion_pair 40 | 41 | 42 | def DeepSupervision(criterion, xs, y): 43 | """ 44 | Args: 45 | criterion: loss function 46 | xs: tuple of inputs 47 | y: ground truth 48 | """ 49 | loss = 0. 50 | for x in xs: 51 | loss += criterion(x, y) 52 | # loss /= len(xs) 53 | return loss 54 | 55 | -------------------------------------------------------------------------------- /losses/arcface_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class ArcFaceLoss(nn.Module): 8 | def __init__(self, margin=0.1, scale=16, easy_margin=False): 9 | super(ArcFaceLoss, self).__init__() 10 | self.m = margin 11 | self.s = scale 12 | self.easy_margin = easy_margin 13 | 14 | def forward(self, input, target): 15 | 16 | # make a one-hot index 17 | index = input.data * 0.0 #size=(B,Classnum) 18 | index.scatter_(1,target.data.view(-1,1),1) 19 | index = index.bool() 20 | 21 | cos_m = math.cos(self.m) 22 | sin_m = math.sin(self.m) 23 | cos_t = input[index] 24 | sin_t = torch.sqrt(1.0 - cos_t * cos_t) 25 | cos_t_add_m = cos_t * cos_m - sin_t * sin_m 26 | 27 | if self.easy_margin: 28 | cond = F.relu(cos_t) 29 | keep = cos_t 30 | else: 31 | cond_v = cos_t - math.cos(math.pi - self.m) 32 | cond = F.relu(cond_v) 33 | keep = cos_t - math.sin(math.pi - self.m) * self.m 34 | 35 | cos_t_add_m = torch.where(cond.bool(), cos_t_add_m, keep) 36 | 37 | output = input * 1.0 #size=(B,Classnum) 38 | output[index] = cos_t_add_m 39 | output = self.s * output 40 | 41 | return F.cross_entropy(output, target) 42 | -------------------------------------------------------------------------------- /losses/circle_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class CircleLoss(nn.Module): 8 | def __init__(self, scale=96, margin=0.3, **kwargs): 9 | super(CircleLoss, self).__init__() 10 | self.s = scale 11 | self.m = margin 12 | 13 | def forward(self, inputs, targets): 14 | 15 | mask = torch.zeros_like(inputs).cuda() 16 | mask.scatter_(1, targets.view(-1, 1), 1.0) 17 | 18 | pos_scale = self.s * F.relu(1 + self.m - inputs.detach()) 19 | neg_scale = self.s * F.relu(inputs.detach() + self.m) 20 | scale_matrix = pos_scale * mask + neg_scale * (1 - mask) 21 | 22 | scores = (inputs - (1 - self.m) * mask - self.m * (1 - mask)) * scale_matrix 23 | 24 | loss = F.cross_entropy(scores, targets) 25 | 26 | return loss 27 | 28 | 29 | class PairwiseCircleLoss(nn.Module): 30 | def __init__(self, scale=48, margin=0.35, **kwargs): 31 | super(PairwiseCircleLoss, self).__init__() 32 | self.s = scale 33 | self.m = margin 34 | 35 | def forward(self, inputs, targets): 36 | 37 | inputs = F.normalize(inputs, p=2, dim=1) 38 | similarities = torch.matmul(inputs, inputs.t()) 39 | 40 | targets = targets.view(-1,1) 41 | mask = torch.eq(targets, targets.T).float().cuda() 42 | mask_self = torch.eye(targets.size(0)).float().cuda() 43 | mask_pos = mask - mask_self 44 | mask_neg = 1 - mask 45 | 46 | pos_scale = self.s * F.relu(1 + self.m - similarities.detach()) 47 | neg_scale = self.s * F.relu(similarities.detach() + self.m) 48 | scale_matrix = pos_scale * mask_pos + neg_scale * mask_neg 49 | 50 | scores = (similarities - self.m) * mask_neg + (1 - self.m - similarities) * mask_pos 51 | scores = scores * scale_matrix 52 | 53 | neg_scores_LSE = torch.logsumexp(scores*mask_neg - 99999999*(1-mask_neg), dim=1) 54 | pos_scores_LSE = torch.logsumexp(scores*mask_pos - 99999999*(1-mask_pos), dim=1) 55 | 56 | loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean() 57 | 58 | return loss 59 | -------------------------------------------------------------------------------- /losses/contrastive_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class ContrastiveLoss(nn.Module): 8 | def __init__(self, scale=16, **kwargs): 9 | super(ContrastiveLoss, self).__init__() 10 | self.scale = scale 11 | 12 | def forward(self, inputs, targets): 13 | inputs = F.normalize(inputs, p=2, dim=1) 14 | similarities = torch.matmul(inputs, inputs.t()) * self.scale 15 | 16 | targets = targets.view(-1,1) 17 | mask = torch.eq(targets, targets.T).float().cuda() 18 | mask_self = torch.eye(targets.size(0)).float().cuda() 19 | mask_pos = mask - mask_self 20 | mask_neg = 1 - mask 21 | 22 | # compute log_prob 23 | exp_logits = torch.exp(similarities) * (1 - mask_self) 24 | # log_prob = similarities - torch.log(exp_logits.sum(1, keepdim=True)) 25 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * mask_neg).sum(1, keepdim=True) + exp_logits) 26 | log_prob = similarities - log_sum_exp_pos_and_all_neg 27 | 28 | # compute mean of log-likelihood over positive 29 | loss = (mask_pos * log_prob).sum(1) / mask_pos.sum(1) 30 | 31 | loss = - loss.mean() 32 | 33 | return loss -------------------------------------------------------------------------------- /losses/cosface_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class CosFaceLoss(nn.Module): 8 | def __init__(self, scale=16, margin=0.1, **kwargs): 9 | super(CosFaceLoss, self).__init__() 10 | self.s = scale 11 | self.m = margin 12 | 13 | def forward(self, inputs, targets): 14 | 15 | one_hot = torch.zeros_like(inputs) 16 | one_hot.scatter_(1, targets.view(-1, 1), 1.0) 17 | 18 | output = self.s * (inputs - one_hot * self.m) 19 | 20 | return F.cross_entropy(output, targets) 21 | 22 | 23 | class PairwiseCosFaceLoss(nn.Module): 24 | def __init__(self, scale=16, margin=0): 25 | super(PairwiseCosFaceLoss, self).__init__() 26 | self.s = scale 27 | self.m = margin 28 | 29 | def forward(self, inputs, targets): 30 | inputs = F.normalize(inputs, p=2, dim=1) 31 | similarities = torch.matmul(inputs, inputs.t()) 32 | 33 | targets = targets.view(-1,1) 34 | mask = torch.eq(targets, targets.T).float().cuda() 35 | mask_self = torch.eye(targets.size(0)).float().cuda() 36 | mask_pos = mask - mask_self 37 | mask_neg = 1 - mask 38 | 39 | scores = (similarities + self.m) * mask_neg - similarities * mask_pos 40 | scores = scores * self.s 41 | 42 | neg_scores_LSE = torch.logsumexp(scores*mask_neg - 99999999*(1-mask_neg), dim=1) 43 | pos_scores_LSE = torch.logsumexp(scores*mask_pos - 99999999*(1-mask_pos), dim=1) 44 | 45 | loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean() 46 | 47 | return loss -------------------------------------------------------------------------------- /losses/cross_entropy_label_smooth.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class CrossEntropyLabelSmooth(nn.Module): 8 | """Cross entropy loss with label smoothing regularizer. 9 | 10 | Reference: 11 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 12 | Equation: y = (1 - epsilon) * y + epsilon / K. 13 | 14 | Args: 15 | num_classes (int): number of classes. 16 | epsilon (float): weight. 17 | """ 18 | def __init__(self, epsilon=0.1): 19 | super().__init__() 20 | self.epsilon = epsilon 21 | self.logsoftmax = nn.LogSoftmax(dim=1) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (num_classes) 28 | """ 29 | _, num_classes = inputs.size() 30 | log_probs = self.logsoftmax(inputs) 31 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda() 32 | targets = (1 - self.epsilon) * targets + self.epsilon / num_classes 33 | loss = (- targets * log_probs).mean(0).sum() 34 | 35 | return loss 36 | -------------------------------------------------------------------------------- /losses/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class TripletLoss(nn.Module): 8 | """Triplet loss with hard positive/negative mining. 9 | 10 | Reference: 11 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 12 | 13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 14 | 15 | Args: 16 | margin (float): margin for triplet. 17 | """ 18 | def __init__(self, margin=0.3, distance='euclidean'): 19 | super(TripletLoss, self).__init__() 20 | if distance not in ['euclidean', 'cosine']: 21 | raise KeyError("Unsupported distance: {}".format(distance)) 22 | self.distance = distance 23 | self.margin = margin 24 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 25 | 26 | def forward(self, inputs, targets): 27 | """ 28 | Args: 29 | inputs: feature matrix with shape (batch_size, feat_dim) 30 | targets: ground truth labels with shape (num_classes) 31 | """ 32 | n = inputs.size(0) 33 | 34 | # Compute pairwise distance, replace by the official when merged 35 | if self.distance == 'euclidean': 36 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 37 | dist = dist + dist.t() 38 | dist.addmm_(1, -2, inputs, inputs.t()) 39 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 40 | elif self.distance == 'cosine': 41 | inputs = F.normalize(inputs, p=2, dim=1) 42 | dist = - torch.mm(inputs, inputs.t()) 43 | 44 | # For each anchor, find the hardest positive and negative 45 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 46 | dist_ap, dist_an = [], [] 47 | for i in range(n): 48 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 49 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 50 | dist_ap = torch.cat(dist_ap) 51 | dist_an = torch.cat(dist_an) 52 | 53 | # Compute ranking hinge loss 54 | y = torch.ones_like(dist_an) 55 | loss = self.ranking_loss(dist_an, dist_ap, y) 56 | 57 | return loss -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import datetime 5 | import argparse 6 | import os.path as osp 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torch.optim import lr_scheduler 14 | 15 | from configs.default import get_config 16 | from data import build_dataloader 17 | from models import build_model 18 | from losses import build_losses 19 | from tools.eval_metrics import evaluate 20 | from tools.utils import AverageMeter, Logger, save_checkpoint, set_seed 21 | 22 | 23 | def parse_option(): 24 | parser = argparse.ArgumentParser(description='Train image-based re-id model') 25 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file') 26 | # Datasets 27 | parser.add_argument('--root', type=str, help="your root path to data directory") 28 | parser.add_argument('--dataset', type=str, help="market1501, cuhk03, dukemtmcreid, msmt17") 29 | # Miscs 30 | parser.add_argument('--output', type=str, help="your output path to save model and logs") 31 | parser.add_argument('--resume', type=str, metavar='PATH') 32 | parser.add_argument('--eval', action='store_true', help="evaluation only") 33 | parser.add_argument('--tag', type=str, help='tag for log file') 34 | parser.add_argument('--gpu', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') 35 | 36 | args, unparsed = parser.parse_known_args() 37 | config = get_config(args) 38 | 39 | return config 40 | 41 | 42 | def main(config): 43 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU 44 | 45 | if not config.EVAL_MODE: 46 | sys.stdout = Logger(osp.join(config.OUTPUT, 'log_train.txt')) 47 | else: 48 | sys.stdout = Logger(osp.join(config.OUTPUT, 'log_test.txt')) 49 | print("==========\nConfig:{}\n==========".format(config)) 50 | print("Currently using GPU {}".format(config.GPU)) 51 | # Set random seed 52 | set_seed(config.SEED) 53 | 54 | # Build dataloader 55 | trainloader, queryloader, galleryloader, num_classes = build_dataloader(config) 56 | # Build model 57 | model, classifier = build_model(config, num_classes) 58 | # Build classification and pairwise loss 59 | criterion_cla, criterion_pair = build_losses(config) 60 | # Build optimizer 61 | parameters = list(model.parameters()) + list(classifier.parameters()) 62 | if config.TRAIN.OPTIMIZER.NAME == 'adam': 63 | optimizer = optim.Adam(parameters, lr=config.TRAIN.OPTIMIZER.LR, 64 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 65 | elif config.TRAIN.OPTIMIZER.NAME == 'adamw': 66 | optimizer = optim.AdamW(parameters, lr=config.TRAIN.OPTIMIZER.LR, 67 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 68 | elif config.TRAIN.OPTIMIZER.NAME == 'sgd': 69 | optimizer = optim.SGD(parameters, lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 70 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True) 71 | else: 72 | raise KeyError("Unknown optimizer: {}".format(config.TRAIN.OPTIMIZER.NAME)) 73 | # Build lr_scheduler 74 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.TRAIN.LR_SCHEDULER.STEPSIZE, 75 | gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE) 76 | 77 | start_epoch = config.TRAIN.START_EPOCH 78 | if config.MODEL.RESUME: 79 | print("Loading checkpoint from '{}'".format(config.MODEL.RESUME)) 80 | checkpoint = torch.load(config.MODEL.RESUME) 81 | model.load_state_dict(checkpoint['state_dict']) 82 | start_epoch = checkpoint['epoch'] 83 | 84 | model = nn.DataParallel(model).cuda() 85 | classifier = nn.DataParallel(classifier).cuda() 86 | 87 | if config.EVAL_MODE: 88 | print("Evaluate only") 89 | test(model, queryloader, galleryloader) 90 | return 91 | 92 | start_time = time.time() 93 | train_time = 0 94 | best_rank1 = -np.inf 95 | best_epoch = 0 96 | print("==> Start training") 97 | for epoch in range(start_epoch, config.TRAIN.MAX_EPOCH): 98 | start_train_time = time.time() 99 | train(epoch, model, classifier, criterion_cla, criterion_pair, optimizer, trainloader) 100 | train_time += round(time.time() - start_train_time) 101 | 102 | if (epoch+1) > config.TEST.START_EVAL and config.TEST.EVAL_STEP > 0 and \ 103 | (epoch+1) % config.TEST.EVAL_STEP == 0 or (epoch+1) == config.TRAIN.MAX_EPOCH: 104 | print("==> Test") 105 | rank1 = test(model, queryloader, galleryloader) 106 | is_best = rank1 > best_rank1 107 | if is_best: 108 | best_rank1 = rank1 109 | best_epoch = epoch + 1 110 | 111 | state_dict = model.module.state_dict() 112 | save_checkpoint({ 113 | 'state_dict': state_dict, 114 | 'rank1': rank1, 115 | 'epoch': epoch, 116 | }, is_best, osp.join(config.OUTPUT, 'checkpoint_ep' + str(epoch+1) + '.pth.tar')) 117 | scheduler.step() 118 | 119 | print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch)) 120 | 121 | elapsed = round(time.time() - start_time) 122 | elapsed = str(datetime.timedelta(seconds=elapsed)) 123 | train_time = str(datetime.timedelta(seconds=train_time)) 124 | print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) 125 | 126 | 127 | def train(epoch, model, classifier, criterion_cla, criterion_pair, optimizer, trainloader): 128 | batch_cla_loss = AverageMeter() 129 | batch_pair_loss = AverageMeter() 130 | corrects = AverageMeter() 131 | batch_time = AverageMeter() 132 | data_time = AverageMeter() 133 | 134 | model.train() 135 | classifier.train() 136 | 137 | end = time.time() 138 | for batch_idx, (imgs, pids, _) in enumerate(trainloader): 139 | imgs, pids = imgs.cuda(), pids.cuda() 140 | # Measure data loading time 141 | data_time.update(time.time() - end) 142 | # Zero the parameter gradients 143 | optimizer.zero_grad() 144 | # Forward 145 | features = model(imgs) 146 | outputs = classifier(features) 147 | _, preds = torch.max(outputs.data, 1) 148 | # Compute loss 149 | cla_loss = criterion_cla(outputs, pids) 150 | pair_loss = criterion_pair(features, pids) 151 | loss = cla_loss + pair_loss 152 | # Backward + Optimize 153 | loss.backward() 154 | optimizer.step() 155 | # statistics 156 | corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0)) 157 | batch_cla_loss.update(cla_loss.item(), pids.size(0)) 158 | batch_pair_loss.update(pair_loss.item(), pids.size(0)) 159 | # measure elapsed time 160 | batch_time.update(time.time() - end) 161 | end = time.time() 162 | 163 | print('Epoch{0} ' 164 | 'Time:{batch_time.sum:.1f}s ' 165 | 'Data:{data_time.sum:.1f}s ' 166 | 'ClaLoss:{cla_loss.avg:.4f} ' 167 | 'PairLoss:{pair_loss.avg:.4f} ' 168 | 'Acc:{acc.avg:.2%} '.format( 169 | epoch+1, batch_time=batch_time, data_time=data_time, 170 | cla_loss=batch_cla_loss, pair_loss=batch_pair_loss, acc=corrects)) 171 | 172 | 173 | def fliplr(img): 174 | '''flip horizontal''' 175 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W 176 | img_flip = img.index_select(3,inv_idx) 177 | 178 | return img_flip 179 | 180 | 181 | @torch.no_grad() 182 | def extract_feature(model, dataloader): 183 | features, pids, camids = [], [], [] 184 | for batch_idx, (imgs, batch_pids, batch_camids) in enumerate(dataloader): 185 | flip_imgs = fliplr(imgs) 186 | imgs, flip_imgs = imgs.cuda(), flip_imgs.cuda() 187 | batch_features = model(imgs).data.cpu() 188 | batch_features_flip = model(flip_imgs).data.cpu() 189 | batch_features += batch_features_flip 190 | 191 | features.append(batch_features) 192 | pids.append(batch_pids) 193 | camids.append(batch_camids) 194 | features = torch.cat(features, 0) 195 | pids = torch.cat(pids, 0).numpy() 196 | camids = torch.cat(camids, 0).numpy() 197 | 198 | return features, pids, camids 199 | 200 | 201 | def test(model, queryloader, galleryloader): 202 | since = time.time() 203 | model.eval() 204 | # Extract features for query set 205 | qf, q_pids, q_camids = extract_feature(model, queryloader) 206 | print("Extracted features for query set, obtained {} matrix".format(qf.shape)) 207 | # Extract features for gallery set 208 | gf, g_pids, g_camids = extract_feature(model, galleryloader) 209 | print("Extracted features for gallery set, obtained {} matrix".format(gf.shape)) 210 | time_elapsed = time.time() - since 211 | print('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 212 | # Compute distance matrix between query and gallery 213 | m, n = qf.size(0), gf.size(0) 214 | distmat = torch.zeros((m,n)) 215 | if config.TEST.DISTANCE == 'euclidean': 216 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 217 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 218 | for i in range(m): 219 | distmat[i:i+1].addmm_(1, -2, qf[i:i+1], gf.t()) 220 | else: 221 | # Cosine similarity 222 | qf = F.normalize(qf, p=2, dim=1) 223 | gf = F.normalize(gf, p=2, dim=1) 224 | for i in range(m): 225 | distmat[i] = - torch.mm(qf[i:i+1], gf.t()) 226 | distmat = distmat.numpy() 227 | 228 | print("Computing CMC and mAP") 229 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) 230 | 231 | print("Results ----------------------------------------") 232 | print('top1:{:.1%} top5:{:.1%} top10:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], mAP)) 233 | print("------------------------------------------------") 234 | 235 | return cmc[0] 236 | 237 | 238 | if __name__ == '__main__': 239 | config = parse_option() 240 | main(config) -------------------------------------------------------------------------------- /models/Classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | from torch.nn import functional as F 5 | from torch.nn import Parameter 6 | 7 | 8 | __all__ = ['Classifier', 'NormalizedClassifier'] 9 | 10 | 11 | class Classifier(nn.Module): 12 | def __init__(self, feature_dim, num_classes): 13 | super().__init__() 14 | self.classifier = nn.Linear(feature_dim, num_classes) 15 | init.normal_(self.classifier.weight.data, std=0.001) 16 | init.constant_(self.classifier.bias.data, 0.0) 17 | 18 | def forward(self, x): 19 | y = self.classifier(x) 20 | 21 | return y 22 | 23 | 24 | class NormalizedClassifier(nn.Module): 25 | def __init__(self, feature_dim, num_classes): 26 | super().__init__() 27 | self.weight = Parameter(torch.Tensor(num_classes, feature_dim)) 28 | self.weight.data.uniform_(-1, 1).renorm_(2,0,1e-5).mul_(1e5) 29 | 30 | def forward(self, x): 31 | w = self.weight 32 | 33 | x = F.normalize(x, p=2, dim=1) 34 | w = F.normalize(w, p=2, dim=1) 35 | 36 | return F.linear(x, w) 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /models/ResNet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch import nn 3 | from torch.nn import init 4 | from torch.nn import functional as F 5 | 6 | 7 | class ResNet50(nn.Module): 8 | def __init__(self, res4_stride=1, **kwargs): 9 | super().__init__() 10 | 11 | resnet50 = torchvision.models.resnet50(pretrained=True) 12 | if res4_stride == 1: 13 | resnet50.layer4[0].conv2.stride=(1, 1) 14 | resnet50.layer4[0].downsample[0].stride=(1, 1) 15 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 16 | 17 | self.bn = nn.BatchNorm1d(2048) 18 | init.normal_(self.bn.weight.data, 1.0, 0.02) 19 | init.constant_(self.bn.bias.data, 0.0) 20 | 21 | def forward(self, x): 22 | x = self.base(x) 23 | x = F.avg_pool2d(x, x.size()[2:]) 24 | x = x.view(x.size(0), -1) 25 | f = self.bn(x) 26 | 27 | return f -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.Classifier import Classifier, NormalizedClassifier 2 | from models.ResNet import ResNet50 3 | 4 | 5 | def build_model(config, num_classes): 6 | # Build backbone 7 | print("Initializing model: {}".format(config.MODEL.NAME)) 8 | if config.MODEL.NAME == 'resnet50': 9 | model = ResNet50(res4_stride=config.MODEL.RES4_STRIDE) 10 | else: 11 | raise KeyError("Invalid model: '{}'".format(config.MODEL.NAME)) 12 | print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0)) 13 | 14 | # Build classifier 15 | if config.LOSS.CLA_LOSS in ['crossentropy', 'crossentropylabelsmooth']: 16 | classifier = Classifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_classes) 17 | else: 18 | classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_classes) 19 | 20 | return model, classifier -------------------------------------------------------------------------------- /tools/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_ap_cmc(index, good_index, junk_index): 5 | ap = 0 6 | cmc = np.zeros(len(index)) 7 | 8 | # remove junk_index 9 | mask = np.in1d(index, junk_index, invert=True) 10 | index = index[mask] 11 | 12 | # find good_index index 13 | ngood = len(good_index) 14 | mask = np.in1d(index, good_index) 15 | rows_good = np.argwhere(mask==True) 16 | rows_good = rows_good.flatten() 17 | 18 | cmc[rows_good[0]:] = 1.0 19 | for i in range(ngood): 20 | d_recall = 1.0/ngood 21 | precision = (i+1)*1.0/(rows_good[i]+1) 22 | # if rows_good[i]!=0: 23 | # old_precision = i*1.0/rows_good[i] 24 | # else: 25 | # old_precision=1.0 26 | # ap = ap + d_recall*(old_precision + precision)/2 27 | ap = ap + d_recall*precision 28 | 29 | return ap, cmc 30 | 31 | 32 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids): 33 | num_q, num_g = distmat.shape 34 | index = np.argsort(distmat, axis=1) # from small to large 35 | 36 | num_no_gt = 0 # num of query imgs without groundtruth 37 | num_r1 = 0 38 | CMC = np.zeros(len(g_pids)) 39 | AP = 0 40 | 41 | for i in range(num_q): 42 | # groundtruth index 43 | query_index = np.argwhere(g_pids==q_pids[i]) 44 | camera_index = np.argwhere(g_camids==q_camids[i]) 45 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 46 | if good_index.size == 0: 47 | num_no_gt += 1 48 | continue 49 | # remove gallery samples that have the same pid and camid with query 50 | junk_index = np.intersect1d(query_index, camera_index) 51 | 52 | ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index) 53 | if CMC_tmp[0]==1: 54 | num_r1 += 1 55 | CMC = CMC + CMC_tmp 56 | AP += ap_tmp 57 | 58 | if num_no_gt > 0: 59 | print("{} query imgs do not have groundtruth.".format(num_no_gt)) 60 | 61 | CMC = CMC / (num_q - num_no_gt) 62 | mAP = AP / (num_q - num_no_gt) 63 | 64 | return CMC, mAP 65 | -------------------------------------------------------------------------------- /tools/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | 9 | class Random2DTranslation(object): 10 | """ 11 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 12 | 13 | Args: 14 | height (int): target height. 15 | width (int): target width. 16 | p (float): probability of performing this transformation. Default: 0.5. 17 | """ 18 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 19 | self.height = height 20 | self.width = width 21 | self.p = p 22 | self.interpolation = interpolation 23 | 24 | def __call__(self, img): 25 | """ 26 | Args: 27 | img (PIL Image): Image to be cropped. 28 | 29 | Returns: 30 | PIL Image: Cropped image. 31 | """ 32 | if random.uniform(0, 1) > self.p: 33 | return img.resize((self.width, self.height), self.interpolation) 34 | 35 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 36 | resized_img = img.resize((new_width, new_height), self.interpolation) 37 | x_maxrange = new_width - self.width 38 | y_maxrange = new_height - self.height 39 | x1 = int(round(random.uniform(0, x_maxrange))) 40 | y1 = int(round(random.uniform(0, y_maxrange))) 41 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 42 | return croped_img 43 | 44 | class RandomErasing(object): 45 | """ Randomly selects a rectangle region in an image and erases its pixels. 46 | 'Random Erasing Data Augmentation' by Zhong et al. 47 | See https://arxiv.org/pdf/1708.04896.pdf 48 | Args: 49 | probability: The probability that the Random Erasing operation will be performed. 50 | sl: Minimum proportion of erased area against input image. 51 | sh: Maximum proportion of erased area against input image. 52 | r1: Minimum aspect ratio of erased area. 53 | mean: Erasing value. 54 | """ 55 | 56 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 57 | self.probability = probability 58 | self.mean = mean 59 | self.sl = sl 60 | self.sh = sh 61 | self.r1 = r1 62 | 63 | def __call__(self, img): 64 | 65 | if random.uniform(0, 1) > self.probability: 66 | return img 67 | 68 | for attempt in range(100): 69 | area = img.size()[1] * img.size()[2] 70 | 71 | target_area = random.uniform(self.sl, self.sh) * area 72 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 73 | 74 | h = int(round(math.sqrt(target_area * aspect_ratio))) 75 | w = int(round(math.sqrt(target_area / aspect_ratio))) 76 | 77 | if w < img.size()[2] and h < img.size()[1]: 78 | x1 = random.randint(0, img.size()[1] - h) 79 | y1 = random.randint(0, img.size()[2] - w) 80 | if img.size()[0] == 3: 81 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 82 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 83 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 84 | else: 85 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 86 | return img 87 | 88 | return img -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import errno 5 | import json 6 | import os.path as osp 7 | import torch 8 | import random 9 | import numpy as np 10 | 11 | 12 | def set_seed(seed=None): 13 | if seed is None: 14 | return 15 | random.seed(seed) 16 | os.environ['PYTHONHASHSEED'] = ("%s" % seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | torch.backends.cudnn.benchmark = False 22 | torch.backends.cudnn.deterministic = True 23 | 24 | 25 | def mkdir_if_missing(directory): 26 | if not osp.exists(directory): 27 | try: 28 | os.makedirs(directory) 29 | except OSError as e: 30 | if e.errno != errno.EEXIST: 31 | raise 32 | 33 | 34 | def read_json(fpath): 35 | with open(fpath, 'r') as f: 36 | obj = json.load(f) 37 | return obj 38 | 39 | 40 | def write_json(obj, fpath): 41 | mkdir_if_missing(osp.dirname(fpath)) 42 | with open(fpath, 'w') as f: 43 | json.dump(obj, f, indent=4, separators=(',', ': ')) 44 | 45 | 46 | class AverageMeter(object): 47 | """Computes and stores the average and current value. 48 | 49 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 50 | """ 51 | def __init__(self): 52 | self.reset() 53 | 54 | def reset(self): 55 | self.val = 0 56 | self.avg = 0 57 | self.sum = 0 58 | self.count = 0 59 | 60 | def update(self, val, n=1): 61 | self.val = val 62 | self.sum += val * n 63 | self.count += n 64 | self.avg = self.sum / self.count 65 | 66 | 67 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 68 | mkdir_if_missing(osp.dirname(fpath)) 69 | torch.save(state, fpath) 70 | if is_best: 71 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) 72 | 73 | 74 | class Logger(object): 75 | """ 76 | Write console output to external text file. 77 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 78 | """ 79 | def __init__(self, fpath=None): 80 | self.console = sys.stdout 81 | self.file = None 82 | if fpath is not None: 83 | mkdir_if_missing(os.path.dirname(fpath)) 84 | self.file = open(fpath, 'w') 85 | 86 | def __del__(self): 87 | self.close() 88 | 89 | def __enter__(self): 90 | pass 91 | 92 | def __exit__(self, *args): 93 | self.close() 94 | 95 | def write(self, msg): 96 | self.console.write(msg) 97 | if self.file is not None: 98 | self.file.write(msg) 99 | 100 | def flush(self): 101 | self.console.flush() 102 | if self.file is not None: 103 | self.file.flush() 104 | os.fsync(self.file.fileno()) 105 | 106 | def close(self): 107 | self.console.close() 108 | if self.file is not None: 109 | self.file.close() -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python main.py --dataset msmt17 --cfg configs/res50_ce_triplet.yaml --gpu 0,1 # 2 | --------------------------------------------------------------------------------