├── .gitignore ├── LICENSE ├── README.md ├── examples ├── test.py ├── train_pplr.py └── train_pplr_cam.py ├── figs └── overview.jpg ├── pplr ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dukemtmc.py │ ├── market1501.py │ ├── msmt17.py │ └── veri.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── evaluators.py ├── loss │ ├── __init__.py │ ├── crossentropy.py │ ├── loss.py │ └── triplet.py ├── models │ ├── __init__.py │ ├── resnet.py │ └── resnet_part.py ├── trainers.py └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── preprocessor.py │ ├── sampler.py │ └── transforms.py │ ├── faiss_rerank.py │ ├── faiss_utils.py │ ├── logging.py │ ├── meters.py │ ├── osutils.py │ ├── rerank.py │ └── serialization.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Git ### 2 | # Created by git for backups. To disable backups in Git: 3 | # $ git config --global mergetool.keepBackup false 4 | *.orig 5 | 6 | # Created by git when using merge tools for conflicts 7 | *.BACKUP.* 8 | *.BASE.* 9 | *.LOCAL.* 10 | *.REMOTE.* 11 | *_BACKUP_*.txt 12 | *_BASE_*.txt 13 | *_LOCAL_*.txt 14 | *_REMOTE_*.txt 15 | 16 | ### Intellij ### 17 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 18 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 19 | 20 | # User-specific stuff 21 | .idea/**/workspace.xml 22 | .idea/**/tasks.xml 23 | .idea/**/usage.statistics.xml 24 | .idea/**/dictionaries 25 | .idea/**/shelf 26 | 27 | # AWS User-specific 28 | .idea/**/aws.xml 29 | 30 | # Generated files 31 | .idea/**/contentModel.xml 32 | 33 | # Sensitive or high-churn files 34 | .idea/**/dataSources/ 35 | .idea/**/dataSources.ids 36 | .idea/**/dataSources.local.xml 37 | .idea/**/sqlDataSources.xml 38 | .idea/**/dynamic.xml 39 | .idea/**/uiDesigner.xml 40 | .idea/**/dbnavigator.xml 41 | 42 | # Gradle 43 | .idea/**/gradle.xml 44 | .idea/**/libraries 45 | 46 | # Gradle and Maven with auto-import 47 | # When using Gradle or Maven with auto-import, you should exclude module files, 48 | # since they will be recreated, and may cause churn. Uncomment if using 49 | # auto-import. 50 | # .idea/artifacts 51 | # .idea/compiler.xml 52 | # .idea/jarRepositories.xml 53 | # .idea/modules.xml 54 | # .idea/*.iml 55 | # .idea/modules 56 | # *.iml 57 | # *.ipr 58 | 59 | # CMake 60 | cmake-build-*/ 61 | 62 | # Mongo Explorer plugin 63 | .idea/**/mongoSettings.xml 64 | 65 | # File-based project format 66 | *.iws 67 | 68 | # IntelliJ 69 | out/ 70 | 71 | # mpeltonen/sbt-idea plugin 72 | .idea_modules/ 73 | 74 | # JIRA plugin 75 | atlassian-ide-plugin.xml 76 | 77 | # Cursive Clojure plugin 78 | .idea/replstate.xml 79 | 80 | # SonarLint plugin 81 | .idea/sonarlint/ 82 | 83 | # Crashlytics plugin (for Android Studio and IntelliJ) 84 | com_crashlytics_export_strings.xml 85 | crashlytics.properties 86 | crashlytics-build.properties 87 | fabric.properties 88 | 89 | # Editor-based Rest Client 90 | .idea/httpRequests 91 | 92 | # Android studio 3.1+ serialized cache file 93 | .idea/caches/build_file_checksums.ser 94 | 95 | ### Intellij Patch ### 96 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 97 | 98 | # *.iml 99 | # modules.xml 100 | # .idea/misc.xml 101 | # *.ipr 102 | 103 | # Sonarlint plugin 104 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 105 | .idea/**/sonarlint/ 106 | 107 | # SonarQube Plugin 108 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 109 | .idea/**/sonarIssues.xml 110 | 111 | # Markdown Navigator plugin 112 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 113 | .idea/**/markdown-navigator.xml 114 | .idea/**/markdown-navigator-enh.xml 115 | .idea/**/markdown-navigator/ 116 | 117 | # Cache file creation bug 118 | # See https://youtrack.jetbrains.com/issue/JBR-2257 119 | .idea/$CACHE_FILE$ 120 | 121 | # CodeStream plugin 122 | # https://plugins.jetbrains.com/plugin/12206-codestream 123 | .idea/codestream.xml 124 | 125 | ### PyCharm+all ### 126 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 127 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 128 | 129 | # User-specific stuff 130 | 131 | # AWS User-specific 132 | 133 | # Generated files 134 | 135 | # Sensitive or high-churn files 136 | 137 | # Gradle 138 | 139 | # Gradle and Maven with auto-import 140 | # When using Gradle or Maven with auto-import, you should exclude module files, 141 | # since they will be recreated, and may cause churn. Uncomment if using 142 | # auto-import. 143 | # .idea/artifacts 144 | # .idea/compiler.xml 145 | # .idea/jarRepositories.xml 146 | # .idea/modules.xml 147 | # .idea/*.iml 148 | # .idea/modules 149 | # *.iml 150 | # *.ipr 151 | 152 | # CMake 153 | 154 | # Mongo Explorer plugin 155 | 156 | # File-based project format 157 | 158 | # IntelliJ 159 | 160 | # mpeltonen/sbt-idea plugin 161 | 162 | # JIRA plugin 163 | 164 | # Cursive Clojure plugin 165 | 166 | # SonarLint plugin 167 | 168 | # Crashlytics plugin (for Android Studio and IntelliJ) 169 | 170 | # Editor-based Rest Client 171 | 172 | # Android studio 3.1+ serialized cache file 173 | 174 | ### PyCharm+all Patch ### 175 | # Ignore everything but code style settings and run configurations 176 | # that are supposed to be shared within teams. 177 | 178 | .idea/* 179 | 180 | !.idea/codeStyles 181 | !.idea/runConfigurations 182 | 183 | ### Python ### 184 | # Byte-compiled / optimized / DLL files 185 | __pycache__/ 186 | *.py[cod] 187 | *$py.class 188 | 189 | # C extensions 190 | *.so 191 | 192 | # Distribution / packaging 193 | .Python 194 | build/ 195 | develop-eggs/ 196 | dist/ 197 | downloads/ 198 | eggs/ 199 | .eggs/ 200 | lib/ 201 | lib64/ 202 | parts/ 203 | sdist/ 204 | var/ 205 | wheels/ 206 | share/python-wheels/ 207 | *.egg-info/ 208 | .installed.cfg 209 | *.egg 210 | MANIFEST 211 | 212 | # PyInstaller 213 | # Usually these files are written by a python script from a template 214 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 215 | *.manifest 216 | *.spec 217 | 218 | # Installer logs 219 | pip-log.txt 220 | pip-delete-this-directory.txt 221 | 222 | # Unit test / coverage reports 223 | htmlcov/ 224 | .tox/ 225 | .nox/ 226 | .coverage 227 | .coverage.* 228 | .cache 229 | nosetests.xml 230 | coverage.xml 231 | *.cover 232 | *.py,cover 233 | .hypothesis/ 234 | .pytest_cache/ 235 | cover/ 236 | 237 | # Translations 238 | *.mo 239 | *.pot 240 | 241 | # Django stuff: 242 | *.log 243 | local_settings.py 244 | db.sqlite3 245 | db.sqlite3-journal 246 | 247 | # Flask stuff: 248 | instance/ 249 | .webassets-cache 250 | 251 | # Scrapy stuff: 252 | .scrapy 253 | 254 | # Sphinx documentation 255 | docs/_build/ 256 | 257 | # PyBuilder 258 | .pybuilder/ 259 | target/ 260 | 261 | # Jupyter Notebook 262 | .ipynb_checkpoints 263 | 264 | # IPython 265 | profile_default/ 266 | ipython_config.py 267 | 268 | # pyenv 269 | # For a library or package, you might want to ignore these files since the code is 270 | # intended to run in multiple environments; otherwise, check them in: 271 | # .python-version 272 | 273 | # pipenv 274 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 275 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 276 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 277 | # install all needed dependencies. 278 | #Pipfile.lock 279 | 280 | # poetry 281 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 282 | # This is especially recommended for binary packages to ensure reproducibility, and is more 283 | # commonly ignored for libraries. 284 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 285 | #poetry.lock 286 | 287 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 288 | __pypackages__/ 289 | 290 | # Celery stuff 291 | celerybeat-schedule 292 | celerybeat.pid 293 | 294 | # SageMath parsed files 295 | *.sage.py 296 | 297 | # Environments 298 | .env 299 | .venv 300 | env/ 301 | venv/ 302 | ENV/ 303 | env.bak/ 304 | venv.bak/ 305 | 306 | # Spyder project settings 307 | .spyderproject 308 | .spyproject 309 | 310 | # Rope project settings 311 | .ropeproject 312 | 313 | # mkdocs documentation 314 | /site 315 | 316 | # mypy 317 | .mypy_cache/ 318 | .dmypy.json 319 | dmypy.json 320 | 321 | # Pyre type checker 322 | .pyre/ 323 | 324 | # pytype static type analyzer 325 | .pytype/ 326 | 327 | # Cython debug symbols 328 | cython_debug/ 329 | 330 | # PyCharm 331 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 332 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 333 | # and can be added to the global gitignore or merged into this file. For a more nuclear 334 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 335 | #.idea/ 336 | 337 | # project specific 338 | logs/* 339 | data/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yoonki Cho 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Part-based Pseudo Label Refinement (PPLR) 2 | Official PyTorch implementation of [Part-based Pseudo Label Refinement for Unsupervised Person Re-identification](https://arxiv.org/abs/2203.14675) (CVPR 2022). 3 | 4 | ## Updates 5 | - [07/2022] Pretrained weights are released. 6 | - [06/2022] Code is released. 7 | 8 | ## Overview 9 | ![overview](figs/overview.jpg) 10 | >We propose a Part-based Pseudo Label Refinement (PPLR) framework that reduces the label noise by employing the complementary relationship between global and part features. 11 | Specifically, we design a cross agreement score as the similarity of k-nearest neighbors between feature spaces to exploit the reliable complementary relationship. 12 | Based on the cross agreement, we refine pseudo-labels of global features by ensembling the predictions of part features, which collectively alleviate the noise in global feature clustering. 13 | We further refine pseudo-labels of part features by applying label smoothing according to the suitability of given labels for each part. 14 | Our PPLR learns discriminative representations with rich local contexts. Also, it operates in a self-ensemble manner without auxiliary teacher networks, which is computationally efficient. 15 | 16 | ## Getting Started 17 | ### Installation 18 | ```shell 19 | git clone https://github.com/yoonkicho/PPLR 20 | cd PPLR 21 | python setup.py develop 22 | ``` 23 | ### Preparing Datasets 24 | ```shell 25 | cd examples && mkdir data 26 | ``` 27 | Download the object re-ID datasets [Market-1501](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view), [MSMT17](https://arxiv.org/abs/1711.08565), and [VeRi-776](https://github.com/JDAI-CV/VeRidataset) to `PPLR/examples/data`. 28 | The directory should look like: 29 | ``` 30 | PPLR/examples/data 31 | ├── Market-1501-v15.09.15 32 | ├── MSMT17_V1 33 | └── VeRi 34 | ``` 35 | ## Training 36 | We utilize 4 TITAN RTX GPUs for training. 37 | We use 384x128 sized images for Market-1501 and MSMT17 and 256x256 sized images for VeRi-776. 38 | 39 | ### Training without camera labels 40 | For Market-1501: 41 | ``` 42 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 43 | python examples/train_pplr.py \ 44 | -d market1501 --logs-dir $PATH_FOR_LOGS 45 | ``` 46 | For MSMT17: 47 | ``` 48 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 49 | python examples/train_pplr.py \ 50 | -d msmt17 --logs-dir $PATH_FOR_LOGS 51 | ``` 52 | For VeRi-776: 53 | ``` 54 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 55 | python examples/train_pplr.py \ 56 | -d veri -n 8 --height 256 --width 256 --eps 0.7 --logs-dir $PATH_FOR_LOGS 57 | ``` 58 | 59 | ### Training with camera labels 60 | For Market-1501: 61 | ``` 62 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 63 | python examples/train_pplr_cam.py \ 64 | -d market1501 --eps 0.4 --logs-dir $PATH_FOR_LOGS 65 | ``` 66 | For MSMT17: 67 | ``` 68 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 69 | python examples/train_pplr_cam.py \ 70 | -d msmt17 --eps 0.6 --lam-cam 1.0 --logs-dir $PATH_FOR_LOGS 71 | ``` 72 | For VeRi-776: 73 | ``` 74 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 75 | python examples/train_pplr_cam.py \ 76 | -d veri -n 8 --height 256 --width 256 --eps 0.7 --logs-dir $PATH_FOR_LOGS 77 | ``` 78 | 79 | ## Testing 80 | We use a single TITAN RTX GPU for testing. 81 | 82 | You can download pre-trained weights from this [link](https://drive.google.com/drive/folders/1m5wDOJG7qk62PjkoOpTspNmk0nhLc4Vi?usp=sharing). 83 | 84 | For Market-1501: 85 | ``` 86 | CUDA_VISIBLE_DEVICES=0\ 87 | python examples/test.py \ 88 | -d market1501 --resume $PATH_FOR_MODEL 89 | ``` 90 | For MSMT17: 91 | ``` 92 | CUDA_VISIBLE_DEVICES=0\ 93 | python examples/test.py \ 94 | -d msmt17 --resume $PATH_FOR_MODEL 95 | ``` 96 | For VeRi-776: 97 | ``` 98 | CUDA_VISIBLE_DEVICES=0\ 99 | python examples/test.py \ 100 | -d veri --height 256 --width 256 --resume $PATH_FOR_MODEL 101 | ``` 102 | 103 | ## Acknowledgement 104 | Some parts of the code is borrowed from [SpCL](https://github.com/yxgeee/SpCL). 105 | 106 | ## Citation 107 | If you find this code useful for your research, please consider citing our paper: 108 | 109 | ````BibTex 110 | @inproceedings{cho2022part, 111 | title={Part-based Pseudo Label Refinement for Unsupervised Person Re-identification}, 112 | author={Cho, Yoonki and Kim, Woo Jae and Hong, Seunghoon and Yoon, Sung-Eui}, 113 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 114 | pages={7308--7318}, 115 | year={2022} 116 | } 117 | ```` 118 | -------------------------------------------------------------------------------- /examples/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import random 5 | import numpy as np 6 | import sys 7 | 8 | import torch 9 | from torch import nn 10 | from torch.backends import cudnn 11 | from torch.utils.data import DataLoader 12 | 13 | from pplr import datasets 14 | from pplr.models import resnet50part 15 | from pplr.evaluators import Evaluator 16 | from pplr.utils.data import transforms as T 17 | from pplr.utils.data.preprocessor import Preprocessor 18 | from pplr.utils.logging import Logger 19 | from pplr.utils.serialization import load_checkpoint, copy_state_dict 20 | 21 | 22 | def get_data(name, data_dir, height, width, batch_size, workers): 23 | root = data_dir 24 | 25 | dataset = datasets.create(name, root) 26 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 27 | std=[0.229, 0.224, 0.225]) 28 | test_transformer = T.Compose([ 29 | T.Resize((height, width), interpolation=3), 30 | T.ToTensor(), 31 | normalizer 32 | ]) 33 | test_loader = DataLoader(Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 34 | root=dataset.images_dir, transform=test_transformer), 35 | batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True) 36 | 37 | return dataset, test_loader 38 | 39 | 40 | def main(): 41 | args = parser.parse_args() 42 | 43 | if args.seed is not None: 44 | random.seed(args.seed) 45 | np.random.seed(args.seed) 46 | torch.manual_seed(args.seed) 47 | torch.cuda.manual_seed(args.seed) 48 | torch.cuda.manual_seed_all(args.seed) 49 | cudnn.deterministic = True 50 | cudnn.benchmark = False 51 | 52 | main_worker(args) 53 | 54 | 55 | def main_worker(args): 56 | cudnn.benchmark = True 57 | 58 | log_dir = osp.dirname(args.resume) 59 | sys.stdout = Logger(osp.join(log_dir, 'log_test.txt')) 60 | print("==========\nArgs:{}\n==========".format(args)) 61 | 62 | # dataset 63 | dataset, test_loader = get_data(args.dataset, args.data_dir, args.height, args.width, args.batch_size, args.workers) 64 | 65 | # model 66 | model = resnet50part(num_parts=args.part, num_classes=3000) 67 | model.cuda() 68 | model = nn.DataParallel(model) 69 | 70 | # load a checkpoint 71 | checkpoint = load_checkpoint(args.resume) 72 | copy_state_dict(checkpoint, model) 73 | 74 | # evaluate 75 | evaluator = Evaluator(model) 76 | print("Test on {}:".format(args.dataset)) 77 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=True, rerank=args.rerank) 78 | return 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser(description="Testing the model") 83 | # data 84 | parser.add_argument('-d', '--dataset', type=str, default='market1501') 85 | parser.add_argument('-b', '--batch-size', type=int, default=64) 86 | parser.add_argument('-j', '--workers', type=int, default=4) 87 | parser.add_argument('--height', type=int, default=384, help="input height") 88 | parser.add_argument('--width', type=int, default=128, help="input width") 89 | 90 | # path 91 | working_dir = osp.dirname(osp.abspath(__file__)) 92 | parser.add_argument('--data-dir', type=str, metavar='PATH', default=osp.join(working_dir, 'data')) 93 | 94 | # testing configs 95 | parser.add_argument('--resume', type=str, required=True, metavar='PATH') 96 | parser.add_argument('--rerank', action='store_true', help="evaluation only") 97 | parser.add_argument('--seed', type=int, default=1) 98 | 99 | # model configs 100 | parser.add_argument('--part', type=int, default=3, help="number of part") 101 | 102 | main() 103 | -------------------------------------------------------------------------------- /examples/train_pplr.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import random 5 | import numpy as np 6 | import sys 7 | import time 8 | 9 | from sklearn.cluster import DBSCAN 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from torch.backends import cudnn 15 | from torch.utils.data import DataLoader 16 | 17 | from pplr import datasets 18 | from pplr.models import resnet50part 19 | from pplr.trainers import PPLRTrainer 20 | from pplr.evaluators import Evaluator, extract_all_features 21 | from pplr.utils.data import IterLoader 22 | from pplr.utils.data import transforms as T 23 | from pplr.utils.data.sampler import RandomMultipleGallerySampler 24 | from pplr.utils.data.preprocessor import Preprocessor 25 | from pplr.utils.logging import Logger 26 | from pplr.utils.faiss_rerank import compute_ranked_list, compute_jaccard_distance 27 | 28 | best_mAP = 0 29 | 30 | 31 | def get_data(name, data_dir): 32 | root = data_dir 33 | dataset = datasets.create(name, root) 34 | return dataset 35 | 36 | 37 | def get_train_loader(dataset, height, width, batch_size, workers, 38 | num_instances, iters, trainset=None): 39 | 40 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 41 | std=[0.229, 0.224, 0.225]) 42 | train_transformer = T.Compose([ 43 | T.Resize((height, width), interpolation=3), 44 | T.RandomHorizontalFlip(p=0.5), 45 | T.Pad(10), 46 | T.RandomCrop((height, width)), 47 | T.ToTensor(), 48 | normalizer, 49 | T.RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406]) 50 | ]) 51 | 52 | train_set = sorted(dataset.train) if trainset is None else sorted(trainset) 53 | rmgs_flag = num_instances > 0 54 | if rmgs_flag: 55 | sampler = RandomMultipleGallerySampler(train_set, num_instances) 56 | else: 57 | sampler = None 58 | train_loader = IterLoader( 59 | DataLoader(Preprocessor(train_set, root=dataset.images_dir, transform=train_transformer), 60 | batch_size=batch_size, num_workers=workers, sampler=sampler, 61 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=iters) 62 | 63 | return train_loader 64 | 65 | 66 | def get_test_loader(dataset, height, width, batch_size, workers, testset=None): 67 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 68 | std=[0.229, 0.224, 0.225]) 69 | 70 | test_transformer = T.Compose([ 71 | T.Resize((height, width), interpolation=3), 72 | T.ToTensor(), 73 | normalizer 74 | ]) 75 | 76 | if (testset is None): 77 | testset = list(set(dataset.query) | set(dataset.gallery)) 78 | 79 | test_loader = DataLoader( 80 | Preprocessor(testset, root=dataset.images_dir, transform=test_transformer), 81 | batch_size=batch_size, num_workers=workers, 82 | shuffle=False, pin_memory=True) 83 | 84 | return test_loader 85 | 86 | 87 | def compute_pseudo_labels(features, cluster, k1): 88 | mat_dist = compute_jaccard_distance(features, k1=k1, k2=6) 89 | ids = cluster.fit_predict(mat_dist) 90 | num_ids = len(set(ids)) - (1 if -1 in ids else 0) 91 | 92 | labels = [] 93 | outliers = 0 94 | for i, id in enumerate(ids): 95 | if id != -1: 96 | labels.append(id) 97 | else: 98 | labels.append(num_ids + outliers) 99 | outliers += 1 100 | 101 | return torch.Tensor(labels).long().detach(), num_ids 102 | 103 | 104 | def compute_cross_agreement(features_g, features_p, k, search_option=0): 105 | print("Compute cross agreement score...") 106 | N, D, P = features_p.size() 107 | score = torch.FloatTensor() 108 | end = time.time() 109 | ranked_list_g = compute_ranked_list(features_g, k=k, search_option=search_option, verbose=False) 110 | 111 | for i in range(P): 112 | ranked_list_p_i = compute_ranked_list(features_p[:, :, i], k=k, search_option=search_option, verbose=False) 113 | intersect_i = torch.FloatTensor( 114 | [len(np.intersect1d(ranked_list_g[j], ranked_list_p_i[j])) for j in range(N)]) 115 | union_i = torch.FloatTensor( 116 | [len(np.union1d(ranked_list_g[j], ranked_list_p_i[j])) for j in range(N)]) 117 | score_i = intersect_i / union_i 118 | score = torch.cat([score, score_i.unsqueeze(1)], dim=1) 119 | 120 | print("Cross agreement score time cost: {}".format(time.time() - end)) 121 | return score 122 | 123 | 124 | def main(): 125 | args = parser.parse_args() 126 | 127 | if args.seed is not None: 128 | random.seed(args.seed) 129 | np.random.seed(args.seed) 130 | torch.manual_seed(args.seed) 131 | torch.cuda.manual_seed(args.seed) 132 | torch.cuda.manual_seed_all(args.seed) 133 | cudnn.deterministic = True 134 | cudnn.benchmark = False 135 | 136 | main_worker(args) 137 | 138 | 139 | def main_worker(args): 140 | global best_mAP 141 | 142 | cudnn.benchmark = True 143 | 144 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 145 | print("==========\nArgs:{}\n==========".format(args)) 146 | 147 | # dataset 148 | dataset = get_data(args.dataset, args.data_dir) 149 | test_loader = get_test_loader(dataset, args.height, args.width, args.batch_size, args.workers) 150 | cluster_loader = get_test_loader(dataset, args.height, args.width, args.batch_size, args.workers, 151 | testset=sorted(dataset.train)) 152 | 153 | # model 154 | num_part = args.part 155 | model = resnet50part(num_parts=args.part, num_classes=3000) 156 | model.cuda() 157 | model = nn.DataParallel(model) 158 | 159 | # evaluator 160 | evaluator = Evaluator(model) 161 | 162 | # optimizer 163 | params = [] 164 | for key, value in model.named_parameters(): 165 | if not value.requires_grad: 166 | continue 167 | params += [{"params": [value], "lr": args.lr, "weight_decay": args.weight_decay}] 168 | optimizer = torch.optim.Adam(params) 169 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1) 170 | 171 | score_log = torch.FloatTensor([]) 172 | for epoch in range(args.epochs): 173 | features_g, features_p, _ = extract_all_features(model, cluster_loader) 174 | features_g = torch.cat([features_g[f].unsqueeze(0) for f, _, _ in sorted(dataset.train)], 0) 175 | features_p = torch.cat([features_p[f].unsqueeze(0) for f, _, _ in sorted(dataset.train)], 0) 176 | 177 | if epoch == 0: 178 | cluster = DBSCAN(eps=args.eps, min_samples=4, metric='precomputed', n_jobs=8) 179 | 180 | # assign pseudo-labels 181 | pseudo_labels, num_class = compute_pseudo_labels(features_g, cluster, args.k1) 182 | 183 | # Compute the cross-agreement 184 | score = compute_cross_agreement(features_g, features_p, k=args.k) 185 | score_log = torch.cat([score_log, score.unsqueeze(0)], dim=0) 186 | 187 | # generate new dataset with pseudo-labels 188 | num_outliers = 0 189 | new_dataset = [] 190 | 191 | idxs, pids = [], [] 192 | for i, ((fname, _, cid), label) in enumerate(zip(sorted(dataset.train), pseudo_labels)): 193 | pid = label.item() 194 | if pid >= num_class: # append data except outliers 195 | num_outliers += 1 196 | else: 197 | new_dataset.append((fname, pid, cid)) 198 | idxs.append(i) 199 | pids.append(pid) 200 | 201 | train_loader = get_train_loader(dataset, args.height, args.width, args.batch_size, 202 | args.workers, args.num_instances, args.iters, trainset=new_dataset) 203 | 204 | # statistics of clusters and un-clustered instances 205 | print('==> Statistics for epoch {}: {} clusters, {} un-clustered instances'.format(epoch, num_class, 206 | num_outliers)) 207 | 208 | # reindex 209 | idxs, pids = np.asarray(idxs), np.asarray(pids) 210 | features_g = features_g[idxs, :] 211 | features_p = features_p[idxs, :, :] 212 | score = score[idxs, :] 213 | 214 | # compute cluster centroids 215 | centroids_g, centroids_p = [], [] 216 | for pid in sorted(np.unique(pids)): # loop all pids 217 | idxs_p = np.where(pids == pid)[0] 218 | centroids_g.append(features_g[idxs_p].mean(0)) 219 | centroids_p.append(features_p[idxs_p].mean(0)) 220 | 221 | centroids_g = F.normalize(torch.stack(centroids_g), p=2, dim=1) 222 | model.module.classifier.weight.data[:num_class].copy_(centroids_g) 223 | for i in range(num_part): 224 | centroids_p_i = torch.stack(centroids_p)[:, :, i] 225 | centroids_p_i = F.normalize(centroids_p_i, p=2, dim=1) 226 | classifier_p_i = getattr(model.module, 'classifier' + str(i)) 227 | classifier_p_i.weight.data[:num_class].copy_(centroids_p_i) 228 | 229 | # training 230 | trainer = PPLRTrainer(model, score, num_class=num_class, num_part=num_part, beta=args.beta, 231 | aals_epoch=args.aals_epoch) 232 | 233 | trainer.train(epoch, train_loader, optimizer, print_freq=args.print_freq, train_iters=len(train_loader)) 234 | lr_scheduler.step() 235 | 236 | # evaluation 237 | if ((epoch+1) % args.eval_step == 0) or (epoch == args.epochs-1): 238 | mAP = evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=False) 239 | 240 | if mAP > best_mAP: 241 | best_mAP = mAP 242 | torch.save(model.state_dict(), osp.join(args.logs_dir, 'best.pth')) 243 | print('\n* Finished epoch {:3d} model mAP: {:5.1%} best: {:5.1%}\n'.format(epoch, mAP, best_mAP)) 244 | 245 | torch.save(model.state_dict(), osp.join(args.logs_dir, 'last.pth')) 246 | np.save(osp.join(args.logs_dir, 'scores.npy'), score_log.numpy()) 247 | 248 | # results 249 | model.load_state_dict(torch.load(osp.join(args.logs_dir, 'best.pth'))) 250 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=True) 251 | 252 | 253 | if __name__ == '__main__': 254 | parser = argparse.ArgumentParser(description="Part-based Pseudo Label Refinement") 255 | # data 256 | parser.add_argument('-d', '--dataset', type=str, default='market1501') 257 | parser.add_argument('-b', '--batch-size', type=int, default=64) 258 | parser.add_argument('-j', '--workers', type=int, default=4) 259 | parser.add_argument('-n', '--num-instances', type=int, default=4, 260 | help="each minibatch consist of " 261 | "(batch_size // num_instances) identities, and " 262 | "each identity has num_instances instances, " 263 | "default: 0 (NOT USE)") 264 | parser.add_argument('--height', type=int, default=384, help="input height") 265 | parser.add_argument('--width', type=int, default=128, help="input width") 266 | 267 | # path 268 | working_dir = osp.dirname(osp.abspath(__file__)) 269 | parser.add_argument('--data-dir', type=str, metavar='PATH', default=osp.join(working_dir, 'data')) 270 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 271 | default=osp.join(working_dir, 'logs/test')) 272 | 273 | # training configs 274 | parser.add_argument('--seed', type=int, default=1) 275 | parser.add_argument('--print-freq', type=int, default=10) 276 | parser.add_argument('--eval-step', type=int, default=5) 277 | 278 | # PPLR 279 | parser.add_argument('--part', type=int, default=3, help="number of part") 280 | parser.add_argument('--k', type=int, default=20, 281 | help="hyperparameter for cross agreement score") 282 | parser.add_argument('--beta', type=float, default=0.5, 283 | help="weighting parameter for part-guided label refinement") 284 | parser.add_argument('--aals-epoch', type=int, default=5, 285 | help="starting epoch for agreement-aware label smoothing") 286 | 287 | # optimizer 288 | parser.add_argument('--lr', type=float, default=0.00035, help="learning rate") 289 | parser.add_argument('--weight-decay', type=float, default=5e-4) 290 | parser.add_argument('--epochs', type=int, default=50) 291 | parser.add_argument('--iters', type=int, default=400) 292 | parser.add_argument('--step-size', type=int, default=20) 293 | 294 | # cluster 295 | parser.add_argument('--k1', type=int, default=30, 296 | help="hyperparameter for jaccard distance") 297 | parser.add_argument('--k2', type=int, default=6, 298 | help="hyperparameter for jaccard distance") 299 | parser.add_argument('--eps', type=float, default=0.5, 300 | help="distance threshold for DBSCAN") 301 | 302 | main() 303 | -------------------------------------------------------------------------------- /examples/train_pplr_cam.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import random 5 | import numpy as np 6 | import sys 7 | import time 8 | 9 | from sklearn.cluster import DBSCAN 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from torch.backends import cudnn 15 | from torch.utils.data import DataLoader 16 | 17 | from pplr import datasets 18 | from pplr.models import resnet50part 19 | from pplr.loss import InterCamProxy 20 | from pplr.trainers import PPLRTrainerCAM 21 | from pplr.evaluators import Evaluator, extract_all_features 22 | from pplr.utils.data import IterLoader 23 | from pplr.utils.data import transforms as T 24 | from pplr.utils.data.sampler import RandomMultipleGallerySampler 25 | from pplr.utils.data.preprocessor import Preprocessor 26 | from pplr.utils.logging import Logger 27 | from pplr.utils.faiss_rerank import compute_ranked_list, compute_jaccard_distance 28 | 29 | best_mAP = 0 30 | 31 | 32 | def get_data(name, data_dir): 33 | root = data_dir 34 | dataset = datasets.create(name, root) 35 | return dataset 36 | 37 | 38 | def get_train_loader(dataset, height, width, batch_size, workers, 39 | num_instances, iters, trainset=None): 40 | 41 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225]) 43 | train_transformer = T.Compose([ 44 | T.Resize((height, width), interpolation=3), 45 | T.RandomHorizontalFlip(p=0.5), 46 | T.Pad(10), 47 | T.RandomCrop((height, width)), 48 | T.ToTensor(), 49 | normalizer, 50 | T.RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406]) 51 | ]) 52 | 53 | train_set = sorted(dataset.train) if trainset is None else sorted(trainset) 54 | rmgs_flag = num_instances > 0 55 | if rmgs_flag: 56 | sampler = RandomMultipleGallerySampler(train_set, num_instances) 57 | else: 58 | sampler = None 59 | train_loader = IterLoader( 60 | DataLoader(Preprocessor(train_set, root=dataset.images_dir, transform=train_transformer), 61 | batch_size=batch_size, num_workers=workers, sampler=sampler, 62 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=iters) 63 | 64 | return train_loader 65 | 66 | 67 | def get_test_loader(dataset, height, width, batch_size, workers, testset=None): 68 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 69 | std=[0.229, 0.224, 0.225]) 70 | 71 | test_transformer = T.Compose([ 72 | T.Resize((height, width), interpolation=3), 73 | T.ToTensor(), 74 | normalizer 75 | ]) 76 | 77 | if (testset is None): 78 | testset = list(set(dataset.query) | set(dataset.gallery)) 79 | 80 | test_loader = DataLoader( 81 | Preprocessor(testset, root=dataset.images_dir, transform=test_transformer), 82 | batch_size=batch_size, num_workers=workers, 83 | shuffle=False, pin_memory=True) 84 | 85 | return test_loader 86 | 87 | 88 | def compute_pseudo_labels(features, cluster, k1): 89 | mat_dist = compute_jaccard_distance(features, k1=k1, k2=6) 90 | ids = cluster.fit_predict(mat_dist) 91 | num_ids = len(set(ids)) - (1 if -1 in ids else 0) 92 | 93 | labels = [] 94 | outliers = 0 95 | for i, id in enumerate(ids): 96 | if id != -1: 97 | labels.append(id) 98 | else: 99 | labels.append(num_ids + outliers) 100 | outliers += 1 101 | 102 | return torch.Tensor(labels).long().detach(), num_ids 103 | 104 | 105 | def compute_cross_agreement(features_g, features_p, k, search_option=0): 106 | print("Compute cross-agreement...") 107 | N, D, P = features_p.size() 108 | score = torch.FloatTensor() 109 | end = time.time() 110 | ranked_list_g = compute_ranked_list(features_g, k=k, search_option=search_option, verbose=False) 111 | 112 | for i in range(P): 113 | ranked_list_p_i = compute_ranked_list(features_p[:, :, i], k=k, search_option=search_option, verbose=False) 114 | intersect_i = torch.FloatTensor( 115 | [len(np.intersect1d(ranked_list_g[j], ranked_list_p_i[j])) for j in range(N)]) 116 | union_i = torch.FloatTensor( 117 | [len(np.union1d(ranked_list_g[j], ranked_list_p_i[j])) for j in range(N)]) 118 | score_i = intersect_i / union_i 119 | score = torch.cat([score, score_i.unsqueeze(1)], dim=1) 120 | 121 | print("Cross agreement time cost: {}".format(time.time() - end)) 122 | return score 123 | 124 | 125 | def main(): 126 | args = parser.parse_args() 127 | 128 | if args.seed is not None: 129 | random.seed(args.seed) 130 | np.random.seed(args.seed) 131 | torch.manual_seed(args.seed) 132 | torch.cuda.manual_seed(args.seed) 133 | torch.cuda.manual_seed_all(args.seed) 134 | cudnn.deterministic = True 135 | cudnn.benchmark = False 136 | 137 | main_worker(args) 138 | 139 | 140 | def main_worker(args): 141 | global best_mAP 142 | 143 | cudnn.benchmark = True 144 | 145 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 146 | print("==========\nArgs:{}\n==========".format(args)) 147 | 148 | # dataset 149 | dataset = get_data(args.dataset, args.data_dir) 150 | test_loader = get_test_loader(dataset, args.height, args.width, args.batch_size, args.workers) 151 | cluster_loader = get_test_loader(dataset, args.height, args.width, args.batch_size, args.workers, 152 | testset=sorted(dataset.train)) 153 | 154 | # model 155 | num_part = args.part 156 | model = resnet50part(num_parts=args.part, num_classes=3000) 157 | model.cuda() 158 | model = nn.DataParallel(model) 159 | 160 | # evaluator 161 | evaluator = Evaluator(model) 162 | 163 | # optimizer 164 | params = [] 165 | for key, value in model.named_parameters(): 166 | if not value.requires_grad: 167 | continue 168 | params += [{"params": [value], "lr": args.lr, "weight_decay": args.weight_decay}] 169 | optimizer = torch.optim.Adam(params) 170 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1) 171 | 172 | score_log = torch.FloatTensor([]) 173 | for epoch in range(args.epochs): 174 | features_g, features_p, _ = extract_all_features(model, cluster_loader) 175 | features_g = torch.cat([features_g[f].unsqueeze(0) for f, _, _ in sorted(dataset.train)], 0) 176 | features_p = torch.cat([features_p[f].unsqueeze(0) for f, _, _ in sorted(dataset.train)], 0) 177 | 178 | if epoch == 0: 179 | cluster = DBSCAN(eps=args.eps, min_samples=4, metric='precomputed', n_jobs=8) 180 | 181 | # assign pseudo-labels 182 | pseudo_labels, num_class = compute_pseudo_labels(features_g, cluster, args.k1) 183 | 184 | # compute the cross-agreement 185 | score = compute_cross_agreement(features_g, features_p, k=args.k) 186 | score_log = torch.cat([score_log, score.unsqueeze(0)], dim=0) 187 | 188 | # generate new dataset with pseudo-labels 189 | num_outliers = 0 190 | new_dataset = [] 191 | 192 | idxs, cids, pids = [], [], [] 193 | for i, ((fname, _, cid), label) in enumerate(zip(sorted(dataset.train), pseudo_labels)): 194 | pid = label.item() 195 | if pid >= num_class: # append data except outliers 196 | num_outliers += 1 197 | else: 198 | new_dataset.append((fname, pid, cid)) 199 | idxs.append(i) 200 | cids.append(cid) 201 | pids.append(pid) 202 | 203 | train_loader = get_train_loader(dataset, args.height, args.width, args.batch_size, 204 | args.workers, args.num_instances, args.iters, trainset=new_dataset) 205 | 206 | # statistics of clusters and un-clustered instances 207 | print('==> Statistics for epoch {}: {} clusters, {} un-clustered instances'.format(epoch, num_class, 208 | num_outliers)) 209 | 210 | # reindex 211 | idxs, cids, pids = np.asarray(idxs), np.asarray(cids), np.asarray(pids) 212 | features_g = features_g[idxs, :] 213 | features_p = features_p[idxs, :, :] 214 | score = score[idxs, :] 215 | 216 | # compute cluster centroids and camera-aware proxies 217 | centroids_g, centroids_p = [], [] 218 | cam_proxy, cam_proxy_p, cam_proxy_pids, cam_proxy_cids = [], [], [], [] 219 | for pid in sorted(np.unique(pids)): # loop all pids 220 | idxs_p = np.where(pids == pid)[0] 221 | centroids_g.append(features_g[idxs_p].mean(0)) 222 | centroids_p.append(features_p[idxs_p].mean(0)) 223 | 224 | for cid in sorted(np.unique(cids[idxs_p])): # loop all cids for pid 225 | idxs_c = np.where(cids == cid)[0] 226 | idxs_cp = np.intersect1d(idxs_p, idxs_c) 227 | cam_proxy.append(features_g[idxs_cp].mean(0)) 228 | cam_proxy_p.append(features_p[idxs_cp].mean(0)) 229 | cam_proxy_pids.append(pid) 230 | cam_proxy_cids.append(cid) 231 | 232 | centroids_g = F.normalize(torch.stack(centroids_g), p=2, dim=1) 233 | model.module.classifier.weight.data[:num_class].copy_(centroids_g) 234 | memory = InterCamProxy(centroids_g.size(1), len(cam_proxy_pids)).cuda() 235 | memory.proxy = F.normalize(torch.stack(cam_proxy), p=2, dim=1).cuda() 236 | memory.pids = torch.Tensor(cam_proxy_pids).long().cuda() 237 | memory.cids = torch.Tensor(cam_proxy_cids).long().cuda() 238 | 239 | memory_p = [] 240 | for i in range(num_part): 241 | centroids_p_i = torch.stack(centroids_p)[:, :, i] 242 | centroids_p_i = F.normalize(centroids_p_i, p=2, dim=1) 243 | classifier_p_i = getattr(model.module, 'classifier' + str(i)) 244 | classifier_p_i.weight.data[:num_class].copy_(centroids_p_i) 245 | 246 | memory_p_i = InterCamProxy(centroids_g.size(1), len(cam_proxy_pids)).cuda() 247 | cam_proxy_p_i = torch.stack(cam_proxy_p)[:, :, i] 248 | memory_p_i.proxy = F.normalize(cam_proxy_p_i, p=2, dim=1).cuda() 249 | memory_p_i.pids = torch.Tensor(cam_proxy_pids).long().cuda() 250 | memory_p_i.cids = torch.Tensor(cam_proxy_cids).long().cuda() 251 | memory_p.append(memory_p_i) 252 | 253 | # training 254 | trainer = PPLRTrainerCAM(model, score, memory, memory_p, num_class=num_class, num_part=num_part, 255 | beta=args.beta, aals_epoch=args.aals_epoch, lam_cam=args.lam_cam) 256 | 257 | trainer.train(epoch, train_loader, optimizer, print_freq=args.print_freq, train_iters=len(train_loader)) 258 | lr_scheduler.step() 259 | 260 | # evaluation 261 | if ((epoch+1) % args.eval_step == 0) or (epoch == args.epochs-1): 262 | mAP = evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=False) 263 | 264 | if mAP > best_mAP: 265 | best_mAP = mAP 266 | torch.save(model.state_dict(), osp.join(args.logs_dir, 'best.pth')) 267 | print('\n* Finished epoch {:3d} model mAP: {:5.1%} best: {:5.1%}\n'.format(epoch, mAP, best_mAP)) 268 | 269 | torch.save(model.state_dict(), osp.join(args.logs_dir, 'last.pth')) 270 | np.save(osp.join(args.logs_dir, 'scores.npy'), score_log.numpy()) 271 | 272 | # results 273 | model.load_state_dict(torch.load(osp.join(args.logs_dir, 'best.pth'))) 274 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=True) 275 | 276 | 277 | if __name__ == '__main__': 278 | parser = argparse.ArgumentParser(description="Part-based Pseudo Label Refinement with Camera-Aware Proxies") 279 | # data 280 | parser.add_argument('-d', '--dataset', type=str, default='market1501') 281 | parser.add_argument('-b', '--batch-size', type=int, default=64) 282 | parser.add_argument('-j', '--workers', type=int, default=4) 283 | parser.add_argument('-n', '--num-instances', type=int, default=4, 284 | help="each minibatch consist of " 285 | "(batch_size // num_instances) identities, and " 286 | "each identity has num_instances instances, " 287 | "default: 0 (NOT USE)") 288 | parser.add_argument('--height', type=int, default=384, help="input height") 289 | parser.add_argument('--width', type=int, default=128, help="input width") 290 | 291 | # path 292 | working_dir = osp.dirname(osp.abspath(__file__)) 293 | parser.add_argument('--data-dir', type=str, metavar='PATH', default=osp.join(working_dir, 'data')) 294 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 295 | default=osp.join(working_dir, 'logs/test')) 296 | 297 | # training configs 298 | parser.add_argument('--seed', type=int, default=1) 299 | parser.add_argument('--print-freq', type=int, default=10) 300 | parser.add_argument('--eval-step', type=int, default=5) 301 | 302 | # PPLR 303 | parser.add_argument('--part', type=int, default=3, help="number of part") 304 | parser.add_argument('--k', type=int, default=20, 305 | help="hyperparameter for cross agreement score") 306 | parser.add_argument('--beta', type=float, default=0.5, 307 | help="weighting parameter for part-guided label refinement") 308 | parser.add_argument('--aals-epoch', type=int, default=5, 309 | help="starting epoch for agreement-aware label smoothing") 310 | parser.add_argument('--lam-cam', type=float, default=0.5, 311 | help="weighting parameter of inter-camera contrastive loss") 312 | 313 | # optimizer 314 | parser.add_argument('--lr', type=float, default=0.00035, help="learning rate") 315 | parser.add_argument('--weight-decay', type=float, default=5e-4) 316 | parser.add_argument('--epochs', type=int, default=50) 317 | parser.add_argument('--iters', type=int, default=400) 318 | parser.add_argument('--step-size', type=int, default=20) 319 | 320 | # cluster 321 | parser.add_argument('--k1', type=int, default=30, 322 | help="hyperparameter for jaccard distance") 323 | parser.add_argument('--k2', type=int, default=6, 324 | help="hyperparameter for jaccard distance") 325 | parser.add_argument('--eps', type=float, default=0.5, 326 | help="distance threshold for DBSCAN") 327 | 328 | main() 329 | -------------------------------------------------------------------------------- /figs/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoonkicho/PPLR/5337193c4a2c7d0870c558a37d28634744d5b6e5/figs/overview.jpg -------------------------------------------------------------------------------- /pplr/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import models 6 | from . import utils 7 | from . import evaluators 8 | from . import trainers 9 | 10 | __version__ = '1.0.0' 11 | -------------------------------------------------------------------------------- /pplr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .market1501 import Market1501 5 | from .msmt17 import MSMT17 6 | from .veri import VeRi 7 | from .dukemtmc import DukeMTMCreID 8 | 9 | 10 | __factory = { 11 | 'market1501': Market1501, 12 | 'msmt17': MSMT17, 13 | 'veri': VeRi, 14 | 'dukemtmc': DukeMTMCreID, 15 | } 16 | 17 | 18 | def names(): 19 | return sorted(__factory.keys()) 20 | 21 | 22 | def create(name, root, *args, **kwargs): 23 | """ 24 | Create a dataset instance. 25 | 26 | Parameters 27 | ---------- 28 | name : str 29 | The dataset name. 30 | root : str 31 | The path to the dataset directory. 32 | split_id : int, optional 33 | The index of data split. Default: 0 34 | num_val : int or float, optional 35 | When int, it means the number of validation identities. When float, 36 | it means the proportion of validation to all the trainval. Default: 100 37 | download : bool, optional 38 | If True, will download the dataset. Default: False 39 | """ 40 | if name not in __factory: 41 | raise KeyError("Unknown dataset:", name) 42 | return __factory[name](root, *args, **kwargs) 43 | 44 | 45 | def get_dataset(name, root, *args, **kwargs): 46 | warnings.warn("get_dataset is deprecated. Use create instead.") 47 | return create(name, root, *args, **kwargs) 48 | -------------------------------------------------------------------------------- /pplr/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from ..utils.osutils import mkdir_if_missing 15 | from ..utils.data import BaseImageDataset 16 | 17 | 18 | class DukeMTMCreID(BaseImageDataset): 19 | """ 20 | DukeMTMC-reID 21 | Reference: 22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 25 | 26 | Dataset statistics: 27 | # identities: 1404 (train + query) 28 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 29 | # cameras: 8 30 | """ 31 | dataset_dir = 'DukeMTMC-reID' 32 | 33 | def __init__(self, root='./dataset', verbose=True, **kwargs): 34 | super(DukeMTMCreID, self).__init__() 35 | self.dataset_dir = osp.join(root, self.dataset_dir) 36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 37 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 38 | self.query_dir = osp.join(self.dataset_dir, 'query') 39 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 40 | 41 | self._download_data() 42 | self._check_before_run() 43 | 44 | train = self._process_dir(self.train_dir, relabel=True) 45 | query = self._process_dir(self.query_dir, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-reID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 59 | 60 | def _download_data(self): 61 | if osp.exists(self.dataset_dir): 62 | print("This dataset has been downloaded.") 63 | return 64 | 65 | print("Creating directory {}".format(self.dataset_dir)) 66 | mkdir_if_missing(self.dataset_dir) 67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 68 | 69 | print("Downloading DukeMTMC-reID dataset") 70 | urllib.request.urlretrieve(self.dataset_url, fpath) 71 | 72 | print("Extracting files") 73 | zip_ref = zipfile.ZipFile(fpath, 'r') 74 | zip_ref.extractall(self.dataset_dir) 75 | zip_ref.close() 76 | 77 | def _check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 83 | if not osp.exists(self.query_dir): 84 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 87 | 88 | def _process_dir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | pid_container = set() 92 | for img_path in img_paths: 93 | pid, _ = map(int, pattern.search(img_path).groups()) 94 | pid_container.add(pid) 95 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 96 | 97 | dataset = [] 98 | for img_path in img_paths: 99 | pid, camid = map(int, pattern.search(img_path).groups()) 100 | assert 1 <= camid <= 8 101 | camid -= 1 # index starts from 0 102 | if relabel: pid = pid2label[pid] 103 | dataset.append((img_path, pid, camid)) 104 | 105 | return dataset 106 | -------------------------------------------------------------------------------- /pplr/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | 6 | from ..utils.data import BaseImageDataset 7 | 8 | 9 | class Market1501(BaseImageDataset): 10 | """ 11 | Market1501 12 | Reference: 13 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 14 | URL: http://www.liangzheng.org/Project/project_reid.html 15 | 16 | Dataset statistics: 17 | # identities: 1501 (+1 for background) 18 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 19 | """ 20 | dataset_dir = 'Market-1501-v15.09.15' 21 | 22 | def __init__(self, root, verbose=True, **kwargs): 23 | super(Market1501, self).__init__() 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 26 | self.query_dir = osp.join(self.dataset_dir, 'query') 27 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 28 | 29 | self._check_before_run() 30 | 31 | train = self._process_dir(self.train_dir, relabel=True) 32 | query = self._process_dir(self.query_dir, relabel=False) 33 | gallery = self._process_dir(self.gallery_dir, relabel=False) 34 | 35 | if verbose: 36 | print("=> Market1501 loaded") 37 | self.print_dataset_statistics(train, query, gallery) 38 | 39 | self.train = train 40 | self.query = query 41 | self.gallery = gallery 42 | 43 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 44 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 45 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 46 | 47 | def _check_before_run(self): 48 | """Check if all files are available before going deeper""" 49 | if not osp.exists(self.dataset_dir): 50 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 51 | if not osp.exists(self.train_dir): 52 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 53 | if not osp.exists(self.query_dir): 54 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 55 | if not osp.exists(self.gallery_dir): 56 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 57 | 58 | def _process_dir(self, dir_path, relabel=False): 59 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 60 | pattern = re.compile(r'([-\d]+)_c(\d)') 61 | 62 | pid_container = set() 63 | for img_path in img_paths: 64 | pid, _ = map(int, pattern.search(img_path).groups()) 65 | if pid == -1: continue # junk images are just ignored 66 | pid_container.add(pid) 67 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 68 | 69 | dataset = [] 70 | for img_path in img_paths: 71 | pid, camid = map(int, pattern.search(img_path).groups()) 72 | if pid == -1: continue # junk images are just ignored 73 | assert 0 <= pid <= 1501 # pid == 0 means background 74 | assert 1 <= camid <= 6 75 | camid -= 1 # index starts from 0 76 | if relabel: pid = pid2label[pid] 77 | dataset.append((img_path, pid, camid)) 78 | 79 | return dataset 80 | -------------------------------------------------------------------------------- /pplr/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import tarfile 4 | 5 | import glob 6 | import re 7 | import urllib 8 | import zipfile 9 | 10 | from ..utils.osutils import mkdir_if_missing 11 | from ..utils.serialization import write_json 12 | 13 | 14 | def _pluck_msmt(list_file, subdir, pattern=re.compile(r'([-\d]+)_([-\d]+)_([-\d]+)')): 15 | with open(list_file, 'r') as f: 16 | lines = f.readlines() 17 | ret = [] 18 | pids = [] 19 | for line in lines: 20 | line = line.strip() 21 | fname = line.split(' ')[0] 22 | pid, _, cam = map(int, pattern.search(osp.basename(fname)).groups()) 23 | if pid not in pids: 24 | pids.append(pid) 25 | ret.append((osp.join(subdir,fname), pid, cam)) 26 | return ret, pids 27 | 28 | 29 | class Dataset_MSMT(object): 30 | def __init__(self, root): 31 | self.root = root 32 | self.train, self.val, self.trainval = [], [], [] 33 | self.query, self.gallery = [], [] 34 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 35 | 36 | @property 37 | def images_dir(self): 38 | return osp.join(self.root, 'MSMT17_V1') 39 | 40 | def load(self, verbose=True): 41 | exdir = osp.join(self.root, 'MSMT17_V1') 42 | self.train, train_pids = _pluck_msmt(osp.join(exdir, 'list_train.txt'), 'train') 43 | self.val, val_pids = _pluck_msmt(osp.join(exdir, 'list_val.txt'), 'train') 44 | self.train = self.train + self.val 45 | self.query, query_pids = _pluck_msmt(osp.join(exdir, 'list_query.txt'), 'test') 46 | self.gallery, gallery_pids = _pluck_msmt(osp.join(exdir, 'list_gallery.txt'), 'test') 47 | self.num_train_pids = len(list(set(train_pids).union(set(val_pids)))) 48 | 49 | if verbose: 50 | print(self.__class__.__name__, "dataset loaded") 51 | print(" subset | # ids | # images") 52 | print(" ---------------------------") 53 | print(" train | {:5d} | {:8d}" 54 | .format(self.num_train_pids, len(self.train))) 55 | print(" query | {:5d} | {:8d}" 56 | .format(len(query_pids), len(self.query))) 57 | print(" gallery | {:5d} | {:8d}" 58 | .format(len(gallery_pids), len(self.gallery))) 59 | 60 | 61 | class MSMT17(Dataset_MSMT): 62 | 63 | def __init__(self, root, split_id=0, download=True): 64 | super(MSMT17, self).__init__(root) 65 | 66 | if download: 67 | self.download() 68 | 69 | self.load() 70 | 71 | def download(self): 72 | 73 | import re 74 | import hashlib 75 | import shutil 76 | from glob import glob 77 | from zipfile import ZipFile 78 | 79 | raw_dir = osp.join(self.root) 80 | mkdir_if_missing(raw_dir) 81 | 82 | # Download the raw zip file 83 | fpath = osp.join(raw_dir, 'MSMT17_V1') 84 | if osp.isdir(fpath): 85 | print("Using downloaded file: " + fpath) 86 | else: 87 | raise RuntimeError("Please download the dataset manually to {}".format(fpath)) 88 | -------------------------------------------------------------------------------- /pplr/datasets/veri.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import re 7 | import os.path as osp 8 | 9 | from ..utils.data import BaseImageDataset 10 | 11 | 12 | class VeRi(BaseImageDataset): 13 | """ 14 | VeRi 15 | Reference: 16 | Liu, X., Liu, W., Ma, H., Fu, H.: Large-scale vehicle re-identification in urban surveillance videos. In: IEEE % 17 | International Conference on Multimedia and Expo. (2016) accepted. 18 | Dataset statistics: 19 | # identities: 776 vehicles(576 for training and 200 for testing) 20 | # images: 37778 (train) + 11579 (query) 21 | """ 22 | dataset_dir = 'VeRi' 23 | 24 | def __init__(self, root, verbose=True, **kwargs): 25 | super(VeRi, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 30 | 31 | self.check_before_run() 32 | 33 | train = self.process_dir(self.train_dir, relabel=True) 34 | query = self.process_dir(self.query_dir, relabel=False) 35 | gallery = self.process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | print('=> VeRi loaded') 39 | self.print_dataset_statistics(train, query, gallery) 40 | 41 | self.train = train 42 | self.query = query 43 | self.gallery = gallery 44 | 45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 48 | 49 | def check_before_run(self): 50 | """Check if all files are available before going deeper""" 51 | if not osp.exists(self.dataset_dir): 52 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 53 | if not osp.exists(self.train_dir): 54 | raise RuntimeError('"{}" is not available'.format(self.train_dir)) 55 | if not osp.exists(self.query_dir): 56 | raise RuntimeError('"{}" is not available'.format(self.query_dir)) 57 | if not osp.exists(self.gallery_dir): 58 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir)) 59 | 60 | def process_dir(self, dir_path, relabel=False): 61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 62 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 63 | 64 | pid_container = set() 65 | for img_path in img_paths: 66 | pid, _ = map(int, pattern.search(img_path).groups()) 67 | if pid == -1: 68 | continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | dataset = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1: 76 | continue # junk images are just ignored 77 | assert 0 <= pid <= 776 # pid == 0 means background 78 | assert 1 <= camid <= 20 79 | camid -= 1 # index starts from 0 80 | if relabel: 81 | pid = pid2label[pid] 82 | dataset.append((img_path, pid, camid)) 83 | 84 | return dataset -------------------------------------------------------------------------------- /pplr/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap' 10 | ] 11 | -------------------------------------------------------------------------------- /pplr/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from ..utils import to_torch 5 | 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | with torch.no_grad(): 9 | output, target = to_torch(output), to_torch(target) 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | ret = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 20 | ret.append(correct_k.mul_(1. / batch_size)) 21 | return ret 22 | -------------------------------------------------------------------------------- /pplr/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /pplr/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import collections 4 | from collections import OrderedDict 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import random 9 | import copy 10 | 11 | from .evaluation_metrics import cmc, mean_ap 12 | from .utils.meters import AverageMeter 13 | from .utils.rerank import re_ranking 14 | from .utils import to_torch 15 | 16 | 17 | def extract_cnn_feature(model, inputs): 18 | inputs = to_torch(inputs).cuda() 19 | outputs = model(inputs) 20 | outputs = outputs.data.cpu() 21 | return outputs 22 | 23 | 24 | def extract_features(model, data_loader, print_freq=50): 25 | model.eval() 26 | batch_time = AverageMeter() 27 | data_time = AverageMeter() 28 | 29 | features = OrderedDict() 30 | labels = OrderedDict() 31 | 32 | end = time.time() 33 | with torch.no_grad(): 34 | for i, (imgs, fnames, pids, _, _) in enumerate(data_loader): 35 | data_time.update(time.time() - end) 36 | 37 | outputs = extract_cnn_feature(model, imgs) 38 | for fname, output, pid in zip(fnames, outputs, pids): 39 | features[fname] = output 40 | labels[fname] = pid 41 | 42 | batch_time.update(time.time() - end) 43 | end = time.time() 44 | 45 | if (i + 1) % print_freq == 0: 46 | print('Extract Features: [{}/{}]\t' 47 | 'Time {:.3f} ({:.3f})\t' 48 | 'Data {:.3f} ({:.3f})\t' 49 | .format(i + 1, len(data_loader), 50 | batch_time.val, batch_time.avg, 51 | data_time.val, data_time.avg)) 52 | 53 | return features, labels 54 | 55 | 56 | def extract_all_features(model, data_loader, print_freq=50): 57 | model.eval() 58 | batch_time = AverageMeter() 59 | data_time = AverageMeter() 60 | 61 | features_g = OrderedDict() 62 | features_p = OrderedDict() 63 | labels = OrderedDict() 64 | 65 | end = time.time() 66 | with torch.no_grad(): 67 | for i, (imgs, fnames, pids, _, _) in enumerate(data_loader): 68 | data_time.update(time.time() - end) 69 | inputs = to_torch(imgs).cuda() 70 | if isinstance(model, nn.DataParallel): 71 | outputs_g, outputs_p = model.module.extract_all_features(inputs) 72 | else: 73 | outputs_g, outputs_p = model.extract_all_features(inputs) 74 | outputs_g, outputs_p = outputs_g.data.cpu(), outputs_p.data.cpu() 75 | 76 | for fname, output_g, output_p, pid in zip(fnames, outputs_g, outputs_p, pids): 77 | features_g[fname] = output_g 78 | features_p[fname] = output_p 79 | labels[fname] = pid 80 | 81 | batch_time.update(time.time() - end) 82 | end = time.time() 83 | 84 | if (i + 1) % print_freq == 0: 85 | print('Extract Features: [{}/{}]\t' 86 | 'Time {:.3f} ({:.3f})\t' 87 | 'Data {:.3f} ({:.3f})\t' 88 | .format(i + 1, len(data_loader), 89 | batch_time.val, batch_time.avg, 90 | data_time.val, data_time.avg)) 91 | 92 | return features_g, features_p, labels 93 | 94 | 95 | def pairwise_distance(features, query=None, gallery=None): 96 | if query is None and gallery is None: 97 | n = len(features) 98 | x = torch.cat(list(features.values())) 99 | x = x.view(n, -1) 100 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 101 | dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t()) 102 | return dist_m 103 | 104 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 105 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 106 | m, n = x.size(0), y.size(0) 107 | x = x.view(m, -1) 108 | y = y.view(n, -1) 109 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 110 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 111 | dist_m.addmm_(1, -2, x, y.t()) 112 | return dist_m, x.numpy(), y.numpy() 113 | 114 | 115 | def evaluate_all(query_features, gallery_features, distmat, query=None, gallery=None, 116 | query_ids=None, gallery_ids=None, 117 | query_cams=None, gallery_cams=None, 118 | cmc_topk=(1, 5, 10), cmc_flag=False): 119 | if query is not None and gallery is not None: 120 | query_ids = [pid for _, pid, _ in query] 121 | gallery_ids = [pid for _, pid, _ in gallery] 122 | query_cams = [cam for _, _, cam in query] 123 | gallery_cams = [cam for _, _, cam in gallery] 124 | else: 125 | assert (query_ids is not None and gallery_ids is not None 126 | and query_cams is not None and gallery_cams is not None) 127 | 128 | # Compute mean AP 129 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 130 | print('Mean AP: {:4.1%}'.format(mAP)) 131 | 132 | if (not cmc_flag): 133 | return mAP 134 | 135 | cmc_configs = { 136 | 'market1501': dict(separate_camera_set=False, 137 | single_gallery_shot=False, 138 | first_match_break=True),} 139 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 140 | query_cams, gallery_cams, **params) 141 | for name, params in cmc_configs.items()} 142 | 143 | print('CMC Scores:') 144 | for k in cmc_topk: 145 | print(' top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k-1])) 146 | return cmc_scores['market1501'], mAP 147 | 148 | 149 | class Evaluator(object): 150 | def __init__(self, model): 151 | super(Evaluator, self).__init__() 152 | self.model = model 153 | 154 | def evaluate(self, data_loader, query, gallery, cmc_flag=False, rerank=False): 155 | features, _ = extract_features(self.model, data_loader) 156 | distmat, query_features, gallery_features = pairwise_distance(features, query, gallery) 157 | results = evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 158 | 159 | if (not rerank): 160 | return results 161 | 162 | print('Applying person re-ranking ...') 163 | distmat_qq, _, _ = pairwise_distance(features, query, query) 164 | distmat_gg, _, _ = pairwise_distance(features, gallery, gallery) 165 | distmat = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy()) 166 | return evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 167 | -------------------------------------------------------------------------------- /pplr/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .triplet import TripletLoss, SoftTripletLoss 4 | from .crossentropy import CrossEntropyLabelSmooth, SoftEntropy 5 | from .loss import AALS, PGLR, InterCamProxy 6 | 7 | __all__ = [ 8 | 'TripletLoss', 9 | 'CrossEntropyLabelSmooth', 10 | 'SoftTripletLoss', 11 | 'SoftEntropy', 12 | 'AALS', 13 | 'PGLR', 14 | 'InterCamProxy' 15 | ] -------------------------------------------------------------------------------- /pplr/loss/crossentropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import * 5 | 6 | 7 | class CrossEntropyLabelSmooth(nn.Module): 8 | """Cross entropy loss with label smoothing regularizer. 9 | 10 | Reference: 11 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 12 | Equation: y = (1 - epsilon) * y + epsilon / K. 13 | 14 | Args: 15 | num_classes (int): number of classes. 16 | epsilon (float): weight. 17 | """ 18 | 19 | def __init__(self, num_classes, epsilon=0.1): 20 | super(CrossEntropyLabelSmooth, self).__init__() 21 | self.num_classes = num_classes 22 | self.epsilon = epsilon 23 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 24 | 25 | def forward(self, inputs, targets): 26 | """ 27 | Args: 28 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 29 | targets: ground truth labels with shape (num_classes) 30 | """ 31 | log_probs = self.logsoftmax(inputs) 32 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 33 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 34 | loss = (- targets * log_probs).mean(0).sum() 35 | return loss 36 | 37 | 38 | class SoftEntropy(nn.Module): 39 | def __init__(self): 40 | super(SoftEntropy, self).__init__() 41 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 42 | 43 | def forward(self, inputs, targets): 44 | log_probs = self.logsoftmax(inputs) 45 | loss = (- F.softmax(targets, dim=1).detach() * log_probs).mean(0).sum() 46 | return loss 47 | -------------------------------------------------------------------------------- /pplr/loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AALS(nn.Module): 7 | """ Agreement-aware label smoothing """ 8 | def __init__(self): 9 | super(AALS, self).__init__() 10 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 11 | 12 | def forward(self, logits, targets, ca): 13 | log_preds = self.logsoftmax(logits) # B * C 14 | targets = torch.zeros_like(log_preds).scatter_(1, targets.unsqueeze(1), 1) 15 | uni = (torch.ones_like(log_preds) / log_preds.size(-1)).cuda() 16 | 17 | loss_ce = (- targets * log_preds).sum(1) 18 | loss_kld = F.kl_div(log_preds, uni, reduction='none').sum(1) 19 | loss = (ca * loss_ce + (1-ca) * loss_kld).mean() 20 | return loss 21 | 22 | 23 | class PGLR(nn.Module): 24 | """ Part-guided label refinement """ 25 | def __init__(self, lam=0.5): 26 | super(PGLR, self).__init__() 27 | self.softmax = nn.Softmax(dim=1) 28 | self.logsoftmax = nn.LogSoftmax(dim=1) 29 | self.lam = lam 30 | 31 | def forward(self, logits_g, logits_p, targets, ca): 32 | targets = torch.zeros_like(logits_g).scatter_(1, targets.unsqueeze(1), 1) 33 | w = torch.softmax(ca, dim=1) # B * P 34 | w = torch.unsqueeze(w, 1) # B * 1 * P 35 | preds_p = self.softmax(logits_p) # B * C * P 36 | ensembled_preds = (preds_p * w).sum(2).detach() # B * class_num 37 | refined_targets = self.lam * targets + (1-self.lam) * ensembled_preds 38 | 39 | log_preds_g = self.logsoftmax(logits_g) 40 | loss = (-refined_targets * log_preds_g).sum(1).mean() 41 | return loss 42 | 43 | 44 | class InterCamProxy(nn.Module): 45 | """ Camera-aware proxy with inter-camera contrastive learning """ 46 | def __init__(self, num_features, num_samples, num_hards=50, temp=0.07): 47 | super(InterCamProxy, self).__init__() 48 | self.num_features = num_features # D 49 | self.num_samples = num_samples # N 50 | self.num_hards = num_hards 51 | self.logsoftmax = nn.LogSoftmax(dim=0) 52 | self.temp = temp 53 | self.register_buffer('proxy', torch.zeros(num_samples, num_features)) 54 | self.register_buffer('pids', torch.zeros(num_samples).long()) 55 | self.register_buffer('cids', torch.zeros(num_samples).long()) 56 | 57 | """ Inter-camera contrastive loss """ 58 | def forward(self, inputs, targets, cams): 59 | B, D = inputs.shape 60 | inputs = F.normalize(inputs, dim=1).cuda() # B * D 61 | sims = inputs @ self.proxy.T # B * N 62 | sims /= self.temp 63 | temp_sims = sims.detach().clone() 64 | 65 | loss = torch.tensor([0.]).cuda() 66 | for i in range(B): 67 | pos_mask = (targets[i] == self.pids).float() * (cams[i] != self.cids).float() 68 | neg_mask = (targets[i] != self.pids).float() 69 | pos_idx = torch.nonzero(pos_mask > 0).squeeze(-1) 70 | if len(pos_idx) == 0: 71 | continue 72 | hard_neg_idx = torch.sort(temp_sims[i] + (-9999999.) * (1.-neg_mask), descending=True).indices[:self.num_hards] 73 | sims_i = sims[i, torch.cat([pos_idx, hard_neg_idx])] 74 | targets_i = torch.zeros(len(sims_i)).cuda() 75 | targets_i[:len(pos_idx)] = 1.0 / len(pos_idx) 76 | loss += - (targets_i * self.logsoftmax(sims_i)).sum() 77 | 78 | loss /= B 79 | return loss 80 | -------------------------------------------------------------------------------- /pplr/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def euclidean_dist(x, y): 9 | m, n = x.size(0), y.size(0) 10 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 11 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 12 | dist = xx + yy 13 | dist.addmm_(1, -2, x, y.t()) 14 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 15 | return dist 16 | 17 | def cosine_dist(x, y): 18 | bs1, bs2 = x.size(0), y.size(0) 19 | frac_up = torch.matmul(x, y.transpose(0, 1)) 20 | frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \ 21 | (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1) 22 | cosine = frac_up / frac_down 23 | return 1-cosine 24 | 25 | def _batch_hard(mat_distance, mat_similarity, indice=False): 26 | sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1, descending=True) 27 | hard_p = sorted_mat_distance[:, 0] 28 | hard_p_indice = positive_indices[:, 0] 29 | sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1, descending=False) 30 | hard_n = sorted_mat_distance[:, 0] 31 | hard_n_indice = negative_indices[:, 0] 32 | if(indice): 33 | return hard_p, hard_n, hard_p_indice, hard_n_indice 34 | return hard_p, hard_n 35 | 36 | class TripletLoss(nn.Module): 37 | ''' 38 | Compute Triplet loss augmented with Batch Hard 39 | Details can be seen in 'In defense of the Triplet Loss for Person Re-Identification' 40 | ''' 41 | 42 | def __init__(self, margin, normalize_feature=False): 43 | super(TripletLoss, self).__init__() 44 | self.margin = margin 45 | self.normalize_feature = normalize_feature 46 | self.margin_loss = nn.MarginRankingLoss(margin=margin).cuda() 47 | 48 | def forward(self, emb, label): 49 | if self.normalize_feature: 50 | # equal to cosine similarity 51 | emb = F.normalize(emb) 52 | mat_dist = euclidean_dist(emb, emb) 53 | # mat_dist = cosine_dist(emb, emb) 54 | assert mat_dist.size(0) == mat_dist.size(1) 55 | N = mat_dist.size(0) 56 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 57 | 58 | dist_ap, dist_an = _batch_hard(mat_dist, mat_sim) 59 | assert dist_an.size(0)==dist_ap.size(0) 60 | y = torch.ones_like(dist_ap) 61 | loss = self.margin_loss(dist_an, dist_ap, y) 62 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 63 | return loss, prec 64 | 65 | 66 | class SoftTripletLoss(nn.Module): 67 | def __init__(self, margin=0.0): 68 | super(SoftTripletLoss, self).__init__() 69 | self.margin = margin 70 | 71 | def forward(self, emb, label): 72 | mat_dist = euclidean_dist(emb, emb) 73 | assert mat_dist.size(0) == mat_dist.size(1) 74 | N = mat_dist.size(0) 75 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 76 | 77 | dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True) 78 | assert dist_an.size(0) == dist_ap.size(0) 79 | triple_dist = torch.stack((dist_ap, dist_an), dim=1) 80 | triple_dist = F.log_softmax(triple_dist, dim=1) 81 | loss = (- self.margin * triple_dist[:, 0] - (1 - self.margin) * triple_dist[:, 1]).mean() 82 | 83 | return loss -------------------------------------------------------------------------------- /pplr/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .resnet_part import * 5 | 6 | 7 | __factory = { 8 | 'resnet18': resnet18, 9 | 'resnet34': resnet34, 10 | 'resnet50': resnet50, 11 | 'resnet101': resnet101, 12 | 'resnet152': resnet152, 13 | 'resnet18part': resnet18part, 14 | 'resnet34part': resnet34part, 15 | 'resnet50part': resnet50part, 16 | 'resnet101part': resnet101part, 17 | 'resnet152part': resnet152part, 18 | 19 | } 20 | 21 | 22 | def names(): 23 | return sorted(__factory.keys()) 24 | 25 | 26 | def create(name, *args, **kwargs): 27 | """ 28 | Create a model instance. 29 | 30 | Parameters 31 | ---------- 32 | name : str 33 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 34 | 'resnet50', 'resnet101', and 'resnet152'. 35 | pretrained : bool, optional 36 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 37 | model. Default: True 38 | cut_at_pooling : bool, optional 39 | If True, will cut the model before the last global pooling layer and 40 | ignore the remaining kwargs. Default: False 41 | num_features : int, optional 42 | If positive, will append a Linear layer after the global pooling layer, 43 | with this number of output units, followed by a BatchNorm layer. 44 | Otherwise these layers will not be appended. Default: 256 for 45 | 'inception', 0 for 'resnet*' 46 | norm : bool, optional 47 | If True, will normalize the feature to be unit L2-norm for each sample. 48 | Otherwise will append a ReLU layer after the above Linear layer if 49 | num_features > 0. Default: False 50 | dropout : float, optional 51 | If positive, will append a Dropout layer with this dropout rate. 52 | Default: 0 53 | num_classes : int, optional 54 | If positive, will append a Linear layer at the end as the classifier 55 | with this number of output units. Default: 0 56 | num_classes : int, optional 57 | If positive, will append a Linear layer at the end as the classifier 58 | with this number of output units. Default: 0 59 | """ 60 | if name not in __factory: 61 | raise KeyError("Unknown model:", name) 62 | return __factory[name](*args, **kwargs) 63 | -------------------------------------------------------------------------------- /pplr/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152'] 12 | 13 | 14 | class ResNet(nn.Module): 15 | __factory = { 16 | 18: torchvision.models.resnet18, 17 | 34: torchvision.models.resnet34, 18 | 50: torchvision.models.resnet50, 19 | 101: torchvision.models.resnet101, 20 | 152: torchvision.models.resnet152, 21 | } 22 | 23 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 24 | num_features=0, norm=False, dropout=0, num_classes=0): 25 | super(ResNet, self).__init__() 26 | self.pretrained = pretrained 27 | self.depth = depth 28 | self.cut_at_pooling = cut_at_pooling 29 | # Construct base (pretrained) resnet 30 | if depth not in ResNet.__factory: 31 | raise KeyError("Unsupported depth:", depth) 32 | resnet = ResNet.__factory[depth](pretrained=pretrained) 33 | resnet.layer4[0].conv2.stride = (1,1) 34 | resnet.layer4[0].downsample[0].stride = (1,1) 35 | self.base = nn.Sequential( 36 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 37 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 38 | self.gap = nn.AdaptiveAvgPool2d(1) 39 | 40 | if not self.cut_at_pooling: 41 | self.num_features = num_features 42 | self.norm = norm 43 | self.dropout = dropout 44 | self.has_embedding = num_features > 0 45 | self.num_classes = num_classes 46 | 47 | out_planes = resnet.fc.in_features 48 | 49 | # Append new layers 50 | if self.has_embedding: 51 | self.feat = nn.Linear(out_planes, self.num_features) 52 | self.feat_bn = nn.BatchNorm1d(self.num_features) 53 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 54 | init.constant_(self.feat.bias, 0) 55 | else: 56 | # Change the num_features to CNN output channels 57 | self.num_features = out_planes 58 | self.feat_bn = nn.BatchNorm1d(self.num_features) 59 | self.feat_bn.bias.requires_grad_(False) 60 | if self.dropout > 0: 61 | self.drop = nn.Dropout(self.dropout) 62 | if self.num_classes > 0: 63 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 64 | init.normal_(self.classifier.weight, std=0.001) 65 | init.constant_(self.feat_bn.weight, 1) 66 | init.constant_(self.feat_bn.bias, 0) 67 | 68 | if not pretrained: 69 | self.reset_params() 70 | 71 | def forward(self, x): 72 | bs = x.size(0) 73 | x = self.base(x) 74 | 75 | x = self.gap(x) 76 | x = x.view(x.size(0), -1) 77 | 78 | if self.cut_at_pooling: 79 | return x 80 | 81 | if self.has_embedding: 82 | bn_x = self.feat_bn(self.feat(x)) 83 | else: 84 | bn_x = self.feat_bn(x) 85 | 86 | if (self.training is False): 87 | bn_x = F.normalize(bn_x) 88 | return bn_x 89 | 90 | if self.norm: 91 | bn_x = F.normalize(bn_x) 92 | elif self.has_embedding: 93 | bn_x = F.relu(bn_x) 94 | 95 | if self.dropout > 0: 96 | bn_x = self.drop(bn_x) 97 | 98 | if self.num_classes > 0: 99 | prob = self.classifier(bn_x) 100 | else: 101 | return bn_x 102 | 103 | return prob 104 | 105 | def reset_params(self): 106 | for m in self.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | init.kaiming_normal_(m.weight, mode='fan_out') 109 | if m.bias is not None: 110 | init.constant_(m.bias, 0) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | init.constant_(m.weight, 1) 113 | init.constant_(m.bias, 0) 114 | elif isinstance(m, nn.BatchNorm1d): 115 | init.constant_(m.weight, 1) 116 | init.constant_(m.bias, 0) 117 | elif isinstance(m, nn.Linear): 118 | init.normal_(m.weight, std=0.001) 119 | if m.bias is not None: 120 | init.constant_(m.bias, 0) 121 | 122 | 123 | def resnet18(**kwargs): 124 | return ResNet(18, **kwargs) 125 | 126 | 127 | def resnet34(**kwargs): 128 | return ResNet(34, **kwargs) 129 | 130 | 131 | def resnet50(**kwargs): 132 | return ResNet(50, **kwargs) 133 | 134 | 135 | def resnet101(**kwargs): 136 | return ResNet(101, **kwargs) 137 | 138 | 139 | def resnet152(**kwargs): 140 | return ResNet(152, **kwargs) 141 | -------------------------------------------------------------------------------- /pplr/models/resnet_part.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | 9 | 10 | __all__ = ['ResNetPart', 'resnet18part', 'resnet34part', 'resnet50part', 'resnet101part', 11 | 'resnet152part'] 12 | 13 | 14 | class ResNetPart(nn.Module): 15 | """ ResNet with part features by uniform partitioning """ 16 | __factory = { 17 | 18: torchvision.models.resnet18, 18 | 34: torchvision.models.resnet34, 19 | 50: torchvision.models.resnet50, 20 | 101: torchvision.models.resnet101, 21 | 152: torchvision.models.resnet152, 22 | } 23 | 24 | def __init__(self, depth, pretrained=True, num_parts=3, num_classes=0): 25 | super(ResNetPart, self).__init__() 26 | self.pretrained = pretrained 27 | self.depth = depth 28 | # Construct base (pretrained) resnet 29 | if depth not in ResNetPart.__factory: 30 | raise KeyError("Unsupported depth:", depth) 31 | resnet = ResNetPart.__factory[depth](pretrained=pretrained) 32 | resnet.layer4[0].conv2.stride = (1,1) 33 | resnet.layer4[0].downsample[0].stride = (1,1) 34 | 35 | self.num_parts = num_parts 36 | self.num_classes = num_classes 37 | 38 | self.base = nn.Sequential( 39 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 40 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 41 | self.gap = nn.AdaptiveAvgPool2d(1) 42 | self.rap = nn.AdaptiveAvgPool2d((self.num_parts, 1)) 43 | 44 | # global feature classifiers 45 | self.bnneck = nn.BatchNorm1d(2048) 46 | init.constant_(self.bnneck.weight, 1) 47 | init.constant_(self.bnneck.bias, 0) 48 | self.bnneck.bias.requires_grad_(False) 49 | 50 | self.classifier = nn.Linear(2048, self.num_classes, bias=False) 51 | init.normal_(self.classifier.weight, std=0.001) 52 | 53 | # part feature classifiers 54 | for i in range(self.num_parts): 55 | name = 'bnneck' + str(i) 56 | setattr(self, name, nn.BatchNorm1d(2048)) 57 | init.constant_(getattr(self, name).weight, 1) 58 | init.constant_(getattr(self, name).bias, 0) 59 | getattr(self, name).bias.requires_grad_(False) 60 | 61 | name = 'classifier' + str(i) 62 | setattr(self, name, nn.Linear(2048, self.num_classes, bias=False)) 63 | 64 | if not pretrained: 65 | self.reset_params() 66 | 67 | def forward(self, x): 68 | x = self.base(x) 69 | 70 | f_g = self.gap(x) 71 | f_g = f_g.view(x.size(0), -1) 72 | f_g = self.bnneck(f_g) 73 | 74 | if self.training is False: 75 | f_g = F.normalize(f_g) 76 | return f_g 77 | 78 | logits_g = self.classifier(f_g) 79 | 80 | f_p = self.rap(x) 81 | f_p = f_p.view(f_p.size(0), f_p.size(1), -1) 82 | 83 | logits_p = [] 84 | fs_p = [] 85 | for i in range(self.num_parts): 86 | f_p_i = f_p[:, :, i] 87 | f_p_i = getattr(self, 'bnneck' + str(i))(f_p_i) 88 | logits_p_i = getattr(self, 'classifier' + str(i))(f_p_i) 89 | logits_p.append(logits_p_i) 90 | fs_p.append(f_p_i) 91 | 92 | fs_p = torch.stack(fs_p, dim=-1) 93 | logits_p = torch.stack(logits_p, dim=-1) 94 | 95 | return f_g, fs_p, logits_g, logits_p 96 | 97 | def reset_params(self): 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | init.kaiming_normal_(m.weight, mode='fan_out') 101 | if m.bias is not None: 102 | init.constant_(m.bias, 0) 103 | elif isinstance(m, nn.BatchNorm2d): 104 | init.constant_(m.weight, 1) 105 | init.constant_(m.bias, 0) 106 | elif isinstance(m, nn.BatchNorm1d): 107 | init.constant_(m.weight, 1) 108 | init.constant_(m.bias, 0) 109 | elif isinstance(m, nn.Linear): 110 | init.normal_(m.weight, std=0.001) 111 | if m.bias is not None: 112 | init.constant_(m.bias, 0) 113 | 114 | def extract_all_features(self, x): 115 | x = self.base(x) 116 | 117 | f_g = self.gap(x) 118 | f_g = f_g.view(x.size(0), -1) 119 | f_g = self.bnneck(f_g) 120 | f_g = F.normalize(f_g) 121 | 122 | f_p = self.rap(x) 123 | f_p = f_p.view(f_p.size(0), f_p.size(1), -1) 124 | 125 | fs_p = [] 126 | for i in range(self.num_parts): 127 | f_p_i = f_p[:, :, i] 128 | f_p_i = getattr(self, 'bnneck' + str(i))(f_p_i) 129 | f_p_i = F.normalize(f_p_i) 130 | fs_p.append(f_p_i) 131 | fs_p = torch.stack(fs_p, dim=-1) 132 | 133 | return f_g, fs_p 134 | 135 | 136 | def resnet18part(**kwargs): 137 | return ResNetPart(18, **kwargs) 138 | 139 | 140 | def resnet34part(**kwargs): 141 | return ResNetPart(34, **kwargs) 142 | 143 | 144 | def resnet50part(**kwargs): 145 | return ResNetPart(50, **kwargs) 146 | 147 | 148 | def resnet101part(**kwargs): 149 | return ResNetPart(101, **kwargs) 150 | 151 | 152 | def resnet152part(**kwargs): 153 | return ResNetPart(152, **kwargs) 154 | -------------------------------------------------------------------------------- /pplr/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | 4 | from .evaluation_metrics import accuracy 5 | from .loss import AALS, PGLR, SoftTripletLoss, CrossEntropyLabelSmooth 6 | from .utils.meters import AverageMeter 7 | 8 | 9 | class PPLRTrainer(object): 10 | def __init__(self, model, score, num_class=500, num_part=6, beta=0.5, aals_epoch=5): 11 | super(PPLRTrainer, self).__init__() 12 | self.model = model 13 | self.score = score 14 | 15 | self.num_class = num_class 16 | self.num_part = num_part 17 | self.aals_epoch = aals_epoch 18 | 19 | self.criterion_pglr = PGLR(lam=beta).cuda() 20 | self.criterion_aals = AALS().cuda() 21 | self.criterion_ce = CrossEntropyLabelSmooth(num_classes=num_class).cuda() 22 | self.criterion_tri = SoftTripletLoss().cuda() 23 | 24 | def train(self, epoch, train_dataloader, optimizer, print_freq=1, train_iters=200): 25 | self.model.train() 26 | 27 | batch_time = AverageMeter() 28 | losses_gce = AverageMeter() 29 | losses_tri = AverageMeter() 30 | losses_pce = AverageMeter() 31 | precisions = AverageMeter() 32 | 33 | time.sleep(1) 34 | end = time.time() 35 | for i in range(train_iters): 36 | data = train_dataloader.next() 37 | inputs, targets, ca = self._parse_data(data) 38 | 39 | # feedforward 40 | emb_g, emb_p, logits_g, logits_p = self.model(inputs) 41 | logits_g, logits_p = logits_g[:, :self.num_class], logits_p[:, :self.num_class, :] 42 | 43 | # loss 44 | loss_gce = self.criterion_pglr(logits_g, logits_p, targets, ca) 45 | loss_tri = self.criterion_tri(emb_g, targets) 46 | 47 | loss_pce = 0. 48 | if self.num_part > 0: 49 | if epoch >= self.aals_epoch: 50 | for part in range(self.num_part): 51 | loss_pce += self.criterion_aals(logits_p[:, :, part], targets, ca[:, part]) 52 | else: 53 | for part in range(self.num_part): 54 | loss_pce += self.criterion_ce(logits_p[:, :, part], targets) 55 | loss_pce /= self.num_part 56 | 57 | loss = loss_gce + loss_tri + loss_pce 58 | 59 | # update 60 | optimizer.zero_grad() 61 | loss.backward() 62 | optimizer.step() 63 | 64 | # summing-up 65 | prec, = accuracy(logits_g.data, targets.data) 66 | 67 | losses_gce.update(loss_gce.item()) 68 | losses_tri.update(loss_tri.item()) 69 | losses_pce.update(loss_pce.item()) 70 | precisions.update(prec[0]) 71 | 72 | batch_time.update(time.time() - end) 73 | end = time.time() 74 | 75 | if (i + 1) % print_freq == 0: 76 | print('Epoch: [{}][{}/{}]\t' 77 | 'Time {:.3f} ({:.3f})\t' 78 | 'L_GCE {:.3f} ({:.3f})\t' 79 | 'L_PCE {:.3f} ({:.3f})\t' 80 | 'L_TRI {:.3f} ({:.3f})\t' 81 | 'Prec {:.2%} ({:.2%})\t' 82 | .format(epoch, i + 1, len(train_dataloader), 83 | batch_time.val, batch_time.avg, 84 | losses_gce.val, losses_gce.avg, 85 | losses_pce.val, losses_pce.avg, 86 | losses_tri.val, losses_tri.avg, 87 | precisions.val, precisions.avg)) 88 | 89 | def _parse_data(self, inputs): 90 | imgs, _, pids, _, idxs = inputs 91 | ca = self.score[idxs] 92 | return imgs.cuda(), pids.cuda(), ca.cuda() 93 | 94 | 95 | class PPLRTrainerCAM(object): 96 | def __init__(self, model, score, memory, memory_p, num_class=500, num_part=6, beta=0.5, aals_epoch=5, lam_cam=0.5): 97 | super(PPLRTrainerCAM, self).__init__() 98 | self.model = model 99 | self.score = score 100 | self.memory = memory 101 | self.memory_p = memory_p 102 | 103 | self.num_class = num_class 104 | self.num_part = num_part 105 | self.lam_cam = lam_cam 106 | self.aals_epoch = aals_epoch 107 | 108 | self.criterion_pglr = PGLR(lam=beta).cuda() 109 | self.criterion_aals = AALS().cuda() 110 | self.criterion_ce = CrossEntropyLabelSmooth(num_classes=num_class).cuda() 111 | self.criterion_tri = SoftTripletLoss().cuda() 112 | 113 | def train(self, epoch, train_dataloader, optimizer, print_freq=1, train_iters=200): 114 | self.model.train() 115 | 116 | batch_time = AverageMeter() 117 | losses_gce = AverageMeter() 118 | losses_tri = AverageMeter() 119 | losses_cam = AverageMeter() 120 | losses_pce = AverageMeter() 121 | 122 | precisions = AverageMeter() 123 | 124 | time.sleep(1) 125 | end = time.time() 126 | for i in range(train_iters): 127 | data = train_dataloader.next() 128 | inputs, targets, cams, ca = self._parse_data(data) 129 | 130 | # feedforward 131 | emb_g, emb_p, logits_g, logits_p = self.model(inputs) 132 | logits_g, logits_p = logits_g[:, :self.num_class], logits_p[:, :self.num_class, :] 133 | 134 | # loss 135 | loss_gce = self.criterion_pglr(logits_g, logits_p, targets, ca) 136 | loss_tri = self.criterion_tri(emb_g, targets) 137 | loss_gcam = self.memory(emb_g, targets, cams) 138 | 139 | loss_pce = 0. 140 | loss_pcam = 0. 141 | if self.num_part > 0: 142 | if epoch >= self.aals_epoch: 143 | for part in range(self.num_part): 144 | loss_pce += self.criterion_aals(logits_p[:, :, part], targets, ca[:, part]) 145 | loss_pcam += self.memory_p[part](emb_p[:, :, part], targets, cams) 146 | else: 147 | for part in range(self.num_part): 148 | loss_pce += self.criterion_ce(logits_p[:, :, part], targets) 149 | loss_pcam += self.memory_p[part](emb_p[:, :, part], targets, cams) 150 | loss_pce /= self.num_part 151 | loss_pcam /= self.num_part 152 | 153 | loss_cam = loss_pcam + loss_gcam 154 | loss = loss_gce + loss_pce + loss_tri + loss_cam * self.lam_cam 155 | 156 | # update 157 | optimizer.zero_grad() 158 | loss.backward() 159 | optimizer.step() 160 | 161 | # summing-up 162 | prec, = accuracy(logits_g.data, targets.data) 163 | 164 | losses_gce.update(loss_gce.item()) 165 | losses_tri.update(loss_tri.item()) 166 | losses_cam.update(loss_cam.item()) 167 | losses_pce.update(loss_pce.item()) 168 | precisions.update(prec[0]) 169 | 170 | batch_time.update(time.time() - end) 171 | end = time.time() 172 | 173 | if (i + 1) % print_freq == 0: 174 | print('Epoch: [{}][{}/{}]\t' 175 | 'Time {:.3f} ({:.3f})\t' 176 | 'L_GCE {:.3f} ({:.3f})\t' 177 | 'L_PCE {:.3f} ({:.3f})\t' 178 | 'L_TRI {:.3f} ({:.3f})\t' 179 | 'L_CAM {:.3f} ({:.3f})\t' 180 | 'Prec {:.2%} ({:.2%})\t' 181 | .format(epoch, i + 1, len(train_dataloader), 182 | batch_time.val, batch_time.avg, 183 | losses_gce.val, losses_gce.avg, 184 | losses_pce.val, losses_pce.avg, 185 | losses_tri.val, losses_tri.avg, 186 | losses_cam.val, losses_cam.avg, 187 | precisions.val, precisions.avg)) 188 | 189 | def _parse_data(self, inputs): 190 | imgs, _, pids, cids, idxs = inputs 191 | ca = self.score[idxs] 192 | return imgs.cuda(), pids.cuda(), cids.cuda(), ca.cuda() 193 | -------------------------------------------------------------------------------- /pplr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /pplr/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .base_dataset import BaseDataset, BaseImageDataset 4 | from .preprocessor import Preprocessor 5 | 6 | class IterLoader: 7 | def __init__(self, loader, length=None): 8 | self.loader = loader 9 | self.length = length 10 | self.iter = None 11 | 12 | def __len__(self): 13 | if (self.length is not None): 14 | return self.length 15 | return len(self.loader) 16 | 17 | def new_epoch(self): 18 | self.iter = iter(self.loader) 19 | 20 | def next(self): 21 | try: 22 | return next(self.iter) 23 | except: 24 | self.iter = iter(self.loader) 25 | return next(self.iter) 26 | -------------------------------------------------------------------------------- /pplr/utils/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def print_dataset_statistics(self): 23 | raise NotImplementedError 24 | 25 | @property 26 | def images_dir(self): 27 | return None 28 | 29 | 30 | class BaseImageDataset(BaseDataset): 31 | """ 32 | Base class of image reid dataset 33 | """ 34 | 35 | def print_dataset_statistics(self, train, query, gallery): 36 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 37 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 38 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 39 | 40 | print("Dataset statistics:") 41 | print(" ----------------------------------------") 42 | print(" subset | # ids | # images | # cameras") 43 | print(" ----------------------------------------") 44 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 45 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 46 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 47 | print(" ----------------------------------------") 48 | -------------------------------------------------------------------------------- /pplr/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import random 7 | import math 8 | from PIL import Image 9 | 10 | 11 | class Preprocessor(Dataset): 12 | def __init__(self, dataset, root=None, transform=None): 13 | super(Preprocessor, self).__init__() 14 | self.dataset = dataset 15 | self.root = root 16 | self.transform = transform 17 | 18 | def __len__(self): 19 | return len(self.dataset) 20 | 21 | def __getitem__(self, indices): 22 | return self._get_single_item(indices) 23 | 24 | def _get_single_item(self, index): 25 | fname, pid, camid = self.dataset[index] 26 | fpath = fname 27 | if self.root is not None: 28 | fpath = osp.join(self.root, fname) 29 | 30 | img = Image.open(fpath).convert('RGB') 31 | 32 | if self.transform is not None: 33 | img = self.transform(img) 34 | 35 | return img, fname, pid, camid, index 36 | -------------------------------------------------------------------------------- /pplr/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import math 4 | 5 | import numpy as np 6 | import copy 7 | import random 8 | import torch 9 | from torch.utils.data.sampler import ( 10 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 11 | WeightedRandomSampler) 12 | 13 | 14 | def No_index(a, b): 15 | assert isinstance(a, list) 16 | return [i for i, j in enumerate(a) if j != b] 17 | 18 | 19 | class RandomIdentitySampler(Sampler): 20 | def __init__(self, data_source, num_instances): 21 | self.data_source = data_source 22 | self.num_instances = num_instances 23 | self.index_dic = defaultdict(list) 24 | for index, (_, pid, _) in enumerate(data_source): 25 | self.index_dic[pid].append(index) 26 | self.pids = list(self.index_dic.keys()) 27 | self.num_samples = len(self.pids) 28 | 29 | def __len__(self): 30 | return self.num_samples * self.num_instances 31 | 32 | def __iter__(self): 33 | indices = torch.randperm(self.num_samples).tolist() 34 | ret = [] 35 | for i in indices: 36 | pid = self.pids[i] 37 | t = self.index_dic[pid] 38 | if len(t) >= self.num_instances: 39 | t = np.random.choice(t, size=self.num_instances, replace=False) 40 | else: 41 | t = np.random.choice(t, size=self.num_instances, replace=True) 42 | ret.extend(t) 43 | return iter(ret) 44 | 45 | 46 | class RandomMultipleGallerySampler(Sampler): 47 | def __init__(self, data_source, num_instances=4): 48 | self.data_source = data_source 49 | self.index_pid = defaultdict(int) 50 | self.pid_cam = defaultdict(list) 51 | self.pid_index = defaultdict(list) 52 | self.num_instances = num_instances 53 | 54 | for index, (_, pid, cam) in enumerate(data_source): 55 | if (pid<0): continue 56 | self.index_pid[index] = pid 57 | self.pid_cam[pid].append(cam) 58 | self.pid_index[pid].append(index) 59 | 60 | self.pids = list(self.pid_index.keys()) 61 | self.num_samples = len(self.pids) 62 | 63 | def __len__(self): 64 | return self.num_samples * self.num_instances 65 | 66 | def __iter__(self): 67 | indices = torch.randperm(len(self.pids)).tolist() 68 | ret = [] 69 | 70 | for kid in indices: 71 | i = random.choice(self.pid_index[self.pids[kid]]) 72 | 73 | _, i_pid, i_cam = self.data_source[i] 74 | 75 | ret.append(i) 76 | 77 | pid_i = self.index_pid[i] 78 | cams = self.pid_cam[pid_i] 79 | index = self.pid_index[pid_i] 80 | select_cams = No_index(cams, i_cam) 81 | 82 | if select_cams: 83 | 84 | if len(select_cams) >= self.num_instances: 85 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 86 | else: 87 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True) 88 | 89 | for kk in cam_indexes: 90 | ret.append(index[kk]) 91 | 92 | else: 93 | select_indexes = No_index(index, i) 94 | if (not select_indexes): continue 95 | if len(select_indexes) >= self.num_instances: 96 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 97 | else: 98 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 99 | 100 | for kk in ind_indexes: 101 | ret.append(index[kk]) 102 | 103 | return iter(ret) 104 | -------------------------------------------------------------------------------- /pplr/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | """ Randomly selects a rectangle region in an image and erases its pixels. 54 | 'Random Erasing Data Augmentation' by Zhong et al. 55 | See https://arxiv.org/pdf/1708.04896.pdf 56 | Args: 57 | probability: The probability that the Random Erasing operation will be performed. 58 | sl: Minimum proportion of erased area against input image. 59 | sh: Maximum proportion of erased area against input image. 60 | r1: Minimum aspect ratio of erased area. 61 | mean: Erasing value. 62 | """ 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) >= self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | -------------------------------------------------------------------------------- /pplr/utils/faiss_rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 5 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 6 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 7 | """ 8 | 9 | import os, sys 10 | import time 11 | import numpy as np 12 | from scipy.spatial.distance import cdist 13 | import gc 14 | import faiss 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | 19 | from .faiss_utils import search_index_pytorch, search_raw_array_pytorch, \ 20 | index_init_gpu, index_init_cpu 21 | 22 | def k_reciprocal_neigh(initial_rank, i, k1): 23 | forward_k_neigh_index = initial_rank[i,:k1+1] 24 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 25 | fi = np.where(backward_k_neigh_index==i)[0] 26 | return forward_k_neigh_index[fi] 27 | 28 | def compute_jaccard_distance(target_features, k1=20, k2=6, print_flag=True, search_option=0, use_float16=False): 29 | end = time.time() 30 | if print_flag: 31 | print('Computing jaccard distance...') 32 | 33 | ngpus = faiss.get_num_gpus() 34 | N = target_features.size(0) 35 | mat_type = np.float16 if use_float16 else np.float32 36 | 37 | if (search_option==0): 38 | # GPU + PyTorch CUDA Tensors (1) 39 | res = faiss.StandardGpuResources() 40 | res.setDefaultNullStreamAllDevices() 41 | _, initial_rank = search_raw_array_pytorch(res, target_features, target_features, k1) 42 | initial_rank = initial_rank.cpu().numpy() 43 | elif (search_option==1): 44 | # GPU + PyTorch CUDA Tensors (2) 45 | res = faiss.StandardGpuResources() 46 | index = faiss.GpuIndexFlatL2(res, target_features.size(-1)) 47 | index.add(target_features.cpu().numpy()) 48 | _, initial_rank = search_index_pytorch(index, target_features, k1) 49 | res.syncDefaultStreamCurrentDevice() 50 | initial_rank = initial_rank.cpu().numpy() 51 | elif (search_option==2): 52 | # GPU 53 | index = index_init_gpu(ngpus, target_features.size(-1)) 54 | index.add(target_features.cpu().numpy()) 55 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 56 | else: 57 | # CPU 58 | index = index_init_cpu(target_features.size(-1)) 59 | index.add(target_features.cpu().numpy()) 60 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 61 | 62 | 63 | nn_k1 = [] 64 | nn_k1_half = [] 65 | for i in range(N): 66 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 67 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1/2)))) 68 | 69 | V = np.zeros((N, N), dtype=mat_type) 70 | for i in range(N): 71 | k_reciprocal_index = nn_k1[i] 72 | k_reciprocal_expansion_index = k_reciprocal_index 73 | for candidate in k_reciprocal_index: 74 | candidate_k_reciprocal_index = nn_k1_half[candidate] 75 | if (len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index)) > 2/3*len(candidate_k_reciprocal_index)): 76 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 77 | 78 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 79 | dist = 2-2*torch.mm(target_features[i].unsqueeze(0).contiguous(), target_features[k_reciprocal_expansion_index].t()) 80 | if use_float16: 81 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type) 82 | else: 83 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy() 84 | 85 | del nn_k1, nn_k1_half 86 | 87 | if k2 != 1: 88 | V_qe = np.zeros_like(V, dtype=mat_type) 89 | for i in range(N): 90 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:], axis=0) 91 | V = V_qe 92 | del V_qe 93 | 94 | del initial_rank 95 | 96 | invIndex = [] 97 | for i in range(N): 98 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num 99 | 100 | jaccard_dist = np.zeros((N, N), dtype=mat_type) 101 | for i in range(N): 102 | temp_min = np.zeros((1,N), dtype=mat_type) 103 | # temp_max = np.zeros((1,N), dtype=mat_type) 104 | indNonZero = np.where(V[i,:] != 0)[0] 105 | indImages = [] 106 | indImages = [invIndex[ind] for ind in indNonZero] 107 | for j in range(len(indNonZero)): 108 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 109 | # temp_max[0,indImages[j]] = temp_max[0,indImages[j]]+np.maximum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 110 | 111 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 112 | # jaccard_dist[i] = 1-temp_min/(temp_max+1e-6) 113 | 114 | del invIndex, V 115 | 116 | pos_bool = (jaccard_dist < 0) 117 | jaccard_dist[pos_bool] = 0.0 118 | if print_flag: 119 | print ("Jaccard distance computing time cost: {}".format(time.time()-end)) 120 | 121 | return jaccard_dist 122 | 123 | @torch.no_grad() 124 | def compute_ranked_list(features, k=20, search_option=0, fp16=False, verbose=True): 125 | 126 | end = time.time() 127 | if verbose: 128 | print("Computing ranked list...") 129 | 130 | if search_option < 3: 131 | torch.cuda.empty_cache() 132 | features = features.cuda().detach() 133 | 134 | ngpus = faiss.get_num_gpus() 135 | 136 | if search_option == 0: 137 | # Faiss Search + PyTorch CUDA Tensors (1) 138 | res = faiss.StandardGpuResources() 139 | res.setDefaultNullStreamAllDevices() 140 | _, initial_rank = search_raw_array_pytorch(res, features, features, k+1) 141 | initial_rank = initial_rank.cpu().numpy() 142 | 143 | elif search_option == 1: 144 | # Faiss Search + PyTorch CUDA Tensors (2) 145 | res = faiss.StandardGpuResources() 146 | index = faiss.GpuIndexFlatL2(res, features.size(-1)) 147 | index.add(features.cpu().numpy()) 148 | _, initial_rank = search_index_pytorch(index, features, k+1) 149 | res.syncDefaultStreamCurrentDevice() 150 | initial_rank = initial_rank.cpu().numpy() 151 | 152 | elif search_option == 2: 153 | # PyTorch Search + PyTorch CUDA Tensors 154 | torch.cuda.empty_cache() 155 | features = features.cuda().detach() 156 | dist_m = compute_euclidean_distance(features, cuda=True) 157 | initial_rank = torch.argsort(dist_m, dim=1) 158 | initial_rank = initial_rank.cpu().numpy() 159 | 160 | else: 161 | # Numpy Search (CPU) 162 | torch.cuda.empty_cache() 163 | features = features.cuda().detach() 164 | dist_m = compute_euclidean_distance(features, cuda=False) 165 | initial_rank = np.argsort(dist_m.cpu().numpy(), axis=1) 166 | features = features.cpu() 167 | 168 | features = features.cpu() 169 | if verbose: 170 | print("Ranked list computing time cost: {}".format(time.time() - end)) 171 | 172 | return initial_rank[:, 1:k+1] 173 | 174 | @torch.no_grad() 175 | def compute_euclidean_distance(features, others=None, cuda=False): 176 | if others is None: 177 | if cuda: 178 | features = features.cuda() 179 | 180 | n = features.size(0) 181 | x = features.view(n, -1) 182 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 183 | dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t()) 184 | del features 185 | 186 | else: 187 | if cuda: 188 | features = features.cuda() 189 | others = others.cuda() 190 | 191 | m, n = features.size(0), others.size(0) 192 | dist_m = torch.pow(features, 2).sum(dim=1, keepdim=True).expand(m, n) +\ 193 | torch.pow(others, 2).sum(dim=1, keepdim=True).expand(n, m).t() 194 | 195 | dist_m.addmm_(features, others.t(), beta=1, alpha=-2) 196 | del features, others 197 | 198 | return dist_m 199 | -------------------------------------------------------------------------------- /pplr/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | def swig_ptr_from_FloatTensor(x): 7 | assert x.is_contiguous() 8 | assert x.dtype == torch.float32 9 | return faiss.cast_integer_to_float_ptr( 10 | x.storage().data_ptr() + x.storage_offset() * 4) 11 | 12 | def swig_ptr_from_LongTensor(x): 13 | assert x.is_contiguous() 14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 15 | return faiss.cast_integer_to_long_ptr( 16 | x.storage().data_ptr() + x.storage_offset() * 8) 17 | 18 | def search_index_pytorch(index, x, k, D=None, I=None): 19 | """call the search function of an index with pytorch tensor I/O (CPU 20 | and GPU supported)""" 21 | assert x.is_contiguous() 22 | n, d = x.size() 23 | assert d == index.d 24 | 25 | if D is None: 26 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 27 | else: 28 | assert D.size() == (n, k) 29 | 30 | if I is None: 31 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 32 | else: 33 | assert I.size() == (n, k) 34 | torch.cuda.synchronize() 35 | xptr = swig_ptr_from_FloatTensor(x) 36 | Iptr = swig_ptr_from_LongTensor(I) 37 | Dptr = swig_ptr_from_FloatTensor(D) 38 | index.search_c(n, xptr, 39 | k, Dptr, Iptr) 40 | torch.cuda.synchronize() 41 | return D, I 42 | 43 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 44 | metric=faiss.METRIC_L2): 45 | assert xb.device == xq.device 46 | 47 | nq, d = xq.size() 48 | if xq.is_contiguous(): 49 | xq_row_major = True 50 | elif xq.t().is_contiguous(): 51 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 52 | xq_row_major = False 53 | else: 54 | raise TypeError('matrix should be row or column-major') 55 | 56 | xq_ptr = swig_ptr_from_FloatTensor(xq) 57 | 58 | nb, d2 = xb.size() 59 | assert d2 == d 60 | if xb.is_contiguous(): 61 | xb_row_major = True 62 | elif xb.t().is_contiguous(): 63 | xb = xb.t() 64 | xb_row_major = False 65 | else: 66 | raise TypeError('matrix should be row or column-major') 67 | xb_ptr = swig_ptr_from_FloatTensor(xb) 68 | 69 | if D is None: 70 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 71 | else: 72 | assert D.shape == (nq, k) 73 | assert D.device == xb.device 74 | 75 | if I is None: 76 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 77 | else: 78 | assert I.shape == (nq, k) 79 | assert I.device == xb.device 80 | 81 | D_ptr = swig_ptr_from_FloatTensor(D) 82 | I_ptr = swig_ptr_from_LongTensor(I) 83 | 84 | faiss.bruteForceKnn(res, metric, 85 | xb_ptr, xb_row_major, nb, 86 | xq_ptr, xq_row_major, nq, 87 | d, k, D_ptr, I_ptr) 88 | 89 | return D, I 90 | 91 | def index_init_gpu(ngpus, feat_dim): 92 | flat_config = [] 93 | for i in range(ngpus): 94 | cfg = faiss.GpuIndexFlatConfig() 95 | cfg.useFloat16 = False 96 | cfg.device = i 97 | flat_config.append(cfg) 98 | 99 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 100 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 101 | index = faiss.IndexShards(feat_dim) 102 | for sub_index in indexes: 103 | index.add_shard(sub_index) 104 | index.reset() 105 | return index 106 | 107 | def index_init_cpu(feat_dim): 108 | return faiss.IndexFlatL2(feat_dim) 109 | -------------------------------------------------------------------------------- /pplr/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /pplr/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /pplr/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /pplr/utils/rerank.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | __all__ = ['re_ranking'] 6 | 7 | import numpy as np 8 | 9 | 10 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 11 | 12 | # The following naming, e.g. gallery_num, is different from outer scope. 13 | # Don't care about it. 14 | 15 | original_dist = np.concatenate( 16 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 17 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 18 | axis=0) 19 | original_dist = np.power(original_dist, 2).astype(np.float32) 20 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 21 | V = np.zeros_like(original_dist).astype(np.float32) 22 | initial_rank = np.argsort(original_dist).astype(np.int32) 23 | 24 | query_num = q_g_dist.shape[0] 25 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 26 | all_num = gallery_num 27 | 28 | for i in range(all_num): 29 | # k-reciprocal neighbors 30 | forward_k_neigh_index = initial_rank[i,:k1+1] 31 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 32 | fi = np.where(backward_k_neigh_index==i)[0] 33 | k_reciprocal_index = forward_k_neigh_index[fi] 34 | k_reciprocal_expansion_index = k_reciprocal_index 35 | for j in range(len(k_reciprocal_index)): 36 | candidate = k_reciprocal_index[j] 37 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 38 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 39 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 40 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 41 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 42 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 43 | 44 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 45 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 46 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 47 | original_dist = original_dist[:query_num,] 48 | if k2 != 1: 49 | V_qe = np.zeros_like(V,dtype=np.float32) 50 | for i in range(all_num): 51 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 52 | V = V_qe 53 | del V_qe 54 | del initial_rank 55 | invIndex = [] 56 | for i in range(gallery_num): 57 | invIndex.append(np.where(V[:,i] != 0)[0]) 58 | 59 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 60 | 61 | 62 | for i in range(query_num): 63 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 64 | indNonZero = np.where(V[i,:] != 0)[0] 65 | indImages = [] 66 | indImages = [invIndex[ind] for ind in indNonZero] 67 | for j in range(len(indNonZero)): 68 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 69 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 70 | 71 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 72 | del original_dist 73 | del V 74 | del jaccard_dist 75 | final_dist = final_dist[:query_num,query_num:] 76 | return final_dist 77 | -------------------------------------------------------------------------------- /pplr/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | # checkpoint = torch.load(fpath) 34 | checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 35 | print("=> Loaded checkpoint '{}'".format(fpath)) 36 | return checkpoint 37 | else: 38 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 39 | 40 | 41 | def copy_state_dict(state_dict, model, strip=None): 42 | tgt_state = model.state_dict() 43 | copied_names = set() 44 | for name, param in state_dict.items(): 45 | if strip is not None and name.startswith(strip): 46 | name = name[len(strip):] 47 | if name not in tgt_state: 48 | continue 49 | if isinstance(param, Parameter): 50 | param = param.data 51 | if param.size() != tgt_state[name].size(): 52 | print('mismatch:', name, param.size(), tgt_state[name].size()) 53 | continue 54 | tgt_state[name].copy_(param) 55 | copied_names.add(name) 56 | 57 | missing = set(tgt_state.keys()) - copied_names 58 | if len(missing) > 0: 59 | print("missing keys in state_dict:", missing) 60 | 61 | return model 62 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup(name='PPLR', 5 | version='1.0.0', 6 | description='Part-based Pseudo Label Refinement for Unsupervised Person Re-identification', 7 | author='Yoonki Cho', 8 | author_email='yoonki@kaist.ac.kr', 9 | url='https://github.com/yoonkicho/PPLR', 10 | install_requires=[ 11 | 'numpy', 'torch', 'torchvision', 12 | 'six', 'h5py', 'Pillow', 'scipy', 13 | 'scikit-learn', 'metric-learn', 'faiss_gpu==1.6.3'], 14 | packages=find_packages() 15 | ) --------------------------------------------------------------------------------