├── .gitignore ├── .gitmodules ├── DATASETS.md ├── LICENSE ├── README.md ├── docs └── visualization │ └── coseg_model_mars_viz.ipynb ├── scratch └── readme └── src ├── data_manager.py ├── eval_metrics.py ├── losses.py ├── main_video_person_reid.py ├── models ├── ResNet.py ├── SE_ResNet.py ├── __init__.py ├── aggregation_layers.py ├── cosam.py └── senet.py ├── project_utils.py ├── samplers.py ├── transforms.py ├── utils.py └── video_loader.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/person_re_ranking"] 2 | path = src/person_re_ranking 3 | url = https://github.com/InnovArul/person_re_ranking 4 | -------------------------------------------------------------------------------- /DATASETS.md: -------------------------------------------------------------------------------- 1 | ## MARS 2 | 3 | MARS is the largest dataset available to date for video-based person reID. The instructions are copied here: 4 | 5 | * Create a directory named mars/ under data/. 6 | * Download dataset to data/mars/ from http://www.liangzheng.com.cn/Project/project_mars.html. 7 | * Extract bbox_train.zip and bbox_test.zip. 8 | * Download split information from https://github.com/liangzheng06/MARS-evaluation/tree/master/info and put info/ in data/mars (we want to follow the standard split in [8]). 9 | * The data structure would look like: 10 | 11 | ``` 12 | mars/ 13 | bbox_test/ 14 | bbox_train/ 15 | info/ 16 | ``` 17 | Use -d mars when running the training code. 18 | 19 | ## DukeMTMC-VideoReID 20 | 21 | * Create a directory named dukemtmc-vidreid/ under data/. 22 | 23 | * Download “DukeMTMC-VideoReID” from http://vision.cs.duke.edu/DukeMTMC/ and unzip the file to “dukemtmc-vidreid/”. 24 | 25 | * The data structure should look like 26 | 27 | ``` 28 | dukemtmc-vidreid/ 29 | DukeMTMC-VideoReID/ 30 | train/ 31 | query/ 32 | gallery/ 33 | ``` 34 | 35 | ## iLIDS-VID 36 | 37 | * Create a directory named ilids-vid/ under data/. 38 | 39 | * Download the dataset from http://www.eecs.qmul.ac.uk/~xiatian/downloads_qmul_iLIDS-VID_ReID_dataset.html to "ilids-vid". 40 | 41 | * Organize the data structure to match 42 | 43 | ``` 44 | ilids-vid/ 45 | i-LIDS-VID/ 46 | train-test people splits 47 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Video-based Person Re-identification 2 | Source code for the ICCV-2019 paper "Co-segmentation Inspired Attention Networks for Video-based Person Re-identification". Our paper can be found here. 3 | 4 | ## Introduction 5 | 6 | Our work attempts to tackle some of the challenges in Video-based Person re-identification (Re-ID) such as Background clutter, Misalignment error and partial occlusion by means of an co-segmentation inspired approach. The intention is to attend to the task-dependent common portions of the images (i.e., video frames of a person) that may aid the network in better focusing on most relevant features. This repository contains code for Co-segmentation inspired Re-ID architecture, “Co-segmentation Activation Module (COSAM)". 7 | Co-segmentation masks are “Interpretable” and helps to understand how and where the network attends to when creating a description about the person. 8 | 9 | ### Credits 10 | 11 | The source code is built upon the github repositories Video-Person-ReID (from jiyanggao) and deep-person-reid (from KaiyangZhou). Mainly, the data-loading, data-sampling and training part are borrowed from their repository. The strong baseline performances are based on the models from the codebase Video-Person-ReID. Check out their papers Revisiting Temporal Modeling for Video-based Person ReID (Gao et al.,), OSNet (Zhou et al., ICCV 2019). 12 | 13 | We would like to thank jiyanggao and KaiyangZhou for their generous contribution to release the code to the community. 14 | 15 | ## Datasets 16 | 17 | Dataset preparation instructions can be found in the repositories Video-Person-ReID and deep-person-reid. For completeness, I have compiled the dataset instructions here. 18 | 19 | ## Training 20 | 21 | `python main_video_person_reid.py -a resnet50_cosam45_tp -d --gpu-devices ` 22 | 23 | `` can be `mars` or `dukemtmcvidreid` 24 | 25 | ## Testing 26 | 27 | `python main_video_person_reid.py -a resnet50_cosam45_tp -d --gpu-devices --evaluate --pretrained-model ` 28 | 29 | -------------------------------------------------------------------------------- /scratch/readme: -------------------------------------------------------------------------------- 1 | the training folders will be created here -------------------------------------------------------------------------------- /src/data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import glob 4 | import re 5 | import sys 6 | import urllib 7 | import tarfile 8 | import zipfile 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | import numpy as np 12 | 13 | from utils import mkdir_if_missing, write_json, read_json 14 | 15 | """Dataset classes""" 16 | 17 | 18 | class Mars(object): 19 | """ 20 | MARS 21 | 22 | Reference: 23 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 24 | 25 | Dataset statistics: 26 | # identities: 1261 27 | # tracklets: 8298 (train) + 1980 (query) + 9330 (gallery) 28 | # cameras: 6 29 | 30 | Args: 31 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0). 32 | """ 33 | 34 | def __init__(self, root='../data', min_seq_len=0): 35 | 36 | self.root = os.path.join(root, 'mars') 37 | 38 | root = self.root 39 | self.train_name_path = osp.join(root, 'info/train_name.txt') 40 | self.test_name_path = osp.join(root, 'info/test_name.txt') 41 | self.track_train_info_path = osp.join(root, 'info/tracks_train_info.mat') 42 | self.track_test_info_path = osp.join(root, 'info/tracks_test_info.mat') 43 | self.query_IDX_path = osp.join(root, 'info/query_IDX.mat') 44 | 45 | self._check_before_run() 46 | 47 | # prepare meta data 48 | train_names = self._get_names(self.train_name_path) 49 | test_names = self._get_names(self.test_name_path) 50 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) 51 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) 52 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) 53 | query_IDX -= 1 # index from 0 54 | track_query = track_test[query_IDX,:] 55 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 56 | track_gallery = track_test[gallery_IDX,:] 57 | 58 | train, num_train_tracklets, num_train_pids, num_train_imgs = \ 59 | self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) 60 | 61 | query, num_query_tracklets, num_query_pids, num_query_imgs = \ 62 | self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 63 | 64 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = \ 65 | self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 66 | 67 | num_imgs_per_tracklet = num_train_imgs + num_query_imgs + num_gallery_imgs 68 | min_num = np.min(num_imgs_per_tracklet) 69 | max_num = np.max(num_imgs_per_tracklet) 70 | avg_num = np.mean(num_imgs_per_tracklet) 71 | 72 | num_total_pids = num_train_pids + num_query_pids 73 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 74 | 75 | print("=> MARS loaded") 76 | print("Dataset statistics:") 77 | print(" ------------------------------") 78 | print(" subset | # ids | # tracklets") 79 | print(" ------------------------------") 80 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 81 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 82 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 83 | print(" ------------------------------") 84 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 85 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 86 | print(" ------------------------------") 87 | 88 | self.train = train 89 | self.query = query 90 | self.gallery = gallery 91 | 92 | self.num_train_pids = num_train_pids 93 | self.num_query_pids = num_query_pids 94 | self.num_gallery_pids = num_gallery_pids 95 | 96 | def _check_before_run(self): 97 | """Check if all files are available before going deeper""" 98 | if not osp.exists(self.root): 99 | raise RuntimeError("'{}' is not available".format(self.root)) 100 | if not osp.exists(self.train_name_path): 101 | raise RuntimeError("'{}' is not available".format(self.train_name_path)) 102 | if not osp.exists(self.test_name_path): 103 | raise RuntimeError("'{}' is not available".format(self.test_name_path)) 104 | if not osp.exists(self.track_train_info_path): 105 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path)) 106 | if not osp.exists(self.track_test_info_path): 107 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path)) 108 | if not osp.exists(self.query_IDX_path): 109 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path)) 110 | 111 | def _get_names(self, fpath): 112 | names = [] 113 | with open(fpath, 'r') as f: 114 | for line in f: 115 | new_line = line.rstrip() 116 | names.append(new_line) 117 | return names 118 | 119 | def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 120 | assert home_dir in ['bbox_train', 'bbox_test'] 121 | num_tracklets = meta_data.shape[0] 122 | pid_list = list(set(meta_data[:,2].tolist())) 123 | num_pids = len(pid_list) 124 | 125 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 126 | tracklets = [] 127 | num_imgs_per_tracklet = [] 128 | 129 | for tracklet_idx in range(num_tracklets): 130 | data = meta_data[tracklet_idx,...] 131 | start_index, end_index, pid, camid = data 132 | if pid == -1: continue # junk images are just ignored 133 | assert 1 <= camid <= 6 134 | if relabel: pid = pid2label[pid] 135 | camid -= 1 # index starts from 0 136 | img_names = names[start_index-1:end_index] 137 | 138 | # make sure image names correspond to the same person 139 | pnames = [img_name[:4] for img_name in img_names] 140 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 141 | 142 | # make sure all images are captured under the same camera 143 | camnames = [img_name[5] for img_name in img_names] 144 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 145 | 146 | # append image names with directory information 147 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] 148 | if len(img_paths) >= min_seq_len: 149 | img_paths = tuple(img_paths) 150 | tracklets.append((img_paths, pid, camid)) 151 | num_imgs_per_tracklet.append(len(img_paths)) 152 | 153 | num_tracklets = len(tracklets) 154 | 155 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 156 | 157 | 158 | 159 | class DukeMTMCVidReID(object): 160 | """ 161 | DukeMTMCVidReID 162 | 163 | Reference: 164 | Wu et al. Exploit the Unknown Gradually: One-Shot Video-Based Person 165 | Re-Identification by Stepwise Learning. CVPR 2018. 166 | 167 | URL: https://github.com/Yu-Wu/DukeMTMC-VideoReID 168 | 169 | Dataset statistics: 170 | # identities: 702 (train) + 702 (test) 171 | # tracklets: 2196 (train) + 2636 (test) 172 | """ 173 | dataset_dir = 'dukemtmc-vidreid' 174 | 175 | def __init__(self, root='../data', min_seq_len=0, verbose=True, **kwargs): 176 | self.dataset_dir = osp.join(root, self.dataset_dir) 177 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-VideoReID.zip' 178 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/train') 179 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/query') 180 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/gallery') 181 | self.split_train_json_path = osp.join(self.dataset_dir, 'split_train.json') 182 | self.split_query_json_path = osp.join(self.dataset_dir, 'split_query.json') 183 | self.split_gallery_json_path = osp.join(self.dataset_dir, 'split_gallery.json') 184 | 185 | self.min_seq_len = min_seq_len 186 | self._download_data() 187 | self._check_before_run() 188 | print("Note: if root path is changed, the previously generated json files need to be re-generated (so delete them first)") 189 | 190 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 191 | self._process_dir(self.train_dir, self.split_train_json_path, relabel=True) 192 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 193 | self._process_dir(self.query_dir, self.split_query_json_path, relabel=False) 194 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 195 | self._process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) 196 | 197 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 198 | min_num = np.min(num_imgs_per_tracklet) 199 | max_num = np.max(num_imgs_per_tracklet) 200 | avg_num = np.mean(num_imgs_per_tracklet) 201 | 202 | num_total_pids = num_train_pids + num_query_pids 203 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 204 | 205 | if verbose: 206 | print("=> DukeMTMC-VideoReID loaded") 207 | print("Dataset statistics:") 208 | print(" ------------------------------") 209 | print(" subset | # ids | # tracklets") 210 | print(" ------------------------------") 211 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 212 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 213 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 214 | print(" ------------------------------") 215 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 216 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 217 | print(" ------------------------------") 218 | 219 | self.train = train 220 | self.query = query 221 | self.gallery = gallery 222 | 223 | self.num_train_pids = num_train_pids 224 | self.num_query_pids = num_query_pids 225 | self.num_gallery_pids = num_gallery_pids 226 | 227 | def _download_data(self): 228 | if osp.exists(self.dataset_dir): 229 | print("This dataset has been downloaded.") 230 | return 231 | 232 | print("Creating directory {}".format(self.dataset_dir)) 233 | mkdir_if_missing(self.dataset_dir) 234 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 235 | 236 | print("Downloading DukeMTMC-VideoReID dataset") 237 | url_opener = urllib.URLopener() 238 | url_opener.retrieve(self.dataset_url, fpath) 239 | 240 | print("Extracting files") 241 | zip_ref = zipfile.ZipFile(fpath, 'r') 242 | zip_ref.extractall(self.dataset_dir) 243 | zip_ref.close() 244 | 245 | def _check_before_run(self): 246 | """Check if all files are available before going deeper""" 247 | if not osp.exists(self.dataset_dir): 248 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 249 | if not osp.exists(self.train_dir): 250 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 251 | if not osp.exists(self.query_dir): 252 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 253 | if not osp.exists(self.gallery_dir): 254 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 255 | 256 | def _process_dir(self, dir_path, json_path, relabel): 257 | if osp.exists(json_path): 258 | print("=> {} generated before, awesome!".format(json_path)) 259 | split = read_json(json_path) 260 | return split['tracklets'], split['num_tracklets'], split['num_pids'], split['num_imgs_per_tracklet'] 261 | 262 | print("=> Automatically generating split (might take a while for the first time, have a coffe)") 263 | pdirs = glob.glob(osp.join(dir_path, '*')) # avoid .DS_Store 264 | print("Processing {} with {} person identities".format(dir_path, len(pdirs))) 265 | 266 | pid_container = set() 267 | for pdir in pdirs: 268 | pid = int(osp.basename(pdir)) 269 | pid_container.add(pid) 270 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 271 | 272 | tracklets = [] 273 | num_imgs_per_tracklet = [] 274 | for pdir in pdirs: 275 | pid = int(osp.basename(pdir)) 276 | if relabel: pid = pid2label[pid] 277 | tdirs = glob.glob(osp.join(pdir, '*')) 278 | for tdir in tdirs: 279 | raw_img_paths = glob.glob(osp.join(tdir, '*.jpg')) 280 | num_imgs = len(raw_img_paths) 281 | 282 | if num_imgs < self.min_seq_len: 283 | continue 284 | 285 | num_imgs_per_tracklet.append(num_imgs) 286 | img_paths = [] 287 | for img_idx in range(num_imgs): 288 | # some tracklet starts from 0002 instead of 0001 289 | img_idx_name = 'F' + str(img_idx+1).zfill(4) 290 | res = glob.glob(osp.join(tdir, '*' + img_idx_name + '*.jpg')) 291 | if len(res) == 0: 292 | print("Warn: index name {} in {} is missing, jump to next".format(img_idx_name, tdir)) 293 | continue 294 | img_paths.append(res[0]) 295 | img_name = osp.basename(img_paths[0]) 296 | if img_name.find('_') == -1: 297 | # old naming format: 0001C6F0099X30823.jpg 298 | camid = int(img_name[5]) - 1 299 | else: 300 | # new naming format: 0001_C6_F0099_X30823.jpg 301 | camid = int(img_name[6]) - 1 302 | img_paths = tuple(img_paths) 303 | tracklets.append((img_paths, pid, camid)) 304 | 305 | num_pids = len(pid_container) 306 | num_tracklets = len(tracklets) 307 | 308 | print("Saving split to {}".format(json_path)) 309 | split_dict = { 310 | 'tracklets': tracklets, 311 | 'num_tracklets': num_tracklets, 312 | 'num_pids': num_pids, 313 | 'num_imgs_per_tracklet': num_imgs_per_tracklet, 314 | } 315 | write_json(split_dict, json_path) 316 | 317 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 318 | 319 | 320 | class iLIDSVID(object): 321 | """ 322 | iLIDS-VID 323 | 324 | Reference: 325 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014. 326 | 327 | Dataset statistics: 328 | # identities: 300 329 | # tracklets: 600 330 | # cameras: 2 331 | 332 | Args: 333 | split_id (int): indicates which split to use. There are totally 10 splits. 334 | """ 335 | root = '../data/ilids-vid' 336 | dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar' 337 | data_dir = osp.join(root, 'i-LIDS-VID') 338 | split_dir = osp.join(root, 'train-test people splits') 339 | split_mat_path = osp.join(split_dir, 'train_test_splits_ilidsvid.mat') 340 | split_path = osp.join(root, 'splits.json') 341 | cam_1_path = osp.join(root, 'i-LIDS-VID/sequences/cam1') 342 | cam_2_path = osp.join(root, 'i-LIDS-VID/sequences/cam2') 343 | 344 | def __init__(self, split_id=0): 345 | self._download_data() 346 | self._check_before_run() 347 | 348 | self._prepare_split() 349 | splits = read_json(self.split_path) 350 | if split_id >= len(splits): 351 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 352 | split = splits[split_id] 353 | train_dirs, test_dirs = split['train'], split['test'] 354 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 355 | 356 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 357 | self._process_data(train_dirs, cam1=True, cam2=True) 358 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 359 | self._process_data(test_dirs, cam1=True, cam2=False) 360 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 361 | self._process_data(test_dirs, cam1=False, cam2=True) 362 | 363 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 364 | min_num = np.min(num_imgs_per_tracklet) 365 | max_num = np.max(num_imgs_per_tracklet) 366 | avg_num = np.mean(num_imgs_per_tracklet) 367 | 368 | num_total_pids = num_train_pids + num_query_pids 369 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 370 | 371 | print("=> iLIDS-VID loaded") 372 | print("Dataset statistics:") 373 | print(" ------------------------------") 374 | print(" subset | # ids | # tracklets") 375 | print(" ------------------------------") 376 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 377 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 378 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 379 | print(" ------------------------------") 380 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 381 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 382 | print(" ------------------------------") 383 | 384 | self.train = train 385 | self.query = query 386 | self.gallery = gallery 387 | 388 | self.num_train_pids = num_train_pids 389 | self.num_query_pids = num_query_pids 390 | self.num_gallery_pids = num_gallery_pids 391 | 392 | def _download_data(self): 393 | if osp.exists(self.root): 394 | print("This dataset has been downloaded.") 395 | return 396 | 397 | mkdir_if_missing(self.root) 398 | fpath = osp.join(self.root, osp.basename(self.dataset_url)) 399 | 400 | print("Downloading iLIDS-VID dataset") 401 | url_opener = urllib.URLopener() 402 | url_opener.retrieve(self.dataset_url, fpath) 403 | 404 | print("Extracting files") 405 | tar = tarfile.open(fpath) 406 | tar.extractall(path=self.root) 407 | tar.close() 408 | 409 | def _check_before_run(self): 410 | """Check if all files are available before going deeper""" 411 | if not osp.exists(self.root): 412 | raise RuntimeError("'{}' is not available".format(self.root)) 413 | if not osp.exists(self.data_dir): 414 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 415 | if not osp.exists(self.split_dir): 416 | raise RuntimeError("'{}' is not available".format(self.split_dir)) 417 | 418 | def _prepare_split(self): 419 | if not osp.exists(self.split_path): 420 | print("Creating splits") 421 | mat_split_data = loadmat(self.split_mat_path)['ls_set'] 422 | 423 | num_splits = mat_split_data.shape[0] 424 | num_total_ids = mat_split_data.shape[1] 425 | assert num_splits == 10 426 | assert num_total_ids == 300 427 | num_ids_each = num_total_ids/2 428 | 429 | # pids in mat_split_data are indices, so we need to transform them 430 | # to real pids 431 | person_cam1_dirs = os.listdir(self.cam_1_path) 432 | person_cam2_dirs = os.listdir(self.cam_2_path) 433 | 434 | # make sure persons in one camera view can be found in the other camera view 435 | assert set(person_cam1_dirs) == set(person_cam2_dirs) 436 | 437 | splits = [] 438 | for i_split in range(num_splits): 439 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14. 440 | train_idxs = sorted(list(mat_split_data[i_split,num_ids_each:])) 441 | test_idxs = sorted(list(mat_split_data[i_split,:num_ids_each])) 442 | 443 | train_idxs = [int(i)-1 for i in train_idxs] 444 | test_idxs = [int(i)-1 for i in test_idxs] 445 | 446 | # transform pids to person dir names 447 | train_dirs = [person_cam1_dirs[i] for i in train_idxs] 448 | test_dirs = [person_cam1_dirs[i] for i in test_idxs] 449 | 450 | split = {'train': train_dirs, 'test': test_dirs} 451 | splits.append(split) 452 | 453 | print("Totally {} splits are created, following Wang et al. ECCV'14".format(len(splits))) 454 | print("Split file is saved to {}".format(self.split_path)) 455 | write_json(splits, self.split_path) 456 | 457 | print("Splits created") 458 | 459 | def _process_data(self, dirnames, cam1=True, cam2=True): 460 | tracklets = [] 461 | num_imgs_per_tracklet = [] 462 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 463 | 464 | for dirname in dirnames: 465 | if cam1: 466 | person_dir = osp.join(self.cam_1_path, dirname) 467 | img_names = glob.glob(osp.join(person_dir, '*.png')) 468 | assert len(img_names) > 0 469 | img_names = tuple(img_names) 470 | pid = dirname2pid[dirname] 471 | tracklets.append((img_names, pid, 0)) 472 | num_imgs_per_tracklet.append(len(img_names)) 473 | 474 | if cam2: 475 | person_dir = osp.join(self.cam_2_path, dirname) 476 | img_names = glob.glob(osp.join(person_dir, '*.png')) 477 | assert len(img_names) > 0 478 | img_names = tuple(img_names) 479 | pid = dirname2pid[dirname] 480 | tracklets.append((img_names, pid, 1)) 481 | num_imgs_per_tracklet.append(len(img_names)) 482 | 483 | num_tracklets = len(tracklets) 484 | num_pids = len(dirnames) 485 | 486 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 487 | 488 | """Create dataset""" 489 | 490 | __factory = { 491 | 'mars': Mars, 492 | 'ilidsvid': iLIDSVID, 493 | 'dukemtmcvidreid': DukeMTMCVidReID, 494 | } 495 | 496 | def get_names(): 497 | return __factory.keys() 498 | 499 | def init_dataset(name, *args, **kwargs): 500 | if name not in __factory.keys(): 501 | raise KeyError("Unknown dataset: {}".format(name)) 502 | return __factory[name](*args, **kwargs) 503 | 504 | if __name__ == '__main__': 505 | # test 506 | #dataset = Market1501() 507 | #dataset = Mars() 508 | dataset = iLIDSVID() 509 | dataset = PRID() 510 | -------------------------------------------------------------------------------- /src/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | import copy 4 | 5 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 6 | num_q, num_g = distmat.shape 7 | if num_g < max_rank: 8 | max_rank = num_g 9 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 10 | indices = np.argsort(distmat, axis=1) 11 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 12 | 13 | # compute cmc curve for each query 14 | all_cmc = [] 15 | all_AP = [] 16 | num_valid_q = 0. 17 | for q_idx in range(num_q): 18 | # get query pid and camid 19 | q_pid = q_pids[q_idx] 20 | q_camid = q_camids[q_idx] 21 | 22 | # remove gallery samples that have the same pid and camid with query 23 | order = indices[q_idx] 24 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 25 | keep = np.invert(remove) 26 | 27 | # compute cmc curve 28 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 29 | if not np.any(orig_cmc): 30 | # this condition is true when query identity does not appear in gallery 31 | continue 32 | 33 | cmc = orig_cmc.cumsum() 34 | cmc[cmc > 1] = 1 35 | 36 | all_cmc.append(cmc[:max_rank]) 37 | num_valid_q += 1. 38 | 39 | # compute average precision 40 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 41 | num_rel = orig_cmc.sum() 42 | tmp_cmc = orig_cmc.cumsum() 43 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 44 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 45 | AP = tmp_cmc.sum() / num_rel 46 | all_AP.append(AP) 47 | 48 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 49 | 50 | all_cmc = np.asarray(all_cmc).astype(np.float32) 51 | all_cmc = all_cmc.sum(0) / num_valid_q 52 | mAP = np.mean(all_AP) 53 | 54 | return all_cmc, mAP 55 | 56 | 57 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | 7 | """ 8 | Shorthands for loss: 9 | - CrossEntropyLabelSmooth: xent 10 | - TripletLoss: htri 11 | - CenterLoss: cent 12 | """ 13 | __all__ = ['CrossEntropyLabelSmooth', 'TripletLoss'] 14 | 15 | 16 | class CrossEntropyLabelSmooth(nn.Module): 17 | """Cross entropy loss with label smoothing regularizer. 18 | 19 | Reference: 20 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 21 | Equation: y = (1 - epsilon) * y + epsilon / K. 22 | 23 | Args: 24 | num_classes (int): number of classes. 25 | epsilon (float): weight. 26 | """ 27 | 28 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 29 | super(CrossEntropyLabelSmooth, self).__init__() 30 | self.num_classes = num_classes 31 | self.epsilon = epsilon 32 | self.use_gpu = use_gpu 33 | self.logsoftmax = nn.LogSoftmax(dim=1) 34 | 35 | def forward(self, inputs, targets): 36 | """ 37 | Args: 38 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 39 | targets: ground truth labels with shape (num_classes) 40 | """ 41 | log_probs = self.logsoftmax(inputs) 42 | targets = torch.zeros(log_probs.size()).scatter_( 43 | 1, targets.unsqueeze(1).data.cpu(), 1) 44 | if self.use_gpu: 45 | targets = targets.cuda() 46 | targets = Variable(targets, requires_grad=False) 47 | targets = (1 - self.epsilon) * targets + \ 48 | self.epsilon / self.num_classes 49 | loss = (- targets * log_probs).mean(0).sum() 50 | return loss 51 | 52 | 53 | class TripletLoss(nn.Module): 54 | """Triplet loss with hard positive/negative mining. 55 | 56 | Reference: 57 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 58 | 59 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 60 | 61 | Args: 62 | margin (float): margin for triplet. 63 | """ 64 | 65 | def __init__(self, margin=0.3): 66 | super(TripletLoss, self).__init__() 67 | self.margin = margin 68 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 69 | 70 | def forward(self, inputs, targets): 71 | """ 72 | Args: 73 | inputs: feature matrix with shape (batch_size, feat_dim) 74 | targets: ground truth labels with shape (num_classes) 75 | """ 76 | n = inputs.size(0) 77 | # Compute pairwise distance, replace by the official when merged 78 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 79 | dist = dist + dist.t() 80 | dist.addmm_(1, -2, inputs, inputs.t()) 81 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 82 | # For each anchor, find the hardest positive and negative 83 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 84 | dist_ap, dist_an = [], [] 85 | for i in range(n): 86 | dist_ap.append(dist[i][mask[i]].max()) 87 | dist_an.append(dist[i][mask[i] == 0].min()) 88 | 89 | dist_ap = torch.stack(dist_ap) 90 | dist_an = torch.stack(dist_an) 91 | 92 | # Compute ranking hinge loss 93 | y = dist_an.data.new() 94 | y.resize_as_(dist_an.data) 95 | y.fill_(1) 96 | y = Variable(y) 97 | loss = self.ranking_loss(dist_an, dist_ap, y) 98 | return loss 99 | 100 | 101 | if __name__ == '__main__': 102 | pass 103 | -------------------------------------------------------------------------------- /src/main_video_person_reid.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import sys 4 | import time 5 | import datetime 6 | import argparse 7 | import os.path as osp 8 | import numpy as np 9 | from tqdm import tqdm 10 | import more_itertools as mit 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | from torch.utils.data import DataLoader 16 | from torch.autograd import Variable 17 | from torch.optim import lr_scheduler 18 | 19 | import data_manager 20 | from video_loader import VideoDataset 21 | import transforms as T 22 | import models 23 | import utils 24 | from losses import CrossEntropyLabelSmooth, TripletLoss 25 | from utils import AverageMeter, Logger, save_checkpoint, get_features 26 | from eval_metrics import evaluate 27 | from samplers import RandomIdentitySampler 28 | from project_utils import * 29 | 30 | parser = argparse.ArgumentParser( 31 | description='Train video model with cross entropy loss') 32 | # Datasets 33 | parser.add_argument('-d', '--dataset', type=str, default='mars', 34 | choices=data_manager.get_names()) 35 | parser.add_argument('-j', '--workers', default=0, type=int, 36 | help="number of data loading workers (default: 4)") 37 | parser.add_argument('--height', type=int, default=256, 38 | help="height of an image (default: 256)") 39 | parser.add_argument('--width', type=int, default=128, 40 | help="width of an image (default: 128)") 41 | parser.add_argument('--seq-len', type=int, default=4, 42 | help="number of images to sample in a tracklet") 43 | parser.add_argument('--test-num-tracks', type=int, default=16, 44 | help="number of tracklets to pass to GPU during test (to avoid OOM error)") 45 | # Optimization options 46 | parser.add_argument('--max-epoch', default=800, type=int, 47 | help="maximum epochs to run") 48 | parser.add_argument('--start-epoch', default=0, type=int, 49 | help="manual epoch number (useful on restarts)") 50 | parser.add_argument('--data-selection', type=str, 51 | default='random', help="random/evenly") 52 | parser.add_argument('--train-batch', default=32, type=int, 53 | help="train batch size") 54 | parser.add_argument('--test-batch', default=1, type=int, help="has to be 1") 55 | parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float, 56 | help="initial learning rate, use 0.0001 for rnn, use 0.0003 for pooling and attention") 57 | parser.add_argument('--stepsize', default=200, type=int, 58 | help="stepsize to decay learning rate (>0 means this is enabled)") 59 | parser.add_argument('--gamma', default=0.1, type=float, 60 | help="learning rate decay") 61 | parser.add_argument('--weight-decay', default=5e-04, type=float, 62 | help="weight decay (default: 5e-04)") 63 | parser.add_argument('--margin', type=float, default=0.3, 64 | help="margin for triplet loss") 65 | parser.add_argument('--num-instances', type=int, default=4, 66 | help="number of instances per identity") 67 | 68 | # Architecture 69 | parser.add_argument('-a', '--arch', type=str, default='se_resnet50_cosam45_tp', 70 | help="se_resnet50_tp, se_resnet50_ta") 71 | 72 | # Miscs 73 | parser.add_argument('--print-freq', type=int, 74 | default=78, help="print frequency") 75 | parser.add_argument('--seed', type=int, default=1, help="manual seed") 76 | parser.add_argument('--pretrained-model', type=str, 77 | default=None, help='need to be set for loading pretrained models') 78 | parser.add_argument('--evaluate', action='store_true', help="evaluation only") 79 | parser.add_argument('--eval-step', type=int, default=200, 80 | help="run evaluation for every N epochs (set to -1 to test after training)") 81 | parser.add_argument('--save-step', type=int, default=50) 82 | parser.add_argument('--save-dir', type=str, default='') 83 | parser.add_argument('--save-prefix', type=str, default='se_resnet50_cosam45_tp') 84 | parser.add_argument('--use-cpu', action='store_true', help="use cpu") 85 | parser.add_argument('--gpu-devices', default='1', type=str, 86 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 87 | 88 | args = parser.parse_args() 89 | 90 | 91 | def main(): 92 | args.save_dir = args.arch + '_' + args.save_dir 93 | args.save_prefix = args.arch + '_' + args.save_dir 94 | 95 | torch.manual_seed(args.seed) 96 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 97 | use_gpu = torch.cuda.is_available() 98 | if args.use_cpu: 99 | use_gpu = False 100 | 101 | # append date with save_dir 102 | args.save_dir = '../scratch/' + utils.get_currenttime_prefix() + '_' + \ 103 | args.dataset + '_' + args.save_dir 104 | if args.pretrained_model is not None: 105 | args.save_dir = os.path.dirname(args.pretrained_model) 106 | 107 | if not osp.exists(args.save_dir): 108 | os.makedirs(args.save_dir) 109 | 110 | if not args.evaluate: 111 | sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) 112 | else: 113 | sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt')) 114 | print("==========\nArgs:{}\n==========".format(args)) 115 | 116 | if use_gpu: 117 | print("Currently using GPU {}".format(args.gpu_devices)) 118 | cudnn.benchmark = True 119 | torch.cuda.manual_seed_all(args.seed) 120 | else: 121 | print("Currently using CPU (GPU is highly recommended)") 122 | 123 | print("Initializing dataset {}".format(args.dataset)) 124 | dataset = data_manager.init_dataset(name=args.dataset) 125 | 126 | transform_train = T.Compose([ 127 | T.Random2DTranslation(args.height, args.width), 128 | T.RandomHorizontalFlip(), 129 | T.ToTensor(), 130 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 131 | ]) 132 | 133 | transform_test = T.Compose([ 134 | T.Resize((args.height, args.width)), 135 | T.ToTensor(), 136 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 137 | ]) 138 | 139 | pin_memory = True if use_gpu else False 140 | 141 | trainloader = DataLoader( 142 | VideoDataset(dataset.train, seq_len=args.seq_len, 143 | sample=args.data_selection, transform=transform_train), 144 | sampler=RandomIdentitySampler( 145 | dataset.train, num_instances=args.num_instances), 146 | batch_size=args.train_batch, num_workers=args.workers, 147 | pin_memory=pin_memory, drop_last=True, 148 | ) 149 | 150 | queryloader = DataLoader( 151 | VideoDataset(dataset.query, seq_len=args.seq_len, 152 | sample='dense', transform=transform_test), 153 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 154 | pin_memory=pin_memory, drop_last=False, 155 | ) 156 | 157 | galleryloader = DataLoader( 158 | VideoDataset(dataset.gallery, seq_len=args.seq_len, 159 | sample='dense', transform=transform_test), 160 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 161 | pin_memory=pin_memory, drop_last=False, 162 | ) 163 | 164 | print("Initializing model: {}".format(args.arch)) 165 | model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, seq_len=args.seq_len) 166 | 167 | # pretrained model loading 168 | if args.pretrained_model is not None: 169 | if not os.path.exists(args.pretrained_model): 170 | raise IOError("Can't find pretrained model: {}".format( 171 | args.pretrained_model)) 172 | print("Loading checkpoint from '{}'".format(args.pretrained_model)) 173 | pretrained_state = torch.load(args.pretrained_model)['state_dict'] 174 | print(len(pretrained_state), ' keys in pretrained model') 175 | 176 | current_model_state = model.state_dict() 177 | pretrained_state = {key: val 178 | for key, val in pretrained_state.items() 179 | if key in current_model_state and val.size() == current_model_state[key].size()} 180 | 181 | print(len(pretrained_state), 182 | ' keys in pretrained model are available in current model') 183 | current_model_state.update(pretrained_state) 184 | model.load_state_dict(current_model_state) 185 | 186 | print("Model size: {:.5f}M".format(sum(p.numel() 187 | for p in model.parameters())/1000000.0)) 188 | 189 | if use_gpu: 190 | model = nn.DataParallel(model).cuda() 191 | 192 | criterion_xent = CrossEntropyLabelSmooth( 193 | num_classes=dataset.num_train_pids, use_gpu=use_gpu) 194 | criterion_htri = TripletLoss(margin=args.margin) 195 | 196 | optimizer = torch.optim.Adam( 197 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 198 | if args.stepsize > 0: 199 | scheduler = lr_scheduler.StepLR( 200 | optimizer, step_size=args.stepsize, gamma=args.gamma) 201 | start_epoch = args.start_epoch 202 | 203 | if args.evaluate: 204 | print("Evaluate only") 205 | test(model, queryloader, galleryloader, use_gpu) 206 | return 207 | 208 | start_time = time.time() 209 | best_rank1 = -np.inf 210 | 211 | is_first_time = True 212 | for epoch in range(start_epoch, args.max_epoch): 213 | print("==> Epoch {}/{}".format(epoch+1, args.max_epoch)) 214 | 215 | train(model, criterion_xent, criterion_htri, 216 | optimizer, trainloader, use_gpu) 217 | 218 | if args.stepsize > 0: 219 | scheduler.step() 220 | 221 | rank1 = 'NA' 222 | is_best = False 223 | 224 | if args.eval_step > 0 and (epoch+1) % args.eval_step == 0 or (epoch+1) == args.max_epoch: 225 | print("==> Test") 226 | rank1 = test(model, queryloader, galleryloader, use_gpu) 227 | is_best = rank1 > best_rank1 228 | if is_best: 229 | best_rank1 = rank1 230 | 231 | # save the model as required 232 | if (epoch+1) % args.save_step == 0: 233 | if use_gpu: 234 | state_dict = model.module.state_dict() 235 | else: 236 | state_dict = model.state_dict() 237 | 238 | save_checkpoint({ 239 | 'state_dict': state_dict, 240 | 'rank1': rank1, 241 | 'epoch': epoch, 242 | }, is_best, osp.join(args.save_dir, args.save_prefix + 'checkpoint_ep' + str(epoch+1) + '.pth.tar')) 243 | 244 | is_first_time = False 245 | if not is_first_time: 246 | utils.disable_all_print_once() 247 | 248 | elapsed = round(time.time() - start_time) 249 | elapsed = str(datetime.timedelta(seconds=elapsed)) 250 | print("Finished. Total elapsed time (h:m:s): {}".format(elapsed)) 251 | 252 | 253 | def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu): 254 | model.train() 255 | losses = AverageMeter() 256 | import time 257 | torch.backends.cudnn.benchmark = True 258 | 259 | start = time.time() 260 | for batch_idx, (imgs, pids, _) in enumerate(tqdm(trainloader)): 261 | if use_gpu: 262 | imgs, pids = imgs.cuda(), pids.cuda() 263 | 264 | imgs, pids = Variable(imgs), Variable(pids) 265 | 266 | outputs, features = model(imgs) 267 | 268 | # combine hard triplet loss with cross entropy loss 269 | xent_loss = criterion_xent(outputs, pids) 270 | htri_loss = criterion_htri(features, pids) 271 | loss = xent_loss + htri_loss 272 | 273 | optimizer.zero_grad() 274 | loss.backward() 275 | optimizer.step() 276 | 277 | losses.update(loss.item(), pids.size(0)) 278 | if (batch_idx+1) % args.print_freq == 0: 279 | print("Batch {}/{}\t Loss {:.6f} ({:.6f})".format(batch_idx + 280 | 1, len(trainloader), losses.val, losses.avg)) 281 | 282 | 283 | def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]): 284 | model.eval() 285 | 286 | qf, q_pids, q_camids = [], [], [] 287 | print('extracting query feats') 288 | for batch_idx, (imgs, pids, camids) in enumerate(tqdm(queryloader)): 289 | if use_gpu: 290 | imgs = imgs.cuda() 291 | 292 | with torch.no_grad(): 293 | #imgs = Variable(imgs, volatile=True) 294 | # b=1, n=number of clips, s=16 295 | b, n, s, c, h, w = imgs.size() 296 | assert(b == 1) 297 | imgs = imgs.view(b*n, s, c, h, w) 298 | features = get_features(model, imgs, args.test_num_tracks) 299 | features = torch.mean(features, 0) 300 | features = features.data.cpu() 301 | qf.append(features) 302 | q_pids.extend(pids) 303 | q_camids.extend(camids) 304 | torch.cuda.empty_cache() 305 | 306 | qf = torch.stack(qf) 307 | q_pids = np.asarray(q_pids) 308 | q_camids = np.asarray(q_camids) 309 | 310 | print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1))) 311 | 312 | gf, g_pids, g_camids = [], [], [] 313 | print('extracting gallery feats') 314 | for batch_idx, (imgs, pids, camids) in enumerate(tqdm(galleryloader)): 315 | if use_gpu: 316 | imgs = imgs.cuda() 317 | 318 | with torch.no_grad(): 319 | #imgs = Variable(imgs, volatile=True) 320 | b, n, s, c, h, w = imgs.size() 321 | imgs = imgs.view(b*n, s, c, h, w) 322 | assert(b == 1) 323 | # handle chunked data 324 | features = get_features(model, imgs, args.test_num_tracks) 325 | features = torch.mean(features, 0) 326 | torch.cuda.empty_cache() 327 | 328 | features = features.data.cpu() 329 | gf.append(features) 330 | g_pids.extend(pids) 331 | g_camids.extend(camids) 332 | gf = torch.stack(gf) 333 | g_pids = np.asarray(g_pids) 334 | g_camids = np.asarray(g_camids) 335 | 336 | print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1))) 337 | print("Computing distance matrix") 338 | 339 | m, n = qf.size(0), gf.size(0) 340 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 341 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 342 | distmat.addmm_(1, -2, qf, gf.t()) 343 | distmat = distmat.numpy() 344 | 345 | print("Computing CMC and mAP") 346 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) 347 | 348 | print("Results ----------") 349 | print("mAP: {:.1%}".format(mAP)) 350 | print("CMC curve") 351 | for r in ranks: 352 | print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1])) 353 | print("------------------") 354 | 355 | # re-ranking 356 | from person_re_ranking.python_version.re_ranking_feature import re_ranking 357 | rerank_distmat = re_ranking( 358 | qf.numpy(), gf.numpy(), k1=20, k2=6, lambda_value=0.3) 359 | print("Computing CMC and mAP for re-ranking") 360 | cmc, mAP = evaluate(rerank_distmat, q_pids, g_pids, q_camids, g_camids) 361 | print("Results ----------") 362 | print("mAP: {:.1%}".format(mAP)) 363 | print("CMC curve") 364 | for r in ranks: 365 | print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1])) 366 | print("------------------") 367 | 368 | return cmc[0] 369 | 370 | 371 | if __name__ == '__main__': 372 | main() 373 | -------------------------------------------------------------------------------- /src/models/ResNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import os 5 | import sys 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.autograd import Variable 9 | import torchvision 10 | import torchvision.models as models 11 | import utils 12 | from .aggregation_layers import AGGREGATION 13 | from .cosam import COSEG_ATTENTION 14 | 15 | 16 | def get_ResNet(net_type): 17 | if net_type == "resnet50": 18 | model = models.resnet50(pretrained=True) 19 | model = nn.Sequential(*list(model.children())[:-2]) 20 | elif net_type == "senet101": 21 | model = models.resnet101(pretrained=True) 22 | model = nn.Sequential(*list(model.children())[:-2]) 23 | else: 24 | assert False, "unknown ResNet type : " + net_type 25 | 26 | return model 27 | 28 | 29 | class Identity(nn.Module): 30 | def __init__(self, *args, **kwargs): 31 | super().__init__() 32 | print("instantiating " + self.__class__.__name__) 33 | 34 | def forward(self, x, b, t): 35 | return [x, None, None] 36 | 37 | 38 | class CosegAttention(nn.Module): 39 | def __init__(self, attention_types, num_feat_maps, h_w, t): 40 | super().__init__() 41 | print("instantiating " + self.__class__.__name__) 42 | self.attention_modules = nn.ModuleList() 43 | 44 | for i, attention_type in enumerate(attention_types): 45 | if attention_type in COSEG_ATTENTION: 46 | self.attention_modules.append( 47 | COSEG_ATTENTION[attention_type]( 48 | num_feat_maps[i], h_w=h_w[i], t=t) 49 | ) 50 | else: 51 | assert False, "unknown attention type " + attention_type 52 | 53 | def forward(self, x, i, b, t): 54 | return self.attention_modules[i](x, b, t) 55 | 56 | 57 | class ResNet(nn.Module): 58 | def __init__( 59 | self, 60 | num_classes, 61 | net_type="resnet50", 62 | attention_types=["NONE", "NONE", "NONE", "NONE"], 63 | aggregation_type="tp", 64 | seq_len=4, 65 | **kwargs 66 | ): 67 | super(ResNet, self).__init__() 68 | print( 69 | "instantiating " 70 | + self.__class__.__name__ 71 | + " net type" 72 | + net_type 73 | + " from " 74 | + __file__ 75 | ) 76 | print("attention type", attention_types) 77 | 78 | # base network instantiation 79 | self.base = get_ResNet(net_type=net_type) 80 | self.feat_dim = 2048 81 | 82 | # attention modules 83 | self.num_feat_maps = [256, 512, 1024, 2048] 84 | self.h_w = [(64, 32), (32, 16), (16, 8), (8, 4)] 85 | if aggregation_type == "ta": 86 | self.h_w = [(64, 32), (32, 16), (16, 8), (8, 4)] 87 | else: 88 | utils.set_stride(self.base[7], 1) 89 | self.h_w = [(64, 32), (32, 16), (16, 8), (16, 8)] 90 | 91 | print(self.h_w) 92 | self.attention_modules = CosegAttention( 93 | attention_types, num_feat_maps=self.num_feat_maps, h_w=self.h_w, t=seq_len 94 | ) 95 | 96 | # aggregation module 97 | self.aggregation = AGGREGATION[aggregation_type]( 98 | self.feat_dim, h_w=self.h_w[-1], t=seq_len 99 | ) 100 | 101 | # classifier 102 | self.classifier = nn.Linear(self.aggregation.feat_dim, num_classes) 103 | 104 | def base_layer0(self, x): 105 | x = self.base[0](x) 106 | x = self.base[1](x) 107 | x = self.base[2](x) 108 | x = self.base[3](x) 109 | x = self.base[4](x) 110 | return x 111 | 112 | def base_layer1(self, x): 113 | return self.base[5](x) 114 | 115 | def base_layer2(self, x): 116 | return self.base[6](x) 117 | 118 | def base_layer3(self, x): 119 | return self.base[7](x) 120 | 121 | def extract_features(self, x, b, t): 122 | features_in = x 123 | attentions = [] 124 | 125 | for index in range(len(self.num_feat_maps)): 126 | features_before_attention = getattr(self, "base_layer" + str(index))( 127 | features_in 128 | ) 129 | features_out, channelwise, spatialwise = self.attention_modules( 130 | features_before_attention, index, b, t 131 | ) 132 | features_in = features_out 133 | 134 | # note down the attentions 135 | attentions.append( 136 | (features_before_attention, features_out, channelwise, spatialwise) 137 | ) 138 | 139 | return features_out, attentions 140 | 141 | def return_values( 142 | self, features, logits, attentions, is_training, return_attentions 143 | ): 144 | buffer = [] 145 | if not is_training: 146 | buffer.append(features) 147 | if not return_attentions: 148 | return features 149 | else: 150 | buffer.append(logits) 151 | buffer.append(features) 152 | 153 | if return_attentions: 154 | buffer.append(attentions) 155 | 156 | return buffer 157 | 158 | def forward(self, x, return_attentions=False): 159 | b = x.size(0) 160 | t = x.size(1) 161 | x = x.view(b * t, x.size(2), x.size(3), x.size(4)) 162 | 163 | final_spatial_features, attentions = self.extract_features(x, b, t) 164 | f = self.aggregation(final_spatial_features, b, t) 165 | 166 | if not self.training: 167 | return self.return_values( 168 | f, None, attentions, self.training, return_attentions 169 | ) 170 | 171 | y = self.classifier(f) 172 | return self.return_values(f, y, attentions, self.training, return_attentions) 173 | 174 | 175 | # all mutual correlation attention models 176 | def ResNet50_COSAM45_TP(num_classes, **kwargs): 177 | return ResNet( 178 | num_classes, 179 | net_type="resnet50", 180 | aggregation_type="tp", 181 | attention_types=["NONE", "NONE", 182 | "COSAM", "COSAM"], 183 | **kwargs 184 | ) 185 | 186 | 187 | def ResNet50_COSAM45_TA(num_classes, **kwargs): 188 | return ResNet( 189 | num_classes, 190 | net_type="resnet50", 191 | aggregation_type="ta", 192 | attention_types=["NONE", "NONE", 193 | "COSAM", "COSAM"], 194 | **kwargs 195 | ) 196 | 197 | 198 | def ResNet50_COSAM45_RNN(num_classes, **kwargs): 199 | return ResNet( 200 | num_classes, 201 | net_type="resnet50", 202 | aggregation_type="rnn", 203 | attention_types=["NONE", "NONE", 204 | "COSAM", "COSAM"], 205 | **kwargs 206 | ) 207 | 208 | -------------------------------------------------------------------------------- /src/models/SE_ResNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import os 5 | import sys 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.autograd import Variable 9 | import torchvision 10 | import utils as utils 11 | 12 | this_path = os.path.split(__file__)[0] 13 | sys.path.append(this_path) 14 | from .aggregation_layers import AGGREGATION 15 | from .cosam import COSEG_ATTENTION 16 | import senet 17 | 18 | def get_SENet(net_type): 19 | if net_type == "senet50": 20 | model = senet.se_resnet50(pretrained=True) 21 | elif net_type == "senet101": 22 | model = senet.se_resnet101(pretrained=True) 23 | else: 24 | assert False, "unknown SE ResNet type : " + net_type 25 | 26 | return model 27 | 28 | 29 | class CosegAttention(nn.Module): 30 | def __init__(self, attention_types, num_feat_maps, h_w, t): 31 | super().__init__() 32 | print("instantiating " + self.__class__.__name__) 33 | self.attention_modules = nn.ModuleList() 34 | 35 | for i, attention_type in enumerate(attention_types): 36 | if attention_type in COSEG_ATTENTION: 37 | self.attention_modules.append( 38 | COSEG_ATTENTION[attention_type]( 39 | num_feat_maps[i], h_w=h_w[i], t=t) 40 | ) 41 | else: 42 | assert False, "unknown attention type " + attention_type 43 | 44 | def forward(self, x, i, b, t): 45 | return self.attention_modules[i](x, b, t) 46 | 47 | 48 | class SE_ResNet(nn.Module): 49 | def __init__( 50 | self, 51 | num_classes, 52 | net_type="senet50", 53 | attention_types=["NONE", "NONE", "NONE", "NONE", "NONE"], 54 | aggregation_type="tp", 55 | seq_len=4, 56 | is_baseline=False, 57 | **kwargs 58 | ): 59 | super(SE_ResNet, self).__init__() 60 | print( 61 | "instantiating " 62 | + self.__class__.__name__ 63 | + " net type" 64 | + net_type 65 | + " from " 66 | + __file__ 67 | ) 68 | print("attention type", attention_types) 69 | 70 | # base network instantiation 71 | self.base = get_SENet(net_type=net_type) 72 | self.feat_dim = self.base.feature_dim 73 | 74 | # attention modules 75 | self.num_feat_maps = [64, 256, 512, 1024, 2048] 76 | self.h_w = [(64, 32), (64, 32), (32, 16), (16, 8), (8, 4)] 77 | 78 | # allow reducing spatial dimension for Temporal attention (ta) to keep the params at a manageable number 79 | # according to the baseline paper 80 | if aggregation_type == "ta": 81 | self.h_w = [(64, 32), (64, 32), (32, 16), (16, 8), (8, 4)] 82 | else: 83 | utils.set_stride(self.base.layer4, 1) 84 | self.h_w = [(64, 32), (64, 32), (32, 16), (16, 8), (16, 8)] 85 | 86 | print(self.h_w) 87 | 88 | self.attention_modules = CosegAttention( 89 | attention_types, num_feat_maps=self.num_feat_maps, h_w=self.h_w, t=seq_len 90 | ) 91 | 92 | # aggregation module 93 | self.aggregation = AGGREGATION[aggregation_type]( 94 | self.feat_dim, h_w=self.h_w[-1], t=seq_len 95 | ) 96 | 97 | # classifier 98 | self.classifier = nn.Linear(self.aggregation.feat_dim, num_classes) 99 | 100 | def extract_features(self, x, b, t): 101 | features_in = x 102 | attentions = [] 103 | 104 | for index in range(5): 105 | features_before_attention = getattr(self.base, "layer" + str(index))( 106 | features_in 107 | ) 108 | features_out, channelwise, spatialwise = self.attention_modules( 109 | features_before_attention, index, b, t 110 | ) 111 | # print(features_out.shape) 112 | features_in = features_out 113 | 114 | # note down the attentions 115 | attentions.append( 116 | (features_before_attention, features_out, channelwise, spatialwise) 117 | ) 118 | 119 | return features_out, attentions 120 | 121 | def return_values( 122 | self, features, logits, attentions, is_training, return_attentions 123 | ): 124 | buffer = [] 125 | if not is_training: 126 | buffer.append(features) 127 | if not return_attentions: 128 | return features 129 | else: 130 | buffer.append(logits) 131 | buffer.append(features) 132 | 133 | if return_attentions: 134 | buffer.append(attentions) 135 | 136 | return buffer 137 | 138 | def forward(self, x, return_attentions=False): 139 | b = x.size(0) 140 | t = x.size(1) 141 | x = x.view(b * t, x.size(2), x.size(3), x.size(4)) 142 | 143 | final_spatial_features, attentions = self.extract_features(x, b, t) 144 | f = self.aggregation(final_spatial_features, b, t) 145 | 146 | if not self.training: 147 | return self.return_values( 148 | f, None, attentions, self.training, return_attentions 149 | ) 150 | 151 | y = self.classifier(f) 152 | return self.return_values(f, y, attentions, self.training, return_attentions) 153 | 154 | 155 | 156 | # all attention models 157 | def SE_ResNet50_COSAM45_TP(num_classes, **kwargs): 158 | return SE_ResNet( 159 | num_classes, 160 | net_type="senet50", 161 | aggregation_type="tp", 162 | attention_types=[ 163 | "NONE", 164 | "NONE", 165 | "NONE", 166 | "COSAM", 167 | "COSAM", 168 | ], 169 | **kwargs 170 | ) 171 | 172 | 173 | def SE_ResNet50_COSAM45_TA(num_classes, **kwargs): 174 | return SE_ResNet( 175 | num_classes, 176 | net_type="senet50", 177 | aggregation_type="ta", 178 | attention_types=[ 179 | "NONE", 180 | "NONE", 181 | "NONE", 182 | "COSAM", 183 | "COSAM", 184 | ], 185 | **kwargs 186 | ) 187 | 188 | 189 | def SE_ResNet50_COSAM45_RNN(num_classes, **kwargs): 190 | return SE_ResNet( 191 | num_classes, 192 | net_type="senet50", 193 | aggregation_type="rnn", 194 | attention_types=[ 195 | "NONE", 196 | "NONE", 197 | "NONE", 198 | "COSAM", 199 | "COSAM", 200 | ], 201 | **kwargs 202 | ) 203 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .ResNet import * 4 | from .SE_ResNet import * 5 | 6 | __factory = { 7 | # ResNet50 network 8 | "resnet50_cosam45_tp": ResNet50_COSAM45_TP, 9 | "resnet50_cosam45_ta": ResNet50_COSAM45_TA, 10 | "resnet50_cosam45_rnn": ResNet50_COSAM45_RNN, 11 | 12 | # Squeeze and Expand network 13 | "se_resnet50_cosam45_tp": SE_ResNet50_COSAM45_TP, 14 | "se_resnet50_cosam45_ta": SE_ResNet50_COSAM45_TA, 15 | "se_resnet50_cosam45_rnn": SE_ResNet50_COSAM45_RNN, 16 | 17 | } 18 | 19 | 20 | def get_names(): 21 | return __factory.keys() 22 | 23 | 24 | def init_model(name, *args, **kwargs): 25 | if name not in __factory.keys(): 26 | raise KeyError("Unknown model: {}".format(name)) 27 | return __factory[name](*args, **kwargs) 28 | -------------------------------------------------------------------------------- /src/models/aggregation_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AggregationTP(nn.Module): 7 | def __init__(self, feat_dim, *args, **kwargs): 8 | super().__init__() 9 | print("instantiating " + self.__class__.__name__) 10 | 11 | self.feat_dim = feat_dim 12 | 13 | def forward(self, x, b, t): 14 | x = F.avg_pool2d(x, x.size()[2:]) 15 | x = x.view(b, t, -1) 16 | x = x.permute(0, 2, 1) 17 | f = F.avg_pool1d(x, t) 18 | f = f.view(b, self.feat_dim) 19 | return f 20 | 21 | 22 | class AggregationTA(nn.Module): 23 | def __init__(self, feat_dim, *args, **kwargs): 24 | super().__init__() 25 | print("instantiating " + self.__class__.__name__) 26 | 27 | self.feat_dim = feat_dim 28 | self.middle_dim = 256 29 | self.attention_conv = nn.Conv2d( 30 | self.feat_dim, self.middle_dim, [8, 4] 31 | ) # 8, 4 cooresponds to 256, 128 input image size 32 | self.attention_tconv = nn.Conv1d(self.middle_dim, 1, 3, padding=1) 33 | 34 | def forward(self, x, b, t): 35 | 36 | # spatial attention 37 | a = F.relu(self.attention_conv(x)) 38 | 39 | # arrange into batch temporal view 40 | a = a.view(b, t, self.middle_dim) 41 | 42 | # temporal attention 43 | a = a.permute(0, 2, 1) 44 | a = F.relu(self.attention_tconv(a)) 45 | a = a.view(b, t) 46 | a = F.softmax(a, dim=1) 47 | 48 | # global avg pooling of conv features 49 | x = F.avg_pool2d(x, x.size()[2:]) 50 | 51 | # apply temporal attention 52 | x = x.view(b, t, -1) 53 | a = torch.unsqueeze(a, -1) 54 | a = a.expand(b, t, self.feat_dim) 55 | att_x = torch.mul(x, a) 56 | att_x = torch.sum(att_x, 1) 57 | f = att_x.view(b, self.feat_dim) 58 | 59 | return f 60 | 61 | 62 | class AggregationRNN(nn.Module): 63 | def __init__(self, feat_dim, *args, **kwargs): 64 | super().__init__() 65 | print("instantiating " + self.__class__.__name__) 66 | 67 | self.hidden_dim = 512 68 | self.lstm = nn.LSTM( 69 | input_size=feat_dim, 70 | hidden_size=self.hidden_dim, 71 | num_layers=1, 72 | batch_first=True, 73 | ) 74 | self.feat_dim = self.hidden_dim 75 | 76 | def forward(self, x, b, t): 77 | x = F.avg_pool2d(x, x.size()[2:]) 78 | x = x.view(b, t, -1) 79 | 80 | # apply LSTM 81 | output, (h_n, c_n) = self.lstm(x) 82 | output = output.permute(0, 2, 1) 83 | f = F.avg_pool1d(output, t) 84 | f = f.view(b, self.hidden_dim) 85 | 86 | return f 87 | 88 | 89 | AGGREGATION = { 90 | "tp": AggregationTP, 91 | "ta": AggregationTA, 92 | "rnn": AggregationRNN, 93 | } 94 | -------------------------------------------------------------------------------- /src/models/cosam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import utils as utils 4 | import torch.nn.functional as F 5 | 6 | class Identity(nn.Module): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__() 9 | print("instantiating " + self.__class__.__name__) 10 | 11 | def forward(self, x, b, t): 12 | return [x, None, None] 13 | 14 | 15 | class COSAM(nn.Module): 16 | def __init__(self, in_channels, t, h_w, *args): 17 | super().__init__() 18 | print( 19 | "instantiating " 20 | + self.__class__.__name__ 21 | + " with in-channels: " 22 | + str(in_channels) 23 | + ", t = " 24 | + str(t) 25 | + ", hw = " 26 | + str(h_w) 27 | ) 28 | 29 | self.eps = 1e-4 30 | self.h_w = h_w 31 | self.t = t 32 | self.mid_channels = 256 33 | if in_channels <= 256: 34 | self.mid_channels = 64 35 | 36 | self.dim_reduction = nn.Sequential( 37 | nn.Conv2d(in_channels, self.mid_channels, kernel_size=1), 38 | nn.BatchNorm2d(self.mid_channels), 39 | nn.ReLU(), 40 | ) 41 | 42 | self.spatial_mask_summary = nn.Sequential( 43 | nn.Conv2d((t - 1) * h_w[0] * h_w[1], 64, kernel_size=1), 44 | nn.BatchNorm2d(64), 45 | nn.ReLU(), 46 | nn.Conv2d(64, 1, kernel_size=1), 47 | ) 48 | 49 | self.channelwise_attention = nn.Sequential( 50 | nn.Linear(in_channels, self.mid_channels), 51 | nn.Tanh(), 52 | nn.Linear(self.mid_channels, in_channels), 53 | nn.Sigmoid(), 54 | ) 55 | 56 | def get_selectable_indices(self, t): 57 | init_list = list(range(t)) 58 | index_list = [] 59 | for i in range(t): 60 | list_instance = list(init_list) 61 | list_instance.remove(i) 62 | index_list.append(list_instance) 63 | 64 | return index_list 65 | 66 | def get_channelwise_attention(self, feat_maps, b, t): 67 | num_imgs, num_channels, h, w = feat_maps.shape 68 | 69 | # perform global average pooling 70 | channel_avg_pool = F.adaptive_avg_pool2d(feat_maps, output_size=1) 71 | # pass the global average pooled features through the fully connected network with sigmoid activation 72 | channelwise_attention = self.channelwise_attention(channel_avg_pool.view(num_imgs, -1)) 73 | 74 | # perform group attention 75 | # groupify the attentions 76 | idwise_channelattention = channelwise_attention.view(b, t, -1) 77 | 78 | # take the mean of attention to attend common channels between frames 79 | group_attention = torch.mean(idwise_channelattention, dim=1, keepdim=True).expand_as(idwise_channelattention) 80 | channelwise_attention = group_attention.contiguous().view(num_imgs, num_channels, 1, 1) 81 | 82 | return channelwise_attention 83 | 84 | def get_spatial_attention(self, feat_maps, b, t): 85 | total, c, h, w = feat_maps.shape 86 | dim_reduced_featuremaps = self.dim_reduction(feat_maps) # #frames x C x H x W 87 | 88 | # resize the feature maps for temporal processing 89 | identitywise_maps = dim_reduced_featuremaps.view(b, t, dim_reduced_featuremaps.shape[1], h, w) 90 | 91 | # get the combination of frame indices 92 | index_list = self.get_selectable_indices(t) 93 | 94 | # select the other images within same id 95 | other_selected_imgs = identitywise_maps[:, index_list] 96 | other_selected_imgs = other_selected_imgs.view(-1, t - 1, dim_reduced_featuremaps.shape[1], h, w) # #frames x t-1 x C x H x W 97 | 98 | # permutate the other dimensions except descriptor dim to last 99 | other_selected_imgs = other_selected_imgs.permute((0, 2, 1, 3, 4)) # #frames x C x t-1 x H x W 100 | 101 | # prepare two matrices for multiplication 102 | dim_reduced_featuremaps = dim_reduced_featuremaps.view(total, self.mid_channels, -1) # #frames x C x (H * W) 103 | dim_reduced_featuremaps = dim_reduced_featuremaps.permute((0, 2, 1)) # frames x (H * W) x C 104 | other_selected_imgs = other_selected_imgs.contiguous().view(total, self.mid_channels, -1) # #frames x C x (t-1 * H * W) 105 | 106 | # mean subtract and divide by std 107 | dim_reduced_featuremaps = dim_reduced_featuremaps - torch.mean(dim_reduced_featuremaps, dim=2, keepdim=True) 108 | dim_reduced_featuremaps = dim_reduced_featuremaps / (torch.std(dim_reduced_featuremaps, dim=2, keepdim=True) + self.eps) 109 | other_selected_imgs = other_selected_imgs - torch.mean(other_selected_imgs, dim=1, keepdim=True) 110 | other_selected_imgs = other_selected_imgs / (torch.std(other_selected_imgs, dim=1, keepdim=True) + self.eps) 111 | 112 | mutual_correlation = (torch.bmm(dim_reduced_featuremaps, other_selected_imgs) 113 | / other_selected_imgs.shape[1]) # #frames x (HW) x (t-1 * H * W) 114 | 115 | mutual_correlation = mutual_correlation.permute(0, 2, 1) # #frames x (t-1 * H * W) x (HW) 116 | mutual_correlation = mutual_correlation.view(total, -1, h, w) # #frames x (t-1 * H * W) x H x W 117 | mutual_correlation_mask = self.spatial_mask_summary(mutual_correlation).sigmoid() # #frames x 1 x H x W 118 | 119 | return mutual_correlation_mask 120 | 121 | def forward(self, feat_maps, b, t): 122 | # get the spatial attention mask 123 | mutualcorr_spatial_mask = self.get_spatial_attention( 124 | feat_maps=feat_maps, b=b, t=t 125 | ) 126 | attended_out = torch.mul(feat_maps, mutualcorr_spatial_mask) 127 | 128 | # channel-wise attention 129 | channelwise_mask = self.get_channelwise_attention( 130 | feat_maps=attended_out, b=b, t=t 131 | ) 132 | attended_out = attended_out + torch.mul(attended_out, channelwise_mask) 133 | 134 | return attended_out, channelwise_mask, mutualcorr_spatial_mask 135 | 136 | 137 | COSEG_ATTENTION = { 138 | "NONE": Identity, 139 | "COSAM": COSAM, 140 | } 141 | -------------------------------------------------------------------------------- /src/models/senet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from collections import OrderedDict 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils import model_zoo 10 | from torch.nn import functional as F 11 | import torchvision 12 | 13 | 14 | """ 15 | Code imported from https://github.com/Cadene/pretrained-models.pytorch 16 | """ 17 | 18 | 19 | __all__ = ['senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', 'se_resnext101_32x4d', 20 | 'se_resnet50_fc512'] 21 | 22 | 23 | pretrained_settings = { 24 | 'senet154': { 25 | 'imagenet': { 26 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', 27 | 'input_space': 'RGB', 28 | 'input_size': [3, 224, 224], 29 | 'input_range': [0, 1], 30 | 'mean': [0.485, 0.456, 0.406], 31 | 'std': [0.229, 0.224, 0.225], 32 | 'num_classes': 1000 33 | } 34 | }, 35 | 'se_resnet50': { 36 | 'imagenet': { 37 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', 38 | 'input_space': 'RGB', 39 | 'input_size': [3, 224, 224], 40 | 'input_range': [0, 1], 41 | 'mean': [0.485, 0.456, 0.406], 42 | 'std': [0.229, 0.224, 0.225], 43 | 'num_classes': 1000 44 | } 45 | }, 46 | 'se_resnet101': { 47 | 'imagenet': { 48 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', 49 | 'input_space': 'RGB', 50 | 'input_size': [3, 224, 224], 51 | 'input_range': [0, 1], 52 | 'mean': [0.485, 0.456, 0.406], 53 | 'std': [0.229, 0.224, 0.225], 54 | 'num_classes': 1000 55 | } 56 | }, 57 | 'se_resnet152': { 58 | 'imagenet': { 59 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', 60 | 'input_space': 'RGB', 61 | 'input_size': [3, 224, 224], 62 | 'input_range': [0, 1], 63 | 'mean': [0.485, 0.456, 0.406], 64 | 'std': [0.229, 0.224, 0.225], 65 | 'num_classes': 1000 66 | } 67 | }, 68 | 'se_resnext50_32x4d': { 69 | 'imagenet': { 70 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', 71 | 'input_space': 'RGB', 72 | 'input_size': [3, 224, 224], 73 | 'input_range': [0, 1], 74 | 'mean': [0.485, 0.456, 0.406], 75 | 'std': [0.229, 0.224, 0.225], 76 | 'num_classes': 1000 77 | } 78 | }, 79 | 'se_resnext101_32x4d': { 80 | 'imagenet': { 81 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', 82 | 'input_space': 'RGB', 83 | 'input_size': [3, 224, 224], 84 | 'input_range': [0, 1], 85 | 'mean': [0.485, 0.456, 0.406], 86 | 'std': [0.229, 0.224, 0.225], 87 | 'num_classes': 1000 88 | } 89 | }, 90 | } 91 | 92 | 93 | class SEModule(nn.Module): 94 | 95 | def __init__(self, channels, reduction): 96 | super(SEModule, self).__init__() 97 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 98 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0) 101 | self.sigmoid = nn.Sigmoid() 102 | 103 | def forward(self, x): 104 | module_input = x 105 | x = self.avg_pool(x) 106 | x = self.fc1(x) 107 | x = self.relu(x) 108 | x = self.fc2(x) 109 | x = self.sigmoid(x) 110 | return module_input * x 111 | 112 | 113 | class Bottleneck(nn.Module): 114 | """ 115 | Base class for bottlenecks that implements `forward()` method. 116 | """ 117 | def forward(self, x): 118 | residual = x 119 | 120 | out = self.conv1(x) 121 | out = self.bn1(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv2(out) 125 | out = self.bn2(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv3(out) 129 | out = self.bn3(out) 130 | 131 | if self.downsample is not None: 132 | residual = self.downsample(x) 133 | 134 | out = self.se_module(out) + residual 135 | out = self.relu(out) 136 | 137 | return out 138 | 139 | 140 | class SEBottleneck(Bottleneck): 141 | """ 142 | Bottleneck for SENet154. 143 | """ 144 | expansion = 4 145 | 146 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 147 | downsample=None): 148 | super(SEBottleneck, self).__init__() 149 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 150 | self.bn1 = nn.BatchNorm2d(planes * 2) 151 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, 152 | stride=stride, padding=1, groups=groups, 153 | bias=False) 154 | self.bn2 = nn.BatchNorm2d(planes * 4) 155 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, 156 | bias=False) 157 | self.bn3 = nn.BatchNorm2d(planes * 4) 158 | self.relu = nn.ReLU(inplace=True) 159 | self.se_module = SEModule(planes * 4, reduction=reduction) 160 | self.downsample = downsample 161 | self.stride = stride 162 | 163 | 164 | class SEResNetBottleneck(Bottleneck): 165 | """ 166 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe 167 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 168 | (the latter is used in the torchvision implementation of ResNet). 169 | """ 170 | expansion = 4 171 | 172 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 173 | downsample=None): 174 | super(SEResNetBottleneck, self).__init__() 175 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, 176 | stride=stride) 177 | self.bn1 = nn.BatchNorm2d(planes) 178 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, 179 | groups=groups, bias=False) 180 | self.bn2 = nn.BatchNorm2d(planes) 181 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 182 | self.bn3 = nn.BatchNorm2d(planes * 4) 183 | self.relu = nn.ReLU(inplace=True) 184 | self.se_module = SEModule(planes * 4, reduction=reduction) 185 | self.downsample = downsample 186 | self.stride = stride 187 | 188 | 189 | class SEResNeXtBottleneck(Bottleneck): 190 | """ 191 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 192 | """ 193 | expansion = 4 194 | 195 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 196 | downsample=None, base_width=4): 197 | super(SEResNeXtBottleneck, self).__init__() 198 | width = int(math.floor(planes * (base_width / 64.)) * groups) 199 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, 200 | stride=1) 201 | self.bn1 = nn.BatchNorm2d(width) 202 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 203 | padding=1, groups=groups, bias=False) 204 | self.bn2 = nn.BatchNorm2d(width) 205 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 206 | self.bn3 = nn.BatchNorm2d(planes * 4) 207 | self.relu = nn.ReLU(inplace=True) 208 | self.se_module = SEModule(planes * 4, reduction=reduction) 209 | self.downsample = downsample 210 | self.stride = stride 211 | 212 | 213 | class SENet(nn.Module): 214 | """ 215 | Squeeze-and-excitation network 216 | 217 | Reference: 218 | Hu et al. Squeeze-and-Excitation Networks. CVPR 2018. 219 | """ 220 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2, 221 | inplanes=128, input_3x3=True, downsample_kernel_size=3, downsample_padding=1, 222 | last_stride=2, fc_dims=None, **kwargs): 223 | """ 224 | Parameters 225 | ---------- 226 | block (nn.Module): Bottleneck class. 227 | - For SENet154: SEBottleneck 228 | - For SE-ResNet models: SEResNetBottleneck 229 | - For SE-ResNeXt models: SEResNeXtBottleneck 230 | layers (list of ints): Number of residual blocks for 4 layers of the 231 | network (layer1...layer4). 232 | groups (int): Number of groups for the 3x3 convolution in each 233 | bottleneck block. 234 | - For SENet154: 64 235 | - For SE-ResNet models: 1 236 | - For SE-ResNeXt models: 32 237 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 238 | - For all models: 16 239 | dropout_p (float or None): Drop probability for the Dropout layer. 240 | If `None` the Dropout layer is not used. 241 | - For SENet154: 0.2 242 | - For SE-ResNet models: None 243 | - For SE-ResNeXt models: None 244 | inplanes (int): Number of input channels for layer1. 245 | - For SENet154: 128 246 | - For SE-ResNet models: 64 247 | - For SE-ResNeXt models: 64 248 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 249 | a single 7x7 convolution in layer0. 250 | - For SENet154: True 251 | - For SE-ResNet models: False 252 | - For SE-ResNeXt models: False 253 | downsample_kernel_size (int): Kernel size for downsampling convolutions 254 | in layer2, layer3 and layer4. 255 | - For SENet154: 3 256 | - For SE-ResNet models: 1 257 | - For SE-ResNeXt models: 1 258 | downsample_padding (int): Padding for downsampling convolutions in 259 | layer2, layer3 and layer4. 260 | - For SENet154: 1 261 | - For SE-ResNet models: 0 262 | - For SE-ResNeXt models: 0 263 | num_classes (int): Number of outputs in `classifier` layer. 264 | """ 265 | super(SENet, self).__init__() 266 | self.inplanes = inplanes 267 | 268 | if input_3x3: 269 | layer0_modules = [ 270 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, 271 | bias=False)), 272 | ('bn1', nn.BatchNorm2d(64)), 273 | ('relu1', nn.ReLU(inplace=True)), 274 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, 275 | bias=False)), 276 | ('bn2', nn.BatchNorm2d(64)), 277 | ('relu2', nn.ReLU(inplace=True)), 278 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, 279 | bias=False)), 280 | ('bn3', nn.BatchNorm2d(inplanes)), 281 | ('relu3', nn.ReLU(inplace=True)), 282 | ] 283 | else: 284 | layer0_modules = [ 285 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, 286 | padding=3, bias=False)), 287 | ('bn1', nn.BatchNorm2d(inplanes)), 288 | ('relu1', nn.ReLU(inplace=True)), 289 | ] 290 | # To preserve compatibility with Caffe weights `ceil_mode=True` 291 | # is used instead of `padding=1`. 292 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, 293 | ceil_mode=True))) 294 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 295 | self.layer1 = self._make_layer( 296 | block, 297 | planes=64, 298 | blocks=layers[0], 299 | groups=groups, 300 | reduction=reduction, 301 | downsample_kernel_size=1, 302 | downsample_padding=0 303 | ) 304 | self.layer2 = self._make_layer( 305 | block, 306 | planes=128, 307 | blocks=layers[1], 308 | stride=2, 309 | groups=groups, 310 | reduction=reduction, 311 | downsample_kernel_size=downsample_kernel_size, 312 | downsample_padding=downsample_padding 313 | ) 314 | self.layer3 = self._make_layer( 315 | block, 316 | planes=256, 317 | blocks=layers[2], 318 | stride=2, 319 | groups=groups, 320 | reduction=reduction, 321 | downsample_kernel_size=downsample_kernel_size, 322 | downsample_padding=downsample_padding 323 | ) 324 | self.layer4 = self._make_layer( 325 | block, 326 | planes=512, 327 | blocks=layers[3], 328 | stride=last_stride, 329 | groups=groups, 330 | reduction=reduction, 331 | downsample_kernel_size=downsample_kernel_size, 332 | downsample_padding=downsample_padding 333 | ) 334 | 335 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 336 | self.fc = self._construct_fc_layer(fc_dims, 512 * block.expansion, dropout_p) 337 | 338 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 339 | downsample_kernel_size=1, downsample_padding=0): 340 | downsample = None 341 | if stride != 1 or self.inplanes != planes * block.expansion: 342 | downsample = nn.Sequential( 343 | nn.Conv2d(self.inplanes, planes * block.expansion, 344 | kernel_size=downsample_kernel_size, stride=stride, 345 | padding=downsample_padding, bias=False), 346 | nn.BatchNorm2d(planes * block.expansion), 347 | ) 348 | 349 | layers = [] 350 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 351 | downsample)) 352 | self.inplanes = planes * block.expansion 353 | for i in range(1, blocks): 354 | layers.append(block(self.inplanes, planes, groups, reduction)) 355 | 356 | return nn.Sequential(*layers) 357 | 358 | def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): 359 | """ 360 | Construct fully connected layer 361 | 362 | - fc_dims (list or tuple): dimensions of fc layers, if None, 363 | no fc layers are constructed 364 | - input_dim (int): input dimension 365 | - dropout_p (float): dropout probability, if None, dropout is unused 366 | """ 367 | if fc_dims is None: 368 | self.feature_dim = input_dim 369 | return None 370 | 371 | assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either list or tuple, but got {}'.format(type(fc_dims)) 372 | 373 | layers = [] 374 | for dim in fc_dims: 375 | layers.append(nn.Linear(input_dim, dim)) 376 | layers.append(nn.BatchNorm1d(dim)) 377 | layers.append(nn.ReLU(inplace=True)) 378 | if dropout_p is not None: 379 | layers.append(nn.Dropout(p=dropout_p)) 380 | input_dim = dim 381 | 382 | self.feature_dim = fc_dims[-1] 383 | 384 | return nn.Sequential(*layers) 385 | 386 | def featuremaps(self, x): 387 | x = self.layer0(x) 388 | x = self.layer1(x) 389 | x = self.layer2(x) 390 | x = self.layer3(x) 391 | x = self.layer4(x) 392 | return x 393 | 394 | def forward(self, x): 395 | f = self.featuremaps(x) 396 | return f 397 | 398 | def init_pretrained_weights(model, model_url): 399 | """ 400 | Initialize model with pretrained weights. 401 | Layers that don't match with pretrained layers in name or size are kept unchanged. 402 | """ 403 | pretrain_dict = model_zoo.load_url(model_url) 404 | model_dict = model.state_dict() 405 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 406 | model_dict.update(pretrain_dict) 407 | model.load_state_dict(model_dict) 408 | print('Initialized model with pretrained weights from {}'.format(model_url)) 409 | 410 | 411 | def senet154(pretrained=True, **kwargs): 412 | model = SENet( 413 | num_classes=num_classes, 414 | loss=loss, 415 | block=SEBottleneck, 416 | layers=[3, 8, 36, 3], 417 | groups=64, 418 | reduction=16, 419 | dropout_p=0.2, 420 | last_stride=2, 421 | fc_dims=None, 422 | **kwargs 423 | ) 424 | if pretrained: 425 | model_url = pretrained_settings['senet154']['imagenet']['url'] 426 | init_pretrained_weights(model, model_url) 427 | return model 428 | 429 | 430 | def se_resnet50(pretrained=True, **kwargs): 431 | model = SENet( 432 | block=SEResNetBottleneck, 433 | layers=[3, 4, 6, 3], 434 | groups=1, 435 | reduction=16, 436 | dropout_p=None, 437 | inplanes=64, 438 | input_3x3=False, 439 | downsample_kernel_size=1, 440 | downsample_padding=0, 441 | last_stride=2, 442 | fc_dims=None, 443 | **kwargs 444 | ) 445 | if pretrained: 446 | model_url = pretrained_settings['se_resnet50']['imagenet']['url'] 447 | init_pretrained_weights(model, model_url) 448 | return model 449 | 450 | 451 | def se_resnet50_fc512(pretrained=True, **kwargs): 452 | model = SENet( 453 | block=SEResNetBottleneck, 454 | layers=[3, 4, 6, 3], 455 | groups=1, 456 | reduction=16, 457 | dropout_p=None, 458 | inplanes=64, 459 | input_3x3=False, 460 | downsample_kernel_size=1, 461 | downsample_padding=0, 462 | last_stride=1, 463 | fc_dims=[512], 464 | **kwargs 465 | ) 466 | if pretrained: 467 | model_url = pretrained_settings['se_resnet50']['imagenet']['url'] 468 | init_pretrained_weights(model, model_url) 469 | return model 470 | 471 | 472 | def se_resnet101(pretrained=True, **kwargs): 473 | model = SENet( 474 | block=SEResNetBottleneck, 475 | layers=[3, 4, 23, 3], 476 | groups=1, 477 | reduction=16, 478 | dropout_p=None, 479 | inplanes=64, 480 | input_3x3=False, 481 | downsample_kernel_size=1, 482 | downsample_padding=0, 483 | last_stride=2, 484 | fc_dims=None, 485 | **kwargs 486 | ) 487 | if pretrained: 488 | model_url = pretrained_settings['se_resnet101']['imagenet']['url'] 489 | init_pretrained_weights(model, model_url) 490 | return model 491 | 492 | 493 | def se_resnet152(pretrained=True, **kwargs): 494 | model = SENet( 495 | block=SEResNetBottleneck, 496 | layers=[3, 8, 36, 3], 497 | groups=1, 498 | reduction=16, 499 | dropout_p=None, 500 | inplanes=64, 501 | input_3x3=False, 502 | downsample_kernel_size=1, 503 | downsample_padding=0, 504 | last_stride=2, 505 | fc_dims=None, 506 | **kwargs 507 | ) 508 | if pretrained: 509 | model_url = pretrained_settings['se_resnet152']['imagenet']['url'] 510 | init_pretrained_weights(model, model_url) 511 | return model 512 | 513 | 514 | def se_resnext50_32x4d(pretrained=True, **kwargs): 515 | model = SENet( 516 | block=SEResNeXtBottleneck, 517 | layers=[3, 4, 6, 3], 518 | groups=32, 519 | reduction=16, 520 | dropout_p=None, 521 | inplanes=64, 522 | input_3x3=False, 523 | downsample_kernel_size=1, 524 | downsample_padding=0, 525 | last_stride=2, 526 | fc_dims=None, 527 | **kwargs 528 | ) 529 | if pretrained: 530 | model_url = pretrained_settings['se_resnext50_32x4d']['imagenet']['url'] 531 | init_pretrained_weights(model, model_url) 532 | return model 533 | 534 | 535 | def se_resnext101_32x4d(pretrained=True, **kwargs): 536 | model = SENet( 537 | block=SEResNeXtBottleneck, 538 | layers=[3, 4, 23, 3], 539 | groups=32, 540 | reduction=16, 541 | dropout_p=None, 542 | inplanes=64, 543 | input_3x3=False, 544 | downsample_kernel_size=1, 545 | downsample_padding=0, 546 | last_stride=2, 547 | fc_dims=None, 548 | **kwargs 549 | ) 550 | if pretrained: 551 | model_url = pretrained_settings['se_resnext101_32x4d']['imagenet']['url'] 552 | init_pretrained_weights(model, model_url) 553 | return model -------------------------------------------------------------------------------- /src/project_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | import errno, copy 5 | import os.path as osp 6 | import more_itertools as mit 7 | import torch, torch.nn as nn 8 | import numpy as np 9 | import scipy.io as sio 10 | 11 | from torch.utils.data.sampler import Sampler 12 | 13 | 14 | def read_Duke_attributes(attribute_file_path): 15 | attributes = [ 16 | "age", 17 | "backpack", 18 | "bag", 19 | "boots", 20 | "clothes", 21 | "down", 22 | "gender", 23 | "hair", 24 | "handbag", 25 | "hat", 26 | "shoes", 27 | # "top", # patched up by changing to up 28 | "up", 29 | {"upcloth": [ 30 | "upblack", 31 | "upblue", 32 | "upbrown", 33 | "upgray", 34 | "upgreen", 35 | "uppurple", 36 | "upred", 37 | "upwhite", 38 | "upyellow", 39 | ]}, 40 | {"downcloth": [ 41 | "downblack", 42 | "downblue", 43 | "downbrown", 44 | "downgray", 45 | "downgreen", 46 | "downpink", 47 | "downpurple", 48 | "downred", 49 | "downwhite", 50 | "downyellow", 51 | ]}, 52 | ] 53 | 54 | attribute2index = {} 55 | for index, item in enumerate(attributes): 56 | key = item 57 | if isinstance(item, dict): 58 | key = list(item.keys())[0] 59 | attribute2index[key] = index 60 | 61 | num_options_per_attributes = { 62 | "age": ["young", "teen", "adult", "old"], 63 | "backpack": ["no", "yes"], 64 | "bag": ["no", "yes"], 65 | "boots": ["no", "yes"], 66 | "clothes": ["dress", "pants"], 67 | "down": ["long", "short"], 68 | "downcloth": [ 69 | "downblack", 70 | "downblue", 71 | "downbrown", 72 | "downgray", 73 | "downgreen", 74 | "downpink", 75 | "downpurple", 76 | "downred", 77 | "downwhite", 78 | "downyellow", 79 | ], 80 | "gender": ["male", "female"], 81 | "hair": ["short", "long"], 82 | "handbag": ["no", "yes"], 83 | "hat": ["no", "yes"], 84 | "shoes": ["dark", "light"], 85 | # "top": ["short", "long"], 86 | "up": ["long", "short"], 87 | "upcloth": [ 88 | "upblack", 89 | "upblue", 90 | "upbrown", 91 | "upgray", 92 | "upgreen", 93 | "uppurple", 94 | "upred", 95 | "upwhite", 96 | "upyellow", 97 | ], 98 | } 99 | 100 | def parse_attributes_by_id(attributes): 101 | # get the order of attributes in the buffer (from dtype) 102 | attribute_names = attributes.dtype.names 103 | attribute2index = { 104 | attribute: index 105 | for index, attribute in enumerate(attribute_names) 106 | if attribute != "image_index" 107 | } 108 | 109 | # get person id and arrange by index 110 | current_attributes = attributes.item() 111 | all_ids = current_attributes[-1] 112 | id2index = {id_: index for index, id_ in enumerate(all_ids)} 113 | 114 | # collect attributes for each ID 115 | attributes_byID = {} 116 | for attribute_name, attribute_index in attribute2index.items(): 117 | # each attribute values are stored as a row in the attribute annotation file 118 | current_attribute_values = current_attributes[attribute_index] 119 | 120 | for id_, id_index in id2index.items(): 121 | if id_ not in attributes_byID: 122 | attributes_byID[id_] = {} 123 | 124 | attribute_value = current_attribute_values[id_index] 125 | 126 | # patch for up vs. top 127 | if attribute_name == "top": 128 | attribute_name = "up" 129 | attribute_value = (attribute_value % 2) + 1 130 | 131 | attributes_byID[id_][attribute_name] = ( 132 | attribute_value - 1 133 | ) # 0-based values 134 | 135 | return attributes_byID 136 | 137 | def merge_ID_attributes(train_attributes, test_attributes): 138 | all_ID_attributes = copy.deepcopy(train_attributes) 139 | 140 | for ID, val in test_attributes.items(): 141 | assert ( 142 | ID not in train_attributes.keys() 143 | ), "attribute merge: test ID {} already exists in train".format(ID) 144 | 145 | all_ID_attributes[ID] = val 146 | 147 | return all_ID_attributes 148 | 149 | 150 | root_dir, filename = osp.split(attribute_file_path) 151 | root_key = osp.splitext(filename)[0] 152 | buffer_file_path = osp.join(root_dir, root_key + "_attribute_cache.pth") 153 | 154 | if not osp.exists(buffer_file_path): 155 | print(buffer_file_path + " does not exist!, reading the attributes") 156 | att_file_contents = sio.loadmat( 157 | attribute_file_path, squeeze_me=True 158 | ) # squeeze 1-dim elements 159 | 160 | attributes = att_file_contents[root_key] 161 | 162 | train_attributes = attributes.item()[0] 163 | test_attributes = attributes.item()[1] 164 | 165 | train_attributes_byID = parse_attributes_by_id(train_attributes) 166 | test_attributes_byID = parse_attributes_by_id(test_attributes) 167 | print('train', len(train_attributes_byID), 'test', len(test_attributes_byID)) 168 | all_ID_attributes = merge_ID_attributes( 169 | train_attributes_byID, test_attributes_byID 170 | ) 171 | 172 | # writing attribute cache 173 | print("saving the attributes to cache " + buffer_file_path) 174 | torch.save(all_ID_attributes, buffer_file_path) 175 | 176 | else: 177 | print("reading the attributes from cache " + buffer_file_path) 178 | all_ID_attributes = torch.load(buffer_file_path) 179 | 180 | return all_ID_attributes 181 | 182 | 183 | def read_Mars_attributes(attributes_file): 184 | pass 185 | 186 | 187 | def shortlist_Mars_on_attribute(data_source, attribute, value): 188 | attributes_file = '' 189 | all_attributes = read_Mars_attributes(attributes_file) 190 | 191 | for img_path, pid, camid in data_source: 192 | pass 193 | 194 | 195 | def shortlist_Duke_on_attribute(data_source, attribute, value): 196 | attributes_file = '/media/data1/datasets/personreid/dukemtmc-reid/attributes/duke_attribute.mat' 197 | all_attributes = read_Duke_attributes(attributes_file) 198 | 199 | # collect all indices where the attribute has particular value 200 | relevant_indices = [] 201 | not_found_pids = [] 202 | for index, (img_path, pid, camid) in enumerate(data_source): 203 | stringify_pid = "%04d" % pid 204 | 205 | if stringify_pid not in all_attributes: 206 | not_found_pids.append(stringify_pid) 207 | continue 208 | 209 | current_pid_attributes = all_attributes[stringify_pid] 210 | if current_pid_attributes[attribute] == value: 211 | relevant_indices.append(index) 212 | 213 | print('not found pids : {}'.format(not_found_pids)) 214 | print(len(relevant_indices), ' sampled out of ', len(data_source)) 215 | return relevant_indices, all_attributes 216 | 217 | 218 | class AttributeBasedSampler(Sampler): 219 | def __init__(self, data_source, attribute, value, dataset_name): 220 | super().__init__(data_source) 221 | 222 | if dataset_name == 'mars': 223 | relevant_indices, _ = shortlist_Mars_on_attribute(data_source, attribute, value) 224 | elif dataset_name == 'dukemtmcvidreid': 225 | relevant_indices, _ = shortlist_Duke_on_attribute(data_source, attribute, value) 226 | else: 227 | assert False, 'unknown dataset ' + dataset_name 228 | 229 | # get the indices of data instances which has 230 | # attribute's value as value 231 | self.instance_indices = relevant_indices 232 | 233 | def __len__(self): 234 | return len(self.instance_indices) 235 | 236 | def __iter__(self): 237 | return iter(self.instance_indices) 238 | -------------------------------------------------------------------------------- /src/samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | import torch 6 | 7 | class RandomIdentitySampler(torch.utils.data.Sampler): 8 | """ 9 | Randomly sample N identities, then for each identity, 10 | randomly sample K instances, therefore batch size is N*K. 11 | 12 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 13 | 14 | Args: 15 | data_source (Dataset): dataset to sample from. 16 | num_instances (int): number of instances per identity. 17 | """ 18 | def __init__(self, data_source, num_instances=4): 19 | self.data_source = data_source 20 | self.num_instances = num_instances 21 | self.index_dic = defaultdict(list) 22 | for index, (_, pid, _) in enumerate(data_source): 23 | self.index_dic[pid].append(index) 24 | self.pids = list(self.index_dic.keys()) 25 | self.num_identities = len(self.pids) 26 | 27 | def __iter__(self): 28 | indices = torch.randperm(self.num_identities) 29 | ret = [] 30 | for i in indices: 31 | pid = self.pids[i] 32 | t = self.index_dic[pid] 33 | replace = False if len(t) >= self.num_instances else True 34 | t = np.random.choice(t, size=self.num_instances, replace=replace) 35 | ret.extend(t) 36 | return iter(ret) 37 | 38 | def __len__(self): 39 | return self.num_identities * self.num_instances 40 | -------------------------------------------------------------------------------- /src/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import numpy as np 7 | 8 | class Random2DTranslation(object): 9 | """ 10 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 11 | 12 | Args: 13 | height (int): target height. 14 | width (int): target width. 15 | p (float): probability of performing this transformation. Default: 0.5. 16 | """ 17 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 18 | self.height = height 19 | self.width = width 20 | self.p = p 21 | self.interpolation = interpolation 22 | 23 | def __call__(self, img): 24 | """ 25 | Args: 26 | img (PIL Image): Image to be cropped. 27 | 28 | Returns: 29 | PIL Image: Cropped image. 30 | """ 31 | if random.random() < self.p: 32 | return img.resize((self.width, self.height), self.interpolation) 33 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 34 | resized_img = img.resize((new_width, new_height), self.interpolation) 35 | x_maxrange = new_width - self.width 36 | y_maxrange = new_height - self.height 37 | x1 = int(round(random.uniform(0, x_maxrange))) 38 | y1 = int(round(random.uniform(0, y_maxrange))) 39 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 40 | return croped_img 41 | 42 | if __name__ == '__main__': 43 | pass 44 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | import errno 5 | import shutil 6 | import json, math 7 | import os.path as osp 8 | import more_itertools as mit 9 | import torch, torch.nn as nn 10 | import numpy as np 11 | import inspect 12 | 13 | is_print_once_enabled = True 14 | 15 | 16 | def static_vars(**kwargs): 17 | def decorate(func): 18 | for k in kwargs: 19 | setattr(func, k, kwargs[k]) 20 | return func 21 | 22 | return decorate 23 | 24 | 25 | def disable_all_print_once(): 26 | global is_print_once_enabled 27 | is_print_once_enabled = False 28 | 29 | 30 | @static_vars(lines={}) 31 | def print_once(msg): 32 | # return from the function if the API is disabled 33 | global is_print_once_enabled 34 | if not is_print_once_enabled: 35 | return 36 | 37 | from inspect import getframeinfo, stack 38 | 39 | caller = getframeinfo(stack()[1][0]) 40 | current_file_line = "%s:%d" % (caller.filename, caller.lineno) 41 | 42 | # if the current called file and line is not in buffer print once 43 | if current_file_line not in print_once.lines: 44 | print(msg) 45 | print_once.lines[current_file_line] = True 46 | 47 | 48 | def get_executing_filepath(): 49 | frame = inspect.stack()[1] 50 | module = inspect.getmodule(frame[0]) 51 | filename = module.__file__ 52 | return os.path.split(filename)[0] 53 | 54 | 55 | def set_stride(module, stride): 56 | """ 57 | 58 | """ 59 | print("setting stride of ", module, " to ", stride) 60 | for internal_module in module.modules(): 61 | if isinstance(internal_module, nn.Conv2d) or isinstance( 62 | internal_module, nn.MaxPool2d 63 | ): 64 | internal_module.stride = stride 65 | 66 | return internal_module 67 | 68 | 69 | def get_gaussian_kernel(channels, kernel_size=5, mean=0, sigma=[1, 4]): 70 | # CONVERT INTO NP ARRAY 71 | sigma_ = torch.zeros((2, 2)).float() 72 | sigma_[0, 0] = sigma[0] 73 | sigma_[1, 1] = sigma[1] 74 | sigma = sigma_ 75 | 76 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 77 | x_cord = torch.linspace(-1, 1, kernel_size) 78 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) 79 | y_grid = x_grid.t() 80 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() 81 | 82 | variance = (sigma @ sigma.t()).float() 83 | inv_variance = torch.inverse(variance) 84 | 85 | # Calculate the 2-dimensional gaussian kernel which is 86 | # the product of two gaussian distributions for two different 87 | # variables (in this case called x and y) 88 | gaussian_kernel = (1.0 / (2.0 * math.pi * torch.det(variance))) * torch.exp( 89 | -torch.sum( 90 | ((xy_grid - mean) @ inv_variance.unsqueeze(0)) * (xy_grid - mean), dim=-1 91 | ) 92 | / 2 93 | ) 94 | 95 | # Make sure sum of values in gaussian kernel equals 1. 96 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 97 | 98 | # Reshape to 2d depthwise convolutional weight 99 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 100 | gaussian_kernel = gaussian_kernel.repeat(1, channels, 1, 1) 101 | return gaussian_kernel 102 | 103 | 104 | def mkdir_if_missing(directory): 105 | """to create a directory 106 | 107 | Arguments: 108 | directory {str} -- directory path 109 | """ 110 | 111 | if not osp.exists(directory): 112 | try: 113 | os.makedirs(directory) 114 | except OSError as e: 115 | if e.errno != errno.EEXIST: 116 | raise 117 | 118 | 119 | # function to freeze certain module's weight 120 | def freeze_weights(to_be_freezed): 121 | for param in to_be_freezed.parameters(): 122 | param.requires_grad = False 123 | 124 | for module in to_be_freezed.children(): 125 | for param in module.parameters(): 126 | param.requires_grad = False 127 | 128 | 129 | def load_pretrained_model(model, pretrained_model_path, verbose=False): 130 | """To load the pretrained model considering the number of keys and their sizes 131 | 132 | Arguments: 133 | model {loaded model} -- already loaded model 134 | pretrained_model_path {str} -- path to the pretrained model file 135 | 136 | Raises: 137 | IOError -- if the file path is not found 138 | 139 | Returns: 140 | model -- model with loaded params 141 | """ 142 | 143 | if not os.path.exists(pretrained_model_path): 144 | raise IOError("Can't find pretrained model: {}".format(pretrained_model_path)) 145 | 146 | print("Loading checkpoint from '{}'".format(pretrained_model_path)) 147 | pretrained_state = torch.load(pretrained_model_path)["state_dict"] 148 | print(len(pretrained_state), " keys in pretrained model") 149 | 150 | current_model_state = model.state_dict() 151 | print(len(current_model_state), " keys in current model") 152 | pretrained_state = { 153 | key: val 154 | for key, val in pretrained_state.items() 155 | if key in current_model_state and val.size() == current_model_state[key].size() 156 | } 157 | 158 | print( 159 | len(pretrained_state), 160 | " keys in pretrained model are available in current model", 161 | ) 162 | current_model_state.update(pretrained_state) 163 | model.load_state_dict(current_model_state) 164 | 165 | if verbose: 166 | non_available_keys_in_pretrained = [ 167 | key 168 | for key, val in pretrained_state.items() 169 | if key not in current_model_state 170 | or val.size() != current_model_state[key].size() 171 | ] 172 | non_available_keys_in_current = [ 173 | key 174 | for key, val in current_model_state.items() 175 | if key not in pretrained_state or val.size() != pretrained_state[key].size() 176 | ] 177 | 178 | print( 179 | "not available keys in pretrained model: ", non_available_keys_in_pretrained 180 | ) 181 | print("not available keys in current model: ", non_available_keys_in_current) 182 | 183 | return model 184 | 185 | 186 | def get_currenttime_prefix(): 187 | """to get a prefix of current time 188 | 189 | Returns: 190 | [str] -- current time encoded into string 191 | """ 192 | 193 | from time import localtime, strftime 194 | 195 | return strftime("%d-%b-%Y_%H:%M:%S", localtime()) 196 | 197 | 198 | def get_learnable_params(model): 199 | """to get the list of learnable params 200 | 201 | Arguments: 202 | model {model} -- loaded model 203 | 204 | Returns: 205 | list -- learnable params 206 | """ 207 | 208 | # list down the names of learnable params 209 | details = [] 210 | for name, param in model.named_parameters(): 211 | if param.requires_grad: 212 | details.append((name, param.shape)) 213 | print("learnable params (" + str(len(details)) + ") : ", details) 214 | 215 | # short list the params which has requires_grad as true 216 | learnable_params = [param for param in model.parameters() if param.requires_grad] 217 | 218 | print( 219 | "Model size: {:.5f}M".format( 220 | sum(p.numel() for p in learnable_params) / 1000000.0 221 | ) 222 | ) 223 | return learnable_params 224 | 225 | 226 | def get_features(model, imgs, test_num_tracks): 227 | """to handle higher seq length videos due to OOM error 228 | specifically used during test 229 | 230 | Arguments: 231 | model -- model under test 232 | imgs -- imgs to get features for 233 | 234 | Returns: 235 | features 236 | """ 237 | 238 | # handle chunked data 239 | all_features = [] 240 | 241 | for test_imgs in mit.chunked(imgs, test_num_tracks): 242 | current_test_imgs = torch.stack(test_imgs) 243 | num_current_test_imgs = current_test_imgs.shape[0] 244 | # print(current_test_imgs.shape) 245 | features = model(current_test_imgs) 246 | features = features.view(num_current_test_imgs, -1) 247 | all_features.append(features) 248 | 249 | return torch.cat(all_features) 250 | 251 | 252 | def get_spatial_features(model, imgs, test_num_tracks): 253 | """to handle higher seq length videos due to OOM error 254 | specifically used during test 255 | 256 | Arguments: 257 | model -- model under test 258 | imgs -- imgs to get features for 259 | 260 | Returns: 261 | features 262 | """ 263 | 264 | # handle chunked data 265 | all_features, all_spatial_features = [], [] 266 | 267 | for test_imgs in mit.chunked(imgs, test_num_tracks): 268 | current_test_imgs = torch.stack(test_imgs) 269 | num_current_test_imgs = current_test_imgs.shape[0] 270 | features, spatial_feats = model(current_test_imgs) 271 | features = features.view(num_current_test_imgs, -1) 272 | 273 | all_spatial_features.append(spatial_feats) 274 | all_features.append(features) 275 | 276 | return torch.cat(all_features), torch.cat(all_spatial_features) 277 | 278 | 279 | class AverageMeter(object): 280 | """Computes and stores the average and current value. 281 | 282 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 283 | """ 284 | 285 | def __init__(self): 286 | self.reset() 287 | 288 | def reset(self): 289 | self.val = 0 290 | self.avg = 0 291 | self.sum = 0 292 | self.count = 0 293 | 294 | def update(self, val, n=1): 295 | self.val = val 296 | self.sum += val * n 297 | self.count += n 298 | self.avg = self.sum / self.count 299 | 300 | 301 | def save_checkpoint(state, is_best, fpath="checkpoint.pth.tar"): 302 | mkdir_if_missing(osp.dirname(fpath)) 303 | print("saving model to " + fpath) 304 | torch.save(state, fpath) 305 | if is_best: 306 | shutil.copy(fpath, osp.join(osp.dirname(fpath), "best_model.pth.tar")) 307 | 308 | 309 | def open_all_layers(model): 310 | """ 311 | Open all layers in model for training. 312 | 313 | Args: 314 | - model (nn.Module): neural net model. 315 | """ 316 | model.train() 317 | for p in model.parameters(): 318 | p.requires_grad = True 319 | 320 | 321 | def open_specified_layers(model, open_layers): 322 | """ 323 | Open specified layers in model for training while keeping 324 | other layers frozen. 325 | 326 | Args: 327 | - model (nn.Module): neural net model. 328 | - open_layers (list): list of layer names. 329 | """ 330 | if isinstance(model, nn.DataParallel): 331 | model = model.module 332 | 333 | for layer in open_layers: 334 | assert hasattr( 335 | model, layer 336 | ), '"{}" is not an attribute of the model, please provide the correct name'.format( 337 | layer 338 | ) 339 | 340 | # check if all the open layers are there in model 341 | all_names = [name for name, module in model.named_children()] 342 | for tobeopen_layer in open_layers: 343 | assert tobeopen_layer in all_names, "{} not in model".format(tobeopen_layer) 344 | 345 | for name, module in model.named_children(): 346 | if name in open_layers: 347 | module.train() 348 | for p in module.parameters(): 349 | p.requires_grad = True 350 | else: 351 | module.eval() 352 | for p in module.parameters(): 353 | p.requires_grad = False 354 | 355 | 356 | class Logger(object): 357 | """ 358 | Write console output to external text file. 359 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 360 | """ 361 | 362 | def __init__(self, fpath=None): 363 | self.console = sys.stdout 364 | self.file = None 365 | if fpath is not None: 366 | mkdir_if_missing(os.path.dirname(fpath)) 367 | self.file = open(fpath, "a") 368 | 369 | def __del__(self): 370 | self.close() 371 | 372 | def __enter__(self): 373 | pass 374 | 375 | def __exit__(self, *args): 376 | self.close() 377 | 378 | def write(self, msg): 379 | self.console.write(msg) 380 | if self.file is not None: 381 | self.file.write(msg) 382 | 383 | def flush(self): 384 | self.console.flush() 385 | if self.file is not None: 386 | self.file.flush() 387 | os.fsync(self.file.fileno()) 388 | 389 | def close(self): 390 | self.console.close() 391 | if self.file is not None: 392 | self.file.close() 393 | 394 | 395 | def read_json(fpath): 396 | with open(fpath, "r") as f: 397 | obj = json.load(f) 398 | return obj 399 | 400 | 401 | def write_json(obj, fpath): 402 | mkdir_if_missing(osp.dirname(fpath)) 403 | with open(fpath, "w") as f: 404 | json.dump(obj, f, indent=4, separators=(",", ": ")) 405 | 406 | -------------------------------------------------------------------------------- /src/video_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | import random 9 | 10 | def read_image(img_path): 11 | """Keep reading image until succeed. 12 | This can avoid IOError incurred by heavy IO process.""" 13 | got_img = False 14 | while not got_img: 15 | try: 16 | img = Image.open(img_path).convert('RGB') 17 | got_img = True 18 | except IOError: 19 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 20 | pass 21 | return img 22 | 23 | 24 | class VideoDataset(Dataset): 25 | """Video Person ReID Dataset. 26 | Note batch data has shape (batch, seq_len, channel, height, width). 27 | """ 28 | sample_methods = ['evenly', 'random', 'all'] 29 | 30 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None): 31 | self.dataset = dataset 32 | self.seq_len = seq_len 33 | self.sample = sample 34 | self.transform = transform 35 | 36 | def __len__(self): 37 | return len(self.dataset) 38 | 39 | def __getitem__(self, index): 40 | img_paths, pid, camid = self.dataset[index] 41 | num = len(img_paths) 42 | if self.sample == 'random': 43 | """ 44 | Randomly sample seq_len consecutive frames from num frames, 45 | if num is smaller than seq_len, then replicate items. 46 | This sampling strategy is used in training phase. 47 | """ 48 | frame_indices = range(num) 49 | rand_end = max(0, len(frame_indices) - self.seq_len - 1) 50 | begin_index = random.randint(0, rand_end) 51 | end_index = min(begin_index + self.seq_len, len(frame_indices)) 52 | 53 | indices = list(frame_indices[begin_index:end_index]) 54 | 55 | for index in indices: 56 | if len(indices) >= self.seq_len: 57 | break 58 | indices.append(index) 59 | indices=np.array(indices) 60 | imgs = [] 61 | for index in indices: 62 | index=int(index) 63 | img_path = img_paths[index] 64 | img = read_image(img_path) 65 | if self.transform is not None: 66 | img = self.transform(img) 67 | img = img.unsqueeze(0) 68 | imgs.append(img) 69 | imgs = torch.cat(imgs, dim=0) 70 | #imgs=imgs.permute(1,0,2,3) 71 | return imgs, pid, camid 72 | 73 | if self.sample == 'evenly': 74 | """ 75 | Evenly sample seq_len items from num items. 76 | """ 77 | if num >= self.seq_len: 78 | num -= num % self.seq_len 79 | indices = np.arange(0, num, int(num/self.seq_len)) 80 | else: 81 | # if num is smaller than seq_len, simply replicate the last image 82 | # until the seq_len requirement is satisfied 83 | indices = np.arange(0, num) 84 | num_pads = self.seq_len - num 85 | indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num-1)]) 86 | 87 | assert len(indices) == self.seq_len 88 | 89 | for index in indices: 90 | if len(indices) >= self.seq_len: 91 | break 92 | indices.append(index) 93 | indices=np.array(indices) 94 | imgs = [] 95 | for index in indices: 96 | index=int(index) 97 | img_path = img_paths[index] 98 | img = read_image(img_path) 99 | if self.transform is not None: 100 | img = self.transform(img) 101 | img = img.unsqueeze(0) 102 | imgs.append(img) 103 | imgs = torch.cat(imgs, dim=0) 104 | #imgs=imgs.permute(1,0,2,3) 105 | return imgs, pid, camid 106 | 107 | elif self.sample == 'dense': 108 | """ 109 | Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1. 110 | This sampling strategy is used in test phase. 111 | """ 112 | cur_index=0 113 | frame_indices = list(range(num)) 114 | indices_list=[] 115 | while num-cur_index > self.seq_len: 116 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len]) 117 | cur_index+=self.seq_len 118 | last_seq=frame_indices[cur_index:] 119 | for index in last_seq: 120 | if len(last_seq) >= self.seq_len: 121 | break 122 | last_seq.append(index) 123 | indices_list.append(last_seq) 124 | 125 | imgs_list=[] 126 | for indices in indices_list: 127 | imgs = [] 128 | for index in indices: 129 | index=int(index) 130 | img_path = img_paths[index] 131 | img = read_image(img_path) 132 | if self.transform is not None: 133 | img = self.transform(img) 134 | img = img.unsqueeze(0) 135 | imgs.append(img) 136 | imgs = torch.cat(imgs, dim=0) 137 | #imgs=imgs.permute(1,0,2,3) 138 | imgs_list.append(imgs) 139 | imgs_array = torch.stack(imgs_list) 140 | return imgs_array, pid, camid 141 | 142 | else: 143 | raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods)) 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | --------------------------------------------------------------------------------