├── .idea
├── .gitignore
├── OTLA.iml
├── deployment.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── SpCL-master
├── .gitignore
├── LICENSE
├── README.md
├── examples
│ ├── otla_tool.py
│ ├── spcl_train_uda.py
│ ├── spcl_train_usl.py
│ └── test.py
├── figs
│ ├── framework.png
│ └── results.png
├── setup.cfg
├── setup.py
└── spcl
│ ├── __init__.py
│ ├── datasets
│ ├── __init__.py
│ ├── dukemtmc.py
│ ├── market1501.py
│ ├── msmt17.py
│ ├── personx.py
│ ├── regdb.py
│ ├── regdb_ir.py
│ ├── regdb_rgb.py
│ ├── sysumm01.py
│ ├── sysumm01_ir.py
│ ├── sysumm01_rgb.py
│ ├── vehicleid.py
│ ├── vehiclex.py
│ └── veri.py
│ ├── evaluation_metrics
│ ├── __init__.py
│ ├── classification.py
│ └── ranking.py
│ ├── evaluators.py
│ ├── models
│ ├── __init__.py
│ ├── dsbn.py
│ ├── hm.py
│ ├── resnet.py
│ ├── resnet_ibn.py
│ └── resnet_ibn_a.py
│ ├── trainers.py
│ └── utils
│ ├── __init__.py
│ ├── data
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── preprocessor.py
│ ├── sampler.py
│ └── transforms.py
│ ├── faiss_rerank.py
│ ├── faiss_utils.py
│ ├── logging.py
│ ├── meters.py
│ ├── osutils.py
│ ├── rerank.py
│ └── serialization.py
├── config
├── config_regdb.yaml
└── config_sysu.yaml
├── data_loader.py
├── data_manager.py
├── engine.py
├── eval_metrics.py
├── image
└── main_figure.png
├── loss.py
├── main_test.py
├── main_train.py
├── model
├── backbone
│ └── resnet.py
└── network.py
├── optimizer.py
├── otla_sk.py
├── utils.py
└── video-poster
├── 0971.mp4
└── 0971.pdf
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /../../../../../../../:\Users\王蒋铭\pytorch\github\OTLA\.idea/dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/OTLA.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Optimal Transport for Label-Efficient Visible-Infrared Person Re-Identification (OTLA-ReID)
2 | This is Official Repository for "Optimal Transport for Label-Efficient
3 | Visible-Infrared Person Re-Identification" ([PDF](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136840091.pdf), [Supplementary Material](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136840091-supp.pdf)), which is accepted by *ECCV 2022*. This work is done at the DMCV Laboratory of East China Normal University. You can link at [DMCV-Lab](https://dmcv-ecnu.github.io/) to find DMCV Laboratory website page.
4 |
5 | 
6 |
7 | ### Update:
8 | **[2022-7-17]** Semi-supervised setting and supervised setting can be run with current code. Unsupervised setting will be updated with a few of days.
9 |
10 | **[2022-7-21]** Update some critical informtion of REAMDE.md.
11 |
12 | **[2022-9-22]** Update the code of SpCL-master, which can be used to generator pseudo labels of visible modality for unsupervised setting.
13 |
14 | **[2022-10-28]** Update the paper link.
15 |
16 |
17 | ## Requirements
18 | + python 3.7.11
19 | + numpy 1.21.4
20 | + torch 1.10.0
21 | + torchvision 0.11.0
22 | + easydict 1.9
23 | + PyYAML 6.0
24 | + tensorboardX 2.2
25 |
26 |
27 | ## Prepare Datasets
28 | Download the VI-ReID datasets [SYSU-MM01](https://github.com/wuancong/SYSU-MM01) (Email the author to get it) and [RegDB](http://dm.dongguk.edu/link.html) (Submit a copyright form). follow the link of [DDAG](https://github.com/mangye16/DDAG) to obtain more information of VI-ReID datasets. Download visible ReID datasets [Market-1501](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view), [MSMT17](https://arxiv.org/abs/1711.08565) (Email the author to get it), [DukeMTMC-reID](https://drive.google.com/file/d/1jjE85dRCMOgRtvJ5RQV9-Afs-2_5dY3O/view) if you want to run unsupervised setting. Please follow the link of [OpenUnReID](https://github.com/open-mmlab/OpenUnReID/blob/master/docs/INSTALL.md) to obtain more information of visible ReID datasets.
29 |
30 |
31 | ## Training
32 | You need to firstly choose the ```setting:``` of config file corresponding VI-ReID dataset.
33 |
34 | + For ```semi-supervised``` / ```supervised``` setting, if you want to train the model(s) in the paper, run following command:
35 | ```shell
36 | cd OTLA-ReID/
37 | python main_train.py --config config/config_sysu.yaml
38 | ```
39 | + For ```unsupervised``` setting, you should write the right path of ```train_visible_image_path:``` and ```train_visible_label_path:``` , which are the produced visible data and pseudo label path of VI-ReID datasets by well-established UDA-ReID or USL-ReID methods (e.g. [SpCL](https://github.com/yxgeee/SpCL)). Then run following command:
40 | ```shell
41 | cd OTLA-ReID/
42 | python main_train.py --config config/config_sysu.yaml
43 | ```
44 |
45 | Here, we give an example of running SpCL to generate visible pseudo label in SpCL-master. However, you firstly need to install environment which can be found in [SpCL](https://github.com/yxgeee/SpCL):
46 | + For SYSU-MM01:
47 | ```shell
48 | cd OTLA-ReID/SpCL-master/
49 | CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/spcl_train_uda.py -ds market1501 -dt sysumm01_rgb --logs-dir logs/spcl_uda/market1501TOsysumm01_rgb_resnet50 --epochs 51 --iters 800
50 | ```
51 | + For RegDB:
52 | ```shell
53 | cd OTLA-ReID/SpCL-master/
54 | CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/spcl_train_uda.py -ds market1501 -dt regdb_rgb --logs-dir logs/spcl_uda/regdbTOsysumm01_rgb_resnet50 --epochs 51 --iters 50
55 | ```
56 | The generated visible images and visible pseudo labels are both saved under the dataset directory.
57 |
58 | ## Testing
59 | If you want to test the trained model(s), run following command:
60 | ```shell
61 | cd OTLA-ReID/
62 | python main_test.py --config config/config_sysu.yaml --resume --resume_path ./sysu_semi-supervised_otla-reid/sysu_save_model/best_checkpoint.pth
63 | ```
64 |
65 | ## Citation
66 | If you find this code useful for your research, please cite our paper:
67 | ```
68 | @inproceedings{wang2022optimal,
69 | title={Optimal Transport for Label-Efficient Visible-Infrared Person Re-Identification},
70 | author={Wang, Jiangming and Zhang, Zhizhong and Chen, Mingang and Zhang, Yi and Wang, Cong and Sheng, Bin and Qu, Yanyun and Xie, Yuan},
71 | booktitle={European Conference on Computer Vision},
72 | pages={93--109},
73 | year={2022},
74 | organization={Springer}
75 | }
76 | ```
77 |
78 | ## Acknowledgements
79 | This work is developed based on repositories of [SeLa(ICLR 2020)](https://github.com/yukimasano/self-label), [DDAG(ECCV 2020)](https://github.com/mangye16/DDAG), [SpCL(NIPS 2020)](https://github.com/yxgeee/SpCL), [MMT(ICLR 2020)](https://github.com/yxgeee/MMT), [HCD(ICCV 2021)](https://github.com/tangshixiang/HCD). We sincerely thanks all developers of these high-quality repositories.
80 |
--------------------------------------------------------------------------------
/SpCL-master/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | logs/*
3 | scripts/*
4 |
5 | # temporary files which can be created if a process still has a handle open of a deleted file
6 | .fuse_hidden*
7 |
8 | # KDE directory preferences
9 | .directory
10 |
11 | # Linux trash folder which might appear on any partition or disk
12 | .Trash-*
13 |
14 | # .nfs files are created when an open file is removed but is still being accessed
15 | .nfs*
16 |
17 |
18 | *.DS_Store
19 | .AppleDouble
20 | .LSOverride
21 |
22 | # Icon must end with two \r
23 | Icon
24 |
25 |
26 | # Thumbnails
27 | ._*
28 |
29 | # Files that might appear in the root of a volume
30 | .DocumentRevisions-V100
31 | .fseventsd
32 | .Spotlight-V100
33 | .TemporaryItems
34 | .Trashes
35 | .VolumeIcon.icns
36 | .com.apple.timemachine.donotpresent
37 |
38 | # Directories potentially created on remote AFP share
39 | .AppleDB
40 | .AppleDesktop
41 | Network Trash Folder
42 | Temporary Items
43 | .apdisk
44 |
45 |
46 | # swap
47 | [._]*.s[a-v][a-z]
48 | [._]*.sw[a-p]
49 | [._]s[a-v][a-z]
50 | [._]sw[a-p]
51 | # session
52 | Session.vim
53 | # temporary
54 | .netrwhist
55 | *~
56 | # auto-generated tag files
57 | tags
58 |
59 |
60 | # cache files for sublime text
61 | *.tmlanguage.cache
62 | *.tmPreferences.cache
63 | *.stTheme.cache
64 |
65 | # workspace files are user-specific
66 | *.sublime-workspace
67 |
68 | # project files should be checked into the repository, unless a significant
69 | # proportion of contributors will probably not be using SublimeText
70 | # *.sublime-project
71 |
72 | # sftp configuration file
73 | sftp-config.json
74 |
75 | # Package control specific files
76 | Package Control.last-run
77 | Package Control.ca-list
78 | Package Control.ca-bundle
79 | Package Control.system-ca-bundle
80 | Package Control.cache/
81 | Package Control.ca-certs/
82 | Package Control.merged-ca-bundle
83 | Package Control.user-ca-bundle
84 | oscrypto-ca-bundle.crt
85 | bh_unicode_properties.cache
86 |
87 | # Sublime-github package stores a github token in this file
88 | # https://packagecontrol.io/packages/sublime-github
89 | GitHub.sublime-settings
90 |
91 |
92 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
93 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
94 |
95 | # User-specific stuff:
96 | .idea
97 | .idea/**/workspace.xml
98 | .idea/**/tasks.xml
99 |
100 | # Sensitive or high-churn files:
101 | .idea/**/dataSources/
102 | .idea/**/dataSources.ids
103 | .idea/**/dataSources.xml
104 | .idea/**/dataSources.local.xml
105 | .idea/**/sqlDataSources.xml
106 | .idea/**/dynamic.xml
107 | .idea/**/uiDesigner.xml
108 |
109 | # Gradle:
110 | .idea/**/gradle.xml
111 | .idea/**/libraries
112 |
113 | # Mongo Explorer plugin:
114 | .idea/**/mongoSettings.xml
115 |
116 | ## File-based project format:
117 | *.iws
118 |
119 | ## Plugin-specific files:
120 |
121 | # IntelliJ
122 | /out/
123 |
124 | # mpeltonen/sbt-idea plugin
125 | .idea_modules/
126 |
127 | # JIRA plugin
128 | atlassian-ide-plugin.xml
129 |
130 | # Crashlytics plugin (for Android Studio and IntelliJ)
131 | com_crashlytics_export_strings.xml
132 | crashlytics.properties
133 | crashlytics-build.properties
134 | fabric.properties
135 |
136 |
137 | # Byte-compiled / optimized / DLL files
138 | __pycache__/
139 | *.py[cod]
140 | *$py.class
141 |
142 | # C extensions
143 | *.so
144 |
145 | # Distribution / packaging
146 | .Python
147 | env/
148 | build/
149 | develop-eggs/
150 | dist/
151 | downloads/
152 | eggs/
153 | .eggs/
154 | lib/
155 | lib64/
156 | parts/
157 | sdist/
158 | var/
159 | *.egg-info/
160 | .installed.cfg
161 | *.egg
162 |
163 | # PyInstaller
164 | # Usually these files are written by a python script from a template
165 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
166 | *.manifest
167 | *.spec
168 |
169 | # Installer logs
170 | pip-log.txt
171 | pip-delete-this-directory.txt
172 |
173 | # Unit test / coverage reports
174 | htmlcov/
175 | .tox/
176 | .coverage
177 | .coverage.*
178 | .cache
179 | nosetests.xml
180 | coverage.xml
181 | *,cover
182 | .hypothesis/
183 |
184 | # Translations
185 | *.mo
186 | *.pot
187 |
188 | # Django stuff:
189 | *.log
190 | local_settings.py
191 |
192 | # Flask stuff:
193 | instance/
194 | .webassets-cache
195 |
196 | # Scrapy stuff:
197 | .scrapy
198 |
199 | # Sphinx documentation
200 | docs/_build/
201 |
202 | # PyBuilder
203 | target/
204 |
205 | # IPython Notebook
206 | .ipynb_checkpoints
207 |
208 | # pyenv
209 | .python-version
210 |
211 | # celery beat schedule file
212 | celerybeat-schedule
213 |
214 | # dotenv
215 | .env
216 |
217 | # virtualenv
218 | venv/
219 | ENV/
220 |
221 | # Spyder project settings
222 | .spyderproject
223 |
224 | # Rope project settings
225 | .ropeproject
226 |
227 |
228 | # Project specific
229 | examples/data
230 | examples/logs
231 |
--------------------------------------------------------------------------------
/SpCL-master/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Yixiao Ge
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/SpCL-master/README.md:
--------------------------------------------------------------------------------
1 | 
2 | 
3 |
4 | # Self-paced Contrastive Learning (SpCL)
5 |
6 | The *official* repository for [Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID](https://arxiv.org/abs/2006.02713), which is accepted by [NeurIPS-2020](https://nips.cc/). `SpCL` achieves state-of-the-art performances on both **unsupervised domain adaptation** tasks and **unsupervised learning** tasks for object re-ID, including person re-ID and vehicle re-ID.
7 |
8 | 
9 |
10 | ### Updates
11 |
12 | [2020-10-13] All trained models for the camera-ready version have been updated, see [Trained Models](#trained-models) for details.
13 |
14 | [2020-09-25] `SpCL` has been accepted by NeurIPS on the condition that experiments on DukeMTMC-reID dataset should be removed, since the dataset has been taken down and should no longer be used.
15 |
16 | [2020-07-01] We did the code refactoring to support distributed training, stronger performances and more features. Please see [OpenUnReID](https://github.com/open-mmlab/OpenUnReID).
17 |
18 | ## Requirements
19 |
20 | ### Installation
21 |
22 | ```shell
23 | git clone https://github.com/yxgeee/SpCL.git
24 | cd SpCL
25 | python setup.py develop
26 | ```
27 |
28 | ### Prepare Datasets
29 |
30 | ```shell
31 | cd examples && mkdir data
32 | ```
33 | Download the person datasets [Market-1501](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view), [MSMT17](https://arxiv.org/abs/1711.08565), [PersonX](https://github.com/sxzrt/Instructions-of-the-PersonX-dataset#data-for-visda2020-chanllenge), and the vehicle datasets [VehicleID](https://www.pkuml.org/resources/pku-vehicleid.html), [VeRi-776](https://github.com/JDAI-CV/VeRidataset), [VehicleX](https://www.aicitychallenge.org/2020-track2-download/).
34 | Then unzip them under the directory like
35 | ```
36 | SpCL/examples/data
37 | ├── market1501
38 | │ └── Market-1501-v15.09.15
39 | ├── msmt17
40 | │ └── MSMT17_V1
41 | ├── personx
42 | │ └── PersonX
43 | ├── vehicleid
44 | │ └── VehicleID -> VehicleID_V1.0
45 | ├── vehiclex
46 | │ └── AIC20_ReID_Simulation -> AIC20_track2/AIC20_ReID_Simulation
47 | └── veri
48 | └── VeRi -> VeRi_with_plate
49 | ```
50 |
51 | ### Prepare ImageNet Pre-trained Models for IBN-Net
52 | When training with the backbone of [IBN-ResNet](https://arxiv.org/abs/1807.09441), you need to download the ImageNet-pretrained model from this [link](https://drive.google.com/drive/folders/1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S) and save it under the path of `logs/pretrained/`.
53 | ```shell
54 | mkdir logs && cd logs
55 | mkdir pretrained
56 | ```
57 | The file tree should be
58 | ```
59 | SpCL/logs
60 | └── pretrained
61 | └── resnet50_ibn_a.pth.tar
62 | ```
63 | ImageNet-pretrained models for **ResNet-50** will be automatically downloaded in the python script.
64 |
65 |
66 | ## Training
67 |
68 | We utilize 4 GTX-1080TI GPUs for training. **Note that**
69 |
70 | + The training for `SpCL` is end-to-end, which means that no source-domain pre-training is required.
71 | + use `--iters 400` (default) for Market-1501 and PersonX datasets, and `--iters 800` for MSMT17, VeRi-776, VehicleID and VehicleX datasets;
72 | + use `--width 128 --height 256` (default) for person datasets, and `--height 224 --width 224` for vehicle datasets;
73 | + use `-a resnet50` (default) for the backbone of ResNet-50, and `-a resnet_ibn50a` for the backbone of IBN-ResNet.
74 |
75 | ### Unsupervised Domain Adaptation
76 | To train the model(s) in the paper, run this command:
77 | ```shell
78 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
79 | python examples/spcl_train_uda.py \
80 | -ds $SOURCE_DATASET -dt $TARGET_DATASET --logs-dir $PATH_OF_LOGS
81 | ```
82 |
83 | **Some examples:**
84 | ```shell
85 | ### PersonX -> Market-1501 ###
86 | # use all default settings is ok
87 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
88 | python examples/spcl_train_uda.py \
89 | -ds personx -dt market1501 --logs-dir logs/spcl_uda/personx2market_resnet50
90 |
91 | ### Market-1501 -> MSMT17 ###
92 | # use all default settings except for iters=800
93 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
94 | python examples/spcl_train_uda.py --iters 800 \
95 | -ds market1501 -dt msmt17 --logs-dir logs/spcl_uda/market2msmt_resnet50
96 |
97 | ### VehicleID -> VeRi-776 ###
98 | # use all default settings except for iters=800, height=224 and width=224
99 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
100 | python examples/spcl_train_uda.py --iters 800 --height 224 --width 224 \
101 | -ds vehicleid -dt veri --logs-dir logs/spcl_uda/vehicleid2veri_resnet50
102 | ```
103 |
104 |
105 | ### Unsupervised Learning
106 | To train the model(s) in the paper, run this command:
107 | ```shell
108 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
109 | python examples/spcl_train_usl.py \
110 | -d $DATASET --logs-dir $PATH_OF_LOGS
111 | ```
112 |
113 | **Some examples:**
114 | ```shell
115 | ### Market-1501 ###
116 | # use all default settings is ok
117 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
118 | python examples/spcl_train_usl.py \
119 | -d market1501 --logs-dir logs/spcl_usl/market_resnet50
120 |
121 | ### MSMT17 ###
122 | # use all default settings except for iters=800
123 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
124 | python examples/spcl_train_usl.py --iters 800 \
125 | -d msmt17 --logs-dir logs/spcl_usl/msmt_resnet50
126 |
127 | ### VeRi-776 ###
128 | # use all default settings except for iters=800, height=224 and width=224
129 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
130 | python examples/spcl_train_usl.py --iters 800 --height 224 --width 224 \
131 | -d veri --logs-dir logs/spcl_usl/veri_resnet50
132 | ```
133 |
134 |
135 | ## Evaluation
136 |
137 | We utilize 1 GTX-1080TI GPU for testing. **Note that**
138 |
139 | + use `--width 128 --height 256` (default) for person datasets, and `--height 224 --width 224` for vehicle datasets;
140 | + use `--dsbn` for domain adaptive models, and add `--test-source` if you want to test on the source domain;
141 | + use `-a resnet50` (default) for the backbone of ResNet-50, and `-a resnet_ibn50a` for the backbone of IBN-ResNet.
142 |
143 | ### Unsupervised Domain Adaptation
144 |
145 | To evaluate the domain adaptive model on the **target-domain** dataset, run:
146 | ```shell
147 | CUDA_VISIBLE_DEVICES=0 \
148 | python examples/test.py --dsbn \
149 | -d $DATASET --resume $PATH_OF_MODEL
150 | ```
151 |
152 | To evaluate the domain adaptive model on the **source-domain** dataset, run:
153 | ```shell
154 | CUDA_VISIBLE_DEVICES=0 \
155 | python examples/test.py --dsbn --test-source \
156 | -d $DATASET --resume $PATH_OF_MODEL
157 | ```
158 |
159 | **Some examples:**
160 | ```shell
161 | ### Market-1501 -> MSMT17 ###
162 | # test on the target domain
163 | CUDA_VISIBLE_DEVICES=0 \
164 | python examples/test.py --dsbn \
165 | -d msmt17 --resume logs/spcl_uda/market2msmt_resnet50/model_best.pth.tar
166 | # test on the source domain
167 | CUDA_VISIBLE_DEVICES=0 \
168 | python examples/test.py --dsbn --test-source \
169 | -d market1501 --resume logs/spcl_uda/market2msmt_resnet50/model_best.pth.tar
170 | ```
171 |
172 | ### Unsupervised Learning
173 | To evaluate the model, run:
174 | ```shell
175 | CUDA_VISIBLE_DEVICES=0 \
176 | python examples/test.py \
177 | -d $DATASET --resume $PATH
178 | ```
179 |
180 | **Some examples:**
181 | ```shell
182 | ### Market-1501 ###
183 | CUDA_VISIBLE_DEVICES=0 \
184 | python examples/test.py \
185 | -d market1501 --resume logs/spcl_usl/market_resnet50/model_best.pth.tar
186 | ```
187 |
188 | ## Trained Models
189 |
190 | 
191 |
192 | You can download the above models in the paper from [[Google Drive]](https://drive.google.com/drive/folders/1ryx-fPGjrexwm9ZP9QO3Qk4SKzNqbaXw?usp=sharing) or [[Baidu Yun]](https://pan.baidu.com/s/1FInOhEdQsOEk-1oMWWB0Ag)(password: w3l9).
193 |
194 |
195 | ## Citation
196 | If you find this code useful for your research, please cite our paper
197 | ```
198 | @inproceedings{ge2020selfpaced,
199 | title={Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID},
200 | author={Yixiao Ge and Feng Zhu and Dapeng Chen and Rui Zhao and Hongsheng Li},
201 | booktitle={Advances in Neural Information Processing Systems},
202 | year={2020}
203 | }
204 | ```
205 |
--------------------------------------------------------------------------------
/SpCL-master/examples/otla_tool.py:
--------------------------------------------------------------------------------
1 | import os
2 | import errno
3 | from PIL import Image
4 | import numpy as np
5 | import collections
6 | import torch
7 |
8 |
9 | def mkdir_if_missing(dir_path):
10 | """
11 | Create file if missing.
12 | """
13 | try:
14 | os.makedirs(dir_path)
15 | except OSError as e:
16 | if e.errno != errno.EEXIST:
17 | raise
18 |
19 |
20 | def save_checkpoint_pseudo_label(state, fpath="checkpoint.pth.tar"):
21 | """
22 | Save model for generating pseudo label.
23 | """
24 | mkdir_if_missing(os.path.dirname(fpath))
25 | torch.save(state, fpath)
26 |
27 |
28 | def mask_outlier(train_pseudo_label):
29 | """
30 | Mask outlier data of clustering results.
31 | """
32 | index2label = collections.defaultdict(int)
33 | for label in train_pseudo_label:
34 | index2label[label.item()] += 1
35 | nums = np.fromiter(index2label.values(), dtype=float)
36 | label = np.fromiter(index2label.keys(), dtype=float)
37 | train_label = label[nums > 1]
38 |
39 | return np.array([i in train_label for i in train_pseudo_label])
40 |
41 |
42 | def R_gt(train_real_label, train_pseudo_label):
43 | '''
44 | The Average Maximum Proportion of Ground-truth Classes (R_gt) in supplementary material.
45 | '''
46 | p = 0
47 | mask = mask_outlier(train_pseudo_label)
48 | train_real_label = train_real_label[mask]
49 | ids_container = list(np.unique(train_real_label))
50 | id2label = {id_: label for label, id_ in enumerate(ids_container)}
51 | for i, label in enumerate(train_real_label):
52 | train_real_label[i] = id2label[label]
53 | train_pseudo_label = train_pseudo_label[mask]
54 | ids_container = list(np.unique(train_pseudo_label))
55 | id2label = {id_: label for label, id_ in enumerate(ids_container)}
56 | for i, label in enumerate(train_pseudo_label):
57 | train_pseudo_label[i] = id2label[label]
58 | for i in range(np.unique(train_real_label).size):
59 | sample_id = (train_real_label == i)
60 | sample_label = train_pseudo_label[sample_id]
61 | sample_num_per_label = np.zeros(np.unique(train_pseudo_label).size)
62 | for j in sample_label:
63 | sample_num_per_label[j] += 1
64 | p_i = np.max(sample_num_per_label) / sample_label.size
65 | p += p_i
66 | p = p / np.unique(train_real_label).size
67 | print("R_gt: {:.4f}".format(p))
68 |
69 | return p
70 |
71 |
72 | def R_ct(train_real_label, train_pseudo_label):
73 | '''
74 | The Average Maximum Proportion of Pseudo Classes (R_ct) in supplementary material.
75 | '''
76 | p = 0
77 | mask = mask_outlier(train_pseudo_label)
78 | train_real_label = train_real_label[mask]
79 | ids_container = list(np.unique(train_real_label))
80 | id2label = {id_: label for label, id_ in enumerate(ids_container)}
81 | for i, label in enumerate(train_real_label):
82 | train_real_label[i] = id2label[label]
83 | train_pseudo_label = train_pseudo_label[mask]
84 | ids_container = list(np.unique(train_pseudo_label))
85 | id2label = {id_: label for label, id_ in enumerate(ids_container)}
86 | for i, label in enumerate(train_pseudo_label):
87 | train_pseudo_label[i] = id2label[label]
88 | for i in range(np.unique(train_pseudo_label).size):
89 | sample_id = (train_pseudo_label == i)
90 | sample_label = train_real_label[sample_id]
91 | sample_num_per_label = np.zeros(np.unique(train_real_label).size)
92 | for j in sample_label:
93 | sample_num_per_label[j] += 1
94 | p_i = np.max(sample_num_per_label) / sample_label.size
95 | p += p_i
96 | p = p / np.unique(train_pseudo_label).size
97 | print("R_ct: {:.4f}".format(p))
98 |
99 | return p
100 |
101 |
102 | def P_v(train_real_label, train_pseudo_label):
103 | '''
104 | The Proportion of Visible Training Samples (P_v).
105 | '''
106 | len_data = len(train_real_label)
107 | mask = mask_outlier(train_pseudo_label)
108 | len_mask_data = len(train_pseudo_label[mask])
109 | p = len_mask_data / len_data
110 | print("P_v: {:.4f}, total samples: {}, total samples without outliers: {}".format(p, len_data, len_mask_data))
111 |
112 | return p
113 |
114 |
115 | def Q_v(train_real_label, train_pseudo_label):
116 | mask = mask_outlier(train_pseudo_label)
117 | n_class = np.unique(train_real_label).size
118 | n_cluster_class = np.unique(train_pseudo_label[mask]).size
119 | p = np.min((n_class, n_cluster_class)) / np.max((n_class, n_cluster_class))
120 | print("Q_v: {:.4f}, number of real classes: {}, number of pseudo classes: {}".format(p, n_class, n_cluster_class))
121 |
122 | return p
123 |
124 |
125 | def R_plq(train_real_label, train_pseudo_label):
126 | '''
127 | The Final Metric (R_plq) in supplementary material.
128 | '''
129 | R_gt_p = R_gt(train_real_label, train_pseudo_label)
130 | R_ct_p = R_ct(train_real_label, train_pseudo_label)
131 | P_v_p = P_v(train_real_label, train_pseudo_label)
132 | Q_v_p = Q_v(train_real_label, train_pseudo_label)
133 | R_plq_p = (R_gt_p + R_ct_p) / 2 * P_v_p * Q_v_p
134 | print("R_plq: {:.4f}".format(R_plq_p))
135 |
136 | return R_gt_p, R_ct_p, P_v_p, Q_v_p, R_plq_p
137 |
138 |
139 | def save_image_label(train_image_path, train_pseudo_label, train_real_label, model, epoch, logs_dir, save_path,
140 | img_size=(144, 288), source_domain="market1501", target_domain="sysumm01_rgb", method_name="spcl_uda"):
141 | train_image = []
142 | for fname in train_image_path:
143 | img = Image.open(fname)
144 | img = img.resize(img_size, Image.ANTIALIAS)
145 | pix_array = np.array(img)
146 | train_image.append(pix_array)
147 |
148 | train_image = np.array(train_image)
149 | train_pseudo_label = np.array(train_pseudo_label)
150 | train_real_label = np.array(train_real_label)
151 |
152 | ids_container = list(np.unique(train_pseudo_label))
153 | id2label = {id_: label for label, id_ in enumerate(ids_container)}
154 | for i, label in enumerate(train_pseudo_label):
155 | train_pseudo_label[i] = id2label[label]
156 |
157 | R_gt_p, R_ct_p, P_v_p, Q_v_p, R_plq_p = R_plq(train_real_label, train_pseudo_label)
158 |
159 | np.save(os.path.join(save_path, method_name+"_"+source_domain+"TO"+target_domain+"_"+"train_rgb_resized_img.npy"), train_image)
160 | np.save(os.path.join(save_path, method_name+"_"+source_domain+"TO"+target_domain+"_"+"train_rgb_resized_label.npy"), train_pseudo_label)
161 |
162 | save_checkpoint_pseudo_label({
163 | "state_dict": model.state_dict(),
164 | "epoch": epoch,
165 | "R_gt": R_gt_p,
166 | "R_ct": R_ct_p,
167 | "P_v": P_v_p,
168 | "Q_v": Q_v_p,
169 | "R_plq": R_plq_p,
170 | }, fpath=os.path.join(logs_dir, "checkpoint_pseudo_label.pth.tar"))
--------------------------------------------------------------------------------
/SpCL-master/examples/test.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import argparse
3 | import os.path as osp
4 | import random
5 | import numpy as np
6 | import sys
7 |
8 | import torch
9 | from torch import nn
10 | from torch.backends import cudnn
11 | from torch.utils.data import DataLoader
12 |
13 | from spcl import datasets
14 | from spcl import models
15 | from spcl.models.dsbn import convert_dsbn, convert_bn
16 | from spcl.evaluators import Evaluator
17 | from spcl.utils.data import transforms as T
18 | from spcl.utils.data.preprocessor import Preprocessor
19 | from spcl.utils.logging import Logger
20 | from spcl.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict
21 |
22 |
23 | def get_data(name, data_dir, height, width, batch_size, workers):
24 | root = osp.join(data_dir, name)
25 |
26 | dataset = datasets.create(name, root)
27 |
28 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
29 | std=[0.229, 0.224, 0.225])
30 |
31 | test_transformer = T.Compose([
32 | T.Resize((height, width), interpolation=3),
33 | T.ToTensor(),
34 | normalizer
35 | ])
36 |
37 | test_loader = DataLoader(
38 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)),
39 | root=dataset.images_dir, transform=test_transformer),
40 | batch_size=batch_size, num_workers=workers,
41 | shuffle=False, pin_memory=True)
42 |
43 | return dataset, test_loader
44 |
45 |
46 | def main():
47 | args = parser.parse_args()
48 |
49 | if args.seed is not None:
50 | random.seed(args.seed)
51 | np.random.seed(args.seed)
52 | torch.manual_seed(args.seed)
53 | cudnn.deterministic = True
54 |
55 | main_worker(args)
56 |
57 |
58 | def main_worker(args):
59 | cudnn.benchmark = True
60 |
61 | log_dir = osp.dirname(args.resume)
62 | sys.stdout = Logger(osp.join(log_dir, 'log_test.txt'))
63 | print("==========\nArgs:{}\n==========".format(args))
64 |
65 | # Create data loaders
66 | dataset, test_loader = get_data(args.dataset, args.data_dir, args.height,
67 | args.width, args.batch_size, args.workers)
68 |
69 | # Create model
70 | model = models.create(args.arch, pretrained=False, num_features=args.features, dropout=args.dropout, num_classes=0)
71 | if args.dsbn:
72 | print("==> Load the model with domain-specific BNs")
73 | convert_dsbn(model)
74 |
75 | # Load from checkpoint
76 | checkpoint = load_checkpoint(args.resume)
77 | copy_state_dict(checkpoint['state_dict'], model, strip='module.')
78 |
79 | if args.dsbn:
80 | print("==> Test with {}-domain BNs".format("source" if args.test_source else "target"))
81 | convert_bn(model, use_target=(not args.test_source))
82 |
83 | model.cuda()
84 | model = nn.DataParallel(model)
85 |
86 | # Evaluator
87 | model.eval()
88 | evaluator = Evaluator(model)
89 | print("Test on {}:".format(args.dataset))
90 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=True, rerank=args.rerank)
91 | return
92 |
93 | if __name__ == '__main__':
94 | parser = argparse.ArgumentParser(description="Testing the model")
95 | # data
96 | parser.add_argument('-d', '--dataset', type=str, required=True,
97 | choices=datasets.names())
98 | parser.add_argument('-b', '--batch-size', type=int, default=256)
99 | parser.add_argument('-j', '--workers', type=int, default=4)
100 | parser.add_argument('--height', type=int, default=256, help="input height")
101 | parser.add_argument('--width', type=int, default=128, help="input width")
102 | # model
103 | parser.add_argument('-a', '--arch', type=str, default='resnet50',
104 | choices=models.names())
105 | parser.add_argument('--features', type=int, default=0)
106 | parser.add_argument('--dropout', type=float, default=0)
107 | parser.add_argument('--resume', type=str, required=True, metavar='PATH')
108 | # testing configs
109 | parser.add_argument('--rerank', action='store_true',
110 | help="evaluation only")
111 | parser.add_argument('--dsbn', action='store_true',
112 | help="test on the model with domain-specific BN")
113 | parser.add_argument('--test-source', action='store_true',
114 | help="test on the source domain")
115 | parser.add_argument('--seed', type=int, default=1)
116 | # path
117 | working_dir = osp.dirname(osp.abspath(__file__))
118 | parser.add_argument('--data-dir', type=str, metavar='PATH',
119 | default=osp.join(working_dir, 'data'))
120 | main()
121 |
--------------------------------------------------------------------------------
/SpCL-master/figs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjm-wjm/OTLA-ReID/1405e96bd8339deeebf61718346b24722770ac61/SpCL-master/figs/framework.png
--------------------------------------------------------------------------------
/SpCL-master/figs/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjm-wjm/OTLA-ReID/1405e96bd8339deeebf61718346b24722770ac61/SpCL-master/figs/results.png
--------------------------------------------------------------------------------
/SpCL-master/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
3 |
--------------------------------------------------------------------------------
/SpCL-master/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 |
4 | setup(name='SpCL',
5 | version='1.0.0',
6 | description='Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID',
7 | author='Yixiao Ge',
8 | author_email='geyixiao831@gmail.com',
9 | url='https://github.com/yxgeee/SpCL',
10 | install_requires=[
11 | 'numpy', 'torch', 'torchvision',
12 | 'six', 'h5py', 'Pillow', 'scipy',
13 | 'scikit-learn', 'metric-learn', 'faiss_gpu==1.6.3'],
14 | packages=find_packages(),
15 | keywords=[
16 | 'Unsupervised Learning',
17 | 'Unsupervised Domain Adaptation',
18 | 'Contrastive Learning',
19 | 'Object Re-identification'
20 | ])
21 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from . import datasets
4 | from . import evaluation_metrics
5 | from . import models
6 | from . import utils
7 | from . import evaluators
8 | from . import trainers
9 |
10 | __version__ = '0.1.0'
11 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import warnings
3 |
4 | from .market1501 import Market1501
5 | from .msmt17 import MSMT17
6 | from .personx import PersonX
7 | from .veri import VeRi
8 | from .vehicleid import VehicleID
9 | from .vehiclex import VehicleX
10 | from .dukemtmc import DukeMTMC
11 | from .sysumm01 import SYSU_MM01
12 | from .sysumm01_rgb import SYSU_MM01_RGB
13 | from .sysumm01_ir import SYSU_MM01_IR
14 | from .regdb import RegDB
15 | from .regdb_rgb import RegDB_RGB
16 | from .regdb_ir import RegDB_IR
17 |
18 |
19 | __factory = {
20 | 'market1501': Market1501,
21 | 'msmt17': MSMT17,
22 | 'personx': PersonX,
23 | 'veri': VeRi,
24 | 'vehicleid': VehicleID,
25 | 'vehiclex': VehicleX,
26 | 'dukemtmc': DukeMTMC,
27 | 'sysumm01': SYSU_MM01,
28 | 'sysumm01_rgb': SYSU_MM01_RGB,
29 | 'sysumm01_ir': SYSU_MM01_IR,
30 | 'regdb': RegDB,
31 | 'regdb_rgb': RegDB_RGB,
32 | 'regdb_ir': RegDB_IR
33 | }
34 |
35 |
36 | def names():
37 | return sorted(__factory.keys())
38 |
39 |
40 | def create(name, root, *args, **kwargs):
41 | """
42 | Create a dataset instance.
43 |
44 | Parameters
45 | ----------
46 | name : str
47 | The dataset name.
48 | root : str
49 | The path to the dataset directory.
50 | split_id : int, optional
51 | The index of data split. Default: 0
52 | num_val : int or float, optional
53 | When int, it means the number of validation identities. When float,
54 | it means the proportion of validation to all the trainval. Default: 100
55 | download : bool, optional
56 | If True, will download the dataset. Default: False
57 | """
58 | if name not in __factory:
59 | raise KeyError("Unknown dataset:", name)
60 | return __factory[name](root, *args, **kwargs)
61 |
62 |
63 | def get_dataset(name, root, *args, **kwargs):
64 | warnings.warn("get_dataset is deprecated. Use create instead.")
65 | return create(name, root, *args, **kwargs)
66 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/dukemtmc.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 | import glob
4 | import re
5 | import urllib
6 | import zipfile
7 |
8 | from ..utils.data import BaseImageDataset
9 | from ..utils.osutils import mkdir_if_missing
10 | from ..utils.serialization import write_json
11 |
12 |
13 | class DukeMTMC(BaseImageDataset):
14 | """
15 | DukeMTMC-reID
16 | Reference:
17 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
18 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
19 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation
20 |
21 | Dataset statistics:
22 | # identities: 1404 (train + query)
23 | # images:16522 (train) + 2228 (query) + 17661 (gallery)
24 | # cameras: 8
25 | """
26 | dataset_dir = '.'
27 |
28 | def __init__(self, root, verbose=True, **kwargs):
29 | super(DukeMTMC, self).__init__()
30 | self.dataset_dir = osp.join(root, self.dataset_dir)
31 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
32 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train')
33 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query')
34 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test')
35 |
36 | self._download_data()
37 | self._check_before_run()
38 |
39 | train = self._process_dir(self.train_dir, relabel=True)
40 | query = self._process_dir(self.query_dir, relabel=False)
41 | gallery = self._process_dir(self.gallery_dir, relabel=False)
42 |
43 | if verbose:
44 | print("=> DukeMTMC-reID loaded")
45 | self.print_dataset_statistics(train, query, gallery)
46 |
47 | self.train = train
48 | self.query = query
49 | self.gallery = gallery
50 |
51 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
52 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
53 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
54 |
55 | def _download_data(self):
56 | if osp.exists(self.dataset_dir):
57 | print("This dataset has been downloaded.")
58 | return
59 |
60 | print("Creating directory {}".format(self.dataset_dir))
61 | mkdir_if_missing(self.dataset_dir)
62 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url))
63 |
64 | print("Downloading DukeMTMC-reID dataset")
65 | urllib.request.urlretrieve(self.dataset_url, fpath)
66 |
67 | print("Extracting files")
68 | zip_ref = zipfile.ZipFile(fpath, 'r')
69 | zip_ref.extractall(self.dataset_dir)
70 | zip_ref.close()
71 |
72 | def _check_before_run(self):
73 | """Check if all files are available before going deeper"""
74 | if not osp.exists(self.dataset_dir):
75 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
76 | if not osp.exists(self.train_dir):
77 | raise RuntimeError("'{}' is not available".format(self.train_dir))
78 | if not osp.exists(self.query_dir):
79 | raise RuntimeError("'{}' is not available".format(self.query_dir))
80 | if not osp.exists(self.gallery_dir):
81 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
82 |
83 | def _process_dir(self, dir_path, relabel=False):
84 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
85 | pattern = re.compile(r'([-\d]+)_c(\d)')
86 |
87 | pid_container = set()
88 | for img_path in img_paths:
89 | pid, _ = map(int, pattern.search(img_path).groups())
90 | pid_container.add(pid)
91 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
92 |
93 | dataset = []
94 | for img_path in img_paths:
95 | pid, camid = map(int, pattern.search(img_path).groups())
96 | assert 1 <= camid <= 8
97 | camid -= 1 # index starts from 0
98 | if relabel: pid = pid2label[pid]
99 | dataset.append((img_path, pid, camid))
100 |
101 | return dataset
102 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/market1501.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 | import glob
4 | import re
5 | import urllib
6 | import zipfile
7 |
8 | from ..utils.data import BaseImageDataset
9 | from ..utils.osutils import mkdir_if_missing
10 | from ..utils.serialization import write_json
11 |
12 | class Market1501(BaseImageDataset):
13 | """
14 | Market1501
15 | Reference:
16 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
17 | URL: http://www.liangzheng.org/Project/project_reid.html
18 |
19 | Dataset statistics:
20 | # identities: 1501 (+1 for background)
21 | # images: 12936 (train) + 3368 (query) + 15913 (gallery)
22 | """
23 | dataset_dir = 'Market-1501-v15.09.15'
24 |
25 | def __init__(self, root, verbose=True, **kwargs):
26 | super(Market1501, self).__init__()
27 | self.dataset_dir = osp.join(root, self.dataset_dir)
28 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
29 | self.query_dir = osp.join(self.dataset_dir, 'query')
30 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
31 |
32 | self._check_before_run()
33 |
34 | train = self._process_dir(self.train_dir, relabel=True)
35 | query = self._process_dir(self.query_dir, relabel=False)
36 | gallery = self._process_dir(self.gallery_dir, relabel=False)
37 |
38 | if verbose:
39 | print("=> Market1501 loaded")
40 | self.print_dataset_statistics(train, query, gallery)
41 |
42 | self.train = train
43 | self.query = query
44 | self.gallery = gallery
45 |
46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
49 |
50 | def _check_before_run(self):
51 | """Check if all files are available before going deeper"""
52 | if not osp.exists(self.dataset_dir):
53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
54 | if not osp.exists(self.train_dir):
55 | raise RuntimeError("'{}' is not available".format(self.train_dir))
56 | if not osp.exists(self.query_dir):
57 | raise RuntimeError("'{}' is not available".format(self.query_dir))
58 | if not osp.exists(self.gallery_dir):
59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
60 |
61 | def _process_dir(self, dir_path, relabel=False):
62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
63 | pattern = re.compile(r'([-\d]+)_c(\d)')
64 |
65 | pid_container = set()
66 | for img_path in img_paths:
67 | pid, _ = map(int, pattern.search(img_path).groups())
68 | if pid == -1: continue # junk images are just ignored
69 | pid_container.add(pid)
70 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
71 |
72 | dataset = []
73 | for img_path in img_paths:
74 | pid, camid = map(int, pattern.search(img_path).groups())
75 | if pid == -1: continue # junk images are just ignored
76 | assert 0 <= pid <= 1501 # pid == 0 means background
77 | assert 1 <= camid <= 6
78 | camid -= 1 # index starts from 0
79 | if relabel: pid = pid2label[pid]
80 | dataset.append((img_path, pid, camid))
81 |
82 | return dataset
83 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/msmt17.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 | import tarfile
4 |
5 | import glob
6 | import re
7 | import urllib
8 | import zipfile
9 |
10 | from ..utils.osutils import mkdir_if_missing
11 | from ..utils.serialization import write_json
12 |
13 |
14 | def _pluck_msmt(list_file, subdir, pattern=re.compile(r'([-\d]+)_([-\d]+)_([-\d]+)')):
15 | with open(list_file, 'r') as f:
16 | lines = f.readlines()
17 | ret = []
18 | pids = []
19 | for line in lines:
20 | line = line.strip()
21 | fname = line.split(' ')[0]
22 | pid, _, cam = map(int, pattern.search(osp.basename(fname)).groups())
23 | if pid not in pids:
24 | pids.append(pid)
25 | ret.append((osp.join(subdir,fname), pid, cam))
26 | return ret, pids
27 |
28 | class Dataset_MSMT(object):
29 | def __init__(self, root):
30 | self.root = root
31 | self.train, self.val, self.trainval = [], [], []
32 | self.query, self.gallery = [], []
33 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0
34 |
35 | @property
36 | def images_dir(self):
37 | return osp.join(self.root, 'MSMT17_V1')
38 |
39 | def load(self, verbose=True):
40 | exdir = osp.join(self.root, 'MSMT17_V1')
41 | self.train, train_pids = _pluck_msmt(osp.join(exdir, 'list_train.txt'), 'train')
42 | self.val, val_pids = _pluck_msmt(osp.join(exdir, 'list_val.txt'), 'train')
43 | self.train = self.train + self.val
44 | self.query, query_pids = _pluck_msmt(osp.join(exdir, 'list_query.txt'), 'test')
45 | self.gallery, gallery_pids = _pluck_msmt(osp.join(exdir, 'list_gallery.txt'), 'test')
46 | self.num_train_pids = len(list(set(train_pids).union(set(val_pids))))
47 |
48 | if verbose:
49 | print(self.__class__.__name__, "dataset loaded")
50 | print(" subset | # ids | # images")
51 | print(" ---------------------------")
52 | print(" train | {:5d} | {:8d}"
53 | .format(self.num_train_pids, len(self.train)))
54 | print(" query | {:5d} | {:8d}"
55 | .format(len(query_pids), len(self.query)))
56 | print(" gallery | {:5d} | {:8d}"
57 | .format(len(gallery_pids), len(self.gallery)))
58 |
59 | class MSMT17(Dataset_MSMT):
60 |
61 | def __init__(self, root, split_id=0, download=True):
62 | super(MSMT17, self).__init__(root)
63 |
64 | if download:
65 | self.download()
66 |
67 | self.load()
68 |
69 | def download(self):
70 |
71 | import re
72 | import hashlib
73 | import shutil
74 | from glob import glob
75 | from zipfile import ZipFile
76 |
77 | raw_dir = osp.join(self.root)
78 | mkdir_if_missing(raw_dir)
79 |
80 | # Download the raw zip file
81 | fpath = osp.join(raw_dir, 'MSMT17_V1')
82 | if osp.isdir(fpath):
83 | print("Using downloaded file: " + fpath)
84 | else:
85 | raise RuntimeError("Please download the dataset manually to {}".format(fpath))
86 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/personx.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 | import glob
4 | import re
5 | import urllib
6 | import zipfile
7 |
8 | from ..utils.data import BaseImageDataset
9 | from ..utils.osutils import mkdir_if_missing
10 | from ..utils.serialization import write_json
11 |
12 | class PersonX(BaseImageDataset):
13 | """
14 | PersonX
15 | Reference:
16 | Sun et al. Dissecting Person Re-identification from the Viewpoint of Viewpoint. CVPR 2019.
17 |
18 | Dataset statistics:
19 | # identities: 1266
20 | # images: 9840 (train) + 5136 (query) + 30816 (gallery)
21 | """
22 | dataset_dir = 'PersonX'
23 |
24 | def __init__(self, root, verbose=True, **kwargs):
25 | super(PersonX, self).__init__()
26 | self.dataset_dir = osp.join(root, self.dataset_dir)
27 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
28 | self.query_dir = osp.join(self.dataset_dir, 'query')
29 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
30 |
31 | self._check_before_run()
32 |
33 | train = self._process_dir(self.train_dir, relabel=True)
34 | query = self._process_dir(self.query_dir, relabel=False)
35 | gallery = self._process_dir(self.gallery_dir, relabel=False)
36 |
37 | if verbose:
38 | print("=> PersonX loaded")
39 | self.print_dataset_statistics(train, query, gallery)
40 |
41 | self.train = train
42 | self.query = query
43 | self.gallery = gallery
44 |
45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
48 |
49 | def _check_before_run(self):
50 | """Check if all files are available before going deeper"""
51 | if not osp.exists(self.dataset_dir):
52 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
53 | if not osp.exists(self.train_dir):
54 | raise RuntimeError("'{}' is not available".format(self.train_dir))
55 | if not osp.exists(self.query_dir):
56 | raise RuntimeError("'{}' is not available".format(self.query_dir))
57 | if not osp.exists(self.gallery_dir):
58 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
59 |
60 | def _process_dir(self, dir_path, relabel=False):
61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
62 | pattern = re.compile(r'([-\d]+)_c([-\d]+)')
63 | cam2label = {3:1, 4:2, 8:3, 10:4, 11:5, 12:6}
64 |
65 | pid_container = set()
66 | for img_path in img_paths:
67 | pid, _ = map(int, pattern.search(img_path).groups())
68 | pid_container.add(pid)
69 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
70 |
71 | dataset = []
72 | for img_path in img_paths:
73 | pid, camid = map(int, pattern.search(img_path).groups())
74 | assert (camid in cam2label.keys())
75 | camid = cam2label[camid]
76 | camid -= 1 # index starts from 0
77 | if relabel: pid = pid2label[pid]
78 | dataset.append((img_path, pid, camid))
79 |
80 | return dataset
81 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/regdb.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Wed Mar 31 23:02:42 2021
5 |
6 | @author: vision
7 | """
8 |
9 |
10 | from __future__ import print_function, absolute_import
11 | import os.path as osp
12 | import os
13 | import random
14 | from glob import glob
15 | import re
16 | import urllib
17 | import zipfile
18 |
19 | from ..utils.data import BaseImageDataset
20 | from ..utils.osutils import mkdir_if_missing
21 | from ..utils.serialization import write_json
22 |
23 | class RegDB(BaseImageDataset):
24 | dataset_dir = "RegDB"
25 |
26 | def __init__(self, root, verbose=True, ii=1, mode='', **kwargs):
27 | super(RegDB, self).__init__()
28 |
29 | self.dataset_dir = osp.join(root, self.dataset_dir)
30 | self.ii = ii
31 | self.index_train_RGB = self.loadIdx(open((self.dataset_dir+'/idx/train_visible_{}.txt').format(self.ii),'r'))
32 | self.index_train_IR = self.loadIdx(open((self.dataset_dir+'/idx/train_thermal_{}.txt').format(self.ii),'r'))
33 | self.index_test_RGB = self.loadIdx(open((self.dataset_dir+'/idx/test_visible_{}.txt').format(self.ii),'r'))
34 | self.index_test_IR = self.loadIdx(open((self.dataset_dir+'/idx/test_thermal_{}.txt').format(self.ii),'r'))
35 |
36 | self.train = self._process_dir(self.index_train_RGB, 0, 0) + self._process_dir(self.index_train_IR, 1, 0)
37 | if mode == 't2v':
38 | self.query = self._process_dir(self.index_test_IR, 1, 206)
39 | self.gallery = self._process_dir(self.index_test_RGB, 0, 206)
40 | elif mode == 'v2t':
41 | self.query = self._process_dir(self.index_test_RGB, 0, 206)
42 | self.gallery = self._process_dir(self.index_test_IR, 1, 206)
43 |
44 | if verbose:
45 | print("=> RegDB loaded trial:{}".format(ii))
46 | self.print_dataset_statistics(self.train, self.query, self.gallery)
47 |
48 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
49 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
50 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
51 |
52 | def _check_before_run(self):
53 | """Check if all files are available before going deeper"""
54 | if not osp.exists(self.dataset_dir):
55 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
56 | if not osp.exists(self.train_dir):
57 | raise RuntimeError("'{}' is not available".format(self.train_dir))
58 | if not osp.exists(self.val_dir):
59 | raise RuntimeError("'{}' is not available".format(self.val_dir))
60 | if not osp.exists(self.text_dir):
61 | raise RuntimeError("'{}' is not available".format(self.text_dir))
62 |
63 | def loadIdx(self, index):
64 | Lines = index.readlines()
65 | idx = []
66 | for line in Lines:
67 | tmp = line.strip('\n')
68 | tmp = tmp.split(' ')
69 | idx.append(tmp)
70 | return idx
71 |
72 | def _process_dir(self, index, cam, delta):
73 | dataset = []
74 | for idx in index:
75 | fname = osp.join(self.dataset_dir, idx[0])
76 | pid = int(idx[1]) + delta
77 | dataset.append((fname, pid, cam))
78 | return dataset
79 |
80 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/regdb_ir.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Wed Mar 31 23:02:42 2021
5 |
6 | @author: vision
7 | """
8 |
9 |
10 | from __future__ import print_function, absolute_import
11 | import os.path as osp
12 | import os
13 | import random
14 | from glob import glob
15 | import re
16 | import urllib
17 | import zipfile
18 |
19 | from ..utils.data import BaseImageDataset
20 | from ..utils.osutils import mkdir_if_missing
21 | from ..utils.serialization import write_json
22 |
23 | class RegDB_IR(BaseImageDataset):
24 | dataset_dir = "RegDB"
25 |
26 | def __init__(self, root, verbose=True, ii=1, mode='', **kwargs):
27 | super(RegDB_IR, self).__init__()
28 |
29 | self.dataset_dir = osp.join(root, self.dataset_dir)
30 | self.ii = ii
31 | self.index_train_RGB = self.loadIdx(open((self.dataset_dir+'/idx/train_visible_{}.txt').format(self.ii),'r'))
32 | self.index_train_IR = self.loadIdx(open((self.dataset_dir+'/idx/train_thermal_{}.txt').format(self.ii),'r'))
33 | self.index_test_RGB = self.loadIdx(open((self.dataset_dir+'/idx/test_visible_{}.txt').format(self.ii),'r'))
34 | self.index_test_IR = self.loadIdx(open((self.dataset_dir+'/idx/test_thermal_{}.txt').format(self.ii),'r'))
35 |
36 | self.train = self._process_dir(self.index_train_IR, 1, 0)
37 | if mode == 't2v':
38 | self.query = self._process_dir(self.index_test_IR, 1, 206)
39 | self.gallery = self._process_dir(self.index_test_RGB, 0, 206)
40 | elif mode == 'v2t':
41 | self.query = self._process_dir(self.index_test_RGB, 0, 206)
42 | self.gallery = self._process_dir(self.index_test_IR, 1, 206)
43 |
44 | if verbose:
45 | print("=> RegDB IR loaded trial:{}".format(ii))
46 | self.print_dataset_statistics(self.train, self.query, self.gallery)
47 |
48 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
49 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
50 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
51 |
52 | def _check_before_run(self):
53 | """Check if all files are available before going deeper"""
54 | if not osp.exists(self.dataset_dir):
55 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
56 | if not osp.exists(self.train_dir):
57 | raise RuntimeError("'{}' is not available".format(self.train_dir))
58 | if not osp.exists(self.val_dir):
59 | raise RuntimeError("'{}' is not available".format(self.val_dir))
60 | if not osp.exists(self.text_dir):
61 | raise RuntimeError("'{}' is not available".format(self.text_dir))
62 |
63 | def loadIdx(self, index):
64 | Lines = index.readlines()
65 | idx = []
66 | for line in Lines:
67 | tmp = line.strip('\n')
68 | tmp = tmp.split(' ')
69 | idx.append(tmp)
70 | return idx
71 |
72 | def _process_dir(self, index, cam, delta):
73 | dataset = []
74 | for idx in index:
75 | fname = osp.join(self.dataset_dir, idx[0])
76 | pid = int(idx[1]) + delta
77 | dataset.append((fname, pid, cam))
78 | return dataset
79 |
80 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/regdb_rgb.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Wed Mar 31 23:02:42 2021
5 |
6 | @author: vision
7 | """
8 |
9 |
10 | from __future__ import print_function, absolute_import
11 | import os.path as osp
12 | import os
13 | import random
14 | from glob import glob
15 | import re
16 | import urllib
17 | import zipfile
18 |
19 | from ..utils.data import BaseImageDataset
20 | from ..utils.osutils import mkdir_if_missing
21 | from ..utils.serialization import write_json
22 |
23 | class RegDB_RGB(BaseImageDataset):
24 | dataset_dir = "RegDB"
25 |
26 | def __init__(self, root, verbose=True, ii=1, mode='t2v', **kwargs):
27 | super(RegDB_RGB, self).__init__()
28 |
29 | self.dataset_dir = osp.join(root, self.dataset_dir)
30 | self.ii = ii
31 | self.index_train_RGB = self.loadIdx(open((self.dataset_dir+'/idx/train_visible_{}.txt').format(self.ii),'r'))
32 | self.index_train_IR = self.loadIdx(open((self.dataset_dir+'/idx/train_thermal_{}.txt').format(self.ii),'r'))
33 | self.index_test_RGB = self.loadIdx(open((self.dataset_dir+'/idx/test_visible_{}.txt').format(self.ii),'r'))
34 | self.index_test_IR = self.loadIdx(open((self.dataset_dir+'/idx/test_thermal_{}.txt').format(self.ii),'r'))
35 |
36 | self.train = self._process_dir(self.index_train_RGB, 0, 0)
37 | if mode == 't2v':
38 | self.query = self._process_dir(self.index_test_IR, 1, 206)
39 | self.gallery = self._process_dir(self.index_test_RGB, 0, 206)
40 | elif mode == 'v2t':
41 | self.query = self._process_dir(self.index_test_RGB, 0, 206)
42 | self.gallery = self._process_dir(self.index_test_IR, 1, 206)
43 |
44 | if verbose:
45 | print("=> RegDB RGB loaded trial:{}".format(ii))
46 | self.print_dataset_statistics(self.train, self.query, self.gallery)
47 |
48 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
49 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
50 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
51 |
52 | def _check_before_run(self):
53 | """Check if all files are available before going deeper"""
54 | if not osp.exists(self.dataset_dir):
55 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
56 | if not osp.exists(self.train_dir):
57 | raise RuntimeError("'{}' is not available".format(self.train_dir))
58 | if not osp.exists(self.val_dir):
59 | raise RuntimeError("'{}' is not available".format(self.val_dir))
60 | if not osp.exists(self.text_dir):
61 | raise RuntimeError("'{}' is not available".format(self.text_dir))
62 |
63 | def loadIdx(self, index):
64 | Lines = index.readlines()
65 | idx = []
66 | for line in Lines:
67 | tmp = line.strip('\n')
68 | tmp = tmp.split(' ')
69 | idx.append(tmp)
70 | return idx
71 |
72 | def _process_dir(self, index, cam, delta):
73 | dataset = []
74 | for idx in index:
75 | fname = osp.join(self.dataset_dir, idx[0])
76 | pid = int(idx[1]) + delta
77 | dataset.append((fname, pid, cam))
78 | return dataset
79 |
80 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/sysumm01.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Sat Mar 20 20:21:24 2021
5 |
6 | @author: vision
7 | """
8 |
9 |
10 | from __future__ import print_function, absolute_import
11 | import os.path as osp
12 | import re
13 | import random
14 | import numpy as np
15 | from glob import glob
16 |
17 | from ..utils.data import BaseImageDataset
18 |
19 |
20 | class SYSU_MM01(BaseImageDataset):
21 | dataset_dir = "SYSU-MM01"
22 |
23 | def __init__(self, root='', verbose=True, pid_begin=0, mode='all', **kwargs):
24 | super(SYSU_MM01, self).__init__()
25 |
26 | self.pid_begin = pid_begin
27 | self.dataset_dir = osp.join(root, self.dataset_dir)
28 | self.train_dir = osp.join(self.dataset_dir, 'exp/train_id.txt')
29 | self.val_dir = osp.join(self.dataset_dir, 'exp/val_id.txt')
30 | self.text_dir = osp.join(self.dataset_dir, 'exp/test_id.txt')
31 |
32 | self._check_before_run()
33 |
34 | self.train_id = self._get_id(self.train_dir) + self._get_id(self.val_dir)
35 | self.query_id = self._get_id(self.text_dir)
36 | self.gallery_id = self.query_id
37 |
38 | self.rgb_cams = ['cam1', 'cam2', 'cam4', 'cam5']
39 | self.ir_cams = ['cam3', 'cam6']
40 | self.train = self._process_dir(self.train_id, self.rgb_cams + self.ir_cams)
41 | self.query = self._process_dir(self.query_id, self.ir_cams)
42 | if mode == 'all':
43 | # self.gallery = self._process_dir(self.gallery_id, self.rgb_cams)
44 | self.gallery = self._process_dir_gallery(self.gallery_id, self.rgb_cams)
45 | elif mode == 'indoor':
46 | # self.gallery = self._process_dir(self.gallery_id, ['cam1', 'cam2'])
47 | self.gallery = self._process_dir_gallery(self.gallery_id, ['cam1', 'cam2'])
48 |
49 | if verbose:
50 | print("=> SYSU-MM01 loaded")
51 | self.print_dataset_statistics(self.train, self.query, self.gallery)
52 |
53 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
54 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
55 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
56 |
57 | def _check_before_run(self):
58 | """Check if all files are available before going deeper"""
59 | if not osp.exists(self.dataset_dir):
60 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
61 | if not osp.exists(self.train_dir):
62 | raise RuntimeError("'{}' is not available".format(self.train_dir))
63 | if not osp.exists(self.val_dir):
64 | raise RuntimeError("'{}' is not available".format(self.val_dir))
65 | if not osp.exists(self.text_dir):
66 | raise RuntimeError("'{}' is not available".format(self.text_dir))
67 |
68 | def _get_id(self, file_path):
69 | with open(file_path, 'r') as f:
70 | ids = f.read().splitlines()
71 | ids = [int(y) for y in ids[0].split(',')]
72 | ids = ["%04d" % x for x in ids]
73 | return ids
74 |
75 | def _process_dir(self, ids, cams):
76 | ids_container = list(np.unique(ids))
77 | id2label = {id_: label for label, id_ in enumerate(ids_container)}
78 |
79 | dataset = []
80 | for id_ in sorted(ids):
81 | for cam in cams:
82 | img_dir = osp.join(self.dataset_dir, cam, id_)
83 | if osp.isdir(img_dir):
84 | img_list = glob(osp.join(img_dir, "*.jpg"))
85 | img_list.sort()
86 | for img_path in img_list:
87 | dataset.append((img_path, self.pid_begin + id2label[id_], int(cam[-1])-1))
88 | return dataset
89 |
90 | def _process_dir_gallery(self, ids, cams):
91 | ids_container = list(np.unique(ids))
92 | id2label = {id_: label for label, id_ in enumerate(ids_container)}
93 |
94 | dataset = []
95 | for id_ in sorted(ids):
96 | for cam in cams:
97 | img_dir = osp.join(self.dataset_dir, cam, id_)
98 | if osp.isdir(img_dir):
99 | img_list = glob(osp.join(img_dir, "*.jpg"))
100 | img_list.sort()
101 | dataset.append((random.choice(img_list), self.pid_begin + id2label[id_], int(cam[-1])-1))
102 | return dataset
103 |
104 | # def _process_train(self, train_path):
105 | # data = []
106 |
107 | # file_path_list = ['cam1', 'cam2', 'cam4', 'cam5']
108 |
109 | # for file_path in file_path_list:
110 | # camid = self.dataset_name + "_" + file_path
111 | # pid_list = os.listdir(os.path.join(train_path, file_path))
112 | # for pid_dir in pid_list:
113 | # pid = self.dataset_name + "_" + pid_dir
114 | # img_list = glob(os.path.join(train_path, file_path, pid_dir, "*.jpg"))
115 | # for img_path in img_list:
116 | # data.append([img_path, pid, camid])
117 | # return data
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/sysumm01_ir.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 | import re
4 | import random
5 | import numpy as np
6 | from glob import glob
7 |
8 | from ..utils.data import BaseImageDataset
9 |
10 |
11 | class SYSU_MM01_IR(BaseImageDataset):
12 | dataset_dir = "SYSU-MM01"
13 |
14 | def __init__(self, root='', verbose=True, ncl=1, mode='all', **kwargs):
15 | super(SYSU_MM01_IR, self).__init__()
16 |
17 | self.dataset_dir = osp.join(root, self.dataset_dir)
18 | self.train_dir = osp.join(self.dataset_dir, 'exp/train_id.txt')
19 | self.val_dir = osp.join(self.dataset_dir, 'exp/val_id.txt')
20 | self.text_dir = osp.join(self.dataset_dir, 'exp/test_id.txt')
21 |
22 | self._check_before_run()
23 |
24 | self.train_id = self._get_id(self.train_dir) + self._get_id(self.val_dir)
25 | self.query_id = self._get_id(self.text_dir)
26 | self.gallery_id = self.query_id
27 |
28 | self.rgb_cams = ['cam1', 'cam2', 'cam4', 'cam5']
29 | self.ir_cams = ['cam3', 'cam6']
30 | self.train = self._process_dir(self.train_id, self.ir_cams)
31 | self.query = self._process_dir(self.query_id, self.ir_cams)
32 | if mode == 'all':
33 | # self.gallery = self._process_dir(self.gallery_id, self.rgb_cams)
34 | self.gallery = self._process_dir_gallery(self.gallery_id, self.rgb_cams)
35 | elif mode == 'indoor':
36 | # self.gallery = self._process_dir(self.gallery_id, ['cam1', 'cam2'])
37 | self.gallery = self._process_dir_gallery(self.gallery_id, ['cam1', 'cam2'])
38 |
39 | if verbose:
40 | print("=> SYSU-MM01 IR loaded")
41 | self.print_dataset_statistics(self.train, self.query, self.gallery)
42 |
43 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
44 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
45 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
46 |
47 | def _check_before_run(self):
48 | """Check if all files are available before going deeper"""
49 | if not osp.exists(self.dataset_dir):
50 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
51 | if not osp.exists(self.train_dir):
52 | raise RuntimeError("'{}' is not available".format(self.train_dir))
53 | if not osp.exists(self.val_dir):
54 | raise RuntimeError("'{}' is not available".format(self.val_dir))
55 | if not osp.exists(self.text_dir):
56 | raise RuntimeError("'{}' is not available".format(self.text_dir))
57 |
58 | def _get_id(self, file_path):
59 | with open(file_path, 'r') as f:
60 | ids = f.read().splitlines()
61 | ids = [int(y) for y in ids[0].split(',')]
62 | ids = ["%04d" % x for x in ids]
63 | return ids
64 |
65 | def _process_dir(self, ids, cams):
66 | ids_container = list(np.unique(ids))
67 | id2label = {id_: label for label, id_ in enumerate(ids_container)}
68 |
69 | dataset = []
70 | for id_ in sorted(ids):
71 | for cam in cams:
72 | img_dir = osp.join(self.dataset_dir, cam, id_)
73 | if osp.isdir(img_dir):
74 | img_list = glob(osp.join(img_dir, "*.jpg"))
75 | img_list.sort()
76 | for img_path in img_list:
77 | dataset.append((img_path, id2label[id_], int(cam[-1]) - 1))
78 | return dataset
79 |
80 | def _process_dir_gallery(self, ids, cams):
81 | ids_container = list(np.unique(ids))
82 | id2label = {id_: label for label, id_ in enumerate(ids_container)}
83 |
84 | dataset = []
85 | for id_ in sorted(ids):
86 | for cam in cams:
87 | img_dir = osp.join(self.dataset_dir, cam, id_)
88 | if osp.isdir(img_dir):
89 | img_list = glob(osp.join(img_dir, "*.jpg"))
90 | img_list.sort()
91 | dataset.append((random.choice(img_list), id2label[id_], int(cam[-1]) - 1))
92 | return dataset
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/sysumm01_rgb.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 | import re
4 | import random
5 | import numpy as np
6 | from glob import glob
7 |
8 | from ..utils.data import BaseImageDataset
9 |
10 |
11 | class SYSU_MM01_RGB(BaseImageDataset):
12 | dataset_dir = "SYSU-MM01"
13 |
14 | def __init__(self, root='', verbose=True, ncl=1, mode='all', **kwargs):
15 | super(SYSU_MM01_RGB, self).__init__()
16 |
17 | self.dataset_dir = osp.join(root, self.dataset_dir)
18 | self.train_dir = osp.join(self.dataset_dir, 'exp/train_id.txt')
19 | self.val_dir = osp.join(self.dataset_dir, 'exp/val_id.txt')
20 | self.text_dir = osp.join(self.dataset_dir, 'exp/test_id.txt')
21 |
22 | self._check_before_run()
23 |
24 | self.train_id = self._get_id(self.train_dir) + self._get_id(self.val_dir)
25 | self.query_id = self._get_id(self.text_dir)
26 | self.gallery_id = self.query_id
27 |
28 | self.rgb_cams = ['cam1', 'cam2', 'cam4', 'cam5']
29 | self.ir_cams = ['cam3', 'cam6']
30 | self.train = self._process_dir(self.train_id, self.rgb_cams)
31 | self.query = self._process_dir(self.query_id, self.ir_cams)
32 | if mode == 'all':
33 | # self.gallery = self._process_dir(self.gallery_id, self.rgb_cams)
34 | self.gallery = self._process_dir_gallery(self.gallery_id, self.rgb_cams)
35 | elif mode == 'indoor':
36 | # self.gallery = self._process_dir(self.gallery_id, ['cam1', 'cam2'])
37 | self.gallery = self._process_dir_gallery(self.gallery_id, ['cam1', 'cam2'])
38 |
39 | if verbose:
40 | print("=> SYSU-MM01 RGB loaded")
41 | self.print_dataset_statistics(self.train, self.query, self.gallery)
42 |
43 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
44 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
45 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
46 |
47 | def _check_before_run(self):
48 | """Check if all files are available before going deeper"""
49 | if not osp.exists(self.dataset_dir):
50 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
51 | if not osp.exists(self.train_dir):
52 | raise RuntimeError("'{}' is not available".format(self.train_dir))
53 | if not osp.exists(self.val_dir):
54 | raise RuntimeError("'{}' is not available".format(self.val_dir))
55 | if not osp.exists(self.text_dir):
56 | raise RuntimeError("'{}' is not available".format(self.text_dir))
57 |
58 | def _get_id(self, file_path):
59 | with open(file_path, 'r') as f:
60 | ids = f.read().splitlines()
61 | ids = [int(y) for y in ids[0].split(',')]
62 | ids = ["%04d" % x for x in ids]
63 | return ids
64 |
65 | def _process_dir(self, ids, cams):
66 | ids_container = list(np.unique(ids))
67 | id2label = {id_: label for label, id_ in enumerate(ids_container)}
68 |
69 | dataset = []
70 | for id_ in sorted(ids):
71 | for cam in cams:
72 | img_dir = osp.join(self.dataset_dir, cam, id_)
73 | if osp.isdir(img_dir):
74 | img_list = glob(osp.join(img_dir, "*.jpg"))
75 | img_list.sort()
76 | for img_path in img_list:
77 | dataset.append((img_path, id2label[id_], int(cam[-1]) - 1))
78 | return dataset
79 |
80 | def _process_dir_gallery(self, ids, cams):
81 | ids_container = list(np.unique(ids))
82 | id2label = {id_: label for label, id_ in enumerate(ids_container)}
83 |
84 | dataset = []
85 | for id_ in sorted(ids):
86 | for cam in cams:
87 | img_dir = osp.join(self.dataset_dir, cam, id_)
88 | if osp.isdir(img_dir):
89 | img_list = glob(osp.join(img_dir, "*.jpg"))
90 | img_list.sort()
91 | dataset.append((random.choice(img_list), id2label[id_], int(cam[-1]) - 1))
92 | return dataset
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/vehicleid.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import random
6 | import os.path as osp
7 |
8 | from ..utils.data import BaseImageDataset
9 | from collections import defaultdict
10 |
11 |
12 | class VehicleID(BaseImageDataset):
13 | """
14 | VehicleID
15 | Reference:
16 | Deep Relative Distance Learning: Tell the Difference Between Similar Vehicles
17 |
18 | Dataset statistics:
19 | # train_list: 13164 vehicles for model training
20 | # test_list_800: 800 vehicles for model testing(small test set in paper
21 | # test_list_1600: 1600 vehicles for model testing(medium test set in paper
22 | # test_list_2400: 2400 vehicles for model testing(large test set in paper
23 | # test_list_3200: 3200 vehicles for model testing
24 | # test_list_6000: 6000 vehicles for model testing
25 | # test_list_13164: 13164 vehicles for model testing
26 | """
27 | dataset_dir = 'VehicleID'
28 |
29 | def __init__(self, root, verbose=True, test_size=800, **kwargs):
30 | super(VehicleID, self).__init__()
31 | self.dataset_dir = osp.join(root, self.dataset_dir)
32 | self.img_dir = osp.join(self.dataset_dir, 'image')
33 | self.split_dir = osp.join(self.dataset_dir, 'train_test_split')
34 | self.train_list = osp.join(self.split_dir, 'train_list.txt')
35 | self.test_size = test_size
36 |
37 | if self.test_size == 800:
38 | self.test_list = osp.join(self.split_dir, 'test_list_800.txt')
39 | elif self.test_size == 1600:
40 | self.test_list = osp.join(self.split_dir, 'test_list_1600.txt')
41 | elif self.test_size == 2400:
42 | self.test_list = osp.join(self.split_dir, 'test_list_2400.txt')
43 |
44 | print(self.test_list)
45 |
46 | self.check_before_run()
47 |
48 | train, query, gallery = self.process_split(relabel=True)
49 | self.train = train
50 | self.query = query
51 | self.gallery = gallery
52 |
53 | if verbose:
54 | print('=> VehicleID loaded')
55 | self.print_dataset_statistics(train, query, gallery)
56 |
57 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
58 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
59 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
60 |
61 | def check_before_run(self):
62 | """Check if all files are available before going deeper"""
63 | if not osp.exists(self.dataset_dir):
64 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir))
65 | if not osp.exists(self.split_dir):
66 | raise RuntimeError('"{}" is not available'.format(self.split_dir))
67 | if not osp.exists(self.train_list):
68 | raise RuntimeError('"{}" is not available'.format(self.train_list))
69 | if self.test_size not in [800, 1600, 2400]:
70 | raise RuntimeError('"{}" is not available'.format(self.test_size))
71 | if not osp.exists(self.test_list):
72 | raise RuntimeError('"{}" is not available'.format(self.test_list))
73 |
74 | def get_pid2label(self, pids):
75 | pid_container = set(pids)
76 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
77 | return pid2label
78 |
79 | def parse_img_pids(self, nl_pairs, pid2label=None):
80 | # il_pair is the pairs of img name and label
81 | output = []
82 | for info in nl_pairs:
83 | name = info[0]
84 | pid = info[1]
85 | if pid2label is not None:
86 | pid = pid2label[pid]
87 | camid = 0 # don't have camid information use 0 for all
88 | img_path = osp.join(self.img_dir, name+'.jpg')
89 | output.append((img_path, pid, camid))
90 | return output
91 |
92 | def process_split(self, relabel=False):
93 | # read train paths
94 | train_pid_dict = defaultdict(list)
95 |
96 | # 'train_list.txt' format:
97 | # the first number is the number of image
98 | # the second number is the id of vehicle
99 | with open(self.train_list) as f_train:
100 | train_data = f_train.readlines()
101 | for data in train_data:
102 | name, pid = data.strip().split(' ')
103 | pid = int(pid)
104 | train_pid_dict[pid].append([name, pid])
105 | train_pids = list(train_pid_dict.keys())
106 | num_train_pids = len(train_pids)
107 | assert num_train_pids == 13164, 'There should be 13164 vehicles for training,' \
108 | ' but but got {}, please check the data'\
109 | .format(num_train_pids)
110 | # print('num of train ids: {}'.format(num_train_pids))
111 | test_pid_dict = defaultdict(list)
112 | with open(self.test_list) as f_test:
113 | test_data = f_test.readlines()
114 | for data in test_data:
115 | name, pid = data.split(' ')
116 | pid = int(pid)
117 | test_pid_dict[pid].append([name, pid])
118 | test_pids = list(test_pid_dict.keys())
119 | num_test_pids = len(test_pids)
120 | assert num_test_pids == self.test_size, 'There should be {} vehicles for testing,' \
121 | ' but but got {}, please check the data'\
122 | .format(self.test_size, num_test_pids)
123 |
124 | train_data = []
125 | query_data = []
126 | gallery_data = []
127 |
128 | # for train ids, all images are used in the train set.
129 | for pid in train_pids:
130 | imginfo = train_pid_dict[pid] # imginfo include image name and id
131 | train_data.extend(imginfo)
132 |
133 | # for each test id, random choose one image for gallery
134 | # and the other ones for query.
135 | for pid in test_pids:
136 | imginfo = test_pid_dict[pid]
137 | sample = random.choice(imginfo)
138 | imginfo.remove(sample)
139 | query_data.extend(imginfo)
140 | gallery_data.append(sample)
141 |
142 | if relabel:
143 | train_pid2label = self.get_pid2label(train_pids)
144 | else:
145 | train_pid2label = None
146 | # for key, value in train_pid2label.items():
147 | # print('{key}:{value}'.format(key=key, value=value))
148 |
149 | train = self.parse_img_pids(train_data, train_pid2label)
150 | query = self.parse_img_pids(query_data)
151 | gallery = self.parse_img_pids(gallery_data)
152 | return train, query, gallery
153 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/vehiclex.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import glob
6 | import re
7 | import os.path as osp
8 |
9 | from ..utils.data import BaseDataset
10 |
11 |
12 | class VehicleX(BaseDataset):
13 | """
14 | VeRi
15 | Reference:
16 | PAMTRI: Pose-Aware Multi-Task Learning for Vehicle Re-Identification Using Highly Randomized Synthetic Data. In: ICCV 2019
17 | """
18 | dataset_dir = 'AIC20_ReID_Simulation'
19 |
20 | def __init__(self, root, verbose=True, **kwargs):
21 | super(VehicleX, self).__init__()
22 | self.dataset_dir = osp.join(root, self.dataset_dir)
23 | self.train_dir = osp.join(self.dataset_dir, 'image_train')
24 |
25 | self.check_before_run()
26 |
27 | train = self.process_dir(self.train_dir, relabel=True)
28 |
29 | if verbose:
30 | print('=> VehicleX loaded')
31 | self.print_dataset_statistics(train)
32 |
33 | self.train = train
34 |
35 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
36 |
37 | def check_before_run(self):
38 | """Check if all files are available before going deeper"""
39 | if not osp.exists(self.dataset_dir):
40 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir))
41 | if not osp.exists(self.train_dir):
42 | raise RuntimeError('"{}" is not available'.format(self.train_dir))
43 |
44 | def process_dir(self, dir_path, relabel=False):
45 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
46 | pattern = re.compile(r'([-\d]+)_c([-\d]+)')
47 |
48 | pid_container = set()
49 | for img_path in img_paths:
50 | pid, _ = map(int, pattern.search(img_path).groups())
51 | if pid == -1:
52 | continue # junk images are just ignored
53 | pid_container.add(pid)
54 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
55 |
56 | dataset = []
57 | for img_path in img_paths:
58 | pid, camid = map(int, pattern.search(img_path).groups())
59 | if pid == -1:
60 | continue # junk images are just ignored
61 | assert 1 <= pid <= 1362
62 | assert 6 <= camid <= 36
63 | camid -= 6 # index starts from 0
64 | if relabel:
65 | pid = pid2label[pid]
66 | dataset.append((img_path, pid, camid))
67 | return dataset
68 |
69 | def print_dataset_statistics(self, train):
70 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
71 |
72 | print("Dataset statistics:")
73 | print(" ----------------------------------------")
74 | print(" subset | # ids | # images | # cameras")
75 | print(" ----------------------------------------")
76 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
77 | print(" ----------------------------------------")
78 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/datasets/veri.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import glob
6 | import re
7 | import os.path as osp
8 |
9 | from ..utils.data import BaseImageDataset
10 |
11 |
12 | class VeRi(BaseImageDataset):
13 | """
14 | VeRi
15 | Reference:
16 | Liu, X., Liu, W., Ma, H., Fu, H.: Large-scale vehicle re-identification in urban surveillance videos. In: IEEE %
17 | International Conference on Multimedia and Expo. (2016) accepted.
18 | Dataset statistics:
19 | # identities: 776 vehicles(576 for training and 200 for testing)
20 | # images: 37778 (train) + 11579 (query)
21 | """
22 | dataset_dir = 'VeRi'
23 |
24 | def __init__(self, root, verbose=True, **kwargs):
25 | super(VeRi, self).__init__()
26 | self.dataset_dir = osp.join(root, self.dataset_dir)
27 | self.train_dir = osp.join(self.dataset_dir, 'image_train')
28 | self.query_dir = osp.join(self.dataset_dir, 'image_query')
29 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
30 |
31 | self.check_before_run()
32 |
33 | train = self.process_dir(self.train_dir, relabel=True)
34 | query = self.process_dir(self.query_dir, relabel=False)
35 | gallery = self.process_dir(self.gallery_dir, relabel=False)
36 |
37 | if verbose:
38 | print('=> VeRi loaded')
39 | self.print_dataset_statistics(train, query, gallery)
40 |
41 | self.train = train
42 | self.query = query
43 | self.gallery = gallery
44 |
45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
48 |
49 | def check_before_run(self):
50 | """Check if all files are available before going deeper"""
51 | if not osp.exists(self.dataset_dir):
52 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir))
53 | if not osp.exists(self.train_dir):
54 | raise RuntimeError('"{}" is not available'.format(self.train_dir))
55 | if not osp.exists(self.query_dir):
56 | raise RuntimeError('"{}" is not available'.format(self.query_dir))
57 | if not osp.exists(self.gallery_dir):
58 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir))
59 |
60 | def process_dir(self, dir_path, relabel=False):
61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
62 | pattern = re.compile(r'([-\d]+)_c([-\d]+)')
63 |
64 | pid_container = set()
65 | for img_path in img_paths:
66 | pid, _ = map(int, pattern.search(img_path).groups())
67 | if pid == -1:
68 | continue # junk images are just ignored
69 | pid_container.add(pid)
70 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
71 |
72 | dataset = []
73 | for img_path in img_paths:
74 | pid, camid = map(int, pattern.search(img_path).groups())
75 | if pid == -1:
76 | continue # junk images are just ignored
77 | assert 0 <= pid <= 776 # pid == 0 means background
78 | assert 1 <= camid <= 20
79 | camid -= 1 # index starts from 0
80 | if relabel:
81 | pid = pid2label[pid]
82 | dataset.append((img_path, pid, camid))
83 |
84 | return dataset
85 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/evaluation_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .classification import accuracy
4 | from .ranking import cmc, mean_ap
5 |
6 | __all__ = [
7 | 'accuracy',
8 | 'cmc',
9 | 'mean_ap'
10 | ]
11 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/evaluation_metrics/classification.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from ..utils import to_torch
5 |
6 |
7 | def accuracy(output, target, topk=(1,)):
8 | with torch.no_grad():
9 | output, target = to_torch(output), to_torch(target)
10 | maxk = max(topk)
11 | batch_size = target.size(0)
12 |
13 | _, pred = output.topk(maxk, 1, True, True)
14 | pred = pred.t()
15 | correct = pred.eq(target.view(1, -1).expand_as(pred))
16 |
17 | ret = []
18 | for k in topk:
19 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
20 | ret.append(correct_k.mul_(1. / batch_size))
21 | return ret
22 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/evaluation_metrics/ranking.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | from sklearn.metrics import average_precision_score
6 |
7 | from ..utils import to_numpy
8 |
9 |
10 | def _unique_sample(ids_dict, num):
11 | mask = np.zeros(num, dtype=np.bool)
12 | for _, indices in ids_dict.items():
13 | i = np.random.choice(indices)
14 | mask[i] = True
15 | return mask
16 |
17 |
18 | def cmc(distmat, query_ids=None, gallery_ids=None,
19 | query_cams=None, gallery_cams=None, topk=100,
20 | separate_camera_set=False,
21 | single_gallery_shot=False,
22 | first_match_break=False):
23 | distmat = to_numpy(distmat)
24 | m, n = distmat.shape
25 | # Fill up default values
26 | if query_ids is None:
27 | query_ids = np.arange(m)
28 | if gallery_ids is None:
29 | gallery_ids = np.arange(n)
30 | if query_cams is None:
31 | query_cams = np.zeros(m).astype(np.int32)
32 | if gallery_cams is None:
33 | gallery_cams = np.ones(n).astype(np.int32)
34 | # Ensure numpy array
35 | query_ids = np.asarray(query_ids)
36 | gallery_ids = np.asarray(gallery_ids)
37 | query_cams = np.asarray(query_cams)
38 | gallery_cams = np.asarray(gallery_cams)
39 | # Sort and find correct matches
40 | indices = np.argsort(distmat, axis=1)
41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
42 | # Compute CMC for each query
43 | ret = np.zeros(topk)
44 | num_valid_queries = 0
45 | for i in range(m):
46 | # Filter out the same id and same camera
47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) |
48 | (gallery_cams[indices[i]] != query_cams[i]))
49 | if separate_camera_set:
50 | # Filter out samples from same camera
51 | valid &= (gallery_cams[indices[i]] != query_cams[i])
52 | if not np.any(matches[i, valid]): continue
53 | if single_gallery_shot:
54 | repeat = 10
55 | gids = gallery_ids[indices[i][valid]]
56 | inds = np.where(valid)[0]
57 | ids_dict = defaultdict(list)
58 | for j, x in zip(inds, gids):
59 | ids_dict[x].append(j)
60 | else:
61 | repeat = 1
62 | for _ in range(repeat):
63 | if single_gallery_shot:
64 | # Randomly choose one instance for each id
65 | sampled = (valid & _unique_sample(ids_dict, len(valid)))
66 | index = np.nonzero(matches[i, sampled])[0]
67 | else:
68 | index = np.nonzero(matches[i, valid])[0]
69 | delta = 1. / (len(index) * repeat)
70 | for j, k in enumerate(index):
71 | if k - j >= topk: break
72 | if first_match_break:
73 | ret[k - j] += 1
74 | break
75 | ret[k - j] += delta
76 | num_valid_queries += 1
77 | if num_valid_queries == 0:
78 | raise RuntimeError("No valid query")
79 | return ret.cumsum() / num_valid_queries
80 |
81 |
82 | def mean_ap(distmat, query_ids=None, gallery_ids=None,
83 | query_cams=None, gallery_cams=None):
84 | distmat = to_numpy(distmat)
85 | m, n = distmat.shape
86 | # Fill up default values
87 | if query_ids is None:
88 | query_ids = np.arange(m)
89 | if gallery_ids is None:
90 | gallery_ids = np.arange(n)
91 | if query_cams is None:
92 | query_cams = np.zeros(m).astype(np.int32)
93 | if gallery_cams is None:
94 | gallery_cams = np.ones(n).astype(np.int32)
95 | # Ensure numpy array
96 | query_ids = np.asarray(query_ids)
97 | gallery_ids = np.asarray(gallery_ids)
98 | query_cams = np.asarray(query_cams)
99 | gallery_cams = np.asarray(gallery_cams)
100 | # Sort and find correct matches
101 | indices = np.argsort(distmat, axis=1)
102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
103 | # Compute AP for each query
104 | aps = []
105 | for i in range(m):
106 | # Filter out the same id and same camera
107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) |
108 | (gallery_cams[indices[i]] != query_cams[i]))
109 | y_true = matches[i, valid]
110 | y_score = -distmat[i][indices[i]][valid]
111 | if not np.any(y_true): continue
112 | aps.append(average_precision_score(y_true, y_score))
113 | if len(aps) == 0:
114 | raise RuntimeError("No valid query")
115 | return np.mean(aps)
116 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/evaluators.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import time
3 | import collections
4 | from collections import OrderedDict
5 | import numpy as np
6 | import torch
7 | import random
8 | import copy
9 |
10 | from .evaluation_metrics import cmc, mean_ap
11 | from .utils.meters import AverageMeter
12 | from .utils.rerank import re_ranking
13 | from .utils import to_torch
14 |
15 | def extract_cnn_feature(model, inputs):
16 | inputs = to_torch(inputs).cuda()
17 | outputs = model(inputs)
18 | outputs = outputs.data.cpu()
19 | return outputs
20 |
21 | def extract_features(model, data_loader, print_freq=50):
22 | model.eval()
23 | batch_time = AverageMeter()
24 | data_time = AverageMeter()
25 |
26 | features = OrderedDict()
27 | labels = OrderedDict()
28 |
29 | end = time.time()
30 | with torch.no_grad():
31 | for i, (imgs, fnames, pids, _, _) in enumerate(data_loader):
32 | data_time.update(time.time() - end)
33 |
34 | outputs = extract_cnn_feature(model, imgs)
35 | for fname, output, pid in zip(fnames, outputs, pids):
36 | features[fname] = output
37 | labels[fname] = pid
38 |
39 | batch_time.update(time.time() - end)
40 | end = time.time()
41 |
42 | if (i + 1) % print_freq == 0:
43 | print('Extract Features: [{}/{}]\t'
44 | 'Time {:.3f} ({:.3f})\t'
45 | 'Data {:.3f} ({:.3f})\t'
46 | .format(i + 1, len(data_loader),
47 | batch_time.val, batch_time.avg,
48 | data_time.val, data_time.avg))
49 |
50 | return features, labels
51 |
52 | def pairwise_distance(features, query=None, gallery=None):
53 | if query is None and gallery is None:
54 | n = len(features)
55 | x = torch.cat(list(features.values()))
56 | x = x.view(n, -1)
57 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2
58 | dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t())
59 | return dist_m
60 |
61 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0)
62 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0)
63 | m, n = x.size(0), y.size(0)
64 | x = x.view(m, -1)
65 | y = y.view(n, -1)
66 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
67 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
68 | dist_m.addmm_(1, -2, x, y.t())
69 | return dist_m, x.numpy(), y.numpy()
70 |
71 | def evaluate_all(query_features, gallery_features, distmat, query=None, gallery=None,
72 | query_ids=None, gallery_ids=None,
73 | query_cams=None, gallery_cams=None,
74 | cmc_topk=(1, 5, 10), cmc_flag=False):
75 | if query is not None and gallery is not None:
76 | query_ids = [pid for _, pid, _ in query]
77 | gallery_ids = [pid for _, pid, _ in gallery]
78 | query_cams = [cam for _, _, cam in query]
79 | gallery_cams = [cam for _, _, cam in gallery]
80 | else:
81 | assert (query_ids is not None and gallery_ids is not None
82 | and query_cams is not None and gallery_cams is not None)
83 |
84 | # Compute mean AP
85 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams)
86 | print('Mean AP: {:4.1%}'.format(mAP))
87 |
88 | if (not cmc_flag):
89 | return mAP
90 |
91 | cmc_configs = {
92 | 'market1501': dict(separate_camera_set=False,
93 | single_gallery_shot=False,
94 | first_match_break=True),}
95 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids,
96 | query_cams, gallery_cams, **params)
97 | for name, params in cmc_configs.items()}
98 |
99 | print('CMC Scores:')
100 | for k in cmc_topk:
101 | print(' top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k-1]))
102 | return cmc_scores['market1501'], mAP
103 |
104 |
105 | class Evaluator(object):
106 | def __init__(self, model):
107 | super(Evaluator, self).__init__()
108 | self.model = model
109 |
110 | def evaluate(self, data_loader, query, gallery, cmc_flag=False, rerank=False):
111 | features, _ = extract_features(self.model, data_loader)
112 | distmat, query_features, gallery_features = pairwise_distance(features, query, gallery)
113 | results = evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag)
114 |
115 | if (not rerank):
116 | return results
117 |
118 | print('Applying person re-ranking ...')
119 | distmat_qq, _, _ = pairwise_distance(features, query, query)
120 | distmat_gg, _, _ = pairwise_distance(features, gallery, gallery)
121 | distmat = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy())
122 | return evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag)
123 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .resnet import *
4 | from .resnet_ibn import *
5 |
6 |
7 | __factory = {
8 | 'resnet18': resnet18,
9 | 'resnet34': resnet34,
10 | 'resnet50': resnet50,
11 | 'resnet101': resnet101,
12 | 'resnet152': resnet152,
13 | 'resnet_ibn50a': resnet_ibn50a,
14 | 'resnet_ibn101a': resnet_ibn101a
15 | }
16 |
17 |
18 | def names():
19 | return sorted(__factory.keys())
20 |
21 |
22 | def create(name, *args, **kwargs):
23 | """
24 | Create a model instance.
25 |
26 | Parameters
27 | ----------
28 | name : str
29 | Model name. Can be one of 'inception', 'resnet18', 'resnet34',
30 | 'resnet50', 'resnet101', and 'resnet152'.
31 | pretrained : bool, optional
32 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained
33 | model. Default: True
34 | cut_at_pooling : bool, optional
35 | If True, will cut the model before the last global pooling layer and
36 | ignore the remaining kwargs. Default: False
37 | num_features : int, optional
38 | If positive, will append a Linear layer after the global pooling layer,
39 | with this number of output units, followed by a BatchNorm layer.
40 | Otherwise these layers will not be appended. Default: 256 for
41 | 'inception', 0 for 'resnet*'
42 | norm : bool, optional
43 | If True, will normalize the feature to be unit L2-norm for each sample.
44 | Otherwise will append a ReLU layer after the above Linear layer if
45 | num_features > 0. Default: False
46 | dropout : float, optional
47 | If positive, will append a Dropout layer with this dropout rate.
48 | Default: 0
49 | num_classes : int, optional
50 | If positive, will append a Linear layer at the end as the classifier
51 | with this number of output units. Default: 0
52 | """
53 | if name not in __factory:
54 | raise KeyError("Unknown model:", name)
55 | return __factory[name](*args, **kwargs)
56 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/models/dsbn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | # Domain-specific BatchNorm
5 |
6 | class DSBN2d(nn.Module):
7 | def __init__(self, planes):
8 | super(DSBN2d, self).__init__()
9 | self.num_features = planes
10 | self.BN_S = nn.BatchNorm2d(planes)
11 | self.BN_T = nn.BatchNorm2d(planes)
12 |
13 | def forward(self, x):
14 | if (not self.training):
15 | return self.BN_T(x)
16 |
17 | bs = x.size(0)
18 | assert (bs%2==0)
19 | split = torch.split(x, int(bs/2), 0)
20 | out1 = self.BN_S(split[0].contiguous())
21 | out2 = self.BN_T(split[1].contiguous())
22 | out = torch.cat((out1, out2), 0)
23 | return out
24 |
25 | class DSBN1d(nn.Module):
26 | def __init__(self, planes):
27 | super(DSBN1d, self).__init__()
28 | self.num_features = planes
29 | self.BN_S = nn.BatchNorm1d(planes)
30 | self.BN_T = nn.BatchNorm1d(planes)
31 |
32 | def forward(self, x):
33 | if (not self.training):
34 | return self.BN_T(x)
35 |
36 | bs = x.size(0)
37 | assert (bs%2==0)
38 | split = torch.split(x, int(bs/2), 0)
39 | out1 = self.BN_S(split[0].contiguous())
40 | out2 = self.BN_T(split[1].contiguous())
41 | out = torch.cat((out1, out2), 0)
42 | return out
43 |
44 | def convert_dsbn(model):
45 | for _, (child_name, child) in enumerate(model.named_children()):
46 | assert(not next(model.parameters()).is_cuda)
47 | if isinstance(child, nn.BatchNorm2d):
48 | m = DSBN2d(child.num_features)
49 | m.BN_S.load_state_dict(child.state_dict())
50 | m.BN_T.load_state_dict(child.state_dict())
51 | setattr(model, child_name, m)
52 | elif isinstance(child, nn.BatchNorm1d):
53 | m = DSBN1d(child.num_features)
54 | m.BN_S.load_state_dict(child.state_dict())
55 | m.BN_T.load_state_dict(child.state_dict())
56 | setattr(model, child_name, m)
57 | else:
58 | convert_dsbn(child)
59 |
60 | def convert_bn(model, use_target=True):
61 | for _, (child_name, child) in enumerate(model.named_children()):
62 | assert(not next(model.parameters()).is_cuda)
63 | if isinstance(child, DSBN2d):
64 | m = nn.BatchNorm2d(child.num_features)
65 | if use_target:
66 | m.load_state_dict(child.BN_T.state_dict())
67 | else:
68 | m.load_state_dict(child.BN_S.state_dict())
69 | setattr(model, child_name, m)
70 | elif isinstance(child, DSBN1d):
71 | m = nn.BatchNorm1d(child.num_features)
72 | if use_target:
73 | m.load_state_dict(child.BN_T.state_dict())
74 | else:
75 | m.load_state_dict(child.BN_S.state_dict())
76 | setattr(model, child_name, m)
77 | else:
78 | convert_bn(child, use_target=use_target)
79 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/models/hm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.nn import init
6 | from torch import nn, autograd
7 |
8 |
9 | class HM(autograd.Function):
10 |
11 | @staticmethod
12 | def forward(ctx, inputs, indexes, features, momentum):
13 | ctx.features = features
14 | ctx.momentum = momentum
15 | ctx.save_for_backward(inputs, indexes)
16 | outputs = inputs.mm(ctx.features.t())
17 |
18 | return outputs
19 |
20 | @staticmethod
21 | def backward(ctx, grad_outputs):
22 | inputs, indexes = ctx.saved_tensors
23 | grad_inputs = None
24 | if ctx.needs_input_grad[0]:
25 | grad_inputs = grad_outputs.mm(ctx.features)
26 |
27 | # momentum update
28 | for x, y in zip(inputs, indexes):
29 | ctx.features[y] = ctx.momentum * ctx.features[y] + (1. - ctx.momentum) * x
30 | ctx.features[y] /= ctx.features[y].norm()
31 |
32 | return grad_inputs, None, None, None
33 |
34 |
35 | def hm(inputs, indexes, features, momentum=0.5):
36 | return HM.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device))
37 |
38 |
39 | class HybridMemory(nn.Module):
40 | def __init__(self, num_features, num_samples, temp=0.05, momentum=0.2):
41 | super(HybridMemory, self).__init__()
42 | self.num_features = num_features
43 | self.num_samples = num_samples
44 |
45 | self.momentum = momentum
46 | self.temp = temp
47 |
48 | self.register_buffer('features', torch.zeros(num_samples, num_features))
49 | self.register_buffer('labels', torch.zeros(num_samples).long())
50 |
51 | def forward(self, inputs, indexes):
52 | # inputs: B*2048, features: L*2048
53 | inputs = hm(inputs, indexes, self.features, self.momentum)
54 | inputs /= self.temp
55 | B = inputs.size(0)
56 |
57 | def masked_softmax(vec, mask, dim=1, epsilon=1e-6):
58 | exps = torch.exp(vec)
59 | masked_exps = exps * mask.float().clone()
60 | masked_sums = masked_exps.sum(dim, keepdim=True) + epsilon
61 | return (masked_exps/masked_sums)
62 |
63 | targets = self.labels[indexes].clone()
64 | labels = self.labels.clone()
65 |
66 | sim = torch.zeros(labels.max()+1, B).float().cuda()
67 | sim.index_add_(0, labels, inputs.t().contiguous())
68 | nums = torch.zeros(labels.max()+1, 1).float().cuda()
69 | nums.index_add_(0, labels, torch.ones(self.num_samples,1).float().cuda())
70 | mask = (nums>0).float()
71 | sim /= (mask*nums+(1-mask)).clone().expand_as(sim)
72 | mask = mask.expand_as(sim)
73 | masked_sim = masked_softmax(sim.t().contiguous(), mask.t().contiguous())
74 | return F.nll_loss(torch.log(masked_sim+1e-6), targets)
75 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/models/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch.nn import init
6 | import torchvision
7 | import torch
8 |
9 |
10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
11 | 'resnet152']
12 |
13 |
14 | class ResNet(nn.Module):
15 | __factory = {
16 | 18: torchvision.models.resnet18,
17 | 34: torchvision.models.resnet34,
18 | 50: torchvision.models.resnet50,
19 | 101: torchvision.models.resnet101,
20 | 152: torchvision.models.resnet152,
21 | }
22 |
23 | def __init__(self, depth, pretrained=True, cut_at_pooling=False,
24 | num_features=0, norm=False, dropout=0, num_classes=0):
25 | super(ResNet, self).__init__()
26 | self.pretrained = pretrained
27 | self.depth = depth
28 | self.cut_at_pooling = cut_at_pooling
29 | # Construct base (pretrained) resnet
30 | if depth not in ResNet.__factory:
31 | raise KeyError("Unsupported depth:", depth)
32 | resnet = ResNet.__factory[depth](pretrained=pretrained)
33 | resnet.layer4[0].conv2.stride = (1,1)
34 | resnet.layer4[0].downsample[0].stride = (1,1)
35 | self.base = nn.Sequential(
36 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,
37 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4)
38 | self.gap = nn.AdaptiveAvgPool2d(1)
39 |
40 | if not self.cut_at_pooling:
41 | self.num_features = num_features
42 | self.norm = norm
43 | self.dropout = dropout
44 | self.has_embedding = num_features > 0
45 | self.num_classes = num_classes
46 |
47 | out_planes = resnet.fc.in_features
48 |
49 | # Append new layers
50 | if self.has_embedding:
51 | self.feat = nn.Linear(out_planes, self.num_features)
52 | self.feat_bn = nn.BatchNorm1d(self.num_features)
53 | init.kaiming_normal_(self.feat.weight, mode='fan_out')
54 | init.constant_(self.feat.bias, 0)
55 | else:
56 | # Change the num_features to CNN output channels
57 | self.num_features = out_planes
58 | self.feat_bn = nn.BatchNorm1d(self.num_features)
59 | self.feat_bn.bias.requires_grad_(False)
60 | if self.dropout > 0:
61 | self.drop = nn.Dropout(self.dropout)
62 | if self.num_classes > 0:
63 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False)
64 | init.normal_(self.classifier.weight, std=0.001)
65 | init.constant_(self.feat_bn.weight, 1)
66 | init.constant_(self.feat_bn.bias, 0)
67 |
68 | if not pretrained:
69 | self.reset_params()
70 |
71 | def forward(self, x):
72 | bs = x.size(0)
73 | x = self.base(x)
74 |
75 | x = self.gap(x)
76 | x = x.view(x.size(0), -1)
77 |
78 | if self.cut_at_pooling:
79 | return x
80 |
81 | if self.has_embedding:
82 | bn_x = self.feat_bn(self.feat(x))
83 | else:
84 | bn_x = self.feat_bn(x)
85 |
86 | if (self.training is False):
87 | bn_x = F.normalize(bn_x)
88 | return bn_x
89 |
90 | if self.norm:
91 | bn_x = F.normalize(bn_x)
92 | elif self.has_embedding:
93 | bn_x = F.relu(bn_x)
94 |
95 | if self.dropout > 0:
96 | bn_x = self.drop(bn_x)
97 |
98 | if self.num_classes > 0:
99 | prob = self.classifier(bn_x)
100 | else:
101 | return bn_x
102 |
103 | return prob
104 |
105 | def reset_params(self):
106 | for m in self.modules():
107 | if isinstance(m, nn.Conv2d):
108 | init.kaiming_normal_(m.weight, mode='fan_out')
109 | if m.bias is not None:
110 | init.constant_(m.bias, 0)
111 | elif isinstance(m, nn.BatchNorm2d):
112 | init.constant_(m.weight, 1)
113 | init.constant_(m.bias, 0)
114 | elif isinstance(m, nn.BatchNorm1d):
115 | init.constant_(m.weight, 1)
116 | init.constant_(m.bias, 0)
117 | elif isinstance(m, nn.Linear):
118 | init.normal_(m.weight, std=0.001)
119 | if m.bias is not None:
120 | init.constant_(m.bias, 0)
121 |
122 |
123 | def resnet18(**kwargs):
124 | return ResNet(18, **kwargs)
125 |
126 |
127 | def resnet34(**kwargs):
128 | return ResNet(34, **kwargs)
129 |
130 |
131 | def resnet50(**kwargs):
132 | return ResNet(50, **kwargs)
133 |
134 |
135 | def resnet101(**kwargs):
136 | return ResNet(101, **kwargs)
137 |
138 |
139 | def resnet152(**kwargs):
140 | return ResNet(152, **kwargs)
141 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/models/resnet_ibn.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch.nn import init
6 | import torchvision
7 | import torch
8 |
9 | from .resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a
10 |
11 |
12 | __all__ = ['ResNetIBN', 'resnet_ibn50a', 'resnet_ibn101a']
13 |
14 |
15 | class ResNetIBN(nn.Module):
16 | __factory = {
17 | '50a': resnet50_ibn_a,
18 | '101a': resnet101_ibn_a
19 | }
20 |
21 | def __init__(self, depth, pretrained=True, cut_at_pooling=False,
22 | num_features=0, norm=False, dropout=0, num_classes=0):
23 | super(ResNetIBN, self).__init__()
24 |
25 | self.depth = depth
26 | self.pretrained = pretrained
27 | self.cut_at_pooling = cut_at_pooling
28 |
29 | resnet = ResNetIBN.__factory[depth](pretrained=pretrained)
30 | resnet.layer4[0].conv2.stride = (1,1)
31 | resnet.layer4[0].downsample[0].stride = (1,1)
32 | self.base = nn.Sequential(
33 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,
34 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4)
35 | self.gap = nn.AdaptiveAvgPool2d(1)
36 |
37 | if not self.cut_at_pooling:
38 | self.num_features = num_features
39 | self.norm = norm
40 | self.dropout = dropout
41 | self.has_embedding = num_features > 0
42 | self.num_classes = num_classes
43 |
44 | out_planes = resnet.fc.in_features
45 |
46 | # Append new layers
47 | if self.has_embedding:
48 | self.feat = nn.Linear(out_planes, self.num_features)
49 | self.feat_bn = nn.BatchNorm1d(self.num_features)
50 | init.kaiming_normal_(self.feat.weight, mode='fan_out')
51 | init.constant_(self.feat.bias, 0)
52 | else:
53 | # Change the num_features to CNN output channels
54 | self.num_features = out_planes
55 | self.feat_bn = nn.BatchNorm1d(self.num_features)
56 | self.feat_bn.bias.requires_grad_(False)
57 | if self.dropout > 0:
58 | self.drop = nn.Dropout(self.dropout)
59 | if self.num_classes > 0:
60 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False)
61 | init.normal_(self.classifier.weight, std=0.001)
62 | init.constant_(self.feat_bn.weight, 1)
63 | init.constant_(self.feat_bn.bias, 0)
64 |
65 | if not pretrained:
66 | self.reset_params()
67 |
68 | def forward(self, x):
69 | x = self.base(x)
70 |
71 | x = self.gap(x)
72 | x = x.view(x.size(0), -1)
73 |
74 | if self.cut_at_pooling:
75 | return x
76 |
77 | if self.has_embedding:
78 | bn_x = self.feat_bn(self.feat(x))
79 | else:
80 | bn_x = self.feat_bn(x)
81 |
82 | if self.training is False:
83 | bn_x = F.normalize(bn_x)
84 | return bn_x
85 |
86 | if self.norm:
87 | bn_x = F.normalize(bn_x)
88 | elif self.has_embedding:
89 | bn_x = F.relu(bn_x)
90 |
91 | if self.dropout > 0:
92 | bn_x = self.drop(bn_x)
93 |
94 | if self.num_classes > 0:
95 | prob = self.classifier(bn_x)
96 | else:
97 | return bn_x
98 |
99 | return prob
100 |
101 | def reset_params(self):
102 | for m in self.modules():
103 | if isinstance(m, nn.Conv2d):
104 | init.kaiming_normal_(m.weight, mode='fan_out')
105 | if m.bias is not None:
106 | init.constant_(m.bias, 0)
107 | elif isinstance(m, nn.BatchNorm2d):
108 | init.constant_(m.weight, 1)
109 | init.constant_(m.bias, 0)
110 | elif isinstance(m, nn.BatchNorm1d):
111 | init.constant_(m.weight, 1)
112 | init.constant_(m.bias, 0)
113 | elif isinstance(m, nn.Linear):
114 | init.normal_(m.weight, std=0.001)
115 | if m.bias is not None:
116 | init.constant_(m.bias, 0)
117 |
118 |
119 | def resnet_ibn50a(**kwargs):
120 | return ResNetIBN('50a', **kwargs)
121 |
122 |
123 | def resnet_ibn101a(**kwargs):
124 | return ResNetIBN('101a', **kwargs)
125 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/models/resnet_ibn_a.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import torch.utils.model_zoo as model_zoo
5 |
6 |
7 | __all__ = ['ResNet', 'resnet50_ibn_a', 'resnet101_ibn_a']
8 |
9 |
10 | model_urls = {
11 | 'ibn_resnet50a': './logs/pretrained/resnet50_ibn_a.pth.tar',
12 | 'ibn_resnet101a': './logs/pretrained/resnet101_ibn_a.pth.tar',
13 | }
14 |
15 |
16 | def conv3x3(in_planes, out_planes, stride=1):
17 | "3x3 convolution with padding"
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
19 | padding=1, bias=False)
20 |
21 |
22 | class BasicBlock(nn.Module):
23 | expansion = 1
24 |
25 | def __init__(self, inplanes, planes, stride=1, downsample=None):
26 | super(BasicBlock, self).__init__()
27 | self.conv1 = conv3x3(inplanes, planes, stride)
28 | self.bn1 = nn.BatchNorm2d(planes)
29 | self.relu = nn.ReLU(inplace=True)
30 | self.conv2 = conv3x3(planes, planes)
31 | self.bn2 = nn.BatchNorm2d(planes)
32 | self.downsample = downsample
33 | self.stride = stride
34 |
35 | def forward(self, x):
36 | residual = x
37 |
38 | out = self.conv1(x)
39 | out = self.bn1(out)
40 | out = self.relu(out)
41 |
42 | out = self.conv2(out)
43 | out = self.bn2(out)
44 |
45 | if self.downsample is not None:
46 | residual = self.downsample(x)
47 |
48 | out += residual
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 |
54 | class IBN(nn.Module):
55 | def __init__(self, planes):
56 | super(IBN, self).__init__()
57 | half1 = int(planes/2)
58 | self.half = half1
59 | half2 = planes - half1
60 | self.IN = nn.InstanceNorm2d(half1, affine=True)
61 | self.BN = nn.BatchNorm2d(half2)
62 |
63 | def forward(self, x):
64 | split = torch.split(x, self.half, 1)
65 | out1 = self.IN(split[0].contiguous())
66 | out2 = self.BN(split[1].contiguous())
67 | out = torch.cat((out1, out2), 1)
68 | return out
69 |
70 |
71 | class Bottleneck(nn.Module):
72 | expansion = 4
73 |
74 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None):
75 | super(Bottleneck, self).__init__()
76 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
77 | if ibn:
78 | self.bn1 = IBN(planes)
79 | else:
80 | self.bn1 = nn.BatchNorm2d(planes)
81 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
82 | padding=1, bias=False)
83 | self.bn2 = nn.BatchNorm2d(planes)
84 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
85 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
86 | self.relu = nn.ReLU(inplace=True)
87 | self.downsample = downsample
88 | self.stride = stride
89 |
90 | def forward(self, x):
91 | residual = x
92 |
93 | out = self.conv1(x)
94 | out = self.bn1(out)
95 | out = self.relu(out)
96 |
97 | out = self.conv2(out)
98 | out = self.bn2(out)
99 | out = self.relu(out)
100 |
101 | out = self.conv3(out)
102 | out = self.bn3(out)
103 |
104 | if self.downsample is not None:
105 | residual = self.downsample(x)
106 |
107 | out += residual
108 | out = self.relu(out)
109 |
110 | return out
111 |
112 |
113 | class ResNet(nn.Module):
114 |
115 | def __init__(self, block, layers, num_classes=1000):
116 | scale = 64
117 | self.inplanes = scale
118 | super(ResNet, self).__init__()
119 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3,
120 | bias=False)
121 | self.bn1 = nn.BatchNorm2d(scale)
122 | self.relu = nn.ReLU(inplace=True)
123 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
124 | self.layer1 = self._make_layer(block, scale, layers[0])
125 | self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2)
126 | self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2)
127 | self.layer4 = self._make_layer(block, scale*8, layers[3], stride=2)
128 | self.avgpool = nn.AvgPool2d(7)
129 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes)
130 |
131 | for m in self.modules():
132 | if isinstance(m, nn.Conv2d):
133 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
134 | m.weight.data.normal_(0, math.sqrt(2. / n))
135 | elif isinstance(m, nn.BatchNorm2d):
136 | m.weight.data.fill_(1)
137 | m.bias.data.zero_()
138 | elif isinstance(m, nn.InstanceNorm2d):
139 | m.weight.data.fill_(1)
140 | m.bias.data.zero_()
141 |
142 | def _make_layer(self, block, planes, blocks, stride=1):
143 | downsample = None
144 | if stride != 1 or self.inplanes != planes * block.expansion:
145 | downsample = nn.Sequential(
146 | nn.Conv2d(self.inplanes, planes * block.expansion,
147 | kernel_size=1, stride=stride, bias=False),
148 | nn.BatchNorm2d(planes * block.expansion),
149 | )
150 |
151 | layers = []
152 | ibn = True
153 | if planes == 512:
154 | ibn = False
155 | layers.append(block(self.inplanes, planes, ibn, stride, downsample))
156 | self.inplanes = planes * block.expansion
157 | for i in range(1, blocks):
158 | layers.append(block(self.inplanes, planes, ibn))
159 |
160 | return nn.Sequential(*layers)
161 |
162 | def forward(self, x):
163 | x = self.conv1(x)
164 | x = self.bn1(x)
165 | x = self.relu(x)
166 | x = self.maxpool(x)
167 |
168 | x = self.layer1(x)
169 | x = self.layer2(x)
170 | x = self.layer3(x)
171 | x = self.layer4(x)
172 |
173 | x = self.avgpool(x)
174 | x = x.view(x.size(0), -1)
175 | x = self.fc(x)
176 |
177 | return x
178 |
179 |
180 | def resnet50_ibn_a(pretrained=False, **kwargs):
181 | """Constructs a ResNet-50 model.
182 | Args:
183 | pretrained (bool): If True, returns a model pre-trained on ImageNet
184 | """
185 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
186 | if pretrained:
187 | state_dict = torch.load(model_urls['ibn_resnet50a'], map_location=torch.device('cpu'))['state_dict']
188 | state_dict = remove_module_key(state_dict)
189 | model.load_state_dict(state_dict)
190 | return model
191 |
192 |
193 | def resnet101_ibn_a(pretrained=False, **kwargs):
194 | """Constructs a ResNet-101 model.
195 | Args:
196 | pretrained (bool): If True, returns a model pre-trained on ImageNet
197 | """
198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
199 | if pretrained:
200 | state_dict = torch.load(model_urls['ibn_resnet101a'], map_location=torch.device('cpu'))['state_dict']
201 | state_dict = remove_module_key(state_dict)
202 | model.load_state_dict(state_dict)
203 | return model
204 |
205 |
206 | def remove_module_key(state_dict):
207 | for key in list(state_dict.keys()):
208 | if 'module' in key:
209 | state_dict[key.replace('module.','')] = state_dict.pop(key)
210 | return state_dict
211 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/trainers.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import time
3 | import numpy as np
4 | import collections
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import functional as F
9 |
10 | from .utils.meters import AverageMeter
11 |
12 |
13 | class SpCLTrainer_UDA(object):
14 | def __init__(self, encoder, memory, source_classes):
15 | super(SpCLTrainer_UDA, self).__init__()
16 | self.encoder = encoder
17 | self.memory = memory
18 | self.source_classes = source_classes
19 |
20 | def train(self, epoch, data_loader_source, data_loader_target,
21 | optimizer, print_freq=10, train_iters=400):
22 | self.encoder.train()
23 |
24 | batch_time = AverageMeter()
25 | data_time = AverageMeter()
26 |
27 | losses_s = AverageMeter()
28 | losses_t = AverageMeter()
29 |
30 | end = time.time()
31 | for i in range(train_iters):
32 | # load data
33 | source_inputs = data_loader_source.next()
34 | target_inputs = data_loader_target.next()
35 | data_time.update(time.time() - end)
36 |
37 | # process inputs
38 | s_inputs, s_targets, _ = self._parse_data(source_inputs)
39 | t_inputs, _, t_indexes = self._parse_data(target_inputs)
40 |
41 | # arrange batch for domain-specific BN
42 | device_num = torch.cuda.device_count()
43 | B, C, H, W = s_inputs.size()
44 | def reshape(inputs):
45 | return inputs.view(device_num, -1, C, H, W)
46 | s_inputs, t_inputs = reshape(s_inputs), reshape(t_inputs)
47 | inputs = torch.cat((s_inputs, t_inputs), 1).view(-1, C, H, W)
48 |
49 | # forward
50 | f_out = self._forward(inputs)
51 |
52 | # de-arrange batch
53 | f_out = f_out.view(device_num, -1, f_out.size(-1))
54 | f_out_s, f_out_t = f_out.split(f_out.size(1)//2, dim=1)
55 | f_out_s, f_out_t = f_out_s.contiguous().view(-1, f_out.size(-1)), f_out_t.contiguous().view(-1, f_out.size(-1))
56 |
57 | # compute loss with the hybrid memory
58 | loss_s = self.memory(f_out_s, s_targets)
59 | loss_t = self.memory(f_out_t, t_indexes+self.source_classes)
60 |
61 | loss = loss_s+loss_t
62 | optimizer.zero_grad()
63 | loss.backward()
64 | optimizer.step()
65 |
66 | losses_s.update(loss_s.item())
67 | losses_t.update(loss_t.item())
68 |
69 | # print log
70 | batch_time.update(time.time() - end)
71 | end = time.time()
72 |
73 | if (i + 1) % print_freq == 0:
74 | print('Epoch: [{}][{}/{}]\t'
75 | 'Time {:.3f} ({:.3f})\t'
76 | 'Data {:.3f} ({:.3f})\t'
77 | 'Loss_s {:.3f} ({:.3f})\t'
78 | 'Loss_t {:.3f} ({:.3f})'
79 | .format(epoch, i + 1, len(data_loader_target),
80 | batch_time.val, batch_time.avg,
81 | data_time.val, data_time.avg,
82 | losses_s.val, losses_s.avg,
83 | losses_t.val, losses_t.avg))
84 |
85 | def _parse_data(self, inputs):
86 | imgs, _, pids, _, indexes = inputs
87 | return imgs.cuda(), pids.cuda(), indexes.cuda()
88 |
89 | def _forward(self, inputs):
90 | return self.encoder(inputs)
91 |
92 |
93 | class SpCLTrainer_USL(object):
94 | def __init__(self, encoder, memory):
95 | super(SpCLTrainer_USL, self).__init__()
96 | self.encoder = encoder
97 | self.memory = memory
98 |
99 | def train(self, epoch, data_loader, optimizer, print_freq=10, train_iters=400):
100 | self.encoder.train()
101 |
102 | batch_time = AverageMeter()
103 | data_time = AverageMeter()
104 |
105 | losses = AverageMeter()
106 |
107 | end = time.time()
108 | for i in range(train_iters):
109 | # load data
110 | inputs = data_loader.next()
111 | data_time.update(time.time() - end)
112 |
113 | # process inputs
114 | inputs, _, indexes = self._parse_data(inputs)
115 |
116 | # forward
117 | f_out = self._forward(inputs)
118 |
119 | # compute loss with the hybrid memory
120 | loss = self.memory(f_out, indexes)
121 |
122 | optimizer.zero_grad()
123 | loss.backward()
124 | optimizer.step()
125 |
126 | losses.update(loss.item())
127 |
128 | # print log
129 | batch_time.update(time.time() - end)
130 | end = time.time()
131 |
132 | if (i + 1) % print_freq == 0:
133 | print('Epoch: [{}][{}/{}]\t'
134 | 'Time {:.3f} ({:.3f})\t'
135 | 'Data {:.3f} ({:.3f})\t'
136 | 'Loss {:.3f} ({:.3f})'
137 | .format(epoch, i + 1, len(data_loader),
138 | batch_time.val, batch_time.avg,
139 | data_time.val, data_time.avg,
140 | losses.val, losses.avg))
141 |
142 | def _parse_data(self, inputs):
143 | imgs, _, pids, _, indexes = inputs
144 | return imgs.cuda(), pids.cuda(), indexes.cuda()
145 |
146 | def _forward(self, inputs):
147 | return self.encoder(inputs)
148 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 |
5 |
6 | def to_numpy(tensor):
7 | if torch.is_tensor(tensor):
8 | return tensor.cpu().numpy()
9 | elif type(tensor).__module__ != 'numpy':
10 | raise ValueError("Cannot convert {} to numpy array"
11 | .format(type(tensor)))
12 | return tensor
13 |
14 |
15 | def to_torch(ndarray):
16 | if type(ndarray).__module__ == 'numpy':
17 | return torch.from_numpy(ndarray)
18 | elif not torch.is_tensor(ndarray):
19 | raise ValueError("Cannot convert {} to torch tensor"
20 | .format(type(ndarray)))
21 | return ndarray
22 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .base_dataset import BaseDataset, BaseImageDataset
4 | from .preprocessor import Preprocessor
5 |
6 | class IterLoader:
7 | def __init__(self, loader, length=None):
8 | self.loader = loader
9 | self.length = length
10 | self.iter = None
11 |
12 | def __len__(self):
13 | if (self.length is not None):
14 | return self.length
15 | return len(self.loader)
16 |
17 | def new_epoch(self):
18 | self.iter = iter(self.loader)
19 |
20 | def next(self):
21 | try:
22 | return next(self.iter)
23 | except:
24 | self.iter = iter(self.loader)
25 | return next(self.iter)
26 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import numpy as np
3 |
4 |
5 | class BaseDataset(object):
6 | """
7 | Base class of reid dataset
8 | """
9 |
10 | def get_imagedata_info(self, data):
11 | pids, cams = [], []
12 | for _, pid, camid in data:
13 | pids += [pid]
14 | cams += [camid]
15 | pids = set(pids)
16 | cams = set(cams)
17 | num_pids = len(pids)
18 | num_cams = len(cams)
19 | num_imgs = len(data)
20 | return num_pids, num_imgs, num_cams
21 |
22 | def print_dataset_statistics(self):
23 | raise NotImplementedError
24 |
25 | @property
26 | def images_dir(self):
27 | return None
28 |
29 |
30 | class BaseImageDataset(BaseDataset):
31 | """
32 | Base class of image reid dataset
33 | """
34 |
35 | def print_dataset_statistics(self, train, query, gallery):
36 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
37 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
38 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
39 |
40 | print("Dataset statistics:")
41 | print(" ----------------------------------------")
42 | print(" subset | # ids | # images | # cameras")
43 | print(" ----------------------------------------")
44 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
45 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
46 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
47 | print(" ----------------------------------------")
48 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/data/preprocessor.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import os.path as osp
4 | from torch.utils.data import DataLoader, Dataset
5 | import numpy as np
6 | import random
7 | import math
8 | from PIL import Image
9 |
10 | class Preprocessor(Dataset):
11 | def __init__(self, dataset, root=None, transform=None):
12 | super(Preprocessor, self).__init__()
13 | self.dataset = dataset
14 | self.root = root
15 | self.transform = transform
16 |
17 | def __len__(self):
18 | return len(self.dataset)
19 |
20 | def __getitem__(self, indices):
21 | return self._get_single_item(indices)
22 |
23 | def _get_single_item(self, index):
24 | fname, pid, camid = self.dataset[index]
25 | fpath = fname
26 | if self.root is not None:
27 | fpath = osp.join(self.root, fname)
28 |
29 | img = Image.open(fpath).convert('RGB')
30 |
31 | if self.transform is not None:
32 | img = self.transform(img)
33 |
34 | return img, fname, pid, camid, index
35 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import defaultdict
3 | import math
4 |
5 | import numpy as np
6 | import copy
7 | import random
8 | import torch
9 | from torch.utils.data.sampler import (
10 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler,
11 | WeightedRandomSampler)
12 |
13 |
14 | def No_index(a, b):
15 | assert isinstance(a, list)
16 | return [i for i, j in enumerate(a) if j != b]
17 |
18 |
19 | class RandomIdentitySampler(Sampler):
20 | def __init__(self, data_source, num_instances):
21 | self.data_source = data_source
22 | self.num_instances = num_instances
23 | self.index_dic = defaultdict(list)
24 | for index, (_, pid, _) in enumerate(data_source):
25 | self.index_dic[pid].append(index)
26 | self.pids = list(self.index_dic.keys())
27 | self.num_samples = len(self.pids)
28 |
29 | def __len__(self):
30 | return self.num_samples * self.num_instances
31 |
32 | def __iter__(self):
33 | indices = torch.randperm(self.num_samples).tolist()
34 | ret = []
35 | for i in indices:
36 | pid = self.pids[i]
37 | t = self.index_dic[pid]
38 | if len(t) >= self.num_instances:
39 | t = np.random.choice(t, size=self.num_instances, replace=False)
40 | else:
41 | t = np.random.choice(t, size=self.num_instances, replace=True)
42 | ret.extend(t)
43 | return iter(ret)
44 |
45 |
46 | class RandomMultipleGallerySampler(Sampler):
47 | def __init__(self, data_source, num_instances=4):
48 | self.data_source = data_source
49 | self.index_pid = defaultdict(int)
50 | self.pid_cam = defaultdict(list)
51 | self.pid_index = defaultdict(list)
52 | self.num_instances = num_instances
53 |
54 | for index, (_, pid, cam) in enumerate(data_source):
55 | if (pid<0): continue
56 | self.index_pid[index] = pid
57 | self.pid_cam[pid].append(cam)
58 | self.pid_index[pid].append(index)
59 |
60 | self.pids = list(self.pid_index.keys())
61 | self.num_samples = len(self.pids)
62 |
63 | def __len__(self):
64 | return self.num_samples * self.num_instances
65 |
66 | def __iter__(self):
67 | indices = torch.randperm(len(self.pids)).tolist()
68 | ret = []
69 |
70 | for kid in indices:
71 | i = random.choice(self.pid_index[self.pids[kid]])
72 |
73 | _, i_pid, i_cam = self.data_source[i]
74 |
75 | ret.append(i)
76 |
77 | pid_i = self.index_pid[i]
78 | cams = self.pid_cam[pid_i]
79 | index = self.pid_index[pid_i]
80 | select_cams = No_index(cams, i_cam)
81 |
82 | if select_cams:
83 |
84 | if len(select_cams) >= self.num_instances:
85 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False)
86 | else:
87 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True)
88 |
89 | for kk in cam_indexes:
90 | ret.append(index[kk])
91 |
92 | else:
93 | select_indexes = No_index(index, i)
94 | if (not select_indexes): continue
95 | if len(select_indexes) >= self.num_instances:
96 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False)
97 | else:
98 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True)
99 |
100 | for kk in ind_indexes:
101 | ret.append(index[kk])
102 |
103 |
104 | return iter(ret)
105 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/data/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from torchvision.transforms import *
4 | from PIL import Image
5 | import random
6 | import math
7 | import numpy as np
8 |
9 | class RectScale(object):
10 | def __init__(self, height, width, interpolation=Image.BILINEAR):
11 | self.height = height
12 | self.width = width
13 | self.interpolation = interpolation
14 |
15 | def __call__(self, img):
16 | w, h = img.size
17 | if h == self.height and w == self.width:
18 | return img
19 | return img.resize((self.width, self.height), self.interpolation)
20 |
21 |
22 | class RandomSizedRectCrop(object):
23 | def __init__(self, height, width, interpolation=Image.BILINEAR):
24 | self.height = height
25 | self.width = width
26 | self.interpolation = interpolation
27 |
28 | def __call__(self, img):
29 | for attempt in range(10):
30 | area = img.size[0] * img.size[1]
31 | target_area = random.uniform(0.64, 1.0) * area
32 | aspect_ratio = random.uniform(2, 3)
33 |
34 | h = int(round(math.sqrt(target_area * aspect_ratio)))
35 | w = int(round(math.sqrt(target_area / aspect_ratio)))
36 |
37 | if w <= img.size[0] and h <= img.size[1]:
38 | x1 = random.randint(0, img.size[0] - w)
39 | y1 = random.randint(0, img.size[1] - h)
40 |
41 | img = img.crop((x1, y1, x1 + w, y1 + h))
42 | assert(img.size == (w, h))
43 |
44 | return img.resize((self.width, self.height), self.interpolation)
45 |
46 | # Fallback
47 | scale = RectScale(self.height, self.width,
48 | interpolation=self.interpolation)
49 | return scale(img)
50 |
51 |
52 | class RandomErasing(object):
53 | """ Randomly selects a rectangle region in an image and erases its pixels.
54 | 'Random Erasing Data Augmentation' by Zhong et al.
55 | See https://arxiv.org/pdf/1708.04896.pdf
56 | Args:
57 | probability: The probability that the Random Erasing operation will be performed.
58 | sl: Minimum proportion of erased area against input image.
59 | sh: Maximum proportion of erased area against input image.
60 | r1: Minimum aspect ratio of erased area.
61 | mean: Erasing value.
62 | """
63 |
64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
65 | self.probability = probability
66 | self.mean = mean
67 | self.sl = sl
68 | self.sh = sh
69 | self.r1 = r1
70 |
71 | def __call__(self, img):
72 |
73 | if random.uniform(0, 1) >= self.probability:
74 | return img
75 |
76 | for attempt in range(100):
77 | area = img.size()[1] * img.size()[2]
78 |
79 | target_area = random.uniform(self.sl, self.sh) * area
80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1)
81 |
82 | h = int(round(math.sqrt(target_area * aspect_ratio)))
83 | w = int(round(math.sqrt(target_area / aspect_ratio)))
84 |
85 | if w < img.size()[2] and h < img.size()[1]:
86 | x1 = random.randint(0, img.size()[1] - h)
87 | y1 = random.randint(0, img.size()[2] - w)
88 | if img.size()[0] == 3:
89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
92 | else:
93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
94 | return img
95 |
96 | return img
97 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/faiss_rerank.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
5 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
6 | Matlab version: https://github.com/zhunzhong07/person-re-ranking
7 | """
8 |
9 | import os, sys
10 | import time
11 | import numpy as np
12 | from scipy.spatial.distance import cdist
13 | import gc
14 | import faiss
15 |
16 | import torch
17 | import torch.nn.functional as F
18 |
19 | from .faiss_utils import search_index_pytorch, search_raw_array_pytorch, \
20 | index_init_gpu, index_init_cpu
21 |
22 | def k_reciprocal_neigh(initial_rank, i, k1):
23 | forward_k_neigh_index = initial_rank[i,:k1+1]
24 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
25 | fi = np.where(backward_k_neigh_index==i)[0]
26 | return forward_k_neigh_index[fi]
27 |
28 | def compute_jaccard_distance(target_features, k1=20, k2=6, print_flag=True, search_option=0, use_float16=False):
29 | end = time.time()
30 | if print_flag:
31 | print('Computing jaccard distance...')
32 |
33 | ngpus = faiss.get_num_gpus()
34 | N = target_features.size(0)
35 | mat_type = np.float16 if use_float16 else np.float32
36 |
37 | if (search_option==0):
38 | # GPU + PyTorch CUDA Tensors (1)
39 | res = faiss.StandardGpuResources()
40 | res.setDefaultNullStreamAllDevices()
41 | _, initial_rank = search_raw_array_pytorch(res, target_features, target_features, k1)
42 | initial_rank = initial_rank.cpu().numpy()
43 | elif (search_option==1):
44 | # GPU + PyTorch CUDA Tensors (2)
45 | res = faiss.StandardGpuResources()
46 | index = faiss.GpuIndexFlatL2(res, target_features.size(-1))
47 | index.add(target_features.cpu().numpy())
48 | _, initial_rank = search_index_pytorch(index, target_features, k1)
49 | res.syncDefaultStreamCurrentDevice()
50 | initial_rank = initial_rank.cpu().numpy()
51 | elif (search_option==2):
52 | # GPU
53 | index = index_init_gpu(ngpus, target_features.size(-1))
54 | index.add(target_features.cpu().numpy())
55 | _, initial_rank = index.search(target_features.cpu().numpy(), k1)
56 | else:
57 | # CPU
58 | index = index_init_cpu(target_features.size(-1))
59 | index.add(target_features.cpu().numpy())
60 | _, initial_rank = index.search(target_features.cpu().numpy(), k1)
61 |
62 |
63 | nn_k1 = []
64 | nn_k1_half = []
65 | for i in range(N):
66 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1))
67 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1/2))))
68 |
69 | V = np.zeros((N, N), dtype=mat_type)
70 | for i in range(N):
71 | k_reciprocal_index = nn_k1[i]
72 | k_reciprocal_expansion_index = k_reciprocal_index
73 | for candidate in k_reciprocal_index:
74 | candidate_k_reciprocal_index = nn_k1_half[candidate]
75 | if (len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index)) > 2/3*len(candidate_k_reciprocal_index)):
76 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)
77 |
78 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique
79 | dist = 2-2*torch.mm(target_features[i].unsqueeze(0).contiguous(), target_features[k_reciprocal_expansion_index].t())
80 | if use_float16:
81 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type)
82 | else:
83 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy()
84 |
85 | del nn_k1, nn_k1_half
86 |
87 | if k2 != 1:
88 | V_qe = np.zeros_like(V, dtype=mat_type)
89 | for i in range(N):
90 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:], axis=0)
91 | V = V_qe
92 | del V_qe
93 |
94 | del initial_rank
95 |
96 | invIndex = []
97 | for i in range(N):
98 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num
99 |
100 | jaccard_dist = np.zeros((N, N), dtype=mat_type)
101 | for i in range(N):
102 | temp_min = np.zeros((1,N), dtype=mat_type)
103 | # temp_max = np.zeros((1,N), dtype=mat_type)
104 | indNonZero = np.where(V[i,:] != 0)[0]
105 | indImages = []
106 | indImages = [invIndex[ind] for ind in indNonZero]
107 | for j in range(len(indNonZero)):
108 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
109 | # temp_max[0,indImages[j]] = temp_max[0,indImages[j]]+np.maximum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
110 |
111 | jaccard_dist[i] = 1-temp_min/(2-temp_min)
112 | # jaccard_dist[i] = 1-temp_min/(temp_max+1e-6)
113 |
114 | del invIndex, V
115 |
116 | pos_bool = (jaccard_dist < 0)
117 | jaccard_dist[pos_bool] = 0.0
118 | if print_flag:
119 | print ("Jaccard distance computing time cost: {}".format(time.time()-end))
120 |
121 | return jaccard_dist
122 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/faiss_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import faiss
4 | import torch
5 |
6 | def swig_ptr_from_FloatTensor(x):
7 | assert x.is_contiguous()
8 | assert x.dtype == torch.float32
9 | return faiss.cast_integer_to_float_ptr(
10 | x.storage().data_ptr() + x.storage_offset() * 4)
11 |
12 | def swig_ptr_from_LongTensor(x):
13 | assert x.is_contiguous()
14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
15 | # return faiss.cast_integer_to_long_ptr(
16 | # x.storage().data_ptr() + x.storage_offset() * 8)
17 | return faiss.cast_integer_to_idx_t_ptr(
18 | x.storage().data_ptr() + x.storage_offset() * 8)
19 |
20 | def search_index_pytorch(index, x, k, D=None, I=None):
21 | """call the search function of an index with pytorch tensor I/O (CPU
22 | and GPU supported)"""
23 | assert x.is_contiguous()
24 | n, d = x.size()
25 | assert d == index.d
26 |
27 | if D is None:
28 | D = torch.empty((n, k), dtype=torch.float32, device=x.device)
29 | else:
30 | assert D.size() == (n, k)
31 |
32 | if I is None:
33 | I = torch.empty((n, k), dtype=torch.int64, device=x.device)
34 | else:
35 | assert I.size() == (n, k)
36 | torch.cuda.synchronize()
37 | xptr = swig_ptr_from_FloatTensor(x)
38 | Iptr = swig_ptr_from_LongTensor(I)
39 | Dptr = swig_ptr_from_FloatTensor(D)
40 | index.search_c(n, xptr,
41 | k, Dptr, Iptr)
42 | torch.cuda.synchronize()
43 | return D, I
44 |
45 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None,
46 | metric=faiss.METRIC_L2):
47 | assert xb.device == xq.device
48 |
49 | nq, d = xq.size()
50 | if xq.is_contiguous():
51 | xq_row_major = True
52 | elif xq.t().is_contiguous():
53 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-)
54 | xq_row_major = False
55 | else:
56 | raise TypeError('matrix should be row or column-major')
57 |
58 | xq_ptr = swig_ptr_from_FloatTensor(xq)
59 |
60 | nb, d2 = xb.size()
61 | assert d2 == d
62 | if xb.is_contiguous():
63 | xb_row_major = True
64 | elif xb.t().is_contiguous():
65 | xb = xb.t()
66 | xb_row_major = False
67 | else:
68 | raise TypeError('matrix should be row or column-major')
69 | xb_ptr = swig_ptr_from_FloatTensor(xb)
70 |
71 | if D is None:
72 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
73 | else:
74 | assert D.shape == (nq, k)
75 | assert D.device == xb.device
76 |
77 | if I is None:
78 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
79 | else:
80 | assert I.shape == (nq, k)
81 | assert I.device == xb.device
82 |
83 | D_ptr = swig_ptr_from_FloatTensor(D)
84 | I_ptr = swig_ptr_from_LongTensor(I)
85 |
86 | faiss.bruteForceKnn(res, metric,
87 | xb_ptr, xb_row_major, nb,
88 | xq_ptr, xq_row_major, nq,
89 | d, k, D_ptr, I_ptr)
90 |
91 | return D, I
92 |
93 | def index_init_gpu(ngpus, feat_dim):
94 | flat_config = []
95 | for i in range(ngpus):
96 | cfg = faiss.GpuIndexFlatConfig()
97 | cfg.useFloat16 = False
98 | cfg.device = i
99 | flat_config.append(cfg)
100 |
101 | res = [faiss.StandardGpuResources() for i in range(ngpus)]
102 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)]
103 | index = faiss.IndexShards(feat_dim)
104 | for sub_index in indexes:
105 | index.add_shard(sub_index)
106 | index.reset()
107 | return index
108 |
109 | def index_init_cpu(feat_dim):
110 | return faiss.IndexFlatL2(feat_dim)
111 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/logging.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import sys
4 |
5 | from .osutils import mkdir_if_missing
6 |
7 |
8 | class Logger(object):
9 | def __init__(self, fpath=None):
10 | self.console = sys.stdout
11 | self.file = None
12 | if fpath is not None:
13 | mkdir_if_missing(os.path.dirname(fpath))
14 | self.file = open(fpath, 'w')
15 |
16 | def __del__(self):
17 | self.close()
18 |
19 | def __enter__(self):
20 | pass
21 |
22 | def __exit__(self, *args):
23 | self.close()
24 |
25 | def write(self, msg):
26 | self.console.write(msg)
27 | if self.file is not None:
28 | self.file.write(msg)
29 |
30 | def flush(self):
31 | self.console.flush()
32 | if self.file is not None:
33 | self.file.flush()
34 | os.fsync(self.file.fileno())
35 |
36 | def close(self):
37 | self.console.close()
38 | if self.file is not None:
39 | self.file.close()
40 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/meters.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 |
7 | def __init__(self):
8 | self.val = 0
9 | self.avg = 0
10 | self.sum = 0
11 | self.count = 0
12 |
13 | def reset(self):
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/osutils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import errno
4 |
5 |
6 | def mkdir_if_missing(dir_path):
7 | try:
8 | os.makedirs(dir_path)
9 | except OSError as e:
10 | if e.errno != errno.EEXIST:
11 | raise
12 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/rerank.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2/python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Source: https://github.com/zhunzhong07/person-re-ranking
5 | Created on Mon Jun 26 14:46:56 2017
6 | @author: luohao
7 | Modified by Houjing Huang, 2017-12-22.
8 | - This version accepts distance matrix instead of raw features.
9 | - The difference of `/` division between python 2 and 3 is handled.
10 | - numpy.float16 is replaced by numpy.float32 for numerical precision.
11 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
12 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
13 | Matlab version: https://github.com/zhunzhong07/person-re-ranking
14 | API
15 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery]
16 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query]
17 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery]
18 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3)
19 | Returns:
20 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery]
21 | """
22 | from __future__ import absolute_import
23 | from __future__ import print_function
24 | from __future__ import division
25 |
26 | __all__ = ['re_ranking']
27 |
28 | import numpy as np
29 |
30 |
31 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):
32 |
33 | # The following naming, e.g. gallery_num, is different from outer scope.
34 | # Don't care about it.
35 |
36 | original_dist = np.concatenate(
37 | [np.concatenate([q_q_dist, q_g_dist], axis=1),
38 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)],
39 | axis=0)
40 | original_dist = np.power(original_dist, 2).astype(np.float32)
41 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0))
42 | V = np.zeros_like(original_dist).astype(np.float32)
43 | initial_rank = np.argsort(original_dist).astype(np.int32)
44 |
45 | query_num = q_g_dist.shape[0]
46 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1]
47 | all_num = gallery_num
48 |
49 | for i in range(all_num):
50 | # k-reciprocal neighbors
51 | forward_k_neigh_index = initial_rank[i,:k1+1]
52 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
53 | fi = np.where(backward_k_neigh_index==i)[0]
54 | k_reciprocal_index = forward_k_neigh_index[fi]
55 | k_reciprocal_expansion_index = k_reciprocal_index
56 | for j in range(len(k_reciprocal_index)):
57 | candidate = k_reciprocal_index[j]
58 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1]
59 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1]
60 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
61 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
62 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index):
63 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)
64 |
65 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
66 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index])
67 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight)
68 | original_dist = original_dist[:query_num,]
69 | if k2 != 1:
70 | V_qe = np.zeros_like(V,dtype=np.float32)
71 | for i in range(all_num):
72 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0)
73 | V = V_qe
74 | del V_qe
75 | del initial_rank
76 | invIndex = []
77 | for i in range(gallery_num):
78 | invIndex.append(np.where(V[:,i] != 0)[0])
79 |
80 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32)
81 |
82 |
83 | for i in range(query_num):
84 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32)
85 | indNonZero = np.where(V[i,:] != 0)[0]
86 | indImages = []
87 | indImages = [invIndex[ind] for ind in indNonZero]
88 | for j in range(len(indNonZero)):
89 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
90 | jaccard_dist[i] = 1-temp_min/(2.-temp_min)
91 |
92 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value
93 | del original_dist
94 | del V
95 | del jaccard_dist
96 | final_dist = final_dist[:query_num,query_num:]
97 | return final_dist
98 |
--------------------------------------------------------------------------------
/SpCL-master/spcl/utils/serialization.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import json
3 | import os.path as osp
4 | import shutil
5 |
6 | import torch
7 | from torch.nn import Parameter
8 |
9 | from .osutils import mkdir_if_missing
10 |
11 |
12 | def read_json(fpath):
13 | with open(fpath, 'r') as f:
14 | obj = json.load(f)
15 | return obj
16 |
17 |
18 | def write_json(obj, fpath):
19 | mkdir_if_missing(osp.dirname(fpath))
20 | with open(fpath, 'w') as f:
21 | json.dump(obj, f, indent=4, separators=(',', ': '))
22 |
23 |
24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
25 | mkdir_if_missing(osp.dirname(fpath))
26 | torch.save(state, fpath)
27 | if is_best:
28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))
29 |
30 |
31 | def load_checkpoint(fpath):
32 | if osp.isfile(fpath):
33 | # checkpoint = torch.load(fpath)
34 | checkpoint = torch.load(fpath, map_location=torch.device('cpu'))
35 | print("=> Loaded checkpoint '{}'".format(fpath))
36 | return checkpoint
37 | else:
38 | raise ValueError("=> No checkpoint found at '{}'".format(fpath))
39 |
40 |
41 | def copy_state_dict(state_dict, model, strip=None):
42 | tgt_state = model.state_dict()
43 | copied_names = set()
44 | for name, param in state_dict.items():
45 | if strip is not None and name.startswith(strip):
46 | name = name[len(strip):]
47 | if name not in tgt_state:
48 | continue
49 | if isinstance(param, Parameter):
50 | param = param.data
51 | if param.size() != tgt_state[name].size():
52 | print('mismatch:', name, param.size(), tgt_state[name].size())
53 | continue
54 | tgt_state[name].copy_(param)
55 | copied_names.add(name)
56 |
57 | missing = set(tgt_state.keys()) - copied_names
58 | if len(missing) > 0:
59 | print("missing keys in state_dict:", missing)
60 |
61 | return model
62 |
--------------------------------------------------------------------------------
/config/config_regdb.yaml:
--------------------------------------------------------------------------------
1 | ## Note: color = rgb = visible, thermal = ir = infrared.
2 |
3 | ## dataset parameters
4 | dataset: regdb # sysu or regdb
5 | dataset_path: ../../dataset/ # dataset root path
6 | trial: 1 # only for regdb test
7 | mode: visibletothermal # all or indoor (sysu test), thermaltovisible or visibletothermal (regdb test)
8 | workers: 4 # number of data loading workers (default: 4)
9 | dataset_num_size: 2 # the multiple of dataset size per trainloader
10 |
11 | ## model parameters
12 | arch: resnet50 # network baseline
13 | pool_dim: 2048 # pooling dim: 2048 for resnet50
14 | per_add_iters: 5 # number of iters adding to coefficient of GRL for each training batch
15 | lambda_sk: 25 # hyperparameter for Sinkhorn-Knopp algorithm
16 |
17 | ## optimizer parameters
18 | optim: adam # optimizer: adam
19 | lr: 0.0035 # learning rate: 0.0035 for adam
20 |
21 | ## normal parameters
22 | file_name: otla-reid/ # log file name
23 | setting: semi-supervised # training setting: supervised or semi-supervised or unsupervised
24 | train_visible_image_path: ../../dataset/RegDB/spcl_uda_market1501TOregdb_rgb_train_rgb_resized_img.npy # the stored visible image path getting from USL-ReID or UDA-ReID methods for unsupervised setting
25 | train_visible_label_path: ../../dataset/RegDB/spcl_uda_market1501TOregdb_rgb_train_rgb_resized_label.npy # the stored visible label path getting from USL-ReID or UDA-ReID methods for unsupervised setting
26 | seed: 0 # random seed
27 | gpu: 0 # gpu device ids for CUDA_VISIBLE_DEVICES
28 | model_path: save_model/ # model save path
29 | log_path: log/ # log save path
30 | vis_log_path: vis_log/ # tensorboard log save path
31 | save_epoch: 10 # save model every few epochs
32 | img_w: 144 # image width
33 | img_h: 288 # image height
34 | train_batch_size: 4 # training batch size: 4
35 | num_pos: 8 # number of pos per identity for each modality: 8
36 | test_batch_size: 64 # testing batch size
37 | start_epoch: 0 # start training epoch
38 | end_epoch: 81 # end training epoch
39 | eval_epoch: 1 # testing epochs
40 |
41 | ## loss parameters
42 | margin: 0.3 # triplet loss margin
43 | lambda_vr: 0.1 # coefficient of prediction alignment loss
44 | lambda_rv: 0.5 # coefficient of prediction alignment loss
--------------------------------------------------------------------------------
/config/config_sysu.yaml:
--------------------------------------------------------------------------------
1 | ## Note: color = rgb = visible, thermal = ir = infrared.
2 |
3 | ## dataset parameters
4 | dataset: sysu # sysu or regdb
5 | dataset_path: ../../dataset/ # dataset root path
6 | mode: all # all or indoor (sysu test), thermaltovisible or visibletothermal (regdb test)
7 | workers: 4 # number of data loading workers (default: 4)
8 | dataset_num_size: 1 # the multiple of dataset size per trainloader
9 |
10 | ## model parameters
11 | arch: resnet50 # network baseline
12 | pool_dim: 2048 # pooling dim: 2048 for resnet50
13 | per_add_iters: 1 # number of iters adding to coefficient of GRL for each training batch
14 | lambda_sk: 25 # hyperparameter for Sinkhorn-Knopp algorithm
15 |
16 | ## optimizer parameters
17 | optim: adam # optimizer: adam
18 | lr: 0.0035 # learning rate: 0.0035 for adam
19 |
20 | ## normal parameters
21 | file_name: otla-reid/ # log file name
22 | setting: semi-supervised # training setting: supervised or semi-supervised or unsupervised
23 | train_visible_image_path: ../../dataset/SYSU-MM01/spcl_uda_market1501TOsysumm01_rgb_train_rgb_resized_img.npy # the stored visible image path getting from USL-ReID or UDA-ReID methods for unsupervised setting
24 | train_visible_label_path: ../../dataset/SYSU-MM01/spcl_uda_market1501TOsysumm01_rgb_train_rgb_resized_label.npy # the stored visible label path getting from USL-ReID or UDA-ReID methods for unsupervised setting
25 | seed: 0 # random seed
26 | gpu: 0 # gpu device ids for CUDA_VISIBLE_DEVICES
27 | model_path: save_model/ # model save path
28 | log_path: log/ # log save path
29 | vis_log_path: vis_log/ # tensorboard log save path
30 | save_epoch: 10 # save model every few epochs
31 | img_w: 144 # image width
32 | img_h: 288 # image height
33 | train_batch_size: 4 # training batch size: 4
34 | num_pos: 8 # number of pos per identity for each modality: 8
35 | test_batch_size: 64 # testing batch size
36 | start_epoch: 0 # start training epoch
37 | end_epoch: 81 # end training epoch
38 | eval_epoch: 1 # testing epochs
39 |
40 | ## loss parameters
41 | margin: 0.3 # triplet loss margin
42 | lambda_vr: 0.1 # coefficient of prediction alignment loss
43 | lambda_rv: 0.5 # coefficient of prediction alignment loss
--------------------------------------------------------------------------------
/data_manager.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import random
4 |
5 |
6 | def process_query_sysu(data_path, mode="all"):
7 | if mode == "all":
8 | ir_cameras = ["cam3", "cam6"]
9 | elif mode == "indoor":
10 | ir_cameras = ["cam3", "cam6"]
11 |
12 | file_path = os.path.join(data_path, "exp/test_id.txt")
13 | files_ir = []
14 |
15 | with open(file_path, 'r') as file:
16 | ids = file.read().splitlines()
17 | ids = [int(y) for y in ids[0].split(',')]
18 | ids = ["%04d" % x for x in ids]
19 |
20 | for id in sorted(ids):
21 | for cam in ir_cameras:
22 | img_dir = os.path.join(data_path, cam, id)
23 | if os.path.isdir(img_dir):
24 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)])
25 | files_ir.extend(new_files)
26 |
27 | query_img = []
28 | query_id = []
29 | query_cam = []
30 | for img_path in files_ir:
31 | camid, pid = int(img_path[-15]), int(img_path[-13:-9])
32 | query_img.append(img_path)
33 | query_id.append(pid)
34 | query_cam.append(camid)
35 |
36 | return query_img, np.array(query_id), np.array(query_cam)
37 |
38 |
39 | def process_gallery_sysu(data_path, mode="all", trial=0):
40 | random.seed(trial)
41 |
42 | if mode == "all":
43 | rgb_cameras = ["cam1", "cam2", "cam4", "cam5"]
44 | elif mode == "indoor":
45 | rgb_cameras = ["cam1", "cam2"]
46 |
47 | file_path = os.path.join(data_path, "exp/test_id.txt")
48 | files_rgb = []
49 | with open(file_path, 'r') as file:
50 | ids = file.read().splitlines()
51 | ids = [int(y) for y in ids[0].split(',')]
52 | ids = ["%04d" % x for x in ids]
53 |
54 | for id in sorted(ids):
55 | for cam in rgb_cameras:
56 | img_dir = os.path.join(data_path, cam, id)
57 | if os.path.isdir(img_dir):
58 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)])
59 | files_rgb.append(random.choice(new_files))
60 |
61 | gall_img = []
62 | gall_id = []
63 | gall_cam = []
64 | for img_path in files_rgb:
65 | camid, pid = int(img_path[-15]), int(img_path[-13:-9])
66 | gall_img.append(img_path)
67 | gall_id.append(pid)
68 | gall_cam.append(camid)
69 |
70 | return gall_img, np.array(gall_id), np.array(gall_cam)
71 |
72 |
73 | def process_test_regdb(img_dir, trial=1, modality="visible"):
74 | if modality == "visible":
75 | input_data_path = os.path.join(img_dir, "idx/test_visible_{}".format(trial) + ".txt")
76 | elif modality == "thermal":
77 | input_data_path = os.path.join(img_dir, "idx/test_thermal_{}".format(trial) + ".txt")
78 |
79 | with open(input_data_path) as f:
80 | data_file_list = open(input_data_path, 'rt').read().splitlines()
81 | # Get full list of image and labels
82 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list]
83 | file_label = [int(s.split(' ')[1]) for s in data_file_list]
84 |
85 | return file_image, np.array(file_label)
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import torch
4 | from torch.autograd import Variable
5 | from utils import AverageMeter
6 | from eval_metrics import eval_regdb, eval_sysu
7 |
8 |
9 | def trainer(args, epoch, main_net, adjust_learning_rate, optimizer, trainloader, criterion, writer=None, print_freq=50):
10 | current_lr = adjust_learning_rate(args, optimizer, epoch)
11 |
12 | total_loss = AverageMeter()
13 | id_loss_rgb = AverageMeter()
14 | id_loss_ir = AverageMeter()
15 | tri_loss_rgb = AverageMeter()
16 | tri_loss_ir = AverageMeter()
17 | dis_loss = AverageMeter()
18 | pa_loss = AverageMeter()
19 | batch_time = AverageMeter()
20 |
21 | correct_tri_rgb = 0
22 | correct_tri_ir = 0
23 | pre_rgb = 0 # it is meaningful only in the case of semi supervised setting
24 | pre_ir = 0 # it is meaningful only in the case of semi supervised setting
25 | pre_rgb_ir = 0 # it is meaningful only in the case of semi supervised setting, whether labels of selected samples per batch are equal
26 | num_rgb = 0
27 | num_ir = 0
28 |
29 | main_net.train() # switch to train mode
30 | end = time.time()
31 |
32 | for batch_id, (input_rgb, input_ir, label_rgb, label_ir) in enumerate(trainloader):
33 | # label_ir is only used to calculate the prediction accuracy of pseudo infrared labels on semi-supervised setting
34 | # label_ir is meaningless on unsupervised setting
35 | # for supervised setting, we change "label_rgb" of "loss_id_ir" and "loss_tri_ir" into "label_ir"
36 |
37 | label_rgb = label_rgb.cuda()
38 | label_ir = label_ir.cuda()
39 | input_rgb = input_rgb.cuda()
40 | input_ir = input_ir.cuda()
41 |
42 | feat, output_cls, output_dis = main_net(input_rgb, input_ir, modal=0, train_set=True)
43 |
44 | loss_id_rgb = criterion[0](output_cls[:input_rgb.size(0)], label_rgb)
45 | loss_tri_rgb, correct_tri_batch_rgb = criterion[1](feat[:input_rgb.size(0)], label_rgb)
46 |
47 | if args.setting == "semi-supervised" or args.setting == "unsupervised":
48 | loss_id_ir = criterion[0](output_cls[input_rgb.size(0):], label_rgb)
49 | loss_tri_ir, correct_tri_batch_ir = criterion[1](feat[input_rgb.size(0):], label_rgb)
50 | elif args.setting == "supervised":
51 | loss_id_ir = criterion[0](output_cls[input_rgb.size(0):], label_ir)
52 | loss_tri_ir, correct_tri_batch_ir = criterion[1](feat[input_rgb.size(0):], label_ir)
53 |
54 | dis_label = torch.cat((torch.ones(input_rgb.size(0)), torch.zeros(input_ir.size(0))), dim=0).cuda()
55 | loss_dis = criterion[2](output_dis.view(-1), dis_label)
56 |
57 | loss_pa, sim_rgbtoir, sim_irtorgb = criterion[3](output_cls[:input_rgb.size(0)], output_cls[input_rgb.size(0):])
58 |
59 | loss = loss_id_rgb + loss_tri_rgb + 0.1 * loss_id_ir + 0.5 * loss_tri_ir + loss_dis + loss_pa
60 |
61 | optimizer.zero_grad()
62 | loss.backward()
63 | optimizer.step()
64 |
65 | correct_tri_rgb += correct_tri_batch_rgb
66 | correct_tri_ir += correct_tri_batch_ir
67 | _, pre_label = output_cls.max(1)
68 | pre_batch_rgb = (pre_label[:input_rgb.size(0)].eq(label_rgb).sum().item())
69 | pre_batch_ir = (pre_label[input_rgb.size(0):].eq(label_ir).sum().item())
70 | pre_batch_rgb_ir = (label_rgb.eq(label_ir).sum().item())
71 | pre_rgb += pre_batch_rgb
72 | pre_ir += pre_batch_ir
73 | pre_rgb_ir += pre_batch_rgb_ir
74 | num_rgb += input_rgb.size(0)
75 | num_ir += input_ir.size(0)
76 | assert num_rgb == num_ir
77 |
78 | total_loss.update(loss.item(), input_rgb.size(0) + input_ir.size(0))
79 | id_loss_rgb.update(loss_id_rgb.item(), input_rgb.size(0))
80 | id_loss_ir.update(loss_id_ir.item(), input_ir.size(0))
81 | tri_loss_rgb.update(loss_tri_rgb, input_rgb.size(0))
82 | tri_loss_ir.update(loss_tri_ir, input_ir.size(0))
83 | dis_loss.update(loss_dis, input_rgb.size(0) + input_ir.size(0))
84 | pa_loss.update(loss_pa.item(), input_rgb.size(0) + input_ir.size(0))
85 |
86 | # measure elapsed time
87 | batch_time.update(time.time() - end)
88 | end = time.time()
89 |
90 | if batch_id % print_freq == 0:
91 | print("Epoch: [{}][{}/{}] "
92 | "Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) "
93 | "Lr: {:.6f} "
94 | "Coeff: {:.3f} "
95 | "Total_Loss: {total_loss.val:.4f}({total_loss.avg:.4f}) "
96 | "ID_Loss_RGB: {id_loss_rgb.val:.4f}({id_loss_rgb.avg:.4f}) "
97 | "ID_Loss_IR: {id_loss_ir.val:.4f}({id_loss_ir.avg:.4f}) "
98 | "Tri_Loss_RGB: {tri_loss_rgb.val:.4f}({tri_loss_rgb.avg:.4f}) "
99 | "Tri_Loss_IR: {tri_loss_ir.val:.4f}({tri_loss_ir.avg:.4f}) "
100 | "Dis_Loss: {dis_loss.val:.4f}({dis_loss.avg:.4f}) "
101 | "Pa_Loss: {pa_loss.val:.4f}({pa_loss.avg:.4f}) "
102 | "Tri_RGB_Acc: {:.2f}% "
103 | "Tri_IR_Acc: {:.2f}% "
104 | "Pre_RGB_Acc: {:.2f}% "
105 | "Pre_IR_Acc: {:.2f}% "
106 | "Pre_RGB_IR_Acc: {:.2f}% ".format(epoch, batch_id, len(trainloader), current_lr, main_net.adnet.coeff,
107 | 100. * correct_tri_rgb / num_rgb,
108 | 100. * correct_tri_ir / num_ir,
109 | 100. * pre_rgb / num_rgb,
110 | 100. * pre_ir / num_ir,
111 | 100. * pre_rgb_ir / num_rgb,
112 | batch_time=batch_time,
113 | total_loss=total_loss,
114 | id_loss_rgb=id_loss_rgb,
115 | id_loss_ir=id_loss_ir,
116 | tri_loss_rgb=tri_loss_rgb,
117 | tri_loss_ir=tri_loss_ir,
118 | dis_loss=dis_loss,
119 | pa_loss=pa_loss))
120 |
121 | if writer is not None:
122 | writer.add_scalar("Lr", current_lr, epoch)
123 | writer.add_scalar("Coeff", main_net.adnet.coeff, epoch)
124 | writer.add_scalar("Total_Loss", total_loss.avg, epoch)
125 | writer.add_scalar("ID_Loss_RGB", id_loss_rgb.avg, epoch)
126 | writer.add_scalar("ID_Loss_IR", id_loss_ir.avg, epoch)
127 | writer.add_scalar("Tri_Loss_RGB", tri_loss_rgb.avg, epoch)
128 | writer.add_scalar("Tri_Loss_IR", tri_loss_ir.avg, epoch)
129 | writer.add_scalar("Dis_Loss", dis_loss.avg, epoch)
130 | writer.add_scalar("Pa_Loss", pa_loss.avg, epoch)
131 | writer.add_scalar("Tri_RGB_Acc", 100. * correct_tri_rgb / num_rgb, epoch)
132 | writer.add_scalar("Tri_IR_Acc", 100. * correct_tri_ir / num_ir, epoch)
133 | writer.add_scalar("Pre_RGB_Acc", 100. * pre_rgb / num_rgb, epoch)
134 | writer.add_scalar("Pre_IR_Acc", 100. * pre_ir / num_ir, epoch)
135 |
136 |
137 | def tester(args, epoch, main_net, test_mode, gall_label, gall_loader, query_label, query_loader, feat_dim=2048, query_cam=None, gall_cam=None, writer=None):
138 | # switch to evaluation mode
139 | main_net.eval()
140 |
141 | print("Extracting Gallery Feature...")
142 | ngall = len(gall_label)
143 | start = time.time()
144 | ptr = 0
145 | gall_feat = np.zeros((ngall, feat_dim))
146 | with torch.no_grad():
147 | for batch_idx, (input, label) in enumerate(gall_loader):
148 | batch_num = input.size(0)
149 | input = Variable(input.cuda())
150 | feat = main_net(input, input, modal=test_mode[0])
151 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy()
152 | ptr = ptr + batch_num
153 | print("Extracting Time:\t {:.3f}".format(time.time() - start))
154 |
155 | print("Extracting Query Feature...")
156 | nquery = len(query_label)
157 | start = time.time()
158 | ptr = 0
159 | query_feat = np.zeros((nquery, feat_dim))
160 | with torch.no_grad():
161 | for batch_idx, (input, label) in enumerate(query_loader):
162 | batch_num = input.size(0)
163 | input = Variable(input.cuda())
164 | feat = main_net(input, input, modal=test_mode[1])
165 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy()
166 | ptr = ptr + batch_num
167 | print("Extracting Time:\t {:.3f}".format(time.time() - start))
168 |
169 | start = time.time()
170 | # compute the similarity
171 | distmat = -np.matmul(query_feat, np.transpose(gall_feat))
172 | # evaluation
173 | if args.dataset == "sysu":
174 | cmc, mAP, mINP = eval_sysu(distmat, query_label, gall_label, query_cam, gall_cam)
175 | elif args.dataset == "regdb":
176 | cmc, mAP, mINP = eval_regdb(distmat, query_label, gall_label)
177 | print("Evaluation Time:\t {:.3f}".format(time.time() - start))
178 |
179 | if writer is not None:
180 | writer.add_scalar("Rank1", cmc[0], epoch)
181 | writer.add_scalar("mAP", mAP, epoch)
182 | writer.add_scalar("mINP", mINP, epoch)
183 |
184 | return cmc, mAP, mINP
--------------------------------------------------------------------------------
/eval_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=20):
5 | """
6 | Evaluation with SYSU-MM01 metric.
7 | Note: For each query identity, its gallery images from the same camera view are discarded,
8 | which follows the original setting in "RGB-Infrared Cross-Modality Person Re-Identificatio, ICCV 2017".
9 | """
10 | num_q, num_g = distmat.shape
11 | if num_g < max_rank:
12 | max_rank = num_g
13 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
14 | indices = np.argsort(distmat, axis=1)
15 | pred_label = g_pids[indices]
16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
17 |
18 | # compute cmc curve for each query
19 | new_all_cmc = []
20 | all_cmc = []
21 | all_AP = []
22 | all_INP = []
23 | num_valid_q = 0. # number of valid query
24 | for q_idx in range(num_q):
25 | # get query pid and camid
26 | q_pid = q_pids[q_idx]
27 | q_camid = q_camids[q_idx]
28 |
29 | # remove gallery samples that have the same pid and camid with query
30 | order = indices[q_idx]
31 | remove = (q_camid == 3) & (g_camids[order] == 2)
32 | keep = np.invert(remove)
33 |
34 | # compute cmc curve
35 | # the cmc calculation is different from standard protocol
36 | # we follow the protocol of the author's released code
37 | new_cmc = pred_label[q_idx][keep]
38 | new_index = np.unique(new_cmc, return_index=True)[1]
39 | new_cmc = [new_cmc[index] for index in sorted(new_index)]
40 |
41 | new_match = (new_cmc == q_pid).astype(np.int32)
42 | new_cmc = new_match.cumsum()
43 | new_all_cmc.append(new_cmc[:max_rank])
44 |
45 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
46 | if not np.any(orig_cmc):
47 | # this condition is true when query identity does not appear in gallery
48 | continue
49 |
50 | cmc = orig_cmc.cumsum()
51 |
52 | # compute mINP
53 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook
54 | pos_idx = np.where(orig_cmc == 1)
55 | pos_max_idx = np.max(pos_idx)
56 | inp = cmc[pos_max_idx] / (pos_max_idx + 1.0)
57 | all_INP.append(inp)
58 |
59 | cmc[cmc > 1] = 1
60 |
61 | all_cmc.append(cmc[:max_rank])
62 | num_valid_q += 1.
63 |
64 | # compute average precision
65 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
66 | num_rel = orig_cmc.sum()
67 | tmp_cmc = orig_cmc.cumsum()
68 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
69 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
70 | AP = tmp_cmc.sum() / num_rel
71 | all_AP.append(AP)
72 |
73 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
74 |
75 | all_cmc = np.asarray(all_cmc).astype(np.float32)
76 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC
77 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32)
78 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q
79 | mAP = np.mean(all_AP)
80 | mINP = np.mean(all_INP)
81 |
82 | return new_all_cmc, mAP, mINP
83 |
84 |
85 | def eval_regdb(distmat, q_pids, g_pids, max_rank=20):
86 | """
87 | Evaluation with RegDB metric.
88 | """
89 | num_q, num_g = distmat.shape
90 | if num_g < max_rank:
91 | max_rank = num_g
92 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
93 | indices = np.argsort(distmat, axis=1)
94 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
95 |
96 | # compute cmc curve for each query
97 | all_cmc = []
98 | all_AP = []
99 | all_INP = []
100 | num_valid_q = 0. # number of valid query
101 |
102 | # only two cameras
103 | q_camids = np.ones(num_q).astype(np.int32)
104 | g_camids = 2 * np.ones(num_g).astype(np.int32)
105 |
106 | for q_idx in range(num_q):
107 | # get query pid and camid
108 | q_pid = q_pids[q_idx]
109 | q_camid = q_camids[q_idx]
110 |
111 | # remove gallery samples that have the same pid and camid with query
112 | order = indices[q_idx]
113 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
114 | keep = np.invert(remove)
115 |
116 | # compute cmc curve
117 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
118 | if not np.any(raw_cmc):
119 | # this condition is true when query identity does not appear in gallery
120 | continue
121 |
122 | cmc = raw_cmc.cumsum()
123 |
124 | # compute mINP
125 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook
126 | pos_idx = np.where(raw_cmc == 1)
127 | pos_max_idx = np.max(pos_idx)
128 | inp = cmc[pos_max_idx] / (pos_max_idx + 1.0)
129 | all_INP.append(inp)
130 |
131 | cmc[cmc > 1] = 1
132 |
133 | all_cmc.append(cmc[:max_rank])
134 | num_valid_q += 1.
135 |
136 | # compute average precision
137 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
138 | num_rel = raw_cmc.sum()
139 | tmp_cmc = raw_cmc.cumsum()
140 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
141 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
142 | AP = tmp_cmc.sum() / num_rel
143 | all_AP.append(AP)
144 |
145 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
146 |
147 | all_cmc = np.asarray(all_cmc).astype(np.float32)
148 | all_cmc = all_cmc.sum(0) / num_valid_q
149 | mAP = np.mean(all_AP)
150 | mINP = np.mean(all_INP)
151 |
152 | return all_cmc, mAP, mINP
--------------------------------------------------------------------------------
/image/main_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjm-wjm/OTLA-ReID/1405e96bd8339deeebf61718346b24722770ac61/image/main_figure.png
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def normalize(x, axis=-1):
6 | """
7 | Normalizing to unit length along the specified dimension.
8 | """
9 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
10 | return x
11 |
12 |
13 | class TripletLoss(nn.Module):
14 | """
15 | Triplet loss with hard positive/negative mining.
16 | Reference: Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
17 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
18 | Args:
19 | - margin (float): margin for triplet.
20 | - inputs: feature matrix with shape (batch_size, feat_dim).
21 | - targets: ground truth labels with shape (num_classes).
22 | """
23 | def __init__(self, margin=0.3):
24 | super(TripletLoss, self).__init__()
25 | self.margin = margin
26 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
27 |
28 | def forward(self, inputs, targets):
29 | n = inputs.size(0)
30 |
31 | # Compute pairwise distance, replace by the official when merged
32 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
33 | dist = dist + dist.t()
34 | dist.addmm_(1, -2, inputs, inputs.t())
35 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
36 |
37 | # For each anchor, find the hardest positive and negative
38 | mask = targets.expand(n, n).eq(targets.expand(n, n).t())
39 | dist_ap, dist_an = [], []
40 | for i in range(n):
41 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
42 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
43 | dist_ap = torch.cat(dist_ap)
44 | dist_an = torch.cat(dist_an)
45 |
46 | # Compute ranking hinge loss
47 | y = torch.ones_like(dist_an)
48 | loss = self.ranking_loss(dist_an, dist_ap, y)
49 |
50 | # compute accuracy
51 | correct = torch.ge(dist_an, dist_ap).sum().item() # torch.eq: greater than or equal to >=
52 |
53 | return loss, correct
54 |
55 |
56 | class PredictionAlignmentLoss(nn.Module):
57 | """
58 | Proposed loss for Prediction Alignment Learning (PAL).
59 | """
60 | def __init__(self, lambda_vr=0.1, lambda_rv=0.5):
61 | super(PredictionAlignmentLoss, self).__init__()
62 | self.lambda_vr = lambda_vr
63 | self.lambda_rv = lambda_rv
64 |
65 | def forward(self, x_rgb, x_ir):
66 | sim_rgbtoir = torch.mm(normalize(x_rgb), normalize(x_ir).t())
67 | sim_irtorgb = torch.mm(normalize(x_ir), normalize(x_rgb).t())
68 | sim_irtoir = torch.mm(normalize(x_ir), normalize(x_ir).t())
69 |
70 | sim_rgbtoir = nn.Softmax(1)(sim_rgbtoir)
71 | sim_irtorgb = nn.Softmax(1)(sim_irtorgb)
72 | sim_irtoir = nn.Softmax(1)(sim_irtoir)
73 |
74 | KL_criterion = nn.KLDivLoss(reduction="batchmean")
75 |
76 | x_rgbtoir = torch.mm(sim_rgbtoir, x_ir)
77 | x_irtorgb = torch.mm(sim_irtorgb, x_rgb)
78 | x_irtoir = torch.mm(sim_irtoir, x_ir)
79 |
80 | x_rgb_s = nn.Softmax(1)(x_rgb)
81 | x_rgbtoir_ls = nn.LogSoftmax(1)(x_rgbtoir)
82 | x_irtorgb_s = nn.Softmax(1)(x_irtorgb)
83 | x_irtoir_ls = nn.LogSoftmax(1)(x_irtoir)
84 |
85 | loss_rgbtoir = KL_criterion(x_rgbtoir_ls, x_rgb_s)
86 | loss_irtorgb = KL_criterion(x_irtoir_ls, x_irtorgb_s)
87 |
88 | loss = self.lambda_vr * loss_rgbtoir + self.lambda_rv * loss_irtorgb
89 |
90 | return loss, sim_rgbtoir, sim_irtorgb
--------------------------------------------------------------------------------
/main_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import easydict
3 | import sys
4 | import os
5 | import time
6 | import yaml
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.data as data
11 | import torchvision.transforms as transforms
12 | from utils import Logger, set_seed, GenIdx
13 | from data_loader import TestData
14 | from data_manager import process_query_sysu, process_gallery_sysu, process_test_regdb
15 | from model.network import BaseResNet
16 | from engine import tester
17 |
18 |
19 | def main_worker(args, args_main):
20 | ## set gpu id and seed id
21 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
22 | torch.backends.cudnn.benchmark = True # accelerate the running speed of convolution network
23 | device = "cuda" if torch.cuda.is_available() else "cpu"
24 | set_seed(args.seed, cuda=torch.cuda.is_available())
25 |
26 | ## set file
27 | if not os.path.isdir(args.dataset + "_" + args.setting + "_" + args.file_name):
28 | os.makedirs(args.dataset + "_" + args.setting + "_" + args.file_name)
29 | file_name = args.dataset + "_" + args.setting + "_" + args.file_name
30 |
31 | if args.dataset == "sysu":
32 | data_path = args.dataset_path + "SYSU-MM01/"
33 | log_path = os.path.join(file_name, args.dataset + "_" + args.log_path)
34 | test_mode = [1, 2]
35 | elif args.dataset == "regdb":
36 | data_path = args.dataset_path + "RegDB/"
37 | log_path = os.path.join(file_name, args.dataset + "_" + args.log_path)
38 | if args.mode == "thermaltovisible":
39 | test_mode = [1, 2]
40 | elif args.mode == "visibletothermal":
41 | test_mode = [2, 1]
42 |
43 | if not os.path.isdir(log_path):
44 | os.makedirs(log_path)
45 |
46 | sys.stdout = Logger(os.path.join(log_path, "log_test.txt"))
47 |
48 | ## load data
49 | print("==========\nargs_main:{}\n==========".format(args_main))
50 | print("==========\nargs:{}\n==========".format(args))
51 | print("==> Loading data...")
52 |
53 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54 | transform_test = transforms.Compose([
55 | transforms.ToPILImage(),
56 | transforms.Resize((args.img_h, args.img_w)),
57 | transforms.ToTensor(),
58 | normalize,
59 | ])
60 |
61 | end = time.time()
62 | if args.dataset == "sysu":
63 | # testing set
64 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode)
65 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode)
66 | elif args.dataset == "regdb":
67 | # testing set
68 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modality=args.mode.split("to")[0])
69 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modality=args.mode.split("to")[1])
70 |
71 | gallset = TestData(gall_img, gall_label, transform_test=transform_test, img_size=(args.img_w, args.img_h))
72 | queryset = TestData(query_img, query_label, transform_test=transform_test, img_size=(args.img_w, args.img_h))
73 |
74 | # testing data loader
75 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers)
76 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers)
77 |
78 | print("Dataset {} Statistics:".format(args.dataset))
79 | print(" ----------------------------")
80 | print(" subset | # ids | # images")
81 | print(" ----------------------------")
82 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), len(query_label)))
83 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), len(gall_label)))
84 | print(" ----------------------------")
85 | print("Data loading time:\t {:.3f}".format(time.time() - end))
86 |
87 | if args.dataset == "sysu":
88 | n_class = 395 # initial value
89 | elif args.dataset == "regdb":
90 | n_class = 206 # initial value
91 | else:
92 | n_class = 1000 # initial value
93 | epoch = 0 # initial value
94 |
95 | ## resume checkpoints
96 | if args_main.resume:
97 | resume_path = args_main.resume_path
98 | if os.path.isfile(resume_path):
99 | checkpoint = torch.load(resume_path)
100 | if "main_net" in checkpoint.keys():
101 | n_class = checkpoint["main_net"]["classifier.weight"].size(0)
102 | elif "net" in checkpoint.keys():
103 | n_class = checkpoint["net"]["classifier.weight"].size(0)
104 | epoch = checkpoint["epoch"]
105 | print("==> Loading checkpoint {} (epoch {}, number of classes {})".format(resume_path, epoch, n_class))
106 | else:
107 | print("==> No checkpoint is found at {} (epoch {}, number of classes {})".format(resume_path, epoch, n_class))
108 | else:
109 | print("==> No checkpont is loaded (epoch {}, number of classes {})".format(epoch, n_class))
110 |
111 | ## build model
112 | main_net = BaseResNet(pool_dim=args.pool_dim, class_num=n_class, per_add_iters=args.per_add_iters, arch=args.arch)
113 | if args_main.resume and os.path.isfile(resume_path):
114 | if "main_net" in checkpoint.keys():
115 | main_net.load_state_dict(checkpoint["main_net"])
116 | elif "net" in checkpoint.keys():
117 | main_net.load_state_dict(checkpoint["net"])
118 | main_net.to(device)
119 |
120 | # start testing
121 | if args.dataset == "sysu":
122 | print("Testing Epoch: {}, Testing mode: {}".format(epoch, args.mode))
123 | elif args.dataset == "regdb":
124 | print("Testing Epoch: {}, Testing mode: {}, Trial: {}".format(epoch, args.mode, args.trial))
125 |
126 | end = time.time()
127 | if args.dataset == "sysu":
128 | cmc, mAP, mINP = tester(args, epoch, main_net, test_mode, gall_label, gall_loader, query_label, query_loader, feat_dim=args.pool_dim, query_cam=query_cam, gall_cam=gall_cam)
129 | elif args.dataset == "regdb":
130 | cmc, mAP, mINP = tester(args, epoch, main_net, test_mode, gall_label, gall_loader, query_label, query_loader, feat_dim=args.pool_dim)
131 | print("Testing time per epoch: {:.3f}".format(time.time() - end))
132 |
133 | print("Performance: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}".format(cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP))
134 |
135 |
136 | if __name__ == "__main__":
137 | parser = argparse.ArgumentParser(description="OTLA-ReID for testing")
138 | parser.add_argument("--config", default="config/baseline.yaml", help="config file")
139 | parser.add_argument("--resume", action="store_true", help="resume from checkpoint")
140 | parser.add_argument("--resume_path", default="", help="checkpoint path")
141 |
142 | args_main = parser.parse_args()
143 | args = yaml.load(open(args_main.config), Loader=yaml.FullLoader)
144 | args = easydict.EasyDict(args)
145 |
146 | main_worker(args, args_main)
--------------------------------------------------------------------------------
/model/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6 | 'resnet152']
7 |
8 | model_urls = {
9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
14 | }
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1, dilation=1):
18 | """3x3 convolution with padding"""
19 | # original padding is 1; original dilation is 1
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=dilation, bias=False, dilation=dilation)
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
27 | super(BasicBlock, self).__init__()
28 | self.conv1 = conv3x3(inplanes, planes, stride, dilation)
29 | self.bn1 = nn.BatchNorm2d(planes)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.conv2 = conv3x3(planes, planes)
32 | self.bn2 = nn.BatchNorm2d(planes)
33 | self.downsample = downsample
34 | self.stride = stride
35 |
36 | def forward(self, x):
37 | residual = x
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 | out = self.bn2(out)
45 |
46 | if self.downsample is not None:
47 | residual = self.downsample(x)
48 |
49 | out += residual
50 | out = self.relu(out)
51 |
52 | return out
53 |
54 |
55 | class Bottleneck(nn.Module):
56 | expansion = 4
57 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
58 | super(Bottleneck, self).__init__()
59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
60 | self.bn1 = nn.BatchNorm2d(planes)
61 | # original padding is 1; original dilation is 1
62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation)
63 | self.bn2 = nn.BatchNorm2d(planes)
64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
65 | self.bn3 = nn.BatchNorm2d(planes * 4)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.downsample = downsample
68 | self.stride = stride
69 |
70 | def forward(self, x):
71 | residual = x
72 |
73 | out = self.conv1(x)
74 | out = self.bn1(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv2(out)
78 | out = self.bn2(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv3(out)
82 | out = self.bn3(out)
83 |
84 | if self.downsample is not None:
85 | residual = self.downsample(x)
86 |
87 | out += residual
88 | out = self.relu(out)
89 |
90 | return out
91 |
92 |
93 | class ResNet(nn.Module):
94 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1):
95 | self.inplanes = 64
96 | super(ResNet, self).__init__()
97 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
98 | self.bn1 = nn.BatchNorm2d(64)
99 | self.relu = nn.ReLU(inplace=True)
100 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
101 | self.layer1 = self._make_layer(block, 64, layers[0])
102 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
103 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
104 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation)
105 |
106 | for m in self.modules():
107 | if isinstance(m, nn.Conv2d):
108 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
109 | m.weight.data.normal_(0, math.sqrt(2. / n))
110 | elif isinstance(m, nn.BatchNorm2d):
111 | m.weight.data.fill_(1)
112 | m.bias.data.zero_()
113 |
114 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
115 | downsample = None
116 | if stride != 1 or self.inplanes != planes * block.expansion:
117 | downsample = nn.Sequential(
118 | nn.Conv2d(self.inplanes, planes * block.expansion,
119 | kernel_size=1, stride=stride, bias=False),
120 | nn.BatchNorm2d(planes * block.expansion),
121 | )
122 |
123 | layers = []
124 | layers.append(block(self.inplanes, planes, stride, downsample, dilation))
125 | self.inplanes = planes * block.expansion
126 | for i in range(1, blocks):
127 | layers.append(block(self.inplanes, planes))
128 |
129 | return nn.Sequential(*layers)
130 |
131 | def forward(self, x):
132 | x = self.conv1(x)
133 | x = self.bn1(x)
134 | x = self.relu(x)
135 | x = self.maxpool(x)
136 |
137 | x = self.layer1(x)
138 | x = self.layer2(x)
139 | x = self.layer3(x)
140 | x = self.layer4(x)
141 |
142 | return x
143 |
144 |
145 | def remove_fc(state_dict):
146 | """Remove the fc layer parameters from state_dict."""
147 | # for key, value in state_dict.items():
148 | for key, value in list(state_dict.items()):
149 | if key.startswith('fc.'):
150 | del state_dict[key]
151 |
152 | return state_dict
153 |
154 |
155 | def resnet18(pretrained=False, **kwargs):
156 | """Constructs a ResNet-18 model.
157 | Args:
158 | pretrained (bool): If True, returns a model pre-trained on ImageNet
159 | """
160 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
161 | if pretrained:
162 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18'])))
163 |
164 | return model
165 |
166 |
167 | def resnet34(pretrained=False, **kwargs):
168 | """Constructs a ResNet-34 model.
169 | Args:
170 | pretrained (bool): If True, returns a model pre-trained on ImageNet
171 | """
172 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
173 | if pretrained:
174 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34'])))
175 |
176 | return model
177 |
178 |
179 | def resnet50(pretrained=False, **kwargs):
180 | """Constructs a ResNet-50 model.
181 | Args:
182 | pretrained (bool): If True, returns a model pre-trained on ImageNet
183 | """
184 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
185 | if pretrained:
186 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50'])))
187 |
188 | return model
189 |
190 |
191 | def resnet101(pretrained=False, **kwargs):
192 | """Constructs a ResNet-101 model.
193 | Args:
194 | pretrained (bool): If True, returns a model pre-trained on ImageNet
195 | """
196 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
197 | if pretrained:
198 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet101'])))
199 |
200 | return model
201 |
202 |
203 | def resnet152(pretrained=False, **kwargs):
204 | """Constructs a ResNet-152 model.
205 | Args:
206 | pretrained (bool): If True, returns a model pre-trained on ImageNet
207 | """
208 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
209 | if pretrained:
210 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet152'])))
211 |
212 | return model
--------------------------------------------------------------------------------
/model/network.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import init
5 | from .backbone.resnet import resnet50
6 |
7 |
8 | class Normalize(nn.Module):
9 | def __init__(self, power=2):
10 | super(Normalize, self).__init__()
11 | self.power = power
12 |
13 | def forward(self, x):
14 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
15 | out = x.div(norm)
16 | return out
17 |
18 |
19 | def weights_init_kaiming(m):
20 | classname = m.__class__.__name__
21 | if classname.find("Conv") != -1:
22 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
23 | elif classname.find("Linear") != -1:
24 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_out")
25 | init.zeros_(m.bias.data)
26 | elif classname.find("BatchNorm1d") != -1:
27 | init.normal_(m.weight.data, 1.0, 0.01)
28 | init.zeros_(m.bias.data)
29 |
30 |
31 | def weights_init_classifier(m):
32 | classname = m.__class__.__name__
33 | if classname.find("Linear") != -1:
34 | init.normal_(m.weight.data, 0, 0.001)
35 | if m.bias is not None:
36 | init.zeros_(m.bias.data)
37 |
38 |
39 | class gradientreverselayer(torch.autograd.Function):
40 | @staticmethod
41 | def forward(ctx, coeff, input):
42 | ctx.coeff = coeff
43 | # this is necessary. if we just return "input", "backward" will not be called sometimes
44 | return input.view_as(input)
45 |
46 | @staticmethod
47 | def backward(ctx, grad_outputs):
48 | coeff = ctx.coeff
49 | return None, -coeff * grad_outputs
50 |
51 |
52 | class AdversarialLayer(nn.Module):
53 | def __init__(self, per_add_iters, iter_num=0, alpha=10.0, low_value=0.0, high_value=1.0, max_iter=10000.0):
54 | super(AdversarialLayer, self).__init__()
55 | self.per_add_iters = per_add_iters
56 | self.iter_num = iter_num
57 | self.alpha = alpha
58 | self.low_value = low_value
59 | self.high_value = high_value
60 | self.max_iter = max_iter
61 | self.grl = gradientreverselayer.apply
62 |
63 | def forward(self, input, train_set=True):
64 | if train_set:
65 | self.iter_num += self.per_add_iters
66 | self.coeff = np.float(
67 | 2.0 * (self.high_value - self.low_value) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iter)) - (
68 | self.high_value - self.low_value) + self.low_value)
69 |
70 | return self.grl(self.coeff, input)
71 |
72 |
73 | class DiscriminateNet(nn.Module):
74 | def __init__(self, input_dim, class_num=1):
75 | super(DiscriminateNet, self).__init__()
76 | self.ad_layer1 = nn.Linear(input_dim, input_dim//2)
77 | self.ad_layer2 = nn.Linear(input_dim//2, input_dim//2)
78 | self.ad_layer3 = nn.Linear(input_dim//2, class_num)
79 | self.relu1 = nn.ReLU()
80 | self.relu2 = nn.ReLU()
81 | self.dropout1 = nn.Dropout(0.5)
82 | self.dropout2 = nn.Dropout(0.5)
83 | self.bn = nn.BatchNorm1d(class_num)
84 | self.bn2 = nn.BatchNorm1d(input_dim // 2)
85 | self.bn.bias.requires_grad_(False)
86 | self.bn2.bias.requires_grad_(False)
87 | self.sigmoid = nn.Sigmoid()
88 |
89 | self.ad_layer1.apply(weights_init_kaiming)
90 | self.ad_layer2.apply(weights_init_kaiming)
91 | self.ad_layer3.apply(weights_init_classifier)
92 |
93 | def forward(self, x):
94 | x = self.ad_layer1(x)
95 | x = self.relu1(x)
96 | x = self.dropout1(x)
97 | x = self.ad_layer2(x)
98 | x = self.relu2(x)
99 | x = self.dropout2(x)
100 | x = self.ad_layer3(x)
101 | x = self.bn(x)
102 | x = self.sigmoid(x) # binary classification
103 |
104 | return x
105 |
106 |
107 | class BaseResNet(nn.Module):
108 | def __init__(self, pool_dim, class_num, per_add_iters, arch="resnet50"):
109 | super(BaseResNet, self).__init__()
110 |
111 | if arch == "resnet50":
112 | network = resnet50(pretrained=True, last_conv_stride=1, last_conv_dilation=1)
113 |
114 | self.layer0 = nn.Sequential(network.conv1,
115 | network.bn1,
116 | network.relu,
117 | network.maxpool)
118 | self.layer1 = network.layer1
119 | self.layer2 = network.layer2
120 | self.layer3 = network.layer3
121 | self.layer4 = network.layer4
122 |
123 | self.bottleneck_0 = nn.BatchNorm1d(64)
124 | self.bottleneck_0.bias.requires_grad_(False) # no shift
125 | self.bottleneck_1 = nn.BatchNorm1d(256)
126 | self.bottleneck_1.bias.requires_grad_(False) # no shift
127 | self.bottleneck_2 = nn.BatchNorm1d(512)
128 | self.bottleneck_2.bias.requires_grad_(False) # no shift
129 | self.bottleneck_3 = nn.BatchNorm1d(1024)
130 | self.bottleneck_3.bias.requires_grad_(False) # no shift
131 | self.bottleneck = nn.BatchNorm1d(pool_dim)
132 | self.bottleneck.bias.requires_grad_(False) # no shift
133 |
134 | self.classifier = nn.Linear(pool_dim, class_num, bias=False)
135 | self.adnet = AdversarialLayer(per_add_iters=per_add_iters)
136 | self.disnet = DiscriminateNet(64 + 256 + 512 + 1024 + pool_dim, 1)
137 |
138 | self.bottleneck_0.apply(weights_init_kaiming)
139 | self.bottleneck_1.apply(weights_init_kaiming)
140 | self.bottleneck_2.apply(weights_init_kaiming)
141 | self.bottleneck_3.apply(weights_init_kaiming)
142 | self.bottleneck.apply(weights_init_kaiming)
143 | self.classifier.apply(weights_init_classifier)
144 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
145 |
146 | self.l2norm = Normalize(2)
147 |
148 | def forward(self, x_rgb, x_ir, modal=0, train_set=True):
149 | if modal == 0:
150 | x = torch.cat((x_rgb, x_ir), dim=0)
151 | elif modal == 1:
152 | x = x_rgb
153 | elif modal == 2:
154 | x = x_ir
155 |
156 | x_0 = self.layer0(x)
157 | x_1 = self.layer1(x_0)
158 | x_2 = self.layer2(x_1)
159 | x_3 = self.layer3(x_2)
160 | x_4 = self.layer4(x_3)
161 |
162 | x_pool_0 = self.avgpool(x_0)
163 | x_pool_0 = x_pool_0.view(x_pool_0.size(0), x_pool_0.size(1))
164 | x_pool_1 = self.avgpool(x_1)
165 | x_pool_1 = x_pool_1.view(x_pool_1.size(0), x_pool_1.size(1))
166 | x_pool_2 = self.avgpool(x_2)
167 | x_pool_2 = x_pool_2.view(x_pool_2.size(0), x_pool_2.size(1))
168 | x_pool_3 = self.avgpool(x_3)
169 | x_pool_3 = x_pool_3.view(x_pool_3.size(0), x_pool_3.size(1))
170 | x_pool_4 = self.avgpool(x_4)
171 | x_pool_4 = x_pool_4.view(x_pool_4.size(0), x_pool_4.size(1))
172 |
173 | feat_0 = self.bottleneck_0(x_pool_0)
174 | feat_1 = self.bottleneck_1(x_pool_1)
175 | feat_2 = self.bottleneck_2(x_pool_2)
176 | feat_3 = self.bottleneck_3(x_pool_3)
177 | feat_4 = self.bottleneck(x_pool_4)
178 |
179 | if self.training:
180 | feat = torch.cat((feat_0, feat_1, feat_2, feat_3, feat_4), dim=1)
181 | x = self.adnet(feat, train_set=train_set)
182 | x_dis = self.disnet(x)
183 | p_4 = self.classifier(feat_4)
184 |
185 | return x_pool_4, p_4, x_dis
186 | else:
187 | return self.l2norm(feat_4)
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 |
3 |
4 | def adjust_learning_rate(args, optimizer, epoch):
5 | if epoch < 10:
6 | lr = args.lr * (epoch + 1) / 10
7 | elif epoch >= 10 and epoch < 20:
8 | lr = args.lr
9 | elif epoch >= 20 and epoch < 50:
10 | lr = args.lr * 0.1
11 | elif epoch >= 50:
12 | lr = args.lr * 0.01
13 |
14 | optimizer.param_groups[0]["lr"] = 0.1 * lr
15 | for i in range(len(optimizer.param_groups) - 1):
16 | optimizer.param_groups[i + 1]["lr"] = lr
17 |
18 | return lr
19 |
20 |
21 | def select_optimizer(args, main_net):
22 | if args.optim == "adam":
23 | ignored_params = list(map(id, main_net.bottleneck.parameters())) \
24 | + list(map(id, main_net.classifier.parameters())) \
25 | + list(map(id, main_net.adnet.parameters())) \
26 | + list(map(id, main_net.disnet.parameters())) \
27 | + list(map(id, main_net.bottleneck_0.parameters())) \
28 | + list(map(id, main_net.bottleneck_1.parameters())) \
29 | + list(map(id, main_net.bottleneck_2.parameters())) \
30 | + list(map(id, main_net.bottleneck_3.parameters()))
31 |
32 | base_params = filter(lambda p: id(p) not in ignored_params, main_net.parameters())
33 | optimizer = optim.Adam([
34 | {"params": base_params, "lr": 0.1 * args.lr},
35 | {"params": main_net.bottleneck_0.parameters(), "lr": args.lr},
36 | {"params": main_net.bottleneck_1.parameters(), "lr": args.lr},
37 | {"params": main_net.bottleneck_2.parameters(), "lr": args.lr},
38 | {"params": main_net.bottleneck_3.parameters(), "lr": args.lr},
39 | {"params": main_net.bottleneck.parameters(), "lr": args.lr},
40 | {"params": main_net.classifier.parameters(), "lr": args.lr},
41 | {"params": main_net.adnet.parameters(), "lr": args.lr},
42 | {"params": main_net.disnet.parameters(), "lr": args.lr}],
43 | weight_decay=5e-4)
44 |
45 | return optimizer
--------------------------------------------------------------------------------
/otla_sk.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from utils import sort_list_with_unique_index
6 |
7 |
8 | def cpu_sk_ir_trainloader(args, main_net, trainloader, tIndex, n_class, print_freq=50):
9 | main_net.train()
10 |
11 | n_ir = len(tIndex)
12 | P = np.zeros((n_ir, n_class))
13 |
14 | with torch.no_grad():
15 | for batch_idx, (input_rgb, input_ir, label_rgb, label_ir) in enumerate(trainloader):
16 | t = time.time()
17 | input_ir = input_ir.cuda()
18 | _, p, _ = main_net(input_ir, input_ir, modal=2, train_set=False)
19 | p_softmax = nn.Softmax(1)(p).cpu().numpy()
20 | P[batch_idx * args.train_batch_size * args.num_pos:(batch_idx + 1) * args.train_batch_size * args.num_pos, :] = p_softmax
21 |
22 | if batch_idx == 0:
23 | ir_real_label = label_ir
24 | else:
25 | ir_real_label = torch.cat((ir_real_label, label_ir), dim=0)
26 |
27 | if (batch_idx + 1) % print_freq == 0:
28 | print("Extract predictions: [{}/{}]\t"
29 | "Time consuming: {:.3f}\t"
30 | .format(batch_idx + 1, len(trainloader), time.time() - t))
31 |
32 | # optimizer label using Sinkhorn-Knopp algorithm
33 | unique_tIndex_first_idx, unique_tIndex_last_idx, unique_tIndex_num, idx_order, unique_tIndex_list = sort_list_with_unique_index(tIndex)
34 | unique_tIndex_idx = unique_tIndex_last_idx # last
35 | ir_real_label = ir_real_label[unique_tIndex_idx]
36 | P_ = P[unique_tIndex_idx]
37 | for i, idx in enumerate(idx_order):
38 | P_[i] = (P[unique_tIndex_list[idx]].mean(axis=0))
39 | PS = (P_.T) ** args.lambda_sk
40 |
41 | n_ir_unique = len(np.unique(tIndex))
42 | alpha = np.ones((n_class, 1)) / n_class # initial value for alpha
43 | beta = np.ones((n_ir_unique, 1)) / n_ir_unique # initial value for beta
44 |
45 | inv_K = 1. / n_class
46 | inv_N = 1. / n_ir_unique
47 |
48 | err = 1e6
49 | step = 0
50 | tt = time.time()
51 | while err > 1e-1:
52 | alpha = inv_K / (PS @ beta) # (KxN) @ (N,1) = K x 1
53 | beta_new = inv_N / (alpha.T @ PS).T # ((1,K) @ (KxN)).t() = N x 1
54 | if step % 10 == 0:
55 | err = np.nansum(np.abs(beta / beta_new - 1))
56 | beta = beta_new
57 | step += 1
58 | print("Sinkhorn-Knopp Error: {:.3f} Total step: {} Total time: {:.3f}".format(err, step, time.time() - tt))
59 | PS *= np.squeeze(beta)
60 | PS = PS.T
61 | PS *= np.squeeze(alpha)
62 | PS = PS.T
63 | argmaxes = np.nanargmax(PS, 0) # size n_ir
64 | ir_pseudo_label_op = torch.LongTensor(argmaxes)
65 |
66 | # the max prediction of softmax
67 | argmaxes_ = np.nanargmax(P_, 1)
68 | ir_pseudo_label_mp = torch.LongTensor(argmaxes_)
69 |
70 | return ir_pseudo_label_op, ir_pseudo_label_mp, ir_real_label, tIndex[unique_tIndex_idx]
--------------------------------------------------------------------------------
/video-poster/0971.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjm-wjm/OTLA-ReID/1405e96bd8339deeebf61718346b24722770ac61/video-poster/0971.mp4
--------------------------------------------------------------------------------
/video-poster/0971.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjm-wjm/OTLA-ReID/1405e96bd8339deeebf61718346b24722770ac61/video-poster/0971.pdf
--------------------------------------------------------------------------------