├── .gitignore
├── .gitmodules
├── DATASETS.md
├── LICENSE
├── README.md
├── docs
└── visualization
│ └── coseg_model_mars_viz.ipynb
├── scratch
└── readme
└── src
├── data_manager.py
├── eval_metrics.py
├── losses.py
├── main_video_person_reid.py
├── models
├── ResNet.py
├── SE_ResNet.py
├── __init__.py
├── aggregation_layers.py
├── cosam.py
└── senet.py
├── project_utils.py
├── samplers.py
├── transforms.py
├── utils.py
└── video_loader.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "src/person_re_ranking"]
2 | path = src/person_re_ranking
3 | url = https://github.com/InnovArul/person_re_ranking
4 |
--------------------------------------------------------------------------------
/DATASETS.md:
--------------------------------------------------------------------------------
1 | ## MARS
2 |
3 | MARS is the largest dataset available to date for video-based person reID. The instructions are copied here:
4 |
5 | * Create a directory named mars/ under data/.
6 | * Download dataset to data/mars/ from http://www.liangzheng.com.cn/Project/project_mars.html.
7 | * Extract bbox_train.zip and bbox_test.zip.
8 | * Download split information from https://github.com/liangzheng06/MARS-evaluation/tree/master/info and put info/ in data/mars (we want to follow the standard split in [8]).
9 | * The data structure would look like:
10 |
11 | ```
12 | mars/
13 | bbox_test/
14 | bbox_train/
15 | info/
16 | ```
17 | Use -d mars when running the training code.
18 |
19 | ## DukeMTMC-VideoReID
20 |
21 | * Create a directory named dukemtmc-vidreid/ under data/.
22 |
23 | * Download “DukeMTMC-VideoReID” from http://vision.cs.duke.edu/DukeMTMC/ and unzip the file to “dukemtmc-vidreid/”.
24 |
25 | * The data structure should look like
26 |
27 | ```
28 | dukemtmc-vidreid/
29 | DukeMTMC-VideoReID/
30 | train/
31 | query/
32 | gallery/
33 | ```
34 |
35 | ## iLIDS-VID
36 |
37 | * Create a directory named ilids-vid/ under data/.
38 |
39 | * Download the dataset from http://www.eecs.qmul.ac.uk/~xiatian/downloads_qmul_iLIDS-VID_ReID_dataset.html to "ilids-vid".
40 |
41 | * Organize the data structure to match
42 |
43 | ```
44 | ilids-vid/
45 | i-LIDS-VID/
46 | train-test people splits
47 | ```
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Video-based Person Re-identification
2 | Source code for the ICCV-2019 paper "Co-segmentation Inspired Attention Networks for Video-based Person Re-identification". Our paper can be found here.
3 |
4 | ## Introduction
5 |
6 | Our work attempts to tackle some of the challenges in Video-based Person re-identification (Re-ID) such as Background clutter, Misalignment error and partial occlusion by means of an co-segmentation inspired approach. The intention is to attend to the task-dependent common portions of the images (i.e., video frames of a person) that may aid the network in better focusing on most relevant features. This repository contains code for Co-segmentation inspired Re-ID architecture, “Co-segmentation Activation Module (COSAM)".
7 | Co-segmentation masks are “Interpretable” and helps to understand how and where the network attends to when creating a description about the person.
8 |
9 | ### Credits
10 |
11 | The source code is built upon the github repositories Video-Person-ReID (from jiyanggao) and deep-person-reid (from KaiyangZhou). Mainly, the data-loading, data-sampling and training part are borrowed from their repository. The strong baseline performances are based on the models from the codebase Video-Person-ReID. Check out their papers Revisiting Temporal Modeling for Video-based Person ReID (Gao et al.,), OSNet (Zhou et al., ICCV 2019).
12 |
13 | We would like to thank jiyanggao and KaiyangZhou for their generous contribution to release the code to the community.
14 |
15 | ## Datasets
16 |
17 | Dataset preparation instructions can be found in the repositories Video-Person-ReID and deep-person-reid. For completeness, I have compiled the dataset instructions here.
18 |
19 | ## Training
20 |
21 | `python main_video_person_reid.py -a resnet50_cosam45_tp -d --gpu-devices `
22 |
23 | `` can be `mars` or `dukemtmcvidreid`
24 |
25 | ## Testing
26 |
27 | `python main_video_person_reid.py -a resnet50_cosam45_tp -d --gpu-devices --evaluate --pretrained-model `
28 |
29 |
--------------------------------------------------------------------------------
/scratch/readme:
--------------------------------------------------------------------------------
1 | the training folders will be created here
--------------------------------------------------------------------------------
/src/data_manager.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os
3 | import glob
4 | import re
5 | import sys
6 | import urllib
7 | import tarfile
8 | import zipfile
9 | import os.path as osp
10 | from scipy.io import loadmat
11 | import numpy as np
12 |
13 | from utils import mkdir_if_missing, write_json, read_json
14 |
15 | """Dataset classes"""
16 |
17 |
18 | class Mars(object):
19 | """
20 | MARS
21 |
22 | Reference:
23 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016.
24 |
25 | Dataset statistics:
26 | # identities: 1261
27 | # tracklets: 8298 (train) + 1980 (query) + 9330 (gallery)
28 | # cameras: 6
29 |
30 | Args:
31 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0).
32 | """
33 |
34 | def __init__(self, root='../data', min_seq_len=0):
35 |
36 | self.root = os.path.join(root, 'mars')
37 |
38 | root = self.root
39 | self.train_name_path = osp.join(root, 'info/train_name.txt')
40 | self.test_name_path = osp.join(root, 'info/test_name.txt')
41 | self.track_train_info_path = osp.join(root, 'info/tracks_train_info.mat')
42 | self.track_test_info_path = osp.join(root, 'info/tracks_test_info.mat')
43 | self.query_IDX_path = osp.join(root, 'info/query_IDX.mat')
44 |
45 | self._check_before_run()
46 |
47 | # prepare meta data
48 | train_names = self._get_names(self.train_name_path)
49 | test_names = self._get_names(self.test_name_path)
50 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4)
51 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4)
52 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,)
53 | query_IDX -= 1 # index from 0
54 | track_query = track_test[query_IDX,:]
55 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX]
56 | track_gallery = track_test[gallery_IDX,:]
57 |
58 | train, num_train_tracklets, num_train_pids, num_train_imgs = \
59 | self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len)
60 |
61 | query, num_query_tracklets, num_query_pids, num_query_imgs = \
62 | self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
63 |
64 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = \
65 | self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
66 |
67 | num_imgs_per_tracklet = num_train_imgs + num_query_imgs + num_gallery_imgs
68 | min_num = np.min(num_imgs_per_tracklet)
69 | max_num = np.max(num_imgs_per_tracklet)
70 | avg_num = np.mean(num_imgs_per_tracklet)
71 |
72 | num_total_pids = num_train_pids + num_query_pids
73 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets
74 |
75 | print("=> MARS loaded")
76 | print("Dataset statistics:")
77 | print(" ------------------------------")
78 | print(" subset | # ids | # tracklets")
79 | print(" ------------------------------")
80 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets))
81 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets))
82 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets))
83 | print(" ------------------------------")
84 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets))
85 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num))
86 | print(" ------------------------------")
87 |
88 | self.train = train
89 | self.query = query
90 | self.gallery = gallery
91 |
92 | self.num_train_pids = num_train_pids
93 | self.num_query_pids = num_query_pids
94 | self.num_gallery_pids = num_gallery_pids
95 |
96 | def _check_before_run(self):
97 | """Check if all files are available before going deeper"""
98 | if not osp.exists(self.root):
99 | raise RuntimeError("'{}' is not available".format(self.root))
100 | if not osp.exists(self.train_name_path):
101 | raise RuntimeError("'{}' is not available".format(self.train_name_path))
102 | if not osp.exists(self.test_name_path):
103 | raise RuntimeError("'{}' is not available".format(self.test_name_path))
104 | if not osp.exists(self.track_train_info_path):
105 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path))
106 | if not osp.exists(self.track_test_info_path):
107 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path))
108 | if not osp.exists(self.query_IDX_path):
109 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path))
110 |
111 | def _get_names(self, fpath):
112 | names = []
113 | with open(fpath, 'r') as f:
114 | for line in f:
115 | new_line = line.rstrip()
116 | names.append(new_line)
117 | return names
118 |
119 | def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0):
120 | assert home_dir in ['bbox_train', 'bbox_test']
121 | num_tracklets = meta_data.shape[0]
122 | pid_list = list(set(meta_data[:,2].tolist()))
123 | num_pids = len(pid_list)
124 |
125 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)}
126 | tracklets = []
127 | num_imgs_per_tracklet = []
128 |
129 | for tracklet_idx in range(num_tracklets):
130 | data = meta_data[tracklet_idx,...]
131 | start_index, end_index, pid, camid = data
132 | if pid == -1: continue # junk images are just ignored
133 | assert 1 <= camid <= 6
134 | if relabel: pid = pid2label[pid]
135 | camid -= 1 # index starts from 0
136 | img_names = names[start_index-1:end_index]
137 |
138 | # make sure image names correspond to the same person
139 | pnames = [img_name[:4] for img_name in img_names]
140 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images"
141 |
142 | # make sure all images are captured under the same camera
143 | camnames = [img_name[5] for img_name in img_names]
144 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!"
145 |
146 | # append image names with directory information
147 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names]
148 | if len(img_paths) >= min_seq_len:
149 | img_paths = tuple(img_paths)
150 | tracklets.append((img_paths, pid, camid))
151 | num_imgs_per_tracklet.append(len(img_paths))
152 |
153 | num_tracklets = len(tracklets)
154 |
155 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet
156 |
157 |
158 |
159 | class DukeMTMCVidReID(object):
160 | """
161 | DukeMTMCVidReID
162 |
163 | Reference:
164 | Wu et al. Exploit the Unknown Gradually: One-Shot Video-Based Person
165 | Re-Identification by Stepwise Learning. CVPR 2018.
166 |
167 | URL: https://github.com/Yu-Wu/DukeMTMC-VideoReID
168 |
169 | Dataset statistics:
170 | # identities: 702 (train) + 702 (test)
171 | # tracklets: 2196 (train) + 2636 (test)
172 | """
173 | dataset_dir = 'dukemtmc-vidreid'
174 |
175 | def __init__(self, root='../data', min_seq_len=0, verbose=True, **kwargs):
176 | self.dataset_dir = osp.join(root, self.dataset_dir)
177 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-VideoReID.zip'
178 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/train')
179 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/query')
180 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/gallery')
181 | self.split_train_json_path = osp.join(self.dataset_dir, 'split_train.json')
182 | self.split_query_json_path = osp.join(self.dataset_dir, 'split_query.json')
183 | self.split_gallery_json_path = osp.join(self.dataset_dir, 'split_gallery.json')
184 |
185 | self.min_seq_len = min_seq_len
186 | self._download_data()
187 | self._check_before_run()
188 | print("Note: if root path is changed, the previously generated json files need to be re-generated (so delete them first)")
189 |
190 | train, num_train_tracklets, num_train_pids, num_imgs_train = \
191 | self._process_dir(self.train_dir, self.split_train_json_path, relabel=True)
192 | query, num_query_tracklets, num_query_pids, num_imgs_query = \
193 | self._process_dir(self.query_dir, self.split_query_json_path, relabel=False)
194 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \
195 | self._process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False)
196 |
197 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery
198 | min_num = np.min(num_imgs_per_tracklet)
199 | max_num = np.max(num_imgs_per_tracklet)
200 | avg_num = np.mean(num_imgs_per_tracklet)
201 |
202 | num_total_pids = num_train_pids + num_query_pids
203 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets
204 |
205 | if verbose:
206 | print("=> DukeMTMC-VideoReID loaded")
207 | print("Dataset statistics:")
208 | print(" ------------------------------")
209 | print(" subset | # ids | # tracklets")
210 | print(" ------------------------------")
211 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets))
212 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets))
213 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets))
214 | print(" ------------------------------")
215 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets))
216 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num))
217 | print(" ------------------------------")
218 |
219 | self.train = train
220 | self.query = query
221 | self.gallery = gallery
222 |
223 | self.num_train_pids = num_train_pids
224 | self.num_query_pids = num_query_pids
225 | self.num_gallery_pids = num_gallery_pids
226 |
227 | def _download_data(self):
228 | if osp.exists(self.dataset_dir):
229 | print("This dataset has been downloaded.")
230 | return
231 |
232 | print("Creating directory {}".format(self.dataset_dir))
233 | mkdir_if_missing(self.dataset_dir)
234 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url))
235 |
236 | print("Downloading DukeMTMC-VideoReID dataset")
237 | url_opener = urllib.URLopener()
238 | url_opener.retrieve(self.dataset_url, fpath)
239 |
240 | print("Extracting files")
241 | zip_ref = zipfile.ZipFile(fpath, 'r')
242 | zip_ref.extractall(self.dataset_dir)
243 | zip_ref.close()
244 |
245 | def _check_before_run(self):
246 | """Check if all files are available before going deeper"""
247 | if not osp.exists(self.dataset_dir):
248 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
249 | if not osp.exists(self.train_dir):
250 | raise RuntimeError("'{}' is not available".format(self.train_dir))
251 | if not osp.exists(self.query_dir):
252 | raise RuntimeError("'{}' is not available".format(self.query_dir))
253 | if not osp.exists(self.gallery_dir):
254 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
255 |
256 | def _process_dir(self, dir_path, json_path, relabel):
257 | if osp.exists(json_path):
258 | print("=> {} generated before, awesome!".format(json_path))
259 | split = read_json(json_path)
260 | return split['tracklets'], split['num_tracklets'], split['num_pids'], split['num_imgs_per_tracklet']
261 |
262 | print("=> Automatically generating split (might take a while for the first time, have a coffe)")
263 | pdirs = glob.glob(osp.join(dir_path, '*')) # avoid .DS_Store
264 | print("Processing {} with {} person identities".format(dir_path, len(pdirs)))
265 |
266 | pid_container = set()
267 | for pdir in pdirs:
268 | pid = int(osp.basename(pdir))
269 | pid_container.add(pid)
270 | pid2label = {pid:label for label, pid in enumerate(pid_container)}
271 |
272 | tracklets = []
273 | num_imgs_per_tracklet = []
274 | for pdir in pdirs:
275 | pid = int(osp.basename(pdir))
276 | if relabel: pid = pid2label[pid]
277 | tdirs = glob.glob(osp.join(pdir, '*'))
278 | for tdir in tdirs:
279 | raw_img_paths = glob.glob(osp.join(tdir, '*.jpg'))
280 | num_imgs = len(raw_img_paths)
281 |
282 | if num_imgs < self.min_seq_len:
283 | continue
284 |
285 | num_imgs_per_tracklet.append(num_imgs)
286 | img_paths = []
287 | for img_idx in range(num_imgs):
288 | # some tracklet starts from 0002 instead of 0001
289 | img_idx_name = 'F' + str(img_idx+1).zfill(4)
290 | res = glob.glob(osp.join(tdir, '*' + img_idx_name + '*.jpg'))
291 | if len(res) == 0:
292 | print("Warn: index name {} in {} is missing, jump to next".format(img_idx_name, tdir))
293 | continue
294 | img_paths.append(res[0])
295 | img_name = osp.basename(img_paths[0])
296 | if img_name.find('_') == -1:
297 | # old naming format: 0001C6F0099X30823.jpg
298 | camid = int(img_name[5]) - 1
299 | else:
300 | # new naming format: 0001_C6_F0099_X30823.jpg
301 | camid = int(img_name[6]) - 1
302 | img_paths = tuple(img_paths)
303 | tracklets.append((img_paths, pid, camid))
304 |
305 | num_pids = len(pid_container)
306 | num_tracklets = len(tracklets)
307 |
308 | print("Saving split to {}".format(json_path))
309 | split_dict = {
310 | 'tracklets': tracklets,
311 | 'num_tracklets': num_tracklets,
312 | 'num_pids': num_pids,
313 | 'num_imgs_per_tracklet': num_imgs_per_tracklet,
314 | }
315 | write_json(split_dict, json_path)
316 |
317 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet
318 |
319 |
320 | class iLIDSVID(object):
321 | """
322 | iLIDS-VID
323 |
324 | Reference:
325 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014.
326 |
327 | Dataset statistics:
328 | # identities: 300
329 | # tracklets: 600
330 | # cameras: 2
331 |
332 | Args:
333 | split_id (int): indicates which split to use. There are totally 10 splits.
334 | """
335 | root = '../data/ilids-vid'
336 | dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar'
337 | data_dir = osp.join(root, 'i-LIDS-VID')
338 | split_dir = osp.join(root, 'train-test people splits')
339 | split_mat_path = osp.join(split_dir, 'train_test_splits_ilidsvid.mat')
340 | split_path = osp.join(root, 'splits.json')
341 | cam_1_path = osp.join(root, 'i-LIDS-VID/sequences/cam1')
342 | cam_2_path = osp.join(root, 'i-LIDS-VID/sequences/cam2')
343 |
344 | def __init__(self, split_id=0):
345 | self._download_data()
346 | self._check_before_run()
347 |
348 | self._prepare_split()
349 | splits = read_json(self.split_path)
350 | if split_id >= len(splits):
351 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1))
352 | split = splits[split_id]
353 | train_dirs, test_dirs = split['train'], split['test']
354 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs)))
355 |
356 | train, num_train_tracklets, num_train_pids, num_imgs_train = \
357 | self._process_data(train_dirs, cam1=True, cam2=True)
358 | query, num_query_tracklets, num_query_pids, num_imgs_query = \
359 | self._process_data(test_dirs, cam1=True, cam2=False)
360 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \
361 | self._process_data(test_dirs, cam1=False, cam2=True)
362 |
363 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery
364 | min_num = np.min(num_imgs_per_tracklet)
365 | max_num = np.max(num_imgs_per_tracklet)
366 | avg_num = np.mean(num_imgs_per_tracklet)
367 |
368 | num_total_pids = num_train_pids + num_query_pids
369 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets
370 |
371 | print("=> iLIDS-VID loaded")
372 | print("Dataset statistics:")
373 | print(" ------------------------------")
374 | print(" subset | # ids | # tracklets")
375 | print(" ------------------------------")
376 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets))
377 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets))
378 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets))
379 | print(" ------------------------------")
380 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets))
381 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num))
382 | print(" ------------------------------")
383 |
384 | self.train = train
385 | self.query = query
386 | self.gallery = gallery
387 |
388 | self.num_train_pids = num_train_pids
389 | self.num_query_pids = num_query_pids
390 | self.num_gallery_pids = num_gallery_pids
391 |
392 | def _download_data(self):
393 | if osp.exists(self.root):
394 | print("This dataset has been downloaded.")
395 | return
396 |
397 | mkdir_if_missing(self.root)
398 | fpath = osp.join(self.root, osp.basename(self.dataset_url))
399 |
400 | print("Downloading iLIDS-VID dataset")
401 | url_opener = urllib.URLopener()
402 | url_opener.retrieve(self.dataset_url, fpath)
403 |
404 | print("Extracting files")
405 | tar = tarfile.open(fpath)
406 | tar.extractall(path=self.root)
407 | tar.close()
408 |
409 | def _check_before_run(self):
410 | """Check if all files are available before going deeper"""
411 | if not osp.exists(self.root):
412 | raise RuntimeError("'{}' is not available".format(self.root))
413 | if not osp.exists(self.data_dir):
414 | raise RuntimeError("'{}' is not available".format(self.data_dir))
415 | if not osp.exists(self.split_dir):
416 | raise RuntimeError("'{}' is not available".format(self.split_dir))
417 |
418 | def _prepare_split(self):
419 | if not osp.exists(self.split_path):
420 | print("Creating splits")
421 | mat_split_data = loadmat(self.split_mat_path)['ls_set']
422 |
423 | num_splits = mat_split_data.shape[0]
424 | num_total_ids = mat_split_data.shape[1]
425 | assert num_splits == 10
426 | assert num_total_ids == 300
427 | num_ids_each = num_total_ids/2
428 |
429 | # pids in mat_split_data are indices, so we need to transform them
430 | # to real pids
431 | person_cam1_dirs = os.listdir(self.cam_1_path)
432 | person_cam2_dirs = os.listdir(self.cam_2_path)
433 |
434 | # make sure persons in one camera view can be found in the other camera view
435 | assert set(person_cam1_dirs) == set(person_cam2_dirs)
436 |
437 | splits = []
438 | for i_split in range(num_splits):
439 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14.
440 | train_idxs = sorted(list(mat_split_data[i_split,num_ids_each:]))
441 | test_idxs = sorted(list(mat_split_data[i_split,:num_ids_each]))
442 |
443 | train_idxs = [int(i)-1 for i in train_idxs]
444 | test_idxs = [int(i)-1 for i in test_idxs]
445 |
446 | # transform pids to person dir names
447 | train_dirs = [person_cam1_dirs[i] for i in train_idxs]
448 | test_dirs = [person_cam1_dirs[i] for i in test_idxs]
449 |
450 | split = {'train': train_dirs, 'test': test_dirs}
451 | splits.append(split)
452 |
453 | print("Totally {} splits are created, following Wang et al. ECCV'14".format(len(splits)))
454 | print("Split file is saved to {}".format(self.split_path))
455 | write_json(splits, self.split_path)
456 |
457 | print("Splits created")
458 |
459 | def _process_data(self, dirnames, cam1=True, cam2=True):
460 | tracklets = []
461 | num_imgs_per_tracklet = []
462 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)}
463 |
464 | for dirname in dirnames:
465 | if cam1:
466 | person_dir = osp.join(self.cam_1_path, dirname)
467 | img_names = glob.glob(osp.join(person_dir, '*.png'))
468 | assert len(img_names) > 0
469 | img_names = tuple(img_names)
470 | pid = dirname2pid[dirname]
471 | tracklets.append((img_names, pid, 0))
472 | num_imgs_per_tracklet.append(len(img_names))
473 |
474 | if cam2:
475 | person_dir = osp.join(self.cam_2_path, dirname)
476 | img_names = glob.glob(osp.join(person_dir, '*.png'))
477 | assert len(img_names) > 0
478 | img_names = tuple(img_names)
479 | pid = dirname2pid[dirname]
480 | tracklets.append((img_names, pid, 1))
481 | num_imgs_per_tracklet.append(len(img_names))
482 |
483 | num_tracklets = len(tracklets)
484 | num_pids = len(dirnames)
485 |
486 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet
487 |
488 | """Create dataset"""
489 |
490 | __factory = {
491 | 'mars': Mars,
492 | 'ilidsvid': iLIDSVID,
493 | 'dukemtmcvidreid': DukeMTMCVidReID,
494 | }
495 |
496 | def get_names():
497 | return __factory.keys()
498 |
499 | def init_dataset(name, *args, **kwargs):
500 | if name not in __factory.keys():
501 | raise KeyError("Unknown dataset: {}".format(name))
502 | return __factory[name](*args, **kwargs)
503 |
504 | if __name__ == '__main__':
505 | # test
506 | #dataset = Market1501()
507 | #dataset = Mars()
508 | dataset = iLIDSVID()
509 | dataset = PRID()
510 |
--------------------------------------------------------------------------------
/src/eval_metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import numpy as np
3 | import copy
4 |
5 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
6 | num_q, num_g = distmat.shape
7 | if num_g < max_rank:
8 | max_rank = num_g
9 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
10 | indices = np.argsort(distmat, axis=1)
11 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
12 |
13 | # compute cmc curve for each query
14 | all_cmc = []
15 | all_AP = []
16 | num_valid_q = 0.
17 | for q_idx in range(num_q):
18 | # get query pid and camid
19 | q_pid = q_pids[q_idx]
20 | q_camid = q_camids[q_idx]
21 |
22 | # remove gallery samples that have the same pid and camid with query
23 | order = indices[q_idx]
24 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
25 | keep = np.invert(remove)
26 |
27 | # compute cmc curve
28 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
29 | if not np.any(orig_cmc):
30 | # this condition is true when query identity does not appear in gallery
31 | continue
32 |
33 | cmc = orig_cmc.cumsum()
34 | cmc[cmc > 1] = 1
35 |
36 | all_cmc.append(cmc[:max_rank])
37 | num_valid_q += 1.
38 |
39 | # compute average precision
40 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
41 | num_rel = orig_cmc.sum()
42 | tmp_cmc = orig_cmc.cumsum()
43 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
44 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
45 | AP = tmp_cmc.sum() / num_rel
46 | all_AP.append(AP)
47 |
48 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
49 |
50 | all_cmc = np.asarray(all_cmc).astype(np.float32)
51 | all_cmc = all_cmc.sum(0) / num_valid_q
52 | mAP = np.mean(all_AP)
53 |
54 | return all_cmc, mAP
55 |
56 |
57 |
--------------------------------------------------------------------------------
/src/losses.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Variable
6 |
7 | """
8 | Shorthands for loss:
9 | - CrossEntropyLabelSmooth: xent
10 | - TripletLoss: htri
11 | - CenterLoss: cent
12 | """
13 | __all__ = ['CrossEntropyLabelSmooth', 'TripletLoss']
14 |
15 |
16 | class CrossEntropyLabelSmooth(nn.Module):
17 | """Cross entropy loss with label smoothing regularizer.
18 |
19 | Reference:
20 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
21 | Equation: y = (1 - epsilon) * y + epsilon / K.
22 |
23 | Args:
24 | num_classes (int): number of classes.
25 | epsilon (float): weight.
26 | """
27 |
28 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
29 | super(CrossEntropyLabelSmooth, self).__init__()
30 | self.num_classes = num_classes
31 | self.epsilon = epsilon
32 | self.use_gpu = use_gpu
33 | self.logsoftmax = nn.LogSoftmax(dim=1)
34 |
35 | def forward(self, inputs, targets):
36 | """
37 | Args:
38 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
39 | targets: ground truth labels with shape (num_classes)
40 | """
41 | log_probs = self.logsoftmax(inputs)
42 | targets = torch.zeros(log_probs.size()).scatter_(
43 | 1, targets.unsqueeze(1).data.cpu(), 1)
44 | if self.use_gpu:
45 | targets = targets.cuda()
46 | targets = Variable(targets, requires_grad=False)
47 | targets = (1 - self.epsilon) * targets + \
48 | self.epsilon / self.num_classes
49 | loss = (- targets * log_probs).mean(0).sum()
50 | return loss
51 |
52 |
53 | class TripletLoss(nn.Module):
54 | """Triplet loss with hard positive/negative mining.
55 |
56 | Reference:
57 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
58 |
59 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
60 |
61 | Args:
62 | margin (float): margin for triplet.
63 | """
64 |
65 | def __init__(self, margin=0.3):
66 | super(TripletLoss, self).__init__()
67 | self.margin = margin
68 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
69 |
70 | def forward(self, inputs, targets):
71 | """
72 | Args:
73 | inputs: feature matrix with shape (batch_size, feat_dim)
74 | targets: ground truth labels with shape (num_classes)
75 | """
76 | n = inputs.size(0)
77 | # Compute pairwise distance, replace by the official when merged
78 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
79 | dist = dist + dist.t()
80 | dist.addmm_(1, -2, inputs, inputs.t())
81 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
82 | # For each anchor, find the hardest positive and negative
83 | mask = targets.expand(n, n).eq(targets.expand(n, n).t())
84 | dist_ap, dist_an = [], []
85 | for i in range(n):
86 | dist_ap.append(dist[i][mask[i]].max())
87 | dist_an.append(dist[i][mask[i] == 0].min())
88 |
89 | dist_ap = torch.stack(dist_ap)
90 | dist_an = torch.stack(dist_an)
91 |
92 | # Compute ranking hinge loss
93 | y = dist_an.data.new()
94 | y.resize_as_(dist_an.data)
95 | y.fill_(1)
96 | y = Variable(y)
97 | loss = self.ranking_loss(dist_an, dist_ap, y)
98 | return loss
99 |
100 |
101 | if __name__ == '__main__':
102 | pass
103 |
--------------------------------------------------------------------------------
/src/main_video_person_reid.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os
3 | import sys
4 | import time
5 | import datetime
6 | import argparse
7 | import os.path as osp
8 | import numpy as np
9 | from tqdm import tqdm
10 | import more_itertools as mit
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.backends.cudnn as cudnn
15 | from torch.utils.data import DataLoader
16 | from torch.autograd import Variable
17 | from torch.optim import lr_scheduler
18 |
19 | import data_manager
20 | from video_loader import VideoDataset
21 | import transforms as T
22 | import models
23 | import utils
24 | from losses import CrossEntropyLabelSmooth, TripletLoss
25 | from utils import AverageMeter, Logger, save_checkpoint, get_features
26 | from eval_metrics import evaluate
27 | from samplers import RandomIdentitySampler
28 | from project_utils import *
29 |
30 | parser = argparse.ArgumentParser(
31 | description='Train video model with cross entropy loss')
32 | # Datasets
33 | parser.add_argument('-d', '--dataset', type=str, default='mars',
34 | choices=data_manager.get_names())
35 | parser.add_argument('-j', '--workers', default=0, type=int,
36 | help="number of data loading workers (default: 4)")
37 | parser.add_argument('--height', type=int, default=256,
38 | help="height of an image (default: 256)")
39 | parser.add_argument('--width', type=int, default=128,
40 | help="width of an image (default: 128)")
41 | parser.add_argument('--seq-len', type=int, default=4,
42 | help="number of images to sample in a tracklet")
43 | parser.add_argument('--test-num-tracks', type=int, default=16,
44 | help="number of tracklets to pass to GPU during test (to avoid OOM error)")
45 | # Optimization options
46 | parser.add_argument('--max-epoch', default=800, type=int,
47 | help="maximum epochs to run")
48 | parser.add_argument('--start-epoch', default=0, type=int,
49 | help="manual epoch number (useful on restarts)")
50 | parser.add_argument('--data-selection', type=str,
51 | default='random', help="random/evenly")
52 | parser.add_argument('--train-batch', default=32, type=int,
53 | help="train batch size")
54 | parser.add_argument('--test-batch', default=1, type=int, help="has to be 1")
55 | parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
56 | help="initial learning rate, use 0.0001 for rnn, use 0.0003 for pooling and attention")
57 | parser.add_argument('--stepsize', default=200, type=int,
58 | help="stepsize to decay learning rate (>0 means this is enabled)")
59 | parser.add_argument('--gamma', default=0.1, type=float,
60 | help="learning rate decay")
61 | parser.add_argument('--weight-decay', default=5e-04, type=float,
62 | help="weight decay (default: 5e-04)")
63 | parser.add_argument('--margin', type=float, default=0.3,
64 | help="margin for triplet loss")
65 | parser.add_argument('--num-instances', type=int, default=4,
66 | help="number of instances per identity")
67 |
68 | # Architecture
69 | parser.add_argument('-a', '--arch', type=str, default='se_resnet50_cosam45_tp',
70 | help="se_resnet50_tp, se_resnet50_ta")
71 |
72 | # Miscs
73 | parser.add_argument('--print-freq', type=int,
74 | default=78, help="print frequency")
75 | parser.add_argument('--seed', type=int, default=1, help="manual seed")
76 | parser.add_argument('--pretrained-model', type=str,
77 | default=None, help='need to be set for loading pretrained models')
78 | parser.add_argument('--evaluate', action='store_true', help="evaluation only")
79 | parser.add_argument('--eval-step', type=int, default=200,
80 | help="run evaluation for every N epochs (set to -1 to test after training)")
81 | parser.add_argument('--save-step', type=int, default=50)
82 | parser.add_argument('--save-dir', type=str, default='')
83 | parser.add_argument('--save-prefix', type=str, default='se_resnet50_cosam45_tp')
84 | parser.add_argument('--use-cpu', action='store_true', help="use cpu")
85 | parser.add_argument('--gpu-devices', default='1', type=str,
86 | help='gpu device ids for CUDA_VISIBLE_DEVICES')
87 |
88 | args = parser.parse_args()
89 |
90 |
91 | def main():
92 | args.save_dir = args.arch + '_' + args.save_dir
93 | args.save_prefix = args.arch + '_' + args.save_dir
94 |
95 | torch.manual_seed(args.seed)
96 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
97 | use_gpu = torch.cuda.is_available()
98 | if args.use_cpu:
99 | use_gpu = False
100 |
101 | # append date with save_dir
102 | args.save_dir = '../scratch/' + utils.get_currenttime_prefix() + '_' + \
103 | args.dataset + '_' + args.save_dir
104 | if args.pretrained_model is not None:
105 | args.save_dir = os.path.dirname(args.pretrained_model)
106 |
107 | if not osp.exists(args.save_dir):
108 | os.makedirs(args.save_dir)
109 |
110 | if not args.evaluate:
111 | sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
112 | else:
113 | sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
114 | print("==========\nArgs:{}\n==========".format(args))
115 |
116 | if use_gpu:
117 | print("Currently using GPU {}".format(args.gpu_devices))
118 | cudnn.benchmark = True
119 | torch.cuda.manual_seed_all(args.seed)
120 | else:
121 | print("Currently using CPU (GPU is highly recommended)")
122 |
123 | print("Initializing dataset {}".format(args.dataset))
124 | dataset = data_manager.init_dataset(name=args.dataset)
125 |
126 | transform_train = T.Compose([
127 | T.Random2DTranslation(args.height, args.width),
128 | T.RandomHorizontalFlip(),
129 | T.ToTensor(),
130 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
131 | ])
132 |
133 | transform_test = T.Compose([
134 | T.Resize((args.height, args.width)),
135 | T.ToTensor(),
136 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
137 | ])
138 |
139 | pin_memory = True if use_gpu else False
140 |
141 | trainloader = DataLoader(
142 | VideoDataset(dataset.train, seq_len=args.seq_len,
143 | sample=args.data_selection, transform=transform_train),
144 | sampler=RandomIdentitySampler(
145 | dataset.train, num_instances=args.num_instances),
146 | batch_size=args.train_batch, num_workers=args.workers,
147 | pin_memory=pin_memory, drop_last=True,
148 | )
149 |
150 | queryloader = DataLoader(
151 | VideoDataset(dataset.query, seq_len=args.seq_len,
152 | sample='dense', transform=transform_test),
153 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
154 | pin_memory=pin_memory, drop_last=False,
155 | )
156 |
157 | galleryloader = DataLoader(
158 | VideoDataset(dataset.gallery, seq_len=args.seq_len,
159 | sample='dense', transform=transform_test),
160 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
161 | pin_memory=pin_memory, drop_last=False,
162 | )
163 |
164 | print("Initializing model: {}".format(args.arch))
165 | model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, seq_len=args.seq_len)
166 |
167 | # pretrained model loading
168 | if args.pretrained_model is not None:
169 | if not os.path.exists(args.pretrained_model):
170 | raise IOError("Can't find pretrained model: {}".format(
171 | args.pretrained_model))
172 | print("Loading checkpoint from '{}'".format(args.pretrained_model))
173 | pretrained_state = torch.load(args.pretrained_model)['state_dict']
174 | print(len(pretrained_state), ' keys in pretrained model')
175 |
176 | current_model_state = model.state_dict()
177 | pretrained_state = {key: val
178 | for key, val in pretrained_state.items()
179 | if key in current_model_state and val.size() == current_model_state[key].size()}
180 |
181 | print(len(pretrained_state),
182 | ' keys in pretrained model are available in current model')
183 | current_model_state.update(pretrained_state)
184 | model.load_state_dict(current_model_state)
185 |
186 | print("Model size: {:.5f}M".format(sum(p.numel()
187 | for p in model.parameters())/1000000.0))
188 |
189 | if use_gpu:
190 | model = nn.DataParallel(model).cuda()
191 |
192 | criterion_xent = CrossEntropyLabelSmooth(
193 | num_classes=dataset.num_train_pids, use_gpu=use_gpu)
194 | criterion_htri = TripletLoss(margin=args.margin)
195 |
196 | optimizer = torch.optim.Adam(
197 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
198 | if args.stepsize > 0:
199 | scheduler = lr_scheduler.StepLR(
200 | optimizer, step_size=args.stepsize, gamma=args.gamma)
201 | start_epoch = args.start_epoch
202 |
203 | if args.evaluate:
204 | print("Evaluate only")
205 | test(model, queryloader, galleryloader, use_gpu)
206 | return
207 |
208 | start_time = time.time()
209 | best_rank1 = -np.inf
210 |
211 | is_first_time = True
212 | for epoch in range(start_epoch, args.max_epoch):
213 | print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
214 |
215 | train(model, criterion_xent, criterion_htri,
216 | optimizer, trainloader, use_gpu)
217 |
218 | if args.stepsize > 0:
219 | scheduler.step()
220 |
221 | rank1 = 'NA'
222 | is_best = False
223 |
224 | if args.eval_step > 0 and (epoch+1) % args.eval_step == 0 or (epoch+1) == args.max_epoch:
225 | print("==> Test")
226 | rank1 = test(model, queryloader, galleryloader, use_gpu)
227 | is_best = rank1 > best_rank1
228 | if is_best:
229 | best_rank1 = rank1
230 |
231 | # save the model as required
232 | if (epoch+1) % args.save_step == 0:
233 | if use_gpu:
234 | state_dict = model.module.state_dict()
235 | else:
236 | state_dict = model.state_dict()
237 |
238 | save_checkpoint({
239 | 'state_dict': state_dict,
240 | 'rank1': rank1,
241 | 'epoch': epoch,
242 | }, is_best, osp.join(args.save_dir, args.save_prefix + 'checkpoint_ep' + str(epoch+1) + '.pth.tar'))
243 |
244 | is_first_time = False
245 | if not is_first_time:
246 | utils.disable_all_print_once()
247 |
248 | elapsed = round(time.time() - start_time)
249 | elapsed = str(datetime.timedelta(seconds=elapsed))
250 | print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
251 |
252 |
253 | def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
254 | model.train()
255 | losses = AverageMeter()
256 | import time
257 | torch.backends.cudnn.benchmark = True
258 |
259 | start = time.time()
260 | for batch_idx, (imgs, pids, _) in enumerate(tqdm(trainloader)):
261 | if use_gpu:
262 | imgs, pids = imgs.cuda(), pids.cuda()
263 |
264 | imgs, pids = Variable(imgs), Variable(pids)
265 |
266 | outputs, features = model(imgs)
267 |
268 | # combine hard triplet loss with cross entropy loss
269 | xent_loss = criterion_xent(outputs, pids)
270 | htri_loss = criterion_htri(features, pids)
271 | loss = xent_loss + htri_loss
272 |
273 | optimizer.zero_grad()
274 | loss.backward()
275 | optimizer.step()
276 |
277 | losses.update(loss.item(), pids.size(0))
278 | if (batch_idx+1) % args.print_freq == 0:
279 | print("Batch {}/{}\t Loss {:.6f} ({:.6f})".format(batch_idx +
280 | 1, len(trainloader), losses.val, losses.avg))
281 |
282 |
283 | def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]):
284 | model.eval()
285 |
286 | qf, q_pids, q_camids = [], [], []
287 | print('extracting query feats')
288 | for batch_idx, (imgs, pids, camids) in enumerate(tqdm(queryloader)):
289 | if use_gpu:
290 | imgs = imgs.cuda()
291 |
292 | with torch.no_grad():
293 | #imgs = Variable(imgs, volatile=True)
294 | # b=1, n=number of clips, s=16
295 | b, n, s, c, h, w = imgs.size()
296 | assert(b == 1)
297 | imgs = imgs.view(b*n, s, c, h, w)
298 | features = get_features(model, imgs, args.test_num_tracks)
299 | features = torch.mean(features, 0)
300 | features = features.data.cpu()
301 | qf.append(features)
302 | q_pids.extend(pids)
303 | q_camids.extend(camids)
304 | torch.cuda.empty_cache()
305 |
306 | qf = torch.stack(qf)
307 | q_pids = np.asarray(q_pids)
308 | q_camids = np.asarray(q_camids)
309 |
310 | print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
311 |
312 | gf, g_pids, g_camids = [], [], []
313 | print('extracting gallery feats')
314 | for batch_idx, (imgs, pids, camids) in enumerate(tqdm(galleryloader)):
315 | if use_gpu:
316 | imgs = imgs.cuda()
317 |
318 | with torch.no_grad():
319 | #imgs = Variable(imgs, volatile=True)
320 | b, n, s, c, h, w = imgs.size()
321 | imgs = imgs.view(b*n, s, c, h, w)
322 | assert(b == 1)
323 | # handle chunked data
324 | features = get_features(model, imgs, args.test_num_tracks)
325 | features = torch.mean(features, 0)
326 | torch.cuda.empty_cache()
327 |
328 | features = features.data.cpu()
329 | gf.append(features)
330 | g_pids.extend(pids)
331 | g_camids.extend(camids)
332 | gf = torch.stack(gf)
333 | g_pids = np.asarray(g_pids)
334 | g_camids = np.asarray(g_camids)
335 |
336 | print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
337 | print("Computing distance matrix")
338 |
339 | m, n = qf.size(0), gf.size(0)
340 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
341 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
342 | distmat.addmm_(1, -2, qf, gf.t())
343 | distmat = distmat.numpy()
344 |
345 | print("Computing CMC and mAP")
346 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
347 |
348 | print("Results ----------")
349 | print("mAP: {:.1%}".format(mAP))
350 | print("CMC curve")
351 | for r in ranks:
352 | print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
353 | print("------------------")
354 |
355 | # re-ranking
356 | from person_re_ranking.python_version.re_ranking_feature import re_ranking
357 | rerank_distmat = re_ranking(
358 | qf.numpy(), gf.numpy(), k1=20, k2=6, lambda_value=0.3)
359 | print("Computing CMC and mAP for re-ranking")
360 | cmc, mAP = evaluate(rerank_distmat, q_pids, g_pids, q_camids, g_camids)
361 | print("Results ----------")
362 | print("mAP: {:.1%}".format(mAP))
363 | print("CMC curve")
364 | for r in ranks:
365 | print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
366 | print("------------------")
367 |
368 | return cmc[0]
369 |
370 |
371 | if __name__ == '__main__':
372 | main()
373 |
--------------------------------------------------------------------------------
/src/models/ResNet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import os
5 | import sys
6 | from torch import nn
7 | from torch.nn import functional as F
8 | from torch.autograd import Variable
9 | import torchvision
10 | import torchvision.models as models
11 | import utils
12 | from .aggregation_layers import AGGREGATION
13 | from .cosam import COSEG_ATTENTION
14 |
15 |
16 | def get_ResNet(net_type):
17 | if net_type == "resnet50":
18 | model = models.resnet50(pretrained=True)
19 | model = nn.Sequential(*list(model.children())[:-2])
20 | elif net_type == "senet101":
21 | model = models.resnet101(pretrained=True)
22 | model = nn.Sequential(*list(model.children())[:-2])
23 | else:
24 | assert False, "unknown ResNet type : " + net_type
25 |
26 | return model
27 |
28 |
29 | class Identity(nn.Module):
30 | def __init__(self, *args, **kwargs):
31 | super().__init__()
32 | print("instantiating " + self.__class__.__name__)
33 |
34 | def forward(self, x, b, t):
35 | return [x, None, None]
36 |
37 |
38 | class CosegAttention(nn.Module):
39 | def __init__(self, attention_types, num_feat_maps, h_w, t):
40 | super().__init__()
41 | print("instantiating " + self.__class__.__name__)
42 | self.attention_modules = nn.ModuleList()
43 |
44 | for i, attention_type in enumerate(attention_types):
45 | if attention_type in COSEG_ATTENTION:
46 | self.attention_modules.append(
47 | COSEG_ATTENTION[attention_type](
48 | num_feat_maps[i], h_w=h_w[i], t=t)
49 | )
50 | else:
51 | assert False, "unknown attention type " + attention_type
52 |
53 | def forward(self, x, i, b, t):
54 | return self.attention_modules[i](x, b, t)
55 |
56 |
57 | class ResNet(nn.Module):
58 | def __init__(
59 | self,
60 | num_classes,
61 | net_type="resnet50",
62 | attention_types=["NONE", "NONE", "NONE", "NONE"],
63 | aggregation_type="tp",
64 | seq_len=4,
65 | **kwargs
66 | ):
67 | super(ResNet, self).__init__()
68 | print(
69 | "instantiating "
70 | + self.__class__.__name__
71 | + " net type"
72 | + net_type
73 | + " from "
74 | + __file__
75 | )
76 | print("attention type", attention_types)
77 |
78 | # base network instantiation
79 | self.base = get_ResNet(net_type=net_type)
80 | self.feat_dim = 2048
81 |
82 | # attention modules
83 | self.num_feat_maps = [256, 512, 1024, 2048]
84 | self.h_w = [(64, 32), (32, 16), (16, 8), (8, 4)]
85 | if aggregation_type == "ta":
86 | self.h_w = [(64, 32), (32, 16), (16, 8), (8, 4)]
87 | else:
88 | utils.set_stride(self.base[7], 1)
89 | self.h_w = [(64, 32), (32, 16), (16, 8), (16, 8)]
90 |
91 | print(self.h_w)
92 | self.attention_modules = CosegAttention(
93 | attention_types, num_feat_maps=self.num_feat_maps, h_w=self.h_w, t=seq_len
94 | )
95 |
96 | # aggregation module
97 | self.aggregation = AGGREGATION[aggregation_type](
98 | self.feat_dim, h_w=self.h_w[-1], t=seq_len
99 | )
100 |
101 | # classifier
102 | self.classifier = nn.Linear(self.aggregation.feat_dim, num_classes)
103 |
104 | def base_layer0(self, x):
105 | x = self.base[0](x)
106 | x = self.base[1](x)
107 | x = self.base[2](x)
108 | x = self.base[3](x)
109 | x = self.base[4](x)
110 | return x
111 |
112 | def base_layer1(self, x):
113 | return self.base[5](x)
114 |
115 | def base_layer2(self, x):
116 | return self.base[6](x)
117 |
118 | def base_layer3(self, x):
119 | return self.base[7](x)
120 |
121 | def extract_features(self, x, b, t):
122 | features_in = x
123 | attentions = []
124 |
125 | for index in range(len(self.num_feat_maps)):
126 | features_before_attention = getattr(self, "base_layer" + str(index))(
127 | features_in
128 | )
129 | features_out, channelwise, spatialwise = self.attention_modules(
130 | features_before_attention, index, b, t
131 | )
132 | features_in = features_out
133 |
134 | # note down the attentions
135 | attentions.append(
136 | (features_before_attention, features_out, channelwise, spatialwise)
137 | )
138 |
139 | return features_out, attentions
140 |
141 | def return_values(
142 | self, features, logits, attentions, is_training, return_attentions
143 | ):
144 | buffer = []
145 | if not is_training:
146 | buffer.append(features)
147 | if not return_attentions:
148 | return features
149 | else:
150 | buffer.append(logits)
151 | buffer.append(features)
152 |
153 | if return_attentions:
154 | buffer.append(attentions)
155 |
156 | return buffer
157 |
158 | def forward(self, x, return_attentions=False):
159 | b = x.size(0)
160 | t = x.size(1)
161 | x = x.view(b * t, x.size(2), x.size(3), x.size(4))
162 |
163 | final_spatial_features, attentions = self.extract_features(x, b, t)
164 | f = self.aggregation(final_spatial_features, b, t)
165 |
166 | if not self.training:
167 | return self.return_values(
168 | f, None, attentions, self.training, return_attentions
169 | )
170 |
171 | y = self.classifier(f)
172 | return self.return_values(f, y, attentions, self.training, return_attentions)
173 |
174 |
175 | # all mutual correlation attention models
176 | def ResNet50_COSAM45_TP(num_classes, **kwargs):
177 | return ResNet(
178 | num_classes,
179 | net_type="resnet50",
180 | aggregation_type="tp",
181 | attention_types=["NONE", "NONE",
182 | "COSAM", "COSAM"],
183 | **kwargs
184 | )
185 |
186 |
187 | def ResNet50_COSAM45_TA(num_classes, **kwargs):
188 | return ResNet(
189 | num_classes,
190 | net_type="resnet50",
191 | aggregation_type="ta",
192 | attention_types=["NONE", "NONE",
193 | "COSAM", "COSAM"],
194 | **kwargs
195 | )
196 |
197 |
198 | def ResNet50_COSAM45_RNN(num_classes, **kwargs):
199 | return ResNet(
200 | num_classes,
201 | net_type="resnet50",
202 | aggregation_type="rnn",
203 | attention_types=["NONE", "NONE",
204 | "COSAM", "COSAM"],
205 | **kwargs
206 | )
207 |
208 |
--------------------------------------------------------------------------------
/src/models/SE_ResNet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import os
5 | import sys
6 | from torch import nn
7 | from torch.nn import functional as F
8 | from torch.autograd import Variable
9 | import torchvision
10 | import utils as utils
11 |
12 | this_path = os.path.split(__file__)[0]
13 | sys.path.append(this_path)
14 | from .aggregation_layers import AGGREGATION
15 | from .cosam import COSEG_ATTENTION
16 | import senet
17 |
18 | def get_SENet(net_type):
19 | if net_type == "senet50":
20 | model = senet.se_resnet50(pretrained=True)
21 | elif net_type == "senet101":
22 | model = senet.se_resnet101(pretrained=True)
23 | else:
24 | assert False, "unknown SE ResNet type : " + net_type
25 |
26 | return model
27 |
28 |
29 | class CosegAttention(nn.Module):
30 | def __init__(self, attention_types, num_feat_maps, h_w, t):
31 | super().__init__()
32 | print("instantiating " + self.__class__.__name__)
33 | self.attention_modules = nn.ModuleList()
34 |
35 | for i, attention_type in enumerate(attention_types):
36 | if attention_type in COSEG_ATTENTION:
37 | self.attention_modules.append(
38 | COSEG_ATTENTION[attention_type](
39 | num_feat_maps[i], h_w=h_w[i], t=t)
40 | )
41 | else:
42 | assert False, "unknown attention type " + attention_type
43 |
44 | def forward(self, x, i, b, t):
45 | return self.attention_modules[i](x, b, t)
46 |
47 |
48 | class SE_ResNet(nn.Module):
49 | def __init__(
50 | self,
51 | num_classes,
52 | net_type="senet50",
53 | attention_types=["NONE", "NONE", "NONE", "NONE", "NONE"],
54 | aggregation_type="tp",
55 | seq_len=4,
56 | is_baseline=False,
57 | **kwargs
58 | ):
59 | super(SE_ResNet, self).__init__()
60 | print(
61 | "instantiating "
62 | + self.__class__.__name__
63 | + " net type"
64 | + net_type
65 | + " from "
66 | + __file__
67 | )
68 | print("attention type", attention_types)
69 |
70 | # base network instantiation
71 | self.base = get_SENet(net_type=net_type)
72 | self.feat_dim = self.base.feature_dim
73 |
74 | # attention modules
75 | self.num_feat_maps = [64, 256, 512, 1024, 2048]
76 | self.h_w = [(64, 32), (64, 32), (32, 16), (16, 8), (8, 4)]
77 |
78 | # allow reducing spatial dimension for Temporal attention (ta) to keep the params at a manageable number
79 | # according to the baseline paper
80 | if aggregation_type == "ta":
81 | self.h_w = [(64, 32), (64, 32), (32, 16), (16, 8), (8, 4)]
82 | else:
83 | utils.set_stride(self.base.layer4, 1)
84 | self.h_w = [(64, 32), (64, 32), (32, 16), (16, 8), (16, 8)]
85 |
86 | print(self.h_w)
87 |
88 | self.attention_modules = CosegAttention(
89 | attention_types, num_feat_maps=self.num_feat_maps, h_w=self.h_w, t=seq_len
90 | )
91 |
92 | # aggregation module
93 | self.aggregation = AGGREGATION[aggregation_type](
94 | self.feat_dim, h_w=self.h_w[-1], t=seq_len
95 | )
96 |
97 | # classifier
98 | self.classifier = nn.Linear(self.aggregation.feat_dim, num_classes)
99 |
100 | def extract_features(self, x, b, t):
101 | features_in = x
102 | attentions = []
103 |
104 | for index in range(5):
105 | features_before_attention = getattr(self.base, "layer" + str(index))(
106 | features_in
107 | )
108 | features_out, channelwise, spatialwise = self.attention_modules(
109 | features_before_attention, index, b, t
110 | )
111 | # print(features_out.shape)
112 | features_in = features_out
113 |
114 | # note down the attentions
115 | attentions.append(
116 | (features_before_attention, features_out, channelwise, spatialwise)
117 | )
118 |
119 | return features_out, attentions
120 |
121 | def return_values(
122 | self, features, logits, attentions, is_training, return_attentions
123 | ):
124 | buffer = []
125 | if not is_training:
126 | buffer.append(features)
127 | if not return_attentions:
128 | return features
129 | else:
130 | buffer.append(logits)
131 | buffer.append(features)
132 |
133 | if return_attentions:
134 | buffer.append(attentions)
135 |
136 | return buffer
137 |
138 | def forward(self, x, return_attentions=False):
139 | b = x.size(0)
140 | t = x.size(1)
141 | x = x.view(b * t, x.size(2), x.size(3), x.size(4))
142 |
143 | final_spatial_features, attentions = self.extract_features(x, b, t)
144 | f = self.aggregation(final_spatial_features, b, t)
145 |
146 | if not self.training:
147 | return self.return_values(
148 | f, None, attentions, self.training, return_attentions
149 | )
150 |
151 | y = self.classifier(f)
152 | return self.return_values(f, y, attentions, self.training, return_attentions)
153 |
154 |
155 |
156 | # all attention models
157 | def SE_ResNet50_COSAM45_TP(num_classes, **kwargs):
158 | return SE_ResNet(
159 | num_classes,
160 | net_type="senet50",
161 | aggregation_type="tp",
162 | attention_types=[
163 | "NONE",
164 | "NONE",
165 | "NONE",
166 | "COSAM",
167 | "COSAM",
168 | ],
169 | **kwargs
170 | )
171 |
172 |
173 | def SE_ResNet50_COSAM45_TA(num_classes, **kwargs):
174 | return SE_ResNet(
175 | num_classes,
176 | net_type="senet50",
177 | aggregation_type="ta",
178 | attention_types=[
179 | "NONE",
180 | "NONE",
181 | "NONE",
182 | "COSAM",
183 | "COSAM",
184 | ],
185 | **kwargs
186 | )
187 |
188 |
189 | def SE_ResNet50_COSAM45_RNN(num_classes, **kwargs):
190 | return SE_ResNet(
191 | num_classes,
192 | net_type="senet50",
193 | aggregation_type="rnn",
194 | attention_types=[
195 | "NONE",
196 | "NONE",
197 | "NONE",
198 | "COSAM",
199 | "COSAM",
200 | ],
201 | **kwargs
202 | )
203 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .ResNet import *
4 | from .SE_ResNet import *
5 |
6 | __factory = {
7 | # ResNet50 network
8 | "resnet50_cosam45_tp": ResNet50_COSAM45_TP,
9 | "resnet50_cosam45_ta": ResNet50_COSAM45_TA,
10 | "resnet50_cosam45_rnn": ResNet50_COSAM45_RNN,
11 |
12 | # Squeeze and Expand network
13 | "se_resnet50_cosam45_tp": SE_ResNet50_COSAM45_TP,
14 | "se_resnet50_cosam45_ta": SE_ResNet50_COSAM45_TA,
15 | "se_resnet50_cosam45_rnn": SE_ResNet50_COSAM45_RNN,
16 |
17 | }
18 |
19 |
20 | def get_names():
21 | return __factory.keys()
22 |
23 |
24 | def init_model(name, *args, **kwargs):
25 | if name not in __factory.keys():
26 | raise KeyError("Unknown model: {}".format(name))
27 | return __factory[name](*args, **kwargs)
28 |
--------------------------------------------------------------------------------
/src/models/aggregation_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class AggregationTP(nn.Module):
7 | def __init__(self, feat_dim, *args, **kwargs):
8 | super().__init__()
9 | print("instantiating " + self.__class__.__name__)
10 |
11 | self.feat_dim = feat_dim
12 |
13 | def forward(self, x, b, t):
14 | x = F.avg_pool2d(x, x.size()[2:])
15 | x = x.view(b, t, -1)
16 | x = x.permute(0, 2, 1)
17 | f = F.avg_pool1d(x, t)
18 | f = f.view(b, self.feat_dim)
19 | return f
20 |
21 |
22 | class AggregationTA(nn.Module):
23 | def __init__(self, feat_dim, *args, **kwargs):
24 | super().__init__()
25 | print("instantiating " + self.__class__.__name__)
26 |
27 | self.feat_dim = feat_dim
28 | self.middle_dim = 256
29 | self.attention_conv = nn.Conv2d(
30 | self.feat_dim, self.middle_dim, [8, 4]
31 | ) # 8, 4 cooresponds to 256, 128 input image size
32 | self.attention_tconv = nn.Conv1d(self.middle_dim, 1, 3, padding=1)
33 |
34 | def forward(self, x, b, t):
35 |
36 | # spatial attention
37 | a = F.relu(self.attention_conv(x))
38 |
39 | # arrange into batch temporal view
40 | a = a.view(b, t, self.middle_dim)
41 |
42 | # temporal attention
43 | a = a.permute(0, 2, 1)
44 | a = F.relu(self.attention_tconv(a))
45 | a = a.view(b, t)
46 | a = F.softmax(a, dim=1)
47 |
48 | # global avg pooling of conv features
49 | x = F.avg_pool2d(x, x.size()[2:])
50 |
51 | # apply temporal attention
52 | x = x.view(b, t, -1)
53 | a = torch.unsqueeze(a, -1)
54 | a = a.expand(b, t, self.feat_dim)
55 | att_x = torch.mul(x, a)
56 | att_x = torch.sum(att_x, 1)
57 | f = att_x.view(b, self.feat_dim)
58 |
59 | return f
60 |
61 |
62 | class AggregationRNN(nn.Module):
63 | def __init__(self, feat_dim, *args, **kwargs):
64 | super().__init__()
65 | print("instantiating " + self.__class__.__name__)
66 |
67 | self.hidden_dim = 512
68 | self.lstm = nn.LSTM(
69 | input_size=feat_dim,
70 | hidden_size=self.hidden_dim,
71 | num_layers=1,
72 | batch_first=True,
73 | )
74 | self.feat_dim = self.hidden_dim
75 |
76 | def forward(self, x, b, t):
77 | x = F.avg_pool2d(x, x.size()[2:])
78 | x = x.view(b, t, -1)
79 |
80 | # apply LSTM
81 | output, (h_n, c_n) = self.lstm(x)
82 | output = output.permute(0, 2, 1)
83 | f = F.avg_pool1d(output, t)
84 | f = f.view(b, self.hidden_dim)
85 |
86 | return f
87 |
88 |
89 | AGGREGATION = {
90 | "tp": AggregationTP,
91 | "ta": AggregationTA,
92 | "rnn": AggregationRNN,
93 | }
94 |
--------------------------------------------------------------------------------
/src/models/cosam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import utils as utils
4 | import torch.nn.functional as F
5 |
6 | class Identity(nn.Module):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__()
9 | print("instantiating " + self.__class__.__name__)
10 |
11 | def forward(self, x, b, t):
12 | return [x, None, None]
13 |
14 |
15 | class COSAM(nn.Module):
16 | def __init__(self, in_channels, t, h_w, *args):
17 | super().__init__()
18 | print(
19 | "instantiating "
20 | + self.__class__.__name__
21 | + " with in-channels: "
22 | + str(in_channels)
23 | + ", t = "
24 | + str(t)
25 | + ", hw = "
26 | + str(h_w)
27 | )
28 |
29 | self.eps = 1e-4
30 | self.h_w = h_w
31 | self.t = t
32 | self.mid_channels = 256
33 | if in_channels <= 256:
34 | self.mid_channels = 64
35 |
36 | self.dim_reduction = nn.Sequential(
37 | nn.Conv2d(in_channels, self.mid_channels, kernel_size=1),
38 | nn.BatchNorm2d(self.mid_channels),
39 | nn.ReLU(),
40 | )
41 |
42 | self.spatial_mask_summary = nn.Sequential(
43 | nn.Conv2d((t - 1) * h_w[0] * h_w[1], 64, kernel_size=1),
44 | nn.BatchNorm2d(64),
45 | nn.ReLU(),
46 | nn.Conv2d(64, 1, kernel_size=1),
47 | )
48 |
49 | self.channelwise_attention = nn.Sequential(
50 | nn.Linear(in_channels, self.mid_channels),
51 | nn.Tanh(),
52 | nn.Linear(self.mid_channels, in_channels),
53 | nn.Sigmoid(),
54 | )
55 |
56 | def get_selectable_indices(self, t):
57 | init_list = list(range(t))
58 | index_list = []
59 | for i in range(t):
60 | list_instance = list(init_list)
61 | list_instance.remove(i)
62 | index_list.append(list_instance)
63 |
64 | return index_list
65 |
66 | def get_channelwise_attention(self, feat_maps, b, t):
67 | num_imgs, num_channels, h, w = feat_maps.shape
68 |
69 | # perform global average pooling
70 | channel_avg_pool = F.adaptive_avg_pool2d(feat_maps, output_size=1)
71 | # pass the global average pooled features through the fully connected network with sigmoid activation
72 | channelwise_attention = self.channelwise_attention(channel_avg_pool.view(num_imgs, -1))
73 |
74 | # perform group attention
75 | # groupify the attentions
76 | idwise_channelattention = channelwise_attention.view(b, t, -1)
77 |
78 | # take the mean of attention to attend common channels between frames
79 | group_attention = torch.mean(idwise_channelattention, dim=1, keepdim=True).expand_as(idwise_channelattention)
80 | channelwise_attention = group_attention.contiguous().view(num_imgs, num_channels, 1, 1)
81 |
82 | return channelwise_attention
83 |
84 | def get_spatial_attention(self, feat_maps, b, t):
85 | total, c, h, w = feat_maps.shape
86 | dim_reduced_featuremaps = self.dim_reduction(feat_maps) # #frames x C x H x W
87 |
88 | # resize the feature maps for temporal processing
89 | identitywise_maps = dim_reduced_featuremaps.view(b, t, dim_reduced_featuremaps.shape[1], h, w)
90 |
91 | # get the combination of frame indices
92 | index_list = self.get_selectable_indices(t)
93 |
94 | # select the other images within same id
95 | other_selected_imgs = identitywise_maps[:, index_list]
96 | other_selected_imgs = other_selected_imgs.view(-1, t - 1, dim_reduced_featuremaps.shape[1], h, w) # #frames x t-1 x C x H x W
97 |
98 | # permutate the other dimensions except descriptor dim to last
99 | other_selected_imgs = other_selected_imgs.permute((0, 2, 1, 3, 4)) # #frames x C x t-1 x H x W
100 |
101 | # prepare two matrices for multiplication
102 | dim_reduced_featuremaps = dim_reduced_featuremaps.view(total, self.mid_channels, -1) # #frames x C x (H * W)
103 | dim_reduced_featuremaps = dim_reduced_featuremaps.permute((0, 2, 1)) # frames x (H * W) x C
104 | other_selected_imgs = other_selected_imgs.contiguous().view(total, self.mid_channels, -1) # #frames x C x (t-1 * H * W)
105 |
106 | # mean subtract and divide by std
107 | dim_reduced_featuremaps = dim_reduced_featuremaps - torch.mean(dim_reduced_featuremaps, dim=2, keepdim=True)
108 | dim_reduced_featuremaps = dim_reduced_featuremaps / (torch.std(dim_reduced_featuremaps, dim=2, keepdim=True) + self.eps)
109 | other_selected_imgs = other_selected_imgs - torch.mean(other_selected_imgs, dim=1, keepdim=True)
110 | other_selected_imgs = other_selected_imgs / (torch.std(other_selected_imgs, dim=1, keepdim=True) + self.eps)
111 |
112 | mutual_correlation = (torch.bmm(dim_reduced_featuremaps, other_selected_imgs)
113 | / other_selected_imgs.shape[1]) # #frames x (HW) x (t-1 * H * W)
114 |
115 | mutual_correlation = mutual_correlation.permute(0, 2, 1) # #frames x (t-1 * H * W) x (HW)
116 | mutual_correlation = mutual_correlation.view(total, -1, h, w) # #frames x (t-1 * H * W) x H x W
117 | mutual_correlation_mask = self.spatial_mask_summary(mutual_correlation).sigmoid() # #frames x 1 x H x W
118 |
119 | return mutual_correlation_mask
120 |
121 | def forward(self, feat_maps, b, t):
122 | # get the spatial attention mask
123 | mutualcorr_spatial_mask = self.get_spatial_attention(
124 | feat_maps=feat_maps, b=b, t=t
125 | )
126 | attended_out = torch.mul(feat_maps, mutualcorr_spatial_mask)
127 |
128 | # channel-wise attention
129 | channelwise_mask = self.get_channelwise_attention(
130 | feat_maps=attended_out, b=b, t=t
131 | )
132 | attended_out = attended_out + torch.mul(attended_out, channelwise_mask)
133 |
134 | return attended_out, channelwise_mask, mutualcorr_spatial_mask
135 |
136 |
137 | COSEG_ATTENTION = {
138 | "NONE": Identity,
139 | "COSAM": COSAM,
140 | }
141 |
--------------------------------------------------------------------------------
/src/models/senet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 | from collections import OrderedDict
5 | import math
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils import model_zoo
10 | from torch.nn import functional as F
11 | import torchvision
12 |
13 |
14 | """
15 | Code imported from https://github.com/Cadene/pretrained-models.pytorch
16 | """
17 |
18 |
19 | __all__ = ['senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', 'se_resnext101_32x4d',
20 | 'se_resnet50_fc512']
21 |
22 |
23 | pretrained_settings = {
24 | 'senet154': {
25 | 'imagenet': {
26 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
27 | 'input_space': 'RGB',
28 | 'input_size': [3, 224, 224],
29 | 'input_range': [0, 1],
30 | 'mean': [0.485, 0.456, 0.406],
31 | 'std': [0.229, 0.224, 0.225],
32 | 'num_classes': 1000
33 | }
34 | },
35 | 'se_resnet50': {
36 | 'imagenet': {
37 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
38 | 'input_space': 'RGB',
39 | 'input_size': [3, 224, 224],
40 | 'input_range': [0, 1],
41 | 'mean': [0.485, 0.456, 0.406],
42 | 'std': [0.229, 0.224, 0.225],
43 | 'num_classes': 1000
44 | }
45 | },
46 | 'se_resnet101': {
47 | 'imagenet': {
48 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
49 | 'input_space': 'RGB',
50 | 'input_size': [3, 224, 224],
51 | 'input_range': [0, 1],
52 | 'mean': [0.485, 0.456, 0.406],
53 | 'std': [0.229, 0.224, 0.225],
54 | 'num_classes': 1000
55 | }
56 | },
57 | 'se_resnet152': {
58 | 'imagenet': {
59 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
60 | 'input_space': 'RGB',
61 | 'input_size': [3, 224, 224],
62 | 'input_range': [0, 1],
63 | 'mean': [0.485, 0.456, 0.406],
64 | 'std': [0.229, 0.224, 0.225],
65 | 'num_classes': 1000
66 | }
67 | },
68 | 'se_resnext50_32x4d': {
69 | 'imagenet': {
70 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
71 | 'input_space': 'RGB',
72 | 'input_size': [3, 224, 224],
73 | 'input_range': [0, 1],
74 | 'mean': [0.485, 0.456, 0.406],
75 | 'std': [0.229, 0.224, 0.225],
76 | 'num_classes': 1000
77 | }
78 | },
79 | 'se_resnext101_32x4d': {
80 | 'imagenet': {
81 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
82 | 'input_space': 'RGB',
83 | 'input_size': [3, 224, 224],
84 | 'input_range': [0, 1],
85 | 'mean': [0.485, 0.456, 0.406],
86 | 'std': [0.229, 0.224, 0.225],
87 | 'num_classes': 1000
88 | }
89 | },
90 | }
91 |
92 |
93 | class SEModule(nn.Module):
94 |
95 | def __init__(self, channels, reduction):
96 | super(SEModule, self).__init__()
97 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
98 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)
99 | self.relu = nn.ReLU(inplace=True)
100 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)
101 | self.sigmoid = nn.Sigmoid()
102 |
103 | def forward(self, x):
104 | module_input = x
105 | x = self.avg_pool(x)
106 | x = self.fc1(x)
107 | x = self.relu(x)
108 | x = self.fc2(x)
109 | x = self.sigmoid(x)
110 | return module_input * x
111 |
112 |
113 | class Bottleneck(nn.Module):
114 | """
115 | Base class for bottlenecks that implements `forward()` method.
116 | """
117 | def forward(self, x):
118 | residual = x
119 |
120 | out = self.conv1(x)
121 | out = self.bn1(out)
122 | out = self.relu(out)
123 |
124 | out = self.conv2(out)
125 | out = self.bn2(out)
126 | out = self.relu(out)
127 |
128 | out = self.conv3(out)
129 | out = self.bn3(out)
130 |
131 | if self.downsample is not None:
132 | residual = self.downsample(x)
133 |
134 | out = self.se_module(out) + residual
135 | out = self.relu(out)
136 |
137 | return out
138 |
139 |
140 | class SEBottleneck(Bottleneck):
141 | """
142 | Bottleneck for SENet154.
143 | """
144 | expansion = 4
145 |
146 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
147 | downsample=None):
148 | super(SEBottleneck, self).__init__()
149 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
150 | self.bn1 = nn.BatchNorm2d(planes * 2)
151 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
152 | stride=stride, padding=1, groups=groups,
153 | bias=False)
154 | self.bn2 = nn.BatchNorm2d(planes * 4)
155 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1,
156 | bias=False)
157 | self.bn3 = nn.BatchNorm2d(planes * 4)
158 | self.relu = nn.ReLU(inplace=True)
159 | self.se_module = SEModule(planes * 4, reduction=reduction)
160 | self.downsample = downsample
161 | self.stride = stride
162 |
163 |
164 | class SEResNetBottleneck(Bottleneck):
165 | """
166 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
167 | implementation and uses `stride=stride` in `conv1` and not in `conv2`
168 | (the latter is used in the torchvision implementation of ResNet).
169 | """
170 | expansion = 4
171 |
172 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
173 | downsample=None):
174 | super(SEResNetBottleneck, self).__init__()
175 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False,
176 | stride=stride)
177 | self.bn1 = nn.BatchNorm2d(planes)
178 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
179 | groups=groups, bias=False)
180 | self.bn2 = nn.BatchNorm2d(planes)
181 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
182 | self.bn3 = nn.BatchNorm2d(planes * 4)
183 | self.relu = nn.ReLU(inplace=True)
184 | self.se_module = SEModule(planes * 4, reduction=reduction)
185 | self.downsample = downsample
186 | self.stride = stride
187 |
188 |
189 | class SEResNeXtBottleneck(Bottleneck):
190 | """
191 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
192 | """
193 | expansion = 4
194 |
195 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
196 | downsample=None, base_width=4):
197 | super(SEResNeXtBottleneck, self).__init__()
198 | width = int(math.floor(planes * (base_width / 64.)) * groups)
199 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
200 | stride=1)
201 | self.bn1 = nn.BatchNorm2d(width)
202 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
203 | padding=1, groups=groups, bias=False)
204 | self.bn2 = nn.BatchNorm2d(width)
205 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
206 | self.bn3 = nn.BatchNorm2d(planes * 4)
207 | self.relu = nn.ReLU(inplace=True)
208 | self.se_module = SEModule(planes * 4, reduction=reduction)
209 | self.downsample = downsample
210 | self.stride = stride
211 |
212 |
213 | class SENet(nn.Module):
214 | """
215 | Squeeze-and-excitation network
216 |
217 | Reference:
218 | Hu et al. Squeeze-and-Excitation Networks. CVPR 2018.
219 | """
220 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
221 | inplanes=128, input_3x3=True, downsample_kernel_size=3, downsample_padding=1,
222 | last_stride=2, fc_dims=None, **kwargs):
223 | """
224 | Parameters
225 | ----------
226 | block (nn.Module): Bottleneck class.
227 | - For SENet154: SEBottleneck
228 | - For SE-ResNet models: SEResNetBottleneck
229 | - For SE-ResNeXt models: SEResNeXtBottleneck
230 | layers (list of ints): Number of residual blocks for 4 layers of the
231 | network (layer1...layer4).
232 | groups (int): Number of groups for the 3x3 convolution in each
233 | bottleneck block.
234 | - For SENet154: 64
235 | - For SE-ResNet models: 1
236 | - For SE-ResNeXt models: 32
237 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
238 | - For all models: 16
239 | dropout_p (float or None): Drop probability for the Dropout layer.
240 | If `None` the Dropout layer is not used.
241 | - For SENet154: 0.2
242 | - For SE-ResNet models: None
243 | - For SE-ResNeXt models: None
244 | inplanes (int): Number of input channels for layer1.
245 | - For SENet154: 128
246 | - For SE-ResNet models: 64
247 | - For SE-ResNeXt models: 64
248 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
249 | a single 7x7 convolution in layer0.
250 | - For SENet154: True
251 | - For SE-ResNet models: False
252 | - For SE-ResNeXt models: False
253 | downsample_kernel_size (int): Kernel size for downsampling convolutions
254 | in layer2, layer3 and layer4.
255 | - For SENet154: 3
256 | - For SE-ResNet models: 1
257 | - For SE-ResNeXt models: 1
258 | downsample_padding (int): Padding for downsampling convolutions in
259 | layer2, layer3 and layer4.
260 | - For SENet154: 1
261 | - For SE-ResNet models: 0
262 | - For SE-ResNeXt models: 0
263 | num_classes (int): Number of outputs in `classifier` layer.
264 | """
265 | super(SENet, self).__init__()
266 | self.inplanes = inplanes
267 |
268 | if input_3x3:
269 | layer0_modules = [
270 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
271 | bias=False)),
272 | ('bn1', nn.BatchNorm2d(64)),
273 | ('relu1', nn.ReLU(inplace=True)),
274 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
275 | bias=False)),
276 | ('bn2', nn.BatchNorm2d(64)),
277 | ('relu2', nn.ReLU(inplace=True)),
278 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
279 | bias=False)),
280 | ('bn3', nn.BatchNorm2d(inplanes)),
281 | ('relu3', nn.ReLU(inplace=True)),
282 | ]
283 | else:
284 | layer0_modules = [
285 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
286 | padding=3, bias=False)),
287 | ('bn1', nn.BatchNorm2d(inplanes)),
288 | ('relu1', nn.ReLU(inplace=True)),
289 | ]
290 | # To preserve compatibility with Caffe weights `ceil_mode=True`
291 | # is used instead of `padding=1`.
292 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
293 | ceil_mode=True)))
294 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
295 | self.layer1 = self._make_layer(
296 | block,
297 | planes=64,
298 | blocks=layers[0],
299 | groups=groups,
300 | reduction=reduction,
301 | downsample_kernel_size=1,
302 | downsample_padding=0
303 | )
304 | self.layer2 = self._make_layer(
305 | block,
306 | planes=128,
307 | blocks=layers[1],
308 | stride=2,
309 | groups=groups,
310 | reduction=reduction,
311 | downsample_kernel_size=downsample_kernel_size,
312 | downsample_padding=downsample_padding
313 | )
314 | self.layer3 = self._make_layer(
315 | block,
316 | planes=256,
317 | blocks=layers[2],
318 | stride=2,
319 | groups=groups,
320 | reduction=reduction,
321 | downsample_kernel_size=downsample_kernel_size,
322 | downsample_padding=downsample_padding
323 | )
324 | self.layer4 = self._make_layer(
325 | block,
326 | planes=512,
327 | blocks=layers[3],
328 | stride=last_stride,
329 | groups=groups,
330 | reduction=reduction,
331 | downsample_kernel_size=downsample_kernel_size,
332 | downsample_padding=downsample_padding
333 | )
334 |
335 | self.global_avgpool = nn.AdaptiveAvgPool2d(1)
336 | self.fc = self._construct_fc_layer(fc_dims, 512 * block.expansion, dropout_p)
337 |
338 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
339 | downsample_kernel_size=1, downsample_padding=0):
340 | downsample = None
341 | if stride != 1 or self.inplanes != planes * block.expansion:
342 | downsample = nn.Sequential(
343 | nn.Conv2d(self.inplanes, planes * block.expansion,
344 | kernel_size=downsample_kernel_size, stride=stride,
345 | padding=downsample_padding, bias=False),
346 | nn.BatchNorm2d(planes * block.expansion),
347 | )
348 |
349 | layers = []
350 | layers.append(block(self.inplanes, planes, groups, reduction, stride,
351 | downsample))
352 | self.inplanes = planes * block.expansion
353 | for i in range(1, blocks):
354 | layers.append(block(self.inplanes, planes, groups, reduction))
355 |
356 | return nn.Sequential(*layers)
357 |
358 | def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
359 | """
360 | Construct fully connected layer
361 |
362 | - fc_dims (list or tuple): dimensions of fc layers, if None,
363 | no fc layers are constructed
364 | - input_dim (int): input dimension
365 | - dropout_p (float): dropout probability, if None, dropout is unused
366 | """
367 | if fc_dims is None:
368 | self.feature_dim = input_dim
369 | return None
370 |
371 | assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either list or tuple, but got {}'.format(type(fc_dims))
372 |
373 | layers = []
374 | for dim in fc_dims:
375 | layers.append(nn.Linear(input_dim, dim))
376 | layers.append(nn.BatchNorm1d(dim))
377 | layers.append(nn.ReLU(inplace=True))
378 | if dropout_p is not None:
379 | layers.append(nn.Dropout(p=dropout_p))
380 | input_dim = dim
381 |
382 | self.feature_dim = fc_dims[-1]
383 |
384 | return nn.Sequential(*layers)
385 |
386 | def featuremaps(self, x):
387 | x = self.layer0(x)
388 | x = self.layer1(x)
389 | x = self.layer2(x)
390 | x = self.layer3(x)
391 | x = self.layer4(x)
392 | return x
393 |
394 | def forward(self, x):
395 | f = self.featuremaps(x)
396 | return f
397 |
398 | def init_pretrained_weights(model, model_url):
399 | """
400 | Initialize model with pretrained weights.
401 | Layers that don't match with pretrained layers in name or size are kept unchanged.
402 | """
403 | pretrain_dict = model_zoo.load_url(model_url)
404 | model_dict = model.state_dict()
405 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
406 | model_dict.update(pretrain_dict)
407 | model.load_state_dict(model_dict)
408 | print('Initialized model with pretrained weights from {}'.format(model_url))
409 |
410 |
411 | def senet154(pretrained=True, **kwargs):
412 | model = SENet(
413 | num_classes=num_classes,
414 | loss=loss,
415 | block=SEBottleneck,
416 | layers=[3, 8, 36, 3],
417 | groups=64,
418 | reduction=16,
419 | dropout_p=0.2,
420 | last_stride=2,
421 | fc_dims=None,
422 | **kwargs
423 | )
424 | if pretrained:
425 | model_url = pretrained_settings['senet154']['imagenet']['url']
426 | init_pretrained_weights(model, model_url)
427 | return model
428 |
429 |
430 | def se_resnet50(pretrained=True, **kwargs):
431 | model = SENet(
432 | block=SEResNetBottleneck,
433 | layers=[3, 4, 6, 3],
434 | groups=1,
435 | reduction=16,
436 | dropout_p=None,
437 | inplanes=64,
438 | input_3x3=False,
439 | downsample_kernel_size=1,
440 | downsample_padding=0,
441 | last_stride=2,
442 | fc_dims=None,
443 | **kwargs
444 | )
445 | if pretrained:
446 | model_url = pretrained_settings['se_resnet50']['imagenet']['url']
447 | init_pretrained_weights(model, model_url)
448 | return model
449 |
450 |
451 | def se_resnet50_fc512(pretrained=True, **kwargs):
452 | model = SENet(
453 | block=SEResNetBottleneck,
454 | layers=[3, 4, 6, 3],
455 | groups=1,
456 | reduction=16,
457 | dropout_p=None,
458 | inplanes=64,
459 | input_3x3=False,
460 | downsample_kernel_size=1,
461 | downsample_padding=0,
462 | last_stride=1,
463 | fc_dims=[512],
464 | **kwargs
465 | )
466 | if pretrained:
467 | model_url = pretrained_settings['se_resnet50']['imagenet']['url']
468 | init_pretrained_weights(model, model_url)
469 | return model
470 |
471 |
472 | def se_resnet101(pretrained=True, **kwargs):
473 | model = SENet(
474 | block=SEResNetBottleneck,
475 | layers=[3, 4, 23, 3],
476 | groups=1,
477 | reduction=16,
478 | dropout_p=None,
479 | inplanes=64,
480 | input_3x3=False,
481 | downsample_kernel_size=1,
482 | downsample_padding=0,
483 | last_stride=2,
484 | fc_dims=None,
485 | **kwargs
486 | )
487 | if pretrained:
488 | model_url = pretrained_settings['se_resnet101']['imagenet']['url']
489 | init_pretrained_weights(model, model_url)
490 | return model
491 |
492 |
493 | def se_resnet152(pretrained=True, **kwargs):
494 | model = SENet(
495 | block=SEResNetBottleneck,
496 | layers=[3, 8, 36, 3],
497 | groups=1,
498 | reduction=16,
499 | dropout_p=None,
500 | inplanes=64,
501 | input_3x3=False,
502 | downsample_kernel_size=1,
503 | downsample_padding=0,
504 | last_stride=2,
505 | fc_dims=None,
506 | **kwargs
507 | )
508 | if pretrained:
509 | model_url = pretrained_settings['se_resnet152']['imagenet']['url']
510 | init_pretrained_weights(model, model_url)
511 | return model
512 |
513 |
514 | def se_resnext50_32x4d(pretrained=True, **kwargs):
515 | model = SENet(
516 | block=SEResNeXtBottleneck,
517 | layers=[3, 4, 6, 3],
518 | groups=32,
519 | reduction=16,
520 | dropout_p=None,
521 | inplanes=64,
522 | input_3x3=False,
523 | downsample_kernel_size=1,
524 | downsample_padding=0,
525 | last_stride=2,
526 | fc_dims=None,
527 | **kwargs
528 | )
529 | if pretrained:
530 | model_url = pretrained_settings['se_resnext50_32x4d']['imagenet']['url']
531 | init_pretrained_weights(model, model_url)
532 | return model
533 |
534 |
535 | def se_resnext101_32x4d(pretrained=True, **kwargs):
536 | model = SENet(
537 | block=SEResNeXtBottleneck,
538 | layers=[3, 4, 23, 3],
539 | groups=32,
540 | reduction=16,
541 | dropout_p=None,
542 | inplanes=64,
543 | input_3x3=False,
544 | downsample_kernel_size=1,
545 | downsample_padding=0,
546 | last_stride=2,
547 | fc_dims=None,
548 | **kwargs
549 | )
550 | if pretrained:
551 | model_url = pretrained_settings['se_resnext101_32x4d']['imagenet']['url']
552 | init_pretrained_weights(model, model_url)
553 | return model
--------------------------------------------------------------------------------
/src/project_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import sys
4 | import errno, copy
5 | import os.path as osp
6 | import more_itertools as mit
7 | import torch, torch.nn as nn
8 | import numpy as np
9 | import scipy.io as sio
10 |
11 | from torch.utils.data.sampler import Sampler
12 |
13 |
14 | def read_Duke_attributes(attribute_file_path):
15 | attributes = [
16 | "age",
17 | "backpack",
18 | "bag",
19 | "boots",
20 | "clothes",
21 | "down",
22 | "gender",
23 | "hair",
24 | "handbag",
25 | "hat",
26 | "shoes",
27 | # "top", # patched up by changing to up
28 | "up",
29 | {"upcloth": [
30 | "upblack",
31 | "upblue",
32 | "upbrown",
33 | "upgray",
34 | "upgreen",
35 | "uppurple",
36 | "upred",
37 | "upwhite",
38 | "upyellow",
39 | ]},
40 | {"downcloth": [
41 | "downblack",
42 | "downblue",
43 | "downbrown",
44 | "downgray",
45 | "downgreen",
46 | "downpink",
47 | "downpurple",
48 | "downred",
49 | "downwhite",
50 | "downyellow",
51 | ]},
52 | ]
53 |
54 | attribute2index = {}
55 | for index, item in enumerate(attributes):
56 | key = item
57 | if isinstance(item, dict):
58 | key = list(item.keys())[0]
59 | attribute2index[key] = index
60 |
61 | num_options_per_attributes = {
62 | "age": ["young", "teen", "adult", "old"],
63 | "backpack": ["no", "yes"],
64 | "bag": ["no", "yes"],
65 | "boots": ["no", "yes"],
66 | "clothes": ["dress", "pants"],
67 | "down": ["long", "short"],
68 | "downcloth": [
69 | "downblack",
70 | "downblue",
71 | "downbrown",
72 | "downgray",
73 | "downgreen",
74 | "downpink",
75 | "downpurple",
76 | "downred",
77 | "downwhite",
78 | "downyellow",
79 | ],
80 | "gender": ["male", "female"],
81 | "hair": ["short", "long"],
82 | "handbag": ["no", "yes"],
83 | "hat": ["no", "yes"],
84 | "shoes": ["dark", "light"],
85 | # "top": ["short", "long"],
86 | "up": ["long", "short"],
87 | "upcloth": [
88 | "upblack",
89 | "upblue",
90 | "upbrown",
91 | "upgray",
92 | "upgreen",
93 | "uppurple",
94 | "upred",
95 | "upwhite",
96 | "upyellow",
97 | ],
98 | }
99 |
100 | def parse_attributes_by_id(attributes):
101 | # get the order of attributes in the buffer (from dtype)
102 | attribute_names = attributes.dtype.names
103 | attribute2index = {
104 | attribute: index
105 | for index, attribute in enumerate(attribute_names)
106 | if attribute != "image_index"
107 | }
108 |
109 | # get person id and arrange by index
110 | current_attributes = attributes.item()
111 | all_ids = current_attributes[-1]
112 | id2index = {id_: index for index, id_ in enumerate(all_ids)}
113 |
114 | # collect attributes for each ID
115 | attributes_byID = {}
116 | for attribute_name, attribute_index in attribute2index.items():
117 | # each attribute values are stored as a row in the attribute annotation file
118 | current_attribute_values = current_attributes[attribute_index]
119 |
120 | for id_, id_index in id2index.items():
121 | if id_ not in attributes_byID:
122 | attributes_byID[id_] = {}
123 |
124 | attribute_value = current_attribute_values[id_index]
125 |
126 | # patch for up vs. top
127 | if attribute_name == "top":
128 | attribute_name = "up"
129 | attribute_value = (attribute_value % 2) + 1
130 |
131 | attributes_byID[id_][attribute_name] = (
132 | attribute_value - 1
133 | ) # 0-based values
134 |
135 | return attributes_byID
136 |
137 | def merge_ID_attributes(train_attributes, test_attributes):
138 | all_ID_attributes = copy.deepcopy(train_attributes)
139 |
140 | for ID, val in test_attributes.items():
141 | assert (
142 | ID not in train_attributes.keys()
143 | ), "attribute merge: test ID {} already exists in train".format(ID)
144 |
145 | all_ID_attributes[ID] = val
146 |
147 | return all_ID_attributes
148 |
149 |
150 | root_dir, filename = osp.split(attribute_file_path)
151 | root_key = osp.splitext(filename)[0]
152 | buffer_file_path = osp.join(root_dir, root_key + "_attribute_cache.pth")
153 |
154 | if not osp.exists(buffer_file_path):
155 | print(buffer_file_path + " does not exist!, reading the attributes")
156 | att_file_contents = sio.loadmat(
157 | attribute_file_path, squeeze_me=True
158 | ) # squeeze 1-dim elements
159 |
160 | attributes = att_file_contents[root_key]
161 |
162 | train_attributes = attributes.item()[0]
163 | test_attributes = attributes.item()[1]
164 |
165 | train_attributes_byID = parse_attributes_by_id(train_attributes)
166 | test_attributes_byID = parse_attributes_by_id(test_attributes)
167 | print('train', len(train_attributes_byID), 'test', len(test_attributes_byID))
168 | all_ID_attributes = merge_ID_attributes(
169 | train_attributes_byID, test_attributes_byID
170 | )
171 |
172 | # writing attribute cache
173 | print("saving the attributes to cache " + buffer_file_path)
174 | torch.save(all_ID_attributes, buffer_file_path)
175 |
176 | else:
177 | print("reading the attributes from cache " + buffer_file_path)
178 | all_ID_attributes = torch.load(buffer_file_path)
179 |
180 | return all_ID_attributes
181 |
182 |
183 | def read_Mars_attributes(attributes_file):
184 | pass
185 |
186 |
187 | def shortlist_Mars_on_attribute(data_source, attribute, value):
188 | attributes_file = ''
189 | all_attributes = read_Mars_attributes(attributes_file)
190 |
191 | for img_path, pid, camid in data_source:
192 | pass
193 |
194 |
195 | def shortlist_Duke_on_attribute(data_source, attribute, value):
196 | attributes_file = '/media/data1/datasets/personreid/dukemtmc-reid/attributes/duke_attribute.mat'
197 | all_attributes = read_Duke_attributes(attributes_file)
198 |
199 | # collect all indices where the attribute has particular value
200 | relevant_indices = []
201 | not_found_pids = []
202 | for index, (img_path, pid, camid) in enumerate(data_source):
203 | stringify_pid = "%04d" % pid
204 |
205 | if stringify_pid not in all_attributes:
206 | not_found_pids.append(stringify_pid)
207 | continue
208 |
209 | current_pid_attributes = all_attributes[stringify_pid]
210 | if current_pid_attributes[attribute] == value:
211 | relevant_indices.append(index)
212 |
213 | print('not found pids : {}'.format(not_found_pids))
214 | print(len(relevant_indices), ' sampled out of ', len(data_source))
215 | return relevant_indices, all_attributes
216 |
217 |
218 | class AttributeBasedSampler(Sampler):
219 | def __init__(self, data_source, attribute, value, dataset_name):
220 | super().__init__(data_source)
221 |
222 | if dataset_name == 'mars':
223 | relevant_indices, _ = shortlist_Mars_on_attribute(data_source, attribute, value)
224 | elif dataset_name == 'dukemtmcvidreid':
225 | relevant_indices, _ = shortlist_Duke_on_attribute(data_source, attribute, value)
226 | else:
227 | assert False, 'unknown dataset ' + dataset_name
228 |
229 | # get the indices of data instances which has
230 | # attribute's value as value
231 | self.instance_indices = relevant_indices
232 |
233 | def __len__(self):
234 | return len(self.instance_indices)
235 |
236 | def __iter__(self):
237 | return iter(self.instance_indices)
238 |
--------------------------------------------------------------------------------
/src/samplers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import defaultdict
3 | import numpy as np
4 |
5 | import torch
6 |
7 | class RandomIdentitySampler(torch.utils.data.Sampler):
8 | """
9 | Randomly sample N identities, then for each identity,
10 | randomly sample K instances, therefore batch size is N*K.
11 |
12 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
13 |
14 | Args:
15 | data_source (Dataset): dataset to sample from.
16 | num_instances (int): number of instances per identity.
17 | """
18 | def __init__(self, data_source, num_instances=4):
19 | self.data_source = data_source
20 | self.num_instances = num_instances
21 | self.index_dic = defaultdict(list)
22 | for index, (_, pid, _) in enumerate(data_source):
23 | self.index_dic[pid].append(index)
24 | self.pids = list(self.index_dic.keys())
25 | self.num_identities = len(self.pids)
26 |
27 | def __iter__(self):
28 | indices = torch.randperm(self.num_identities)
29 | ret = []
30 | for i in indices:
31 | pid = self.pids[i]
32 | t = self.index_dic[pid]
33 | replace = False if len(t) >= self.num_instances else True
34 | t = np.random.choice(t, size=self.num_instances, replace=replace)
35 | ret.extend(t)
36 | return iter(ret)
37 |
38 | def __len__(self):
39 | return self.num_identities * self.num_instances
40 |
--------------------------------------------------------------------------------
/src/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from torchvision.transforms import *
4 | from PIL import Image
5 | import random
6 | import numpy as np
7 |
8 | class Random2DTranslation(object):
9 | """
10 | With a probability, first increase image size to (1 + 1/8), and then perform random crop.
11 |
12 | Args:
13 | height (int): target height.
14 | width (int): target width.
15 | p (float): probability of performing this transformation. Default: 0.5.
16 | """
17 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
18 | self.height = height
19 | self.width = width
20 | self.p = p
21 | self.interpolation = interpolation
22 |
23 | def __call__(self, img):
24 | """
25 | Args:
26 | img (PIL Image): Image to be cropped.
27 |
28 | Returns:
29 | PIL Image: Cropped image.
30 | """
31 | if random.random() < self.p:
32 | return img.resize((self.width, self.height), self.interpolation)
33 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125))
34 | resized_img = img.resize((new_width, new_height), self.interpolation)
35 | x_maxrange = new_width - self.width
36 | y_maxrange = new_height - self.height
37 | x1 = int(round(random.uniform(0, x_maxrange)))
38 | y1 = int(round(random.uniform(0, y_maxrange)))
39 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height))
40 | return croped_img
41 |
42 | if __name__ == '__main__':
43 | pass
44 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import sys
4 | import errno
5 | import shutil
6 | import json, math
7 | import os.path as osp
8 | import more_itertools as mit
9 | import torch, torch.nn as nn
10 | import numpy as np
11 | import inspect
12 |
13 | is_print_once_enabled = True
14 |
15 |
16 | def static_vars(**kwargs):
17 | def decorate(func):
18 | for k in kwargs:
19 | setattr(func, k, kwargs[k])
20 | return func
21 |
22 | return decorate
23 |
24 |
25 | def disable_all_print_once():
26 | global is_print_once_enabled
27 | is_print_once_enabled = False
28 |
29 |
30 | @static_vars(lines={})
31 | def print_once(msg):
32 | # return from the function if the API is disabled
33 | global is_print_once_enabled
34 | if not is_print_once_enabled:
35 | return
36 |
37 | from inspect import getframeinfo, stack
38 |
39 | caller = getframeinfo(stack()[1][0])
40 | current_file_line = "%s:%d" % (caller.filename, caller.lineno)
41 |
42 | # if the current called file and line is not in buffer print once
43 | if current_file_line not in print_once.lines:
44 | print(msg)
45 | print_once.lines[current_file_line] = True
46 |
47 |
48 | def get_executing_filepath():
49 | frame = inspect.stack()[1]
50 | module = inspect.getmodule(frame[0])
51 | filename = module.__file__
52 | return os.path.split(filename)[0]
53 |
54 |
55 | def set_stride(module, stride):
56 | """
57 |
58 | """
59 | print("setting stride of ", module, " to ", stride)
60 | for internal_module in module.modules():
61 | if isinstance(internal_module, nn.Conv2d) or isinstance(
62 | internal_module, nn.MaxPool2d
63 | ):
64 | internal_module.stride = stride
65 |
66 | return internal_module
67 |
68 |
69 | def get_gaussian_kernel(channels, kernel_size=5, mean=0, sigma=[1, 4]):
70 | # CONVERT INTO NP ARRAY
71 | sigma_ = torch.zeros((2, 2)).float()
72 | sigma_[0, 0] = sigma[0]
73 | sigma_[1, 1] = sigma[1]
74 | sigma = sigma_
75 |
76 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
77 | x_cord = torch.linspace(-1, 1, kernel_size)
78 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
79 | y_grid = x_grid.t()
80 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
81 |
82 | variance = (sigma @ sigma.t()).float()
83 | inv_variance = torch.inverse(variance)
84 |
85 | # Calculate the 2-dimensional gaussian kernel which is
86 | # the product of two gaussian distributions for two different
87 | # variables (in this case called x and y)
88 | gaussian_kernel = (1.0 / (2.0 * math.pi * torch.det(variance))) * torch.exp(
89 | -torch.sum(
90 | ((xy_grid - mean) @ inv_variance.unsqueeze(0)) * (xy_grid - mean), dim=-1
91 | )
92 | / 2
93 | )
94 |
95 | # Make sure sum of values in gaussian kernel equals 1.
96 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
97 |
98 | # Reshape to 2d depthwise convolutional weight
99 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
100 | gaussian_kernel = gaussian_kernel.repeat(1, channels, 1, 1)
101 | return gaussian_kernel
102 |
103 |
104 | def mkdir_if_missing(directory):
105 | """to create a directory
106 |
107 | Arguments:
108 | directory {str} -- directory path
109 | """
110 |
111 | if not osp.exists(directory):
112 | try:
113 | os.makedirs(directory)
114 | except OSError as e:
115 | if e.errno != errno.EEXIST:
116 | raise
117 |
118 |
119 | # function to freeze certain module's weight
120 | def freeze_weights(to_be_freezed):
121 | for param in to_be_freezed.parameters():
122 | param.requires_grad = False
123 |
124 | for module in to_be_freezed.children():
125 | for param in module.parameters():
126 | param.requires_grad = False
127 |
128 |
129 | def load_pretrained_model(model, pretrained_model_path, verbose=False):
130 | """To load the pretrained model considering the number of keys and their sizes
131 |
132 | Arguments:
133 | model {loaded model} -- already loaded model
134 | pretrained_model_path {str} -- path to the pretrained model file
135 |
136 | Raises:
137 | IOError -- if the file path is not found
138 |
139 | Returns:
140 | model -- model with loaded params
141 | """
142 |
143 | if not os.path.exists(pretrained_model_path):
144 | raise IOError("Can't find pretrained model: {}".format(pretrained_model_path))
145 |
146 | print("Loading checkpoint from '{}'".format(pretrained_model_path))
147 | pretrained_state = torch.load(pretrained_model_path)["state_dict"]
148 | print(len(pretrained_state), " keys in pretrained model")
149 |
150 | current_model_state = model.state_dict()
151 | print(len(current_model_state), " keys in current model")
152 | pretrained_state = {
153 | key: val
154 | for key, val in pretrained_state.items()
155 | if key in current_model_state and val.size() == current_model_state[key].size()
156 | }
157 |
158 | print(
159 | len(pretrained_state),
160 | " keys in pretrained model are available in current model",
161 | )
162 | current_model_state.update(pretrained_state)
163 | model.load_state_dict(current_model_state)
164 |
165 | if verbose:
166 | non_available_keys_in_pretrained = [
167 | key
168 | for key, val in pretrained_state.items()
169 | if key not in current_model_state
170 | or val.size() != current_model_state[key].size()
171 | ]
172 | non_available_keys_in_current = [
173 | key
174 | for key, val in current_model_state.items()
175 | if key not in pretrained_state or val.size() != pretrained_state[key].size()
176 | ]
177 |
178 | print(
179 | "not available keys in pretrained model: ", non_available_keys_in_pretrained
180 | )
181 | print("not available keys in current model: ", non_available_keys_in_current)
182 |
183 | return model
184 |
185 |
186 | def get_currenttime_prefix():
187 | """to get a prefix of current time
188 |
189 | Returns:
190 | [str] -- current time encoded into string
191 | """
192 |
193 | from time import localtime, strftime
194 |
195 | return strftime("%d-%b-%Y_%H:%M:%S", localtime())
196 |
197 |
198 | def get_learnable_params(model):
199 | """to get the list of learnable params
200 |
201 | Arguments:
202 | model {model} -- loaded model
203 |
204 | Returns:
205 | list -- learnable params
206 | """
207 |
208 | # list down the names of learnable params
209 | details = []
210 | for name, param in model.named_parameters():
211 | if param.requires_grad:
212 | details.append((name, param.shape))
213 | print("learnable params (" + str(len(details)) + ") : ", details)
214 |
215 | # short list the params which has requires_grad as true
216 | learnable_params = [param for param in model.parameters() if param.requires_grad]
217 |
218 | print(
219 | "Model size: {:.5f}M".format(
220 | sum(p.numel() for p in learnable_params) / 1000000.0
221 | )
222 | )
223 | return learnable_params
224 |
225 |
226 | def get_features(model, imgs, test_num_tracks):
227 | """to handle higher seq length videos due to OOM error
228 | specifically used during test
229 |
230 | Arguments:
231 | model -- model under test
232 | imgs -- imgs to get features for
233 |
234 | Returns:
235 | features
236 | """
237 |
238 | # handle chunked data
239 | all_features = []
240 |
241 | for test_imgs in mit.chunked(imgs, test_num_tracks):
242 | current_test_imgs = torch.stack(test_imgs)
243 | num_current_test_imgs = current_test_imgs.shape[0]
244 | # print(current_test_imgs.shape)
245 | features = model(current_test_imgs)
246 | features = features.view(num_current_test_imgs, -1)
247 | all_features.append(features)
248 |
249 | return torch.cat(all_features)
250 |
251 |
252 | def get_spatial_features(model, imgs, test_num_tracks):
253 | """to handle higher seq length videos due to OOM error
254 | specifically used during test
255 |
256 | Arguments:
257 | model -- model under test
258 | imgs -- imgs to get features for
259 |
260 | Returns:
261 | features
262 | """
263 |
264 | # handle chunked data
265 | all_features, all_spatial_features = [], []
266 |
267 | for test_imgs in mit.chunked(imgs, test_num_tracks):
268 | current_test_imgs = torch.stack(test_imgs)
269 | num_current_test_imgs = current_test_imgs.shape[0]
270 | features, spatial_feats = model(current_test_imgs)
271 | features = features.view(num_current_test_imgs, -1)
272 |
273 | all_spatial_features.append(spatial_feats)
274 | all_features.append(features)
275 |
276 | return torch.cat(all_features), torch.cat(all_spatial_features)
277 |
278 |
279 | class AverageMeter(object):
280 | """Computes and stores the average and current value.
281 |
282 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
283 | """
284 |
285 | def __init__(self):
286 | self.reset()
287 |
288 | def reset(self):
289 | self.val = 0
290 | self.avg = 0
291 | self.sum = 0
292 | self.count = 0
293 |
294 | def update(self, val, n=1):
295 | self.val = val
296 | self.sum += val * n
297 | self.count += n
298 | self.avg = self.sum / self.count
299 |
300 |
301 | def save_checkpoint(state, is_best, fpath="checkpoint.pth.tar"):
302 | mkdir_if_missing(osp.dirname(fpath))
303 | print("saving model to " + fpath)
304 | torch.save(state, fpath)
305 | if is_best:
306 | shutil.copy(fpath, osp.join(osp.dirname(fpath), "best_model.pth.tar"))
307 |
308 |
309 | def open_all_layers(model):
310 | """
311 | Open all layers in model for training.
312 |
313 | Args:
314 | - model (nn.Module): neural net model.
315 | """
316 | model.train()
317 | for p in model.parameters():
318 | p.requires_grad = True
319 |
320 |
321 | def open_specified_layers(model, open_layers):
322 | """
323 | Open specified layers in model for training while keeping
324 | other layers frozen.
325 |
326 | Args:
327 | - model (nn.Module): neural net model.
328 | - open_layers (list): list of layer names.
329 | """
330 | if isinstance(model, nn.DataParallel):
331 | model = model.module
332 |
333 | for layer in open_layers:
334 | assert hasattr(
335 | model, layer
336 | ), '"{}" is not an attribute of the model, please provide the correct name'.format(
337 | layer
338 | )
339 |
340 | # check if all the open layers are there in model
341 | all_names = [name for name, module in model.named_children()]
342 | for tobeopen_layer in open_layers:
343 | assert tobeopen_layer in all_names, "{} not in model".format(tobeopen_layer)
344 |
345 | for name, module in model.named_children():
346 | if name in open_layers:
347 | module.train()
348 | for p in module.parameters():
349 | p.requires_grad = True
350 | else:
351 | module.eval()
352 | for p in module.parameters():
353 | p.requires_grad = False
354 |
355 |
356 | class Logger(object):
357 | """
358 | Write console output to external text file.
359 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
360 | """
361 |
362 | def __init__(self, fpath=None):
363 | self.console = sys.stdout
364 | self.file = None
365 | if fpath is not None:
366 | mkdir_if_missing(os.path.dirname(fpath))
367 | self.file = open(fpath, "a")
368 |
369 | def __del__(self):
370 | self.close()
371 |
372 | def __enter__(self):
373 | pass
374 |
375 | def __exit__(self, *args):
376 | self.close()
377 |
378 | def write(self, msg):
379 | self.console.write(msg)
380 | if self.file is not None:
381 | self.file.write(msg)
382 |
383 | def flush(self):
384 | self.console.flush()
385 | if self.file is not None:
386 | self.file.flush()
387 | os.fsync(self.file.fileno())
388 |
389 | def close(self):
390 | self.console.close()
391 | if self.file is not None:
392 | self.file.close()
393 |
394 |
395 | def read_json(fpath):
396 | with open(fpath, "r") as f:
397 | obj = json.load(f)
398 | return obj
399 |
400 |
401 | def write_json(obj, fpath):
402 | mkdir_if_missing(osp.dirname(fpath))
403 | with open(fpath, "w") as f:
404 | json.dump(obj, f, indent=4, separators=(",", ": "))
405 |
406 |
--------------------------------------------------------------------------------
/src/video_loader.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os
3 | from PIL import Image
4 | import numpy as np
5 |
6 | import torch
7 | from torch.utils.data import Dataset
8 | import random
9 |
10 | def read_image(img_path):
11 | """Keep reading image until succeed.
12 | This can avoid IOError incurred by heavy IO process."""
13 | got_img = False
14 | while not got_img:
15 | try:
16 | img = Image.open(img_path).convert('RGB')
17 | got_img = True
18 | except IOError:
19 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
20 | pass
21 | return img
22 |
23 |
24 | class VideoDataset(Dataset):
25 | """Video Person ReID Dataset.
26 | Note batch data has shape (batch, seq_len, channel, height, width).
27 | """
28 | sample_methods = ['evenly', 'random', 'all']
29 |
30 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None):
31 | self.dataset = dataset
32 | self.seq_len = seq_len
33 | self.sample = sample
34 | self.transform = transform
35 |
36 | def __len__(self):
37 | return len(self.dataset)
38 |
39 | def __getitem__(self, index):
40 | img_paths, pid, camid = self.dataset[index]
41 | num = len(img_paths)
42 | if self.sample == 'random':
43 | """
44 | Randomly sample seq_len consecutive frames from num frames,
45 | if num is smaller than seq_len, then replicate items.
46 | This sampling strategy is used in training phase.
47 | """
48 | frame_indices = range(num)
49 | rand_end = max(0, len(frame_indices) - self.seq_len - 1)
50 | begin_index = random.randint(0, rand_end)
51 | end_index = min(begin_index + self.seq_len, len(frame_indices))
52 |
53 | indices = list(frame_indices[begin_index:end_index])
54 |
55 | for index in indices:
56 | if len(indices) >= self.seq_len:
57 | break
58 | indices.append(index)
59 | indices=np.array(indices)
60 | imgs = []
61 | for index in indices:
62 | index=int(index)
63 | img_path = img_paths[index]
64 | img = read_image(img_path)
65 | if self.transform is not None:
66 | img = self.transform(img)
67 | img = img.unsqueeze(0)
68 | imgs.append(img)
69 | imgs = torch.cat(imgs, dim=0)
70 | #imgs=imgs.permute(1,0,2,3)
71 | return imgs, pid, camid
72 |
73 | if self.sample == 'evenly':
74 | """
75 | Evenly sample seq_len items from num items.
76 | """
77 | if num >= self.seq_len:
78 | num -= num % self.seq_len
79 | indices = np.arange(0, num, int(num/self.seq_len))
80 | else:
81 | # if num is smaller than seq_len, simply replicate the last image
82 | # until the seq_len requirement is satisfied
83 | indices = np.arange(0, num)
84 | num_pads = self.seq_len - num
85 | indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num-1)])
86 |
87 | assert len(indices) == self.seq_len
88 |
89 | for index in indices:
90 | if len(indices) >= self.seq_len:
91 | break
92 | indices.append(index)
93 | indices=np.array(indices)
94 | imgs = []
95 | for index in indices:
96 | index=int(index)
97 | img_path = img_paths[index]
98 | img = read_image(img_path)
99 | if self.transform is not None:
100 | img = self.transform(img)
101 | img = img.unsqueeze(0)
102 | imgs.append(img)
103 | imgs = torch.cat(imgs, dim=0)
104 | #imgs=imgs.permute(1,0,2,3)
105 | return imgs, pid, camid
106 |
107 | elif self.sample == 'dense':
108 | """
109 | Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1.
110 | This sampling strategy is used in test phase.
111 | """
112 | cur_index=0
113 | frame_indices = list(range(num))
114 | indices_list=[]
115 | while num-cur_index > self.seq_len:
116 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len])
117 | cur_index+=self.seq_len
118 | last_seq=frame_indices[cur_index:]
119 | for index in last_seq:
120 | if len(last_seq) >= self.seq_len:
121 | break
122 | last_seq.append(index)
123 | indices_list.append(last_seq)
124 |
125 | imgs_list=[]
126 | for indices in indices_list:
127 | imgs = []
128 | for index in indices:
129 | index=int(index)
130 | img_path = img_paths[index]
131 | img = read_image(img_path)
132 | if self.transform is not None:
133 | img = self.transform(img)
134 | img = img.unsqueeze(0)
135 | imgs.append(img)
136 | imgs = torch.cat(imgs, dim=0)
137 | #imgs=imgs.permute(1,0,2,3)
138 | imgs_list.append(imgs)
139 | imgs_array = torch.stack(imgs_list)
140 | return imgs_array, pid, camid
141 |
142 | else:
143 | raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods))
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
--------------------------------------------------------------------------------