├── .gitignore ├── LICENSE ├── README.md ├── build_env.sh ├── configs ├── mutualrefine.yml └── pretrain.yml ├── figs ├── framework.png └── results.png ├── main.py ├── scripts ├── duke2market │ ├── eval.sh │ ├── pretrain.sh │ └── train.sh ├── market2duke │ ├── eval.sh │ ├── pretrain.sh │ └── train.sh └── market2msmt │ ├── eval.sh │ ├── pretrain.sh │ └── train.sh ├── secret ├── __init__.py ├── cluster │ ├── RefineCluster.py │ ├── __init__.py │ └── faiss_utils.py ├── config │ ├── __init__.py │ ├── config.py │ └── defaults.py ├── data │ ├── __init__.py │ ├── build.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── dukemtmc.py │ │ ├── market1501.py │ │ └── msmt17.py │ ├── preprocessor │ │ ├── __init__.py │ │ └── preprocessor.py │ ├── samplers │ │ ├── PartRandomMultipleGallerySampler.py │ │ ├── RandomMultipleGallerySampler.py │ │ └── __init__.py │ └── transforms │ │ ├── __init__.py │ │ ├── build.py │ │ └── transforms.py ├── engine │ ├── __init__.py │ ├── mutualrefine.py │ └── pretrain.py ├── loss │ ├── __init__.py │ ├── crossentropy.py │ └── triplet.py ├── metrics │ ├── Partevaluators.py │ ├── __init__.py │ ├── evaluators.py │ ├── rank_c.py │ ├── rank_cylib │ │ ├── Makefile │ │ ├── __init__.py │ │ ├── rank_cy.c │ │ ├── rank_cy.pyx │ │ ├── setup.py │ │ └── test_cython.py │ ├── ranking.py │ └── rerank.py ├── models │ ├── __init__.py │ └── resnet.py ├── optim │ ├── __init__.py │ ├── lr_scheduler.py │ └── optimizer.py └── utils │ ├── __init__.py │ ├── defaults.py │ ├── logger.py │ ├── meters.py │ ├── osutils.py │ └── serialization.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # log file 2 | logs/* 3 | log/* 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 LunarShen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [AAAI2022] SECRET 2 | The official repository for SECRET: Self-Consistent Pseudo Label Refinement for Unsupervised Domain Adaptive Person Re-Identification. 3 | 4 | 5 | 6 | ![framework](figs/framework.png) 7 | 8 | ## Installation 9 | 10 | ```shell 11 | git clone https://github.com/LunarShen/SECRET.git 12 | cd SECRET 13 | conda create -n secret python=3.8 14 | conda activate secret 15 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch 16 | pip install tqdm numpy six h5py Pillow scipy scikit-learn metric-learn pyyaml yacs termcolor faiss-gpu==1.6.3 opencv-python Cython 17 | python setup.py develop 18 | ``` 19 | 20 | ## Prepare Datasets 21 | 22 | ```shell 23 | mkdir Data 24 | ``` 25 | Download the raw datasets [DukeMTMC-reID](https://arxiv.org/abs/1609.01775), [Market-1501](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Zheng_Scalable_Person_Re-Identification_ICCV_2015_paper.pdf), [MSMT17](https://arxiv.org/abs/1711.08565), 26 | and then unzip them under the directory like 27 | ``` 28 | SECRET/Data 29 | ├── dukemtmc 30 | │   └── DukeMTMC-reID 31 | ├── market1501 32 | │   └── Market-1501-v15.09.15 33 | └── msmt17 34 | └── MSMT17_V1 35 | ``` 36 | 37 | ## Training 38 | 39 | We utilize 4 GTX-1080TI GPUs for training. 40 | 41 | ### Stage I: Pre-training on the source domain 42 | 43 | ```shell 44 | # duke-to-market 45 | sh scripts/duke2market/pretrain.sh 46 | # market-to-duke 47 | sh scripts/market2duke/pretrain.sh 48 | # market-to-msmt 49 | sh scripts/market2msmt/pretrain.sh 50 | ``` 51 | 52 | ### Stage II: fine-tuning with SECRET 53 | 54 | ```shell 55 | # duke-to-market 56 | sh scripts/duke2market/train.sh 57 | # market-to-duke 58 | sh scripts/market2duke/train.sh 59 | # market-to-msmt 60 | sh scripts/market2msmt/train.sh 61 | ``` 62 | 63 | ## Evaluation 64 | 65 | ```shell 66 | # duke-to-market 67 | sh scripts/duke2market/eval.sh 68 | # market-to-duke 69 | sh scripts/market2duke/eval.sh 70 | # market-to-msmt 71 | sh scripts/market2msmt/eval.sh 72 | ``` 73 | 74 | ## Results 75 | ![results](figs/results.png) 76 | 77 | ## Acknowledgements 78 | Codebase from [MMT](https://github.com/yxgeee/MMT), [fast-reid](https://github.com/JDAI-CV/fast-reid) 79 | 80 | ## Citation 81 | If you find this project useful for your research, please cite our paper. 82 | ```bibtex 83 | @inproceedings{he2022secret, 84 | title={SECRET: Self-Consistent Pseudo Label Refinement for Unsupervised Domain Adaptive Person Re-identification}, 85 | author={He, Tao and Shen, Leqi and Guo, Yuchen and Ding, Guiguang and Guo, Zhenhua}, 86 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 87 | volume={36}, 88 | number={1}, 89 | pages={879--887}, 90 | year={2022} 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /build_env.sh: -------------------------------------------------------------------------------- 1 | conda create -n secret python=3.8 2 | conda activate secret 3 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch 4 | pip install tqdm numpy six h5py Pillow scipy scikit-learn metric-learn pyyaml yacs termcolor faiss-gpu==1.6.3 opencv-python Cython 5 | python setup.py develop 6 | -------------------------------------------------------------------------------- /configs/mutualrefine.yml: -------------------------------------------------------------------------------- 1 | DATASETS: 2 | SOURCE: "dukemtmc" 3 | TARGET: "market1501" 4 | DIR: "data" 5 | 6 | OUTPUT_DIR: "log/duke2market/mutualrefine" 7 | 8 | GPU_Device: [0,1,2,3] 9 | 10 | MODE: "mutualrefine" 11 | 12 | CLUSTER: 13 | REFINE_K: 0.4 14 | 15 | INPUT: 16 | REA: 17 | ENABLED: True 18 | 19 | DATALOADER: 20 | BATCH_SIZE: 64 21 | ITERS: 400 22 | 23 | CHECKPOING: 24 | REMAIN_CLASSIFIER: False 25 | SAVE_STEP: [-1] 26 | PRETRAIN_PATH: "log/duke2market/pretrain/checkpoint_new.pth.tar" 27 | 28 | OPTIM: 29 | SCHED: "single_step" 30 | STEPS: [50] 31 | EPOCHS: 50 32 | -------------------------------------------------------------------------------- /configs/pretrain.yml: -------------------------------------------------------------------------------- 1 | DATASETS: 2 | SOURCE: "dukemtmc" 3 | TARGET: "market1501" 4 | DIR: "data" 5 | 6 | OUTPUT_DIR: "log/duke2market/pretrain" 7 | 8 | GPU_Device: [0,1,2,3] 9 | 10 | MODE: "pretrain" 11 | 12 | INPUT: 13 | REA: 14 | ENABLED: False 15 | 16 | MODEL: 17 | PART_DETACH: True 18 | 19 | DATALOADER: 20 | BATCH_SIZE: 64 21 | ITERS: 100 22 | 23 | CHECKPOING: 24 | REMAIN_CLASSIFIER: True 25 | SAVE_STEP: [80] 26 | 27 | OPTIM: 28 | SCHED: "warmupmultisteplr" 29 | STEPS: [40,70] 30 | EPOCHS: 80 31 | -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LunarShen/SECRET/33f081ed6475e5ecdafe89f44d77171a42b3301f/figs/framework.png -------------------------------------------------------------------------------- /figs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LunarShen/SECRET/33f081ed6475e5ecdafe89f44d77171a42b3301f/figs/results.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | from secret.config import get_cfg 3 | from secret.utils.defaults import default_argument_parser, default_setup 4 | from secret.engine import create_engine 5 | 6 | def setup(args): 7 | """ 8 | Create configs and perform basic setups. 9 | """ 10 | cfg = get_cfg() 11 | cfg.merge_from_file(args.config_file) 12 | cfg.merge_from_list(args.opts) 13 | cfg.freeze() 14 | default_setup(cfg, args) 15 | return cfg 16 | 17 | def main(args): 18 | cfg = setup(args) 19 | engine = create_engine(cfg) 20 | engine.run() 21 | return 22 | 23 | if __name__ == "__main__": 24 | args = default_argument_parser().parse_args() 25 | main(args) 26 | -------------------------------------------------------------------------------- /scripts/duke2market/eval.sh: -------------------------------------------------------------------------------- 1 | python main.py --config-file configs/mutualrefine.yml \ 2 | DATASETS.DIR "Data" \ 3 | DATASETS.TARGET "market1501" \ 4 | CHECKPOING.EVAL "log/duke2market/mutualrefine/model_mAP_best.pth.tar" \ 5 | OUTPUT_DIR "log/duke2market/eval" \ 6 | GPU_Device [0,1,2,3] \ 7 | MODE 'mutualrefine' \ 8 | MODEL.ARCH "resnet50" 9 | -------------------------------------------------------------------------------- /scripts/duke2market/pretrain.sh: -------------------------------------------------------------------------------- 1 | python main.py --config-file configs/pretrain.yml \ 2 | DATASETS.DIR "Data" \ 3 | DATASETS.SOURCE "dukemtmc" \ 4 | DATASETS.TARGET "market1501" \ 5 | OUTPUT_DIR "log/duke2market/pretrain" \ 6 | GPU_Device [0,1,2,3] \ 7 | MODE 'pretrain' \ 8 | MODEL.ARCH "resnet50" 9 | -------------------------------------------------------------------------------- /scripts/duke2market/train.sh: -------------------------------------------------------------------------------- 1 | python main.py --config-file configs/mutualrefine.yml \ 2 | DATASETS.DIR "Data" \ 3 | DATASETS.SOURCE "dukemtmc" \ 4 | DATASETS.TARGET "market1501" \ 5 | CHECKPOING.PRETRAIN_PATH "log/duke2market/pretrain/checkpoint_new.pth.tar" \ 6 | OUTPUT_DIR "log/duke2market/mutualrefine" \ 7 | GPU_Device [0,1,2,3] OPTIM.EPOCHS 50 \ 8 | MODE 'mutualrefine' \ 9 | MODEL.ARCH "resnet50" 10 | -------------------------------------------------------------------------------- /scripts/market2duke/eval.sh: -------------------------------------------------------------------------------- 1 | python main.py --config-file configs/mutualrefine.yml \ 2 | DATASETS.DIR "Data" \ 3 | DATASETS.TARGET "dukemtmc" \ 4 | CHECKPOING.EVAL "log/market2duke/mutualrefine/model_mAP_best.pth.tar" \ 5 | OUTPUT_DIR "log/market2duke/eval" \ 6 | GPU_Device [0,1,2,3] \ 7 | MODE 'mutualrefine' \ 8 | MODEL.ARCH "resnet50" 9 | -------------------------------------------------------------------------------- /scripts/market2duke/pretrain.sh: -------------------------------------------------------------------------------- 1 | python main.py --config-file configs/pretrain.yml \ 2 | DATASETS.DIR "Data" \ 3 | DATASETS.SOURCE "market1501" \ 4 | DATASETS.TARGET "dukemtmc" \ 5 | OUTPUT_DIR "log/market2duke/pretrain" \ 6 | GPU_Device [0,1,2,3] \ 7 | MODE 'pretrain' \ 8 | MODEL.ARCH "resnet50" 9 | -------------------------------------------------------------------------------- /scripts/market2duke/train.sh: -------------------------------------------------------------------------------- 1 | python main.py --config-file configs/mutualrefine.yml \ 2 | DATASETS.DIR "Data" \ 3 | DATASETS.SOURCE "market1501" \ 4 | DATASETS.TARGET "dukemtmc" \ 5 | CHECKPOING.PRETRAIN_PATH "log/market2duke/pretrain/checkpoint_new.pth.tar" \ 6 | OUTPUT_DIR "log/market2duke/mutualrefine" \ 7 | GPU_Device [0,1,2,3] OPTIM.EPOCHS 50 \ 8 | MODE 'mutualrefine' \ 9 | MODEL.ARCH "resnet50" 10 | -------------------------------------------------------------------------------- /scripts/market2msmt/eval.sh: -------------------------------------------------------------------------------- 1 | python main.py --config-file configs/mutualrefine.yml \ 2 | DATASETS.DIR "Data" \ 3 | DATASETS.TARGET "msmt17" \ 4 | CHECKPOING.EVAL "log/market2msmt/mutualrefine/model_mAP_best.pth.tar" \ 5 | OUTPUT_DIR "log/market2msmt/eval" \ 6 | GPU_Device [0,1,2,3] \ 7 | MODE 'mutualrefine' \ 8 | MODEL.ARCH "resnet50" 9 | -------------------------------------------------------------------------------- /scripts/market2msmt/pretrain.sh: -------------------------------------------------------------------------------- 1 | python main.py --config-file configs/pretrain.yml \ 2 | DATASETS.DIR "Data" \ 3 | DATASETS.SOURCE "market1501" \ 4 | DATASETS.TARGET "msmt17" \ 5 | OUTPUT_DIR "log/market2msmt/pretrain" \ 6 | GPU_Device [0,1,2,3] \ 7 | MODE 'pretrain' \ 8 | MODEL.ARCH "resnet50" 9 | -------------------------------------------------------------------------------- /scripts/market2msmt/train.sh: -------------------------------------------------------------------------------- 1 | python main.py --config-file configs/mutualrefine.yml \ 2 | DATASETS.DIR "Data" \ 3 | DATASETS.SOURCE "market1501" \ 4 | DATASETS.TARGET "msmt17" \ 5 | CHECKPOING.PRETRAIN_PATH "log/market2msmt/pretrain/checkpoint_new.pth.tar" \ 6 | OUTPUT_DIR "log/market2msmt/mutualrefine" \ 7 | GPU_Device [0,1,2,3] OPTIM.EPOCHS 50 \ 8 | MODE 'mutualrefine' \ 9 | MODEL.ARCH "resnet50" 10 | -------------------------------------------------------------------------------- /secret/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | -------------------------------------------------------------------------------- /secret/cluster/RefineCluster.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy 3 | import numpy as np 4 | import collections 5 | import torch 6 | 7 | def RefineClusterProcess(Reference_Cluster_result, Target_Cluster_result, divide_ratio): 8 | L = len(Reference_Cluster_result) 9 | assert L == len(Target_Cluster_result) 10 | 11 | Target_Cluster_nums = len(set(Target_Cluster_result)) - (1 if -1 in Target_Cluster_result else 0) 12 | 13 | Final_Cluster = np.zeros(L, dtype=np.int64) - 1 14 | assert len(np.where(Final_Cluster == -1)[0]) == L 15 | 16 | ban_cluster = 0 17 | for Target_Cluster in range(Target_Cluster_nums): 18 | Target_Cluster_index = np.where(Target_Cluster_result == Target_Cluster)[0] 19 | 20 | zero_index = np.where(Reference_Cluster_result == -1)[0] 21 | Target_Cluster_index = np.setdiff1d(Target_Cluster_index, zero_index) 22 | 23 | if np.size(Target_Cluster_index) == 0: 24 | ban_cluster+=1 25 | continue 26 | num_ID = len(Target_Cluster_index) 27 | num_Part = np.bincount(Reference_Cluster_result[Target_Cluster_index]) 28 | ban_flag = True 29 | 30 | for i in range(int(1/divide_ratio)): 31 | _max = np.argmax(num_Part) 32 | 33 | if num_Part[_max] > 0 and num_Part[_max] > num_ID * divide_ratio: 34 | Reference_Cluster_index = np.where(Reference_Cluster_result == _max)[0] 35 | fit_condition = np.intersect1d(Target_Cluster_index, Reference_Cluster_index) 36 | Final_Cluster[fit_condition] = Target_Cluster - ban_cluster 37 | num_Part[_max] = 0 38 | ban_flag = False 39 | else: 40 | break 41 | if ban_flag: 42 | ban_cluster += 1 43 | 44 | return Final_Cluster 45 | -------------------------------------------------------------------------------- /secret/cluster/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import faiss_utils 4 | from . import RefineCluster 5 | -------------------------------------------------------------------------------- /secret/cluster/faiss_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import os, sys 4 | import time 5 | import numpy as np 6 | from scipy.spatial.distance import cdist 7 | import gc 8 | import faiss 9 | import logging 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | def swig_ptr_from_FloatTensor(x): 15 | assert x.is_contiguous() 16 | assert x.dtype == torch.float32 17 | return faiss.cast_integer_to_float_ptr( 18 | x.storage().data_ptr() + x.storage_offset() * 4) 19 | 20 | def swig_ptr_from_LongTensor(x): 21 | assert x.is_contiguous() 22 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 23 | return faiss.cast_integer_to_long_ptr( 24 | x.storage().data_ptr() + x.storage_offset() * 8) 25 | 26 | def search_index_pytorch(index, x, k, D=None, I=None): 27 | """call the search function of an index with pytorch tensor I/O (CPU 28 | and GPU supported)""" 29 | assert x.is_contiguous() 30 | n, d = x.size() 31 | assert d == index.d 32 | 33 | if D is None: 34 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 35 | else: 36 | assert D.size() == (n, k) 37 | 38 | if I is None: 39 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 40 | else: 41 | assert I.size() == (n, k) 42 | torch.cuda.synchronize() 43 | xptr = swig_ptr_from_FloatTensor(x) 44 | Iptr = swig_ptr_from_LongTensor(I) 45 | Dptr = swig_ptr_from_FloatTensor(D) 46 | index.search_c(n, xptr, 47 | k, Dptr, Iptr) 48 | torch.cuda.synchronize() 49 | return D, I 50 | 51 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 52 | metric=faiss.METRIC_L2): 53 | assert xb.device == xq.device 54 | 55 | nq, d = xq.size() 56 | if xq.is_contiguous(): 57 | xq_row_major = True 58 | elif xq.t().is_contiguous(): 59 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 60 | xq_row_major = False 61 | else: 62 | raise TypeError('matrix should be row or column-major') 63 | 64 | xq_ptr = swig_ptr_from_FloatTensor(xq) 65 | 66 | nb, d2 = xb.size() 67 | assert d2 == d 68 | if xb.is_contiguous(): 69 | xb_row_major = True 70 | elif xb.t().is_contiguous(): 71 | xb = xb.t() 72 | xb_row_major = False 73 | else: 74 | raise TypeError('matrix should be row or column-major') 75 | xb_ptr = swig_ptr_from_FloatTensor(xb) 76 | 77 | if D is None: 78 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 79 | else: 80 | assert D.shape == (nq, k) 81 | assert D.device == xb.device 82 | 83 | if I is None: 84 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 85 | else: 86 | assert I.shape == (nq, k) 87 | assert I.device == xb.device 88 | 89 | D_ptr = swig_ptr_from_FloatTensor(D) 90 | I_ptr = swig_ptr_from_LongTensor(I) 91 | 92 | faiss.bruteForceKnn(res, metric, 93 | xb_ptr, xb_row_major, nb, 94 | xq_ptr, xq_row_major, nq, 95 | d, k, D_ptr, I_ptr) 96 | 97 | return D, I 98 | 99 | def index_init_gpu(ngpus, feat_dim): 100 | flat_config = [] 101 | for i in range(ngpus): 102 | cfg = faiss.GpuIndexFlatConfig() 103 | cfg.useFloat16 = False 104 | cfg.device = i 105 | flat_config.append(cfg) 106 | 107 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 108 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 109 | index = faiss.IndexShards(feat_dim) 110 | for sub_index in indexes: 111 | index.add_shard(sub_index) 112 | index.reset() 113 | return index 114 | 115 | def index_init_cpu(feat_dim): 116 | return faiss.IndexFlatL2(feat_dim) 117 | 118 | def k_reciprocal_neigh(initial_rank, i, k1): 119 | forward_k_neigh_index = initial_rank[i,:k1+1] 120 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 121 | fi = np.where(backward_k_neigh_index==i)[0] 122 | return forward_k_neigh_index[fi] 123 | 124 | def compute_jaccard_distance(target_features, k1=30, k2=6, print_flag=True, search_option=0, use_float16=False): 125 | end = time.time() 126 | if print_flag: 127 | logger = logging.getLogger('UnReID') 128 | logger.info('Computing jaccard distance...') 129 | 130 | ngpus = faiss.get_num_gpus() 131 | N = target_features.size(0) 132 | mat_type = np.float16 if use_float16 else np.float32 133 | 134 | if (search_option==0): 135 | # GPU + PyTorch CUDA Tensors (1) 136 | res = faiss.StandardGpuResources() 137 | res.setDefaultNullStreamAllDevices() 138 | _, initial_rank = search_raw_array_pytorch(res, target_features, target_features, k1) 139 | initial_rank = initial_rank.cpu().numpy() 140 | elif (search_option==1): 141 | # GPU + PyTorch CUDA Tensors (2) 142 | res = faiss.StandardGpuResources() 143 | index = faiss.GpuIndexFlatL2(res, target_features.size(-1)) 144 | index.add(target_features.cpu().numpy()) 145 | _, initial_rank = search_index_pytorch(index, target_features, k1) 146 | res.syncDefaultStreamCurrentDevice() 147 | initial_rank = initial_rank.cpu().numpy() 148 | elif (search_option==2): 149 | # GPU 150 | index = index_init_gpu(ngpus, target_features.size(-1)) 151 | index.add(target_features.cpu().numpy()) 152 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 153 | else: 154 | # CPU 155 | index = index_init_cpu(target_features.size(-1)) 156 | index.add(target_features.cpu().numpy()) 157 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 158 | 159 | 160 | nn_k1 = [] 161 | nn_k1_half = [] 162 | for i in range(N): 163 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 164 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1/2)))) 165 | 166 | V = np.zeros((N, N), dtype=mat_type) 167 | for i in range(N): 168 | k_reciprocal_index = nn_k1[i] 169 | k_reciprocal_expansion_index = k_reciprocal_index 170 | for candidate in k_reciprocal_index: 171 | candidate_k_reciprocal_index = nn_k1_half[candidate] 172 | if (len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index)) > 2/3*len(candidate_k_reciprocal_index)): 173 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 174 | 175 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 176 | dist = 2-2*torch.mm(target_features[i].unsqueeze(0).contiguous(), target_features[k_reciprocal_expansion_index].t()) 177 | if use_float16: 178 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type) 179 | else: 180 | V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy() 181 | 182 | del nn_k1, nn_k1_half 183 | 184 | if k2 != 1: 185 | V_qe = np.zeros_like(V, dtype=mat_type) 186 | for i in range(N): 187 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:], axis=0) 188 | V = V_qe 189 | del V_qe 190 | 191 | del initial_rank 192 | 193 | invIndex = [] 194 | for i in range(N): 195 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num 196 | 197 | jaccard_dist = np.zeros((N, N), dtype=mat_type) 198 | for i in range(N): 199 | temp_min = np.zeros((1,N), dtype=mat_type) 200 | # temp_max = np.zeros((1,N), dtype=mat_type) 201 | indNonZero = np.where(V[i,:] != 0)[0] 202 | indImages = [] 203 | indImages = [invIndex[ind] for ind in indNonZero] 204 | for j in range(len(indNonZero)): 205 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 206 | # temp_max[0,indImages[j]] = temp_max[0,indImages[j]]+np.maximum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 207 | 208 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 209 | # jaccard_dist[i] = 1-temp_min/(temp_max+1e-6) 210 | 211 | del invIndex, V 212 | 213 | pos_bool = (jaccard_dist < 0) 214 | jaccard_dist[pos_bool] = 0.0 215 | if print_flag: 216 | logger.info("Jaccard distance computing time cost: {}".format(time.time()-end)) 217 | 218 | return jaccard_dist 219 | -------------------------------------------------------------------------------- /secret/config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .config import CfgNode, get_cfg 8 | from .defaults import _C as cfg 9 | -------------------------------------------------------------------------------- /secret/config/config.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import logging 7 | import os 8 | from typing import Any 9 | 10 | import yaml 11 | from yacs.config import CfgNode as _CfgNode 12 | 13 | from ..utils.osutils import PathManager 14 | 15 | 16 | BASE_KEY = "_BASE_" 17 | 18 | class CfgNode(_CfgNode): 19 | """ 20 | Our own extended version of :class:`yacs.config.CfgNode`. 21 | It contains the following extra features: 22 | 1. The :meth:`merge_from_file` method supports the "_BASE_" key, 23 | which allows the new CfgNode to inherit all the attributes from the 24 | base configuration file. 25 | 2. Keys that start with "COMPUTED_" are treated as insertion-only 26 | "computed" attributes. They can be inserted regardless of whether 27 | the CfgNode is frozen or not. 28 | 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate 29 | expressions in config. See examples in 30 | https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types 31 | Note that this may lead to arbitrary code execution: you must not 32 | load a config file from untrusted sources before manually inspecting 33 | the content of the file. 34 | """ 35 | 36 | @staticmethod 37 | def load_yaml_with_base(filename: str, allow_unsafe: bool = False): 38 | """ 39 | Just like `yaml.load(open(filename))`, but inherit attributes from its 40 | `_BASE_`. 41 | Args: 42 | filename (str): the file name of the current config. Will be used to 43 | find the base config file. 44 | allow_unsafe (bool): whether to allow loading the config file with 45 | `yaml.unsafe_load`. 46 | Returns: 47 | (dict): the loaded yaml 48 | """ 49 | with PathManager.open(filename, "r") as f: 50 | try: 51 | cfg = yaml.safe_load(f) 52 | except yaml.constructor.ConstructorError: 53 | if not allow_unsafe: 54 | raise 55 | logger = logging.getLogger(__name__) 56 | logger.warning( 57 | "Loading config {} with yaml.unsafe_load. Your machine may " 58 | "be at risk if the file contains malicious content.".format( 59 | filename 60 | ) 61 | ) 62 | f.close() 63 | with open(filename, "r") as f: 64 | cfg = yaml.unsafe_load(f) 65 | 66 | def merge_a_into_b(a, b): 67 | # merge dict a into dict b. values in a will overwrite b. 68 | for k, v in a.items(): 69 | if isinstance(v, dict) and k in b: 70 | assert isinstance( 71 | b[k], dict 72 | ), "Cannot inherit key '{}' from base!".format(k) 73 | merge_a_into_b(v, b[k]) 74 | else: 75 | b[k] = v 76 | 77 | if BASE_KEY in cfg: 78 | base_cfg_file = cfg[BASE_KEY] 79 | if base_cfg_file.startswith("~"): 80 | base_cfg_file = os.path.expanduser(base_cfg_file) 81 | if not any( 82 | map(base_cfg_file.startswith, ["/", "https://", "http://"]) 83 | ): 84 | # the path to base cfg is relative to the config file itself. 85 | base_cfg_file = os.path.join( 86 | os.path.dirname(filename), base_cfg_file 87 | ) 88 | base_cfg = CfgNode.load_yaml_with_base( 89 | base_cfg_file, allow_unsafe=allow_unsafe 90 | ) 91 | del cfg[BASE_KEY] 92 | 93 | merge_a_into_b(cfg, base_cfg) 94 | return base_cfg 95 | return cfg 96 | 97 | def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False): 98 | """ 99 | Merge configs from a given yaml file. 100 | Args: 101 | cfg_filename: the file name of the yaml config. 102 | allow_unsafe: whether to allow loading the config file with 103 | `yaml.unsafe_load`. 104 | """ 105 | loaded_cfg = CfgNode.load_yaml_with_base( 106 | cfg_filename, allow_unsafe=allow_unsafe 107 | ) 108 | loaded_cfg = type(self)(loaded_cfg) 109 | self.merge_from_other_cfg(loaded_cfg) 110 | 111 | # Forward the following calls to base, but with a check on the BASE_KEY. 112 | def merge_from_other_cfg(self, cfg_other): 113 | """ 114 | Args: 115 | cfg_other (CfgNode): configs to merge from. 116 | """ 117 | assert ( 118 | BASE_KEY not in cfg_other 119 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 120 | return super().merge_from_other_cfg(cfg_other) 121 | 122 | def merge_from_list(self, cfg_list: list): 123 | """ 124 | Args: 125 | cfg_list (list): list of configs to merge from. 126 | """ 127 | keys = set(cfg_list[0::2]) 128 | assert ( 129 | BASE_KEY not in keys 130 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 131 | return super().merge_from_list(cfg_list) 132 | 133 | def __setattr__(self, name: str, val: Any): 134 | if name.startswith("COMPUTED_"): 135 | if name in self: 136 | old_val = self[name] 137 | if old_val == val: 138 | return 139 | raise KeyError( 140 | "Computed attributed '{}' already exists " 141 | "with a different value! old={}, new={}.".format( 142 | name, old_val, val 143 | ) 144 | ) 145 | self[name] = val 146 | else: 147 | super().__setattr__(name, val) 148 | 149 | 150 | def get_cfg() -> CfgNode: 151 | """ 152 | Get a copy of the default config. 153 | Returns: 154 | a fastreid CfgNode instance. 155 | """ 156 | from .defaults import _C 157 | 158 | return _C.clone() 159 | -------------------------------------------------------------------------------- /secret/config/defaults.py: -------------------------------------------------------------------------------- 1 | from .config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | # or _TEST for a test-specific parameter. 9 | # For example, the number of images during training will be 10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be 11 | # IMAGES_PER_BATCH_TEST 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Config definition 15 | # ----------------------------------------------------------------------------- 16 | 17 | _C = CN() 18 | 19 | # ----------------------------------------------------------------------------- 20 | # MODEL 21 | # ----------------------------------------------------------------------------- 22 | _C.MODEL = CN() 23 | _C.MODEL.DEVICE = "cuda" 24 | _C.MODEL.ARCH = 'resnet50' 25 | 26 | # ---------------------------------------------------------------------------- # 27 | # Backbone options 28 | # ---------------------------------------------------------------------------- # 29 | _C.MODEL.BACKBONE = CN() 30 | # If use ImageNet pretrain model 31 | _C.MODEL.BACKBONE.PRETRAIN = True 32 | _C.MODEL.PART_DETACH = False 33 | 34 | # ---------------------------------------------------------------------------- # 35 | # REID LOSSES options 36 | # ---------------------------------------------------------------------------- # 37 | _C.MODEL.LOSSES = CN() 38 | 39 | # Cross Entropy Loss options 40 | _C.MODEL.LOSSES.CE = CN() 41 | _C.MODEL.LOSSES.CE.EPSILON = 0.1 42 | 43 | # Triplet Loss options 44 | _C.MODEL.LOSSES.TRI = CN() 45 | 46 | _C.MEAN_TEACH = CN() 47 | _C.MEAN_TEACH.CE_SOFT_WRIGHT = 0.5 48 | _C.MEAN_TEACH.TRI_SOFT_WRIGHT = 0.8 49 | _C.MEAN_TEACH.ALPHA = 0.999 50 | 51 | _C.CLUSTER = CN() 52 | _C.CLUSTER.K1 = 30 53 | _C.CLUSTER.K2 = 6 54 | _C.CLUSTER.EPS = 0.600 55 | _C.CLUSTER.REFINE_K = 0.4 56 | # ----------------------------------------------------------------------------- 57 | # INPU 58 | # ----------------------------------------------------------------------------- 59 | _C.INPUT = CN() 60 | # Size of the image during training 61 | _C.INPUT.SIZE_TRAIN = [256, 128] 62 | # Size of the image during test 63 | _C.INPUT.SIZE_TEST = [256, 128] 64 | # Values to be used for image normalization 65 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 66 | # Values to be used for image normalization 67 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 68 | 69 | # Random probability for image horizontal flip 70 | _C.INPUT.DO_FLIP = True 71 | _C.INPUT.FLIP_PROB = 0.5 72 | 73 | # Value of padding size 74 | _C.INPUT.DO_PAD = True 75 | _C.INPUT.PADDING = 10 76 | 77 | # Random Erasing 78 | _C.INPUT.REA = CN() 79 | _C.INPUT.REA.ENABLED = False 80 | _C.INPUT.REA.PROB = 0.5 81 | _C.INPUT.REA.MEAN = [0.485, 0.456, 0.406] 82 | 83 | # ----------------------------------------------------------------------------- 84 | # Dataset 85 | # ----------------------------------------------------------------------------- 86 | _C.DATASETS = CN() 87 | 88 | _C.DATASETS.SOURCE = "dukemtmc" 89 | 90 | _C.DATASETS.TARGET = "market1501" 91 | 92 | _C.DATASETS.DIR = "Data" 93 | 94 | # ----------------------------------------------------------------------------- 95 | # DataLoader 96 | # ----------------------------------------------------------------------------- 97 | _C.DATALOADER = CN() 98 | # Number of instance for each person 99 | _C.DATALOADER.NUM_INSTANCES = 4 100 | _C.DATALOADER.NUM_WORKERS = 4 101 | 102 | _C.DATALOADER.BATCH_SIZE = 64 103 | _C.DATALOADER.ITER_MODE = True 104 | _C.DATALOADER.ITERS = 100 105 | 106 | # ---------------------------------------------------------------------------- # 107 | # OPTIM 108 | # ---------------------------------------------------------------------------- # 109 | _C.OPTIM = CN() 110 | _C.OPTIM.OPT = 'adam' 111 | _C.OPTIM.LR = 0.00035 112 | _C.OPTIM.WEIGHT_DECAY = 5e-04 113 | _C.OPTIM.MOMENTUM = 0.9 114 | 115 | _C.OPTIM.SGD_DAMPENING = 0 116 | _C.OPTIM.SGD_NESTEROV = False 117 | 118 | _C.OPTIM.RMSPROP_ALPHA = 0.99 119 | 120 | _C.OPTIM.ADAM_BETA1 = 0.9 121 | _C.OPTIM.ADAM_BETA2 = 0.999 122 | 123 | # Multi-step learning rate options 124 | _C.OPTIM.SCHED = "warmupmultisteplr" 125 | _C.OPTIM.GAMMA = 0.1 126 | _C.OPTIM.STEPS = [40, 70] 127 | 128 | _C.OPTIM.WARMUP_ITERS = 10 129 | _C.OPTIM.WARMUP_FACTOR = 0.01 130 | _C.OPTIM.WARMUP_METHOD = "linear" 131 | 132 | _C.OPTIM.EPOCHS = 80 133 | 134 | _C.TEST = CN() 135 | _C.TEST.PRINT_PERIOD = 200 136 | 137 | # Re-rank 138 | _C.TEST.RERANK = CN() 139 | _C.TEST.RERANK.ENABLED = False 140 | _C.TEST.RERANK.K1 = 20 141 | _C.TEST.RERANK.K2 = 6 142 | _C.TEST.RERANK.LAMBDA = 0.3 143 | 144 | # ---------------------------------------------------------------------------- # 145 | # Misc options 146 | # ---------------------------------------------------------------------------- # 147 | _C.MODE = "USL" 148 | _C.OUTPUT_DIR = "log/test" 149 | _C.RESUME = "" 150 | _C.PRINT_PERIOD = 100 151 | _C.SEED = 1 152 | _C.GPU_Device = [0,1,2,3] 153 | 154 | _C.CHECKPOING = CN() 155 | _C.CHECKPOING.REMAIN_CLASSIFIER = True 156 | _C.CHECKPOING.SAVE_STEP = [10] 157 | _C.CHECKPOING.PRETRAIN_PATH = '' 158 | _C.CHECKPOING.EVAL = '' 159 | 160 | _C.CUDNN_BENCHMARK = True 161 | -------------------------------------------------------------------------------- /secret/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import build 4 | -------------------------------------------------------------------------------- /secret/data/build.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | from .datasets import * 4 | from .samplers import * 5 | from .preprocessor import * 6 | from .transforms import build_transforms 7 | from torch.utils.data import DataLoader 8 | 9 | class IterLoader: 10 | def __init__(self, loader, length=None): 11 | self.loader = loader 12 | self.length = length 13 | self.iter = None 14 | 15 | def __len__(self): 16 | if (self.length is not None): 17 | return self.length 18 | return len(self.loader) 19 | 20 | def new_epoch(self): 21 | self.iter = iter(self.loader) 22 | 23 | def next(self): 24 | try: 25 | return next(self.iter) 26 | except: 27 | self.iter = iter(self.loader) 28 | return next(self.iter) 29 | 30 | def build_data(name, data_dir, mode='ReID'): 31 | root = osp.join(data_dir, name) 32 | dataset = create(name, root) 33 | return dataset 34 | 35 | def build_loader(cfg, dataset, inputset=None, num_instances = 4, is_train = True, mode = None): 36 | if mode is None: mode = cfg.MODE 37 | 38 | transform = build_transforms(cfg, is_train) 39 | 40 | if is_train: 41 | 42 | if mode == 'mutualrefine': 43 | dataset = sorted(dataset.train) if inputset is None else inputset 44 | 45 | rmgs_flag = num_instances > 0 46 | if rmgs_flag: 47 | sampler = PartRandomMultipleGallerySampler(dataset, num_instances) 48 | else: 49 | sampler = None 50 | 51 | loader = DataLoader(Preprocessor(dataset, transform=transform), 52 | batch_size=cfg.DATALOADER.BATCH_SIZE, num_workers=cfg.DATALOADER.NUM_WORKERS, sampler=sampler, 53 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True) 54 | if cfg.DATALOADER.ITER_MODE: 55 | loader = IterLoader(loader, length = cfg.DATALOADER.ITERS) 56 | 57 | elif mode == 'pretrain': 58 | dataset = sorted(dataset.train) if inputset is None else inputset 59 | 60 | rmgs_flag = num_instances > 0 61 | if rmgs_flag: 62 | sampler = RandomMultipleGallerySampler(dataset, num_instances) 63 | else: 64 | sampler = None 65 | 66 | loader = DataLoader(Preprocessor(dataset, transform=transform), 67 | batch_size=cfg.DATALOADER.BATCH_SIZE, num_workers=cfg.DATALOADER.NUM_WORKERS, sampler=sampler, 68 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True) 69 | if cfg.DATALOADER.ITER_MODE: 70 | loader = IterLoader(loader, length = cfg.DATALOADER.ITERS) 71 | else: 72 | raise KeyError('NotImplementedError') 73 | else: 74 | dataset = list(set(dataset.query) | set(dataset.gallery)) if inputset is None else inputset 75 | loader = DataLoader( 76 | Preprocessor(dataset, transform=transform), 77 | batch_size=cfg.DATALOADER.BATCH_SIZE, num_workers=cfg.DATALOADER.NUM_WORKERS, 78 | shuffle=False, pin_memory=True) 79 | 80 | return loader 81 | -------------------------------------------------------------------------------- /secret/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .dukemtmc import DukeMTMC 5 | from .market1501 import Market1501 6 | from .msmt17 import MSMT17 7 | 8 | __factory = { 9 | 'msmt17': MSMT17, 10 | 'dukemtmc': DukeMTMC, 11 | 'market1501': Market1501, 12 | } 13 | 14 | def names(): 15 | return sorted(__factory.keys()) 16 | 17 | def create(name, root, *args, **kwargs): 18 | """ 19 | Create a dataset instance. 20 | 21 | Parameters 22 | ---------- 23 | name : str 24 | The dataset name. 25 | root : str 26 | The path to the dataset directory. 27 | split_id : int, optional 28 | The index of data split. Default: 0 29 | num_val : int or float, optional 30 | When int, it means the number of validation identities. When float, 31 | it means the proportion of validation to all the trainval. Default: 100 32 | download : bool, optional 33 | If True, will download the dataset. Default: False 34 | """ 35 | if name not in __factory: 36 | raise KeyError("Unknown dataset:", name) 37 | return __factory[name](root, *args, **kwargs) 38 | 39 | 40 | def get_dataset(name, root, *args, **kwargs): 41 | warnings.warn("get_dataset is deprecated. Use create instead.") 42 | return create(name, root, *args, **kwargs) 43 | -------------------------------------------------------------------------------- /secret/data/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | import logging 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def print_dataset_statistics(self): 23 | raise NotImplementedError 24 | 25 | @property 26 | def images_dir(self): 27 | return None 28 | 29 | 30 | class BaseImageDataset(BaseDataset): 31 | """ 32 | Base class of image reid dataset 33 | """ 34 | 35 | def print_dataset_statistics(self, train, query, gallery): 36 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 37 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 38 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 39 | 40 | logger = logging.getLogger('UnReID') 41 | logger.info("Dataset statistics:" + 42 | "\n ----------------------------------------" + 43 | "\n subset | # ids | # images | # cameras" + 44 | "\n ----------------------------------------" + 45 | "\n train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams) + 46 | "\n query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams) + 47 | "\n gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams) + 48 | "\n ----------------------------------------") 49 | -------------------------------------------------------------------------------- /secret/data/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | import logging 8 | 9 | from .base_dataset import BaseImageDataset 10 | 11 | class DukeMTMC(BaseImageDataset): 12 | """ 13 | DukeMTMC-reID 14 | Reference: 15 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 16 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 17 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 18 | 19 | Dataset statistics: 20 | # identities: 1404 (train + query) 21 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 22 | # cameras: 8 23 | """ 24 | dataset_dir = '.' 25 | 26 | def __init__(self, root, verbose=True, **kwargs): 27 | super(DukeMTMC, self).__init__() 28 | self.dataset_dir = osp.join(root, self.dataset_dir) 29 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 30 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 31 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 32 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 33 | 34 | self._check_before_run() 35 | 36 | train = self._process_dir(self.train_dir, relabel=True) 37 | query = self._process_dir(self.query_dir, relabel=False) 38 | gallery = self._process_dir(self.gallery_dir, relabel=False) 39 | 40 | if verbose: 41 | logger = logging.getLogger('UnReID') 42 | logger.info("=> DukeMTMC-reID loaded") 43 | self.print_dataset_statistics(train, query, gallery) 44 | 45 | self.train = train 46 | self.query = query 47 | self.gallery = gallery 48 | 49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 52 | 53 | def _check_before_run(self): 54 | """Check if all files are available before going deeper""" 55 | if not osp.exists(self.dataset_dir): 56 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 57 | if not osp.exists(self.train_dir): 58 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 59 | if not osp.exists(self.query_dir): 60 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 61 | if not osp.exists(self.gallery_dir): 62 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 63 | 64 | def _process_dir(self, dir_path, relabel=False): 65 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 66 | pattern = re.compile(r'([-\d]+)_c(\d)') 67 | 68 | pid_container = set() 69 | for img_path in img_paths: 70 | pid, _ = map(int, pattern.search(img_path).groups()) 71 | pid_container.add(pid) 72 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 73 | 74 | dataset = [] 75 | for img_path in img_paths: 76 | pid, camid = map(int, pattern.search(img_path).groups()) 77 | assert 1 <= camid <= 8 78 | camid -= 1 # index starts from 0 79 | if relabel: pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | -------------------------------------------------------------------------------- /secret/data/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | import logging 8 | 9 | from .base_dataset import BaseImageDataset 10 | 11 | class Market1501(BaseImageDataset): 12 | """ 13 | Market1501 14 | Reference: 15 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 16 | URL: http://www.liangzheng.org/Project/project_reid.html 17 | 18 | Dataset statistics: 19 | # identities: 1501 (+1 for background) 20 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 21 | """ 22 | dataset_dir = 'Market-1501-v15.09.15' 23 | 24 | def __init__(self, root, verbose=True, **kwargs): 25 | super(Market1501, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 30 | 31 | self._check_before_run() 32 | 33 | train = self._process_dir(self.train_dir, relabel=True) 34 | query = self._process_dir(self.query_dir, relabel=False) 35 | gallery = self._process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | logger = logging.getLogger('UnReID') 39 | logger.info("=> Market1501 loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _process_dir(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c(\d)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | dataset = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1: continue # junk images are just ignored 76 | assert 0 <= pid <= 1501 # pid == 0 means background 77 | assert 1 <= camid <= 6 78 | camid -= 1 # index starts from 0 79 | if relabel: pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | -------------------------------------------------------------------------------- /secret/data/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import logging 4 | 5 | import re 6 | 7 | def _pluck_msmt(list_file, subdir, pattern=re.compile(r'([-\d]+)_([-\d]+)_([-\d]+)')): 8 | with open(list_file, 'r') as f: 9 | lines = f.readlines() 10 | ret = [] 11 | pids = [] 12 | for line in lines: 13 | line = line.strip() 14 | fname = line.split(' ')[0] 15 | pid, _, cam = map(int, pattern.search(osp.basename(fname)).groups()) 16 | if pid not in pids: 17 | pids.append(pid) 18 | ret.append((osp.join(subdir,fname), pid, cam)) 19 | return ret, pids 20 | 21 | class Dataset_MSMT(object): 22 | def __init__(self, root): 23 | self.root = root 24 | self.train, self.val, self.trainval = [], [], [] 25 | self.query, self.gallery = [], [] 26 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 27 | 28 | @property 29 | def images_dir(self): 30 | return osp.join(self.root, 'MSMT17_V1') 31 | 32 | def load(self, verbose=True): 33 | exdir = osp.join(self.root, 'MSMT17_V1') 34 | self.train, train_pids = _pluck_msmt(osp.join(exdir, 'list_train.txt'), osp.join(exdir, 'train')) 35 | self.val, val_pids = _pluck_msmt(osp.join(exdir, 'list_val.txt'), osp.join(exdir, 'train')) 36 | self.train = self.train + self.val 37 | self.query, query_pids = _pluck_msmt(osp.join(exdir, 'list_query.txt'), osp.join(exdir, 'test')) 38 | self.gallery, gallery_pids = _pluck_msmt(osp.join(exdir, 'list_gallery.txt'), osp.join(exdir, 'test')) 39 | self.num_train_pids = len(list(set(train_pids).union(set(val_pids)))) 40 | 41 | if verbose: 42 | logger = logging.getLogger('UnReID') 43 | logger.info("=> MSMT17 loaded") 44 | logger.info("{} dataset loaded".format(self.__class__.__name__)) 45 | logger.info(" subset | # ids | # images") 46 | logger.info(" ---------------------------") 47 | logger.info(" train | {:5d} | {:8d}" 48 | .format(self.num_train_pids, len(self.train))) 49 | logger.info(" query | {:5d} | {:8d}" 50 | .format(len(query_pids), len(self.query))) 51 | logger.info(" gallery | {:5d} | {:8d}" 52 | .format(len(gallery_pids), len(self.gallery))) 53 | 54 | class MSMT17(Dataset_MSMT): 55 | 56 | def __init__(self, root, split_id=0, download=True): 57 | super(MSMT17, self).__init__(root) 58 | 59 | self.load() 60 | -------------------------------------------------------------------------------- /secret/data/preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .preprocessor import Preprocessor 3 | -------------------------------------------------------------------------------- /secret/data/preprocessor/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import random 7 | import math 8 | from PIL import Image 9 | 10 | class Preprocessor(Dataset): 11 | def __init__(self, dataset, transform=None): 12 | super(Preprocessor, self).__init__() 13 | self.dataset = dataset 14 | self.transform = transform 15 | 16 | def __len__(self): 17 | return len(self.dataset) 18 | 19 | def __getitem__(self, indices): 20 | return self._get_single_item(indices) 21 | 22 | def _get_single_item(self, index): 23 | fname, pid, camid = self.dataset[index] 24 | fpath = fname 25 | 26 | img = Image.open(fpath).convert('RGB') 27 | 28 | if self.transform is not None: 29 | img = self.transform(img) 30 | 31 | return img, fname, pid, index 32 | # return img, fname, pid, camid 33 | -------------------------------------------------------------------------------- /secret/data/samplers/PartRandomMultipleGallerySampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import math 4 | 5 | import numpy as np 6 | import copy 7 | import random 8 | import torch 9 | from torch.utils.data.sampler import ( 10 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 11 | WeightedRandomSampler) 12 | 13 | def No_index(a, b): 14 | assert isinstance(a, list) 15 | return [i for i, j in enumerate(a) if j != b] 16 | 17 | class PartRandomMultipleGallerySampler(Sampler): 18 | def __init__(self, data_source, num_instances=4): 19 | self.data_source = data_source 20 | self.index_pid = defaultdict(int) 21 | self.pid_cam = defaultdict(list) 22 | self.pid_index = defaultdict(list) 23 | self.num_instances = num_instances 24 | 25 | for index, (_, pid, cam) in enumerate(data_source): 26 | self.index_pid[index] = pid[0] 27 | self.pid_cam[pid[0]].append(cam) 28 | self.pid_index[pid[0]].append(index) 29 | 30 | self.pids = list(self.pid_index.keys()) 31 | self.num_samples = len(self.pids) 32 | 33 | def __len__(self): 34 | return self.num_samples * self.num_instances 35 | 36 | def __iter__(self): 37 | indices = torch.randperm(len(self.pids)).tolist() 38 | ret = [] 39 | 40 | for kid in indices: 41 | i = random.choice(self.pid_index[self.pids[kid]]) 42 | 43 | _, i_pid, i_cam = self.data_source[i] 44 | 45 | ret.append(i) 46 | 47 | pid_i = self.index_pid[i] 48 | cams = self.pid_cam[pid_i] 49 | index = self.pid_index[pid_i] 50 | select_cams = No_index(cams, i_cam) 51 | 52 | if select_cams: 53 | 54 | if len(select_cams) >= self.num_instances: 55 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 56 | else: 57 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True) 58 | 59 | for kk in cam_indexes: 60 | ret.append(index[kk]) 61 | 62 | else: 63 | select_indexes = No_index(index, i) 64 | if (not select_indexes): continue 65 | if len(select_indexes) >= self.num_instances: 66 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 67 | else: 68 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 69 | 70 | for kk in ind_indexes: 71 | ret.append(index[kk]) 72 | 73 | 74 | return iter(ret) 75 | -------------------------------------------------------------------------------- /secret/data/samplers/RandomMultipleGallerySampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import math 4 | 5 | import numpy as np 6 | import copy 7 | import random 8 | import torch 9 | from torch.utils.data.sampler import ( 10 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 11 | WeightedRandomSampler) 12 | 13 | def No_index(a, b): 14 | assert isinstance(a, list) 15 | return [i for i, j in enumerate(a) if j != b] 16 | 17 | class RandomMultipleGallerySampler(Sampler): 18 | def __init__(self, data_source, num_instances=4): 19 | self.data_source = data_source 20 | self.index_pid = defaultdict(int) 21 | self.pid_cam = defaultdict(list) 22 | self.pid_index = defaultdict(list) 23 | self.num_instances = num_instances 24 | 25 | for index, (_, pid, cam) in enumerate(data_source): 26 | if (pid<0): continue 27 | self.index_pid[index] = pid 28 | self.pid_cam[pid].append(cam) 29 | self.pid_index[pid].append(index) 30 | 31 | self.pids = list(self.pid_index.keys()) 32 | self.num_samples = len(self.pids) 33 | 34 | def __len__(self): 35 | return self.num_samples * self.num_instances 36 | 37 | def __iter__(self): 38 | indices = torch.randperm(len(self.pids)).tolist() 39 | ret = [] 40 | 41 | for kid in indices: 42 | i = random.choice(self.pid_index[self.pids[kid]]) 43 | 44 | _, i_pid, i_cam = self.data_source[i] 45 | 46 | ret.append(i) 47 | 48 | pid_i = self.index_pid[i] 49 | cams = self.pid_cam[pid_i] 50 | index = self.pid_index[pid_i] 51 | select_cams = No_index(cams, i_cam) 52 | 53 | if select_cams: 54 | 55 | if len(select_cams) >= self.num_instances: 56 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 57 | else: 58 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True) 59 | 60 | for kk in cam_indexes: 61 | ret.append(index[kk]) 62 | 63 | else: 64 | select_indexes = No_index(index, i) 65 | if (not select_indexes): continue 66 | if len(select_indexes) >= self.num_instances: 67 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 68 | else: 69 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 70 | 71 | for kk in ind_indexes: 72 | ret.append(index[kk]) 73 | 74 | 75 | return iter(ret) 76 | -------------------------------------------------------------------------------- /secret/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch.utils.data.sampler import ( 4 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 5 | WeightedRandomSampler) 6 | 7 | from .RandomMultipleGallerySampler import RandomMultipleGallerySampler 8 | from .PartRandomMultipleGallerySampler import PartRandomMultipleGallerySampler 9 | -------------------------------------------------------------------------------- /secret/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .build import build_transforms 4 | from .transforms import * 5 | -------------------------------------------------------------------------------- /secret/data/transforms/build.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .transforms import * 3 | import torchvision 4 | 5 | def build_transforms(cfg, is_train=True): 6 | res = [] 7 | 8 | if is_train: 9 | size_train = cfg.INPUT.SIZE_TRAIN 10 | 11 | # horizontal filp 12 | do_flip = cfg.INPUT.DO_FLIP 13 | flip_prob = cfg.INPUT.FLIP_PROB 14 | 15 | # padding 16 | do_pad = cfg.INPUT.DO_PAD 17 | padding = cfg.INPUT.PADDING 18 | 19 | # random erasing 20 | do_rea = cfg.INPUT.REA.ENABLED 21 | rea_prob = cfg.INPUT.REA.PROB 22 | rea_mean = cfg.INPUT.REA.MEAN 23 | 24 | # res.append(Resize(size_train, interpolation=3)) 25 | res.append(Resize(size_train, torchvision.transforms.InterpolationMode.BICUBIC)) 26 | if do_flip: 27 | res.append(RandomHorizontalFlip(p=flip_prob)) 28 | if do_pad: 29 | res.extend([Pad(padding), 30 | RandomCrop(size_train)]) 31 | else: 32 | size_test = cfg.INPUT.SIZE_TEST 33 | # res.append(Resize(size_test, interpolation=3)) 34 | res.append(Resize(size_test, torchvision.transforms.InterpolationMode.BICUBIC)) 35 | 36 | res.append(ToTensor()) 37 | res.append(Normalize(mean=cfg.INPUT.PIXEL_MEAN, 38 | std=cfg.INPUT.PIXEL_STD)) 39 | 40 | if is_train and do_rea: 41 | res.append(RandomErasing(probability=rea_prob, mean=rea_mean)) 42 | 43 | return Compose(res) 44 | -------------------------------------------------------------------------------- /secret/data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | 9 | class RandomErasing(object): 10 | """ Randomly selects a rectangle region in an image and erases its pixels. 11 | 'Random Erasing Data Augmentation' by Zhong et al. 12 | See https://arxiv.org/pdf/1708.04896.pdf 13 | Args: 14 | probability: The probability that the Random Erasing operation will be performed. 15 | sl: Minimum proportion of erased area against input image. 16 | sh: Maximum proportion of erased area against input image. 17 | r1: Minimum aspect ratio of erased area. 18 | mean: Erasing value. 19 | """ 20 | 21 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 22 | self.probability = probability 23 | self.mean = mean 24 | self.sl = sl 25 | self.sh = sh 26 | self.r1 = r1 27 | 28 | def __call__(self, img): 29 | 30 | if random.uniform(0, 1) >= self.probability: 31 | return img 32 | 33 | for attempt in range(100): 34 | area = img.size()[1] * img.size()[2] 35 | 36 | target_area = random.uniform(self.sl, self.sh) * area 37 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 38 | 39 | h = int(round(math.sqrt(target_area * aspect_ratio))) 40 | w = int(round(math.sqrt(target_area / aspect_ratio))) 41 | 42 | if w < img.size()[2] and h < img.size()[1]: 43 | x1 = random.randint(0, img.size()[1] - h) 44 | y1 = random.randint(0, img.size()[2] - w) 45 | if img.size()[0] == 3: 46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 47 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 48 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 49 | else: 50 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 51 | return img 52 | 53 | return img 54 | -------------------------------------------------------------------------------- /secret/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .mutualrefine import mutualrefine 4 | from .pretrain import pretrain 5 | 6 | __factory = { 7 | 'mutualrefine': mutualrefine, 8 | 'pretrain': pretrain, 9 | } 10 | 11 | def names(): 12 | return sorted(__factory.keys()) 13 | 14 | def create_engine(cfg): 15 | if cfg.MODE not in __factory: 16 | raise KeyError("Unknown Engine:", cfg.MODE) 17 | engine = __factory[cfg.MODE](cfg) 18 | return engine 19 | -------------------------------------------------------------------------------- /secret/engine/mutualrefine.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import logging 3 | import time 4 | import datetime 5 | import os.path as osp 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import collections 10 | import numpy as np 11 | from sklearn.cluster import DBSCAN 12 | 13 | from ..models import create_model 14 | from ..data.build import build_data, build_loader 15 | from ..metrics.Partevaluators import Evaluator, extract_features 16 | from ..metrics.ranking import accuracy 17 | from ..optim.optimizer import build_optimizer 18 | from ..optim.lr_scheduler import build_lr_scheduler 19 | from ..loss import CrossEntropyLabelSmooth, SoftTripletLoss, SoftEntropy 20 | from ..utils.meters import AverageMeter 21 | from ..utils.serialization import save_checkpoint, load_checkpoint, copy_state_dict 22 | from ..cluster.faiss_utils import compute_jaccard_distance 23 | from ..cluster.RefineCluster import RefineClusterProcess 24 | from ..utils.osutils import PathManager 25 | 26 | class mutualrefine(object): 27 | 28 | def __init__(self, cfg): 29 | self.cfg = cfg 30 | self.logger = logging.getLogger('UnReID') 31 | self.best_mAP = 0 32 | self.best_top1 = 0 33 | self.cluster_list = [] 34 | 35 | def _build_dataset(self): 36 | self.Target_dataset = build_data(self.cfg.DATASETS.TARGET, self.cfg.DATASETS.DIR) 37 | self.Target_cluster_loader = build_loader(self.cfg, None, inputset=sorted(self.Target_dataset.train), is_train=False) 38 | self.Target_test_loader = build_loader(self.cfg, self.Target_dataset, is_train=False) 39 | self.num_classes = len(self.Target_dataset.train) 40 | 41 | def _build_model(self): 42 | self.model = create_model(self.cfg, self.num_classes) 43 | self.model_ema = create_model(self.cfg, self.num_classes) 44 | 45 | 46 | if self.cfg.CHECKPOING.PRETRAIN_PATH: 47 | initial_weights = load_checkpoint(self.cfg.CHECKPOING.PRETRAIN_PATH) 48 | copy_state_dict(initial_weights['state_dict'], self.model) 49 | copy_state_dict(initial_weights['state_dict'], self.model_ema) 50 | 51 | if self.cfg.CHECKPOING.EVAL: 52 | initial_weights = load_checkpoint(self.cfg.CHECKPOING.EVAL) 53 | copy_state_dict(initial_weights['state_dict'], self.model_ema) 54 | 55 | start_epoch = 0 56 | 57 | self.model = nn.DataParallel(self.model) 58 | self.model_ema = nn.DataParallel(self.model_ema) 59 | 60 | for param in self.model_ema.parameters(): 61 | param.detach_() 62 | 63 | self.evaluator = Evaluator(self.cfg, self.model_ema) 64 | 65 | return start_epoch 66 | 67 | def _build_optim(self, epoch): 68 | if self.cfg.OPTIM.SCHED == 'single_step': 69 | scale = 1.0 70 | if self.cfg.CHECKPOING.PRETRAIN_PATH: 71 | scale = 1.0 72 | else: 73 | if epoch < 40: 74 | scale = 1.0 75 | elif epoch < 60: 76 | scale = 0.3 77 | elif epoch < 80: 78 | scale = 0.1 79 | LR = self.cfg.OPTIM.LR * scale 80 | else: 81 | raise NotImplementedError("NO {} for UDA".format(self.cfg.OPTIM.SCHED)) 82 | self.optimizer = build_optimizer(self.cfg, self.model, LR = LR) 83 | self.logger.info('lr: {:.8f}'.format(self.optimizer.param_groups[0]['lr'])) 84 | 85 | def run(self): 86 | self._build_dataset() 87 | start_epoch = self._build_model() 88 | if self.cfg.CHECKPOING.EVAL: 89 | self.eval() 90 | return 91 | 92 | self.start_epoch = start_epoch 93 | for epoch in range(start_epoch, self.cfg.OPTIM.EPOCHS): 94 | epoch_time = time.time() 95 | 96 | self.generate_pseudo_dataset(epoch) 97 | self._build_optim(epoch) 98 | self.init_train() 99 | self.train(epoch) 100 | self.eval_save(epoch) 101 | 102 | eta_seconds = (time.time()-epoch_time) * (self.cfg.OPTIM.EPOCHS - (epoch + 1)) 103 | eta_str = str(datetime.timedelta(seconds=int(eta_seconds))) 104 | self.logger.info('eta: {}'.format(eta_str)) 105 | 106 | def generate_pseudo_dataset(self, epoch): 107 | 108 | self.logger.info('Extract feat and Calculate dist...') 109 | dict_f = extract_features(self.model_ema, self.Target_cluster_loader, print_freq = self.cfg.TEST.PRINT_PERIOD, GenLabel = True) 110 | part_num = len(dict_f[self.Target_dataset.train[0][0]]) 111 | model_FC = ['classifier', 'classifier_partup', 'classifier_partdown'] 112 | cf = [torch.cat([dict_f[f][i].unsqueeze(0) for f, _, _ in sorted(self.Target_dataset.train)], 0) for i in range(part_num)] 113 | self.num_clusters_list = [] 114 | self.labels_list = [] 115 | 116 | for i in range(part_num): 117 | rerank_dist = compute_jaccard_distance(cf[i]) 118 | 119 | if (epoch==0 or epoch == self.start_epoch): 120 | # # DBSCAN cluster 121 | if self.cfg.CHECKPOING.PRETRAIN_PATH: 122 | tri_mat = np.triu(rerank_dist, 1) 123 | tri_mat = tri_mat[np.nonzero(tri_mat)] 124 | tri_mat = np.sort(tri_mat,axis=None) 125 | rho = 1.6e-3 126 | top_num = np.round(rho*tri_mat.size).astype(int) 127 | eps = tri_mat[:top_num].mean() 128 | else: 129 | eps = self.cfg.CLUSTER.EPS 130 | 131 | self.logger.info('eps for cluster: {:.3f}'.format(eps)) 132 | self.cluster_list.append(DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=-1)) 133 | 134 | self.logger.info('Clustering and labeling...') 135 | labels = self.cluster_list[i].fit_predict(rerank_dist) 136 | num_ids = len(set(labels)) - (1 if -1 in labels else 0) 137 | self.num_clusters_list.append(num_ids) 138 | self.labels_list.append(labels) 139 | self.logger.info('Clustered into {} classes'.format(num_ids)) 140 | 141 | self.new_num_clusters_list = [] 142 | self.new_labels_list = [] 143 | self.refine_index = [[1,2],[0],[0]] 144 | self.logger.info(self.refine_index) 145 | for i in range(part_num): 146 | target_labels = self.labels_list[i] 147 | 148 | new_labels_list = [] 149 | tmp_refine_index = self.refine_index[i] 150 | for j in tmp_refine_index: 151 | new_labels_list.append(RefineClusterProcess(self.labels_list[j], target_labels, divide_ratio = self.cfg.CLUSTER.REFINE_K)) 152 | 153 | labels = [] 154 | dictReindex = dict() 155 | label_index = 0 156 | useful_nums = 0 157 | for j in range(len(target_labels)): 158 | tmp_label = [int(target_labels[j])] 159 | tmp_label += [int(x[j]) for x in new_labels_list] 160 | if -1 in tmp_label: 161 | labels.append(-1) 162 | else: 163 | labels.append(int(target_labels[j])) 164 | useful_nums += 1 165 | if int(target_labels[j]) not in dictReindex.keys(): 166 | dictReindex[int(target_labels[j])] = label_index 167 | label_index += 1 168 | 169 | for index in range(len(labels)): 170 | if labels[index]==-1: continue 171 | labels[index] = dictReindex[labels[index]] 172 | 173 | del_samples = (len(labels) - useful_nums)-len(np.where(target_labels == -1)[0]) 174 | self.logger.info('useful samples {}, mutual refine del {} samples, del {} noisy cluster'.format(useful_nums, del_samples, self.num_clusters_list[i] - label_index)) 175 | self.new_labels_list.append(labels) 176 | self.new_num_clusters_list.append(label_index) 177 | 178 | cluster_centers_dict = collections.defaultdict(list) 179 | for index, ((fname, _, cid), label) in enumerate(zip(sorted(self.Target_dataset.train), labels)): 180 | if label==-1: continue 181 | cluster_centers_dict[label].append(cf[i][index]) 182 | 183 | cluster_centers = [torch.stack(cluster_centers_dict[idx]).mean(0) for idx in sorted(cluster_centers_dict.keys())] 184 | cluster_centers = torch.stack(cluster_centers) 185 | model_param = getattr(self.model.module, model_FC[i]) 186 | model_param.weight.data[:label_index].copy_(F.normalize(cluster_centers, dim=1).float().cuda()) 187 | model_ema_param = getattr(self.model_ema.module, model_FC[i]) 188 | model_ema_param.weight.data[:label_index].copy_(F.normalize(cluster_centers, dim=1).float().cuda()) 189 | 190 | self.num_clusters_list = self.new_num_clusters_list 191 | self.labels_list = self.new_labels_list 192 | 193 | new_dataset = [] 194 | for i, (fname, _, cid) in enumerate(sorted(self.Target_dataset.train)): 195 | label = [] 196 | for L in range(part_num): 197 | label.append(int(self.labels_list[L][i])) 198 | if -1 in label: 199 | continue 200 | new_dataset.append((fname, label, cid)) 201 | 202 | self.logger.info('new dataset length {}'.format(len(new_dataset))) 203 | self.Target_train_loader = build_loader(self.cfg, None, inputset=new_dataset, num_instances = self.cfg.DATALOADER.NUM_INSTANCES, is_train=True) 204 | assert self.cfg.DATALOADER.ITERS == len(self.Target_train_loader) 205 | 206 | def eval(self): 207 | self.logger.info('bn_x') 208 | use_cython = True 209 | self.logger.info('Use Cython Eval...') 210 | return self.evaluator.evaluate(self.Target_test_loader, 211 | self.Target_dataset.query, self.Target_dataset.gallery, use_cython = use_cython) 212 | 213 | def eval_save(self, epoch): 214 | if (epoch+1) != self.cfg.OPTIM.EPOCHS and self.cfg.CHECKPOING.SAVE_STEP[0] > 0 and (epoch+1) not in self.cfg.CHECKPOING.SAVE_STEP: 215 | return 216 | elif (epoch+1) != self.cfg.OPTIM.EPOCHS and self.cfg.CHECKPOING.SAVE_STEP[0] < 0 and (epoch+1) % -self.cfg.CHECKPOING.SAVE_STEP[0] != 0: 217 | return 218 | 219 | top1, mAP = self.eval()[0] 220 | 221 | is_top1_best = top1 > self.best_top1 222 | self.best_top1 = max(top1, self.best_top1) 223 | is_mAP_best = mAP > self.best_mAP 224 | self.best_mAP = max(mAP, self.best_mAP) 225 | 226 | _state_dict = self.model_ema.module.state_dict() 227 | 228 | save_checkpoint({ 229 | 'state_dict': _state_dict, 230 | 'epoch': epoch + 1, 231 | 'best_top1': self.best_top1, 232 | 'best_mAP': self.best_mAP 233 | }, is_top1_best, is_mAP_best, fpath=osp.join(self.cfg.OUTPUT_DIR, 'checkpoint_new.pth.tar'), remain=self.cfg.CHECKPOING.REMAIN_CLASSIFIER) 234 | 235 | self.logger.info('Finished epoch {:3d}\n Target mAP: {:5.1%} best: {:5.1%}{}\nTarget top1: {:5.1%} best: {:5.1%}{}'. 236 | format(epoch + 1, mAP, self.best_mAP, ' *' if is_mAP_best else '', top1, self.best_top1, ' *' if is_top1_best else '')) 237 | 238 | return 239 | 240 | def init_train(self): 241 | self.criterion_ce = CrossEntropyLabelSmooth(self.num_clusters_list[0], epsilon = self.cfg.MODEL.LOSSES.CE.EPSILON).cuda() 242 | self.criterion_ce_up = CrossEntropyLabelSmooth(self.num_clusters_list[1], epsilon = self.cfg.MODEL.LOSSES.CE.EPSILON).cuda() 243 | self.criterion_ce_down = CrossEntropyLabelSmooth(self.num_clusters_list[2], epsilon = self.cfg.MODEL.LOSSES.CE.EPSILON).cuda() 244 | self.criterion_tri = SoftTripletLoss(margin=0.0).cuda() 245 | 246 | self.criterion_ce_soft = SoftEntropy().cuda() 247 | self.criterion_tri_soft = SoftTripletLoss(margin=None).cuda() 248 | 249 | def train(self, epoch): 250 | self.Target_train_loader.new_epoch() 251 | 252 | self.model.train() 253 | self.model_ema.train() 254 | 255 | batch_time = AverageMeter() 256 | data_time = AverageMeter() 257 | losses_ce = AverageMeter() 258 | losses_tri = AverageMeter() 259 | losses_ce_soft = AverageMeter() 260 | losses_tri_soft = AverageMeter() 261 | precisions = AverageMeter() 262 | 263 | end = time.time() 264 | for i in range(self.cfg.DATALOADER.ITERS): 265 | target_inputs = self.Target_train_loader.next() 266 | data_time.update(time.time() - end) 267 | 268 | inputs, targets = self._parse_data(target_inputs) 269 | 270 | [x, part_up, part_down], _, [prob, prob_part_up, prob_part_down] = self.model(inputs, finetune = True) 271 | prob = prob[:,:self.num_clusters_list[0]] 272 | prob_part_up = prob_part_up[:,:self.num_clusters_list[1]] 273 | prob_part_down = prob_part_down[:,:self.num_clusters_list[2]] 274 | 275 | [x_ema, part_up_ema, part_down_ema], _, [prob_ema, prob_part_up_ema, prob_part_down_ema] = self.model_ema(inputs, finetune = True) 276 | prob_ema = prob_ema[:,:self.num_clusters_list[0]] 277 | prob_part_up_ema = prob_part_up_ema[:,:self.num_clusters_list[1]] 278 | prob_part_down_ema = prob_part_down_ema[:,:self.num_clusters_list[2]] 279 | 280 | loss_ce = self.criterion_ce(prob, targets[0]) + \ 281 | self.criterion_ce_up(prob_part_up, targets[1]) + \ 282 | self.criterion_ce_down(prob_part_down, targets[2]) 283 | loss_tri = self.criterion_tri(x, x, targets[0]) + \ 284 | self.criterion_tri(part_up, part_up, targets[1]) + \ 285 | self.criterion_tri(part_down, part_down, targets[2]) 286 | 287 | loss_ce_soft = self.criterion_ce_soft(prob, prob_ema) + \ 288 | self.criterion_ce_soft(prob_part_up, prob_part_up_ema) + \ 289 | self.criterion_ce_soft(prob_part_down, prob_part_down_ema) 290 | loss_tri_soft = self.criterion_tri_soft(x, x_ema, targets[0]) + \ 291 | self.criterion_tri_soft(part_up, part_up_ema, targets[1]) + \ 292 | self.criterion_tri_soft(part_down, part_down_ema, targets[2]) 293 | 294 | loss = loss_ce * (1-self.cfg.MEAN_TEACH.CE_SOFT_WRIGHT) + \ 295 | loss_tri * (1-self.cfg.MEAN_TEACH.TRI_SOFT_WRIGHT) + \ 296 | loss_ce_soft * self.cfg.MEAN_TEACH.CE_SOFT_WRIGHT + \ 297 | loss_tri_soft * self.cfg.MEAN_TEACH.TRI_SOFT_WRIGHT 298 | 299 | self.optimizer.zero_grad() 300 | loss.backward() 301 | self.optimizer.step() 302 | 303 | self.update_ema_variables(self.model, self.model_ema, self.cfg.MEAN_TEACH.ALPHA, epoch*len(self.Target_train_loader) + i) 304 | 305 | prec = accuracy(prob_ema.data, targets[0].data) 306 | 307 | losses_ce.update(loss_ce.item()) 308 | losses_tri.update(loss_tri.item()) 309 | losses_ce_soft.update(loss_ce_soft.item()) 310 | losses_tri_soft.update(loss_tri_soft.item()) 311 | precisions.update(prec) 312 | 313 | batch_time.update(time.time() - end) 314 | 315 | if ((i + 1) % self.cfg.PRINT_PERIOD == 0) or ((i + 1) % self.cfg.DATALOADER.ITERS == 0): 316 | self.logger.info('Epoch: [{}][{}/{}]\t' 317 | 'Time {:.3f} ({:.3f})\t' 318 | 'Data {:.3f} ({:.3f})\t' 319 | 'Loss_ce {:.3f} ({:.3f})\t' 320 | 'Loss_tri {:.3f} ({:.3f})\t' 321 | 'Loss_ce_soft {:.3f} ({:.3f})\t' 322 | 'Loss_tri_soft {:.3f} ({:.3f})\t' 323 | 'Prec {:.2%} ({:.2%})\t' 324 | .format(epoch + 1, i + 1, self.cfg.DATALOADER.ITERS, 325 | batch_time.val, batch_time.avg, 326 | data_time.val, data_time.avg, 327 | losses_ce.val, losses_ce.avg, 328 | losses_tri.val, losses_tri.avg, 329 | losses_ce_soft.val, losses_ce_soft.avg, 330 | losses_tri_soft.val, losses_tri_soft.avg, 331 | precisions.val, precisions.avg)) 332 | 333 | end = time.time() 334 | 335 | def update_ema_variables(self, model, ema_model, alpha, global_step): 336 | alpha = min(1 - 1 / (global_step + 1), alpha) 337 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 338 | ema_param.data.mul_(alpha).add_(param.data, alpha = 1 - alpha) 339 | 340 | def _parse_data(self, inputs): 341 | imgs, _, pids, _ = inputs 342 | inputs = imgs.cuda() 343 | targets = [p.cuda() for p in pids] 344 | return inputs, targets 345 | -------------------------------------------------------------------------------- /secret/engine/pretrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import logging 3 | import time 4 | import torch.nn as nn 5 | import os.path as osp 6 | 7 | from ..models import create_model 8 | from ..data.build import build_data, build_loader 9 | from ..metrics.Partevaluators import Evaluator 10 | from ..metrics.ranking import accuracy 11 | from ..optim.optimizer import build_optimizer 12 | from ..optim.lr_scheduler import build_lr_scheduler 13 | from ..loss import CrossEntropyLabelSmooth, SoftTripletLoss 14 | from ..utils.meters import AverageMeter 15 | from ..utils.serialization import save_checkpoint, load_checkpoint, copy_state_dict 16 | 17 | class pretrain(object): 18 | 19 | def __init__(self, cfg): 20 | self.cfg = cfg 21 | self.logger = logging.getLogger('UnReID') 22 | self.best_mAP = 0 23 | self.best_top1 = 0 24 | 25 | def _build_dataset(self): 26 | self.Source_dataset = build_data(self.cfg.DATASETS.SOURCE, self.cfg.DATASETS.DIR) 27 | self.Target_dataset = build_data(self.cfg.DATASETS.TARGET, self.cfg.DATASETS.DIR) 28 | self.Source_train_loader = build_loader(self.cfg, self.Source_dataset, num_instances = self.cfg.DATALOADER.NUM_INSTANCES, is_train=True) 29 | self.Target_train_loader = build_loader(self.cfg, self.Target_dataset, num_instances = 0,is_train=True) 30 | self.Target_test_loader = build_loader(self.cfg, self.Target_dataset, is_train=False) 31 | self.num_classes = self.Source_dataset.num_train_pids 32 | 33 | def _build_model(self): 34 | self.model = create_model(self.cfg, self.num_classes) 35 | 36 | start_epoch = 0 37 | 38 | self.model = nn.DataParallel(self.model) 39 | self.evaluator = Evaluator(self.cfg, self.model) 40 | 41 | return start_epoch 42 | 43 | def _build_optim(self): 44 | self.optimizer = build_optimizer(self.cfg, self.model) 45 | self.lr_scheduler = build_lr_scheduler(self.cfg, self.optimizer) 46 | 47 | def run(self): 48 | self._build_dataset() 49 | start_epoch = self._build_model() 50 | self._build_optim() 51 | 52 | self.init_train() 53 | for epoch in range(start_epoch, self.cfg.OPTIM.EPOCHS): 54 | self.train(epoch) 55 | self.eval_save(epoch) 56 | 57 | def eval(self): 58 | return self.evaluator.evaluate(self.Target_test_loader, 59 | self.Target_dataset.query, self.Target_dataset.gallery) 60 | 61 | 62 | def eval_save(self, epoch): 63 | if (epoch+1) != self.cfg.OPTIM.EPOCHS and self.cfg.CHECKPOING.SAVE_STEP[0] > 0 and (epoch+1) not in self.cfg.CHECKPOING.SAVE_STEP: 64 | return 65 | elif (epoch+1) != self.cfg.OPTIM.EPOCHS and self.cfg.CHECKPOING.SAVE_STEP[0] < 0 and (epoch+1) % -self.cfg.CHECKPOING.SAVE_STEP[0] != 0: 66 | return 67 | 68 | results = self.eval() 69 | 70 | _state_dict = self.model.module.state_dict() 71 | 72 | save_checkpoint({ 73 | 'state_dict': _state_dict, 74 | 'epoch': epoch + 1 75 | }, False, False, fpath=osp.join(self.cfg.OUTPUT_DIR, 'checkpoint_new.pth.tar'), remain=self.cfg.CHECKPOING.REMAIN_CLASSIFIER) 76 | 77 | self.logger.info('Finished epoch {:3d}'. 78 | format(epoch + 1)) 79 | 80 | return 81 | 82 | def init_train(self): 83 | self.criterion_ce = CrossEntropyLabelSmooth(self.num_classes, epsilon = self.cfg.MODEL.LOSSES.CE.EPSILON).cuda() 84 | self.criterion_triple = SoftTripletLoss(margin=0.0).cuda() 85 | 86 | def train(self, epoch): 87 | self.Source_train_loader.new_epoch() 88 | self.Target_train_loader.new_epoch() 89 | self.model.train() 90 | 91 | batch_time = AverageMeter() 92 | data_time = AverageMeter() 93 | losses_ce = AverageMeter() 94 | losses_tri = AverageMeter() 95 | precisions = AverageMeter() 96 | 97 | end = time.time() 98 | 99 | for i in range(self.cfg.DATALOADER.ITERS): 100 | source_inputs = self.Source_train_loader.next() 101 | target_inputs = self.Target_train_loader.next() 102 | data_time.update(time.time() - end) 103 | 104 | s_inputs, targets = self._parse_data(source_inputs) 105 | t_inputs, _ = self._parse_data(target_inputs) 106 | 107 | [x, part_up, part_down], [prob, prob_part_up, prob_part_down] = self.model(s_inputs) 108 | self.model(t_inputs) 109 | 110 | # backward main # 111 | loss_ce = self.criterion_ce(prob, targets) + self.criterion_ce(prob_part_up, targets) + self.criterion_ce(prob_part_down, targets) 112 | loss_tri = self.criterion_triple(x, x, targets) + self.criterion_triple(part_up, part_up, targets) + self.criterion_triple(part_down, part_down, targets) 113 | loss = loss_ce + loss_tri 114 | 115 | self.optimizer.zero_grad() 116 | loss.backward() 117 | self.optimizer.step() 118 | 119 | prec = accuracy(prob.data, targets.data) 120 | losses_ce.update(loss_ce.item()) 121 | losses_tri.update(loss_tri.item()) 122 | precisions.update(prec) 123 | 124 | batch_time.update(time.time() - end) 125 | 126 | if ((i + 1) % self.cfg.PRINT_PERIOD == 0) or ((i + 1) % self.cfg.DATALOADER.ITERS == 0): 127 | self.logger.info('Epoch: [{}][{}/{}]\t' 128 | 'Time {:.3f} ({:.3f})\t' 129 | 'Data {:.3f} ({:.3f})\t' 130 | 'Loss_ce {:.3f} ({:.3f})\t' 131 | 'Loss_tri {:.3f} ({:.3f})\t' 132 | 'Prec {:.2%} ({:.2%})' 133 | .format(epoch+1, i + 1, self.cfg.DATALOADER.ITERS, 134 | batch_time.val, batch_time.avg, 135 | data_time.val, data_time.avg, 136 | losses_ce.val, losses_ce.avg, 137 | losses_tri.val, losses_tri.avg, 138 | precisions.val, precisions.avg)) 139 | 140 | end = time.time() 141 | 142 | self.lr_scheduler.step() 143 | 144 | def _parse_data(self, inputs): 145 | imgs, _, pids, _ = inputs 146 | inputs = imgs.cuda() 147 | targets = pids.cuda() 148 | return inputs, targets 149 | -------------------------------------------------------------------------------- /secret/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .triplet import SoftTripletLoss, TripletLoss 4 | from .crossentropy import CrossEntropyLabelSmooth, SoftEntropy 5 | -------------------------------------------------------------------------------- /secret/loss/crossentropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import * 5 | 6 | class CrossEntropyLabelSmooth(nn.Module): 7 | 8 | def __init__(self, num_classes, epsilon=0.1): 9 | super(CrossEntropyLabelSmooth, self).__init__() 10 | self.num_classes = num_classes 11 | self.epsilon = epsilon 12 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 13 | 14 | def forward(self, inputs, targets): 15 | log_probs = self.logsoftmax(inputs) 16 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 17 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 18 | 19 | loss = (- targets * log_probs).mean(0).sum() 20 | 21 | return loss 22 | 23 | class SoftEntropy(nn.Module): 24 | def __init__(self): 25 | super(SoftEntropy, self).__init__() 26 | self.logsoftmax = nn.LogSoftmax(dim=1).cuda() 27 | 28 | def forward(self, inputs, targets): 29 | log_probs = self.logsoftmax(inputs) 30 | loss = (- F.softmax(targets, dim=1).detach() * log_probs).mean(0).sum() 31 | return loss 32 | -------------------------------------------------------------------------------- /secret/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def euclidean_dist(x, y): 9 | m, n = x.size(0), y.size(0) 10 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 11 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 12 | dist = xx + yy 13 | dist.addmm_(x, y.t(), beta=1, alpha=-2) 14 | dist = dist.clamp(min=1e-12).sqrt() 15 | return dist 16 | 17 | def cosine_dist(x, y): 18 | bs1, bs2 = x.size(0), y.size(0) 19 | frac_up = torch.matmul(x, y.transpose(0, 1)) 20 | frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \ 21 | (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1) 22 | cosine = frac_up / frac_down 23 | return 1-cosine 24 | 25 | def _batch_hard(mat_distance, mat_similarity, indice=False): 26 | sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1, descending=True) 27 | hard_p = sorted_mat_distance[:, 0] 28 | hard_p_indice = positive_indices[:, 0] 29 | sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1, descending=False) 30 | hard_n = sorted_mat_distance[:, 0] 31 | hard_n_indice = negative_indices[:, 0] 32 | if(indice): 33 | return hard_p, hard_n, hard_p_indice, hard_n_indice 34 | return hard_p, hard_n 35 | 36 | class TripletLoss(nn.Module): 37 | 38 | def __init__(self, margin, normalize_feature=False): 39 | super(TripletLoss, self).__init__() 40 | self.margin = margin 41 | self.normalize_feature = normalize_feature 42 | self.margin_loss = nn.MarginRankingLoss(margin=margin).cuda() 43 | 44 | def forward(self, emb, label): 45 | if self.normalize_feature: 46 | # equal to cosine similarity 47 | emb = F.normalize(emb) 48 | mat_dist = euclidean_dist(emb, emb) 49 | # mat_dist = cosine_dist(emb, emb) 50 | assert mat_dist.size(0) == mat_dist.size(1) 51 | N = mat_dist.size(0) 52 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 53 | 54 | dist_ap, dist_an = _batch_hard(mat_dist, mat_sim) 55 | assert dist_an.size(0)==dist_ap.size(0) 56 | y = torch.ones_like(dist_ap) 57 | loss = self.margin_loss(dist_an, dist_ap, y) 58 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 59 | return loss, prec 60 | 61 | class SoftTripletLoss(nn.Module): 62 | 63 | def __init__(self, margin=None, normalize_feature=False): 64 | super(SoftTripletLoss, self).__init__() 65 | self.margin = margin 66 | self.normalize_feature = normalize_feature 67 | 68 | def forward(self, emb1, emb2, label): 69 | 70 | if self.normalize_feature: 71 | # equal to cosine similarity 72 | emb1 = F.normalize(emb1) 73 | emb2 = F.normalize(emb2) 74 | 75 | mat_dist = euclidean_dist(emb1, emb1) 76 | assert mat_dist.size(0) == mat_dist.size(1) 77 | N = mat_dist.size(0) 78 | mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float() 79 | 80 | dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True) 81 | assert dist_an.size(0)==dist_ap.size(0) 82 | triple_dist = torch.stack((dist_ap, dist_an), dim=1) 83 | triple_dist = F.log_softmax(triple_dist, dim=1) 84 | if (self.margin is not None): 85 | loss = (- self.margin * triple_dist[:,0] - (1 - self.margin) * triple_dist[:,1]).mean() 86 | return loss 87 | 88 | mat_dist_ref = euclidean_dist(emb2, emb2) 89 | dist_ap_ref = torch.gather(mat_dist_ref, 1, ap_idx.view(N,1).expand(N,N))[:,0] 90 | dist_an_ref = torch.gather(mat_dist_ref, 1, an_idx.view(N,1).expand(N,N))[:,0] 91 | triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1) 92 | triple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach() 93 | 94 | loss = (- triple_dist_ref * triple_dist).mean(0).sum() 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /secret/metrics/Partevaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import collections 4 | from collections import OrderedDict 5 | import numpy as np 6 | import torch 7 | import random 8 | import copy 9 | import logging 10 | 11 | from .rank_c import evaluate_rank 12 | from .ranking import cmc, mean_ap 13 | from .rerank import re_ranking 14 | from ..utils import to_torch 15 | from ..utils.meters import AverageMeter 16 | 17 | def extract_cnn_feature(model, inputs, GenLabel = True): 18 | inputs = to_torch(inputs).cuda() 19 | if GenLabel: 20 | outputs = model(inputs, finetune=True)[1] 21 | else: 22 | outputs = model(inputs, finetune=False) 23 | outputs = [x.data.cpu() for x in outputs] 24 | return outputs 25 | 26 | def extract_features(model, data_loader, print_freq=100, GenLabel = True): 27 | logger = logging.getLogger('UnReID') 28 | 29 | model.eval() 30 | batch_time = AverageMeter() 31 | data_time = AverageMeter() 32 | 33 | features = OrderedDict() 34 | 35 | end = time.time() 36 | with torch.no_grad(): 37 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 38 | data_time.update(time.time() - end) 39 | 40 | outputs = extract_cnn_feature(model, imgs, GenLabel) 41 | for index, fname in enumerate(fnames): 42 | features[fname] = [x[index] for x in outputs] 43 | 44 | batch_time.update(time.time() - end) 45 | end = time.time() 46 | 47 | if ((i + 1) % print_freq == 0) or ((i + 1) % len(data_loader) == 0) : 48 | logger.info('Extract Features: [{}/{}]\t' 49 | 'Time {:.3f} ({:.3f})\t' 50 | 'Data {:.3f} ({:.3f})\t' 51 | .format(i + 1, len(data_loader), 52 | batch_time.val, batch_time.avg, 53 | data_time.val, data_time.avg)) 54 | 55 | return features 56 | 57 | def pairwise_distance(x, y): 58 | m, n = x.size(0), y.size(0) 59 | x = x.view(m, -1) 60 | y = y.view(n, -1) 61 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 62 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 63 | dist_m.addmm_(x, y.t(), beta=1, alpha=-2) 64 | return dist_m 65 | 66 | def evaluate_all(distmat, query=None, gallery=None, 67 | query_ids=None, gallery_ids=None, 68 | query_cams=None, gallery_cams=None, 69 | cmc_topk=(1, 5, 10), cmc_flag=False, use_cython = False): 70 | logger = logging.getLogger('UnReID') 71 | 72 | if query is not None and gallery is not None: 73 | query_ids = [pid for _, pid, _ in query] 74 | gallery_ids = [pid for _, pid, _ in gallery] 75 | query_cams = [cam for _, _, cam in query] 76 | gallery_cams = [cam for _, _, cam in gallery] 77 | else: 78 | assert (query_ids is not None and gallery_ids is not None 79 | and query_cams is not None and gallery_cams is not None) 80 | 81 | if use_cython is True: 82 | return evaluate_rank(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 83 | # Compute mean AP 84 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 85 | logger.info('Mean AP: {:4.1%}'.format(mAP)) 86 | 87 | cmc_configs = { 88 | 'market1501': dict(separate_camera_set=False, 89 | single_gallery_shot=False, 90 | first_match_break=True),} 91 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 92 | query_cams, gallery_cams, **params) 93 | for name, params in cmc_configs.items()} 94 | 95 | logger.info('CMC Scores:') 96 | for k in cmc_topk: 97 | logger.info(' top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k-1])) 98 | return cmc_scores['market1501'][0], mAP 99 | 100 | 101 | class Evaluator(object): 102 | def __init__(self, cfg, model): 103 | super(Evaluator, self).__init__() 104 | self.cfg = cfg 105 | self.model = model 106 | self.logger = logging.getLogger('UnReID') 107 | 108 | def evaluate(self, data_loader, query, gallery, use_cython = False): 109 | features = extract_features(self.model, data_loader, print_freq = self.cfg.TEST.PRINT_PERIOD, GenLabel=False) 110 | eval_nums = len(features[query[0][0]]) 111 | results = [] 112 | for i in range(eval_nums): 113 | x = torch.cat([features[f][i].unsqueeze(0) for f, _, _ in query], 0) 114 | y = torch.cat([features[f][i].unsqueeze(0) for f, _, _ in gallery], 0) 115 | distmat = pairwise_distance(x,y) 116 | results.append(evaluate_all(distmat, query=query, gallery=gallery, use_cython = use_cython)) 117 | return results 118 | -------------------------------------------------------------------------------- /secret/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import ranking 4 | from . import rerank 5 | from . import Partevaluators 6 | from . import evaluators 7 | -------------------------------------------------------------------------------- /secret/metrics/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import collections 4 | from collections import OrderedDict 5 | import numpy as np 6 | import torch 7 | import random 8 | import copy 9 | import logging 10 | 11 | from .rank_c import evaluate_rank 12 | from .ranking import cmc, mean_ap 13 | from .rerank import re_ranking 14 | from ..utils import to_torch 15 | from ..utils.meters import AverageMeter 16 | 17 | def extract_cnn_feature(model, inputs): 18 | inputs = to_torch(inputs).cuda() 19 | outputs = model(inputs) 20 | outputs = outputs.data.cpu() 21 | return outputs 22 | 23 | def extract_features(model, data_loader, print_freq=100): 24 | logger = logging.getLogger('UnReID') 25 | 26 | model.eval() 27 | batch_time = AverageMeter() 28 | data_time = AverageMeter() 29 | 30 | features = OrderedDict() 31 | labels = OrderedDict() 32 | 33 | end = time.time() 34 | with torch.no_grad(): 35 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 36 | data_time.update(time.time() - end) 37 | 38 | outputs = extract_cnn_feature(model, imgs) 39 | for fname, output, pid in zip(fnames, outputs, pids): 40 | features[fname] = output 41 | labels[fname] = pid 42 | 43 | batch_time.update(time.time() - end) 44 | end = time.time() 45 | 46 | if ((i + 1) % print_freq == 0) or ((i + 1) % len(data_loader) == 0) : 47 | logger.info('Extract Features: [{}/{}]\t' 48 | 'Time {:.3f} ({:.3f})\t' 49 | 'Data {:.3f} ({:.3f})\t' 50 | .format(i + 1, len(data_loader), 51 | batch_time.val, batch_time.avg, 52 | data_time.val, data_time.avg)) 53 | 54 | return features, labels 55 | 56 | def pairwise_distance(features, query=None, gallery=None): 57 | if query is None and gallery is None: 58 | n = len(features) 59 | x = torch.cat(list(features.values())) 60 | x = x.view(n, -1) 61 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 62 | dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t()) 63 | return dist_m 64 | 65 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 66 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 67 | m, n = x.size(0), y.size(0) 68 | x = x.view(m, -1) 69 | y = y.view(n, -1) 70 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 71 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 72 | # dist_m.addmm_(1, -2, x, y.t()) 73 | dist_m.addmm_(x, y.t(), beta=1, alpha=-2) 74 | return dist_m 75 | 76 | def evaluate_all(distmat, query=None, gallery=None, 77 | query_ids=None, gallery_ids=None, 78 | query_cams=None, gallery_cams=None, 79 | cmc_topk=(1, 5, 10), cmc_flag=False, use_cython = False): 80 | logger = logging.getLogger('UnReID') 81 | 82 | if query is not None and gallery is not None: 83 | query_ids = [pid for _, pid, _ in query] 84 | gallery_ids = [pid for _, pid, _ in gallery] 85 | query_cams = [cam for _, _, cam in query] 86 | gallery_cams = [cam for _, _, cam in gallery] 87 | else: 88 | assert (query_ids is not None and gallery_ids is not None 89 | and query_cams is not None and gallery_cams is not None) 90 | 91 | if use_cython is True: 92 | return evaluate_rank(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 93 | # Compute mean AP 94 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 95 | logger.info('Mean AP: {:4.1%}'.format(mAP)) 96 | 97 | cmc_configs = { 98 | 'market1501': dict(separate_camera_set=False, 99 | single_gallery_shot=False, 100 | first_match_break=True),} 101 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 102 | query_cams, gallery_cams, **params) 103 | for name, params in cmc_configs.items()} 104 | 105 | logger.info('CMC Scores:') 106 | for k in cmc_topk: 107 | logger.info(' top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k-1])) 108 | return cmc_scores['market1501'][0], mAP 109 | 110 | 111 | class Evaluator(object): 112 | def __init__(self, cfg, model): 113 | super(Evaluator, self).__init__() 114 | self.cfg = cfg 115 | self.model = model 116 | self.logger = logging.getLogger('UnReID') 117 | 118 | def evaluate(self, data_loader, query, gallery, use_cython = False): 119 | features, _ = extract_features(self.model, data_loader, print_freq = self.cfg.TEST.PRINT_PERIOD) 120 | distmat = pairwise_distance(features, query, gallery) 121 | results = evaluate_all(distmat, query=query, gallery=gallery, use_cython = use_cython) 122 | 123 | if (not self.cfg.TEST.RERANK.ENABLED): 124 | return results 125 | 126 | self.logger.info('Applying person re-ranking ...') 127 | distmat_qq = pairwise_distance(features, query, query) 128 | distmat_gg = pairwise_distance(features, gallery, gallery) 129 | distmat = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy(), 130 | k1=self.cfg.TEST.RERANK.K1, k2=self.cfg.TEST.RERANK.K2, lambda_value=self.cfg.TEST.RERANK.LAMBDA) 131 | return evaluate_all(distmat, query=query, gallery=gallery) 132 | -------------------------------------------------------------------------------- /secret/metrics/rank_c.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import numpy as np 3 | import warnings 4 | from collections import defaultdict 5 | import logging 6 | 7 | try: 8 | from .rank_cylib.rank_cy import evaluate_cy 9 | IS_CYTHON_AVAI = True 10 | except ImportError: 11 | IS_CYTHON_AVAI = False 12 | warnings.warn( 13 | 'Cython evaluation (very fast so highly recommended) is ' 14 | 'unavailable, now use python evaluation.' 15 | ) 16 | 17 | def evaluate_rank( 18 | distmat, 19 | q_pids, 20 | g_pids, 21 | q_camids, 22 | g_camids, 23 | max_rank=10, 24 | use_cython=True 25 | ): 26 | """Evaluates CMC rank. 27 | 28 | Args: 29 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 30 | q_pids (numpy.ndarray): 1-D array containing person identities 31 | of each query instance. 32 | g_pids (numpy.ndarray): 1-D array containing person identities 33 | of each gallery instance. 34 | q_camids (numpy.ndarray): 1-D array containing camera views under 35 | which each query instance is captured. 36 | g_camids (numpy.ndarray): 1-D array containing camera views under 37 | which each gallery instance is captured. 38 | max_rank (int, optional): maximum CMC rank to be computed. Default is 50. 39 | use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03. 40 | Default is False. This should be enabled when using cuhk03 classic split. 41 | use_cython (bool, optional): use cython code for evaluation. Default is True. 42 | This is highly recommended as the cython code can speed up the cmc computation 43 | by more than 10x. This requires Cython to be installed. 44 | """ 45 | use_metric_cuhk03=False 46 | 47 | cmc_scores, mAP = evaluate_cy( 48 | distmat, q_pids, g_pids, q_camids, g_camids, max_rank, 49 | use_metric_cuhk03 50 | ) 51 | cmc_topk = (1, 5, 10) 52 | logger = logging.getLogger('UnReID') 53 | logger.info('Mean AP: {:4.1%}'.format(mAP)) 54 | logger.info('CMC Scores:') 55 | for k in cmc_topk: 56 | logger.info(' top-{:<4}{:12.1%}'.format(k, cmc_scores[k-1])) 57 | return cmc_scores[0], mAP 58 | -------------------------------------------------------------------------------- /secret/metrics/rank_cylib/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | $(PYTHON) setup.py build_ext --inplace 3 | rm -rf build 4 | clean: 5 | rm -rf build 6 | rm -f rank_cy.c *.so -------------------------------------------------------------------------------- /secret/metrics/rank_cylib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LunarShen/SECRET/33f081ed6475e5ecdafe89f44d77171a42b3301f/secret/metrics/rank_cylib/__init__.py -------------------------------------------------------------------------------- /secret/metrics/rank_cylib/rank_cy.pyx: -------------------------------------------------------------------------------- 1 | # cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True 2 | 3 | from __future__ import print_function 4 | import numpy as np 5 | 6 | import cython 7 | 8 | cimport numpy as np 9 | 10 | import random 11 | from collections import defaultdict 12 | 13 | """ 14 | Compiler directives: 15 | https://github.com/cython/cython/wiki/enhancements-compilerdirectives 16 | 17 | Cython tutorial: 18 | https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html 19 | 20 | Credit to https://github.com/luzai 21 | """ 22 | 23 | 24 | # Main interface 25 | cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False): 26 | distmat = np.asarray(distmat, dtype=np.float32) 27 | q_pids = np.asarray(q_pids, dtype=np.int64) 28 | g_pids = np.asarray(g_pids, dtype=np.int64) 29 | q_camids = np.asarray(q_camids, dtype=np.int64) 30 | g_camids = np.asarray(g_camids, dtype=np.int64) 31 | if use_metric_cuhk03: 32 | return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 33 | return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 34 | 35 | 36 | cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 37 | long[:]q_camids, long[:]g_camids, long max_rank): 38 | 39 | cdef long num_q = distmat.shape[0] 40 | cdef long num_g = distmat.shape[1] 41 | 42 | if num_g < max_rank: 43 | max_rank = num_g 44 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 45 | 46 | cdef: 47 | long num_repeats = 10 48 | long[:,:] indices = np.argsort(distmat, axis=1) 49 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 50 | 51 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 52 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 53 | float num_valid_q = 0. # number of valid query 54 | 55 | long q_idx, q_pid, q_camid, g_idx 56 | long[:] order = np.zeros(num_g, dtype=np.int64) 57 | long keep 58 | 59 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 60 | float[:] masked_raw_cmc = np.zeros(num_g, dtype=np.float32) 61 | float[:] cmc, masked_cmc 62 | long num_g_real, num_g_real_masked, rank_idx, rnd_idx 63 | unsigned long meet_condition 64 | float AP 65 | long[:] kept_g_pids, mask 66 | 67 | float num_rel 68 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 69 | float tmp_cmc_sum 70 | 71 | for q_idx in range(num_q): 72 | # get query pid and camid 73 | q_pid = q_pids[q_idx] 74 | q_camid = q_camids[q_idx] 75 | 76 | # remove gallery samples that have the same pid and camid with query 77 | for g_idx in range(num_g): 78 | order[g_idx] = indices[q_idx, g_idx] 79 | num_g_real = 0 80 | meet_condition = 0 81 | kept_g_pids = np.zeros(num_g, dtype=np.int64) 82 | 83 | for g_idx in range(num_g): 84 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 85 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 86 | kept_g_pids[num_g_real] = g_pids[order[g_idx]] 87 | num_g_real += 1 88 | if matches[q_idx][g_idx] > 1e-31: 89 | meet_condition = 1 90 | 91 | if not meet_condition: 92 | # this condition is true when query identity does not appear in gallery 93 | continue 94 | 95 | # cuhk03-specific setting 96 | g_pids_dict = defaultdict(list) # overhead! 97 | for g_idx in range(num_g_real): 98 | g_pids_dict[kept_g_pids[g_idx]].append(g_idx) 99 | 100 | cmc = np.zeros(max_rank, dtype=np.float32) 101 | for _ in range(num_repeats): 102 | mask = np.zeros(num_g_real, dtype=np.int64) 103 | 104 | for _, idxs in g_pids_dict.items(): 105 | # randomly sample one image for each gallery person 106 | rnd_idx = np.random.choice(idxs) 107 | #rnd_idx = idxs[0] # use deterministic for debugging 108 | mask[rnd_idx] = 1 109 | 110 | num_g_real_masked = 0 111 | for g_idx in range(num_g_real): 112 | if mask[g_idx] == 1: 113 | masked_raw_cmc[num_g_real_masked] = raw_cmc[g_idx] 114 | num_g_real_masked += 1 115 | 116 | masked_cmc = np.zeros(num_g, dtype=np.float32) 117 | function_cumsum(masked_raw_cmc, masked_cmc, num_g_real_masked) 118 | for g_idx in range(num_g_real_masked): 119 | if masked_cmc[g_idx] > 1: 120 | masked_cmc[g_idx] = 1 121 | 122 | for rank_idx in range(max_rank): 123 | cmc[rank_idx] += masked_cmc[rank_idx] / num_repeats 124 | 125 | for rank_idx in range(max_rank): 126 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 127 | # compute average precision 128 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 129 | function_cumsum(raw_cmc, tmp_cmc, num_g_real) 130 | num_rel = 0 131 | tmp_cmc_sum = 0 132 | for g_idx in range(num_g_real): 133 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx] 134 | num_rel += raw_cmc[g_idx] 135 | all_AP[q_idx] = tmp_cmc_sum / num_rel 136 | num_valid_q += 1. 137 | 138 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 139 | 140 | # compute averaged cmc 141 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 142 | for rank_idx in range(max_rank): 143 | for q_idx in range(num_q): 144 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 145 | avg_cmc[rank_idx] /= num_valid_q 146 | 147 | cdef float mAP = 0 148 | for q_idx in range(num_q): 149 | mAP += all_AP[q_idx] 150 | mAP /= num_valid_q 151 | 152 | return np.asarray(avg_cmc).astype(np.float32), mAP 153 | 154 | 155 | cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 156 | long[:]q_camids, long[:]g_camids, long max_rank): 157 | 158 | cdef long num_q = distmat.shape[0] 159 | cdef long num_g = distmat.shape[1] 160 | 161 | if num_g < max_rank: 162 | max_rank = num_g 163 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 164 | 165 | cdef: 166 | long[:,:] indices = np.argsort(distmat, axis=1) 167 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 168 | 169 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 170 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 171 | float num_valid_q = 0. # number of valid query 172 | 173 | long q_idx, q_pid, q_camid, g_idx 174 | long[:] order = np.zeros(num_g, dtype=np.int64) 175 | long keep 176 | 177 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 178 | float[:] cmc = np.zeros(num_g, dtype=np.float32) 179 | long num_g_real, rank_idx 180 | unsigned long meet_condition 181 | 182 | float num_rel 183 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 184 | float tmp_cmc_sum 185 | 186 | for q_idx in range(num_q): 187 | # get query pid and camid 188 | q_pid = q_pids[q_idx] 189 | q_camid = q_camids[q_idx] 190 | 191 | # remove gallery samples that have the same pid and camid with query 192 | for g_idx in range(num_g): 193 | order[g_idx] = indices[q_idx, g_idx] 194 | num_g_real = 0 195 | meet_condition = 0 196 | 197 | for g_idx in range(num_g): 198 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 199 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 200 | num_g_real += 1 201 | if matches[q_idx][g_idx] > 1e-31: 202 | meet_condition = 1 203 | 204 | if not meet_condition: 205 | # this condition is true when query identity does not appear in gallery 206 | continue 207 | 208 | # compute cmc 209 | function_cumsum(raw_cmc, cmc, num_g_real) 210 | for g_idx in range(num_g_real): 211 | if cmc[g_idx] > 1: 212 | cmc[g_idx] = 1 213 | 214 | for rank_idx in range(max_rank): 215 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 216 | num_valid_q += 1. 217 | 218 | # compute average precision 219 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 220 | function_cumsum(raw_cmc, tmp_cmc, num_g_real) 221 | num_rel = 0 222 | tmp_cmc_sum = 0 223 | for g_idx in range(num_g_real): 224 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx] 225 | num_rel += raw_cmc[g_idx] 226 | all_AP[q_idx] = tmp_cmc_sum / num_rel 227 | 228 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 229 | 230 | # compute averaged cmc 231 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 232 | for rank_idx in range(max_rank): 233 | for q_idx in range(num_q): 234 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 235 | avg_cmc[rank_idx] /= num_valid_q 236 | 237 | cdef float mAP = 0 238 | for q_idx in range(num_q): 239 | mAP += all_AP[q_idx] 240 | mAP /= num_valid_q 241 | 242 | return np.asarray(avg_cmc).astype(np.float32), mAP 243 | 244 | 245 | # Compute the cumulative sum 246 | cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, long n): 247 | cdef long i 248 | dst[0] = src[0] 249 | for i in range(1, n): 250 | dst[i] = src[i] + dst[i - 1] -------------------------------------------------------------------------------- /secret/metrics/rank_cylib/setup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from distutils.core import setup 3 | from distutils.extension import Extension 4 | from Cython.Build import cythonize 5 | 6 | 7 | def numpy_include(): 8 | try: 9 | numpy_include = np.get_include() 10 | except AttributeError: 11 | numpy_include = np.get_numpy_include() 12 | return numpy_include 13 | 14 | 15 | ext_modules = [ 16 | Extension( 17 | 'rank_cy', 18 | ['rank_cy.pyx'], 19 | include_dirs=[numpy_include()], 20 | ) 21 | ] 22 | 23 | setup( 24 | name='Cython-based reid evaluation code', 25 | ext_modules=cythonize(ext_modules) 26 | ) 27 | -------------------------------------------------------------------------------- /secret/metrics/rank_cylib/test_cython.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import numpy as np 4 | import timeit 5 | import os.path as osp 6 | 7 | from secret import metrics 8 | 9 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 10 | """ 11 | Test the speed of cython-based evaluation code. The speed improvements 12 | can be much bigger when using the real reid data, which contains a larger 13 | amount of query and gallery images. 14 | 15 | Note: you might encounter the following error: 16 | 'AssertionError: Error: all query identities do not appear in gallery'. 17 | This is normal because the inputs are random numbers. Just try again. 18 | """ 19 | 20 | print('*** Compare running time ***') 21 | 22 | setup = ''' 23 | import sys 24 | import os.path as osp 25 | import numpy as np 26 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 27 | from torchreid import metrics 28 | num_q = 30 29 | num_g = 300 30 | max_rank = 5 31 | distmat = np.random.rand(num_q, num_g) * 20 32 | q_pids = np.random.randint(0, num_q, size=num_q) 33 | g_pids = np.random.randint(0, num_g, size=num_g) 34 | q_camids = np.random.randint(0, 5, size=num_q) 35 | g_camids = np.random.randint(0, 5, size=num_g) 36 | ''' 37 | 38 | print('=> Using market1501\'s metric') 39 | pytime = timeit.timeit( 40 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)', 41 | setup=setup, 42 | number=20 43 | ) 44 | cytime = timeit.timeit( 45 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)', 46 | setup=setup, 47 | number=20 48 | ) 49 | print('Python time: {} s'.format(pytime)) 50 | print('Cython time: {} s'.format(cytime)) 51 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 52 | 53 | print('=> Using cuhk03\'s metric') 54 | pytime = timeit.timeit( 55 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)', 56 | setup=setup, 57 | number=20 58 | ) 59 | cytime = timeit.timeit( 60 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)', 61 | setup=setup, 62 | number=20 63 | ) 64 | print('Python time: {} s'.format(pytime)) 65 | print('Cython time: {} s'.format(cytime)) 66 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 67 | """ 68 | print("=> Check precision") 69 | 70 | num_q = 30 71 | num_g = 300 72 | max_rank = 5 73 | distmat = np.random.rand(num_q, num_g) * 20 74 | q_pids = np.random.randint(0, num_q, size=num_q) 75 | g_pids = np.random.randint(0, num_g, size=num_g) 76 | q_camids = np.random.randint(0, 5, size=num_q) 77 | g_camids = np.random.randint(0, 5, size=num_g) 78 | 79 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False) 80 | print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 81 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True) 82 | print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 83 | """ 84 | -------------------------------------------------------------------------------- /secret/metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import torch 5 | import numpy as np 6 | from sklearn.metrics import average_precision_score 7 | 8 | from ..utils import to_numpy, to_torch 9 | 10 | def accuracy(output, target, topk=(1,)): 11 | with torch.no_grad(): 12 | output, target = to_torch(output), to_torch(target) 13 | maxk = max(topk) 14 | batch_size = target.size(0) 15 | 16 | _, pred = output.topk(maxk, 1, True, True) 17 | pred = pred.t() 18 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 19 | 20 | ret = [] 21 | for k in topk: 22 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 23 | ret.append(correct_k.mul_(1. / batch_size)) 24 | return ret[0][0] 25 | 26 | def _unique_sample(ids_dict, num): 27 | mask = np.zeros(num, dtype=np.bool) 28 | for _, indices in ids_dict.items(): 29 | i = np.random.choice(indices) 30 | mask[i] = True 31 | return mask 32 | 33 | 34 | def cmc(distmat, query_ids=None, gallery_ids=None, 35 | query_cams=None, gallery_cams=None, topk=100, 36 | separate_camera_set=False, 37 | single_gallery_shot=False, 38 | first_match_break=False): 39 | distmat = to_numpy(distmat) 40 | m, n = distmat.shape 41 | # Fill up default values 42 | if query_ids is None: 43 | query_ids = np.arange(m) 44 | if gallery_ids is None: 45 | gallery_ids = np.arange(n) 46 | if query_cams is None: 47 | query_cams = np.zeros(m).astype(np.int32) 48 | if gallery_cams is None: 49 | gallery_cams = np.ones(n).astype(np.int32) 50 | # Ensure numpy array 51 | query_ids = np.asarray(query_ids) 52 | gallery_ids = np.asarray(gallery_ids) 53 | query_cams = np.asarray(query_cams) 54 | gallery_cams = np.asarray(gallery_cams) 55 | # Sort and find correct matches 56 | indices = np.argsort(distmat, axis=1) 57 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 58 | # Compute CMC for each query 59 | ret = np.zeros(topk) 60 | num_valid_queries = 0 61 | for i in range(m): 62 | # Filter out the same id and same camera 63 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 64 | (gallery_cams[indices[i]] != query_cams[i])) 65 | if separate_camera_set: 66 | # Filter out samples from same camera 67 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 68 | if not np.any(matches[i, valid]): continue 69 | if single_gallery_shot: 70 | repeat = 10 71 | gids = gallery_ids[indices[i][valid]] 72 | inds = np.where(valid)[0] 73 | ids_dict = defaultdict(list) 74 | for j, x in zip(inds, gids): 75 | ids_dict[x].append(j) 76 | else: 77 | repeat = 1 78 | for _ in range(repeat): 79 | if single_gallery_shot: 80 | # Randomly choose one instance for each id 81 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 82 | index = np.nonzero(matches[i, sampled])[0] 83 | else: 84 | index = np.nonzero(matches[i, valid])[0] 85 | delta = 1. / (len(index) * repeat) 86 | for j, k in enumerate(index): 87 | if k - j >= topk: break 88 | if first_match_break: 89 | ret[k - j] += 1 90 | break 91 | ret[k - j] += delta 92 | num_valid_queries += 1 93 | if num_valid_queries == 0: 94 | raise RuntimeError("No valid query") 95 | return ret.cumsum() / num_valid_queries 96 | 97 | 98 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 99 | query_cams=None, gallery_cams=None): 100 | distmat = to_numpy(distmat) 101 | m, n = distmat.shape 102 | # Fill up default values 103 | if query_ids is None: 104 | query_ids = np.arange(m) 105 | if gallery_ids is None: 106 | gallery_ids = np.arange(n) 107 | if query_cams is None: 108 | query_cams = np.zeros(m).astype(np.int32) 109 | if gallery_cams is None: 110 | gallery_cams = np.ones(n).astype(np.int32) 111 | # Ensure numpy array 112 | query_ids = np.asarray(query_ids) 113 | gallery_ids = np.asarray(gallery_ids) 114 | query_cams = np.asarray(query_cams) 115 | gallery_cams = np.asarray(gallery_cams) 116 | # Sort and find correct matches 117 | indices = np.argsort(distmat, axis=1) 118 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 119 | # Compute AP for each query 120 | aps = [] 121 | for i in range(m): 122 | # Filter out the same id and same camera 123 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 124 | (gallery_cams[indices[i]] != query_cams[i])) 125 | y_true = matches[i, valid] 126 | y_score = -distmat[i][indices[i]][valid] 127 | if not np.any(y_true): continue 128 | aps.append(average_precision_score(y_true, y_score)) 129 | if len(aps) == 0: 130 | raise RuntimeError("No valid query") 131 | return np.mean(aps) 132 | -------------------------------------------------------------------------------- /secret/metrics/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Source: https://github.com/zhunzhong07/person-re-ranking 5 | Created on Mon Jun 26 14:46:56 2017 6 | @author: luohao 7 | Modified by Houjing Huang, 2017-12-22. 8 | - This version accepts distance matrix instead of raw features. 9 | - The difference of `/` division between python 2 and 3 is handled. 10 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 11 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 12 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 13 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 14 | API 15 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 16 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 17 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 18 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 19 | Returns: 20 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import print_function 24 | from __future__ import division 25 | 26 | __all__ = ['re_ranking'] 27 | 28 | import numpy as np 29 | 30 | 31 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 32 | 33 | # The following naming, e.g. gallery_num, is different from outer scope. 34 | # Don't care about it. 35 | 36 | original_dist = np.concatenate( 37 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 38 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 39 | axis=0) 40 | original_dist = np.power(original_dist, 2).astype(np.float32) 41 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 42 | V = np.zeros_like(original_dist).astype(np.float32) 43 | initial_rank = np.argsort(original_dist).astype(np.int32) 44 | 45 | query_num = q_g_dist.shape[0] 46 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 47 | all_num = gallery_num 48 | 49 | for i in range(all_num): 50 | # k-reciprocal neighbors 51 | forward_k_neigh_index = initial_rank[i,:k1+1] 52 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 53 | fi = np.where(backward_k_neigh_index==i)[0] 54 | k_reciprocal_index = forward_k_neigh_index[fi] 55 | k_reciprocal_expansion_index = k_reciprocal_index 56 | for j in range(len(k_reciprocal_index)): 57 | candidate = k_reciprocal_index[j] 58 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 59 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 60 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 61 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 62 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 63 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 64 | 65 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 66 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 67 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 68 | original_dist = original_dist[:query_num,] 69 | if k2 != 1: 70 | V_qe = np.zeros_like(V,dtype=np.float32) 71 | for i in range(all_num): 72 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 73 | V = V_qe 74 | del V_qe 75 | del initial_rank 76 | invIndex = [] 77 | for i in range(gallery_num): 78 | invIndex.append(np.where(V[:,i] != 0)[0]) 79 | 80 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 81 | 82 | 83 | for i in range(query_num): 84 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 85 | indNonZero = np.where(V[i,:] != 0)[0] 86 | indImages = [] 87 | indImages = [invIndex[ind] for ind in indNonZero] 88 | for j in range(len(indNonZero)): 89 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 90 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 91 | 92 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 93 | del original_dist 94 | del V 95 | del jaccard_dist 96 | final_dist = final_dist[:query_num,query_num:] 97 | return final_dist 98 | -------------------------------------------------------------------------------- /secret/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | 4 | from .resnet import resnet50 5 | 6 | __factory = { 7 | 'resnet50': resnet50, 8 | } 9 | 10 | def names(): 11 | return sorted(__factory.keys()) 12 | 13 | 14 | def create_model(cfg, num_classes, **kwargs): 15 | """ 16 | Create a model instance. 17 | 18 | Parameters 19 | ---------- 20 | name : str 21 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 22 | 'resnet50', 'resnet101', and 'resnet152'. 23 | pretrained : bool, optional 24 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 25 | model. Default: True 26 | cut_at_pooling : bool, optional 27 | If True, will cut the model before the last global pooling layer and 28 | ignore the remaining kwargs. Default: False 29 | num_features : int, optional 30 | If positive, will append a Linear layer after the global pooling layer, 31 | with this number of output units, followed by a BatchNorm layer. 32 | Otherwise these layers will not be appended. Default: 256 for 33 | 'inception', 0 for 'resnet*' 34 | norm : bool, optional 35 | If True, will normalize the feature to be unit L2-norm for each sample. 36 | Otherwise will append a ReLU layer after the above Linear layer if 37 | num_features > 0. Default: False 38 | dropout : float, optional 39 | If positive, will append a Dropout layer with this dropout rate. 40 | Default: 0 41 | num_classes : int, optional 42 | If positive, will append a Linear layer at the end as the classifier 43 | with this number of output units. Default: 0 44 | """ 45 | if cfg.MODEL.ARCH not in __factory: 46 | raise KeyError("Unknown model:", cfg.MODEL.ARCH) 47 | model = __factory[cfg.MODEL.ARCH](cfg, num_classes, **kwargs).cuda() 48 | return model 49 | -------------------------------------------------------------------------------- /secret/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | 9 | __all__ = ['resnet50'] 10 | 11 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d( 14 | in_planes, 15 | out_planes, 16 | kernel_size=3, 17 | stride=stride, 18 | padding=dilation, 19 | groups=groups, 20 | bias=False, 21 | dilation=dilation 22 | ) 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d( 27 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False 28 | ) 29 | 30 | class Bottleneck(nn.Module): 31 | expansion = 4 32 | 33 | def __init__( 34 | self, 35 | inplanes, 36 | planes, 37 | stride=1, 38 | downsample=None, 39 | groups=1, 40 | base_width=64, 41 | dilation=1, 42 | norm_layer=None 43 | ): 44 | super(Bottleneck, self).__init__() 45 | if norm_layer is None: 46 | norm_layer = nn.BatchNorm2d 47 | width = int(planes * (base_width/64.)) * groups 48 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv1x1(inplanes, width) 50 | self.bn1 = norm_layer(width) 51 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 52 | self.bn2 = norm_layer(width) 53 | self.conv3 = conv1x1(width, planes * self.expansion) 54 | self.bn3 = norm_layer(planes * self.expansion) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | self.reset_params() 60 | 61 | def forward(self, x): 62 | identity = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | identity = self.downsample(x) 77 | 78 | out += identity 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | def reset_params(self): 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | init.kaiming_normal_(m.weight, mode='fan_out') 87 | if m.bias is not None: 88 | init.constant_(m.bias, 0) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | init.constant_(m.weight, 1) 91 | init.constant_(m.bias, 0) 92 | elif isinstance(m, nn.BatchNorm1d): 93 | init.constant_(m.weight, 1) 94 | init.constant_(m.bias, 0) 95 | elif isinstance(m, nn.Linear): 96 | init.normal_(m.weight, std=0.001) 97 | if m.bias is not None: 98 | init.constant_(m.bias, 0) 99 | 100 | class ResNet(nn.Module): 101 | __factory = { 102 | 50: torchvision.models.resnet50, 103 | } 104 | 105 | def __init__(self, depth, cfg, num_classes): 106 | super(ResNet, self).__init__() 107 | 108 | self.pretrained = cfg.MODEL.BACKBONE.PRETRAIN 109 | self.depth = depth 110 | # Construct base (pretrained) resnet 111 | if depth not in ResNet.__factory: 112 | raise KeyError("Unsupported depth:", depth) 113 | resnet = ResNet.__factory[depth](pretrained=self.pretrained) 114 | resnet.layer4[0].conv2.stride = (1,1) 115 | resnet.layer4[0].downsample[0].stride = (1,1) 116 | self.base = nn.Sequential( 117 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, 118 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4) 119 | self.gap = nn.AdaptiveAvgPool2d(1) 120 | 121 | self.num_features = resnet.fc.in_features 122 | self.num_classes = num_classes 123 | 124 | out_planes = resnet.fc.in_features 125 | 126 | # Append new layers 127 | self.num_features = out_planes 128 | self.part_detach = cfg.MODEL.PART_DETACH 129 | 130 | self.feat_bn = nn.BatchNorm1d(self.num_features) 131 | self.feat_bn.bias.requires_grad_(False) 132 | init.constant_(self.feat_bn.weight, 1) 133 | init.constant_(self.feat_bn.bias, 0) 134 | 135 | self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False) 136 | init.normal_(self.classifier.weight, std=0.001) 137 | 138 | norm_layer = nn.BatchNorm2d 139 | block = Bottleneck 140 | planes = 512 141 | self.planes = planes 142 | downsample = nn.Sequential( 143 | conv1x1(out_planes, planes * block.expansion), 144 | norm_layer(planes * block.expansion), 145 | ) 146 | self.part_bottleneck = block( 147 | out_planes, planes, downsample = downsample, norm_layer = norm_layer 148 | ) 149 | 150 | self.part_num_features = planes * block.expansion 151 | self.part_pool = nn.AdaptiveAvgPool2d((2,1)) 152 | 153 | self.partup_feat_bn = nn.BatchNorm1d(self.part_num_features) 154 | self.partup_feat_bn.bias.requires_grad_(False) 155 | init.constant_(self.partup_feat_bn.weight, 1) 156 | init.constant_(self.partup_feat_bn.bias, 0) 157 | 158 | self.partdown_feat_bn = nn.BatchNorm1d(self.part_num_features) 159 | self.partdown_feat_bn.bias.requires_grad_(False) 160 | init.constant_(self.partdown_feat_bn.weight, 1) 161 | init.constant_(self.partdown_feat_bn.bias, 0) 162 | 163 | self.classifier_partup = nn.Linear(self.part_num_features, self.num_classes, bias = False) 164 | init.normal_(self.classifier_partup.weight, std=0.001) 165 | self.classifier_partdown = nn.Linear(self.part_num_features, self.num_classes, bias = False) 166 | init.normal_(self.classifier_partdown.weight, std=0.001) 167 | 168 | if not self.pretrained: 169 | self.reset_params() 170 | 171 | def forward(self, x, finetune = False): 172 | featuremap = self.base(x) 173 | 174 | x = self.gap(featuremap) 175 | x = x.view(x.size(0), -1) 176 | 177 | bn_x = self.feat_bn(x) 178 | 179 | if self.part_detach: 180 | part_x = self.part_bottleneck(featuremap.detach()) 181 | else: 182 | part_x = self.part_bottleneck(featuremap) 183 | 184 | part_x = self.part_pool(part_x) 185 | part_up = part_x[:, :, 0, :] 186 | part_up = part_up.view(part_up.size(0), -1) 187 | bn_part_up = self.partup_feat_bn(part_up) 188 | 189 | part_down = part_x[:, :, 1, :] 190 | part_down = part_down.view(part_down.size(0), -1) 191 | bn_part_down = self.partdown_feat_bn(part_down) 192 | 193 | if self.training is False and finetune is False: 194 | bn_x = F.normalize(bn_x) 195 | return [bn_x] 196 | 197 | prob = self.classifier(bn_x) 198 | prob_part_up = self.classifier_partup(bn_part_up) 199 | prob_part_down = self.classifier_partdown(bn_part_down) 200 | 201 | if finetune is True: 202 | bn_x = F.normalize(bn_x) 203 | bn_part_up = F.normalize(bn_part_up) 204 | bn_part_down = F.normalize(bn_part_down) 205 | return [x, part_up, part_down], [bn_x, bn_part_up, bn_part_down], [prob, prob_part_up, prob_part_down] 206 | else: 207 | return [x, part_up, part_down], [prob, prob_part_up, prob_part_down] 208 | 209 | def reset_params(self): 210 | for m in self.modules(): 211 | if isinstance(m, nn.Conv2d): 212 | init.kaiming_normal_(m.weight, mode='fan_out') 213 | if m.bias is not None: 214 | init.constant_(m.bias, 0) 215 | elif isinstance(m, nn.BatchNorm2d): 216 | init.constant_(m.weight, 1) 217 | init.constant_(m.bias, 0) 218 | elif isinstance(m, nn.BatchNorm1d): 219 | init.constant_(m.weight, 1) 220 | init.constant_(m.bias, 0) 221 | elif isinstance(m, nn.Linear): 222 | init.normal_(m.weight, std=0.001) 223 | if m.bias is not None: 224 | init.constant_(m.bias, 0) 225 | 226 | resnet = PartResNet.__factory[self.depth](pretrained=self.pretrained) 227 | self.base[0].load_state_dict(resnet.conv1.state_dict()) 228 | self.base[1].load_state_dict(resnet.bn1.state_dict()) 229 | self.base[2].load_state_dict(resnet.maxpool.state_dict()) 230 | self.base[3].load_state_dict(resnet.layer1.state_dict()) 231 | self.base[4].load_state_dict(resnet.layer2.state_dict()) 232 | self.base[5].load_state_dict(resnet.layer3.state_dict()) 233 | self.base[6].load_state_dict(resnet.layer4.state_dict()) 234 | 235 | def resnet50(cfg, num_classes, **kwargs): 236 | return ResNet(50, cfg, num_classes, **kwargs) 237 | -------------------------------------------------------------------------------- /secret/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import lr_scheduler 4 | from . import optimizer 5 | -------------------------------------------------------------------------------- /secret/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | from bisect import bisect_right 4 | 5 | __factory = ['single_step', 'multi_step', 'cosine', 'warmupmultisteplr'] 6 | 7 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | milestones, 12 | gamma=0.1, 13 | warmup_factor=1.0 / 3, 14 | warmup_iters=500, 15 | warmup_method="linear", 16 | last_epoch=-1, 17 | ): 18 | if not list(milestones) == sorted(milestones): 19 | raise ValueError( 20 | "Milestones should be a list of" " increasing integers. Got {}", 21 | milestones, 22 | ) 23 | 24 | if warmup_method not in ("constant", "linear"): 25 | raise ValueError( 26 | "Only 'constant' or 'linear' warmup_method accepted" 27 | "got {}".format(warmup_method) 28 | ) 29 | self.milestones = milestones 30 | self.gamma = gamma 31 | self.warmup_factor = warmup_factor 32 | self.warmup_iters = warmup_iters 33 | self.warmup_method = warmup_method 34 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 35 | 36 | def get_lr(self): 37 | warmup_factor = 1 38 | if self.last_epoch < self.warmup_iters: 39 | if self.warmup_method == "constant": 40 | warmup_factor = self.warmup_factor 41 | elif self.warmup_method == "linear": 42 | alpha = float(self.last_epoch) / float(self.warmup_iters) 43 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 44 | return [ 45 | base_lr 46 | * warmup_factor 47 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 48 | for base_lr in self.base_lrs 49 | ] 50 | 51 | def build_lr_scheduler(cfg, optimizer, last_epoch = -1): 52 | """A function wrapper for building a learning rate scheduler. 53 | Args: 54 | optimizer (Optimizer): an Optimizer. 55 | lr_scheduler (str, optional): learning rate scheduler method. Default is single_step. 56 | stepsize (int or list, optional): step size to decay learning rate. When ``lr_scheduler`` 57 | is "single_step", ``stepsize`` should be an integer. When ``lr_scheduler`` is 58 | "multi_step", ``stepsize`` is a list. Default is 1. 59 | gamma (float, optional): decay rate. Default is 0.1. 60 | max_epoch (int, optional): maximum epoch (for cosine annealing). Default is 1. 61 | Examples:: 62 | >>> # Decay learning rate by every 20 epochs. 63 | >>> scheduler = torchreid.optim.build_lr_scheduler( 64 | >>> optimizer, lr_scheduler='single_step', stepsize=20 65 | >>> ) 66 | >>> # Decay learning rate at 30, 50 and 55 epochs. 67 | >>> scheduler = torchreid.optim.build_lr_scheduler( 68 | >>> optimizer, lr_scheduler='multi_step', stepsize=[30, 50, 55] 69 | >>> ) 70 | """ 71 | stepsize = cfg.OPTIM.STEPS 72 | 73 | if cfg.OPTIM.SCHED not in __factory: 74 | raise ValueError( 75 | 'Unsupported scheduler: {}. Must be one of {}'.format( 76 | cfg.OPTIM.SCHED, __factory 77 | ) 78 | ) 79 | 80 | if cfg.OPTIM.SCHED == 'single_step': 81 | if isinstance(stepsize, list): 82 | stepsize = stepsize[-1] 83 | 84 | if not isinstance(stepsize, int): 85 | raise TypeError( 86 | 'For single_step lr_scheduler, stepsize must ' 87 | 'be an integer, but got {}'.format(type(stepsize)) 88 | ) 89 | 90 | scheduler = torch.optim.lr_scheduler.StepLR( 91 | optimizer, step_size=stepsize, gamma=cfg.OPTIM.GAMMA, last_epoch = last_epoch 92 | ) 93 | 94 | elif cfg.OPTIM.SCHED == 'multi_step': 95 | if not isinstance(stepsize, list): 96 | raise TypeError( 97 | 'For multi_step lr_scheduler, stepsize must ' 98 | 'be a list, but got {}'.format(type(stepsize)) 99 | ) 100 | 101 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 102 | optimizer, milestones=stepsize, gamma=cfg.OPTIM.GAMMA, last_epoch = last_epoch 103 | ) 104 | 105 | elif cfg.OPTIM.SCHED == 'cosine': 106 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 107 | optimizer, float(cfg.OPTIM.COSINE_MAX_EPOCH), last_epoch = last_epoch 108 | ) 109 | 110 | elif cfg.OPTIM.SCHED == 'warmupmultisteplr': 111 | if not isinstance(stepsize, list): 112 | raise TypeError( 113 | 'For WarmupMultiStepLR lr_scheduler, stepsize must ' 114 | 'be a list, but got {}'.format(type(stepsize)) 115 | ) 116 | 117 | scheduler = WarmupMultiStepLR( 118 | optimizer, milestones = stepsize, gamma = cfg.OPTIM.GAMMA, 119 | warmup_factor = cfg.OPTIM.WARMUP_FACTOR, 120 | warmup_iters = cfg.OPTIM.WARMUP_ITERS, 121 | warmup_method = cfg.OPTIM.WARMUP_METHOD, 122 | last_epoch = last_epoch 123 | ) 124 | 125 | return scheduler 126 | -------------------------------------------------------------------------------- /secret/optim/optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import warnings 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | __factory = ['adam', 'amsgrad', 'sgd', 'rmsprop'] 8 | 9 | 10 | def build_optimizer(cfg, model, LR = None): 11 | """A function wrapper for building an optimizer. 12 | Args: 13 | model (nn.Module): model. 14 | optim (str, optional): optimizer. Default is "adam". 15 | lr (float, optional): learning rate. Default is 0.0003. 16 | weight_decay (float, optional): weight decay (L2 penalty). Default is 5e-04. 17 | momentum (float, optional): momentum factor in sgd. Default is 0.9, 18 | sgd_dampening (float, optional): dampening for momentum. Default is 0. 19 | sgd_nesterov (bool, optional): enables Nesterov momentum. Default is False. 20 | rmsprop_alpha (float, optional): smoothing constant for rmsprop. Default is 0.99. 21 | adam_beta1 (float, optional): beta-1 value in adam. Default is 0.9. 22 | adam_beta2 (float, optional): beta-2 value in adam. Default is 0.99, 23 | staged_lr (bool, optional): uses different learning rates for base and new layers. Base 24 | layers are pretrained layers while new layers are randomly initialized, e.g. the 25 | identity classification layer. Enabling ``staged_lr`` can allow the base layers to 26 | be trained with a smaller learning rate determined by ``base_lr_mult``, while the new 27 | layers will take the ``lr``. Default is False. 28 | new_layers (str or list): attribute names in ``model``. Default is empty. 29 | base_lr_mult (float, optional): learning rate multiplier for base layers. Default is 0.1. 30 | Examples:: 31 | >>> # A normal optimizer can be built by 32 | >>> optimizer = torchreid.optim.build_optimizer(model, optim='sgd', lr=0.01) 33 | >>> # If you want to use a smaller learning rate for pretrained layers 34 | >>> # and the attribute name for the randomly initialized layer is 'classifier', 35 | >>> # you can do 36 | >>> optimizer = torchreid.optim.build_optimizer( 37 | >>> model, optim='sgd', lr=0.01, staged_lr=True, 38 | >>> new_layers='classifier', base_lr_mult=0.1 39 | >>> ) 40 | >>> # Now the `classifier` has learning rate 0.01 but the base layers 41 | >>> # have learning rate 0.01 * 0.1. 42 | >>> # new_layers can also take multiple attribute names. Say the new layers 43 | >>> # are 'fc' and 'classifier', you can do 44 | >>> optimizer = torchreid.optim.build_optimizer( 45 | >>> model, optim='sgd', lr=0.01, staged_lr=True, 46 | >>> new_layers=['fc', 'classifier'], base_lr_mult=0.1 47 | >>> ) 48 | """ 49 | 50 | if cfg.OPTIM.OPT not in __factory: 51 | raise ValueError( 52 | 'Unsupported optim: {}. Must be one of {}'.format( 53 | cfg.OPTIM.OPT, __factory 54 | ) 55 | ) 56 | 57 | if not isinstance(model, nn.Module): 58 | raise TypeError( 59 | 'model given to build_optimizer must be an instance of nn.Module' 60 | ) 61 | 62 | if LR is None: 63 | LR = cfg.OPTIM.LR 64 | # param_groups = [] 65 | # for _, value in model.named_parameters(): 66 | # if not value.requires_grad: 67 | # continue 68 | # param_groups += [{"params": [value]}] 69 | param_groups = [{'params': model.parameters(), 'initial_lr': LR}] # model.parameters() 70 | 71 | # if len(cfg.GPU_Device) == 1: 72 | # n = cfg.DATALOADER.BATCH_SIZE // cfg.OPTIM.FORWARD_BATCH_SIZE 73 | # LR = LR/n 74 | 75 | if cfg.OPTIM.OPT == 'adam': 76 | optimizer = torch.optim.Adam( 77 | param_groups, 78 | lr=LR, 79 | weight_decay=cfg.OPTIM.WEIGHT_DECAY, 80 | betas=(cfg.OPTIM.ADAM_BETA1, cfg.OPTIM.ADAM_BETA2), 81 | ) 82 | 83 | elif cfg.OPTIM.OPT == 'amsgrad': 84 | optimizer = torch.optim.Adam( 85 | param_groups, 86 | lr=LR, 87 | weight_decay=cfg.OPTIM.WEIGHT_DECAY, 88 | betas=(cfg.OPTIM.ADAM_BETA1, cfg.OPTIM.ADAM_BETA2), 89 | amsgrad=True, 90 | ) 91 | 92 | elif cfg.OPTIM.OPT == 'sgd': 93 | optimizer = torch.optim.SGD( 94 | param_groups, 95 | lr=LR, 96 | momentum=cfg.OPTIM.MOMENTUM, 97 | weight_decay=cfg.OPTIM.WEIGHT_DECAY, 98 | dampening=cfg.OPTIM.SGD_DAMPENING, 99 | nesterov=cfg.OPTIM.SGD_NESTEROV, 100 | ) 101 | 102 | elif cfg.OPTIM.OPT == 'rmsprop': 103 | optimizer = torch.optim.RMSprop( 104 | param_groups, 105 | lr=LR, 106 | momentum=cfg.OPTIM.MOMENTUM, 107 | weight_decay=cfg.OPTIM.WEIGHT_DECAY, 108 | alpha=cfg.OPTIM.RMSPROP_ALPHA, 109 | ) 110 | 111 | return optimizer 112 | -------------------------------------------------------------------------------- /secret/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import osutils 4 | from . import logger 5 | from . import defaults 6 | from . import meters 7 | from . import serialization 8 | 9 | import torch 10 | 11 | def to_numpy(tensor): 12 | if torch.is_tensor(tensor): 13 | return tensor.cpu().numpy() 14 | elif type(tensor).__module__ != 'numpy': 15 | raise ValueError("Cannot convert {} to numpy array" 16 | .format(type(tensor))) 17 | return tensor 18 | 19 | 20 | def to_torch(ndarray): 21 | if type(ndarray).__module__ == 'numpy': 22 | return torch.from_numpy(ndarray) 23 | elif not torch.is_tensor(ndarray): 24 | raise ValueError("Cannot convert {} to torch tensor" 25 | .format(type(ndarray))) 26 | return ndarray 27 | -------------------------------------------------------------------------------- /secret/utils/defaults.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import numpy as np 5 | import torch 6 | import random 7 | from torch.backends import cudnn 8 | 9 | from .osutils import PathManager 10 | from .logger import setup_logger 11 | 12 | def default_argument_parser(): 13 | """ 14 | Create a parser with some common arguments used by fastreid users. 15 | Returns: 16 | argparse.ArgumentParser: 17 | """ 18 | parser = argparse.ArgumentParser(description="SECRET Training") 19 | parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") 20 | parser.add_argument( 21 | "--finetune", 22 | action="store_true", 23 | help="whether to attempt to finetune from the trained model", 24 | ) 25 | parser.add_argument( 26 | "--resume", 27 | action="store_true", 28 | help="whether to attempt to resume from the checkpoint directory", 29 | ) 30 | parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") 31 | parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") 32 | 33 | parser.add_argument( 34 | "opts", 35 | help="Modify config options using the command-line", 36 | default=None, 37 | nargs=argparse.REMAINDER, 38 | ) 39 | return parser 40 | 41 | 42 | def default_setup(cfg, args): 43 | """ 44 | Perform some basic common setups at the beginning of a job, including: 45 | 1. Set up the detectron2 logger 46 | 2. Log basic information about environment, cmdline arguments, and config 47 | 3. Backup the config to the output directory 48 | Args: 49 | cfg (CfgNode): the full config to be used 50 | args (argparse.NameSpace): the command line arguments to be logged 51 | """ 52 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(s) for s in cfg.GPU_Device) 53 | 54 | output_dir = cfg.OUTPUT_DIR 55 | if output_dir: 56 | PathManager.mkdirs(output_dir) 57 | 58 | rank = 0 59 | setup_logger(output_dir, distributed_rank=rank, name="fvcore") 60 | logger = setup_logger(output_dir, distributed_rank=rank) 61 | 62 | logger.info("Command line arguments: " + str(args)) 63 | if hasattr(args, "config_file") and args.config_file != "": 64 | logger.info( 65 | "Contents of args.config_file={}:\n{}".format( 66 | args.config_file, PathManager.open(args.config_file, "r").read() 67 | ) 68 | ) 69 | 70 | logger.info("Running with full config:\n{}".format(cfg)) 71 | if output_dir: 72 | # Note: some of our scripts may expect the existence of 73 | # config.yaml in output directory 74 | path = os.path.join(output_dir, "config.yaml") 75 | with PathManager.open(path, "w") as f: 76 | f.write(cfg.dump()) 77 | logger.info("Full config saved to {}".format(os.path.abspath(path))) 78 | 79 | # make sure each worker has a different, yet deterministic seed if specified 80 | np.random.seed(cfg.SEED) 81 | torch.manual_seed(cfg.SEED) 82 | torch.cuda.manual_seed(cfg.SEED) 83 | torch.cuda.manual_seed_all(cfg.SEED) 84 | random.seed(cfg.SEED) 85 | cudnn.deterministic = True 86 | 87 | # cudnn benchmark has large overhead. It shouldn't be used considering the small size of 88 | # typical validation set. 89 | # if not (hasattr(args, "eval_only") and args.eval_only): 90 | torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK 91 | -------------------------------------------------------------------------------- /secret/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import functools 3 | import logging 4 | import os 5 | import sys 6 | import time 7 | from collections import Counter 8 | from .osutils import PathManager 9 | from termcolor import colored 10 | 11 | 12 | class _ColorfulFormatter(logging.Formatter): 13 | def __init__(self, *args, **kwargs): 14 | self._root_name = kwargs.pop("root_name") + "." 15 | self._abbrev_name = kwargs.pop("abbrev_name", "") 16 | if len(self._abbrev_name): 17 | self._abbrev_name = self._abbrev_name + "." 18 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 19 | 20 | def formatMessage(self, record): 21 | record.name = record.name.replace(self._root_name, self._abbrev_name) 22 | log = super(_ColorfulFormatter, self).formatMessage(record) 23 | if record.levelno == logging.WARNING: 24 | prefix = colored("WARNING", "red", attrs=["blink"]) 25 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 26 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 27 | else: 28 | return log 29 | return prefix + " " + log 30 | 31 | 32 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers 33 | def setup_logger( 34 | output=None, distributed_rank=0, *, color=True, name="UnReID", abbrev_name=None 35 | ): 36 | """ 37 | Args: 38 | output (str): a file name or a directory to save log. If None, will not save log file. 39 | If ends with ".txt" or ".log", assumed to be a file name. 40 | Otherwise, logs will be saved to `output/log.txt`. 41 | name (str): the root module name of this logger 42 | abbrev_name (str): an abbreviation of the module, to avoid long names in logs. 43 | Set to "" to not log the root module in logs. 44 | By default, will abbreviate "detectron2" to "d2" and leave other 45 | modules unchanged. 46 | """ 47 | logger = logging.getLogger(name) 48 | logger.setLevel(logging.DEBUG) 49 | logger.propagate = False 50 | 51 | if abbrev_name is None: 52 | abbrev_name = "d2" if name == "detectron2" else name 53 | 54 | plain_formatter = logging.Formatter( 55 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 56 | ) 57 | # stdout logging: master only 58 | if distributed_rank == 0: 59 | ch = logging.StreamHandler(stream=sys.stdout) 60 | ch.setLevel(logging.DEBUG) 61 | if color: 62 | formatter = _ColorfulFormatter( 63 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 64 | datefmt="%m/%d %H:%M:%S", 65 | root_name=name, 66 | abbrev_name=str(abbrev_name), 67 | ) 68 | else: 69 | formatter = plain_formatter 70 | ch.setFormatter(formatter) 71 | logger.addHandler(ch) 72 | 73 | # file logging: all workers 74 | if output is not None: 75 | if output.endswith(".txt") or output.endswith(".log"): 76 | filename = output 77 | else: 78 | filename = os.path.join(output, "log.txt") 79 | if distributed_rank > 0: 80 | filename = filename + ".rank{}".format(distributed_rank) 81 | PathManager.mkdirs(os.path.dirname(filename)) 82 | 83 | fh = logging.StreamHandler(_cached_log_stream(filename)) 84 | fh.setLevel(logging.DEBUG) 85 | fh.setFormatter(plain_formatter) 86 | logger.addHandler(fh) 87 | 88 | return logger 89 | 90 | 91 | # cache the opened file object, so that different calls to `setup_logger` 92 | # with the same file name can safely write to the same file. 93 | @functools.lru_cache(maxsize=None) 94 | def _cached_log_stream(filename): 95 | return PathManager.open(filename, "w") 96 | # return PathManager.open(filename, "a") 97 | 98 | 99 | """ 100 | Below are some other convenient logging methods. 101 | They are mainly adopted from 102 | https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py 103 | """ 104 | 105 | 106 | def _find_caller(): 107 | """ 108 | Returns: 109 | str: module name of the caller 110 | tuple: a hashable key to be used to identify different callers 111 | """ 112 | frame = sys._getframe(2) 113 | while frame: 114 | code = frame.f_code 115 | if os.path.join("utils", "logger.") not in code.co_filename: 116 | mod_name = frame.f_globals["__name__"] 117 | if mod_name == "__main__": 118 | mod_name = "detectron2" 119 | return mod_name, (code.co_filename, frame.f_lineno, code.co_name) 120 | frame = frame.f_back 121 | 122 | 123 | _LOG_COUNTER = Counter() 124 | _LOG_TIMER = {} 125 | 126 | 127 | def log_first_n(lvl, msg, n=1, *, name=None, key="caller"): 128 | """ 129 | Log only for the first n times. 130 | Args: 131 | lvl (int): the logging level 132 | msg (str): 133 | n (int): 134 | name (str): name of the logger to use. Will use the caller's module by default. 135 | key (str or tuple[str]): the string(s) can be one of "caller" or 136 | "message", which defines how to identify duplicated logs. 137 | For example, if called with `n=1, key="caller"`, this function 138 | will only log the first call from the same caller, regardless of 139 | the message content. 140 | If called with `n=1, key="message"`, this function will log the 141 | same content only once, even if they are called from different places. 142 | If called with `n=1, key=("caller", "message")`, this function 143 | will not log only if the same caller has logged the same message before. 144 | """ 145 | if isinstance(key, str): 146 | key = (key,) 147 | assert len(key) > 0 148 | 149 | caller_module, caller_key = _find_caller() 150 | hash_key = () 151 | if "caller" in key: 152 | hash_key = hash_key + caller_key 153 | if "message" in key: 154 | hash_key = hash_key + (msg,) 155 | 156 | _LOG_COUNTER[hash_key] += 1 157 | if _LOG_COUNTER[hash_key] <= n: 158 | logging.getLogger(name or caller_module).log(lvl, msg) 159 | 160 | 161 | def log_every_n(lvl, msg, n=1, *, name=None): 162 | """ 163 | Log once per n times. 164 | Args: 165 | lvl (int): the logging level 166 | msg (str): 167 | n (int): 168 | name (str): name of the logger to use. Will use the caller's module by default. 169 | """ 170 | caller_module, key = _find_caller() 171 | _LOG_COUNTER[key] += 1 172 | if n == 1 or _LOG_COUNTER[key] % n == 1: 173 | logging.getLogger(name or caller_module).log(lvl, msg) 174 | 175 | 176 | def log_every_n_seconds(lvl, msg, n=1, *, name=None): 177 | """ 178 | Log no more than once per n seconds. 179 | Args: 180 | lvl (int): the logging level 181 | msg (str): 182 | n (int): 183 | name (str): name of the logger to use. Will use the caller's module by default. 184 | """ 185 | caller_module, key = _find_caller() 186 | last_logged = _LOG_TIMER.get(key, None) 187 | current_time = time.time() 188 | if last_logged is None or current_time - last_logged >= n: 189 | logging.getLogger(name or caller_module).log(lvl, msg) 190 | _LOG_TIMER[key] = current_time 191 | 192 | # def create_small_table(small_dict): 193 | # """ 194 | # Create a small table using the keys of small_dict as headers. This is only 195 | # suitable for small dictionaries. 196 | # Args: 197 | # small_dict (dict): a result dictionary of only a few items. 198 | # Returns: 199 | # str: the table as a string. 200 | # """ 201 | # keys, values = tuple(zip(*small_dict.items())) 202 | # table = tabulate( 203 | # [values], 204 | # headers=keys, 205 | # tablefmt="pipe", 206 | # floatfmt=".3f", 207 | # stralign="center", 208 | # numalign="center", 209 | # ) 210 | # return table 211 | -------------------------------------------------------------------------------- /secret/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /secret/utils/osutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | from __future__ import absolute_import 4 | import errno 5 | import logging 6 | import os 7 | import shutil 8 | from collections import OrderedDict 9 | from typing import ( 10 | IO, 11 | Any, 12 | Callable, 13 | Dict, 14 | List, 15 | MutableMapping, 16 | Optional, 17 | Union, 18 | ) 19 | 20 | __all__ = ["PathManager", "get_cache_dir"] 21 | 22 | 23 | def get_cache_dir(cache_dir: Optional[str] = None) -> str: 24 | """ 25 | Returns a default directory to cache static files 26 | (usually downloaded from Internet), if None is provided. 27 | Args: 28 | cache_dir (None or str): if not None, will be returned as is. 29 | If None, returns the default cache directory as: 30 | 1) $FVCORE_CACHE, if set 31 | 2) otherwise ~/.torch/fvcore_cache 32 | """ 33 | if cache_dir is None: 34 | cache_dir = os.path.expanduser( 35 | os.getenv("FVCORE_CACHE", "~/.torch/fvcore_cache") 36 | ) 37 | return cache_dir 38 | 39 | 40 | class PathHandler: 41 | """ 42 | PathHandler is a base class that defines common I/O functionality for a URI 43 | protocol. It routes I/O for a generic URI which may look like "protocol://*" 44 | or a canonical filepath "/foo/bar/baz". 45 | """ 46 | 47 | _strict_kwargs_check = True 48 | 49 | def _check_kwargs(self, kwargs: Dict[str, Any]) -> None: 50 | """ 51 | Checks if the given arguments are empty. Throws a ValueError if strict 52 | kwargs checking is enabled and args are non-empty. If strict kwargs 53 | checking is disabled, only a warning is logged. 54 | Args: 55 | kwargs (Dict[str, Any]) 56 | """ 57 | if self._strict_kwargs_check: 58 | if len(kwargs) > 0: 59 | raise ValueError("Unused arguments: {}".format(kwargs)) 60 | else: 61 | logger = logging.getLogger(__name__) 62 | for k, v in kwargs.items(): 63 | logger.warning( 64 | "[PathManager] {}={} argument ignored".format(k, v) 65 | ) 66 | 67 | def _get_supported_prefixes(self) -> List[str]: 68 | """ 69 | Returns: 70 | List[str]: the list of URI prefixes this PathHandler can support 71 | """ 72 | raise NotImplementedError() 73 | 74 | def _get_local_path(self, path: str, **kwargs: Any) -> str: 75 | """ 76 | Get a filepath which is compatible with native Python I/O such as `open` 77 | and `os.path`. 78 | If URI points to a remote resource, this function may download and cache 79 | the resource to local disk. In this case, this function is meant to be 80 | used with read-only resources. 81 | Args: 82 | path (str): A URI supported by this PathHandler 83 | Returns: 84 | local_path (str): a file path which exists on the local file system 85 | """ 86 | raise NotImplementedError() 87 | 88 | def _open( 89 | self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any 90 | ) -> Union[IO[str], IO[bytes]]: 91 | """ 92 | Open a stream to a URI, similar to the built-in `open`. 93 | Args: 94 | path (str): A URI supported by this PathHandler 95 | mode (str): Specifies the mode in which the file is opened. It defaults 96 | to 'r'. 97 | buffering (int): An optional integer used to set the buffering policy. 98 | Pass 0 to switch buffering off and an integer >= 1 to indicate the 99 | size in bytes of a fixed-size chunk buffer. When no buffering 100 | argument is given, the default buffering policy depends on the 101 | underlying I/O implementation. 102 | Returns: 103 | file: a file-like object. 104 | """ 105 | raise NotImplementedError() 106 | 107 | def _copy( 108 | self, 109 | src_path: str, 110 | dst_path: str, 111 | overwrite: bool = False, 112 | **kwargs: Any, 113 | ) -> bool: 114 | """ 115 | Copies a source path to a destination path. 116 | Args: 117 | src_path (str): A URI supported by this PathHandler 118 | dst_path (str): A URI supported by this PathHandler 119 | overwrite (bool): Bool flag for forcing overwrite of existing file 120 | Returns: 121 | status (bool): True on success 122 | """ 123 | raise NotImplementedError() 124 | 125 | def _exists(self, path: str, **kwargs: Any) -> bool: 126 | """ 127 | Checks if there is a resource at the given URI. 128 | Args: 129 | path (str): A URI supported by this PathHandler 130 | Returns: 131 | bool: true if the path exists 132 | """ 133 | raise NotImplementedError() 134 | 135 | def _isfile(self, path: str, **kwargs: Any) -> bool: 136 | """ 137 | Checks if the resource at the given URI is a file. 138 | Args: 139 | path (str): A URI supported by this PathHandler 140 | Returns: 141 | bool: true if the path is a file 142 | """ 143 | raise NotImplementedError() 144 | 145 | def _isdir(self, path: str, **kwargs: Any) -> bool: 146 | """ 147 | Checks if the resource at the given URI is a directory. 148 | Args: 149 | path (str): A URI supported by this PathHandler 150 | Returns: 151 | bool: true if the path is a directory 152 | """ 153 | raise NotImplementedError() 154 | 155 | def _ls(self, path: str, **kwargs: Any) -> List[str]: 156 | """ 157 | List the contents of the directory at the provided URI. 158 | Args: 159 | path (str): A URI supported by this PathHandler 160 | Returns: 161 | List[str]: list of contents in given path 162 | """ 163 | raise NotImplementedError() 164 | 165 | def _mkdirs(self, path: str, **kwargs: Any) -> None: 166 | """ 167 | Recursive directory creation function. Like mkdir(), but makes all 168 | intermediate-level directories needed to contain the leaf directory. 169 | Similar to the native `os.makedirs`. 170 | Args: 171 | path (str): A URI supported by this PathHandler 172 | """ 173 | raise NotImplementedError() 174 | 175 | def _rm(self, path: str, **kwargs: Any) -> None: 176 | """ 177 | Remove the file (not directory) at the provided URI. 178 | Args: 179 | path (str): A URI supported by this PathHandler 180 | """ 181 | raise NotImplementedError() 182 | 183 | 184 | class NativePathHandler(PathHandler): 185 | """ 186 | Handles paths that can be accessed using Python native system calls. This 187 | handler uses `open()` and `os.*` calls on the given path. 188 | """ 189 | 190 | def _get_local_path(self, path: str, **kwargs: Any) -> str: 191 | self._check_kwargs(kwargs) 192 | return path 193 | 194 | def _open( 195 | self, 196 | path: str, 197 | mode: str = "r", 198 | buffering: int = -1, 199 | encoding: Optional[str] = None, 200 | errors: Optional[str] = None, 201 | newline: Optional[str] = None, 202 | closefd: bool = True, 203 | opener: Optional[Callable] = None, 204 | **kwargs: Any, 205 | ) -> Union[IO[str], IO[bytes]]: 206 | """ 207 | Open a path. 208 | Args: 209 | path (str): A URI supported by this PathHandler 210 | mode (str): Specifies the mode in which the file is opened. It defaults 211 | to 'r'. 212 | buffering (int): An optional integer used to set the buffering policy. 213 | Pass 0 to switch buffering off and an integer >= 1 to indicate the 214 | size in bytes of a fixed-size chunk buffer. When no buffering 215 | argument is given, the default buffering policy works as follows: 216 | * Binary files are buffered in fixed-size chunks; the size of 217 | the buffer is chosen using a heuristic trying to determine the 218 | underlying device’s “block size” and falling back on 219 | io.DEFAULT_BUFFER_SIZE. On many systems, the buffer will 220 | typically be 4096 or 8192 bytes long. 221 | encoding (Optional[str]): the name of the encoding used to decode or 222 | encode the file. This should only be used in text mode. 223 | errors (Optional[str]): an optional string that specifies how encoding 224 | and decoding errors are to be handled. This cannot be used in binary 225 | mode. 226 | newline (Optional[str]): controls how universal newlines mode works 227 | (it only applies to text mode). It can be None, '', '\n', '\r', 228 | and '\r\n'. 229 | closefd (bool): If closefd is False and a file descriptor rather than 230 | a filename was given, the underlying file descriptor will be kept 231 | open when the file is closed. If a filename is given closefd must 232 | be True (the default) otherwise an error will be raised. 233 | opener (Optional[Callable]): A custom opener can be used by passing 234 | a callable as opener. The underlying file descriptor for the file 235 | object is then obtained by calling opener with (file, flags). 236 | opener must return an open file descriptor (passing os.open as opener 237 | results in functionality similar to passing None). 238 | See https://docs.python.org/3/library/functions.html#open for details. 239 | Returns: 240 | file: a file-like object. 241 | """ 242 | self._check_kwargs(kwargs) 243 | return open( # type: ignore 244 | path, 245 | mode, 246 | buffering=buffering, 247 | encoding=encoding, 248 | errors=errors, 249 | newline=newline, 250 | closefd=closefd, 251 | opener=opener, 252 | ) 253 | 254 | def _copy( 255 | self, 256 | src_path: str, 257 | dst_path: str, 258 | overwrite: bool = False, 259 | **kwargs: Any, 260 | ) -> bool: 261 | """ 262 | Copies a source path to a destination path. 263 | Args: 264 | src_path (str): A URI supported by this PathHandler 265 | dst_path (str): A URI supported by this PathHandler 266 | overwrite (bool): Bool flag for forcing overwrite of existing file 267 | Returns: 268 | status (bool): True on success 269 | """ 270 | self._check_kwargs(kwargs) 271 | 272 | if os.path.exists(dst_path) and not overwrite: 273 | logger = logging.getLogger(__name__) 274 | logger.error("Destination file {} already exists.".format(dst_path)) 275 | return False 276 | 277 | try: 278 | shutil.copyfile(src_path, dst_path) 279 | return True 280 | except Exception as e: 281 | logger = logging.getLogger(__name__) 282 | logger.error("Error in file copy - {}".format(str(e))) 283 | return False 284 | 285 | def _exists(self, path: str, **kwargs: Any) -> bool: 286 | self._check_kwargs(kwargs) 287 | return os.path.exists(path) 288 | 289 | def _isfile(self, path: str, **kwargs: Any) -> bool: 290 | self._check_kwargs(kwargs) 291 | return os.path.isfile(path) 292 | 293 | def _isdir(self, path: str, **kwargs: Any) -> bool: 294 | self._check_kwargs(kwargs) 295 | return os.path.isdir(path) 296 | 297 | def _ls(self, path: str, **kwargs: Any) -> List[str]: 298 | self._check_kwargs(kwargs) 299 | return os.listdir(path) 300 | 301 | def _mkdirs(self, path: str, **kwargs: Any) -> None: 302 | self._check_kwargs(kwargs) 303 | try: 304 | os.makedirs(path, exist_ok=True) 305 | except OSError as e: 306 | # EEXIST it can still happen if multiple processes are creating the dir 307 | if e.errno != errno.EEXIST: 308 | raise 309 | 310 | def _rm(self, path: str, **kwargs: Any) -> None: 311 | self._check_kwargs(kwargs) 312 | os.remove(path) 313 | 314 | 315 | class PathManager: 316 | """ 317 | A class for users to open generic paths or translate generic paths to file names. 318 | """ 319 | 320 | _PATH_HANDLERS: MutableMapping[str, PathHandler] = OrderedDict() 321 | _NATIVE_PATH_HANDLER = NativePathHandler() 322 | 323 | @staticmethod 324 | def __get_path_handler(path: str) -> PathHandler: 325 | """ 326 | Finds a PathHandler that supports the given path. Falls back to the native 327 | PathHandler if no other handler is found. 328 | Args: 329 | path (str): URI path to resource 330 | Returns: 331 | handler (PathHandler) 332 | """ 333 | for p in PathManager._PATH_HANDLERS.keys(): 334 | if path.startswith(p): 335 | return PathManager._PATH_HANDLERS[p] 336 | return PathManager._NATIVE_PATH_HANDLER 337 | 338 | @staticmethod 339 | def open( 340 | path: str, mode: str = "r", buffering: int = -1, **kwargs: Any 341 | ) -> Union[IO[str], IO[bytes]]: 342 | """ 343 | Open a stream to a URI, similar to the built-in `open`. 344 | Args: 345 | path (str): A URI supported by this PathHandler 346 | mode (str): Specifies the mode in which the file is opened. It defaults 347 | to 'r'. 348 | buffering (int): An optional integer used to set the buffering policy. 349 | Pass 0 to switch buffering off and an integer >= 1 to indicate the 350 | size in bytes of a fixed-size chunk buffer. When no buffering 351 | argument is given, the default buffering policy depends on the 352 | underlying I/O implementation. 353 | Returns: 354 | file: a file-like object. 355 | """ 356 | return PathManager.__get_path_handler(path)._open( # type: ignore 357 | path, mode, buffering=buffering, **kwargs 358 | ) 359 | 360 | @staticmethod 361 | def copy( 362 | src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any 363 | ) -> bool: 364 | """ 365 | Copies a source path to a destination path. 366 | Args: 367 | src_path (str): A URI supported by this PathHandler 368 | dst_path (str): A URI supported by this PathHandler 369 | overwrite (bool): Bool flag for forcing overwrite of existing file 370 | Returns: 371 | status (bool): True on success 372 | """ 373 | 374 | # Copying across handlers is not supported. 375 | assert PathManager.__get_path_handler( # type: ignore 376 | src_path 377 | ) == PathManager.__get_path_handler(dst_path) 378 | return PathManager.__get_path_handler(src_path)._copy( 379 | src_path, dst_path, overwrite, **kwargs 380 | ) 381 | 382 | @staticmethod 383 | def get_local_path(path: str, **kwargs: Any) -> str: 384 | """ 385 | Get a filepath which is compatible with native Python I/O such as `open` 386 | and `os.path`. 387 | If URI points to a remote resource, this function may download and cache 388 | the resource to local disk. 389 | Args: 390 | path (str): A URI supported by this PathHandler 391 | Returns: 392 | local_path (str): a file path which exists on the local file system 393 | """ 394 | return PathManager.__get_path_handler( # type: ignore 395 | path 396 | )._get_local_path(path, **kwargs) 397 | 398 | @staticmethod 399 | def exists(path: str, **kwargs: Any) -> bool: 400 | """ 401 | Checks if there is a resource at the given URI. 402 | Args: 403 | path (str): A URI supported by this PathHandler 404 | Returns: 405 | bool: true if the path exists 406 | """ 407 | return PathManager.__get_path_handler(path)._exists( # type: ignore 408 | path, **kwargs 409 | ) 410 | 411 | @staticmethod 412 | def isfile(path: str, **kwargs: Any) -> bool: 413 | """ 414 | Checks if there the resource at the given URI is a file. 415 | Args: 416 | path (str): A URI supported by this PathHandler 417 | Returns: 418 | bool: true if the path is a file 419 | """ 420 | return PathManager.__get_path_handler(path)._isfile( # type: ignore 421 | path, **kwargs 422 | ) 423 | 424 | @staticmethod 425 | def isdir(path: str, **kwargs: Any) -> bool: 426 | """ 427 | Checks if the resource at the given URI is a directory. 428 | Args: 429 | path (str): A URI supported by this PathHandler 430 | Returns: 431 | bool: true if the path is a directory 432 | """ 433 | return PathManager.__get_path_handler(path)._isdir( # type: ignore 434 | path, **kwargs 435 | ) 436 | 437 | @staticmethod 438 | def ls(path: str, **kwargs: Any) -> List[str]: 439 | """ 440 | List the contents of the directory at the provided URI. 441 | Args: 442 | path (str): A URI supported by this PathHandler 443 | Returns: 444 | List[str]: list of contents in given path 445 | """ 446 | return PathManager.__get_path_handler(path)._ls( # type: ignore 447 | path, **kwargs 448 | ) 449 | 450 | @staticmethod 451 | def mkdirs(path: str, **kwargs: Any) -> None: 452 | """ 453 | Recursive directory creation function. Like mkdir(), but makes all 454 | intermediate-level directories needed to contain the leaf directory. 455 | Similar to the native `os.makedirs`. 456 | Args: 457 | path (str): A URI supported by this PathHandler 458 | """ 459 | return PathManager.__get_path_handler(path)._mkdirs( # type: ignore 460 | path, **kwargs 461 | ) 462 | 463 | @staticmethod 464 | def rm(path: str, **kwargs: Any) -> None: 465 | """ 466 | Remove the file (not directory) at the provided URI. 467 | Args: 468 | path (str): A URI supported by this PathHandler 469 | """ 470 | return PathManager.__get_path_handler(path)._rm( # type: ignore 471 | path, **kwargs 472 | ) 473 | 474 | @staticmethod 475 | def register_handler(handler: PathHandler) -> None: 476 | """ 477 | Register a path handler associated with `handler._get_supported_prefixes` 478 | URI prefixes. 479 | Args: 480 | handler (PathHandler) 481 | """ 482 | assert isinstance(handler, PathHandler), handler 483 | for prefix in handler._get_supported_prefixes(): 484 | assert prefix not in PathManager._PATH_HANDLERS 485 | PathManager._PATH_HANDLERS[prefix] = handler 486 | 487 | # Sort path handlers in reverse order so longer prefixes take priority, 488 | # eg: http://foo/bar before http://foo 489 | PathManager._PATH_HANDLERS = OrderedDict( 490 | sorted( 491 | PathManager._PATH_HANDLERS.items(), 492 | key=lambda t: t[0], 493 | reverse=True, 494 | ) 495 | ) 496 | 497 | @staticmethod 498 | def set_strict_kwargs_checking(enable: bool) -> None: 499 | """ 500 | Toggles strict kwargs checking. If enabled, a ValueError is thrown if any 501 | unused parameters are passed to a PathHandler function. If disabled, only 502 | a warning is given. 503 | With a centralized file API, there's a tradeoff of convenience and 504 | correctness delegating arguments to the proper I/O layers. An underlying 505 | `PathHandler` may support custom arguments which should not be statically 506 | exposed on the `PathManager` function. For example, a custom `HTTPURLHandler` 507 | may want to expose a `cache_timeout` argument for `open()` which specifies 508 | how old a locally cached resource can be before it's refetched from the 509 | remote server. This argument would not make sense for a `NativePathHandler`. 510 | If strict kwargs checking is disabled, `cache_timeout` can be passed to 511 | `PathManager.open` which will forward the arguments to the underlying 512 | handler. By default, checking is enabled since it is innately unsafe: 513 | multiple `PathHandler`s could reuse arguments with different semantic 514 | meanings or types. 515 | Args: 516 | enable (bool) 517 | """ 518 | PathManager._NATIVE_PATH_HANDLER._strict_kwargs_check = enable 519 | for handler in PathManager._PATH_HANDLERS.values(): 520 | handler._strict_kwargs_check = enable 521 | -------------------------------------------------------------------------------- /secret/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | import logging 6 | 7 | import torch 8 | from torch.nn import Parameter 9 | 10 | def save_checkpoint(state, is_top1_best, is_mAP_best, fpath='checkpoint.pth.tar', remain=False): 11 | if remain is False: state['state_dict'] = {k: v for k, v in state['state_dict'].items() if 'classifier' not in k} 12 | if 'student_dict' in state and remain is False: state['student_dict'] = {k: v for k, v in state['state_dict'].items() if 'classifier' not in k} 13 | torch.save(state, fpath, _use_new_zipfile_serialization=False) 14 | if is_top1_best: 15 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_top1_best.pth.tar')) 16 | if is_mAP_best: 17 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_mAP_best.pth.tar')) 18 | 19 | def save_checkpoint_idattr(state, is_top1_best, is_mAP_best, fpath='checkpoint.pth.tar', remain=False): 20 | if remain is False: state['state_dict'] = {k: v for k, v in state['state_dict'].items() if '.classifier' not in k} 21 | if 'student_dict' in state and remain is False: state['student_dict'] = {k: v for k, v in state['state_dict'].items() if 'classifier' not in k} 22 | torch.save(state, fpath, _use_new_zipfile_serialization=False) 23 | if is_top1_best: 24 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_top1_best.pth.tar')) 25 | if is_mAP_best: 26 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_mAP_best.pth.tar')) 27 | 28 | 29 | def save_checkpoint_Attr(state, is_mA_best, is_F1_best, fpath='checkpoint.pth.tar', remain=False): 30 | if remain is False: state['state_dict'] = {k: v for k, v in state['state_dict'].items() if 'classifier' not in k} 31 | torch.save(state, fpath, _use_new_zipfile_serialization=False) 32 | if is_mA_best: 33 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_mA_best.pth.tar')) 34 | if is_F1_best: 35 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_F1_best.pth.tar')) 36 | 37 | def load_checkpoint(fpath): 38 | logger = logging.getLogger('UnReID') 39 | if osp.isfile(fpath): 40 | # checkpoint = torch.load(fpath) 41 | checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 42 | logger.info("=> Loaded checkpoint '{}'".format(fpath)) 43 | return checkpoint 44 | else: 45 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 46 | 47 | def copy_state_dict(state_dict, model, strip=None): 48 | logger = logging.getLogger('UnReID') 49 | tgt_state = model.state_dict() 50 | copied_names = set() 51 | for name, param in state_dict.items(): 52 | if strip is not None and name.startswith(strip): 53 | name = name[len(strip):] 54 | if name not in tgt_state: 55 | continue 56 | if isinstance(param, Parameter): 57 | param = param.data 58 | if param.size() != tgt_state[name].size(): 59 | # logger.info('mismatch: {} {} {}'.format(name, param.size(), tgt_state[name].size())) 60 | continue 61 | tgt_state[name].copy_(param) 62 | copied_names.add(name) 63 | 64 | missing = set(tgt_state.keys()) - copied_names 65 | # if len(missing) > 0: 66 | # logger.info("missing keys in state_dict: {}".format(missing)) 67 | 68 | return model 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import numpy as np 3 | from distutils.extension import Extension 4 | from Cython.Build import cythonize 5 | 6 | 7 | def numpy_include(): 8 | try: 9 | numpy_include = np.get_include() 10 | except AttributeError: 11 | numpy_include = np.get_numpy_include() 12 | return numpy_include 13 | 14 | 15 | ext_modules = [ 16 | Extension( 17 | 'secret.metrics.rank_cylib.rank_cy', 18 | ['secret/metrics/rank_cylib/rank_cy.pyx'], 19 | include_dirs=[numpy_include()], 20 | ) 21 | ] 22 | 23 | 24 | setup(name='ReIDSecret', 25 | version='1.0.0', 26 | packages=find_packages(), 27 | ext_modules=cythonize(ext_modules)) 28 | --------------------------------------------------------------------------------