├── .gitignore ├── LICENSE ├── README.md ├── configs ├── c2dres50_ce_cal.yaml ├── default_img.py ├── default_vid.py ├── res50_cels_cal.yaml ├── res50_cels_cal_16x4.yaml └── res50_cels_cal_tri_16x4.yaml ├── data ├── __init__.py ├── dataloader.py ├── dataset_loader.py ├── datasets │ ├── ccvid.py │ ├── deepchange.py │ ├── last.py │ ├── ltcc.py │ ├── prcc.py │ └── vcclothes.py ├── img_transforms.py ├── samplers.py ├── spatial_transforms.py └── temporal_transforms.py ├── losses ├── __init__.py ├── arcface_loss.py ├── circle_loss.py ├── clothes_based_adversarial_loss.py ├── contrastive_loss.py ├── cosface_loss.py ├── cross_entropy_loss_with_label_smooth.py ├── gather.py └── triplet_loss.py ├── main.py ├── models ├── __init__.py ├── classifier.py ├── img_resnet.py ├── utils │ ├── c3d_blocks.py │ ├── inflate.py │ ├── nonlocal_blocks.py │ └── pooling.py └── vid_resnet.py ├── script.sh ├── test.py ├── tools ├── eval_metrics.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### A Simple Codebase for Clothes-Changing Person Re-identification. 2 | #### [Clothes-Changing Person Re-identification with RGB Modality Only (CVPR, 2022)](https://arxiv.org/abs/2204.06890) 3 | 4 | #### Requirements 5 | - Python 3.6 6 | - Pytorch 1.6.0 7 | - yacs 8 | - apex 9 | 10 | #### CCVID Dataset 11 | - [[BaiduYun]](https://pan.baidu.com/s/1W9yjqxS9qxfPUSu76JpE1g) password: q0q2 12 | - [[GoogleDrive]](https://drive.google.com/file/d/1vkZxm5v-aBXa_JEi23MMeW4DgisGtS4W/view?usp=sharing) 13 | 14 | #### Get Started 15 | - Replace `_C.DATA.ROOT` and `_C.OUTPUT` in `configs/default_img.py&default_vid.py`with your own `data path` and `output path`, respectively. 16 | - Run `script.sh` 17 | 18 | 19 | #### Citation 20 | 21 | If you use our code/dataset in your research or wish to refer to the baseline results, please use the following BibTeX entry. 22 | 23 | @inproceedings{gu2022CAL, 24 | title={Clothes-Changing Person Re-identification with RGB Modality Only}, 25 | author={Gu, Xinqian and Chang, Hong and Ma, Bingpeng and Bai, Shutao and Shan, Shiguang and Chen, Xilin}, 26 | booktitle={CVPR}, 27 | year={2022}, 28 | } 29 | 30 | #### Related Repos 31 | 32 | - [Simple-ReID](https://github.com/guxinqian/Simple-ReID) 33 | - [fast-reid](https://github.com/JDAI-CV/fast-reid) 34 | - [deep-person-reid](https://github.com/KaiyangZhou/deep-person-reid) 35 | - [Pytorch ReID](https://github.com/layumi/Person_reID_baseline_pytorch) 36 | 37 | -------------------------------------------------------------------------------- /configs/c2dres50_ce_cal.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: c2dres50 3 | LOSS: 4 | CLA_LOSS: crossentropy 5 | CAL: cal 6 | TAG: c2dres50-ce-cal -------------------------------------------------------------------------------- /configs/default_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from yacs.config import CfgNode as CN 4 | 5 | 6 | _C = CN() 7 | # ----------------------------------------------------------------------------- 8 | # Data settings 9 | # ----------------------------------------------------------------------------- 10 | _C.DATA = CN() 11 | # Root path for dataset directory 12 | _C.DATA.ROOT = '/home/guxinqian/data' 13 | # Dataset for evaluation 14 | _C.DATA.DATASET = 'ltcc' 15 | # Workers for dataloader 16 | _C.DATA.NUM_WORKERS = 4 17 | # Height of input image 18 | _C.DATA.HEIGHT = 384 19 | # Width of input image 20 | _C.DATA.WIDTH = 192 21 | # Batch size for training 22 | _C.DATA.TRAIN_BATCH = 32 23 | # Batch size for testing 24 | _C.DATA.TEST_BATCH = 128 25 | # The number of instances per identity for training sampler 26 | _C.DATA.NUM_INSTANCES = 8 27 | # ----------------------------------------------------------------------------- 28 | # Augmentation settings 29 | # ----------------------------------------------------------------------------- 30 | _C.AUG = CN() 31 | # Random crop prob 32 | _C.AUG.RC_PROB = 0.5 33 | # Random erase prob 34 | _C.AUG.RE_PROB = 0.5 35 | # Random flip prob 36 | _C.AUG.RF_PROB = 0.5 37 | # ----------------------------------------------------------------------------- 38 | # Model settings 39 | # ----------------------------------------------------------------------------- 40 | _C.MODEL = CN() 41 | # Model name 42 | _C.MODEL.NAME = 'resnet50' 43 | # The stride for laery4 in resnet 44 | _C.MODEL.RES4_STRIDE = 1 45 | # feature dim 46 | _C.MODEL.FEATURE_DIM = 4096 47 | # Model path for resuming 48 | _C.MODEL.RESUME = '' 49 | # Global pooling after the backbone 50 | _C.MODEL.POOLING = CN() 51 | # Choose in ['avg', 'max', 'gem', 'maxavg'] 52 | _C.MODEL.POOLING.NAME = 'maxavg' 53 | # Initialized power for GeM pooling 54 | _C.MODEL.POOLING.P = 3 55 | # ----------------------------------------------------------------------------- 56 | # Losses for training 57 | # ----------------------------------------------------------------------------- 58 | _C.LOSS = CN() 59 | # Classification loss 60 | _C.LOSS.CLA_LOSS = 'crossentropy' 61 | # Clothes classification loss 62 | _C.LOSS.CLOTHES_CLA_LOSS = 'cosface' 63 | # Scale for classification loss 64 | _C.LOSS.CLA_S = 16. 65 | # Margin for classification loss 66 | _C.LOSS.CLA_M = 0. 67 | # Pairwise loss 68 | _C.LOSS.PAIR_LOSS = 'triplet' 69 | # The weight for pairwise loss 70 | _C.LOSS.PAIR_LOSS_WEIGHT = 0.0 71 | # Scale for pairwise loss 72 | _C.LOSS.PAIR_S = 16. 73 | # Margin for pairwise loss 74 | _C.LOSS.PAIR_M = 0.3 75 | # Clothes-based adversarial loss 76 | _C.LOSS.CAL = 'cal' 77 | # Epsilon for clothes-based adversarial loss 78 | _C.LOSS.EPSILON = 0.1 79 | # Momentum for clothes-based adversarial loss with memory bank 80 | _C.LOSS.MOMENTUM = 0. 81 | # ----------------------------------------------------------------------------- 82 | # Training settings 83 | # ----------------------------------------------------------------------------- 84 | _C.TRAIN = CN() 85 | _C.TRAIN.START_EPOCH = 0 86 | _C.TRAIN.MAX_EPOCH = 60 87 | # Start epoch for clothes classification 88 | _C.TRAIN.START_EPOCH_CC = 25 89 | # Start epoch for adversarial training 90 | _C.TRAIN.START_EPOCH_ADV = 25 91 | # Optimizer 92 | _C.TRAIN.OPTIMIZER = CN() 93 | _C.TRAIN.OPTIMIZER.NAME = 'adam' 94 | # Learning rate 95 | _C.TRAIN.OPTIMIZER.LR = 0.00035 96 | _C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4 97 | # LR scheduler 98 | _C.TRAIN.LR_SCHEDULER = CN() 99 | # Stepsize to decay learning rate 100 | _C.TRAIN.LR_SCHEDULER.STEPSIZE = [20, 40] 101 | # LR decay rate, used in StepLRScheduler 102 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 103 | # Using amp for training 104 | _C.TRAIN.AMP = False 105 | # ----------------------------------------------------------------------------- 106 | # Testing settings 107 | # ----------------------------------------------------------------------------- 108 | _C.TEST = CN() 109 | # Perform evaluation after every N epochs (set to -1 to test after training) 110 | _C.TEST.EVAL_STEP = 5 111 | # Start to evaluate after specific epoch 112 | _C.TEST.START_EVAL = 0 113 | # ----------------------------------------------------------------------------- 114 | # Misc 115 | # ----------------------------------------------------------------------------- 116 | # Fixed random seed 117 | _C.SEED = 1 118 | # Perform evaluation only 119 | _C.EVAL_MODE = False 120 | # GPU device ids for CUDA_VISIBLE_DEVICES 121 | _C.GPU = '0' 122 | # Path to output folder, overwritten by command line argument 123 | _C.OUTPUT = '/data/guxinqian/logs/' 124 | # Tag of experiment, overwritten by command line argument 125 | _C.TAG = 'res50-ce-cal' 126 | 127 | 128 | def update_config(config, args): 129 | config.defrost() 130 | config.merge_from_file(args.cfg) 131 | 132 | # merge from specific arguments 133 | if args.root: 134 | config.DATA.ROOT = args.root 135 | if args.output: 136 | config.OUTPUT = args.output 137 | 138 | if args.resume: 139 | config.MODEL.RESUME = args.resume 140 | if args.eval: 141 | config.EVAL_MODE = True 142 | 143 | if args.tag: 144 | config.TAG = args.tag 145 | 146 | if args.dataset: 147 | config.DATA.DATASET = args.dataset 148 | if args.gpu: 149 | config.GPU = args.gpu 150 | if args.amp: 151 | config.TRAIN.AMP = True 152 | 153 | # output folder 154 | config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET, config.TAG) 155 | 156 | config.freeze() 157 | 158 | 159 | def get_img_config(args): 160 | """Get a yacs CfgNode object with default values.""" 161 | config = _C.clone() 162 | update_config(config, args) 163 | 164 | return config 165 | -------------------------------------------------------------------------------- /configs/default_vid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from yacs.config import CfgNode as CN 4 | 5 | 6 | _C = CN() 7 | # ----------------------------------------------------------------------------- 8 | # Data settings 9 | # ----------------------------------------------------------------------------- 10 | _C.DATA = CN() 11 | # Root path for dataset directory 12 | _C.DATA.ROOT = '/home/guxinqian/data' 13 | # Dataset for evaluation 14 | _C.DATA.DATASET = 'ccvid' 15 | # Whether split each full-length video in the training set into some clips 16 | _C.DATA.DENSE_SAMPLING = True 17 | # Sampling step of dense sampling for training set 18 | _C.DATA.SAMPLING_STEP = 64 19 | # Workers for dataloader 20 | _C.DATA.NUM_WORKERS = 4 21 | # Height of input image 22 | _C.DATA.HEIGHT = 256 23 | # Width of input image 24 | _C.DATA.WIDTH = 128 25 | # Batch size for training 26 | _C.DATA.TRAIN_BATCH = 16 27 | # Batch size for testing 28 | _C.DATA.TEST_BATCH = 128 29 | # The number of instances per identity for training sampler 30 | _C.DATA.NUM_INSTANCES = 4 31 | # ----------------------------------------------------------------------------- 32 | # Augmentation settings 33 | # ----------------------------------------------------------------------------- 34 | _C.AUG = CN() 35 | # Random erase prob 36 | _C.AUG.RE_PROB = 0.0 37 | # Temporal sampling mode for training, 'tsn' or 'stride' 38 | _C.AUG.TEMPORAL_SAMPLING_MODE = 'stride' 39 | # Sequence length of each input video clip 40 | _C.AUG.SEQ_LEN = 8 41 | # Sampling stride of each input video clip 42 | _C.AUG.SAMPLING_STRIDE = 4 43 | # ----------------------------------------------------------------------------- 44 | # Model settings 45 | # ----------------------------------------------------------------------------- 46 | _C.MODEL = CN() 47 | # Model name. All supported model can be seen in models/__init__.py 48 | _C.MODEL.NAME = 'c2dres50' 49 | # The stride for laery4 in resnet 50 | _C.MODEL.RES4_STRIDE = 1 51 | # feature dim 52 | _C.MODEL.FEATURE_DIM = 2048 53 | # Model path for resuming 54 | _C.MODEL.RESUME = '' 55 | # Params for AP3D 56 | _C.MODEL.AP3D = CN() 57 | # Temperature for APM 58 | _C.MODEL.AP3D.TEMPERATURE = 4 59 | # Contrastive attention 60 | _C.MODEL.AP3D.CONTRACTIVE_ATT = True 61 | # ----------------------------------------------------------------------------- 62 | # Losses for training 63 | # ----------------------------------------------------------------------------- 64 | _C.LOSS = CN() 65 | # Classification loss 66 | _C.LOSS.CLA_LOSS = 'crossentropy' 67 | # Clothes classification loss 68 | _C.LOSS.CLOTHES_CLA_LOSS = 'cosface' 69 | # Scale for classification loss 70 | _C.LOSS.CLA_S = 16. 71 | # Margin for classification loss 72 | _C.LOSS.CLA_M = 0. 73 | # Pairwise loss 74 | _C.LOSS.PAIR_LOSS = 'triplet' 75 | # The weight for pairwise loss 76 | _C.LOSS.PAIR_LOSS_WEIGHT = 0.0 77 | # Scale for pairwise loss 78 | _C.LOSS.PAIR_S = 16. 79 | # Margin for pairwise loss 80 | _C.LOSS.PAIR_M = 0.3 81 | # Clothes-based adversarial loss 82 | _C.LOSS.CAL = 'cal' 83 | # Epsilon for clothes-based adversarial loss 84 | _C.LOSS.EPSILON = 0.1 85 | # Momentum for clothes-based adversarial loss with memory bank 86 | _C.LOSS.MOMENTUM = 0. 87 | # ----------------------------------------------------------------------------- 88 | # Training settings 89 | # ----------------------------------------------------------------------------- 90 | _C.TRAIN = CN() 91 | _C.TRAIN.START_EPOCH = 0 92 | _C.TRAIN.MAX_EPOCH = 150 93 | # Start epoch for clothes classification 94 | _C.TRAIN.START_EPOCH_CC = 50 95 | # Start epoch for adversarial training 96 | _C.TRAIN.START_EPOCH_ADV = 50 97 | # Optimizer 98 | _C.TRAIN.OPTIMIZER = CN() 99 | _C.TRAIN.OPTIMIZER.NAME = 'adam' 100 | # Learning rate 101 | _C.TRAIN.OPTIMIZER.LR = 0.00035 102 | _C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4 103 | # LR scheduler 104 | _C.TRAIN.LR_SCHEDULER = CN() 105 | # Stepsize to decay learning rate 106 | _C.TRAIN.LR_SCHEDULER.STEPSIZE = [40, 80, 120] 107 | # LR decay rate, used in StepLRScheduler 108 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 109 | # Using amp for training 110 | _C.TRAIN.AMP = False 111 | # ----------------------------------------------------------------------------- 112 | # Testing settings 113 | # ----------------------------------------------------------------------------- 114 | _C.TEST = CN() 115 | # Perform evaluation after every N epochs (set to -1 to test after training) 116 | _C.TEST.EVAL_STEP = 10 117 | # Start to evaluate after specific epoch 118 | _C.TEST.START_EVAL = 0 119 | # ----------------------------------------------------------------------------- 120 | # Misc 121 | # ----------------------------------------------------------------------------- 122 | # Fixed random seed 123 | _C.SEED = 1 124 | # Perform evaluation only 125 | _C.EVAL_MODE = False 126 | # GPU device ids for CUDA_VISIBLE_DEVICES 127 | _C.GPU = '0, 1' 128 | # Path to output folder, overwritten by command line argument 129 | _C.OUTPUT = '/data/guxinqian/logs/' 130 | # Tag of experiment, overwritten by command line argument 131 | _C.TAG = 'res50-ce-cal' 132 | 133 | 134 | def update_config(config, args): 135 | config.defrost() 136 | config.merge_from_file(args.cfg) 137 | 138 | # merge from specific arguments 139 | if args.root: 140 | config.DATA.ROOT = args.root 141 | if args.output: 142 | config.OUTPUT = args.output 143 | 144 | if args.resume: 145 | config.MODEL.RESUME = args.resume 146 | if args.eval: 147 | config.EVAL_MODE = True 148 | 149 | if args.tag: 150 | config.TAG = args.tag 151 | 152 | if args.dataset: 153 | config.DATA.DATASET = args.dataset 154 | if args.gpu: 155 | config.GPU = args.gpu 156 | if args.amp: 157 | config.TRAIN.AMP = True 158 | 159 | # output folder 160 | config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET, config.TAG) 161 | 162 | config.freeze() 163 | 164 | 165 | def get_vid_config(args): 166 | """Get a yacs CfgNode object with default values.""" 167 | config = _C.clone() 168 | update_config(config, args) 169 | 170 | return config 171 | -------------------------------------------------------------------------------- /configs/res50_cels_cal.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: crossentropylabelsmooth 5 | CAL: cal 6 | TAG: res50-cels-cal -------------------------------------------------------------------------------- /configs/res50_cels_cal_16x4.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | DATA: 4 | NUM_INSTANCES: 4 5 | TRAIN_BATCH: 32 6 | LOSS: 7 | CLA_LOSS: crossentropylabelsmooth 8 | CAL: cal 9 | TAG: res50-cels-cal-16x4 -------------------------------------------------------------------------------- /configs/res50_cels_cal_tri_16x4.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | DATA: 4 | NUM_INSTANCES: 4 5 | TRAIN_BATCH: 32 6 | LOSS: 7 | CLA_LOSS: crossentropylabelsmooth 8 | PAIR_LOSS: triplet 9 | CAL: cal 10 | PAIR_M: 0.3 11 | PAIR_LOSS_WEIGHT: 1.0 12 | TAG: res50-cels-cal-tri-16x4 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import data.img_transforms as T 2 | import data.spatial_transforms as ST 3 | import data.temporal_transforms as TT 4 | from torch.utils.data import DataLoader 5 | from data.dataloader import DataLoaderX 6 | from data.dataset_loader import ImageDataset, VideoDataset 7 | from data.samplers import DistributedRandomIdentitySampler, DistributedInferenceSampler 8 | from data.datasets.ltcc import LTCC 9 | from data.datasets.prcc import PRCC 10 | from data.datasets.last import LaST 11 | from data.datasets.ccvid import CCVID 12 | from data.datasets.deepchange import DeepChange 13 | from data.datasets.vcclothes import VCClothes, VCClothesSameClothes, VCClothesClothesChanging 14 | 15 | 16 | __factory = { 17 | 'ltcc': LTCC, 18 | 'prcc': PRCC, 19 | 'vcclothes': VCClothes, 20 | 'vcclothes_sc': VCClothesSameClothes, 21 | 'vcclothes_cc': VCClothesClothesChanging, 22 | 'last': LaST, 23 | 'ccvid': CCVID, 24 | 'deepchange': DeepChange, 25 | } 26 | 27 | VID_DATASET = ['ccvid'] 28 | 29 | 30 | def get_names(): 31 | return list(__factory.keys()) 32 | 33 | 34 | def build_dataset(config): 35 | if config.DATA.DATASET not in __factory.keys(): 36 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, __factory.keys())) 37 | 38 | if config.DATA.DATASET in VID_DATASET: 39 | dataset = __factory[config.DATA.DATASET](root=config.DATA.ROOT, 40 | sampling_step=config.DATA.SAMPLING_STEP, 41 | seq_len=config.AUG.SEQ_LEN, 42 | stride=config.AUG.SAMPLING_STRIDE) 43 | else: 44 | dataset = __factory[config.DATA.DATASET](root=config.DATA.ROOT) 45 | 46 | return dataset 47 | 48 | 49 | def build_img_transforms(config): 50 | transform_train = T.Compose([ 51 | T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)), 52 | T.RandomCroping(p=config.AUG.RC_PROB), 53 | T.RandomHorizontalFlip(p=config.AUG.RF_PROB), 54 | T.ToTensor(), 55 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 56 | T.RandomErasing(probability=config.AUG.RE_PROB) 57 | ]) 58 | transform_test = T.Compose([ 59 | T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)), 60 | T.ToTensor(), 61 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 62 | ]) 63 | 64 | return transform_train, transform_test 65 | 66 | 67 | def build_vid_transforms(config): 68 | spatial_transform_train = ST.Compose([ 69 | ST.Scale((config.DATA.HEIGHT, config.DATA.WIDTH), interpolation=3), 70 | ST.RandomHorizontalFlip(), 71 | ST.ToTensor(), 72 | ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 73 | ST.RandomErasing(height=config.DATA.HEIGHT, width=config.DATA.WIDTH, probability=config.AUG.RE_PROB) 74 | ]) 75 | spatial_transform_test = ST.Compose([ 76 | ST.Scale((config.DATA.HEIGHT, config.DATA.WIDTH), interpolation=3), 77 | ST.ToTensor(), 78 | ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 79 | ]) 80 | 81 | if config.AUG.TEMPORAL_SAMPLING_MODE == 'tsn': 82 | temporal_transform_train = TT.TemporalDivisionCrop(size=config.AUG.SEQ_LEN) 83 | elif config.AUG.TEMPORAL_SAMPLING_MODE == 'stride': 84 | temporal_transform_train = TT.TemporalRandomCrop(size=config.AUG.SEQ_LEN, 85 | stride=config.AUG.SAMPLING_STRIDE) 86 | else: 87 | raise KeyError("Invalid temporal sempling mode '{}'".format(config.AUG.TEMPORAL_SAMPLING_MODE)) 88 | 89 | temporal_transform_test = None 90 | 91 | return spatial_transform_train, spatial_transform_test, temporal_transform_train, temporal_transform_test 92 | 93 | 94 | def build_dataloader(config): 95 | dataset = build_dataset(config) 96 | # video dataset 97 | if config.DATA.DATASET in VID_DATASET: 98 | spatial_transform_train, spatial_transform_test, temporal_transform_train, temporal_transform_test = build_vid_transforms(config) 99 | 100 | if config.DATA.DENSE_SAMPLING: 101 | train_sampler = DistributedRandomIdentitySampler(dataset.train_dense, 102 | num_instances=config.DATA.NUM_INSTANCES, 103 | seed=config.SEED) 104 | # split each original training video into a series of short videos and sample one clip for each short video during training 105 | trainloader = DataLoaderX( 106 | dataset=VideoDataset(dataset.train_dense, spatial_transform_train, temporal_transform_train), 107 | sampler=train_sampler, 108 | batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS, 109 | pin_memory=True, drop_last=True) 110 | else: 111 | train_sampler = DistributedRandomIdentitySampler(dataset.train, 112 | num_instances=config.DATA.NUM_INSTANCES, 113 | seed=config.SEED) 114 | # sample one clip for each original training video during training 115 | trainloader = DataLoaderX( 116 | dataset=VideoDataset(dataset.train, spatial_transform_train, temporal_transform_train), 117 | sampler=train_sampler, 118 | batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS, 119 | pin_memory=True, drop_last=True) 120 | 121 | # split each original test video into a series of clips and use the averaged feature of all clips as its representation 122 | queryloader = DataLoaderX( 123 | dataset=VideoDataset(dataset.recombined_query, spatial_transform_test, temporal_transform_test), 124 | sampler=DistributedInferenceSampler(dataset.recombined_query), 125 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 126 | pin_memory=True, drop_last=False, shuffle=False) 127 | galleryloader = DataLoaderX( 128 | dataset=VideoDataset(dataset.recombined_gallery, spatial_transform_test, temporal_transform_test), 129 | sampler=DistributedInferenceSampler(dataset.recombined_gallery), 130 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 131 | pin_memory=True, drop_last=False, shuffle=False) 132 | 133 | return trainloader, queryloader, galleryloader, dataset, train_sampler 134 | # image dataset 135 | else: 136 | transform_train, transform_test = build_img_transforms(config) 137 | train_sampler = DistributedRandomIdentitySampler(dataset.train, 138 | num_instances=config.DATA.NUM_INSTANCES, 139 | seed=config.SEED) 140 | trainloader = DataLoaderX(dataset=ImageDataset(dataset.train, transform=transform_train), 141 | sampler=train_sampler, 142 | batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS, 143 | pin_memory=True, drop_last=True) 144 | 145 | galleryloader = DataLoaderX(dataset=ImageDataset(dataset.gallery, transform=transform_test), 146 | sampler=DistributedInferenceSampler(dataset.gallery), 147 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 148 | pin_memory=True, drop_last=False, shuffle=False) 149 | 150 | if config.DATA.DATASET == 'prcc': 151 | queryloader_same = DataLoaderX(dataset=ImageDataset(dataset.query_same, transform=transform_test), 152 | sampler=DistributedInferenceSampler(dataset.query_same), 153 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 154 | pin_memory=True, drop_last=False, shuffle=False) 155 | queryloader_diff = DataLoaderX(dataset=ImageDataset(dataset.query_diff, transform=transform_test), 156 | sampler=DistributedInferenceSampler(dataset.query_diff), 157 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 158 | pin_memory=True, drop_last=False, shuffle=False) 159 | 160 | return trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler 161 | else: 162 | queryloader = DataLoaderX(dataset=ImageDataset(dataset.query, transform=transform_test), 163 | sampler=DistributedInferenceSampler(dataset.query), 164 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 165 | pin_memory=True, drop_last=False, shuffle=False) 166 | 167 | return trainloader, queryloader, galleryloader, dataset, train_sampler 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | # refer to: https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/data/data_utils.py 2 | 3 | import torch 4 | import threading 5 | import queue 6 | from torch.utils.data import DataLoader 7 | from torch import distributed as dist 8 | 9 | 10 | """ 11 | #based on http://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | This is a single-function package that transforms arbitrary generator into a background-thead generator that 13 | prefetches several batches of data in a parallel background thead. 14 | 15 | This is useful if you have a computationally heavy process (CPU or GPU) that 16 | iteratively processes minibatches from the generator while the generator 17 | consumes some other resource (disk IO / loading from database / more CPU if you have unused cores). 18 | 19 | By default these two processes will constantly wait for one another to finish. If you make generator work in 20 | prefetch mode (see examples below), they will work in parallel, potentially saving you your GPU time. 21 | We personally use the prefetch generator when iterating minibatches of data for deep learning with PyTorch etc. 22 | 23 | Quick usage example (ipython notebook) - https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb 24 | This package contains this object 25 | - BackgroundGenerator(any_other_generator[,max_prefetch = something]) 26 | """ 27 | 28 | 29 | class BackgroundGenerator(threading.Thread): 30 | """ 31 | the usage is below 32 | >> for batch in BackgroundGenerator(my_minibatch_iterator): 33 | >> doit() 34 | More details are written in the BackgroundGenerator doc 35 | >> help(BackgroundGenerator) 36 | """ 37 | 38 | def __init__(self, generator, local_rank, max_prefetch=10): 39 | """ 40 | This function transforms generator into a background-thead generator. 41 | :param generator: generator or genexp or any 42 | It can be used with any minibatch generator. 43 | 44 | It is quite lightweight, but not entirely weightless. 45 | Using global variables inside generator is not recommended (may raise GIL and zero-out the 46 | benefit of having a background thread.) 47 | The ideal use case is when everything it requires is store inside it and everything it 48 | outputs is passed through queue. 49 | 50 | There's no restriction on doing weird stuff, reading/writing files, retrieving 51 | URLs [or whatever] wlilst iterating. 52 | 53 | :param max_prefetch: defines, how many iterations (at most) can background generator keep 54 | stored at any moment of time. 55 | Whenever there's already max_prefetch batches stored in queue, the background process will halt until 56 | one of these batches is dequeued. 57 | 58 | !Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator! 59 | 60 | Setting max_prefetch to -1 lets it store as many batches as it can, which will work 61 | slightly (if any) faster, but will require storing 62 | all batches in memory. If you use infinite generator with max_prefetch=-1, it will exceed the RAM size 63 | unless dequeued quickly enough. 64 | """ 65 | super().__init__() 66 | self.queue = queue.Queue(max_prefetch) 67 | self.generator = generator 68 | self.local_rank = local_rank 69 | self.daemon = True 70 | self.exit_event = threading.Event() 71 | self.start() 72 | 73 | def run(self): 74 | torch.cuda.set_device(self.local_rank) 75 | for item in self.generator: 76 | if self.exit_event.is_set(): 77 | break 78 | self.queue.put(item) 79 | self.queue.put(None) 80 | 81 | def next(self): 82 | next_item = self.queue.get() 83 | if next_item is None: 84 | raise StopIteration 85 | return next_item 86 | 87 | # Python 3 compatibility 88 | def __next__(self): 89 | return self.next() 90 | 91 | def __iter__(self): 92 | return self 93 | 94 | 95 | class DataLoaderX(DataLoader): 96 | def __init__(self, **kwargs): 97 | super().__init__(**kwargs) 98 | local_rank = dist.get_rank() 99 | self.stream = torch.cuda.Stream(local_rank) # create a new cuda stream in each process 100 | self.local_rank = local_rank 101 | 102 | def __iter__(self): 103 | self.iter = super().__iter__() 104 | self.iter = BackgroundGenerator(self.iter, self.local_rank) 105 | self.preload() 106 | return self 107 | 108 | def _shutdown_background_thread(self): 109 | if not self.iter.is_alive(): 110 | # avoid re-entrance or ill-conditioned thread state 111 | return 112 | 113 | # Set exit event to True for background threading stopping 114 | self.iter.exit_event.set() 115 | 116 | # Exhaust all remaining elements, so that the queue becomes empty, 117 | # and the thread should quit 118 | for _ in self.iter: 119 | pass 120 | 121 | # Waiting for background thread to quit 122 | self.iter.join() 123 | 124 | def preload(self): 125 | self.batch = next(self.iter, None) 126 | if self.batch is None: 127 | return None 128 | with torch.cuda.stream(self.stream): 129 | # if isinstance(self.batch[0], torch.Tensor): 130 | # self.batch[0] = self.batch[0].to(device=self.local_rank, non_blocking=True) 131 | for k, v in enumerate(self.batch): 132 | if isinstance(self.batch[k], torch.Tensor): 133 | self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) 134 | 135 | def __next__(self): 136 | torch.cuda.current_stream().wait_stream( 137 | self.stream 138 | ) # wait tensor to put on GPU 139 | batch = self.batch 140 | if batch is None: 141 | raise StopIteration 142 | self.preload() 143 | return batch 144 | 145 | # Signal for shutting down background thread 146 | def shutdown(self): 147 | # If the dataloader is to be freed, shutdown its BackgroundGenerator 148 | self._shutdown_background_thread() 149 | -------------------------------------------------------------------------------- /data/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | import os.path as osp 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def read_image(img_path): 9 | """Keep reading image until succeed. 10 | This can avoid IOError incurred by heavy IO process.""" 11 | got_img = False 12 | if not osp.exists(img_path): 13 | raise IOError("{} does not exist".format(img_path)) 14 | while not got_img: 15 | try: 16 | img = Image.open(img_path).convert('RGB') 17 | got_img = True 18 | except IOError: 19 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 20 | pass 21 | return img 22 | 23 | 24 | class ImageDataset(Dataset): 25 | """Image Person ReID Dataset""" 26 | def __init__(self, dataset, transform=None): 27 | self.dataset = dataset 28 | self.transform = transform 29 | 30 | def __len__(self): 31 | return len(self.dataset) 32 | 33 | def __getitem__(self, index): 34 | img_path, pid, camid, clothes_id = self.dataset[index] 35 | img = read_image(img_path) 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | return img, pid, camid, clothes_id 39 | 40 | 41 | def pil_loader(path): 42 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 43 | with open(path, 'rb') as f: 44 | with Image.open(f) as img: 45 | return img.convert('RGB') 46 | 47 | 48 | def accimage_loader(path): 49 | try: 50 | import accimage 51 | return accimage.Image(path) 52 | except IOError: 53 | # Potentially a decoding problem, fall back to PIL.Image 54 | return pil_loader(path) 55 | 56 | 57 | def get_default_image_loader(): 58 | from torchvision import get_image_backend 59 | if get_image_backend() == 'accimage': 60 | return accimage_loader 61 | else: 62 | return pil_loader 63 | 64 | 65 | def image_loader(path): 66 | from torchvision import get_image_backend 67 | if get_image_backend() == 'accimage': 68 | return accimage_loader(path) 69 | else: 70 | return pil_loader(path) 71 | 72 | 73 | def video_loader(img_paths, image_loader): 74 | video = [] 75 | for image_path in img_paths: 76 | if osp.exists(image_path): 77 | video.append(image_loader(image_path)) 78 | else: 79 | return video 80 | 81 | return video 82 | 83 | 84 | def get_default_video_loader(): 85 | image_loader = get_default_image_loader() 86 | return functools.partial(video_loader, image_loader=image_loader) 87 | 88 | 89 | class VideoDataset(Dataset): 90 | """Video Person ReID Dataset. 91 | Note: 92 | Batch data has shape N x C x T x H x W 93 | Args: 94 | dataset (list): List with items (img_paths, pid, camid) 95 | temporal_transform (callable, optional): A function/transform that takes in a list of frame indices 96 | and returns a transformed version 97 | target_transform (callable, optional): A function/transform that takes in the 98 | target and transforms it. 99 | loader (callable, optional): A function to load an video given its path and frame indices. 100 | """ 101 | 102 | def __init__(self, 103 | dataset, 104 | spatial_transform=None, 105 | temporal_transform=None, 106 | get_loader=get_default_video_loader, 107 | cloth_changing=True): 108 | self.dataset = dataset 109 | self.spatial_transform = spatial_transform 110 | self.temporal_transform = temporal_transform 111 | self.loader = get_loader() 112 | self.cloth_changing = cloth_changing 113 | 114 | def __len__(self): 115 | return len(self.dataset) 116 | 117 | def __getitem__(self, index): 118 | """ 119 | Args: 120 | index (int): Index 121 | 122 | Returns: 123 | tuple: (clip, pid, camid) where pid is identity of the clip. 124 | """ 125 | if self.cloth_changing: 126 | img_paths, pid, camid, clothes_id = self.dataset[index] 127 | else: 128 | img_paths, pid, camid = self.dataset[index] 129 | 130 | if self.temporal_transform is not None: 131 | img_paths = self.temporal_transform(img_paths) 132 | 133 | clip = self.loader(img_paths) 134 | 135 | if self.spatial_transform is not None: 136 | self.spatial_transform.randomize_parameters() 137 | clip = [self.spatial_transform(img) for img in clip] 138 | 139 | # trans T x C x H x W to C x T x H x W 140 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3) 141 | 142 | if self.cloth_changing: 143 | return clip, pid, camid, clothes_id 144 | else: 145 | return clip, pid, camid -------------------------------------------------------------------------------- /data/datasets/ccvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class CCVID(object): 15 | """ CCVID 16 | 17 | Reference: 18 | Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022. 19 | """ 20 | def __init__(self, root='/data/datasets/', sampling_step=64, seq_len=16, stride=4, **kwargs): 21 | self.root = osp.join(root, 'CCVID') 22 | self.train_path = osp.join(self.root, 'train.txt') 23 | self.query_path = osp.join(self.root, 'query.txt') 24 | self.gallery_path = osp.join(self.root, 'gallery.txt') 25 | self._check_before_run() 26 | 27 | train, num_train_tracklets, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes, _ = \ 28 | self._process_data(self.train_path, relabel=True) 29 | clothes2label = self._clothes2label_test(self.query_path, self.gallery_path) 30 | query, num_query_tracklets, num_query_pids, num_query_imgs, num_query_clothes, _, _ = \ 31 | self._process_data(self.query_path, relabel=False, clothes2label=clothes2label) 32 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs, num_gallery_clothes, _, _ = \ 33 | self._process_data(self.gallery_path, relabel=False, clothes2label=clothes2label) 34 | 35 | # slice each full-length video in the trainingset into more video clip 36 | train_dense = self._densesampling_for_trainingset(train, sampling_step) 37 | # In the test stage, each video sample is divided into a series of equilong video clips with a pre-defined stride. 38 | recombined_query, query_vid2clip_index = self._recombination_for_testset(query, seq_len=seq_len, stride=stride) 39 | recombined_gallery, gallery_vid2clip_index = self._recombination_for_testset(gallery, seq_len=seq_len, stride=stride) 40 | 41 | num_imgs_per_tracklet = num_train_imgs + num_gallery_imgs + num_query_imgs 42 | min_num = np.min(num_imgs_per_tracklet) 43 | max_num = np.max(num_imgs_per_tracklet) 44 | avg_num = np.mean(num_imgs_per_tracklet) 45 | 46 | num_total_pids = num_train_pids + num_gallery_pids 47 | num_total_clothes = num_train_clothes + len(clothes2label) 48 | num_total_tracklets = num_train_tracklets + num_gallery_tracklets + num_query_tracklets 49 | 50 | logger = logging.getLogger('reid.dataset') 51 | logger.info("=> CCVID loaded") 52 | logger.info("Dataset statistics:") 53 | logger.info(" ---------------------------------------------") 54 | logger.info(" subset | # ids | # tracklets | # clothes") 55 | logger.info(" ---------------------------------------------") 56 | logger.info(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_clothes)) 57 | logger.info(" train_dense | {:5d} | {:11d} | {:9d}".format(num_train_pids, len(train_dense), num_train_clothes)) 58 | logger.info(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_clothes)) 59 | logger.info(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_clothes)) 60 | logger.info(" ---------------------------------------------") 61 | logger.info(" total | {:5d} | {:11d} | {:9d}".format(num_total_pids, num_total_tracklets, num_total_clothes)) 62 | logger.info(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 63 | logger.info(" ---------------------------------------------") 64 | 65 | self.train = train 66 | self.train_dense = train_dense 67 | self.query = query 68 | self.gallery = gallery 69 | 70 | self.recombined_query = recombined_query 71 | self.recombined_gallery = recombined_gallery 72 | self.query_vid2clip_index = query_vid2clip_index 73 | self.gallery_vid2clip_index = gallery_vid2clip_index 74 | 75 | self.num_train_pids = num_train_pids 76 | self.num_train_clothes = num_train_clothes 77 | self.pid2clothes = pid2clothes 78 | 79 | def _check_before_run(self): 80 | """Check if all files are available before going deeper""" 81 | if not osp.exists(self.root): 82 | raise RuntimeError("'{}' is not available".format(self.root)) 83 | if not osp.exists(self.train_path): 84 | raise RuntimeError("'{}' is not available".format(self.train_path)) 85 | if not osp.exists(self.query_path): 86 | raise RuntimeError("'{}' is not available".format(self.query_path)) 87 | if not osp.exists(self.gallery_path): 88 | raise RuntimeError("'{}' is not available".format(self.gallery_path)) 89 | 90 | def _clothes2label_test(self, query_path, gallery_path): 91 | pid_container = set() 92 | clothes_container = set() 93 | with open(query_path, 'r') as f: 94 | for line in f: 95 | new_line = line.rstrip() 96 | tracklet_path, pid, clothes_label = new_line.split() 97 | clothes = '{}_{}'.format(pid, clothes_label) 98 | pid_container.add(pid) 99 | clothes_container.add(clothes) 100 | with open(gallery_path, 'r') as f: 101 | for line in f: 102 | new_line = line.rstrip() 103 | tracklet_path, pid, clothes_label = new_line.split() 104 | clothes = '{}_{}'.format(pid, clothes_label) 105 | pid_container.add(pid) 106 | clothes_container.add(clothes) 107 | pid_container = sorted(pid_container) 108 | clothes_container = sorted(clothes_container) 109 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 110 | clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)} 111 | 112 | return clothes2label 113 | 114 | def _process_data(self, data_path, relabel=False, clothes2label=None): 115 | tracklet_path_list = [] 116 | pid_container = set() 117 | clothes_container = set() 118 | with open(data_path, 'r') as f: 119 | for line in f: 120 | new_line = line.rstrip() 121 | tracklet_path, pid, clothes_label = new_line.split() 122 | tracklet_path_list.append((tracklet_path, pid, clothes_label)) 123 | clothes = '{}_{}'.format(pid, clothes_label) 124 | pid_container.add(pid) 125 | clothes_container.add(clothes) 126 | pid_container = sorted(pid_container) 127 | clothes_container = sorted(clothes_container) 128 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 129 | if clothes2label is None: 130 | clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)} 131 | 132 | num_tracklets = len(tracklet_path_list) 133 | num_pids = len(pid_container) 134 | num_clothes = len(clothes_container) 135 | 136 | tracklets = [] 137 | num_imgs_per_tracklet = [] 138 | pid2clothes = np.zeros((num_pids, len(clothes2label))) 139 | 140 | for tracklet_path, pid, clothes_label in tracklet_path_list: 141 | img_paths = glob.glob(osp.join(self.root, tracklet_path, '*')) 142 | img_paths.sort() 143 | 144 | clothes = '{}_{}'.format(pid, clothes_label) 145 | clothes_id = clothes2label[clothes] 146 | pid2clothes[pid2label[pid], clothes_id] = 1 147 | if relabel: 148 | pid = pid2label[pid] 149 | else: 150 | pid = int(pid) 151 | session = tracklet_path.split('/')[0] 152 | cam = tracklet_path.split('_')[1] 153 | if session == 'session3': 154 | camid = int(cam) + 12 155 | else: 156 | camid = int(cam) 157 | 158 | num_imgs_per_tracklet.append(len(img_paths)) 159 | tracklets.append((img_paths, pid, camid, clothes_id)) 160 | 161 | num_tracklets = len(tracklets) 162 | 163 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet, num_clothes, pid2clothes, clothes2label 164 | 165 | def _densesampling_for_trainingset(self, dataset, sampling_step=64): 166 | ''' Split all videos in training set into lots of clips for dense sampling. 167 | 168 | Args: 169 | dataset (list): input dataset, each video is organized as (img_paths, pid, camid, clothes_id) 170 | sampling_step (int): sampling step for dense sampling 171 | 172 | Returns: 173 | new_dataset (list): output dataset 174 | ''' 175 | new_dataset = [] 176 | for (img_paths, pid, camid, clothes_id) in dataset: 177 | if sampling_step != 0: 178 | num_sampling = len(img_paths)//sampling_step 179 | if num_sampling == 0: 180 | new_dataset.append((img_paths, pid, camid, clothes_id)) 181 | else: 182 | for idx in range(num_sampling): 183 | if idx == num_sampling - 1: 184 | new_dataset.append((img_paths[idx*sampling_step:], pid, camid, clothes_id)) 185 | else: 186 | new_dataset.append((img_paths[idx*sampling_step : (idx+1)*sampling_step], pid, camid, clothes_id)) 187 | else: 188 | new_dataset.append((img_paths, pid, camid, clothes_id)) 189 | 190 | return new_dataset 191 | 192 | def _recombination_for_testset(self, dataset, seq_len=16, stride=4): 193 | ''' Split all videos in test set into lots of equilong clips. 194 | 195 | Args: 196 | dataset (list): input dataset, each video is organized as (img_paths, pid, camid, clothes_id) 197 | seq_len (int): sequence length of each output clip 198 | stride (int): temporal sampling stride 199 | 200 | Returns: 201 | new_dataset (list): output dataset with lots of equilong clips 202 | vid2clip_index (list): a list contains the start and end clip index of each original video 203 | ''' 204 | new_dataset = [] 205 | vid2clip_index = np.zeros((len(dataset), 2), dtype=int) 206 | for idx, (img_paths, pid, camid, clothes_id) in enumerate(dataset): 207 | # start index 208 | vid2clip_index[idx, 0] = len(new_dataset) 209 | # process the sequence that can be divisible by seq_len*stride 210 | for i in range(len(img_paths)//(seq_len*stride)): 211 | for j in range(stride): 212 | begin_idx = i * (seq_len * stride) + j 213 | end_idx = (i + 1) * (seq_len * stride) 214 | clip_paths = img_paths[begin_idx : end_idx : stride] 215 | assert(len(clip_paths) == seq_len) 216 | new_dataset.append((clip_paths, pid, camid, clothes_id)) 217 | # process the remaining sequence that can't be divisible by seq_len*stride 218 | if len(img_paths)%(seq_len*stride) != 0: 219 | # reducing stride 220 | new_stride = (len(img_paths)%(seq_len*stride)) // seq_len 221 | for i in range(new_stride): 222 | begin_idx = len(img_paths) // (seq_len*stride) * (seq_len*stride) + i 223 | end_idx = len(img_paths) // (seq_len*stride) * (seq_len*stride) + seq_len * new_stride 224 | clip_paths = img_paths[begin_idx : end_idx : new_stride] 225 | assert(len(clip_paths) == seq_len) 226 | new_dataset.append((clip_paths, pid, camid, clothes_id)) 227 | # process the remaining sequence that can't be divisible by seq_len 228 | if len(img_paths) % seq_len != 0: 229 | clip_paths = img_paths[len(img_paths)//seq_len*seq_len:] 230 | # loop padding 231 | while len(clip_paths) < seq_len: 232 | for index in clip_paths: 233 | if len(clip_paths) >= seq_len: 234 | break 235 | clip_paths.append(index) 236 | assert(len(clip_paths) == seq_len) 237 | new_dataset.append((clip_paths, pid, camid, clothes_id)) 238 | # end index 239 | vid2clip_index[idx, 1] = len(new_dataset) 240 | assert((vid2clip_index[idx, 1]-vid2clip_index[idx, 0]) == math.ceil(len(img_paths)/seq_len)) 241 | 242 | return new_dataset, vid2clip_index.tolist() 243 | 244 | -------------------------------------------------------------------------------- /data/datasets/deepchange.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class DeepChange(object): 15 | """ DeepChange 16 | 17 | Reference: 18 | Xu et al. DeepChange: A Long-Term Person Re-Identification Benchmark. arXiv:2105.14685, 2021. 19 | 20 | URL: https://github.com/PengBoXiangShang/deepchange 21 | """ 22 | dataset_dir = 'DeepChangeDataset' 23 | def __init__(self, root='data', **kwargs): 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'train-set') 26 | self.train_list = osp.join(self.dataset_dir, 'train-set-bbox.txt') 27 | self.val_query_dir = osp.join(self.dataset_dir, 'val-set-query') 28 | self.val_query_list = osp.join(self.dataset_dir, 'val-set-query-bbox.txt') 29 | self.val_gallery_dir = osp.join(self.dataset_dir, 'val-set-gallery') 30 | self.val_gallery_list = osp.join(self.dataset_dir, 'val-set-gallery-bbox.txt') 31 | self.test_query_dir = osp.join(self.dataset_dir, 'test-set-query') 32 | self.test_query_list = osp.join(self.dataset_dir, 'test-set-query-bbox.txt') 33 | self.test_gallery_dir = osp.join(self.dataset_dir, 'test-set-gallery') 34 | self.test_gallery_list = osp.join(self.dataset_dir, 'test-set-gallery-bbox.txt') 35 | self._check_before_run() 36 | 37 | train_names = self._get_names(self.train_list) 38 | val_query_names = self._get_names(self.val_query_list) 39 | val_gallery_names = self._get_names(self.val_gallery_list) 40 | test_query_names = self._get_names(self.test_query_list) 41 | test_gallery_names = self._get_names(self.test_gallery_list) 42 | 43 | pid2label, clothes2label, pid2clothes = self.get_pid2label_and_clothes2label(train_names) 44 | train, num_train_pids, num_train_clothes = self._process_dir(self.train_dir, train_names, clothes2label, pid2label=pid2label) 45 | 46 | pid2label, clothes2label = self.get_pid2label_and_clothes2label(val_query_names, val_gallery_names) 47 | val_query, num_val_query_pids, num_val_query_clothes = self._process_dir(self.val_query_dir, val_query_names, clothes2label) 48 | val_gallery, num_val_gallery_pids, num_val_gallery_clothes = self._process_dir(self.val_gallery_dir, val_gallery_names, clothes2label) 49 | num_val_pids = len(pid2label) 50 | num_val_clothes = len(clothes2label) 51 | 52 | pid2label, clothes2label = self.get_pid2label_and_clothes2label(test_query_names, test_gallery_names) 53 | test_query, num_test_query_pids, num_test_query_clothes = self._process_dir(self.test_query_dir, test_query_names, clothes2label) 54 | test_gallery, num_test_gallery_pids, num_test_gallery_clothes = self._process_dir(self.test_gallery_dir, test_gallery_names, clothes2label) 55 | num_test_pids = len(pid2label) 56 | num_test_clothes = len(clothes2label) 57 | 58 | num_total_pids = num_train_pids + num_val_pids + num_test_pids 59 | num_total_clothes = num_train_clothes + num_val_clothes + num_test_clothes 60 | num_total_imgs = len(train) + len(val_query) + len(val_gallery) + len(test_query) + len(test_gallery) 61 | 62 | logger = logging.getLogger('reid.dataset') 63 | logger.info("=> DeepChange loaded") 64 | logger.info("Dataset statistics:") 65 | logger.info(" --------------------------------------------") 66 | logger.info(" subset | # ids | # images | # clothes") 67 | logger.info(" ----------------------------------------") 68 | logger.info(" train | {:5d} | {:8d} | {:9d} ".format(num_train_pids, len(train), num_train_clothes)) 69 | logger.info(" query(val) | {:5d} | {:8d} | {:9d} ".format(num_val_query_pids, len(val_query), num_val_query_clothes)) 70 | logger.info(" gallery(val) | {:5d} | {:8d} | {:9d} ".format(num_val_gallery_pids, len(val_gallery), num_val_gallery_clothes)) 71 | logger.info(" query | {:5d} | {:8d} | {:9d} ".format(num_test_query_pids, len(test_query), num_test_query_clothes)) 72 | logger.info(" gallery | {:5d} | {:8d} | {:9d} ".format(num_test_gallery_pids, len(test_gallery), num_test_gallery_clothes)) 73 | logger.info(" --------------------------------------------") 74 | logger.info(" total | {:5d} | {:8d} | {:9d} ".format(num_total_pids, num_total_imgs, num_total_clothes)) 75 | logger.info(" --------------------------------------------") 76 | 77 | self.train = train 78 | self.val_query = val_query 79 | self.val_gallery = val_gallery 80 | self.query = test_query 81 | self.gallery = test_gallery 82 | 83 | self.num_train_pids = num_train_pids 84 | self.num_train_clothes = num_train_clothes 85 | self.pid2clothes = pid2clothes 86 | 87 | def _get_names(self, fpath): 88 | names = [] 89 | with open(fpath, 'r') as f: 90 | for line in f: 91 | new_line = line.rstrip() 92 | names.append(new_line) 93 | return names 94 | 95 | def get_pid2label_and_clothes2label(self, img_names1, img_names2=None): 96 | if img_names2 is not None: 97 | img_names = img_names1 + img_names2 98 | else: 99 | img_names = img_names1 100 | 101 | pid_container = set() 102 | clothes_container = set() 103 | for img_name in img_names: 104 | names = img_name.split('.')[0].split('_') 105 | clothes = names[0] + names[2] 106 | pid = int(names[0][1:]) 107 | pid_container.add(pid) 108 | clothes_container.add(clothes) 109 | pid_container = sorted(pid_container) 110 | clothes_container = sorted(clothes_container) 111 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 112 | clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)} 113 | 114 | if img_names2 is not None: 115 | return pid2label, clothes2label 116 | 117 | num_pids = len(pid_container) 118 | num_clothes = len(clothes_container) 119 | pid2clothes = np.zeros((num_pids, num_clothes)) 120 | for img_name in img_names: 121 | names = img_name.split('.')[0].split('_') 122 | clothes = names[0] + names[2] 123 | pid = int(names[0][1:]) 124 | pid = pid2label[pid] 125 | clothes_id = clothes2label[clothes] 126 | pid2clothes[pid, clothes_id] = 1 127 | 128 | return pid2label, clothes2label, pid2clothes 129 | 130 | def _check_before_run(self): 131 | """Check if all files are available before going deeper""" 132 | if not osp.exists(self.dataset_dir): 133 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 134 | if not osp.exists(self.train_dir): 135 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 136 | if not osp.exists(self.val_query_dir): 137 | raise RuntimeError("'{}' is not available".format(self.val_query_dir)) 138 | if not osp.exists(self.val_gallery_dir): 139 | raise RuntimeError("'{}' is not available".format(self.val_gallery_dir)) 140 | if not osp.exists(self.test_query_dir): 141 | raise RuntimeError("'{}' is not available".format(self.test_query_dir)) 142 | if not osp.exists(self.test_gallery_dir): 143 | raise RuntimeError("'{}' is not available".format(self.test_gallery_dir)) 144 | 145 | def _process_dir(self, home_dir, img_names, clothes2label, pid2label=None): 146 | dataset = [] 147 | pid_container = set() 148 | clothes_container = set() 149 | for img_name in img_names: 150 | img_path = osp.join(home_dir, img_name.split(',')[0]) 151 | names = img_name.split('.')[0].split('_') 152 | tracklet_id = int(img_name.split(',')[1]) 153 | clothes = names[0] + names[2] 154 | clothes_id = clothes2label[clothes] 155 | clothes_container.add(clothes_id) 156 | pid = int(names[0][1:]) 157 | pid_container.add(pid) 158 | camid = int(names[1][1:]) 159 | if pid2label is not None: 160 | pid = pid2label[pid] 161 | # on DeepChange, we allow the true matches coming from the same camera 162 | # but different tracklets as query following the original paper. 163 | # So we use tracklet_id to replace camid for each sample. 164 | dataset.append((img_path, pid, tracklet_id, clothes_id)) 165 | num_pids = len(pid_container) 166 | num_clothes = len(clothes_container) 167 | 168 | return dataset, num_pids, num_clothes -------------------------------------------------------------------------------- /data/datasets/last.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class LaST(object): 15 | """ LaST 16 | 17 | Reference: 18 | Shu et al. Large-Scale Spatio-Temporal Person Re-identification: Algorithm and Benchmark. arXiv:2105.15076, 2021. 19 | 20 | URL: https://github.com/shuxjweb/last 21 | 22 | Note that LaST does not provide the clothes label for val and test set. 23 | """ 24 | dataset_dir = "last" 25 | def __init__(self, root='data', **kwargs): 26 | super(LaST, self).__init__() 27 | self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'train') 29 | self.val_query_dir = osp.join(self.dataset_dir, 'val', 'query') 30 | self.val_gallery_dir = osp.join(self.dataset_dir, 'val', 'gallery') 31 | self.test_query_dir = osp.join(self.dataset_dir, 'test', 'query') 32 | self.test_gallery_dir = osp.join(self.dataset_dir, 'test', 'gallery') 33 | self._check_before_run() 34 | 35 | pid2label, clothes2label, pid2clothes = self.get_pid2label_and_clothes2label(self.train_dir) 36 | 37 | train, num_train_pids = self._process_dir(self.train_dir, pid2label=pid2label, clothes2label=clothes2label, relabel=True) 38 | val_query, num_val_query_pids = self._process_dir(self.val_query_dir, relabel=False) 39 | val_gallery, num_val_gallery_pids = self._process_dir(self.val_gallery_dir, relabel=False, recam=len(val_query)) 40 | test_query, num_test_query_pids = self._process_dir(self.test_query_dir, relabel=False) 41 | test_gallery, num_test_gallery_pids = self._process_dir(self.test_gallery_dir, relabel=False, recam=len(test_query)) 42 | 43 | num_total_pids = num_train_pids+num_val_gallery_pids+num_test_gallery_pids 44 | num_total_imgs = len(train) + len(val_query) + len(val_gallery) + len(test_query) + len(test_gallery) 45 | 46 | logger = logging.getLogger('reid.dataset') 47 | logger.info("=> LaST loaded") 48 | logger.info("Dataset statistics:") 49 | logger.info(" --------------------------------------------") 50 | logger.info(" subset | # ids | # images | # clothes") 51 | logger.info(" ----------------------------------------") 52 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, len(train), len(clothes2label))) 53 | logger.info(" query(val) | {:5d} | {:8d} |".format(num_val_query_pids, len(val_query))) 54 | logger.info(" gallery(val) | {:5d} | {:8d} |".format(num_val_gallery_pids, len(val_gallery))) 55 | logger.info(" query | {:5d} | {:8d} |".format(num_test_query_pids, len(test_query))) 56 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_gallery_pids, len(test_gallery))) 57 | logger.info(" --------------------------------------------") 58 | logger.info(" total | {:5d} | {:8d} | ".format(num_total_pids, num_total_imgs)) 59 | logger.info(" --------------------------------------------") 60 | 61 | self.train = train 62 | self.val_query = val_query 63 | self.val_gallery = val_gallery 64 | self.query = test_query 65 | self.gallery = test_gallery 66 | 67 | self.num_train_pids = num_train_pids 68 | self.num_train_clothes = len(clothes2label) 69 | self.pid2clothes = pid2clothes 70 | 71 | def get_pid2label_and_clothes2label(self, dir_path): 72 | img_paths = glob.glob(osp.join(dir_path, '*/*.jpg')) # [103367,] 73 | img_paths.sort() 74 | 75 | pid_container = set() 76 | clothes_container = set() 77 | for img_path in img_paths: 78 | names = osp.basename(img_path).split('.')[0].split('_') 79 | clothes = names[0] + '_' + names[-1] 80 | pid = int(names[0]) 81 | pid_container.add(pid) 82 | clothes_container.add(clothes) 83 | pid_container = sorted(pid_container) 84 | clothes_container = sorted(clothes_container) 85 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 86 | clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)} 87 | 88 | num_pids = len(pid_container) 89 | num_clothes = len(clothes_container) 90 | 91 | pid2clothes = np.zeros((num_pids, num_clothes)) 92 | for img_path in img_paths: 93 | names = osp.basename(img_path).split('.')[0].split('_') 94 | clothes = names[0] + '_' + names[-1] 95 | pid = int(names[0]) 96 | pid = pid2label[pid] 97 | clothes_id = clothes2label[clothes] 98 | pid2clothes[pid, clothes_id] = 1 99 | 100 | return pid2label, clothes2label, pid2clothes 101 | 102 | def _check_before_run(self): 103 | """Check if all files are available before going deeper""" 104 | if not osp.exists(self.dataset_dir): 105 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 106 | if not osp.exists(self.train_dir): 107 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 108 | if not osp.exists(self.val_query_dir): 109 | raise RuntimeError("'{}' is not available".format(self.val_query_dir)) 110 | if not osp.exists(self.val_gallery_dir): 111 | raise RuntimeError("'{}' is not available".format(self.val_gallery_dir)) 112 | if not osp.exists(self.test_query_dir): 113 | raise RuntimeError("'{}' is not available".format(self.test_query_dir)) 114 | if not osp.exists(self.test_gallery_dir): 115 | raise RuntimeError("'{}' is not available".format(self.test_gallery_dir)) 116 | 117 | def _process_dir(self, dir_path, pid2label=None, clothes2label=None, relabel=False, recam=0): 118 | if 'query' in dir_path: 119 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 120 | else: 121 | img_paths = glob.glob(osp.join(dir_path, '*/*.jpg')) 122 | img_paths.sort() 123 | 124 | dataset = [] 125 | pid_container = set() 126 | for ii, img_path in enumerate(img_paths): 127 | names = osp.basename(img_path).split('.')[0].split('_') 128 | clothes = names[0] + '_' + names[-1] 129 | pid = int(names[0]) 130 | pid_container.add(pid) 131 | camid = int(recam + ii) 132 | if relabel and pid2label is not None: 133 | pid = pid2label[pid] 134 | if relabel and clothes2label is not None: 135 | clothes_id = clothes2label[clothes] 136 | else: 137 | clothes_id = pid 138 | dataset.append((img_path, pid, camid, clothes_id)) 139 | num_pids = len(pid_container) 140 | 141 | return dataset, num_pids -------------------------------------------------------------------------------- /data/datasets/ltcc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class LTCC(object): 15 | """ LTCC 16 | 17 | Reference: 18 | Qian et al. Long-Term Cloth-Changing Person Re-identification. arXiv:2005.12633, 2020. 19 | 20 | URL: https://naiq.github.io/LTCC_Perosn_ReID.html# 21 | """ 22 | dataset_dir = 'LTCC_ReID' 23 | def __init__(self, root='data', **kwargs): 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'train') 26 | self.query_dir = osp.join(self.dataset_dir, 'query') 27 | self.gallery_dir = osp.join(self.dataset_dir, 'test') 28 | self._check_before_run() 29 | 30 | train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \ 31 | self._process_dir_train(self.train_dir) 32 | query, gallery, num_test_pids, num_query_imgs, num_gallery_imgs, num_test_clothes = \ 33 | self._process_dir_test(self.query_dir, self.gallery_dir) 34 | num_total_pids = num_train_pids + num_test_pids 35 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 36 | num_test_imgs = num_query_imgs + num_gallery_imgs 37 | num_total_clothes = num_train_clothes + num_test_clothes 38 | 39 | logger = logging.getLogger('reid.dataset') 40 | logger.info("=> LTCC loaded") 41 | logger.info("Dataset statistics:") 42 | logger.info(" ----------------------------------------") 43 | logger.info(" subset | # ids | # images | # clothes") 44 | logger.info(" ----------------------------------------") 45 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes)) 46 | logger.info(" test | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes)) 47 | logger.info(" query | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs)) 48 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs)) 49 | logger.info(" ----------------------------------------") 50 | logger.info(" total | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes)) 51 | logger.info(" ----------------------------------------") 52 | 53 | self.train = train 54 | self.query = query 55 | self.gallery = gallery 56 | 57 | self.num_train_pids = num_train_pids 58 | self.num_train_clothes = num_train_clothes 59 | self.pid2clothes = pid2clothes 60 | 61 | def _check_before_run(self): 62 | """Check if all files are available before going deeper""" 63 | if not osp.exists(self.dataset_dir): 64 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 65 | if not osp.exists(self.train_dir): 66 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 67 | if not osp.exists(self.query_dir): 68 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 69 | if not osp.exists(self.gallery_dir): 70 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 71 | 72 | def _process_dir_train(self, dir_path): 73 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 74 | img_paths.sort() 75 | pattern1 = re.compile(r'(\d+)_(\d+)_c(\d+)') 76 | pattern2 = re.compile(r'(\w+)_c') 77 | 78 | pid_container = set() 79 | clothes_container = set() 80 | for img_path in img_paths: 81 | pid, _, _ = map(int, pattern1.search(img_path).groups()) 82 | clothes_id = pattern2.search(img_path).group(1) 83 | pid_container.add(pid) 84 | clothes_container.add(clothes_id) 85 | pid_container = sorted(pid_container) 86 | clothes_container = sorted(clothes_container) 87 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 88 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 89 | 90 | num_pids = len(pid_container) 91 | num_clothes = len(clothes_container) 92 | 93 | dataset = [] 94 | pid2clothes = np.zeros((num_pids, num_clothes)) 95 | for img_path in img_paths: 96 | pid, _, camid = map(int, pattern1.search(img_path).groups()) 97 | clothes = pattern2.search(img_path).group(1) 98 | camid -= 1 # index starts from 0 99 | pid = pid2label[pid] 100 | clothes_id = clothes2label[clothes] 101 | dataset.append((img_path, pid, camid, clothes_id)) 102 | pid2clothes[pid, clothes_id] = 1 103 | 104 | num_imgs = len(dataset) 105 | 106 | return dataset, num_pids, num_imgs, num_clothes, pid2clothes 107 | 108 | def _process_dir_test(self, query_path, gallery_path): 109 | query_img_paths = glob.glob(osp.join(query_path, '*.png')) 110 | gallery_img_paths = glob.glob(osp.join(gallery_path, '*.png')) 111 | query_img_paths.sort() 112 | gallery_img_paths.sort() 113 | pattern1 = re.compile(r'(\d+)_(\d+)_c(\d+)') 114 | pattern2 = re.compile(r'(\w+)_c') 115 | 116 | pid_container = set() 117 | clothes_container = set() 118 | for img_path in query_img_paths: 119 | pid, _, _ = map(int, pattern1.search(img_path).groups()) 120 | clothes_id = pattern2.search(img_path).group(1) 121 | pid_container.add(pid) 122 | clothes_container.add(clothes_id) 123 | for img_path in gallery_img_paths: 124 | pid, _, _ = map(int, pattern1.search(img_path).groups()) 125 | clothes_id = pattern2.search(img_path).group(1) 126 | pid_container.add(pid) 127 | clothes_container.add(clothes_id) 128 | pid_container = sorted(pid_container) 129 | clothes_container = sorted(clothes_container) 130 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 131 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 132 | 133 | num_pids = len(pid_container) 134 | num_clothes = len(clothes_container) 135 | 136 | query_dataset = [] 137 | gallery_dataset = [] 138 | for img_path in query_img_paths: 139 | pid, _, camid = map(int, pattern1.search(img_path).groups()) 140 | clothes_id = pattern2.search(img_path).group(1) 141 | camid -= 1 # index starts from 0 142 | clothes_id = clothes2label[clothes_id] 143 | query_dataset.append((img_path, pid, camid, clothes_id)) 144 | 145 | for img_path in gallery_img_paths: 146 | pid, _, camid = map(int, pattern1.search(img_path).groups()) 147 | clothes_id = pattern2.search(img_path).group(1) 148 | camid -= 1 # index starts from 0 149 | clothes_id = clothes2label[clothes_id] 150 | gallery_dataset.append((img_path, pid, camid, clothes_id)) 151 | 152 | num_imgs_query = len(query_dataset) 153 | num_imgs_gallery = len(gallery_dataset) 154 | 155 | return query_dataset, gallery_dataset, num_pids, num_imgs_query, num_imgs_gallery, num_clothes 156 | 157 | -------------------------------------------------------------------------------- /data/datasets/prcc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class PRCC(object): 15 | """ PRCC 16 | 17 | Reference: 18 | Yang et al. Person Re-identification by Contour Sketch under Moderate Clothing Change. TPAMI, 2019. 19 | 20 | URL: https://drive.google.com/file/d/1yTYawRm4ap3M-j0PjLQJ--xmZHseFDLz/view 21 | """ 22 | dataset_dir = 'prcc' 23 | def __init__(self, root='data', **kwargs): 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'rgb/train') 26 | self.val_dir = osp.join(self.dataset_dir, 'rgb/val') 27 | self.test_dir = osp.join(self.dataset_dir, 'rgb/test') 28 | self._check_before_run() 29 | 30 | train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \ 31 | self._process_dir_train(self.train_dir) 32 | val, num_val_pids, num_val_imgs, num_val_clothes, _ = \ 33 | self._process_dir_train(self.val_dir) 34 | 35 | query_same, query_diff, gallery, num_test_pids, \ 36 | num_query_imgs_same, num_query_imgs_diff, num_gallery_imgs, \ 37 | num_test_clothes, gallery_idx = self._process_dir_test(self.test_dir) 38 | 39 | num_total_pids = num_train_pids + num_test_pids 40 | num_test_imgs = num_query_imgs_same + num_query_imgs_diff + num_gallery_imgs 41 | num_total_imgs = num_train_imgs + num_val_imgs + num_test_imgs 42 | num_total_clothes = num_train_clothes + num_test_clothes 43 | 44 | logger = logging.getLogger('reid.dataset') 45 | logger.info("=> PRCC loaded") 46 | logger.info("Dataset statistics:") 47 | logger.info(" --------------------------------------------") 48 | logger.info(" subset | # ids | # images | # clothes") 49 | logger.info(" --------------------------------------------") 50 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes)) 51 | logger.info(" val | {:5d} | {:8d} | {:9d}".format(num_val_pids, num_val_imgs, num_val_clothes)) 52 | logger.info(" test | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes)) 53 | logger.info(" query(same) | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs_same)) 54 | logger.info(" query(diff) | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs_diff)) 55 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs)) 56 | logger.info(" --------------------------------------------") 57 | logger.info(" total | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes)) 58 | logger.info(" --------------------------------------------") 59 | 60 | self.train = train 61 | self.val = val 62 | self.query_same = query_same 63 | self.query_diff = query_diff 64 | self.gallery = gallery 65 | 66 | self.num_train_pids = num_train_pids 67 | self.num_train_clothes = num_train_clothes 68 | self.pid2clothes = pid2clothes 69 | self.gallery_idx = gallery_idx 70 | 71 | def _check_before_run(self): 72 | """Check if all files are available before going deeper""" 73 | if not osp.exists(self.dataset_dir): 74 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 75 | if not osp.exists(self.train_dir): 76 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 77 | if not osp.exists(self.val_dir): 78 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 79 | if not osp.exists(self.test_dir): 80 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 81 | 82 | def _process_dir_train(self, dir_path): 83 | pdirs = glob.glob(osp.join(dir_path, '*')) 84 | pdirs.sort() 85 | 86 | pid_container = set() 87 | clothes_container = set() 88 | for pdir in pdirs: 89 | pid = int(osp.basename(pdir)) 90 | pid_container.add(pid) 91 | img_dirs = glob.glob(osp.join(pdir, '*.jpg')) 92 | for img_dir in img_dirs: 93 | cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C' 94 | if cam in ['A', 'B']: 95 | clothes_container.add(osp.basename(pdir)) 96 | else: 97 | clothes_container.add(osp.basename(pdir)+osp.basename(img_dir)[0]) 98 | pid_container = sorted(pid_container) 99 | clothes_container = sorted(clothes_container) 100 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 101 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 102 | cam2label = {'A': 0, 'B': 1, 'C': 2} 103 | 104 | num_pids = len(pid_container) 105 | num_clothes = len(clothes_container) 106 | 107 | dataset = [] 108 | pid2clothes = np.zeros((num_pids, num_clothes)) 109 | for pdir in pdirs: 110 | pid = int(osp.basename(pdir)) 111 | img_dirs = glob.glob(osp.join(pdir, '*.jpg')) 112 | for img_dir in img_dirs: 113 | cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C' 114 | label = pid2label[pid] 115 | camid = cam2label[cam] 116 | if cam in ['A', 'B']: 117 | clothes_id = clothes2label[osp.basename(pdir)] 118 | else: 119 | clothes_id = clothes2label[osp.basename(pdir)+osp.basename(img_dir)[0]] 120 | dataset.append((img_dir, label, camid, clothes_id)) 121 | pid2clothes[label, clothes_id] = 1 122 | 123 | num_imgs = len(dataset) 124 | 125 | return dataset, num_pids, num_imgs, num_clothes, pid2clothes 126 | 127 | def _process_dir_test(self, test_path): 128 | pdirs = glob.glob(osp.join(test_path, '*')) 129 | pdirs.sort() 130 | 131 | pid_container = set() 132 | for pdir in glob.glob(osp.join(test_path, 'A', '*')): 133 | pid = int(osp.basename(pdir)) 134 | pid_container.add(pid) 135 | pid_container = sorted(pid_container) 136 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 137 | cam2label = {'A': 0, 'B': 1, 'C': 2} 138 | 139 | num_pids = len(pid_container) 140 | num_clothes = num_pids * 2 141 | 142 | query_dataset_same_clothes = [] 143 | query_dataset_diff_clothes = [] 144 | gallery_dataset = [] 145 | for cam in ['A', 'B', 'C']: 146 | pdirs = glob.glob(osp.join(test_path, cam, '*')) 147 | for pdir in pdirs: 148 | pid = int(osp.basename(pdir)) 149 | img_dirs = glob.glob(osp.join(pdir, '*.jpg')) 150 | for img_dir in img_dirs: 151 | # pid = pid2label[pid] 152 | camid = cam2label[cam] 153 | if cam == 'A': 154 | clothes_id = pid2label[pid] * 2 155 | gallery_dataset.append((img_dir, pid, camid, clothes_id)) 156 | elif cam == 'B': 157 | clothes_id = pid2label[pid] * 2 158 | query_dataset_same_clothes.append((img_dir, pid, camid, clothes_id)) 159 | else: 160 | clothes_id = pid2label[pid] * 2 + 1 161 | query_dataset_diff_clothes.append((img_dir, pid, camid, clothes_id)) 162 | 163 | pid2imgidx = {} 164 | for idx, (img_dir, pid, camid, clothes_id) in enumerate(gallery_dataset): 165 | if pid not in pid2imgidx: 166 | pid2imgidx[pid] = [] 167 | pid2imgidx[pid].append(idx) 168 | 169 | # get 10 gallery index to perform single-shot test 170 | gallery_idx = {} 171 | random.seed(3) 172 | for idx in range(0, 10): 173 | gallery_idx[idx] = [] 174 | for pid in pid2imgidx: 175 | gallery_idx[idx].append(random.choice(pid2imgidx[pid])) 176 | 177 | num_imgs_query_same = len(query_dataset_same_clothes) 178 | num_imgs_query_diff = len(query_dataset_diff_clothes) 179 | num_imgs_gallery = len(gallery_dataset) 180 | 181 | return query_dataset_same_clothes, query_dataset_diff_clothes, gallery_dataset, \ 182 | num_pids, num_imgs_query_same, num_imgs_query_diff, num_imgs_gallery, \ 183 | num_clothes, gallery_idx 184 | -------------------------------------------------------------------------------- /data/datasets/vcclothes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class VCClothes(object): 15 | """ VC-Clothes 16 | 17 | Reference: 18 | Wang et al. When Person Re-identification Meets Changing Clothes. In CVPR Workshop, 2020. 19 | 20 | URL: https://wanfb.github.io/dataset.html 21 | """ 22 | dataset_dir = 'VC-Clothes' 23 | def __init__(self, root='data', mode='all', **kwargs): 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'train') 26 | self.query_dir = osp.join(self.dataset_dir, 'query') 27 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 28 | # 'all' for all cameras; 'sc' for cam2&3; 'cc' for cam3&4 29 | self.mode = mode 30 | self._check_before_run() 31 | 32 | train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = self._process_dir_train() 33 | query, gallery, num_test_pids, num_query_imgs, num_gallery_imgs, num_test_clothes = self._process_dir_test() 34 | num_total_pids = num_train_pids + num_test_pids 35 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 36 | num_test_imgs = num_query_imgs + num_gallery_imgs 37 | num_total_clothes = num_train_clothes + num_test_clothes 38 | 39 | logger = logging.getLogger('reid.dataset') 40 | logger.info("=> VC-Clothes loaded") 41 | logger.info("Dataset statistics:") 42 | logger.info(" ----------------------------------------") 43 | logger.info(" subset | # ids | # images | # clothes") 44 | logger.info(" ----------------------------------------") 45 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes)) 46 | logger.info(" test | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes)) 47 | logger.info(" query | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs)) 48 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs)) 49 | logger.info(" ----------------------------------------") 50 | logger.info(" total | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes)) 51 | logger.info(" ----------------------------------------") 52 | 53 | self.train = train 54 | self.query = query 55 | self.gallery = gallery 56 | 57 | self.num_train_pids = num_train_pids 58 | self.num_train_clothes = num_train_clothes 59 | self.pid2clothes = pid2clothes 60 | 61 | def _check_before_run(self): 62 | """Check if all files are available before going deeper""" 63 | if not osp.exists(self.dataset_dir): 64 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 65 | if not osp.exists(self.train_dir): 66 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 67 | if not osp.exists(self.query_dir): 68 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 69 | if not osp.exists(self.gallery_dir): 70 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 71 | 72 | def _process_dir_train(self): 73 | img_paths = glob.glob(osp.join(self.train_dir, '*.jpg')) 74 | img_paths.sort() 75 | pattern = re.compile(r'(\d+)-(\d+)-(\d+)-(\d+)') 76 | 77 | pid_container = set() 78 | clothes_container = set() 79 | for img_path in img_paths: 80 | pid, camid, clothes, _ = pattern.search(img_path).groups() 81 | clothes_id = pid + clothes 82 | pid, camid = int(pid), int(camid) 83 | pid_container.add(pid) 84 | clothes_container.add(clothes_id) 85 | pid_container = sorted(pid_container) 86 | clothes_container = sorted(clothes_container) 87 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 88 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 89 | 90 | num_pids = len(pid_container) 91 | num_clothes = len(clothes_container) 92 | 93 | dataset = [] 94 | pid2clothes = np.zeros((num_pids, num_clothes)) 95 | for img_path in img_paths: 96 | pid, camid, clothes, _ = pattern.search(img_path).groups() 97 | clothes_id = pid + clothes 98 | pid, camid = int(pid), int(camid) 99 | camid -= 1 # index starts from 0 100 | pid = pid2label[pid] 101 | clothes_id = clothes2label[clothes_id] 102 | dataset.append((img_path, pid, camid, clothes_id)) 103 | pid2clothes[pid, clothes_id] = 1 104 | 105 | num_imgs = len(dataset) 106 | 107 | return dataset, num_pids, num_imgs, num_clothes, pid2clothes 108 | 109 | def _process_dir_test(self): 110 | query_img_paths = glob.glob(osp.join(self.query_dir, '*.jpg')) 111 | gallery_img_paths = glob.glob(osp.join(self.gallery_dir, '*.jpg')) 112 | query_img_paths.sort() 113 | gallery_img_paths.sort() 114 | pattern = re.compile(r'(\d+)-(\d+)-(\d+)-(\d+)') 115 | 116 | pid_container = set() 117 | clothes_container = set() 118 | for img_path in query_img_paths: 119 | pid, camid, clothes, _ = pattern.search(img_path).groups() 120 | clothes_id = pid + clothes 121 | pid, camid = int(pid), int(camid) 122 | if self.mode == 'sc' and camid not in [2, 3]: 123 | continue 124 | if self.mode == 'cc' and camid not in [3, 4]: 125 | continue 126 | pid_container.add(pid) 127 | clothes_container.add(clothes_id) 128 | for img_path in gallery_img_paths: 129 | pid, camid, clothes, _ = pattern.search(img_path).groups() 130 | clothes_id = pid + clothes 131 | pid, camid = int(pid), int(camid) 132 | if self.mode == 'sc' and camid not in [2, 3]: 133 | continue 134 | if self.mode == 'cc' and camid not in [3, 4]: 135 | continue 136 | pid_container.add(pid) 137 | clothes_container.add(clothes_id) 138 | pid_container = sorted(pid_container) 139 | clothes_container = sorted(clothes_container) 140 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 141 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 142 | 143 | num_pids = len(pid_container) 144 | num_clothes = len(clothes_container) 145 | 146 | query_dataset = [] 147 | gallery_dataset = [] 148 | for img_path in query_img_paths: 149 | pid, camid, clothes, _ = pattern.search(img_path).groups() 150 | clothes_id = pid + clothes 151 | pid, camid = int(pid), int(camid) 152 | if self.mode == 'sc' and camid not in [2, 3]: 153 | continue 154 | if self.mode == 'cc' and camid not in [3, 4]: 155 | continue 156 | camid -= 1 # index starts from 0 157 | clothes_id = clothes2label[clothes_id] 158 | query_dataset.append((img_path, pid, camid, clothes_id)) 159 | 160 | for img_path in gallery_img_paths: 161 | pid, camid, clothes, _ = pattern.search(img_path).groups() 162 | clothes_id = pid + clothes 163 | pid, camid = int(pid), int(camid) 164 | if self.mode == 'sc' and camid not in [2, 3]: 165 | continue 166 | if self.mode == 'cc' and camid not in [3, 4]: 167 | continue 168 | camid -= 1 # index starts from 0 169 | clothes_id = clothes2label[clothes_id] 170 | gallery_dataset.append((img_path, pid, camid, clothes_id)) 171 | 172 | num_imgs_query = len(query_dataset) 173 | num_imgs_gallery = len(gallery_dataset) 174 | 175 | return query_dataset, gallery_dataset, num_pids, num_imgs_query, num_imgs_gallery, num_clothes 176 | 177 | 178 | def VCClothesSameClothes(root='data', **kwargs): 179 | return VCClothes(root=root, mode='sc') 180 | 181 | 182 | def VCClothesClothesChanging(root='data', **kwargs): 183 | return VCClothes(root=root, mode='cc') 184 | -------------------------------------------------------------------------------- /data/img_transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import * 2 | from PIL import Image 3 | import random 4 | import math 5 | 6 | 7 | class ResizeWithEqualScale(object): 8 | """ 9 | Resize an image with equal scale as the original image. 10 | 11 | Args: 12 | height (int): resized height. 13 | width (int): resized width. 14 | interpolation: interpolation manner. 15 | fill_color (tuple): color for padding. 16 | """ 17 | def __init__(self, height, width, interpolation=Image.BILINEAR, fill_color=(0,0,0)): 18 | self.height = height 19 | self.width = width 20 | self.interpolation = interpolation 21 | self.fill_color = fill_color 22 | 23 | def __call__(self, img): 24 | width, height = img.size 25 | if self.height / self.width >= height / width: 26 | height = int(self.width * (height / width)) 27 | width = self.width 28 | else: 29 | width = int(self.height * (width / height)) 30 | height = self.height 31 | 32 | resized_img = img.resize((width, height), self.interpolation) 33 | new_img = Image.new('RGB', (self.width, self.height), self.fill_color) 34 | new_img.paste(resized_img, (int((self.width - width) / 2), int((self.height - height) / 2))) 35 | 36 | return new_img 37 | 38 | 39 | class RandomCroping(object): 40 | """ 41 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 42 | 43 | Args: 44 | p (float): probability of performing this transformation. Default: 0.5. 45 | """ 46 | def __init__(self, p=0.5, interpolation=Image.BILINEAR): 47 | self.p = p 48 | self.interpolation = interpolation 49 | 50 | def __call__(self, img): 51 | """ 52 | Args: 53 | img (PIL Image): Image to be cropped. 54 | 55 | Returns: 56 | PIL Image: Cropped image. 57 | """ 58 | width, height = img.size 59 | if random.uniform(0, 1) >= self.p: 60 | return img 61 | 62 | new_width, new_height = int(round(width * 1.125)), int(round(height * 1.125)) 63 | resized_img = img.resize((new_width, new_height), self.interpolation) 64 | x_maxrange = new_width - width 65 | y_maxrange = new_height - height 66 | x1 = int(round(random.uniform(0, x_maxrange))) 67 | y1 = int(round(random.uniform(0, y_maxrange))) 68 | croped_img = resized_img.crop((x1, y1, x1 + width, y1 + height)) 69 | 70 | return croped_img 71 | 72 | 73 | class RandomErasing(object): 74 | """ 75 | Randomly selects a rectangle region in an image and erases its pixels. 76 | 77 | Reference: 78 | Zhong et al. Random Erasing Data Augmentation. arxiv: 1708.04896, 2017. 79 | 80 | Args: 81 | probability: The probability that the Random Erasing operation will be performed. 82 | sl: Minimum proportion of erased area against input image. 83 | sh: Maximum proportion of erased area against input image. 84 | r1: Minimum aspect ratio of erased area. 85 | mean: Erasing value. 86 | """ 87 | 88 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 89 | self.probability = probability 90 | self.mean = mean 91 | self.sl = sl 92 | self.sh = sh 93 | self.r1 = r1 94 | 95 | def __call__(self, img): 96 | 97 | if random.uniform(0, 1) >= self.probability: 98 | return img 99 | 100 | for attempt in range(100): 101 | area = img.size()[1] * img.size()[2] 102 | 103 | target_area = random.uniform(self.sl, self.sh) * area 104 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 105 | 106 | h = int(round(math.sqrt(target_area * aspect_ratio))) 107 | w = int(round(math.sqrt(target_area / aspect_ratio))) 108 | 109 | if w < img.size()[2] and h < img.size()[1]: 110 | x1 = random.randint(0, img.size()[1] - h) 111 | y1 = random.randint(0, img.size()[2] - w) 112 | if img.size()[0] == 3: 113 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 114 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 115 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 116 | else: 117 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 118 | return img 119 | 120 | return img -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import random 4 | import numpy as np 5 | from torch import distributed as dist 6 | from collections import defaultdict 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | class RandomIdentitySampler(Sampler): 11 | """ 12 | Randomly sample N identities, then for each identity, 13 | randomly sample K instances, therefore batch size is N*K. 14 | 15 | Args: 16 | data_source (Dataset): dataset to sample from. 17 | num_instances (int): number of instances per identity. 18 | """ 19 | def __init__(self, data_source, num_instances=4): 20 | self.data_source = data_source 21 | self.num_instances = num_instances 22 | self.index_dic = defaultdict(list) 23 | for index, (_, pid, _, _) in enumerate(data_source): 24 | self.index_dic[pid].append(index) 25 | self.pids = list(self.index_dic.keys()) 26 | self.num_identities = len(self.pids) 27 | 28 | # compute number of examples in an epoch 29 | self.length = 0 30 | for pid in self.pids: 31 | idxs = self.index_dic[pid] 32 | num = len(idxs) 33 | if num < self.num_instances: 34 | num = self.num_instances 35 | self.length += num - num % self.num_instances 36 | 37 | def __iter__(self): 38 | list_container = [] 39 | 40 | for pid in self.pids: 41 | idxs = copy.deepcopy(self.index_dic[pid]) 42 | if len(idxs) < self.num_instances: 43 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 44 | random.shuffle(idxs) 45 | batch_idxs = [] 46 | for idx in idxs: 47 | batch_idxs.append(idx) 48 | if len(batch_idxs) == self.num_instances: 49 | list_container.append(batch_idxs) 50 | batch_idxs = [] 51 | 52 | random.shuffle(list_container) 53 | 54 | ret = [] 55 | for batch_idxs in list_container: 56 | ret.extend(batch_idxs) 57 | 58 | return iter(ret) 59 | 60 | def __len__(self): 61 | return self.length 62 | 63 | 64 | class DistributedRandomIdentitySampler(Sampler): 65 | """ 66 | Randomly sample N identities, then for each identity, 67 | randomly sample K instances, therefore batch size is N*K. 68 | 69 | Args: 70 | - data_source (Dataset): dataset to sample from. 71 | - num_instances (int): number of instances per identity. 72 | - num_replicas (int, optional): Number of processes participating in 73 | distributed training. By default, :attr:`world_size` is retrieved from the 74 | current distributed group. 75 | - rank (int, optional): Rank of the current process within :attr:`num_replicas`. 76 | By default, :attr:`rank` is retrieved from the current distributed group. 77 | - seed (int, optional): random seed used to shuffle the sampler. 78 | This number should be identical across all 79 | processes in the distributed group. Default: ``0``. 80 | """ 81 | def __init__(self, data_source, num_instances=4, 82 | num_replicas=None, rank=None, seed=0): 83 | if num_replicas is None: 84 | if not dist.is_available(): 85 | raise RuntimeError("Requires distributed package to be available") 86 | num_replicas = dist.get_world_size() 87 | if rank is None: 88 | if not dist.is_available(): 89 | raise RuntimeError("Requires distributed package to be available") 90 | rank = dist.get_rank() 91 | if rank >= num_replicas or rank < 0: 92 | raise ValueError( 93 | "Invalid rank {}, rank should be in the interval" 94 | " [0, {}]".format(rank, num_replicas - 1)) 95 | self.num_replicas = num_replicas 96 | self.rank = rank 97 | self.seed = seed 98 | self.epoch = 0 99 | 100 | self.data_source = data_source 101 | self.num_instances = num_instances 102 | self.index_dic = defaultdict(list) 103 | for index, (_, pid, _, _) in enumerate(data_source): 104 | self.index_dic[pid].append(index) 105 | self.pids = list(self.index_dic.keys()) 106 | self.num_identities = len(self.pids) 107 | 108 | # compute number of examples in an epoch 109 | self.length = 0 110 | for pid in self.pids: 111 | idxs = self.index_dic[pid] 112 | num = len(idxs) 113 | if num < self.num_instances: 114 | num = self.num_instances 115 | self.length += num - num % self.num_instances 116 | assert self.length % self.num_instances == 0 117 | 118 | if self.length // self.num_instances % self.num_replicas != 0: 119 | self.num_samples = math.ceil((self.length // self.num_instances - self.num_replicas) / self.num_replicas) * self.num_instances 120 | else: 121 | self.num_samples = math.ceil(self.length / self.num_replicas) 122 | self.total_size = self.num_samples * self.num_replicas 123 | 124 | def __iter__(self): 125 | # deterministically shuffle based on epoch and seed 126 | random.seed(self.seed + self.epoch) 127 | np.random.seed(self.seed + self.epoch) 128 | 129 | list_container = [] 130 | for pid in self.pids: 131 | idxs = copy.deepcopy(self.index_dic[pid]) 132 | if len(idxs) < self.num_instances: 133 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 134 | random.shuffle(idxs) 135 | batch_idxs = [] 136 | for idx in idxs: 137 | batch_idxs.append(idx) 138 | if len(batch_idxs) == self.num_instances: 139 | list_container.append(batch_idxs) 140 | batch_idxs = [] 141 | random.shuffle(list_container) 142 | 143 | # remove tail of data to make it evenly divisible. 144 | list_container = list_container[:self.total_size//self.num_instances] 145 | assert len(list_container) == self.total_size//self.num_instances 146 | 147 | # subsample 148 | list_container = list_container[self.rank:self.total_size//self.num_instances:self.num_replicas] 149 | assert len(list_container) == self.num_samples//self.num_instances 150 | 151 | ret = [] 152 | for batch_idxs in list_container: 153 | ret.extend(batch_idxs) 154 | 155 | return iter(ret) 156 | 157 | def __len__(self): 158 | return self.num_samples 159 | 160 | def set_epoch(self, epoch): 161 | """ 162 | Sets the epoch for this sampler. This ensures all replicas 163 | use a different random ordering for each epoch. Otherwise, the next iteration of this 164 | sampler will yield the same ordering. 165 | 166 | Args: 167 | epoch (int): Epoch number. 168 | """ 169 | self.epoch = epoch 170 | 171 | 172 | class DistributedInferenceSampler(Sampler): 173 | """ 174 | refer to: https://github.com/huggingface/transformers/blob/447808c85f0e6d6b0aeeb07214942bf1e578f9d2/src/transformers/trainer_pt_utils.py 175 | 176 | Distributed Sampler that subsamples indicies sequentially, 177 | making it easier to collate all results at the end. 178 | Even though we only use this sampler for eval and predict (no training), 179 | which means that the model params won't have to be synced (i.e. will not hang 180 | for synchronization even if varied number of forward passes), we still add extra 181 | samples to the sampler to make it evenly divisible (like in `DistributedSampler`) 182 | to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. 183 | """ 184 | def __init__(self, dataset, rank=None, num_replicas=None): 185 | if num_replicas is None: 186 | if not dist.is_available(): 187 | raise RuntimeError("Requires distributed package to be available") 188 | num_replicas = dist.get_world_size() 189 | if rank is None: 190 | if not dist.is_available(): 191 | raise RuntimeError("Requires distributed package to be available") 192 | rank = dist.get_rank() 193 | self.dataset = dataset 194 | self.num_replicas = num_replicas 195 | self.rank = rank 196 | 197 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 198 | self.total_size = self.num_samples * self.num_replicas 199 | 200 | def __iter__(self): 201 | indices = list(range(len(self.dataset))) 202 | # add extra samples to make it evenly divisible 203 | indices += [indices[-1]] * (self.total_size - len(indices)) 204 | # subsample 205 | indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] 206 | return iter(indices) 207 | 208 | def __len__(self): 209 | return self.num_samples -------------------------------------------------------------------------------- /data/spatial_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numbers 4 | import collections 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as T 8 | from PIL import Image, ImageOps 9 | try: 10 | import accimage 11 | except ImportError: 12 | accimage = None 13 | 14 | 15 | class Compose(object): 16 | """Composes several transforms together. 17 | 18 | Args: 19 | transforms (list of ``Transform`` objects): list of transforms to compose. 20 | 21 | Example: 22 | >>> transforms.Compose([ 23 | >>> transforms.CenterCrop(10), 24 | >>> transforms.ToTensor(), 25 | >>> ]) 26 | """ 27 | 28 | def __init__(self, transforms): 29 | self.transforms = transforms 30 | 31 | def __call__(self, img): 32 | for t in self.transforms: 33 | img = t(img) 34 | return img 35 | 36 | def randomize_parameters(self): 37 | for t in self.transforms: 38 | t.randomize_parameters() 39 | 40 | 41 | class ToTensor(object): 42 | """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. 43 | Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 44 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 45 | """ 46 | 47 | def __init__(self, norm_value=255): 48 | self.norm_value = norm_value 49 | 50 | def __call__(self, pic): 51 | """ 52 | Args: 53 | pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. 54 | Returns: 55 | Tensor: Converted image. 56 | """ 57 | if isinstance(pic, np.ndarray): 58 | # handle numpy array 59 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 60 | # backward compatibility 61 | return img.float().div(self.norm_value) 62 | 63 | if accimage is not None and isinstance(pic, accimage.Image): 64 | nppic = np.zeros( 65 | [pic.channels, pic.height, pic.width], dtype=np.float32) 66 | pic.copyto(nppic) 67 | return torch.from_numpy(nppic) 68 | 69 | # handle PIL Image 70 | if pic.mode == 'I': 71 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 72 | elif pic.mode == 'I;16': 73 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 74 | else: 75 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 76 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 77 | if pic.mode == 'YCbCr': 78 | nchannel = 3 79 | elif pic.mode == 'I;16': 80 | nchannel = 1 81 | else: 82 | nchannel = len(pic.mode) 83 | img = img.view(pic.size[1], pic.size[0], nchannel) 84 | # put it from HWC to CHW format 85 | # yikes, this transpose takes 80% of the loading time/CPU 86 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 87 | if isinstance(img, torch.ByteTensor): 88 | return img.float().div(self.norm_value) 89 | else: 90 | return img 91 | 92 | def randomize_parameters(self): 93 | pass 94 | 95 | 96 | class Normalize(object): 97 | """Normalize an tensor image with mean and standard deviation. 98 | Given mean: (R, G, B) and std: (R, G, B), 99 | will normalize each channel of the torch.*Tensor, i.e. 100 | channel = (channel - mean) / std 101 | 102 | Args: 103 | mean (sequence): Sequence of means for R, G, B channels respecitvely. 104 | std (sequence): Sequence of standard deviations for R, G, B channels 105 | respecitvely. 106 | """ 107 | 108 | def __init__(self, mean, std): 109 | self.mean = mean 110 | self.std = std 111 | 112 | def __call__(self, tensor): 113 | """ 114 | Args: 115 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 116 | Returns: 117 | Tensor: Normalized image. 118 | """ 119 | # TODO: make efficient 120 | for t, m, s in zip(tensor, self.mean, self.std): 121 | t.sub_(m).div_(s) 122 | return tensor 123 | 124 | def randomize_parameters(self): 125 | pass 126 | 127 | 128 | class Scale(object): 129 | """Rescale the input PIL.Image to the given size. 130 | 131 | Args: 132 | size (sequence or int): Desired output size. If size is a sequence like 133 | (w, h), output size will be matched to this. If size is an int, 134 | smaller edge of the image will be matched to this number. 135 | i.e, if height > width, then image will be rescaled to 136 | (size * height / width, size) 137 | interpolation (int, optional): Desired interpolation. Default is 138 | ``PIL.Image.BILINEAR`` 139 | """ 140 | 141 | def __init__(self, size, interpolation=Image.BILINEAR): 142 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 143 | self.size = size 144 | self.interpolation = interpolation 145 | 146 | def __call__(self, img): 147 | """ 148 | Args: 149 | img (PIL.Image): Image to be scaled. 150 | Returns: 151 | PIL.Image: Rescaled image. 152 | """ 153 | if isinstance(self.size, int): 154 | w, h = img.size 155 | if (w <= h and w == self.size) or (h <= w and h == self.size): 156 | return img 157 | if w < h: 158 | ow = self.size 159 | oh = int(self.size * h / w) 160 | return img.resize((ow, oh), self.interpolation) 161 | else: 162 | oh = self.size 163 | ow = int(self.size * w / h) 164 | return img.resize((ow, oh), self.interpolation) 165 | else: 166 | return img.resize(self.size[::-1], self.interpolation) 167 | 168 | def randomize_parameters(self): 169 | pass 170 | 171 | 172 | class RandomHorizontalFlip(object): 173 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 174 | 175 | def __call__(self, img): 176 | """ 177 | Args: 178 | img (PIL.Image): Image to be flipped. 179 | Returns: 180 | PIL.Image: Randomly flipped image. 181 | """ 182 | if self.p < 0.5: 183 | return img.transpose(Image.FLIP_LEFT_RIGHT) 184 | return img 185 | 186 | def randomize_parameters(self): 187 | self.p = random.random() 188 | 189 | 190 | class RandomCrop(object): 191 | """ 192 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 193 | 194 | Args: 195 | height (int): target height. 196 | width (int): target width. 197 | p (float): probability of performing this transformation. Default: 0.5. 198 | """ 199 | def __init__(self, size, p=0.5, interpolation=Image.BILINEAR): 200 | if isinstance(size, numbers.Number): 201 | self.size = (int(size), int(size)) 202 | else: 203 | self.size = size 204 | 205 | self.height, self.width = self.size 206 | self.p = p 207 | self.interpolation = interpolation 208 | 209 | def __call__(self, img): 210 | """ 211 | Args: 212 | img (PIL Image): Image to be cropped. 213 | 214 | Returns: 215 | PIL Image: Cropped image. 216 | """ 217 | if not self.cropping: 218 | return img.resize((self.width, self.height), self.interpolation) 219 | 220 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 221 | resized_img = img.resize((new_width, new_height), self.interpolation) 222 | x_maxrange = new_width - self.width 223 | y_maxrange = new_height - self.height 224 | x1 = int(round(self.tl_x * x_maxrange)) 225 | y1 = int(round(self.tl_y * y_maxrange)) 226 | return resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 227 | 228 | def randomize_parameters(self): 229 | self.cropping = random.uniform(0, 1) < self.p 230 | self.tl_x = random.random() 231 | self.tl_y = random.random() 232 | 233 | 234 | class RandomErasing(object): 235 | """ 236 | Randomly selects a rectangle region in an image and erases its pixels. 237 | 238 | Reference: 239 | Zhong et al. Random Erasing Data Augmentation. arxiv: 1708.04896, 2017. 240 | 241 | Args: 242 | probability: The probability that the Random Erasing operation will be performed. 243 | sl: Minimum proportion of erased area against input image. 244 | sh: Maximum proportion of erased area against input image. 245 | r1: Minimum aspect ratio of erased area. 246 | mean: Erasing value. 247 | """ 248 | 249 | def __init__(self, height=256, width=128, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.485, 0.456, 0.406]): 250 | self.probability = probability 251 | self.mean = mean 252 | self.sl = sl 253 | self.sh = sh 254 | self.r1 = r1 255 | self.height = height 256 | self.width = width 257 | 258 | def __call__(self, img): 259 | if self.re: 260 | return img 261 | 262 | if img.size()[0] == 3: 263 | img[0, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[0] 264 | img[1, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[1] 265 | img[2, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[2] 266 | else: 267 | img[0, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[0] 268 | return img 269 | 270 | def randomize_parameters(self): 271 | self.re = random.uniform(0, 1) < self.probability 272 | self.h, self.w, self.x1, self.y1 = 0, 0, 0, 0 273 | whether_re = False 274 | if self.re: 275 | for attempt in range(100): 276 | area = self.height*self.width 277 | 278 | target_area = random.uniform(self.sl, self.sh) * area 279 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 280 | 281 | self.h = int(round(math.sqrt(target_area * aspect_ratio))) 282 | self.w = int(round(math.sqrt(target_area / aspect_ratio))) 283 | if self.w < self.width and self.h < self.height: 284 | self.x1 = random.randint(0, self.height - self.h) 285 | self.y1 = random.randint(0, self.width - self.w) 286 | whether_re = True 287 | break 288 | 289 | self.re = whether_re -------------------------------------------------------------------------------- /data/temporal_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | class TemporalRandomCrop(object): 6 | """Temporally crop the given frame indices at a random location. 7 | 8 | If the number of frames is less than the size, 9 | loop the indices as many times as necessary to satisfy the size. 10 | 11 | Args: 12 | size (int): Desired output size of the crop. 13 | stride (int): Temporal sampling stride 14 | """ 15 | 16 | def __init__(self, size=4, stride=8): 17 | self.size = size 18 | self.stride = stride 19 | 20 | def __call__(self, frame_indices): 21 | """ 22 | Args: 23 | frame_indices (list): frame indices to be cropped. 24 | Returns: 25 | list: Cropped frame indices. 26 | """ 27 | frame_indices = list(frame_indices) 28 | 29 | if len(frame_indices) >= self.size * self.stride: 30 | rand_end = len(frame_indices) - (self.size - 1) * self.stride - 1 31 | begin_index = random.randint(0, rand_end) 32 | end_index = begin_index + (self.size - 1) * self.stride + 1 33 | out = frame_indices[begin_index:end_index:self.stride] 34 | elif len(frame_indices) >= self.size: 35 | clips = [] 36 | for i in range(self.size): 37 | clips.append(frame_indices[len(frame_indices)//self.size*i : len(frame_indices)//self.size*(i+1)]) 38 | out = [] 39 | for i in range(self.size): 40 | out.append(random.choice(clips[i])) 41 | else: 42 | index = np.random.choice(len(frame_indices), size=self.size, replace=True) 43 | index.sort() 44 | out = [frame_indices[index[i]] for i in range(self.size)] 45 | 46 | return out 47 | 48 | 49 | class TemporalBeginCrop(object): 50 | """Temporally crop the given frame indices at a beginning. 51 | 52 | If the number of frames is less than the size, 53 | loop the indices as many times as necessary to satisfy the size. 54 | 55 | Args: 56 | size (int): Desired output size of the crop. 57 | stride (int): Temporal sampling stride 58 | """ 59 | 60 | def __init__(self, size=8, stride=4): 61 | self.size = size 62 | self.stride = stride 63 | 64 | def __call__(self, frame_indices): 65 | frame_indices = list(frame_indices) 66 | 67 | if len(frame_indices) >= self.size * self.stride: 68 | out = frame_indices[0 : self.size * self.stride : self.stride] 69 | else: 70 | out = frame_indices[0 : self.size] 71 | while len(out) < self.size: 72 | for index in out: 73 | if len(out) >= self.size: 74 | break 75 | out.append(index) 76 | 77 | return out 78 | 79 | 80 | class TemporalDivisionCrop(object): 81 | """Temporally crop the given frame indices by TSN. 82 | 83 | Args: 84 | size (int): Desired output size of the crop. 85 | """ 86 | def __init__(self, size=4): 87 | self.size = size 88 | 89 | def __call__(self, frame_indices): 90 | """ 91 | Args: 92 | frame_indices (list): frame indices to be cropped. 93 | Returns: 94 | list: Cropped frame indices. 95 | """ 96 | frame_indices = list(frame_indices) 97 | 98 | if len(frame_indices) >= self.size: 99 | clips = [] 100 | for i in range(self.size): 101 | clips.append(frame_indices[len(frame_indices)//self.size*i : len(frame_indices)//self.size*(i+1)]) 102 | out = [] 103 | for i in range(self.size): 104 | out.append(random.choice(clips[i])) 105 | else: 106 | index = np.random.choice(len(frame_indices), size=self.size, replace=True) 107 | index.sort() 108 | out = [frame_indices[index[i]] for i in range(self.size)] 109 | 110 | return out 111 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from losses.cross_entropy_loss_with_label_smooth import CrossEntropyWithLabelSmooth 3 | from losses.triplet_loss import TripletLoss 4 | from losses.contrastive_loss import ContrastiveLoss 5 | from losses.arcface_loss import ArcFaceLoss 6 | from losses.cosface_loss import CosFaceLoss, PairwiseCosFaceLoss 7 | from losses.circle_loss import CircleLoss, PairwiseCircleLoss 8 | from losses.clothes_based_adversarial_loss import ClothesBasedAdversarialLoss, ClothesBasedAdversarialLossWithMemoryBank 9 | 10 | 11 | def build_losses(config, num_train_clothes): 12 | # Build identity classification loss 13 | if config.LOSS.CLA_LOSS == 'crossentropy': 14 | criterion_cla = nn.CrossEntropyLoss() 15 | elif config.LOSS.CLA_LOSS == 'crossentropylabelsmooth': 16 | criterion_cla = CrossEntropyWithLabelSmooth() 17 | elif config.LOSS.CLA_LOSS == 'arcface': 18 | criterion_cla = ArcFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 19 | elif config.LOSS.CLA_LOSS == 'cosface': 20 | criterion_cla = CosFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 21 | elif config.LOSS.CLA_LOSS == 'circle': 22 | criterion_cla = CircleLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 23 | else: 24 | raise KeyError("Invalid classification loss: '{}'".format(config.LOSS.CLA_LOSS)) 25 | 26 | # Build pairwise loss 27 | if config.LOSS.PAIR_LOSS == 'triplet': 28 | criterion_pair = TripletLoss(margin=config.LOSS.PAIR_M) 29 | elif config.LOSS.PAIR_LOSS == 'contrastive': 30 | criterion_pair = ContrastiveLoss(scale=config.LOSS.PAIR_S) 31 | elif config.LOSS.PAIR_LOSS == 'cosface': 32 | criterion_pair = PairwiseCosFaceLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M) 33 | elif config.LOSS.PAIR_LOSS == 'circle': 34 | criterion_pair = PairwiseCircleLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M) 35 | else: 36 | raise KeyError("Invalid pairwise loss: '{}'".format(config.LOSS.PAIR_LOSS)) 37 | 38 | # Build clothes classification loss 39 | if config.LOSS.CLOTHES_CLA_LOSS == 'crossentropy': 40 | criterion_clothes = nn.CrossEntropyLoss() 41 | elif config.LOSS.CLOTHES_CLA_LOSS == 'cosface': 42 | criterion_clothes = CosFaceLoss(scale=config.LOSS.CLA_S, margin=0) 43 | else: 44 | raise KeyError("Invalid clothes classification loss: '{}'".format(config.LOSS.CLOTHES_CLA_LOSS)) 45 | 46 | # Build clothes-based adversarial loss 47 | if config.LOSS.CAL == 'cal': 48 | criterion_cal = ClothesBasedAdversarialLoss(scale=config.LOSS.CLA_S, epsilon=config.LOSS.EPSILON) 49 | elif config.LOSS.CAL == 'calwithmemory': 50 | criterion_cal = ClothesBasedAdversarialLossWithMemoryBank(num_clothes=num_train_clothes, feat_dim=config.MODEL.FEATURE_DIM, 51 | momentum=config.LOSS.MOMENTUM, scale=config.LOSS.CLA_S, epsilon=config.LOSS.EPSILON) 52 | else: 53 | raise KeyError("Invalid clothing classification loss: '{}'".format(config.LOSS.CAL)) 54 | 55 | return criterion_cla, criterion_pair, criterion_clothes, criterion_cal 56 | -------------------------------------------------------------------------------- /losses/arcface_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class ArcFaceLoss(nn.Module): 8 | """ ArcFace loss. 9 | 10 | Reference: 11 | Deng et al. ArcFace: Additive Angular Margin Loss for Deep Face Recognition. In CVPR, 2019. 12 | 13 | Args: 14 | scale (float): scaling factor. 15 | margin (float): pre-defined margin. 16 | """ 17 | def __init__(self, scale=16, margin=0.1): 18 | super().__init__() 19 | self.s = scale 20 | self.m = margin 21 | 22 | def forward(self, inputs, targets): 23 | """ 24 | Args: 25 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 26 | targets: ground truth labels with shape (batch_size) 27 | """ 28 | # get a one-hot index 29 | index = inputs.data * 0.0 30 | index.scatter_(1, targets.data.view(-1, 1), 1) 31 | index = index.bool() 32 | 33 | cos_m = math.cos(self.m) 34 | sin_m = math.sin(self.m) 35 | cos_t = inputs[index] 36 | sin_t = torch.sqrt(1.0 - cos_t * cos_t) 37 | cos_t_add_m = cos_t * cos_m - sin_t * sin_m 38 | 39 | cond_v = cos_t - math.cos(math.pi - self.m) 40 | cond = F.relu(cond_v) 41 | keep = cos_t - math.sin(math.pi - self.m) * self.m 42 | 43 | cos_t_add_m = torch.where(cond.bool(), cos_t_add_m, keep) 44 | 45 | output = inputs * 1.0 46 | output[index] = cos_t_add_m 47 | output = self.s * output 48 | 49 | return F.cross_entropy(output, targets) 50 | -------------------------------------------------------------------------------- /losses/circle_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch import distributed as dist 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class CircleLoss(nn.Module): 9 | """ Circle Loss based on the predictions of classifier. 10 | 11 | Reference: 12 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020. 13 | 14 | Args: 15 | scale (float): scaling factor. 16 | margin (float): pre-defined margin. 17 | """ 18 | def __init__(self, scale=96, margin=0.3, **kwargs): 19 | super().__init__() 20 | self.s = scale 21 | self.m = margin 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (batch_size) 28 | """ 29 | mask = torch.zeros_like(inputs).cuda() 30 | mask.scatter_(1, targets.view(-1, 1), 1.0) 31 | 32 | pos_scale = self.s * F.relu(1 + self.m - inputs.detach()) 33 | neg_scale = self.s * F.relu(inputs.detach() + self.m) 34 | scale_matrix = pos_scale * mask + neg_scale * (1 - mask) 35 | 36 | scores = (inputs - (1 - self.m) * mask - self.m * (1 - mask)) * scale_matrix 37 | 38 | loss = F.cross_entropy(scores, targets) 39 | 40 | return loss 41 | 42 | 43 | class PairwiseCircleLoss(nn.Module): 44 | """ Circle Loss among sample pairs. 45 | 46 | Reference: 47 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020. 48 | 49 | Args: 50 | scale (float): scaling factor. 51 | margin (float): pre-defined margin. 52 | """ 53 | def __init__(self, scale=48, margin=0.35, **kwargs): 54 | super().__init__() 55 | self.s = scale 56 | self.m = margin 57 | 58 | def forward(self, inputs, targets): 59 | """ 60 | Args: 61 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 62 | targets: ground truth labels with shape (batch_size) 63 | """ 64 | # l2-normalize 65 | inputs = F.normalize(inputs, p=2, dim=1) 66 | 67 | # gather all samples from different GPUs as gallery to compute pairwise loss. 68 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 69 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 70 | m, n = targets.size(0), gallery_targets.size(0) 71 | 72 | # compute cosine similarity 73 | similarities = torch.matmul(inputs, gallery_inputs.t()) 74 | 75 | # get mask for pos/neg pairs 76 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1) 77 | mask = torch.eq(targets, gallery_targets.T).float().cuda() 78 | mask_self = torch.zeros_like(mask) 79 | rank = dist.get_rank() 80 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda() 81 | mask_pos = mask - mask_self 82 | mask_neg = 1 - mask 83 | 84 | pos_scale = self.s * F.relu(1 + self.m - similarities.detach()) 85 | neg_scale = self.s * F.relu(similarities.detach() + self.m) 86 | scale_matrix = pos_scale * mask_pos + neg_scale * mask_neg 87 | 88 | scores = (similarities - self.m) * mask_neg + (1 - self.m - similarities) * mask_pos 89 | scores = scores * scale_matrix 90 | 91 | neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1) 92 | pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1) 93 | 94 | loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean() 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /losses/clothes_based_adversarial_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from losses.gather import GatherLayer 5 | 6 | 7 | class ClothesBasedAdversarialLoss(nn.Module): 8 | """ Clothes-based Adversarial Loss. 9 | 10 | Reference: 11 | Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022. 12 | 13 | Args: 14 | scale (float): scaling factor. 15 | epsilon (float): a trade-off hyper-parameter. 16 | """ 17 | def __init__(self, scale=16, epsilon=0.1): 18 | super().__init__() 19 | self.scale = scale 20 | self.epsilon = epsilon 21 | 22 | def forward(self, inputs, targets, positive_mask): 23 | """ 24 | Args: 25 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 26 | targets: ground truth labels with shape (batch_size) 27 | positive_mask: positive mask matrix with shape (batch_size, num_classes). The clothes classes with 28 | the same identity as the anchor sample are defined as positive clothes classes and their mask 29 | values are 1. The clothes classes with different identities from the anchor sample are defined 30 | as negative clothes classes and their mask values in positive_mask are 0. 31 | """ 32 | inputs = self.scale * inputs 33 | negtive_mask = 1 - positive_mask 34 | identity_mask = torch.zeros(inputs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda() 35 | 36 | exp_logits = torch.exp(inputs) 37 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * negtive_mask).sum(1, keepdim=True) + exp_logits) 38 | log_prob = inputs - log_sum_exp_pos_and_all_neg 39 | 40 | mask = (1 - self.epsilon) * identity_mask + self.epsilon / positive_mask.sum(1, keepdim=True) * positive_mask 41 | loss = (- mask * log_prob).sum(1).mean() 42 | 43 | return loss 44 | 45 | 46 | class ClothesBasedAdversarialLossWithMemoryBank(nn.Module): 47 | """ Clothes-based Adversarial Loss between mini batch and the samples in memory. 48 | 49 | Reference: 50 | Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022. 51 | 52 | Args: 53 | num_clothes (int): the number of clothes classes. 54 | feat_dim (int): the dimensions of feature. 55 | momentum (float): momentum to update memory. 56 | scale (float): scaling factor. 57 | epsilon (float): a trade-off hyper-parameter. 58 | """ 59 | def __init__(self, num_clothes, feat_dim, momentum=0., scale=16, epsilon=0.1): 60 | super().__init__() 61 | self.num_clothes = num_clothes 62 | self.feat_dim = feat_dim 63 | self.momentum = momentum 64 | self.epsilon = epsilon 65 | self.scale = scale 66 | 67 | self.register_buffer('feature_memory', torch.zeros((num_clothes, feat_dim))) 68 | self.register_buffer('label_memory', torch.zeros(num_clothes, dtype=torch.int64) - 1) 69 | self.has_been_filled = False 70 | 71 | def forward(self, inputs, targets, positive_mask): 72 | """ 73 | Args: 74 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 75 | targets: ground truth labels with shape (batch_size) 76 | positive_mask: positive mask matrix with shape (batch_size, num_classes). 77 | """ 78 | # gather all samples from different GPUs to update memory. 79 | gathered_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 80 | gathered_targets = torch.cat(GatherLayer.apply(targets), dim=0) 81 | self._update_memory(gathered_inputs.detach(), gathered_targets) 82 | 83 | inputs_norm = F.normalize(inputs, p=2, dim=1) 84 | memory_norm = F.normalize(self.feature_memory.detach(), p=2, dim=1) 85 | similarities = torch.matmul(inputs_norm, memory_norm.t()) * self.scale 86 | 87 | negtive_mask = 1 - positive_mask 88 | mask_identity = torch.zeros(positive_mask.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda() 89 | 90 | if not self.has_been_filled: 91 | invalid_index = self.label_memory == -1 92 | positive_mask[:, invalid_index] = 0 93 | negtive_mask[:, invalid_index] = 0 94 | if sum(invalid_index.type(torch.int)) == 0: 95 | self.has_been_filled = True 96 | print('Memory bank is full') 97 | 98 | # compute log_prob 99 | exp_logits = torch.exp(similarities) 100 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * negtive_mask).sum(1, keepdim=True) + exp_logits) 101 | log_prob = similarities - log_sum_exp_pos_and_all_neg 102 | 103 | # compute mean of log-likelihood over positive 104 | mask = (1 - self.epsilon) * mask_identity + self.epsilon / positive_mask.sum(1, keepdim=True) * positive_mask 105 | loss = (- mask * log_prob).sum(1).mean() 106 | 107 | return loss 108 | 109 | def _update_memory(self, features, labels): 110 | label_to_feat = {} 111 | for x, y in zip(features, labels): 112 | if y not in label_to_feat: 113 | label_to_feat[y] = [x.unsqueeze(0)] 114 | else: 115 | label_to_feat[y].append(x.unsqueeze(0)) 116 | if not self.has_been_filled: 117 | for y in label_to_feat: 118 | feat = torch.mean(torch.cat(label_to_feat[y], dim=0), dim=0) 119 | self.feature_memory[y] = feat 120 | self.label_memory[y] = y 121 | else: 122 | for y in label_to_feat: 123 | feat = torch.mean(torch.cat(label_to_feat[y], dim=0), dim=0) 124 | self.feature_memory[y] = self.momentum * self.feature_memory[y] + (1. - self.momentum) * feat 125 | # self.embedding_memory[y] /= self.embedding_memory[y].norm() -------------------------------------------------------------------------------- /losses/contrastive_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch import distributed as dist 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class ContrastiveLoss(nn.Module): 9 | """ Supervised Contrastive Learning Loss among sample pairs. 10 | 11 | Args: 12 | scale (float): scaling factor. 13 | """ 14 | def __init__(self, scale=16, **kwargs): 15 | super().__init__() 16 | self.s = scale 17 | 18 | def forward(self, inputs, targets): 19 | """ 20 | Args: 21 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 22 | targets: ground truth labels with shape (batch_size) 23 | """ 24 | # l2-normalize 25 | inputs = F.normalize(inputs, p=2, dim=1) 26 | 27 | # gather all samples from different GPUs as gallery to compute pairwise loss. 28 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 29 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 30 | m, n = targets.size(0), gallery_targets.size(0) 31 | 32 | # compute cosine similarity 33 | similarities = torch.matmul(inputs, gallery_inputs.t()) * self.s 34 | 35 | # get mask for pos/neg pairs 36 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1) 37 | mask = torch.eq(targets, gallery_targets.T).float().cuda() 38 | mask_self = torch.zeros_like(mask) 39 | rank = dist.get_rank() 40 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda() 41 | mask_pos = mask - mask_self 42 | mask_neg = 1 - mask 43 | 44 | # compute log_prob 45 | exp_logits = torch.exp(similarities) * (1 - mask_self) 46 | # log_prob = similarities - torch.log(exp_logits.sum(1, keepdim=True)) 47 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * mask_neg).sum(1, keepdim=True) + exp_logits) 48 | log_prob = similarities - log_sum_exp_pos_and_all_neg 49 | 50 | # compute mean of log-likelihood over positive 51 | loss = (mask_pos * log_prob).sum(1) / mask_pos.sum(1) 52 | 53 | loss = - loss.mean() 54 | 55 | return loss -------------------------------------------------------------------------------- /losses/cosface_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch import distributed as dist 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class CosFaceLoss(nn.Module): 9 | """ CosFace Loss based on the predictions of classifier. 10 | 11 | Reference: 12 | Wang et al. CosFace: Large Margin Cosine Loss for Deep Face Recognition. In CVPR, 2018. 13 | 14 | Args: 15 | scale (float): scaling factor. 16 | margin (float): pre-defined margin. 17 | """ 18 | def __init__(self, scale=16, margin=0.1, **kwargs): 19 | super().__init__() 20 | self.s = scale 21 | self.m = margin 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (batch_size) 28 | """ 29 | one_hot = torch.zeros_like(inputs) 30 | one_hot.scatter_(1, targets.view(-1, 1), 1.0) 31 | 32 | output = self.s * (inputs - one_hot * self.m) 33 | 34 | return F.cross_entropy(output, targets) 35 | 36 | 37 | class PairwiseCosFaceLoss(nn.Module): 38 | """ CosFace Loss among sample pairs. 39 | 40 | Reference: 41 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020. 42 | 43 | Args: 44 | scale (float): scaling factor. 45 | margin (float): pre-defined margin. 46 | """ 47 | def __init__(self, scale=16, margin=0): 48 | super().__init__() 49 | self.s = scale 50 | self.m = margin 51 | 52 | def forward(self, inputs, targets): 53 | """ 54 | Args: 55 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 56 | targets: ground truth labels with shape (batch_size) 57 | """ 58 | # l2-normalize 59 | inputs = F.normalize(inputs, p=2, dim=1) 60 | 61 | # gather all samples from different GPUs as gallery to compute pairwise loss. 62 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 63 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 64 | m, n = targets.size(0), gallery_targets.size(0) 65 | 66 | # compute cosine similarity 67 | similarities = torch.matmul(inputs, gallery_inputs.t()) 68 | 69 | # get mask for pos/neg pairs 70 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1) 71 | mask = torch.eq(targets, gallery_targets.T).float().cuda() 72 | mask_self = torch.zeros_like(mask) 73 | rank = dist.get_rank() 74 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda() 75 | mask_pos = mask - mask_self 76 | mask_neg = 1 - mask 77 | 78 | scores = (similarities + self.m) * mask_neg - similarities * mask_pos 79 | scores = scores * self.s 80 | 81 | neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1) 82 | pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1) 83 | 84 | loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean() 85 | 86 | return loss -------------------------------------------------------------------------------- /losses/cross_entropy_loss_with_label_smooth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CrossEntropyWithLabelSmooth(nn.Module): 6 | """ Cross entropy loss with label smoothing regularization. 7 | 8 | Reference: 9 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. In CVPR, 2016. 10 | Equation: 11 | y = (1 - epsilon) * y + epsilon / K. 12 | 13 | Args: 14 | epsilon (float): a hyper-parameter in the above equation. 15 | """ 16 | def __init__(self, epsilon=0.1): 17 | super().__init__() 18 | self.epsilon = epsilon 19 | self.logsoftmax = nn.LogSoftmax(dim=1) 20 | 21 | def forward(self, inputs, targets): 22 | """ 23 | Args: 24 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 25 | targets: ground truth labels with shape (batch_size) 26 | """ 27 | _, num_classes = inputs.size() 28 | log_probs = self.logsoftmax(inputs) 29 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda() 30 | targets = (1 - self.epsilon) * targets + self.epsilon / num_classes 31 | loss = (- targets * log_probs).mean(0).sum() 32 | 33 | return loss 34 | -------------------------------------------------------------------------------- /losses/gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | class GatherLayer(torch.autograd.Function): 6 | """Gather tensors from all process, supporting backward propagation.""" 7 | 8 | @staticmethod 9 | def forward(ctx, input): 10 | ctx.save_for_backward(input) 11 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 12 | dist.all_gather(output, input) 13 | 14 | return tuple(output) 15 | 16 | @staticmethod 17 | def backward(ctx, *grads): 18 | (input,) = ctx.saved_tensors 19 | grad_out = torch.zeros_like(input) 20 | 21 | # dist.reduce_scatter(grad_out, list(grads)) 22 | # grad_out.div_(dist.get_world_size()) 23 | 24 | grad_out[:] = grads[dist.get_rank()] 25 | 26 | return grad_out -------------------------------------------------------------------------------- /losses/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class TripletLoss(nn.Module): 9 | """ Triplet loss with hard example mining. 10 | 11 | Reference: 12 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 13 | 14 | Args: 15 | margin (float): pre-defined margin. 16 | 17 | Note that we use cosine similarity, rather than Euclidean distance in the original paper. 18 | """ 19 | def __init__(self, margin=0.3): 20 | super().__init__() 21 | self.m = margin 22 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 23 | 24 | def forward(self, inputs, targets): 25 | """ 26 | Args: 27 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 28 | targets: ground truth labels with shape (batch_size) 29 | """ 30 | # l2-normlize 31 | inputs = F.normalize(inputs, p=2, dim=1) 32 | 33 | # gather all samples from different GPUs as gallery to compute pairwise loss. 34 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 35 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 36 | 37 | # compute distance 38 | dist = 1 - torch.matmul(inputs, gallery_inputs.t()) # values in [0, 2] 39 | 40 | # get positive and negative masks 41 | targets, gallery_targets = targets.view(-1,1), gallery_targets.view(-1,1) 42 | mask_pos = torch.eq(targets, gallery_targets.T).float().cuda() 43 | mask_neg = 1 - mask_pos 44 | 45 | # For each anchor, find the hardest positive and negative pairs 46 | dist_ap, _ = torch.max((dist - mask_neg * 99999999.), dim=1) 47 | dist_an, _ = torch.min((dist + mask_pos * 99999999.), dim=1) 48 | 49 | # Compute ranking hinge loss 50 | y = torch.ones_like(dist_an) 51 | loss = self.ranking_loss(dist_an, dist_ap, y) 52 | 53 | return loss -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import datetime 5 | import argparse 6 | import logging 7 | import os.path as osp 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.optim import lr_scheduler 14 | from torch import distributed as dist 15 | from apex import amp 16 | 17 | from configs.default_img import get_img_config 18 | from configs.default_vid import get_vid_config 19 | from data import build_dataloader 20 | from models import build_model 21 | from losses import build_losses 22 | from tools.utils import save_checkpoint, set_seed, get_logger 23 | from train import train_cal, train_cal_with_memory 24 | from test import test, test_prcc 25 | 26 | 27 | VID_DATASET = ['ccvid'] 28 | 29 | 30 | def parse_option(): 31 | parser = argparse.ArgumentParser(description='Train clothes-changing re-id model with clothes-based adversarial loss') 32 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file') 33 | # Datasets 34 | parser.add_argument('--root', type=str, help="your root path to data directory") 35 | parser.add_argument('--dataset', type=str, default='ltcc', help="ltcc, prcc, vcclothes, ccvid, last, deepchange") 36 | # Miscs 37 | parser.add_argument('--output', type=str, help="your output path to save model and logs") 38 | parser.add_argument('--resume', type=str, metavar='PATH') 39 | parser.add_argument('--amp', action='store_true', help="automatic mixed precision") 40 | parser.add_argument('--eval', action='store_true', help="evaluation only") 41 | parser.add_argument('--tag', type=str, help='tag for log file') 42 | parser.add_argument('--gpu', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') 43 | 44 | args, unparsed = parser.parse_known_args() 45 | if args.dataset in VID_DATASET: 46 | config = get_vid_config(args) 47 | else: 48 | config = get_img_config(args) 49 | 50 | return config 51 | 52 | 53 | def main(config): 54 | # Build dataloader 55 | if config.DATA.DATASET == 'prcc': 56 | trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler = build_dataloader(config) 57 | else: 58 | trainloader, queryloader, galleryloader, dataset, train_sampler = build_dataloader(config) 59 | # Define a matrix pid2clothes with shape (num_pids, num_clothes). 60 | # pid2clothes[i, j] = 1 when j-th clothes belongs to i-th identity. Otherwise, pid2clothes[i, j] = 0. 61 | pid2clothes = torch.from_numpy(dataset.pid2clothes) 62 | 63 | # Build model 64 | model, classifier, clothes_classifier = build_model(config, dataset.num_train_pids, dataset.num_train_clothes) 65 | # Build identity classification loss, pairwise loss, clothes classificaiton loss, and adversarial loss. 66 | criterion_cla, criterion_pair, criterion_clothes, criterion_adv = build_losses(config, dataset.num_train_clothes) 67 | # Build optimizer 68 | parameters = list(model.parameters()) + list(classifier.parameters()) 69 | if config.TRAIN.OPTIMIZER.NAME == 'adam': 70 | optimizer = optim.Adam(parameters, lr=config.TRAIN.OPTIMIZER.LR, 71 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 72 | optimizer_cc = optim.Adam(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, 73 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 74 | elif config.TRAIN.OPTIMIZER.NAME == 'adamw': 75 | optimizer = optim.AdamW(parameters, lr=config.TRAIN.OPTIMIZER.LR, 76 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 77 | optimizer_cc = optim.AdamW(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, 78 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 79 | elif config.TRAIN.OPTIMIZER.NAME == 'sgd': 80 | optimizer = optim.SGD(parameters, lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 81 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True) 82 | optimizer_cc = optim.SGD(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 83 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True) 84 | else: 85 | raise KeyError("Unknown optimizer: {}".format(config.TRAIN.OPTIMIZER.NAME)) 86 | # Build lr_scheduler 87 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.TRAIN.LR_SCHEDULER.STEPSIZE, 88 | gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE) 89 | 90 | start_epoch = config.TRAIN.START_EPOCH 91 | if config.MODEL.RESUME: 92 | logger.info("Loading checkpoint from '{}'".format(config.MODEL.RESUME)) 93 | checkpoint = torch.load(config.MODEL.RESUME) 94 | model.load_state_dict(checkpoint['model_state_dict']) 95 | classifier.load_state_dict(checkpoint['classifier_state_dict']) 96 | if config.LOSS.CAL == 'calwithmemory': 97 | criterion_adv.load_state_dict(checkpoint['clothes_classifier_state_dict']) 98 | else: 99 | clothes_classifier.load_state_dict(checkpoint['clothes_classifier_state_dict']) 100 | start_epoch = checkpoint['epoch'] 101 | 102 | local_rank = dist.get_rank() 103 | model = model.cuda(local_rank) 104 | classifier = classifier.cuda(local_rank) 105 | if config.LOSS.CAL == 'calwithmemory': 106 | criterion_adv = criterion_adv.cuda(local_rank) 107 | else: 108 | clothes_classifier = clothes_classifier.cuda(local_rank) 109 | torch.cuda.set_device(local_rank) 110 | 111 | if config.TRAIN.AMP: 112 | [model, classifier], optimizer = amp.initialize([model, classifier], optimizer, opt_level="O1") 113 | if config.LOSS.CAL != 'calwithmemory': 114 | clothes_classifier, optimizer_cc = amp.initialize(clothes_classifier, optimizer_cc, opt_level="O1") 115 | 116 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) 117 | classifier = nn.parallel.DistributedDataParallel(classifier, device_ids=[local_rank], output_device=local_rank) 118 | if config.LOSS.CAL != 'calwithmemory': 119 | clothes_classifier = nn.parallel.DistributedDataParallel(clothes_classifier, device_ids=[local_rank], output_device=local_rank) 120 | 121 | if config.EVAL_MODE: 122 | logger.info("Evaluate only") 123 | with torch.no_grad(): 124 | if config.DATA.DATASET == 'prcc': 125 | test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset) 126 | else: 127 | test(config, model, queryloader, galleryloader, dataset) 128 | return 129 | 130 | start_time = time.time() 131 | train_time = 0 132 | best_rank1 = -np.inf 133 | best_epoch = 0 134 | logger.info("==> Start training") 135 | for epoch in range(start_epoch, config.TRAIN.MAX_EPOCH): 136 | train_sampler.set_epoch(epoch) 137 | start_train_time = time.time() 138 | if config.LOSS.CAL == 'calwithmemory': 139 | train_cal_with_memory(config, epoch, model, classifier, criterion_cla, criterion_pair, 140 | criterion_adv, optimizer, trainloader, pid2clothes) 141 | else: 142 | train_cal(config, epoch, model, classifier, clothes_classifier, criterion_cla, criterion_pair, 143 | criterion_clothes, criterion_adv, optimizer, optimizer_cc, trainloader, pid2clothes) 144 | train_time += round(time.time() - start_train_time) 145 | 146 | if (epoch+1) > config.TEST.START_EVAL and config.TEST.EVAL_STEP > 0 and \ 147 | (epoch+1) % config.TEST.EVAL_STEP == 0 or (epoch+1) == config.TRAIN.MAX_EPOCH: 148 | logger.info("==> Test") 149 | torch.cuda.empty_cache() 150 | if config.DATA.DATASET == 'prcc': 151 | rank1 = test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset) 152 | else: 153 | rank1 = test(config, model, queryloader, galleryloader, dataset) 154 | torch.cuda.empty_cache() 155 | is_best = rank1 > best_rank1 156 | if is_best: 157 | best_rank1 = rank1 158 | best_epoch = epoch + 1 159 | 160 | model_state_dict = model.module.state_dict() 161 | classifier_state_dict = classifier.module.state_dict() 162 | if config.LOSS.CAL == 'calwithmemory': 163 | clothes_classifier_state_dict = criterion_adv.state_dict() 164 | else: 165 | clothes_classifier_state_dict = clothes_classifier.module.state_dict() 166 | if local_rank == 0: 167 | save_checkpoint({ 168 | 'model_state_dict': model_state_dict, 169 | 'classifier_state_dict': classifier_state_dict, 170 | 'clothes_classifier_state_dict': clothes_classifier_state_dict, 171 | 'rank1': rank1, 172 | 'epoch': epoch, 173 | }, is_best, osp.join(config.OUTPUT, 'checkpoint_ep' + str(epoch+1) + '.pth.tar')) 174 | scheduler.step() 175 | 176 | logger.info("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch)) 177 | 178 | elapsed = round(time.time() - start_time) 179 | elapsed = str(datetime.timedelta(seconds=elapsed)) 180 | train_time = str(datetime.timedelta(seconds=train_time)) 181 | logger.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) 182 | 183 | 184 | if __name__ == '__main__': 185 | config = parse_option() 186 | # Set GPU 187 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU 188 | # Init dist 189 | dist.init_process_group(backend="nccl", init_method='env://') 190 | local_rank = dist.get_rank() 191 | # Set random seed 192 | set_seed(config.SEED + local_rank) 193 | # get logger 194 | if not config.EVAL_MODE: 195 | output_file = osp.join(config.OUTPUT, 'log_train_.log') 196 | else: 197 | output_file = osp.join(config.OUTPUT, 'log_test.log') 198 | logger = get_logger(output_file, local_rank, 'reid') 199 | logger.info("Config:\n-----------------------------------------") 200 | logger.info(config) 201 | logger.info("-----------------------------------------") 202 | 203 | main(config) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from models.classifier import Classifier, NormalizedClassifier 3 | from models.img_resnet import ResNet50 4 | from models.vid_resnet import C2DResNet50, I3DResNet50, AP3DResNet50, NLResNet50, AP3DNLResNet50 5 | 6 | 7 | __factory = { 8 | 'resnet50': ResNet50, 9 | 'c2dres50': C2DResNet50, 10 | 'i3dres50': I3DResNet50, 11 | 'ap3dres50': AP3DResNet50, 12 | 'nlres50': NLResNet50, 13 | 'ap3dnlres50': AP3DNLResNet50, 14 | } 15 | 16 | 17 | def build_model(config, num_identities, num_clothes): 18 | logger = logging.getLogger('reid.model') 19 | # Build backbone 20 | logger.info("Initializing model: {}".format(config.MODEL.NAME)) 21 | if config.MODEL.NAME not in __factory.keys(): 22 | raise KeyError("Invalid model: '{}'".format(config.MODEL.NAME)) 23 | else: 24 | logger.info("Init model: '{}'".format(config.MODEL.NAME)) 25 | model = __factory[config.MODEL.NAME](config) 26 | logger.info("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0)) 27 | 28 | # Build classifier 29 | if config.LOSS.CLA_LOSS in ['crossentropy', 'crossentropylabelsmooth']: 30 | identity_classifier = Classifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities) 31 | else: 32 | identity_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities) 33 | 34 | clothes_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_clothes) 35 | 36 | return model, identity_classifier, clothes_classifier -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | from torch.nn import functional as F 5 | from torch.nn import Parameter 6 | 7 | 8 | __all__ = ['Classifier', 'NormalizedClassifier'] 9 | 10 | 11 | class Classifier(nn.Module): 12 | def __init__(self, feature_dim, num_classes): 13 | super().__init__() 14 | self.classifier = nn.Linear(feature_dim, num_classes) 15 | init.normal_(self.classifier.weight.data, std=0.001) 16 | init.constant_(self.classifier.bias.data, 0.0) 17 | 18 | def forward(self, x): 19 | y = self.classifier(x) 20 | 21 | return y 22 | 23 | 24 | class NormalizedClassifier(nn.Module): 25 | def __init__(self, feature_dim, num_classes): 26 | super().__init__() 27 | self.weight = Parameter(torch.Tensor(num_classes, feature_dim)) 28 | self.weight.data.uniform_(-1, 1).renorm_(2,0,1e-5).mul_(1e5) 29 | 30 | def forward(self, x): 31 | w = self.weight 32 | 33 | x = F.normalize(x, p=2, dim=1) 34 | w = F.normalize(w, p=2, dim=1) 35 | 36 | return F.linear(x, w) 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /models/img_resnet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch import nn 3 | from torch.nn import init 4 | from models.utils import pooling 5 | 6 | 7 | class ResNet50(nn.Module): 8 | def __init__(self, config, **kwargs): 9 | super().__init__() 10 | 11 | resnet50 = torchvision.models.resnet50(pretrained=True) 12 | if config.MODEL.RES4_STRIDE == 1: 13 | resnet50.layer4[0].conv2.stride=(1, 1) 14 | resnet50.layer4[0].downsample[0].stride=(1, 1) 15 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 16 | 17 | if config.MODEL.POOLING.NAME == 'avg': 18 | self.globalpooling = nn.AdaptiveAvgPool2d(1) 19 | elif config.MODEL.POOLING.NAME == 'max': 20 | self.globalpooling = nn.AdaptiveMaxPool2d(1) 21 | elif config.MODEL.POOLING.NAME == 'gem': 22 | self.globalpooling = pooling.GeMPooling(p=config.MODEL.POOLING.P) 23 | elif config.MODEL.POOLING.NAME == 'maxavg': 24 | self.globalpooling = pooling.MaxAvgPooling() 25 | else: 26 | raise KeyError("Invalid pooling: '{}'".format(config.MODEL.POOLING.NAME)) 27 | 28 | self.bn = nn.BatchNorm1d(config.MODEL.FEATURE_DIM) 29 | init.normal_(self.bn.weight.data, 1.0, 0.02) 30 | init.constant_(self.bn.bias.data, 0.0) 31 | 32 | def forward(self, x): 33 | x = self.base(x) 34 | x = self.globalpooling(x) 35 | x = x.view(x.size(0), -1) 36 | f = self.bn(x) 37 | 38 | return f -------------------------------------------------------------------------------- /models/utils/inflate.py: -------------------------------------------------------------------------------- 1 | # inflate 2D modules to 3D modules 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | 7 | def inflate_conv(conv2d, 8 | time_dim=1, 9 | time_padding=0, 10 | time_stride=1, 11 | time_dilation=1, 12 | center=False): 13 | # To preserve activations, padding should be by continuity and not zero 14 | # or no padding in time dimension 15 | kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1]) 16 | padding = (time_padding, conv2d.padding[0], conv2d.padding[1]) 17 | stride = (time_stride, conv2d.stride[0], conv2d.stride[0]) 18 | dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1]) 19 | conv3d = nn.Conv3d( 20 | conv2d.in_channels, 21 | conv2d.out_channels, 22 | kernel_dim, 23 | padding=padding, 24 | dilation=dilation, 25 | stride=stride) 26 | # Repeat filter time_dim times along time dimension 27 | weight_2d = conv2d.weight.data 28 | if center: 29 | weight_3d = torch.zeros(*weight_2d.shape) 30 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 31 | middle_idx = time_dim // 2 32 | weight_3d[:, :, middle_idx, :, :] = weight_2d 33 | else: 34 | weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 35 | weight_3d = weight_3d / time_dim 36 | 37 | # Assign new params 38 | conv3d.weight = nn.Parameter(weight_3d) 39 | conv3d.bias = conv2d.bias 40 | return conv3d 41 | 42 | 43 | def inflate_linear(linear2d, time_dim): 44 | """ 45 | Args: 46 | time_dim: final time dimension of the features 47 | """ 48 | linear3d = nn.Linear(linear2d.in_features * time_dim, 49 | linear2d.out_features) 50 | weight3d = linear2d.weight.data.repeat(1, time_dim) 51 | weight3d = weight3d / time_dim 52 | 53 | linear3d.weight = nn.Parameter(weight3d) 54 | linear3d.bias = linear2d.bias 55 | return linear3d 56 | 57 | 58 | def inflate_batch_norm(batch2d): 59 | # In pytorch 0.2.0 the 2d and 3d versions of batch norm 60 | # work identically except for the check that verifies the 61 | # input dimensions 62 | 63 | batch3d = nn.BatchNorm3d(batch2d.num_features) 64 | # retrieve 3d _check_input_dim function 65 | batch2d._check_input_dim = batch3d._check_input_dim 66 | return batch2d 67 | 68 | 69 | def inflate_pool(pool2d, 70 | time_dim=1, 71 | time_padding=0, 72 | time_stride=None, 73 | time_dilation=1): 74 | kernel_dim = (time_dim, pool2d.kernel_size, pool2d.kernel_size) 75 | padding = (time_padding, pool2d.padding, pool2d.padding) 76 | if time_stride is None: 77 | time_stride = time_dim 78 | stride = (time_stride, pool2d.stride, pool2d.stride) 79 | if isinstance(pool2d, nn.MaxPool2d): 80 | dilation = (time_dilation, pool2d.dilation, pool2d.dilation) 81 | pool3d = nn.MaxPool3d( 82 | kernel_dim, 83 | padding=padding, 84 | dilation=dilation, 85 | stride=stride, 86 | ceil_mode=pool2d.ceil_mode) 87 | elif isinstance(pool2d, nn.AvgPool2d): 88 | pool3d = nn.AvgPool3d(kernel_dim, stride=stride) 89 | else: 90 | raise ValueError( 91 | '{} is not among known pooling classes'.format(type(pool2d))) 92 | return pool3d 93 | 94 | 95 | class MaxPool2dFor3dInput(nn.Module): 96 | """ 97 | Since nn.MaxPool3d is nondeterministic operation, using fixed random seeds can't get consistent results. 98 | So we attempt to use max_pool2d to implement MaxPool3d with kernelsize (1, kernel_size, kernel_size). 99 | """ 100 | def __init__(self, kernel_size, stride=None, padding=0, dilation=1): 101 | super().__init__() 102 | self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 103 | def forward(self, x): 104 | b, c, t, h, w = x.size() 105 | x = x.permute(0, 2, 1, 3, 4).contiguous() # b, t, c, h, w 106 | x = x.view(b*t, c, h, w) 107 | # max pooling 108 | x = self.maxpool(x) 109 | _, _, h, w = x.size() 110 | x = x.view(b, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous() 111 | 112 | return x -------------------------------------------------------------------------------- /models/utils/nonlocal_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from models.utils import inflate 6 | 7 | 8 | class NonLocalBlockND(nn.Module): 9 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 10 | super(NonLocalBlockND, self).__init__() 11 | 12 | assert dimension in [1, 2, 3] 13 | 14 | self.dimension = dimension 15 | self.sub_sample = sub_sample 16 | self.in_channels = in_channels 17 | self.inter_channels = inter_channels 18 | 19 | if self.inter_channels is None: 20 | self.inter_channels = in_channels // 2 21 | if self.inter_channels == 0: 22 | self.inter_channels = 1 23 | 24 | if dimension == 3: 25 | conv_nd = nn.Conv3d 26 | # max_pool = inflate.MaxPool2dFor3dInput 27 | max_pool = nn.MaxPool3d 28 | bn = nn.BatchNorm3d 29 | elif dimension == 2: 30 | conv_nd = nn.Conv2d 31 | max_pool = nn.MaxPool2d 32 | bn = nn.BatchNorm2d 33 | else: 34 | conv_nd = nn.Conv1d 35 | max_pool = nn.MaxPool1d 36 | bn = nn.BatchNorm1d 37 | 38 | self.g = conv_nd(self.in_channels, self.inter_channels, 39 | kernel_size=1, stride=1, padding=0, bias=True) 40 | self.theta = conv_nd(self.in_channels, self.inter_channels, 41 | kernel_size=1, stride=1, padding=0, bias=True) 42 | self.phi = conv_nd(self.in_channels, self.inter_channels, 43 | kernel_size=1, stride=1, padding=0, bias=True) 44 | # if sub_sample: 45 | # self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 46 | # self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 47 | if sub_sample: 48 | if dimension == 3: 49 | self.g = nn.Sequential(self.g, max_pool((1, 2, 2))) 50 | self.phi = nn.Sequential(self.phi, max_pool((1, 2, 2))) 51 | else: 52 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 53 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 54 | 55 | if bn_layer: 56 | self.W = nn.Sequential( 57 | conv_nd(self.inter_channels, self.in_channels, 58 | kernel_size=1, stride=1, padding=0, bias=True), 59 | bn(self.in_channels) 60 | ) 61 | else: 62 | self.W = conv_nd(self.inter_channels, self.in_channels, 63 | kernel_size=1, stride=1, padding=0, bias=True) 64 | 65 | # init 66 | for m in self.modules(): 67 | if isinstance(m, conv_nd): 68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 69 | m.weight.data.normal_(0, math.sqrt(2. / n)) 70 | elif isinstance(m, bn): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | if bn_layer: 75 | nn.init.constant_(self.W[1].weight.data, 0.0) 76 | nn.init.constant_(self.W[1].bias.data, 0.0) 77 | else: 78 | nn.init.constant_(self.W.weight.data, 0.0) 79 | nn.init.constant_(self.W.bias.data, 0.0) 80 | 81 | 82 | def forward(self, x): 83 | ''' 84 | :param x: (b, c, t, h, w) 85 | :return: 86 | ''' 87 | batch_size = x.size(0) 88 | 89 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 90 | g_x = g_x.permute(0, 2, 1) 91 | 92 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 93 | theta_x = theta_x.permute(0, 2, 1) 94 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 95 | f = torch.matmul(theta_x, phi_x) 96 | f = F.softmax(f, dim=-1) 97 | 98 | y = torch.matmul(f, g_x) 99 | y = y.permute(0, 2, 1).contiguous() 100 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 101 | y = self.W(y) 102 | z = y + x 103 | 104 | return z 105 | 106 | 107 | class NonLocalBlock1D(NonLocalBlockND): 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 109 | super(NonLocalBlock1D, self).__init__(in_channels, 110 | inter_channels=inter_channels, 111 | dimension=1, sub_sample=sub_sample, 112 | bn_layer=bn_layer) 113 | 114 | 115 | class NonLocalBlock2D(NonLocalBlockND): 116 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 117 | super(NonLocalBlock2D, self).__init__(in_channels, 118 | inter_channels=inter_channels, 119 | dimension=2, sub_sample=sub_sample, 120 | bn_layer=bn_layer) 121 | 122 | 123 | class NonLocalBlock3D(NonLocalBlockND): 124 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 125 | super(NonLocalBlock3D, self).__init__(in_channels, 126 | inter_channels=inter_channels, 127 | dimension=3, sub_sample=sub_sample, 128 | bn_layer=bn_layer) 129 | -------------------------------------------------------------------------------- /models/utils/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class GeMPooling(nn.Module): 7 | def __init__(self, p=3, eps=1e-6): 8 | super().__init__() 9 | self.p = nn.Parameter(torch.ones(1) * p) 10 | self.eps = eps 11 | 12 | def forward(self, x): 13 | return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), x.size()[2:]).pow(1./self.p) 14 | 15 | 16 | class MaxAvgPooling(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.maxpooling = nn.AdaptiveMaxPool2d(1) 20 | self.avgpooling = nn.AdaptiveAvgPool2d(1) 21 | 22 | def forward(self, x): 23 | max_f = self.maxpooling(x) 24 | avg_f = self.avgpooling(x) 25 | 26 | return torch.cat((max_f, avg_f), 1) 27 | -------------------------------------------------------------------------------- /models/vid_resnet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torch.nn import functional as F 5 | from models.utils import inflate 6 | from models.utils import c3d_blocks 7 | from models.utils import nonlocal_blocks 8 | 9 | 10 | __all__ = ['AP3DResNet50', 'AP3DNLResNet50', 'NLResNet50', 'C2DResNet50', 11 | 'I3DResNet50', 12 | ] 13 | 14 | 15 | class Bottleneck3D(nn.Module): 16 | def __init__(self, bottleneck2d, block, inflate_time=False, temperature=4, contrastive_att=True): 17 | super().__init__() 18 | self.conv1 = inflate.inflate_conv(bottleneck2d.conv1, time_dim=1) 19 | self.bn1 = inflate.inflate_batch_norm(bottleneck2d.bn1) 20 | if inflate_time == True: 21 | self.conv2 = block(bottleneck2d.conv2, temperature=temperature, contrastive_att=contrastive_att) 22 | else: 23 | self.conv2 = inflate.inflate_conv(bottleneck2d.conv2, time_dim=1) 24 | self.bn2 = inflate.inflate_batch_norm(bottleneck2d.bn2) 25 | self.conv3 = inflate.inflate_conv(bottleneck2d.conv3, time_dim=1) 26 | self.bn3 = inflate.inflate_batch_norm(bottleneck2d.bn3) 27 | self.relu = nn.ReLU(inplace=True) 28 | 29 | if bottleneck2d.downsample is not None: 30 | self.downsample = self._inflate_downsample(bottleneck2d.downsample) 31 | else: 32 | self.downsample = None 33 | 34 | def _inflate_downsample(self, downsample2d, time_stride=1): 35 | downsample3d = nn.Sequential( 36 | inflate.inflate_conv(downsample2d[0], time_dim=1, 37 | time_stride=time_stride), 38 | inflate.inflate_batch_norm(downsample2d[1])) 39 | return downsample3d 40 | 41 | def forward(self, x): 42 | residual = x 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv3(out) 52 | out = self.bn3(out) 53 | 54 | if self.downsample is not None: 55 | residual = self.downsample(x) 56 | 57 | out += residual 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | 63 | class ResNet503D(nn.Module): 64 | def __init__(self, config, block, c3d_idx, nl_idx, **kwargs): 65 | super().__init__() 66 | self.block = block 67 | self.temperature = config.MODEL.AP3D.TEMPERATURE 68 | self.contrastive_att = config.MODEL.AP3D.CONTRACTIVE_ATT 69 | 70 | resnet2d = torchvision.models.resnet50(pretrained=True) 71 | if config.MODEL.RES4_STRIDE == 1: 72 | resnet2d.layer4[0].conv2.stride=(1, 1) 73 | resnet2d.layer4[0].downsample[0].stride=(1, 1) 74 | 75 | self.conv1 = inflate.inflate_conv(resnet2d.conv1, time_dim=1) 76 | self.bn1 = inflate.inflate_batch_norm(resnet2d.bn1) 77 | self.relu = nn.ReLU(inplace=True) 78 | self.maxpool = inflate.inflate_pool(resnet2d.maxpool, time_dim=1) 79 | # self.maxpool = inflate.MaxPool2dFor3dInput(kernel_size=resnet2d.maxpool.kernel_size, 80 | # stride=resnet2d.maxpool.stride, 81 | # padding=resnet2d.maxpool.padding, 82 | # dilation=resnet2d.maxpool.dilation) 83 | 84 | self.layer1 = self._inflate_reslayer(resnet2d.layer1, c3d_idx=c3d_idx[0], \ 85 | nonlocal_idx=nl_idx[0], nonlocal_channels=256) 86 | self.layer2 = self._inflate_reslayer(resnet2d.layer2, c3d_idx=c3d_idx[1], \ 87 | nonlocal_idx=nl_idx[1], nonlocal_channels=512) 88 | self.layer3 = self._inflate_reslayer(resnet2d.layer3, c3d_idx=c3d_idx[2], \ 89 | nonlocal_idx=nl_idx[2], nonlocal_channels=1024) 90 | self.layer4 = self._inflate_reslayer(resnet2d.layer4, c3d_idx=c3d_idx[3], \ 91 | nonlocal_idx=nl_idx[3], nonlocal_channels=2048) 92 | 93 | self.bn = nn.BatchNorm1d(2048) 94 | init.normal_(self.bn.weight.data, 1.0, 0.02) 95 | init.constant_(self.bn.bias.data, 0.0) 96 | 97 | def _inflate_reslayer(self, reslayer2d, c3d_idx, nonlocal_idx=[], nonlocal_channels=0): 98 | reslayers3d = [] 99 | for i,layer2d in enumerate(reslayer2d): 100 | if i not in c3d_idx: 101 | layer3d = Bottleneck3D(layer2d, c3d_blocks.C2D, inflate_time=False) 102 | else: 103 | layer3d = Bottleneck3D(layer2d, self.block, inflate_time=True, \ 104 | temperature=self.temperature, contrastive_att=self.contrastive_att) 105 | reslayers3d.append(layer3d) 106 | 107 | if i in nonlocal_idx: 108 | non_local_block = nonlocal_blocks.NonLocalBlock3D(nonlocal_channels, sub_sample=True) 109 | reslayers3d.append(non_local_block) 110 | 111 | return nn.Sequential(*reslayers3d) 112 | 113 | def forward(self, x): 114 | x = self.conv1(x) 115 | x = self.bn1(x) 116 | x = self.relu(x) 117 | x = self.maxpool(x) 118 | 119 | x = self.layer1(x) 120 | x = self.layer2(x) 121 | x = self.layer3(x) 122 | x = self.layer4(x) 123 | 124 | b, c, t, h, w = x.size() 125 | x = x.permute(0, 2, 1, 3, 4).contiguous() 126 | x = x.view(b*t, c, h, w) 127 | # spatial max pooling 128 | x = F.max_pool2d(x, x.size()[2:]) 129 | x = x.view(b, t, -1) 130 | # temporal avg pooling 131 | x = x.mean(1) 132 | f = self.bn(x) 133 | 134 | return f 135 | 136 | 137 | def C2DResNet50(config, **kwargs): 138 | c3d_idx = [[],[],[],[]] 139 | nl_idx = [[],[],[],[]] 140 | 141 | return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs) 142 | 143 | 144 | def AP3DResNet50(config, **kwargs): 145 | c3d_idx = [[],[0, 2],[0, 2, 4],[]] 146 | nl_idx = [[],[],[],[]] 147 | 148 | return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs) 149 | 150 | 151 | def I3DResNet50(config, **kwargs): 152 | c3d_idx = [[],[0, 2],[0, 2, 4],[]] 153 | nl_idx = [[],[],[],[]] 154 | 155 | return ResNet503D(config, c3d_blocks.I3D, c3d_idx, nl_idx, **kwargs) 156 | 157 | 158 | def AP3DNLResNet50(config, **kwargs): 159 | c3d_idx = [[],[0, 2],[0, 2, 4],[]] 160 | nl_idx = [[],[1, 3],[1, 3, 5],[]] 161 | 162 | return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs) 163 | 164 | 165 | def NLResNet50(config, **kwargs): 166 | c3d_idx = [[],[],[],[]] 167 | nl_idx = [[],[1, 3],[1, 3, 5],[]] 168 | 169 | return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs) 170 | -------------------------------------------------------------------------------- /script.sh: -------------------------------------------------------------------------------- 1 | # The code is builded with DistributedDataParallel. 2 | # Reprodecing the results in the paper should train the model on 2 GPUs. 3 | # You can also train this model on single GPU and double config.DATA.TRAIN_BATCH in configs. 4 | # For LTCC dataset 5 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset ltcc --cfg configs/res50_cels_cal.yaml --gpu 0,1 # 6 | # For PRCC dataset 7 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset prcc --cfg configs/res50_cels_cal.yaml --gpu 0,1 # 8 | # For VC-Clothes dataset. You should change the root path of '--resume' to your output path. 9 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset vcclothes --cfg configs/res50_cels_cal.yaml --gpu 0,1 # 10 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset vcclothes_cc --cfg configs/res50_cels_cal.yaml --gpu 0,1 --eval --resume /data/guxinqian/logs/vcclothes/res50-cels-cal/best_model.pth.tar # 11 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset vcclothes_sc --cfg configs/res50_cels_cal.yaml --gpu 0,1 --eval --resume /data/guxinqian/logs/vcclothes/res50-cels-cal/best_model.pth.tar # 12 | # For DeepChange dataset. Using amp can accelerate training. 13 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset deepchange --cfg configs/res50_cels_cal_16x4.yaml --amp --gpu 0,1 # 14 | # For LaST dataset. Using amp can accelerate training. 15 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset last --cfg configs/res50_cels_cal_tri_16x4.yaml --amp --gpu 0,1 # 16 | # For CCVID dataset 17 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset ccvid --cfg configs/c2dres50_ce_cal.yaml --gpu 0,1 # -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import logging 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import distributed as dist 8 | from tools.eval_metrics import evaluate, evaluate_with_clothes 9 | 10 | 11 | VID_DATASET = ['ccvid'] 12 | 13 | 14 | def concat_all_gather(tensors, num_total_examples): 15 | ''' 16 | Performs all_gather operation on the provided tensor list. 17 | ''' 18 | outputs = [] 19 | for tensor in tensors: 20 | tensor = tensor.cuda() 21 | tensors_gather = [tensor.clone() for _ in range(dist.get_world_size())] 22 | dist.all_gather(tensors_gather, tensor) 23 | output = torch.cat(tensors_gather, dim=0).cpu() 24 | # truncate the dummy elements added by DistributedInferenceSampler 25 | outputs.append(output[:num_total_examples]) 26 | return outputs 27 | 28 | 29 | @torch.no_grad() 30 | def extract_img_feature(model, dataloader): 31 | features, pids, camids, clothes_ids = [], torch.tensor([]), torch.tensor([]), torch.tensor([]) 32 | for batch_idx, (imgs, batch_pids, batch_camids, batch_clothes_ids) in enumerate(dataloader): 33 | flip_imgs = torch.flip(imgs, [3]) 34 | imgs, flip_imgs = imgs.cuda(), flip_imgs.cuda() 35 | batch_features = model(imgs) 36 | batch_features_flip = model(flip_imgs) 37 | batch_features += batch_features_flip 38 | batch_features = F.normalize(batch_features, p=2, dim=1) 39 | 40 | features.append(batch_features.cpu()) 41 | pids = torch.cat((pids, batch_pids.cpu()), dim=0) 42 | camids = torch.cat((camids, batch_camids.cpu()), dim=0) 43 | clothes_ids = torch.cat((clothes_ids, batch_clothes_ids.cpu()), dim=0) 44 | features = torch.cat(features, 0) 45 | 46 | return features, pids, camids, clothes_ids 47 | 48 | 49 | @torch.no_grad() 50 | def extract_vid_feature(model, dataloader, vid2clip_index, data_length): 51 | # In build_dataloader, each original test video is split into a series of equilong clips. 52 | # During test, we first extact features for all clips 53 | clip_features, clip_pids, clip_camids, clip_clothes_ids = [], torch.tensor([]), torch.tensor([]), torch.tensor([]) 54 | for batch_idx, (vids, batch_pids, batch_camids, batch_clothes_ids) in enumerate(dataloader): 55 | if (batch_idx + 1) % 200==0: 56 | logger.info("{}/{}".format(batch_idx+1, len(dataloader))) 57 | vids = vids.cuda() 58 | batch_features = model(vids) 59 | clip_features.append(batch_features.cpu()) 60 | clip_pids = torch.cat((clip_pids, batch_pids.cpu()), dim=0) 61 | clip_camids = torch.cat((clip_camids, batch_camids.cpu()), dim=0) 62 | clip_clothes_ids = torch.cat((clip_clothes_ids, batch_clothes_ids.cpu()), dim=0) 63 | clip_features = torch.cat(clip_features, 0) 64 | 65 | # Gather samples from different GPUs 66 | clip_features, clip_pids, clip_camids, clip_clothes_ids = \ 67 | concat_all_gather([clip_features, clip_pids, clip_camids, clip_clothes_ids], data_length) 68 | 69 | # Use the averaged feature of all clips split from a video as the representation of this original full-length video 70 | features = torch.zeros(len(vid2clip_index), clip_features.size(1)).cuda() 71 | clip_features = clip_features.cuda() 72 | pids = torch.zeros(len(vid2clip_index)) 73 | camids = torch.zeros(len(vid2clip_index)) 74 | clothes_ids = torch.zeros(len(vid2clip_index)) 75 | for i, idx in enumerate(vid2clip_index): 76 | features[i] = clip_features[idx[0] : idx[1], :].mean(0) 77 | features[i] = F.normalize(features[i], p=2, dim=0) 78 | pids[i] = clip_pids[idx[0]] 79 | camids[i] = clip_camids[idx[0]] 80 | clothes_ids[i] = clip_clothes_ids[idx[0]] 81 | features = features.cpu() 82 | 83 | return features, pids, camids, clothes_ids 84 | 85 | 86 | def test(config, model, queryloader, galleryloader, dataset): 87 | logger = logging.getLogger('reid.test') 88 | since = time.time() 89 | model.eval() 90 | local_rank = dist.get_rank() 91 | # Extract features 92 | if config.DATA.DATASET in VID_DATASET: 93 | qf, q_pids, q_camids, q_clothes_ids = extract_vid_feature(model, queryloader, 94 | dataset.query_vid2clip_index, 95 | len(dataset.recombined_query)) 96 | gf, g_pids, g_camids, g_clothes_ids = extract_vid_feature(model, galleryloader, 97 | dataset.gallery_vid2clip_index, 98 | len(dataset.recombined_gallery)) 99 | else: 100 | qf, q_pids, q_camids, q_clothes_ids = extract_img_feature(model, queryloader) 101 | gf, g_pids, g_camids, g_clothes_ids = extract_img_feature(model, galleryloader) 102 | # Gather samples from different GPUs 103 | torch.cuda.empty_cache() 104 | qf, q_pids, q_camids, q_clothes_ids = concat_all_gather([qf, q_pids, q_camids, q_clothes_ids], len(dataset.query)) 105 | gf, g_pids, g_camids, g_clothes_ids = concat_all_gather([gf, g_pids, g_camids, g_clothes_ids], len(dataset.gallery)) 106 | torch.cuda.empty_cache() 107 | time_elapsed = time.time() - since 108 | 109 | logger.info("Extracted features for query set, obtained {} matrix".format(qf.shape)) 110 | logger.info("Extracted features for gallery set, obtained {} matrix".format(gf.shape)) 111 | logger.info('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 112 | # Compute distance matrix between query and gallery 113 | since = time.time() 114 | m, n = qf.size(0), gf.size(0) 115 | distmat = torch.zeros((m,n)) 116 | qf, gf = qf.cuda(), gf.cuda() 117 | # Cosine similarity 118 | for i in range(m): 119 | distmat[i] = (- torch.mm(qf[i:i+1], gf.t())).cpu() 120 | distmat = distmat.numpy() 121 | q_pids, q_camids, q_clothes_ids = q_pids.numpy(), q_camids.numpy(), q_clothes_ids.numpy() 122 | g_pids, g_camids, g_clothes_ids = g_pids.numpy(), g_camids.numpy(), g_clothes_ids.numpy() 123 | time_elapsed = time.time() - since 124 | logger.info('Distance computing in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 125 | 126 | since = time.time() 127 | logger.info("Computing CMC and mAP") 128 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) 129 | logger.info("Results ---------------------------------------------------") 130 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP)) 131 | logger.info("-----------------------------------------------------------") 132 | time_elapsed = time.time() - since 133 | logger.info('Using {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 134 | 135 | if config.DATA.DATASET in ['last', 'deepchange', 'vcclothes_sc', 'vcclothes_cc']: return cmc[0] 136 | 137 | logger.info("Computing CMC and mAP only for the same clothes setting") 138 | cmc, mAP = evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothes_ids, g_clothes_ids, mode='SC') 139 | logger.info("Results ---------------------------------------------------") 140 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP)) 141 | logger.info("-----------------------------------------------------------") 142 | 143 | logger.info("Computing CMC and mAP only for clothes-changing") 144 | cmc, mAP = evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothes_ids, g_clothes_ids, mode='CC') 145 | logger.info("Results ---------------------------------------------------") 146 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP)) 147 | logger.info("-----------------------------------------------------------") 148 | 149 | return cmc[0] 150 | 151 | 152 | def test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset): 153 | logger = logging.getLogger('reid.test') 154 | since = time.time() 155 | model.eval() 156 | local_rank = dist.get_rank() 157 | # Extract features for query set 158 | qsf, qs_pids, qs_camids, qs_clothes_ids = extract_img_feature(model, queryloader_same) 159 | qdf, qd_pids, qd_camids, qd_clothes_ids = extract_img_feature(model, queryloader_diff) 160 | # Extract features for gallery set 161 | gf, g_pids, g_camids, g_clothes_ids = extract_img_feature(model, galleryloader) 162 | # Gather samples from different GPUs 163 | torch.cuda.empty_cache() 164 | qsf, qs_pids, qs_camids, qs_clothes_ids = concat_all_gather([qsf, qs_pids, qs_camids, qs_clothes_ids], len(dataset.query_same)) 165 | qdf, qd_pids, qd_camids, qd_clothes_ids = concat_all_gather([qdf, qd_pids, qd_camids, qd_clothes_ids], len(dataset.query_diff)) 166 | gf, g_pids, g_camids, g_clothes_ids = concat_all_gather([gf, g_pids, g_camids, g_clothes_ids], len(dataset.gallery)) 167 | time_elapsed = time.time() - since 168 | 169 | logger.info("Extracted features for query set (with same clothes), obtained {} matrix".format(qsf.shape)) 170 | logger.info("Extracted features for query set (with different clothes), obtained {} matrix".format(qdf.shape)) 171 | logger.info("Extracted features for gallery set, obtained {} matrix".format(gf.shape)) 172 | logger.info('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 173 | # Compute distance matrix between query and gallery 174 | m, n, k = qsf.size(0), qdf.size(0), gf.size(0) 175 | distmat_same = torch.zeros((m, k)) 176 | distmat_diff = torch.zeros((n, k)) 177 | qsf, qdf, gf = qsf.cuda(), qdf.cuda(), gf.cuda() 178 | # Cosine similarity 179 | for i in range(m): 180 | distmat_same[i] = (- torch.mm(qsf[i:i+1], gf.t())).cpu() 181 | for i in range(n): 182 | distmat_diff[i] = (- torch.mm(qdf[i:i+1], gf.t())).cpu() 183 | distmat_same = distmat_same.numpy() 184 | distmat_diff = distmat_diff.numpy() 185 | qs_pids, qs_camids, qs_clothes_ids = qs_pids.numpy(), qs_camids.numpy(), qs_clothes_ids.numpy() 186 | qd_pids, qd_camids, qd_clothes_ids = qd_pids.numpy(), qd_camids.numpy(), qd_clothes_ids.numpy() 187 | g_pids, g_camids, g_clothes_ids = g_pids.numpy(), g_camids.numpy(), g_clothes_ids.numpy() 188 | 189 | logger.info("Computing CMC and mAP for the same clothes setting") 190 | cmc, mAP = evaluate(distmat_same, qs_pids, g_pids, qs_camids, g_camids) 191 | logger.info("Results ---------------------------------------------------") 192 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP)) 193 | logger.info("-----------------------------------------------------------") 194 | 195 | logger.info("Computing CMC and mAP only for clothes changing") 196 | cmc, mAP = evaluate(distmat_diff, qd_pids, g_pids, qd_camids, g_camids) 197 | logger.info("Results ---------------------------------------------------") 198 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP)) 199 | logger.info("-----------------------------------------------------------") 200 | 201 | return cmc[0] -------------------------------------------------------------------------------- /tools/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | 4 | 5 | def compute_ap_cmc(index, good_index, junk_index): 6 | """ Compute AP and CMC for each sample 7 | """ 8 | ap = 0 9 | cmc = np.zeros(len(index)) 10 | 11 | # remove junk_index 12 | mask = np.in1d(index, junk_index, invert=True) 13 | index = index[mask] 14 | 15 | # find good_index index 16 | ngood = len(good_index) 17 | mask = np.in1d(index, good_index) 18 | rows_good = np.argwhere(mask==True) 19 | rows_good = rows_good.flatten() 20 | 21 | cmc[rows_good[0]:] = 1.0 22 | for i in range(ngood): 23 | d_recall = 1.0/ngood 24 | precision = (i+1)*1.0/(rows_good[i]+1) 25 | ap = ap + d_recall*precision 26 | 27 | return ap, cmc 28 | 29 | 30 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids): 31 | """ Compute CMC and mAP 32 | 33 | Args: 34 | distmat (numpy ndarray): distance matrix with shape (num_query, num_gallery). 35 | q_pids (numpy array): person IDs for query samples. 36 | g_pids (numpy array): person IDs for gallery samples. 37 | q_camids (numpy array): camera IDs for query samples. 38 | g_camids (numpy array): camera IDs for gallery samples. 39 | """ 40 | num_q, num_g = distmat.shape 41 | index = np.argsort(distmat, axis=1) # from small to large 42 | 43 | num_no_gt = 0 # num of query imgs without groundtruth 44 | num_r1 = 0 45 | CMC = np.zeros(len(g_pids)) 46 | AP = 0 47 | 48 | for i in range(num_q): 49 | # groundtruth index 50 | query_index = np.argwhere(g_pids==q_pids[i]) 51 | camera_index = np.argwhere(g_camids==q_camids[i]) 52 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 53 | if good_index.size == 0: 54 | num_no_gt += 1 55 | continue 56 | # remove gallery samples that have the same pid and camid with query 57 | junk_index = np.intersect1d(query_index, camera_index) 58 | 59 | ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index) 60 | if CMC_tmp[0]==1: 61 | num_r1 += 1 62 | CMC = CMC + CMC_tmp 63 | AP += ap_tmp 64 | 65 | if num_no_gt > 0: 66 | logger = logging.getLogger('reid.evaluate') 67 | logger.info("{} query samples do not have groundtruth.".format(num_no_gt)) 68 | 69 | CMC = CMC / (num_q - num_no_gt) 70 | mAP = AP / (num_q - num_no_gt) 71 | 72 | return CMC, mAP 73 | 74 | 75 | def evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothids, g_clothids, mode='CC'): 76 | """ Compute CMC and mAP with clothes 77 | 78 | Args: 79 | distmat (numpy ndarray): distance matrix with shape (num_query, num_gallery). 80 | q_pids (numpy array): person IDs for query samples. 81 | g_pids (numpy array): person IDs for gallery samples. 82 | q_camids (numpy array): camera IDs for query samples. 83 | g_camids (numpy array): camera IDs for gallery samples. 84 | q_clothids (numpy array): clothes IDs for query samples. 85 | g_clothids (numpy array): clothes IDs for gallery samples. 86 | mode: 'CC' for clothes-changing; 'SC' for the same clothes. 87 | """ 88 | assert mode in ['CC', 'SC'] 89 | 90 | num_q, num_g = distmat.shape 91 | index = np.argsort(distmat, axis=1) # from small to large 92 | 93 | num_no_gt = 0 # num of query imgs without groundtruth 94 | num_r1 = 0 95 | CMC = np.zeros(len(g_pids)) 96 | AP = 0 97 | 98 | for i in range(num_q): 99 | # groundtruth index 100 | query_index = np.argwhere(g_pids==q_pids[i]) 101 | camera_index = np.argwhere(g_camids==q_camids[i]) 102 | cloth_index = np.argwhere(g_clothids==q_clothids[i]) 103 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 104 | if mode == 'CC': 105 | good_index = np.setdiff1d(good_index, cloth_index, assume_unique=True) 106 | # remove gallery samples that have the same (pid, camid) or (pid, clothid) with query 107 | junk_index1 = np.intersect1d(query_index, camera_index) 108 | junk_index2 = np.intersect1d(query_index, cloth_index) 109 | junk_index = np.union1d(junk_index1, junk_index2) 110 | else: 111 | good_index = np.intersect1d(good_index, cloth_index) 112 | # remove gallery samples that have the same (pid, camid) or 113 | # (the same pid and different clothid) with query 114 | junk_index1 = np.intersect1d(query_index, camera_index) 115 | junk_index2 = np.setdiff1d(query_index, cloth_index) 116 | junk_index = np.union1d(junk_index1, junk_index2) 117 | 118 | if good_index.size == 0: 119 | num_no_gt += 1 120 | continue 121 | 122 | ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index) 123 | if CMC_tmp[0]==1: 124 | num_r1 += 1 125 | CMC = CMC + CMC_tmp 126 | AP += ap_tmp 127 | 128 | if num_no_gt > 0: 129 | logger = logging.getLogger('reid.evaluate') 130 | logger.info("{} query samples do not have groundtruth.".format(num_no_gt)) 131 | 132 | if (num_q - num_no_gt) != 0: 133 | CMC = CMC / (num_q - num_no_gt) 134 | mAP = AP / (num_q - num_no_gt) 135 | else: 136 | mAP = 0 137 | 138 | return CMC, mAP -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import errno 5 | import json 6 | import os.path as osp 7 | import torch 8 | import random 9 | import logging 10 | import numpy as np 11 | 12 | 13 | def set_seed(seed=None): 14 | if seed is None: 15 | return 16 | random.seed(seed) 17 | os.environ['PYTHONHASHSEED'] = ("%s" % seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | torch.backends.cudnn.benchmark = False 23 | torch.backends.cudnn.deterministic = True 24 | 25 | 26 | def mkdir_if_missing(directory): 27 | if not osp.exists(directory): 28 | try: 29 | os.makedirs(directory) 30 | except OSError as e: 31 | if e.errno != errno.EEXIST: 32 | raise 33 | 34 | 35 | def read_json(fpath): 36 | with open(fpath, 'r') as f: 37 | obj = json.load(f) 38 | return obj 39 | 40 | 41 | def write_json(obj, fpath): 42 | mkdir_if_missing(osp.dirname(fpath)) 43 | with open(fpath, 'w') as f: 44 | json.dump(obj, f, indent=4, separators=(',', ': ')) 45 | 46 | 47 | class AverageMeter(object): 48 | """Computes and stores the average and current value. 49 | 50 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 51 | """ 52 | def __init__(self): 53 | self.reset() 54 | 55 | def reset(self): 56 | self.val = 0 57 | self.avg = 0 58 | self.sum = 0 59 | self.count = 0 60 | 61 | def update(self, val, n=1): 62 | self.val = val 63 | self.sum += val * n 64 | self.count += n 65 | self.avg = self.sum / self.count 66 | 67 | 68 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 69 | mkdir_if_missing(osp.dirname(fpath)) 70 | torch.save(state, fpath) 71 | if is_best: 72 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) 73 | 74 | ''' 75 | class Logger(object): 76 | """ 77 | Write console output to external text file. 78 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 79 | """ 80 | def __init__(self, fpath=None): 81 | self.console = sys.stdout 82 | self.file = None 83 | if fpath is not None: 84 | mkdir_if_missing(os.path.dirname(fpath)) 85 | self.file = open(fpath, 'w') 86 | 87 | def __del__(self): 88 | self.close() 89 | 90 | def __enter__(self): 91 | pass 92 | 93 | def __exit__(self, *args): 94 | self.close() 95 | 96 | def write(self, msg): 97 | self.console.write(msg) 98 | if self.file is not None: 99 | self.file.write(msg) 100 | 101 | def flush(self): 102 | self.console.flush() 103 | if self.file is not None: 104 | self.file.flush() 105 | os.fsync(self.file.fileno()) 106 | 107 | def close(self): 108 | self.console.close() 109 | if self.file is not None: 110 | self.file.close() 111 | ''' 112 | 113 | 114 | def get_logger(fpath, local_rank=0, name=''): 115 | # Creat logger 116 | logger = logging.getLogger(name) 117 | level = logging.INFO if local_rank in [-1, 0] else logging.WARN 118 | logger.setLevel(level=level) 119 | 120 | # Output to console 121 | console_handler = logging.StreamHandler(sys.stdout) 122 | console_handler.setLevel(level=level) 123 | console_handler.setFormatter(logging.Formatter('%(message)s')) 124 | logger.addHandler(console_handler) 125 | 126 | # Output to file 127 | if fpath is not None: 128 | mkdir_if_missing(os.path.dirname(fpath)) 129 | file_handler = logging.FileHandler(fpath, mode='w') 130 | file_handler.setLevel(level=level) 131 | file_handler.setFormatter(logging.Formatter('%(message)s')) 132 | logger.addHandler(file_handler) 133 | 134 | return logger -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import logging 4 | import torch 5 | from apex import amp 6 | from tools.utils import AverageMeter 7 | 8 | 9 | def train_cal(config, epoch, model, classifier, clothes_classifier, criterion_cla, criterion_pair, 10 | criterion_clothes, criterion_adv, optimizer, optimizer_cc, trainloader, pid2clothes): 11 | logger = logging.getLogger('reid.train') 12 | batch_cla_loss = AverageMeter() 13 | batch_pair_loss = AverageMeter() 14 | batch_clo_loss = AverageMeter() 15 | batch_adv_loss = AverageMeter() 16 | corrects = AverageMeter() 17 | clothes_corrects = AverageMeter() 18 | batch_time = AverageMeter() 19 | data_time = AverageMeter() 20 | 21 | model.train() 22 | classifier.train() 23 | clothes_classifier.train() 24 | 25 | end = time.time() 26 | for batch_idx, (imgs, pids, camids, clothes_ids) in enumerate(trainloader): 27 | # Get all positive clothes classes (belonging to the same identity) for each sample 28 | pos_mask = pid2clothes[pids] 29 | imgs, pids, clothes_ids, pos_mask = imgs.cuda(), pids.cuda(), clothes_ids.cuda(), pos_mask.float().cuda() 30 | # Measure data loading time 31 | data_time.update(time.time() - end) 32 | # Forward 33 | features = model(imgs) 34 | outputs = classifier(features) 35 | pred_clothes = clothes_classifier(features.detach()) 36 | _, preds = torch.max(outputs.data, 1) 37 | 38 | # Update the clothes discriminator 39 | clothes_loss = criterion_clothes(pred_clothes, clothes_ids) 40 | if epoch >= config.TRAIN.START_EPOCH_CC: 41 | optimizer_cc.zero_grad() 42 | if config.TRAIN.AMP: 43 | with amp.scale_loss(clothes_loss, optimizer_cc) as scaled_loss: 44 | scaled_loss.backward() 45 | else: 46 | clothes_loss.backward() 47 | optimizer_cc.step() 48 | 49 | # Update the backbone 50 | new_pred_clothes = clothes_classifier(features) 51 | _, clothes_preds = torch.max(new_pred_clothes.data, 1) 52 | 53 | # Compute loss 54 | cla_loss = criterion_cla(outputs, pids) 55 | pair_loss = criterion_pair(features, pids) 56 | adv_loss = criterion_adv(new_pred_clothes, clothes_ids, pos_mask) 57 | if epoch >= config.TRAIN.START_EPOCH_ADV: 58 | loss = cla_loss + adv_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss 59 | else: 60 | loss = cla_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss 61 | optimizer.zero_grad() 62 | if config.TRAIN.AMP: 63 | with amp.scale_loss(loss, optimizer) as scaled_loss: 64 | scaled_loss.backward() 65 | else: 66 | loss.backward() 67 | optimizer.step() 68 | 69 | # statistics 70 | corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0)) 71 | clothes_corrects.update(torch.sum(clothes_preds == clothes_ids.data).float()/clothes_ids.size(0), clothes_ids.size(0)) 72 | batch_cla_loss.update(cla_loss.item(), pids.size(0)) 73 | batch_pair_loss.update(pair_loss.item(), pids.size(0)) 74 | batch_clo_loss.update(clothes_loss.item(), clothes_ids.size(0)) 75 | batch_adv_loss.update(adv_loss.item(), clothes_ids.size(0)) 76 | # measure elapsed time 77 | batch_time.update(time.time() - end) 78 | end = time.time() 79 | 80 | logger.info('Epoch{0} ' 81 | 'Time:{batch_time.sum:.1f}s ' 82 | 'Data:{data_time.sum:.1f}s ' 83 | 'ClaLoss:{cla_loss.avg:.4f} ' 84 | 'PairLoss:{pair_loss.avg:.4f} ' 85 | 'CloLoss:{clo_loss.avg:.4f} ' 86 | 'AdvLoss:{adv_loss.avg:.4f} ' 87 | 'Acc:{acc.avg:.2%} ' 88 | 'CloAcc:{clo_acc.avg:.2%} '.format( 89 | epoch+1, batch_time=batch_time, data_time=data_time, 90 | cla_loss=batch_cla_loss, pair_loss=batch_pair_loss, 91 | clo_loss=batch_clo_loss, adv_loss=batch_adv_loss, 92 | acc=corrects, clo_acc=clothes_corrects)) 93 | 94 | 95 | def train_cal_with_memory(config, epoch, model, classifier, criterion_cla, criterion_pair, 96 | criterion_adv, optimizer, trainloader, pid2clothes): 97 | logger = logging.getLogger('reid.train') 98 | batch_cla_loss = AverageMeter() 99 | batch_pair_loss = AverageMeter() 100 | batch_adv_loss = AverageMeter() 101 | corrects = AverageMeter() 102 | batch_time = AverageMeter() 103 | data_time = AverageMeter() 104 | 105 | model.train() 106 | classifier.train() 107 | 108 | end = time.time() 109 | for batch_idx, (imgs, pids, camids, clothes_ids) in enumerate(trainloader): 110 | # Get all positive clothes classes (belonging to the same identity) for each sample 111 | pos_mask = pid2clothes[pids] 112 | imgs, pids, clothes_ids, pos_mask = imgs.cuda(), pids.cuda(), clothes_ids.cuda(), pos_mask.float().cuda() 113 | # Measure data loading time 114 | data_time.update(time.time() - end) 115 | # Forward 116 | features = model(imgs) 117 | outputs = classifier(features) 118 | _, preds = torch.max(outputs.data, 1) 119 | 120 | # Compute loss 121 | cla_loss = criterion_cla(outputs, pids) 122 | pair_loss = criterion_pair(features, pids) 123 | 124 | if epoch >= config.TRAIN.START_EPOCH_ADV: 125 | adv_loss = criterion_adv(features, clothes_ids, pos_mask) 126 | loss = cla_loss + adv_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss 127 | else: 128 | loss = cla_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss 129 | 130 | optimizer.zero_grad() 131 | if config.TRAIN.AMP: 132 | with amp.scale_loss(loss, optimizer) as scaled_loss: 133 | scaled_loss.backward() 134 | else: 135 | loss.backward() 136 | optimizer.step() 137 | 138 | # statistics 139 | corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0)) 140 | batch_cla_loss.update(cla_loss.item(), pids.size(0)) 141 | batch_pair_loss.update(pair_loss.item(), pids.size(0)) 142 | if epoch >= config.TRAIN.START_EPOCH_ADV: 143 | batch_adv_loss.update(adv_loss.item(), clothes_ids.size(0)) 144 | # measure elapsed time 145 | batch_time.update(time.time() - end) 146 | end = time.time() 147 | 148 | logger.info('Epoch{0} ' 149 | 'Time:{batch_time.sum:.1f}s ' 150 | 'Data:{data_time.sum:.1f}s ' 151 | 'ClaLoss:{cla_loss.avg:.4f} ' 152 | 'PairLoss:{pair_loss.avg:.4f} ' 153 | 'AdvLoss:{adv_loss.avg:.4f} ' 154 | 'Acc:{acc.avg:.2%} '.format( 155 | epoch+1, batch_time=batch_time, data_time=data_time, 156 | cla_loss=batch_cla_loss, pair_loss=batch_pair_loss, 157 | adv_loss=batch_adv_loss, acc=corrects)) --------------------------------------------------------------------------------