├── .idea ├── .gitignore ├── OTLA.iml ├── deployment.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── SpCL-master ├── .gitignore ├── LICENSE ├── README.md ├── examples │ ├── otla_tool.py │ ├── spcl_train_uda.py │ ├── spcl_train_usl.py │ └── test.py ├── figs │ ├── framework.png │ └── results.png ├── setup.cfg ├── setup.py └── spcl │ ├── __init__.py │ ├── datasets │ ├── __init__.py │ ├── dukemtmc.py │ ├── market1501.py │ ├── msmt17.py │ ├── personx.py │ ├── regdb.py │ ├── regdb_ir.py │ ├── regdb_rgb.py │ ├── sysumm01.py │ ├── sysumm01_ir.py │ ├── sysumm01_rgb.py │ ├── vehicleid.py │ ├── vehiclex.py │ └── veri.py │ ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py │ ├── evaluators.py │ ├── models │ ├── __init__.py │ ├── dsbn.py │ ├── hm.py │ ├── resnet.py │ ├── resnet_ibn.py │ └── resnet_ibn_a.py │ ├── trainers.py │ └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── preprocessor.py │ ├── sampler.py │ └── transforms.py │ ├── faiss_rerank.py │ ├── faiss_utils.py │ ├── logging.py │ ├── meters.py │ ├── osutils.py │ ├── rerank.py │ └── serialization.py ├── config ├── config_regdb.yaml └── config_sysu.yaml ├── data_loader.py ├── data_manager.py ├── engine.py ├── eval_metrics.py ├── image └── main_figure.png ├── loss.py ├── main_test.py ├── main_train.py ├── model ├── backbone │ └── resnet.py └── network.py ├── optimizer.py ├── otla_sk.py ├── utils.py └── video-poster ├── 0971.mp4 └── 0971.pdf /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../../../../../:\Users\王蒋铭\pytorch\github\OTLA\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/OTLA.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 16 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimal Transport for Label-Efficient Visible-Infrared Person Re-Identification (OTLA-ReID) 2 | This is Official Repository for "Optimal Transport for Label-Efficient 3 | Visible-Infrared Person Re-Identification" ([PDF](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136840091.pdf), [Supplementary Material](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136840091-supp.pdf)), which is accepted by *ECCV 2022*. This work is done at the DMCV Laboratory of East China Normal University. You can link at [DMCV-Lab](https://dmcv-ecnu.github.io/) to find DMCV Laboratory website page. 4 | 5 | ![main_figure](./image/main_figure.png) 6 | 7 | ### Update: 8 | **[2022-7-17]** Semi-supervised setting and supervised setting can be run with current code. Unsupervised setting will be updated with a few of days. 9 | 10 | **[2022-7-21]** Update some critical informtion of REAMDE.md. 11 | 12 | **[2022-9-22]** Update the code of SpCL-master, which can be used to generator pseudo labels of visible modality for unsupervised setting. 13 | 14 | **[2022-10-28]** Update the paper link. 15 | 16 | 17 | ## Requirements 18 | + python 3.7.11 19 | + numpy 1.21.4 20 | + torch 1.10.0 21 | + torchvision 0.11.0 22 | + easydict 1.9 23 | + PyYAML 6.0 24 | + tensorboardX 2.2 25 | 26 | 27 | ## Prepare Datasets 28 | Download the VI-ReID datasets [SYSU-MM01](https://github.com/wuancong/SYSU-MM01) (Email the author to get it) and [RegDB](http://dm.dongguk.edu/link.html) (Submit a copyright form). follow the link of [DDAG](https://github.com/mangye16/DDAG) to obtain more information of VI-ReID datasets. Download visible ReID datasets [Market-1501](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view), [MSMT17](https://arxiv.org/abs/1711.08565) (Email the author to get it), [DukeMTMC-reID](https://drive.google.com/file/d/1jjE85dRCMOgRtvJ5RQV9-Afs-2_5dY3O/view) if you want to run unsupervised setting. Please follow the link of [OpenUnReID](https://github.com/open-mmlab/OpenUnReID/blob/master/docs/INSTALL.md) to obtain more information of visible ReID datasets. 29 | 30 | 31 | ## Training 32 | You need to firstly choose the ```setting:``` of config file corresponding VI-ReID dataset. 33 | 34 | + For ```semi-supervised``` / ```supervised``` setting, if you want to train the model(s) in the paper, run following command: 35 | ```shell 36 | cd OTLA-ReID/ 37 | python main_train.py --config config/config_sysu.yaml 38 | ``` 39 | + For ```unsupervised``` setting, you should write the right path of ```train_visible_image_path:``` and ```train_visible_label_path:``` , which are the produced visible data and pseudo label path of VI-ReID datasets by well-established UDA-ReID or USL-ReID methods (e.g. [SpCL](https://github.com/yxgeee/SpCL)). Then run following command: 40 | ```shell 41 | cd OTLA-ReID/ 42 | python main_train.py --config config/config_sysu.yaml 43 | ``` 44 | 45 | Here, we give an example of running SpCL to generate visible pseudo label in SpCL-master. However, you firstly need to install environment which can be found in [SpCL](https://github.com/yxgeee/SpCL): 46 | + For SYSU-MM01: 47 | ```shell 48 | cd OTLA-ReID/SpCL-master/ 49 | CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/spcl_train_uda.py -ds market1501 -dt sysumm01_rgb --logs-dir logs/spcl_uda/market1501TOsysumm01_rgb_resnet50 --epochs 51 --iters 800 50 | ``` 51 | + For RegDB: 52 | ```shell 53 | cd OTLA-ReID/SpCL-master/ 54 | CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/spcl_train_uda.py -ds market1501 -dt regdb_rgb --logs-dir logs/spcl_uda/regdbTOsysumm01_rgb_resnet50 --epochs 51 --iters 50 55 | ``` 56 | The generated visible images and visible pseudo labels are both saved under the dataset directory. 57 | 58 | ## Testing 59 | If you want to test the trained model(s), run following command: 60 | ```shell 61 | cd OTLA-ReID/ 62 | python main_test.py --config config/config_sysu.yaml --resume --resume_path ./sysu_semi-supervised_otla-reid/sysu_save_model/best_checkpoint.pth 63 | ``` 64 | 65 | ## Citation 66 | If you find this code useful for your research, please cite our paper: 67 | ``` 68 | @inproceedings{wang2022optimal, 69 | title={Optimal Transport for Label-Efficient Visible-Infrared Person Re-Identification}, 70 | author={Wang, Jiangming and Zhang, Zhizhong and Chen, Mingang and Zhang, Yi and Wang, Cong and Sheng, Bin and Qu, Yanyun and Xie, Yuan}, 71 | booktitle={European Conference on Computer Vision}, 72 | pages={93--109}, 73 | year={2022}, 74 | organization={Springer} 75 | } 76 | ``` 77 | 78 | ## Acknowledgements 79 | This work is developed based on repositories of [SeLa(ICLR 2020)](https://github.com/yukimasano/self-label), [DDAG(ECCV 2020)](https://github.com/mangye16/DDAG), [SpCL(NIPS 2020)](https://github.com/yxgeee/SpCL), [MMT(ICLR 2020)](https://github.com/yxgeee/MMT), [HCD(ICCV 2021)](https://github.com/tangshixiang/HCD). We sincerely thanks all developers of these high-quality repositories. 80 | -------------------------------------------------------------------------------- /SpCL-master/.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | logs/* 3 | scripts/* 4 | 5 | # temporary files which can be created if a process still has a handle open of a deleted file 6 | .fuse_hidden* 7 | 8 | # KDE directory preferences 9 | .directory 10 | 11 | # Linux trash folder which might appear on any partition or disk 12 | .Trash-* 13 | 14 | # .nfs files are created when an open file is removed but is still being accessed 15 | .nfs* 16 | 17 | 18 | *.DS_Store 19 | .AppleDouble 20 | .LSOverride 21 | 22 | # Icon must end with two \r 23 | Icon 24 | 25 | 26 | # Thumbnails 27 | ._* 28 | 29 | # Files that might appear in the root of a volume 30 | .DocumentRevisions-V100 31 | .fseventsd 32 | .Spotlight-V100 33 | .TemporaryItems 34 | .Trashes 35 | .VolumeIcon.icns 36 | .com.apple.timemachine.donotpresent 37 | 38 | # Directories potentially created on remote AFP share 39 | .AppleDB 40 | .AppleDesktop 41 | Network Trash Folder 42 | Temporary Items 43 | .apdisk 44 | 45 | 46 | # swap 47 | [._]*.s[a-v][a-z] 48 | [._]*.sw[a-p] 49 | [._]s[a-v][a-z] 50 | [._]sw[a-p] 51 | # session 52 | Session.vim 53 | # temporary 54 | .netrwhist 55 | *~ 56 | # auto-generated tag files 57 | tags 58 | 59 | 60 | # cache files for sublime text 61 | *.tmlanguage.cache 62 | *.tmPreferences.cache 63 | *.stTheme.cache 64 | 65 | # workspace files are user-specific 66 | *.sublime-workspace 67 | 68 | # project files should be checked into the repository, unless a significant 69 | # proportion of contributors will probably not be using SublimeText 70 | # *.sublime-project 71 | 72 | # sftp configuration file 73 | sftp-config.json 74 | 75 | # Package control specific files 76 | Package Control.last-run 77 | Package Control.ca-list 78 | Package Control.ca-bundle 79 | Package Control.system-ca-bundle 80 | Package Control.cache/ 81 | Package Control.ca-certs/ 82 | Package Control.merged-ca-bundle 83 | Package Control.user-ca-bundle 84 | oscrypto-ca-bundle.crt 85 | bh_unicode_properties.cache 86 | 87 | # Sublime-github package stores a github token in this file 88 | # https://packagecontrol.io/packages/sublime-github 89 | GitHub.sublime-settings 90 | 91 | 92 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 93 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 94 | 95 | # User-specific stuff: 96 | .idea 97 | .idea/**/workspace.xml 98 | .idea/**/tasks.xml 99 | 100 | # Sensitive or high-churn files: 101 | .idea/**/dataSources/ 102 | .idea/**/dataSources.ids 103 | .idea/**/dataSources.xml 104 | .idea/**/dataSources.local.xml 105 | .idea/**/sqlDataSources.xml 106 | .idea/**/dynamic.xml 107 | .idea/**/uiDesigner.xml 108 | 109 | # Gradle: 110 | .idea/**/gradle.xml 111 | .idea/**/libraries 112 | 113 | # Mongo Explorer plugin: 114 | .idea/**/mongoSettings.xml 115 | 116 | ## File-based project format: 117 | *.iws 118 | 119 | ## Plugin-specific files: 120 | 121 | # IntelliJ 122 | /out/ 123 | 124 | # mpeltonen/sbt-idea plugin 125 | .idea_modules/ 126 | 127 | # JIRA plugin 128 | atlassian-ide-plugin.xml 129 | 130 | # Crashlytics plugin (for Android Studio and IntelliJ) 131 | com_crashlytics_export_strings.xml 132 | crashlytics.properties 133 | crashlytics-build.properties 134 | fabric.properties 135 | 136 | 137 | # Byte-compiled / optimized / DLL files 138 | __pycache__/ 139 | *.py[cod] 140 | *$py.class 141 | 142 | # C extensions 143 | *.so 144 | 145 | # Distribution / packaging 146 | .Python 147 | env/ 148 | build/ 149 | develop-eggs/ 150 | dist/ 151 | downloads/ 152 | eggs/ 153 | .eggs/ 154 | lib/ 155 | lib64/ 156 | parts/ 157 | sdist/ 158 | var/ 159 | *.egg-info/ 160 | .installed.cfg 161 | *.egg 162 | 163 | # PyInstaller 164 | # Usually these files are written by a python script from a template 165 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 166 | *.manifest 167 | *.spec 168 | 169 | # Installer logs 170 | pip-log.txt 171 | pip-delete-this-directory.txt 172 | 173 | # Unit test / coverage reports 174 | htmlcov/ 175 | .tox/ 176 | .coverage 177 | .coverage.* 178 | .cache 179 | nosetests.xml 180 | coverage.xml 181 | *,cover 182 | .hypothesis/ 183 | 184 | # Translations 185 | *.mo 186 | *.pot 187 | 188 | # Django stuff: 189 | *.log 190 | local_settings.py 191 | 192 | # Flask stuff: 193 | instance/ 194 | .webassets-cache 195 | 196 | # Scrapy stuff: 197 | .scrapy 198 | 199 | # Sphinx documentation 200 | docs/_build/ 201 | 202 | # PyBuilder 203 | target/ 204 | 205 | # IPython Notebook 206 | .ipynb_checkpoints 207 | 208 | # pyenv 209 | .python-version 210 | 211 | # celery beat schedule file 212 | celerybeat-schedule 213 | 214 | # dotenv 215 | .env 216 | 217 | # virtualenv 218 | venv/ 219 | ENV/ 220 | 221 | # Spyder project settings 222 | .spyderproject 223 | 224 | # Rope project settings 225 | .ropeproject 226 | 227 | 228 | # Project specific 229 | examples/data 230 | examples/logs 231 | -------------------------------------------------------------------------------- /SpCL-master/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yixiao Ge 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SpCL-master/README.md: -------------------------------------------------------------------------------- 1 | ![Python >=3.5](https://img.shields.io/badge/Python->=3.5-blue.svg) 2 | ![PyTorch >=1.0](https://img.shields.io/badge/PyTorch->=1.0-yellow.svg) 3 | 4 | # Self-paced Contrastive Learning (SpCL) 5 | 6 | The *official* repository for [Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID](https://arxiv.org/abs/2006.02713), which is accepted by [NeurIPS-2020](https://nips.cc/). `SpCL` achieves state-of-the-art performances on both **unsupervised domain adaptation** tasks and **unsupervised learning** tasks for object re-ID, including person re-ID and vehicle re-ID. 7 | 8 | ![framework](figs/framework.png) 9 | 10 | ### Updates 11 | 12 | [2020-10-13] All trained models for the camera-ready version have been updated, see [Trained Models](#trained-models) for details. 13 | 14 | [2020-09-25] `SpCL` has been accepted by NeurIPS on the condition that experiments on DukeMTMC-reID dataset should be removed, since the dataset has been taken down and should no longer be used. 15 | 16 | [2020-07-01] We did the code refactoring to support distributed training, stronger performances and more features. Please see [OpenUnReID](https://github.com/open-mmlab/OpenUnReID). 17 | 18 | ## Requirements 19 | 20 | ### Installation 21 | 22 | ```shell 23 | git clone https://github.com/yxgeee/SpCL.git 24 | cd SpCL 25 | python setup.py develop 26 | ``` 27 | 28 | ### Prepare Datasets 29 | 30 | ```shell 31 | cd examples && mkdir data 32 | ``` 33 | Download the person datasets [Market-1501](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view), [MSMT17](https://arxiv.org/abs/1711.08565), [PersonX](https://github.com/sxzrt/Instructions-of-the-PersonX-dataset#data-for-visda2020-chanllenge), and the vehicle datasets [VehicleID](https://www.pkuml.org/resources/pku-vehicleid.html), [VeRi-776](https://github.com/JDAI-CV/VeRidataset), [VehicleX](https://www.aicitychallenge.org/2020-track2-download/). 34 | Then unzip them under the directory like 35 | ``` 36 | SpCL/examples/data 37 | ├── market1501 38 | │   └── Market-1501-v15.09.15 39 | ├── msmt17 40 | │   └── MSMT17_V1 41 | ├── personx 42 | │   └── PersonX 43 | ├── vehicleid 44 | │   └── VehicleID -> VehicleID_V1.0 45 | ├── vehiclex 46 | │   └── AIC20_ReID_Simulation -> AIC20_track2/AIC20_ReID_Simulation 47 | └── veri 48 | └── VeRi -> VeRi_with_plate 49 | ``` 50 | 51 | ### Prepare ImageNet Pre-trained Models for IBN-Net 52 | When training with the backbone of [IBN-ResNet](https://arxiv.org/abs/1807.09441), you need to download the ImageNet-pretrained model from this [link](https://drive.google.com/drive/folders/1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S) and save it under the path of `logs/pretrained/`. 53 | ```shell 54 | mkdir logs && cd logs 55 | mkdir pretrained 56 | ``` 57 | The file tree should be 58 | ``` 59 | SpCL/logs 60 | └── pretrained 61 |    └── resnet50_ibn_a.pth.tar 62 | ``` 63 | ImageNet-pretrained models for **ResNet-50** will be automatically downloaded in the python script. 64 | 65 | 66 | ## Training 67 | 68 | We utilize 4 GTX-1080TI GPUs for training. **Note that** 69 | 70 | + The training for `SpCL` is end-to-end, which means that no source-domain pre-training is required. 71 | + use `--iters 400` (default) for Market-1501 and PersonX datasets, and `--iters 800` for MSMT17, VeRi-776, VehicleID and VehicleX datasets; 72 | + use `--width 128 --height 256` (default) for person datasets, and `--height 224 --width 224` for vehicle datasets; 73 | + use `-a resnet50` (default) for the backbone of ResNet-50, and `-a resnet_ibn50a` for the backbone of IBN-ResNet. 74 | 75 | ### Unsupervised Domain Adaptation 76 | To train the model(s) in the paper, run this command: 77 | ```shell 78 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 79 | python examples/spcl_train_uda.py \ 80 | -ds $SOURCE_DATASET -dt $TARGET_DATASET --logs-dir $PATH_OF_LOGS 81 | ``` 82 | 83 | **Some examples:** 84 | ```shell 85 | ### PersonX -> Market-1501 ### 86 | # use all default settings is ok 87 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 88 | python examples/spcl_train_uda.py \ 89 | -ds personx -dt market1501 --logs-dir logs/spcl_uda/personx2market_resnet50 90 | 91 | ### Market-1501 -> MSMT17 ### 92 | # use all default settings except for iters=800 93 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 94 | python examples/spcl_train_uda.py --iters 800 \ 95 | -ds market1501 -dt msmt17 --logs-dir logs/spcl_uda/market2msmt_resnet50 96 | 97 | ### VehicleID -> VeRi-776 ### 98 | # use all default settings except for iters=800, height=224 and width=224 99 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 100 | python examples/spcl_train_uda.py --iters 800 --height 224 --width 224 \ 101 | -ds vehicleid -dt veri --logs-dir logs/spcl_uda/vehicleid2veri_resnet50 102 | ``` 103 | 104 | 105 | ### Unsupervised Learning 106 | To train the model(s) in the paper, run this command: 107 | ```shell 108 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 109 | python examples/spcl_train_usl.py \ 110 | -d $DATASET --logs-dir $PATH_OF_LOGS 111 | ``` 112 | 113 | **Some examples:** 114 | ```shell 115 | ### Market-1501 ### 116 | # use all default settings is ok 117 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 118 | python examples/spcl_train_usl.py \ 119 | -d market1501 --logs-dir logs/spcl_usl/market_resnet50 120 | 121 | ### MSMT17 ### 122 | # use all default settings except for iters=800 123 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 124 | python examples/spcl_train_usl.py --iters 800 \ 125 | -d msmt17 --logs-dir logs/spcl_usl/msmt_resnet50 126 | 127 | ### VeRi-776 ### 128 | # use all default settings except for iters=800, height=224 and width=224 129 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 130 | python examples/spcl_train_usl.py --iters 800 --height 224 --width 224 \ 131 | -d veri --logs-dir logs/spcl_usl/veri_resnet50 132 | ``` 133 | 134 | 135 | ## Evaluation 136 | 137 | We utilize 1 GTX-1080TI GPU for testing. **Note that** 138 | 139 | + use `--width 128 --height 256` (default) for person datasets, and `--height 224 --width 224` for vehicle datasets; 140 | + use `--dsbn` for domain adaptive models, and add `--test-source` if you want to test on the source domain; 141 | + use `-a resnet50` (default) for the backbone of ResNet-50, and `-a resnet_ibn50a` for the backbone of IBN-ResNet. 142 | 143 | ### Unsupervised Domain Adaptation 144 | 145 | To evaluate the domain adaptive model on the **target-domain** dataset, run: 146 | ```shell 147 | CUDA_VISIBLE_DEVICES=0 \ 148 | python examples/test.py --dsbn \ 149 | -d $DATASET --resume $PATH_OF_MODEL 150 | ``` 151 | 152 | To evaluate the domain adaptive model on the **source-domain** dataset, run: 153 | ```shell 154 | CUDA_VISIBLE_DEVICES=0 \ 155 | python examples/test.py --dsbn --test-source \ 156 | -d $DATASET --resume $PATH_OF_MODEL 157 | ``` 158 | 159 | **Some examples:** 160 | ```shell 161 | ### Market-1501 -> MSMT17 ### 162 | # test on the target domain 163 | CUDA_VISIBLE_DEVICES=0 \ 164 | python examples/test.py --dsbn \ 165 | -d msmt17 --resume logs/spcl_uda/market2msmt_resnet50/model_best.pth.tar 166 | # test on the source domain 167 | CUDA_VISIBLE_DEVICES=0 \ 168 | python examples/test.py --dsbn --test-source \ 169 | -d market1501 --resume logs/spcl_uda/market2msmt_resnet50/model_best.pth.tar 170 | ``` 171 | 172 | ### Unsupervised Learning 173 | To evaluate the model, run: 174 | ```shell 175 | CUDA_VISIBLE_DEVICES=0 \ 176 | python examples/test.py \ 177 | -d $DATASET --resume $PATH 178 | ``` 179 | 180 | **Some examples:** 181 | ```shell 182 | ### Market-1501 ### 183 | CUDA_VISIBLE_DEVICES=0 \ 184 | python examples/test.py \ 185 | -d market1501 --resume logs/spcl_usl/market_resnet50/model_best.pth.tar 186 | ``` 187 | 188 | ## Trained Models 189 | 190 | ![framework](figs/results.png) 191 | 192 | You can download the above models in the paper from [[Google Drive]](https://drive.google.com/drive/folders/1ryx-fPGjrexwm9ZP9QO3Qk4SKzNqbaXw?usp=sharing) or [[Baidu Yun]](https://pan.baidu.com/s/1FInOhEdQsOEk-1oMWWB0Ag)(password: w3l9). 193 | 194 | 195 | ## Citation 196 | If you find this code useful for your research, please cite our paper 197 | ``` 198 | @inproceedings{ge2020selfpaced, 199 | title={Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID}, 200 | author={Yixiao Ge and Feng Zhu and Dapeng Chen and Rui Zhao and Hongsheng Li}, 201 | booktitle={Advances in Neural Information Processing Systems}, 202 | year={2020} 203 | } 204 | ``` 205 | -------------------------------------------------------------------------------- /SpCL-master/examples/otla_tool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | from PIL import Image 4 | import numpy as np 5 | import collections 6 | import torch 7 | 8 | 9 | def mkdir_if_missing(dir_path): 10 | """ 11 | Create file if missing. 12 | """ 13 | try: 14 | os.makedirs(dir_path) 15 | except OSError as e: 16 | if e.errno != errno.EEXIST: 17 | raise 18 | 19 | 20 | def save_checkpoint_pseudo_label(state, fpath="checkpoint.pth.tar"): 21 | """ 22 | Save model for generating pseudo label. 23 | """ 24 | mkdir_if_missing(os.path.dirname(fpath)) 25 | torch.save(state, fpath) 26 | 27 | 28 | def mask_outlier(train_pseudo_label): 29 | """ 30 | Mask outlier data of clustering results. 31 | """ 32 | index2label = collections.defaultdict(int) 33 | for label in train_pseudo_label: 34 | index2label[label.item()] += 1 35 | nums = np.fromiter(index2label.values(), dtype=float) 36 | label = np.fromiter(index2label.keys(), dtype=float) 37 | train_label = label[nums > 1] 38 | 39 | return np.array([i in train_label for i in train_pseudo_label]) 40 | 41 | 42 | def R_gt(train_real_label, train_pseudo_label): 43 | ''' 44 | The Average Maximum Proportion of Ground-truth Classes (R_gt) in supplementary material. 45 | ''' 46 | p = 0 47 | mask = mask_outlier(train_pseudo_label) 48 | train_real_label = train_real_label[mask] 49 | ids_container = list(np.unique(train_real_label)) 50 | id2label = {id_: label for label, id_ in enumerate(ids_container)} 51 | for i, label in enumerate(train_real_label): 52 | train_real_label[i] = id2label[label] 53 | train_pseudo_label = train_pseudo_label[mask] 54 | ids_container = list(np.unique(train_pseudo_label)) 55 | id2label = {id_: label for label, id_ in enumerate(ids_container)} 56 | for i, label in enumerate(train_pseudo_label): 57 | train_pseudo_label[i] = id2label[label] 58 | for i in range(np.unique(train_real_label).size): 59 | sample_id = (train_real_label == i) 60 | sample_label = train_pseudo_label[sample_id] 61 | sample_num_per_label = np.zeros(np.unique(train_pseudo_label).size) 62 | for j in sample_label: 63 | sample_num_per_label[j] += 1 64 | p_i = np.max(sample_num_per_label) / sample_label.size 65 | p += p_i 66 | p = p / np.unique(train_real_label).size 67 | print("R_gt: {:.4f}".format(p)) 68 | 69 | return p 70 | 71 | 72 | def R_ct(train_real_label, train_pseudo_label): 73 | ''' 74 | The Average Maximum Proportion of Pseudo Classes (R_ct) in supplementary material. 75 | ''' 76 | p = 0 77 | mask = mask_outlier(train_pseudo_label) 78 | train_real_label = train_real_label[mask] 79 | ids_container = list(np.unique(train_real_label)) 80 | id2label = {id_: label for label, id_ in enumerate(ids_container)} 81 | for i, label in enumerate(train_real_label): 82 | train_real_label[i] = id2label[label] 83 | train_pseudo_label = train_pseudo_label[mask] 84 | ids_container = list(np.unique(train_pseudo_label)) 85 | id2label = {id_: label for label, id_ in enumerate(ids_container)} 86 | for i, label in enumerate(train_pseudo_label): 87 | train_pseudo_label[i] = id2label[label] 88 | for i in range(np.unique(train_pseudo_label).size): 89 | sample_id = (train_pseudo_label == i) 90 | sample_label = train_real_label[sample_id] 91 | sample_num_per_label = np.zeros(np.unique(train_real_label).size) 92 | for j in sample_label: 93 | sample_num_per_label[j] += 1 94 | p_i = np.max(sample_num_per_label) / sample_label.size 95 | p += p_i 96 | p = p / np.unique(train_pseudo_label).size 97 | print("R_ct: {:.4f}".format(p)) 98 | 99 | return p 100 | 101 | 102 | def P_v(train_real_label, train_pseudo_label): 103 | ''' 104 | The Proportion of Visible Training Samples (P_v). 105 | ''' 106 | len_data = len(train_real_label) 107 | mask = mask_outlier(train_pseudo_label) 108 | len_mask_data = len(train_pseudo_label[mask]) 109 | p = len_mask_data / len_data 110 | print("P_v: {:.4f}, total samples: {}, total samples without outliers: {}".format(p, len_data, len_mask_data)) 111 | 112 | return p 113 | 114 | 115 | def Q_v(train_real_label, train_pseudo_label): 116 | mask = mask_outlier(train_pseudo_label) 117 | n_class = np.unique(train_real_label).size 118 | n_cluster_class = np.unique(train_pseudo_label[mask]).size 119 | p = np.min((n_class, n_cluster_class)) / np.max((n_class, n_cluster_class)) 120 | print("Q_v: {:.4f}, number of real classes: {}, number of pseudo classes: {}".format(p, n_class, n_cluster_class)) 121 | 122 | return p 123 | 124 | 125 | def R_plq(train_real_label, train_pseudo_label): 126 | ''' 127 | The Final Metric (R_plq) in supplementary material. 128 | ''' 129 | R_gt_p = R_gt(train_real_label, train_pseudo_label) 130 | R_ct_p = R_ct(train_real_label, train_pseudo_label) 131 | P_v_p = P_v(train_real_label, train_pseudo_label) 132 | Q_v_p = Q_v(train_real_label, train_pseudo_label) 133 | R_plq_p = (R_gt_p + R_ct_p) / 2 * P_v_p * Q_v_p 134 | print("R_plq: {:.4f}".format(R_plq_p)) 135 | 136 | return R_gt_p, R_ct_p, P_v_p, Q_v_p, R_plq_p 137 | 138 | 139 | def save_image_label(train_image_path, train_pseudo_label, train_real_label, model, epoch, logs_dir, save_path, 140 | img_size=(144, 288), source_domain="market1501", target_domain="sysumm01_rgb", method_name="spcl_uda"): 141 | train_image = [] 142 | for fname in train_image_path: 143 | img = Image.open(fname) 144 | img = img.resize(img_size, Image.ANTIALIAS) 145 | pix_array = np.array(img) 146 | train_image.append(pix_array) 147 | 148 | train_image = np.array(train_image) 149 | train_pseudo_label = np.array(train_pseudo_label) 150 | train_real_label = np.array(train_real_label) 151 | 152 | ids_container = list(np.unique(train_pseudo_label)) 153 | id2label = {id_: label for label, id_ in enumerate(ids_container)} 154 | for i, label in enumerate(train_pseudo_label): 155 | train_pseudo_label[i] = id2label[label] 156 | 157 | R_gt_p, R_ct_p, P_v_p, Q_v_p, R_plq_p = R_plq(train_real_label, train_pseudo_label) 158 | 159 | np.save(os.path.join(save_path, method_name+"_"+source_domain+"TO"+target_domain+"_"+"train_rgb_resized_img.npy"), train_image) 160 | np.save(os.path.join(save_path, method_name+"_"+source_domain+"TO"+target_domain+"_"+"train_rgb_resized_label.npy"), train_pseudo_label) 161 | 162 | save_checkpoint_pseudo_label({ 163 | "state_dict": model.state_dict(), 164 | "epoch": epoch, 165 | "R_gt": R_gt_p, 166 | "R_ct": R_ct_p, 167 | "P_v": P_v_p, 168 | "Q_v": Q_v_p, 169 | "R_plq": R_plq_p, 170 | }, fpath=os.path.join(logs_dir, "checkpoint_pseudo_label.pth.tar")) -------------------------------------------------------------------------------- /SpCL-master/examples/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import random 5 | import numpy as np 6 | import sys 7 | 8 | import torch 9 | from torch import nn 10 | from torch.backends import cudnn 11 | from torch.utils.data import DataLoader 12 | 13 | from spcl import datasets 14 | from spcl import models 15 | from spcl.models.dsbn import convert_dsbn, convert_bn 16 | from spcl.evaluators import Evaluator 17 | from spcl.utils.data import transforms as T 18 | from spcl.utils.data.preprocessor import Preprocessor 19 | from spcl.utils.logging import Logger 20 | from spcl.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict 21 | 22 | 23 | def get_data(name, data_dir, height, width, batch_size, workers): 24 | root = osp.join(data_dir, name) 25 | 26 | dataset = datasets.create(name, root) 27 | 28 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 29 | std=[0.229, 0.224, 0.225]) 30 | 31 | test_transformer = T.Compose([ 32 | T.Resize((height, width), interpolation=3), 33 | T.ToTensor(), 34 | normalizer 35 | ]) 36 | 37 | test_loader = DataLoader( 38 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 39 | root=dataset.images_dir, transform=test_transformer), 40 | batch_size=batch_size, num_workers=workers, 41 | shuffle=False, pin_memory=True) 42 | 43 | return dataset, test_loader 44 | 45 | 46 | def main(): 47 | args = parser.parse_args() 48 | 49 | if args.seed is not None: 50 | random.seed(args.seed) 51 | np.random.seed(args.seed) 52 | torch.manual_seed(args.seed) 53 | cudnn.deterministic = True 54 | 55 | main_worker(args) 56 | 57 | 58 | def main_worker(args): 59 | cudnn.benchmark = True 60 | 61 | log_dir = osp.dirname(args.resume) 62 | sys.stdout = Logger(osp.join(log_dir, 'log_test.txt')) 63 | print("==========\nArgs:{}\n==========".format(args)) 64 | 65 | # Create data loaders 66 | dataset, test_loader = get_data(args.dataset, args.data_dir, args.height, 67 | args.width, args.batch_size, args.workers) 68 | 69 | # Create model 70 | model = models.create(args.arch, pretrained=False, num_features=args.features, dropout=args.dropout, num_classes=0) 71 | if args.dsbn: 72 | print("==> Load the model with domain-specific BNs") 73 | convert_dsbn(model) 74 | 75 | # Load from checkpoint 76 | checkpoint = load_checkpoint(args.resume) 77 | copy_state_dict(checkpoint['state_dict'], model, strip='module.') 78 | 79 | if args.dsbn: 80 | print("==> Test with {}-domain BNs".format("source" if args.test_source else "target")) 81 | convert_bn(model, use_target=(not args.test_source)) 82 | 83 | model.cuda() 84 | model = nn.DataParallel(model) 85 | 86 | # Evaluator 87 | model.eval() 88 | evaluator = Evaluator(model) 89 | print("Test on {}:".format(args.dataset)) 90 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=True, rerank=args.rerank) 91 | return 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser(description="Testing the model") 95 | # data 96 | parser.add_argument('-d', '--dataset', type=str, required=True, 97 | choices=datasets.names()) 98 | parser.add_argument('-b', '--batch-size', type=int, default=256) 99 | parser.add_argument('-j', '--workers', type=int, default=4) 100 | parser.add_argument('--height', type=int, default=256, help="input height") 101 | parser.add_argument('--width', type=int, default=128, help="input width") 102 | # model 103 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 104 | choices=models.names()) 105 | parser.add_argument('--features', type=int, default=0) 106 | parser.add_argument('--dropout', type=float, default=0) 107 | parser.add_argument('--resume', type=str, required=True, metavar='PATH') 108 | # testing configs 109 | parser.add_argument('--rerank', action='store_true', 110 | help="evaluation only") 111 | parser.add_argument('--dsbn', action='store_true', 112 | help="test on the model with domain-specific BN") 113 | parser.add_argument('--test-source', action='store_true', 114 | help="test on the source domain") 115 | parser.add_argument('--seed', type=int, default=1) 116 | # path 117 | working_dir = osp.dirname(osp.abspath(__file__)) 118 | parser.add_argument('--data-dir', type=str, metavar='PATH', 119 | default=osp.join(working_dir, 'data')) 120 | main() 121 | -------------------------------------------------------------------------------- /SpCL-master/figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjm-wjm/OTLA-ReID/1405e96bd8339deeebf61718346b24722770ac61/SpCL-master/figs/framework.png -------------------------------------------------------------------------------- /SpCL-master/figs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjm-wjm/OTLA-ReID/1405e96bd8339deeebf61718346b24722770ac61/SpCL-master/figs/results.png -------------------------------------------------------------------------------- /SpCL-master/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /SpCL-master/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup(name='SpCL', 5 | version='1.0.0', 6 | description='Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID', 7 | author='Yixiao Ge', 8 | author_email='geyixiao831@gmail.com', 9 | url='https://github.com/yxgeee/SpCL', 10 | install_requires=[ 11 | 'numpy', 'torch', 'torchvision', 12 | 'six', 'h5py', 'Pillow', 'scipy', 13 | 'scikit-learn', 'metric-learn', 'faiss_gpu==1.6.3'], 14 | packages=find_packages(), 15 | keywords=[ 16 | 'Unsupervised Learning', 17 | 'Unsupervised Domain Adaptation', 18 | 'Contrastive Learning', 19 | 'Object Re-identification' 20 | ]) 21 | -------------------------------------------------------------------------------- /SpCL-master/spcl/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import models 6 | from . import utils 7 | from . import evaluators 8 | from . import trainers 9 | 10 | __version__ = '0.1.0' 11 | -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .market1501 import Market1501 5 | from .msmt17 import MSMT17 6 | from .personx import PersonX 7 | from .veri import VeRi 8 | from .vehicleid import VehicleID 9 | from .vehiclex import VehicleX 10 | from .dukemtmc import DukeMTMC 11 | from .sysumm01 import SYSU_MM01 12 | from .sysumm01_rgb import SYSU_MM01_RGB 13 | from .sysumm01_ir import SYSU_MM01_IR 14 | from .regdb import RegDB 15 | from .regdb_rgb import RegDB_RGB 16 | from .regdb_ir import RegDB_IR 17 | 18 | 19 | __factory = { 20 | 'market1501': Market1501, 21 | 'msmt17': MSMT17, 22 | 'personx': PersonX, 23 | 'veri': VeRi, 24 | 'vehicleid': VehicleID, 25 | 'vehiclex': VehicleX, 26 | 'dukemtmc': DukeMTMC, 27 | 'sysumm01': SYSU_MM01, 28 | 'sysumm01_rgb': SYSU_MM01_RGB, 29 | 'sysumm01_ir': SYSU_MM01_IR, 30 | 'regdb': RegDB, 31 | 'regdb_rgb': RegDB_RGB, 32 | 'regdb_ir': RegDB_IR 33 | } 34 | 35 | 36 | def names(): 37 | return sorted(__factory.keys()) 38 | 39 | 40 | def create(name, root, *args, **kwargs): 41 | """ 42 | Create a dataset instance. 43 | 44 | Parameters 45 | ---------- 46 | name : str 47 | The dataset name. 48 | root : str 49 | The path to the dataset directory. 50 | split_id : int, optional 51 | The index of data split. Default: 0 52 | num_val : int or float, optional 53 | When int, it means the number of validation identities. When float, 54 | it means the proportion of validation to all the trainval. Default: 100 55 | download : bool, optional 56 | If True, will download the dataset. Default: False 57 | """ 58 | if name not in __factory: 59 | raise KeyError("Unknown dataset:", name) 60 | return __factory[name](root, *args, **kwargs) 61 | 62 | 63 | def get_dataset(name, root, *args, **kwargs): 64 | warnings.warn("get_dataset is deprecated. Use create instead.") 65 | return create(name, root, *args, **kwargs) 66 | -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | 13 | class DukeMTMC(BaseImageDataset): 14 | """ 15 | DukeMTMC-reID 16 | Reference: 17 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 18 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 19 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 20 | 21 | Dataset statistics: 22 | # identities: 1404 (train + query) 23 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 24 | # cameras: 8 25 | """ 26 | dataset_dir = '.' 27 | 28 | def __init__(self, root, verbose=True, **kwargs): 29 | super(DukeMTMC, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 32 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 33 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 35 | 36 | self._download_data() 37 | self._check_before_run() 38 | 39 | train = self._process_dir(self.train_dir, relabel=True) 40 | query = self._process_dir(self.query_dir, relabel=False) 41 | gallery = self._process_dir(self.gallery_dir, relabel=False) 42 | 43 | if verbose: 44 | print("=> DukeMTMC-reID loaded") 45 | self.print_dataset_statistics(train, query, gallery) 46 | 47 | self.train = train 48 | self.query = query 49 | self.gallery = gallery 50 | 51 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 52 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 53 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 54 | 55 | def _download_data(self): 56 | if osp.exists(self.dataset_dir): 57 | print("This dataset has been downloaded.") 58 | return 59 | 60 | print("Creating directory {}".format(self.dataset_dir)) 61 | mkdir_if_missing(self.dataset_dir) 62 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 63 | 64 | print("Downloading DukeMTMC-reID dataset") 65 | urllib.request.urlretrieve(self.dataset_url, fpath) 66 | 67 | print("Extracting files") 68 | zip_ref = zipfile.ZipFile(fpath, 'r') 69 | zip_ref.extractall(self.dataset_dir) 70 | zip_ref.close() 71 | 72 | def _check_before_run(self): 73 | """Check if all files are available before going deeper""" 74 | if not osp.exists(self.dataset_dir): 75 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 76 | if not osp.exists(self.train_dir): 77 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 78 | if not osp.exists(self.query_dir): 79 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 80 | if not osp.exists(self.gallery_dir): 81 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 82 | 83 | def _process_dir(self, dir_path, relabel=False): 84 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 85 | pattern = re.compile(r'([-\d]+)_c(\d)') 86 | 87 | pid_container = set() 88 | for img_path in img_paths: 89 | pid, _ = map(int, pattern.search(img_path).groups()) 90 | pid_container.add(pid) 91 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 92 | 93 | dataset = [] 94 | for img_path in img_paths: 95 | pid, camid = map(int, pattern.search(img_path).groups()) 96 | assert 1 <= camid <= 8 97 | camid -= 1 # index starts from 0 98 | if relabel: pid = pid2label[pid] 99 | dataset.append((img_path, pid, camid)) 100 | 101 | return dataset 102 | -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | class Market1501(BaseImageDataset): 13 | """ 14 | Market1501 15 | Reference: 16 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 17 | URL: http://www.liangzheng.org/Project/project_reid.html 18 | 19 | Dataset statistics: 20 | # identities: 1501 (+1 for background) 21 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 22 | """ 23 | dataset_dir = 'Market-1501-v15.09.15' 24 | 25 | def __init__(self, root, verbose=True, **kwargs): 26 | super(Market1501, self).__init__() 27 | self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 31 | 32 | self._check_before_run() 33 | 34 | train = self._process_dir(self.train_dir, relabel=True) 35 | query = self._process_dir(self.query_dir, relabel=False) 36 | gallery = self._process_dir(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print("=> Market1501 loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _process_dir(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c(\d)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | dataset = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1: continue # junk images are just ignored 76 | assert 0 <= pid <= 1501 # pid == 0 means background 77 | assert 1 <= camid <= 6 78 | camid -= 1 # index starts from 0 79 | if relabel: pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import tarfile 4 | 5 | import glob 6 | import re 7 | import urllib 8 | import zipfile 9 | 10 | from ..utils.osutils import mkdir_if_missing 11 | from ..utils.serialization import write_json 12 | 13 | 14 | def _pluck_msmt(list_file, subdir, pattern=re.compile(r'([-\d]+)_([-\d]+)_([-\d]+)')): 15 | with open(list_file, 'r') as f: 16 | lines = f.readlines() 17 | ret = [] 18 | pids = [] 19 | for line in lines: 20 | line = line.strip() 21 | fname = line.split(' ')[0] 22 | pid, _, cam = map(int, pattern.search(osp.basename(fname)).groups()) 23 | if pid not in pids: 24 | pids.append(pid) 25 | ret.append((osp.join(subdir,fname), pid, cam)) 26 | return ret, pids 27 | 28 | class Dataset_MSMT(object): 29 | def __init__(self, root): 30 | self.root = root 31 | self.train, self.val, self.trainval = [], [], [] 32 | self.query, self.gallery = [], [] 33 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 34 | 35 | @property 36 | def images_dir(self): 37 | return osp.join(self.root, 'MSMT17_V1') 38 | 39 | def load(self, verbose=True): 40 | exdir = osp.join(self.root, 'MSMT17_V1') 41 | self.train, train_pids = _pluck_msmt(osp.join(exdir, 'list_train.txt'), 'train') 42 | self.val, val_pids = _pluck_msmt(osp.join(exdir, 'list_val.txt'), 'train') 43 | self.train = self.train + self.val 44 | self.query, query_pids = _pluck_msmt(osp.join(exdir, 'list_query.txt'), 'test') 45 | self.gallery, gallery_pids = _pluck_msmt(osp.join(exdir, 'list_gallery.txt'), 'test') 46 | self.num_train_pids = len(list(set(train_pids).union(set(val_pids)))) 47 | 48 | if verbose: 49 | print(self.__class__.__name__, "dataset loaded") 50 | print(" subset | # ids | # images") 51 | print(" ---------------------------") 52 | print(" train | {:5d} | {:8d}" 53 | .format(self.num_train_pids, len(self.train))) 54 | print(" query | {:5d} | {:8d}" 55 | .format(len(query_pids), len(self.query))) 56 | print(" gallery | {:5d} | {:8d}" 57 | .format(len(gallery_pids), len(self.gallery))) 58 | 59 | class MSMT17(Dataset_MSMT): 60 | 61 | def __init__(self, root, split_id=0, download=True): 62 | super(MSMT17, self).__init__(root) 63 | 64 | if download: 65 | self.download() 66 | 67 | self.load() 68 | 69 | def download(self): 70 | 71 | import re 72 | import hashlib 73 | import shutil 74 | from glob import glob 75 | from zipfile import ZipFile 76 | 77 | raw_dir = osp.join(self.root) 78 | mkdir_if_missing(raw_dir) 79 | 80 | # Download the raw zip file 81 | fpath = osp.join(raw_dir, 'MSMT17_V1') 82 | if osp.isdir(fpath): 83 | print("Using downloaded file: " + fpath) 84 | else: 85 | raise RuntimeError("Please download the dataset manually to {}".format(fpath)) 86 | -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/personx.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | class PersonX(BaseImageDataset): 13 | """ 14 | PersonX 15 | Reference: 16 | Sun et al. Dissecting Person Re-identification from the Viewpoint of Viewpoint. CVPR 2019. 17 | 18 | Dataset statistics: 19 | # identities: 1266 20 | # images: 9840 (train) + 5136 (query) + 30816 (gallery) 21 | """ 22 | dataset_dir = 'PersonX' 23 | 24 | def __init__(self, root, verbose=True, **kwargs): 25 | super(PersonX, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 30 | 31 | self._check_before_run() 32 | 33 | train = self._process_dir(self.train_dir, relabel=True) 34 | query = self._process_dir(self.query_dir, relabel=False) 35 | gallery = self._process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | print("=> PersonX loaded") 39 | self.print_dataset_statistics(train, query, gallery) 40 | 41 | self.train = train 42 | self.query = query 43 | self.gallery = gallery 44 | 45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 48 | 49 | def _check_before_run(self): 50 | """Check if all files are available before going deeper""" 51 | if not osp.exists(self.dataset_dir): 52 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 53 | if not osp.exists(self.train_dir): 54 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 55 | if not osp.exists(self.query_dir): 56 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 57 | if not osp.exists(self.gallery_dir): 58 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 59 | 60 | def _process_dir(self, dir_path, relabel=False): 61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 62 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 63 | cam2label = {3:1, 4:2, 8:3, 10:4, 11:5, 12:6} 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | pid_container.add(pid) 69 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 70 | 71 | dataset = [] 72 | for img_path in img_paths: 73 | pid, camid = map(int, pattern.search(img_path).groups()) 74 | assert (camid in cam2label.keys()) 75 | camid = cam2label[camid] 76 | camid -= 1 # index starts from 0 77 | if relabel: pid = pid2label[pid] 78 | dataset.append((img_path, pid, camid)) 79 | 80 | return dataset 81 | -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/regdb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Mar 31 23:02:42 2021 5 | 6 | @author: vision 7 | """ 8 | 9 | 10 | from __future__ import print_function, absolute_import 11 | import os.path as osp 12 | import os 13 | import random 14 | from glob import glob 15 | import re 16 | import urllib 17 | import zipfile 18 | 19 | from ..utils.data import BaseImageDataset 20 | from ..utils.osutils import mkdir_if_missing 21 | from ..utils.serialization import write_json 22 | 23 | class RegDB(BaseImageDataset): 24 | dataset_dir = "RegDB" 25 | 26 | def __init__(self, root, verbose=True, ii=1, mode='', **kwargs): 27 | super(RegDB, self).__init__() 28 | 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | self.ii = ii 31 | self.index_train_RGB = self.loadIdx(open((self.dataset_dir+'/idx/train_visible_{}.txt').format(self.ii),'r')) 32 | self.index_train_IR = self.loadIdx(open((self.dataset_dir+'/idx/train_thermal_{}.txt').format(self.ii),'r')) 33 | self.index_test_RGB = self.loadIdx(open((self.dataset_dir+'/idx/test_visible_{}.txt').format(self.ii),'r')) 34 | self.index_test_IR = self.loadIdx(open((self.dataset_dir+'/idx/test_thermal_{}.txt').format(self.ii),'r')) 35 | 36 | self.train = self._process_dir(self.index_train_RGB, 0, 0) + self._process_dir(self.index_train_IR, 1, 0) 37 | if mode == 't2v': 38 | self.query = self._process_dir(self.index_test_IR, 1, 206) 39 | self.gallery = self._process_dir(self.index_test_RGB, 0, 206) 40 | elif mode == 'v2t': 41 | self.query = self._process_dir(self.index_test_RGB, 0, 206) 42 | self.gallery = self._process_dir(self.index_test_IR, 1, 206) 43 | 44 | if verbose: 45 | print("=> RegDB loaded trial:{}".format(ii)) 46 | self.print_dataset_statistics(self.train, self.query, self.gallery) 47 | 48 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 49 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 50 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 51 | 52 | def _check_before_run(self): 53 | """Check if all files are available before going deeper""" 54 | if not osp.exists(self.dataset_dir): 55 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 56 | if not osp.exists(self.train_dir): 57 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 58 | if not osp.exists(self.val_dir): 59 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 60 | if not osp.exists(self.text_dir): 61 | raise RuntimeError("'{}' is not available".format(self.text_dir)) 62 | 63 | def loadIdx(self, index): 64 | Lines = index.readlines() 65 | idx = [] 66 | for line in Lines: 67 | tmp = line.strip('\n') 68 | tmp = tmp.split(' ') 69 | idx.append(tmp) 70 | return idx 71 | 72 | def _process_dir(self, index, cam, delta): 73 | dataset = [] 74 | for idx in index: 75 | fname = osp.join(self.dataset_dir, idx[0]) 76 | pid = int(idx[1]) + delta 77 | dataset.append((fname, pid, cam)) 78 | return dataset 79 | 80 | -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/regdb_ir.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Mar 31 23:02:42 2021 5 | 6 | @author: vision 7 | """ 8 | 9 | 10 | from __future__ import print_function, absolute_import 11 | import os.path as osp 12 | import os 13 | import random 14 | from glob import glob 15 | import re 16 | import urllib 17 | import zipfile 18 | 19 | from ..utils.data import BaseImageDataset 20 | from ..utils.osutils import mkdir_if_missing 21 | from ..utils.serialization import write_json 22 | 23 | class RegDB_IR(BaseImageDataset): 24 | dataset_dir = "RegDB" 25 | 26 | def __init__(self, root, verbose=True, ii=1, mode='', **kwargs): 27 | super(RegDB_IR, self).__init__() 28 | 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | self.ii = ii 31 | self.index_train_RGB = self.loadIdx(open((self.dataset_dir+'/idx/train_visible_{}.txt').format(self.ii),'r')) 32 | self.index_train_IR = self.loadIdx(open((self.dataset_dir+'/idx/train_thermal_{}.txt').format(self.ii),'r')) 33 | self.index_test_RGB = self.loadIdx(open((self.dataset_dir+'/idx/test_visible_{}.txt').format(self.ii),'r')) 34 | self.index_test_IR = self.loadIdx(open((self.dataset_dir+'/idx/test_thermal_{}.txt').format(self.ii),'r')) 35 | 36 | self.train = self._process_dir(self.index_train_IR, 1, 0) 37 | if mode == 't2v': 38 | self.query = self._process_dir(self.index_test_IR, 1, 206) 39 | self.gallery = self._process_dir(self.index_test_RGB, 0, 206) 40 | elif mode == 'v2t': 41 | self.query = self._process_dir(self.index_test_RGB, 0, 206) 42 | self.gallery = self._process_dir(self.index_test_IR, 1, 206) 43 | 44 | if verbose: 45 | print("=> RegDB IR loaded trial:{}".format(ii)) 46 | self.print_dataset_statistics(self.train, self.query, self.gallery) 47 | 48 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 49 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 50 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 51 | 52 | def _check_before_run(self): 53 | """Check if all files are available before going deeper""" 54 | if not osp.exists(self.dataset_dir): 55 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 56 | if not osp.exists(self.train_dir): 57 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 58 | if not osp.exists(self.val_dir): 59 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 60 | if not osp.exists(self.text_dir): 61 | raise RuntimeError("'{}' is not available".format(self.text_dir)) 62 | 63 | def loadIdx(self, index): 64 | Lines = index.readlines() 65 | idx = [] 66 | for line in Lines: 67 | tmp = line.strip('\n') 68 | tmp = tmp.split(' ') 69 | idx.append(tmp) 70 | return idx 71 | 72 | def _process_dir(self, index, cam, delta): 73 | dataset = [] 74 | for idx in index: 75 | fname = osp.join(self.dataset_dir, idx[0]) 76 | pid = int(idx[1]) + delta 77 | dataset.append((fname, pid, cam)) 78 | return dataset 79 | 80 | -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/regdb_rgb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Mar 31 23:02:42 2021 5 | 6 | @author: vision 7 | """ 8 | 9 | 10 | from __future__ import print_function, absolute_import 11 | import os.path as osp 12 | import os 13 | import random 14 | from glob import glob 15 | import re 16 | import urllib 17 | import zipfile 18 | 19 | from ..utils.data import BaseImageDataset 20 | from ..utils.osutils import mkdir_if_missing 21 | from ..utils.serialization import write_json 22 | 23 | class RegDB_RGB(BaseImageDataset): 24 | dataset_dir = "RegDB" 25 | 26 | def __init__(self, root, verbose=True, ii=1, mode='t2v', **kwargs): 27 | super(RegDB_RGB, self).__init__() 28 | 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | self.ii = ii 31 | self.index_train_RGB = self.loadIdx(open((self.dataset_dir+'/idx/train_visible_{}.txt').format(self.ii),'r')) 32 | self.index_train_IR = self.loadIdx(open((self.dataset_dir+'/idx/train_thermal_{}.txt').format(self.ii),'r')) 33 | self.index_test_RGB = self.loadIdx(open((self.dataset_dir+'/idx/test_visible_{}.txt').format(self.ii),'r')) 34 | self.index_test_IR = self.loadIdx(open((self.dataset_dir+'/idx/test_thermal_{}.txt').format(self.ii),'r')) 35 | 36 | self.train = self._process_dir(self.index_train_RGB, 0, 0) 37 | if mode == 't2v': 38 | self.query = self._process_dir(self.index_test_IR, 1, 206) 39 | self.gallery = self._process_dir(self.index_test_RGB, 0, 206) 40 | elif mode == 'v2t': 41 | self.query = self._process_dir(self.index_test_RGB, 0, 206) 42 | self.gallery = self._process_dir(self.index_test_IR, 1, 206) 43 | 44 | if verbose: 45 | print("=> RegDB RGB loaded trial:{}".format(ii)) 46 | self.print_dataset_statistics(self.train, self.query, self.gallery) 47 | 48 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 49 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 50 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 51 | 52 | def _check_before_run(self): 53 | """Check if all files are available before going deeper""" 54 | if not osp.exists(self.dataset_dir): 55 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 56 | if not osp.exists(self.train_dir): 57 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 58 | if not osp.exists(self.val_dir): 59 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 60 | if not osp.exists(self.text_dir): 61 | raise RuntimeError("'{}' is not available".format(self.text_dir)) 62 | 63 | def loadIdx(self, index): 64 | Lines = index.readlines() 65 | idx = [] 66 | for line in Lines: 67 | tmp = line.strip('\n') 68 | tmp = tmp.split(' ') 69 | idx.append(tmp) 70 | return idx 71 | 72 | def _process_dir(self, index, cam, delta): 73 | dataset = [] 74 | for idx in index: 75 | fname = osp.join(self.dataset_dir, idx[0]) 76 | pid = int(idx[1]) + delta 77 | dataset.append((fname, pid, cam)) 78 | return dataset 79 | 80 | -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/sysumm01.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Mar 20 20:21:24 2021 5 | 6 | @author: vision 7 | """ 8 | 9 | 10 | from __future__ import print_function, absolute_import 11 | import os.path as osp 12 | import re 13 | import random 14 | import numpy as np 15 | from glob import glob 16 | 17 | from ..utils.data import BaseImageDataset 18 | 19 | 20 | class SYSU_MM01(BaseImageDataset): 21 | dataset_dir = "SYSU-MM01" 22 | 23 | def __init__(self, root='', verbose=True, pid_begin=0, mode='all', **kwargs): 24 | super(SYSU_MM01, self).__init__() 25 | 26 | self.pid_begin = pid_begin 27 | self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'exp/train_id.txt') 29 | self.val_dir = osp.join(self.dataset_dir, 'exp/val_id.txt') 30 | self.text_dir = osp.join(self.dataset_dir, 'exp/test_id.txt') 31 | 32 | self._check_before_run() 33 | 34 | self.train_id = self._get_id(self.train_dir) + self._get_id(self.val_dir) 35 | self.query_id = self._get_id(self.text_dir) 36 | self.gallery_id = self.query_id 37 | 38 | self.rgb_cams = ['cam1', 'cam2', 'cam4', 'cam5'] 39 | self.ir_cams = ['cam3', 'cam6'] 40 | self.train = self._process_dir(self.train_id, self.rgb_cams + self.ir_cams) 41 | self.query = self._process_dir(self.query_id, self.ir_cams) 42 | if mode == 'all': 43 | # self.gallery = self._process_dir(self.gallery_id, self.rgb_cams) 44 | self.gallery = self._process_dir_gallery(self.gallery_id, self.rgb_cams) 45 | elif mode == 'indoor': 46 | # self.gallery = self._process_dir(self.gallery_id, ['cam1', 'cam2']) 47 | self.gallery = self._process_dir_gallery(self.gallery_id, ['cam1', 'cam2']) 48 | 49 | if verbose: 50 | print("=> SYSU-MM01 loaded") 51 | self.print_dataset_statistics(self.train, self.query, self.gallery) 52 | 53 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 54 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 55 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 56 | 57 | def _check_before_run(self): 58 | """Check if all files are available before going deeper""" 59 | if not osp.exists(self.dataset_dir): 60 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 61 | if not osp.exists(self.train_dir): 62 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 63 | if not osp.exists(self.val_dir): 64 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 65 | if not osp.exists(self.text_dir): 66 | raise RuntimeError("'{}' is not available".format(self.text_dir)) 67 | 68 | def _get_id(self, file_path): 69 | with open(file_path, 'r') as f: 70 | ids = f.read().splitlines() 71 | ids = [int(y) for y in ids[0].split(',')] 72 | ids = ["%04d" % x for x in ids] 73 | return ids 74 | 75 | def _process_dir(self, ids, cams): 76 | ids_container = list(np.unique(ids)) 77 | id2label = {id_: label for label, id_ in enumerate(ids_container)} 78 | 79 | dataset = [] 80 | for id_ in sorted(ids): 81 | for cam in cams: 82 | img_dir = osp.join(self.dataset_dir, cam, id_) 83 | if osp.isdir(img_dir): 84 | img_list = glob(osp.join(img_dir, "*.jpg")) 85 | img_list.sort() 86 | for img_path in img_list: 87 | dataset.append((img_path, self.pid_begin + id2label[id_], int(cam[-1])-1)) 88 | return dataset 89 | 90 | def _process_dir_gallery(self, ids, cams): 91 | ids_container = list(np.unique(ids)) 92 | id2label = {id_: label for label, id_ in enumerate(ids_container)} 93 | 94 | dataset = [] 95 | for id_ in sorted(ids): 96 | for cam in cams: 97 | img_dir = osp.join(self.dataset_dir, cam, id_) 98 | if osp.isdir(img_dir): 99 | img_list = glob(osp.join(img_dir, "*.jpg")) 100 | img_list.sort() 101 | dataset.append((random.choice(img_list), self.pid_begin + id2label[id_], int(cam[-1])-1)) 102 | return dataset 103 | 104 | # def _process_train(self, train_path): 105 | # data = [] 106 | 107 | # file_path_list = ['cam1', 'cam2', 'cam4', 'cam5'] 108 | 109 | # for file_path in file_path_list: 110 | # camid = self.dataset_name + "_" + file_path 111 | # pid_list = os.listdir(os.path.join(train_path, file_path)) 112 | # for pid_dir in pid_list: 113 | # pid = self.dataset_name + "_" + pid_dir 114 | # img_list = glob(os.path.join(train_path, file_path, pid_dir, "*.jpg")) 115 | # for img_path in img_list: 116 | # data.append([img_path, pid, camid]) 117 | # return data -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/sysumm01_ir.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import re 4 | import random 5 | import numpy as np 6 | from glob import glob 7 | 8 | from ..utils.data import BaseImageDataset 9 | 10 | 11 | class SYSU_MM01_IR(BaseImageDataset): 12 | dataset_dir = "SYSU-MM01" 13 | 14 | def __init__(self, root='', verbose=True, ncl=1, mode='all', **kwargs): 15 | super(SYSU_MM01_IR, self).__init__() 16 | 17 | self.dataset_dir = osp.join(root, self.dataset_dir) 18 | self.train_dir = osp.join(self.dataset_dir, 'exp/train_id.txt') 19 | self.val_dir = osp.join(self.dataset_dir, 'exp/val_id.txt') 20 | self.text_dir = osp.join(self.dataset_dir, 'exp/test_id.txt') 21 | 22 | self._check_before_run() 23 | 24 | self.train_id = self._get_id(self.train_dir) + self._get_id(self.val_dir) 25 | self.query_id = self._get_id(self.text_dir) 26 | self.gallery_id = self.query_id 27 | 28 | self.rgb_cams = ['cam1', 'cam2', 'cam4', 'cam5'] 29 | self.ir_cams = ['cam3', 'cam6'] 30 | self.train = self._process_dir(self.train_id, self.ir_cams) 31 | self.query = self._process_dir(self.query_id, self.ir_cams) 32 | if mode == 'all': 33 | # self.gallery = self._process_dir(self.gallery_id, self.rgb_cams) 34 | self.gallery = self._process_dir_gallery(self.gallery_id, self.rgb_cams) 35 | elif mode == 'indoor': 36 | # self.gallery = self._process_dir(self.gallery_id, ['cam1', 'cam2']) 37 | self.gallery = self._process_dir_gallery(self.gallery_id, ['cam1', 'cam2']) 38 | 39 | if verbose: 40 | print("=> SYSU-MM01 IR loaded") 41 | self.print_dataset_statistics(self.train, self.query, self.gallery) 42 | 43 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 44 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 45 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 46 | 47 | def _check_before_run(self): 48 | """Check if all files are available before going deeper""" 49 | if not osp.exists(self.dataset_dir): 50 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 51 | if not osp.exists(self.train_dir): 52 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 53 | if not osp.exists(self.val_dir): 54 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 55 | if not osp.exists(self.text_dir): 56 | raise RuntimeError("'{}' is not available".format(self.text_dir)) 57 | 58 | def _get_id(self, file_path): 59 | with open(file_path, 'r') as f: 60 | ids = f.read().splitlines() 61 | ids = [int(y) for y in ids[0].split(',')] 62 | ids = ["%04d" % x for x in ids] 63 | return ids 64 | 65 | def _process_dir(self, ids, cams): 66 | ids_container = list(np.unique(ids)) 67 | id2label = {id_: label for label, id_ in enumerate(ids_container)} 68 | 69 | dataset = [] 70 | for id_ in sorted(ids): 71 | for cam in cams: 72 | img_dir = osp.join(self.dataset_dir, cam, id_) 73 | if osp.isdir(img_dir): 74 | img_list = glob(osp.join(img_dir, "*.jpg")) 75 | img_list.sort() 76 | for img_path in img_list: 77 | dataset.append((img_path, id2label[id_], int(cam[-1]) - 1)) 78 | return dataset 79 | 80 | def _process_dir_gallery(self, ids, cams): 81 | ids_container = list(np.unique(ids)) 82 | id2label = {id_: label for label, id_ in enumerate(ids_container)} 83 | 84 | dataset = [] 85 | for id_ in sorted(ids): 86 | for cam in cams: 87 | img_dir = osp.join(self.dataset_dir, cam, id_) 88 | if osp.isdir(img_dir): 89 | img_list = glob(osp.join(img_dir, "*.jpg")) 90 | img_list.sort() 91 | dataset.append((random.choice(img_list), id2label[id_], int(cam[-1]) - 1)) 92 | return dataset -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/sysumm01_rgb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import re 4 | import random 5 | import numpy as np 6 | from glob import glob 7 | 8 | from ..utils.data import BaseImageDataset 9 | 10 | 11 | class SYSU_MM01_RGB(BaseImageDataset): 12 | dataset_dir = "SYSU-MM01" 13 | 14 | def __init__(self, root='', verbose=True, ncl=1, mode='all', **kwargs): 15 | super(SYSU_MM01_RGB, self).__init__() 16 | 17 | self.dataset_dir = osp.join(root, self.dataset_dir) 18 | self.train_dir = osp.join(self.dataset_dir, 'exp/train_id.txt') 19 | self.val_dir = osp.join(self.dataset_dir, 'exp/val_id.txt') 20 | self.text_dir = osp.join(self.dataset_dir, 'exp/test_id.txt') 21 | 22 | self._check_before_run() 23 | 24 | self.train_id = self._get_id(self.train_dir) + self._get_id(self.val_dir) 25 | self.query_id = self._get_id(self.text_dir) 26 | self.gallery_id = self.query_id 27 | 28 | self.rgb_cams = ['cam1', 'cam2', 'cam4', 'cam5'] 29 | self.ir_cams = ['cam3', 'cam6'] 30 | self.train = self._process_dir(self.train_id, self.rgb_cams) 31 | self.query = self._process_dir(self.query_id, self.ir_cams) 32 | if mode == 'all': 33 | # self.gallery = self._process_dir(self.gallery_id, self.rgb_cams) 34 | self.gallery = self._process_dir_gallery(self.gallery_id, self.rgb_cams) 35 | elif mode == 'indoor': 36 | # self.gallery = self._process_dir(self.gallery_id, ['cam1', 'cam2']) 37 | self.gallery = self._process_dir_gallery(self.gallery_id, ['cam1', 'cam2']) 38 | 39 | if verbose: 40 | print("=> SYSU-MM01 RGB loaded") 41 | self.print_dataset_statistics(self.train, self.query, self.gallery) 42 | 43 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 44 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 45 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 46 | 47 | def _check_before_run(self): 48 | """Check if all files are available before going deeper""" 49 | if not osp.exists(self.dataset_dir): 50 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 51 | if not osp.exists(self.train_dir): 52 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 53 | if not osp.exists(self.val_dir): 54 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 55 | if not osp.exists(self.text_dir): 56 | raise RuntimeError("'{}' is not available".format(self.text_dir)) 57 | 58 | def _get_id(self, file_path): 59 | with open(file_path, 'r') as f: 60 | ids = f.read().splitlines() 61 | ids = [int(y) for y in ids[0].split(',')] 62 | ids = ["%04d" % x for x in ids] 63 | return ids 64 | 65 | def _process_dir(self, ids, cams): 66 | ids_container = list(np.unique(ids)) 67 | id2label = {id_: label for label, id_ in enumerate(ids_container)} 68 | 69 | dataset = [] 70 | for id_ in sorted(ids): 71 | for cam in cams: 72 | img_dir = osp.join(self.dataset_dir, cam, id_) 73 | if osp.isdir(img_dir): 74 | img_list = glob(osp.join(img_dir, "*.jpg")) 75 | img_list.sort() 76 | for img_path in img_list: 77 | dataset.append((img_path, id2label[id_], int(cam[-1]) - 1)) 78 | return dataset 79 | 80 | def _process_dir_gallery(self, ids, cams): 81 | ids_container = list(np.unique(ids)) 82 | id2label = {id_: label for label, id_ in enumerate(ids_container)} 83 | 84 | dataset = [] 85 | for id_ in sorted(ids): 86 | for cam in cams: 87 | img_dir = osp.join(self.dataset_dir, cam, id_) 88 | if osp.isdir(img_dir): 89 | img_list = glob(osp.join(img_dir, "*.jpg")) 90 | img_list.sort() 91 | dataset.append((random.choice(img_list), id2label[id_], int(cam[-1]) - 1)) 92 | return dataset -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/vehicleid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | import os.path as osp 7 | 8 | from ..utils.data import BaseImageDataset 9 | from collections import defaultdict 10 | 11 | 12 | class VehicleID(BaseImageDataset): 13 | """ 14 | VehicleID 15 | Reference: 16 | Deep Relative Distance Learning: Tell the Difference Between Similar Vehicles 17 | 18 | Dataset statistics: 19 | # train_list: 13164 vehicles for model training 20 | # test_list_800: 800 vehicles for model testing(small test set in paper 21 | # test_list_1600: 1600 vehicles for model testing(medium test set in paper 22 | # test_list_2400: 2400 vehicles for model testing(large test set in paper 23 | # test_list_3200: 3200 vehicles for model testing 24 | # test_list_6000: 6000 vehicles for model testing 25 | # test_list_13164: 13164 vehicles for model testing 26 | """ 27 | dataset_dir = 'VehicleID' 28 | 29 | def __init__(self, root, verbose=True, test_size=800, **kwargs): 30 | super(VehicleID, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.img_dir = osp.join(self.dataset_dir, 'image') 33 | self.split_dir = osp.join(self.dataset_dir, 'train_test_split') 34 | self.train_list = osp.join(self.split_dir, 'train_list.txt') 35 | self.test_size = test_size 36 | 37 | if self.test_size == 800: 38 | self.test_list = osp.join(self.split_dir, 'test_list_800.txt') 39 | elif self.test_size == 1600: 40 | self.test_list = osp.join(self.split_dir, 'test_list_1600.txt') 41 | elif self.test_size == 2400: 42 | self.test_list = osp.join(self.split_dir, 'test_list_2400.txt') 43 | 44 | print(self.test_list) 45 | 46 | self.check_before_run() 47 | 48 | train, query, gallery = self.process_split(relabel=True) 49 | self.train = train 50 | self.query = query 51 | self.gallery = gallery 52 | 53 | if verbose: 54 | print('=> VehicleID loaded') 55 | self.print_dataset_statistics(train, query, gallery) 56 | 57 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 58 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 59 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 60 | 61 | def check_before_run(self): 62 | """Check if all files are available before going deeper""" 63 | if not osp.exists(self.dataset_dir): 64 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 65 | if not osp.exists(self.split_dir): 66 | raise RuntimeError('"{}" is not available'.format(self.split_dir)) 67 | if not osp.exists(self.train_list): 68 | raise RuntimeError('"{}" is not available'.format(self.train_list)) 69 | if self.test_size not in [800, 1600, 2400]: 70 | raise RuntimeError('"{}" is not available'.format(self.test_size)) 71 | if not osp.exists(self.test_list): 72 | raise RuntimeError('"{}" is not available'.format(self.test_list)) 73 | 74 | def get_pid2label(self, pids): 75 | pid_container = set(pids) 76 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 77 | return pid2label 78 | 79 | def parse_img_pids(self, nl_pairs, pid2label=None): 80 | # il_pair is the pairs of img name and label 81 | output = [] 82 | for info in nl_pairs: 83 | name = info[0] 84 | pid = info[1] 85 | if pid2label is not None: 86 | pid = pid2label[pid] 87 | camid = 0 # don't have camid information use 0 for all 88 | img_path = osp.join(self.img_dir, name+'.jpg') 89 | output.append((img_path, pid, camid)) 90 | return output 91 | 92 | def process_split(self, relabel=False): 93 | # read train paths 94 | train_pid_dict = defaultdict(list) 95 | 96 | # 'train_list.txt' format: 97 | # the first number is the number of image 98 | # the second number is the id of vehicle 99 | with open(self.train_list) as f_train: 100 | train_data = f_train.readlines() 101 | for data in train_data: 102 | name, pid = data.strip().split(' ') 103 | pid = int(pid) 104 | train_pid_dict[pid].append([name, pid]) 105 | train_pids = list(train_pid_dict.keys()) 106 | num_train_pids = len(train_pids) 107 | assert num_train_pids == 13164, 'There should be 13164 vehicles for training,' \ 108 | ' but but got {}, please check the data'\ 109 | .format(num_train_pids) 110 | # print('num of train ids: {}'.format(num_train_pids)) 111 | test_pid_dict = defaultdict(list) 112 | with open(self.test_list) as f_test: 113 | test_data = f_test.readlines() 114 | for data in test_data: 115 | name, pid = data.split(' ') 116 | pid = int(pid) 117 | test_pid_dict[pid].append([name, pid]) 118 | test_pids = list(test_pid_dict.keys()) 119 | num_test_pids = len(test_pids) 120 | assert num_test_pids == self.test_size, 'There should be {} vehicles for testing,' \ 121 | ' but but got {}, please check the data'\ 122 | .format(self.test_size, num_test_pids) 123 | 124 | train_data = [] 125 | query_data = [] 126 | gallery_data = [] 127 | 128 | # for train ids, all images are used in the train set. 129 | for pid in train_pids: 130 | imginfo = train_pid_dict[pid] # imginfo include image name and id 131 | train_data.extend(imginfo) 132 | 133 | # for each test id, random choose one image for gallery 134 | # and the other ones for query. 135 | for pid in test_pids: 136 | imginfo = test_pid_dict[pid] 137 | sample = random.choice(imginfo) 138 | imginfo.remove(sample) 139 | query_data.extend(imginfo) 140 | gallery_data.append(sample) 141 | 142 | if relabel: 143 | train_pid2label = self.get_pid2label(train_pids) 144 | else: 145 | train_pid2label = None 146 | # for key, value in train_pid2label.items(): 147 | # print('{key}:{value}'.format(key=key, value=value)) 148 | 149 | train = self.parse_img_pids(train_data, train_pid2label) 150 | query = self.parse_img_pids(query_data) 151 | gallery = self.parse_img_pids(gallery_data) 152 | return train, query, gallery 153 | -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/vehiclex.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import re 7 | import os.path as osp 8 | 9 | from ..utils.data import BaseDataset 10 | 11 | 12 | class VehicleX(BaseDataset): 13 | """ 14 | VeRi 15 | Reference: 16 | PAMTRI: Pose-Aware Multi-Task Learning for Vehicle Re-Identification Using Highly Randomized Synthetic Data. In: ICCV 2019 17 | """ 18 | dataset_dir = 'AIC20_ReID_Simulation' 19 | 20 | def __init__(self, root, verbose=True, **kwargs): 21 | super(VehicleX, self).__init__() 22 | self.dataset_dir = osp.join(root, self.dataset_dir) 23 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 24 | 25 | self.check_before_run() 26 | 27 | train = self.process_dir(self.train_dir, relabel=True) 28 | 29 | if verbose: 30 | print('=> VehicleX loaded') 31 | self.print_dataset_statistics(train) 32 | 33 | self.train = train 34 | 35 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 36 | 37 | def check_before_run(self): 38 | """Check if all files are available before going deeper""" 39 | if not osp.exists(self.dataset_dir): 40 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 41 | if not osp.exists(self.train_dir): 42 | raise RuntimeError('"{}" is not available'.format(self.train_dir)) 43 | 44 | def process_dir(self, dir_path, relabel=False): 45 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 46 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 47 | 48 | pid_container = set() 49 | for img_path in img_paths: 50 | pid, _ = map(int, pattern.search(img_path).groups()) 51 | if pid == -1: 52 | continue # junk images are just ignored 53 | pid_container.add(pid) 54 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 55 | 56 | dataset = [] 57 | for img_path in img_paths: 58 | pid, camid = map(int, pattern.search(img_path).groups()) 59 | if pid == -1: 60 | continue # junk images are just ignored 61 | assert 1 <= pid <= 1362 62 | assert 6 <= camid <= 36 63 | camid -= 6 # index starts from 0 64 | if relabel: 65 | pid = pid2label[pid] 66 | dataset.append((img_path, pid, camid)) 67 | return dataset 68 | 69 | def print_dataset_statistics(self, train): 70 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 71 | 72 | print("Dataset statistics:") 73 | print(" ----------------------------------------") 74 | print(" subset | # ids | # images | # cameras") 75 | print(" ----------------------------------------") 76 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 77 | print(" ----------------------------------------") 78 | -------------------------------------------------------------------------------- /SpCL-master/spcl/datasets/veri.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import re 7 | import os.path as osp 8 | 9 | from ..utils.data import BaseImageDataset 10 | 11 | 12 | class VeRi(BaseImageDataset): 13 | """ 14 | VeRi 15 | Reference: 16 | Liu, X., Liu, W., Ma, H., Fu, H.: Large-scale vehicle re-identification in urban surveillance videos. In: IEEE % 17 | International Conference on Multimedia and Expo. (2016) accepted. 18 | Dataset statistics: 19 | # identities: 776 vehicles(576 for training and 200 for testing) 20 | # images: 37778 (train) + 11579 (query) 21 | """ 22 | dataset_dir = 'VeRi' 23 | 24 | def __init__(self, root, verbose=True, **kwargs): 25 | super(VeRi, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 30 | 31 | self.check_before_run() 32 | 33 | train = self.process_dir(self.train_dir, relabel=True) 34 | query = self.process_dir(self.query_dir, relabel=False) 35 | gallery = self.process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | print('=> VeRi loaded') 39 | self.print_dataset_statistics(train, query, gallery) 40 | 41 | self.train = train 42 | self.query = query 43 | self.gallery = gallery 44 | 45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 48 | 49 | def check_before_run(self): 50 | """Check if all files are available before going deeper""" 51 | if not osp.exists(self.dataset_dir): 52 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 53 | if not osp.exists(self.train_dir): 54 | raise RuntimeError('"{}" is not available'.format(self.train_dir)) 55 | if not osp.exists(self.query_dir): 56 | raise RuntimeError('"{}" is not available'.format(self.query_dir)) 57 | if not osp.exists(self.gallery_dir): 58 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir)) 59 | 60 | def process_dir(self, dir_path, relabel=False): 61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 62 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 63 | 64 | pid_container = set() 65 | for img_path in img_paths: 66 | pid, _ = map(int, pattern.search(img_path).groups()) 67 | if pid == -1: 68 | continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | dataset = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1: 76 | continue # junk images are just ignored 77 | assert 0 <= pid <= 776 # pid == 0 means background 78 | assert 1 <= camid <= 20 79 | camid -= 1 # index starts from 0 80 | if relabel: 81 | pid = pid2label[pid] 82 | dataset.append((img_path, pid, camid)) 83 | 84 | return dataset 85 | -------------------------------------------------------------------------------- /SpCL-master/spcl/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap' 10 | ] 11 | -------------------------------------------------------------------------------- /SpCL-master/spcl/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from ..utils import to_torch 5 | 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | with torch.no_grad(): 9 | output, target = to_torch(output), to_torch(target) 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | ret = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 20 | ret.append(correct_k.mul_(1. / batch_size)) 21 | return ret 22 | -------------------------------------------------------------------------------- /SpCL-master/spcl/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /SpCL-master/spcl/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import collections 4 | from collections import OrderedDict 5 | import numpy as np 6 | import torch 7 | import random 8 | import copy 9 | 10 | from .evaluation_metrics import cmc, mean_ap 11 | from .utils.meters import AverageMeter 12 | from .utils.rerank import re_ranking 13 | from .utils import to_torch 14 | 15 | def extract_cnn_feature(model, inputs): 16 | inputs = to_torch(inputs).cuda() 17 | outputs = model(inputs) 18 | outputs = outputs.data.cpu() 19 | return outputs 20 | 21 | def extract_features(model, data_loader, print_freq=50): 22 | model.eval() 23 | batch_time = AverageMeter() 24 | data_time = AverageMeter() 25 | 26 | features = OrderedDict() 27 | labels = OrderedDict() 28 | 29 | end = time.time() 30 | with torch.no_grad(): 31 | for i, (imgs, fnames, pids, _, _) in enumerate(data_loader): 32 | data_time.update(time.time() - end) 33 | 34 | outputs = extract_cnn_feature(model, imgs) 35 | for fname, output, pid in zip(fnames, outputs, pids): 36 | features[fname] = output 37 | labels[fname] = pid 38 | 39 | batch_time.update(time.time() - end) 40 | end = time.time() 41 | 42 | if (i + 1) % print_freq == 0: 43 | print('Extract Features: [{}/{}]\t' 44 | 'Time {:.3f} ({:.3f})\t' 45 | 'Data {:.3f} ({:.3f})\t' 46 | .format(i + 1, len(data_loader), 47 | batch_time.val, batch_time.avg, 48 | data_time.val, data_time.avg)) 49 | 50 | return features, labels 51 | 52 | def pairwise_distance(features, query=None, gallery=None): 53 | if query is None and gallery is None: 54 | n = len(features) 55 | x = torch.cat(list(features.values())) 56 | x = x.view(n, -1) 57 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 58 | dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t()) 59 | return dist_m 60 | 61 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 62 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 63 | m, n = x.size(0), y.size(0) 64 | x = x.view(m, -1) 65 | y = y.view(n, -1) 66 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 67 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 68 | dist_m.addmm_(1, -2, x, y.t()) 69 | return dist_m, x.numpy(), y.numpy() 70 | 71 | def evaluate_all(query_features, gallery_features, distmat, query=None, gallery=None, 72 | query_ids=None, gallery_ids=None, 73 | query_cams=None, gallery_cams=None, 74 | cmc_topk=(1, 5, 10), cmc_flag=False): 75 | if query is not None and gallery is not None: 76 | query_ids = [pid for _, pid, _ in query] 77 | gallery_ids = [pid for _, pid, _ in gallery] 78 | query_cams = [cam for _, _, cam in query] 79 | gallery_cams = [cam for _, _, cam in gallery] 80 | else: 81 | assert (query_ids is not None and gallery_ids is not None 82 | and query_cams is not None and gallery_cams is not None) 83 | 84 | # Compute mean AP 85 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 86 | print('Mean AP: {:4.1%}'.format(mAP)) 87 | 88 | if (not cmc_flag): 89 | return mAP 90 | 91 | cmc_configs = { 92 | 'market1501': dict(separate_camera_set=False, 93 | single_gallery_shot=False, 94 | first_match_break=True),} 95 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 96 | query_cams, gallery_cams, **params) 97 | for name, params in cmc_configs.items()} 98 | 99 | print('CMC Scores:') 100 | for k in cmc_topk: 101 | print(' top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k-1])) 102 | return cmc_scores['market1501'], mAP 103 | 104 | 105 | class Evaluator(object): 106 | def __init__(self, model): 107 | super(Evaluator, self).__init__() 108 | self.model = model 109 | 110 | def evaluate(self, data_loader, query, gallery, cmc_flag=False, rerank=False): 111 | features, _ = extract_features(self.model, data_loader) 112 | distmat, query_features, gallery_features = pairwise_distance(features, query, gallery) 113 | results = evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 114 | 115 | if (not rerank): 116 | return results 117 | 118 | print('Applying person re-ranking ...') 119 | distmat_qq, _, _ = pairwise_distance(features, query, query) 120 | distmat_gg, _, _ = pairwise_distance(features, gallery, gallery) 121 | distmat = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy()) 122 | return evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 123 | -------------------------------------------------------------------------------- /SpCL-master/spcl/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .resnet_ibn import * 5 | 6 | 7 | __factory = { 8 | 'resnet18': resnet18, 9 | 'resnet34': resnet34, 10 | 'resnet50': resnet50, 11 | 'resnet101': resnet101, 12 | 'resnet152': resnet152, 13 | 'resnet_ibn50a': resnet_ibn50a, 14 | 'resnet_ibn101a': resnet_ibn101a 15 | } 16 | 17 | 18 | def names(): 19 | return sorted(__factory.keys()) 20 | 21 | 22 | def create(name, *args, **kwargs): 23 | """ 24 | Create a model instance. 25 | 26 | Parameters 27 | ---------- 28 | name : str 29 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 30 | 'resnet50', 'resnet101', and 'resnet152'. 31 | pretrained : bool, optional 32 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 33 | model. Default: True 34 | cut_at_pooling : bool, optional 35 | If True, will cut the model before the last global pooling layer and 36 | ignore the remaining kwargs. Default: False 37 | num_features : int, optional 38 | If positive, will append a Linear layer after the global pooling layer, 39 | with this number of output units, followed by a BatchNorm layer. 40 | Otherwise these layers will not be appended. Default: 256 for 41 | 'inception', 0 for 'resnet*' 42 | norm : bool, optional 43 | If True, will normalize the feature to be unit L2-norm for each sample. 44 | Otherwise will append a ReLU layer after the above Linear layer if 45 | num_features > 0. Default: False 46 | dropout : float, optional 47 | If positive, will append a Dropout layer with this dropout rate. 48 | Default: 0 49 | num_classes : int, optional 50 | If positive, will append a Linear layer at the end as the classifier 51 | with this number of output units. Default: 0 52 | """ 53 | if name not in __factory: 54 | raise KeyError("Unknown model:", name) 55 | return __factory[name](*args, **kwargs) 56 | -------------------------------------------------------------------------------- /SpCL-master/spcl/models/dsbn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Domain-specific BatchNorm 5 | 6 | class DSBN2d(nn.Module): 7 | def __init__(self, planes): 8 | super(DSBN2d, self).__init__() 9 | self.num_features = planes 10 | self.BN_S = nn.BatchNorm2d(planes) 11 | self.BN_T = nn.BatchNorm2d(planes) 12 | 13 | def forward(self, x): 14 | if (not self.training): 15 | return self.BN_T(x) 16 | 17 | bs = x.size(0) 18 | assert (bs%2==0) 19 | split = torch.split(x, int(bs/2), 0) 20 | out1 = self.BN_S(split[0].contiguous()) 21 | out2 = self.BN_T(split[1].contiguous()) 22 | out = torch.cat((out1, out2), 0) 23 | return out 24 | 25 | class DSBN1d(nn.Module): 26 | def __init__(self, planes): 27 | super(DSBN1d, self).__init__() 28 | self.num_features = planes 29 | self.BN_S = nn.BatchNorm1d(planes) 30 | self.BN_T = nn.BatchNorm1d(planes) 31 | 32 | def forward(self, x): 33 | if (not self.training): 34 | return self.BN_T(x) 35 | 36 | bs = x.size(0) 37 | assert (bs%2==0) 38 | split = torch.split(x, int(bs/2), 0) 39 | out1 = self.BN_S(split[0].contiguous()) 40 | out2 = self.BN_T(split[1].contiguous()) 41 | out = torch.cat((out1, out2), 0) 42 | return out 43 | 44 | def convert_dsbn(model): 45 | for _, (child_name, child) in enumerate(model.named_children()): 46 | assert(not next(model.parameters()).is_cuda) 47 | if isinstance(child, nn.BatchNorm2d): 48 | m = DSBN2d(child.num_features) 49 | m.BN_S.load_state_dict(child.state_dict()) 50 | m.BN_T.load_state_dict(child.state_dict()) 51 | setattr(model, child_name, m) 52 | elif isinstance(child, nn.BatchNorm1d): 53 | m = DSBN1d(child.num_features) 54 | m.BN_S.load_state_dict(child.state_dict()) 55 | m.BN_T.load_state_dict(child.state_dict()) 56 | setattr(model, child_name, m) 57 | else: 58 | convert_dsbn(child) 59 | 60 | def convert_bn(model, use_target=True): 61 | for _, (child_name, child) in enumerate(model.named_children()): 62 | assert(not next(model.parameters()).is_cuda) 63 | if isinstance(child, DSBN2d): 64 | m = nn.BatchNorm2d(child.num_features) 65 | if use_target: 66 | m.load_state_dict(child.BN_T.state_dict()) 67 | else: 68 | m.load_state_dict(child.BN_S.state_dict()) 69 | setattr(model, child_name, m) 70 | elif isinstance(child, DSBN1d): 71 | m = nn.BatchNorm1d(child.num_features) 72 | if use_target: 73 | m.load_state_dict(child.BN_T.state_dict()) 74 | else: 75 | m.load_state_dict(child.BN_S.state_dict()) 76 | setattr(model, child_name, m) 77 | else: 78 | convert_bn(child, use_target=use_target) 79 | -------------------------------------------------------------------------------- /SpCL-master/spcl/models/hm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | from torch import nn, autograd 7 | 8 | 9 | class HM(autograd.Function): 10 | 11 | @staticmethod 12 | def forward(ctx, inputs, indexes, features, momentum): 13 | ctx.features = features 14 | ctx.momentum = momentum 15 | ctx.save_for_backward(inputs, indexes) 16 | outputs = inputs.mm(ctx.features.t()) 17 | 18 | return outputs 19 | 20 | @staticmethod 21 | def backward(ctx, grad_outputs): 22 | inputs, indexes = ctx.saved_tensors 23 | grad_inputs = None 24 | if ctx.needs_input_grad[0]: 25 | grad_inputs = grad_outputs.mm(ctx.features) 26 | 27 | # momentum update 28 | for x, y in zip(inputs, indexes): 29 | ctx.features[y] = ctx.momentum * ctx.features[y] + (1. - ctx.momentum) * x 30 | ctx.features[y] /= ctx.features[y].norm() 31 | 32 | return grad_inputs, None, None, None 33 | 34 | 35 | def hm(inputs, indexes, features, momentum=0.5): 36 | return HM.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 37 | 38 | 39 | class HybridMemory(nn.Module): 40 | def __init__(self, num_features, num_samples, temp=0.05, momentum=0.2): 41 | super(HybridMemory, self).__init__() 42 | self.num_features = num_features 43 | self.num_samples = num_samples 44 | 45 | self.momentum = momentum 46 | self.temp = temp 47 | 48 | self.register_buffer('features', torch.zeros(num_samples, num_features)) 49 | self.register_buffer('labels', torch.zeros(num_samples).long()) 50 | 51 | def forward(self, inputs, indexes): 52 | # inputs: B*2048, features: L*2048 53 | inputs = hm(inputs, indexes, self.features, self.momentum) 54 | inputs /= self.temp 55 | B = inputs.size(0) 56 | 57 | def masked_softmax(vec, mask, dim=1, epsilon=1e-6): 58 | exps = torch.exp(vec) 59 | masked_exps = exps * mask.float().clone() 60 | masked_sums = masked_exps.sum(dim, keepdim=True) + epsilon 61 | return (masked_exps/masked_sums) 62 | 63 | targets = self.labels[indexes].clone() 64 | labels = self.labels.clone() 65 | 66 | sim = torch.zeros(labels.max()+1, B).float().cuda() 67 | sim.index_add_(0, labels, inputs.t().contiguous()) 68 | nums = torch.zeros(labels.max()+1, 1).float().cuda() 69 | nums.index_add_(0, labels, torch.ones(self.num_samples,1).float().cuda()) 70 | mask = (nums>0).float() 71 | sim /= (mask*nums+(1-mask)).clone().expand_as(sim) 72 | mask = mask.expand_as(sim) 73 | masked_sim = masked_softmax(sim.t().contiguous(), mask.t().contiguous()) 74 | return F.nll_loss(torch.log(masked_sim+1e-6), targets) 75 | -------------------------------------------------------------------------------- /SpCL-master/spcl/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152'] 12 | 13 | 14 | class ResNet(nn.Module): 15 | __factory = { 16 | 18: torchvision.models.resnet18, 17 | 34: torchvision.models.resnet34, 18 | 50: torchvision.models.resnet50, 19 | 101: torchvision.models.resnet101, 20 | 152: torchvision.models.resnet152, 21 | } 22 | 23 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 24 | num_features=0, norm=False, dropout=0, num_classes=0): 25 | super(ResNet, self).__init__() 26 | self.pretrained = pretrained 27 | self.depth = depth 28 | self.cut_at_pooling = cut_at_pooling 29 | # Construct base (pretrained) resnet 30 | if depth not in ResNet.__factory: 31 | raise KeyError("Unsupported depth:", depth) 32 | resnet = ResNet.__factory[depth](pretrained=pretrained) 33 | resnet.layer4[0].conv2.stride = (1,1) 34 | resnet.layer4[0].downsample[0].stride = (1,1) 35 | self.base = nn.Sequential( 36 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 37 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 38 | self.gap = nn.AdaptiveAvgPool2d(1) 39 | 40 | if not self.cut_at_pooling: 41 | self.num_features = num_features 42 | self.norm = norm 43 | self.dropout = dropout 44 | self.has_embedding = num_features > 0 45 | self.num_classes = num_classes 46 | 47 | out_planes = resnet.fc.in_features 48 | 49 | # Append new layers 50 | if self.has_embedding: 51 | self.feat = nn.Linear(out_planes, self.num_features) 52 | self.feat_bn = nn.BatchNorm1d(self.num_features) 53 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 54 | init.constant_(self.feat.bias, 0) 55 | else: 56 | # Change the num_features to CNN output channels 57 | self.num_features = out_planes 58 | self.feat_bn = nn.BatchNorm1d(self.num_features) 59 | self.feat_bn.bias.requires_grad_(False) 60 | if self.dropout > 0: 61 | self.drop = nn.Dropout(self.dropout) 62 | if self.num_classes > 0: 63 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 64 | init.normal_(self.classifier.weight, std=0.001) 65 | init.constant_(self.feat_bn.weight, 1) 66 | init.constant_(self.feat_bn.bias, 0) 67 | 68 | if not pretrained: 69 | self.reset_params() 70 | 71 | def forward(self, x): 72 | bs = x.size(0) 73 | x = self.base(x) 74 | 75 | x = self.gap(x) 76 | x = x.view(x.size(0), -1) 77 | 78 | if self.cut_at_pooling: 79 | return x 80 | 81 | if self.has_embedding: 82 | bn_x = self.feat_bn(self.feat(x)) 83 | else: 84 | bn_x = self.feat_bn(x) 85 | 86 | if (self.training is False): 87 | bn_x = F.normalize(bn_x) 88 | return bn_x 89 | 90 | if self.norm: 91 | bn_x = F.normalize(bn_x) 92 | elif self.has_embedding: 93 | bn_x = F.relu(bn_x) 94 | 95 | if self.dropout > 0: 96 | bn_x = self.drop(bn_x) 97 | 98 | if self.num_classes > 0: 99 | prob = self.classifier(bn_x) 100 | else: 101 | return bn_x 102 | 103 | return prob 104 | 105 | def reset_params(self): 106 | for m in self.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | init.kaiming_normal_(m.weight, mode='fan_out') 109 | if m.bias is not None: 110 | init.constant_(m.bias, 0) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | init.constant_(m.weight, 1) 113 | init.constant_(m.bias, 0) 114 | elif isinstance(m, nn.BatchNorm1d): 115 | init.constant_(m.weight, 1) 116 | init.constant_(m.bias, 0) 117 | elif isinstance(m, nn.Linear): 118 | init.normal_(m.weight, std=0.001) 119 | if m.bias is not None: 120 | init.constant_(m.bias, 0) 121 | 122 | 123 | def resnet18(**kwargs): 124 | return ResNet(18, **kwargs) 125 | 126 | 127 | def resnet34(**kwargs): 128 | return ResNet(34, **kwargs) 129 | 130 | 131 | def resnet50(**kwargs): 132 | return ResNet(50, **kwargs) 133 | 134 | 135 | def resnet101(**kwargs): 136 | return ResNet(101, **kwargs) 137 | 138 | 139 | def resnet152(**kwargs): 140 | return ResNet(152, **kwargs) 141 | -------------------------------------------------------------------------------- /SpCL-master/spcl/models/resnet_ibn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | 9 | from .resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a 10 | 11 | 12 | __all__ = ['ResNetIBN', 'resnet_ibn50a', 'resnet_ibn101a'] 13 | 14 | 15 | class ResNetIBN(nn.Module): 16 | __factory = { 17 | '50a': resnet50_ibn_a, 18 | '101a': resnet101_ibn_a 19 | } 20 | 21 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 22 | num_features=0, norm=False, dropout=0, num_classes=0): 23 | super(ResNetIBN, self).__init__() 24 | 25 | self.depth = depth 26 | self.pretrained = pretrained 27 | self.cut_at_pooling = cut_at_pooling 28 | 29 | resnet = ResNetIBN.__factory[depth](pretrained=pretrained) 30 | resnet.layer4[0].conv2.stride = (1,1) 31 | resnet.layer4[0].downsample[0].stride = (1,1) 32 | self.base = nn.Sequential( 33 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 34 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 35 | self.gap = nn.AdaptiveAvgPool2d(1) 36 | 37 | if not self.cut_at_pooling: 38 | self.num_features = num_features 39 | self.norm = norm 40 | self.dropout = dropout 41 | self.has_embedding = num_features > 0 42 | self.num_classes = num_classes 43 | 44 | out_planes = resnet.fc.in_features 45 | 46 | # Append new layers 47 | if self.has_embedding: 48 | self.feat = nn.Linear(out_planes, self.num_features) 49 | self.feat_bn = nn.BatchNorm1d(self.num_features) 50 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 51 | init.constant_(self.feat.bias, 0) 52 | else: 53 | # Change the num_features to CNN output channels 54 | self.num_features = out_planes 55 | self.feat_bn = nn.BatchNorm1d(self.num_features) 56 | self.feat_bn.bias.requires_grad_(False) 57 | if self.dropout > 0: 58 | self.drop = nn.Dropout(self.dropout) 59 | if self.num_classes > 0: 60 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 61 | init.normal_(self.classifier.weight, std=0.001) 62 | init.constant_(self.feat_bn.weight, 1) 63 | init.constant_(self.feat_bn.bias, 0) 64 | 65 | if not pretrained: 66 | self.reset_params() 67 | 68 | def forward(self, x): 69 | x = self.base(x) 70 | 71 | x = self.gap(x) 72 | x = x.view(x.size(0), -1) 73 | 74 | if self.cut_at_pooling: 75 | return x 76 | 77 | if self.has_embedding: 78 | bn_x = self.feat_bn(self.feat(x)) 79 | else: 80 | bn_x = self.feat_bn(x) 81 | 82 | if self.training is False: 83 | bn_x = F.normalize(bn_x) 84 | return bn_x 85 | 86 | if self.norm: 87 | bn_x = F.normalize(bn_x) 88 | elif self.has_embedding: 89 | bn_x = F.relu(bn_x) 90 | 91 | if self.dropout > 0: 92 | bn_x = self.drop(bn_x) 93 | 94 | if self.num_classes > 0: 95 | prob = self.classifier(bn_x) 96 | else: 97 | return bn_x 98 | 99 | return prob 100 | 101 | def reset_params(self): 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | init.kaiming_normal_(m.weight, mode='fan_out') 105 | if m.bias is not None: 106 | init.constant_(m.bias, 0) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | init.constant_(m.weight, 1) 109 | init.constant_(m.bias, 0) 110 | elif isinstance(m, nn.BatchNorm1d): 111 | init.constant_(m.weight, 1) 112 | init.constant_(m.bias, 0) 113 | elif isinstance(m, nn.Linear): 114 | init.normal_(m.weight, std=0.001) 115 | if m.bias is not None: 116 | init.constant_(m.bias, 0) 117 | 118 | 119 | def resnet_ibn50a(**kwargs): 120 | return ResNetIBN('50a', **kwargs) 121 | 122 | 123 | def resnet_ibn101a(**kwargs): 124 | return ResNetIBN('101a', **kwargs) 125 | -------------------------------------------------------------------------------- /SpCL-master/spcl/models/resnet_ibn_a.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['ResNet', 'resnet50_ibn_a', 'resnet101_ibn_a'] 8 | 9 | 10 | model_urls = { 11 | 'ibn_resnet50a': './logs/pretrained/resnet50_ibn_a.pth.tar', 12 | 'ibn_resnet101a': './logs/pretrained/resnet101_ibn_a.pth.tar', 13 | } 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class IBN(nn.Module): 55 | def __init__(self, planes): 56 | super(IBN, self).__init__() 57 | half1 = int(planes/2) 58 | self.half = half1 59 | half2 = planes - half1 60 | self.IN = nn.InstanceNorm2d(half1, affine=True) 61 | self.BN = nn.BatchNorm2d(half2) 62 | 63 | def forward(self, x): 64 | split = torch.split(x, self.half, 1) 65 | out1 = self.IN(split[0].contiguous()) 66 | out2 = self.BN(split[1].contiguous()) 67 | out = torch.cat((out1, out2), 1) 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): 75 | super(Bottleneck, self).__init__() 76 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 77 | if ibn: 78 | self.bn1 = IBN(planes) 79 | else: 80 | self.bn1 = nn.BatchNorm2d(planes) 81 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 82 | padding=1, bias=False) 83 | self.bn2 = nn.BatchNorm2d(planes) 84 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 85 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 86 | self.relu = nn.ReLU(inplace=True) 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x): 91 | residual = x 92 | 93 | out = self.conv1(x) 94 | out = self.bn1(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv2(out) 98 | out = self.bn2(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv3(out) 102 | out = self.bn3(out) 103 | 104 | if self.downsample is not None: 105 | residual = self.downsample(x) 106 | 107 | out += residual 108 | out = self.relu(out) 109 | 110 | return out 111 | 112 | 113 | class ResNet(nn.Module): 114 | 115 | def __init__(self, block, layers, num_classes=1000): 116 | scale = 64 117 | self.inplanes = scale 118 | super(ResNet, self).__init__() 119 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 120 | bias=False) 121 | self.bn1 = nn.BatchNorm2d(scale) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 124 | self.layer1 = self._make_layer(block, scale, layers[0]) 125 | self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2) 126 | self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2) 127 | self.layer4 = self._make_layer(block, scale*8, layers[3], stride=2) 128 | self.avgpool = nn.AvgPool2d(7) 129 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 130 | 131 | for m in self.modules(): 132 | if isinstance(m, nn.Conv2d): 133 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 134 | m.weight.data.normal_(0, math.sqrt(2. / n)) 135 | elif isinstance(m, nn.BatchNorm2d): 136 | m.weight.data.fill_(1) 137 | m.bias.data.zero_() 138 | elif isinstance(m, nn.InstanceNorm2d): 139 | m.weight.data.fill_(1) 140 | m.bias.data.zero_() 141 | 142 | def _make_layer(self, block, planes, blocks, stride=1): 143 | downsample = None 144 | if stride != 1 or self.inplanes != planes * block.expansion: 145 | downsample = nn.Sequential( 146 | nn.Conv2d(self.inplanes, planes * block.expansion, 147 | kernel_size=1, stride=stride, bias=False), 148 | nn.BatchNorm2d(planes * block.expansion), 149 | ) 150 | 151 | layers = [] 152 | ibn = True 153 | if planes == 512: 154 | ibn = False 155 | layers.append(block(self.inplanes, planes, ibn, stride, downsample)) 156 | self.inplanes = planes * block.expansion 157 | for i in range(1, blocks): 158 | layers.append(block(self.inplanes, planes, ibn)) 159 | 160 | return nn.Sequential(*layers) 161 | 162 | def forward(self, x): 163 | x = self.conv1(x) 164 | x = self.bn1(x) 165 | x = self.relu(x) 166 | x = self.maxpool(x) 167 | 168 | x = self.layer1(x) 169 | x = self.layer2(x) 170 | x = self.layer3(x) 171 | x = self.layer4(x) 172 | 173 | x = self.avgpool(x) 174 | x = x.view(x.size(0), -1) 175 | x = self.fc(x) 176 | 177 | return x 178 | 179 | 180 | def resnet50_ibn_a(pretrained=False, **kwargs): 181 | """Constructs a ResNet-50 model. 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 186 | if pretrained: 187 | state_dict = torch.load(model_urls['ibn_resnet50a'], map_location=torch.device('cpu'))['state_dict'] 188 | state_dict = remove_module_key(state_dict) 189 | model.load_state_dict(state_dict) 190 | return model 191 | 192 | 193 | def resnet101_ibn_a(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | state_dict = torch.load(model_urls['ibn_resnet101a'], map_location=torch.device('cpu'))['state_dict'] 201 | state_dict = remove_module_key(state_dict) 202 | model.load_state_dict(state_dict) 203 | return model 204 | 205 | 206 | def remove_module_key(state_dict): 207 | for key in list(state_dict.keys()): 208 | if 'module' in key: 209 | state_dict[key.replace('module.','')] = state_dict.pop(key) 210 | return state_dict 211 | -------------------------------------------------------------------------------- /SpCL-master/spcl/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import numpy as np 4 | import collections 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | 10 | from .utils.meters import AverageMeter 11 | 12 | 13 | class SpCLTrainer_UDA(object): 14 | def __init__(self, encoder, memory, source_classes): 15 | super(SpCLTrainer_UDA, self).__init__() 16 | self.encoder = encoder 17 | self.memory = memory 18 | self.source_classes = source_classes 19 | 20 | def train(self, epoch, data_loader_source, data_loader_target, 21 | optimizer, print_freq=10, train_iters=400): 22 | self.encoder.train() 23 | 24 | batch_time = AverageMeter() 25 | data_time = AverageMeter() 26 | 27 | losses_s = AverageMeter() 28 | losses_t = AverageMeter() 29 | 30 | end = time.time() 31 | for i in range(train_iters): 32 | # load data 33 | source_inputs = data_loader_source.next() 34 | target_inputs = data_loader_target.next() 35 | data_time.update(time.time() - end) 36 | 37 | # process inputs 38 | s_inputs, s_targets, _ = self._parse_data(source_inputs) 39 | t_inputs, _, t_indexes = self._parse_data(target_inputs) 40 | 41 | # arrange batch for domain-specific BN 42 | device_num = torch.cuda.device_count() 43 | B, C, H, W = s_inputs.size() 44 | def reshape(inputs): 45 | return inputs.view(device_num, -1, C, H, W) 46 | s_inputs, t_inputs = reshape(s_inputs), reshape(t_inputs) 47 | inputs = torch.cat((s_inputs, t_inputs), 1).view(-1, C, H, W) 48 | 49 | # forward 50 | f_out = self._forward(inputs) 51 | 52 | # de-arrange batch 53 | f_out = f_out.view(device_num, -1, f_out.size(-1)) 54 | f_out_s, f_out_t = f_out.split(f_out.size(1)//2, dim=1) 55 | f_out_s, f_out_t = f_out_s.contiguous().view(-1, f_out.size(-1)), f_out_t.contiguous().view(-1, f_out.size(-1)) 56 | 57 | # compute loss with the hybrid memory 58 | loss_s = self.memory(f_out_s, s_targets) 59 | loss_t = self.memory(f_out_t, t_indexes+self.source_classes) 60 | 61 | loss = loss_s+loss_t 62 | optimizer.zero_grad() 63 | loss.backward() 64 | optimizer.step() 65 | 66 | losses_s.update(loss_s.item()) 67 | losses_t.update(loss_t.item()) 68 | 69 | # print log 70 | batch_time.update(time.time() - end) 71 | end = time.time() 72 | 73 | if (i + 1) % print_freq == 0: 74 | print('Epoch: [{}][{}/{}]\t' 75 | 'Time {:.3f} ({:.3f})\t' 76 | 'Data {:.3f} ({:.3f})\t' 77 | 'Loss_s {:.3f} ({:.3f})\t' 78 | 'Loss_t {:.3f} ({:.3f})' 79 | .format(epoch, i + 1, len(data_loader_target), 80 | batch_time.val, batch_time.avg, 81 | data_time.val, data_time.avg, 82 | losses_s.val, losses_s.avg, 83 | losses_t.val, losses_t.avg)) 84 | 85 | def _parse_data(self, inputs): 86 | imgs, _, pids, _, indexes = inputs 87 | return imgs.cuda(), pids.cuda(), indexes.cuda() 88 | 89 | def _forward(self, inputs): 90 | return self.encoder(inputs) 91 | 92 | 93 | class SpCLTrainer_USL(object): 94 | def __init__(self, encoder, memory): 95 | super(SpCLTrainer_USL, self).__init__() 96 | self.encoder = encoder 97 | self.memory = memory 98 | 99 | def train(self, epoch, data_loader, optimizer, print_freq=10, train_iters=400): 100 | self.encoder.train() 101 | 102 | batch_time = AverageMeter() 103 | data_time = AverageMeter() 104 | 105 | losses = AverageMeter() 106 | 107 | end = time.time() 108 | for i in range(train_iters): 109 | # load data 110 | inputs = data_loader.next() 111 | data_time.update(time.time() - end) 112 | 113 | # process inputs 114 | inputs, _, indexes = self._parse_data(inputs) 115 | 116 | # forward 117 | f_out = self._forward(inputs) 118 | 119 | # compute loss with the hybrid memory 120 | loss = self.memory(f_out, indexes) 121 | 122 | optimizer.zero_grad() 123 | loss.backward() 124 | optimizer.step() 125 | 126 | losses.update(loss.item()) 127 | 128 | # print log 129 | batch_time.update(time.time() - end) 130 | end = time.time() 131 | 132 | if (i + 1) % print_freq == 0: 133 | print('Epoch: [{}][{}/{}]\t' 134 | 'Time {:.3f} ({:.3f})\t' 135 | 'Data {:.3f} ({:.3f})\t' 136 | 'Loss {:.3f} ({:.3f})' 137 | .format(epoch, i + 1, len(data_loader), 138 | batch_time.val, batch_time.avg, 139 | data_time.val, data_time.avg, 140 | losses.val, losses.avg)) 141 | 142 | def _parse_data(self, inputs): 143 | imgs, _, pids, _, indexes = inputs 144 | return imgs.cuda(), pids.cuda(), indexes.cuda() 145 | 146 | def _forward(self, inputs): 147 | return self.encoder(inputs) 148 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .base_dataset import BaseDataset, BaseImageDataset 4 | from .preprocessor import Preprocessor 5 | 6 | class IterLoader: 7 | def __init__(self, loader, length=None): 8 | self.loader = loader 9 | self.length = length 10 | self.iter = None 11 | 12 | def __len__(self): 13 | if (self.length is not None): 14 | return self.length 15 | return len(self.loader) 16 | 17 | def new_epoch(self): 18 | self.iter = iter(self.loader) 19 | 20 | def next(self): 21 | try: 22 | return next(self.iter) 23 | except: 24 | self.iter = iter(self.loader) 25 | return next(self.iter) 26 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def print_dataset_statistics(self): 23 | raise NotImplementedError 24 | 25 | @property 26 | def images_dir(self): 27 | return None 28 | 29 | 30 | class BaseImageDataset(BaseDataset): 31 | """ 32 | Base class of image reid dataset 33 | """ 34 | 35 | def print_dataset_statistics(self, train, query, gallery): 36 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 37 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 38 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 39 | 40 | print("Dataset statistics:") 41 | print(" ----------------------------------------") 42 | print(" subset | # ids | # images | # cameras") 43 | print(" ----------------------------------------") 44 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 45 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 46 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 47 | print(" ----------------------------------------") 48 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import random 7 | import math 8 | from PIL import Image 9 | 10 | class Preprocessor(Dataset): 11 | def __init__(self, dataset, root=None, transform=None): 12 | super(Preprocessor, self).__init__() 13 | self.dataset = dataset 14 | self.root = root 15 | self.transform = transform 16 | 17 | def __len__(self): 18 | return len(self.dataset) 19 | 20 | def __getitem__(self, indices): 21 | return self._get_single_item(indices) 22 | 23 | def _get_single_item(self, index): 24 | fname, pid, camid = self.dataset[index] 25 | fpath = fname 26 | if self.root is not None: 27 | fpath = osp.join(self.root, fname) 28 | 29 | img = Image.open(fpath).convert('RGB') 30 | 31 | if self.transform is not None: 32 | img = self.transform(img) 33 | 34 | return img, fname, pid, camid, index 35 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import math 4 | 5 | import numpy as np 6 | import copy 7 | import random 8 | import torch 9 | from torch.utils.data.sampler import ( 10 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 11 | WeightedRandomSampler) 12 | 13 | 14 | def No_index(a, b): 15 | assert isinstance(a, list) 16 | return [i for i, j in enumerate(a) if j != b] 17 | 18 | 19 | class RandomIdentitySampler(Sampler): 20 | def __init__(self, data_source, num_instances): 21 | self.data_source = data_source 22 | self.num_instances = num_instances 23 | self.index_dic = defaultdict(list) 24 | for index, (_, pid, _) in enumerate(data_source): 25 | self.index_dic[pid].append(index) 26 | self.pids = list(self.index_dic.keys()) 27 | self.num_samples = len(self.pids) 28 | 29 | def __len__(self): 30 | return self.num_samples * self.num_instances 31 | 32 | def __iter__(self): 33 | indices = torch.randperm(self.num_samples).tolist() 34 | ret = [] 35 | for i in indices: 36 | pid = self.pids[i] 37 | t = self.index_dic[pid] 38 | if len(t) >= self.num_instances: 39 | t = np.random.choice(t, size=self.num_instances, replace=False) 40 | else: 41 | t = np.random.choice(t, size=self.num_instances, replace=True) 42 | ret.extend(t) 43 | return iter(ret) 44 | 45 | 46 | class RandomMultipleGallerySampler(Sampler): 47 | def __init__(self, data_source, num_instances=4): 48 | self.data_source = data_source 49 | self.index_pid = defaultdict(int) 50 | self.pid_cam = defaultdict(list) 51 | self.pid_index = defaultdict(list) 52 | self.num_instances = num_instances 53 | 54 | for index, (_, pid, cam) in enumerate(data_source): 55 | if (pid<0): continue 56 | self.index_pid[index] = pid 57 | self.pid_cam[pid].append(cam) 58 | self.pid_index[pid].append(index) 59 | 60 | self.pids = list(self.pid_index.keys()) 61 | self.num_samples = len(self.pids) 62 | 63 | def __len__(self): 64 | return self.num_samples * self.num_instances 65 | 66 | def __iter__(self): 67 | indices = torch.randperm(len(self.pids)).tolist() 68 | ret = [] 69 | 70 | for kid in indices: 71 | i = random.choice(self.pid_index[self.pids[kid]]) 72 | 73 | _, i_pid, i_cam = self.data_source[i] 74 | 75 | ret.append(i) 76 | 77 | pid_i = self.index_pid[i] 78 | cams = self.pid_cam[pid_i] 79 | index = self.pid_index[pid_i] 80 | select_cams = No_index(cams, i_cam) 81 | 82 | if select_cams: 83 | 84 | if len(select_cams) >= self.num_instances: 85 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 86 | else: 87 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True) 88 | 89 | for kk in cam_indexes: 90 | ret.append(index[kk]) 91 | 92 | else: 93 | select_indexes = No_index(index, i) 94 | if (not select_indexes): continue 95 | if len(select_indexes) >= self.num_instances: 96 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 97 | else: 98 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 99 | 100 | for kk in ind_indexes: 101 | ret.append(index[kk]) 102 | 103 | 104 | return iter(ret) 105 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | """ Randomly selects a rectangle region in an image and erases its pixels. 54 | 'Random Erasing Data Augmentation' by Zhong et al. 55 | See https://arxiv.org/pdf/1708.04896.pdf 56 | Args: 57 | probability: The probability that the Random Erasing operation will be performed. 58 | sl: Minimum proportion of erased area against input image. 59 | sh: Maximum proportion of erased area against input image. 60 | r1: Minimum aspect ratio of erased area. 61 | mean: Erasing value. 62 | """ 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) >= self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/faiss_rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 5 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 6 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 7 | """ 8 | 9 | import os, sys 10 | import time 11 | import numpy as np 12 | from scipy.spatial.distance import cdist 13 | import gc 14 | import faiss 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | 19 | from .faiss_utils import search_index_pytorch, search_raw_array_pytorch, \ 20 | index_init_gpu, index_init_cpu 21 | 22 | def k_reciprocal_neigh(initial_rank, i, k1): 23 | forward_k_neigh_index = initial_rank[i,:k1+1] 24 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 25 | fi = np.where(backward_k_neigh_index==i)[0] 26 | return forward_k_neigh_index[fi] 27 | 28 | def compute_jaccard_distance(target_features, k1=20, k2=6, print_flag=True, search_option=0, use_float16=False): 29 | end = time.time() 30 | if print_flag: 31 | print('Computing jaccard distance...') 32 | 33 | ngpus = faiss.get_num_gpus() 34 | N = target_features.size(0) 35 | mat_type = np.float16 if use_float16 else np.float32 36 | 37 | if (search_option==0): 38 | # GPU + PyTorch CUDA Tensors (1) 39 | res = faiss.StandardGpuResources() 40 | res.setDefaultNullStreamAllDevices() 41 | _, initial_rank = search_raw_array_pytorch(res, target_features, target_features, k1) 42 | initial_rank = initial_rank.cpu().numpy() 43 | elif (search_option==1): 44 | # GPU + PyTorch CUDA Tensors (2) 45 | res = faiss.StandardGpuResources() 46 | index = faiss.GpuIndexFlatL2(res, target_features.size(-1)) 47 | index.add(target_features.cpu().numpy()) 48 | _, initial_rank = search_index_pytorch(index, target_features, k1) 49 | res.syncDefaultStreamCurrentDevice() 50 | initial_rank = initial_rank.cpu().numpy() 51 | elif (search_option==2): 52 | # GPU 53 | index = index_init_gpu(ngpus, target_features.size(-1)) 54 | index.add(target_features.cpu().numpy()) 55 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 56 | else: 57 | # CPU 58 | index = index_init_cpu(target_features.size(-1)) 59 | index.add(target_features.cpu().numpy()) 60 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 61 | 62 | 63 | nn_k1 = [] 64 | nn_k1_half = [] 65 | for i in range(N): 66 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 67 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1/2)))) 68 | 69 | V = np.zeros((N, N), dtype=mat_type) 70 | for i in range(N): 71 | k_reciprocal_index = nn_k1[i] 72 | k_reciprocal_expansion_index = k_reciprocal_index 73 | for candidate in k_reciprocal_index: 74 | candidate_k_reciprocal_index = nn_k1_half[candidate] 75 | if (len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index)) > 2/3*len(candidate_k_reciprocal_index)): 76 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 77 | 78 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 79 | dist = 2-2*torch.mm(target_features[i].unsqueeze(0).contiguous(), target_features[k_reciprocal_expansion_index].t()) 80 | if use_float16: 81 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type) 82 | else: 83 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy() 84 | 85 | del nn_k1, nn_k1_half 86 | 87 | if k2 != 1: 88 | V_qe = np.zeros_like(V, dtype=mat_type) 89 | for i in range(N): 90 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:], axis=0) 91 | V = V_qe 92 | del V_qe 93 | 94 | del initial_rank 95 | 96 | invIndex = [] 97 | for i in range(N): 98 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num 99 | 100 | jaccard_dist = np.zeros((N, N), dtype=mat_type) 101 | for i in range(N): 102 | temp_min = np.zeros((1,N), dtype=mat_type) 103 | # temp_max = np.zeros((1,N), dtype=mat_type) 104 | indNonZero = np.where(V[i,:] != 0)[0] 105 | indImages = [] 106 | indImages = [invIndex[ind] for ind in indNonZero] 107 | for j in range(len(indNonZero)): 108 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 109 | # temp_max[0,indImages[j]] = temp_max[0,indImages[j]]+np.maximum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 110 | 111 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 112 | # jaccard_dist[i] = 1-temp_min/(temp_max+1e-6) 113 | 114 | del invIndex, V 115 | 116 | pos_bool = (jaccard_dist < 0) 117 | jaccard_dist[pos_bool] = 0.0 118 | if print_flag: 119 | print ("Jaccard distance computing time cost: {}".format(time.time()-end)) 120 | 121 | return jaccard_dist 122 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | def swig_ptr_from_FloatTensor(x): 7 | assert x.is_contiguous() 8 | assert x.dtype == torch.float32 9 | return faiss.cast_integer_to_float_ptr( 10 | x.storage().data_ptr() + x.storage_offset() * 4) 11 | 12 | def swig_ptr_from_LongTensor(x): 13 | assert x.is_contiguous() 14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 15 | # return faiss.cast_integer_to_long_ptr( 16 | # x.storage().data_ptr() + x.storage_offset() * 8) 17 | return faiss.cast_integer_to_idx_t_ptr( 18 | x.storage().data_ptr() + x.storage_offset() * 8) 19 | 20 | def search_index_pytorch(index, x, k, D=None, I=None): 21 | """call the search function of an index with pytorch tensor I/O (CPU 22 | and GPU supported)""" 23 | assert x.is_contiguous() 24 | n, d = x.size() 25 | assert d == index.d 26 | 27 | if D is None: 28 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 29 | else: 30 | assert D.size() == (n, k) 31 | 32 | if I is None: 33 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 34 | else: 35 | assert I.size() == (n, k) 36 | torch.cuda.synchronize() 37 | xptr = swig_ptr_from_FloatTensor(x) 38 | Iptr = swig_ptr_from_LongTensor(I) 39 | Dptr = swig_ptr_from_FloatTensor(D) 40 | index.search_c(n, xptr, 41 | k, Dptr, Iptr) 42 | torch.cuda.synchronize() 43 | return D, I 44 | 45 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 46 | metric=faiss.METRIC_L2): 47 | assert xb.device == xq.device 48 | 49 | nq, d = xq.size() 50 | if xq.is_contiguous(): 51 | xq_row_major = True 52 | elif xq.t().is_contiguous(): 53 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 54 | xq_row_major = False 55 | else: 56 | raise TypeError('matrix should be row or column-major') 57 | 58 | xq_ptr = swig_ptr_from_FloatTensor(xq) 59 | 60 | nb, d2 = xb.size() 61 | assert d2 == d 62 | if xb.is_contiguous(): 63 | xb_row_major = True 64 | elif xb.t().is_contiguous(): 65 | xb = xb.t() 66 | xb_row_major = False 67 | else: 68 | raise TypeError('matrix should be row or column-major') 69 | xb_ptr = swig_ptr_from_FloatTensor(xb) 70 | 71 | if D is None: 72 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 73 | else: 74 | assert D.shape == (nq, k) 75 | assert D.device == xb.device 76 | 77 | if I is None: 78 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 79 | else: 80 | assert I.shape == (nq, k) 81 | assert I.device == xb.device 82 | 83 | D_ptr = swig_ptr_from_FloatTensor(D) 84 | I_ptr = swig_ptr_from_LongTensor(I) 85 | 86 | faiss.bruteForceKnn(res, metric, 87 | xb_ptr, xb_row_major, nb, 88 | xq_ptr, xq_row_major, nq, 89 | d, k, D_ptr, I_ptr) 90 | 91 | return D, I 92 | 93 | def index_init_gpu(ngpus, feat_dim): 94 | flat_config = [] 95 | for i in range(ngpus): 96 | cfg = faiss.GpuIndexFlatConfig() 97 | cfg.useFloat16 = False 98 | cfg.device = i 99 | flat_config.append(cfg) 100 | 101 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 102 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 103 | index = faiss.IndexShards(feat_dim) 104 | for sub_index in indexes: 105 | index.add_shard(sub_index) 106 | index.reset() 107 | return index 108 | 109 | def index_init_cpu(feat_dim): 110 | return faiss.IndexFlatL2(feat_dim) 111 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Source: https://github.com/zhunzhong07/person-re-ranking 5 | Created on Mon Jun 26 14:46:56 2017 6 | @author: luohao 7 | Modified by Houjing Huang, 2017-12-22. 8 | - This version accepts distance matrix instead of raw features. 9 | - The difference of `/` division between python 2 and 3 is handled. 10 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 11 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 12 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 13 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 14 | API 15 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 16 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 17 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 18 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 19 | Returns: 20 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import print_function 24 | from __future__ import division 25 | 26 | __all__ = ['re_ranking'] 27 | 28 | import numpy as np 29 | 30 | 31 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 32 | 33 | # The following naming, e.g. gallery_num, is different from outer scope. 34 | # Don't care about it. 35 | 36 | original_dist = np.concatenate( 37 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 38 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 39 | axis=0) 40 | original_dist = np.power(original_dist, 2).astype(np.float32) 41 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 42 | V = np.zeros_like(original_dist).astype(np.float32) 43 | initial_rank = np.argsort(original_dist).astype(np.int32) 44 | 45 | query_num = q_g_dist.shape[0] 46 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 47 | all_num = gallery_num 48 | 49 | for i in range(all_num): 50 | # k-reciprocal neighbors 51 | forward_k_neigh_index = initial_rank[i,:k1+1] 52 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 53 | fi = np.where(backward_k_neigh_index==i)[0] 54 | k_reciprocal_index = forward_k_neigh_index[fi] 55 | k_reciprocal_expansion_index = k_reciprocal_index 56 | for j in range(len(k_reciprocal_index)): 57 | candidate = k_reciprocal_index[j] 58 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 59 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 60 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 61 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 62 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 63 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 64 | 65 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 66 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 67 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 68 | original_dist = original_dist[:query_num,] 69 | if k2 != 1: 70 | V_qe = np.zeros_like(V,dtype=np.float32) 71 | for i in range(all_num): 72 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 73 | V = V_qe 74 | del V_qe 75 | del initial_rank 76 | invIndex = [] 77 | for i in range(gallery_num): 78 | invIndex.append(np.where(V[:,i] != 0)[0]) 79 | 80 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 81 | 82 | 83 | for i in range(query_num): 84 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 85 | indNonZero = np.where(V[i,:] != 0)[0] 86 | indImages = [] 87 | indImages = [invIndex[ind] for ind in indNonZero] 88 | for j in range(len(indNonZero)): 89 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 90 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 91 | 92 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 93 | del original_dist 94 | del V 95 | del jaccard_dist 96 | final_dist = final_dist[:query_num,query_num:] 97 | return final_dist 98 | -------------------------------------------------------------------------------- /SpCL-master/spcl/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | # checkpoint = torch.load(fpath) 34 | checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 35 | print("=> Loaded checkpoint '{}'".format(fpath)) 36 | return checkpoint 37 | else: 38 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 39 | 40 | 41 | def copy_state_dict(state_dict, model, strip=None): 42 | tgt_state = model.state_dict() 43 | copied_names = set() 44 | for name, param in state_dict.items(): 45 | if strip is not None and name.startswith(strip): 46 | name = name[len(strip):] 47 | if name not in tgt_state: 48 | continue 49 | if isinstance(param, Parameter): 50 | param = param.data 51 | if param.size() != tgt_state[name].size(): 52 | print('mismatch:', name, param.size(), tgt_state[name].size()) 53 | continue 54 | tgt_state[name].copy_(param) 55 | copied_names.add(name) 56 | 57 | missing = set(tgt_state.keys()) - copied_names 58 | if len(missing) > 0: 59 | print("missing keys in state_dict:", missing) 60 | 61 | return model 62 | -------------------------------------------------------------------------------- /config/config_regdb.yaml: -------------------------------------------------------------------------------- 1 | ## Note: color = rgb = visible, thermal = ir = infrared. 2 | 3 | ## dataset parameters 4 | dataset: regdb # sysu or regdb 5 | dataset_path: ../../dataset/ # dataset root path 6 | trial: 1 # only for regdb test 7 | mode: visibletothermal # all or indoor (sysu test), thermaltovisible or visibletothermal (regdb test) 8 | workers: 4 # number of data loading workers (default: 4) 9 | dataset_num_size: 2 # the multiple of dataset size per trainloader 10 | 11 | ## model parameters 12 | arch: resnet50 # network baseline 13 | pool_dim: 2048 # pooling dim: 2048 for resnet50 14 | per_add_iters: 5 # number of iters adding to coefficient of GRL for each training batch 15 | lambda_sk: 25 # hyperparameter for Sinkhorn-Knopp algorithm 16 | 17 | ## optimizer parameters 18 | optim: adam # optimizer: adam 19 | lr: 0.0035 # learning rate: 0.0035 for adam 20 | 21 | ## normal parameters 22 | file_name: otla-reid/ # log file name 23 | setting: semi-supervised # training setting: supervised or semi-supervised or unsupervised 24 | train_visible_image_path: ../../dataset/RegDB/spcl_uda_market1501TOregdb_rgb_train_rgb_resized_img.npy # the stored visible image path getting from USL-ReID or UDA-ReID methods for unsupervised setting 25 | train_visible_label_path: ../../dataset/RegDB/spcl_uda_market1501TOregdb_rgb_train_rgb_resized_label.npy # the stored visible label path getting from USL-ReID or UDA-ReID methods for unsupervised setting 26 | seed: 0 # random seed 27 | gpu: 0 # gpu device ids for CUDA_VISIBLE_DEVICES 28 | model_path: save_model/ # model save path 29 | log_path: log/ # log save path 30 | vis_log_path: vis_log/ # tensorboard log save path 31 | save_epoch: 10 # save model every few epochs 32 | img_w: 144 # image width 33 | img_h: 288 # image height 34 | train_batch_size: 4 # training batch size: 4 35 | num_pos: 8 # number of pos per identity for each modality: 8 36 | test_batch_size: 64 # testing batch size 37 | start_epoch: 0 # start training epoch 38 | end_epoch: 81 # end training epoch 39 | eval_epoch: 1 # testing epochs 40 | 41 | ## loss parameters 42 | margin: 0.3 # triplet loss margin 43 | lambda_vr: 0.1 # coefficient of prediction alignment loss 44 | lambda_rv: 0.5 # coefficient of prediction alignment loss -------------------------------------------------------------------------------- /config/config_sysu.yaml: -------------------------------------------------------------------------------- 1 | ## Note: color = rgb = visible, thermal = ir = infrared. 2 | 3 | ## dataset parameters 4 | dataset: sysu # sysu or regdb 5 | dataset_path: ../../dataset/ # dataset root path 6 | mode: all # all or indoor (sysu test), thermaltovisible or visibletothermal (regdb test) 7 | workers: 4 # number of data loading workers (default: 4) 8 | dataset_num_size: 1 # the multiple of dataset size per trainloader 9 | 10 | ## model parameters 11 | arch: resnet50 # network baseline 12 | pool_dim: 2048 # pooling dim: 2048 for resnet50 13 | per_add_iters: 1 # number of iters adding to coefficient of GRL for each training batch 14 | lambda_sk: 25 # hyperparameter for Sinkhorn-Knopp algorithm 15 | 16 | ## optimizer parameters 17 | optim: adam # optimizer: adam 18 | lr: 0.0035 # learning rate: 0.0035 for adam 19 | 20 | ## normal parameters 21 | file_name: otla-reid/ # log file name 22 | setting: semi-supervised # training setting: supervised or semi-supervised or unsupervised 23 | train_visible_image_path: ../../dataset/SYSU-MM01/spcl_uda_market1501TOsysumm01_rgb_train_rgb_resized_img.npy # the stored visible image path getting from USL-ReID or UDA-ReID methods for unsupervised setting 24 | train_visible_label_path: ../../dataset/SYSU-MM01/spcl_uda_market1501TOsysumm01_rgb_train_rgb_resized_label.npy # the stored visible label path getting from USL-ReID or UDA-ReID methods for unsupervised setting 25 | seed: 0 # random seed 26 | gpu: 0 # gpu device ids for CUDA_VISIBLE_DEVICES 27 | model_path: save_model/ # model save path 28 | log_path: log/ # log save path 29 | vis_log_path: vis_log/ # tensorboard log save path 30 | save_epoch: 10 # save model every few epochs 31 | img_w: 144 # image width 32 | img_h: 288 # image height 33 | train_batch_size: 4 # training batch size: 4 34 | num_pos: 8 # number of pos per identity for each modality: 8 35 | test_batch_size: 64 # testing batch size 36 | start_epoch: 0 # start training epoch 37 | end_epoch: 81 # end training epoch 38 | eval_epoch: 1 # testing epochs 39 | 40 | ## loss parameters 41 | margin: 0.3 # triplet loss margin 42 | lambda_vr: 0.1 # coefficient of prediction alignment loss 43 | lambda_rv: 0.5 # coefficient of prediction alignment loss -------------------------------------------------------------------------------- /data_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | 5 | 6 | def process_query_sysu(data_path, mode="all"): 7 | if mode == "all": 8 | ir_cameras = ["cam3", "cam6"] 9 | elif mode == "indoor": 10 | ir_cameras = ["cam3", "cam6"] 11 | 12 | file_path = os.path.join(data_path, "exp/test_id.txt") 13 | files_ir = [] 14 | 15 | with open(file_path, 'r') as file: 16 | ids = file.read().splitlines() 17 | ids = [int(y) for y in ids[0].split(',')] 18 | ids = ["%04d" % x for x in ids] 19 | 20 | for id in sorted(ids): 21 | for cam in ir_cameras: 22 | img_dir = os.path.join(data_path, cam, id) 23 | if os.path.isdir(img_dir): 24 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)]) 25 | files_ir.extend(new_files) 26 | 27 | query_img = [] 28 | query_id = [] 29 | query_cam = [] 30 | for img_path in files_ir: 31 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 32 | query_img.append(img_path) 33 | query_id.append(pid) 34 | query_cam.append(camid) 35 | 36 | return query_img, np.array(query_id), np.array(query_cam) 37 | 38 | 39 | def process_gallery_sysu(data_path, mode="all", trial=0): 40 | random.seed(trial) 41 | 42 | if mode == "all": 43 | rgb_cameras = ["cam1", "cam2", "cam4", "cam5"] 44 | elif mode == "indoor": 45 | rgb_cameras = ["cam1", "cam2"] 46 | 47 | file_path = os.path.join(data_path, "exp/test_id.txt") 48 | files_rgb = [] 49 | with open(file_path, 'r') as file: 50 | ids = file.read().splitlines() 51 | ids = [int(y) for y in ids[0].split(',')] 52 | ids = ["%04d" % x for x in ids] 53 | 54 | for id in sorted(ids): 55 | for cam in rgb_cameras: 56 | img_dir = os.path.join(data_path, cam, id) 57 | if os.path.isdir(img_dir): 58 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)]) 59 | files_rgb.append(random.choice(new_files)) 60 | 61 | gall_img = [] 62 | gall_id = [] 63 | gall_cam = [] 64 | for img_path in files_rgb: 65 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 66 | gall_img.append(img_path) 67 | gall_id.append(pid) 68 | gall_cam.append(camid) 69 | 70 | return gall_img, np.array(gall_id), np.array(gall_cam) 71 | 72 | 73 | def process_test_regdb(img_dir, trial=1, modality="visible"): 74 | if modality == "visible": 75 | input_data_path = os.path.join(img_dir, "idx/test_visible_{}".format(trial) + ".txt") 76 | elif modality == "thermal": 77 | input_data_path = os.path.join(img_dir, "idx/test_thermal_{}".format(trial) + ".txt") 78 | 79 | with open(input_data_path) as f: 80 | data_file_list = open(input_data_path, 'rt').read().splitlines() 81 | # Get full list of image and labels 82 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 83 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 84 | 85 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | from utils import AverageMeter 6 | from eval_metrics import eval_regdb, eval_sysu 7 | 8 | 9 | def trainer(args, epoch, main_net, adjust_learning_rate, optimizer, trainloader, criterion, writer=None, print_freq=50): 10 | current_lr = adjust_learning_rate(args, optimizer, epoch) 11 | 12 | total_loss = AverageMeter() 13 | id_loss_rgb = AverageMeter() 14 | id_loss_ir = AverageMeter() 15 | tri_loss_rgb = AverageMeter() 16 | tri_loss_ir = AverageMeter() 17 | dis_loss = AverageMeter() 18 | pa_loss = AverageMeter() 19 | batch_time = AverageMeter() 20 | 21 | correct_tri_rgb = 0 22 | correct_tri_ir = 0 23 | pre_rgb = 0 # it is meaningful only in the case of semi supervised setting 24 | pre_ir = 0 # it is meaningful only in the case of semi supervised setting 25 | pre_rgb_ir = 0 # it is meaningful only in the case of semi supervised setting, whether labels of selected samples per batch are equal 26 | num_rgb = 0 27 | num_ir = 0 28 | 29 | main_net.train() # switch to train mode 30 | end = time.time() 31 | 32 | for batch_id, (input_rgb, input_ir, label_rgb, label_ir) in enumerate(trainloader): 33 | # label_ir is only used to calculate the prediction accuracy of pseudo infrared labels on semi-supervised setting 34 | # label_ir is meaningless on unsupervised setting 35 | # for supervised setting, we change "label_rgb" of "loss_id_ir" and "loss_tri_ir" into "label_ir" 36 | 37 | label_rgb = label_rgb.cuda() 38 | label_ir = label_ir.cuda() 39 | input_rgb = input_rgb.cuda() 40 | input_ir = input_ir.cuda() 41 | 42 | feat, output_cls, output_dis = main_net(input_rgb, input_ir, modal=0, train_set=True) 43 | 44 | loss_id_rgb = criterion[0](output_cls[:input_rgb.size(0)], label_rgb) 45 | loss_tri_rgb, correct_tri_batch_rgb = criterion[1](feat[:input_rgb.size(0)], label_rgb) 46 | 47 | if args.setting == "semi-supervised" or args.setting == "unsupervised": 48 | loss_id_ir = criterion[0](output_cls[input_rgb.size(0):], label_rgb) 49 | loss_tri_ir, correct_tri_batch_ir = criterion[1](feat[input_rgb.size(0):], label_rgb) 50 | elif args.setting == "supervised": 51 | loss_id_ir = criterion[0](output_cls[input_rgb.size(0):], label_ir) 52 | loss_tri_ir, correct_tri_batch_ir = criterion[1](feat[input_rgb.size(0):], label_ir) 53 | 54 | dis_label = torch.cat((torch.ones(input_rgb.size(0)), torch.zeros(input_ir.size(0))), dim=0).cuda() 55 | loss_dis = criterion[2](output_dis.view(-1), dis_label) 56 | 57 | loss_pa, sim_rgbtoir, sim_irtorgb = criterion[3](output_cls[:input_rgb.size(0)], output_cls[input_rgb.size(0):]) 58 | 59 | loss = loss_id_rgb + loss_tri_rgb + 0.1 * loss_id_ir + 0.5 * loss_tri_ir + loss_dis + loss_pa 60 | 61 | optimizer.zero_grad() 62 | loss.backward() 63 | optimizer.step() 64 | 65 | correct_tri_rgb += correct_tri_batch_rgb 66 | correct_tri_ir += correct_tri_batch_ir 67 | _, pre_label = output_cls.max(1) 68 | pre_batch_rgb = (pre_label[:input_rgb.size(0)].eq(label_rgb).sum().item()) 69 | pre_batch_ir = (pre_label[input_rgb.size(0):].eq(label_ir).sum().item()) 70 | pre_batch_rgb_ir = (label_rgb.eq(label_ir).sum().item()) 71 | pre_rgb += pre_batch_rgb 72 | pre_ir += pre_batch_ir 73 | pre_rgb_ir += pre_batch_rgb_ir 74 | num_rgb += input_rgb.size(0) 75 | num_ir += input_ir.size(0) 76 | assert num_rgb == num_ir 77 | 78 | total_loss.update(loss.item(), input_rgb.size(0) + input_ir.size(0)) 79 | id_loss_rgb.update(loss_id_rgb.item(), input_rgb.size(0)) 80 | id_loss_ir.update(loss_id_ir.item(), input_ir.size(0)) 81 | tri_loss_rgb.update(loss_tri_rgb, input_rgb.size(0)) 82 | tri_loss_ir.update(loss_tri_ir, input_ir.size(0)) 83 | dis_loss.update(loss_dis, input_rgb.size(0) + input_ir.size(0)) 84 | pa_loss.update(loss_pa.item(), input_rgb.size(0) + input_ir.size(0)) 85 | 86 | # measure elapsed time 87 | batch_time.update(time.time() - end) 88 | end = time.time() 89 | 90 | if batch_id % print_freq == 0: 91 | print("Epoch: [{}][{}/{}] " 92 | "Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) " 93 | "Lr: {:.6f} " 94 | "Coeff: {:.3f} " 95 | "Total_Loss: {total_loss.val:.4f}({total_loss.avg:.4f}) " 96 | "ID_Loss_RGB: {id_loss_rgb.val:.4f}({id_loss_rgb.avg:.4f}) " 97 | "ID_Loss_IR: {id_loss_ir.val:.4f}({id_loss_ir.avg:.4f}) " 98 | "Tri_Loss_RGB: {tri_loss_rgb.val:.4f}({tri_loss_rgb.avg:.4f}) " 99 | "Tri_Loss_IR: {tri_loss_ir.val:.4f}({tri_loss_ir.avg:.4f}) " 100 | "Dis_Loss: {dis_loss.val:.4f}({dis_loss.avg:.4f}) " 101 | "Pa_Loss: {pa_loss.val:.4f}({pa_loss.avg:.4f}) " 102 | "Tri_RGB_Acc: {:.2f}% " 103 | "Tri_IR_Acc: {:.2f}% " 104 | "Pre_RGB_Acc: {:.2f}% " 105 | "Pre_IR_Acc: {:.2f}% " 106 | "Pre_RGB_IR_Acc: {:.2f}% ".format(epoch, batch_id, len(trainloader), current_lr, main_net.adnet.coeff, 107 | 100. * correct_tri_rgb / num_rgb, 108 | 100. * correct_tri_ir / num_ir, 109 | 100. * pre_rgb / num_rgb, 110 | 100. * pre_ir / num_ir, 111 | 100. * pre_rgb_ir / num_rgb, 112 | batch_time=batch_time, 113 | total_loss=total_loss, 114 | id_loss_rgb=id_loss_rgb, 115 | id_loss_ir=id_loss_ir, 116 | tri_loss_rgb=tri_loss_rgb, 117 | tri_loss_ir=tri_loss_ir, 118 | dis_loss=dis_loss, 119 | pa_loss=pa_loss)) 120 | 121 | if writer is not None: 122 | writer.add_scalar("Lr", current_lr, epoch) 123 | writer.add_scalar("Coeff", main_net.adnet.coeff, epoch) 124 | writer.add_scalar("Total_Loss", total_loss.avg, epoch) 125 | writer.add_scalar("ID_Loss_RGB", id_loss_rgb.avg, epoch) 126 | writer.add_scalar("ID_Loss_IR", id_loss_ir.avg, epoch) 127 | writer.add_scalar("Tri_Loss_RGB", tri_loss_rgb.avg, epoch) 128 | writer.add_scalar("Tri_Loss_IR", tri_loss_ir.avg, epoch) 129 | writer.add_scalar("Dis_Loss", dis_loss.avg, epoch) 130 | writer.add_scalar("Pa_Loss", pa_loss.avg, epoch) 131 | writer.add_scalar("Tri_RGB_Acc", 100. * correct_tri_rgb / num_rgb, epoch) 132 | writer.add_scalar("Tri_IR_Acc", 100. * correct_tri_ir / num_ir, epoch) 133 | writer.add_scalar("Pre_RGB_Acc", 100. * pre_rgb / num_rgb, epoch) 134 | writer.add_scalar("Pre_IR_Acc", 100. * pre_ir / num_ir, epoch) 135 | 136 | 137 | def tester(args, epoch, main_net, test_mode, gall_label, gall_loader, query_label, query_loader, feat_dim=2048, query_cam=None, gall_cam=None, writer=None): 138 | # switch to evaluation mode 139 | main_net.eval() 140 | 141 | print("Extracting Gallery Feature...") 142 | ngall = len(gall_label) 143 | start = time.time() 144 | ptr = 0 145 | gall_feat = np.zeros((ngall, feat_dim)) 146 | with torch.no_grad(): 147 | for batch_idx, (input, label) in enumerate(gall_loader): 148 | batch_num = input.size(0) 149 | input = Variable(input.cuda()) 150 | feat = main_net(input, input, modal=test_mode[0]) 151 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 152 | ptr = ptr + batch_num 153 | print("Extracting Time:\t {:.3f}".format(time.time() - start)) 154 | 155 | print("Extracting Query Feature...") 156 | nquery = len(query_label) 157 | start = time.time() 158 | ptr = 0 159 | query_feat = np.zeros((nquery, feat_dim)) 160 | with torch.no_grad(): 161 | for batch_idx, (input, label) in enumerate(query_loader): 162 | batch_num = input.size(0) 163 | input = Variable(input.cuda()) 164 | feat = main_net(input, input, modal=test_mode[1]) 165 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 166 | ptr = ptr + batch_num 167 | print("Extracting Time:\t {:.3f}".format(time.time() - start)) 168 | 169 | start = time.time() 170 | # compute the similarity 171 | distmat = -np.matmul(query_feat, np.transpose(gall_feat)) 172 | # evaluation 173 | if args.dataset == "sysu": 174 | cmc, mAP, mINP = eval_sysu(distmat, query_label, gall_label, query_cam, gall_cam) 175 | elif args.dataset == "regdb": 176 | cmc, mAP, mINP = eval_regdb(distmat, query_label, gall_label) 177 | print("Evaluation Time:\t {:.3f}".format(time.time() - start)) 178 | 179 | if writer is not None: 180 | writer.add_scalar("Rank1", cmc[0], epoch) 181 | writer.add_scalar("mAP", mAP, epoch) 182 | writer.add_scalar("mINP", mINP, epoch) 183 | 184 | return cmc, mAP, mINP -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=20): 5 | """ 6 | Evaluation with SYSU-MM01 metric. 7 | Note: For each query identity, its gallery images from the same camera view are discarded, 8 | which follows the original setting in "RGB-Infrared Cross-Modality Person Re-Identificatio, ICCV 2017". 9 | """ 10 | num_q, num_g = distmat.shape 11 | if num_g < max_rank: 12 | max_rank = num_g 13 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 14 | indices = np.argsort(distmat, axis=1) 15 | pred_label = g_pids[indices] 16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 17 | 18 | # compute cmc curve for each query 19 | new_all_cmc = [] 20 | all_cmc = [] 21 | all_AP = [] 22 | all_INP = [] 23 | num_valid_q = 0. # number of valid query 24 | for q_idx in range(num_q): 25 | # get query pid and camid 26 | q_pid = q_pids[q_idx] 27 | q_camid = q_camids[q_idx] 28 | 29 | # remove gallery samples that have the same pid and camid with query 30 | order = indices[q_idx] 31 | remove = (q_camid == 3) & (g_camids[order] == 2) 32 | keep = np.invert(remove) 33 | 34 | # compute cmc curve 35 | # the cmc calculation is different from standard protocol 36 | # we follow the protocol of the author's released code 37 | new_cmc = pred_label[q_idx][keep] 38 | new_index = np.unique(new_cmc, return_index=True)[1] 39 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 40 | 41 | new_match = (new_cmc == q_pid).astype(np.int32) 42 | new_cmc = new_match.cumsum() 43 | new_all_cmc.append(new_cmc[:max_rank]) 44 | 45 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 46 | if not np.any(orig_cmc): 47 | # this condition is true when query identity does not appear in gallery 48 | continue 49 | 50 | cmc = orig_cmc.cumsum() 51 | 52 | # compute mINP 53 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 54 | pos_idx = np.where(orig_cmc == 1) 55 | pos_max_idx = np.max(pos_idx) 56 | inp = cmc[pos_max_idx] / (pos_max_idx + 1.0) 57 | all_INP.append(inp) 58 | 59 | cmc[cmc > 1] = 1 60 | 61 | all_cmc.append(cmc[:max_rank]) 62 | num_valid_q += 1. 63 | 64 | # compute average precision 65 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 66 | num_rel = orig_cmc.sum() 67 | tmp_cmc = orig_cmc.cumsum() 68 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 69 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 70 | AP = tmp_cmc.sum() / num_rel 71 | all_AP.append(AP) 72 | 73 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 74 | 75 | all_cmc = np.asarray(all_cmc).astype(np.float32) 76 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 77 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 78 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 79 | mAP = np.mean(all_AP) 80 | mINP = np.mean(all_INP) 81 | 82 | return new_all_cmc, mAP, mINP 83 | 84 | 85 | def eval_regdb(distmat, q_pids, g_pids, max_rank=20): 86 | """ 87 | Evaluation with RegDB metric. 88 | """ 89 | num_q, num_g = distmat.shape 90 | if num_g < max_rank: 91 | max_rank = num_g 92 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 93 | indices = np.argsort(distmat, axis=1) 94 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 95 | 96 | # compute cmc curve for each query 97 | all_cmc = [] 98 | all_AP = [] 99 | all_INP = [] 100 | num_valid_q = 0. # number of valid query 101 | 102 | # only two cameras 103 | q_camids = np.ones(num_q).astype(np.int32) 104 | g_camids = 2 * np.ones(num_g).astype(np.int32) 105 | 106 | for q_idx in range(num_q): 107 | # get query pid and camid 108 | q_pid = q_pids[q_idx] 109 | q_camid = q_camids[q_idx] 110 | 111 | # remove gallery samples that have the same pid and camid with query 112 | order = indices[q_idx] 113 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 114 | keep = np.invert(remove) 115 | 116 | # compute cmc curve 117 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 118 | if not np.any(raw_cmc): 119 | # this condition is true when query identity does not appear in gallery 120 | continue 121 | 122 | cmc = raw_cmc.cumsum() 123 | 124 | # compute mINP 125 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 126 | pos_idx = np.where(raw_cmc == 1) 127 | pos_max_idx = np.max(pos_idx) 128 | inp = cmc[pos_max_idx] / (pos_max_idx + 1.0) 129 | all_INP.append(inp) 130 | 131 | cmc[cmc > 1] = 1 132 | 133 | all_cmc.append(cmc[:max_rank]) 134 | num_valid_q += 1. 135 | 136 | # compute average precision 137 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 138 | num_rel = raw_cmc.sum() 139 | tmp_cmc = raw_cmc.cumsum() 140 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 141 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 142 | AP = tmp_cmc.sum() / num_rel 143 | all_AP.append(AP) 144 | 145 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 146 | 147 | all_cmc = np.asarray(all_cmc).astype(np.float32) 148 | all_cmc = all_cmc.sum(0) / num_valid_q 149 | mAP = np.mean(all_AP) 150 | mINP = np.mean(all_INP) 151 | 152 | return all_cmc, mAP, mINP -------------------------------------------------------------------------------- /image/main_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjm-wjm/OTLA-ReID/1405e96bd8339deeebf61718346b24722770ac61/image/main_figure.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def normalize(x, axis=-1): 6 | """ 7 | Normalizing to unit length along the specified dimension. 8 | """ 9 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 10 | return x 11 | 12 | 13 | class TripletLoss(nn.Module): 14 | """ 15 | Triplet loss with hard positive/negative mining. 16 | Reference: Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 17 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 18 | Args: 19 | - margin (float): margin for triplet. 20 | - inputs: feature matrix with shape (batch_size, feat_dim). 21 | - targets: ground truth labels with shape (num_classes). 22 | """ 23 | def __init__(self, margin=0.3): 24 | super(TripletLoss, self).__init__() 25 | self.margin = margin 26 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 27 | 28 | def forward(self, inputs, targets): 29 | n = inputs.size(0) 30 | 31 | # Compute pairwise distance, replace by the official when merged 32 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 33 | dist = dist + dist.t() 34 | dist.addmm_(1, -2, inputs, inputs.t()) 35 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 36 | 37 | # For each anchor, find the hardest positive and negative 38 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 39 | dist_ap, dist_an = [], [] 40 | for i in range(n): 41 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 42 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 43 | dist_ap = torch.cat(dist_ap) 44 | dist_an = torch.cat(dist_an) 45 | 46 | # Compute ranking hinge loss 47 | y = torch.ones_like(dist_an) 48 | loss = self.ranking_loss(dist_an, dist_ap, y) 49 | 50 | # compute accuracy 51 | correct = torch.ge(dist_an, dist_ap).sum().item() # torch.eq: greater than or equal to >= 52 | 53 | return loss, correct 54 | 55 | 56 | class PredictionAlignmentLoss(nn.Module): 57 | """ 58 | Proposed loss for Prediction Alignment Learning (PAL). 59 | """ 60 | def __init__(self, lambda_vr=0.1, lambda_rv=0.5): 61 | super(PredictionAlignmentLoss, self).__init__() 62 | self.lambda_vr = lambda_vr 63 | self.lambda_rv = lambda_rv 64 | 65 | def forward(self, x_rgb, x_ir): 66 | sim_rgbtoir = torch.mm(normalize(x_rgb), normalize(x_ir).t()) 67 | sim_irtorgb = torch.mm(normalize(x_ir), normalize(x_rgb).t()) 68 | sim_irtoir = torch.mm(normalize(x_ir), normalize(x_ir).t()) 69 | 70 | sim_rgbtoir = nn.Softmax(1)(sim_rgbtoir) 71 | sim_irtorgb = nn.Softmax(1)(sim_irtorgb) 72 | sim_irtoir = nn.Softmax(1)(sim_irtoir) 73 | 74 | KL_criterion = nn.KLDivLoss(reduction="batchmean") 75 | 76 | x_rgbtoir = torch.mm(sim_rgbtoir, x_ir) 77 | x_irtorgb = torch.mm(sim_irtorgb, x_rgb) 78 | x_irtoir = torch.mm(sim_irtoir, x_ir) 79 | 80 | x_rgb_s = nn.Softmax(1)(x_rgb) 81 | x_rgbtoir_ls = nn.LogSoftmax(1)(x_rgbtoir) 82 | x_irtorgb_s = nn.Softmax(1)(x_irtorgb) 83 | x_irtoir_ls = nn.LogSoftmax(1)(x_irtoir) 84 | 85 | loss_rgbtoir = KL_criterion(x_rgbtoir_ls, x_rgb_s) 86 | loss_irtorgb = KL_criterion(x_irtoir_ls, x_irtorgb_s) 87 | 88 | loss = self.lambda_vr * loss_rgbtoir + self.lambda_rv * loss_irtorgb 89 | 90 | return loss, sim_rgbtoir, sim_irtorgb -------------------------------------------------------------------------------- /main_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import easydict 3 | import sys 4 | import os 5 | import time 6 | import yaml 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.data as data 11 | import torchvision.transforms as transforms 12 | from utils import Logger, set_seed, GenIdx 13 | from data_loader import TestData 14 | from data_manager import process_query_sysu, process_gallery_sysu, process_test_regdb 15 | from model.network import BaseResNet 16 | from engine import tester 17 | 18 | 19 | def main_worker(args, args_main): 20 | ## set gpu id and seed id 21 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 22 | torch.backends.cudnn.benchmark = True # accelerate the running speed of convolution network 23 | device = "cuda" if torch.cuda.is_available() else "cpu" 24 | set_seed(args.seed, cuda=torch.cuda.is_available()) 25 | 26 | ## set file 27 | if not os.path.isdir(args.dataset + "_" + args.setting + "_" + args.file_name): 28 | os.makedirs(args.dataset + "_" + args.setting + "_" + args.file_name) 29 | file_name = args.dataset + "_" + args.setting + "_" + args.file_name 30 | 31 | if args.dataset == "sysu": 32 | data_path = args.dataset_path + "SYSU-MM01/" 33 | log_path = os.path.join(file_name, args.dataset + "_" + args.log_path) 34 | test_mode = [1, 2] 35 | elif args.dataset == "regdb": 36 | data_path = args.dataset_path + "RegDB/" 37 | log_path = os.path.join(file_name, args.dataset + "_" + args.log_path) 38 | if args.mode == "thermaltovisible": 39 | test_mode = [1, 2] 40 | elif args.mode == "visibletothermal": 41 | test_mode = [2, 1] 42 | 43 | if not os.path.isdir(log_path): 44 | os.makedirs(log_path) 45 | 46 | sys.stdout = Logger(os.path.join(log_path, "log_test.txt")) 47 | 48 | ## load data 49 | print("==========\nargs_main:{}\n==========".format(args_main)) 50 | print("==========\nargs:{}\n==========".format(args)) 51 | print("==> Loading data...") 52 | 53 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 54 | transform_test = transforms.Compose([ 55 | transforms.ToPILImage(), 56 | transforms.Resize((args.img_h, args.img_w)), 57 | transforms.ToTensor(), 58 | normalize, 59 | ]) 60 | 61 | end = time.time() 62 | if args.dataset == "sysu": 63 | # testing set 64 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 65 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode) 66 | elif args.dataset == "regdb": 67 | # testing set 68 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modality=args.mode.split("to")[0]) 69 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modality=args.mode.split("to")[1]) 70 | 71 | gallset = TestData(gall_img, gall_label, transform_test=transform_test, img_size=(args.img_w, args.img_h)) 72 | queryset = TestData(query_img, query_label, transform_test=transform_test, img_size=(args.img_w, args.img_h)) 73 | 74 | # testing data loader 75 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers) 76 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers) 77 | 78 | print("Dataset {} Statistics:".format(args.dataset)) 79 | print(" ----------------------------") 80 | print(" subset | # ids | # images") 81 | print(" ----------------------------") 82 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), len(query_label))) 83 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), len(gall_label))) 84 | print(" ----------------------------") 85 | print("Data loading time:\t {:.3f}".format(time.time() - end)) 86 | 87 | if args.dataset == "sysu": 88 | n_class = 395 # initial value 89 | elif args.dataset == "regdb": 90 | n_class = 206 # initial value 91 | else: 92 | n_class = 1000 # initial value 93 | epoch = 0 # initial value 94 | 95 | ## resume checkpoints 96 | if args_main.resume: 97 | resume_path = args_main.resume_path 98 | if os.path.isfile(resume_path): 99 | checkpoint = torch.load(resume_path) 100 | if "main_net" in checkpoint.keys(): 101 | n_class = checkpoint["main_net"]["classifier.weight"].size(0) 102 | elif "net" in checkpoint.keys(): 103 | n_class = checkpoint["net"]["classifier.weight"].size(0) 104 | epoch = checkpoint["epoch"] 105 | print("==> Loading checkpoint {} (epoch {}, number of classes {})".format(resume_path, epoch, n_class)) 106 | else: 107 | print("==> No checkpoint is found at {} (epoch {}, number of classes {})".format(resume_path, epoch, n_class)) 108 | else: 109 | print("==> No checkpont is loaded (epoch {}, number of classes {})".format(epoch, n_class)) 110 | 111 | ## build model 112 | main_net = BaseResNet(pool_dim=args.pool_dim, class_num=n_class, per_add_iters=args.per_add_iters, arch=args.arch) 113 | if args_main.resume and os.path.isfile(resume_path): 114 | if "main_net" in checkpoint.keys(): 115 | main_net.load_state_dict(checkpoint["main_net"]) 116 | elif "net" in checkpoint.keys(): 117 | main_net.load_state_dict(checkpoint["net"]) 118 | main_net.to(device) 119 | 120 | # start testing 121 | if args.dataset == "sysu": 122 | print("Testing Epoch: {}, Testing mode: {}".format(epoch, args.mode)) 123 | elif args.dataset == "regdb": 124 | print("Testing Epoch: {}, Testing mode: {}, Trial: {}".format(epoch, args.mode, args.trial)) 125 | 126 | end = time.time() 127 | if args.dataset == "sysu": 128 | cmc, mAP, mINP = tester(args, epoch, main_net, test_mode, gall_label, gall_loader, query_label, query_loader, feat_dim=args.pool_dim, query_cam=query_cam, gall_cam=gall_cam) 129 | elif args.dataset == "regdb": 130 | cmc, mAP, mINP = tester(args, epoch, main_net, test_mode, gall_label, gall_loader, query_label, query_loader, feat_dim=args.pool_dim) 131 | print("Testing time per epoch: {:.3f}".format(time.time() - end)) 132 | 133 | print("Performance: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}".format(cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 134 | 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser(description="OTLA-ReID for testing") 138 | parser.add_argument("--config", default="config/baseline.yaml", help="config file") 139 | parser.add_argument("--resume", action="store_true", help="resume from checkpoint") 140 | parser.add_argument("--resume_path", default="", help="checkpoint path") 141 | 142 | args_main = parser.parse_args() 143 | args = yaml.load(open(args_main.config), Loader=yaml.FullLoader) 144 | args = easydict.EasyDict(args) 145 | 146 | main_worker(args, args_main) -------------------------------------------------------------------------------- /model/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | # original padding is 1; original dilation is 1 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, bias=False, dilation=dilation) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | # original padding is 1; original dilation is 1 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * 4) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | if self.downsample is not None: 85 | residual = self.downsample(x) 86 | 87 | out += residual 88 | out = self.relu(out) 89 | 90 | return out 91 | 92 | 93 | class ResNet(nn.Module): 94 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1): 95 | self.inplanes = 64 96 | super(ResNet, self).__init__() 97 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 98 | self.bn1 = nn.BatchNorm2d(64) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 101 | self.layer1 = self._make_layer(block, 64, layers[0]) 102 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 103 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 104 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 105 | 106 | for m in self.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 109 | m.weight.data.normal_(0, math.sqrt(2. / n)) 110 | elif isinstance(m, nn.BatchNorm2d): 111 | m.weight.data.fill_(1) 112 | m.bias.data.zero_() 113 | 114 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 115 | downsample = None 116 | if stride != 1 or self.inplanes != planes * block.expansion: 117 | downsample = nn.Sequential( 118 | nn.Conv2d(self.inplanes, planes * block.expansion, 119 | kernel_size=1, stride=stride, bias=False), 120 | nn.BatchNorm2d(planes * block.expansion), 121 | ) 122 | 123 | layers = [] 124 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 125 | self.inplanes = planes * block.expansion 126 | for i in range(1, blocks): 127 | layers.append(block(self.inplanes, planes)) 128 | 129 | return nn.Sequential(*layers) 130 | 131 | def forward(self, x): 132 | x = self.conv1(x) 133 | x = self.bn1(x) 134 | x = self.relu(x) 135 | x = self.maxpool(x) 136 | 137 | x = self.layer1(x) 138 | x = self.layer2(x) 139 | x = self.layer3(x) 140 | x = self.layer4(x) 141 | 142 | return x 143 | 144 | 145 | def remove_fc(state_dict): 146 | """Remove the fc layer parameters from state_dict.""" 147 | # for key, value in state_dict.items(): 148 | for key, value in list(state_dict.items()): 149 | if key.startswith('fc.'): 150 | del state_dict[key] 151 | 152 | return state_dict 153 | 154 | 155 | def resnet18(pretrained=False, **kwargs): 156 | """Constructs a ResNet-18 model. 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | """ 160 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 161 | if pretrained: 162 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 163 | 164 | return model 165 | 166 | 167 | def resnet34(pretrained=False, **kwargs): 168 | """Constructs a ResNet-34 model. 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | """ 172 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 173 | if pretrained: 174 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 175 | 176 | return model 177 | 178 | 179 | def resnet50(pretrained=False, **kwargs): 180 | """Constructs a ResNet-50 model. 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 185 | if pretrained: 186 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 187 | 188 | return model 189 | 190 | 191 | def resnet101(pretrained=False, **kwargs): 192 | """Constructs a ResNet-101 model. 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 197 | if pretrained: 198 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet101']))) 199 | 200 | return model 201 | 202 | 203 | def resnet152(pretrained=False, **kwargs): 204 | """Constructs a ResNet-152 model. 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 209 | if pretrained: 210 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet152']))) 211 | 212 | return model -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | from .backbone.resnet import resnet50 6 | 7 | 8 | class Normalize(nn.Module): 9 | def __init__(self, power=2): 10 | super(Normalize, self).__init__() 11 | self.power = power 12 | 13 | def forward(self, x): 14 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 15 | out = x.div(norm) 16 | return out 17 | 18 | 19 | def weights_init_kaiming(m): 20 | classname = m.__class__.__name__ 21 | if classname.find("Conv") != -1: 22 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 23 | elif classname.find("Linear") != -1: 24 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_out") 25 | init.zeros_(m.bias.data) 26 | elif classname.find("BatchNorm1d") != -1: 27 | init.normal_(m.weight.data, 1.0, 0.01) 28 | init.zeros_(m.bias.data) 29 | 30 | 31 | def weights_init_classifier(m): 32 | classname = m.__class__.__name__ 33 | if classname.find("Linear") != -1: 34 | init.normal_(m.weight.data, 0, 0.001) 35 | if m.bias is not None: 36 | init.zeros_(m.bias.data) 37 | 38 | 39 | class gradientreverselayer(torch.autograd.Function): 40 | @staticmethod 41 | def forward(ctx, coeff, input): 42 | ctx.coeff = coeff 43 | # this is necessary. if we just return "input", "backward" will not be called sometimes 44 | return input.view_as(input) 45 | 46 | @staticmethod 47 | def backward(ctx, grad_outputs): 48 | coeff = ctx.coeff 49 | return None, -coeff * grad_outputs 50 | 51 | 52 | class AdversarialLayer(nn.Module): 53 | def __init__(self, per_add_iters, iter_num=0, alpha=10.0, low_value=0.0, high_value=1.0, max_iter=10000.0): 54 | super(AdversarialLayer, self).__init__() 55 | self.per_add_iters = per_add_iters 56 | self.iter_num = iter_num 57 | self.alpha = alpha 58 | self.low_value = low_value 59 | self.high_value = high_value 60 | self.max_iter = max_iter 61 | self.grl = gradientreverselayer.apply 62 | 63 | def forward(self, input, train_set=True): 64 | if train_set: 65 | self.iter_num += self.per_add_iters 66 | self.coeff = np.float( 67 | 2.0 * (self.high_value - self.low_value) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iter)) - ( 68 | self.high_value - self.low_value) + self.low_value) 69 | 70 | return self.grl(self.coeff, input) 71 | 72 | 73 | class DiscriminateNet(nn.Module): 74 | def __init__(self, input_dim, class_num=1): 75 | super(DiscriminateNet, self).__init__() 76 | self.ad_layer1 = nn.Linear(input_dim, input_dim//2) 77 | self.ad_layer2 = nn.Linear(input_dim//2, input_dim//2) 78 | self.ad_layer3 = nn.Linear(input_dim//2, class_num) 79 | self.relu1 = nn.ReLU() 80 | self.relu2 = nn.ReLU() 81 | self.dropout1 = nn.Dropout(0.5) 82 | self.dropout2 = nn.Dropout(0.5) 83 | self.bn = nn.BatchNorm1d(class_num) 84 | self.bn2 = nn.BatchNorm1d(input_dim // 2) 85 | self.bn.bias.requires_grad_(False) 86 | self.bn2.bias.requires_grad_(False) 87 | self.sigmoid = nn.Sigmoid() 88 | 89 | self.ad_layer1.apply(weights_init_kaiming) 90 | self.ad_layer2.apply(weights_init_kaiming) 91 | self.ad_layer3.apply(weights_init_classifier) 92 | 93 | def forward(self, x): 94 | x = self.ad_layer1(x) 95 | x = self.relu1(x) 96 | x = self.dropout1(x) 97 | x = self.ad_layer2(x) 98 | x = self.relu2(x) 99 | x = self.dropout2(x) 100 | x = self.ad_layer3(x) 101 | x = self.bn(x) 102 | x = self.sigmoid(x) # binary classification 103 | 104 | return x 105 | 106 | 107 | class BaseResNet(nn.Module): 108 | def __init__(self, pool_dim, class_num, per_add_iters, arch="resnet50"): 109 | super(BaseResNet, self).__init__() 110 | 111 | if arch == "resnet50": 112 | network = resnet50(pretrained=True, last_conv_stride=1, last_conv_dilation=1) 113 | 114 | self.layer0 = nn.Sequential(network.conv1, 115 | network.bn1, 116 | network.relu, 117 | network.maxpool) 118 | self.layer1 = network.layer1 119 | self.layer2 = network.layer2 120 | self.layer3 = network.layer3 121 | self.layer4 = network.layer4 122 | 123 | self.bottleneck_0 = nn.BatchNorm1d(64) 124 | self.bottleneck_0.bias.requires_grad_(False) # no shift 125 | self.bottleneck_1 = nn.BatchNorm1d(256) 126 | self.bottleneck_1.bias.requires_grad_(False) # no shift 127 | self.bottleneck_2 = nn.BatchNorm1d(512) 128 | self.bottleneck_2.bias.requires_grad_(False) # no shift 129 | self.bottleneck_3 = nn.BatchNorm1d(1024) 130 | self.bottleneck_3.bias.requires_grad_(False) # no shift 131 | self.bottleneck = nn.BatchNorm1d(pool_dim) 132 | self.bottleneck.bias.requires_grad_(False) # no shift 133 | 134 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 135 | self.adnet = AdversarialLayer(per_add_iters=per_add_iters) 136 | self.disnet = DiscriminateNet(64 + 256 + 512 + 1024 + pool_dim, 1) 137 | 138 | self.bottleneck_0.apply(weights_init_kaiming) 139 | self.bottleneck_1.apply(weights_init_kaiming) 140 | self.bottleneck_2.apply(weights_init_kaiming) 141 | self.bottleneck_3.apply(weights_init_kaiming) 142 | self.bottleneck.apply(weights_init_kaiming) 143 | self.classifier.apply(weights_init_classifier) 144 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 145 | 146 | self.l2norm = Normalize(2) 147 | 148 | def forward(self, x_rgb, x_ir, modal=0, train_set=True): 149 | if modal == 0: 150 | x = torch.cat((x_rgb, x_ir), dim=0) 151 | elif modal == 1: 152 | x = x_rgb 153 | elif modal == 2: 154 | x = x_ir 155 | 156 | x_0 = self.layer0(x) 157 | x_1 = self.layer1(x_0) 158 | x_2 = self.layer2(x_1) 159 | x_3 = self.layer3(x_2) 160 | x_4 = self.layer4(x_3) 161 | 162 | x_pool_0 = self.avgpool(x_0) 163 | x_pool_0 = x_pool_0.view(x_pool_0.size(0), x_pool_0.size(1)) 164 | x_pool_1 = self.avgpool(x_1) 165 | x_pool_1 = x_pool_1.view(x_pool_1.size(0), x_pool_1.size(1)) 166 | x_pool_2 = self.avgpool(x_2) 167 | x_pool_2 = x_pool_2.view(x_pool_2.size(0), x_pool_2.size(1)) 168 | x_pool_3 = self.avgpool(x_3) 169 | x_pool_3 = x_pool_3.view(x_pool_3.size(0), x_pool_3.size(1)) 170 | x_pool_4 = self.avgpool(x_4) 171 | x_pool_4 = x_pool_4.view(x_pool_4.size(0), x_pool_4.size(1)) 172 | 173 | feat_0 = self.bottleneck_0(x_pool_0) 174 | feat_1 = self.bottleneck_1(x_pool_1) 175 | feat_2 = self.bottleneck_2(x_pool_2) 176 | feat_3 = self.bottleneck_3(x_pool_3) 177 | feat_4 = self.bottleneck(x_pool_4) 178 | 179 | if self.training: 180 | feat = torch.cat((feat_0, feat_1, feat_2, feat_3, feat_4), dim=1) 181 | x = self.adnet(feat, train_set=train_set) 182 | x_dis = self.disnet(x) 183 | p_4 = self.classifier(feat_4) 184 | 185 | return x_pool_4, p_4, x_dis 186 | else: 187 | return self.l2norm(feat_4) -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def adjust_learning_rate(args, optimizer, epoch): 5 | if epoch < 10: 6 | lr = args.lr * (epoch + 1) / 10 7 | elif epoch >= 10 and epoch < 20: 8 | lr = args.lr 9 | elif epoch >= 20 and epoch < 50: 10 | lr = args.lr * 0.1 11 | elif epoch >= 50: 12 | lr = args.lr * 0.01 13 | 14 | optimizer.param_groups[0]["lr"] = 0.1 * lr 15 | for i in range(len(optimizer.param_groups) - 1): 16 | optimizer.param_groups[i + 1]["lr"] = lr 17 | 18 | return lr 19 | 20 | 21 | def select_optimizer(args, main_net): 22 | if args.optim == "adam": 23 | ignored_params = list(map(id, main_net.bottleneck.parameters())) \ 24 | + list(map(id, main_net.classifier.parameters())) \ 25 | + list(map(id, main_net.adnet.parameters())) \ 26 | + list(map(id, main_net.disnet.parameters())) \ 27 | + list(map(id, main_net.bottleneck_0.parameters())) \ 28 | + list(map(id, main_net.bottleneck_1.parameters())) \ 29 | + list(map(id, main_net.bottleneck_2.parameters())) \ 30 | + list(map(id, main_net.bottleneck_3.parameters())) 31 | 32 | base_params = filter(lambda p: id(p) not in ignored_params, main_net.parameters()) 33 | optimizer = optim.Adam([ 34 | {"params": base_params, "lr": 0.1 * args.lr}, 35 | {"params": main_net.bottleneck_0.parameters(), "lr": args.lr}, 36 | {"params": main_net.bottleneck_1.parameters(), "lr": args.lr}, 37 | {"params": main_net.bottleneck_2.parameters(), "lr": args.lr}, 38 | {"params": main_net.bottleneck_3.parameters(), "lr": args.lr}, 39 | {"params": main_net.bottleneck.parameters(), "lr": args.lr}, 40 | {"params": main_net.classifier.parameters(), "lr": args.lr}, 41 | {"params": main_net.adnet.parameters(), "lr": args.lr}, 42 | {"params": main_net.disnet.parameters(), "lr": args.lr}], 43 | weight_decay=5e-4) 44 | 45 | return optimizer -------------------------------------------------------------------------------- /otla_sk.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from utils import sort_list_with_unique_index 6 | 7 | 8 | def cpu_sk_ir_trainloader(args, main_net, trainloader, tIndex, n_class, print_freq=50): 9 | main_net.train() 10 | 11 | n_ir = len(tIndex) 12 | P = np.zeros((n_ir, n_class)) 13 | 14 | with torch.no_grad(): 15 | for batch_idx, (input_rgb, input_ir, label_rgb, label_ir) in enumerate(trainloader): 16 | t = time.time() 17 | input_ir = input_ir.cuda() 18 | _, p, _ = main_net(input_ir, input_ir, modal=2, train_set=False) 19 | p_softmax = nn.Softmax(1)(p).cpu().numpy() 20 | P[batch_idx * args.train_batch_size * args.num_pos:(batch_idx + 1) * args.train_batch_size * args.num_pos, :] = p_softmax 21 | 22 | if batch_idx == 0: 23 | ir_real_label = label_ir 24 | else: 25 | ir_real_label = torch.cat((ir_real_label, label_ir), dim=0) 26 | 27 | if (batch_idx + 1) % print_freq == 0: 28 | print("Extract predictions: [{}/{}]\t" 29 | "Time consuming: {:.3f}\t" 30 | .format(batch_idx + 1, len(trainloader), time.time() - t)) 31 | 32 | # optimizer label using Sinkhorn-Knopp algorithm 33 | unique_tIndex_first_idx, unique_tIndex_last_idx, unique_tIndex_num, idx_order, unique_tIndex_list = sort_list_with_unique_index(tIndex) 34 | unique_tIndex_idx = unique_tIndex_last_idx # last 35 | ir_real_label = ir_real_label[unique_tIndex_idx] 36 | P_ = P[unique_tIndex_idx] 37 | for i, idx in enumerate(idx_order): 38 | P_[i] = (P[unique_tIndex_list[idx]].mean(axis=0)) 39 | PS = (P_.T) ** args.lambda_sk 40 | 41 | n_ir_unique = len(np.unique(tIndex)) 42 | alpha = np.ones((n_class, 1)) / n_class # initial value for alpha 43 | beta = np.ones((n_ir_unique, 1)) / n_ir_unique # initial value for beta 44 | 45 | inv_K = 1. / n_class 46 | inv_N = 1. / n_ir_unique 47 | 48 | err = 1e6 49 | step = 0 50 | tt = time.time() 51 | while err > 1e-1: 52 | alpha = inv_K / (PS @ beta) # (KxN) @ (N,1) = K x 1 53 | beta_new = inv_N / (alpha.T @ PS).T # ((1,K) @ (KxN)).t() = N x 1 54 | if step % 10 == 0: 55 | err = np.nansum(np.abs(beta / beta_new - 1)) 56 | beta = beta_new 57 | step += 1 58 | print("Sinkhorn-Knopp Error: {:.3f} Total step: {} Total time: {:.3f}".format(err, step, time.time() - tt)) 59 | PS *= np.squeeze(beta) 60 | PS = PS.T 61 | PS *= np.squeeze(alpha) 62 | PS = PS.T 63 | argmaxes = np.nanargmax(PS, 0) # size n_ir 64 | ir_pseudo_label_op = torch.LongTensor(argmaxes) 65 | 66 | # the max prediction of softmax 67 | argmaxes_ = np.nanargmax(P_, 1) 68 | ir_pseudo_label_mp = torch.LongTensor(argmaxes_) 69 | 70 | return ir_pseudo_label_op, ir_pseudo_label_mp, ir_real_label, tIndex[unique_tIndex_idx] -------------------------------------------------------------------------------- /video-poster/0971.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjm-wjm/OTLA-ReID/1405e96bd8339deeebf61718346b24722770ac61/video-poster/0971.mp4 -------------------------------------------------------------------------------- /video-poster/0971.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjm-wjm/OTLA-ReID/1405e96bd8339deeebf61718346b24722770ac61/video-poster/0971.pdf --------------------------------------------------------------------------------