├── .gitignore ├── LICENSE ├── README.md ├── data └── .gitkeep ├── examples ├── oneshot_DukeMTMC-VideoReID_used_in_paper.pickle └── oneshot_mars_used_in_paper.pickle ├── logs └── .gitkeep ├── reid ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dukemtmc_videoReID.py │ └── mars.py ├── dist_metric.py ├── eug.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── evaluators.py ├── feature_extraction │ ├── __init__.py │ ├── cnn.py │ └── database.py ├── loss │ ├── __init__.py │ ├── oim.py │ ├── tri_clu_loss.py │ └── triplet.py ├── metric_learning │ ├── __init__.py │ ├── euclidean.py │ └── kissme.py ├── models │ ├── __init__.py │ ├── end2end.py │ └── resnet.py ├── trainers.py └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── dataset.py │ ├── preprocessor.py │ ├── sampler.py │ └── transforms.py │ ├── logging.py │ ├── meters.py │ ├── osutils.py │ └── serialization.py ├── run.py └── run.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | init/ 3 | 4 | *.ckpt 5 | data/* 6 | !data/.gitkeep 7 | logs/* 8 | !logs/.gitkeep 9 | 10 | # temporary files which can be created if a process still has a handle open of a deleted file 11 | .fuse_hidden* 12 | 13 | # KDE directory preferences 14 | .directory 15 | 16 | # Linux trash folder which might appear on any partition or disk 17 | .Trash-* 18 | 19 | # .nfs files are created when an open file is removed but is still being accessed 20 | .nfs* 21 | 22 | 23 | *.DS_Store 24 | .AppleDouble 25 | .LSOverride 26 | 27 | # Icon must end with two \r 28 | Icon 29 | 30 | 31 | # Thumbnails 32 | ._* 33 | 34 | # Files that might appear in the root of a volume 35 | .DocumentRevisions-V100 36 | .fseventsd 37 | .Spotlight-V100 38 | .TemporaryItems 39 | .Trashes 40 | .VolumeIcon.icns 41 | .com.apple.timemachine.donotpresent 42 | 43 | # Directories potentially created on remote AFP share 44 | .AppleDB 45 | .AppleDesktop 46 | Network Trash Folder 47 | Temporary Items 48 | .apdisk 49 | 50 | 51 | # swap 52 | [._]*.s[a-v][a-z] 53 | [._]*.sw[a-p] 54 | [._]s[a-v][a-z] 55 | [._]sw[a-p] 56 | # session 57 | Session.vim 58 | # temporary 59 | .netrwhist 60 | *~ 61 | # auto-generated tag files 62 | tags 63 | 64 | 65 | # cache files for sublime text 66 | *.tmlanguage.cache 67 | *.tmPreferences.cache 68 | *.stTheme.cache 69 | 70 | # workspace files are user-specific 71 | *.sublime-workspace 72 | 73 | # project files should be checked into the repository, unless a significant 74 | # proportion of contributors will probably not be using SublimeText 75 | # *.sublime-project 76 | 77 | # sftp configuration file 78 | sftp-config.json 79 | 80 | # Package control specific files 81 | Package Control.last-run 82 | Package Control.ca-list 83 | Package Control.ca-bundle 84 | Package Control.system-ca-bundle 85 | Package Control.cache/ 86 | Package Control.ca-certs/ 87 | Package Control.merged-ca-bundle 88 | Package Control.user-ca-bundle 89 | oscrypto-ca-bundle.crt 90 | bh_unicode_properties.cache 91 | 92 | # Sublime-github package stores a github token in this file 93 | # https://packagecontrol.io/packages/sublime-github 94 | GitHub.sublime-settings 95 | 96 | 97 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 98 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 99 | 100 | # User-specific stuff: 101 | .idea 102 | .idea/**/workspace.xml 103 | .idea/**/tasks.xml 104 | 105 | # Sensitive or high-churn files: 106 | .idea/**/dataSources/ 107 | .idea/**/dataSources.ids 108 | .idea/**/dataSources.xml 109 | .idea/**/dataSources.local.xml 110 | .idea/**/sqlDataSources.xml 111 | .idea/**/dynamic.xml 112 | .idea/**/uiDesigner.xml 113 | 114 | # Gradle: 115 | .idea/**/gradle.xml 116 | .idea/**/libraries 117 | 118 | # Mongo Explorer plugin: 119 | .idea/**/mongoSettings.xml 120 | 121 | ## File-based project format: 122 | *.iws 123 | 124 | ## Plugin-specific files: 125 | 126 | # IntelliJ 127 | /out/ 128 | 129 | # mpeltonen/sbt-idea plugin 130 | .idea_modules/ 131 | 132 | # JIRA plugin 133 | atlassian-ide-plugin.xml 134 | 135 | # Crashlytics plugin (for Android Studio and IntelliJ) 136 | com_crashlytics_export_strings.xml 137 | crashlytics.properties 138 | crashlytics-build.properties 139 | fabric.properties 140 | 141 | 142 | # Byte-compiled / optimized / DLL files 143 | __pycache__/ 144 | *.py[cod] 145 | *$py.class 146 | 147 | # C extensions 148 | *.so 149 | 150 | # Distribution / packaging 151 | .Python 152 | env/ 153 | build/ 154 | develop-eggs/ 155 | dist/ 156 | downloads/ 157 | eggs/ 158 | .eggs/ 159 | lib/ 160 | lib64/ 161 | parts/ 162 | sdist/ 163 | var/ 164 | *.egg-info/ 165 | .installed.cfg 166 | *.egg 167 | 168 | # PyInstaller 169 | # Usually these files are written by a python script from a template 170 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 171 | *.manifest 172 | *.spec 173 | 174 | # Installer logs 175 | pip-log.txt 176 | pip-delete-this-directory.txt 177 | 178 | # Unit test / coverage reports 179 | htmlcov/ 180 | .tox/ 181 | .coverage 182 | .coverage.* 183 | .cache 184 | nosetests.xml 185 | coverage.xml 186 | *,cover 187 | .hypothesis/ 188 | 189 | # Translations 190 | *.mo 191 | *.pot 192 | 193 | # Django stuff: 194 | *.log 195 | local_settings.py 196 | 197 | # Flask stuff: 198 | instance/ 199 | .webassets-cache 200 | 201 | # Scrapy stuff: 202 | .scrapy 203 | 204 | # Sphinx documentation 205 | docs/_build/ 206 | 207 | # PyBuilder 208 | target/ 209 | 210 | # IPython Notebook 211 | .ipynb_checkpoints 212 | 213 | # pyenv 214 | .python-version 215 | 216 | # celery beat schedule file 217 | celerybeat-schedule 218 | 219 | # dotenv 220 | .env 221 | 222 | # virtualenv 223 | venv/ 224 | ENV/ 225 | 226 | # Spyder project settings 227 | .spyderproject 228 | 229 | # Rope project settings 230 | .ropeproject 231 | 232 | 233 | # Project specific 234 | examples/data 235 | examples/logs 236 | 237 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yu Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Exploit the Unknown Gradually: One-Shot Video-Based Person Re-Identification by Stepwise Learning](https://yu-wu.net/pdf/CVPR2018_Exploit-Unknown-Gradually.pdf) 2 | 3 | Pytorch implementation for our paper [[Link]](http://openaccess.thecvf.com/content_cvpr_2018/papers/Wu_Exploit_the_Unknown_CVPR_2018_paper.pdf). 4 | This code is based on the [Open-ReID](https://github.com/Cysu/open-reid) library. 5 | 6 | ## Preparation 7 | ### Dependencies 8 | - Python 3.6 9 | - PyTorch (version >= 0.2.0) 10 | - h5py, scikit-learn, metric-learn, tqdm 11 | 12 | ### Download datasets 13 | - DukeMTMC-VideoReID: This [page](https://github.com/Yu-Wu/DukeMTMC-VideoReID) contains more details and baseline code. 14 | - MARS: [[Google Drive]](https://drive.google.com/open?id=1m6yLgtQdhb6pLCcb6_m7sj0LLBRvkDW0)   [[BaiduYun]](https://pan.baidu.com/s/1mByTdvXFsmobXOXBEkIWFw). 15 | - Move the downloaded zip files to `./data/` and unzip here. 16 | 17 | 18 | ## Train 19 | 20 | For the DukeMTMC-VideoReID dataset: 21 | ```shell 22 | python3 run.py --dataset DukeMTMC-VideoReID --logs_dir logs/DukeMTMC-VideoReID_EF_10/ --EF 10 --mode Dissimilarity --max_frames 900 23 | ``` 24 | 25 | For the MARS datasaet: 26 | ``` 27 | python3 run.py --dataset mars --logs_dir logs/mars_EF_10/ --EF 10 --mode Dissimilarity --max_frames 900 28 | ``` 29 | It takes about 10 hours to train EUG (EF=10%) on DukeMTMC-VideoReID with a GTX1080Ti. Please set the `max_frames` smaller if your GPU memory is less than 11G. 30 | 31 | ## Performances 32 | 33 | The performances varies according to random splits for initial labeled data. To reproduce the performances in our paper, please use the one-shot splits at `./examples/` 34 | 35 | 36 | ## Citation 37 | 38 | Please cite the following paper in your publications if it helps your research: 39 | 40 | @inproceedings{wu2018cvpr_oneshot, 41 | title = {Exploit the Unknown Gradually: One-Shot Video-Based Person Re-Identification by Stepwise Learning}, 42 | author = {Wu, Yu and Lin, Yutian and Dong, Xuanyi and Yan, Yan and Ouyang, Wanli and Yang, Yi}, 43 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 44 | year = {2018} 45 | } 46 | 47 | ## Future work 48 | 49 | In the extended version, We improve EUG by leveraging the remaining unselected data in an unsupervised way, which shows a great performance improvement on the image-based re-ID datasets (Market-1501 and DukeMTMC-reID). [[Paper]](https://yu-wu.net/pdf/TIP2019_One-Example-reID.pdf) [[Code]](https://github.com/Yu-Wu/One-Example-Person-ReID) 50 | 51 | 52 | ## Contact 53 | 54 | To report issues for this code, please open an issue on the [issues tracker](https://github.com/Yu-Wu/Exploit-Unknown-Gradually/issues). 55 | 56 | If you have further questions about this paper, please do not hesitate to contact me. 57 | 58 | [Yu Wu's Homepage](https://yu-wu.net) 59 | 60 | 61 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yu-Wu/Exploit-Unknown-Gradually/a3ebea3851ef98e32d4611250c779c84cfbde171/data/.gitkeep -------------------------------------------------------------------------------- /examples/oneshot_DukeMTMC-VideoReID_used_in_paper.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yu-Wu/Exploit-Unknown-Gradually/a3ebea3851ef98e32d4611250c779c84cfbde171/examples/oneshot_DukeMTMC-VideoReID_used_in_paper.pickle -------------------------------------------------------------------------------- /examples/oneshot_mars_used_in_paper.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yu-Wu/Exploit-Unknown-Gradually/a3ebea3851ef98e32d4611250c779c84cfbde171/examples/oneshot_mars_used_in_paper.pickle -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yu-Wu/Exploit-Unknown-Gradually/a3ebea3851ef98e32d4611250c779c84cfbde171/logs/.gitkeep -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import feature_extraction 6 | from . import loss 7 | from . import metric_learning 8 | from . import models 9 | from . import utils 10 | from . import dist_metric 11 | from . import evaluators 12 | from . import trainers 13 | 14 | __version__ = '0.2.0' 15 | -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .mars import Mars 5 | from .dukemtmc_videoReID import DukeMTMC_VideoReID 6 | 7 | 8 | __factory = { 9 | 'mars': Mars, 10 | 'DukeMTMC-VideoReID': DukeMTMC_VideoReID 11 | } 12 | 13 | 14 | def names(): 15 | return sorted(__factory.keys()) 16 | 17 | 18 | def create(name, root, *args, **kwargs): 19 | """ 20 | Create a dataset instance. 21 | 22 | Parameters 23 | ---------- 24 | name : str 25 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03', 26 | 'market1501', and 'dukemtmc'. 27 | root : str 28 | The path to the dataset directory. 29 | split_id : int, optional 30 | The index of data split. Default: 0 31 | num_val : int or float, optional 32 | When int, it means the number of validation identities. When float, 33 | it means the proportion of validation to all the trainval. Default: 100 34 | download : bool, optional 35 | If True, will download the dataset. Default: False 36 | """ 37 | if name not in __factory: 38 | raise KeyError("Unknown dataset:", name) 39 | return __factory[name](root, *args, **kwargs) 40 | 41 | 42 | def get_dataset(name, root, *args, **kwargs): 43 | warnings.warn("get_dataset is deprecated. Use create instead.") 44 | return create(name, root, *args, **kwargs) 45 | -------------------------------------------------------------------------------- /reid/datasets/dukemtmc_videoReID.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import os 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import write_json 7 | 8 | 9 | class DukeMTMC_VideoReID(Dataset): 10 | def __init__(self, root, split_id=0, num_val=100, download=True): 11 | super(self.__class__, self).__init__(root, split_id=split_id) 12 | self.name="DukeMTMC-VideoReID" 13 | self.num_cams = 8 14 | 15 | if download: 16 | self.download() 17 | 18 | if not self._check_integrity(): 19 | raise RuntimeError("Dataset not found or corrupted. " + 20 | "You can use download=True to download it.") 21 | 22 | self.load(num_val) 23 | 24 | def download(self): 25 | if self._check_integrity(): 26 | print("Files already downloaded and verified") 27 | return 28 | print("create new dataset") 29 | import re 30 | import hashlib 31 | import shutil 32 | from glob import glob 33 | from zipfile import ZipFile 34 | 35 | # Format 36 | images_dir = osp.join(self.root, 'images') 37 | mkdir_if_missing(images_dir) 38 | 39 | identities = [[{} for _ in range(self.num_cams)] for _ in range(7141)] 40 | 41 | def register(subdir): 42 | pids = set() 43 | relabeled_pid = -1 44 | vids = [] 45 | person_list = os.listdir(os.path.join(self.root, subdir)); person_list.sort() 46 | for person_id in person_list: 47 | count = 0 48 | pid = int(person_id) 49 | videos = os.listdir(os.path.join(self.root, subdir, person_id)); videos.sort() 50 | for video_id in videos: 51 | video_path = os.path.join(self.root, subdir, person_id, video_id) 52 | video_id = int(video_id) 53 | fnames = os.listdir(video_path) 54 | frame_list = [] 55 | for fname in fnames: 56 | count += 1 57 | cam = int(fname[6]) - 1 58 | assert 0 <= pid <= 7140 59 | assert 0 <= cam <= 8 60 | pids.add(pid) 61 | newname = ('{:04d}_{:02d}_{:04d}_{:04d}.jpg'.format(pid, cam, video_id, len(frame_list))) 62 | frame_list.append(newname) 63 | shutil.copy(osp.join(video_path, fname), osp.join(images_dir, newname)) 64 | identities[pid][cam][video_id] = frame_list 65 | vids.append(frame_list) 66 | print("ID {}, frames {}\t in {}".format(person_id, count, subdir)) 67 | return pids, vids 68 | 69 | print("begin to preprocess mars dataset") 70 | trainval_pids, _ = register('train') 71 | gallery_pids, gallery_vids = register('gallery') 72 | query_pids, query_vids = register('query') 73 | #assert query_pids <= gallery_pids 74 | assert trainval_pids.isdisjoint(gallery_pids) 75 | 76 | # Save meta information into a json file 77 | meta = {'name': 'Mars', 'shot': 'multiple', 'num_cameras': 8, 78 | 'identities': identities, 79 | 'query': query_vids, 80 | 'gallery': gallery_vids} 81 | write_json(meta, osp.join(self.root, 'meta.json')) 82 | 83 | # Save the only training / test split 84 | splits = [{ 85 | 'train': sorted(list(trainval_pids)), 86 | 'query': sorted(list(query_pids)) , 87 | 'gallery': sorted(list(gallery_pids))}] 88 | write_json(splits, osp.join(self.root, 'splits.json')) 89 | 90 | -------------------------------------------------------------------------------- /reid/datasets/mars.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import os 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import write_json 7 | 8 | 9 | class Mars(Dataset): 10 | def __init__(self, root, split_id=0, num_val=100, download=True): 11 | super(self.__class__, self).__init__(root, split_id=split_id) 12 | self.name="mars" 13 | self.num_cams = 6 14 | 15 | if download: 16 | self.download() 17 | 18 | if not self._check_integrity(): 19 | raise RuntimeError("Dataset not found or corrupted. " + 20 | "You can use download=True to download it.") 21 | 22 | self.load(num_val) 23 | 24 | def download(self): 25 | if self._check_integrity(): 26 | print("Files already downloaded and verified") 27 | return 28 | print("create new dataset") 29 | import re 30 | import hashlib 31 | import shutil 32 | from glob import glob 33 | from zipfile import ZipFile 34 | 35 | 36 | # get mars dataset 37 | 38 | # Format 39 | images_dir = osp.join(self.root, 'images') 40 | mkdir_if_missing(images_dir) 41 | 42 | # totally 1261 person (625+636) with 6 camera views each 43 | # id 1~625 are for training 44 | # id 999~1634 are for testing 45 | identities = [[{} for _ in range(6)] for _ in range(1635)] 46 | 47 | def register(subdir): 48 | print("Copy images for {}".format(subdir)) 49 | pids = set() 50 | vids = [] 51 | person_list = os.listdir(os.path.join(self.root, subdir)); person_list.sort() 52 | for person_id in person_list: 53 | count = 0 54 | videos = os.listdir(os.path.join(self.root, subdir, person_id)); videos.sort() 55 | 56 | pid = int(person_id) 57 | assert 0 <= pid <= 1634 # pid == 999, 1000 means background and distractors 58 | if pid == 999: 59 | print("skip junk images") 60 | continue 61 | pids.add(pid) 62 | 63 | for video_id in videos: 64 | video_path = os.path.join(self.root, subdir, person_id, video_id) 65 | video_id = int(video_id) - 1 66 | fnames = os.listdir(video_path) 67 | frame_list = [] 68 | for fname in fnames: 69 | count += 1 70 | cam = int(fname[5]) - 1 71 | assert 0 <= cam <= 5 72 | newname = ('{:04d}_{:02d}_{:04d}_{:04d}.jpg'.format(pid, cam, video_id, len(frame_list))) 73 | frame_list.append(newname) 74 | shutil.copy(osp.join(video_path, fname), osp.join(images_dir, newname)) 75 | identities[pid][cam][video_id] = frame_list 76 | vids.append(frame_list) 77 | #print("ID {}, frames {}\t in {}".format(person_id, count, subdir)) 78 | return pids, vids 79 | 80 | print("begin to preprocess mars dataset") 81 | trainval_pids, _ = register('train_split') 82 | gallery_pids, gallery_vids = register('gallery_split') 83 | query_pids, query_vids = register('query_split') 84 | #assert query_pids <= gallery_pids 85 | assert trainval_pids.isdisjoint(gallery_pids) 86 | 87 | # Save meta information into a json file 88 | meta = {'name': 'Mars', 'shot': 'multiple', 'num_cameras': 6, 89 | 'identities': identities, 90 | 'query': query_vids, 91 | 'gallery': gallery_vids} 92 | write_json(meta, osp.join(self.root, 'meta.json')) 93 | 94 | # Save the only training / test split 95 | splits = [{ 96 | 'train': sorted(list(trainval_pids)), 97 | 'query': sorted(list(query_pids)) , 98 | 'gallery': sorted(list(gallery_pids))}] 99 | write_json(splits, osp.join(self.root, 'splits.json')) 100 | 101 | -------------------------------------------------------------------------------- /reid/dist_metric.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | from .evaluators import extract_features 6 | from .metric_learning import get_metric 7 | 8 | 9 | class DistanceMetric(object): 10 | def __init__(self, algorithm='euclidean', *args, **kwargs): 11 | super(DistanceMetric, self).__init__() 12 | self.algorithm = algorithm 13 | self.metric = get_metric(algorithm, *args, **kwargs) 14 | 15 | def train(self, model, data_loader): 16 | if self.algorithm == 'euclidean': return 17 | features, labels = extract_features(model, data_loader) 18 | features = torch.stack(features.values()).numpy() 19 | labels = torch.Tensor(list(labels.values())).numpy() 20 | self.metric.fit(features, labels) 21 | 22 | def transform(self, X): 23 | if torch.is_tensor(X): 24 | X = X.numpy() 25 | X = self.metric.transform(X) 26 | X = torch.from_numpy(X) 27 | else: 28 | X = self.metric.transform(X) 29 | return X 30 | 31 | -------------------------------------------------------------------------------- /reid/eug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from reid import models 4 | from reid.trainers import Trainer 5 | from reid.evaluators import extract_features, Evaluator 6 | from reid.dist_metric import DistanceMetric 7 | import numpy as np 8 | from collections import OrderedDict 9 | import os.path as osp 10 | import pickle 11 | from reid.utils.serialization import load_checkpoint 12 | from reid.utils.data import transforms as T 13 | from torch.utils.data import DataLoader 14 | from reid.utils.data.preprocessor import Preprocessor 15 | import random 16 | 17 | 18 | class EUG(): 19 | def __init__(self, model_name, batch_size, mode, num_classes, data_dir, l_data, u_data, save_path, dropout=0.5, max_frames=900): 20 | 21 | self.model_name = model_name 22 | self.num_classes = num_classes 23 | self.mode = mode 24 | self.data_dir = data_dir 25 | self.save_path = save_path 26 | 27 | self.l_data = l_data 28 | self.u_data = u_data 29 | self.l_label = np.array([label for _,label,_,_ in l_data]) 30 | self.u_label = np.array([label for _,label,_,_ in u_data]) 31 | 32 | 33 | self.dataloader_params = {} 34 | self.dataloader_params['height'] = 256 35 | self.dataloader_params['width'] = 128 36 | self.dataloader_params['batch_size'] = batch_size 37 | self.dataloader_params['workers'] = 6 38 | 39 | 40 | self.batch_size = batch_size 41 | self.data_height = 256 42 | self.data_width = 128 43 | self.data_workers = 6 44 | 45 | # batch size for eval mode. Default is 1. 46 | self.eval_bs = 1 47 | self.dropout = dropout 48 | self.max_frames = max_frames 49 | 50 | 51 | def get_dataloader(self, dataset, training=False) : 52 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 53 | std=[0.229, 0.224, 0.225]) 54 | 55 | if training: 56 | transformer = T.Compose([ 57 | T.RandomSizedRectCrop(self.data_height, self.data_width), 58 | T.RandomHorizontalFlip(), 59 | T.ToTensor(), 60 | normalizer, 61 | ]) 62 | batch_size = self.batch_size 63 | 64 | else: 65 | transformer = T.Compose([ 66 | T.RectScale(self.data_height, self.data_width), 67 | T.ToTensor(), 68 | normalizer, 69 | ]) 70 | batch_size = self.eval_bs 71 | 72 | data_loader = DataLoader( 73 | Preprocessor(dataset, root=self.data_dir, 74 | transform=transformer, is_training=training, max_frames=self.max_frames), 75 | batch_size=batch_size, num_workers=self.data_workers, 76 | shuffle=training, pin_memory=True, drop_last=training) 77 | 78 | current_status = "Training" if training else "Test" 79 | print("create dataloader for {} with batch_size {}".format(current_status, batch_size)) 80 | return data_loader 81 | 82 | 83 | 84 | 85 | def train(self, train_data, step, epochs=70, step_size=55, init_lr=0.1, dropout=0.5): 86 | 87 | """ create model and dataloader """ 88 | model = models.create(self.model_name, dropout=self.dropout, num_classes=self.num_classes, mode=self.mode) 89 | model = nn.DataParallel(model).cuda() 90 | dataloader = self.get_dataloader(train_data, training=True) 91 | 92 | 93 | # the base parameters for the backbone (e.g. ResNet50) 94 | base_param_ids = set(map(id, model.module.CNN.base.parameters())) 95 | 96 | # we fixed the first three blocks to save GPU memory 97 | base_params_need_for_grad = filter(lambda p: p.requires_grad, model.module.CNN.parameters()) 98 | 99 | # params of the new layers 100 | new_params = [p for p in model.parameters() if id(p) not in base_param_ids] 101 | 102 | # set the learning rate for backbone to be 0.1 times 103 | param_groups = [ 104 | {'params': base_params_need_for_grad, 'lr_mult': 0.1}, 105 | {'params': new_params, 'lr_mult': 1.0}] 106 | 107 | criterion = nn.CrossEntropyLoss().cuda() 108 | optimizer = torch.optim.SGD(param_groups, lr=init_lr, momentum=0.5, weight_decay = 5e-4, nesterov=True) 109 | 110 | # change the learning rate by step 111 | def adjust_lr(epoch, step_size): 112 | lr = init_lr / (10 ** (epoch // step_size)) 113 | for g in optimizer.param_groups: 114 | g['lr'] = lr * g.get('lr_mult', 1) 115 | 116 | if epoch % step_size == 0: 117 | print("Epoch {}, current lr {}".format(epoch, lr)) 118 | 119 | """ main training process """ 120 | trainer = Trainer(model, criterion) 121 | for epoch in range(epochs): 122 | adjust_lr(epoch, step_size) 123 | trainer.train(epoch, dataloader, optimizer, print_freq=len(dataloader)//30 * 10) 124 | 125 | torch.save(model.state_dict(), osp.join(self.save_path, "{}_step_{}.ckpt".format(self.mode, step))) 126 | self.model = model 127 | 128 | 129 | def get_feature(self, dataset): 130 | dataloader = self.get_dataloader(dataset, training=False) 131 | features,_ = extract_features(self.model, dataloader) 132 | features = np.array([logit.numpy() for logit in features.values()]) 133 | return features 134 | 135 | 136 | 137 | def get_Classification_result(self): 138 | logits = self.get_feature(self.u_data) 139 | exp_logits = np.exp(logits) 140 | predict_prob = exp_logits / np.sum(exp_logits,axis=1).reshape((-1,1)) 141 | assert len(logits) == len(predict_prob) 142 | assert predict_prob.shape[1] == self.num_classes 143 | 144 | pred_label = np.argmax(predict_prob, axis=1) 145 | pred_score = predict_prob.max(axis=1) 146 | print("get_Classification_result", predict_prob.shape) 147 | 148 | 149 | num_correct_pred = 0 150 | for idx, p_label in enumerate(pred_label): 151 | if self.u_label[idx] == p_label: 152 | num_correct_pred +=1 153 | 154 | 155 | print("{} predictions on all the unlabeled data: {} of {} is correct, accuracy = {:0.3f}".format( 156 | self.mode, num_correct_pred, pred_label.shape[0], num_correct_pred/pred_label.shape[0])) 157 | 158 | return pred_label, pred_score 159 | 160 | 161 | 162 | def get_Dissimilarity_result(self): 163 | 164 | # extract feature 165 | u_feas = self.get_feature(self.u_data) 166 | l_feas = self.get_feature(self.l_data) 167 | print("u_features", u_feas.shape, "l_features", l_feas.shape) 168 | 169 | scores = np.zeros((u_feas.shape[0])) 170 | labels = np.zeros((u_feas.shape[0])) 171 | 172 | num_correct_pred = 0 173 | for idx, u_fea in enumerate(u_feas): 174 | diffs = l_feas - u_fea 175 | dist = np.linalg.norm(diffs,axis=1) 176 | index_min = np.argmin(dist) 177 | scores[idx] = - dist[index_min] # "- dist" : more dist means less score 178 | labels[idx] = self.l_label[index_min] # take the nearest labled neighbor as the prediction label 179 | 180 | # count the correct number of Nearest Neighbor prediction 181 | if self.u_label[idx] == labels[idx]: 182 | num_correct_pred +=1 183 | 184 | print("{} predictions on all the unlabeled data: {} of {} is correct, accuracy = {:0.3f}".format( 185 | self.mode, num_correct_pred, u_feas.shape[0], num_correct_pred/u_feas.shape[0])) 186 | return labels, scores 187 | 188 | 189 | def estimate_label(self): 190 | 191 | print("label estimation by {} mode.".format(self.mode)) 192 | 193 | if self.mode == "Dissimilarity": 194 | # predict label by dissimilarity cost 195 | [pred_label, pred_score] = self.get_Dissimilarity_result() 196 | 197 | elif self.mode == "Classification": 198 | # predict label by classification 199 | [pred_label, pred_score] = self.get_Classification_result() 200 | else: 201 | raise ValueError 202 | 203 | return pred_label, pred_score 204 | 205 | 206 | 207 | def select_top_data(self, pred_score, nums_to_select): 208 | v = np.zeros(len(pred_score)) 209 | index = np.argsort(-pred_score) 210 | for i in range(nums_to_select): 211 | v[index[i]] = 1 212 | return v.astype('bool') 213 | 214 | 215 | 216 | 217 | def generate_new_train_data(self, sel_idx, pred_y): 218 | """ generate the next training data """ 219 | 220 | seletcted_data = [] 221 | correct, total = 0, 0 222 | for i, flag in enumerate(sel_idx): 223 | if flag: # if selected 224 | seletcted_data.append([self.u_data[i][0], int(pred_y[i]), self.u_data[i][2], self.u_data[i][3]]) 225 | total += 1 226 | if self.u_label[i] == int(pred_y[i]): 227 | correct += 1 228 | acc = correct / total 229 | 230 | new_train_data = self.l_data + seletcted_data 231 | print("selected pseudo-labeled data: {} of {} is correct, accuracy: {:0.4f} new train data: {}".format( 232 | correct, len(seletcted_data), acc, len(new_train_data))) 233 | 234 | return new_train_data 235 | 236 | def resume(self, ckpt_file, step): 237 | print("continued from step", step) 238 | model = models.create(self.model_name, dropout=self.dropout, num_classes=self.num_classes, mode=self.mode) 239 | self.model = nn.DataParallel(model).cuda() 240 | self.model.load_state_dict(load_checkpoint(ckpt_file)) 241 | 242 | def evaluate(self, query, gallery): 243 | test_loader = self.get_dataloader(list(set(query) | set(gallery)), training = False) 244 | evaluator = Evaluator(self.model) 245 | evaluator.evaluate(test_loader, query, gallery) 246 | 247 | 248 | 249 | """ 250 | Get one-shot split for the input dataset. 251 | """ 252 | def get_one_shot_in_cam1(dataset, load_path, seed=0): 253 | 254 | np.random.seed(seed) 255 | random.seed(seed) 256 | 257 | # if previous split exists, load it and return 258 | if osp.exists(load_path): 259 | with open(load_path, "rb") as fp: 260 | dataset = pickle.load(fp) 261 | label_dataset = dataset["label set"] 262 | unlabel_dataset = dataset["unlabel set"] 263 | 264 | print(" labeled | N/A | {:8d}".format(len(label_dataset))) 265 | print(" unlabel | N/A | {:8d}".format(len(unlabel_dataset))) 266 | print("\nLoad one-shot split from", load_path) 267 | return label_dataset, unlabel_dataset 268 | 269 | 270 | 271 | #print("random create new one-shot split and save it to", load_path) 272 | 273 | label_dataset = [] 274 | unlabel_dataset = [] 275 | 276 | # dataset indexed by [pid, cam] 277 | dataset_in_pid_cam = [[[] for _ in range(dataset.num_cams)] for _ in range(dataset.num_train_ids) ] 278 | for index, (images, pid, camid, videoid) in enumerate(dataset.train): 279 | dataset_in_pid_cam[pid][camid].append([images, pid, camid, videoid]) 280 | 281 | 282 | # generate the labeled dataset by randomly selecting a tracklet from the first camera for each identity 283 | for pid, cams_data in enumerate(dataset_in_pid_cam): 284 | for camid, videos in enumerate(cams_data): 285 | if len(videos) != 0: 286 | selected_video = random.choice(videos) 287 | break 288 | label_dataset.append(selected_video) 289 | assert len(label_dataset) == dataset.num_train_ids 290 | labeled_videoIDs =[vid for _, (_,_,_, vid) in enumerate(label_dataset)] 291 | 292 | # generate unlabeled set 293 | for (imgs, pid, camid, videoid) in dataset.train: 294 | if videoid not in labeled_videoIDs: 295 | unlabel_dataset.append([imgs, pid, camid, videoid]) 296 | 297 | 298 | with open(load_path, "wb") as fp: 299 | pickle.dump({"label set":label_dataset, "unlabel set":unlabel_dataset}, fp) 300 | 301 | 302 | print(" labeled | N/A | {:8d}".format(len(label_dataset))) 303 | print(" unlabeled | N/A | {:8d}".format(len(unlabel_dataset))) 304 | print("\nCreate new one-shot split, and save it to", load_path) 305 | return label_dataset, unlabel_dataset 306 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics.base import _average_binary_score 6 | from sklearn.metrics import precision_recall_curve, auc 7 | # from sklearn.metrics import average_precision_score 8 | 9 | 10 | from ..utils import to_numpy 11 | 12 | 13 | def _unique_sample(ids_dict, num): 14 | mask = np.zeros(num, dtype=np.bool) 15 | for _, indices in ids_dict.items(): 16 | i = np.random.choice(indices) 17 | mask[i] = True 18 | return mask 19 | 20 | 21 | def average_precision_score(y_true, y_score, average="macro", 22 | sample_weight=None): 23 | def _binary_average_precision(y_true, y_score, sample_weight=None): 24 | precision, recall, thresholds = precision_recall_curve( 25 | y_true, y_score, sample_weight=sample_weight) 26 | return auc(recall, precision) 27 | 28 | return _average_binary_score(_binary_average_precision, y_true, y_score, 29 | average, sample_weight=sample_weight) 30 | 31 | 32 | def cmc(distmat, query_ids=None, gallery_ids=None, 33 | query_cams=None, gallery_cams=None, topk=100, 34 | separate_camera_set=False, 35 | single_gallery_shot=False, 36 | first_match_break=False): 37 | distmat = to_numpy(distmat) 38 | m, n = distmat.shape 39 | # Fill up default values 40 | if query_ids is None: 41 | query_ids = np.arange(m) 42 | if gallery_ids is None: 43 | gallery_ids = np.arange(n) 44 | if query_cams is None: 45 | query_cams = np.zeros(m).astype(np.int32) 46 | if gallery_cams is None: 47 | gallery_cams = np.ones(n).astype(np.int32) 48 | # Ensure numpy array 49 | query_ids = np.asarray(query_ids) 50 | gallery_ids = np.asarray(gallery_ids) 51 | query_cams = np.asarray(query_cams) 52 | gallery_cams = np.asarray(gallery_cams) 53 | # Sort and find correct matches 54 | indices = np.argsort(distmat, axis=1) 55 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 56 | # Compute CMC for each query 57 | ret = np.zeros(topk) 58 | num_valid_queries = 0 59 | for i in range(m): 60 | # Filter out the same id and same camera 61 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 62 | (gallery_cams[indices[i]] != query_cams[i])) 63 | if separate_camera_set: 64 | # Filter out samples from same camera 65 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 66 | if not np.any(matches[i, valid]): continue 67 | if single_gallery_shot: 68 | repeat = 10 69 | gids = gallery_ids[indices[i][valid]] 70 | inds = np.where(valid)[0] 71 | ids_dict = defaultdict(list) 72 | for j, x in zip(inds, gids): 73 | ids_dict[x].append(j) 74 | else: 75 | repeat = 1 76 | for _ in range(repeat): 77 | if single_gallery_shot: 78 | # Randomly choose one instance for each id 79 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 80 | index = np.nonzero(matches[i, sampled])[0] 81 | else: 82 | index = np.nonzero(matches[i, valid])[0] 83 | delta = 1. / (len(index) * repeat) 84 | for j, k in enumerate(index): 85 | if k - j >= topk: break 86 | if first_match_break: 87 | ret[k - j] += 1 88 | break 89 | ret[k - j] += delta 90 | num_valid_queries += 1 91 | if num_valid_queries == 0: 92 | raise RuntimeError("No valid query") 93 | return ret.cumsum() / num_valid_queries 94 | 95 | 96 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 97 | query_cams=None, gallery_cams=None): 98 | distmat = to_numpy(distmat) 99 | m, n = distmat.shape 100 | # Fill up default values 101 | if query_ids is None: 102 | query_ids = np.arange(m) 103 | if gallery_ids is None: 104 | gallery_ids = np.arange(n) 105 | if query_cams is None: 106 | query_cams = np.zeros(m).astype(np.int32) 107 | if gallery_cams is None: 108 | gallery_cams = np.ones(n).astype(np.int32) 109 | # Ensure numpy array 110 | query_ids = np.asarray(query_ids) 111 | gallery_ids = np.asarray(gallery_ids) 112 | query_cams = np.asarray(query_cams) 113 | gallery_cams = np.asarray(gallery_cams) 114 | # Sort and find correct matches 115 | indices = np.argsort(distmat, axis=1) 116 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 117 | # Compute AP for each query 118 | aps = [] 119 | for i in range(m): 120 | # Filter out the same id and same camera 121 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 122 | (gallery_cams[indices[i]] != query_cams[i])) 123 | y_true = matches[i, valid] 124 | y_score = -distmat[i][indices[i]][valid] 125 | if not np.any(y_true): continue 126 | aps.append(average_precision_score(y_true, y_score)) 127 | if len(aps) == 0: 128 | raise RuntimeError("No valid query") 129 | return np.mean(aps) 130 | 131 | -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | from .evaluation_metrics import cmc, mean_ap 8 | from .feature_extraction import extract_cnn_feature 9 | from .utils.meters import AverageMeter 10 | 11 | import tqdm 12 | 13 | from torch.backends import cudnn 14 | 15 | def extract_features(model, data_loader, print_freq=1, metric=None): 16 | cudnn.benchmark = False 17 | model.eval() 18 | batch_time = AverageMeter() 19 | data_time = AverageMeter() 20 | 21 | features = OrderedDict() 22 | labels = OrderedDict() 23 | 24 | with tqdm.tqdm(total=len(data_loader)) as pbar: 25 | for i, (imgs, fnames, pids, _, _) in enumerate(data_loader): 26 | outputs = extract_cnn_feature(model, imgs) 27 | for fname, output, pid in zip(fnames, outputs, pids): 28 | features[fname] = output 29 | labels[fname] = pid 30 | pbar.update(1) 31 | 32 | 33 | print("Extract {} batch videos".format(len(data_loader))) 34 | cudnn.benchmark = True 35 | return features, labels 36 | 37 | 38 | def pairwise_distance(features, query=None, gallery=None, metric=None): 39 | if query is None and gallery is None: 40 | n = len(features) 41 | x = torch.cat(list(features.values())) 42 | x = x.view(n, -1) 43 | if metric is not None: 44 | x = metric.transform(x) 45 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 46 | dist = dist.expand(n, n) - 2 * torch.mm(x, x.t()) 47 | return dist 48 | 49 | x = torch.cat([features["".join(f)].unsqueeze(0) for f, _, _, _ in query], 0) 50 | y = torch.cat([features["".join(f)].unsqueeze(0) for f, _, _, _ in gallery], 0) 51 | m, n = x.size(0), y.size(0) 52 | x = x.view(m, -1) 53 | y = y.view(n, -1) 54 | if metric is not None: 55 | x = metric.transform(x) 56 | y = metric.transform(y) 57 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 58 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 59 | dist.addmm_(1, -2, x, y.t()) 60 | return dist 61 | 62 | 63 | def evaluate_all(distmat, query=None, gallery=None, 64 | query_ids=None, gallery_ids=None, 65 | query_cams=None, gallery_cams=None, 66 | cmc_topk=(1, 5, 10, 20)): 67 | if query is not None and gallery is not None: 68 | query_ids = [pid for _, pid, _, _ in query] 69 | gallery_ids = [pid for _, pid, _, _ in gallery] 70 | query_cams = [cam for _, _, cam, _ in query] 71 | gallery_cams = [cam for _, _, cam, _ in gallery] 72 | else: 73 | assert (query_ids is not None and gallery_ids is not None 74 | and query_cams is not None and gallery_cams is not None) 75 | 76 | # Compute mean AP 77 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 78 | print('Mean AP: {:4.1%}'.format(mAP)) 79 | 80 | # Compute all kinds of CMC scores 81 | cmc_configs = { 82 | # 'allshots': dict(separate_camera_set=False, 83 | # single_gallery_shot=False, 84 | # first_match_break=False), 85 | # 'cuhk03': dict(separate_camera_set=True, 86 | # single_gallery_shot=True, 87 | # first_match_break=False), 88 | 'market1501': dict(separate_camera_set=False, 89 | single_gallery_shot=False, 90 | first_match_break=True)} 91 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 92 | query_cams, gallery_cams, **params) 93 | for name, params in cmc_configs.items()} 94 | 95 | print('CMC Scores:') 96 | for k in cmc_topk: 97 | print(' top-{:<4}{:12.1%}' 98 | .format(k, 99 | # cmc_scores['allshots'][k - 1], 100 | # cmc_scores['cuhk03'][k - 1], 101 | cmc_scores['market1501'][k - 1])) 102 | 103 | # Use the allshots cmc top-1 score for validation criterion 104 | return cmc_scores['market1501'][0] 105 | 106 | 107 | class Evaluator(object): 108 | def __init__(self, model): 109 | super(Evaluator, self).__init__() 110 | self.model = model 111 | 112 | def evaluate(self, data_loader, query, gallery, metric=None): 113 | features, _ = extract_features(self.model, data_loader) 114 | distmat = pairwise_distance(features, query, gallery, metric=metric) 115 | return evaluate_all(distmat, query=query, gallery=gallery) 116 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cnn import extract_cnn_feature 4 | from .database import FeatureDatabase 5 | 6 | __all__ = [ 7 | 'extract_cnn_feature', 8 | 'FeatureDatabase', 9 | ] 10 | -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import OrderedDict 3 | 4 | from torch.autograd import Variable 5 | 6 | from ..utils import to_torch 7 | 8 | 9 | def extract_cnn_feature(model, inputs, modules=None): 10 | model.eval() 11 | inputs = to_torch(inputs) 12 | inputs = Variable(inputs, volatile=True) 13 | if modules is None: 14 | outputs = model(inputs) 15 | outputs = outputs.data.cpu() 16 | return outputs 17 | # Register forward hook for each module 18 | outputs = OrderedDict() 19 | handles = [] 20 | for m in modules: 21 | outputs[id(m)] = None 22 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 23 | handles.append(m.register_forward_hook(func)) 24 | model(inputs) 25 | for h in handles: 26 | h.remove() 27 | return list(outputs.values()) 28 | -------------------------------------------------------------------------------- /reid/feature_extraction/database.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import h5py 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FeatureDatabase(Dataset): 9 | def __init__(self, *args, **kwargs): 10 | super(FeatureDatabase, self).__init__() 11 | self.fid = h5py.File(*args, **kwargs) 12 | 13 | def __enter__(self): 14 | return self 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | self.close() 18 | 19 | def __getitem__(self, keys): 20 | if isinstance(keys, (tuple, list)): 21 | return [self._get_single_item(k) for k in keys] 22 | return self._get_single_item(keys) 23 | 24 | def _get_single_item(self, key): 25 | return np.asarray(self.fid[key]) 26 | 27 | def __setitem__(self, key, value): 28 | if key in self.fid: 29 | if self.fid[key].shape == value.shape and \ 30 | self.fid[key].dtype == value.dtype: 31 | self.fid[key][...] = value 32 | else: 33 | del self.fid[key] 34 | self.fid.create_dataset(key, data=value) 35 | else: 36 | self.fid.create_dataset(key, data=value) 37 | 38 | def __delitem__(self, key): 39 | del self.fid[key] 40 | 41 | def __len__(self): 42 | return len(self.fid) 43 | 44 | def __iter__(self): 45 | return iter(self.fid) 46 | 47 | def flush(self): 48 | self.fid.flush() 49 | 50 | def close(self): 51 | self.fid.close() 52 | -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .oim import oim, OIM, OIMLoss 4 | from .triplet import TripletLoss 5 | from .tri_clu_loss import TripletClusteringLoss 6 | 7 | __all__ = [ 8 | 'oim', 9 | 'OIM', 10 | 'OIMLoss', 11 | 'TripletLoss', 12 | 'TripletClusteringLoss' 13 | ] 14 | -------------------------------------------------------------------------------- /reid/loss/oim.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, autograd 6 | 7 | 8 | class OIM(autograd.Function): 9 | def __init__(self, lut, momentum=0.5): 10 | super(OIM, self).__init__() 11 | self.lut = lut 12 | self.momentum = momentum 13 | 14 | def forward(self, inputs, targets): 15 | self.save_for_backward(inputs, targets) 16 | outputs = inputs.mm(self.lut.t()) 17 | return outputs 18 | 19 | def backward(self, grad_outputs): 20 | inputs, targets = self.saved_tensors 21 | grad_inputs = None 22 | if self.needs_input_grad[0]: 23 | grad_inputs = grad_outputs.mm(self.lut) 24 | for x, y in zip(inputs, targets): 25 | self.lut[y] = self.momentum * self.lut[y] + (1. - self.momentum) * x 26 | self.lut[y] /= self.lut[y].norm() 27 | return grad_inputs, None 28 | 29 | 30 | def oim(inputs, targets, lut, momentum=0.5): 31 | return OIM(lut, momentum=momentum)(inputs, targets) 32 | 33 | 34 | class OIMLoss(nn.Module): 35 | def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5, 36 | weight=None, size_average=True): 37 | super(OIMLoss, self).__init__() 38 | self.num_features = num_features 39 | self.num_classes = num_classes 40 | self.momentum = momentum 41 | self.scalar = scalar 42 | self.weight = weight 43 | self.size_average = size_average 44 | 45 | self.register_buffer('lut', torch.zeros(num_classes, num_features)) 46 | 47 | def forward(self, inputs, targets): 48 | inputs = oim(inputs, targets, self.lut, momentum=self.momentum) 49 | inputs *= self.scalar 50 | loss = F.cross_entropy(inputs, targets, weight=self.weight, 51 | size_average=self.size_average) 52 | return loss, inputs 53 | -------------------------------------------------------------------------------- /reid/loss/tri_clu_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | 9 | class TripletClusteringLoss(nn.Module): 10 | def __init__(self, clusters, margin=0,): 11 | super(TripletClusteringLoss, self).__init__() 12 | assert isinstance(clusters, torch.autograd.Variable) 13 | self.clusters = clusters 14 | self.margin = margin 15 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 16 | self.num_classes = clusters.size(0) 17 | self.num_features = clusters.size(1) 18 | self.dist = torch.pow(self.clusters, 2).sum(dim=1, keepdim=True) 19 | 20 | def forward(self, inputs, targets): 21 | assert self.num_features == input.size(1) 22 | n = inputs.size(0) 23 | dist = self.dist.expand(self.num_classes, n) 24 | dist += torch.pow(inputs, 2).sum(dim=1).t() 25 | dist.addmm_(1, -2, self.clusters, inputs.t()) 26 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 27 | dist = dist.t() 28 | # For each anchor, find the hardest positive and negative 29 | mask = torch.zeros(n,self.num_classes,out=torch.ByteTensor()) 30 | target_ids = targets.data.numpy().astype(int) 31 | mask[np.arange(n),target_ids] = 1 32 | dist_ap = dist[mask == 1] 33 | dist_an = dist[mask == 0].view(n, -1).min(dim=1) 34 | # Compute ranking hinge loss 35 | y = dist_an.data.new() 36 | y.resize_as_(dist_an.data) 37 | y.fill_(1) 38 | y = Variable(y) 39 | loss = self.ranking_loss(dist_an, dist_ap, y) 40 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 41 | return loss, prec 42 | 43 | def update_clusters(self,clusters): 44 | self.clusters = clusters 45 | -------------------------------------------------------------------------------- /reid/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | 7 | 8 | class TripletLoss(nn.Module): 9 | def __init__(self, margin=0): 10 | super(TripletLoss, self).__init__() 11 | self.margin = margin 12 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 13 | 14 | def forward(self, inputs, targets): 15 | n = inputs.size(0) 16 | # Compute pairwise distance, replace by the official when merged 17 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 18 | dist = dist + dist.t() 19 | dist.addmm_(1, -2, inputs, inputs.t()) 20 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 21 | # For each anchor, find the hardest positive and negative 22 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 23 | dist_ap, dist_an = [], [] 24 | for i in range(n): 25 | dist_ap.append(dist[i][mask[i]].max()) 26 | dist_an.append(dist[i][mask[i] == 0].min()) 27 | dist_ap = torch.cat(dist_ap) 28 | dist_an = torch.cat(dist_an) 29 | # Compute ranking hinge loss 30 | y = dist_an.data.new() 31 | y.resize_as_(dist_an.data) 32 | y.fill_(1) 33 | y = Variable(y) 34 | loss = self.ranking_loss(dist_an, dist_ap, y) 35 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 36 | return loss, prec 37 | -------------------------------------------------------------------------------- /reid/metric_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from metric_learn import (ITML_Supervised, LMNN, LSML_Supervised, 4 | SDML_Supervised, NCA, LFDA, RCA_Supervised) 5 | 6 | from .euclidean import Euclidean 7 | from .kissme import KISSME 8 | 9 | __factory = { 10 | 'euclidean': Euclidean, 11 | 'kissme': KISSME, 12 | 'itml': ITML_Supervised, 13 | 'lmnn': LMNN, 14 | 'lsml': LSML_Supervised, 15 | 'sdml': SDML_Supervised, 16 | 'nca': NCA, 17 | 'lfda': LFDA, 18 | 'rca': RCA_Supervised, 19 | } 20 | 21 | 22 | def get_metric(algorithm, *args, **kwargs): 23 | if algorithm not in __factory: 24 | raise KeyError("Unknown metric:", algorithm) 25 | return __factory[algorithm](*args, **kwargs) 26 | -------------------------------------------------------------------------------- /reid/metric_learning/euclidean.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from metric_learn.base_metric import BaseMetricLearner 5 | 6 | 7 | class Euclidean(BaseMetricLearner): 8 | def __init__(self): 9 | self.M_ = None 10 | 11 | def metric(self): 12 | return self.M_ 13 | 14 | def fit(self, X): 15 | self.M_ = np.eye(X.shape[1]) 16 | self.X_ = X 17 | 18 | def transform(self, X=None): 19 | if X is None: 20 | return self.X_ 21 | return X 22 | -------------------------------------------------------------------------------- /reid/metric_learning/kissme.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from metric_learn.base_metric import BaseMetricLearner 5 | 6 | 7 | def validate_cov_matrix(M): 8 | M = (M + M.T) * 0.5 9 | k = 0 10 | I = np.eye(M.shape[0]) 11 | while True: 12 | try: 13 | _ = np.linalg.cholesky(M) 14 | break 15 | except np.linalg.LinAlgError: 16 | # Find the nearest positive definite matrix for M. Modified from 17 | # http://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd 18 | # Might take several minutes 19 | k += 1 20 | w, v = np.linalg.eig(M) 21 | min_eig = v.min() 22 | M += (-min_eig * k * k + np.spacing(min_eig)) * I 23 | return M 24 | 25 | 26 | class KISSME(BaseMetricLearner): 27 | def __init__(self): 28 | self.M_ = None 29 | 30 | def metric(self): 31 | return self.M_ 32 | 33 | def fit(self, X, y=None): 34 | n = X.shape[0] 35 | if y is None: 36 | y = np.arange(n) 37 | X1, X2 = np.meshgrid(np.arange(n), np.arange(n)) 38 | X1, X2 = X1[X1 < X2], X2[X1 < X2] 39 | matches = (y[X1] == y[X2]) 40 | num_matches = matches.sum() 41 | num_non_matches = len(matches) - num_matches 42 | idxa = X1[matches] 43 | idxb = X2[matches] 44 | S = X[idxa] - X[idxb] 45 | C1 = S.transpose().dot(S) / num_matches 46 | p = np.random.choice(num_non_matches, num_matches, replace=False) 47 | idxa = X1[~matches] 48 | idxb = X2[~matches] 49 | idxa = idxa[p] 50 | idxb = idxb[p] 51 | S = X[idxa] - X[idxb] 52 | C0 = S.transpose().dot(S) / num_matches 53 | self.M_ = np.linalg.inv(C1) - np.linalg.inv(C0) 54 | self.M_ = validate_cov_matrix(self.M_) 55 | self.X_ = X 56 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .end2end import * 5 | 6 | 7 | __factory = { 8 | 'avg_pool': End2End_AvgPooling, 9 | } 10 | 11 | 12 | def names(): 13 | return sorted(__factory.keys()) 14 | 15 | 16 | def create(name, *args, **kwargs): 17 | """ 18 | Create a model instance. 19 | 20 | Parameters 21 | ---------- 22 | name : str 23 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 24 | 'resnet50', 'resnet101', and 'resnet152'. 25 | pretrained : bool, optional 26 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 27 | model. Default: True 28 | cut_at_pooling : bool, optional 29 | If True, will cut the model before the last global pooling layer and 30 | ignore the remaining kwargs. Default: False 31 | num_features : int, optional 32 | If positive, will append a Linear layer after the global pooling layer, 33 | with this number of output units, followed by a BatchNorm layer. 34 | Otherwise these layers will not be appended. Default: 256 for 35 | 'inception', 0 for 'resnet*' 36 | norm : bool, optional 37 | If True, will normalize the feature to be unit L2-norm for each sample. 38 | Otherwise will append a ReLU layer after the above Linear layer if 39 | num_features > 0. Default: False 40 | dropout : float, optional 41 | If positive, will append a Dropout layer with this dropout rate. 42 | Default: 0 43 | num_classes : int, optional 44 | If positive, will append a Linear layer at the end as the classifier 45 | with this number of output units. Default: 0 46 | """ 47 | if name not in __factory: 48 | raise KeyError("Unknown model:", name) 49 | return __factory[name](*args, **kwargs) 50 | -------------------------------------------------------------------------------- /reid/models/end2end.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.autograd import Variable 5 | from torch.nn import functional as F 6 | from torch.nn import init 7 | import torch 8 | import torchvision 9 | import math 10 | 11 | from .resnet import * 12 | 13 | 14 | __all__ = ["End2End_AvgPooling"] 15 | 16 | 17 | class AvgPooling(nn.Module): 18 | def __init__(self, input_feature_size, num_classes, mode, embeding_fea_size=1024, dropout=0.5): 19 | super(self.__class__, self).__init__() 20 | 21 | is_output_feature = {"Dissimilarity":True, "Classification":False} 22 | self.is_output_feature = is_output_feature[mode] 23 | 24 | # embeding 25 | self.embeding = nn.Linear(input_feature_size, embeding_fea_size) 26 | self.embeding_bn = nn.BatchNorm1d(embeding_fea_size) 27 | init.kaiming_normal(self.embeding.weight, mode='fan_out') 28 | init.constant(self.embeding.bias, 0) 29 | init.constant(self.embeding_bn.weight, 1) 30 | init.constant(self.embeding_bn.bias, 0) 31 | self.drop = nn.Dropout(dropout) 32 | 33 | # classifier 34 | self.classify_fc = nn.Linear(embeding_fea_size, num_classes) 35 | init.normal(self.classify_fc.weight, std = 0.001) 36 | init.constant(self.classify_fc.bias, 0) 37 | 38 | def forward(self, inputs): 39 | avg_pool_feat = inputs.mean(dim = 1) 40 | if (not self.training) and self.is_output_feature: 41 | return F.normalize(avg_pool_feat, p=2, dim=1) 42 | 43 | # embeding 44 | net = self.drop(avg_pool_feat) 45 | net = self.embeding(net) 46 | net = self.embeding_bn(net) 47 | net = F.relu(net) 48 | 49 | net = self.drop(net) 50 | 51 | # classifier 52 | predict = self.classify_fc(net) 53 | return predict 54 | 55 | 56 | 57 | class End2End_AvgPooling(nn.Module): 58 | 59 | def __init__(self, pretrained=True, dropout=0, num_classes=0, mode="retrieval"): 60 | super(self.__class__, self).__init__() 61 | 62 | self.CNN = resnet50(dropout=dropout) 63 | self.avg_pooling = AvgPooling(input_feature_size=2048, num_classes=num_classes, dropout=dropout, mode=mode) 64 | 65 | def forward(self, x): 66 | assert len(x.data.shape) == 5 67 | # reshape (batch, samples, ...) ==> (batch * samples, ...) 68 | oriShape = x.data.shape 69 | x = x.view(-1, oriShape[2], oriShape[3], oriShape[4]) 70 | 71 | # resnet encoding 72 | resnet_feature = self.CNN(x) 73 | 74 | # reshape back into (batch, samples, ...) 75 | resnet_feature = resnet_feature.view(oriShape[0], oriShape[1], -1) 76 | 77 | # avg pooling 78 | # if eval and cut_off_before_logits, return predict; else return avg pooling feature 79 | predict = self.avg_pooling(resnet_feature) 80 | return predict 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /reid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152'] 11 | 12 | 13 | class ResNet(nn.Module): 14 | __factory = { 15 | 18: torchvision.models.resnet18, 16 | 34: torchvision.models.resnet34, 17 | 50: torchvision.models.resnet50, 18 | 101: torchvision.models.resnet101, 19 | 152: torchvision.models.resnet152, 20 | } 21 | 22 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 23 | num_features=0, norm=False, dropout=0, num_classes=0): 24 | super(ResNet, self).__init__() 25 | 26 | self.depth = depth 27 | self.pretrained = pretrained 28 | self.cut_at_pooling = cut_at_pooling 29 | 30 | # Construct base (pretrained) resnet 31 | if depth not in ResNet.__factory: 32 | raise KeyError("Unsupported depth:", depth) 33 | 34 | 35 | 36 | self.base = ResNet.__factory[depth](pretrained=pretrained) 37 | 38 | # fix layers [conv1 ~ layer2] 39 | fixed_names = [] 40 | for name, module in self.base._modules.items(): 41 | if name == "layer3": 42 | assert fixed_names == ["conv1", "bn1", "relu", "maxpool", "layer1", "layer2"] 43 | break 44 | fixed_names.append(name) 45 | for param in module.parameters(): 46 | param.requires_grad = False 47 | 48 | if not self.cut_at_pooling: 49 | self.num_features = num_features 50 | self.norm = norm 51 | self.dropout = dropout 52 | self.has_embedding = num_features > 0 53 | self.num_classes = num_classes 54 | 55 | out_planes = self.base.fc.in_features 56 | 57 | # Append new layers 58 | if self.has_embedding: 59 | self.feat = nn.Linear(out_planes, self.num_features) 60 | self.feat_bn = nn.BatchNorm1d(self.num_features) 61 | init.kaiming_normal(self.feat.weight, mode='fan_out') 62 | init.constant(self.feat.bias, 0) 63 | init.constant(self.feat_bn.weight, 1) 64 | init.constant(self.feat_bn.bias, 0) 65 | else: 66 | # Change the num_features to CNN output channels 67 | self.num_features = out_planes 68 | if self.dropout > 0: 69 | self.drop = nn.Dropout(self.dropout) 70 | if self.num_classes > 0: 71 | self.classifier = nn.Linear(self.num_features, self.num_classes) 72 | init.normal(self.classifier.weight, std=0.001) 73 | init.constant(self.classifier.bias, 0) 74 | 75 | if not self.pretrained: 76 | self.reset_params() 77 | 78 | def forward(self, x): 79 | 80 | for name, module in self.base._modules.items(): 81 | #print(name) 82 | if name == 'avgpool': 83 | break 84 | x = module(x) 85 | """ 86 | if name == "bn1": 87 | for param in module.parameters(): 88 | print(param.requires_grad) 89 | print(module.running_mean) 90 | import pdb 91 | pdb.set_trace() 92 | """ 93 | if self.cut_at_pooling: 94 | return x 95 | 96 | x = F.avg_pool2d(x, x.size()[2:]) 97 | x = x.view(x.size(0), -1) 98 | 99 | """ 100 | if self.has_embedding: 101 | x = self.feat(x) 102 | x = self.feat_bn(x) 103 | if self.norm: 104 | x = F.normalize(x) 105 | elif self.has_embedding: 106 | x = F.relu(x) 107 | if self.dropout > 0: 108 | x = self.drop(x) 109 | if self.num_classes > 0: 110 | x = self.classifier(x) 111 | """ 112 | return x 113 | 114 | def reset_params(self): 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | init.kaiming_normal(m.weight, mode='fan_out') 118 | if m.bias is not None: 119 | init.constant(m.bias, 0) 120 | elif isinstance(m, nn.BatchNorm2d): 121 | init.constant(m.weight, 1) 122 | init.constant(m.bias, 0) 123 | elif isinstance(m, nn.Linear): 124 | init.normal(m.weight, std=0.001) 125 | if m.bias is not None: 126 | init.constant(m.bias, 0) 127 | 128 | 129 | def resnet18(**kwargs): 130 | return ResNet(18, **kwargs) 131 | 132 | 133 | def resnet34(**kwargs): 134 | return ResNet(34, **kwargs) 135 | 136 | 137 | def resnet50(**kwargs): 138 | return ResNet(50, **kwargs) 139 | 140 | 141 | def resnet101(**kwargs): 142 | return ResNet(101, **kwargs) 143 | 144 | 145 | def resnet152(**kwargs): 146 | return ResNet(152, **kwargs) 147 | -------------------------------------------------------------------------------- /reid/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from .evaluation_metrics import accuracy 8 | from .loss import OIMLoss, TripletLoss 9 | from .utils.meters import AverageMeter 10 | 11 | 12 | class BaseTrainer(object): 13 | def __init__(self, model, criterion): 14 | super(BaseTrainer, self).__init__() 15 | self.model = model 16 | self.criterion = criterion 17 | 18 | def train(self, epoch, data_loader, optimizer, print_freq=30): 19 | self.model.train() 20 | 21 | 22 | # The following code is used to keep the BN on the first three block fixed 23 | fixed_bns = [] 24 | for idx, (name, module) in enumerate(self.model.module.named_modules()): 25 | if name.find("layer3") != -1: 26 | assert len(fixed_bns) == 22 27 | break 28 | if name.find("bn") != -1: 29 | fixed_bns.append(name) 30 | module.eval() 31 | 32 | 33 | 34 | batch_time = AverageMeter() 35 | data_time = AverageMeter() 36 | losses = AverageMeter() 37 | precisions = AverageMeter() 38 | 39 | end = time.time() 40 | for i, inputs in enumerate(data_loader): 41 | data_time.update(time.time() - end) 42 | 43 | inputs, targets = self._parse_data(inputs) 44 | loss, prec1 = self._forward(inputs, targets) 45 | 46 | losses.update(loss.data[0], targets.size(0)) 47 | precisions.update(prec1, targets.size(0)) 48 | 49 | optimizer.zero_grad() 50 | loss.backward() 51 | 52 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 53 | #torch.nn.utils.clip_grad_norm(self.model.parameters(), 0.75) 54 | optimizer.step() 55 | 56 | batch_time.update(time.time() - end) 57 | end = time.time() 58 | 59 | if (i + 1) % print_freq == 0: 60 | print('Epoch: [{}][{}/{}]\t' 61 | 'Time {:.3f} ({:.3f})\t' 62 | 'Data {:.3f} ({:.3f})\t' 63 | 'Loss {:.3f} ({:.3f})\t' 64 | 'Prec {:.2%} ({:.2%})\t' 65 | .format(epoch, i + 1, len(data_loader), 66 | batch_time.val, batch_time.avg, 67 | data_time.val, data_time.avg, 68 | losses.val, losses.avg, 69 | precisions.val, precisions.avg)) 70 | 71 | def _parse_data(self, inputs): 72 | raise NotImplementedError 73 | 74 | def _forward(self, inputs, targets): 75 | raise NotImplementedError 76 | 77 | 78 | class Trainer(BaseTrainer): 79 | def _parse_data(self, inputs): 80 | imgs, _, pids, _, _ = inputs 81 | inputs = Variable(imgs, requires_grad=False) 82 | targets = Variable(pids.cuda()) 83 | return inputs, targets 84 | 85 | def _forward(self, inputs, targets): 86 | outputs = self.model(inputs) 87 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 88 | loss = self.criterion(outputs, targets) 89 | prec, = accuracy(outputs.data, targets.data) 90 | prec = prec[0] 91 | elif isinstance(self.criterion, OIMLoss): 92 | loss, outputs = self.criterion(outputs, targets) 93 | prec, = accuracy(outputs.data, targets.data) 94 | prec = prec[0] 95 | elif isinstance(self.criterion, TripletLoss): 96 | loss, prec = self.criterion(outputs, targets) 97 | else: 98 | raise ValueError("Unsupported loss:", self.criterion) 99 | return loss, prec 100 | 101 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .dataset import Dataset 4 | from .preprocessor import Preprocessor 5 | -------------------------------------------------------------------------------- /reid/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..serialization import read_json 7 | 8 | 9 | def _pluck(identities, indices, relabel=False): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, video_ids in enumerate(pid_images): 14 | for video_id in video_ids: 15 | images = video_ids[video_id] 16 | for fname in images: 17 | name = osp.splitext(fname)[0] 18 | if relabel: 19 | ret.append((images, index, camid, video_id)) 20 | else: 21 | ret.append((images, pid, camid, video_id)) 22 | return ret 23 | 24 | 25 | class Dataset(object): 26 | def __init__(self, root, split_id=0): 27 | self.root = root 28 | self.split_id = split_id 29 | self.meta = None 30 | self.split = None 31 | self.train = [] 32 | self.query, self.gallery = [], [] 33 | self.num_train_ids = 0 34 | 35 | @property 36 | def images_dir(self): 37 | return osp.join(self.root, 'images') 38 | 39 | def load(self, verbose=True): 40 | splits = read_json(osp.join(self.root, 'splits.json')) 41 | if self.split_id >= len(splits): 42 | raise ValueError("split_id exceeds total splits {}" 43 | .format(len(splits))) 44 | self.split = splits[self.split_id] 45 | 46 | train_pids = np.asarray(self.split['train']) 47 | 48 | 49 | self.meta = read_json(osp.join(self.root, 'meta.json')) 50 | identities = self.meta['identities'] 51 | self.train = _pluck(identities, train_pids, relabel=True) 52 | self.query = _pluck(identities, self.split['query']) 53 | self.gallery = _pluck(identities, self.split['gallery']) 54 | self.num_train_ids = len(train_pids) 55 | 56 | if 'query' in self.meta: 57 | query_fnames = self.meta['query'] 58 | gallery_fnames = self.meta['gallery'] 59 | self.query = [] 60 | for fname_list in query_fnames: 61 | name = osp.splitext(fname_list[0])[0] 62 | pid, cam, vid, _ = map(int, name.split('_')) 63 | self.query.append((tuple(fname_list), pid, cam, vid)) 64 | self.gallery = [] 65 | for fname_list in gallery_fnames: 66 | name = osp.splitext(fname_list[0])[0] 67 | pid, cam, vid, _ = map(int, name.split('_')) 68 | self.gallery.append((tuple(fname_list), pid, cam, vid)) 69 | 70 | 71 | 72 | if verbose: 73 | print(self.__class__.__name__, "dataset loaded") 74 | print(" subset | # ids | # tracklets") 75 | print(" ---------------------------") 76 | print(" train | {:5d} | {:8d}" 77 | .format(self.num_train_ids, len(self.train))) 78 | print(" query | {:5d} | {:8d}" 79 | .format(len(self.split['query']), len(self.query))) 80 | print(" gallery | {:5d} | {:8d}" 81 | .format(len(self.split['gallery']), len(self.gallery))) 82 | 83 | def _check_integrity(self): 84 | return osp.isdir(osp.join(self.root, 'images')) and \ 85 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 86 | osp.isfile(osp.join(self.root, 'splits.json')) 87 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | import numpy as np 6 | import torch 7 | import random 8 | 9 | 10 | class Preprocessor(object): 11 | def __init__(self, dataset, root=None, transform=None, num_samples=16, is_training=True, max_frames = 900): 12 | super(Preprocessor, self).__init__() 13 | self.dataset = dataset 14 | self.root = root 15 | self.transform = transform 16 | self.selected_frames_num = num_samples 17 | self.is_training = is_training 18 | self.max_frames=max_frames 19 | 20 | def __len__(self): 21 | return len(self.dataset) 22 | 23 | def __getitem__(self, indices): 24 | if isinstance(indices, (tuple, list)): 25 | return [self._get_single_item(index) for index in indices] 26 | return self._get_single_item(indices) 27 | 28 | def _get_single_item(self, index): 29 | images, pid, camid, videoid = self.dataset[index] 30 | image_str = "".join(images) 31 | 32 | # random select images if training 33 | if self.is_training: 34 | if len(images) >= self.selected_frames_num: 35 | images = random.sample(images, self.selected_frames_num) 36 | else: 37 | images = random.choices(images, k=self.selected_frames_num) 38 | images.sort() 39 | 40 | else: # for evaluation, we use all the frames 41 | if len(images) > self.max_frames: # to avoid the insufficient memory 42 | images = random.sample(images, self.max_frames) 43 | 44 | video_frames = [] 45 | for fname in images: 46 | if self.root is not None: 47 | fpath = osp.join(self.root, fname) 48 | img = Image.open(fpath).convert('RGB') 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | video_frames.append(img) 52 | 53 | video_frames = torch.stack(video_frames, dim=0) 54 | pid = int(pid) 55 | return video_frames, image_str, pid, camid, videoid 56 | -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | def __init__(self, data_source, num_instances=1): 13 | self.data_source = data_source 14 | self.num_instances = num_instances 15 | self.index_dic = defaultdict(list) 16 | for index, (_, pid, _) in enumerate(data_source): 17 | self.index_dic[pid].append(index) 18 | self.pids = list(self.index_dic.keys()) 19 | self.num_samples = len(self.pids) 20 | 21 | def __len__(self): 22 | return self.num_samples * self.num_instances 23 | 24 | def __iter__(self): 25 | indices = torch.randperm(self.num_samples) 26 | ret = [] 27 | for i in indices: 28 | pid = self.pids[i] 29 | t = self.index_dic[pid] 30 | if len(t) >= self.num_instances: 31 | t = np.random.choice(t, size=self.num_instances, replace=False) 32 | else: 33 | t = np.random.choice(t, size=self.num_instances, replace=True) 34 | ret.extend(t) 35 | return iter(ret) 36 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from torchvision.transforms import * 3 | from PIL import Image 4 | import random, math 5 | 6 | class RectScale(object): 7 | def __init__(self, height, width, interpolation=Image.BILINEAR): 8 | self.height = height 9 | self.width = width 10 | self.interpolation = interpolation 11 | 12 | def __call__(self, img): 13 | w, h = img.size 14 | if h == self.height and w == self.width: 15 | return img 16 | return img.resize((self.width, self.height), self.interpolation) 17 | 18 | 19 | class RandomSizedRectCrop(object): 20 | def __init__(self, height, width, interpolation=Image.BILINEAR): 21 | self.height = height 22 | self.width = width 23 | self.interpolation = interpolation 24 | 25 | def __call__(self, img): 26 | for attempt in range(10): 27 | area = img.size[0] * img.size[1] 28 | target_area = random.uniform(0.64, 1.0) * area 29 | aspect_ratio = random.uniform(2, 3) 30 | 31 | h = int(round(math.sqrt(target_area * aspect_ratio))) 32 | w = int(round(math.sqrt(target_area / aspect_ratio))) 33 | 34 | if w <= img.size[0] and h <= img.size[1]: 35 | x1 = random.randint(0, img.size[0] - w) 36 | y1 = random.randint(0, img.size[1] - h) 37 | 38 | img = img.crop((x1, y1, x1 + w, y1 + h)) 39 | assert(img.size == (w, h)) 40 | 41 | return img.resize((self.width, self.height), self.interpolation) 42 | 43 | # Fallback 44 | scale = RectScale(self.height, self.width, 45 | interpolation=self.interpolation) 46 | return scale(img) 47 | -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | from reid.eug import * 3 | from reid import datasets 4 | from reid import models 5 | import numpy as np 6 | import torch 7 | import argparse 8 | import os 9 | 10 | from reid.utils.logging import Logger 11 | import os.path as osp 12 | import sys 13 | from torch.backends import cudnn 14 | from reid.utils.serialization import load_checkpoint 15 | from torch import nn 16 | import time 17 | import pickle 18 | 19 | 20 | def resume(args): 21 | import re 22 | pattern=re.compile(r'step_(\d+)\.ckpt') 23 | start_step = -1 24 | ckpt_file = "" 25 | 26 | # find start step 27 | files = os.listdir(args.logs_dir) 28 | files.sort() 29 | for filename in files: 30 | try: 31 | iter_ = int(pattern.search(filename).groups()[0]) 32 | if iter_ > start_step: 33 | start_step = iter_ 34 | ckpt_file = osp.join(args.logs_dir, filename) 35 | except: 36 | continue 37 | 38 | # if need resume 39 | if start_step >= 0: 40 | print("continued from iter step", start_step) 41 | 42 | return start_step, ckpt_file 43 | 44 | 45 | 46 | 47 | 48 | def main(args): 49 | cudnn.benchmark = True 50 | cudnn.enabled = True 51 | save_path = args.logs_dir 52 | total_step = 100//args.EF + 1 53 | sys.stdout = Logger(osp.join(args.logs_dir, 'log'+ str(args.EF)+ time.strftime(".%m_%d_%H:%M:%S") + '.txt')) 54 | 55 | # get all the labeled and unlabeled data for training 56 | dataset_all = datasets.create(args.dataset, osp.join(args.data_dir, args.dataset)) 57 | num_all_examples = len(dataset_all.train) 58 | l_data, u_data = get_one_shot_in_cam1(dataset_all, load_path="./examples/oneshot_{}_used_in_paper.pickle".format(dataset_all.name)) 59 | 60 | resume_step, ckpt_file = -1, '' 61 | if args.resume: 62 | resume_step, ckpt_file = resume(args) 63 | 64 | # initial the EUG algorithm 65 | eug = EUG(model_name=args.arch, batch_size=args.batch_size, mode=args.mode, num_classes=dataset_all.num_train_ids, 66 | data_dir=dataset_all.images_dir, l_data=l_data, u_data=u_data, save_path=args.logs_dir, max_frames=args.max_frames) 67 | 68 | 69 | new_train_data = l_data 70 | for step in range(total_step): 71 | # for resume 72 | if step < resume_step: 73 | continue 74 | 75 | nums_to_select = min(int( len(u_data) * (step+1) * args.EF / 100 ), len(u_data)) 76 | print("This is running {} with EF={}%, step {}:\t Nums_to_be_select {}, \t Logs-dir {}".format( 77 | args.mode, args.EF, step, nums_to_select, save_path)) 78 | 79 | # train the model or load ckpt 80 | eug.train(new_train_data, step, epochs=70, step_size=55, init_lr=0.1) if step != resume_step else eug.resume(ckpt_file, step) 81 | 82 | # pseudo-label and confidence score 83 | pred_y, pred_score = eug.estimate_label() 84 | 85 | # select data 86 | selected_idx = eug.select_top_data(pred_score, nums_to_select) 87 | 88 | # add new data 89 | new_train_data = eug.generate_new_train_data(selected_idx, pred_y) 90 | 91 | # evluate 92 | eug.evaluate(dataset_all.query, dataset_all.gallery) 93 | 94 | 95 | 96 | if __name__ == '__main__': 97 | parser = argparse.ArgumentParser(description='Exploit the Unknown Gradually') 98 | parser.add_argument('-d', '--dataset', type=str, default='mars', 99 | choices=datasets.names()) 100 | parser.add_argument('-b', '--batch-size', type=int, default=16) 101 | parser.add_argument('-a', '--arch', type=str, default='avg_pool', 102 | choices=models.names()) 103 | parser.add_argument('-i', '--iter-step', type=int, default=5) 104 | parser.add_argument('-g', '--gamma', type=float, default=0.3) 105 | parser.add_argument('-l', '--l', type=float) 106 | parser.add_argument('--EF', type=int, default=10) 107 | working_dir = os.path.dirname(os.path.abspath(__file__)) 108 | parser.add_argument('--data_dir', type=str, metavar='PATH', 109 | default=os.path.join(working_dir,'data')) 110 | parser.add_argument('--logs_dir', type=str, metavar='PATH', 111 | default=os.path.join(working_dir,'logs')) 112 | parser.add_argument('--resume', type=str, default=None) 113 | parser.add_argument('--continuous', action="store_true") 114 | parser.add_argument('--mode', type=str, choices=["Classification", "Dissimilarity"]) 115 | parser.add_argument('--max_frames', type=int, default=900) 116 | main(parser.parse_args()) 117 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # run DukeMTMC-VideoReID 2 | python3 run.py --dataset DukeMTMC-VideoReID --logs_dir logs/DukeMTMC-VideoReID_EF_10/ --EF 10 --mode Dissimilarity --max_frames 900 3 | 4 | # run mars 5 | #python3 run.py --dataset mars --logs_dir logs/mars_EF_10/ --EF 10 --mode Dissimilarity --max_frames 900 6 | 7 | # if you need to resume 8 | # python3 run.py --dataset mars --logs_dir logs/mars_EF_10/ --EF 10 --mode Dissimilarity --max_frames 900 --resume logs/mars_EF_10/ 9 | 10 | --------------------------------------------------------------------------------