├── .idea ├── deployment.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── vcs.xml ├── workspace.xml └── zzz.iml ├── README.md ├── RegDB_test ├── RegDB_test.py ├── __init__.py ├── data_manager.py ├── eval_metrics.py └── utils.py ├── __pycache__ └── utlis.cpython-37.pyc ├── evaluation ├── README.md ├── __pycache__ │ └── gen_utils.cpython-37.pyc ├── data_split │ ├── rand_perm_cam.mat │ ├── test_id.mat │ └── train_id.mat ├── demo.m ├── euclidean_dist.m ├── evaluate_SYSU_MM01.py ├── evaluation_SYSU_MM01.m ├── gen_utils.py ├── get_cmc_multi_cam.m ├── get_map_multi_cam.m ├── get_testing_set.m ├── result │ ├── result__euclidean_all_search_10shot.mat │ ├── result__euclidean_all_search_1shot.mat │ ├── result__euclidean_indoor_search_10shot.mat │ └── result__euclidean_indoor_search_1shot.mat ├── train_id.mat └── val_id.mat ├── extract_feature.py ├── images ├── embedding_spaces.jpg ├── framework.jpg └── joint_embedding_spaces.jpg ├── mm01.py ├── regdb.py ├── reid ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── New_trainers_3.cpython-36.pyc │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── dist_metric.cpython-36.pyc │ ├── dist_metric.cpython-37.pyc │ ├── evaluators.cpython-36.pyc │ ├── evaluators.cpython-37.pyc │ ├── trainers.cpython-36.pyc │ └── trainers.cpython-37.pyc ├── datasets │ ├── RegDB.py │ ├── RegDB.pyc │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── RegDB.cpython-36.pyc │ │ ├── RegDB.cpython-37.pyc │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── sysu.cpython-36.pyc │ │ └── sysu.cpython-37.pyc │ ├── sysu.py │ └── sysu.pyc ├── dist_metric.py ├── dist_metric.pyc ├── evaluation_metrics │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── ranking.cpython-35.pyc │ │ ├── ranking.cpython-36.pyc │ │ └── ranking.cpython-37.pyc │ ├── ranking.py │ └── ranking.pyc ├── evaluators.py ├── evaluators.pyc ├── evaluators_regdb.py ├── feature_extraction │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── cnn.cpython-35.pyc │ │ ├── cnn.cpython-36.pyc │ │ └── cnn.cpython-37.pyc │ ├── cnn.py │ └── cnn.pyc ├── metric_learning │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── euclidean.cpython-36.pyc │ │ └── euclidean.cpython-37.pyc │ ├── euclidean.py │ └── euclidean.pyc ├── models │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── baseline.cpython-37.pyc │ │ ├── newresnet.cpython-36.pyc │ │ └── newresnet.cpython-37.pyc │ ├── baseline.py │ ├── baseline.pyc │ ├── newresnet.py │ └── newresnet.pyc ├── trainers.py ├── trainers.pyc └── utils │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── logging.cpython-36.pyc │ ├── logging.cpython-37.pyc │ ├── meters.cpython-36.pyc │ ├── meters.cpython-37.pyc │ ├── osutils.cpython-35.pyc │ ├── osutils.cpython-36.pyc │ ├── osutils.cpython-37.pyc │ ├── serialization.cpython-35.pyc │ ├── serialization.cpython-36.pyc │ └── serialization.cpython-37.pyc │ ├── data │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── dataset.cpython-35.pyc │ │ ├── dataset.cpython-36.pyc │ │ ├── dataset.cpython-37.pyc │ │ ├── preprocessor.cpython-35.pyc │ │ ├── preprocessor.cpython-36.pyc │ │ ├── preprocessor.cpython-37.pyc │ │ ├── sampler.cpython-36.pyc │ │ ├── sampler.cpython-37.pyc │ │ ├── transforms.cpython-36.pyc │ │ └── transforms.cpython-37.pyc │ ├── dataset.py │ ├── dataset.pyc │ ├── preprocessor.py │ ├── preprocessor.pyc │ ├── sampler.py │ ├── sampler.pyc │ ├── transforms.py │ └── transforms.pyc │ ├── logging.py │ ├── logging.pyc │ ├── meters.py │ ├── meters.pyc │ ├── osutils.py │ ├── osutils.pyc │ ├── serialization.py │ └── serialization.pyc ├── train.py └── utlis.py /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/zzz.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Farewell to Mutual Information: Variational Distiilation for Cross-Modal Person Re-identification 2 | 3 | Official implementation of the Variational Distillation framework from "Farewell to Mutual Information: Variational Distiilation for Cross-Modal Person Re-identification (CVPR' 21 oral)". 4 | 5 |
6 | 7 |
8 | 9 | Please read [our paper](https://arxiv.org/abs/2104.02862) for a more detailed description of the training procedure. 10 | 11 | ### Bibtex 12 | Please use the following bibtex for citations: 13 | ```latex 14 | @inproceedings{VariationalDistillation, 15 | title={Farewell to Mutual Information Variational Distiilation for Cross-Modal Person Re-identification}, 16 | author={Xudong Tian and Zhizhong Zhang and Shaohui Lin and Yanyun Qu and Yuan Xie and Lizhuang Ma}, 17 | booktitle={Computer Vision and Pattern Recognition}, 18 | year={2021} 19 | } 20 | ``` 21 | 22 | ## Get Started 23 | 1. `cd` to folder where you want to download this project 24 | 25 | 2. Run `git clone https://github.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification.git` 26 | 27 | 3. Install dependencies: 28 | - python>=3.7.0 29 | - [pytorch>=1.3.0](https://pytorch.org/) 30 | - torchvision 31 | 32 | 4. Prepare datasets 33 | - Download RegDB 34 | - Download [SYSU-MM01](https://github.com/wuancong/SYSU-MM01) 35 | 36 | Create a directory to store the required datasets under this project or outside this project, and remember to set `--data-dir` to the right path before training. 37 | 38 | ## Train 39 | This project provides code to train and evaluate different architectures under both datasets. You can directly run `/mm01.py` and `regdb.py` under the default settings or conduct customized modifications for both datasets. 40 | 41 | ## Evaluation 42 | - MM01: To evaluate the model under standard protocol, you need to run `/feature_extract.py` to obtain features at first, then run `/evaluation/evaluation_SYSU_MM01.py` to conduct standard evaluation. 43 | - RegDB: You can directly run `/RegDB_test/RegDB_test.py` to obtain Visible-Thermal performance, and change the default settings to evaluate the model under another setting, i.e., Thermal-Visible. 44 | 45 | ## Results 46 | SYSU-MM01 (all-search mode) 47 | | Metric | Value | 48 | | --- | --- | 49 | | Rank1 | 60.02\% | 50 | | Rank10 | 94.18\% | 51 | | Rank20 | 98.14\% | 52 | | mAP | 58.80\% | 53 | 54 | SYSU-MM01 (indoor-search mode) 55 | | Metric | Value | 56 | | --- | --- | 57 | | Rank1 | 66.05\% | 58 | | Rank10 | 96.59\% | 59 | | Rank20 | 99.38\% | 60 | | mAP | 72.98\% | 61 | 62 | RegDB 63 | | Mode | Rank-1 (mAP) | 64 | | --- | --- | 65 | | Visible-Thermal | 73.2\% (71.6\%) | 66 | | Thermal-Visible | 71.8\% (70.1\%) | 67 | 68 | ## Visualization 69 | 2-D projection of the embedding space obtained by using t-SNE. The results are obtained from our method and the conventional information bottleneck on SYSU-MM01 dataset. Different colors are used to denote different person IDs. 70 | 71 |
72 | 73 |
74 | 75 | In addition, we plot the joint embedding space of data from different modals for better visualization. Note more descriptions and details could be found in [our paper](https://arxiv.org/abs/2104.02862). 76 | 77 |
78 | 79 |
80 | -------------------------------------------------------------------------------- /RegDB_test/RegDB_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from torch import nn 3 | import argparse 4 | import torch.backends.cudnn as cudnn 5 | import torchvision.transforms as transforms 6 | from reid.utils.serialization import load_checkpoint 7 | from reid.models import ft_net 8 | from eval_metrics import eval_regdb 9 | from utils import * 10 | from PIL import Image 11 | from reid.utils.data import transforms as T 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 14 | parser.add_argument('--dataset', default='regdb', help='dataset name: regdb or sysu]') 15 | 16 | parser.add_argument('--features', type=int, default=2048, help='feature dimensions') 17 | parser.add_argument('--z_dim', type=int, default=256, help='information bottleneck dimensions') 18 | 19 | parser.add_argument('--height', type=int, default=256, help='image height, should be chosen in {256, 288, 312}') 20 | parser.add_argument('--width', type=int, default=128, help='image width, should be chosen in {128, 144, 156}') 21 | 22 | parser.add_argument('--model_path', default='/home/txd/Variational Distillation/Exp_RegDB/RegDB_2/', type=str) 23 | args = parser.parse_args() 24 | 25 | # overall settings 26 | data_path = '../data/RegDB/' 27 | n_class = 206 28 | test_mode = [2, 1] 29 | 30 | # model settings 31 | overall_feats = 8192 32 | os.environ["CUDA_VISIBLE_DEVICES"] = "2, 1" 33 | model = ft_net(args=args, num_classes=n_class, num_features=args.features) 34 | model = nn.DataParallel(model, device_ids=[0, 1]) 35 | cudnn.benchmark = True 36 | 37 | #checkpoint_path = args.model_path 38 | 39 | print('==> Loading data..') 40 | 41 | # Data loading code 42 | normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 43 | 44 | test_transformer = T.Compose([ 45 | #T.Resize((312, 156)), 46 | T.Resize((args.height, args.width)), 47 | T.ToTensor(), 48 | normalizer, 49 | ]) 50 | 51 | if args.dataset == 'regdb': 52 | data_dir = '../data/RegDB/' 53 | 54 | for trial in range(1,9): 55 | test_trial = trial +1 56 | model_path = args.model_path + '{}/model_best.pth.tar'.format(trial) 57 | 58 | ######################################### 59 | checkpoint = load_checkpoint(model_path) 60 | state_dict = checkpoint['model'] 61 | from collections import OrderedDict 62 | new_state_dict = OrderedDict() 63 | for k, v in state_dict.items(): 64 | if 'module' not in k or 'cam_module' in k: 65 | k = 'module.' + k 66 | else: 67 | k = k.replace('features.module.', 'module.features.') 68 | new_state_dict[k] = v 69 | model.load_state_dict(new_state_dict) 70 | ######################################### 71 | model = model.cuda() 72 | model.eval() 73 | 74 | # v to t 75 | query_list = data_dir + 'idx/test_visible_{}'.format(test_trial)+ '.txt' 76 | gallery_list = data_dir + 'idx/test_thermal_{}'.format(test_trial)+ '.txt' 77 | # t to v 78 | #query_list = data_dir + 'idx/test_thermal_{}'.format(test_trial)+ '.txt' 79 | #gallery_list = data_dir + 'idx/test_visible_{}'.format(test_trial)+ '.txt' 80 | 81 | query_image_list, query_label = load_data(query_list) 82 | gallery_image_list, gallery_label =load_data(gallery_list) 83 | temp_list = [] 84 | for index, pth in enumerate(query_image_list): 85 | filename = data_dir + pth 86 | img = Image.open(filename) 87 | img = test_transformer(img) 88 | img = img.view([1, 3, args.height, args.width]) 89 | img = img.cuda() 90 | 91 | i_observation, i_representation, i_ms_observation, i_ms_representation, \ 92 | v_observation, v_representation, v_ms_observation, v_ms_representation = model(img) 93 | result_y = torch.cat(tensors=[i_observation[1], i_ms_observation[1], 94 | v_observation[1], v_ms_observation[1]], dim=1) 95 | 96 | result_y = torch.nn.functional.normalize(result_y, dim=1, p=2) 97 | result_y = result_y.view(-1, overall_feats) 98 | result_y = result_y.squeeze() 99 | result_npy = result_y.data.cpu().numpy() 100 | result_npy = result_npy.astype('double') 101 | temp_list.append(result_npy) 102 | 103 | query_feat, query_label = np.array(temp_list), np.array(query_label) 104 | print('Query feature extraction: done') 105 | 106 | temp_list = [] 107 | for index, pth in enumerate(gallery_image_list): 108 | filename = data_dir + pth 109 | img=Image.open(filename) 110 | img=test_transformer(img) 111 | img=img.view([1, 3, args.height, args.width]) 112 | img=img.cuda() 113 | 114 | i_observation, i_representation, i_ms_observation, i_ms_representation, \ 115 | v_observation, v_representation, v_ms_observation, v_ms_representation = model(img) 116 | result_y = torch.cat(tensors=[i_observation[1], i_ms_observation[1], 117 | v_observation[1], v_ms_observation[1]], dim=1) 118 | 119 | result_y = torch.nn.functional.normalize(result_y, dim=1, p=2) 120 | result_y = result_y.view(-1, overall_feats) 121 | result_y = result_y.squeeze() 122 | result_npy=result_y.data.cpu().numpy() 123 | result_npy=result_npy.astype('double') 124 | temp_list.append(result_npy) 125 | 126 | gallery_feat, gallery_label = np.array(temp_list), np.array(gallery_label) 127 | 128 | print('Gallery feature extraction: done') 129 | distmat = np.matmul(query_feat, np.transpose(gallery_feat)) 130 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gallery_label) 131 | 132 | if trial == 0: 133 | all_cmc = cmc 134 | all_mAP = mAP 135 | all_mINP = mINP 136 | 137 | print( 138 | 'Results: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 139 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 140 | 141 | 142 | def load_data(input_data_path ): 143 | with open(input_data_path) as f: 144 | data_file_list = open(input_data_path, 'rt').read().splitlines() 145 | file_image = [s.split(' ')[0] for s in data_file_list] 146 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 147 | 148 | return file_image, file_label -------------------------------------------------------------------------------- /RegDB_test/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import utils 4 | from . import data_manager 5 | from . import eval_metrics 6 | from . import RegDB_test 7 | from .eval_metrics import eval_regdb 8 | from .utils import * 9 | 10 | __version__ = '0.2.0' -------------------------------------------------------------------------------- /RegDB_test/data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import numpy as np 4 | import random 5 | 6 | def process_query_sysu(data_path, mode = 'all', relabel=False): 7 | if mode== 'all': 8 | ir_cameras = ['cam3','cam6'] 9 | elif mode =='indoor': 10 | ir_cameras = ['cam3','cam6'] 11 | 12 | file_path = os.path.join(data_path,'exp/test_id.txt') 13 | files_rgb = [] 14 | files_ir = [] 15 | 16 | with open(file_path, 'r') as file: 17 | ids = file.read().splitlines() 18 | ids = [int(y) for y in ids[0].split(',')] 19 | ids = ["%04d" % x for x in ids] 20 | 21 | for id in sorted(ids): 22 | for cam in ir_cameras: 23 | img_dir = os.path.join(data_path,cam,id) 24 | if os.path.isdir(img_dir): 25 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 26 | files_ir.extend(new_files) 27 | query_img = [] 28 | query_id = [] 29 | query_cam = [] 30 | for img_path in files_ir: 31 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 32 | query_img.append(img_path) 33 | query_id.append(pid) 34 | query_cam.append(camid) 35 | return query_img, np.array(query_id), np.array(query_cam) 36 | 37 | def process_gallery_sysu(data_path, mode = 'all', trial = 0, relabel=False): 38 | 39 | random.seed(trial) 40 | 41 | if mode== 'all': 42 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 43 | elif mode =='indoor': 44 | rgb_cameras = ['cam1','cam2'] 45 | 46 | file_path = os.path.join(data_path,'exp/test_id.txt') 47 | files_rgb = [] 48 | with open(file_path, 'r') as file: 49 | ids = file.read().splitlines() 50 | ids = [int(y) for y in ids[0].split(',')] 51 | ids = ["%04d" % x for x in ids] 52 | 53 | for id in sorted(ids): 54 | for cam in rgb_cameras: 55 | img_dir = os.path.join(data_path,cam,id) 56 | if os.path.isdir(img_dir): 57 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 58 | files_rgb.append(random.choice(new_files)) 59 | gall_img = [] 60 | gall_id = [] 61 | gall_cam = [] 62 | for img_path in files_rgb: 63 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 64 | gall_img.append(img_path) 65 | gall_id.append(pid) 66 | gall_cam.append(camid) 67 | return gall_img, np.array(gall_id), np.array(gall_cam) 68 | 69 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'): 70 | if modal=='visible': 71 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 72 | elif modal=='thermal': 73 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 74 | 75 | with open(input_data_path) as f: 76 | data_file_list = open(input_data_path, 'rt').read().splitlines() 77 | # Get full list of image and labels 78 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 79 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 80 | 81 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /RegDB_test/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | """Cross-Modality ReID""" 4 | import pdb 5 | 6 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20): 7 | """Evaluation with sysu metric 8 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 9 | """ 10 | num_q, num_g = distmat.shape 11 | if num_g < max_rank: 12 | max_rank = num_g 13 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 14 | indices = np.argsort(distmat, axis=1) 15 | pred_label = g_pids[indices] 16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 17 | 18 | # compute cmc curve for each query 19 | new_all_cmc = [] 20 | all_cmc = [] 21 | all_AP = [] 22 | all_INP = [] 23 | num_valid_q = 0. # number of valid query 24 | for q_idx in range(num_q): 25 | # get query pid and camid 26 | q_pid = q_pids[q_idx] 27 | q_camid = q_camids[q_idx] 28 | 29 | # remove gallery samples that have the same pid and camid with query 30 | order = indices[q_idx] 31 | remove = (q_camid == 3) & (g_camids[order] == 2) 32 | keep = np.invert(remove) 33 | 34 | # compute cmc curve 35 | # the cmc calculation is different from standard protocol 36 | # we follow the protocol of the author's released code 37 | new_cmc = pred_label[q_idx][keep] 38 | new_index = np.unique(new_cmc, return_index=True)[1] 39 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 40 | 41 | new_match = (new_cmc == q_pid).astype(np.int32) 42 | new_cmc = new_match.cumsum() 43 | new_all_cmc.append(new_cmc[:max_rank]) 44 | 45 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 46 | if not np.any(orig_cmc): 47 | # this condition is true when query identity does not appear in gallery 48 | continue 49 | 50 | cmc = orig_cmc.cumsum() 51 | 52 | # compute mINP 53 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 54 | pos_idx = np.where(orig_cmc == 1) 55 | pos_max_idx = np.max(pos_idx) 56 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 57 | all_INP.append(inp) 58 | 59 | cmc[cmc > 1] = 1 60 | 61 | all_cmc.append(cmc[:max_rank]) 62 | num_valid_q += 1. 63 | 64 | # compute average precision 65 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 66 | num_rel = orig_cmc.sum() 67 | tmp_cmc = orig_cmc.cumsum() 68 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 69 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 70 | AP = tmp_cmc.sum() / num_rel 71 | all_AP.append(AP) 72 | 73 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 74 | 75 | all_cmc = np.asarray(all_cmc).astype(np.float32) 76 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 77 | 78 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 79 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 80 | mAP = np.mean(all_AP) 81 | mINP = np.mean(all_INP) 82 | return new_all_cmc, mAP, mINP 83 | 84 | 85 | 86 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 87 | num_q, num_g = distmat.shape 88 | if num_g < max_rank: 89 | max_rank = num_g 90 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 91 | indices = np.argsort(distmat, axis=1) 92 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 93 | 94 | # compute cmc curve for each query 95 | all_cmc = [] 96 | all_AP = [] 97 | all_INP = [] 98 | num_valid_q = 0. # number of valid query 99 | 100 | # only two cameras 101 | q_camids = np.ones(num_q).astype(np.int32) 102 | g_camids = 2* np.ones(num_g).astype(np.int32) 103 | 104 | for q_idx in range(num_q): 105 | # get query pid and camid 106 | q_pid = q_pids[q_idx] 107 | q_camid = q_camids[q_idx] 108 | 109 | # remove gallery samples that have the same pid and camid with query 110 | order = indices[q_idx] 111 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 112 | keep = np.invert(remove) 113 | 114 | # compute cmc curve 115 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 116 | if not np.any(raw_cmc): 117 | # this condition is true when query identity does not appear in gallery 118 | continue 119 | 120 | cmc = raw_cmc.cumsum() 121 | 122 | # compute mINP 123 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 124 | pos_idx = np.where(raw_cmc == 1) 125 | pos_max_idx = np.max(pos_idx) 126 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 127 | all_INP.append(inp) 128 | 129 | cmc[cmc > 1] = 1 130 | 131 | all_cmc.append(cmc[:max_rank]) 132 | num_valid_q += 1. 133 | 134 | # compute average precision 135 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 136 | num_rel = raw_cmc.sum() 137 | tmp_cmc = raw_cmc.cumsum() 138 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 139 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 140 | AP = tmp_cmc.sum() / num_rel 141 | all_AP.append(AP) 142 | 143 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 144 | 145 | all_cmc = np.asarray(all_cmc).astype(np.float32) 146 | all_cmc = all_cmc.sum(0) / num_valid_q 147 | mAP = np.mean(all_AP) 148 | mINP = np.mean(all_INP) 149 | return all_cmc, mAP, mINP -------------------------------------------------------------------------------- /RegDB_test/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | import sys 5 | import os.path as osp 6 | import torch 7 | 8 | def load_data(input_data_path ): 9 | with open(input_data_path) as f: 10 | data_file_list = open(input_data_path, 'rt').read().splitlines() 11 | # Get full list of color image and labels 12 | file_image = [s.split(' ')[0] for s in data_file_list] 13 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 14 | 15 | return file_image, file_label 16 | 17 | 18 | def GenIdx( train_color_label, train_thermal_label): 19 | color_pos = [] 20 | unique_label_color = np.unique(train_color_label) 21 | for i in range(len(unique_label_color)): 22 | tmp_pos = [k for k,v in enumerate(train_color_label) if v==unique_label_color[i]] 23 | color_pos.append(tmp_pos) 24 | 25 | thermal_pos = [] 26 | unique_label_thermal = np.unique(train_thermal_label) 27 | for i in range(len(unique_label_thermal)): 28 | tmp_pos = [k for k,v in enumerate(train_thermal_label) if v==unique_label_thermal[i]] 29 | thermal_pos.append(tmp_pos) 30 | return color_pos, thermal_pos 31 | 32 | def GenCamIdx(gall_img, gall_label, mode): 33 | if mode =='indoor': 34 | camIdx = [1,2] 35 | else: 36 | camIdx = [1,2,4,5] 37 | gall_cam = [] 38 | for i in range(len(gall_img)): 39 | gall_cam.append(int(gall_img[i][-10])) 40 | 41 | sample_pos = [] 42 | unique_label = np.unique(gall_label) 43 | for i in range(len(unique_label)): 44 | for j in range(len(camIdx)): 45 | id_pos = [k for k,v in enumerate(gall_label) if v==unique_label[i] and gall_cam[k]==camIdx[j]] 46 | if id_pos: 47 | sample_pos.append(id_pos) 48 | return sample_pos 49 | 50 | def ExtractCam(gall_img): 51 | gall_cam = [] 52 | for i in range(len(gall_img)): 53 | cam_id = int(gall_img[i][-10]) 54 | # if cam_id ==3: 55 | # cam_id = 2 56 | gall_cam.append(cam_id) 57 | 58 | return np.array(gall_cam) 59 | 60 | class IdentitySampler(Sampler): 61 | """Sample person identities evenly in each batch. 62 | Args: 63 | train_color_label, train_thermal_label: labels of two modalities 64 | color_pos, thermal_pos: positions of each identity 65 | batchSize: batch size 66 | """ 67 | 68 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch): 69 | uni_label = np.unique(train_color_label) 70 | self.n_classes = len(uni_label) 71 | 72 | 73 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 74 | for j in range(int(N/(batchSize*num_pos))+1): 75 | batch_idx = np.random.choice(uni_label, batchSize, replace = False) 76 | for i in range(batchSize): 77 | sample_color = np.random.choice(color_pos[batch_idx[i]], num_pos) 78 | sample_thermal = np.random.choice(thermal_pos[batch_idx[i]], num_pos) 79 | 80 | if j ==0 and i==0: 81 | index1= sample_color 82 | index2= sample_thermal 83 | else: 84 | index1 = np.hstack((index1, sample_color)) 85 | index2 = np.hstack((index2, sample_thermal)) 86 | 87 | self.index1 = index1 88 | self.index2 = index2 89 | self.N = N 90 | 91 | def __iter__(self): 92 | return iter(np.arange(len(self.index1))) 93 | 94 | def __len__(self): 95 | return self.N 96 | 97 | class AverageMeter(object): 98 | """Computes and stores the average and current value""" 99 | def __init__(self): 100 | self.reset() 101 | 102 | def reset(self): 103 | self.val = 0 104 | self.avg = 0 105 | self.sum = 0 106 | self.count = 0 107 | 108 | def update(self, val, n=1): 109 | self.val = val 110 | self.sum += val * n 111 | self.count += n 112 | self.avg = self.sum / self.count 113 | 114 | def mkdir_if_missing(directory): 115 | if not osp.exists(directory): 116 | try: 117 | os.makedirs(directory) 118 | except OSError as e: 119 | if e.errno != errno.EEXIST: 120 | raise 121 | class Logger(object): 122 | """ 123 | Write console output to external text file. 124 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 125 | """ 126 | def __init__(self, fpath=None): 127 | self.console = sys.stdout 128 | self.file = None 129 | if fpath is not None: 130 | mkdir_if_missing(osp.dirname(fpath)) 131 | self.file = open(fpath, 'w') 132 | 133 | def __del__(self): 134 | self.close() 135 | 136 | def __enter__(self): 137 | pass 138 | 139 | def __exit__(self, *args): 140 | self.close() 141 | 142 | def write(self, msg): 143 | self.console.write(msg) 144 | if self.file is not None: 145 | self.file.write(msg) 146 | 147 | def flush(self): 148 | self.console.flush() 149 | if self.file is not None: 150 | self.file.flush() 151 | os.fsync(self.file.fileno()) 152 | 153 | def close(self): 154 | self.console.close() 155 | if self.file is not None: 156 | self.file.close() 157 | 158 | def set_seed(seed, cuda=True): 159 | np.random.seed(seed) 160 | torch.manual_seed(seed) 161 | if cuda: 162 | torch.cuda.manual_seed(seed) 163 | 164 | def set_requires_grad(nets, requires_grad=False): 165 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 166 | Parameters: 167 | nets (network list) -- a list of networks 168 | requires_grad (bool) -- whether the networks require gradients or not 169 | """ 170 | if not isinstance(nets, list): 171 | nets = [nets] 172 | for net in nets: 173 | if net is not None: 174 | for param in net.parameters(): 175 | param.requires_grad = requires_grad -------------------------------------------------------------------------------- /__pycache__/utlis.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/__pycache__/utlis.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Testing demo for SYSU-MM01 dataset 2 | A MATLAB version testing demo is provided for evaluation on SYSU-MM01 dataset in "RGB-Infrared Cross-Modality Person Re-identification, ICCV 2017". 3 | 4 | Dataset download links: 5 | 6 | Baiduyun: http://pan.baidu.com/s/1gfIlcmZ 7 | 8 | Dropbox: https://www.dropbox.com/sh/v036mg1q4yg7awb/AABhxU-FJ4X2oyq7-Ts6bgD0a?dl=0 9 | 10 | Project page: http://isee.sysu.edu.cn/project/RGBIRReID.htm 11 | 12 | Testing code: 13 | https://github.com/wuancong/SYSU-MM01/blob/master/evaluation 14 | 15 | ## Citation 16 | If you use the dataset, please cite the following paper: 17 | 18 | Ancong Wu, Wei-Shi Zheng, Hong-Xing Yu, Shaogang Gong and Jianhuang Lai. RGB-Infrared Cross-Modality Person Re-Identification. IEEE International Conference on Computer Vision (ICCV), 2017. 19 | 20 | ## Testing procedure 21 | 22 | 1. Train a model using the samples of the IDs in "./data_split/train_id.mat" on SYSU-MM01. 23 | 24 | 2. Extract features of SYSU-MM01 dataset using data in "SYSU_MM01.zip" (can be found in the download links). 25 | To apply our provided testing code, the features should be saved in the following form: 26 | Features of each camera are saved in seperated mat files named "name_cam#.mat". 27 | In each mat file, feature{id}(i,:) is a row feature vector of the i-th image of id. 28 | An example of our proposed deep zero padding features is provided in "./feature". 29 | 30 | 3. Run "demo.m". The default setting in "demo.m" is single-shot all-search mode. The input parameters can be set according to the comments in "demo.m". A fixed data split of testing set and 10 trials is provided in "./data_split". The average CMC and mAP results of 10 trials of random split will be displayed when testing is finished. The result of this demo is an re-implemented version and is slightly better than the one reported in the paper. 31 | 32 | ## Contact Information 33 | If you have any questions, please feel free to contact wuancong@mail2.sysu.edu.cn. -------------------------------------------------------------------------------- /evaluation/__pycache__/gen_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/__pycache__/gen_utils.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/data_split/rand_perm_cam.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/data_split/rand_perm_cam.mat -------------------------------------------------------------------------------- /evaluation/data_split/test_id.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/data_split/test_id.mat -------------------------------------------------------------------------------- /evaluation/data_split/train_id.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/data_split/train_id.mat -------------------------------------------------------------------------------- /evaluation/demo.m: -------------------------------------------------------------------------------- 1 | % A demo for evaluation on SYSU-MM01 dataset using features learned by deep 2 | % zero padding in "RGB-Infrared Cross-Modality Person Re-identification" 3 | 4 | % Features of each cameras are saved in seperated mat files named "name_cam#.mat" 5 | % In the mat files, feature{id}(i,:) is a row feature vector of the i-th image of id 6 | %feature_info.name = 'feat_deep_zero_padding_'; 7 | %feature_info.dir = './feature'; 8 | feature_info.dir = '../'; 9 | feature_info.name = ''; 10 | result_dir = './result'; % directory for saving result 11 | 12 | setting.mode = 'all_search'; %'all_search' 'indoor_search' 13 | %setting.mode = 'indoor_search'; 14 | setting.number_shot = 1; % 1 for single shot, 10 for multi-shot 15 | %setting.number_shot = 10; 16 | model.test_fun = @euclidean_dist; % Similarity measurement function 17 | % (You could define your own function in the same way as euclidean_dist function) 18 | model.name = 'euclidean'; % model name 19 | model.para = []; % No parameter is needed for euclidean distance here 20 | % (If mahalanobis distance is used, the parameter can be the metric M learned from training data) 21 | 22 | % load data split 23 | content = load('./data_split/test_id.mat'); % fixed testing person IDs 24 | data_split.test_id = content.id; 25 | 26 | content = load('./data_split/rand_perm_cam.mat'); % fixed permutation of samples in each camera 27 | data_split.rand_perm_cam = content.rand_perm_cam; 28 | 29 | % evaluatio 30 | disp('all-search_single_shot') 31 | performance = evaluation_SYSU_MM01(feature_info, data_split, model, setting, result_dir); 32 | disp('all-search_multi_shot') 33 | setting.mode = 'all_search'; %'all_search' 'indoor_search' 34 | %setting.mode = 'indoor_search'; 35 | setting.number_shot = 10; % 1 for single shot, 10 for multi-shot 36 | performance = evaluation_SYSU_MM01(feature_info, data_split, model, setting, result_dir); 37 | 38 | disp('indoor_search_single_shot') 39 | setting.mode = 'indoor_search'; %'all_search' 'indoor_search' 40 | %setting.mode = 'indoor_search'; 41 | setting.number_shot = 1; % 1 for single shot, 10 for multi-shot 42 | performance = evaluation_SYSU_MM01(feature_info, data_split, model, setting, result_dir); 43 | disp('indoor_search_multi_shot') 44 | setting.mode = 'indoor_search'; %'all_search' 'indoor_search' 45 | %setting.mode = 'indoor_search'; 46 | setting.number_shot = 10; % 1 for single shot, 10 for multi-shot 47 | performance = evaluation_SYSU_MM01(feature_info, data_split, model, setting, result_dir); -------------------------------------------------------------------------------- /evaluation/euclidean_dist.m: -------------------------------------------------------------------------------- 1 | function [ dist ] = euclidean_dist( X_gallery, X_probe, model_para ) 2 | dist=pdist2(X_probe,X_gallery,'euclidean'); 3 | end 4 | 5 | -------------------------------------------------------------------------------- /evaluation/evaluate_SYSU_MM01.py: -------------------------------------------------------------------------------- 1 | import os, sys, os.path as osp 2 | from evaluation import gen_utils 3 | import numpy as np 4 | 5 | 6 | def evaluate(feature_dir, prefix, settings, total_runs=10): 7 | """to evaluate and rank results for SYSU_MM01 dataset 8 | 9 | Arguments: 10 | feature_dir {str} -- a dir where features are saved 11 | prefix {str} -- prefix of file names 12 | """ 13 | gallery_cams, probe_cams = gen_utils.get_cam_settings(settings) 14 | all_cams = list(set(gallery_cams + probe_cams)) # get unique cams 15 | features = {} 16 | 17 | # get permutation indices 18 | cam_permutations = gen_utils.get_cam_permutation_indices(all_cams) 19 | 20 | # get test ids 21 | test_ids = gen_utils.get_test_ids() 22 | 23 | for cam_index in all_cams: 24 | # read features 25 | cam_feature_file = osp.join(feature_dir, ("cam{}").format(cam_index)) 26 | features["cam" + str(cam_index)] = gen_utils.load_feature_file(cam_feature_file) 27 | 28 | # perform testing 29 | print(list(features.keys())) 30 | cam_id_locations = [1, 2, 2, 4, 5, 6] 31 | # camera 2 and 3 are in the same location 32 | mAPs = [] 33 | cmcs = [] 34 | 35 | for run_index in range(total_runs): 36 | print("trial #{}".format(run_index)) 37 | X_gallery, Y_gallery, cam_gallery, X_probe, Y_probe, cam_probe = gen_utils.get_testing_set( 38 | features, 39 | cam_permutations, 40 | test_ids, 41 | run_index, 42 | cam_id_locations, 43 | gallery_cams, 44 | probe_cams, 45 | settings, 46 | ) 47 | 48 | # print(X_gallery.shape, Y_gallery.shape, cam_gallery.shape, X_probe.shape, Y_probe.shape, cam_probe.shape) 49 | 50 | dist = gen_utils.euclidean_dist(X_probe, X_gallery) 51 | 52 | cmc = gen_utils.get_cmc_multi_cam( 53 | Y_gallery, cam_gallery, Y_probe, cam_probe, dist 54 | ) 55 | mAP = gen_utils.get_mAP_multi_cam( 56 | Y_gallery, cam_gallery, Y_probe, cam_probe, dist 57 | ) 58 | 59 | print("rank 1 5 10 20", cmc[[0, 4, 9, 19]]) 60 | print("mAP", mAP) 61 | cmcs.append(cmc) 62 | mAPs.append(mAP) 63 | 64 | # find mean mAP and cmc 65 | cmcs = np.array(cmcs) # 10 x #gallery 66 | mAPs = np.array(mAPs) # 10 67 | mean_cmc = np.mean(cmcs, axis=0) 68 | mean_mAP = np.mean(mAPs) 69 | print("mean rank 1 5 10 20", mean_cmc[[0, 4, 9, 19]]) 70 | print("mean mAP", mean_mAP) 71 | 72 | 73 | def evaluate_results(feature_dir, prefix, mode, number_shot, total_runs=10): 74 | # evaluation settings 75 | print( 76 | "running evaluation for features from {} with prefix {} with mode {} and number shot {}".format( 77 | feature_dir, prefix, mode, number_shot 78 | ) 79 | ) 80 | print('total test runs:', total_runs) 81 | settings = {} 82 | settings["mode"] = mode # indoor | all 83 | settings["number_shot"] = number_shot # 1 = single-shot | 10 = multi-shot 84 | evaluate(feature_dir, prefix, settings, total_runs) 85 | 86 | 87 | if __name__ == "__main__": 88 | # example function call 89 | feature_dir = "/home/txd/Variational Distillation/" 90 | prefix = "cam" 91 | 92 | evaluate_results(feature_dir, prefix, 'all', 1) -------------------------------------------------------------------------------- /evaluation/evaluation_SYSU_MM01.m: -------------------------------------------------------------------------------- 1 | function [performance] = evaluation_SYSU_MM01(feature_info, data_split, model, setting, result_dir) 2 | % evaluation on SYSU-MM01 dataset with input features and model 3 | % input: 4 | % Features of each cameras are saved in seperated mat files named "name_cam#.mat" 5 | % In the mat files, feature{id}(i,:) is a row feature vector of the i-th image of id 6 | % feature_info.name = 'feat_deep_zero_padding'; 7 | % feature_info.dir = './feature'; 8 | % result_dir = './result'; % directory for saving result 9 | % 10 | % setting.mode = 'all_search'; %'all_search' 'indoor_search' 11 | % setting.number_shot = 1; % 1 for single shot, 10 for multi-shot 12 | % 13 | % model.test_fun = @euclidean_dist; % Similarity measurement function (You could define your own function in the same way as euclidean_dist function) 14 | % model.name = 'euclidean'; % model name 15 | % model.para = []; % No parameter is needed for euclidean distance here (If mahalanobis distance is used, the parameter can be the metric M learned from training data) 16 | % 17 | % content = load('./data_split/test_id.mat'); % fixed testing person IDs 18 | % data_split.test_id = content.id; 19 | % 20 | % content = load('./data_split/rand_perm_cam.mat'); % fixed permutation of samples in each camera 21 | % data_split.rand_perm_cam = content.rand_perm_cam; 22 | % 23 | % output: 24 | % performance.cmc_mean & performance.map_mean - average results of 10 trials 25 | % performance.cmc_all & performance.map_mean - results of each trial 26 | 27 | feature_name = feature_info.name; 28 | feature_dir = feature_info.dir; 29 | 30 | mode = setting.mode; % 'all_search' 'indoor_search' 31 | number_shot = setting.number_shot; % 1 for single shot, 10 for multi-shot 32 | test_id = data_split.test_id; 33 | rand_perm_cam = data_split.rand_perm_cam; 34 | 35 | %% begin 36 | switch mode 37 | case 'all_search' 38 | gallery_cam_list=[1 2 4 5]; 39 | probe_cam_list=[3 6]; 40 | case 'indoor_search' 41 | gallery_cam_list=[1 2]; 42 | probe_cam_list=[3 6]; 43 | otherwise 44 | disp('mode input error'); 45 | end 46 | 47 | % load features of 6 cameras 48 | load_cam_list = union(probe_cam_list,gallery_cam_list); 49 | cam_count = 6; 50 | feature_cam=cell(cam_count,1); 51 | Y=cell(cam_count,1); 52 | 53 | cam_id = [1 2 2 4 5 6]; % camera 2 and 3 are in the same location 54 | 55 | for i_cam=1:length(load_cam_list) 56 | cam_label=load_cam_list(i_cam); 57 | load_name=[feature_name 'cam' num2str(cam_label) '.mat']; 58 | content=load(fullfile(feature_dir,load_name)); 59 | feature_cam{cam_label}=content.feature_test; 60 | feature_cam{cam_label}=feature_cam{cam_label}'; 61 | Y{cam_label}=(1:length(content.feature_test)); 62 | end 63 | clear content 64 | 65 | % begin testing 66 | cmc_all=cell(10,1); 67 | map_all=zeros(10,1); 68 | 69 | for run_time=1:10 70 | %disp(['trial #',num2str(run_time)]); 71 | % For X_..., each row is an observation 72 | [X_gallery, Y_gallery, Y_cam_gallery, X_probe, Y_probe, Y_cam_probe]=get_testing_set... 73 | (feature_cam, Y, rand_perm_cam, run_time, number_shot, gallery_cam_list, probe_cam_list, test_id, cam_id); 74 | dist = model.test_fun(X_gallery,X_probe,model.para); 75 | cmc = get_cmc_multi_cam(Y_gallery,Y_cam_gallery,Y_probe,Y_cam_probe,dist); 76 | map = get_map_multi_cam(Y_gallery,Y_cam_gallery,Y_probe,Y_cam_probe,dist); 77 | %disp('rank 1 5 10 20'); 78 | %disp(cmc([1 5 10 20])); 79 | %disp('mAP'); 80 | %disp(map); 81 | cmc_all{run_time}=cmc; 82 | map_all(run_time,:)=map; 83 | end 84 | cmc_all=cell2mat(cmc_all); 85 | performance.cmc_all=cmc_all; 86 | performance.map_all=map_all; 87 | cmc_mean=mean(performance.cmc_all); 88 | performance.cmc_mean=cmc_mean; 89 | map_mean=mean(performance.map_all); 90 | performance.map_mean=map_mean; 91 | 92 | % display 93 | disp('Average CMC:'); 94 | disp('rank 1 5 10 20'); 95 | disp(cmc_mean([ 1 5 10 20])); 96 | disp('Average mAP:'); 97 | disp(map_mean); 98 | 99 | % save 100 | save_path=fullfile(result_dir,['result_' feature_name '_' model.name '_' mode '_' num2str(number_shot) 'shot.mat']); 101 | save(save_path,'performance','setting','-v7.3'); 102 | 103 | end -------------------------------------------------------------------------------- /evaluation/gen_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | from scipy.spatial.distance import cdist 4 | import scipy.io as sio 5 | 6 | 7 | def euclidean_dist(X_probe, X_gallery): 8 | return cdist(X_probe, X_gallery) 9 | 10 | 11 | def get_cam_settings(settings): 12 | if settings["mode"] == "indoor": 13 | gallery_cams, probe_cams = [1, 2], [3, 6] 14 | elif settings["mode"] == "all": 15 | gallery_cams, probe_cams = [1, 2, 4, 5], [3, 6] 16 | else: 17 | assert False, "unknown search mode : " + settings["mode"] 18 | 19 | return gallery_cams, probe_cams 20 | 21 | 22 | def get_test_ids(): 23 | test_id_filepath = "./data_split/test_id.mat" 24 | filecontents = sio.loadmat(test_id_filepath) 25 | test_ids = filecontents["id"].squeeze() 26 | 27 | # make the test ids 0-based 28 | test_ids = test_ids - 1 29 | return test_ids 30 | 31 | 32 | def get_cam_permutation_indices(all_cams): 33 | rand_cam_perm_filepath = "./data_split/rand_perm_cam.mat" 34 | filecontents = sio.loadmat(rand_cam_perm_filepath) 35 | mat_cam_permutations = filecontents["rand_perm_cam"] 36 | 37 | # buffer to hold all permutations 38 | all_permutations = {} 39 | 40 | for cam_index in all_cams: 41 | cam_permutations = mat_cam_permutations[cam_index - 1][0].squeeze() 42 | cam_name = "cam" + str(cam_index) 43 | if cam_name not in all_permutations: 44 | all_permutations[cam_name] = {} 45 | 46 | # logistics 47 | print("{} ids found in cam {}".format(len(cam_permutations), cam_index)) 48 | 49 | # collect permutations for all person indices 50 | for person_index, rand_permutations in enumerate(cam_permutations): 51 | all_permutations[cam_name][person_index] = rand_permutations - 1 52 | 53 | return all_permutations 54 | 55 | 56 | def load_feature_file(cam_feature_file): 57 | filecontents = sio.loadmat(cam_feature_file) 58 | mat_features = filecontents["feature_test"].squeeze() 59 | all_features = {} 60 | 61 | # collect features for each person (total 333 persons) 62 | for person_index, current_features in enumerate(mat_features): 63 | all_features[person_index] = current_features 64 | 65 | return all_features 66 | 67 | 68 | def get_testing_set( 69 | features, 70 | cam_permutations, 71 | test_ids, 72 | run_index, 73 | cam_id_locations, 74 | gallery_cams, 75 | probe_cams, 76 | settings, 77 | ): 78 | # cam is indexed from 1 - 6 79 | # person indices are indexed using 0-based numbers 80 | X_gallery_rgb, Y_gallery_rgb, cam_gallery_rgb, X_probe_IR, Y_probe_IR, cam_probe_IR = ( 81 | [], 82 | [], 83 | [], 84 | [], 85 | [], 86 | [], 87 | ) 88 | 89 | number_shot = settings["number_shot"] 90 | 91 | # collect rgb images as gallery 92 | for cam_index in gallery_cams: 93 | cam_name = "cam" + str(cam_index) 94 | current_cam_features = features[cam_name] 95 | 96 | # for all the test ids, collect features 97 | for test_id in test_ids: 98 | current_id_features = current_cam_features[test_id] 99 | 100 | if np.any(np.array(current_id_features.shape) == 0): 101 | continue 102 | # assert (not np.any(np.array(current_id_features.shape) == 0)), 'test id feature count is 0' 103 | 104 | # get the current permutation 105 | current_permutation = cam_permutations[cam_name][test_id][run_index] 106 | current_permutation = current_permutation[:number_shot] 107 | 108 | selected_features = current_id_features[current_permutation, :] 109 | 110 | if len(X_gallery_rgb) == 0: 111 | X_gallery_rgb = selected_features 112 | else: 113 | X_gallery_rgb = np.concatenate( 114 | (X_gallery_rgb, selected_features), axis=0 115 | ) 116 | 117 | Y_gallery_rgb += [test_id] * number_shot 118 | cam_gallery_rgb += [cam_id_locations[cam_index - 1]] * number_shot 119 | 120 | Y_gallery_rgb = np.array(Y_gallery_rgb, dtype=np.int) 121 | cam_gallery_rgb = np.array(cam_gallery_rgb, dtype=np.int) 122 | 123 | # collect all the IR 124 | for cam_index in probe_cams: 125 | cam_name = "cam" + str(cam_index) 126 | current_cam_features = features[cam_name] 127 | 128 | # for all the test ids, collect features 129 | for test_id in test_ids: 130 | current_id_features = current_cam_features[test_id] 131 | if np.any(np.array(current_id_features.shape) == 0): 132 | continue 133 | # assert len(current_id_features) != 0, 'test id feature count is 0' 134 | 135 | if len(X_probe_IR) == 0: 136 | X_probe_IR = current_id_features 137 | 138 | else: 139 | X_probe_IR = np.concatenate((X_probe_IR, current_id_features), axis=0) 140 | 141 | Y_probe_IR += [test_id] * len(current_id_features) 142 | cam_probe_IR += [cam_id_locations[cam_index - 1]] * len(current_id_features) 143 | 144 | Y_probe_IR = np.array(Y_probe_IR, dtype=np.int) 145 | cam_probe_IR = np.array(cam_probe_IR, dtype=np.int) 146 | 147 | return ( 148 | X_gallery_rgb, 149 | Y_gallery_rgb, 150 | cam_gallery_rgb, 151 | X_probe_IR, 152 | Y_probe_IR, 153 | cam_probe_IR, 154 | ) 155 | 156 | 157 | def get_unique(array): 158 | _, idx = np.unique(array, return_index=True) 159 | return array[np.sort(idx)] 160 | 161 | 162 | def get_cmc_multi_cam(Y_gallery, cam_gallery, Y_probe, cam_probe, dist): 163 | # dist = #probe x #gallery 164 | num_probes, num_gallery = dist.shape 165 | gallery_unique_count = get_unique(Y_gallery).shape[0] 166 | match_counter = np.zeros((gallery_unique_count)) 167 | 168 | # sort the distance matrix 169 | sorted_indices = np.argsort(dist, axis=-1) 170 | 171 | Y_result = Y_gallery[sorted_indices] 172 | cam_locations_result = cam_gallery[sorted_indices] 173 | 174 | valid_probe_sample_count = 0 175 | 176 | for probe_index in range(num_probes): 177 | # remove gallery samples from the same camera of the probe 178 | Y_result_i = Y_result[probe_index, :] 179 | Y_result_i[cam_locations_result[probe_index, :] == cam_probe[probe_index]] = -1 180 | 181 | # remove the -1 entries from the label result 182 | # print(Y_result_i.shape) 183 | Y_result_i = np.array([i for i in Y_result_i if i != -1]) 184 | # print(Y_result_i.shape) 185 | 186 | # remove duplicated id in "stable" manner 187 | Y_result_i_unique = get_unique(Y_result_i) 188 | # print(Y_result_i_unique, gallery_unique_count) 189 | 190 | # match for probe i 191 | match_i = Y_result_i_unique == Y_probe[probe_index] 192 | 193 | if np.sum(match_i) != 0: # if there is true matching in gallery 194 | valid_probe_sample_count += 1 195 | match_counter += match_i 196 | 197 | rankk = match_counter / valid_probe_sample_count 198 | cmc = np.cumsum(rankk) 199 | return cmc 200 | 201 | 202 | def get_mAP_multi_cam(Y_gallery, cam_gallery, Y_probe, cam_probe, dist): 203 | # dist = #probe x #gallery 204 | num_probes, num_gallery = dist.shape 205 | 206 | # sort the distance matrix 207 | sorted_indices = np.argsort(dist, axis=-1) 208 | 209 | Y_result = Y_gallery[sorted_indices] 210 | cam_locations_result = cam_gallery[sorted_indices] 211 | 212 | valid_probe_sample_count = 0 213 | avg_precision_sum = 0 214 | 215 | for probe_index in range(num_probes): 216 | # remove gallery samples from the same camera of the probe 217 | Y_result_i = Y_result[probe_index, :] 218 | Y_result_i[cam_locations_result[probe_index, :] == cam_probe[probe_index]] = -1 219 | 220 | # remove the -1 entries from the label result 221 | # print(Y_result_i.shape) 222 | Y_result_i = np.array([i for i in Y_result_i if i != -1]) 223 | # print(Y_result_i.shape) 224 | 225 | # match for probe i 226 | match_i = Y_result_i == Y_probe[probe_index] 227 | true_match_count = np.sum(match_i) 228 | 229 | if true_match_count != 0: # if there is true matching in gallery 230 | valid_probe_sample_count += 1 231 | true_match_rank = np.where(match_i)[0] 232 | 233 | ap = np.mean( 234 | np.array(range(1, true_match_count + 1)) / (true_match_rank + 1) 235 | ) 236 | avg_precision_sum += ap 237 | 238 | mAP = avg_precision_sum / valid_probe_sample_count 239 | return mAP -------------------------------------------------------------------------------- /evaluation/get_cmc_multi_cam.m: -------------------------------------------------------------------------------- 1 | function [cmc,ind]=get_cmc_multi_cam(Y_gallery,Y_cam_gallery,Y_probe,Y_cam_probe,dist) 2 | [~, ind]=sort(dist,2); 3 | Y_result=Y_gallery(ind); 4 | Y_cam_result=Y_cam_gallery(ind); 5 | valid_probe_sample_count=0; 6 | gallery_unique_count=length(unique(Y_gallery)); 7 | match_counter=zeros(1,gallery_unique_count); 8 | probe_sample_count=length(Y_probe); 9 | for i=1:probe_sample_count 10 | % remove gallery samples from the same camera of the probe 11 | Y_result_i=Y_result(i,:); 12 | Y_result_i(Y_cam_result(i,:)==Y_cam_probe(i))=[]; 13 | % remove duplicated id 14 | Y_result_unique_i=unique(Y_result_i,'stable'); 15 | % match for probe i 16 | match_i=(Y_result_unique_i==Y_probe(i)); 17 | if sum(match_i)~=0 % if there is true matching in gallery 18 | valid_probe_sample_count=valid_probe_sample_count+1; 19 | for r=1:length(match_i) 20 | match_counter(r)=match_counter(r)+match_i(r); 21 | end 22 | end 23 | end 24 | rankk=match_counter/valid_probe_sample_count; 25 | cmc=cumsum(rankk); 26 | end -------------------------------------------------------------------------------- /evaluation/get_map_multi_cam.m: -------------------------------------------------------------------------------- 1 | function map=get_map_multi_cam(Y_gallery,Y_cam_gallery,Y_probe,Y_cam_probe,dist) 2 | [~, ind]=sort(dist,2); 3 | Y_result=Y_gallery(ind); 4 | Y_cam_result=Y_cam_gallery(ind); 5 | valid_probe_sample_count=0; 6 | probe_sample_count=length(Y_probe); 7 | ap_sum=0; 8 | for i=1:probe_sample_count 9 | % remove gallery samples from the same camera of the probe 10 | Y_result_i=Y_result(i,:); 11 | Y_result_i(Y_cam_result(i,:)==Y_cam_probe(i))=[]; 12 | % match for probe i 13 | match_i=(Y_result_i==Y_probe(i)); 14 | true_match_count=sum(match_i); 15 | if true_match_count~=0 % if there is true matching in gallery 16 | valid_probe_sample_count=valid_probe_sample_count+1; 17 | true_match_rank=find(match_i==1); 18 | ap=mean((1:true_match_count)./true_match_rank); 19 | ap_sum=ap_sum+ap; 20 | end 21 | end 22 | map=ap_sum/valid_probe_sample_count; 23 | end -------------------------------------------------------------------------------- /evaluation/get_testing_set.m: -------------------------------------------------------------------------------- 1 | function [X_gallery, Y_gallery, Y_cam_gallery, X_probe, Y_probe, Y_cam_probe]=get_testing_set... 2 | (feature_cam, Y, rand_perm_cam, run_time, number_shot, gallery_cam_list, probe_cam_list, test_id, cam_id) 3 | % get testing set for SYSU-MM01 multi-modality re-id dataset 4 | % input: 5 | % feature_cam - feature_cam{i}{id} is feature matrix (each row is a feature) of cam i of person id 6 | % Y - person id for each cell 7 | % rand_perm - rand permutation of indices for selecting gallery 8 | % run_time - current count of evaluation time 9 | % number_shot - 1 single shot, 5 multi-shot, -1 all except one 10 | % test_cam_list - cam list of testing set 11 | % test_id - list of testing persons 12 | % output: 13 | % X_... - feature vectors in each row 14 | % Y_... - person label 15 | % Y_cam_... - camera label 16 | 17 | gallery_cam_count=length(gallery_cam_list); 18 | X_gallery=cell(gallery_cam_count,1); 19 | Y_gallery=cell(gallery_cam_count,1); 20 | Y_cam_gallery=cell(gallery_cam_count,1); 21 | probe_cam_count=length(probe_cam_list); 22 | X_probe=cell(probe_cam_count,1); 23 | Y_probe=cell(probe_cam_count,1); 24 | Y_cam_probe=cell(probe_cam_count,1); 25 | 26 | % gallery 27 | for i_cam=1:gallery_cam_count 28 | cam_num=gallery_cam_list(i_cam); 29 | cam_label=cam_id(cam_num); 30 | Y_i_cam=Y{cam_num}; 31 | id_count=length(Y_i_cam); 32 | 33 | X_gallery{i_cam}=cell(id_count,1); 34 | Y_gallery{i_cam}=[]; 35 | Y_cam_gallery{i_cam}=[]; 36 | for i_id=1:id_count 37 | if isempty(find(test_id==Y_i_cam(i_id))) 38 | continue; 39 | end 40 | rand_perm_this=rand_perm_cam{cam_num}{i_id}(run_time,:); 41 | if isempty(rand_perm_this) 42 | continue; 43 | end 44 | if number_shot<0 45 | ind_gallery_this=rand_perm_this(1:end+number_shot); 46 | else 47 | ind_gallery_this=rand_perm_this(1:number_shot); 48 | end 49 | X_test_this=feature_cam{cam_num}{i_id}; 50 | X_gallery{i_cam}{i_id}=X_test_this(ind_gallery_this,:); 51 | frame_count=size(X_gallery{i_cam}{i_id},1); 52 | Y_gallery{i_cam}=[Y_gallery{i_cam};repmat(Y_i_cam(i_id),frame_count,1)]; 53 | Y_cam_gallery{i_cam}=[Y_cam_gallery{i_cam};repmat(cam_label,frame_count,1)]; 54 | end 55 | X_gallery{i_cam}=cell2mat(X_gallery{i_cam}); 56 | end 57 | 58 | % probe 59 | for i_cam=1:probe_cam_count 60 | cam_num=probe_cam_list(i_cam); 61 | cam_label=cam_id(cam_num); 62 | Y_i_cam=Y{cam_num}; 63 | id_count=length(Y_i_cam); 64 | 65 | X_probe{i_cam}=cell(id_count,1); 66 | Y_probe{i_cam}=[]; 67 | Y_cam_probe{i_cam}=[]; 68 | for i_id=1:id_count 69 | if isempty(find(test_id==Y_i_cam(i_id))) 70 | continue; 71 | end 72 | rand_perm_this=rand_perm_cam{cam_num}{i_id}(run_time,:); 73 | if isempty(rand_perm_this) 74 | continue; 75 | end 76 | X_test_this=feature_cam{cam_num}{i_id}; 77 | X_probe{i_cam}{i_id}=X_test_this; 78 | frame_count=size(X_probe{i_cam}{i_id},1); 79 | Y_probe{i_cam}=[Y_probe{i_cam};repmat(Y_i_cam(i_id),frame_count,1)]; 80 | Y_cam_probe{i_cam}=[Y_cam_probe{i_cam};repmat(cam_label,frame_count,1)]; 81 | end 82 | X_probe{i_cam}=cell2mat(X_probe{i_cam}); 83 | end 84 | 85 | X_gallery=cell2mat(X_gallery); 86 | Y_gallery=cell2mat(Y_gallery); 87 | Y_cam_gallery=cell2mat(Y_cam_gallery); 88 | 89 | X_probe=cell2mat(X_probe); 90 | Y_probe=cell2mat(Y_probe); 91 | Y_cam_probe=cell2mat(Y_cam_probe); 92 | 93 | end -------------------------------------------------------------------------------- /evaluation/result/result__euclidean_all_search_10shot.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/result/result__euclidean_all_search_10shot.mat -------------------------------------------------------------------------------- /evaluation/result/result__euclidean_all_search_1shot.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/result/result__euclidean_all_search_1shot.mat -------------------------------------------------------------------------------- /evaluation/result/result__euclidean_indoor_search_10shot.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/result/result__euclidean_indoor_search_10shot.mat -------------------------------------------------------------------------------- /evaluation/result/result__euclidean_indoor_search_1shot.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/result/result__euclidean_indoor_search_1shot.mat -------------------------------------------------------------------------------- /evaluation/train_id.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/train_id.mat -------------------------------------------------------------------------------- /evaluation/val_id.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/val_id.mat -------------------------------------------------------------------------------- /extract_feature.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | 4 | import argparse 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import scipy.io as sio 9 | from torch.backends import cudnn 10 | 11 | from reid.models.newresnet import * 12 | from reid.utils.serialization import load_checkpoint 13 | import torchvision.transforms as transforms 14 | from PIL import Image 15 | from reid.utils.data import transforms as T 16 | 17 | 18 | def getArgs(): 19 | parser = argparse.ArgumentParser(description="Cross_modality for Person Re-identification") 20 | # data 21 | parser.add_argument('--height', type=int, default=256, 22 | help="input height, default: 256 for resnet*, " 23 | "144 for inception") 24 | parser.add_argument('--width', type=int, default=128, 25 | help="input width, default: 128 for resnet*, " 26 | "56 for inception") 27 | 28 | parser.add_argument('--features', type=int, default=2048) 29 | parser.add_argument('--dropout', type=float, default=0) 30 | 31 | # Bottleneck 32 | parser.add_argument('-z_dim', type=int, default=256, 33 | help="dimension of latent z, better belongs to {128, 256, 512}") 34 | # device set 35 | parser.add_argument('--visible_device', default='1, 0', type=str, help='gpu_ids: e.g. 0, 0,1,2 0,2') 36 | parser.add_argument('--weight-decay', type=float, default=5e-4) 37 | return parser.parse_args() 38 | 39 | def count_param(model): 40 | param_count = 0 41 | for param in model.parameters(): 42 | param_count += param.view(-1).size()[0] 43 | return param_count 44 | 45 | args = getArgs() 46 | 47 | features = 8192 # 8192 for observation (four in total), 2048 for representation (four in total) 48 | os.environ["CUDA_VISIBLE_DEVICES"] = "1, 0" 49 | model = ft_net(args=args, num_classes=395, num_features=args.features) 50 | model_path = "/home/txd/Variational Distillation/Exp_before_OpenSource_7/" 51 | checkpoint_I = load_checkpoint(model_path + 'model_best.pth.tar') 52 | model = nn.DataParallel(model, device_ids=[0, 1]) 53 | cudnn.benchmark = True 54 | ######################################### 55 | checkpoint = load_checkpoint(model_path + 'model_best.pth.tar') 56 | state_dict = checkpoint['model'] 57 | from collections import OrderedDict 58 | new_state_dict = OrderedDict() 59 | for k, v in state_dict.items(): 60 | if 'module' not in k or 'cam_module' in k: 61 | k = 'module.' + k 62 | else: 63 | k = k.replace('features.module.', 'module.features.') 64 | new_state_dict[k] = v 65 | model.load_state_dict(new_state_dict) 66 | ######################################### 67 | model = model.cuda() 68 | model.eval() 69 | normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], 70 | std=[0.229, 0.224, 0.225]) 71 | 72 | test_transformer = T.Compose([ 73 | T.Resize((256, 128)), 74 | T.ToTensor(), 75 | normalizer, 76 | ]) 77 | 78 | print('model initialed \n') 79 | print("The dimension of the feature is " + str(features)) 80 | print("Total params: " + str(count_param(model))) 81 | 82 | cam_name = '' 83 | 84 | path_list = ['./data/sysu/SYSU-MM01/cam1/','./data/sysu/SYSU-MM01/cam2/','./data/sysu/SYSU-MM01/cam3/', \ 85 | './data/sysu/SYSU-MM01/cam4/','./data/sysu/SYSU-MM01/cam5/','./data/sysu/SYSU-MM01/cam6/'] 86 | pic_num = [333,333,533,533,533,333] 87 | for index, path in enumerate(path_list): 88 | print(index) 89 | cams = torch.LongTensor([index]) 90 | sub = ((cams == 2).long() + (cams == 5).long()).cuda() 91 | count = 1 92 | array_list = [] 93 | person_id_list = [] 94 | dict_person = {} 95 | tot_num = pic_num[index] 96 | array_list_to_array = [[] for _ in range(tot_num)] 97 | #print(path) 98 | for fpathe,dirs,fs in os.walk(path): 99 | person_id = fpathe.split('/')[-1] 100 | if(person_id == ''): 101 | continue 102 | cam_name = fpathe[-9:-5] 103 | fs.sort() 104 | person_id_list.append(person_id) 105 | dict_person[person_id] = fs 106 | person_id_list.sort() 107 | for person in person_id_list: 108 | temp_list = [] 109 | for imagename in dict_person[person]: 110 | filename = path + str(person) + '/' + imagename 111 | img=Image.open(filename) 112 | img=test_transformer(img) 113 | img=img.view([1,3,256,128]) 114 | img=img.cuda() 115 | i_observation, i_representation, i_ms_observation, i_ms_representation, \ 116 | v_observation, v_representation, v_ms_observation, v_ms_representation = model(img) 117 | result_y = torch.cat(tensors=[i_observation[1], i_ms_observation[1], 118 | v_observation[1], v_ms_observation[1]], dim=1) 119 | 120 | result_y = torch.nn.functional.normalize(result_y, dim=1, p=2) 121 | result_y = result_y.view(-1, features) 122 | result_y = result_y.squeeze() 123 | result_npy=result_y.data.cpu().numpy() 124 | result_npy=result_npy.astype('double') 125 | temp_list.append(result_npy) 126 | temp_array = np.array(temp_list) 127 | array_list_to_array[int(person)-1]=temp_array 128 | array_list_to_array = np.array(array_list_to_array) 129 | 130 | sio.savemat(model_path + "observation" + cam_name + '.mat', {'feature_test':array_list_to_array}) 131 | 132 | -------------------------------------------------------------------------------- /images/embedding_spaces.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/images/embedding_spaces.jpg -------------------------------------------------------------------------------- /images/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/images/framework.jpg -------------------------------------------------------------------------------- /images/joint_embedding_spaces.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/images/joint_embedding_spaces.jpg -------------------------------------------------------------------------------- /mm01.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import os 5 | import numpy as np 6 | import sys 7 | import torch 8 | from torch import nn 9 | from torch.backends import cudnn 10 | from torch.utils.data import DataLoader 11 | from reid import datasets 12 | from reid.dist_metric import DistanceMetric 13 | from reid.models import ft_net 14 | from reid.trainers import Trainer 15 | from reid.evaluators import Evaluator 16 | from reid.utils.data import transforms as T 17 | from reid.utils.data.preprocessor import Preprocessor 18 | from reid.utils.data.sampler import CamRandomIdentitySampler as RandomIdentitySampler 19 | from reid.utils.data.sampler import CamSampler 20 | from reid.utils.logging import Logger 21 | from reid.utils.serialization import load_checkpoint, save_checkpoint 22 | from utlis import RandomErasing, WarmupMultiStepLR, CrossEntropyLabelSmooth, Rank_loss, ASS_loss 23 | 24 | 25 | def get_data(name, split_id, data_dir, height, width, batch_size, num_instances, 26 | workers, combine_trainval, flip_prob, padding, re_prob, using_HuaWeiCloud, cloud_dataset_root): 27 | root = osp.join(data_dir, name) 28 | 29 | if using_HuaWeiCloud: root = cloud_dataset_root 30 | 31 | print(root) 32 | 33 | dataset = datasets.create(name, root, split_id=split_id) 34 | 35 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 36 | std=[0.229, 0.224, 0.225]) 37 | 38 | trainvallabel = dataset.trainvallabel 39 | train_set = dataset.trainval if combine_trainval else dataset.train 40 | num_classes = (dataset.num_trainval_ids if combine_trainval 41 | else dataset.num_train_ids) 42 | 43 | train_transformer = T.Compose([ 44 | T.Resize((height, width)), 45 | T.RandomHorizontalFlip(p=flip_prob), 46 | T.Pad(padding), 47 | T.RandomCrop((height, width)), 48 | T.ToTensor(), 49 | normalizer, 50 | RandomErasing(probability=re_prob, mean=[0.485, 0.456, 0.406]) 51 | ]) 52 | 53 | test_transformer = T.Compose([ 54 | T.Resize((height, width)), 55 | T.ToTensor(), 56 | normalizer, 57 | ]) 58 | 59 | val_loader = DataLoader( 60 | Preprocessor(dataset.val, root=dataset.images_dir, 61 | transform=test_transformer), 62 | batch_size=32, num_workers=workers, 63 | shuffle=False, pin_memory=True) 64 | 65 | query_loader = DataLoader( 66 | Preprocessor(list(set(dataset.query)), 67 | root=dataset.images_dir, transform=test_transformer), 68 | batch_size=32, num_workers=workers, 69 | sampler=CamSampler(list(set(dataset.query)), [2,5]), 70 | shuffle=False, pin_memory=True) 71 | 72 | gallery_loader = DataLoader( 73 | Preprocessor(list(set(dataset.gallery)), 74 | root=dataset.images_dir, transform=test_transformer), 75 | batch_size=32, num_workers=workers, 76 | sampler=CamSampler(list(set(dataset.gallery)), [0,1,3,4], 4), 77 | shuffle=False, pin_memory=True) 78 | 79 | train_loader = DataLoader( 80 | Preprocessor(train_set, root=dataset.images_dir, 81 | transform=train_transformer), 82 | batch_size=batch_size, num_workers=workers, 83 | sampler=RandomIdentitySampler(train_set, num_instances), 84 | pin_memory=True, drop_last=True) 85 | 86 | return dataset, num_classes, train_loader, trainvallabel, val_loader, query_loader, gallery_loader 87 | 88 | def main(args): 89 | np.random.seed(args.seed) 90 | torch.manual_seed(args.seed) 91 | cudnn.benchmark = True 92 | 93 | if not args.evaluate: 94 | sys.stdout = Logger(osp.join(args.logs_dir+'/log')) 95 | 96 | if args.height is None or args.width is None: args.height, args.width = (256, 128) 97 | 98 | # Dataset and loader 99 | dataset, num_classes, train_loader, trainvallabel, val_loader, query_loader, gallery_loader = \ 100 | get_data(args.dataset, args.split, args.data_dir, args.height, 101 | args.width, args.batch_size, args.num_instances, args.workers, 102 | args.combine_trainval, args.flip_prob, args.padding, args.re_prob, 103 | args.HUAWEI_cloud, args.dataset_root) 104 | 105 | # Model settings 106 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_device 107 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 108 | model = ft_net(args=args, num_classes=num_classes, num_features=args.features) 109 | model = nn.DataParallel(model, device_ids=[0,1]) 110 | model = model.to(device) 111 | 112 | # Evaluation components 113 | evaluator = Evaluator(model) 114 | metric = DistanceMetric(algorithm=args.dist_metric) 115 | 116 | start_epoch = 0 117 | if args.resume: 118 | ######################################### 119 | checkpoint = load_checkpoint(args.resume) 120 | state_dict = checkpoint['model'] 121 | from collections import OrderedDict 122 | new_state_dict = OrderedDict() 123 | for k, v in state_dict.items(): 124 | if 'module' not in k: 125 | k = 'module.' + k 126 | else: 127 | k = k.replace('features.module.', 'module.features.') 128 | new_state_dict[k] = v 129 | model.load_state_dict(new_state_dict) 130 | ######################################### 131 | start_epoch = checkpoint['epoch'] 132 | print("=> Start epoch {}".format(start_epoch)) 133 | 134 | if args.evaluate: 135 | metric.train(model, train_loader) 136 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 137 | exit() 138 | 139 | # Losses 140 | ce_Loss = CrossEntropyLabelSmooth(num_classes= num_classes, epsilon=args.epsilon).cuda() 141 | associate_loss = ASS_loss().cuda() 142 | rank_Loss = Rank_loss(margin_1= args.margin_1, margin_2 =args.margin_2, alpha_1 =args.alpha_1, alpha_2= args.alpha_2).cuda() 143 | 144 | print(args) 145 | # optimizers and schedulers 146 | conv_optim = model.module.optims() 147 | 148 | conv_scheduler = WarmupMultiStepLR(conv_optim, args.mile_stone, args.gamma, args.warmup_factor, 149 | args.warmup_iters, args.warmup_methods) 150 | 151 | trainer = Trainer(args, model, ce_Loss, rank_Loss, associate_loss, trainvallabel) 152 | 153 | best_top1 = -1 154 | 155 | # Start training 156 | for epoch in range(start_epoch, args.epochs): 157 | conv_scheduler.step() 158 | 159 | triple_loss, tot_loss = trainer.train(epoch, train_loader, conv_optim) 160 | 161 | save_checkpoint({ 162 | 'model': model.module.state_dict(), 163 | 'epoch': epoch + 1, 164 | 'best_top1': best_top1, 165 | }, False, epoch, args.logs_dir, fpath='checkpoint.pth.tar') 166 | 167 | if epoch < args.begin_test: 168 | continue 169 | if not epoch % args.evaluate_freq == 0: 170 | continue 171 | 172 | top1 = evaluator.evaluate(query_loader, gallery_loader, metric) 173 | 174 | is_best = top1 > best_top1 175 | best_top1 = max(top1, best_top1) 176 | save_checkpoint({ 177 | 'model': model.module.state_dict(), 178 | 'epoch': epoch + 1, 179 | 'best_top1': best_top1, 180 | }, is_best, epoch, args.logs_dir, fpath='checkpoint.pth.tar') 181 | 182 | print('Test with best model:') 183 | print('\n * Finished epoch {:3d} top1: {:5.1%} best: {:5.1%}{}\n'. 184 | format(epoch, top1, best_top1, ' *' if is_best else '')) 185 | print(args) 186 | 187 | 188 | if __name__ == '__main__': 189 | parser = argparse.ArgumentParser(description="Cross_modality for Person Re-identification") 190 | 191 | # dataset settings 192 | parser.add_argument('-d', '--dataset', type=str, default='sysu', choices=datasets.names()) 193 | parser.add_argument('-b', '--batch-size', type=int, default=128) 194 | parser.add_argument('-j', '--workers', type=int, default=4) 195 | parser.add_argument('--split', type=int, default=0) 196 | 197 | # transformer 198 | parser.add_argument('--height', type=int, default=256) 199 | parser.add_argument('--width', type=int, default= 128) 200 | parser.add_argument('--flip_prob', type=float, default=0.5) 201 | parser.add_argument('--re_prob', type=float, default=0.0) 202 | parser.add_argument('--padding', type=int, default=0) 203 | parser.add_argument('--combine-trainval', default=True, action='store_true', 204 | help="train and val sets together for training, " 205 | "val set alone for validation") 206 | parser.add_argument('--num-instances', type=int, default=8, 207 | help="each minibatch consist of " 208 | "(batch_size // num_instances) identities, and " 209 | "each identity has num_instances instances, " 210 | "default: 4") 211 | # model 212 | parser.add_argument('--features', type=int, default=2048) 213 | 214 | # rank loss settings 215 | parser.add_argument('--margin_1', type=float, default=0.9, help="margin_1 of the triplet loss, default: 0.9") 216 | parser.add_argument('--margin_2', type=float, default=1.0, help="margin_1 of the triplet loss, default: 1.5") 217 | parser.add_argument('--alpha_1', type=float, default=2.2, help="alpha_1 of the triplet loss, default: 2.4") 218 | parser.add_argument('--alpha_2', type=float, default=2.0, help="alpha_2 of the triplet loss, default: 2.2") 219 | 220 | # optimizer and scheduler 221 | parser.add_argument('--lr', type=float, default=2.6e-4, help="learning rate of all parameters") 222 | parser.add_argument('--weight-decay', type=float, default=5e-4) 223 | parser.add_argument('--use_adam', action='store_true', help="use Adam as the optimizer, elsewise SGD ") 224 | parser.add_argument('--gamma', type=float, default = 0.1, help="gamma for learning rate decay") 225 | 226 | parser.add_argument('--mile_stone', type=list, default=[210]) 227 | 228 | parser.add_argument('--warmup_iters', type=int, default=10) 229 | parser.add_argument('--warmup_methods', type=str, default = 'linear', choices=('linear', 'constant')) 230 | parser.add_argument('--warmup_factor', type=float, default = 0.01 ) 231 | 232 | # training configs 233 | parser.add_argument('--resume', type=str, default='', metavar='') 234 | parser.add_argument('--evaluate',action='store_true', 235 | help="this option meaningless " 236 | "since it is required to conduct evaluation on officially approved codes") 237 | 238 | parser.add_argument('--epochs', type=int, default=600) 239 | parser.add_argument('--start_save', type=int, default=100, help="start saving checkpoints after specific epoch") 240 | parser.add_argument('--begin_test', type=int, default=100) 241 | parser.add_argument('--evaluate_freq', type=int, default=5) 242 | 243 | parser.add_argument('--seed', type=int, default=1) 244 | 245 | # adopted metric 246 | parser.add_argument('--dist-metric', type=str, default='euclidean') 247 | 248 | # misc 249 | working_dir = osp.dirname(osp.abspath(__file__)) 250 | parser.add_argument('--data-dir', type=str, metavar='PATH', 251 | default=osp.join(working_dir, 'data')) 252 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 253 | default=osp.join(working_dir, 'Exp_before_OpenSource_7_1')) 254 | # hyper-parameters 255 | parser.add_argument('-CE_loss', type=int, default=1, help="weight of cross entropy loss") 256 | parser.add_argument('-epsilon', type=float, default=0.1, help="label smooth") 257 | parser.add_argument('-Triplet_loss', type=int, default=1, help="weight of triplet loss") 258 | parser.add_argument('-Associate_loss', type=float, default=0.0, help="weight of loss") 259 | 260 | parser.add_argument('-CML_loss', type=int, default=8, help="the weight of conventional mutual learning") 261 | parser.add_argument('-VCD_loss', type=int, default=2, help="the weight of VCD and VML") 262 | parser.add_argument('-VSD_loss', type=float, default=2, help="weight of VSD") 263 | parser.add_argument('-temperature', type=int, default=1, help="the temperature used in knowledge distillation") 264 | 265 | # Bottleneck 266 | parser.add_argument('-z_dim', type=int, default=256, help="dimension of latent z, better set to {128, 256, 512}") 267 | # device set 268 | parser.add_argument('--visible_device', default='2, 1', type=str, help='gpu_ids: e.g. 0, 0,1,2 0,2') 269 | 270 | # HUAWEI cloud 271 | parser.add_argument('--HUAWEI_cloud', type=bool, default=False) 272 | parser.add_argument('--dataset_root', type=str, metavar='PATH', default="/test-ddag/dataset") 273 | parser.add_argument('--data_url', type=str, default="") 274 | parser.add_argument('--init_method', type=str, default="") 275 | parser.add_argument('--train_url', type=str, default="") 276 | 277 | main(parser.parse_args()) 278 | -------------------------------------------------------------------------------- /regdb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import os 5 | import numpy as np 6 | import sys 7 | import torch 8 | from torch import nn 9 | from torch.backends import cudnn 10 | from torch.utils.data import DataLoader 11 | from reid import datasets 12 | from reid.trainers import Trainer 13 | from reid.dist_metric import DistanceMetric 14 | from reid.evaluators_regdb import Evaluator 15 | from reid.models import ft_net 16 | from reid.utils.data import transforms as T 17 | from reid.utils.data.preprocessor import Preprocessor 18 | from reid.utils.data.sampler import CamRandomIdentitySampler as RandomIdentitySampler 19 | from reid.utils.data.sampler import CamSampler 20 | from reid.utils.logging import Logger 21 | from reid.utils.serialization import load_checkpoint, save_checkpoint 22 | from utlis import RandomErasing, WarmupMultiStepLR, CrossEntropyLabelSmooth, Rank_loss, ASS_loss 23 | 24 | 25 | 26 | def get_data(name, split_id, data_dir, height, width, batch_size, num_instances, 27 | workers, combine_trainval, flip_prob, padding, re_prob, ii): 28 | root = osp.join(data_dir, name) 29 | print(root) 30 | 31 | dataset = datasets.create(name, root, split_id=split_id, ii=ii) 32 | 33 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 34 | std=[0.229, 0.224, 0.225]) 35 | 36 | trainvallabel = dataset.trainvallabel 37 | train_set = dataset.trainval if combine_trainval else dataset.train 38 | num_classes = (dataset.num_trainval_ids if combine_trainval else dataset.num_train_ids) 39 | #print("Number of training classes:" + str(dataset.num_trainval_ids)) 40 | 41 | 42 | train_transformer = T.Compose([ 43 | T.Resize((height, width)), 44 | T.RandomHorizontalFlip(p=flip_prob), 45 | T.Pad(padding), 46 | T.RandomCrop((height, width)), 47 | T.ToTensor(), 48 | normalizer, 49 | RandomErasing(probability=re_prob, mean=[0.485, 0.456, 0.406]) 50 | ]) 51 | 52 | test_transformer = T.Compose([ 53 | T.Resize((height, width)), 54 | T.ToTensor(), 55 | normalizer, 56 | ]) 57 | 58 | val_loader = DataLoader( 59 | Preprocessor(dataset.val, root=dataset.images_dir, 60 | transform=test_transformer), 61 | batch_size=32, num_workers=workers, 62 | shuffle=False, pin_memory=True) 63 | 64 | query_loader = DataLoader( 65 | Preprocessor(list(set(dataset.query)), 66 | root=dataset.images_dir, transform=test_transformer), 67 | batch_size=32, num_workers=workers, 68 | sampler=CamSampler(list(set(dataset.query)), [2]), 69 | shuffle=False, pin_memory=True) 70 | 71 | gallery_loader = DataLoader( 72 | Preprocessor(list(set(dataset.gallery)), 73 | root=dataset.images_dir, transform=test_transformer), 74 | batch_size=32, num_workers=workers, 75 | sampler=CamSampler(list(set(dataset.gallery)), [0], 1), 76 | shuffle=False, pin_memory=True) 77 | 78 | train_loader = DataLoader( 79 | Preprocessor(train_set, root=dataset.images_dir, 80 | transform=train_transformer), 81 | batch_size=batch_size, num_workers=workers, 82 | sampler=RandomIdentitySampler(train_set, num_instances), 83 | pin_memory=True, drop_last=True) 84 | 85 | return dataset, num_classes, train_loader, trainvallabel, val_loader, query_loader, gallery_loader 86 | 87 | 88 | def main(args): 89 | np.random.seed(args.seed) 90 | torch.manual_seed(args.seed) 91 | cudnn.benchmark = True 92 | for ii in range(1,11): 93 | #if ii == 5 or ii == 6: ii = ii - 4 94 | print(ii) 95 | if not osp.exists(args.logs_dir+'/{}'.format(ii)): 96 | os.mkdir(args.logs_dir+'/{}'.format(ii)) 97 | sys.stdout = Logger(osp.join(args.logs_dir+'/{}/log'.format(ii))) 98 | dataset, num_classes, train_loader, trainvallabel, val_loader, query_loader, gallery_loader = \ 99 | get_data(args.dataset, args.split, args.data_dir, args.height, 100 | args.width, args.batch_size, args.num_instances, args.workers, 101 | args.combine_trainval, args.flip_prob, args.padding, args.re_prob,ii+1) 102 | if not args.evaluate: 103 | sys.stdout = Logger(osp.join(args.logs_dir+'/log')) 104 | 105 | # Model settings 106 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_device 107 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 108 | model = ft_net(args=args, num_classes=num_classes, num_features=args.features) 109 | model = nn.DataParallel(model, device_ids=[0, 1]) 110 | model = model.to(device) 111 | 112 | # Evaluator settings 113 | evaluator = Evaluator(model) 114 | metric = DistanceMetric(algorithm=args.dist_metric) 115 | 116 | # Resume settings 117 | if args.resume: 118 | checkpoint = load_checkpoint(args.resume) 119 | state_dict = checkpoint['model'] 120 | from collections import OrderedDict 121 | new_state_dict = OrderedDict() 122 | for k, v in state_dict.items(): 123 | if 'module' not in k: 124 | k = 'module.' + k 125 | else: 126 | k = k.replace('features.module.', 'module.features.') 127 | new_state_dict[k] = v 128 | model.load_state_dict(new_state_dict) 129 | start_epoch = checkpoint['epoch'] 130 | print("=> Start epoch {}".format(start_epoch)) 131 | 132 | # Skip training and conducting evaluation in python 133 | if args.evaluate: 134 | metric.train(model, train_loader) 135 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 136 | exit() 137 | 138 | # Losses 139 | ce_Loss = CrossEntropyLabelSmooth(num_classes=num_classes, epsilon=args.epsilon).cuda() 140 | associate_loss = ASS_loss().cuda() 141 | rank_Loss = Rank_loss(margin_1=args.margin_1, margin_2=args.margin_2, alpha_1=args.alpha_1, alpha_2=args.alpha_2).cuda() 142 | 143 | # Optimizers and schedulers 144 | conv_optim = model.module.optims() 145 | conv_scheduler = WarmupMultiStepLR(conv_optim, args.mile_stone, args.gamma, args.warmup_factor, 146 | args.warmup_iters, args.warmup_methods) 147 | 148 | trainer = Trainer(args, model, ce_Loss, rank_Loss, associate_loss, trainvallabel) 149 | best_top1 = -1 150 | start_epoch = 0 151 | print(args) 152 | for epoch in range(start_epoch, args.epochs): 153 | conv_scheduler.step() 154 | triple_loss, tot_loss = trainer.train(epoch, train_loader, conv_optim) 155 | 156 | save_checkpoint({ 157 | 'model': model.module.state_dict(), 158 | 'epoch': epoch + 1, 159 | 'best_top1': best_top1, 160 | }, False, epoch, (args.logs_dir+'/{}'.format(ii)), fpath='checkpoint.pth.tar') 161 | 162 | if epoch < args.begin_test: 163 | continue 164 | if not epoch % args.evaluate_freq == 0: 165 | continue 166 | 167 | top1, cmc, mAP = evaluator.evaluate(query_loader, gallery_loader, metric) 168 | 169 | is_best = top1 > best_top1 170 | best_top1 = max(top1, best_top1) 171 | save_checkpoint({ 172 | 'model': model.module.state_dict(), 173 | 'epoch': epoch + 1, 174 | 'best_top1': best_top1, 175 | }, is_best, epoch, (args.logs_dir+'/{}'.format(ii)), fpath='checkpoint.pth.tar') 176 | print(args) 177 | 178 | if __name__ == '__main__': 179 | parser = argparse.ArgumentParser(description="Cross_modality for Person Re-identification") 180 | 181 | # dataset settings 182 | parser.add_argument('-d', '--dataset', type=str, default='RegDB', choices=datasets.names()) 183 | parser.add_argument('-b', '--batch-size', type=int, default=64) 184 | parser.add_argument('-j', '--workers', type=int, default=4) 185 | parser.add_argument('--split', type=int, default=0) 186 | 187 | # transformer 188 | parser.add_argument('--height', type=int, default=256) 189 | parser.add_argument('--width', type=int, default= 128) 190 | parser.add_argument('--flip_prob', type=float, default=0.5) 191 | parser.add_argument('--re_prob', type=float, default=0.0) 192 | parser.add_argument('--padding', type=int, default=0) 193 | parser.add_argument('--combine-trainval', default=True, action='store_true', 194 | help="train and val sets together for training, " 195 | "val set alone for validation") 196 | parser.add_argument('--num-instances', type=int, default=8, 197 | help="each minibatch consist of " 198 | "(batch_size // num_instances) identities, and " 199 | "each identity has num_instances instances, " 200 | "default: 4") 201 | # model 202 | parser.add_argument('--features', type=int, default=2048) 203 | 204 | # rank loss settings 205 | parser.add_argument('--margin_1', type=float, default=0.9, help="margin_1 of the triplet loss, default: 0.9") 206 | parser.add_argument('--margin_2', type=float, default=1.0, help="margin_1 of the triplet loss, default: 1.5") 207 | parser.add_argument('--alpha_1', type=float, default=2.2, help="alpha_1 of the triplet loss, default: 2.4") 208 | parser.add_argument('--alpha_2', type=float, default=2.0, help="alpha_2 of the triplet loss, default: 2.2") 209 | 210 | # optimizer and scheduler 211 | parser.add_argument('--lr', type=float, default=2.6e-4, help="learning rate of all parameters") 212 | parser.add_argument('--weight-decay', type=float, default=5e-4) 213 | parser.add_argument('--use_adam', action='store_true', help="use Adam as the optimizer, elsewise SGD ") 214 | parser.add_argument('--gamma', type=float, default = 0.1, help="gamma for learning rate decay") 215 | 216 | parser.add_argument('--mile_stone', type=list, default=[210]) 217 | 218 | parser.add_argument('--warmup_iters', type=int, default=10) 219 | parser.add_argument('--warmup_methods', type=str, default = 'linear', choices=('linear', 'constant')) 220 | parser.add_argument('--warmup_factor', type=float, default = 0.01 ) 221 | 222 | # training configs 223 | parser.add_argument('--resume', type=str, default='', metavar='') 224 | parser.add_argument('--evaluate',action='store_true', 225 | help="this option meaningless " 226 | "since it is required to conduct evaluation on officially approved codes") 227 | 228 | parser.add_argument('--epochs', type=int, default=600) 229 | parser.add_argument('--start_save', type=int, default=100, help="start saving checkpoints after specific epoch") 230 | parser.add_argument('--begin_test', type=int, default=100) 231 | parser.add_argument('--evaluate_freq', type=int, default=5) 232 | 233 | parser.add_argument('--seed', type=int, default=1) 234 | 235 | # adopted metric 236 | parser.add_argument('--dist-metric', type=str, default='euclidean') 237 | 238 | # misc 239 | working_dir = osp.dirname(osp.abspath(__file__)) 240 | parser.add_argument('--data-dir', type=str, metavar='PATH', 241 | default=osp.join(working_dir, 'data')) 242 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 243 | default=osp.join(working_dir, 'Exp_RegDB/RegDB_2')) 244 | # hyper-parameters 245 | parser.add_argument('-CE_loss', type=int, default=1, help="weight of cross entropy loss") 246 | parser.add_argument('-epsilon', type=float, default=0.1, help="label smooth") 247 | parser.add_argument('-Triplet_loss', type=int, default=1, help="weight of triplet loss") 248 | parser.add_argument('-Associate_loss', type=float, default=0.0, help="weight of loss") 249 | 250 | parser.add_argument('-CML_loss', type=int, default=8, help="the weight of conventional mutual learning") 251 | parser.add_argument('-VCD_loss', type=int, default=2, help="the weight of VCD and VML") 252 | parser.add_argument('-VSD_loss', type=float, default=2, help="weight of VSD") 253 | parser.add_argument('-temperature', type=int, default=1, help="the temperature used in knowledge distillation") 254 | 255 | # Bottleneck 256 | parser.add_argument('-z_dim', type=int, default=256, help="dimension of latent z, better set to {128, 256, 512}") 257 | # device set 258 | parser.add_argument('--visible_device', default='2, 1', type=str, help='gpu_ids: e.g. 0, 0,1,2 0,2') 259 | 260 | 261 | main(parser.parse_args()) 262 | -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import feature_extraction 6 | from . import metric_learning 7 | from . import models 8 | from . import utils 9 | from . import dist_metric 10 | from . import evaluators 11 | from . import trainers 12 | 13 | __version__ = '0.2.0' 14 | -------------------------------------------------------------------------------- /reid/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__init__.pyc -------------------------------------------------------------------------------- /reid/__pycache__/New_trainers_3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/New_trainers_3.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /reid/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/dist_metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/dist_metric.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/dist_metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/dist_metric.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/evaluators.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/evaluators.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/trainers.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/trainers.cpython-37.pyc -------------------------------------------------------------------------------- /reid/datasets/RegDB.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | from ..utils.data import Dataset 4 | from ..utils.serialization import write_json 5 | 6 | class RegDB(Dataset): 7 | def __init__(self, root, split_id=0, ii=0, num_val=100, download=True): 8 | super(RegDB, self).__init__(root, split_id=split_id) 9 | self.ii = ii 10 | if download: 11 | self.download() 12 | 13 | self.load(num_val) 14 | 15 | def download(self): 16 | index_train_RGB = open('./data/RegDB/idx/train_visible_{}.txt'.format(self.ii),'r') 17 | index_train_IR = open('./data/RegDB/idx/train_thermal_{}.txt'.format(self.ii),'r') 18 | index_test_RGB = open('./data/RegDB/idx/test_visible_{}.txt'.format(self.ii),'r') 19 | index_test_IR = open('./data/RegDB/idx/test_thermal_{}.txt'.format(self.ii),'r') 20 | 21 | def loadIdx(index): 22 | Lines = index.readlines() 23 | idx = [] 24 | for line in Lines: 25 | tmp = line.strip('\n') 26 | tmp = tmp.split(' ') 27 | idx.append(tmp) 28 | return idx 29 | 30 | index_train_RGB = loadIdx(index_train_RGB) 31 | index_train_IR = loadIdx(index_train_IR) 32 | index_test_RGB = loadIdx(index_test_RGB) 33 | index_test_IR = loadIdx(index_test_IR) 34 | 35 | # 412 identities with 3 camera views each 36 | identities = [[[] for _ in range(3)] for _ in range(412)] 37 | def insertToMeta(index, cam, delta): 38 | for idx in index: 39 | fname = osp.basename(idx[0]) 40 | 41 | pid = int(idx[1]) + delta 42 | 43 | identities[pid][cam].append(fname) 44 | 45 | insertToMeta(index_train_RGB, 0, 0) 46 | insertToMeta(index_train_IR, 2, 0) 47 | insertToMeta(index_test_RGB, 0, 206) 48 | insertToMeta(index_test_IR, 2, 206) 49 | 50 | trainval_pids = set() 51 | gallery_pids = set() 52 | query_pids = set() 53 | for i in range(206): 54 | trainval_pids.add(i) 55 | gallery_pids.add(i + 206) 56 | query_pids.add(i + 206) 57 | 58 | # Save meta information into a json file 59 | meta = {'name': 'RegDB', 'shot': 'multiple', 'num_cameras': 3, 60 | 'identities': identities} 61 | write_json(meta, osp.join(self.root, 'meta.json')) 62 | 63 | # Save the only training / test split 64 | splits = [{ 65 | 'trainval': sorted(list(trainval_pids)), 66 | 'query': sorted(list(query_pids)), 67 | 'gallery': sorted(list(gallery_pids))}] 68 | write_json(splits, osp.join(self.root, 'splits.json')) 69 | -------------------------------------------------------------------------------- /reid/datasets/RegDB.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/RegDB.pyc -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .sysu import SYSU 5 | from .RegDB import RegDB 6 | 7 | 8 | __factory = { 9 | 'sysu': SYSU, 10 | 'RegDB': RegDB, 11 | } 12 | 13 | 14 | def names(): 15 | return sorted(__factory.keys()) 16 | 17 | 18 | def create(name, root, *args, **kwargs): 19 | """ 20 | Create a dataset instance. 21 | 22 | Parameters 23 | ---------- 24 | name : str 25 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03', 26 | 'market1501', and 'dukemtmc'. 27 | root : str 28 | The path to the dataset directory. 29 | split_id : int, optional 30 | The index of data split. Default: 0 31 | num_val : int or float, optional 32 | When int, it means the number of validation identities. When float, 33 | it means the proportion of validation to all the trainval. Default: 100 34 | download : bool, optional 35 | If True, will download the dataset. Default: False 36 | """ 37 | if name not in __factory: 38 | raise KeyError("Unknown dataset:", name) 39 | return __factory[name](root, *args, **kwargs) 40 | 41 | 42 | def get_dataset(name, root, *args, **kwargs): 43 | warnings.warn("get_dataset is deprecated. Use create instead.") 44 | return create(name, root, *args, **kwargs) 45 | -------------------------------------------------------------------------------- /reid/datasets/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__init__.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/RegDB.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/RegDB.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/RegDB.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/RegDB.cpython-37.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/sysu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/sysu.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/sysu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/sysu.cpython-37.pyc -------------------------------------------------------------------------------- /reid/datasets/sysu.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import write_json 7 | 8 | 9 | class SYSU(Dataset): 10 | def __init__(self, root, split_id=0, num_val=100, download=True): 11 | super(SYSU, self).__init__(root, split_id=split_id) 12 | 13 | self.root += "/SYSU-MM01" 14 | 15 | if download: 16 | self.download() 17 | 18 | if not self._check_integrity(): 19 | raise RuntimeError("Dataset not found or corrupted. ") 20 | 21 | self.load(num_val) 22 | 23 | def download(self): 24 | 25 | import shutil 26 | from glob import glob 27 | 28 | # Format 29 | images_dir = osp.join(self.root+'/images') 30 | mkdir_if_missing(images_dir) 31 | 32 | # gain the spilt from .mat 33 | import scipy.io as scio 34 | data = scio.loadmat(self.root+'/exp/train_id.mat') 35 | train_id = data['id'][0] 36 | data = scio.loadmat(self.root+'/exp/val_id.mat') 37 | val_id = data['id'][0] 38 | data = scio.loadmat(self.root+'/exp/test_id.mat') 39 | test_id = data['id'][0] 40 | 41 | # 533 identities with 6 camera views each 42 | identities = [[[] for _ in range(6)] for _ in range(533)] 43 | for pid in range(1, 534): 44 | for cam in range(1,7): 45 | images_path = self.root+"/cam"+str(cam)+"/"+str(pid).zfill(4) 46 | fpaths = sorted(glob(images_path+"/*.jpg")) 47 | for fpath in fpaths: 48 | # print(fpath) 49 | fname = ('{:08d}_{:02d}_{:04d}.jpg' 50 | .format(pid-1, cam-1, len(identities[pid-1][cam-1]))) 51 | identities[pid-1][cam-1].append(fname) 52 | shutil.copy(fpath, osp.join(images_dir, fname)) 53 | 54 | trainval_pids = set() 55 | gallery_pids = set() 56 | query_pids = set() 57 | train_val_ = numpy.concatenate((train_id,val_id)) 58 | for i in (train_val_): 59 | trainval_pids.add(int(i) - 1) 60 | for i in test_id: 61 | gallery_pids.add(int(i) - 1) 62 | query_pids.add(int(i) - 1) 63 | 64 | # Save meta information into a json file 65 | meta = {'name': 'sysu', 'shot': 'multiple', 'num_cameras': 6, 66 | 'identities': identities} 67 | write_json(meta, osp.join(self.root, 'meta.json')) 68 | 69 | # Save the only training / test split 70 | splits = [{ 71 | 'trainval': sorted(list(trainval_pids)), 72 | 'query': sorted(list(query_pids)), 73 | 'gallery': sorted(list(gallery_pids))}] 74 | write_json(splits, osp.join(self.root, 'splits.json')) 75 | -------------------------------------------------------------------------------- /reid/datasets/sysu.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/sysu.pyc -------------------------------------------------------------------------------- /reid/dist_metric.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | from .evaluators import extract_features 6 | from .metric_learning import get_metric 7 | 8 | 9 | class DistanceMetric(object): 10 | def __init__(self, algorithm='euclidean', *args, **kwargs): 11 | super(DistanceMetric, self).__init__() 12 | self.algorithm = algorithm 13 | self.metric = get_metric(algorithm, *args, **kwargs) 14 | 15 | def train(self, model, data_loader): 16 | if self.algorithm == 'euclidean': return 17 | features, labels = extract_features(model, data_loader) 18 | features = torch.stack(features.values()).numpy() 19 | labels = torch.Tensor(list(labels.values())).numpy() 20 | self.metric.fit(features, labels) 21 | 22 | def transform(self, X): 23 | if torch.is_tensor(X): 24 | X = X.numpy() 25 | X = self.metric.transform(X) 26 | X = torch.from_numpy(X) 27 | else: 28 | X = self.metric.transform(X) 29 | return X 30 | 31 | -------------------------------------------------------------------------------- /reid/dist_metric.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/dist_metric.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .ranking import cmc, mean_ap 4 | 5 | __all__ = [ 6 | 'cmc', 7 | 'mean_ap', 8 | ] 9 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__init__.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/ranking.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/ranking.cpython-35.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/ranking.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/ranking.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/ranking.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/ranking.cpython-37.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | 24 | distmat = to_numpy(distmat) 25 | m, n = distmat.shape 26 | 27 | # Fill up default values 28 | if query_ids is None: 29 | query_ids = np.arange(m) 30 | if gallery_ids is None: 31 | gallery_ids = np.arange(n) 32 | if query_cams is None: 33 | query_cams = np.zeros(m).astype(np.int32) 34 | if gallery_cams is None: 35 | gallery_cams = np.ones(n).astype(np.int32) 36 | 37 | # Ensure numpy array 38 | query_ids = np.asarray(query_ids) 39 | gallery_ids = np.asarray(gallery_ids) 40 | query_cams = np.asarray(query_cams) 41 | gallery_cams = np.asarray(gallery_cams) 42 | 43 | # Sort and find correct matches 44 | indices = np.argsort(distmat, axis=1) 45 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 46 | 47 | # Compute CMC for each query 48 | ret = np.zeros(topk) 49 | num_valid_queries = 0 50 | for i in range(m): 51 | # Filter out the same id and same camera 52 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 53 | (gallery_cams[indices[i]] != query_cams[i])) 54 | if separate_camera_set: 55 | # Filter out samples from same camera 56 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 57 | if not np.any(matches[i, valid]): continue 58 | if single_gallery_shot: 59 | repeat = 10 60 | gids = gallery_ids[indices[i][valid]] 61 | inds = np.where(valid)[0] 62 | ids_dict = defaultdict(list) 63 | for j, x in zip(inds, gids): 64 | ids_dict[x].append(j) 65 | else: 66 | repeat = 1 67 | for _ in range(repeat): 68 | if single_gallery_shot: 69 | # Randomly choose one instance for each id 70 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 71 | index = np.nonzero(matches[i, sampled])[0] 72 | else: 73 | index = np.nonzero(matches[i, valid])[0] 74 | delta = 1. / (len(index) * repeat) 75 | for j, k in enumerate(index): 76 | if k - j >= topk: break 77 | if first_match_break: 78 | ret[k - j] += 1 79 | break 80 | ret[k - j] += delta 81 | num_valid_queries += 1 82 | if num_valid_queries == 0: 83 | raise RuntimeError("No valid query") 84 | return ret.cumsum() / num_valid_queries 85 | 86 | 87 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 88 | query_cams=None, gallery_cams=None): 89 | distmat = to_numpy(distmat) 90 | m, n = distmat.shape 91 | # Fill up default values 92 | if query_ids is None: 93 | query_ids = np.arange(m) 94 | if gallery_ids is None: 95 | gallery_ids = np.arange(n) 96 | if query_cams is None: 97 | query_cams = np.zeros(m).astype(np.int32) 98 | if gallery_cams is None: 99 | gallery_cams = np.ones(n).astype(np.int32) 100 | # Ensure numpy array 101 | query_ids = np.asarray(query_ids) 102 | gallery_ids = np.asarray(gallery_ids) 103 | query_cams = np.asarray(query_cams) 104 | gallery_cams = np.asarray(gallery_cams) 105 | # Sort and find correct matches 106 | indices = np.argsort(distmat, axis=1) 107 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 108 | # Compute AP for each query 109 | aps = [] 110 | for i in range(m): 111 | # Filter out the same id and same camera 112 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 113 | (gallery_cams[indices[i]] != query_cams[i])) 114 | y_true = matches[i, valid] 115 | y_score = -distmat[i][indices[i]][valid] 116 | if not np.any(y_true): continue 117 | aps.append(average_precision_score(y_true, y_score)) 118 | if len(aps) == 0: 119 | raise RuntimeError("No valid query") 120 | return np.mean(aps) 121 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/ranking.pyc -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | from .evaluation_metrics import cmc, mean_ap 8 | from .feature_extraction import extract_cnn_feature 9 | from .utils.meters import AverageMeter 10 | 11 | 12 | def extract_features(model, data_loader): 13 | model.eval() 14 | batch_time = AverageMeter() 15 | data_time = AverageMeter() 16 | 17 | features = OrderedDict() 18 | labels = OrderedDict() 19 | filenames = [] 20 | 21 | end = time.time() 22 | for i, (imgs, fnames, pids, cams) in enumerate(data_loader): 23 | data_time.update(time.time() - end) 24 | 25 | subs = ((cams == 2).long() + (cams == 5).long()).cuda() 26 | outputs = extract_cnn_feature(model=model, inputs=imgs, sub=subs) 27 | for fname, output, pid in zip(fnames, outputs, pids): 28 | features[fname] = output 29 | labels[fname] = pid 30 | filenames.append(fname) 31 | 32 | batch_time.update(time.time() - end) 33 | end = time.time() 34 | 35 | 36 | return features, labels, filenames 37 | 38 | 39 | def pairwise_distance(features1, features2, fnames1=None, fnames2=None, metric=None): 40 | x = torch.cat([features1[f].unsqueeze(0) for f in fnames1], 0) 41 | y = torch.cat([features2[f].unsqueeze(0) for f in fnames2], 0) 42 | 43 | m, n = x.size(0), y.size(0) 44 | x = x.view(m, -1) 45 | y = y.view(n, -1) 46 | 47 | # normalize 48 | x = torch.nn.functional.normalize(x, dim=1, p=2) 49 | y = torch.nn.functional.normalize(y, dim=1, p=2) 50 | 51 | if metric is not None: 52 | x = metric.transform(x) 53 | y = metric.transform(y) 54 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 55 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 56 | dist.addmm_(1, -2, x, y.t()) 57 | return dist 58 | 59 | 60 | def evaluate_all(distmat, labels1, labels2, fnames1, fnames2, flag, cmc_topk=(1, 10, 20)): 61 | query_ids = [labels1[f] for f in fnames1] 62 | gallery_ids = [labels2[f] for f in fnames2] 63 | query_cams = [0 for f in fnames1] 64 | gallery_cams = [2 for f in fnames2] 65 | 66 | """Compute mean AP""" 67 | 68 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 69 | print('Mean AP: {:4.2%}'.format(mAP)) 70 | if flag: # return mAP only 71 | return mAP 72 | 73 | """Compute all kinds of CMC scores""" 74 | 75 | cmc_configs = { 76 | 'MM01': dict(separate_camera_set=False, 77 | single_gallery_shot=False, 78 | first_match_break=True)} 79 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, query_cams, gallery_cams, **params) 80 | for name, params in cmc_configs.items()} 81 | # name:MM01 82 | # params{'separate_camera_set': False, 'single_gallery_shot': False, 'first_match_break': True} 83 | print('CMC Scores{:>12}'.format('MM01')) 84 | for k in cmc_topk: 85 | print(' top-{:<4}{:12.2%}'.format(k, cmc_scores['MM01'][k - 1])) 86 | # using Rank-1 as the main evaluation criterion 87 | return cmc_scores['MM01'][0] 88 | 89 | 90 | class Evaluator(object): 91 | def __init__(self, model): 92 | super(Evaluator, self).__init__() 93 | self.model = model 94 | 95 | def evaluate(self, data_loader1, data_loader2, metric=None, flag=False): 96 | features1, labels1, fnames1 = extract_features(self.model, data_loader1) 97 | features2, labels2, fnames2 = extract_features(self.model, data_loader2) 98 | distmat = pairwise_distance(features1, features2, fnames1, fnames2, metric=metric) 99 | return evaluate_all(distmat, labels1, labels2, fnames1, fnames2, flag) 100 | -------------------------------------------------------------------------------- /reid/evaluators.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluators.pyc -------------------------------------------------------------------------------- /reid/evaluators_regdb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import numpy as np 7 | from .evaluation_metrics import cmc, mean_ap 8 | from .feature_extraction import extract_cnn_feature 9 | from .utils.meters import AverageMeter 10 | 11 | def extract_features(model, data_loader): 12 | model.eval() 13 | batch_time = AverageMeter() 14 | data_time = AverageMeter() 15 | 16 | features = OrderedDict() 17 | labels = OrderedDict() 18 | filenames = [] 19 | 20 | end = time.time() 21 | for i, (imgs, fnames, pids, cams) in enumerate(data_loader): 22 | data_time.update(time.time() - end) 23 | 24 | subs = ((cams == 2).long() + (cams == 5).long()).cuda() 25 | outputs = extract_cnn_feature(model, imgs, subs) 26 | for fname, output, pid in zip(fnames, outputs, pids): 27 | features[fname] = output 28 | labels[fname] = pid 29 | filenames.append(fname) 30 | 31 | batch_time.update(time.time() - end) 32 | end = time.time() 33 | 34 | return features, labels, filenames 35 | 36 | 37 | def pairwise_distance(features1, features2, fnames1=None, fnames2=None, metric=None): 38 | 39 | x = torch.cat([features1[f].unsqueeze(0) for f in fnames1], 0) 40 | y = torch.cat([features2[f].unsqueeze(0) for f in fnames2], 0) 41 | 42 | m, n = x.size(0), y.size(0) 43 | x = x.view(m, -1) 44 | y = y.view(n, -1) 45 | 46 | # normalize 47 | x = torch.nn.functional.normalize(x, dim=1, p=2) 48 | y = torch.nn.functional.normalize(y, dim=1, p=2) 49 | 50 | if metric is not None: 51 | x = metric.transform(x) 52 | y = metric.transform(y) 53 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 54 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 55 | dist.addmm_(1, -2, x, y.t()) 56 | return dist 57 | 58 | 59 | def evaluate_all(distmat, labels1, labels2, fnames1, fnames2, flag, cmc_topk=(1, 10, 20)): 60 | query_ids = [labels1[f] for f in fnames1] 61 | gallery_ids = [labels2[f] for f in fnames2] 62 | query_cams = [0 for f in fnames1] 63 | gallery_cams = [2 for f in fnames2] 64 | 65 | # Compute mean AP 66 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 67 | print('Mean AP: {:4.2%}'.format(mAP)) 68 | 69 | # return mAP 70 | if flag: 71 | return mAP 72 | # Compute all kinds of CMC scores 73 | cmc_configs = { 74 | 'RegDB': dict(separate_camera_set=False, 75 | single_gallery_shot=False, 76 | first_match_break=True)} 77 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 78 | query_cams, gallery_cams, **params) 79 | for name, params in cmc_configs.items()} 80 | 81 | print('CMC Scores{:>12}'.format('RegDB') 82 | ) 83 | for k in cmc_topk: 84 | print(' top-{:<4}{:12.2%}' 85 | .format(k,cmc_scores['RegDB'][k - 1]) 86 | ) 87 | 88 | # Use the allshots cmc top-1 score for validation criterion 89 | return cmc_scores['RegDB'][0] 90 | 91 | 92 | def eval_regdb(distmat, labels1, labels2, fnames1, fnames2, max_rank = 20, cmc_topk=(1, 10, 20)): 93 | q_pids = [labels1[f].numpy() for f in fnames1] 94 | g_pids = [labels2[f].numpy() for f in fnames2] 95 | q_pids= np.array(q_pids) 96 | g_pids= np.array(g_pids) 97 | # q_pids = q_pids.numpy() 98 | # g_pids = g_pids.numpy() 99 | num_q, num_g = distmat.shape 100 | if num_g < max_rank: 101 | max_rank = num_g 102 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 103 | indices = np.argsort(distmat, axis=1) 104 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 105 | 106 | # compute cmc curve for each query 107 | all_cmc = [] 108 | all_AP = [] 109 | all_INP = [] 110 | num_valid_q = 0. # number of valid query 111 | 112 | # only two cameras 113 | q_camids = np.ones(num_q).astype(np.int32) 114 | g_camids = 2* np.ones(num_g).astype(np.int32) 115 | 116 | for q_idx in range(num_q): 117 | # get query pid and camid 118 | q_pid = q_pids[q_idx] 119 | q_camid = q_camids[q_idx] 120 | 121 | # remove gallery samples that have the same pid and camid with query 122 | order = indices[q_idx] 123 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 124 | keep = np.invert(remove) 125 | 126 | # compute cmc curve 127 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 128 | if not np.any(raw_cmc): 129 | # this condition is true when query identity does not appear in gallery 130 | continue 131 | 132 | cmc = raw_cmc.cumsum() 133 | 134 | # compute mINP 135 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 136 | pos_idx = np.where(raw_cmc == 1) 137 | pos_max_idx = np.max(pos_idx) 138 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 139 | all_INP.append(inp) 140 | cmc[cmc > 1] = 1 141 | all_cmc.append(cmc[:max_rank]) 142 | num_valid_q += 1. 143 | num_rel = raw_cmc.sum() 144 | tmp_cmc = raw_cmc.cumsum() 145 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 146 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 147 | AP = tmp_cmc.sum() / num_rel 148 | all_AP.append(AP) 149 | 150 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 151 | 152 | all_cmc = np.asarray(all_cmc).astype(np.float32) 153 | all_cmc = all_cmc.sum(0) / num_valid_q 154 | mAP = np.mean(all_AP) 155 | mINP = np.mean(all_INP) 156 | 157 | print('Mean AP: {:4.2%}'.format(mAP)) 158 | print('CMC Scores{:>12}'.format('RegDB') 159 | ) 160 | for k in cmc_topk: 161 | print(' top-{:<4}{:12.2%}' 162 | .format(k,all_cmc[k - 1]) 163 | ) 164 | return all_cmc[0], all_cmc, mAP 165 | 166 | 167 | class Evaluator(object): 168 | def __init__(self, model, regdb=True): 169 | super(Evaluator, self).__init__() 170 | self.model = model 171 | self.regdb=regdb 172 | 173 | def evaluate(self, data_loader1, data_loader2, metric=None, flag=False): 174 | features1, labels1, fnames1 = extract_features(model=self.model, data_loader=data_loader1) 175 | features2, labels2, fnames2 = extract_features(model=self.model, data_loader=data_loader2) 176 | distmat = pairwise_distance(features1, features2, fnames1, fnames2, metric=metric) 177 | 178 | if self.regdb: 179 | top1, all_cmc, mAP= eval_regdb(distmat, labels1, labels2, fnames1, fnames2) 180 | return top1, all_cmc, mAP 181 | else: 182 | return evaluate_all(distmat, labels1, labels2, fnames1, fnames2, flag) 183 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cnn import extract_cnn_feature 4 | 5 | __all__ = [ 6 | 'extract_cnn_feature' 7 | ] 8 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__init__.pyc -------------------------------------------------------------------------------- /reid/feature_extraction/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /reid/feature_extraction/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/feature_extraction/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/feature_extraction/__pycache__/cnn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/cnn.cpython-35.pyc -------------------------------------------------------------------------------- /reid/feature_extraction/__pycache__/cnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/cnn.cpython-36.pyc -------------------------------------------------------------------------------- /reid/feature_extraction/__pycache__/cnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/cnn.cpython-37.pyc -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import OrderedDict 3 | 4 | import torch 5 | 6 | from ..utils import to_torch 7 | 8 | def extract_cnn_feature(model, inputs, sub, modules=None): 9 | model.eval() 10 | inputs = to_torch(inputs) 11 | inputs = inputs.cuda() 12 | with torch.no_grad(): 13 | if modules is None: 14 | # whether "modules" is None or not, this function is used to extract feature from a certain dataset. 15 | #_, outputs = model(inputs) 16 | i_observation, i_representation, i_ms_observation, i_ms_representation, \ 17 | v_observation, v_representation, v_ms_observation, v_ms_representation = model(inputs) 18 | outputs = torch.cat(tensors=[i_observation[1], i_ms_observation[1], i_representation[1], i_ms_representation[1], 19 | v_observation[1], v_ms_observation[1], v_representation[1], v_ms_representation[1]], dim=1) 20 | outputs = outputs.data.cpu() 21 | return outputs 22 | # Register forward hook for each module 23 | outputs = OrderedDict() 24 | handles = [] 25 | for m in modules: 26 | outputs[id(m)] = None 27 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 28 | handles.append(m.register_forward_hook(func)) 29 | model(inputs) 30 | for h in handles: 31 | h.remove() 32 | return list(outputs.values()) 33 | -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/cnn.pyc -------------------------------------------------------------------------------- /reid/metric_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | from .euclidean import Euclidean 5 | 6 | __factory = { 7 | 'euclidean': Euclidean 8 | } 9 | 10 | 11 | def get_metric(algorithm, *args, **kwargs): 12 | if algorithm not in __factory: 13 | raise KeyError("Unknown metric:", algorithm) 14 | return __factory[algorithm](*args, **kwargs) 15 | -------------------------------------------------------------------------------- /reid/metric_learning/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__init__.pyc -------------------------------------------------------------------------------- /reid/metric_learning/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /reid/metric_learning/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/metric_learning/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/metric_learning/__pycache__/euclidean.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__pycache__/euclidean.cpython-36.pyc -------------------------------------------------------------------------------- /reid/metric_learning/__pycache__/euclidean.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__pycache__/euclidean.cpython-37.pyc -------------------------------------------------------------------------------- /reid/metric_learning/euclidean.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from numpy.linalg import cholesky 3 | from sklearn.base import BaseEstimator, TransformerMixin 4 | from sklearn.utils.validation import check_array 5 | import numpy as np 6 | 7 | class BaseMetricLearner(BaseEstimator, TransformerMixin): 8 | def __init__(self): 9 | raise NotImplementedError('BaseMetricLearner should not be instantiated') 10 | 11 | def metric(self): 12 | """Computes the Mahalanobis matrix from the transformation matrix. 13 | 14 | .. math:: M = L^{\\top} L 15 | 16 | Returns 17 | ------- 18 | M : (d x d) matrix 19 | """ 20 | L = self.transformer() 21 | return L.T.dot(L) 22 | 23 | def transformer(self): 24 | """Computes the transformation matrix from the Mahalanobis matrix. 25 | 26 | L = cholesky(M).T 27 | 28 | Returns 29 | ------- 30 | L : upper triangular (d x d) matrix 31 | """ 32 | return cholesky(self.metric()).T 33 | 34 | def transform(self, X=None): 35 | """Applies the metric transformation. 36 | 37 | Parameters 38 | ---------- 39 | X : (n x d) matrix, optional 40 | Data to transform. If not supplied, the training data will be used. 41 | 42 | Returns 43 | ------- 44 | transformed : (n x d) matrix 45 | Input data transformed to the metric space by :math:`XL^{\\top}` 46 | """ 47 | if X is None: 48 | X = self.X_ 49 | else: 50 | X = check_array(X, accept_sparse=True) 51 | L = self.transformer() 52 | return X.dot(L.T) 53 | 54 | 55 | 56 | class Euclidean(BaseMetricLearner): 57 | def __init__(self): 58 | self.M_ = None 59 | 60 | def metric(self): 61 | return self.M_ 62 | 63 | def fit(self, X): 64 | self.M_ = np.eye(X.shape[1]) 65 | self.X_ = X 66 | 67 | def transform(self, X=None): 68 | if X is None: 69 | return self.X_ 70 | return X 71 | 72 | def get_metric(self): 73 | pass 74 | def score_pairs(self): 75 | pass 76 | 77 | 78 | a = Euclidean() -------------------------------------------------------------------------------- /reid/metric_learning/euclidean.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/euclidean.pyc -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .newresnet import * 4 | 5 | __factory = { 6 | 'ft_net': ft_net 7 | } 8 | 9 | def names(): 10 | return sorted(__factory.keys()) 11 | 12 | def create(name, *args, **kwargs): 13 | if name not in __factory: 14 | raise KeyError("Unknown model:", name) 15 | return __factory[name](*args, **kwargs) 16 | -------------------------------------------------------------------------------- /reid/models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__init__.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/baseline.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__pycache__/baseline.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/newresnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__pycache__/newresnet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/newresnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__pycache__/newresnet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | 5 | try: 6 | from torch.hub import load_state_dict_from_url 7 | except ImportError: 8 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 9 | 10 | ################################################################################## 11 | # Initialization function 12 | ################################################################################## 13 | def weights_init_kaiming(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Linear') != -1: 16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 17 | nn.init.constant_(m.bias, 0.0) 18 | elif classname.find('Conv') != -1: 19 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 20 | if m.bias is not None: 21 | nn.init.constant_(m.bias, 0.0) 22 | elif classname.find('BatchNorm') != -1: 23 | if m.affine: 24 | nn.init.constant_(m.weight, 1.0) 25 | nn.init.constant_(m.bias, 0.0) 26 | 27 | 28 | def weights_init_Classifier(m): 29 | classname = m.__class__.__name__ 30 | if classname.find('Linear') != -1: 31 | nn.init.normal_(m.weight, std=0.001) 32 | if m.bias: 33 | nn.init.constant_(m.bias, 0.0) 34 | 35 | def weights_init_classifier(m): 36 | classname = m.__class__.__name__ 37 | if classname.find('Linear') != -1: 38 | nn.init.normal_(m.weight.data, std=0.001) 39 | nn.init.constant_(m.bias.data, 0.0) 40 | 41 | ################################################################################## 42 | # Encoder used in framework 43 | ################################################################################## 44 | class Baseline(nn.Module): 45 | in_planes = 2048 46 | def __init__(self, num_classes, num_features): 47 | super(Baseline, self).__init__() 48 | #backbone = resnet50(pretrained= True) 49 | self.cam_256 = ShallowCAM(256) 50 | self.cam_512 = ShallowCAM(512) 51 | self.cam_1024 = ShallowCAM(1024) 52 | self.base = resnet50(pretrained= True) 53 | self.gap = nn.AdaptiveAvgPool2d(1) 54 | 55 | self.num_classes = num_classes 56 | self.in_planes = num_features 57 | 58 | ## use a neck to produce triplet feature and score 59 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 60 | self.bottleneck.bias.requires_grad_(False) 61 | 62 | classifier = [nn.Linear(num_features, num_classes)] 63 | classifier = nn.Sequential(*classifier) 64 | self.classifier = classifier 65 | self.bottleneck.apply(weights_init_kaiming) 66 | self.classifier.apply(weights_init_classifier) 67 | 68 | def forward(self, x): 69 | 70 | for name, module in self.base._modules.items(): 71 | if name == 'avgpool': 72 | break 73 | if name == 'layer2': 74 | x = self.cam_256(x) 75 | if name == 'layer3': 76 | x = self.cam_512(x) 77 | if name == 'layer4': 78 | x = self.cam_1024(x) 79 | x = module(x) 80 | global_feat = self.gap(x) # (b, 2048, 1, 1) 81 | 82 | global_feat = global_feat.view(global_feat.shape[0], -1) 83 | feat1 = self.bottleneck(global_feat) 84 | 85 | cls_score = self.classifier(feat1) 86 | 87 | if self.training: 88 | return cls_score, feat1.view(feat1.shape[0], -1) 89 | else: 90 | return global_feat, feat1 91 | ################################################################################## 92 | # Shallow Cam Module 93 | ################################################################################## 94 | 95 | class CAM_Module(nn.Module): 96 | """ Channel attention module""" 97 | 98 | def __init__(self, in_dim): 99 | super(CAM_Module,self).__init__() 100 | self.channel_in = in_dim 101 | 102 | self.gamma = Parameter(torch.zeros(1)) 103 | self.softmax = torch.nn.Softmax(dim=-1) 104 | 105 | def forward(self, x): 106 | """ 107 | inputs : 108 | x : input feature maps( B X C X H X W) 109 | returns : 110 | out : attention value + input feature 111 | attention: B X C X C 112 | """ 113 | m_batchsize, C, height, width = x.size() 114 | proj_query = x.view(m_batchsize, C, -1) 115 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) 116 | energy = torch.bmm(proj_query, proj_key) 117 | max_energy_0 = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) 118 | energy_new = max_energy_0 - energy 119 | attention = self.softmax(energy_new) 120 | proj_value = x.view(m_batchsize, C, -1) 121 | 122 | out = torch.bmm(attention, proj_value) 123 | out = out.view(m_batchsize, C, height, width) 124 | 125 | gamma = self.gamma.to(out.device) 126 | out = gamma * out + x 127 | return out 128 | 129 | class ShallowCAM(nn.Module): 130 | 131 | def __init__(self, feature_dim): 132 | 133 | super(ShallowCAM,self).__init__() 134 | self.input_feature_dim = feature_dim 135 | self._cam_module = CAM_Module(self.input_feature_dim) 136 | 137 | def forward(self, x): 138 | x = self._cam_module(x) 139 | 140 | return x 141 | 142 | class bnneck(nn.Module): 143 | def __init__(self, indim): 144 | super(bnneck,self).__init__() 145 | self.indim = indim 146 | self.bn = nn.BatchNorm1d(self.indim) 147 | self.bn.bias.requires_grad_(False) 148 | self.bn.apply(weights_init_kaiming) 149 | def forward(self, x): 150 | x = self.bn(x) 151 | return x 152 | 153 | ################################################################################## 154 | # ResNet 155 | ################################################################################## 156 | __all__ = ['resnet50'] 157 | 158 | model_urls = { 159 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 160 | } 161 | 162 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 163 | """3x3 convolution with padding""" 164 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 165 | padding=dilation, groups=groups, bias=False, dilation=dilation) 166 | 167 | 168 | def conv1x1(in_planes, out_planes, stride=1): 169 | """1x1 convolution""" 170 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 171 | 172 | 173 | class BasicBlock(nn.Module): 174 | expansion = 1 175 | __constants__ = ['downsample'] 176 | 177 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 178 | base_width=64, dilation=1, norm_layer=None): 179 | super(BasicBlock, self).__init__() 180 | if norm_layer is None: 181 | norm_layer = nn.BatchNorm2d 182 | if groups != 1 or base_width != 64: 183 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 184 | if dilation > 1: 185 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 186 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 187 | self.conv1 = conv3x3(inplanes, planes, stride) 188 | self.bn1 = norm_layer(planes) 189 | self.relu = nn.ReLU(inplace=True) 190 | self.conv2 = conv3x3(planes, planes) 191 | self.bn2 = norm_layer(planes) 192 | self.downsample = downsample 193 | self.stride = stride 194 | 195 | def forward(self, x): 196 | identity = x 197 | 198 | out = self.conv1(x) 199 | out = self.bn1(out) 200 | out = self.relu(out) 201 | 202 | out = self.conv2(out) 203 | out = self.bn2(out) 204 | 205 | if self.downsample is not None: 206 | identity = self.downsample(x) 207 | 208 | out += identity 209 | out = self.relu(out) 210 | 211 | return out 212 | 213 | 214 | class Bottleneck(nn.Module): 215 | expansion = 4 216 | __constants__ = ['downsample'] 217 | 218 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 219 | base_width=64, dilation=1, norm_layer=None): 220 | super(Bottleneck, self).__init__() 221 | if norm_layer is None: 222 | norm_layer = nn.BatchNorm2d 223 | width = int(planes * (base_width / 64.)) * groups 224 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 225 | self.conv1 = conv1x1(inplanes, width) 226 | self.bn1 = norm_layer(width) 227 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 228 | self.bn2 = norm_layer(width) 229 | self.conv3 = conv1x1(width, planes * self.expansion) 230 | self.bn3 = norm_layer(planes * self.expansion) 231 | self.relu = nn.ReLU(inplace=True) 232 | self.downsample = downsample 233 | self.stride = stride 234 | 235 | def forward(self, x): 236 | identity = x 237 | 238 | out = self.conv1(x) 239 | out = self.bn1(out) 240 | out = self.relu(out) 241 | 242 | out = self.conv2(out) 243 | out = self.bn2(out) 244 | out = self.relu(out) 245 | 246 | out = self.conv3(out) 247 | out = self.bn3(out) 248 | 249 | if self.downsample is not None: 250 | identity = self.downsample(x) 251 | 252 | out += identity 253 | out = self.relu(out) 254 | 255 | return out 256 | 257 | 258 | class ResNet(nn.Module): 259 | 260 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 261 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 262 | norm_layer=None): 263 | super(ResNet, self).__init__() 264 | if norm_layer is None: 265 | norm_layer = nn.BatchNorm2d 266 | self._norm_layer = norm_layer 267 | 268 | self.inplanes = 64 269 | self.dilation = 1 270 | if replace_stride_with_dilation is None: 271 | # each element in the tuple indicates if we should replace 272 | # the 2x2 stride with a dilated convolution instead 273 | replace_stride_with_dilation = [False, False, False] 274 | if len(replace_stride_with_dilation) != 3: 275 | raise ValueError("replace_stride_with_dilation should be None " 276 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 277 | self.groups = groups 278 | self.base_width = width_per_group 279 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 280 | bias=False) 281 | self.bn1 = norm_layer(self.inplanes) 282 | self.relu = nn.ReLU(inplace=True) 283 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 284 | self.layer1 = self._make_layer(block, 64, layers[0]) 285 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 286 | dilate=replace_stride_with_dilation[0]) 287 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 288 | dilate=replace_stride_with_dilation[1]) 289 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 290 | dilate=replace_stride_with_dilation[2]) 291 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 292 | self.fc = nn.Linear(512 * block.expansion, num_classes) 293 | 294 | for m in self.modules(): 295 | if isinstance(m, nn.Conv2d): 296 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 297 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 298 | nn.init.constant_(m.weight, 1) 299 | nn.init.constant_(m.bias, 0) 300 | 301 | # Zero-initialize the last BN in each residual branch, 302 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 303 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 304 | if zero_init_residual: 305 | for m in self.modules(): 306 | if isinstance(m, Bottleneck): 307 | nn.init.constant_(m.bn3.weight, 0) 308 | elif isinstance(m, BasicBlock): 309 | nn.init.constant_(m.bn2.weight, 0) 310 | 311 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 312 | norm_layer = self._norm_layer 313 | downsample = None 314 | previous_dilation = self.dilation 315 | if dilate: 316 | self.dilation *= stride 317 | stride = 1 318 | if stride != 1 or self.inplanes != planes * block.expansion: 319 | downsample = nn.Sequential( 320 | conv1x1(self.inplanes, planes * block.expansion, stride), 321 | norm_layer(planes * block.expansion), 322 | ) 323 | 324 | layers = [] 325 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 326 | self.base_width, previous_dilation, norm_layer)) 327 | self.inplanes = planes * block.expansion 328 | for _ in range(1, blocks): 329 | layers.append(block(self.inplanes, planes, groups=self.groups, 330 | base_width=self.base_width, dilation=self.dilation, 331 | norm_layer=norm_layer)) 332 | 333 | return nn.Sequential(*layers) 334 | 335 | def _forward_impl(self, x): 336 | # See note [TorchScript super()] 337 | x = self.conv1(x) 338 | x = self.bn1(x) 339 | x = self.relu(x) 340 | x = self.maxpool(x) 341 | 342 | x = self.layer1(x) 343 | x = self.layer2(x) 344 | x = self.layer3(x) 345 | x = self.layer4(x) 346 | 347 | x = self.avgpool(x) 348 | x = torch.flatten(x, 1) 349 | x = self.fc(x) 350 | 351 | return x 352 | 353 | def forward(self, x): 354 | return self._forward_impl(x) 355 | 356 | 357 | def load_param(self, model_path): 358 | param_dict = torch.load(model_path) 359 | for i in param_dict: 360 | if 'fc.weight' in param_dict: 361 | continue 362 | self.state_dict()[i].copy_(param_dict[i]) 363 | 364 | 365 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 366 | model = ResNet(block, layers, **kwargs) 367 | if pretrained: 368 | state_dict = load_state_dict_from_url(model_urls[arch], 369 | progress=progress) 370 | model.load_state_dict(state_dict) 371 | return model 372 | 373 | 374 | def resnet50(pretrained=False, progress=True, **kwargs): 375 | """ResNet-50 model from 376 | `"Deep Residual Learning for Image Recognition" `_ 377 | Args: 378 | pretrained (bool): If True, returns a model pre-trained on ImageNet 379 | progress (bool): If True, displays a progress bar of the download to stderr 380 | """ 381 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 382 | **kwargs) 383 | -------------------------------------------------------------------------------- /reid/models/baseline.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/baseline.pyc -------------------------------------------------------------------------------- /reid/models/newresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import init 6 | from torch.nn.functional import softplus 7 | 8 | from reid.models.baseline import Baseline 9 | from utlis import ChannelCompress, to_edge 10 | 11 | __all__ = ['ft_net'] 12 | 13 | ################################################################################## 14 | # Initialization function 15 | ################################################################################## 16 | def weights_init_kaiming(m): 17 | classname = m.__class__.__name__ 18 | # print(classname) 19 | if classname.find('Conv') != -1: 20 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 21 | elif classname.find('Linear') != -1: 22 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 23 | init.constant_(m.bias.data, 0.0) 24 | elif classname.find('BatchNorm1d') != -1: 25 | init.normal_(m.weight.data, 1.0, 0.02) 26 | init.constant_(m.bias.data, 0.0) 27 | 28 | def weights_init_classifier(m): 29 | classname = m.__class__.__name__ 30 | if classname.find('Linear') != -1: 31 | init.normal_(m.weight.data, std=0.001) 32 | init.constant_(m.bias.data, 0.0) 33 | 34 | ################################################################################## 35 | # framework 36 | ################################################################################## 37 | class ft_net(nn.Module): 38 | @staticmethod 39 | def _init_reduction(reduction): 40 | # conv 41 | nn.init.kaiming_normal_(reduction[0].weight, mode='fan_in') 42 | # nn.init.constant_(reduction[0].bias, 0.) 43 | 44 | # bn 45 | nn.init.normal_(reduction[1].weight, mean=1., std=0.02) 46 | nn.init.constant_(reduction[1].bias, 0.) 47 | 48 | @staticmethod 49 | def _init_fc(fc): 50 | nn.init.kaiming_normal_(fc.weight, mode='fan_out') 51 | # nn.init.normal_(fc.weight, std=0.001) 52 | nn.init.constant_(fc.bias, 0.) 53 | 54 | 55 | def __init__(self, args, num_classes, num_features): 56 | super(ft_net, self).__init__() 57 | 58 | # Load basic config to initialize encoders, decoders and discriminators 59 | self.args = args 60 | 61 | self.IR_backbone = Baseline(num_classes, num_features) 62 | self.IR_Bottleneck = VIB(in_ch=2048, z_dim=self.args.z_dim, num_class= num_classes) 63 | #self.IR_MIE = MIEstimator(size1=2048, size2=int(2 * self.args.z_dim)) 64 | 65 | self.RGB_backbone = Baseline(num_classes, num_features) 66 | self.RGB_Bottleneck = VIB(in_ch=2048, z_dim=self.args.z_dim, num_class= num_classes) 67 | #self.RGB_MIE = MIEstimator(size1=2048, size2=int(2 * self.args.z_dim)) 68 | 69 | self.shared_backbone = Baseline(num_classes, num_features) 70 | self.shared_Bottleneck = VIB(in_ch=2048, z_dim=self.args.z_dim, num_class= num_classes) 71 | #self.shared_MIE = MIEstimator(size1=2048, size2=int(2 * self.args.z_dim)) 72 | 73 | def forward(self, x): 74 | # visible branch 75 | v_observation = self.RGB_backbone(x) 76 | v_representation = self.RGB_Bottleneck(v_observation[1]) 77 | 78 | # modal-shared branch 79 | x_grey = to_edge(x) 80 | i_ms_input = torch.cat([x_grey, x_grey, x_grey], dim=1) 81 | 82 | i_ms_observation = self.shared_backbone(i_ms_input) 83 | v_ms_observation = self.shared_backbone(x) 84 | 85 | i_ms_representation = self.shared_Bottleneck(i_ms_observation[1]) 86 | v_ms_representation = self.shared_Bottleneck(v_ms_observation[1]) 87 | 88 | # infrared branch 89 | i_observation = self.IR_backbone(i_ms_input) 90 | i_representation = self.IR_Bottleneck(i_observation[1]) 91 | 92 | return i_observation, i_representation, i_ms_observation, i_ms_representation, \ 93 | v_observation, v_representation, v_ms_observation, v_ms_representation 94 | 95 | def optims(self): 96 | conv_params = [] 97 | 98 | conv_params += list(self.IR_backbone.parameters()) 99 | conv_params += list(self.IR_Bottleneck.bottleneck.parameters()) 100 | conv_params += list(self.IR_Bottleneck.classifier.parameters()) 101 | 102 | conv_params += list(self.RGB_backbone.parameters()) 103 | conv_params += list(self.RGB_Bottleneck.bottleneck.parameters()) 104 | conv_params += list(self.RGB_Bottleneck.classifier.parameters()) 105 | 106 | conv_params += list(self.shared_backbone.parameters()) 107 | conv_params += list(self.shared_Bottleneck.bottleneck.parameters()) 108 | conv_params += list(self.shared_Bottleneck.classifier.parameters()) 109 | 110 | conv_optim = torch.optim.Adam([p for p in conv_params if p.requires_grad], lr=self.args.lr, weight_decay=5e-4) 111 | 112 | return conv_optim 113 | 114 | ################################################################################## 115 | # Variational Information Bottleneck 116 | ################################################################################## 117 | class VIB(nn.Module): 118 | def __init__(self, in_ch=2048, z_dim=256, num_class=395): 119 | super(VIB, self).__init__() 120 | self.in_ch = in_ch 121 | self.out_ch = z_dim * 2 122 | self.num_class = num_class 123 | self.bottleneck = ChannelCompress(in_ch=self.in_ch, out_ch=self.out_ch) 124 | # classifier of VIB, maybe modified later. 125 | classifier = [] 126 | classifier += [nn.Linear(self.out_ch, self.out_ch // 2)] 127 | classifier += [nn.BatchNorm1d(self.out_ch // 2)] 128 | classifier += [nn.LeakyReLU(0.1)] 129 | classifier += [nn.Dropout(0.5)] 130 | classifier += [nn.Linear(self.out_ch // 2, self.num_class)] 131 | classifier = nn.Sequential(*classifier) 132 | self.classifier = classifier 133 | self.classifier.apply(weights_init_classifier) 134 | 135 | def forward(self, v): 136 | z_given_v = self.bottleneck(v) 137 | p_y_given_z = self.classifier(z_given_v) 138 | return p_y_given_z, z_given_v 139 | 140 | ################################################################################## 141 | # Mutual Information Estimator 142 | ################################################################################## 143 | class MIEstimator(nn.Module): 144 | def __init__(self, size1=2048, size2=512): 145 | super(MIEstimator, self).__init__() 146 | self.size1 = size1 147 | self.size2 = size2 148 | self.in_ch = size1 + size2 149 | add_block = [] 150 | add_block += [nn.Linear(self.in_ch, 2048)] 151 | add_block += [nn.BatchNorm1d(2048)] 152 | add_block += [nn.ReLU()] 153 | add_block += [nn.Linear(2048, 512)] 154 | add_block += [nn.BatchNorm1d(512)] 155 | add_block += [nn.ReLU()] 156 | add_block += [nn.Linear(512, 1)] 157 | 158 | add_block = nn.Sequential(*add_block) 159 | add_block.apply(weights_init_kaiming) 160 | self.block = add_block 161 | 162 | # Gradient for JSD mutual information estimation and EB-based estimation 163 | def forward(self, x1, x2, x1_shuff): 164 | """ 165 | :param x1: observation 166 | :param x2: representation 167 | """ 168 | pos = self.block(torch.cat([x1, x2], 1)) # Positive Samples 169 | neg = self.block(torch.cat([x1_shuff, x2], 1)) 170 | 171 | return -softplus(-pos).mean() - softplus(neg).mean(), pos.mean() - neg.exp().mean() + 1 -------------------------------------------------------------------------------- /reid/models/newresnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/newresnet.pyc -------------------------------------------------------------------------------- /reid/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import torch 4 | 5 | from utlis import SinkhornDistance 6 | from .utils.meters import AverageMeter 7 | from torch.nn.functional import kl_div 8 | 9 | 10 | class BaseTrainer(object): 11 | def __init__(self, args, model, ce_loss, rank_loss, associate_loss, trainvallabel): 12 | super(BaseTrainer, self).__init__() 13 | self.model = model 14 | 15 | self.args = args 16 | 17 | self.CE_Loss = ce_loss 18 | self.rank_loss = rank_loss 19 | 20 | self.softmax = torch.nn.Softmax(dim=1) 21 | self.KLD = torch.nn.KLDivLoss() 22 | self.W_dist = SinkhornDistance().cuda() 23 | 24 | self.associate_loss = associate_loss 25 | 26 | self.trainvallabel = trainvallabel 27 | 28 | def train(self, epoch, data_loader, conv_optim, print_freq=24): 29 | self.model.train() 30 | 31 | batch_time = AverageMeter() 32 | data_time = AverageMeter() 33 | losses_total = AverageMeter() 34 | 35 | losses_triple = AverageMeter() 36 | losses_celoss = AverageMeter() 37 | 38 | losses_cml = AverageMeter() 39 | losses_vsd = AverageMeter() 40 | losses_vcd = AverageMeter() 41 | 42 | end = time.time() 43 | for i, inputs in enumerate(data_loader): 44 | data_time.update(time.time() - end) 45 | 46 | inputs, sub, label = self._parse_data(inputs) 47 | 48 | # Calc the loss 49 | ce_loss, triplet_Loss, conventional_ML, vsd_loss, vcd_loss = self._forward(inputs, label, sub, epoch) 50 | L = self.args.CE_loss * ce_loss + \ 51 | self.args.Triplet_loss * triplet_Loss + \ 52 | self.args.CML_loss * conventional_ML + \ 53 | self.args.VSD_loss * vsd_loss + \ 54 | self.args.VCD_loss * vcd_loss 55 | 56 | conv_optim.zero_grad() 57 | L.backward() 58 | conv_optim.step() 59 | 60 | losses_total.update(L.data.item(), label.size(0)) 61 | 62 | losses_celoss.update(ce_loss.item(), label.size(0)) 63 | losses_triple.update(triplet_Loss.item(), label.size(0)) 64 | 65 | losses_cml.update(conventional_ML.item(), label.size(0)) 66 | losses_vcd.update(vcd_loss.item(), label.size(0)) 67 | losses_vsd.update(vsd_loss.item(), label.size(0)) 68 | 69 | # losses_sharedMI.update(shared_MI.item(), label.size(0)) 70 | # losses_specificMI.update(specific_MI.item(), label.size(0)) 71 | 72 | batch_time.update(time.time() - end) 73 | end = time.time() 74 | 75 | if (i + 1) % print_freq == 0: 76 | print('Epoch: [{}][{}/{}]\t' 77 | 'Time {:.2f} ({:.2f})\t' 78 | 'Total Loss {:.2f} ({:.2f})\t' 79 | 'IDE Loss {:.2f} ({:.2f})\t' 80 | 'Triple Loss {:.2f} ({:.2f})\t' 81 | 'CML Loss {:.3f} ({:.3f})\t' 82 | 'VSD Loss {:.3f} ({:.3f})\t' 83 | 'VCD Loss {:.3f} ({:.3f})\t' 84 | .format(epoch, i + 1, len(data_loader), 85 | batch_time.val, batch_time.avg, 86 | losses_total.val, losses_total.avg, 87 | losses_celoss.val, losses_celoss.avg, 88 | losses_triple.val, losses_triple.avg, 89 | losses_cml.val, losses_cml.avg, 90 | losses_vsd.val, losses_vsd.avg, 91 | losses_vcd.val, losses_vcd.avg)) 92 | return losses_triple.avg, losses_total.avg 93 | 94 | def _parse_data(self, inputs): 95 | raise NotImplementedError 96 | 97 | def _forward(self, inputs, targets): 98 | raise NotImplementedError 99 | 100 | 101 | class Trainer(BaseTrainer): 102 | def _parse_data(self, inputs): 103 | imgs, _, pids, cams = inputs 104 | inputs = imgs.cuda() 105 | pids = pids.cuda() 106 | sub = ((cams == 2).long() + (cams == 5).long()).cuda() 107 | label = torch.cuda.LongTensor(range(pids.size(0))) 108 | for i in range(pids.size(0)): 109 | label[i] = self.trainvallabel[pids[i].item()] 110 | return inputs, sub, label 111 | 112 | def _forward(self, inputs, label, sub, epoch): 113 | i_observation, i_representation, i_ms_observation, i_ms_representation, \ 114 | v_observation, v_representation, v_ms_observation, v_ms_representation = self.model(inputs) 115 | 116 | # Classification loss 117 | ce_loss = 0.5 * (self.CE_Loss(i_observation[0], label) + self.CE_Loss(i_representation[0], label)) + \ 118 | 0.5 * (self.CE_Loss(v_observation[0], label) + self.CE_Loss(v_representation[0], label)) + \ 119 | 0.25 * (self.CE_Loss(i_ms_observation[0], label) + self.CE_Loss(i_ms_representation[0], label)) + \ 120 | 0.25 * (self.CE_Loss(v_ms_observation[0], label) + self.CE_Loss(v_ms_representation[0], label)) 121 | 122 | # Metric learning, notice rank loss are applied to v and z, respectively. 123 | triplet_Loss = 0.5 * ( self.rank_loss(i_observation[1], label, sub) + self.rank_loss(i_representation[1], label, sub)) + \ 124 | 0.5 * (self.rank_loss(v_observation[1], label, sub) + self.rank_loss(v_representation[1], label, sub)) + \ 125 | 0.25 * (self.rank_loss(i_ms_observation[1], label, sub) + self.rank_loss(i_ms_representation[1], label, sub)) + \ 126 | 0.25 * (self.rank_loss(v_ms_observation[1], label, sub) + self.rank_loss(v_ms_representation[1], label, sub)) 127 | 128 | #associate_loss = 0.5 * (self.associate_loss(i_observation[1], label, sub) + self.associate_loss(i_representation[1], label, sub)) + \ 129 | # 0.5 * (self.associate_loss(v_observation[1], label, sub) + self.associate_loss(v_representation[1], label, sub)) + \ 130 | # 0.25 * (self.associate_loss(i_ms_observation[1], label, sub) + self.associate_loss(i_ms_representation[1], label, sub)) + \ 131 | # 0.25 * (self.associate_loss(v_ms_observation[1], label, sub) + self.associate_loss(v_ms_representation[1], label, sub)) 132 | 133 | # Conventional mutual learning strategy, conducted only between observations of modal-specific branches. 134 | conventional_ML = self.W_dist(self.softmax(i_observation[0]), self.softmax(v_observation[0])) 135 | 136 | # Variational Self-Distillation, preserving sufficiency 137 | vsd_loss = kl_div(input=self.softmax(i_observation[0].detach() / self.args.temperature), 138 | target=self.softmax(i_representation[0] / self.args.temperature)) + \ 139 | kl_div(input=self.softmax(v_observation[0].detach() / self.args.temperature), 140 | target=self.softmax(v_representation[0] / self.args.temperature)) 141 | 142 | vcd_loss = 0.5 * kl_div(input=self.softmax(v_ms_observation[0].detach()), 143 | target=self.softmax(i_ms_representation[0])) + \ 144 | 0.5 * kl_div(input=self.softmax(i_ms_observation[0].detach()), 145 | target=self.softmax(v_ms_representation[0])) 146 | 147 | # mutual information estimation for modal-specific branches and modal-shared branch 148 | # shuff_order = np.random.permutation(self.args.batch_size) 149 | # specific_MI_I = self.model.module.IR_MIE(x1=i_observation[1], x2=i_representation[1], 150 | # x1_shuff=i_observation[1][shuff_order, :])[0].mean() 151 | # specific_MI_V = self.model.module.IR_MIE(x1=v_observation[1], x2=v_representation[1], 152 | # x1_shuff=v_observation[1][shuff_order, :])[0].mean() 153 | # shared_MI_I = self.model.module.shared_MIE(x1=i_ms_observation[1], x2=i_ms_representation[1], 154 | # x1_shuff=i_ms_observation[1][shuff_order, :])[0].mean() 155 | # shared_MI_V = self.model.module.shared_MIE(x1=v_ms_observation[1], x2=v_ms_representation[1], 156 | # x1_shuff=v_ms_observation[1][shuff_order, :])[0].mean() 157 | 158 | return ce_loss, triplet_Loss, conventional_ML, vsd_loss, vcd_loss#, associate_loss 159 | -------------------------------------------------------------------------------- /reid/trainers.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/trainers.pyc -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /reid/utils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__init__.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/logging.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/logging.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/logging.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/logging.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/meters.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/meters.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/meters.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/meters.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/osutils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/osutils.cpython-35.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/osutils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/osutils.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/osutils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/osutils.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/serialization.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/serialization.cpython-35.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/serialization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/serialization.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/serialization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/serialization.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .dataset import Dataset 4 | from .preprocessor import Preprocessor 5 | 6 | 7 | -------------------------------------------------------------------------------- /reid/utils/data/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__init__.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/dataset.cpython-35.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/preprocessor.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/preprocessor.cpython-35.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/preprocessor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/preprocessor.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/preprocessor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/preprocessor.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/sampler.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..serialization import read_json 7 | 8 | 9 | def _pluck(identities, indices, relabel=False): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, cam_images in enumerate(pid_images): 14 | for fname in cam_images: 15 | name = osp.splitext(fname)[0] 16 | try: 17 | x, y, _ = map(int, name.split('_')) 18 | assert pid == x and camid == y 19 | except: 20 | _, _, _, _, x = map(str, name.split('_')) 21 | if relabel: 22 | ret.append((fname, index, camid)) 23 | else: 24 | ret.append((fname, pid, camid)) 25 | return ret 26 | 27 | 28 | class Dataset(object): 29 | def __init__(self, root, split_id=0): 30 | self.root = root 31 | self.split_id = split_id 32 | self.meta = None 33 | self.split = None 34 | self.train, self.val, self.trainval = [], [], [] 35 | self.query, self.gallery = [], [] 36 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 37 | self.trainvallabel = {} 38 | 39 | @property 40 | def images_dir(self): 41 | return osp.join(self.root, 'images') 42 | 43 | def load(self, num_val=0.3, verbose=False): 44 | splits = read_json(osp.join(self.root, 'splits.json')) 45 | if self.split_id >= len(splits): 46 | raise ValueError("split_id exceeds total splits {}" 47 | .format(len(splits))) 48 | self.split = splits[self.split_id] 49 | 50 | # Randomly split train / val 51 | trainval_pids = np.asarray(self.split['trainval']) 52 | test_pids = np.asarray(self.split['query']) 53 | # np.random.shuffle(trainval_pids) 54 | num = len(trainval_pids) 55 | if isinstance(num_val, float): 56 | num_val = int(round(num * num_val)) 57 | if num_val >= num or num_val < 0: 58 | raise ValueError("num_val exceeds total identities {}" 59 | .format(num)) 60 | train_pids = sorted(trainval_pids[:-num_val]) 61 | val_pids = sorted(trainval_pids[-num_val:]) 62 | 63 | self.meta = read_json(osp.join(self.root, 'meta.json')) 64 | identities = self.meta['identities'] 65 | self.train = _pluck(identities, train_pids, relabel=False) 66 | self.val = _pluck(identities, val_pids, relabel=False) 67 | self.trainval = _pluck(identities, trainval_pids, relabel=False) 68 | # print(self.trainval[1],self.trainval[1][1]) 69 | countIR = 0 70 | countRGB = 0 71 | for image in self.trainval: 72 | # print(image[2] == 2 or image[2] == 5) 73 | if image[2] == 2 or image[2] == 5: 74 | countIR = countIR + 1 75 | else: 76 | countRGB = countRGB + 1 77 | self.query = _pluck(identities, self.split['query']) 78 | self.gallery = _pluck(identities, self.split['gallery']) 79 | query = 0 80 | gallery = 0 81 | for image in self.query: 82 | if image[2] == 2 or image[2] == 5: 83 | query += 1 84 | else: 85 | gallery += 1 86 | 87 | self.num_train_ids = len(train_pids) 88 | self.num_val_ids = len(val_pids) 89 | self.num_trainval_ids = len(trainval_pids) 90 | # print(sorted(trainval_pids)) 91 | for index,i in enumerate(sorted(trainval_pids)): 92 | self.trainvallabel[i] = index 93 | # print (index) 94 | 95 | if verbose: 96 | print(self.__class__.__name__, "dataset loaded") 97 | print(" subset | # ids | # images") 98 | print(" ---------------------+----------+---------") 99 | print(" train | {:8d} | {:8d}" 100 | .format(self.num_train_ids, len(self.train))) 101 | print(" val | {:8d} | {:8d}" 102 | .format(self.num_val_ids, len(self.val))) 103 | print(" trainval | {:8d} | {:8d}" 104 | .format(self.num_trainval_ids, len(self.trainval))) 105 | print(" query | {:8d} | {:8d}" 106 | .format(len(test_pids), len(test_pids) * 4)) 107 | print(" gallery | {:8d} | {:8d}" 108 | .format(len(test_pids), gallery)) 109 | print(" num of RGB and IR | {:8d} | {:8d}" 110 | .format(countRGB, countIR)) 111 | 112 | def _check_integrity(self): 113 | return osp.isdir(osp.join(self.root, 'images')) and \ 114 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 115 | osp.isfile(osp.join(self.root, 'splits.json')) 116 | -------------------------------------------------------------------------------- /reid/utils/data/dataset.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/dataset.pyc -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | 6 | 7 | class Preprocessor(object): 8 | def __init__(self, dataset, root=None, transform=None): 9 | super(Preprocessor, self).__init__() 10 | self.dataset = dataset 11 | self.root = root 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | def __getitem__(self, indices): 18 | if isinstance(indices, (tuple, list)): 19 | print("indices is tuple list") 20 | return [self._get_single_item(index) for index in indices] 21 | return self._get_single_item(indices) 22 | 23 | def _get_single_item(self, index): 24 | fname, pid, camid = self.dataset[index] 25 | fpath = fname 26 | if self.root is not None: 27 | fpath = osp.join(self.root, fname) 28 | img = Image.open(fpath).convert('RGB') 29 | # img = Image.open(fpath).convert('L') 30 | # print(img) 31 | if self.transform is not None: 32 | img = self.transform(img) 33 | return img, fname, pid, camid 34 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/preprocessor.pyc -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from random import shuffle 7 | from torch.utils.data.sampler import ( 8 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 9 | WeightedRandomSampler) 10 | 11 | 12 | class RandomIdentitySampler(Sampler): 13 | def __init__(self, data_source, num_instances=1): 14 | self.data_source = data_source 15 | self.num_instances = num_instances 16 | self.index_dic = defaultdict(list) 17 | for index, (_, pid, _) in enumerate(data_source): 18 | self.index_dic[pid].append(index) 19 | self.pids = list(self.index_dic.keys()) 20 | self.num_samples = len(self.pids) 21 | 22 | def __len__(self): 23 | return self.num_samples * self.num_instances 24 | 25 | def __iter__(self): 26 | indices = torch.randperm(self.num_samples) 27 | ret = [] 28 | for i in indices: 29 | pid = self.pids[i] 30 | t = self.index_dic[pid] 31 | if len(t) >= self.num_instances: 32 | t = np.random.choice(t, size=self.num_instances, replace=False) 33 | else: 34 | t = np.random.choice(t, size=self.num_instances, replace=True) 35 | ret.extend(t) 36 | return iter(ret) 37 | 38 | class CamSampler(Sampler): 39 | def __init__(self, data_source, need_cam, num=0): 40 | self.data_source = data_source 41 | self.index_dic = [] 42 | self.id_cam = [[] for _ in range(533)] 43 | 44 | if num>0: 45 | for index, (_, pid, cam) in enumerate(data_source): 46 | if cam in need_cam: 47 | self.id_cam[pid].append(index) 48 | for i in range(533): 49 | if len(self.id_cam[i])>num: 50 | self.index_dic.extend(self.id_cam[i][:num]) 51 | else: 52 | for index, (_, pid, cam) in enumerate(data_source): 53 | if cam in need_cam: 54 | self.index_dic.append(index) 55 | 56 | def __len__(self): 57 | return len(self.index_dic) 58 | 59 | def __iter__(self): 60 | return iter(self.index_dic) 61 | 62 | class CamRandomIdentitySampler(Sampler): 63 | def __init__(self, data_source, num_instances=2): 64 | self.data_source = data_source 65 | self.num_instances = num_instances 66 | if num_instances % 2 > 0: 67 | raise ValueError("The num_instances should be a even number") 68 | self.index_dic_I = defaultdict(list) 69 | self.index_dic_IR = defaultdict(list) 70 | for index, (name, pid, cam) in enumerate(data_source): 71 | if cam == 2 or cam == 5: 72 | self.index_dic_IR[pid].append(index) 73 | else: 74 | self.index_dic_I[pid].append(index) 75 | self.pids = list(self.index_dic_I.keys()) 76 | self.num_samples = len(self.pids) 77 | 78 | def __len__(self): 79 | return self.num_samples * self.num_instances 80 | 81 | def __iter__(self): 82 | indices = torch.randperm(self.num_samples) 83 | ret = [] 84 | for i in indices: 85 | pid_I = self.pids[i] 86 | pid_IR = self.pids[i] 87 | t_I = self.index_dic_I[pid_I] 88 | t_IR = self.index_dic_IR[pid_IR] 89 | if len(t_I) >= self.num_instances / 2: 90 | t_I = np.random.choice(t_I, size=int(self.num_instances / 2), replace=False) 91 | else: 92 | t_I = np.random.choice(t_I, size=int(self.num_instances / 2), replace=True) 93 | if len(t_IR) >= self.num_instances / 2: 94 | t_IR = np.random.choice(t_IR, size=int(self.num_instances / 2), replace=False) 95 | else: 96 | t_IR = np.random.choice(t_IR, size=int(self.num_instances / 2), replace=True) 97 | # ret.extend(t_I) 98 | # ret.extend(t_IR) 99 | for j in range(self.num_instances // 2): 100 | ret.append(t_I[j]) 101 | ret.append(t_IR[j]) 102 | return iter(ret) 103 | -------------------------------------------------------------------------------- /reid/utils/data/sampler.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/sampler.pyc -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | class RandomErasing(object): 52 | """ Randomly selects a rectangle region in an image and erases its pixels. 53 | 'Random Erasing Data Augmentation' by Zhong et al. 54 | See https://arxiv.org/pdf/1708.04896.pdf 55 | Args: 56 | probability: The probability that the Random Erasing operation will be performed. 57 | sl: Minimum proportion of erased area against input image. 58 | sh: Maximum proportion of erased area against input image. 59 | r1: Minimum aspect ratio of erased area. 60 | mean: Erasing value. 61 | """ 62 | 63 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 64 | self.probability = probability 65 | self.mean = mean 66 | self.sl = sl 67 | self.sh = sh 68 | self.r1 = r1 69 | 70 | def __call__(self, img): 71 | 72 | if random.uniform(0, 1) >= self.probability: 73 | return img 74 | 75 | for attempt in range(100): 76 | area = img.size()[1] * img.size()[2] 77 | 78 | target_area = random.uniform(self.sl, self.sh) * area 79 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 80 | 81 | h = int(round(math.sqrt(target_area * aspect_ratio))) 82 | w = int(round(math.sqrt(target_area / aspect_ratio))) 83 | 84 | if w < img.size()[2] and h < img.size()[1]: 85 | x1 = random.randint(0, img.size()[1] - h) 86 | y1 = random.randint(0, img.size()[2] - w) 87 | if img.size()[0] == 3: 88 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 89 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 90 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 91 | else: 92 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 93 | return img 94 | 95 | return img 96 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/transforms.pyc -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/utils/logging.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/logging.pyc -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/utils/meters.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/meters.pyc -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/utils/osutils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/osutils.pyc -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, epoch, dirname, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(dirname)) 26 | torch.save(state, osp.join(dirname,fpath)) 27 | if is_best: 28 | shutil.copyfile(dirname+'/'+fpath, dirname+'/'+'model_best.pth.tar') 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /reid/utils/serialization.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/serialization.pyc -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | # encoding: utf-8 3 | import os 4 | c = 0.15 5 | for i in range(0,16): 6 | c = c + 0.15 7 | file_data = "" 8 | f = open("demo_sysu.sh") 9 | lines = f.readlines() 10 | print(c) 11 | with open("demo_sysu.sh", "w") as fw: 12 | for line in lines: 13 | print(line) 14 | if "D_loss" in line: 15 | line = "-D_loss " + str(c) + " \\" + '\n' 16 | if "logs-dir" in line: 17 | line = "--logs-dir ./weight_of_D_Loss_" + str(c) + ' \\' + '\n' 18 | file_data += line 19 | fw.write(file_data) 20 | os.system('sh demo_sysu.sh') 21 | 22 | -------------------------------------------------------------------------------- /utlis.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from bisect import bisect_right 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.nn import init 8 | ############################################################################################ 9 | # Channel Compress 10 | ############################################################################################ 11 | 12 | def weights_init_kaiming(m): 13 | classname = m.__class__.__name__ 14 | # print(classname) 15 | if classname.find('Conv') != -1: 16 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 17 | elif classname.find('Linear') != -1: 18 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 19 | init.constant_(m.bias.data, 0.0) 20 | elif classname.find('BatchNorm1d') != -1: 21 | init.normal_(m.weight.data, 1.0, 0.02) 22 | init.constant_(m.bias.data, 0.0) 23 | 24 | def weights_init_classifier(m): 25 | classname = m.__class__.__name__ 26 | if classname.find('Linear') != -1: 27 | init.normal_(m.weight.data, std=0.001) 28 | init.constant_(m.bias.data, 0.0) 29 | 30 | class ChannelCompress(nn.Module): 31 | def __init__(self, in_ch=2048, out_ch=256): 32 | """ 33 | reduce the amount of channels to prevent final embeddings overwhelming shallow feature maps 34 | out_ch could be 512, 256, 128 35 | """ 36 | super(ChannelCompress, self).__init__() 37 | num_bottleneck = 1000 38 | add_block = [] 39 | add_block += [nn.Linear(in_ch, num_bottleneck)] 40 | add_block += [nn.BatchNorm1d(num_bottleneck)] 41 | add_block += [nn.ReLU()] 42 | 43 | add_block += [nn.Linear(num_bottleneck, 500)] 44 | add_block += [nn.BatchNorm1d(500)] 45 | add_block += [nn.ReLU()] 46 | add_block += [nn.Linear(500, out_ch)] 47 | 48 | # Extra BN layer, need to be removed 49 | #add_block += [nn.BatchNorm1d(out_ch)] 50 | 51 | add_block = nn.Sequential(*add_block) 52 | add_block.apply(weights_init_kaiming) 53 | self.model = add_block 54 | 55 | def forward(self, x): 56 | x = self.model(x) 57 | return x 58 | 59 | ############################################################################################ 60 | # Classification Loss 61 | ############################################################################################ 62 | class CrossEntropyLabelSmooth(nn.Module): 63 | """Cross entropy loss with label smoothing regularizer. 64 | 65 | Reference: 66 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 67 | Equation: y = (1 - epsilon) * y + epsilon / K. 68 | 69 | Args: 70 | num_classes (int): number of classes. 71 | epsilon (float): weight. 72 | """ 73 | def __init__(self, num_classes, epsilon=0.0, use_gpu=True): 74 | super(CrossEntropyLabelSmooth, self).__init__() 75 | self.num_classes = num_classes 76 | self.epsilon = epsilon 77 | self.use_gpu = use_gpu 78 | self.logsoftmax = nn.LogSoftmax(dim=1) 79 | 80 | def forward(self, inputs, targets): 81 | """ 82 | Args: 83 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 84 | targets: ground truth labels with shape (num_classes) 85 | """ 86 | log_probs = self.logsoftmax(inputs) 87 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 88 | if self.use_gpu: targets = targets.cuda() 89 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 90 | loss = (- targets * log_probs).mean(0).sum() 91 | return loss 92 | 93 | 94 | ############################################################################################ 95 | # gray_scale function 96 | ############################################################################################ 97 | def to_edge(x): 98 | x = x.data.cpu() 99 | out = torch.FloatTensor(x.size(0), x.size(2), x.size(3)) 100 | for i in range(x.size(0)): 101 | item = x[i,:,:,:] 102 | #print(item.shape) 103 | r, g, b = item[0, :, :], item[1, :, :], item[2, :, :] 104 | xx = 0.2989 * r + 0.5870 * g + 0.1140 * b 105 | #print(xx.shape) 106 | out[i, :, :] = xx 107 | out = out.unsqueeze(1) 108 | return out.cuda() 109 | 110 | ############################################################################################ 111 | # Random Erasing 112 | ############################################################################################ 113 | class RandomErasing(object): 114 | """ Randomly selects a rectangle region in an image and erases its pixels. 115 | 'Random Erasing Data Augmentation' by Zhong et al. 116 | See https://arxiv.org/pdf/1708.04896.pdf 117 | Args: 118 | probability: The probability that the Random Erasing operation will be performed. 119 | sl: Minimum proportion of erased area against input image. 120 | sh: Maximum proportion of erased area against input image. 121 | r1: Minimum aspect ratio of erased area. 122 | mean: Erasing value. 123 | """ 124 | 125 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 126 | self.probability = probability 127 | self.mean = mean 128 | self.sl = sl 129 | self.sh = sh 130 | self.r1 = r1 131 | 132 | def __call__(self, img): 133 | 134 | if random.uniform(0, 1) >= self.probability: 135 | return img 136 | 137 | for attempt in range(100): 138 | area = img.size()[1] * img.size()[2] 139 | 140 | target_area = random.uniform(self.sl, self.sh) * area 141 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 142 | 143 | h = int(round(math.sqrt(target_area * aspect_ratio))) 144 | w = int(round(math.sqrt(target_area / aspect_ratio))) 145 | 146 | if w < img.size()[2] and h < img.size()[1]: 147 | x1 = random.randint(0, img.size()[1] - h) 148 | y1 = random.randint(0, img.size()[2] - w) 149 | if img.size()[0] == 3: 150 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 151 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 152 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 153 | else: 154 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 155 | return img 156 | 157 | return img 158 | 159 | ############################################################################################ 160 | # Warmup scheduler 161 | ############################################################################################ 162 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 163 | def __init__( 164 | self, 165 | optimizer, 166 | milestones, 167 | gamma=0.1, 168 | warmup_factor=1.0 / 3, 169 | warmup_iters=500, 170 | warmup_method="linear", 171 | last_epoch=-1, 172 | ): 173 | if not list(milestones) == sorted(milestones): 174 | raise ValueError( 175 | "Milestones should be a list of" " increasing integers. Got {}", 176 | milestones, 177 | ) 178 | 179 | if warmup_method not in ("constant", "linear"): 180 | raise ValueError( 181 | "Only 'constant' or 'linear' warmup_method accepted" 182 | "got {}".format(warmup_method) 183 | ) 184 | self.milestones = milestones 185 | self.gamma = gamma 186 | self.warmup_factor = warmup_factor 187 | self.warmup_iters = warmup_iters 188 | self.warmup_method = warmup_method 189 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 190 | 191 | def get_lr(self): 192 | warmup_factor = 1 193 | if self.last_epoch < self.warmup_iters: 194 | if self.warmup_method == "constant": 195 | warmup_factor = self.warmup_factor 196 | elif self.warmup_method == "linear": 197 | alpha = self.last_epoch / self.warmup_iters 198 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 199 | return [ 200 | base_lr 201 | * warmup_factor 202 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 203 | for base_lr in self.base_lrs 204 | ] 205 | ############################################################################################ 206 | # Rank loss 207 | ############################################################################################ 208 | class Rank_loss(nn.Module): 209 | 210 | ## Basic idea for cross_modality rank_loss 8 211 | 212 | def __init__(self, margin_1=1.0, margin_2=1.5, alpha_1=2.4, alpha_2=2.2, tval=1.0): 213 | super(Rank_loss, self).__init__() 214 | self.margin_1 = margin_1 # for same modality 215 | self.margin_2 = margin_2 # for different modalities 216 | self.alpha_1 = alpha_1 # for same modality 217 | self.alpha_2 = alpha_2 # for different modalities 218 | self.tval = tval 219 | 220 | def forward(self, x, targets, sub, norm = True): 221 | if norm: 222 | #x = self.normalize(x) 223 | x = torch.nn.functional.normalize(x, dim=1, p=2) 224 | dist_mat = self.euclidean_dist(x, x) # compute the distance 225 | loss = self.rank_loss(dist_mat, targets, sub) 226 | return loss #,dist_mat 227 | 228 | def rank_loss(self, dist, targets, sub): 229 | loss = 0.0 230 | for i in range(dist.size(0)): 231 | is_pos = targets.eq(targets[i]) 232 | is_pos[i] = 0 233 | is_neg = targets.ne(targets[i]) 234 | 235 | intra_modality = sub.eq(sub[i]) 236 | cross_modality = ~ intra_modality 237 | 238 | mask_pos_intra = is_pos* intra_modality 239 | mask_pos_cross = is_pos* cross_modality 240 | mask_neg_intra = is_neg* intra_modality 241 | mask_neg_cross = is_neg* cross_modality 242 | 243 | ap_pos_intra = torch.clamp(torch.add(dist[i][mask_pos_intra], self.margin_1-self.alpha_1),0) 244 | ap_pos_cross = torch.clamp(torch.add(dist[i][mask_pos_cross], self.margin_2-self.alpha_2),0) 245 | 246 | loss_ap = torch.div(torch.sum(ap_pos_intra), ap_pos_intra.size(0)+1e-5) 247 | loss_ap += torch.div(torch.sum(ap_pos_cross), ap_pos_cross.size(0)+1e-5) 248 | 249 | dist_an_intra = dist[i][mask_neg_intra] 250 | dist_an_cross = dist[i][mask_neg_cross] 251 | 252 | an_less_intra = dist_an_intra[torch.lt(dist[i][mask_neg_intra], self.alpha_1)] 253 | an_less_cross = dist_an_cross[torch.lt(dist[i][mask_neg_cross], self.alpha_2)] 254 | 255 | an_weight_intra = torch.exp(self.tval*(-1* an_less_intra +self.alpha_1)) 256 | an_weight_intra_sum = torch.sum(an_weight_intra)+1e-5 257 | an_weight_cross = torch.exp(self.tval*(-1* an_less_cross +self.alpha_2)) 258 | an_weight_cross_sum = torch.sum(an_weight_cross)+1e-5 259 | an_sum_intra = torch.sum(torch.mul(self.alpha_1-an_less_intra,an_weight_intra)) 260 | an_sum_cross = torch.sum(torch.mul(self.alpha_2-an_less_cross,an_weight_cross)) 261 | 262 | loss_an =torch.div(an_sum_intra,an_weight_intra_sum ) +torch.div(an_sum_cross, an_weight_cross_sum) 263 | #loss_an = torch.div(an_sum_cross,an_weight_cross_sum ) 264 | loss += loss_ap + loss_an 265 | #loss += loss_an 266 | return loss * 1.0/ dist.size(0) 267 | 268 | def normalize(self, x, axis=-1): 269 | x = 1.* x /(torch.norm(x, 2, axis, keepdim = True).expand_as(x)+ 1e-12) 270 | return x 271 | 272 | def euclidean_dist(self, x, y): 273 | m, n =x.size(0), y.size(0) 274 | 275 | xx = torch.pow(x,2).sum(1, keepdim= True).expand(m,n) 276 | yy = torch.pow(y,2).sum(1, keepdim= True).expand(n,m).t() 277 | dist = xx + yy 278 | dist.addmm_(1, -2, x, y.t()) 279 | dist = dist.clamp(min =1e-12).sqrt() 280 | 281 | return dist 282 | 283 | ############################################################################################ 284 | # Associate Loss 285 | ############################################################################################ 286 | class ASS_loss(nn.Module): 287 | def __init__(self, walker_loss=1.0, visit_loss=1.0): 288 | super(ASS_loss, self).__init__() 289 | self.walker_loss = walker_loss 290 | self.visit_loss = visit_loss 291 | self.ce = nn.CrossEntropyLoss() 292 | self.logsoftmax = nn.LogSoftmax(dim=-1) 293 | 294 | def forward(self, feature, targets, sub): 295 | ## normalize 296 | feature = torch.nn.functional.normalize(feature, dim=1, p=2) 297 | loss = 0.0 298 | for i in range(feature.size(0)): 299 | cross_modality = sub.ne(sub[i]) 300 | 301 | p_logit_ab, v_loss_ab = self.probablity(feature, cross_modality, targets) 302 | p_logit_ba, v_loss_ba = self.probablity(feature, ~cross_modality, targets) 303 | n1 = targets[cross_modality].size(0) 304 | n2 = targets[~cross_modality].size(0) 305 | 306 | is_pos_ab = targets[cross_modality].expand(n1,n1).eq(targets[cross_modality].expand(n1,n1).t()) 307 | 308 | p_target_ab = is_pos_ab.float()/torch.sum(is_pos_ab, dim=1).float().expand_as(is_pos_ab) 309 | 310 | is_pos_ba = targets[~cross_modality].expand(n2,n2).eq(targets[cross_modality].expand(n2,n2).t()) 311 | p_target_ba = is_pos_ba.float()/torch.sum(is_pos_ba, dim=1).float().expand_as(is_pos_ba) 312 | 313 | p_logit_ab = self.logsoftmax(p_logit_ab) 314 | p_logit_ba = self.logsoftmax(p_logit_ba) 315 | 316 | loss += (- p_target_ab * p_logit_ab).mean(0).sum()+ (- p_target_ba * p_logit_ba).mean(0).sum() 317 | 318 | loss += 1.0*(v_loss_ab+v_loss_ba) 319 | 320 | return loss/feature.size(0)/4 321 | 322 | def probablity(self, feature, cross_modality, target): 323 | a = feature[cross_modality] 324 | b = feature[~cross_modality] 325 | 326 | match_ab = a.mm(b.t()) 327 | 328 | p_ab = F.softmax(match_ab, dim=-1) 329 | p_ba = F.softmax(match_ab, dim=-1) 330 | p_aba = torch.log(1e-8+p_ab.mm(p_ba)) 331 | 332 | visit_loss = self.new_visit(p_ab, target, cross_modality) 333 | 334 | return p_aba, visit_loss 335 | 336 | def new_visit(self, p_ab, target, cross_modality): 337 | p_ab = torch.log(1e-8 +p_ab) 338 | visit_probability = p_ab.mean(dim=0).expand_as(p_ab) 339 | n1 = target[cross_modality].size(0) 340 | n2 = target[~cross_modality].size(0) 341 | p_target_ab = target[cross_modality].expand(n1,n1).eq(target[~cross_modality].expand(n2,n2)) 342 | p_target_ab = p_target_ab.float()/torch.sum(p_target_ab, dim=1).float().expand_as(p_target_ab) 343 | loss = (- p_target_ab * visit_probability).mean(0).sum() 344 | return loss 345 | 346 | def normalize(self, x, axis=-1): 347 | x = 1.* x /(torch.norm(x, 2, axis, keepdim = True).expand_as(x)+ 1e-12) 348 | return x 349 | 350 | ############################################################################################ 351 | # Wasserstein Distance 352 | ############################################################################################ 353 | class SinkhornDistance(nn.Module): 354 | def __init__(self, eps=0.01, max_iter=100, reduction='mean'): 355 | super(SinkhornDistance, self).__init__() 356 | self.eps = eps 357 | self.max_iter = max_iter 358 | self.reduction = reduction 359 | 360 | def forward(self, x, y): 361 | # The Sinkhorn algorithm takes as input three variables : 362 | C = self._cost_matrix(x, y) # Wasserstein cost function 363 | C = C.cuda() 364 | n_points = x.shape[-2] 365 | if x.dim() == 2: 366 | batch_size = 1 367 | else: 368 | batch_size = x.shape[0] 369 | 370 | # both marginals are fixed with equal weights 371 | mu = torch.empty(batch_size, n_points, dtype=torch.float, 372 | requires_grad=False).fill_(1.0 / n_points).squeeze() 373 | nu = torch.empty(batch_size, n_points, dtype=torch.float, 374 | requires_grad=False).fill_(1.0 / n_points).squeeze() 375 | 376 | u = torch.zeros_like(mu) 377 | u = u.cuda() 378 | v = torch.zeros_like(nu) 379 | v = v.cuda() 380 | # To check if algorithm terminates because of threshold 381 | # or max iterations reached 382 | actual_nits = 0 383 | # Stopping criterion 384 | thresh = 1e-1 385 | 386 | # Sinkhorn iterations 387 | for i in range(self.max_iter): 388 | u1 = u # useful to check the update 389 | u = self.eps * (torch.log(mu + 1e-8).cuda() - self.lse(self.M(C, u, v))).cuda() + u 390 | v = self.eps * (torch.log(nu + 1e-8).cuda() - self.lse(self.M(C, u, v).transpose(-2, -1))).cuda() + v 391 | err = (u - u1).abs().sum(-1).mean() 392 | err = err.cuda() 393 | 394 | actual_nits += 1 395 | if err.item() < thresh: 396 | break 397 | 398 | U, V = u, v 399 | # Transport plan pi = diag(a)*K*diag(b) 400 | pi = torch.exp(self.M(C, U, V)) 401 | # Sinkhorn distance 402 | cost = torch.sum(pi * C, dim=(-2, -1)) 403 | 404 | if self.reduction == 'mean': 405 | cost = cost.mean() 406 | elif self.reduction == 'sum': 407 | cost = cost.sum() 408 | 409 | return cost 410 | 411 | def M(self, C, u, v): 412 | "Modified cost for logarithmic updates" 413 | "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$" 414 | return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps 415 | 416 | @staticmethod 417 | def _cost_matrix(x, y, p=2): 418 | "Returns the matrix of $|x_i-y_j|^p$." 419 | x_col = x.unsqueeze(-2) 420 | y_lin = y.unsqueeze(-3) 421 | C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1) 422 | return C 423 | 424 | @staticmethod 425 | def lse(A): 426 | "log-sum-exp" 427 | # add 10^-6 to prevent NaN 428 | result = torch.log(torch.exp(A).sum(-1) + 1e-6) 429 | return result 430 | 431 | @staticmethod 432 | def ave(u, u1, tau): 433 | "Barycenter subroutine, used by kinetic acceleration through extrapolation." 434 | return tau * u + (1 - tau) * u1 --------------------------------------------------------------------------------