├── README.md ├── __pycache__ ├── data_manager.cpython-37.pyc ├── eval_metrics.cpython-37.pyc ├── losses.cpython-37.pyc ├── samplers.cpython-37.pyc ├── transforms.cpython-37.pyc ├── utils.cpython-37.pyc └── video_loader.cpython-37.pyc ├── data_manager.py ├── data_manager_search.py ├── eval_metrics.py ├── losses.py ├── main_video_person_reid_hypergraphsage_part.py ├── models ├── ResNet.py ├── ResNet_hypergraphsage_part.py ├── __init__.py ├── __pycache__ │ ├── ResNet.cpython-37.pyc │ ├── ResNet_dart.cpython-37.pyc │ ├── ResNet_dart_search.cpython-37.pyc │ ├── ResNet_dynamichypergraphsage.cpython-37.pyc │ ├── ResNet_graph.cpython-37.pyc │ ├── ResNet_graphsage.cpython-37.pyc │ ├── ResNet_graphsage_bpm.cpython-37.pyc │ ├── ResNet_graphsage_part.cpython-37.pyc │ ├── ResNet_graphsage_part_new.cpython-37.pyc │ ├── ResNet_graphsage_part_new_att.cpython-37.pyc │ ├── ResNet_hypergraphsage.cpython-37.pyc │ ├── ResNet_hypergraphsage_part.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── convlstm.cpython-37.pyc │ ├── hypnn.cpython-37.pyc │ ├── non_local_embedded_gaussian.cpython-37.pyc │ ├── resnet.cpython-37.pyc │ ├── resnet3d.cpython-37.pyc │ └── utils.cpython-37.pyc ├── architect.py ├── convlstm.py ├── hypnn.py ├── network.py ├── non_local_concatenation.py ├── non_local_dot_product.py ├── non_local_embedded_gaussian.py ├── non_local_gaussian.py ├── resnet.py ├── resnet3d.py └── utils.py ├── run_hypergraphsage_part.sh ├── samplers.py ├── transforms.py ├── utils.py └── video_loader.py /README.md: -------------------------------------------------------------------------------- 1 | # hypergraph_reid 2 | 3 | Implementation of "Learning Multi-Granular Hypergraphs for Video-Based Person Re-Identification" 4 | If you find this help your research, please cite 5 | 6 | @inproceedings{DBLP:conf/cvpr/YanQC0ZT020, 7 | author = {Yichao Yan and 8 | Jie Qin and 9 | Jiaxin Chen and 10 | Li Liu and 11 | Fan Zhu and 12 | Ying Tai and 13 | Ling Shao}, 14 | title = {Learning Multi-Granular Hypergraphs for Video-Based Person Re-Identification}, 15 | booktitle = {2020 {IEEE/CVF} Conference on Computer Vision and Pattern Recognition, 16 | {CVPR} 2020, Seattle, WA, USA, June 13-19, 2020}, 17 | pages = {2896--2905}, 18 | publisher = {{IEEE}}, 19 | year = {2020} 20 | } 21 | 22 | 23 | ## Installation 24 | We use python 3.7 and pytorch=0.4 25 | 26 | ## Data preparation 27 | All experiments are done on MARS, as it is the largest dataset available to date for video-based person reID. Please follow [deep-person-reid](https://github.com/KaiyangZhou/deep-person-reid) to prepare the data. The instructions are copied here: 28 | 29 | 1. Create a directory named `mars/` under `data/`. 30 | 2. Download dataset to `data/mars/` from http://www.liangzheng.com.cn/Project/project_mars.html. 31 | 3. Extract `bbox_train.zip` and `bbox_test.zip`. 32 | 4. Download split information from https://github.com/liangzheng06/MARS-evaluation/tree/master/info and put `info/` in `data/mars` (we want to follow the standard split in [8]). The data structure would look like: 33 | ``` 34 | mars/ 35 | bbox_test/ 36 | bbox_train/ 37 | info/ 38 | ``` 39 | 40 | ### Usage 41 | To train the model, please run 42 | 43 | sh run_hypergraphsage_part.sh 44 | 45 | ### Performance 46 | Trained model [[Google]](https://drive.google.com/file/d/1KBuPWYAHBC2QVLKihpSe-GQEOUkVLCXl/view?usp=sharing) 47 | 48 | The shared trained model achieves 85.6% mAP and 89.5% rank-1 accuracy. According to my training log, the best model achieves 86.2% mAP and 90.0% top-1 accuracy. This may need adjustion in hyperparameters. 49 | 50 | ### Acknowledgements 51 | Our code is developed based on Video-Person-ReID (https://github.com/jiyanggao/Video-Person-ReID). 52 | -------------------------------------------------------------------------------- /__pycache__/data_manager.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/__pycache__/data_manager.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/eval_metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/__pycache__/eval_metrics.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/samplers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/__pycache__/samplers.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/video_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/__pycache__/video_loader.cpython-37.pyc -------------------------------------------------------------------------------- /data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import glob 4 | import re 5 | import sys 6 | import urllib 7 | import tarfile 8 | import zipfile 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | import numpy as np 12 | import random 13 | 14 | from utils import mkdir_if_missing, write_json, read_json 15 | 16 | """Dataset classes""" 17 | 18 | 19 | class Mars(object): 20 | """ 21 | MARS 22 | 23 | Reference: 24 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 25 | 26 | Dataset statistics: 27 | # identities: 1261 28 | # tracklets: 8298 (train) + 1980 (query) + 9330 (gallery) 29 | # cameras: 6 30 | 31 | Args: 32 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0). 33 | """ 34 | root = './data/mars' 35 | train_name_path = osp.join(root, 'info/train_name.txt') 36 | test_name_path = osp.join(root, 'info/test_name.txt') 37 | track_train_info_path = osp.join(root, 'info/tracks_train_info.mat') 38 | track_test_info_path = osp.join(root, 'info/tracks_test_info.mat') 39 | query_IDX_path = osp.join(root, 'info/query_IDX.mat') 40 | 41 | def __init__(self, min_seq_len=0): 42 | self._check_before_run() 43 | 44 | # prepare meta data 45 | train_names = self._get_names(self.train_name_path) 46 | test_names = self._get_names(self.test_name_path) 47 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) 48 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) 49 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) 50 | query_IDX -= 1 # index from 0 51 | track_query = track_test[query_IDX,:] 52 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 53 | track_gallery = track_test[gallery_IDX,:] 54 | 55 | train, num_train_tracklets, num_train_pids, num_train_imgs = \ 56 | self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) 57 | 58 | query, num_query_tracklets, num_query_pids, num_query_imgs = \ 59 | self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 60 | 61 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = \ 62 | self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 63 | 64 | num_imgs_per_tracklet = num_train_imgs + num_query_imgs + num_gallery_imgs 65 | min_num = np.min(num_imgs_per_tracklet) 66 | max_num = np.max(num_imgs_per_tracklet) 67 | avg_num = np.mean(num_imgs_per_tracklet) 68 | 69 | num_total_pids = num_train_pids + num_query_pids 70 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 71 | 72 | print("=> MARS loaded") 73 | print("Dataset statistics:") 74 | print(" ------------------------------") 75 | print(" subset | # ids | # tracklets") 76 | print(" ------------------------------") 77 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 78 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 79 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 80 | print(" ------------------------------") 81 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 82 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 83 | print(" ------------------------------") 84 | 85 | self.train = train 86 | self.query = query 87 | self.gallery = gallery 88 | 89 | self.num_train_pids = num_train_pids 90 | self.num_query_pids = num_query_pids 91 | self.num_gallery_pids = num_gallery_pids 92 | 93 | def _check_before_run(self): 94 | """Check if all files are available before going deeper""" 95 | if not osp.exists(self.root): 96 | raise RuntimeError("'{}' is not available".format(self.root)) 97 | if not osp.exists(self.train_name_path): 98 | raise RuntimeError("'{}' is not available".format(self.train_name_path)) 99 | if not osp.exists(self.test_name_path): 100 | raise RuntimeError("'{}' is not available".format(self.test_name_path)) 101 | if not osp.exists(self.track_train_info_path): 102 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path)) 103 | if not osp.exists(self.track_test_info_path): 104 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path)) 105 | if not osp.exists(self.query_IDX_path): 106 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path)) 107 | 108 | def _get_names(self, fpath): 109 | names = [] 110 | with open(fpath, 'r') as f: 111 | for line in f: 112 | new_line = line.rstrip() 113 | names.append(new_line) 114 | return names 115 | 116 | def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 117 | assert home_dir in ['bbox_train', 'bbox_test'] 118 | num_tracklets = meta_data.shape[0] 119 | pid_list = list(set(meta_data[:,2].tolist())) 120 | num_pids = len(pid_list) 121 | 122 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 123 | tracklets = [] 124 | num_imgs_per_tracklet = [] 125 | 126 | #txt_name = self.root + home_dir + str(len(meta_data)) + '.txt' 127 | #fid = open(txt_name, "w") 128 | 129 | for tracklet_idx in range(num_tracklets): 130 | data = meta_data[tracklet_idx,...] 131 | start_index, end_index, pid, camid = data 132 | if pid == -1: continue # junk images are just ignored 133 | assert 1 <= camid <= 6 134 | if relabel: pid = pid2label[pid] 135 | camid -= 1 # index starts from 0 136 | img_names = names[start_index-1:end_index] 137 | 138 | # make sure image names correspond to the same person 139 | pnames = [img_name[:4] for img_name in img_names] 140 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 141 | 142 | # make sure all images are captured under the same camera 143 | camnames = [img_name[5] for img_name in img_names] 144 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 145 | 146 | # append image names with directory information 147 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] 148 | if len(img_paths) >= min_seq_len: 149 | img_paths = tuple(img_paths) 150 | tracklets.append((img_paths, pid, camid)) 151 | num_imgs_per_tracklet.append(len(img_paths)) 152 | #fid.write(img_names[0] + '\n') 153 | 154 | #fid.close() 155 | num_tracklets = len(tracklets) 156 | 157 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 158 | 159 | class iLIDSVID(object): 160 | """ 161 | iLIDS-VID 162 | 163 | Reference: 164 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014. 165 | 166 | Dataset statistics: 167 | # identities: 300 168 | # tracklets: 600 169 | # cameras: 2 170 | 171 | Args: 172 | split_id (int): indicates which split to use. There are totally 10 splits. 173 | """ 174 | root = './data/ilids-vid' 175 | dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar' 176 | data_dir = osp.join(root, 'i-LIDS-VID') 177 | split_dir = osp.join(root, 'train-test people splits') 178 | split_mat_path = osp.join(split_dir, 'train_test_splits_ilidsvid.mat') 179 | split_path = osp.join(root, 'splits.json') 180 | cam_1_path = osp.join(root, 'i-LIDS-VID/sequences/cam1') 181 | cam_2_path = osp.join(root, 'i-LIDS-VID/sequences/cam2') 182 | 183 | def __init__(self, split_id=0): 184 | self._download_data() 185 | self._check_before_run() 186 | 187 | self._prepare_split() 188 | splits = read_json(self.split_path) 189 | if split_id >= len(splits): 190 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 191 | split = splits[split_id] 192 | train_dirs, test_dirs = split['train'], split['test'] 193 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 194 | 195 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 196 | self._process_data(train_dirs, cam1=True, cam2=True) 197 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 198 | self._process_data(test_dirs, cam1=True, cam2=False) 199 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 200 | self._process_data(test_dirs, cam1=False, cam2=True) 201 | 202 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 203 | min_num = np.min(num_imgs_per_tracklet) 204 | max_num = np.max(num_imgs_per_tracklet) 205 | avg_num = np.mean(num_imgs_per_tracklet) 206 | 207 | num_total_pids = num_train_pids + num_query_pids 208 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 209 | 210 | print("=> iLIDS-VID loaded") 211 | print("Dataset statistics:") 212 | print(" ------------------------------") 213 | print(" subset | # ids | # tracklets") 214 | print(" ------------------------------") 215 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 216 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 217 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 218 | print(" ------------------------------") 219 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 220 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 221 | print(" ------------------------------") 222 | 223 | self.train = train 224 | self.query = query 225 | self.gallery = gallery 226 | 227 | self.num_train_pids = num_train_pids 228 | self.num_query_pids = num_query_pids 229 | self.num_gallery_pids = num_gallery_pids 230 | 231 | def _download_data(self): 232 | if osp.exists(self.root): 233 | print("This dataset has been downloaded.") 234 | return 235 | 236 | mkdir_if_missing(self.root) 237 | fpath = osp.join(self.root, osp.basename(self.dataset_url)) 238 | 239 | print("Downloading iLIDS-VID dataset") 240 | url_opener = urllib.URLopener() 241 | url_opener.retrieve(self.dataset_url, fpath) 242 | 243 | print("Extracting files") 244 | tar = tarfile.open(fpath) 245 | tar.extractall(path=self.root) 246 | tar.close() 247 | 248 | def _check_before_run(self): 249 | """Check if all files are available before going deeper""" 250 | if not osp.exists(self.root): 251 | raise RuntimeError("'{}' is not available".format(self.root)) 252 | if not osp.exists(self.data_dir): 253 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 254 | if not osp.exists(self.split_dir): 255 | raise RuntimeError("'{}' is not available".format(self.split_dir)) 256 | 257 | def _prepare_split(self): 258 | if not osp.exists(self.split_path): 259 | print("Creating splits") 260 | mat_split_data = loadmat(self.split_mat_path)['ls_set'] 261 | 262 | num_splits = mat_split_data.shape[0] 263 | num_total_ids = mat_split_data.shape[1] 264 | assert num_splits == 10 265 | assert num_total_ids == 300 266 | num_ids_each = int(num_total_ids/2) 267 | 268 | # pids in mat_split_data are indices, so we need to transform them 269 | # to real pids 270 | person_cam1_dirs = os.listdir(self.cam_1_path) 271 | person_cam2_dirs = os.listdir(self.cam_2_path) 272 | 273 | # make sure persons in one camera view can be found in the other camera view 274 | assert set(person_cam1_dirs) == set(person_cam2_dirs) 275 | 276 | splits = [] 277 | for i_split in range(num_splits): 278 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14. 279 | train_idxs = sorted(list(mat_split_data[i_split,num_ids_each:])) 280 | test_idxs = sorted(list(mat_split_data[i_split,:num_ids_each])) 281 | 282 | train_idxs = [int(i)-1 for i in train_idxs] 283 | test_idxs = [int(i)-1 for i in test_idxs] 284 | 285 | # transform pids to person dir names 286 | train_dirs = [person_cam1_dirs[i] for i in train_idxs] 287 | test_dirs = [person_cam1_dirs[i] for i in test_idxs] 288 | 289 | split = {'train': train_dirs, 'test': test_dirs} 290 | splits.append(split) 291 | 292 | print("Totally {} splits are created, following Wang et al. ECCV'14".format(len(splits))) 293 | print("Split file is saved to {}".format(self.split_path)) 294 | write_json(splits, self.split_path) 295 | 296 | print("Splits created") 297 | 298 | def _process_data(self, dirnames, cam1=True, cam2=True): 299 | tracklets = [] 300 | num_imgs_per_tracklet = [] 301 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 302 | 303 | #txt_name = 'ilids' + str(cam1) + str(cam2) + '.txt' 304 | #fid = open(txt_name, "w") 305 | 306 | for dirname in dirnames: 307 | if cam1: 308 | person_dir = osp.join(self.cam_1_path, dirname) 309 | img_names = glob.glob(osp.join(person_dir, '*.png')) 310 | assert len(img_names) > 0 311 | img_names = tuple(img_names) 312 | pid = dirname2pid[dirname] 313 | tracklets.append((img_names, pid, 0)) 314 | num_imgs_per_tracklet.append(len(img_names)) 315 | #fid.write("cam1_" + dirname + '\n') 316 | if cam2: 317 | person_dir = osp.join(self.cam_2_path, dirname) 318 | img_names = glob.glob(osp.join(person_dir, '*.png')) 319 | assert len(img_names) > 0 320 | img_names = tuple(img_names) 321 | pid = dirname2pid[dirname] 322 | tracklets.append((img_names, pid, 1)) 323 | num_imgs_per_tracklet.append(len(img_names)) 324 | #fid.write("cam2_" + dirname + '\n') 325 | 326 | num_tracklets = len(tracklets) 327 | num_pids = len(dirnames) 328 | 329 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 330 | 331 | class PRID(object): 332 | """ 333 | PRID 334 | 335 | Reference: 336 | Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011. 337 | 338 | Dataset statistics: 339 | # identities: 200 340 | # tracklets: 400 341 | # cameras: 2 342 | 343 | Args: 344 | split_id (int): indicates which split to use. There are totally 10 splits. 345 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0). 346 | """ 347 | root = './data/prid2011' 348 | dataset_url = 'https://files.icg.tugraz.at/f/6ab7e8ce8f/?raw=1' 349 | split_path = osp.join(root, 'splits_prid2011.json') 350 | cam_a_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_a') 351 | cam_b_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_b') 352 | 353 | def __init__(self, split_id=0, min_seq_len=0): 354 | self._check_before_run() 355 | splits = read_json(self.split_path) 356 | if split_id >= len(splits): 357 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 358 | split = splits[split_id] 359 | train_dirs, test_dirs = split['train'], split['test'] 360 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 361 | 362 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 363 | self._process_data(train_dirs, cam1=True, cam2=True) 364 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 365 | self._process_data(test_dirs, cam1=True, cam2=False) 366 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 367 | self._process_data(test_dirs, cam1=False, cam2=True) 368 | 369 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 370 | min_num = np.min(num_imgs_per_tracklet) 371 | max_num = np.max(num_imgs_per_tracklet) 372 | avg_num = np.mean(num_imgs_per_tracklet) 373 | 374 | num_total_pids = num_train_pids + num_query_pids 375 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 376 | 377 | print("=> PRID-2011 loaded") 378 | print("Dataset statistics:") 379 | print(" ------------------------------") 380 | print(" subset | # ids | # tracklets") 381 | print(" ------------------------------") 382 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 383 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 384 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 385 | print(" ------------------------------") 386 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 387 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 388 | print(" ------------------------------") 389 | 390 | self.train = train 391 | self.query = query 392 | self.gallery = gallery 393 | 394 | self.num_train_pids = num_train_pids 395 | self.num_query_pids = num_query_pids 396 | self.num_gallery_pids = num_gallery_pids 397 | 398 | def _check_before_run(self): 399 | """Check if all files are available before going deeper""" 400 | if not osp.exists(self.root): 401 | raise RuntimeError("'{}' is not available".format(self.root)) 402 | 403 | def _process_data(self, dirnames, cam1=True, cam2=True): 404 | tracklets = [] 405 | num_imgs_per_tracklet = [] 406 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 407 | 408 | #txt_name = 'prid' + str(cam1) + str(cam2) + '.txt' 409 | #fid = open(txt_name, "w") 410 | 411 | for dirname in dirnames: 412 | if cam1: 413 | person_dir = osp.join(self.cam_a_path, dirname) 414 | img_names = glob.glob(osp.join(person_dir, '*.png')) 415 | assert len(img_names) > 0 416 | img_names = tuple(img_names) 417 | pid = dirname2pid[dirname] 418 | tracklets.append((img_names, pid, 0)) 419 | num_imgs_per_tracklet.append(len(img_names)) 420 | #fid.write("cama_" + dirname + '\n') 421 | if cam2: 422 | person_dir = osp.join(self.cam_b_path, dirname) 423 | img_names = glob.glob(osp.join(person_dir, '*.png')) 424 | assert len(img_names) > 0 425 | img_names = tuple(img_names) 426 | pid = dirname2pid[dirname] 427 | tracklets.append((img_names, pid, 1)) 428 | num_imgs_per_tracklet.append(len(img_names)) 429 | #fid.write("camb_" + dirname + '\n') 430 | 431 | num_tracklets = len(tracklets) 432 | num_pids = len(dirnames) 433 | 434 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 435 | 436 | """Create dataset""" 437 | 438 | __factory = { 439 | 'mars': Mars, 440 | 'ilidsvid': iLIDSVID, 441 | 'prid': PRID, 442 | } 443 | 444 | def get_names(): 445 | return __factory.keys() 446 | 447 | def init_dataset(name, *args, **kwargs): 448 | if name not in __factory.keys(): 449 | raise KeyError("Unknown dataset: {}".format(name)) 450 | return __factory[name](*args, **kwargs) 451 | 452 | if __name__ == '__main__': 453 | # test 454 | #dataset = Market1501() 455 | #dataset = Mars() 456 | dataset = iLIDSVID() 457 | dataset = PRID() 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | -------------------------------------------------------------------------------- /data_manager_search.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import glob 4 | import re 5 | import sys 6 | import urllib 7 | import tarfile 8 | import zipfile 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | import numpy as np 12 | import random 13 | 14 | from utils import mkdir_if_missing, write_json, read_json 15 | 16 | """Dataset classes""" 17 | 18 | 19 | class Mars(object): 20 | """ 21 | MARS 22 | 23 | Reference: 24 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 25 | 26 | Dataset statistics: 27 | # identities: 1261 28 | # tracklets: 8298 (train) + 1980 (query) + 9330 (gallery) 29 | # cameras: 6 30 | 31 | Args: 32 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0). 33 | """ 34 | root = './data/mars' 35 | train_name_path = osp.join(root, 'info/train_name.txt') 36 | test_name_path = osp.join(root, 'info/test_name.txt') 37 | track_train_info_path = osp.join(root, 'info/tracks_train_info.mat') 38 | track_test_info_path = osp.join(root, 'info/tracks_test_info.mat') 39 | query_IDX_path = osp.join(root, 'info/query_IDX.mat') 40 | 41 | def __init__(self, min_seq_len=0): 42 | self._check_before_run() 43 | 44 | # prepare meta data 45 | train_names = self._get_names(self.train_name_path) 46 | test_names = self._get_names(self.test_name_path) 47 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) 48 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) 49 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) 50 | query_IDX -= 1 # index from 0 51 | track_query = track_test[query_IDX,:] 52 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 53 | track_gallery = track_test[gallery_IDX,:] 54 | 55 | train, num_train_tracklets, num_train_pids, num_train_imgs, valid, num_valid_tracklets, num_valid_pids, num_valid_imgs = \ 56 | self._process_data_train(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) 57 | 58 | query, num_query_tracklets, num_query_pids, num_query_imgs = \ 59 | self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 60 | 61 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = \ 62 | self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 63 | 64 | num_imgs_per_tracklet = num_train_imgs + num_query_imgs + num_gallery_imgs 65 | min_num = np.min(num_imgs_per_tracklet) 66 | max_num = np.max(num_imgs_per_tracklet) 67 | avg_num = np.mean(num_imgs_per_tracklet) 68 | 69 | num_total_pids = num_train_pids + num_query_pids 70 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 71 | 72 | print("=> MARS loaded") 73 | print("Dataset statistics:") 74 | print(" ------------------------------") 75 | print(" subset | # ids | # tracklets") 76 | print(" ------------------------------") 77 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 78 | print(" valid | {:5d} | {:8d}".format(num_valid_pids, num_valid_tracklets)) 79 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 80 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 81 | print(" ------------------------------") 82 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 83 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 84 | print(" ------------------------------") 85 | 86 | self.train = train 87 | self.valid = valid 88 | self.query = query 89 | self.gallery = gallery 90 | 91 | self.num_train_pids = num_train_pids 92 | self.num_valid_pids = num_valid_pids 93 | self.num_query_pids = num_query_pids 94 | self.num_gallery_pids = num_gallery_pids 95 | 96 | def _check_before_run(self): 97 | """Check if all files are available before going deeper""" 98 | if not osp.exists(self.root): 99 | raise RuntimeError("'{}' is not available".format(self.root)) 100 | if not osp.exists(self.train_name_path): 101 | raise RuntimeError("'{}' is not available".format(self.train_name_path)) 102 | if not osp.exists(self.test_name_path): 103 | raise RuntimeError("'{}' is not available".format(self.test_name_path)) 104 | if not osp.exists(self.track_train_info_path): 105 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path)) 106 | if not osp.exists(self.track_test_info_path): 107 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path)) 108 | if not osp.exists(self.query_IDX_path): 109 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path)) 110 | 111 | def _get_names(self, fpath): 112 | names = [] 113 | with open(fpath, 'r') as f: 114 | for line in f: 115 | new_line = line.rstrip() 116 | names.append(new_line) 117 | return names 118 | 119 | def _process_data_train(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 120 | assert home_dir in ['bbox_train'] 121 | num_tracklets = meta_data.shape[0] 122 | pid_list = list(set(meta_data[:,2].tolist())) 123 | num_pids = len(pid_list) 124 | 125 | if relabel: 126 | pid2label = {pid:label for label, pid in enumerate(pid_list)} 127 | label = [pid2label[meta_data[i, 2]] for i in range(num_tracklets)] 128 | tracklets_train = [] 129 | tracklets_valid = [] 130 | num_imgs_per_tracklet_train = [] 131 | num_imgs_per_tracklet_valid = [] 132 | pid_train = [] 133 | pid_valid = [] 134 | 135 | for p_id in range(num_pids): 136 | pidx = [i for i, x in enumerate(label) if x == p_id] 137 | num_t = len(pidx) 138 | random.shuffle(pidx) 139 | idx_split = int(num_t/2) 140 | for i, idx in enumerate(pidx): 141 | data = meta_data[idx,...] 142 | start_index, end_index, pid, camid = data 143 | if pid == -1: continue # junk images are just ignored 144 | assert 1 <= camid <= 6 145 | if relabel: pid = pid2label[pid] 146 | camid -= 1 # index starts from 0 147 | img_names = names[start_index-1:end_index] 148 | 149 | # make sure image names correspond to the same person 150 | pnames = [img_name[:4] for img_name in img_names] 151 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 152 | 153 | # make sure all images are captured under the same camera 154 | camnames = [img_name[5] for img_name in img_names] 155 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 156 | 157 | # append image names with directory information 158 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] 159 | if len(img_paths) >= min_seq_len: 160 | img_paths = tuple(img_paths) 161 | if i < idx_split: 162 | tracklets_valid.append((img_paths, pid, camid)) 163 | num_imgs_per_tracklet_valid.append(len(img_paths)) 164 | pid_valid.append(pid) 165 | else: 166 | tracklets_train.append((img_paths, pid, camid)) 167 | num_imgs_per_tracklet_train.append(len(img_paths)) 168 | pid_train.append(pid) 169 | if num_t == 1: 170 | tracklets_valid.append((img_paths, pid, camid)) 171 | num_imgs_per_tracklet_valid.append(len(img_paths)) 172 | pid_valid.append(pid) 173 | ''' 174 | for tracklet_idx in range(num_tracklets): 175 | data = meta_data[tracklet_idx,...] 176 | start_index, end_index, pid, camid = data 177 | if pid == -1: continue # junk images are just ignored 178 | assert 1 <= camid <= 6 179 | if relabel: pid = pid2label[pid] 180 | camid -= 1 # index starts from 0 181 | img_names = names[start_index-1:end_index] 182 | 183 | # make sure image names correspond to the same person 184 | pnames = [img_name[:4] for img_name in img_names] 185 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 186 | 187 | # make sure all images are captured under the same camera 188 | camnames = [img_name[5] for img_name in img_names] 189 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 190 | 191 | # append image names with directory information 192 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] 193 | if len(img_paths) >= min_seq_len: 194 | img_paths = tuple(img_paths) 195 | if random.random() > 0.5: 196 | tracklets_train.append((img_paths, pid, camid)) 197 | num_imgs_per_tracklet_train.append(len(img_paths)) 198 | pid_train.append(pid) 199 | else: 200 | tracklets_valid.append((img_paths, pid, camid)) 201 | num_imgs_per_tracklet_valid.append(len(img_paths)) 202 | pid_valid.append(pid) 203 | ''' 204 | num_tracklets_train = len(tracklets_train) 205 | num_tracklets_valid = len(tracklets_valid) 206 | num_pids_train = len(list(set(pid_train))) 207 | num_pids_valid = len(list(set(pid_valid))) 208 | 209 | return tracklets_train, num_tracklets_train, num_pids_train, num_imgs_per_tracklet_train, tracklets_valid, num_tracklets_valid, num_pids_valid, num_imgs_per_tracklet_valid 210 | 211 | def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 212 | assert home_dir in ['bbox_train', 'bbox_test'] 213 | num_tracklets = meta_data.shape[0] 214 | pid_list = list(set(meta_data[:,2].tolist())) 215 | num_pids = len(pid_list) 216 | 217 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 218 | tracklets = [] 219 | num_imgs_per_tracklet = [] 220 | 221 | for tracklet_idx in range(num_tracklets): 222 | data = meta_data[tracklet_idx,...] 223 | start_index, end_index, pid, camid = data 224 | if pid == -1: continue # junk images are just ignored 225 | assert 1 <= camid <= 6 226 | if relabel: pid = pid2label[pid] 227 | camid -= 1 # index starts from 0 228 | img_names = names[start_index-1:end_index] 229 | 230 | # make sure image names correspond to the same person 231 | pnames = [img_name[:4] for img_name in img_names] 232 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 233 | 234 | # make sure all images are captured under the same camera 235 | camnames = [img_name[5] for img_name in img_names] 236 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 237 | 238 | # append image names with directory information 239 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] 240 | if len(img_paths) >= min_seq_len: 241 | img_paths = tuple(img_paths) 242 | tracklets.append((img_paths, pid, camid)) 243 | num_imgs_per_tracklet.append(len(img_paths)) 244 | 245 | num_tracklets = len(tracklets) 246 | 247 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 248 | 249 | class iLIDSVID(object): 250 | """ 251 | iLIDS-VID 252 | 253 | Reference: 254 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014. 255 | 256 | Dataset statistics: 257 | # identities: 300 258 | # tracklets: 600 259 | # cameras: 2 260 | 261 | Args: 262 | split_id (int): indicates which split to use. There are totally 10 splits. 263 | """ 264 | root = './data/ilids-vid' 265 | dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar' 266 | data_dir = osp.join(root, 'i-LIDS-VID') 267 | split_dir = osp.join(root, 'train-test people splits') 268 | split_mat_path = osp.join(split_dir, 'train_test_splits_ilidsvid.mat') 269 | split_path = osp.join(root, 'splits.json') 270 | cam_1_path = osp.join(root, 'i-LIDS-VID/sequences/cam1') 271 | cam_2_path = osp.join(root, 'i-LIDS-VID/sequences/cam2') 272 | 273 | def __init__(self, split_id=0): 274 | self._download_data() 275 | self._check_before_run() 276 | 277 | self._prepare_split() 278 | splits = read_json(self.split_path) 279 | if split_id >= len(splits): 280 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 281 | split = splits[split_id] 282 | train_dirs, test_dirs = split['train'], split['test'] 283 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 284 | 285 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 286 | self._process_data(train_dirs, cam1=True, cam2=True) 287 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 288 | self._process_data(test_dirs, cam1=True, cam2=False) 289 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 290 | self._process_data(test_dirs, cam1=False, cam2=True) 291 | 292 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 293 | min_num = np.min(num_imgs_per_tracklet) 294 | max_num = np.max(num_imgs_per_tracklet) 295 | avg_num = np.mean(num_imgs_per_tracklet) 296 | 297 | num_total_pids = num_train_pids + num_query_pids 298 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 299 | 300 | print("=> iLIDS-VID loaded") 301 | print("Dataset statistics:") 302 | print(" ------------------------------") 303 | print(" subset | # ids | # tracklets") 304 | print(" ------------------------------") 305 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 306 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 307 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 308 | print(" ------------------------------") 309 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 310 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 311 | print(" ------------------------------") 312 | 313 | self.train = train 314 | self.query = query 315 | self.gallery = gallery 316 | 317 | self.num_train_pids = num_train_pids 318 | self.num_query_pids = num_query_pids 319 | self.num_gallery_pids = num_gallery_pids 320 | 321 | def _download_data(self): 322 | if osp.exists(self.root): 323 | print("This dataset has been downloaded.") 324 | return 325 | 326 | mkdir_if_missing(self.root) 327 | fpath = osp.join(self.root, osp.basename(self.dataset_url)) 328 | 329 | print("Downloading iLIDS-VID dataset") 330 | url_opener = urllib.URLopener() 331 | url_opener.retrieve(self.dataset_url, fpath) 332 | 333 | print("Extracting files") 334 | tar = tarfile.open(fpath) 335 | tar.extractall(path=self.root) 336 | tar.close() 337 | 338 | def _check_before_run(self): 339 | """Check if all files are available before going deeper""" 340 | if not osp.exists(self.root): 341 | raise RuntimeError("'{}' is not available".format(self.root)) 342 | if not osp.exists(self.data_dir): 343 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 344 | if not osp.exists(self.split_dir): 345 | raise RuntimeError("'{}' is not available".format(self.split_dir)) 346 | 347 | def _prepare_split(self): 348 | if not osp.exists(self.split_path): 349 | print("Creating splits") 350 | mat_split_data = loadmat(self.split_mat_path)['ls_set'] 351 | 352 | num_splits = mat_split_data.shape[0] 353 | num_total_ids = mat_split_data.shape[1] 354 | assert num_splits == 10 355 | assert num_total_ids == 300 356 | num_ids_each = num_total_ids/2 357 | 358 | # pids in mat_split_data are indices, so we need to transform them 359 | # to real pids 360 | person_cam1_dirs = os.listdir(self.cam_1_path) 361 | person_cam2_dirs = os.listdir(self.cam_2_path) 362 | 363 | # make sure persons in one camera view can be found in the other camera view 364 | assert set(person_cam1_dirs) == set(person_cam2_dirs) 365 | 366 | splits = [] 367 | for i_split in range(num_splits): 368 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14. 369 | train_idxs = sorted(list(mat_split_data[i_split,num_ids_each:])) 370 | test_idxs = sorted(list(mat_split_data[i_split,:num_ids_each])) 371 | 372 | train_idxs = [int(i)-1 for i in train_idxs] 373 | test_idxs = [int(i)-1 for i in test_idxs] 374 | 375 | # transform pids to person dir names 376 | train_dirs = [person_cam1_dirs[i] for i in train_idxs] 377 | test_dirs = [person_cam1_dirs[i] for i in test_idxs] 378 | 379 | split = {'train': train_dirs, 'test': test_dirs} 380 | splits.append(split) 381 | 382 | print("Totally {} splits are created, following Wang et al. ECCV'14".format(len(splits))) 383 | print("Split file is saved to {}".format(self.split_path)) 384 | write_json(splits, self.split_path) 385 | 386 | print("Splits created") 387 | 388 | def _process_data(self, dirnames, cam1=True, cam2=True): 389 | tracklets = [] 390 | num_imgs_per_tracklet = [] 391 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 392 | 393 | for dirname in dirnames: 394 | if cam1: 395 | person_dir = osp.join(self.cam_1_path, dirname) 396 | img_names = glob.glob(osp.join(person_dir, '*.png')) 397 | assert len(img_names) > 0 398 | img_names = tuple(img_names) 399 | pid = dirname2pid[dirname] 400 | tracklets.append((img_names, pid, 0)) 401 | num_imgs_per_tracklet.append(len(img_names)) 402 | 403 | if cam2: 404 | person_dir = osp.join(self.cam_2_path, dirname) 405 | img_names = glob.glob(osp.join(person_dir, '*.png')) 406 | assert len(img_names) > 0 407 | img_names = tuple(img_names) 408 | pid = dirname2pid[dirname] 409 | tracklets.append((img_names, pid, 1)) 410 | num_imgs_per_tracklet.append(len(img_names)) 411 | 412 | num_tracklets = len(tracklets) 413 | num_pids = len(dirnames) 414 | 415 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 416 | 417 | class PRID(object): 418 | """ 419 | PRID 420 | 421 | Reference: 422 | Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011. 423 | 424 | Dataset statistics: 425 | # identities: 200 426 | # tracklets: 400 427 | # cameras: 2 428 | 429 | Args: 430 | split_id (int): indicates which split to use. There are totally 10 splits. 431 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0). 432 | """ 433 | root = './data/prid2011' 434 | dataset_url = 'https://files.icg.tugraz.at/f/6ab7e8ce8f/?raw=1' 435 | split_path = osp.join(root, 'splits_prid2011.json') 436 | cam_a_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_a') 437 | cam_b_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_b') 438 | 439 | def __init__(self, split_id=0, min_seq_len=0): 440 | self._check_before_run() 441 | splits = read_json(self.split_path) 442 | if split_id >= len(splits): 443 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 444 | split = splits[split_id] 445 | train_dirs, test_dirs = split['train'], split['test'] 446 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 447 | 448 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 449 | self._process_data(train_dirs, cam1=True, cam2=True) 450 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 451 | self._process_data(test_dirs, cam1=True, cam2=False) 452 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 453 | self._process_data(test_dirs, cam1=False, cam2=True) 454 | 455 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 456 | min_num = np.min(num_imgs_per_tracklet) 457 | max_num = np.max(num_imgs_per_tracklet) 458 | avg_num = np.mean(num_imgs_per_tracklet) 459 | 460 | num_total_pids = num_train_pids + num_query_pids 461 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 462 | 463 | print("=> PRID-2011 loaded") 464 | print("Dataset statistics:") 465 | print(" ------------------------------") 466 | print(" subset | # ids | # tracklets") 467 | print(" ------------------------------") 468 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 469 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 470 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 471 | print(" ------------------------------") 472 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 473 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 474 | print(" ------------------------------") 475 | 476 | self.train = train 477 | self.query = query 478 | self.gallery = gallery 479 | 480 | self.num_train_pids = num_train_pids 481 | self.num_query_pids = num_query_pids 482 | self.num_gallery_pids = num_gallery_pids 483 | 484 | def _check_before_run(self): 485 | """Check if all files are available before going deeper""" 486 | if not osp.exists(self.root): 487 | raise RuntimeError("'{}' is not available".format(self.root)) 488 | 489 | def _process_data(self, dirnames, cam1=True, cam2=True): 490 | tracklets = [] 491 | num_imgs_per_tracklet = [] 492 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 493 | 494 | for dirname in dirnames: 495 | if cam1: 496 | person_dir = osp.join(self.cam_a_path, dirname) 497 | img_names = glob.glob(osp.join(person_dir, '*.png')) 498 | assert len(img_names) > 0 499 | img_names = tuple(img_names) 500 | pid = dirname2pid[dirname] 501 | tracklets.append((img_names, pid, 0)) 502 | num_imgs_per_tracklet.append(len(img_names)) 503 | 504 | if cam2: 505 | person_dir = osp.join(self.cam_b_path, dirname) 506 | img_names = glob.glob(osp.join(person_dir, '*.png')) 507 | assert len(img_names) > 0 508 | img_names = tuple(img_names) 509 | pid = dirname2pid[dirname] 510 | tracklets.append((img_names, pid, 1)) 511 | num_imgs_per_tracklet.append(len(img_names)) 512 | 513 | num_tracklets = len(tracklets) 514 | num_pids = len(dirnames) 515 | 516 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 517 | 518 | """Create dataset""" 519 | 520 | __factory = { 521 | 'mars': Mars, 522 | 'ilidsvid': iLIDSVID, 523 | 'prid': PRID, 524 | } 525 | 526 | def get_names(): 527 | return __factory.keys() 528 | 529 | def init_dataset(name, *args, **kwargs): 530 | if name not in __factory.keys(): 531 | raise KeyError("Unknown dataset: {}".format(name)) 532 | return __factory[name](*args, **kwargs) 533 | 534 | if __name__ == '__main__': 535 | # test 536 | #dataset = Market1501() 537 | #dataset = Mars() 538 | dataset = iLIDSVID() 539 | dataset = PRID() 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | import copy 4 | 5 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 6 | num_q, num_g = distmat.shape 7 | if num_g < max_rank: 8 | max_rank = num_g 9 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 10 | indices = np.argsort(distmat, axis=1) 11 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 12 | 13 | # compute cmc curve for each query 14 | all_cmc = [] 15 | all_AP = [] 16 | num_valid_q = 0. 17 | for q_idx in range(num_q): 18 | # get query pid and camid 19 | q_pid = q_pids[q_idx] 20 | q_camid = q_camids[q_idx] 21 | 22 | # remove gallery samples that have the same pid and camid with query 23 | order = indices[q_idx] 24 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 25 | keep = np.invert(remove) 26 | 27 | # compute cmc curve 28 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 29 | if not np.any(orig_cmc): 30 | # this condition is true when query identity does not appear in gallery 31 | continue 32 | 33 | cmc = orig_cmc.cumsum() 34 | cmc[cmc > 1] = 1 35 | 36 | all_cmc.append(cmc[:max_rank]) 37 | num_valid_q += 1. 38 | 39 | # compute average precision 40 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 41 | num_rel = orig_cmc.sum() 42 | tmp_cmc = orig_cmc.cumsum() 43 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 44 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 45 | AP = tmp_cmc.sum() / num_rel 46 | all_AP.append(AP) 47 | 48 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 49 | 50 | all_cmc = np.asarray(all_cmc).astype(np.float32) 51 | all_cmc = all_cmc.sum(0) / num_valid_q 52 | mAP = np.mean(all_AP) 53 | 54 | return all_cmc, mAP 55 | 56 | def get_rank_list(dist_vec, q_id, q_cam, g_ids, g_cams, rank_list_size): 57 | sort_inds = np.argsort(dist_vec) 58 | rank_list = [] 59 | same_id = [] 60 | i = 0 61 | for ind, g_id, g_cam in zip(sort_inds, g_ids[sort_inds], g_cams[sort_inds]): 62 | # Skip gallery images with same id and same camera as query 63 | if (q_id == g_id) and (q_cam == g_cam): 64 | continue 65 | same_id.append(q_id == g_id) 66 | rank_list.append(ind) 67 | i += 1 68 | if i >= rank_list_size: 69 | break 70 | return rank_list, same_id 71 | 72 | def save_rank_list_to_file(rank_list, same_id, g_file_path, save_path): 73 | g_imgs = [] 74 | with open(g_file_path, "r") as fid: 75 | while True: 76 | line = fid.readline() 77 | if not line: 78 | break 79 | g_imgs.append(line.rstrip()) 80 | 81 | fid = open(save_path, "a+") 82 | for ind, sid in zip(rank_list, same_id): 83 | fid.write(g_imgs[ind] + '\t') 84 | fid.write(str(sid) + '\t') 85 | fid.write('\n') 86 | fid.close() 87 | 88 | def save_results(distmat, q_pids, g_pids, q_camids, g_camids): 89 | #save_path = "data/mars/mars_results_graphsage_new.txt" 90 | #query_file = "data/mars/mars_query.txt" 91 | #gallery_file = "data/mars/mars_gallery.txt" 92 | 93 | save_path = "data/prid2011/prid_results_graphsage_part.txt" 94 | query_file = "data/prid2011/prid_query.txt" 95 | gallery_file = "data/prid2011/prid_gallery.txt" 96 | 97 | #save_path = "data/ilids-vid/ilids_results_graphsage_part.txt" 98 | #query_file = "data/ilids-vid/ilids_query.txt" 99 | #gallery_file = "data/ilids-vid/ilids_gallery.txt" 100 | 101 | with open(save_path, "a+") as fid: 102 | fid.write("\n") 103 | fid.write("++++++++++++++++++++++++++++++++++++++++++\n") 104 | 105 | q_imgs = [] 106 | with open(query_file, "r") as fid: 107 | while True: 108 | line = fid.readline() 109 | if not line: 110 | break 111 | q_imgs.append(line.rstrip()) 112 | 113 | for i in range(q_pids.shape[0]): 114 | rank_list, same_id = get_rank_list( 115 | distmat[i], q_pids[i], q_camids[i], g_pids, g_camids, 10) 116 | with open(save_path, "a+") as fid: 117 | fid.write(q_imgs[i] + ":\t") 118 | save_rank_list_to_file(rank_list, same_id, gallery_file, save_path) 119 | 120 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import models 7 | from models import hypnn 8 | 9 | """ 10 | Shorthands for loss: 11 | - CrossEntropyLabelSmooth: xent 12 | - TripletLoss: htri 13 | - CenterLoss: cent 14 | """ 15 | __all__ = ['CrossEntropyLabelSmooth', 'TripletLoss', 'CenterLoss'] 16 | 17 | class CrossEntropyLabelSmooth(nn.Module): 18 | """Cross entropy loss with label smoothing regularizer. 19 | 20 | Reference: 21 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 22 | Equation: y = (1 - epsilon) * y + epsilon / K. 23 | 24 | Args: 25 | num_classes (int): number of classes. 26 | epsilon (float): weight. 27 | """ 28 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 29 | super(CrossEntropyLabelSmooth, self).__init__() 30 | self.num_classes = num_classes 31 | self.epsilon = epsilon 32 | self.use_gpu = use_gpu 33 | self.logsoftmax = nn.LogSoftmax(dim=1) 34 | 35 | def forward(self, inputs, targets): 36 | """ 37 | Args: 38 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 39 | targets: ground truth labels with shape (num_classes) 40 | """ 41 | log_probs = self.logsoftmax(inputs) 42 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 43 | if self.use_gpu: targets = targets.cuda() 44 | targets = Variable(targets, requires_grad=False) 45 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 46 | loss = (- targets * log_probs).mean(0).sum() 47 | return loss 48 | 49 | class TripletLoss(nn.Module): 50 | """Triplet loss with hard positive/negative mining. 51 | 52 | Reference: 53 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 54 | 55 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 56 | 57 | Args: 58 | margin (float): margin for triplet. 59 | """ 60 | def __init__(self, margin=0.3): 61 | super(TripletLoss, self).__init__() 62 | self.margin = margin 63 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 64 | 65 | def forward(self, inputs, targets): 66 | """ 67 | Args: 68 | inputs: feature matrix with shape (batch_size, feat_dim) 69 | targets: ground truth labels with shape (num_classes) 70 | """ 71 | n = inputs.size(0) 72 | # Compute pairwise distance, replace by the official when merged 73 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 74 | dist = dist + dist.t() 75 | dist.addmm_(1, -2, inputs, inputs.t()) 76 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 77 | # For each anchor, find the hardest positive and negative 78 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 79 | dist_ap, dist_an = [], [] 80 | for i in range(n): 81 | dist_ap.append(dist[i][mask[i]].max()) 82 | dist_an.append(dist[i][mask[i] == 0].min()) 83 | #dist_ap = torch.cat(dist_ap) 84 | #dist_an = torch.cat(dist_an) 85 | dist_ap = torch.stack(dist_ap, dim=0) 86 | dist_an = torch.stack(dist_an, dim=0) 87 | # Compute ranking hinge loss 88 | y = dist_an.data.new() 89 | y.resize_as_(dist_an.data) 90 | y.fill_(1) 91 | y = Variable(y) 92 | loss = self.ranking_loss(dist_an, dist_ap, y) 93 | return loss 94 | 95 | class CenterLoss(nn.Module): 96 | """Center loss. 97 | 98 | Reference: 99 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 100 | 101 | Args: 102 | num_classes (int): number of classes. 103 | feat_dim (int): feature dimension. 104 | """ 105 | def __init__(self, num_classes=10, feat_dim=2, use_gpu=True): 106 | super(CenterLoss, self).__init__() 107 | self.num_classes = num_classes 108 | self.feat_dim = feat_dim 109 | self.use_gpu = use_gpu 110 | 111 | if self.use_gpu: 112 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 113 | else: 114 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 115 | 116 | def forward(self, x, labels): 117 | """ 118 | Args: 119 | x: feature matrix with shape (batch_size, feat_dim). 120 | labels: ground truth labels with shape (num_classes). 121 | """ 122 | batch_size = x.size(0) 123 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 124 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 125 | distmat.addmm_(1, -2, x, self.centers.t()) 126 | 127 | classes = torch.arange(self.num_classes).long() 128 | if self.use_gpu: classes = classes.cuda() 129 | classes = Variable(classes) 130 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 131 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 132 | 133 | dist = [] 134 | for i in range(batch_size): 135 | value = distmat[i][mask[i]] 136 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 137 | dist.append(value) 138 | dist = torch.cat(dist) 139 | loss = dist.mean() 140 | 141 | return loss 142 | 143 | class PoincareTripletLoss(nn.Module): 144 | """Triplet loss with hard positive/negative mining. 145 | 146 | Reference: 147 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 148 | 149 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 150 | 151 | Args: 152 | margin (float): margin for triplet. 153 | """ 154 | def __init__(self, margin=0.3): 155 | super(PoincareTripletLoss, self).__init__() 156 | self.margin = margin 157 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 158 | #self.epsilon = torch.from_numpy(np.eye(n)*1e-8) 159 | 160 | 161 | def forward(self, inputs, targets): 162 | """ 163 | Args: 164 | inputs: feature matrix with shape (batch_size, feat_dim) 165 | targets: ground truth labels with shape (num_classes) 166 | """ 167 | #print(inputs) 168 | n = inputs.size(0) 169 | #print("inputs", inputs.shape) 170 | 171 | dist_mat = hypnn.dist_batch(inputs, inputs, c=1e-4) 172 | #print(dist_mat.shape) 173 | #print(dist_mat) 174 | 175 | # For each anchor, find the hardest positive and negative 176 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 177 | dist_ap, dist_an = [], [] 178 | for i in range(n): 179 | dist_ap.append(dist_mat[i][mask[i]].max()) 180 | dist_an.append(dist_mat[i][mask[i] == 0].min()) 181 | #dist_ap = torch.cat(dist_ap) 182 | #dist_an = torch.cat(dist_an) 183 | dist_ap = torch.stack(dist_ap, dim=0) 184 | dist_an = torch.stack(dist_an, dim=0) 185 | # Compute ranking hinge loss 186 | y = dist_an.data.new() 187 | y.resize_as_(dist_an.data) 188 | y.fill_(1) 189 | y = Variable(y) 190 | loss = self.ranking_loss(dist_an, dist_ap, y) 191 | return loss 192 | 193 | 194 | if __name__ == '__main__': 195 | pass 196 | -------------------------------------------------------------------------------- /main_video_person_reid_hypergraphsage_part.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import sys 4 | import time 5 | import datetime 6 | import argparse 7 | import os.path as osp 8 | import numpy as np 9 | from torch.nn import functional as F 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.backends.cudnn as cudnn 14 | from torch.utils.data import DataLoader 15 | from torch.autograd import Variable 16 | from torch.optim import lr_scheduler 17 | 18 | import data_manager 19 | from video_loader import VideoDataset 20 | import transforms as T 21 | import models 22 | from models import resnet3d 23 | from losses import CrossEntropyLabelSmooth, TripletLoss 24 | from utils import AverageMeter, Logger, save_checkpoint, build_adj, build_adj_full, build_adj_full_full, WarmupMultiStepLR 25 | from eval_metrics import evaluate 26 | from samplers import RandomIdentitySampler 27 | 28 | parser = argparse.ArgumentParser(description='Train video model with cross entropy loss') 29 | # Datasets 30 | parser.add_argument('-d', '--dataset', type=str, default='mars', 31 | choices=data_manager.get_names()) 32 | parser.add_argument('-j', '--workers', default=4, type=int, 33 | help="number of data loading workers (default: 4)") 34 | parser.add_argument('--height', type=int, default=224, 35 | help="height of an image (default: 224)") 36 | parser.add_argument('--width', type=int, default=112, 37 | help="width of an image (default: 112)") 38 | parser.add_argument('--seq-len', type=int, default=4, help="number of images to sample in a tracklet") 39 | # Optimization options 40 | parser.add_argument('--max-epoch', default=800, type=int, 41 | help="maximum epochs to run") 42 | parser.add_argument('--start-epoch', default=0, type=int, 43 | help="manual epoch number (useful on restarts)") 44 | parser.add_argument('--train-batch', default=32, type=int, 45 | help="train batch size") 46 | parser.add_argument('--test-batch', default=1, type=int, help="has to be 1") 47 | parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float, 48 | help="initial learning rate, use 0.0001 for rnn, use 0.0003 for pooling and attention") 49 | parser.add_argument('--stepsize', default=200, type=int, 50 | help="stepsize to decay learning rate (>0 means this is enabled)") 51 | parser.add_argument('--gamma', default=0.1, type=float, 52 | help="learning rate decay") 53 | parser.add_argument('--weight-decay', default=5e-04, type=float, 54 | help="weight decay (default: 5e-04)") 55 | parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss") 56 | parser.add_argument('--num-instances', type=int, default=4, 57 | help="number of instances per identity") 58 | parser.add_argument('--htri-only', action='store_true', default=False, 59 | help="if this is True, only htri loss is used in training") 60 | parser.add_argument('--xent-only', action='store_true', default=False, 61 | help="if this is True, only xent loss is used in training") 62 | # Architecture 63 | parser.add_argument('-a', '--arch', type=str, default='resnet50tp', help="resnet503d, resnet50tp, resnet50ta, resnetrnn") 64 | parser.add_argument('--pool', type=str, default='avg', choices=['avg', 'max']) 65 | 66 | # Miscs 67 | parser.add_argument('--print-freq', type=int, default=80, help="print frequency") 68 | parser.add_argument('--seed', type=int, default=1, help="manual seed") 69 | parser.add_argument('--pretrained-model', type=str, default='/home/jiyang/Workspace/Works/video-person-reid/3dconv-person-reid/pretrained_models/resnet-50-kinetics.pth', help='need to be set for resnet3d models') 70 | parser.add_argument('--evaluate', action='store_true', help="evaluation only") 71 | parser.add_argument('--eval-step', type=int, default=50, 72 | help="run evaluation for every N epochs (set to -1 to test after training)") 73 | parser.add_argument('--save-dir', type=str, default='log') 74 | parser.add_argument('--use-cpu', action='store_true', help="use cpu") 75 | parser.add_argument('--gpu-devices', default='0,1', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') 76 | 77 | parser.add_argument('--dropout', default=0.1, type=float, help='dropout ratio for GAT') 78 | parser.add_argument('--nhid', default=512, type=int, help='hidden dimension of GAT') 79 | parser.add_argument('--nheads', default=8, type=int, help='number of attention heads for GAT') 80 | parser.add_argument('--concat', default=False, type=bool, help='') 81 | parser.add_argument('--part1', default=4, type=int, help='') 82 | parser.add_argument('--part2', default=8, type=int, help='') 83 | parser.add_argument('--part3', default=2, type=int, help='') 84 | parser.add_argument('--warmup', action='store_true', help='use warming up scheduler') 85 | 86 | args = parser.parse_args() 87 | 88 | def main(): 89 | torch.manual_seed(args.seed) 90 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 91 | use_gpu = torch.cuda.is_available() 92 | if args.use_cpu: use_gpu = False 93 | 94 | if not args.evaluate: 95 | sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) 96 | else: 97 | sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt')) 98 | print("==========\nArgs:{}\n==========".format(args)) 99 | 100 | if use_gpu: 101 | print("Currently using GPU {}".format(args.gpu_devices)) 102 | cudnn.benchmark = True 103 | torch.cuda.manual_seed_all(args.seed) 104 | else: 105 | print("Currently using CPU (GPU is highly recommended)") 106 | 107 | print("Initializing dataset {}".format(args.dataset)) 108 | dataset = data_manager.init_dataset(name=args.dataset) 109 | 110 | transform_train = T.Compose([ 111 | T.Random2DTranslation(args.height, args.width), 112 | T.RandomHorizontalFlip(), 113 | T.ToTensor(), 114 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 115 | T.RandomErasing(probability = 0.5, mean=[0.0, 0.0, 0.0]), 116 | ]) 117 | 118 | transform_test = T.Compose([ 119 | T.Resize((args.height, args.width)), 120 | T.ToTensor(), 121 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 122 | ]) 123 | 124 | pin_memory = True if use_gpu else False 125 | 126 | if args.xent_only: 127 | trainloader = DataLoader( 128 | VideoDataset(dataset.train, seq_len=args.seq_len, sample='random',transform=transform_train), 129 | batch_size=args.train_batch, shuffle=True, num_workers=args.workers, 130 | pin_memory=pin_memory, drop_last=True, 131 | ) 132 | else: 133 | trainloader = DataLoader( 134 | VideoDataset(dataset.train, seq_len=args.seq_len, sample='random',transform=transform_train), 135 | sampler=RandomIdentitySampler(dataset.train, num_instances=args.num_instances), 136 | batch_size=args.train_batch, num_workers=args.workers, 137 | pin_memory=pin_memory, drop_last=True, 138 | ) 139 | 140 | queryloader = DataLoader( 141 | VideoDataset(dataset.query, seq_len=args.seq_len, sample='dense', transform=transform_test), 142 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 143 | pin_memory=pin_memory, drop_last=False, 144 | ) 145 | 146 | galleryloader = DataLoader( 147 | VideoDataset(dataset.gallery, seq_len=args.seq_len, sample='dense', transform=transform_test), 148 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 149 | pin_memory=pin_memory, drop_last=False, 150 | ) 151 | 152 | print("Initializing model: {}".format(args.arch)) 153 | if args.arch=='resnet503d': 154 | model = resnet3d.resnet50(num_classes=dataset.num_train_pids, sample_width=args.width, sample_height=args.height, sample_duration=args.seq_len) 155 | if not os.path.exists(args.pretrained_model): 156 | raise IOError("Can't find pretrained model: {}".format(args.pretrained_model)) 157 | print("Loading checkpoint from '{}'".format(args.pretrained_model)) 158 | checkpoint = torch.load(args.pretrained_model) 159 | state_dict = {} 160 | for key in checkpoint['state_dict']: 161 | if 'fc' in key: continue 162 | state_dict[key.partition("module.")[2]] = checkpoint['state_dict'][key] 163 | model.load_state_dict(state_dict, strict=False) 164 | else: 165 | #model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, dropout=args.dropout, nhid=args.nhid, nheads=args.nheads, concat=args.concat, loss={'xent', 'htri'}) 166 | model = models.init_model(name=args.arch, pool_size=8, input_shape=2048, n_classes=dataset.num_train_pids, loss={'xent', 'htri'}) 167 | print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0)) 168 | 169 | if os.path.exists(args.pretrained_model): 170 | print("Loading checkpoint from '{}'".format(args.pretrained_model)) 171 | checkpoint = torch.load(args.pretrained_model) 172 | model_dict = model.state_dict() 173 | pretrain_dict = checkpoint['state_dict'] 174 | pretrain_dict = {k:v for k, v in pretrain_dict.items() if k in model_dict} 175 | model_dict.update(pretrain_dict) 176 | model.load_state_dict(model_dict) 177 | 178 | criterion_xent = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu) 179 | criterion_htri = TripletLoss(margin=args.margin) 180 | 181 | #optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 182 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay) 183 | if args.stepsize > 0: 184 | if args.warmup: 185 | scheduler = WarmupMultiStepLR(optimizer, [200, 400, 600]) 186 | else: 187 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma) 188 | start_epoch = args.start_epoch 189 | 190 | if use_gpu: 191 | model = nn.DataParallel(model).cuda() 192 | 193 | if args.evaluate: 194 | print("Evaluate only") 195 | test(model, queryloader, galleryloader, args.pool, use_gpu) 196 | return 197 | 198 | start_time = time.time() 199 | best_rank1 = -np.inf 200 | if args.arch=='resnet503d': 201 | torch.backends.cudnn.benchmark = False 202 | 203 | ''' 204 | adj1 = build_adj_full_full(4, args.part1) 205 | adj2 = build_adj_full_full(4, args.part2) 206 | adj3 = build_adj_full_full(4, args.part3) 207 | if use_gpu: 208 | adj1 = adj1.cuda() 209 | adj2 = adj2.cuda() 210 | adj2 = adj2.cuda() 211 | adj1 = Variable(adj1) 212 | adj2 = Variable(adj2) 213 | adj3 = Variable(adj3) 214 | ''' 215 | 216 | for epoch in range(start_epoch, args.max_epoch): 217 | print("==> Epoch {}/{} lr:{}".format(epoch+1, args.max_epoch, scheduler.get_lr()[0])) 218 | 219 | #train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, adj1, adj2, adj3) 220 | train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu) 221 | 222 | if args.stepsize > 0: scheduler.step() 223 | 224 | if args.eval_step > 0 and (epoch+1) % args.eval_step == 0 or (epoch+1) == args.max_epoch: 225 | print("==> Test") 226 | #rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu, adj1, adj2, adj3) 227 | rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu) 228 | is_best = rank1 > best_rank1 229 | if is_best: best_rank1 = rank1 230 | 231 | if use_gpu: 232 | state_dict = model.module.state_dict() 233 | else: 234 | state_dict = model.state_dict() 235 | 236 | save_checkpoint({ 237 | 'state_dict': state_dict, 238 | 'rank1': rank1, 239 | 'epoch': epoch, 240 | }, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch+1) + '.pth.tar')) 241 | 242 | 243 | elapsed = round(time.time() - start_time) 244 | elapsed = str(datetime.timedelta(seconds=elapsed)) 245 | print("Finished. Total elapsed time (h:m:s): {}".format(elapsed)) 246 | 247 | #def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, adj1, adj2, adj3): 248 | def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu): 249 | model.train() 250 | losses = AverageMeter() 251 | #print(adj) 252 | for batch_idx, (imgs, pids, _) in enumerate(trainloader): 253 | if use_gpu: 254 | imgs, pids= imgs.cuda(), pids.cuda() 255 | imgs, pids= Variable(imgs), Variable(pids) 256 | #outputs, features = model(imgs, adj1, adj2, adj3) 257 | outputs, features = model(imgs) 258 | if args.htri_only: 259 | # only use hard triplet loss to train the network 260 | loss = criterion_htri(features, pids) 261 | elif args.xent_only: 262 | loss = criterion_xent(outputs, pids) 263 | else: 264 | # combine hard triplet loss with cross entropy loss 265 | #xent_loss = torch.sum(torch.cat([criterion_xent(logits, pids) for logits in outputs])) 266 | xent_loss = torch.sum(torch.stack([criterion_xent(logits, pids) for logits in outputs], dim=0)) 267 | #htri_loss = torch.sum(torch.stack([criterion_htri(feats, pids) for feats in features])) 268 | #xent_loss = criterion_xent(outputs, pids) 269 | htri_loss = criterion_htri(features, pids) 270 | loss = xent_loss + htri_loss 271 | optimizer.zero_grad() 272 | loss.backward() 273 | optimizer.step() 274 | #losses.update(loss.data[0].item(), pids.size(0)) 275 | losses.update(loss.data.item(), pids.size(0)) 276 | 277 | if (batch_idx+1) % args.print_freq == 0: 278 | print("Batch {}/{}\t Loss {:.6f} ({:.6f})".format(batch_idx+1, len(trainloader), losses.val, losses.avg)) 279 | 280 | #def test(model, queryloader, galleryloader, pool, use_gpu, adj1, adj2, adj3, ranks=[1, 5, 10, 20]): 281 | def test(model, queryloader, galleryloader, pool, use_gpu, ranks=[1, 5, 10, 20]): 282 | model.eval() 283 | 284 | qf, q_pids, q_camids = [], [], [] 285 | with torch.no_grad(): 286 | for batch_idx, (imgs, pids, camids) in enumerate(queryloader): 287 | if use_gpu: 288 | imgs = imgs.cuda() 289 | imgs = Variable(imgs) 290 | # b=1, n=number of clips, s=16 291 | b, n, s, c, h, w = imgs.size() 292 | assert(b==1) 293 | imgs = imgs.view(b*n, s, c, h, w) 294 | #features = model(imgs, adj1, adj2, adj3) 295 | features = model(imgs) 296 | features = features.view(n, -1) 297 | features = torch.mean(features, 0) 298 | features = features.data.cpu() 299 | qf.append(features) 300 | q_pids.extend(pids) 301 | q_camids.extend(camids) 302 | qf = torch.stack(qf) 303 | q_pids = np.asarray(q_pids) 304 | q_camids = np.asarray(q_camids) 305 | 306 | print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1))) 307 | 308 | gf, g_pids, g_camids = [], [], [] 309 | with torch.no_grad(): 310 | for batch_idx, (imgs, pids, camids) in enumerate(galleryloader): 311 | if use_gpu: 312 | imgs = imgs.cuda() 313 | imgs = Variable(imgs) 314 | b, n, s, c, h, w = imgs.size() 315 | imgs = imgs.view(b*n, s , c, h, w) 316 | assert(b==1) 317 | #features = model(imgs, adj1, adj2, adj3) 318 | features = model(imgs) 319 | features = features.view(n, -1) 320 | if pool == 'avg': 321 | features = torch.mean(features, 0) 322 | else: 323 | features, _ = torch.max(features, 0) 324 | features = features.data.cpu() 325 | gf.append(features) 326 | g_pids.extend(pids) 327 | g_camids.extend(camids) 328 | gf = torch.stack(gf) 329 | g_pids = np.asarray(g_pids) 330 | g_camids = np.asarray(g_camids) 331 | 332 | print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1))) 333 | print("Computing distance matrix") 334 | 335 | m, n = qf.size(0), gf.size(0) 336 | # cosine distance 337 | qf = F.normalize(qf, dim=1, p=2) 338 | gf = F.normalize(gf, dim=1, p=2) 339 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 340 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 341 | distmat.addmm_(1, -2, qf, gf.t()) 342 | distmat = distmat.numpy() 343 | 344 | print("Computing CMC and mAP") 345 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) 346 | 347 | print("Results ----------") 348 | print("mAP: {:.1%}".format(mAP)) 349 | print("CMC curve") 350 | for r in ranks: 351 | print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1])) 352 | print("------------------") 353 | 354 | return cmc[0] 355 | 356 | if __name__ == '__main__': 357 | main() 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | -------------------------------------------------------------------------------- /models/ResNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Variable 7 | import torchvision 8 | from .convlstm import ConvLSTM 9 | import math 10 | __all__ = ['ResNet50TP', 'ResNet50TPICA', 'ResNet50TA', 'ResNet50RNN', 'ResNet50CONVRNN', 'ResNet50GRU', 'ResNet50TPNEW', 'ResNet50TPPART'] 11 | from .resnet import ResNet, BasicBlock, Bottleneck, ResNetNonLocal 12 | 13 | class ResNet50TP(nn.Module): 14 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 15 | super(ResNet50TP, self).__init__() 16 | self.loss = loss 17 | resnet50 = torchvision.models.resnet50(pretrained=True) 18 | #self.base = nn.Sequential(*list(resnet50.children())[:-2]) 19 | #===== res50 with stride = 1 ================== 20 | #self.base = ResNet(last_stride=1, 21 | # block=Bottleneck, 22 | # layers=[3, 4, 6, 3]) 23 | self.base = ResNetNonLocal(last_stride=1, 24 | block=Bottleneck, 25 | layers=[3, 4, 6, 3]) 26 | self.base.load_param('/home/yy1/.torch/models/resnet50-19c8e357.pth') 27 | #print(self.base.state_dict()['layer1.2.conv1.weight']) 28 | 29 | self.feat_dim = 2048 30 | self.classifier = nn.Linear(self.feat_dim, num_classes) 31 | 32 | def forward(self, x): 33 | b = x.size(0) 34 | t = x.size(1) 35 | x = x.view(b*t,x.size(2), x.size(3), x.size(4)) 36 | x = self.base(x) 37 | x = F.avg_pool2d(x, x.size()[2:]) 38 | x = x.view(b,t,-1) 39 | x=x.permute(0,2,1) 40 | f = F.avg_pool1d(x,t) 41 | f = f.view(b, self.feat_dim) 42 | if not self.training: 43 | return f 44 | y = self.classifier(f) 45 | 46 | if self.loss == {'xent'}: 47 | return y 48 | elif self.loss == {'xent', 'htri'}: 49 | return y, f 50 | elif self.loss == {'cent'}: 51 | return y, f 52 | else: 53 | raise KeyError("Unsupported loss: {}".format(self.loss)) 54 | 55 | class ResNet50TPICA(nn.Module): 56 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 57 | super(ResNet50TPICA, self).__init__() 58 | self.loss = loss 59 | resnet50 = torchvision.models.resnet50(pretrained=True) 60 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 61 | self.feat_dim = 2048 62 | self.hidden_dim = self.feat_dim/4 63 | self.out_dim = self.feat_dim/2 64 | 65 | #======================= 66 | self.batchnorm1 = nn.BatchNorm2d(self.feat_dim) 67 | self.lrelu = nn.LeakyReLU(0.1) 68 | self.conv1 = nn.Conv2d(self.feat_dim, self.hidden_dim, 1) 69 | self.batchnorm2 = nn.BatchNorm2d(self.hidden_dim) 70 | self.fc = nn.Linear(self.hidden_dim, self.out_dim) 71 | #======================= 72 | self.classifier = nn.Linear(self.out_dim, num_classes) 73 | 74 | def forward(self, x): 75 | b = x.size(0) 76 | t = x.size(1) 77 | x = x.view(b*t,x.size(2), x.size(3), x.size(4)) 78 | x = self.base(x) 79 | x = self.batchnorm1(x) 80 | x = self.lrelu(x) 81 | x = F.avg_pool2d(x, x.size()[2:]) 82 | x = x.view(b,t,-1) 83 | x=x.permute(0,2,1) 84 | f = F.avg_pool1d(x,t) 85 | f = f.view(b, self.feat_dim) 86 | f = f.unsqueeze(2) 87 | f = f.unsqueeze(3) 88 | f = self.conv1(f) 89 | f = self.batchnorm2(f) 90 | f = f.view(b, self.hidden_dim) 91 | f = self.fc(f) 92 | 93 | if not self.training: 94 | return f 95 | y = self.classifier(f) 96 | 97 | if self.loss == {'xent'}: 98 | return y 99 | elif self.loss == {'xent', 'htri'}: 100 | return y, f 101 | elif self.loss == {'cent'}: 102 | return y, f 103 | else: 104 | raise KeyError("Unsupported loss: {}".format(self.loss)) 105 | 106 | class ResNet50TPPART(nn.Module): 107 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 108 | super(ResNet50TPPART, self).__init__() 109 | self.loss = loss 110 | resnet50 = torchvision.models.resnet50(pretrained=True) 111 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 112 | self.feat_dim = 2048 113 | self.classifier = nn.Linear(self.feat_dim, num_classes) 114 | self.p = 4. 115 | 116 | def forward(self, x): 117 | b = x.size(0) 118 | t = x.size(1) 119 | x = x.view(b*t,x.size(2), x.size(3), x.size(4)) 120 | x = self.base(x) 121 | #print(x.shape) 122 | 123 | #----------------- 124 | ''' 125 | x1 = F.avg_pool2d(x, x.size()[2:]) 126 | x1 = x1.view(b,t,-1) 127 | x1=x1.permute(0,2,1) 128 | f1 = F.avg_pool1d(x1,t) 129 | f1 = f1.view(b, self.feat_dim) 130 | ''' 131 | #----------------- 132 | 133 | #x = F.avg_pool2d(x, (int(math.ceil(x.size(-2)/self.p)), x.size(-1)), ceil_mode=True) # 128, 2048, 4, 1 134 | x = F.avg_pool2d(x, (2, x.size(-1))) 135 | 136 | #====================== 137 | x = x.permute(0,2,1,3) 138 | x = x.contiguous().view(b, t, int(self.p), -1) 139 | x = x.view(b, t*int(self.p), -1) 140 | x = x.permute(0, 2, 1) 141 | f = F.avg_pool1d(x, t*int(self.p)) 142 | f = f.view(b, self.feat_dim) 143 | #====================== 144 | #print(x.shape) 145 | #x = x.view(b, t, x.size(1), x.size(2), -1) 146 | #x = x.permute(0, 2, 1) 147 | 148 | #x = x.view(b,int(t*self.p),-1) 149 | #print(x.shape) 150 | #f = x.permute(0, 2, 1) 151 | #f = F.avg_pool1d(f, int(t*self.p)) 152 | #f = f.view(b, self.feat_dim) 153 | #print(f-f1) 154 | 155 | if not self.training: 156 | return f 157 | y = self.classifier(f) 158 | 159 | if self.loss == {'xent'}: 160 | return y 161 | elif self.loss == {'xent', 'htri'}: 162 | return y, f 163 | elif self.loss == {'cent'}: 164 | return y, f 165 | else: 166 | raise KeyError("Unsupported loss: {}".format(self.loss)) 167 | 168 | def weights_init_kaiming(m): 169 | classname = m.__class__.__name__ 170 | if classname.find('Linear') != -1: 171 | nn.init.kaiming_normal(m.weight, a=0, mode='fan_out') 172 | nn.init.constant_(m.bias, 0.0) 173 | elif classname.find('Conv') != -1: 174 | nn.init.kaiming_normal(m.weight, a=0, mode='fan_in') 175 | if m.bias is not None: 176 | nn.init.constant_(m.bias, 0.0) 177 | elif classname.find('BatchNorm') != -1: 178 | if m.affine: 179 | nn.init.constant(m.weight, 1.0) 180 | nn.init.constant(m.bias, 0.0) 181 | 182 | 183 | def weights_init_classifier(m): 184 | classname = m.__class__.__name__ 185 | if classname.find('Linear') != -1: 186 | nn.init.normal(m.weight, std=0.001) 187 | if m.bias: 188 | nn.init.constant(m.bias, 0.0) 189 | 190 | class ResNet50TPNEW(nn.Module): 191 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 192 | super(ResNet50TPNEW, self).__init__() 193 | self.loss = loss 194 | resnet50 = torchvision.models.resnet50(pretrained=True) 195 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 196 | self.feat_dim = 2048 197 | self.classifier = nn.Linear(self.feat_dim, num_classes, bias=False) 198 | 199 | self.bottleneck = nn.BatchNorm1d(self.feat_dim) 200 | self.bottleneck.bias.requires_grad = False 201 | 202 | self.bottleneck.apply(weights_init_kaiming) 203 | self.classifier.apply(weights_init_classifier) 204 | 205 | def forward(self, x): 206 | b = x.size(0) 207 | t = x.size(1) 208 | x = x.view(b*t,x.size(2), x.size(3), x.size(4)) 209 | x = self.base(x) 210 | x = F.avg_pool2d(x, x.size()[2:]) 211 | x = x.view(x.size(0), -1) 212 | x = self.bottleneck(x) 213 | x = x.view(b,t,-1) 214 | x=x.permute(0,2,1) 215 | f = F.avg_pool1d(x,t) 216 | f = f.view(b, self.feat_dim) 217 | if not self.training: 218 | return f 219 | y = self.classifier(f) 220 | 221 | if self.loss == {'xent'}: 222 | return y 223 | elif self.loss == {'xent', 'htri'}: 224 | return y, f 225 | elif self.loss == {'cent'}: 226 | return y, f 227 | else: 228 | raise KeyError("Unsupported loss: {}".format(self.loss)) 229 | 230 | 231 | class ResNet50TA(nn.Module): 232 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 233 | super(ResNet50TA, self).__init__() 234 | self.loss = loss 235 | resnet50 = torchvision.models.resnet50(pretrained=True) 236 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 237 | self.att_gen = 'softmax' # method for attention generation: softmax or sigmoid 238 | self.feat_dim = 2048 # feature dimension 239 | self.middle_dim = 256 # middle layer dimension 240 | self.classifier = nn.Linear(self.feat_dim, num_classes) 241 | self.attention_conv = nn.Conv2d(self.feat_dim, self.middle_dim, [7,4]) # 7,4 cooresponds to 224, 112 input image size 242 | self.attention_tconv = nn.Conv1d(self.middle_dim, 1, 3, padding=1) 243 | def forward(self, x): 244 | b = x.size(0) 245 | t = x.size(1) 246 | x = x.view(b*t, x.size(2), x.size(3), x.size(4)) 247 | x = self.base(x) 248 | a = F.relu(self.attention_conv(x)) 249 | a = a.view(b, t, self.middle_dim) 250 | a = a.permute(0,2,1) 251 | a = F.relu(self.attention_tconv(a)) 252 | a = a.view(b, t) 253 | x = F.avg_pool2d(x, x.size()[2:]) 254 | if self. att_gen=='softmax': 255 | a = F.softmax(a, dim=1) 256 | elif self.att_gen=='sigmoid': 257 | a = F.sigmoid(a) 258 | a = F.normalize(a, p=1, dim=1) 259 | else: 260 | raise KeyError("Unsupported attention generation function: {}".format(self.att_gen)) 261 | x = x.view(b, t, -1) 262 | a = torch.unsqueeze(a, -1) 263 | a = a.expand(b, t, self.feat_dim) 264 | att_x = torch.mul(x,a) 265 | att_x = torch.sum(att_x,1) 266 | 267 | f = att_x.view(b,self.feat_dim) 268 | if not self.training: 269 | return f 270 | y = self.classifier(f) 271 | 272 | if self.loss == {'xent'}: 273 | return y 274 | elif self.loss == {'xent', 'htri'}: 275 | return y, f 276 | elif self.loss == {'cent'}: 277 | return y, f 278 | else: 279 | raise KeyError("Unsupported loss: {}".format(self.loss)) 280 | 281 | 282 | class ResNet50RNN(nn.Module): 283 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 284 | super(ResNet50RNN, self).__init__() 285 | self.loss = loss 286 | resnet50 = torchvision.models.resnet50(pretrained=True) 287 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 288 | self.hidden_dim = 512 289 | self.feat_dim = 2048 290 | self.classifier = nn.Linear(self.hidden_dim, num_classes) 291 | self.lstm = nn.LSTM(input_size=self.feat_dim, hidden_size=self.hidden_dim, num_layers=1, batch_first=False) 292 | def forward(self, x): 293 | b = x.size(0) 294 | t = x.size(1) 295 | x = x.view(b*t,x.size(2), x.size(3), x.size(4)) 296 | x = self.base(x) 297 | x = F.avg_pool2d(x, x.size()[2:]) 298 | x = x.view(b,t,-1) 299 | x = x.permute(1, 0, 2) 300 | output, (h_n, c_n) = self.lstm(x) 301 | #print(output.shape) 302 | #output = output.permute(0, 2, 1) 303 | output = output.permute(1, 2, 0) 304 | #print(output.shape) 305 | f = F.avg_pool1d(output, t) 306 | #print(f.shape) 307 | f = f.view(b, self.hidden_dim) 308 | #print(f.shape) 309 | ''' 310 | torch.Size([32, 4, 512]) 311 | torch.Size([32, 512, 4]) 312 | torch.Size([32, 512, 1]) 313 | torch.Size([32, 512]) 314 | ''' 315 | if not self.training: 316 | return f 317 | y = self.classifier(f) 318 | 319 | if self.loss == {'xent'}: 320 | return y 321 | elif self.loss == {'xent', 'htri'}: 322 | return y, f 323 | elif self.loss == {'cent'}: 324 | return y, f 325 | else: 326 | raise KeyError("Unsupported loss: {}".format(self.loss)) 327 | 328 | class ResNet50GRU(nn.Module): 329 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 330 | super(ResNet50GRU, self).__init__() 331 | self.loss = loss 332 | resnet50 = torchvision.models.resnet50(pretrained=True) 333 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 334 | self.hidden_dim = 512 335 | self.feat_dim = 2048 336 | self.classifier = nn.Linear(self.hidden_dim, num_classes) 337 | self.lstm = nn.GRU(input_size=self.feat_dim, hidden_size=self.hidden_dim, num_layers=1, batch_first=False) 338 | def forward(self, x): 339 | b = x.size(0) 340 | t = x.size(1) 341 | x = x.view(b*t,x.size(2), x.size(3), x.size(4)) 342 | x = self.base(x) 343 | x = F.avg_pool2d(x, x.size()[2:]) 344 | x = x.view(b,t,-1) 345 | x = x.permute(1, 0, 2) 346 | output, h_n = self.lstm(x) 347 | #print(output.shape) 348 | #output = output.permute(0, 2, 1) 349 | output = output.permute(1, 2, 0) 350 | #print(output.shape) 351 | f = F.avg_pool1d(output, t) 352 | #print(f.shape) 353 | f = f.view(b, self.hidden_dim) 354 | #print(f.shape) 355 | ''' 356 | torch.Size([32, 4, 512]) 357 | torch.Size([32, 512, 4]) 358 | torch.Size([32, 512, 1]) 359 | torch.Size([32, 512]) 360 | ''' 361 | if not self.training: 362 | return f 363 | y = self.classifier(f) 364 | 365 | if self.loss == {'xent'}: 366 | return y 367 | elif self.loss == {'xent', 'htri'}: 368 | return y, f 369 | elif self.loss == {'cent'}: 370 | return y, f 371 | else: 372 | raise KeyError("Unsupported loss: {}".format(self.loss)) 373 | 374 | class ResNet50CONVRNN(nn.Module): 375 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 376 | super(ResNet50CONVRNN, self).__init__() 377 | self.loss = loss 378 | resnet50 = torchvision.models.resnet50(pretrained=True) 379 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 380 | self.hidden_dim = 512 381 | self.input_size = 7 382 | self.feat_dim = 2048 383 | self.kernel_size = 3 384 | self.classifier = nn.Linear(self.hidden_dim, num_classes) 385 | self.lstm = ConvLSTM(input_size=(self.input_size, int(self.input_size/2 + 1)), input_dim=self.feat_dim, hidden_dim=self.hidden_dim, kernel_size=(self.kernel_size, self.kernel_size), num_layers=1, batch_first=False) 386 | 387 | def forward(self, x): 388 | b = x.size(0) 389 | t = x.size(1) 390 | x = x.view(b*t,x.size(2), x.size(3), x.size(4)) 391 | x = self.base(x) 392 | #print(x.shape) 393 | #x = F.avg_pool2d(x, x.size()[2:]) 394 | #x = x.view(b,t,-1) 395 | #x = x.permute(1, 0, 2) 396 | x = x.view(b, t, x.size(1), x.size(2), x.size(3)) 397 | x = x.permute(1, 0, 2, 3, 4) 398 | 399 | output, last_state = self.lstm(x) 400 | #print(output.shape) 401 | output = output[0] 402 | #print(output.shape) 403 | #output = output.permute(0, 2, 1) 404 | #output = output.permute(1, 2, 0) 405 | output = output.view(t*b, output.size(2), output.size(3), output.size(4)) 406 | output = F.avg_pool2d(output, output.size()[2:]) 407 | output = output.view(t, b, output.size(1)) 408 | output = output.permute(1, 2, 0) 409 | #print(output.shape) 410 | f = F.avg_pool1d(output, t) 411 | #print(f.shape) 412 | f = f.view(b, self.hidden_dim) 413 | #print(f.shape) 414 | ''' 415 | torch.Size([128, 2048, 7, 4]) 416 | torch.Size([4, 32, 512, 7, 4]) 417 | torch.Size([32, 512, 4]) 418 | torch.Size([32, 512, 1]) 419 | torch.Size([32, 512]) 420 | ''' 421 | if not self.training: 422 | return f 423 | y = self.classifier(f) 424 | 425 | if self.loss == {'xent'}: 426 | return y 427 | elif self.loss == {'xent', 'htri'}: 428 | return y, f 429 | elif self.loss == {'cent'}: 430 | return y, f 431 | else: 432 | raise KeyError("Unsupported loss: {}".format(self.loss)) 433 | 434 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .ResNet import * 4 | from .ResNet_hypergraphsage_part import ResNet50GRAPHPOOLPARTHyper 5 | 6 | __factory = { 7 | 'resnet50graphpoolparthyper': ResNet50GRAPHPOOLPARTHyper, 8 | } 9 | 10 | def get_names(): 11 | return __factory.keys() 12 | 13 | def init_model(name, *args, **kwargs): 14 | if name not in __factory.keys(): 15 | raise KeyError("Unknown model: {}".format(name)) 16 | return __factory[name](*args, **kwargs) 17 | -------------------------------------------------------------------------------- /models/__pycache__/ResNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet_dart.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet_dart.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet_dart_search.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet_dart_search.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet_dynamichypergraphsage.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet_dynamichypergraphsage.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet_graph.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet_graph.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet_graphsage.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet_graphsage.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet_graphsage_bpm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet_graphsage_bpm.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet_graphsage_part.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet_graphsage_part.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet_graphsage_part_new.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet_graphsage_part_new.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet_graphsage_part_new_att.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet_graphsage_part_new_att.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet_hypergraphsage.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet_hypergraphsage.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet_hypergraphsage_part.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/ResNet_hypergraphsage_part.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/convlstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/convlstm.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/hypnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/hypnn.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/non_local_embedded_gaussian.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/non_local_embedded_gaussian.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/resnet3d.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daodaofr/hypergraph_reid/7485b89667b5e9f5de95d50f9f5a21b2fcb28045/models/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /models/architect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | def _concat(xs): 8 | return torch.cat([x.view(-1) for x in xs]) 9 | 10 | 11 | def _clip(grads, max_norm): 12 | total_norm = 0 13 | for g in grads: 14 | param_norm = g.data.norm(2) 15 | total_norm += param_norm ** 2 16 | total_norm = total_norm ** 0.5 17 | clip_coef = max_norm / (total_norm + 1e-6) 18 | if clip_coef < 1: 19 | for g in grads: 20 | g.data.mul_(clip_coef) 21 | return clip_coef 22 | 23 | 24 | class Architect(object): 25 | 26 | def __init__(self, model, args): 27 | self.network_weight_decay = args.wdecay 28 | self.network_clip = args.clip 29 | self.model = model 30 | self.optimizer = torch.optim.Adam(self.model.arch_parameters(), lr=args.arch_lr, weight_decay=args.arch_wdecay) 31 | 32 | def _compute_unrolled_model(self, input, target, eta): 33 | loss = self.model._loss(input, target) 34 | theta = _concat(self.model.parameters()).data 35 | grads = torch.autograd.grad(loss, self.model.parameters()) 36 | clip_coef = _clip(grads, self.network_clip) 37 | dtheta = _concat(grads).data + self.network_weight_decay*theta 38 | unrolled_model = self._construct_model_from_theta(theta.sub(eta, dtheta)) 39 | return unrolled_model, clip_coef 40 | 41 | def step(self, 42 | input_train, target_train, 43 | input_valid, target_valid, 44 | network_optimizer, unrolled): 45 | eta = network_optimizer.param_groups[0]['lr'] 46 | self.optimizer.zero_grad() 47 | if unrolled: 48 | self._backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta) 49 | else: 50 | self._backward_step(hidden_valid, input_valid, target_valid) 51 | self.optimizer.step() 52 | return 53 | 54 | def _backward_step(self, hidden, input, target): 55 | loss, hidden_next = self.model._loss( input, target) 56 | loss.backward() 57 | return hidden_next 58 | 59 | def _backward_step_unrolled(self, 60 | input_train, target_train, 61 | input_valid, target_valid, eta): 62 | unrolled_model, clip_coef = self._compute_unrolled_model(input_train, target_train, eta) 63 | unrolled_loss = unrolled_model._loss(input_valid, target_valid) 64 | 65 | unrolled_loss.backward() 66 | dalpha = [v.grad for v in unrolled_model.arch_parameters()] 67 | dtheta = [v.grad for v in unrolled_model.parameters()] 68 | _clip(dtheta, self.network_clip) 69 | vector = [dt.data for dt in dtheta] 70 | implicit_grads = self._hessian_vector_product(vector, input_train, target_train, r=1e-2) 71 | 72 | for g, ig in zip(dalpha, implicit_grads): 73 | g.data.sub_(eta * clip_coef, ig.data) 74 | 75 | for v, g in zip(self.model.arch_parameters(), dalpha): 76 | if v.grad is None: 77 | v.grad = Variable(g.data) 78 | else: 79 | v.grad.data.copy_(g.data) 80 | return 81 | 82 | def _construct_model_from_theta(self, theta): 83 | model_new = self.model.new() 84 | model_dict = self.model.state_dict() 85 | 86 | params, offset = {}, 0 87 | for k, v in self.model.named_parameters(): 88 | v_length = np.prod(v.size()) 89 | params[k] = theta[offset: offset+v_length].view(v.size()) 90 | offset += v_length 91 | 92 | assert offset == len(theta) 93 | model_dict.update(params) 94 | model_new.load_state_dict(model_dict) 95 | return model_new.cuda() 96 | 97 | def _hessian_vector_product(self, vector, input, target, r=1e-2): 98 | R = r / _concat(vector).norm() 99 | for p, v in zip(self.model.parameters(), vector): 100 | p.data.add_(R, v) 101 | loss = self.model._loss(input, target) 102 | grads_p = torch.autograd.grad(loss, self.model.arch_parameters()) 103 | 104 | for p, v in zip(self.model.parameters(), vector): 105 | p.data.sub_(2*R, v) 106 | loss = self.model._loss(input, target) 107 | grads_n = torch.autograd.grad(loss, self.model.arch_parameters()) 108 | 109 | for p, v in zip(self.model.parameters(), vector): 110 | p.data.add_(R, v) 111 | 112 | return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)] 113 | 114 | -------------------------------------------------------------------------------- /models/convlstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | import torch 4 | 5 | 6 | class ConvLSTMCell(nn.Module): 7 | 8 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias): 9 | """ 10 | Initialize ConvLSTM cell. 11 | 12 | Parameters 13 | ---------- 14 | input_size: (int, int) 15 | Height and width of input tensor as (height, width). 16 | input_dim: int 17 | Number of channels of input tensor. 18 | hidden_dim: int 19 | Number of channels of hidden state. 20 | kernel_size: (int, int) 21 | Size of the convolutional kernel. 22 | bias: bool 23 | Whether or not to add the bias. 24 | """ 25 | 26 | super(ConvLSTMCell, self).__init__() 27 | 28 | self.height, self.width = input_size 29 | self.input_dim = input_dim 30 | self.hidden_dim = hidden_dim 31 | 32 | self.kernel_size = kernel_size 33 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 34 | self.bias = bias 35 | 36 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 37 | out_channels=4 * self.hidden_dim, 38 | kernel_size=self.kernel_size, 39 | padding=self.padding, 40 | bias=self.bias) 41 | 42 | def forward(self, input_tensor, cur_state): 43 | 44 | h_cur, c_cur = cur_state 45 | 46 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 47 | 48 | combined_conv = self.conv(combined) 49 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 50 | i = torch.sigmoid(cc_i) 51 | f = torch.sigmoid(cc_f) 52 | o = torch.sigmoid(cc_o) 53 | g = torch.tanh(cc_g) 54 | 55 | c_next = f * c_cur + i * g 56 | h_next = o * torch.tanh(c_next) 57 | 58 | return h_next, c_next 59 | 60 | def init_hidden(self, batch_size): 61 | return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda(), 62 | Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda()) 63 | 64 | 65 | class ConvLSTM(nn.Module): 66 | 67 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers, 68 | batch_first=False, bias=True, return_all_layers=False): 69 | super(ConvLSTM, self).__init__() 70 | 71 | self._check_kernel_size_consistency(kernel_size) 72 | 73 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 74 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 75 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 76 | if not len(kernel_size) == len(hidden_dim) == num_layers: 77 | raise ValueError('Inconsistent list length.') 78 | 79 | self.height, self.width = input_size 80 | 81 | self.input_dim = input_dim 82 | self.hidden_dim = hidden_dim 83 | self.kernel_size = kernel_size 84 | self.num_layers = num_layers 85 | self.batch_first = batch_first 86 | self.bias = bias 87 | self.return_all_layers = return_all_layers 88 | 89 | cell_list = [] 90 | for i in range(0, self.num_layers): 91 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1] 92 | 93 | cell_list.append(ConvLSTMCell(input_size=(self.height, self.width), 94 | input_dim=cur_input_dim, 95 | hidden_dim=self.hidden_dim[i], 96 | kernel_size=self.kernel_size[i], 97 | bias=self.bias)) 98 | 99 | self.cell_list = nn.ModuleList(cell_list) 100 | 101 | def forward(self, input_tensor, hidden_state=None): 102 | """ 103 | 104 | Parameters 105 | ---------- 106 | input_tensor: todo 107 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 108 | hidden_state: todo 109 | None. todo implement stateful 110 | 111 | Returns 112 | ------- 113 | last_state_list, layer_output 114 | """ 115 | if not self.batch_first: 116 | # (t, b, c, h, w) -> (b, t, c, h, w) 117 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 118 | 119 | # Implement stateful ConvLSTM 120 | if hidden_state is not None: 121 | raise NotImplementedError() 122 | else: 123 | hidden_state = self._init_hidden(batch_size=input_tensor.size(0)) 124 | 125 | layer_output_list = [] 126 | last_state_list = [] 127 | 128 | seq_len = input_tensor.size(1) 129 | cur_layer_input = input_tensor 130 | 131 | for layer_idx in range(self.num_layers): 132 | 133 | h, c = hidden_state[layer_idx] 134 | output_inner = [] 135 | for t in range(seq_len): 136 | 137 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], 138 | cur_state=[h, c]) 139 | output_inner.append(h) 140 | 141 | layer_output = torch.stack(output_inner, dim=1) 142 | cur_layer_input = layer_output 143 | 144 | layer_output_list.append(layer_output) 145 | last_state_list.append([h, c]) 146 | 147 | if not self.return_all_layers: 148 | layer_output_list = layer_output_list[-1:] 149 | last_state_list = last_state_list[-1:] 150 | 151 | return layer_output_list, last_state_list 152 | 153 | def _init_hidden(self, batch_size): 154 | init_states = [] 155 | for i in range(self.num_layers): 156 | init_states.append(self.cell_list[i].init_hidden(batch_size)) 157 | return init_states 158 | 159 | @staticmethod 160 | def _check_kernel_size_consistency(kernel_size): 161 | if not (isinstance(kernel_size, tuple) or 162 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 163 | raise ValueError('`kernel_size` must be tuple or list of tuples') 164 | 165 | @staticmethod 166 | def _extend_for_multilayer(param, num_layers): 167 | if not isinstance(param, list): 168 | param = [param] * num_layers 169 | return param 170 | -------------------------------------------------------------------------------- /models/hypnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | 7 | import numpy as np 8 | from scipy.special import gamma 9 | 10 | def tanh(x, clamp=15): 11 | return x.clamp(-clamp, clamp).tanh() 12 | 13 | 14 | class Artanh(torch.autograd.Function): 15 | @staticmethod 16 | def forward(ctx, x): 17 | x = x.clamp(-1 + 1e-5, 1 - 1e-5) 18 | ctx.save_for_backward(x) 19 | res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5) 20 | return res 21 | 22 | @staticmethod 23 | def backward(ctx, grad_output): 24 | input, = ctx.saved_tensors 25 | return grad_output / (1 - input ** 2) 26 | 27 | 28 | class Arsinh(torch.autograd.Function): 29 | @staticmethod 30 | def forward(ctx, x): 31 | ctx.save_for_backward(x) 32 | return (x + torch.sqrt_(1 + x.pow(2))).clamp_min_(1e-5).log_() 33 | 34 | @staticmethod 35 | def backward(ctx, grad_output): 36 | input, = ctx.saved_tensors 37 | return grad_output / (1 + input ** 2) ** 0.5 38 | 39 | 40 | def artanh(x): 41 | return Artanh.apply(x) 42 | 43 | 44 | def arsinh(x): 45 | return Arsinh.apply(x) 46 | 47 | 48 | def arcosh(x, eps=1e-5): # pragma: no cover 49 | x = x.clamp(-1 + eps, 1 - eps) 50 | return torch.log(x + torch.sqrt(1 + x) * torch.sqrt(x - 1)) 51 | 52 | 53 | def project(x, *, c=1.0): 54 | r""" 55 | Safe projection on the manifold for numerical stability. This was mentioned in [1]_ 56 | Parameters 57 | ---------- 58 | x : tensor 59 | point on the Poincare ball 60 | c : float|tensor 61 | ball negative curvature 62 | Returns 63 | ------- 64 | tensor 65 | projected vector on the manifold 66 | References 67 | ---------- 68 | .. [1] Hyperbolic Neural Networks, NIPS2018 69 | https://arxiv.org/abs/1805.09112 70 | """ 71 | c = torch.as_tensor(c).type_as(x) 72 | return _project(x, c) 73 | 74 | 75 | def _project(x, c): 76 | norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5) 77 | maxnorm = (1 - 1e-3) / (c ** 0.5) 78 | cond = norm > maxnorm 79 | projected = x / norm * maxnorm 80 | return torch.where(cond, projected, x) 81 | 82 | 83 | def lambda_x(x, *, c=1.0, keepdim=False): 84 | r""" 85 | Compute the conformal factor :math:`\lambda^c_x` for a point on the ball 86 | .. math:: 87 | \lambda^c_x = \frac{1}{1 - c \|x\|_2^2} 88 | Parameters 89 | ---------- 90 | x : tensor 91 | point on the Poincare ball 92 | c : float|tensor 93 | ball negative curvature 94 | keepdim : bool 95 | retain the last dim? (default: false) 96 | Returns 97 | ------- 98 | tensor 99 | conformal factor 100 | """ 101 | c = torch.as_tensor(c).type_as(x) 102 | return _lambda_x(x, c, keepdim=keepdim) 103 | 104 | 105 | def _lambda_x(x, c, keepdim: bool = False): 106 | return 2 / (1 - c * x.pow(2).sum(-1, keepdim=keepdim)) 107 | 108 | 109 | def mobius_add(x, y, *, c=1.0): 110 | r""" 111 | Mobius addition is a special operation in a hyperbolic space. 112 | .. math:: 113 | x \oplus_c y = \frac{ 114 | (1 + 2 c \langle x, y\rangle + c \|y\|^2_2) x + (1 - c \|x\|_2^2) y 115 | }{ 116 | 1 + 2 c \langle x, y\rangle + c^2 \|x\|^2_2 \|y\|^2_2 117 | } 118 | In general this operation is not commutative: 119 | .. math:: 120 | x \oplus_c y \ne y \oplus_c x 121 | But in some cases this property holds: 122 | * zero vector case 123 | .. math:: 124 | \mathbf{0} \oplus_c x = x \oplus_c \mathbf{0} 125 | * zero negative curvature case that is same as Euclidean addition 126 | .. math:: 127 | x \oplus_0 y = y \oplus_0 x 128 | Another usefull property is so called left-cancellation law: 129 | .. math:: 130 | (-x) \oplus_c (x \oplus_c y) = y 131 | Parameters 132 | ---------- 133 | x : tensor 134 | point on the Poincare ball 135 | y : tensor 136 | point on the Poincare ball 137 | c : float|tensor 138 | ball negative curvature 139 | Returns 140 | ------- 141 | tensor 142 | the result of mobius addition 143 | """ 144 | c = torch.as_tensor(c).type_as(x) 145 | return _mobius_add(x, y, c) 146 | 147 | 148 | def _mobius_add(x, y, c): 149 | x2 = x.pow(2).sum(dim=-1, keepdim=True) 150 | y2 = y.pow(2).sum(dim=-1, keepdim=True) 151 | xy = (x * y).sum(dim=-1, keepdim=True) 152 | num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y 153 | denom = 1 + 2 * c * xy + c ** 2 * x2 * y2 154 | return num / (denom + 1e-5) 155 | 156 | 157 | def dist(x, y, *, c=1.0, keepdim=False): 158 | r""" 159 | Distance on the Poincare ball 160 | .. math:: 161 | d_c(x, y) = \frac{2}{\sqrt{c}}\tanh^{-1}(\sqrt{c}\|(-x)\oplus_c y\|_2) 162 | .. plot:: plots/extended/poincare/distance.py 163 | Parameters 164 | ---------- 165 | x : tensor 166 | point on poincare ball 167 | y : tensor 168 | point on poincare ball 169 | c : float|tensor 170 | ball negative curvature 171 | keepdim : bool 172 | retain the last dim? (default: false) 173 | Returns 174 | ------- 175 | tensor 176 | geodesic distance between :math:`x` and :math:`y` 177 | """ 178 | c = torch.as_tensor(c).type_as(x) 179 | #print("x", x.shape) 180 | #print("y", y.shape) 181 | return _dist(x, y, c, keepdim=keepdim) 182 | 183 | def dist_batch(x, y, *, c=1.0, keepdim=False): 184 | r""" 185 | Distance on the Poincare ball 186 | .. math:: 187 | d_c(x, y) = \frac{2}{\sqrt{c}}\tanh^{-1}(\sqrt{c}\|(-x)\oplus_c y\|_2) 188 | .. plot:: plots/extended/poincare/distance.py 189 | Parameters 190 | ---------- 191 | x : tensor 192 | point on poincare ball 193 | y : tensor 194 | point on poincare ball 195 | c : float|tensor 196 | ball negative curvature 197 | keepdim : bool 198 | retain the last dim? (default: false) 199 | Returns 200 | ------- 201 | tensor 202 | geodesic distance between :math:`x` and :math:`y` 203 | """ 204 | c = torch.as_tensor(c).type_as(x) 205 | #print("x", x.shape) 206 | #print("y", y.shape) 207 | return _dist_batch(x, y, c, keepdim=keepdim) 208 | 209 | def _dist(x, y, c, keepdim: bool = False): 210 | sqrt_c = c ** 0.5 211 | tmp = _mobius_add(-x, y, c) 212 | #print("tmp",tmp.shape) 213 | tmp = tmp.norm(dim=-1, p=2, keepdim=keepdim) 214 | #print("tmp", tmp.shape) 215 | dist_c = artanh(sqrt_c * _mobius_add(-x, y, c).norm(dim=-1, p=2, keepdim=keepdim)) 216 | #print(dist_c) 217 | return dist_c * 2 / sqrt_c 218 | 219 | def _dist_batch(x, y, c, keepdim: bool = False): 220 | sqrt_c = c ** 0.5 221 | tmp = _mobius_addition_batch(-x, y, c) 222 | #print("tmp",tmp.shape) 223 | tmp = tmp.norm(dim=-1, p=2, keepdim=keepdim) 224 | #print("tmp", tmp.shape) 225 | dist_c = artanh(sqrt_c * _mobius_addition_batch(-x, y, c).norm(dim=-1, p=2, keepdim=keepdim)) 226 | #print(dist_c) 227 | return dist_c * 2 / sqrt_c 228 | 229 | def dist0(x, *, c=1.0, keepdim=False): 230 | r""" 231 | Distance on the Poincare ball to zero 232 | Parameters 233 | ---------- 234 | x : tensor 235 | point on poincare ball 236 | c : float|tensor 237 | ball negative curvature 238 | keepdim : bool 239 | retain the last dim? (default: false) 240 | Returns 241 | ------- 242 | tensor 243 | geodesic distance between :math:`x` and :math:`0` 244 | """ 245 | c = torch.as_tensor(c).type_as(x) 246 | return _dist0(x, c, keepdim=keepdim) 247 | 248 | 249 | def _dist0(x, c, keepdim: bool = False): 250 | sqrt_c = c ** 0.5 251 | dist_c = artanh(sqrt_c * x.norm(dim=-1, p=2, keepdim=keepdim)) 252 | return dist_c * 2 / sqrt_c 253 | 254 | 255 | def expmap(x, u, *, c=1.0): 256 | r""" 257 | Exponential map for Poincare ball model. This is tightly related with :func:`geodesic`. 258 | Intuitively Exponential map is a smooth constant travelling from starting point :math:`x` with speed :math:`u`. 259 | A bit more formally this is travelling along curve :math:`\gamma_{x, u}(t)` such that 260 | .. math:: 261 | \gamma_{x, u}(0) = x\\ 262 | \dot\gamma_{x, u}(0) = u\\ 263 | \|\dot\gamma_{x, u}(t)\|_{\gamma_{x, u}(t)} = \|u\|_x 264 | The existence of this curve relies on uniqueness of differential equation solution, that is local. 265 | For the Poincare ball model the solution is well defined globally and we have. 266 | .. math:: 267 | \operatorname{Exp}^c_x(u) = \gamma_{x, u}(1) = \\ 268 | x\oplus_c \tanh(\sqrt{c}/2 \|u\|_x) \frac{u}{\sqrt{c}\|u\|_2} 269 | Parameters 270 | ---------- 271 | x : tensor 272 | starting point on poincare ball 273 | u : tensor 274 | speed vector on poincare ball 275 | c : float|tensor 276 | ball negative curvature 277 | Returns 278 | ------- 279 | tensor 280 | :math:`\gamma_{x, u}(1)` end point 281 | """ 282 | c = torch.as_tensor(c).type_as(x) 283 | return _expmap(x, u, c) 284 | 285 | 286 | def _expmap(x, u, c): # pragma: no cover 287 | sqrt_c = c ** 0.5 288 | u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5) 289 | second_term = ( 290 | tanh(sqrt_c / 2 * _lambda_x(x, c, keepdim=True) * u_norm) 291 | * u 292 | / (sqrt_c * u_norm) 293 | ) 294 | gamma_1 = _mobius_add(x, second_term, c) 295 | return gamma_1 296 | 297 | 298 | def expmap0(u, *, c=1.0): 299 | r""" 300 | Exponential map for Poincare ball model from :math:`0`. 301 | .. math:: 302 | \operatorname{Exp}^c_0(u) = \tanh(\sqrt{c}/2 \|u\|_2) \frac{u}{\sqrt{c}\|u\|_2} 303 | Parameters 304 | ---------- 305 | u : tensor 306 | speed vector on poincare ball 307 | c : float|tensor 308 | ball negative curvature 309 | Returns 310 | ------- 311 | tensor 312 | :math:`\gamma_{0, u}(1)` end point 313 | """ 314 | c = torch.as_tensor(c).type_as(u) 315 | return _expmap0(u, c) 316 | 317 | 318 | def _expmap0(u, c): 319 | sqrt_c = c ** 0.5 320 | u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5) 321 | gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm) 322 | return gamma_1 323 | 324 | 325 | def logmap(x, y, *, c=1.0): 326 | r""" 327 | Logarithmic map for two points :math:`x` and :math:`y` on the manifold. 328 | .. math:: 329 | \operatorname{Log}^c_x(y) = \frac{2}{\sqrt{c}\lambda_x^c} \tanh^{-1}( 330 | \sqrt{c} \|(-x)\oplus_c y\|_2 331 | ) * \frac{(-x)\oplus_c y}{\|(-x)\oplus_c y\|_2} 332 | The result of Logarithmic map is a vector such that 333 | .. math:: 334 | y = \operatorname{Exp}^c_x(\operatorname{Log}^c_x(y)) 335 | Parameters 336 | ---------- 337 | x : tensor 338 | starting point on poincare ball 339 | y : tensor 340 | target point on poincare ball 341 | c : float|tensor 342 | ball negative curvature 343 | Returns 344 | ------- 345 | tensor 346 | tangent vector that transports :math:`x` to :math:`y` 347 | """ 348 | c = torch.as_tensor(c).type_as(x) 349 | return _logmap(x, y, c) 350 | 351 | 352 | def _logmap(x, y, c): # pragma: no cover 353 | sub = _mobius_add(-x, y, c) 354 | sub_norm = sub.norm(dim=-1, p=2, keepdim=True) 355 | lam = _lambda_x(x, c, keepdim=True) 356 | sqrt_c = c ** 0.5 357 | return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm 358 | 359 | 360 | def logmap0(y, *, c=1.0): 361 | r""" 362 | Logarithmic map for :math:`y` from :math:`0` on the manifold. 363 | .. math:: 364 | \operatorname{Log}^c_0(y) = \tanh^{-1}(\sqrt{c}\|y\|_2) \frac{y}{\|y\|_2} 365 | The result is such that 366 | .. math:: 367 | y = \operatorname{Exp}^c_0(\operatorname{Log}^c_0(y)) 368 | Parameters 369 | ---------- 370 | y : tensor 371 | target point on poincare ball 372 | c : float|tensor 373 | ball negative curvature 374 | Returns 375 | ------- 376 | tensor 377 | tangent vector that transports :math:`0` to :math:`y` 378 | """ 379 | c = torch.as_tensor(c).type_as(y) 380 | return _logmap0(y, c) 381 | 382 | 383 | def _logmap0(y, c): 384 | sqrt_c = c ** 0.5 385 | y_norm = torch.clamp_min(y.norm(dim=-1, p=2, keepdim=True), 1e-5) 386 | return y / y_norm / sqrt_c * artanh(sqrt_c * y_norm) 387 | 388 | 389 | def mobius_matvec(m, x, *, c=1.0): 390 | r""" 391 | Generalization for matrix-vector multiplication to hyperbolic space defined as 392 | .. math:: 393 | M \otimes_c x = (1/\sqrt{c}) \tanh\left( 394 | \frac{\|Mx\|_2}{\|x\|_2}\tanh^{-1}(\sqrt{c}\|x\|_2) 395 | \right)\frac{Mx}{\|Mx\|_2} 396 | Parameters 397 | ---------- 398 | m : tensor 399 | matrix for multiplication 400 | x : tensor 401 | point on poincare ball 402 | c : float|tensor 403 | negative ball curvature 404 | Returns 405 | ------- 406 | tensor 407 | Mobius matvec result 408 | """ 409 | c = torch.as_tensor(c).type_as(x) 410 | return _mobius_matvec(m, x, c) 411 | 412 | 413 | def _mobius_matvec(m, x, c): 414 | x_norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5) 415 | sqrt_c = c ** 0.5 416 | mx = x @ m.transpose(-1, -2) 417 | mx_norm = mx.norm(dim=-1, keepdim=True, p=2) 418 | res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c) 419 | cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8) 420 | res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device) 421 | res = torch.where(cond, res_0, res_c) 422 | return _project(res, c) 423 | 424 | 425 | def _tensor_dot(x, y): 426 | res = torch.einsum('ij,kj->ik', (x, y)) 427 | return res 428 | 429 | 430 | def _mobius_addition_batch(x, y, c): 431 | #print("x", x.shape) 432 | #print("y", y.shape) 433 | xy = _tensor_dot(x, y) # B x C 434 | #print("xy", xy.shape) 435 | x2 = x.pow(2).sum(-1, keepdim=True) # B x 1 436 | y2 = y.pow(2).sum(-1, keepdim=True) # C x 1 437 | num = (1 + 2 * c * xy + c * y2.permute(1, 0)) # B x C 438 | num = num.unsqueeze(2) * x.unsqueeze(1) 439 | num = num + (1 - c * x2).unsqueeze(2) * y # B x C x D 440 | #print("num", num.shape) 441 | denom_part1 = 1 + 2 * c * xy # B x C 442 | denom_part2 = c ** 2 * x2 * y2.permute(1, 0) 443 | denom = denom_part1 + denom_part2 444 | #print("denom", denom.shape) 445 | res = num / (denom.unsqueeze(2) + 1e-5) 446 | #print("res", res.shape) 447 | return res 448 | 449 | 450 | def _hyperbolic_softmax(X, A, P, c): 451 | lambda_pkc = 2 / (1 - c * P.pow(2).sum(dim=1)) 452 | k = lambda_pkc * torch.norm(A, dim=1) / torch.sqrt(c) 453 | mob_add = _mobius_addition_batch(-P, X, c) 454 | num = 2 * torch.sqrt(c) * torch.sum(mob_add * A.unsqueeze(1), dim=-1) 455 | denom = torch.norm(A, dim=1, keepdim=True) * (1 - c * mob_add.pow(2).sum(dim=2)) 456 | logit = k.unsqueeze(1) * arsinh(num / denom) 457 | return logit.permute(1, 0) 458 | 459 | 460 | def p2k(x, c): 461 | denom = 1 + c * x.pow(2).sum(-1, keepdim=True) 462 | return 2 * x / denom 463 | 464 | 465 | def k2p(x, c): 466 | denom = 1 + torch.sqrt(1 - c * x.pow(2).sum(-1, keepdim=True)) 467 | return x / denom 468 | 469 | 470 | def lorenz_factor(x, *, c=1.0, dim=-1, keepdim=False): 471 | """ 472 | 473 | Parameters 474 | ---------- 475 | x : tensor 476 | point on Klein disk 477 | c : float 478 | negative curvature 479 | dim : int 480 | dimension to calculate Lorenz factor 481 | keepdim : bool 482 | retain the last dim? (default: false) 483 | 484 | Returns 485 | ------- 486 | tensor 487 | Lorenz factor 488 | """ 489 | return 1 / torch.sqrt(1 - c * x.pow(2).sum(dim=dim, keepdim=keepdim)) 490 | 491 | 492 | def poincare_mean(x, dim=0, c=1.0): 493 | x = p2k(x, c) 494 | lamb = lorenz_factor(x, c=c, keepdim=True) 495 | mean = torch.sum(lamb * x, dim=dim, keepdim=True) / torch.sum(lamb, dim=dim, keepdim=True) 496 | mean = k2p(mean, c) 497 | return mean.squeeze(dim) 498 | 499 | 500 | def _dist_matrix(x, y, c): 501 | sqrt_c = c ** 0.5 502 | return 2 / sqrt_c * artanh(sqrt_c * torch.norm(_mobius_addition_batch(-x, y, c=c), dim=-1)) 503 | 504 | 505 | def dist_matrix(x, y, c=1.0): 506 | c = torch.as_tensor(c).type_as(x) 507 | return _dist_matrix(x, y, c) 508 | 509 | 510 | def auto_select_c(d): 511 | """ 512 | calculates the radius of the Poincare ball, 513 | such that the d-dimensional ball has constant volume equal to pi 514 | """ 515 | dim2 = d / 2.0 516 | R = gamma(dim2 + 1) / (np.pi ** (dim2 - 1)) 517 | R = R ** (1 / float(d)) 518 | c = 1 / (R ** 2) 519 | return c 520 | 521 | 522 | 523 | 524 | class HyperbolicMLR(nn.Module): 525 | r""" 526 | Module which performs softmax classification 527 | in Hyperbolic space. 528 | """ 529 | def __init__(self, ball_dim, n_classes, c): 530 | super(HyperbolicMLR, self).__init__() 531 | self.a_vals = nn.Parameter(torch.Tensor(n_classes, ball_dim)) 532 | self.p_vals = nn.Parameter(torch.Tensor(n_classes, ball_dim)) 533 | self.c = c 534 | self.n_classes = n_classes 535 | self.ball_dim = ball_dim 536 | self.reset_parameters() 537 | 538 | def forward(self, x, c=None): 539 | if c is None: 540 | c = torch.as_tensor(self.c).type_as(x) 541 | else: 542 | c = torch.as_tensor(c).type_as(x) 543 | p_vals_poincare = expmap0(self.p_vals, c=c) 544 | conformal_factor = (1 - c * p_vals_poincare.pow(2).sum(dim=1, keepdim=True)) 545 | a_vals_poincare = self.a_vals * conformal_factor 546 | logits = _hyperbolic_softmax(x, a_vals_poincare, p_vals_poincare, c) 547 | return logits 548 | 549 | 550 | def extra_repr(self): 551 | return 'Poincare ball dim={}, n_classes={}, c={}'.format( 552 | self.ball_dim, self.n_classes, self.c 553 | ) 554 | 555 | 556 | def reset_parameters(self): 557 | init.kaiming_uniform_(self.a_vals, a=math.sqrt(5)) 558 | init.kaiming_uniform_(self.p_vals, a=math.sqrt(5)) 559 | 560 | 561 | class HypLinear(nn.Module): 562 | def __init__(self, in_features, out_features, c, bias=True): 563 | super(HypLinear, self).__init__() 564 | self.in_features = in_features 565 | self.out_features = out_features 566 | self.c = c 567 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 568 | if bias: 569 | self.bias = nn.Parameter(torch.Tensor(out_features)) 570 | else: 571 | self.register_parameter('bias', None) 572 | self.reset_parameters() 573 | 574 | 575 | def reset_parameters(self): 576 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 577 | if self.bias is not None: 578 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 579 | bound = 1 / math.sqrt(fan_in) 580 | init.uniform_(self.bias, -bound, bound) 581 | 582 | def forward(self, x, c=None): 583 | if c is None: 584 | c = self.c 585 | mv = mobius_matvec(self.weight, x, c=c) 586 | if self.bias is None: 587 | return project(mv, c=c) 588 | else: 589 | bias = expmap0(self.bias, c=c) 590 | return project(mobius_add(mv, bias), c=c) 591 | 592 | 593 | def extra_repr(self): 594 | return 'in_features={}, out_features={}, bias={}, c={}'.format( 595 | self.in_features, self.out_features, self.bias is not None, self.c 596 | ) 597 | 598 | 599 | class ConcatPoincareLayer(nn.Module): 600 | def __init__(self, d1, d2, d_out, c): 601 | super(ConcatPoincareLayer, self).__init__() 602 | self.d1 = d1 603 | self.d2 = d2 604 | self.d_out = d_out 605 | 606 | self.l1 = HypLinear(d1, d_out, bias=False, c=c) 607 | self.l2 = HypLinear(d2, d_out, bias=False, c=c) 608 | self.c = c 609 | 610 | def forward(self, x1, x2, c=None): 611 | if c is None: 612 | c = self.c 613 | return mobius_add(self.l1(x1), self.l2(x2), c=c) 614 | 615 | 616 | def extra_repr(self): 617 | return 'dims {} and {} ---> dim {}'.format( 618 | self.d1, self.d2, self.d_out 619 | ) 620 | 621 | 622 | class HyperbolicDistanceLayer(nn.Module): 623 | def __init__(self, c): 624 | super(HyperbolicDistanceLayer, self).__init__() 625 | self.c = c 626 | 627 | def forward(self, x1, x2, c=None): 628 | if c is None: 629 | c = self.c 630 | return dist(x1, x2, c=c, keepdim=True) 631 | 632 | def extra_repr(self): 633 | return 'c={}'.format(self.c) 634 | 635 | 636 | class ToPoincare(nn.Module): 637 | r""" 638 | Module which maps points in n-dim Euclidean space 639 | to n-dim Poincare ball 640 | """ 641 | def __init__(self, c, train_c=False, train_x=False, ball_dim=None): 642 | super(ToPoincare, self).__init__() 643 | if train_x: 644 | if ball_dim is None: 645 | raise ValueError("if train_x=True, ball_dim has to be integer, got {}".format(ball_dim)) 646 | self.xp = nn.Parameter(torch.zeros((ball_dim,))) 647 | else: 648 | self.register_parameter('xp', None) 649 | 650 | if train_c: 651 | self.c = nn.Parameter(torch.Tensor([c,])) 652 | else: 653 | self.c = c 654 | 655 | self.train_x = train_x 656 | 657 | def forward(self, x): 658 | if self.train_x: 659 | xp = project(expmap0(self.xp, c=self.c), c=self.c) 660 | return project(expmap(xp, x, c=self.c), c=self.c) 661 | return project(expmap0(x, c=self.c), c=self.c) 662 | 663 | def extra_repr(self): 664 | return 'c={}, train_x={}'.format(self.c, self.train_x) 665 | 666 | 667 | class FromPoincare(nn.Module): 668 | r""" 669 | Module which maps points in n-dim Poincare ball 670 | to n-dim Euclidean space 671 | """ 672 | def __init__(self, c, train_c=False, train_x=False, ball_dim=None): 673 | 674 | super(FromPoincare, self).__init__() 675 | 676 | if train_x: 677 | if ball_dim is None: 678 | raise ValueError("if train_x=True, ball_dim has to be integer, got {}".format(ball_dim)) 679 | self.xp = nn.Parameter(torch.zeros((ball_dim,))) 680 | else: 681 | self.register_parameter('xp', None) 682 | 683 | if train_c: 684 | self.c = nn.Parameter(torch.Tensor([c,])) 685 | else: 686 | self.c = c 687 | 688 | self.train_c = train_c 689 | self.train_x = train_x 690 | 691 | def forward(self, x): 692 | if self.train_x: 693 | xp = project(expmap0(self.xp, c=self.c), c=self.c) 694 | return logmap(xp, x, c=self.c) 695 | return logmap0(x, c=self.c) 696 | 697 | def extra_repr(self): 698 | return 'train_c={}, train_x={}'.format(self.train_c, self.train_x) 699 | 700 | 701 | 702 | -------------------------------------------------------------------------------- /models/network.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | # from lib.non_local_concatenation import NONLocalBlock2D 3 | # from lib.non_local_gaussian import NONLocalBlock2D 4 | from lib.non_local_embedded_gaussian import NONLocalBlock2D 5 | # from lib.non_local_dot_product import NONLocalBlock2D 6 | 7 | 8 | class Network(nn.Module): 9 | def __init__(self): 10 | super(Network, self).__init__() 11 | 12 | self.convs = nn.Sequential( 13 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), 14 | nn.BatchNorm2d(32), 15 | nn.ReLU(), 16 | nn.MaxPool2d(2), 17 | 18 | NONLocalBlock2D(in_channels=32), 19 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), 20 | nn.BatchNorm2d(64), 21 | nn.ReLU(), 22 | nn.MaxPool2d(2), 23 | 24 | NONLocalBlock2D(in_channels=64), 25 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 26 | nn.BatchNorm2d(128), 27 | nn.ReLU(), 28 | nn.MaxPool2d(2), 29 | ) 30 | 31 | self.fc = nn.Sequential( 32 | nn.Linear(in_features=128*3*3, out_features=256), 33 | nn.ReLU(), 34 | nn.Dropout(0.5), 35 | 36 | nn.Linear(in_features=256, out_features=10) 37 | ) 38 | 39 | def forward(self, x): 40 | batch_size = x.size(0) 41 | output = self.convs(x).view(batch_size, -1) 42 | output = self.fc(output) 43 | return output 44 | 45 | if __name__ == '__main__': 46 | import torch 47 | from torch.autograd import Variable 48 | 49 | img = Variable(torch.randn(3, 1, 28, 28)) 50 | net = Network() 51 | out = net(img) 52 | print(out.size()) 53 | 54 | -------------------------------------------------------------------------------- /models/non_local_concatenation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant(self.W[1].weight, 0) 46 | nn.init.constant(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant(self.W.weight, 0) 51 | nn.init.constant(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | 56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | 59 | self.concat_project = nn.Sequential( 60 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), 61 | nn.ReLU() 62 | ) 63 | 64 | if sub_sample: 65 | self.g = nn.Sequential(self.g, max_pool_layer) 66 | self.phi = nn.Sequential(self.phi, max_pool_layer) 67 | 68 | def forward(self, x): 69 | ''' 70 | :param x: (b, c, t, h, w) 71 | :return: 72 | ''' 73 | 74 | batch_size = x.size(0) 75 | 76 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 77 | g_x = g_x.permute(0, 2, 1) 78 | 79 | # (b, c, N, 1) 80 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 81 | # (b, c, 1, N) 82 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 83 | 84 | h = theta_x.size(2) 85 | w = phi_x.size(3) 86 | theta_x = theta_x.repeat(1, 1, 1, w) 87 | phi_x = phi_x.repeat(1, 1, h, 1) 88 | 89 | concat_feature = torch.cat([theta_x, phi_x], dim=1) 90 | f = self.concat_project(concat_feature) 91 | b, _, h, w = f.size() 92 | f = f.view(b, h, w) 93 | 94 | N = f.size(-1) 95 | f_div_C = f / N 96 | 97 | y = torch.matmul(f_div_C, g_x) 98 | y = y.permute(0, 2, 1).contiguous() 99 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 100 | W_y = self.W(y) 101 | z = W_y + x 102 | 103 | return z 104 | 105 | 106 | class NONLocalBlock1D(_NonLocalBlockND): 107 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 108 | super(NONLocalBlock1D, self).__init__(in_channels, 109 | inter_channels=inter_channels, 110 | dimension=1, sub_sample=sub_sample, 111 | bn_layer=bn_layer) 112 | 113 | 114 | class NONLocalBlock2D(_NonLocalBlockND): 115 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 116 | super(NONLocalBlock2D, self).__init__(in_channels, 117 | inter_channels=inter_channels, 118 | dimension=2, sub_sample=sub_sample, 119 | bn_layer=bn_layer) 120 | 121 | 122 | class NONLocalBlock3D(_NonLocalBlockND): 123 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 124 | super(NONLocalBlock3D, self).__init__(in_channels, 125 | inter_channels=inter_channels, 126 | dimension=3, sub_sample=sub_sample, 127 | bn_layer=bn_layer) 128 | 129 | 130 | if __name__ == '__main__': 131 | from torch.autograd import Variable 132 | import torch 133 | 134 | for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]: 135 | img = Variable(torch.zeros(2, 3, 20)) 136 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 137 | out = net(img) 138 | print(out.size()) 139 | 140 | img = Variable(torch.zeros(2, 3, 20, 20)) 141 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 142 | out = net(img) 143 | print(out.size()) 144 | 145 | img = Variable(torch.randn(2, 3, 8, 20, 20)) 146 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 147 | out = net(img) 148 | print(out.size()) 149 | -------------------------------------------------------------------------------- /models/non_local_dot_product.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant(self.W[1].weight, 0) 46 | nn.init.constant(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant(self.W.weight, 0) 51 | nn.init.constant(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | 56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | 59 | if sub_sample: 60 | self.g = nn.Sequential(self.g, max_pool_layer) 61 | self.phi = nn.Sequential(self.phi, max_pool_layer) 62 | 63 | def forward(self, x): 64 | ''' 65 | :param x: (b, c, t, h, w) 66 | :return: 67 | ''' 68 | 69 | batch_size = x.size(0) 70 | 71 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 72 | g_x = g_x.permute(0, 2, 1) 73 | 74 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 75 | theta_x = theta_x.permute(0, 2, 1) 76 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 77 | f = torch.matmul(theta_x, phi_x) 78 | N = f.size(-1) 79 | f_div_C = f / N 80 | 81 | y = torch.matmul(f_div_C, g_x) 82 | y = y.permute(0, 2, 1).contiguous() 83 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 84 | W_y = self.W(y) 85 | z = W_y + x 86 | 87 | return z 88 | 89 | 90 | class NONLocalBlock1D(_NonLocalBlockND): 91 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 92 | super(NONLocalBlock1D, self).__init__(in_channels, 93 | inter_channels=inter_channels, 94 | dimension=1, sub_sample=sub_sample, 95 | bn_layer=bn_layer) 96 | 97 | 98 | class NONLocalBlock2D(_NonLocalBlockND): 99 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 100 | super(NONLocalBlock2D, self).__init__(in_channels, 101 | inter_channels=inter_channels, 102 | dimension=2, sub_sample=sub_sample, 103 | bn_layer=bn_layer) 104 | 105 | 106 | class NONLocalBlock3D(_NonLocalBlockND): 107 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 108 | super(NONLocalBlock3D, self).__init__(in_channels, 109 | inter_channels=inter_channels, 110 | dimension=3, sub_sample=sub_sample, 111 | bn_layer=bn_layer) 112 | 113 | 114 | if __name__ == '__main__': 115 | from torch.autograd import Variable 116 | import torch 117 | 118 | for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]: 119 | img = Variable(torch.zeros(2, 3, 20)) 120 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 121 | out = net(img) 122 | print(out.size()) 123 | 124 | img = Variable(torch.zeros(2, 3, 20, 20)) 125 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 126 | out = net(img) 127 | print(out.size()) 128 | 129 | img = Variable(torch.randn(2, 3, 8, 20, 20)) 130 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 131 | out = net(img) 132 | print(out.size()) 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /models/non_local_embedded_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 56 | kernel_size=1, stride=1, padding=0) 57 | 58 | if sub_sample: 59 | self.g = nn.Sequential(self.g, max_pool_layer) 60 | self.phi = nn.Sequential(self.phi, max_pool_layer) 61 | 62 | def forward(self, x): 63 | ''' 64 | :param x: (b, c, t, h, w) 65 | :return: 66 | ''' 67 | 68 | batch_size = x.size(0) 69 | 70 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 71 | g_x = g_x.permute(0, 2, 1) 72 | 73 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 74 | theta_x = theta_x.permute(0, 2, 1) 75 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 76 | f = torch.matmul(theta_x, phi_x) 77 | f_div_C = F.softmax(f, dim=-1) 78 | 79 | y = torch.matmul(f_div_C, g_x) 80 | y = y.permute(0, 2, 1).contiguous() 81 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 82 | W_y = self.W(y) 83 | z = W_y + x 84 | 85 | return z 86 | 87 | 88 | class NONLocalBlock1D(_NonLocalBlockND): 89 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 90 | super(NONLocalBlock1D, self).__init__(in_channels, 91 | inter_channels=inter_channels, 92 | dimension=1, sub_sample=sub_sample, 93 | bn_layer=bn_layer) 94 | 95 | 96 | class NONLocalBlock2D(_NonLocalBlockND): 97 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 98 | super(NONLocalBlock2D, self).__init__(in_channels, 99 | inter_channels=inter_channels, 100 | dimension=2, sub_sample=sub_sample, 101 | bn_layer=bn_layer) 102 | 103 | 104 | class NONLocalBlock3D(_NonLocalBlockND): 105 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 106 | super(NONLocalBlock3D, self).__init__(in_channels, 107 | inter_channels=inter_channels, 108 | dimension=3, sub_sample=sub_sample, 109 | bn_layer=bn_layer) 110 | 111 | 112 | if __name__ == '__main__': 113 | from torch.autograd import Variable 114 | import torch 115 | 116 | sub_sample = True 117 | bn_layer = True 118 | 119 | img = Variable(torch.zeros(2, 3, 20)) 120 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 121 | out = net(img) 122 | print(out.size()) 123 | 124 | img = Variable(torch.zeros(2, 3, 20, 20)) 125 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 126 | out = net(img) 127 | print(out.size()) 128 | 129 | img = Variable(torch.randn(2, 3, 10, 20, 20)) 130 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 131 | out = net(img) 132 | print(out.size()) 133 | 134 | -------------------------------------------------------------------------------- /models/non_local_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant(self.W[1].weight, 0) 46 | nn.init.constant(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant(self.W.weight, 0) 51 | nn.init.constant(self.W.bias, 0) 52 | 53 | if sub_sample: 54 | self.g = nn.Sequential(self.g, max_pool_layer) 55 | self.phi = max_pool_layer 56 | 57 | def forward(self, x): 58 | ''' 59 | :param x: (b, c, t, h, w) 60 | :return: 61 | ''' 62 | 63 | batch_size = x.size(0) 64 | 65 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 66 | 67 | g_x = g_x.permute(0, 2, 1) 68 | 69 | theta_x = x.view(batch_size, self.in_channels, -1) 70 | theta_x = theta_x.permute(0, 2, 1) 71 | 72 | if self.sub_sample: 73 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1) 74 | else: 75 | phi_x = x.view(batch_size, self.in_channels, -1) 76 | 77 | f = torch.matmul(theta_x, phi_x) 78 | f_div_C = F.softmax(f, dim=-1) 79 | 80 | y = torch.matmul(f_div_C, g_x) 81 | y = y.permute(0, 2, 1).contiguous() 82 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 83 | W_y = self.W(y) 84 | z = W_y + x 85 | 86 | return z 87 | 88 | 89 | class NONLocalBlock1D(_NonLocalBlockND): 90 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 91 | super(NONLocalBlock1D, self).__init__(in_channels, 92 | inter_channels=inter_channels, 93 | dimension=1, sub_sample=sub_sample, 94 | bn_layer=bn_layer) 95 | 96 | 97 | class NONLocalBlock2D(_NonLocalBlockND): 98 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 99 | super(NONLocalBlock2D, self).__init__(in_channels, 100 | inter_channels=inter_channels, 101 | dimension=2, sub_sample=sub_sample, 102 | bn_layer=bn_layer) 103 | 104 | 105 | class NONLocalBlock3D(_NonLocalBlockND): 106 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 107 | super(NONLocalBlock3D, self).__init__(in_channels, 108 | inter_channels=inter_channels, 109 | dimension=3, sub_sample=sub_sample, 110 | bn_layer=bn_layer) 111 | 112 | 113 | if __name__ == '__main__': 114 | from torch.autograd import Variable 115 | import torch 116 | 117 | for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]: 118 | img = Variable(torch.zeros(2, 3, 20)) 119 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 120 | out = net(img) 121 | print(out.size()) 122 | 123 | img = Variable(torch.zeros(2, 3, 20, 20)) 124 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 125 | out = net(img) 126 | print(out.size()) 127 | 128 | img = Variable(torch.randn(2, 3, 8, 20, 20)) 129 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 130 | out = net(img) 131 | print(out.size()) 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | from torch import nn 11 | #from non_local_concatenation import NONLocalBlock2D 12 | #from non_local_gaussian import NONLocalBlock2D 13 | from .non_local_embedded_gaussian import NONLocalBlock2D 14 | #from .non_local_dot_product import NONLocalBlock2D 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | """3x3 convolution with padding""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, nonloc=False): 58 | super(Bottleneck, self).__init__() 59 | #print(stride) 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * 4) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | self.nonloc=nonloc 71 | if self.nonloc: 72 | self.nonlocalblock = NONLocalBlock2D(in_channels=(planes * 4)) 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | 95 | if self.nonloc: 96 | out = self.nonlocalblock(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): 103 | self.inplanes = 64 104 | super(ResNet, self).__init__() 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 106 | bias=False) 107 | self.bn1 = nn.BatchNorm2d(64) 108 | # self.relu = nn.ReLU(inplace=True) # add missed relu 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | self.layer1 = self._make_layer(block, 64, layers[0]) 111 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 113 | self.layer4 = self._make_layer( 114 | block, 512, layers[3], stride=last_stride) 115 | 116 | def _make_layer(self, block, planes, blocks, stride=1): 117 | downsample = None 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | downsample = nn.Sequential( 120 | nn.Conv2d(self.inplanes, planes * block.expansion, 121 | kernel_size=1, stride=stride, bias=False), 122 | nn.BatchNorm2d(planes * block.expansion), 123 | ) 124 | 125 | layers = [] 126 | layers.append(block(self.inplanes, planes, stride, downsample)) 127 | self.inplanes = planes * block.expansion 128 | for i in range(1, blocks): 129 | layers.append(block(self.inplanes, planes)) 130 | 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | x = self.bn1(x) 136 | # x = self.relu(x) # add missed relu 137 | x = self.maxpool(x) 138 | 139 | x = self.layer1(x) 140 | x = self.layer2(x) 141 | x = self.layer3(x) 142 | x = self.layer4(x) 143 | 144 | return x 145 | 146 | def load_param(self, model_path): 147 | param_dict = torch.load(model_path) 148 | #print(param_dict['layer1.2.conv1.weight']) 149 | #print(param_dict.keys()) 150 | #for k,v in param_dict.items(): 151 | # print(k) 152 | #for k,v in self.state_dict().items(): 153 | # print(k) 154 | 155 | net1_dict = {k:v for k,v in param_dict.items() if k in self.state_dict().keys()} 156 | #print(net1_dict.keys()) 157 | #print(net1_dict['layer1.2.conv1.weight']) 158 | #self.state_dict().update(net1_dict) 159 | #print(self.state_dict()['layer1.2.conv1.weight']) 160 | self.load_state_dict(net1_dict) 161 | ''' 162 | for i in param_dict: 163 | if 'fc' in i: 164 | continue 165 | self.state_dict()[i].copy_(param_dict[i]) 166 | ''' 167 | def random_init(self): 168 | for m in self.modules(): 169 | if isinstance(m, nn.Conv2d): 170 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 171 | m.weight.data.normal_(0, math.sqrt(2. / n)) 172 | elif isinstance(m, nn.BatchNorm2d): 173 | m.weight.data.fill_(1) 174 | m.bias.data.zero_() 175 | 176 | class ResNetNonLocal(nn.Module): 177 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): 178 | self.inplanes = 64 179 | super(ResNetNonLocal, self).__init__() 180 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 181 | bias=False) 182 | self.bn1 = nn.BatchNorm2d(64) 183 | # self.relu = nn.ReLU(inplace=True) # add missed relu 184 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 185 | self.layer1 = self._make_layer(block, 64, layers[0]) 186 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, nonloc=True) 187 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, nonloc=True) 188 | self.layer4 = self._make_layer( 189 | block, 512, layers[3], stride=last_stride) 190 | 191 | def _make_layer(self, block, planes, blocks, stride=1, nonloc=False): 192 | downsample = None 193 | if stride != 1 or self.inplanes != planes * block.expansion: 194 | downsample = nn.Sequential( 195 | nn.Conv2d(self.inplanes, planes * block.expansion, 196 | kernel_size=1, stride=stride, bias=False), 197 | nn.BatchNorm2d(planes * block.expansion), 198 | ) 199 | 200 | layers = [] 201 | layers.append(block(self.inplanes, planes, stride, downsample)) 202 | self.inplanes = planes * block.expansion 203 | for i in range(1, blocks): 204 | isnonlocal = False 205 | if nonloc and i % 2 == 1: 206 | isnonlocal = True 207 | #print( self.inplanes, planes) 208 | layers.append(block(self.inplanes, planes, nonloc=isnonlocal)) 209 | 210 | return nn.Sequential(*layers) 211 | 212 | def forward(self, x): 213 | x = self.conv1(x) 214 | x = self.bn1(x) 215 | # x = self.relu(x) # add missed relu 216 | x = self.maxpool(x) 217 | 218 | x = self.layer1(x) 219 | x = self.layer2(x) 220 | x = self.layer3(x) 221 | x = self.layer4(x) 222 | 223 | return x 224 | 225 | def load_param(self, model_path): 226 | param_dict = torch.load(model_path) 227 | #print(param_dict['layer1.2.conv1.weight']) 228 | #print(param_dict.keys()) 229 | #for k,v in param_dict.items(): 230 | # print(k) 231 | #for k,v in self.state_dict().items(): 232 | # print(k) 233 | model_dict = self.state_dict() 234 | param_dict = {k:v for k,v in param_dict.items() if k in model_dict.keys()} 235 | #print(net1_dict.keys()) 236 | #print(param_dict['layer1.2.conv1.weight']) 237 | model_dict.update(param_dict) 238 | #print(self.state_dict()['layer1.2.conv1.weight']) 239 | self.load_state_dict(model_dict) 240 | ''' 241 | for i in param_dict: 242 | if 'fc' in i: 243 | continue 244 | self.state_dict()[i].copy_(param_dict[i]) 245 | ''' 246 | def random_init(self): 247 | for m in self.modules(): 248 | if isinstance(m, nn.Conv2d): 249 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 250 | m.weight.data.normal_(0, math.sqrt(2. / n)) 251 | elif isinstance(m, nn.BatchNorm2d): 252 | m.weight.data.fill_(1) 253 | m.bias.data.zero_() 254 | -------------------------------------------------------------------------------- /models/resnet3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = [ 9 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnet200' 11 | ] 12 | 13 | 14 | def conv3x3x3(in_planes, out_planes, stride=1): 15 | # 3x3x3 convolution with padding 16 | return nn.Conv3d( 17 | in_planes, 18 | out_planes, 19 | kernel_size=3, 20 | stride=stride, 21 | padding=1, 22 | bias=False) 23 | 24 | 25 | def downsample_basic_block(x, planes, stride): 26 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 27 | zero_pads = torch.Tensor( 28 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 29 | out.size(4)).zero_() 30 | if isinstance(out.data, torch.cuda.FloatTensor): 31 | zero_pads = zero_pads.cuda() 32 | 33 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 34 | 35 | return out 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None): 42 | super(BasicBlock, self).__init__() 43 | self.conv1 = conv3x3x3(inplanes, planes, stride) 44 | self.bn1 = nn.BatchNorm3d(planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.conv2 = conv3x3x3(planes, planes) 47 | self.bn2 = nn.BatchNorm3d(planes) 48 | self.downsample = downsample 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | residual = x 53 | 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class Bottleneck(nn.Module): 71 | expansion = 4 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None): 74 | super(Bottleneck, self).__init__() 75 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 76 | self.bn1 = nn.BatchNorm3d(planes) 77 | self.conv2 = nn.Conv3d( 78 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 79 | self.bn2 = nn.BatchNorm3d(planes) 80 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 81 | self.bn3 = nn.BatchNorm3d(planes * 4) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.downsample = downsample 84 | self.stride = stride 85 | 86 | def forward(self, x): 87 | residual = x 88 | 89 | out = self.conv1(x) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv2(out) 94 | out = self.bn2(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv3(out) 98 | out = self.bn3(out) 99 | 100 | if self.downsample is not None: 101 | residual = self.downsample(x) 102 | 103 | out += residual 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class ResNet(nn.Module): 110 | 111 | def __init__(self, 112 | block, 113 | layers, 114 | sample_height, 115 | sample_width, 116 | sample_duration, 117 | shortcut_type='B', 118 | num_classes=400): 119 | self.inplanes = 64 120 | super(ResNet, self).__init__() 121 | self.conv1 = nn.Conv3d( 122 | 3, 123 | 64, 124 | kernel_size=7, 125 | stride=(1, 2, 2), 126 | padding=(3, 3, 3), 127 | bias=False) 128 | self.bn1 = nn.BatchNorm3d(64) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 131 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 132 | self.layer2 = self._make_layer( 133 | block, 128, layers[1], shortcut_type, stride=2) 134 | self.layer3 = self._make_layer( 135 | block, 256, layers[2], shortcut_type, stride=2) 136 | self.layer4 = self._make_layer( 137 | block, 512, layers[3], shortcut_type, stride=2) 138 | last_duration = int(math.ceil(sample_duration / 16.0)) 139 | last_height = int(math.ceil(sample_height / 32.0)) 140 | last_width = int(math.ceil(sample_width / 32.0)) 141 | self.avgpool = nn.AvgPool3d( 142 | (last_duration, last_height, last_width), stride=1) 143 | self.fc = nn.Linear(512 * block.expansion, num_classes) 144 | 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv3d): 147 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') 148 | elif isinstance(m, nn.BatchNorm3d): 149 | m.weight.data.fill_(1) 150 | m.bias.data.zero_() 151 | 152 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 153 | downsample = None 154 | if stride != 1 or self.inplanes != planes * block.expansion: 155 | if shortcut_type == 'A': 156 | downsample = partial( 157 | downsample_basic_block, 158 | planes=planes * block.expansion, 159 | stride=stride) 160 | else: 161 | downsample = nn.Sequential( 162 | nn.Conv3d( 163 | self.inplanes, 164 | planes * block.expansion, 165 | kernel_size=1, 166 | stride=stride, 167 | bias=False), nn.BatchNorm3d(planes * block.expansion)) 168 | 169 | layers = [] 170 | layers.append(block(self.inplanes, planes, stride, downsample)) 171 | self.inplanes = planes * block.expansion 172 | for i in range(1, blocks): 173 | layers.append(block(self.inplanes, planes)) 174 | 175 | return nn.Sequential(*layers) 176 | 177 | def load_matched_state_dict(self, state_dict): 178 | 179 | own_state = self.state_dict() 180 | for name, param in state_dict.items(): 181 | if name not in own_state: 182 | continue 183 | #if isinstance(param, Parameter): 184 | # backwards compatibility for serialized parameters 185 | param = param.data 186 | print("loading "+name) 187 | own_state[name].copy_(param) 188 | 189 | def forward(self, x): 190 | # default size is (b, s, c, w, h), s for seq_len, c for channel 191 | # convert for 3d cnn, (b, c, s, w, h) 192 | x=x.permute(0,2,1,3,4) 193 | x = self.conv1(x) 194 | x = self.bn1(x) 195 | x = self.relu(x) 196 | x = self.maxpool(x) 197 | 198 | x = self.layer1(x) 199 | x = self.layer2(x) 200 | x = self.layer3(x) 201 | x = self.layer4(x) 202 | x = self.avgpool(x) 203 | x = x.view(x.size(0), -1) 204 | y = self.fc(x) 205 | 206 | return y, x 207 | 208 | 209 | def get_fine_tuning_parameters(model, ft_begin_index): 210 | if ft_begin_index == 0: 211 | return model.parameters() 212 | 213 | ft_module_names = [] 214 | for i in range(ft_begin_index, 5): 215 | ft_module_names.append('layer{}'.format(i)) 216 | ft_module_names.append('fc') 217 | 218 | parameters = [] 219 | for k, v in model.named_parameters(): 220 | for ft_module in ft_module_names: 221 | if ft_module in k: 222 | parameters.append({'params': v}) 223 | break 224 | else: 225 | parameters.append({'params': v, 'lr': 0.0}) 226 | 227 | return parameters 228 | 229 | 230 | def resnet10(**kwargs): 231 | """Constructs a ResNet-18 model. 232 | """ 233 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 234 | return model 235 | 236 | 237 | def resnet18(**kwargs): 238 | """Constructs a ResNet-18 model. 239 | """ 240 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 241 | return model 242 | 243 | 244 | def resnet34(**kwargs): 245 | """Constructs a ResNet-34 model. 246 | """ 247 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 248 | return model 249 | 250 | 251 | def resnet50(**kwargs): 252 | """Constructs a ResNet-50 model. 253 | """ 254 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 255 | return model 256 | 257 | 258 | def resnet101(**kwargs): 259 | """Constructs a ResNet-101 model. 260 | """ 261 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 262 | return model 263 | 264 | 265 | def resnet152(**kwargs): 266 | """Constructs a ResNet-101 model. 267 | """ 268 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 269 | return model 270 | 271 | 272 | def resnet200(**kwargs): 273 | """Constructs a ResNet-101 model. 274 | """ 275 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) 276 | return model 277 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os, shutil 4 | import numpy as np 5 | import scipy.sparse as sp 6 | from torch.autograd import Variable 7 | 8 | def normalize_adj(mx): 9 | """Row-normalize sparse matrix""" 10 | rowsum = np.array(mx.sum(1)) 11 | r_inv_sqrt = np.power(rowsum, -0.5).flatten() 12 | r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0. 13 | r_mat_inv_sqrt = sp.diags(r_inv_sqrt) 14 | return mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt) 15 | 16 | def build_adj(t=4, p=4): 17 | rows = [] 18 | cols = [] 19 | for j in range(t-1): 20 | for i in range(p): 21 | if i == 0: 22 | rows += [i+j*p, i+j*p] 23 | cols += [i+(j+1)*p, i+(j+1)*p+1] 24 | elif i == p-1: 25 | rows += [i+j*p, i+j*p] 26 | cols += [i+(j+1)*p-1, i+(j+1)*p] 27 | else: 28 | rows += [i+j*p, i+j*p, i+j*p] 29 | cols += [i+(j+1)*p-1, i+(j+1)*p, i+(j+1)*p+1] 30 | data = np.ones(len(rows)) 31 | rows = np.asarray(rows) 32 | cols = np.asarray(cols) 33 | adj = sp.coo_matrix((data, (rows, cols)), shape=(t*p, t*p), dtype=np.float32) 34 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 35 | #print(adj) 36 | adj = normalize_adj(adj + sp.eye(adj.shape[0])) 37 | adj = torch.FloatTensor(np.array(adj.todense())) 38 | return adj 39 | 40 | def build_adj_full(t=4, p=4): 41 | rows = [] 42 | cols = [] 43 | for j in range(t-1): 44 | for i in range(p): 45 | rows += [i+j*p for k in range(p)] 46 | cols += range((j+1)*p, (j+1)*p+p) 47 | data = np.ones(len(rows)) 48 | rows = np.asarray(rows) 49 | cols = np.asarray(cols) 50 | adj = sp.coo_matrix((data, (rows, cols)), shape=(t*p, t*p), dtype=np.float32) 51 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 52 | #print(adj) 53 | adj = normalize_adj(adj + sp.eye(adj.shape[0])) 54 | adj = torch.FloatTensor(np.array(adj.todense())) 55 | return adj 56 | 57 | def build_adj_full_d(t=4, p=4, d=1): 58 | rows = [] 59 | cols = [] 60 | for dd in range(d): 61 | for j in range(t-dd-1): 62 | for i in range(p): 63 | rows += [i+j*p for k in range(p)] 64 | cols += range((j+1+dd)*p, (j+1+dd)*p+p) 65 | data = np.ones(len(rows)) 66 | rows = np.asarray(rows) 67 | cols = np.asarray(cols) 68 | adj = sp.coo_matrix((data, (rows, cols)), shape=(t*p, t*p), dtype=np.float32) 69 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 70 | #print(adj) 71 | adj = normalize_adj(adj + sp.eye(adj.shape[0])) 72 | adj = torch.FloatTensor(np.array(adj.todense())) 73 | return adj 74 | 75 | def build_adj_full_full(t=4, p=4): 76 | rows = [] 77 | cols = [] 78 | for j in range(t-1): 79 | for i in range(p): 80 | rows += [i+j*p for k in range(p*(t-1-j))] 81 | cols += range((j+1)*p, p*t) 82 | data = np.ones(len(rows)) 83 | rows = np.asarray(rows) 84 | cols = np.asarray(cols) 85 | adj = sp.coo_matrix((data, (rows, cols)), shape=(t*p, t*p), dtype=np.float32) 86 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 87 | adj = normalize_adj(adj + sp.eye(adj.shape[0])) 88 | adj = torch.FloatTensor(np.array(adj.todense())) 89 | return adj 90 | 91 | def build_adj_full_circle(t=4, p=4): 92 | rows = [] 93 | cols = [] 94 | for j in range(t-1): 95 | for i in range(p): 96 | if j == 0: 97 | rows += [i+j*p for k in range(p)] 98 | cols += range((t-1)*p, (t-1)*p + p) 99 | rows += [i+j*p for k in range(p)] 100 | cols += range((j+1)*p, (j+1)*p+p) 101 | 102 | data = np.ones(len(rows)) 103 | rows = np.asarray(rows) 104 | cols = np.asarray(cols) 105 | adj = sp.coo_matrix((data, (rows, cols)), shape=(t*p, t*p), dtype=np.float32) 106 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 107 | #print(adj) 108 | adj = normalize_adj(adj + sp.eye(adj.shape[0])) 109 | adj = torch.FloatTensor(np.array(adj.todense())) 110 | 111 | 112 | def repackage_hidden(h): 113 | if type(h) == Variable: 114 | return Variable(h.data) 115 | else: 116 | return tuple(repackage_hidden(v) for v in h) 117 | 118 | 119 | def batchify(data, bsz, args): 120 | nbatch = data.size(0) // bsz 121 | data = data.narrow(0, 0, nbatch * bsz) 122 | data = data.view(bsz, -1).t().contiguous() 123 | print(data.size()) 124 | if args.cuda: 125 | data = data.cuda() 126 | return data 127 | 128 | 129 | def get_batch(source, i, args, seq_len=None, evaluation=False): 130 | seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i) 131 | data = Variable(source[i:i+seq_len], volatile=evaluation) 132 | target = Variable(source[i+1:i+1+seq_len]) 133 | return data, target 134 | 135 | 136 | def create_exp_dir(path, scripts_to_save=None): 137 | if not os.path.exists(path): 138 | os.mkdir(path) 139 | 140 | print('Experiment dir : {}'.format(path)) 141 | if scripts_to_save is not None: 142 | os.mkdir(os.path.join(path, 'scripts')) 143 | for script in scripts_to_save: 144 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 145 | shutil.copyfile(script, dst_file) 146 | 147 | 148 | def save_checkpoint(model, optimizer, epoch, path, finetune=False): 149 | if finetune: 150 | torch.save(model, os.path.join(path, 'finetune_model.pt')) 151 | torch.save(optimizer.state_dict(), os.path.join(path, 'finetune_optimizer.pt')) 152 | else: 153 | torch.save(model, os.path.join(path, 'model.pt')) 154 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer.pt')) 155 | torch.save({'epoch': epoch+1}, os.path.join(path, 'misc.pt')) 156 | 157 | 158 | def embedded_dropout(embed, words, dropout=0.1, scale=None): 159 | if dropout: 160 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout) 161 | mask = Variable(mask) 162 | masked_embed_weight = mask * embed.weight 163 | else: 164 | masked_embed_weight = embed.weight 165 | if scale: 166 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight 167 | 168 | padding_idx = embed.padding_idx 169 | if padding_idx is None: 170 | padding_idx = -1 171 | X = embed._backend.Embedding.apply(words, masked_embed_weight, 172 | padding_idx, embed.max_norm, embed.norm_type, 173 | embed.scale_grad_by_freq, embed.sparse 174 | ) 175 | return X 176 | 177 | 178 | class LockedDropout(nn.Module): 179 | def __init__(self): 180 | super(LockedDropout, self).__init__() 181 | 182 | def forward(self, x, dropout=0.5): 183 | if not self.training or not dropout: 184 | return x 185 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout) 186 | mask = Variable(m.div_(1 - dropout), requires_grad=False) 187 | mask = mask.expand_as(x) 188 | return mask * x 189 | 190 | 191 | def mask2d(B, D, keep_prob, cuda=True): 192 | m = torch.floor(torch.rand(B, D) + keep_prob) / keep_prob 193 | m = Variable(m, requires_grad=False) 194 | if cuda: 195 | m = m.cuda() 196 | return m 197 | 198 | -------------------------------------------------------------------------------- /run_hypergraphsage_part.sh: -------------------------------------------------------------------------------- 1 | python main_video_person_reid_hypergraphsage_part.py --arch=resnet50graphpoolparthyper --gpu-devices='0' --save-dir=log_hypergraphsagepart --eval-step=100 --height=256 --width=128 --warmup --train-batch=32 --seq-len=8 2 | 3 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | import torch 6 | 7 | class RandomIdentitySampler(torch.utils.data.sampler.Sampler): 8 | """ 9 | Randomly sample N identities, then for each identity, 10 | randomly sample K instances, therefore batch size is N*K. 11 | 12 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 13 | 14 | Args: 15 | data_source (Dataset): dataset to sample from. 16 | num_instances (int): number of instances per identity. 17 | """ 18 | def __init__(self, data_source, num_instances=4): 19 | self.data_source = data_source 20 | self.num_instances = num_instances 21 | self.index_dic = defaultdict(list) 22 | for index, (_, pid, _) in enumerate(data_source): 23 | self.index_dic[pid].append(index) 24 | self.pids = list(self.index_dic.keys()) 25 | self.num_identities = len(self.pids) 26 | 27 | def __iter__(self): 28 | indices = torch.randperm(self.num_identities) 29 | ret = [] 30 | for i in indices: 31 | pid = self.pids[i] 32 | t = self.index_dic[pid] 33 | replace = False if len(t) >= self.num_instances else True 34 | t = np.random.choice(t, size=self.num_instances, replace=replace) 35 | ret.extend(t) 36 | return iter(ret) 37 | 38 | def __len__(self): 39 | return self.num_identities * self.num_instances 40 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import numpy as np 7 | import math 8 | 9 | class Random2DTranslation(object): 10 | """ 11 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 12 | 13 | Args: 14 | height (int): target height. 15 | width (int): target width. 16 | p (float): probability of performing this transformation. Default: 0.5. 17 | """ 18 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 19 | self.height = height 20 | self.width = width 21 | self.p = p 22 | self.interpolation = interpolation 23 | 24 | def __call__(self, img): 25 | """ 26 | Args: 27 | img (PIL Image): Image to be cropped. 28 | 29 | Returns: 30 | PIL Image: Cropped image. 31 | """ 32 | if random.random() < self.p: 33 | return img.resize((self.width, self.height), self.interpolation) 34 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 35 | resized_img = img.resize((new_width, new_height), self.interpolation) 36 | x_maxrange = new_width - self.width 37 | y_maxrange = new_height - self.height 38 | x1 = int(round(random.uniform(0, x_maxrange))) 39 | y1 = int(round(random.uniform(0, y_maxrange))) 40 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 41 | return croped_img 42 | 43 | class RandomErasing(object): 44 | """ Randomly selects a rectangle region in an image and erases its pixels. 45 | 'Random Erasing Data Augmentation' by Zhong et al. 46 | See https://arxiv.org/pdf/1708.04896.pdf 47 | Args: 48 | probability: The probability that the Random Erasing operation will be performed. 49 | sl: Minimum proportion of erased area against input image. 50 | sh: Maximum proportion of erased area against input image. 51 | r1: Minimum aspect ratio of erased area. 52 | mean: Erasing value. 53 | """ 54 | 55 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 56 | self.probability = probability 57 | self.mean = mean 58 | self.sl = sl 59 | self.sh = sh 60 | self.r1 = r1 61 | 62 | def __call__(self, img): 63 | 64 | if random.uniform(0, 1) > self.probability: 65 | return img 66 | 67 | for attempt in range(100): 68 | area = img.size()[1] * img.size()[2] 69 | 70 | target_area = random.uniform(self.sl, self.sh) * area 71 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 72 | 73 | h = int(round(math.sqrt(target_area * aspect_ratio))) 74 | w = int(round(math.sqrt(target_area / aspect_ratio))) 75 | 76 | if w < img.size()[2] and h < img.size()[1]: 77 | x1 = random.randint(0, img.size()[1] - h) 78 | y1 = random.randint(0, img.size()[2] - w) 79 | if img.size()[0] == 3: 80 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 81 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 82 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 83 | else: 84 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 85 | return img 86 | 87 | return img 88 | 89 | if __name__ == '__main__': 90 | pass 91 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | import errno 5 | import shutil 6 | import json 7 | import os.path as osp 8 | import numpy as np 9 | import scipy.sparse as sp 10 | import torch 11 | from bisect import bisect_right 12 | 13 | def build_adj(t=4, p=4): 14 | rows = [] 15 | cols = [] 16 | for j in range(t-1): 17 | for i in range(p): 18 | if i == 0: 19 | rows += [i+j*p, i+j*p] 20 | cols += [i+(j+1)*p, i+(j+1)*p+1] 21 | elif i == p-1: 22 | rows += [i+j*p, i+j*p] 23 | cols += [i+(j+1)*p-1, i+(j+1)*p] 24 | else: 25 | rows += [i+j*p, i+j*p, i+j*p] 26 | cols += [i+(j+1)*p-1, i+(j+1)*p, i+(j+1)*p+1] 27 | data = np.ones(len(rows)) 28 | rows = np.asarray(rows) 29 | cols = np.asarray(cols) 30 | adj = sp.coo_matrix((data, (rows, cols)), shape=(t*p, t*p), dtype=np.float32) 31 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 32 | #print(adj) 33 | adj = normalize_adj(adj + sp.eye(adj.shape[0])) 34 | adj = torch.FloatTensor(np.array(adj.todense())) 35 | return adj 36 | 37 | def build_adj_full(t=4, p=4): 38 | rows = [] 39 | cols = [] 40 | for j in range(t-1): 41 | for i in range(p): 42 | rows += [i+j*p for k in range(p)] 43 | cols += range((j+1)*p, (j+1)*p+p) 44 | data = np.ones(len(rows)) 45 | rows = np.asarray(rows) 46 | cols = np.asarray(cols) 47 | adj = sp.coo_matrix((data, (rows, cols)), shape=(t*p, t*p), dtype=np.float32) 48 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 49 | #print(adj) 50 | adj = normalize_adj(adj + sp.eye(adj.shape[0])) 51 | adj = torch.FloatTensor(np.array(adj.todense())) 52 | return adj 53 | 54 | def build_adj_full_full(t=4, p=4): 55 | rows = [] 56 | cols = [] 57 | for j in range(t-1): 58 | for i in range(p): 59 | rows += [i+j*p for k in range(p*(t-1-j))] 60 | cols += range((j+1)*p, p*t) 61 | data = np.ones(len(rows)) 62 | rows = np.asarray(rows) 63 | cols = np.asarray(cols) 64 | adj = sp.coo_matrix((data, (rows, cols)), shape=(t*p, t*p), dtype=np.float32) 65 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 66 | adj = normalize_adj(adj + sp.eye(adj.shape[0])) 67 | adj = torch.FloatTensor(np.array(adj.todense())) 68 | return adj 69 | 70 | def build_adj_full_circle(t=4, p=4): 71 | rows = [] 72 | cols = [] 73 | for j in range(t-1): 74 | for i in range(p): 75 | if j == 0: 76 | rows += [i+j*p for k in range(p)] 77 | cols += range((t-1)*p, (t-1)*p + p) 78 | rows += [i+j*p for k in range(p)] 79 | cols += range((j+1)*p, (j+1)*p+p) 80 | 81 | data = np.ones(len(rows)) 82 | rows = np.asarray(rows) 83 | cols = np.asarray(cols) 84 | adj = sp.coo_matrix((data, (rows, cols)), shape=(t*p, t*p), dtype=np.float32) 85 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 86 | #print(adj) 87 | adj = normalize_adj(adj + sp.eye(adj.shape[0])) 88 | adj = torch.FloatTensor(np.array(adj.todense())) 89 | return adj 90 | 91 | def normalize_adj(mx): 92 | """Row-normalize sparse matrix""" 93 | rowsum = np.array(mx.sum(1)) 94 | r_inv_sqrt = np.power(rowsum, -0.5).flatten() 95 | r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0. 96 | r_mat_inv_sqrt = sp.diags(r_inv_sqrt) 97 | return mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt) 98 | 99 | def mkdir_if_missing(directory): 100 | if not osp.exists(directory): 101 | try: 102 | os.makedirs(directory) 103 | except OSError as e: 104 | if e.errno != errno.EEXIST: 105 | raise 106 | 107 | class AverageMeter(object): 108 | """Computes and stores the average and current value. 109 | 110 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 111 | """ 112 | def __init__(self): 113 | self.reset() 114 | 115 | def reset(self): 116 | self.val = 0 117 | self.avg = 0 118 | self.sum = 0 119 | self.count = 0 120 | 121 | def update(self, val, n=1): 122 | self.val = val 123 | self.sum += val * n 124 | self.count += n 125 | self.avg = self.sum / self.count 126 | 127 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 128 | mkdir_if_missing(osp.dirname(fpath)) 129 | torch.save(state, fpath) 130 | if is_best: 131 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) 132 | 133 | class Logger(object): 134 | """ 135 | Write console output to external text file. 136 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 137 | """ 138 | def __init__(self, fpath=None): 139 | self.console = sys.stdout 140 | self.file = None 141 | if fpath is not None: 142 | mkdir_if_missing(os.path.dirname(fpath)) 143 | self.file = open(fpath, 'w') 144 | 145 | def __del__(self): 146 | self.close() 147 | 148 | def __enter__(self): 149 | pass 150 | 151 | def __exit__(self, *args): 152 | self.close() 153 | 154 | def write(self, msg): 155 | self.console.write(msg) 156 | if self.file is not None: 157 | self.file.write(msg) 158 | 159 | def flush(self): 160 | self.console.flush() 161 | if self.file is not None: 162 | self.file.flush() 163 | os.fsync(self.file.fileno()) 164 | 165 | def close(self): 166 | self.console.close() 167 | if self.file is not None: 168 | self.file.close() 169 | 170 | def read_json(fpath): 171 | with open(fpath, 'r') as f: 172 | obj = json.load(f) 173 | return obj 174 | 175 | def write_json(obj, fpath): 176 | mkdir_if_missing(osp.dirname(fpath)) 177 | with open(fpath, 'w') as f: 178 | json.dump(obj, f, indent=4, separators=(',', ': ')) 179 | 180 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 181 | def __init__( 182 | self, 183 | optimizer, 184 | milestones, 185 | gamma=0.1, 186 | warmup_factor=0.01, 187 | warmup_iters=20., 188 | warmup_method="linear", 189 | last_epoch=-1, 190 | ): 191 | if not list(milestones) == sorted(milestones): 192 | raise ValueError( 193 | "Milestones should be a list of" " increasing integers. Got {}", 194 | milestones, 195 | ) 196 | 197 | if warmup_method not in ("constant", "linear"): 198 | raise ValueError( 199 | "Only 'constant' or 'linear' warmup_method accepted" 200 | "got {}".format(warmup_method) 201 | ) 202 | self.milestones = milestones 203 | self.gamma = gamma 204 | self.warmup_factor = warmup_factor 205 | self.warmup_iters = warmup_iters 206 | self.warmup_method = warmup_method 207 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 208 | #print(self.last_epoch) 209 | 210 | def get_lr(self): 211 | warmup_factor = 1 212 | if self.last_epoch < self.warmup_iters: 213 | if self.warmup_method == "constant": 214 | warmup_factor = self.warmup_factor 215 | elif self.warmup_method == "linear": 216 | #print(self.last_epoch) 217 | alpha = (self.last_epoch + 1) / self.warmup_iters 218 | #print(alpha) 219 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 220 | #print(warmup_factor) 221 | return [ 222 | base_lr 223 | * warmup_factor 224 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 225 | for base_lr in self.base_lrs 226 | ] 227 | 228 | -------------------------------------------------------------------------------- /video_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | import random 9 | 10 | def read_image(img_path): 11 | """Keep reading image until succeed. 12 | This can avoid IOError incurred by heavy IO process.""" 13 | got_img = False 14 | while not got_img: 15 | try: 16 | img = Image.open(img_path).convert('RGB') 17 | got_img = True 18 | except IOError: 19 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 20 | pass 21 | return img 22 | 23 | 24 | class VideoDataset(Dataset): 25 | """Video Person ReID Dataset. 26 | Note batch data has shape (batch, seq_len, channel, height, width). 27 | """ 28 | sample_methods = ['evenly', 'random', 'all'] 29 | 30 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None): 31 | self.dataset = dataset 32 | self.seq_len = seq_len 33 | self.sample = sample 34 | self.transform = transform 35 | 36 | def __len__(self): 37 | return len(self.dataset) 38 | 39 | def __getitem__(self, index): 40 | #print(index, len(self.dataset)) 41 | img_paths, pid, camid = self.dataset[index] 42 | num = len(img_paths) 43 | if self.sample == 'random': 44 | """ 45 | Randomly sample seq_len consecutive frames from num frames, 46 | if num is smaller than seq_len, then replicate items. 47 | This sampling strategy is used in training phase. 48 | """ 49 | frame_indices = list(range(num)) 50 | rand_end = max(0, len(frame_indices) - self.seq_len - 1) 51 | begin_index = random.randint(0, rand_end) 52 | end_index = min(begin_index + self.seq_len, len(frame_indices)) 53 | 54 | indices = frame_indices[begin_index:end_index] 55 | 56 | for index in indices: 57 | if len(indices) >= self.seq_len: 58 | break 59 | indices.append(index) 60 | indices=np.array(indices) 61 | imgs = [] 62 | for index in indices: 63 | index=int(index) 64 | img_path = img_paths[index] 65 | img = read_image(img_path) 66 | if self.transform is not None: 67 | img = self.transform(img) 68 | img = img.unsqueeze(0) 69 | imgs.append(img) 70 | imgs = torch.cat(imgs, dim=0) 71 | #imgs=imgs.permute(1,0,2,3) 72 | return imgs, pid, camid 73 | 74 | elif self.sample == 'dense': 75 | """ 76 | Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1. 77 | This sampling strategy is used in test phase. 78 | """ 79 | cur_index=0 80 | frame_indices = list(range(num)) 81 | indices_list=[] 82 | while num-cur_index > self.seq_len: 83 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len]) 84 | cur_index+=self.seq_len 85 | last_seq=frame_indices[cur_index:] 86 | for index in last_seq: 87 | if len(last_seq) >= self.seq_len: 88 | break 89 | last_seq.append(index) 90 | indices_list.append(last_seq) 91 | imgs_list=[] 92 | for indices in indices_list: 93 | imgs = [] 94 | for index in indices: 95 | index=int(index) 96 | img_path = img_paths[index] 97 | img = read_image(img_path) 98 | if self.transform is not None: 99 | img = self.transform(img) 100 | img = img.unsqueeze(0) 101 | imgs.append(img) 102 | imgs = torch.cat(imgs, dim=0) 103 | #imgs=imgs.permute(1,0,2,3) 104 | imgs_list.append(imgs) 105 | imgs_array = torch.stack(imgs_list) 106 | return imgs_array, pid, camid 107 | 108 | else: 109 | raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods)) 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | --------------------------------------------------------------------------------