8 |
9 | Please read [our paper](https://arxiv.org/abs/2104.02862) for a more detailed description of the training procedure.
10 |
11 | ### Bibtex
12 | Please use the following bibtex for citations:
13 | ```latex
14 | @inproceedings{VariationalDistillation,
15 | title={Farewell to Mutual Information Variational Distiilation for Cross-Modal Person Re-identification},
16 | author={Xudong Tian and Zhizhong Zhang and Shaohui Lin and Yanyun Qu and Yuan Xie and Lizhuang Ma},
17 | booktitle={Computer Vision and Pattern Recognition},
18 | year={2021}
19 | }
20 | ```
21 |
22 | ## Get Started
23 | 1. `cd` to folder where you want to download this project
24 |
25 | 2. Run `git clone https://github.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification.git`
26 |
27 | 3. Install dependencies:
28 | - python>=3.7.0
29 | - [pytorch>=1.3.0](https://pytorch.org/)
30 | - torchvision
31 |
32 | 4. Prepare datasets
33 | - Download RegDB
34 | - Download [SYSU-MM01](https://github.com/wuancong/SYSU-MM01)
35 |
36 | Create a directory to store the required datasets under this project or outside this project, and remember to set `--data-dir` to the right path before training.
37 |
38 | ## Train
39 | This project provides code to train and evaluate different architectures under both datasets. You can directly run `/mm01.py` and `regdb.py` under the default settings or conduct customized modifications for both datasets.
40 |
41 | ## Evaluation
42 | - MM01: To evaluate the model under standard protocol, you need to run `/feature_extract.py` to obtain features at first, then run `/evaluation/evaluation_SYSU_MM01.py` to conduct standard evaluation.
43 | - RegDB: You can directly run `/RegDB_test/RegDB_test.py` to obtain Visible-Thermal performance, and change the default settings to evaluate the model under another setting, i.e., Thermal-Visible.
44 |
45 | ## Results
46 | SYSU-MM01 (all-search mode)
47 | | Metric | Value |
48 | | --- | --- |
49 | | Rank1 | 60.02\% |
50 | | Rank10 | 94.18\% |
51 | | Rank20 | 98.14\% |
52 | | mAP | 58.80\% |
53 |
54 | SYSU-MM01 (indoor-search mode)
55 | | Metric | Value |
56 | | --- | --- |
57 | | Rank1 | 66.05\% |
58 | | Rank10 | 96.59\% |
59 | | Rank20 | 99.38\% |
60 | | mAP | 72.98\% |
61 |
62 | RegDB
63 | | Mode | Rank-1 (mAP) |
64 | | --- | --- |
65 | | Visible-Thermal | 73.2\% (71.6\%) |
66 | | Thermal-Visible | 71.8\% (70.1\%) |
67 |
68 | ## Visualization
69 | 2-D projection of the embedding space obtained by using t-SNE. The results are obtained from our method and the conventional information bottleneck on SYSU-MM01 dataset. Different colors are used to denote different person IDs.
70 |
71 |
72 |
73 |
74 |
75 | In addition, we plot the joint embedding space of data from different modals for better visualization. Note more descriptions and details could be found in [our paper](https://arxiv.org/abs/2104.02862).
76 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/RegDB_test/RegDB_test.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from torch import nn
3 | import argparse
4 | import torch.backends.cudnn as cudnn
5 | import torchvision.transforms as transforms
6 | from reid.utils.serialization import load_checkpoint
7 | from reid.models import ft_net
8 | from eval_metrics import eval_regdb
9 | from utils import *
10 | from PIL import Image
11 | from reid.utils.data import transforms as T
12 |
13 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training')
14 | parser.add_argument('--dataset', default='regdb', help='dataset name: regdb or sysu]')
15 |
16 | parser.add_argument('--features', type=int, default=2048, help='feature dimensions')
17 | parser.add_argument('--z_dim', type=int, default=256, help='information bottleneck dimensions')
18 |
19 | parser.add_argument('--height', type=int, default=256, help='image height, should be chosen in {256, 288, 312}')
20 | parser.add_argument('--width', type=int, default=128, help='image width, should be chosen in {128, 144, 156}')
21 |
22 | parser.add_argument('--model_path', default='/home/txd/Variational Distillation/Exp_RegDB/RegDB_2/', type=str)
23 | args = parser.parse_args()
24 |
25 | # overall settings
26 | data_path = '../data/RegDB/'
27 | n_class = 206
28 | test_mode = [2, 1]
29 |
30 | # model settings
31 | overall_feats = 8192
32 | os.environ["CUDA_VISIBLE_DEVICES"] = "2, 1"
33 | model = ft_net(args=args, num_classes=n_class, num_features=args.features)
34 | model = nn.DataParallel(model, device_ids=[0, 1])
35 | cudnn.benchmark = True
36 |
37 | #checkpoint_path = args.model_path
38 |
39 | print('==> Loading data..')
40 |
41 | # Data loading code
42 | normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
43 |
44 | test_transformer = T.Compose([
45 | #T.Resize((312, 156)),
46 | T.Resize((args.height, args.width)),
47 | T.ToTensor(),
48 | normalizer,
49 | ])
50 |
51 | if args.dataset == 'regdb':
52 | data_dir = '../data/RegDB/'
53 |
54 | for trial in range(1,9):
55 | test_trial = trial +1
56 | model_path = args.model_path + '{}/model_best.pth.tar'.format(trial)
57 |
58 | #########################################
59 | checkpoint = load_checkpoint(model_path)
60 | state_dict = checkpoint['model']
61 | from collections import OrderedDict
62 | new_state_dict = OrderedDict()
63 | for k, v in state_dict.items():
64 | if 'module' not in k or 'cam_module' in k:
65 | k = 'module.' + k
66 | else:
67 | k = k.replace('features.module.', 'module.features.')
68 | new_state_dict[k] = v
69 | model.load_state_dict(new_state_dict)
70 | #########################################
71 | model = model.cuda()
72 | model.eval()
73 |
74 | # v to t
75 | query_list = data_dir + 'idx/test_visible_{}'.format(test_trial)+ '.txt'
76 | gallery_list = data_dir + 'idx/test_thermal_{}'.format(test_trial)+ '.txt'
77 | # t to v
78 | #query_list = data_dir + 'idx/test_thermal_{}'.format(test_trial)+ '.txt'
79 | #gallery_list = data_dir + 'idx/test_visible_{}'.format(test_trial)+ '.txt'
80 |
81 | query_image_list, query_label = load_data(query_list)
82 | gallery_image_list, gallery_label =load_data(gallery_list)
83 | temp_list = []
84 | for index, pth in enumerate(query_image_list):
85 | filename = data_dir + pth
86 | img = Image.open(filename)
87 | img = test_transformer(img)
88 | img = img.view([1, 3, args.height, args.width])
89 | img = img.cuda()
90 |
91 | i_observation, i_representation, i_ms_observation, i_ms_representation, \
92 | v_observation, v_representation, v_ms_observation, v_ms_representation = model(img)
93 | result_y = torch.cat(tensors=[i_observation[1], i_ms_observation[1],
94 | v_observation[1], v_ms_observation[1]], dim=1)
95 |
96 | result_y = torch.nn.functional.normalize(result_y, dim=1, p=2)
97 | result_y = result_y.view(-1, overall_feats)
98 | result_y = result_y.squeeze()
99 | result_npy = result_y.data.cpu().numpy()
100 | result_npy = result_npy.astype('double')
101 | temp_list.append(result_npy)
102 |
103 | query_feat, query_label = np.array(temp_list), np.array(query_label)
104 | print('Query feature extraction: done')
105 |
106 | temp_list = []
107 | for index, pth in enumerate(gallery_image_list):
108 | filename = data_dir + pth
109 | img=Image.open(filename)
110 | img=test_transformer(img)
111 | img=img.view([1, 3, args.height, args.width])
112 | img=img.cuda()
113 |
114 | i_observation, i_representation, i_ms_observation, i_ms_representation, \
115 | v_observation, v_representation, v_ms_observation, v_ms_representation = model(img)
116 | result_y = torch.cat(tensors=[i_observation[1], i_ms_observation[1],
117 | v_observation[1], v_ms_observation[1]], dim=1)
118 |
119 | result_y = torch.nn.functional.normalize(result_y, dim=1, p=2)
120 | result_y = result_y.view(-1, overall_feats)
121 | result_y = result_y.squeeze()
122 | result_npy=result_y.data.cpu().numpy()
123 | result_npy=result_npy.astype('double')
124 | temp_list.append(result_npy)
125 |
126 | gallery_feat, gallery_label = np.array(temp_list), np.array(gallery_label)
127 |
128 | print('Gallery feature extraction: done')
129 | distmat = np.matmul(query_feat, np.transpose(gallery_feat))
130 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gallery_label)
131 |
132 | if trial == 0:
133 | all_cmc = cmc
134 | all_mAP = mAP
135 | all_mINP = mINP
136 |
137 | print(
138 | 'Results: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format(
139 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP))
140 |
141 |
142 | def load_data(input_data_path ):
143 | with open(input_data_path) as f:
144 | data_file_list = open(input_data_path, 'rt').read().splitlines()
145 | file_image = [s.split(' ')[0] for s in data_file_list]
146 | file_label = [int(s.split(' ')[1]) for s in data_file_list]
147 |
148 | return file_image, file_label
--------------------------------------------------------------------------------
/RegDB_test/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from . import utils
4 | from . import data_manager
5 | from . import eval_metrics
6 | from . import RegDB_test
7 | from .eval_metrics import eval_regdb
8 | from .utils import *
9 |
10 | __version__ = '0.2.0'
--------------------------------------------------------------------------------
/RegDB_test/data_manager.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os
3 | import numpy as np
4 | import random
5 |
6 | def process_query_sysu(data_path, mode = 'all', relabel=False):
7 | if mode== 'all':
8 | ir_cameras = ['cam3','cam6']
9 | elif mode =='indoor':
10 | ir_cameras = ['cam3','cam6']
11 |
12 | file_path = os.path.join(data_path,'exp/test_id.txt')
13 | files_rgb = []
14 | files_ir = []
15 |
16 | with open(file_path, 'r') as file:
17 | ids = file.read().splitlines()
18 | ids = [int(y) for y in ids[0].split(',')]
19 | ids = ["%04d" % x for x in ids]
20 |
21 | for id in sorted(ids):
22 | for cam in ir_cameras:
23 | img_dir = os.path.join(data_path,cam,id)
24 | if os.path.isdir(img_dir):
25 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)])
26 | files_ir.extend(new_files)
27 | query_img = []
28 | query_id = []
29 | query_cam = []
30 | for img_path in files_ir:
31 | camid, pid = int(img_path[-15]), int(img_path[-13:-9])
32 | query_img.append(img_path)
33 | query_id.append(pid)
34 | query_cam.append(camid)
35 | return query_img, np.array(query_id), np.array(query_cam)
36 |
37 | def process_gallery_sysu(data_path, mode = 'all', trial = 0, relabel=False):
38 |
39 | random.seed(trial)
40 |
41 | if mode== 'all':
42 | rgb_cameras = ['cam1','cam2','cam4','cam5']
43 | elif mode =='indoor':
44 | rgb_cameras = ['cam1','cam2']
45 |
46 | file_path = os.path.join(data_path,'exp/test_id.txt')
47 | files_rgb = []
48 | with open(file_path, 'r') as file:
49 | ids = file.read().splitlines()
50 | ids = [int(y) for y in ids[0].split(',')]
51 | ids = ["%04d" % x for x in ids]
52 |
53 | for id in sorted(ids):
54 | for cam in rgb_cameras:
55 | img_dir = os.path.join(data_path,cam,id)
56 | if os.path.isdir(img_dir):
57 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)])
58 | files_rgb.append(random.choice(new_files))
59 | gall_img = []
60 | gall_id = []
61 | gall_cam = []
62 | for img_path in files_rgb:
63 | camid, pid = int(img_path[-15]), int(img_path[-13:-9])
64 | gall_img.append(img_path)
65 | gall_id.append(pid)
66 | gall_cam.append(camid)
67 | return gall_img, np.array(gall_id), np.array(gall_cam)
68 |
69 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'):
70 | if modal=='visible':
71 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt'
72 | elif modal=='thermal':
73 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt'
74 |
75 | with open(input_data_path) as f:
76 | data_file_list = open(input_data_path, 'rt').read().splitlines()
77 | # Get full list of image and labels
78 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list]
79 | file_label = [int(s.split(' ')[1]) for s in data_file_list]
80 |
81 | return file_image, np.array(file_label)
--------------------------------------------------------------------------------
/RegDB_test/eval_metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import numpy as np
3 | """Cross-Modality ReID"""
4 | import pdb
5 |
6 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20):
7 | """Evaluation with sysu metric
8 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset"
9 | """
10 | num_q, num_g = distmat.shape
11 | if num_g < max_rank:
12 | max_rank = num_g
13 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
14 | indices = np.argsort(distmat, axis=1)
15 | pred_label = g_pids[indices]
16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
17 |
18 | # compute cmc curve for each query
19 | new_all_cmc = []
20 | all_cmc = []
21 | all_AP = []
22 | all_INP = []
23 | num_valid_q = 0. # number of valid query
24 | for q_idx in range(num_q):
25 | # get query pid and camid
26 | q_pid = q_pids[q_idx]
27 | q_camid = q_camids[q_idx]
28 |
29 | # remove gallery samples that have the same pid and camid with query
30 | order = indices[q_idx]
31 | remove = (q_camid == 3) & (g_camids[order] == 2)
32 | keep = np.invert(remove)
33 |
34 | # compute cmc curve
35 | # the cmc calculation is different from standard protocol
36 | # we follow the protocol of the author's released code
37 | new_cmc = pred_label[q_idx][keep]
38 | new_index = np.unique(new_cmc, return_index=True)[1]
39 | new_cmc = [new_cmc[index] for index in sorted(new_index)]
40 |
41 | new_match = (new_cmc == q_pid).astype(np.int32)
42 | new_cmc = new_match.cumsum()
43 | new_all_cmc.append(new_cmc[:max_rank])
44 |
45 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
46 | if not np.any(orig_cmc):
47 | # this condition is true when query identity does not appear in gallery
48 | continue
49 |
50 | cmc = orig_cmc.cumsum()
51 |
52 | # compute mINP
53 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook
54 | pos_idx = np.where(orig_cmc == 1)
55 | pos_max_idx = np.max(pos_idx)
56 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0)
57 | all_INP.append(inp)
58 |
59 | cmc[cmc > 1] = 1
60 |
61 | all_cmc.append(cmc[:max_rank])
62 | num_valid_q += 1.
63 |
64 | # compute average precision
65 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
66 | num_rel = orig_cmc.sum()
67 | tmp_cmc = orig_cmc.cumsum()
68 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
69 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
70 | AP = tmp_cmc.sum() / num_rel
71 | all_AP.append(AP)
72 |
73 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
74 |
75 | all_cmc = np.asarray(all_cmc).astype(np.float32)
76 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC
77 |
78 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32)
79 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q
80 | mAP = np.mean(all_AP)
81 | mINP = np.mean(all_INP)
82 | return new_all_cmc, mAP, mINP
83 |
84 |
85 |
86 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20):
87 | num_q, num_g = distmat.shape
88 | if num_g < max_rank:
89 | max_rank = num_g
90 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
91 | indices = np.argsort(distmat, axis=1)
92 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
93 |
94 | # compute cmc curve for each query
95 | all_cmc = []
96 | all_AP = []
97 | all_INP = []
98 | num_valid_q = 0. # number of valid query
99 |
100 | # only two cameras
101 | q_camids = np.ones(num_q).astype(np.int32)
102 | g_camids = 2* np.ones(num_g).astype(np.int32)
103 |
104 | for q_idx in range(num_q):
105 | # get query pid and camid
106 | q_pid = q_pids[q_idx]
107 | q_camid = q_camids[q_idx]
108 |
109 | # remove gallery samples that have the same pid and camid with query
110 | order = indices[q_idx]
111 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
112 | keep = np.invert(remove)
113 |
114 | # compute cmc curve
115 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
116 | if not np.any(raw_cmc):
117 | # this condition is true when query identity does not appear in gallery
118 | continue
119 |
120 | cmc = raw_cmc.cumsum()
121 |
122 | # compute mINP
123 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook
124 | pos_idx = np.where(raw_cmc == 1)
125 | pos_max_idx = np.max(pos_idx)
126 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0)
127 | all_INP.append(inp)
128 |
129 | cmc[cmc > 1] = 1
130 |
131 | all_cmc.append(cmc[:max_rank])
132 | num_valid_q += 1.
133 |
134 | # compute average precision
135 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
136 | num_rel = raw_cmc.sum()
137 | tmp_cmc = raw_cmc.cumsum()
138 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
139 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
140 | AP = tmp_cmc.sum() / num_rel
141 | all_AP.append(AP)
142 |
143 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
144 |
145 | all_cmc = np.asarray(all_cmc).astype(np.float32)
146 | all_cmc = all_cmc.sum(0) / num_valid_q
147 | mAP = np.mean(all_AP)
148 | mINP = np.mean(all_INP)
149 | return all_cmc, mAP, mINP
--------------------------------------------------------------------------------
/RegDB_test/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data.sampler import Sampler
4 | import sys
5 | import os.path as osp
6 | import torch
7 |
8 | def load_data(input_data_path ):
9 | with open(input_data_path) as f:
10 | data_file_list = open(input_data_path, 'rt').read().splitlines()
11 | # Get full list of color image and labels
12 | file_image = [s.split(' ')[0] for s in data_file_list]
13 | file_label = [int(s.split(' ')[1]) for s in data_file_list]
14 |
15 | return file_image, file_label
16 |
17 |
18 | def GenIdx( train_color_label, train_thermal_label):
19 | color_pos = []
20 | unique_label_color = np.unique(train_color_label)
21 | for i in range(len(unique_label_color)):
22 | tmp_pos = [k for k,v in enumerate(train_color_label) if v==unique_label_color[i]]
23 | color_pos.append(tmp_pos)
24 |
25 | thermal_pos = []
26 | unique_label_thermal = np.unique(train_thermal_label)
27 | for i in range(len(unique_label_thermal)):
28 | tmp_pos = [k for k,v in enumerate(train_thermal_label) if v==unique_label_thermal[i]]
29 | thermal_pos.append(tmp_pos)
30 | return color_pos, thermal_pos
31 |
32 | def GenCamIdx(gall_img, gall_label, mode):
33 | if mode =='indoor':
34 | camIdx = [1,2]
35 | else:
36 | camIdx = [1,2,4,5]
37 | gall_cam = []
38 | for i in range(len(gall_img)):
39 | gall_cam.append(int(gall_img[i][-10]))
40 |
41 | sample_pos = []
42 | unique_label = np.unique(gall_label)
43 | for i in range(len(unique_label)):
44 | for j in range(len(camIdx)):
45 | id_pos = [k for k,v in enumerate(gall_label) if v==unique_label[i] and gall_cam[k]==camIdx[j]]
46 | if id_pos:
47 | sample_pos.append(id_pos)
48 | return sample_pos
49 |
50 | def ExtractCam(gall_img):
51 | gall_cam = []
52 | for i in range(len(gall_img)):
53 | cam_id = int(gall_img[i][-10])
54 | # if cam_id ==3:
55 | # cam_id = 2
56 | gall_cam.append(cam_id)
57 |
58 | return np.array(gall_cam)
59 |
60 | class IdentitySampler(Sampler):
61 | """Sample person identities evenly in each batch.
62 | Args:
63 | train_color_label, train_thermal_label: labels of two modalities
64 | color_pos, thermal_pos: positions of each identity
65 | batchSize: batch size
66 | """
67 |
68 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch):
69 | uni_label = np.unique(train_color_label)
70 | self.n_classes = len(uni_label)
71 |
72 |
73 | N = np.maximum(len(train_color_label), len(train_thermal_label))
74 | for j in range(int(N/(batchSize*num_pos))+1):
75 | batch_idx = np.random.choice(uni_label, batchSize, replace = False)
76 | for i in range(batchSize):
77 | sample_color = np.random.choice(color_pos[batch_idx[i]], num_pos)
78 | sample_thermal = np.random.choice(thermal_pos[batch_idx[i]], num_pos)
79 |
80 | if j ==0 and i==0:
81 | index1= sample_color
82 | index2= sample_thermal
83 | else:
84 | index1 = np.hstack((index1, sample_color))
85 | index2 = np.hstack((index2, sample_thermal))
86 |
87 | self.index1 = index1
88 | self.index2 = index2
89 | self.N = N
90 |
91 | def __iter__(self):
92 | return iter(np.arange(len(self.index1)))
93 |
94 | def __len__(self):
95 | return self.N
96 |
97 | class AverageMeter(object):
98 | """Computes and stores the average and current value"""
99 | def __init__(self):
100 | self.reset()
101 |
102 | def reset(self):
103 | self.val = 0
104 | self.avg = 0
105 | self.sum = 0
106 | self.count = 0
107 |
108 | def update(self, val, n=1):
109 | self.val = val
110 | self.sum += val * n
111 | self.count += n
112 | self.avg = self.sum / self.count
113 |
114 | def mkdir_if_missing(directory):
115 | if not osp.exists(directory):
116 | try:
117 | os.makedirs(directory)
118 | except OSError as e:
119 | if e.errno != errno.EEXIST:
120 | raise
121 | class Logger(object):
122 | """
123 | Write console output to external text file.
124 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
125 | """
126 | def __init__(self, fpath=None):
127 | self.console = sys.stdout
128 | self.file = None
129 | if fpath is not None:
130 | mkdir_if_missing(osp.dirname(fpath))
131 | self.file = open(fpath, 'w')
132 |
133 | def __del__(self):
134 | self.close()
135 |
136 | def __enter__(self):
137 | pass
138 |
139 | def __exit__(self, *args):
140 | self.close()
141 |
142 | def write(self, msg):
143 | self.console.write(msg)
144 | if self.file is not None:
145 | self.file.write(msg)
146 |
147 | def flush(self):
148 | self.console.flush()
149 | if self.file is not None:
150 | self.file.flush()
151 | os.fsync(self.file.fileno())
152 |
153 | def close(self):
154 | self.console.close()
155 | if self.file is not None:
156 | self.file.close()
157 |
158 | def set_seed(seed, cuda=True):
159 | np.random.seed(seed)
160 | torch.manual_seed(seed)
161 | if cuda:
162 | torch.cuda.manual_seed(seed)
163 |
164 | def set_requires_grad(nets, requires_grad=False):
165 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
166 | Parameters:
167 | nets (network list) -- a list of networks
168 | requires_grad (bool) -- whether the networks require gradients or not
169 | """
170 | if not isinstance(nets, list):
171 | nets = [nets]
172 | for net in nets:
173 | if net is not None:
174 | for param in net.parameters():
175 | param.requires_grad = requires_grad
--------------------------------------------------------------------------------
/__pycache__/utlis.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/__pycache__/utlis.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/README.md:
--------------------------------------------------------------------------------
1 | # Testing demo for SYSU-MM01 dataset
2 | A MATLAB version testing demo is provided for evaluation on SYSU-MM01 dataset in "RGB-Infrared Cross-Modality Person Re-identification, ICCV 2017".
3 |
4 | Dataset download links:
5 |
6 | Baiduyun: http://pan.baidu.com/s/1gfIlcmZ
7 |
8 | Dropbox: https://www.dropbox.com/sh/v036mg1q4yg7awb/AABhxU-FJ4X2oyq7-Ts6bgD0a?dl=0
9 |
10 | Project page: http://isee.sysu.edu.cn/project/RGBIRReID.htm
11 |
12 | Testing code:
13 | https://github.com/wuancong/SYSU-MM01/blob/master/evaluation
14 |
15 | ## Citation
16 | If you use the dataset, please cite the following paper:
17 |
18 | Ancong Wu, Wei-Shi Zheng, Hong-Xing Yu, Shaogang Gong and Jianhuang Lai. RGB-Infrared Cross-Modality Person Re-Identification. IEEE International Conference on Computer Vision (ICCV), 2017.
19 |
20 | ## Testing procedure
21 |
22 | 1. Train a model using the samples of the IDs in "./data_split/train_id.mat" on SYSU-MM01.
23 |
24 | 2. Extract features of SYSU-MM01 dataset using data in "SYSU_MM01.zip" (can be found in the download links).
25 | To apply our provided testing code, the features should be saved in the following form:
26 | Features of each camera are saved in seperated mat files named "name_cam#.mat".
27 | In each mat file, feature{id}(i,:) is a row feature vector of the i-th image of id.
28 | An example of our proposed deep zero padding features is provided in "./feature".
29 |
30 | 3. Run "demo.m". The default setting in "demo.m" is single-shot all-search mode. The input parameters can be set according to the comments in "demo.m". A fixed data split of testing set and 10 trials is provided in "./data_split". The average CMC and mAP results of 10 trials of random split will be displayed when testing is finished. The result of this demo is an re-implemented version and is slightly better than the one reported in the paper.
31 |
32 | ## Contact Information
33 | If you have any questions, please feel free to contact wuancong@mail2.sysu.edu.cn.
--------------------------------------------------------------------------------
/evaluation/__pycache__/gen_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/__pycache__/gen_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/data_split/rand_perm_cam.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/data_split/rand_perm_cam.mat
--------------------------------------------------------------------------------
/evaluation/data_split/test_id.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/data_split/test_id.mat
--------------------------------------------------------------------------------
/evaluation/data_split/train_id.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/data_split/train_id.mat
--------------------------------------------------------------------------------
/evaluation/demo.m:
--------------------------------------------------------------------------------
1 | % A demo for evaluation on SYSU-MM01 dataset using features learned by deep
2 | % zero padding in "RGB-Infrared Cross-Modality Person Re-identification"
3 |
4 | % Features of each cameras are saved in seperated mat files named "name_cam#.mat"
5 | % In the mat files, feature{id}(i,:) is a row feature vector of the i-th image of id
6 | %feature_info.name = 'feat_deep_zero_padding_';
7 | %feature_info.dir = './feature';
8 | feature_info.dir = '../';
9 | feature_info.name = '';
10 | result_dir = './result'; % directory for saving result
11 |
12 | setting.mode = 'all_search'; %'all_search' 'indoor_search'
13 | %setting.mode = 'indoor_search';
14 | setting.number_shot = 1; % 1 for single shot, 10 for multi-shot
15 | %setting.number_shot = 10;
16 | model.test_fun = @euclidean_dist; % Similarity measurement function
17 | % (You could define your own function in the same way as euclidean_dist function)
18 | model.name = 'euclidean'; % model name
19 | model.para = []; % No parameter is needed for euclidean distance here
20 | % (If mahalanobis distance is used, the parameter can be the metric M learned from training data)
21 |
22 | % load data split
23 | content = load('./data_split/test_id.mat'); % fixed testing person IDs
24 | data_split.test_id = content.id;
25 |
26 | content = load('./data_split/rand_perm_cam.mat'); % fixed permutation of samples in each camera
27 | data_split.rand_perm_cam = content.rand_perm_cam;
28 |
29 | % evaluatio
30 | disp('all-search_single_shot')
31 | performance = evaluation_SYSU_MM01(feature_info, data_split, model, setting, result_dir);
32 | disp('all-search_multi_shot')
33 | setting.mode = 'all_search'; %'all_search' 'indoor_search'
34 | %setting.mode = 'indoor_search';
35 | setting.number_shot = 10; % 1 for single shot, 10 for multi-shot
36 | performance = evaluation_SYSU_MM01(feature_info, data_split, model, setting, result_dir);
37 |
38 | disp('indoor_search_single_shot')
39 | setting.mode = 'indoor_search'; %'all_search' 'indoor_search'
40 | %setting.mode = 'indoor_search';
41 | setting.number_shot = 1; % 1 for single shot, 10 for multi-shot
42 | performance = evaluation_SYSU_MM01(feature_info, data_split, model, setting, result_dir);
43 | disp('indoor_search_multi_shot')
44 | setting.mode = 'indoor_search'; %'all_search' 'indoor_search'
45 | %setting.mode = 'indoor_search';
46 | setting.number_shot = 10; % 1 for single shot, 10 for multi-shot
47 | performance = evaluation_SYSU_MM01(feature_info, data_split, model, setting, result_dir);
--------------------------------------------------------------------------------
/evaluation/euclidean_dist.m:
--------------------------------------------------------------------------------
1 | function [ dist ] = euclidean_dist( X_gallery, X_probe, model_para )
2 | dist=pdist2(X_probe,X_gallery,'euclidean');
3 | end
4 |
5 |
--------------------------------------------------------------------------------
/evaluation/evaluate_SYSU_MM01.py:
--------------------------------------------------------------------------------
1 | import os, sys, os.path as osp
2 | from evaluation import gen_utils
3 | import numpy as np
4 |
5 |
6 | def evaluate(feature_dir, prefix, settings, total_runs=10):
7 | """to evaluate and rank results for SYSU_MM01 dataset
8 |
9 | Arguments:
10 | feature_dir {str} -- a dir where features are saved
11 | prefix {str} -- prefix of file names
12 | """
13 | gallery_cams, probe_cams = gen_utils.get_cam_settings(settings)
14 | all_cams = list(set(gallery_cams + probe_cams)) # get unique cams
15 | features = {}
16 |
17 | # get permutation indices
18 | cam_permutations = gen_utils.get_cam_permutation_indices(all_cams)
19 |
20 | # get test ids
21 | test_ids = gen_utils.get_test_ids()
22 |
23 | for cam_index in all_cams:
24 | # read features
25 | cam_feature_file = osp.join(feature_dir, ("cam{}").format(cam_index))
26 | features["cam" + str(cam_index)] = gen_utils.load_feature_file(cam_feature_file)
27 |
28 | # perform testing
29 | print(list(features.keys()))
30 | cam_id_locations = [1, 2, 2, 4, 5, 6]
31 | # camera 2 and 3 are in the same location
32 | mAPs = []
33 | cmcs = []
34 |
35 | for run_index in range(total_runs):
36 | print("trial #{}".format(run_index))
37 | X_gallery, Y_gallery, cam_gallery, X_probe, Y_probe, cam_probe = gen_utils.get_testing_set(
38 | features,
39 | cam_permutations,
40 | test_ids,
41 | run_index,
42 | cam_id_locations,
43 | gallery_cams,
44 | probe_cams,
45 | settings,
46 | )
47 |
48 | # print(X_gallery.shape, Y_gallery.shape, cam_gallery.shape, X_probe.shape, Y_probe.shape, cam_probe.shape)
49 |
50 | dist = gen_utils.euclidean_dist(X_probe, X_gallery)
51 |
52 | cmc = gen_utils.get_cmc_multi_cam(
53 | Y_gallery, cam_gallery, Y_probe, cam_probe, dist
54 | )
55 | mAP = gen_utils.get_mAP_multi_cam(
56 | Y_gallery, cam_gallery, Y_probe, cam_probe, dist
57 | )
58 |
59 | print("rank 1 5 10 20", cmc[[0, 4, 9, 19]])
60 | print("mAP", mAP)
61 | cmcs.append(cmc)
62 | mAPs.append(mAP)
63 |
64 | # find mean mAP and cmc
65 | cmcs = np.array(cmcs) # 10 x #gallery
66 | mAPs = np.array(mAPs) # 10
67 | mean_cmc = np.mean(cmcs, axis=0)
68 | mean_mAP = np.mean(mAPs)
69 | print("mean rank 1 5 10 20", mean_cmc[[0, 4, 9, 19]])
70 | print("mean mAP", mean_mAP)
71 |
72 |
73 | def evaluate_results(feature_dir, prefix, mode, number_shot, total_runs=10):
74 | # evaluation settings
75 | print(
76 | "running evaluation for features from {} with prefix {} with mode {} and number shot {}".format(
77 | feature_dir, prefix, mode, number_shot
78 | )
79 | )
80 | print('total test runs:', total_runs)
81 | settings = {}
82 | settings["mode"] = mode # indoor | all
83 | settings["number_shot"] = number_shot # 1 = single-shot | 10 = multi-shot
84 | evaluate(feature_dir, prefix, settings, total_runs)
85 |
86 |
87 | if __name__ == "__main__":
88 | # example function call
89 | feature_dir = "/home/txd/Variational Distillation/"
90 | prefix = "cam"
91 |
92 | evaluate_results(feature_dir, prefix, 'all', 1)
--------------------------------------------------------------------------------
/evaluation/evaluation_SYSU_MM01.m:
--------------------------------------------------------------------------------
1 | function [performance] = evaluation_SYSU_MM01(feature_info, data_split, model, setting, result_dir)
2 | % evaluation on SYSU-MM01 dataset with input features and model
3 | % input:
4 | % Features of each cameras are saved in seperated mat files named "name_cam#.mat"
5 | % In the mat files, feature{id}(i,:) is a row feature vector of the i-th image of id
6 | % feature_info.name = 'feat_deep_zero_padding';
7 | % feature_info.dir = './feature';
8 | % result_dir = './result'; % directory for saving result
9 | %
10 | % setting.mode = 'all_search'; %'all_search' 'indoor_search'
11 | % setting.number_shot = 1; % 1 for single shot, 10 for multi-shot
12 | %
13 | % model.test_fun = @euclidean_dist; % Similarity measurement function (You could define your own function in the same way as euclidean_dist function)
14 | % model.name = 'euclidean'; % model name
15 | % model.para = []; % No parameter is needed for euclidean distance here (If mahalanobis distance is used, the parameter can be the metric M learned from training data)
16 | %
17 | % content = load('./data_split/test_id.mat'); % fixed testing person IDs
18 | % data_split.test_id = content.id;
19 | %
20 | % content = load('./data_split/rand_perm_cam.mat'); % fixed permutation of samples in each camera
21 | % data_split.rand_perm_cam = content.rand_perm_cam;
22 | %
23 | % output:
24 | % performance.cmc_mean & performance.map_mean - average results of 10 trials
25 | % performance.cmc_all & performance.map_mean - results of each trial
26 |
27 | feature_name = feature_info.name;
28 | feature_dir = feature_info.dir;
29 |
30 | mode = setting.mode; % 'all_search' 'indoor_search'
31 | number_shot = setting.number_shot; % 1 for single shot, 10 for multi-shot
32 | test_id = data_split.test_id;
33 | rand_perm_cam = data_split.rand_perm_cam;
34 |
35 | %% begin
36 | switch mode
37 | case 'all_search'
38 | gallery_cam_list=[1 2 4 5];
39 | probe_cam_list=[3 6];
40 | case 'indoor_search'
41 | gallery_cam_list=[1 2];
42 | probe_cam_list=[3 6];
43 | otherwise
44 | disp('mode input error');
45 | end
46 |
47 | % load features of 6 cameras
48 | load_cam_list = union(probe_cam_list,gallery_cam_list);
49 | cam_count = 6;
50 | feature_cam=cell(cam_count,1);
51 | Y=cell(cam_count,1);
52 |
53 | cam_id = [1 2 2 4 5 6]; % camera 2 and 3 are in the same location
54 |
55 | for i_cam=1:length(load_cam_list)
56 | cam_label=load_cam_list(i_cam);
57 | load_name=[feature_name 'cam' num2str(cam_label) '.mat'];
58 | content=load(fullfile(feature_dir,load_name));
59 | feature_cam{cam_label}=content.feature_test;
60 | feature_cam{cam_label}=feature_cam{cam_label}';
61 | Y{cam_label}=(1:length(content.feature_test));
62 | end
63 | clear content
64 |
65 | % begin testing
66 | cmc_all=cell(10,1);
67 | map_all=zeros(10,1);
68 |
69 | for run_time=1:10
70 | %disp(['trial #',num2str(run_time)]);
71 | % For X_..., each row is an observation
72 | [X_gallery, Y_gallery, Y_cam_gallery, X_probe, Y_probe, Y_cam_probe]=get_testing_set...
73 | (feature_cam, Y, rand_perm_cam, run_time, number_shot, gallery_cam_list, probe_cam_list, test_id, cam_id);
74 | dist = model.test_fun(X_gallery,X_probe,model.para);
75 | cmc = get_cmc_multi_cam(Y_gallery,Y_cam_gallery,Y_probe,Y_cam_probe,dist);
76 | map = get_map_multi_cam(Y_gallery,Y_cam_gallery,Y_probe,Y_cam_probe,dist);
77 | %disp('rank 1 5 10 20');
78 | %disp(cmc([1 5 10 20]));
79 | %disp('mAP');
80 | %disp(map);
81 | cmc_all{run_time}=cmc;
82 | map_all(run_time,:)=map;
83 | end
84 | cmc_all=cell2mat(cmc_all);
85 | performance.cmc_all=cmc_all;
86 | performance.map_all=map_all;
87 | cmc_mean=mean(performance.cmc_all);
88 | performance.cmc_mean=cmc_mean;
89 | map_mean=mean(performance.map_all);
90 | performance.map_mean=map_mean;
91 |
92 | % display
93 | disp('Average CMC:');
94 | disp('rank 1 5 10 20');
95 | disp(cmc_mean([ 1 5 10 20]));
96 | disp('Average mAP:');
97 | disp(map_mean);
98 |
99 | % save
100 | save_path=fullfile(result_dir,['result_' feature_name '_' model.name '_' mode '_' num2str(number_shot) 'shot.mat']);
101 | save(save_path,'performance','setting','-v7.3');
102 |
103 | end
--------------------------------------------------------------------------------
/evaluation/gen_utils.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | from scipy.spatial.distance import cdist
4 | import scipy.io as sio
5 |
6 |
7 | def euclidean_dist(X_probe, X_gallery):
8 | return cdist(X_probe, X_gallery)
9 |
10 |
11 | def get_cam_settings(settings):
12 | if settings["mode"] == "indoor":
13 | gallery_cams, probe_cams = [1, 2], [3, 6]
14 | elif settings["mode"] == "all":
15 | gallery_cams, probe_cams = [1, 2, 4, 5], [3, 6]
16 | else:
17 | assert False, "unknown search mode : " + settings["mode"]
18 |
19 | return gallery_cams, probe_cams
20 |
21 |
22 | def get_test_ids():
23 | test_id_filepath = "./data_split/test_id.mat"
24 | filecontents = sio.loadmat(test_id_filepath)
25 | test_ids = filecontents["id"].squeeze()
26 |
27 | # make the test ids 0-based
28 | test_ids = test_ids - 1
29 | return test_ids
30 |
31 |
32 | def get_cam_permutation_indices(all_cams):
33 | rand_cam_perm_filepath = "./data_split/rand_perm_cam.mat"
34 | filecontents = sio.loadmat(rand_cam_perm_filepath)
35 | mat_cam_permutations = filecontents["rand_perm_cam"]
36 |
37 | # buffer to hold all permutations
38 | all_permutations = {}
39 |
40 | for cam_index in all_cams:
41 | cam_permutations = mat_cam_permutations[cam_index - 1][0].squeeze()
42 | cam_name = "cam" + str(cam_index)
43 | if cam_name not in all_permutations:
44 | all_permutations[cam_name] = {}
45 |
46 | # logistics
47 | print("{} ids found in cam {}".format(len(cam_permutations), cam_index))
48 |
49 | # collect permutations for all person indices
50 | for person_index, rand_permutations in enumerate(cam_permutations):
51 | all_permutations[cam_name][person_index] = rand_permutations - 1
52 |
53 | return all_permutations
54 |
55 |
56 | def load_feature_file(cam_feature_file):
57 | filecontents = sio.loadmat(cam_feature_file)
58 | mat_features = filecontents["feature_test"].squeeze()
59 | all_features = {}
60 |
61 | # collect features for each person (total 333 persons)
62 | for person_index, current_features in enumerate(mat_features):
63 | all_features[person_index] = current_features
64 |
65 | return all_features
66 |
67 |
68 | def get_testing_set(
69 | features,
70 | cam_permutations,
71 | test_ids,
72 | run_index,
73 | cam_id_locations,
74 | gallery_cams,
75 | probe_cams,
76 | settings,
77 | ):
78 | # cam is indexed from 1 - 6
79 | # person indices are indexed using 0-based numbers
80 | X_gallery_rgb, Y_gallery_rgb, cam_gallery_rgb, X_probe_IR, Y_probe_IR, cam_probe_IR = (
81 | [],
82 | [],
83 | [],
84 | [],
85 | [],
86 | [],
87 | )
88 |
89 | number_shot = settings["number_shot"]
90 |
91 | # collect rgb images as gallery
92 | for cam_index in gallery_cams:
93 | cam_name = "cam" + str(cam_index)
94 | current_cam_features = features[cam_name]
95 |
96 | # for all the test ids, collect features
97 | for test_id in test_ids:
98 | current_id_features = current_cam_features[test_id]
99 |
100 | if np.any(np.array(current_id_features.shape) == 0):
101 | continue
102 | # assert (not np.any(np.array(current_id_features.shape) == 0)), 'test id feature count is 0'
103 |
104 | # get the current permutation
105 | current_permutation = cam_permutations[cam_name][test_id][run_index]
106 | current_permutation = current_permutation[:number_shot]
107 |
108 | selected_features = current_id_features[current_permutation, :]
109 |
110 | if len(X_gallery_rgb) == 0:
111 | X_gallery_rgb = selected_features
112 | else:
113 | X_gallery_rgb = np.concatenate(
114 | (X_gallery_rgb, selected_features), axis=0
115 | )
116 |
117 | Y_gallery_rgb += [test_id] * number_shot
118 | cam_gallery_rgb += [cam_id_locations[cam_index - 1]] * number_shot
119 |
120 | Y_gallery_rgb = np.array(Y_gallery_rgb, dtype=np.int)
121 | cam_gallery_rgb = np.array(cam_gallery_rgb, dtype=np.int)
122 |
123 | # collect all the IR
124 | for cam_index in probe_cams:
125 | cam_name = "cam" + str(cam_index)
126 | current_cam_features = features[cam_name]
127 |
128 | # for all the test ids, collect features
129 | for test_id in test_ids:
130 | current_id_features = current_cam_features[test_id]
131 | if np.any(np.array(current_id_features.shape) == 0):
132 | continue
133 | # assert len(current_id_features) != 0, 'test id feature count is 0'
134 |
135 | if len(X_probe_IR) == 0:
136 | X_probe_IR = current_id_features
137 |
138 | else:
139 | X_probe_IR = np.concatenate((X_probe_IR, current_id_features), axis=0)
140 |
141 | Y_probe_IR += [test_id] * len(current_id_features)
142 | cam_probe_IR += [cam_id_locations[cam_index - 1]] * len(current_id_features)
143 |
144 | Y_probe_IR = np.array(Y_probe_IR, dtype=np.int)
145 | cam_probe_IR = np.array(cam_probe_IR, dtype=np.int)
146 |
147 | return (
148 | X_gallery_rgb,
149 | Y_gallery_rgb,
150 | cam_gallery_rgb,
151 | X_probe_IR,
152 | Y_probe_IR,
153 | cam_probe_IR,
154 | )
155 |
156 |
157 | def get_unique(array):
158 | _, idx = np.unique(array, return_index=True)
159 | return array[np.sort(idx)]
160 |
161 |
162 | def get_cmc_multi_cam(Y_gallery, cam_gallery, Y_probe, cam_probe, dist):
163 | # dist = #probe x #gallery
164 | num_probes, num_gallery = dist.shape
165 | gallery_unique_count = get_unique(Y_gallery).shape[0]
166 | match_counter = np.zeros((gallery_unique_count))
167 |
168 | # sort the distance matrix
169 | sorted_indices = np.argsort(dist, axis=-1)
170 |
171 | Y_result = Y_gallery[sorted_indices]
172 | cam_locations_result = cam_gallery[sorted_indices]
173 |
174 | valid_probe_sample_count = 0
175 |
176 | for probe_index in range(num_probes):
177 | # remove gallery samples from the same camera of the probe
178 | Y_result_i = Y_result[probe_index, :]
179 | Y_result_i[cam_locations_result[probe_index, :] == cam_probe[probe_index]] = -1
180 |
181 | # remove the -1 entries from the label result
182 | # print(Y_result_i.shape)
183 | Y_result_i = np.array([i for i in Y_result_i if i != -1])
184 | # print(Y_result_i.shape)
185 |
186 | # remove duplicated id in "stable" manner
187 | Y_result_i_unique = get_unique(Y_result_i)
188 | # print(Y_result_i_unique, gallery_unique_count)
189 |
190 | # match for probe i
191 | match_i = Y_result_i_unique == Y_probe[probe_index]
192 |
193 | if np.sum(match_i) != 0: # if there is true matching in gallery
194 | valid_probe_sample_count += 1
195 | match_counter += match_i
196 |
197 | rankk = match_counter / valid_probe_sample_count
198 | cmc = np.cumsum(rankk)
199 | return cmc
200 |
201 |
202 | def get_mAP_multi_cam(Y_gallery, cam_gallery, Y_probe, cam_probe, dist):
203 | # dist = #probe x #gallery
204 | num_probes, num_gallery = dist.shape
205 |
206 | # sort the distance matrix
207 | sorted_indices = np.argsort(dist, axis=-1)
208 |
209 | Y_result = Y_gallery[sorted_indices]
210 | cam_locations_result = cam_gallery[sorted_indices]
211 |
212 | valid_probe_sample_count = 0
213 | avg_precision_sum = 0
214 |
215 | for probe_index in range(num_probes):
216 | # remove gallery samples from the same camera of the probe
217 | Y_result_i = Y_result[probe_index, :]
218 | Y_result_i[cam_locations_result[probe_index, :] == cam_probe[probe_index]] = -1
219 |
220 | # remove the -1 entries from the label result
221 | # print(Y_result_i.shape)
222 | Y_result_i = np.array([i for i in Y_result_i if i != -1])
223 | # print(Y_result_i.shape)
224 |
225 | # match for probe i
226 | match_i = Y_result_i == Y_probe[probe_index]
227 | true_match_count = np.sum(match_i)
228 |
229 | if true_match_count != 0: # if there is true matching in gallery
230 | valid_probe_sample_count += 1
231 | true_match_rank = np.where(match_i)[0]
232 |
233 | ap = np.mean(
234 | np.array(range(1, true_match_count + 1)) / (true_match_rank + 1)
235 | )
236 | avg_precision_sum += ap
237 |
238 | mAP = avg_precision_sum / valid_probe_sample_count
239 | return mAP
--------------------------------------------------------------------------------
/evaluation/get_cmc_multi_cam.m:
--------------------------------------------------------------------------------
1 | function [cmc,ind]=get_cmc_multi_cam(Y_gallery,Y_cam_gallery,Y_probe,Y_cam_probe,dist)
2 | [~, ind]=sort(dist,2);
3 | Y_result=Y_gallery(ind);
4 | Y_cam_result=Y_cam_gallery(ind);
5 | valid_probe_sample_count=0;
6 | gallery_unique_count=length(unique(Y_gallery));
7 | match_counter=zeros(1,gallery_unique_count);
8 | probe_sample_count=length(Y_probe);
9 | for i=1:probe_sample_count
10 | % remove gallery samples from the same camera of the probe
11 | Y_result_i=Y_result(i,:);
12 | Y_result_i(Y_cam_result(i,:)==Y_cam_probe(i))=[];
13 | % remove duplicated id
14 | Y_result_unique_i=unique(Y_result_i,'stable');
15 | % match for probe i
16 | match_i=(Y_result_unique_i==Y_probe(i));
17 | if sum(match_i)~=0 % if there is true matching in gallery
18 | valid_probe_sample_count=valid_probe_sample_count+1;
19 | for r=1:length(match_i)
20 | match_counter(r)=match_counter(r)+match_i(r);
21 | end
22 | end
23 | end
24 | rankk=match_counter/valid_probe_sample_count;
25 | cmc=cumsum(rankk);
26 | end
--------------------------------------------------------------------------------
/evaluation/get_map_multi_cam.m:
--------------------------------------------------------------------------------
1 | function map=get_map_multi_cam(Y_gallery,Y_cam_gallery,Y_probe,Y_cam_probe,dist)
2 | [~, ind]=sort(dist,2);
3 | Y_result=Y_gallery(ind);
4 | Y_cam_result=Y_cam_gallery(ind);
5 | valid_probe_sample_count=0;
6 | probe_sample_count=length(Y_probe);
7 | ap_sum=0;
8 | for i=1:probe_sample_count
9 | % remove gallery samples from the same camera of the probe
10 | Y_result_i=Y_result(i,:);
11 | Y_result_i(Y_cam_result(i,:)==Y_cam_probe(i))=[];
12 | % match for probe i
13 | match_i=(Y_result_i==Y_probe(i));
14 | true_match_count=sum(match_i);
15 | if true_match_count~=0 % if there is true matching in gallery
16 | valid_probe_sample_count=valid_probe_sample_count+1;
17 | true_match_rank=find(match_i==1);
18 | ap=mean((1:true_match_count)./true_match_rank);
19 | ap_sum=ap_sum+ap;
20 | end
21 | end
22 | map=ap_sum/valid_probe_sample_count;
23 | end
--------------------------------------------------------------------------------
/evaluation/get_testing_set.m:
--------------------------------------------------------------------------------
1 | function [X_gallery, Y_gallery, Y_cam_gallery, X_probe, Y_probe, Y_cam_probe]=get_testing_set...
2 | (feature_cam, Y, rand_perm_cam, run_time, number_shot, gallery_cam_list, probe_cam_list, test_id, cam_id)
3 | % get testing set for SYSU-MM01 multi-modality re-id dataset
4 | % input:
5 | % feature_cam - feature_cam{i}{id} is feature matrix (each row is a feature) of cam i of person id
6 | % Y - person id for each cell
7 | % rand_perm - rand permutation of indices for selecting gallery
8 | % run_time - current count of evaluation time
9 | % number_shot - 1 single shot, 5 multi-shot, -1 all except one
10 | % test_cam_list - cam list of testing set
11 | % test_id - list of testing persons
12 | % output:
13 | % X_... - feature vectors in each row
14 | % Y_... - person label
15 | % Y_cam_... - camera label
16 |
17 | gallery_cam_count=length(gallery_cam_list);
18 | X_gallery=cell(gallery_cam_count,1);
19 | Y_gallery=cell(gallery_cam_count,1);
20 | Y_cam_gallery=cell(gallery_cam_count,1);
21 | probe_cam_count=length(probe_cam_list);
22 | X_probe=cell(probe_cam_count,1);
23 | Y_probe=cell(probe_cam_count,1);
24 | Y_cam_probe=cell(probe_cam_count,1);
25 |
26 | % gallery
27 | for i_cam=1:gallery_cam_count
28 | cam_num=gallery_cam_list(i_cam);
29 | cam_label=cam_id(cam_num);
30 | Y_i_cam=Y{cam_num};
31 | id_count=length(Y_i_cam);
32 |
33 | X_gallery{i_cam}=cell(id_count,1);
34 | Y_gallery{i_cam}=[];
35 | Y_cam_gallery{i_cam}=[];
36 | for i_id=1:id_count
37 | if isempty(find(test_id==Y_i_cam(i_id)))
38 | continue;
39 | end
40 | rand_perm_this=rand_perm_cam{cam_num}{i_id}(run_time,:);
41 | if isempty(rand_perm_this)
42 | continue;
43 | end
44 | if number_shot<0
45 | ind_gallery_this=rand_perm_this(1:end+number_shot);
46 | else
47 | ind_gallery_this=rand_perm_this(1:number_shot);
48 | end
49 | X_test_this=feature_cam{cam_num}{i_id};
50 | X_gallery{i_cam}{i_id}=X_test_this(ind_gallery_this,:);
51 | frame_count=size(X_gallery{i_cam}{i_id},1);
52 | Y_gallery{i_cam}=[Y_gallery{i_cam};repmat(Y_i_cam(i_id),frame_count,1)];
53 | Y_cam_gallery{i_cam}=[Y_cam_gallery{i_cam};repmat(cam_label,frame_count,1)];
54 | end
55 | X_gallery{i_cam}=cell2mat(X_gallery{i_cam});
56 | end
57 |
58 | % probe
59 | for i_cam=1:probe_cam_count
60 | cam_num=probe_cam_list(i_cam);
61 | cam_label=cam_id(cam_num);
62 | Y_i_cam=Y{cam_num};
63 | id_count=length(Y_i_cam);
64 |
65 | X_probe{i_cam}=cell(id_count,1);
66 | Y_probe{i_cam}=[];
67 | Y_cam_probe{i_cam}=[];
68 | for i_id=1:id_count
69 | if isempty(find(test_id==Y_i_cam(i_id)))
70 | continue;
71 | end
72 | rand_perm_this=rand_perm_cam{cam_num}{i_id}(run_time,:);
73 | if isempty(rand_perm_this)
74 | continue;
75 | end
76 | X_test_this=feature_cam{cam_num}{i_id};
77 | X_probe{i_cam}{i_id}=X_test_this;
78 | frame_count=size(X_probe{i_cam}{i_id},1);
79 | Y_probe{i_cam}=[Y_probe{i_cam};repmat(Y_i_cam(i_id),frame_count,1)];
80 | Y_cam_probe{i_cam}=[Y_cam_probe{i_cam};repmat(cam_label,frame_count,1)];
81 | end
82 | X_probe{i_cam}=cell2mat(X_probe{i_cam});
83 | end
84 |
85 | X_gallery=cell2mat(X_gallery);
86 | Y_gallery=cell2mat(Y_gallery);
87 | Y_cam_gallery=cell2mat(Y_cam_gallery);
88 |
89 | X_probe=cell2mat(X_probe);
90 | Y_probe=cell2mat(Y_probe);
91 | Y_cam_probe=cell2mat(Y_cam_probe);
92 |
93 | end
--------------------------------------------------------------------------------
/evaluation/result/result__euclidean_all_search_10shot.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/result/result__euclidean_all_search_10shot.mat
--------------------------------------------------------------------------------
/evaluation/result/result__euclidean_all_search_1shot.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/result/result__euclidean_all_search_1shot.mat
--------------------------------------------------------------------------------
/evaluation/result/result__euclidean_indoor_search_10shot.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/result/result__euclidean_indoor_search_10shot.mat
--------------------------------------------------------------------------------
/evaluation/result/result__euclidean_indoor_search_1shot.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/result/result__euclidean_indoor_search_1shot.mat
--------------------------------------------------------------------------------
/evaluation/train_id.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/train_id.mat
--------------------------------------------------------------------------------
/evaluation/val_id.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/evaluation/val_id.mat
--------------------------------------------------------------------------------
/extract_feature.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os
3 |
4 | import argparse
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 | import scipy.io as sio
9 | from torch.backends import cudnn
10 |
11 | from reid.models.newresnet import *
12 | from reid.utils.serialization import load_checkpoint
13 | import torchvision.transforms as transforms
14 | from PIL import Image
15 | from reid.utils.data import transforms as T
16 |
17 |
18 | def getArgs():
19 | parser = argparse.ArgumentParser(description="Cross_modality for Person Re-identification")
20 | # data
21 | parser.add_argument('--height', type=int, default=256,
22 | help="input height, default: 256 for resnet*, "
23 | "144 for inception")
24 | parser.add_argument('--width', type=int, default=128,
25 | help="input width, default: 128 for resnet*, "
26 | "56 for inception")
27 |
28 | parser.add_argument('--features', type=int, default=2048)
29 | parser.add_argument('--dropout', type=float, default=0)
30 |
31 | # Bottleneck
32 | parser.add_argument('-z_dim', type=int, default=256,
33 | help="dimension of latent z, better belongs to {128, 256, 512}")
34 | # device set
35 | parser.add_argument('--visible_device', default='1, 0', type=str, help='gpu_ids: e.g. 0, 0,1,2 0,2')
36 | parser.add_argument('--weight-decay', type=float, default=5e-4)
37 | return parser.parse_args()
38 |
39 | def count_param(model):
40 | param_count = 0
41 | for param in model.parameters():
42 | param_count += param.view(-1).size()[0]
43 | return param_count
44 |
45 | args = getArgs()
46 |
47 | features = 8192 # 8192 for observation (four in total), 2048 for representation (four in total)
48 | os.environ["CUDA_VISIBLE_DEVICES"] = "1, 0"
49 | model = ft_net(args=args, num_classes=395, num_features=args.features)
50 | model_path = "/home/txd/Variational Distillation/Exp_before_OpenSource_7/"
51 | checkpoint_I = load_checkpoint(model_path + 'model_best.pth.tar')
52 | model = nn.DataParallel(model, device_ids=[0, 1])
53 | cudnn.benchmark = True
54 | #########################################
55 | checkpoint = load_checkpoint(model_path + 'model_best.pth.tar')
56 | state_dict = checkpoint['model']
57 | from collections import OrderedDict
58 | new_state_dict = OrderedDict()
59 | for k, v in state_dict.items():
60 | if 'module' not in k or 'cam_module' in k:
61 | k = 'module.' + k
62 | else:
63 | k = k.replace('features.module.', 'module.features.')
64 | new_state_dict[k] = v
65 | model.load_state_dict(new_state_dict)
66 | #########################################
67 | model = model.cuda()
68 | model.eval()
69 | normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406],
70 | std=[0.229, 0.224, 0.225])
71 |
72 | test_transformer = T.Compose([
73 | T.Resize((256, 128)),
74 | T.ToTensor(),
75 | normalizer,
76 | ])
77 |
78 | print('model initialed \n')
79 | print("The dimension of the feature is " + str(features))
80 | print("Total params: " + str(count_param(model)))
81 |
82 | cam_name = ''
83 |
84 | path_list = ['./data/sysu/SYSU-MM01/cam1/','./data/sysu/SYSU-MM01/cam2/','./data/sysu/SYSU-MM01/cam3/', \
85 | './data/sysu/SYSU-MM01/cam4/','./data/sysu/SYSU-MM01/cam5/','./data/sysu/SYSU-MM01/cam6/']
86 | pic_num = [333,333,533,533,533,333]
87 | for index, path in enumerate(path_list):
88 | print(index)
89 | cams = torch.LongTensor([index])
90 | sub = ((cams == 2).long() + (cams == 5).long()).cuda()
91 | count = 1
92 | array_list = []
93 | person_id_list = []
94 | dict_person = {}
95 | tot_num = pic_num[index]
96 | array_list_to_array = [[] for _ in range(tot_num)]
97 | #print(path)
98 | for fpathe,dirs,fs in os.walk(path):
99 | person_id = fpathe.split('/')[-1]
100 | if(person_id == ''):
101 | continue
102 | cam_name = fpathe[-9:-5]
103 | fs.sort()
104 | person_id_list.append(person_id)
105 | dict_person[person_id] = fs
106 | person_id_list.sort()
107 | for person in person_id_list:
108 | temp_list = []
109 | for imagename in dict_person[person]:
110 | filename = path + str(person) + '/' + imagename
111 | img=Image.open(filename)
112 | img=test_transformer(img)
113 | img=img.view([1,3,256,128])
114 | img=img.cuda()
115 | i_observation, i_representation, i_ms_observation, i_ms_representation, \
116 | v_observation, v_representation, v_ms_observation, v_ms_representation = model(img)
117 | result_y = torch.cat(tensors=[i_observation[1], i_ms_observation[1],
118 | v_observation[1], v_ms_observation[1]], dim=1)
119 |
120 | result_y = torch.nn.functional.normalize(result_y, dim=1, p=2)
121 | result_y = result_y.view(-1, features)
122 | result_y = result_y.squeeze()
123 | result_npy=result_y.data.cpu().numpy()
124 | result_npy=result_npy.astype('double')
125 | temp_list.append(result_npy)
126 | temp_array = np.array(temp_list)
127 | array_list_to_array[int(person)-1]=temp_array
128 | array_list_to_array = np.array(array_list_to_array)
129 |
130 | sio.savemat(model_path + "observation" + cam_name + '.mat', {'feature_test':array_list_to_array})
131 |
132 |
--------------------------------------------------------------------------------
/images/embedding_spaces.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/images/embedding_spaces.jpg
--------------------------------------------------------------------------------
/images/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/images/framework.jpg
--------------------------------------------------------------------------------
/images/joint_embedding_spaces.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/images/joint_embedding_spaces.jpg
--------------------------------------------------------------------------------
/mm01.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import argparse
3 | import os.path as osp
4 | import os
5 | import numpy as np
6 | import sys
7 | import torch
8 | from torch import nn
9 | from torch.backends import cudnn
10 | from torch.utils.data import DataLoader
11 | from reid import datasets
12 | from reid.dist_metric import DistanceMetric
13 | from reid.models import ft_net
14 | from reid.trainers import Trainer
15 | from reid.evaluators import Evaluator
16 | from reid.utils.data import transforms as T
17 | from reid.utils.data.preprocessor import Preprocessor
18 | from reid.utils.data.sampler import CamRandomIdentitySampler as RandomIdentitySampler
19 | from reid.utils.data.sampler import CamSampler
20 | from reid.utils.logging import Logger
21 | from reid.utils.serialization import load_checkpoint, save_checkpoint
22 | from utlis import RandomErasing, WarmupMultiStepLR, CrossEntropyLabelSmooth, Rank_loss, ASS_loss
23 |
24 |
25 | def get_data(name, split_id, data_dir, height, width, batch_size, num_instances,
26 | workers, combine_trainval, flip_prob, padding, re_prob, using_HuaWeiCloud, cloud_dataset_root):
27 | root = osp.join(data_dir, name)
28 |
29 | if using_HuaWeiCloud: root = cloud_dataset_root
30 |
31 | print(root)
32 |
33 | dataset = datasets.create(name, root, split_id=split_id)
34 |
35 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
36 | std=[0.229, 0.224, 0.225])
37 |
38 | trainvallabel = dataset.trainvallabel
39 | train_set = dataset.trainval if combine_trainval else dataset.train
40 | num_classes = (dataset.num_trainval_ids if combine_trainval
41 | else dataset.num_train_ids)
42 |
43 | train_transformer = T.Compose([
44 | T.Resize((height, width)),
45 | T.RandomHorizontalFlip(p=flip_prob),
46 | T.Pad(padding),
47 | T.RandomCrop((height, width)),
48 | T.ToTensor(),
49 | normalizer,
50 | RandomErasing(probability=re_prob, mean=[0.485, 0.456, 0.406])
51 | ])
52 |
53 | test_transformer = T.Compose([
54 | T.Resize((height, width)),
55 | T.ToTensor(),
56 | normalizer,
57 | ])
58 |
59 | val_loader = DataLoader(
60 | Preprocessor(dataset.val, root=dataset.images_dir,
61 | transform=test_transformer),
62 | batch_size=32, num_workers=workers,
63 | shuffle=False, pin_memory=True)
64 |
65 | query_loader = DataLoader(
66 | Preprocessor(list(set(dataset.query)),
67 | root=dataset.images_dir, transform=test_transformer),
68 | batch_size=32, num_workers=workers,
69 | sampler=CamSampler(list(set(dataset.query)), [2,5]),
70 | shuffle=False, pin_memory=True)
71 |
72 | gallery_loader = DataLoader(
73 | Preprocessor(list(set(dataset.gallery)),
74 | root=dataset.images_dir, transform=test_transformer),
75 | batch_size=32, num_workers=workers,
76 | sampler=CamSampler(list(set(dataset.gallery)), [0,1,3,4], 4),
77 | shuffle=False, pin_memory=True)
78 |
79 | train_loader = DataLoader(
80 | Preprocessor(train_set, root=dataset.images_dir,
81 | transform=train_transformer),
82 | batch_size=batch_size, num_workers=workers,
83 | sampler=RandomIdentitySampler(train_set, num_instances),
84 | pin_memory=True, drop_last=True)
85 |
86 | return dataset, num_classes, train_loader, trainvallabel, val_loader, query_loader, gallery_loader
87 |
88 | def main(args):
89 | np.random.seed(args.seed)
90 | torch.manual_seed(args.seed)
91 | cudnn.benchmark = True
92 |
93 | if not args.evaluate:
94 | sys.stdout = Logger(osp.join(args.logs_dir+'/log'))
95 |
96 | if args.height is None or args.width is None: args.height, args.width = (256, 128)
97 |
98 | # Dataset and loader
99 | dataset, num_classes, train_loader, trainvallabel, val_loader, query_loader, gallery_loader = \
100 | get_data(args.dataset, args.split, args.data_dir, args.height,
101 | args.width, args.batch_size, args.num_instances, args.workers,
102 | args.combine_trainval, args.flip_prob, args.padding, args.re_prob,
103 | args.HUAWEI_cloud, args.dataset_root)
104 |
105 | # Model settings
106 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_device
107 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
108 | model = ft_net(args=args, num_classes=num_classes, num_features=args.features)
109 | model = nn.DataParallel(model, device_ids=[0,1])
110 | model = model.to(device)
111 |
112 | # Evaluation components
113 | evaluator = Evaluator(model)
114 | metric = DistanceMetric(algorithm=args.dist_metric)
115 |
116 | start_epoch = 0
117 | if args.resume:
118 | #########################################
119 | checkpoint = load_checkpoint(args.resume)
120 | state_dict = checkpoint['model']
121 | from collections import OrderedDict
122 | new_state_dict = OrderedDict()
123 | for k, v in state_dict.items():
124 | if 'module' not in k:
125 | k = 'module.' + k
126 | else:
127 | k = k.replace('features.module.', 'module.features.')
128 | new_state_dict[k] = v
129 | model.load_state_dict(new_state_dict)
130 | #########################################
131 | start_epoch = checkpoint['epoch']
132 | print("=> Start epoch {}".format(start_epoch))
133 |
134 | if args.evaluate:
135 | metric.train(model, train_loader)
136 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery)
137 | exit()
138 |
139 | # Losses
140 | ce_Loss = CrossEntropyLabelSmooth(num_classes= num_classes, epsilon=args.epsilon).cuda()
141 | associate_loss = ASS_loss().cuda()
142 | rank_Loss = Rank_loss(margin_1= args.margin_1, margin_2 =args.margin_2, alpha_1 =args.alpha_1, alpha_2= args.alpha_2).cuda()
143 |
144 | print(args)
145 | # optimizers and schedulers
146 | conv_optim = model.module.optims()
147 |
148 | conv_scheduler = WarmupMultiStepLR(conv_optim, args.mile_stone, args.gamma, args.warmup_factor,
149 | args.warmup_iters, args.warmup_methods)
150 |
151 | trainer = Trainer(args, model, ce_Loss, rank_Loss, associate_loss, trainvallabel)
152 |
153 | best_top1 = -1
154 |
155 | # Start training
156 | for epoch in range(start_epoch, args.epochs):
157 | conv_scheduler.step()
158 |
159 | triple_loss, tot_loss = trainer.train(epoch, train_loader, conv_optim)
160 |
161 | save_checkpoint({
162 | 'model': model.module.state_dict(),
163 | 'epoch': epoch + 1,
164 | 'best_top1': best_top1,
165 | }, False, epoch, args.logs_dir, fpath='checkpoint.pth.tar')
166 |
167 | if epoch < args.begin_test:
168 | continue
169 | if not epoch % args.evaluate_freq == 0:
170 | continue
171 |
172 | top1 = evaluator.evaluate(query_loader, gallery_loader, metric)
173 |
174 | is_best = top1 > best_top1
175 | best_top1 = max(top1, best_top1)
176 | save_checkpoint({
177 | 'model': model.module.state_dict(),
178 | 'epoch': epoch + 1,
179 | 'best_top1': best_top1,
180 | }, is_best, epoch, args.logs_dir, fpath='checkpoint.pth.tar')
181 |
182 | print('Test with best model:')
183 | print('\n * Finished epoch {:3d} top1: {:5.1%} best: {:5.1%}{}\n'.
184 | format(epoch, top1, best_top1, ' *' if is_best else ''))
185 | print(args)
186 |
187 |
188 | if __name__ == '__main__':
189 | parser = argparse.ArgumentParser(description="Cross_modality for Person Re-identification")
190 |
191 | # dataset settings
192 | parser.add_argument('-d', '--dataset', type=str, default='sysu', choices=datasets.names())
193 | parser.add_argument('-b', '--batch-size', type=int, default=128)
194 | parser.add_argument('-j', '--workers', type=int, default=4)
195 | parser.add_argument('--split', type=int, default=0)
196 |
197 | # transformer
198 | parser.add_argument('--height', type=int, default=256)
199 | parser.add_argument('--width', type=int, default= 128)
200 | parser.add_argument('--flip_prob', type=float, default=0.5)
201 | parser.add_argument('--re_prob', type=float, default=0.0)
202 | parser.add_argument('--padding', type=int, default=0)
203 | parser.add_argument('--combine-trainval', default=True, action='store_true',
204 | help="train and val sets together for training, "
205 | "val set alone for validation")
206 | parser.add_argument('--num-instances', type=int, default=8,
207 | help="each minibatch consist of "
208 | "(batch_size // num_instances) identities, and "
209 | "each identity has num_instances instances, "
210 | "default: 4")
211 | # model
212 | parser.add_argument('--features', type=int, default=2048)
213 |
214 | # rank loss settings
215 | parser.add_argument('--margin_1', type=float, default=0.9, help="margin_1 of the triplet loss, default: 0.9")
216 | parser.add_argument('--margin_2', type=float, default=1.0, help="margin_1 of the triplet loss, default: 1.5")
217 | parser.add_argument('--alpha_1', type=float, default=2.2, help="alpha_1 of the triplet loss, default: 2.4")
218 | parser.add_argument('--alpha_2', type=float, default=2.0, help="alpha_2 of the triplet loss, default: 2.2")
219 |
220 | # optimizer and scheduler
221 | parser.add_argument('--lr', type=float, default=2.6e-4, help="learning rate of all parameters")
222 | parser.add_argument('--weight-decay', type=float, default=5e-4)
223 | parser.add_argument('--use_adam', action='store_true', help="use Adam as the optimizer, elsewise SGD ")
224 | parser.add_argument('--gamma', type=float, default = 0.1, help="gamma for learning rate decay")
225 |
226 | parser.add_argument('--mile_stone', type=list, default=[210])
227 |
228 | parser.add_argument('--warmup_iters', type=int, default=10)
229 | parser.add_argument('--warmup_methods', type=str, default = 'linear', choices=('linear', 'constant'))
230 | parser.add_argument('--warmup_factor', type=float, default = 0.01 )
231 |
232 | # training configs
233 | parser.add_argument('--resume', type=str, default='', metavar='')
234 | parser.add_argument('--evaluate',action='store_true',
235 | help="this option meaningless "
236 | "since it is required to conduct evaluation on officially approved codes")
237 |
238 | parser.add_argument('--epochs', type=int, default=600)
239 | parser.add_argument('--start_save', type=int, default=100, help="start saving checkpoints after specific epoch")
240 | parser.add_argument('--begin_test', type=int, default=100)
241 | parser.add_argument('--evaluate_freq', type=int, default=5)
242 |
243 | parser.add_argument('--seed', type=int, default=1)
244 |
245 | # adopted metric
246 | parser.add_argument('--dist-metric', type=str, default='euclidean')
247 |
248 | # misc
249 | working_dir = osp.dirname(osp.abspath(__file__))
250 | parser.add_argument('--data-dir', type=str, metavar='PATH',
251 | default=osp.join(working_dir, 'data'))
252 | parser.add_argument('--logs-dir', type=str, metavar='PATH',
253 | default=osp.join(working_dir, 'Exp_before_OpenSource_7_1'))
254 | # hyper-parameters
255 | parser.add_argument('-CE_loss', type=int, default=1, help="weight of cross entropy loss")
256 | parser.add_argument('-epsilon', type=float, default=0.1, help="label smooth")
257 | parser.add_argument('-Triplet_loss', type=int, default=1, help="weight of triplet loss")
258 | parser.add_argument('-Associate_loss', type=float, default=0.0, help="weight of loss")
259 |
260 | parser.add_argument('-CML_loss', type=int, default=8, help="the weight of conventional mutual learning")
261 | parser.add_argument('-VCD_loss', type=int, default=2, help="the weight of VCD and VML")
262 | parser.add_argument('-VSD_loss', type=float, default=2, help="weight of VSD")
263 | parser.add_argument('-temperature', type=int, default=1, help="the temperature used in knowledge distillation")
264 |
265 | # Bottleneck
266 | parser.add_argument('-z_dim', type=int, default=256, help="dimension of latent z, better set to {128, 256, 512}")
267 | # device set
268 | parser.add_argument('--visible_device', default='2, 1', type=str, help='gpu_ids: e.g. 0, 0,1,2 0,2')
269 |
270 | # HUAWEI cloud
271 | parser.add_argument('--HUAWEI_cloud', type=bool, default=False)
272 | parser.add_argument('--dataset_root', type=str, metavar='PATH', default="/test-ddag/dataset")
273 | parser.add_argument('--data_url', type=str, default="")
274 | parser.add_argument('--init_method', type=str, default="")
275 | parser.add_argument('--train_url', type=str, default="")
276 |
277 | main(parser.parse_args())
278 |
--------------------------------------------------------------------------------
/regdb.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import argparse
3 | import os.path as osp
4 | import os
5 | import numpy as np
6 | import sys
7 | import torch
8 | from torch import nn
9 | from torch.backends import cudnn
10 | from torch.utils.data import DataLoader
11 | from reid import datasets
12 | from reid.trainers import Trainer
13 | from reid.dist_metric import DistanceMetric
14 | from reid.evaluators_regdb import Evaluator
15 | from reid.models import ft_net
16 | from reid.utils.data import transforms as T
17 | from reid.utils.data.preprocessor import Preprocessor
18 | from reid.utils.data.sampler import CamRandomIdentitySampler as RandomIdentitySampler
19 | from reid.utils.data.sampler import CamSampler
20 | from reid.utils.logging import Logger
21 | from reid.utils.serialization import load_checkpoint, save_checkpoint
22 | from utlis import RandomErasing, WarmupMultiStepLR, CrossEntropyLabelSmooth, Rank_loss, ASS_loss
23 |
24 |
25 |
26 | def get_data(name, split_id, data_dir, height, width, batch_size, num_instances,
27 | workers, combine_trainval, flip_prob, padding, re_prob, ii):
28 | root = osp.join(data_dir, name)
29 | print(root)
30 |
31 | dataset = datasets.create(name, root, split_id=split_id, ii=ii)
32 |
33 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
34 | std=[0.229, 0.224, 0.225])
35 |
36 | trainvallabel = dataset.trainvallabel
37 | train_set = dataset.trainval if combine_trainval else dataset.train
38 | num_classes = (dataset.num_trainval_ids if combine_trainval else dataset.num_train_ids)
39 | #print("Number of training classes:" + str(dataset.num_trainval_ids))
40 |
41 |
42 | train_transformer = T.Compose([
43 | T.Resize((height, width)),
44 | T.RandomHorizontalFlip(p=flip_prob),
45 | T.Pad(padding),
46 | T.RandomCrop((height, width)),
47 | T.ToTensor(),
48 | normalizer,
49 | RandomErasing(probability=re_prob, mean=[0.485, 0.456, 0.406])
50 | ])
51 |
52 | test_transformer = T.Compose([
53 | T.Resize((height, width)),
54 | T.ToTensor(),
55 | normalizer,
56 | ])
57 |
58 | val_loader = DataLoader(
59 | Preprocessor(dataset.val, root=dataset.images_dir,
60 | transform=test_transformer),
61 | batch_size=32, num_workers=workers,
62 | shuffle=False, pin_memory=True)
63 |
64 | query_loader = DataLoader(
65 | Preprocessor(list(set(dataset.query)),
66 | root=dataset.images_dir, transform=test_transformer),
67 | batch_size=32, num_workers=workers,
68 | sampler=CamSampler(list(set(dataset.query)), [2]),
69 | shuffle=False, pin_memory=True)
70 |
71 | gallery_loader = DataLoader(
72 | Preprocessor(list(set(dataset.gallery)),
73 | root=dataset.images_dir, transform=test_transformer),
74 | batch_size=32, num_workers=workers,
75 | sampler=CamSampler(list(set(dataset.gallery)), [0], 1),
76 | shuffle=False, pin_memory=True)
77 |
78 | train_loader = DataLoader(
79 | Preprocessor(train_set, root=dataset.images_dir,
80 | transform=train_transformer),
81 | batch_size=batch_size, num_workers=workers,
82 | sampler=RandomIdentitySampler(train_set, num_instances),
83 | pin_memory=True, drop_last=True)
84 |
85 | return dataset, num_classes, train_loader, trainvallabel, val_loader, query_loader, gallery_loader
86 |
87 |
88 | def main(args):
89 | np.random.seed(args.seed)
90 | torch.manual_seed(args.seed)
91 | cudnn.benchmark = True
92 | for ii in range(1,11):
93 | #if ii == 5 or ii == 6: ii = ii - 4
94 | print(ii)
95 | if not osp.exists(args.logs_dir+'/{}'.format(ii)):
96 | os.mkdir(args.logs_dir+'/{}'.format(ii))
97 | sys.stdout = Logger(osp.join(args.logs_dir+'/{}/log'.format(ii)))
98 | dataset, num_classes, train_loader, trainvallabel, val_loader, query_loader, gallery_loader = \
99 | get_data(args.dataset, args.split, args.data_dir, args.height,
100 | args.width, args.batch_size, args.num_instances, args.workers,
101 | args.combine_trainval, args.flip_prob, args.padding, args.re_prob,ii+1)
102 | if not args.evaluate:
103 | sys.stdout = Logger(osp.join(args.logs_dir+'/log'))
104 |
105 | # Model settings
106 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_device
107 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
108 | model = ft_net(args=args, num_classes=num_classes, num_features=args.features)
109 | model = nn.DataParallel(model, device_ids=[0, 1])
110 | model = model.to(device)
111 |
112 | # Evaluator settings
113 | evaluator = Evaluator(model)
114 | metric = DistanceMetric(algorithm=args.dist_metric)
115 |
116 | # Resume settings
117 | if args.resume:
118 | checkpoint = load_checkpoint(args.resume)
119 | state_dict = checkpoint['model']
120 | from collections import OrderedDict
121 | new_state_dict = OrderedDict()
122 | for k, v in state_dict.items():
123 | if 'module' not in k:
124 | k = 'module.' + k
125 | else:
126 | k = k.replace('features.module.', 'module.features.')
127 | new_state_dict[k] = v
128 | model.load_state_dict(new_state_dict)
129 | start_epoch = checkpoint['epoch']
130 | print("=> Start epoch {}".format(start_epoch))
131 |
132 | # Skip training and conducting evaluation in python
133 | if args.evaluate:
134 | metric.train(model, train_loader)
135 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery)
136 | exit()
137 |
138 | # Losses
139 | ce_Loss = CrossEntropyLabelSmooth(num_classes=num_classes, epsilon=args.epsilon).cuda()
140 | associate_loss = ASS_loss().cuda()
141 | rank_Loss = Rank_loss(margin_1=args.margin_1, margin_2=args.margin_2, alpha_1=args.alpha_1, alpha_2=args.alpha_2).cuda()
142 |
143 | # Optimizers and schedulers
144 | conv_optim = model.module.optims()
145 | conv_scheduler = WarmupMultiStepLR(conv_optim, args.mile_stone, args.gamma, args.warmup_factor,
146 | args.warmup_iters, args.warmup_methods)
147 |
148 | trainer = Trainer(args, model, ce_Loss, rank_Loss, associate_loss, trainvallabel)
149 | best_top1 = -1
150 | start_epoch = 0
151 | print(args)
152 | for epoch in range(start_epoch, args.epochs):
153 | conv_scheduler.step()
154 | triple_loss, tot_loss = trainer.train(epoch, train_loader, conv_optim)
155 |
156 | save_checkpoint({
157 | 'model': model.module.state_dict(),
158 | 'epoch': epoch + 1,
159 | 'best_top1': best_top1,
160 | }, False, epoch, (args.logs_dir+'/{}'.format(ii)), fpath='checkpoint.pth.tar')
161 |
162 | if epoch < args.begin_test:
163 | continue
164 | if not epoch % args.evaluate_freq == 0:
165 | continue
166 |
167 | top1, cmc, mAP = evaluator.evaluate(query_loader, gallery_loader, metric)
168 |
169 | is_best = top1 > best_top1
170 | best_top1 = max(top1, best_top1)
171 | save_checkpoint({
172 | 'model': model.module.state_dict(),
173 | 'epoch': epoch + 1,
174 | 'best_top1': best_top1,
175 | }, is_best, epoch, (args.logs_dir+'/{}'.format(ii)), fpath='checkpoint.pth.tar')
176 | print(args)
177 |
178 | if __name__ == '__main__':
179 | parser = argparse.ArgumentParser(description="Cross_modality for Person Re-identification")
180 |
181 | # dataset settings
182 | parser.add_argument('-d', '--dataset', type=str, default='RegDB', choices=datasets.names())
183 | parser.add_argument('-b', '--batch-size', type=int, default=64)
184 | parser.add_argument('-j', '--workers', type=int, default=4)
185 | parser.add_argument('--split', type=int, default=0)
186 |
187 | # transformer
188 | parser.add_argument('--height', type=int, default=256)
189 | parser.add_argument('--width', type=int, default= 128)
190 | parser.add_argument('--flip_prob', type=float, default=0.5)
191 | parser.add_argument('--re_prob', type=float, default=0.0)
192 | parser.add_argument('--padding', type=int, default=0)
193 | parser.add_argument('--combine-trainval', default=True, action='store_true',
194 | help="train and val sets together for training, "
195 | "val set alone for validation")
196 | parser.add_argument('--num-instances', type=int, default=8,
197 | help="each minibatch consist of "
198 | "(batch_size // num_instances) identities, and "
199 | "each identity has num_instances instances, "
200 | "default: 4")
201 | # model
202 | parser.add_argument('--features', type=int, default=2048)
203 |
204 | # rank loss settings
205 | parser.add_argument('--margin_1', type=float, default=0.9, help="margin_1 of the triplet loss, default: 0.9")
206 | parser.add_argument('--margin_2', type=float, default=1.0, help="margin_1 of the triplet loss, default: 1.5")
207 | parser.add_argument('--alpha_1', type=float, default=2.2, help="alpha_1 of the triplet loss, default: 2.4")
208 | parser.add_argument('--alpha_2', type=float, default=2.0, help="alpha_2 of the triplet loss, default: 2.2")
209 |
210 | # optimizer and scheduler
211 | parser.add_argument('--lr', type=float, default=2.6e-4, help="learning rate of all parameters")
212 | parser.add_argument('--weight-decay', type=float, default=5e-4)
213 | parser.add_argument('--use_adam', action='store_true', help="use Adam as the optimizer, elsewise SGD ")
214 | parser.add_argument('--gamma', type=float, default = 0.1, help="gamma for learning rate decay")
215 |
216 | parser.add_argument('--mile_stone', type=list, default=[210])
217 |
218 | parser.add_argument('--warmup_iters', type=int, default=10)
219 | parser.add_argument('--warmup_methods', type=str, default = 'linear', choices=('linear', 'constant'))
220 | parser.add_argument('--warmup_factor', type=float, default = 0.01 )
221 |
222 | # training configs
223 | parser.add_argument('--resume', type=str, default='', metavar='')
224 | parser.add_argument('--evaluate',action='store_true',
225 | help="this option meaningless "
226 | "since it is required to conduct evaluation on officially approved codes")
227 |
228 | parser.add_argument('--epochs', type=int, default=600)
229 | parser.add_argument('--start_save', type=int, default=100, help="start saving checkpoints after specific epoch")
230 | parser.add_argument('--begin_test', type=int, default=100)
231 | parser.add_argument('--evaluate_freq', type=int, default=5)
232 |
233 | parser.add_argument('--seed', type=int, default=1)
234 |
235 | # adopted metric
236 | parser.add_argument('--dist-metric', type=str, default='euclidean')
237 |
238 | # misc
239 | working_dir = osp.dirname(osp.abspath(__file__))
240 | parser.add_argument('--data-dir', type=str, metavar='PATH',
241 | default=osp.join(working_dir, 'data'))
242 | parser.add_argument('--logs-dir', type=str, metavar='PATH',
243 | default=osp.join(working_dir, 'Exp_RegDB/RegDB_2'))
244 | # hyper-parameters
245 | parser.add_argument('-CE_loss', type=int, default=1, help="weight of cross entropy loss")
246 | parser.add_argument('-epsilon', type=float, default=0.1, help="label smooth")
247 | parser.add_argument('-Triplet_loss', type=int, default=1, help="weight of triplet loss")
248 | parser.add_argument('-Associate_loss', type=float, default=0.0, help="weight of loss")
249 |
250 | parser.add_argument('-CML_loss', type=int, default=8, help="the weight of conventional mutual learning")
251 | parser.add_argument('-VCD_loss', type=int, default=2, help="the weight of VCD and VML")
252 | parser.add_argument('-VSD_loss', type=float, default=2, help="weight of VSD")
253 | parser.add_argument('-temperature', type=int, default=1, help="the temperature used in knowledge distillation")
254 |
255 | # Bottleneck
256 | parser.add_argument('-z_dim', type=int, default=256, help="dimension of latent z, better set to {128, 256, 512}")
257 | # device set
258 | parser.add_argument('--visible_device', default='2, 1', type=str, help='gpu_ids: e.g. 0, 0,1,2 0,2')
259 |
260 |
261 | main(parser.parse_args())
262 |
--------------------------------------------------------------------------------
/reid/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from . import datasets
4 | from . import evaluation_metrics
5 | from . import feature_extraction
6 | from . import metric_learning
7 | from . import models
8 | from . import utils
9 | from . import dist_metric
10 | from . import evaluators
11 | from . import trainers
12 |
13 | __version__ = '0.2.0'
14 |
--------------------------------------------------------------------------------
/reid/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__init__.pyc
--------------------------------------------------------------------------------
/reid/__pycache__/New_trainers_3.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/New_trainers_3.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/__pycache__/dist_metric.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/dist_metric.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/__pycache__/dist_metric.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/dist_metric.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/__pycache__/evaluators.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/evaluators.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/__pycache__/evaluators.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/evaluators.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/__pycache__/trainers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/trainers.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/__pycache__/trainers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/__pycache__/trainers.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/datasets/RegDB.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 | from ..utils.data import Dataset
4 | from ..utils.serialization import write_json
5 |
6 | class RegDB(Dataset):
7 | def __init__(self, root, split_id=0, ii=0, num_val=100, download=True):
8 | super(RegDB, self).__init__(root, split_id=split_id)
9 | self.ii = ii
10 | if download:
11 | self.download()
12 |
13 | self.load(num_val)
14 |
15 | def download(self):
16 | index_train_RGB = open('./data/RegDB/idx/train_visible_{}.txt'.format(self.ii),'r')
17 | index_train_IR = open('./data/RegDB/idx/train_thermal_{}.txt'.format(self.ii),'r')
18 | index_test_RGB = open('./data/RegDB/idx/test_visible_{}.txt'.format(self.ii),'r')
19 | index_test_IR = open('./data/RegDB/idx/test_thermal_{}.txt'.format(self.ii),'r')
20 |
21 | def loadIdx(index):
22 | Lines = index.readlines()
23 | idx = []
24 | for line in Lines:
25 | tmp = line.strip('\n')
26 | tmp = tmp.split(' ')
27 | idx.append(tmp)
28 | return idx
29 |
30 | index_train_RGB = loadIdx(index_train_RGB)
31 | index_train_IR = loadIdx(index_train_IR)
32 | index_test_RGB = loadIdx(index_test_RGB)
33 | index_test_IR = loadIdx(index_test_IR)
34 |
35 | # 412 identities with 3 camera views each
36 | identities = [[[] for _ in range(3)] for _ in range(412)]
37 | def insertToMeta(index, cam, delta):
38 | for idx in index:
39 | fname = osp.basename(idx[0])
40 |
41 | pid = int(idx[1]) + delta
42 |
43 | identities[pid][cam].append(fname)
44 |
45 | insertToMeta(index_train_RGB, 0, 0)
46 | insertToMeta(index_train_IR, 2, 0)
47 | insertToMeta(index_test_RGB, 0, 206)
48 | insertToMeta(index_test_IR, 2, 206)
49 |
50 | trainval_pids = set()
51 | gallery_pids = set()
52 | query_pids = set()
53 | for i in range(206):
54 | trainval_pids.add(i)
55 | gallery_pids.add(i + 206)
56 | query_pids.add(i + 206)
57 |
58 | # Save meta information into a json file
59 | meta = {'name': 'RegDB', 'shot': 'multiple', 'num_cameras': 3,
60 | 'identities': identities}
61 | write_json(meta, osp.join(self.root, 'meta.json'))
62 |
63 | # Save the only training / test split
64 | splits = [{
65 | 'trainval': sorted(list(trainval_pids)),
66 | 'query': sorted(list(query_pids)),
67 | 'gallery': sorted(list(gallery_pids))}]
68 | write_json(splits, osp.join(self.root, 'splits.json'))
69 |
--------------------------------------------------------------------------------
/reid/datasets/RegDB.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/RegDB.pyc
--------------------------------------------------------------------------------
/reid/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import warnings
3 |
4 | from .sysu import SYSU
5 | from .RegDB import RegDB
6 |
7 |
8 | __factory = {
9 | 'sysu': SYSU,
10 | 'RegDB': RegDB,
11 | }
12 |
13 |
14 | def names():
15 | return sorted(__factory.keys())
16 |
17 |
18 | def create(name, root, *args, **kwargs):
19 | """
20 | Create a dataset instance.
21 |
22 | Parameters
23 | ----------
24 | name : str
25 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03',
26 | 'market1501', and 'dukemtmc'.
27 | root : str
28 | The path to the dataset directory.
29 | split_id : int, optional
30 | The index of data split. Default: 0
31 | num_val : int or float, optional
32 | When int, it means the number of validation identities. When float,
33 | it means the proportion of validation to all the trainval. Default: 100
34 | download : bool, optional
35 | If True, will download the dataset. Default: False
36 | """
37 | if name not in __factory:
38 | raise KeyError("Unknown dataset:", name)
39 | return __factory[name](root, *args, **kwargs)
40 |
41 |
42 | def get_dataset(name, root, *args, **kwargs):
43 | warnings.warn("get_dataset is deprecated. Use create instead.")
44 | return create(name, root, *args, **kwargs)
45 |
--------------------------------------------------------------------------------
/reid/datasets/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__init__.pyc
--------------------------------------------------------------------------------
/reid/datasets/__pycache__/RegDB.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/RegDB.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/datasets/__pycache__/RegDB.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/RegDB.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/datasets/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/datasets/__pycache__/sysu.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/sysu.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/datasets/__pycache__/sysu.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/__pycache__/sysu.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/datasets/sysu.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 | import numpy
4 | from ..utils.data import Dataset
5 | from ..utils.osutils import mkdir_if_missing
6 | from ..utils.serialization import write_json
7 |
8 |
9 | class SYSU(Dataset):
10 | def __init__(self, root, split_id=0, num_val=100, download=True):
11 | super(SYSU, self).__init__(root, split_id=split_id)
12 |
13 | self.root += "/SYSU-MM01"
14 |
15 | if download:
16 | self.download()
17 |
18 | if not self._check_integrity():
19 | raise RuntimeError("Dataset not found or corrupted. ")
20 |
21 | self.load(num_val)
22 |
23 | def download(self):
24 |
25 | import shutil
26 | from glob import glob
27 |
28 | # Format
29 | images_dir = osp.join(self.root+'/images')
30 | mkdir_if_missing(images_dir)
31 |
32 | # gain the spilt from .mat
33 | import scipy.io as scio
34 | data = scio.loadmat(self.root+'/exp/train_id.mat')
35 | train_id = data['id'][0]
36 | data = scio.loadmat(self.root+'/exp/val_id.mat')
37 | val_id = data['id'][0]
38 | data = scio.loadmat(self.root+'/exp/test_id.mat')
39 | test_id = data['id'][0]
40 |
41 | # 533 identities with 6 camera views each
42 | identities = [[[] for _ in range(6)] for _ in range(533)]
43 | for pid in range(1, 534):
44 | for cam in range(1,7):
45 | images_path = self.root+"/cam"+str(cam)+"/"+str(pid).zfill(4)
46 | fpaths = sorted(glob(images_path+"/*.jpg"))
47 | for fpath in fpaths:
48 | # print(fpath)
49 | fname = ('{:08d}_{:02d}_{:04d}.jpg'
50 | .format(pid-1, cam-1, len(identities[pid-1][cam-1])))
51 | identities[pid-1][cam-1].append(fname)
52 | shutil.copy(fpath, osp.join(images_dir, fname))
53 |
54 | trainval_pids = set()
55 | gallery_pids = set()
56 | query_pids = set()
57 | train_val_ = numpy.concatenate((train_id,val_id))
58 | for i in (train_val_):
59 | trainval_pids.add(int(i) - 1)
60 | for i in test_id:
61 | gallery_pids.add(int(i) - 1)
62 | query_pids.add(int(i) - 1)
63 |
64 | # Save meta information into a json file
65 | meta = {'name': 'sysu', 'shot': 'multiple', 'num_cameras': 6,
66 | 'identities': identities}
67 | write_json(meta, osp.join(self.root, 'meta.json'))
68 |
69 | # Save the only training / test split
70 | splits = [{
71 | 'trainval': sorted(list(trainval_pids)),
72 | 'query': sorted(list(query_pids)),
73 | 'gallery': sorted(list(gallery_pids))}]
74 | write_json(splits, osp.join(self.root, 'splits.json'))
75 |
--------------------------------------------------------------------------------
/reid/datasets/sysu.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/datasets/sysu.pyc
--------------------------------------------------------------------------------
/reid/dist_metric.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 |
5 | from .evaluators import extract_features
6 | from .metric_learning import get_metric
7 |
8 |
9 | class DistanceMetric(object):
10 | def __init__(self, algorithm='euclidean', *args, **kwargs):
11 | super(DistanceMetric, self).__init__()
12 | self.algorithm = algorithm
13 | self.metric = get_metric(algorithm, *args, **kwargs)
14 |
15 | def train(self, model, data_loader):
16 | if self.algorithm == 'euclidean': return
17 | features, labels = extract_features(model, data_loader)
18 | features = torch.stack(features.values()).numpy()
19 | labels = torch.Tensor(list(labels.values())).numpy()
20 | self.metric.fit(features, labels)
21 |
22 | def transform(self, X):
23 | if torch.is_tensor(X):
24 | X = X.numpy()
25 | X = self.metric.transform(X)
26 | X = torch.from_numpy(X)
27 | else:
28 | X = self.metric.transform(X)
29 | return X
30 |
31 |
--------------------------------------------------------------------------------
/reid/dist_metric.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/dist_metric.pyc
--------------------------------------------------------------------------------
/reid/evaluation_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .ranking import cmc, mean_ap
4 |
5 | __all__ = [
6 | 'cmc',
7 | 'mean_ap',
8 | ]
9 |
--------------------------------------------------------------------------------
/reid/evaluation_metrics/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__init__.pyc
--------------------------------------------------------------------------------
/reid/evaluation_metrics/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/evaluation_metrics/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/evaluation_metrics/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/evaluation_metrics/__pycache__/ranking.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/ranking.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/evaluation_metrics/__pycache__/ranking.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/ranking.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/evaluation_metrics/__pycache__/ranking.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/__pycache__/ranking.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/evaluation_metrics/ranking.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | from sklearn.metrics import average_precision_score
6 |
7 | from ..utils import to_numpy
8 |
9 |
10 | def _unique_sample(ids_dict, num):
11 | mask = np.zeros(num, dtype=np.bool)
12 | for _, indices in ids_dict.items():
13 | i = np.random.choice(indices)
14 | mask[i] = True
15 | return mask
16 |
17 |
18 | def cmc(distmat, query_ids=None, gallery_ids=None,
19 | query_cams=None, gallery_cams=None, topk=100,
20 | separate_camera_set=False,
21 | single_gallery_shot=False,
22 | first_match_break=False):
23 |
24 | distmat = to_numpy(distmat)
25 | m, n = distmat.shape
26 |
27 | # Fill up default values
28 | if query_ids is None:
29 | query_ids = np.arange(m)
30 | if gallery_ids is None:
31 | gallery_ids = np.arange(n)
32 | if query_cams is None:
33 | query_cams = np.zeros(m).astype(np.int32)
34 | if gallery_cams is None:
35 | gallery_cams = np.ones(n).astype(np.int32)
36 |
37 | # Ensure numpy array
38 | query_ids = np.asarray(query_ids)
39 | gallery_ids = np.asarray(gallery_ids)
40 | query_cams = np.asarray(query_cams)
41 | gallery_cams = np.asarray(gallery_cams)
42 |
43 | # Sort and find correct matches
44 | indices = np.argsort(distmat, axis=1)
45 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
46 |
47 | # Compute CMC for each query
48 | ret = np.zeros(topk)
49 | num_valid_queries = 0
50 | for i in range(m):
51 | # Filter out the same id and same camera
52 | valid = ((gallery_ids[indices[i]] != query_ids[i]) |
53 | (gallery_cams[indices[i]] != query_cams[i]))
54 | if separate_camera_set:
55 | # Filter out samples from same camera
56 | valid &= (gallery_cams[indices[i]] != query_cams[i])
57 | if not np.any(matches[i, valid]): continue
58 | if single_gallery_shot:
59 | repeat = 10
60 | gids = gallery_ids[indices[i][valid]]
61 | inds = np.where(valid)[0]
62 | ids_dict = defaultdict(list)
63 | for j, x in zip(inds, gids):
64 | ids_dict[x].append(j)
65 | else:
66 | repeat = 1
67 | for _ in range(repeat):
68 | if single_gallery_shot:
69 | # Randomly choose one instance for each id
70 | sampled = (valid & _unique_sample(ids_dict, len(valid)))
71 | index = np.nonzero(matches[i, sampled])[0]
72 | else:
73 | index = np.nonzero(matches[i, valid])[0]
74 | delta = 1. / (len(index) * repeat)
75 | for j, k in enumerate(index):
76 | if k - j >= topk: break
77 | if first_match_break:
78 | ret[k - j] += 1
79 | break
80 | ret[k - j] += delta
81 | num_valid_queries += 1
82 | if num_valid_queries == 0:
83 | raise RuntimeError("No valid query")
84 | return ret.cumsum() / num_valid_queries
85 |
86 |
87 | def mean_ap(distmat, query_ids=None, gallery_ids=None,
88 | query_cams=None, gallery_cams=None):
89 | distmat = to_numpy(distmat)
90 | m, n = distmat.shape
91 | # Fill up default values
92 | if query_ids is None:
93 | query_ids = np.arange(m)
94 | if gallery_ids is None:
95 | gallery_ids = np.arange(n)
96 | if query_cams is None:
97 | query_cams = np.zeros(m).astype(np.int32)
98 | if gallery_cams is None:
99 | gallery_cams = np.ones(n).astype(np.int32)
100 | # Ensure numpy array
101 | query_ids = np.asarray(query_ids)
102 | gallery_ids = np.asarray(gallery_ids)
103 | query_cams = np.asarray(query_cams)
104 | gallery_cams = np.asarray(gallery_cams)
105 | # Sort and find correct matches
106 | indices = np.argsort(distmat, axis=1)
107 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
108 | # Compute AP for each query
109 | aps = []
110 | for i in range(m):
111 | # Filter out the same id and same camera
112 | valid = ((gallery_ids[indices[i]] != query_ids[i]) |
113 | (gallery_cams[indices[i]] != query_cams[i]))
114 | y_true = matches[i, valid]
115 | y_score = -distmat[i][indices[i]][valid]
116 | if not np.any(y_true): continue
117 | aps.append(average_precision_score(y_true, y_score))
118 | if len(aps) == 0:
119 | raise RuntimeError("No valid query")
120 | return np.mean(aps)
121 |
--------------------------------------------------------------------------------
/reid/evaluation_metrics/ranking.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluation_metrics/ranking.pyc
--------------------------------------------------------------------------------
/reid/evaluators.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import time
3 | from collections import OrderedDict
4 |
5 | import torch
6 |
7 | from .evaluation_metrics import cmc, mean_ap
8 | from .feature_extraction import extract_cnn_feature
9 | from .utils.meters import AverageMeter
10 |
11 |
12 | def extract_features(model, data_loader):
13 | model.eval()
14 | batch_time = AverageMeter()
15 | data_time = AverageMeter()
16 |
17 | features = OrderedDict()
18 | labels = OrderedDict()
19 | filenames = []
20 |
21 | end = time.time()
22 | for i, (imgs, fnames, pids, cams) in enumerate(data_loader):
23 | data_time.update(time.time() - end)
24 |
25 | subs = ((cams == 2).long() + (cams == 5).long()).cuda()
26 | outputs = extract_cnn_feature(model=model, inputs=imgs, sub=subs)
27 | for fname, output, pid in zip(fnames, outputs, pids):
28 | features[fname] = output
29 | labels[fname] = pid
30 | filenames.append(fname)
31 |
32 | batch_time.update(time.time() - end)
33 | end = time.time()
34 |
35 |
36 | return features, labels, filenames
37 |
38 |
39 | def pairwise_distance(features1, features2, fnames1=None, fnames2=None, metric=None):
40 | x = torch.cat([features1[f].unsqueeze(0) for f in fnames1], 0)
41 | y = torch.cat([features2[f].unsqueeze(0) for f in fnames2], 0)
42 |
43 | m, n = x.size(0), y.size(0)
44 | x = x.view(m, -1)
45 | y = y.view(n, -1)
46 |
47 | # normalize
48 | x = torch.nn.functional.normalize(x, dim=1, p=2)
49 | y = torch.nn.functional.normalize(y, dim=1, p=2)
50 |
51 | if metric is not None:
52 | x = metric.transform(x)
53 | y = metric.transform(y)
54 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
55 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
56 | dist.addmm_(1, -2, x, y.t())
57 | return dist
58 |
59 |
60 | def evaluate_all(distmat, labels1, labels2, fnames1, fnames2, flag, cmc_topk=(1, 10, 20)):
61 | query_ids = [labels1[f] for f in fnames1]
62 | gallery_ids = [labels2[f] for f in fnames2]
63 | query_cams = [0 for f in fnames1]
64 | gallery_cams = [2 for f in fnames2]
65 |
66 | """Compute mean AP"""
67 |
68 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams)
69 | print('Mean AP: {:4.2%}'.format(mAP))
70 | if flag: # return mAP only
71 | return mAP
72 |
73 | """Compute all kinds of CMC scores"""
74 |
75 | cmc_configs = {
76 | 'MM01': dict(separate_camera_set=False,
77 | single_gallery_shot=False,
78 | first_match_break=True)}
79 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, query_cams, gallery_cams, **params)
80 | for name, params in cmc_configs.items()}
81 | # name:MM01
82 | # params{'separate_camera_set': False, 'single_gallery_shot': False, 'first_match_break': True}
83 | print('CMC Scores{:>12}'.format('MM01'))
84 | for k in cmc_topk:
85 | print(' top-{:<4}{:12.2%}'.format(k, cmc_scores['MM01'][k - 1]))
86 | # using Rank-1 as the main evaluation criterion
87 | return cmc_scores['MM01'][0]
88 |
89 |
90 | class Evaluator(object):
91 | def __init__(self, model):
92 | super(Evaluator, self).__init__()
93 | self.model = model
94 |
95 | def evaluate(self, data_loader1, data_loader2, metric=None, flag=False):
96 | features1, labels1, fnames1 = extract_features(self.model, data_loader1)
97 | features2, labels2, fnames2 = extract_features(self.model, data_loader2)
98 | distmat = pairwise_distance(features1, features2, fnames1, fnames2, metric=metric)
99 | return evaluate_all(distmat, labels1, labels2, fnames1, fnames2, flag)
100 |
--------------------------------------------------------------------------------
/reid/evaluators.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/evaluators.pyc
--------------------------------------------------------------------------------
/reid/evaluators_regdb.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import time
3 | from collections import OrderedDict
4 |
5 | import torch
6 | import numpy as np
7 | from .evaluation_metrics import cmc, mean_ap
8 | from .feature_extraction import extract_cnn_feature
9 | from .utils.meters import AverageMeter
10 |
11 | def extract_features(model, data_loader):
12 | model.eval()
13 | batch_time = AverageMeter()
14 | data_time = AverageMeter()
15 |
16 | features = OrderedDict()
17 | labels = OrderedDict()
18 | filenames = []
19 |
20 | end = time.time()
21 | for i, (imgs, fnames, pids, cams) in enumerate(data_loader):
22 | data_time.update(time.time() - end)
23 |
24 | subs = ((cams == 2).long() + (cams == 5).long()).cuda()
25 | outputs = extract_cnn_feature(model, imgs, subs)
26 | for fname, output, pid in zip(fnames, outputs, pids):
27 | features[fname] = output
28 | labels[fname] = pid
29 | filenames.append(fname)
30 |
31 | batch_time.update(time.time() - end)
32 | end = time.time()
33 |
34 | return features, labels, filenames
35 |
36 |
37 | def pairwise_distance(features1, features2, fnames1=None, fnames2=None, metric=None):
38 |
39 | x = torch.cat([features1[f].unsqueeze(0) for f in fnames1], 0)
40 | y = torch.cat([features2[f].unsqueeze(0) for f in fnames2], 0)
41 |
42 | m, n = x.size(0), y.size(0)
43 | x = x.view(m, -1)
44 | y = y.view(n, -1)
45 |
46 | # normalize
47 | x = torch.nn.functional.normalize(x, dim=1, p=2)
48 | y = torch.nn.functional.normalize(y, dim=1, p=2)
49 |
50 | if metric is not None:
51 | x = metric.transform(x)
52 | y = metric.transform(y)
53 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
54 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
55 | dist.addmm_(1, -2, x, y.t())
56 | return dist
57 |
58 |
59 | def evaluate_all(distmat, labels1, labels2, fnames1, fnames2, flag, cmc_topk=(1, 10, 20)):
60 | query_ids = [labels1[f] for f in fnames1]
61 | gallery_ids = [labels2[f] for f in fnames2]
62 | query_cams = [0 for f in fnames1]
63 | gallery_cams = [2 for f in fnames2]
64 |
65 | # Compute mean AP
66 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams)
67 | print('Mean AP: {:4.2%}'.format(mAP))
68 |
69 | # return mAP
70 | if flag:
71 | return mAP
72 | # Compute all kinds of CMC scores
73 | cmc_configs = {
74 | 'RegDB': dict(separate_camera_set=False,
75 | single_gallery_shot=False,
76 | first_match_break=True)}
77 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids,
78 | query_cams, gallery_cams, **params)
79 | for name, params in cmc_configs.items()}
80 |
81 | print('CMC Scores{:>12}'.format('RegDB')
82 | )
83 | for k in cmc_topk:
84 | print(' top-{:<4}{:12.2%}'
85 | .format(k,cmc_scores['RegDB'][k - 1])
86 | )
87 |
88 | # Use the allshots cmc top-1 score for validation criterion
89 | return cmc_scores['RegDB'][0]
90 |
91 |
92 | def eval_regdb(distmat, labels1, labels2, fnames1, fnames2, max_rank = 20, cmc_topk=(1, 10, 20)):
93 | q_pids = [labels1[f].numpy() for f in fnames1]
94 | g_pids = [labels2[f].numpy() for f in fnames2]
95 | q_pids= np.array(q_pids)
96 | g_pids= np.array(g_pids)
97 | # q_pids = q_pids.numpy()
98 | # g_pids = g_pids.numpy()
99 | num_q, num_g = distmat.shape
100 | if num_g < max_rank:
101 | max_rank = num_g
102 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
103 | indices = np.argsort(distmat, axis=1)
104 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
105 |
106 | # compute cmc curve for each query
107 | all_cmc = []
108 | all_AP = []
109 | all_INP = []
110 | num_valid_q = 0. # number of valid query
111 |
112 | # only two cameras
113 | q_camids = np.ones(num_q).astype(np.int32)
114 | g_camids = 2* np.ones(num_g).astype(np.int32)
115 |
116 | for q_idx in range(num_q):
117 | # get query pid and camid
118 | q_pid = q_pids[q_idx]
119 | q_camid = q_camids[q_idx]
120 |
121 | # remove gallery samples that have the same pid and camid with query
122 | order = indices[q_idx]
123 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
124 | keep = np.invert(remove)
125 |
126 | # compute cmc curve
127 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
128 | if not np.any(raw_cmc):
129 | # this condition is true when query identity does not appear in gallery
130 | continue
131 |
132 | cmc = raw_cmc.cumsum()
133 |
134 | # compute mINP
135 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook
136 | pos_idx = np.where(raw_cmc == 1)
137 | pos_max_idx = np.max(pos_idx)
138 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0)
139 | all_INP.append(inp)
140 | cmc[cmc > 1] = 1
141 | all_cmc.append(cmc[:max_rank])
142 | num_valid_q += 1.
143 | num_rel = raw_cmc.sum()
144 | tmp_cmc = raw_cmc.cumsum()
145 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
146 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
147 | AP = tmp_cmc.sum() / num_rel
148 | all_AP.append(AP)
149 |
150 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
151 |
152 | all_cmc = np.asarray(all_cmc).astype(np.float32)
153 | all_cmc = all_cmc.sum(0) / num_valid_q
154 | mAP = np.mean(all_AP)
155 | mINP = np.mean(all_INP)
156 |
157 | print('Mean AP: {:4.2%}'.format(mAP))
158 | print('CMC Scores{:>12}'.format('RegDB')
159 | )
160 | for k in cmc_topk:
161 | print(' top-{:<4}{:12.2%}'
162 | .format(k,all_cmc[k - 1])
163 | )
164 | return all_cmc[0], all_cmc, mAP
165 |
166 |
167 | class Evaluator(object):
168 | def __init__(self, model, regdb=True):
169 | super(Evaluator, self).__init__()
170 | self.model = model
171 | self.regdb=regdb
172 |
173 | def evaluate(self, data_loader1, data_loader2, metric=None, flag=False):
174 | features1, labels1, fnames1 = extract_features(model=self.model, data_loader=data_loader1)
175 | features2, labels2, fnames2 = extract_features(model=self.model, data_loader=data_loader2)
176 | distmat = pairwise_distance(features1, features2, fnames1, fnames2, metric=metric)
177 |
178 | if self.regdb:
179 | top1, all_cmc, mAP= eval_regdb(distmat, labels1, labels2, fnames1, fnames2)
180 | return top1, all_cmc, mAP
181 | else:
182 | return evaluate_all(distmat, labels1, labels2, fnames1, fnames2, flag)
183 |
--------------------------------------------------------------------------------
/reid/feature_extraction/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .cnn import extract_cnn_feature
4 |
5 | __all__ = [
6 | 'extract_cnn_feature'
7 | ]
8 |
--------------------------------------------------------------------------------
/reid/feature_extraction/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__init__.pyc
--------------------------------------------------------------------------------
/reid/feature_extraction/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/feature_extraction/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/feature_extraction/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/feature_extraction/__pycache__/cnn.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/cnn.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/feature_extraction/__pycache__/cnn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/cnn.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/feature_extraction/__pycache__/cnn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/__pycache__/cnn.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/feature_extraction/cnn.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import OrderedDict
3 |
4 | import torch
5 |
6 | from ..utils import to_torch
7 |
8 | def extract_cnn_feature(model, inputs, sub, modules=None):
9 | model.eval()
10 | inputs = to_torch(inputs)
11 | inputs = inputs.cuda()
12 | with torch.no_grad():
13 | if modules is None:
14 | # whether "modules" is None or not, this function is used to extract feature from a certain dataset.
15 | #_, outputs = model(inputs)
16 | i_observation, i_representation, i_ms_observation, i_ms_representation, \
17 | v_observation, v_representation, v_ms_observation, v_ms_representation = model(inputs)
18 | outputs = torch.cat(tensors=[i_observation[1], i_ms_observation[1], i_representation[1], i_ms_representation[1],
19 | v_observation[1], v_ms_observation[1], v_representation[1], v_ms_representation[1]], dim=1)
20 | outputs = outputs.data.cpu()
21 | return outputs
22 | # Register forward hook for each module
23 | outputs = OrderedDict()
24 | handles = []
25 | for m in modules:
26 | outputs[id(m)] = None
27 | def func(m, i, o): outputs[id(m)] = o.data.cpu()
28 | handles.append(m.register_forward_hook(func))
29 | model(inputs)
30 | for h in handles:
31 | h.remove()
32 | return list(outputs.values())
33 |
--------------------------------------------------------------------------------
/reid/feature_extraction/cnn.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/feature_extraction/cnn.pyc
--------------------------------------------------------------------------------
/reid/metric_learning/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 |
4 | from .euclidean import Euclidean
5 |
6 | __factory = {
7 | 'euclidean': Euclidean
8 | }
9 |
10 |
11 | def get_metric(algorithm, *args, **kwargs):
12 | if algorithm not in __factory:
13 | raise KeyError("Unknown metric:", algorithm)
14 | return __factory[algorithm](*args, **kwargs)
15 |
--------------------------------------------------------------------------------
/reid/metric_learning/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__init__.pyc
--------------------------------------------------------------------------------
/reid/metric_learning/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/metric_learning/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/metric_learning/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/metric_learning/__pycache__/euclidean.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__pycache__/euclidean.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/metric_learning/__pycache__/euclidean.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/__pycache__/euclidean.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/metric_learning/euclidean.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from numpy.linalg import cholesky
3 | from sklearn.base import BaseEstimator, TransformerMixin
4 | from sklearn.utils.validation import check_array
5 | import numpy as np
6 |
7 | class BaseMetricLearner(BaseEstimator, TransformerMixin):
8 | def __init__(self):
9 | raise NotImplementedError('BaseMetricLearner should not be instantiated')
10 |
11 | def metric(self):
12 | """Computes the Mahalanobis matrix from the transformation matrix.
13 |
14 | .. math:: M = L^{\\top} L
15 |
16 | Returns
17 | -------
18 | M : (d x d) matrix
19 | """
20 | L = self.transformer()
21 | return L.T.dot(L)
22 |
23 | def transformer(self):
24 | """Computes the transformation matrix from the Mahalanobis matrix.
25 |
26 | L = cholesky(M).T
27 |
28 | Returns
29 | -------
30 | L : upper triangular (d x d) matrix
31 | """
32 | return cholesky(self.metric()).T
33 |
34 | def transform(self, X=None):
35 | """Applies the metric transformation.
36 |
37 | Parameters
38 | ----------
39 | X : (n x d) matrix, optional
40 | Data to transform. If not supplied, the training data will be used.
41 |
42 | Returns
43 | -------
44 | transformed : (n x d) matrix
45 | Input data transformed to the metric space by :math:`XL^{\\top}`
46 | """
47 | if X is None:
48 | X = self.X_
49 | else:
50 | X = check_array(X, accept_sparse=True)
51 | L = self.transformer()
52 | return X.dot(L.T)
53 |
54 |
55 |
56 | class Euclidean(BaseMetricLearner):
57 | def __init__(self):
58 | self.M_ = None
59 |
60 | def metric(self):
61 | return self.M_
62 |
63 | def fit(self, X):
64 | self.M_ = np.eye(X.shape[1])
65 | self.X_ = X
66 |
67 | def transform(self, X=None):
68 | if X is None:
69 | return self.X_
70 | return X
71 |
72 | def get_metric(self):
73 | pass
74 | def score_pairs(self):
75 | pass
76 |
77 |
78 | a = Euclidean()
--------------------------------------------------------------------------------
/reid/metric_learning/euclidean.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/metric_learning/euclidean.pyc
--------------------------------------------------------------------------------
/reid/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .newresnet import *
4 |
5 | __factory = {
6 | 'ft_net': ft_net
7 | }
8 |
9 | def names():
10 | return sorted(__factory.keys())
11 |
12 | def create(name, *args, **kwargs):
13 | if name not in __factory:
14 | raise KeyError("Unknown model:", name)
15 | return __factory[name](*args, **kwargs)
16 |
--------------------------------------------------------------------------------
/reid/models/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__init__.pyc
--------------------------------------------------------------------------------
/reid/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/models/__pycache__/baseline.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__pycache__/baseline.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/models/__pycache__/newresnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__pycache__/newresnet.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/models/__pycache__/newresnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/__pycache__/newresnet.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/models/baseline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import Parameter
4 |
5 | try:
6 | from torch.hub import load_state_dict_from_url
7 | except ImportError:
8 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
9 |
10 | ##################################################################################
11 | # Initialization function
12 | ##################################################################################
13 | def weights_init_kaiming(m):
14 | classname = m.__class__.__name__
15 | if classname.find('Linear') != -1:
16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
17 | nn.init.constant_(m.bias, 0.0)
18 | elif classname.find('Conv') != -1:
19 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
20 | if m.bias is not None:
21 | nn.init.constant_(m.bias, 0.0)
22 | elif classname.find('BatchNorm') != -1:
23 | if m.affine:
24 | nn.init.constant_(m.weight, 1.0)
25 | nn.init.constant_(m.bias, 0.0)
26 |
27 |
28 | def weights_init_Classifier(m):
29 | classname = m.__class__.__name__
30 | if classname.find('Linear') != -1:
31 | nn.init.normal_(m.weight, std=0.001)
32 | if m.bias:
33 | nn.init.constant_(m.bias, 0.0)
34 |
35 | def weights_init_classifier(m):
36 | classname = m.__class__.__name__
37 | if classname.find('Linear') != -1:
38 | nn.init.normal_(m.weight.data, std=0.001)
39 | nn.init.constant_(m.bias.data, 0.0)
40 |
41 | ##################################################################################
42 | # Encoder used in framework
43 | ##################################################################################
44 | class Baseline(nn.Module):
45 | in_planes = 2048
46 | def __init__(self, num_classes, num_features):
47 | super(Baseline, self).__init__()
48 | #backbone = resnet50(pretrained= True)
49 | self.cam_256 = ShallowCAM(256)
50 | self.cam_512 = ShallowCAM(512)
51 | self.cam_1024 = ShallowCAM(1024)
52 | self.base = resnet50(pretrained= True)
53 | self.gap = nn.AdaptiveAvgPool2d(1)
54 |
55 | self.num_classes = num_classes
56 | self.in_planes = num_features
57 |
58 | ## use a neck to produce triplet feature and score
59 | self.bottleneck = nn.BatchNorm1d(self.in_planes)
60 | self.bottleneck.bias.requires_grad_(False)
61 |
62 | classifier = [nn.Linear(num_features, num_classes)]
63 | classifier = nn.Sequential(*classifier)
64 | self.classifier = classifier
65 | self.bottleneck.apply(weights_init_kaiming)
66 | self.classifier.apply(weights_init_classifier)
67 |
68 | def forward(self, x):
69 |
70 | for name, module in self.base._modules.items():
71 | if name == 'avgpool':
72 | break
73 | if name == 'layer2':
74 | x = self.cam_256(x)
75 | if name == 'layer3':
76 | x = self.cam_512(x)
77 | if name == 'layer4':
78 | x = self.cam_1024(x)
79 | x = module(x)
80 | global_feat = self.gap(x) # (b, 2048, 1, 1)
81 |
82 | global_feat = global_feat.view(global_feat.shape[0], -1)
83 | feat1 = self.bottleneck(global_feat)
84 |
85 | cls_score = self.classifier(feat1)
86 |
87 | if self.training:
88 | return cls_score, feat1.view(feat1.shape[0], -1)
89 | else:
90 | return global_feat, feat1
91 | ##################################################################################
92 | # Shallow Cam Module
93 | ##################################################################################
94 |
95 | class CAM_Module(nn.Module):
96 | """ Channel attention module"""
97 |
98 | def __init__(self, in_dim):
99 | super(CAM_Module,self).__init__()
100 | self.channel_in = in_dim
101 |
102 | self.gamma = Parameter(torch.zeros(1))
103 | self.softmax = torch.nn.Softmax(dim=-1)
104 |
105 | def forward(self, x):
106 | """
107 | inputs :
108 | x : input feature maps( B X C X H X W)
109 | returns :
110 | out : attention value + input feature
111 | attention: B X C X C
112 | """
113 | m_batchsize, C, height, width = x.size()
114 | proj_query = x.view(m_batchsize, C, -1)
115 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
116 | energy = torch.bmm(proj_query, proj_key)
117 | max_energy_0 = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)
118 | energy_new = max_energy_0 - energy
119 | attention = self.softmax(energy_new)
120 | proj_value = x.view(m_batchsize, C, -1)
121 |
122 | out = torch.bmm(attention, proj_value)
123 | out = out.view(m_batchsize, C, height, width)
124 |
125 | gamma = self.gamma.to(out.device)
126 | out = gamma * out + x
127 | return out
128 |
129 | class ShallowCAM(nn.Module):
130 |
131 | def __init__(self, feature_dim):
132 |
133 | super(ShallowCAM,self).__init__()
134 | self.input_feature_dim = feature_dim
135 | self._cam_module = CAM_Module(self.input_feature_dim)
136 |
137 | def forward(self, x):
138 | x = self._cam_module(x)
139 |
140 | return x
141 |
142 | class bnneck(nn.Module):
143 | def __init__(self, indim):
144 | super(bnneck,self).__init__()
145 | self.indim = indim
146 | self.bn = nn.BatchNorm1d(self.indim)
147 | self.bn.bias.requires_grad_(False)
148 | self.bn.apply(weights_init_kaiming)
149 | def forward(self, x):
150 | x = self.bn(x)
151 | return x
152 |
153 | ##################################################################################
154 | # ResNet
155 | ##################################################################################
156 | __all__ = ['resnet50']
157 |
158 | model_urls = {
159 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
160 | }
161 |
162 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
163 | """3x3 convolution with padding"""
164 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
165 | padding=dilation, groups=groups, bias=False, dilation=dilation)
166 |
167 |
168 | def conv1x1(in_planes, out_planes, stride=1):
169 | """1x1 convolution"""
170 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
171 |
172 |
173 | class BasicBlock(nn.Module):
174 | expansion = 1
175 | __constants__ = ['downsample']
176 |
177 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
178 | base_width=64, dilation=1, norm_layer=None):
179 | super(BasicBlock, self).__init__()
180 | if norm_layer is None:
181 | norm_layer = nn.BatchNorm2d
182 | if groups != 1 or base_width != 64:
183 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
184 | if dilation > 1:
185 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
186 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
187 | self.conv1 = conv3x3(inplanes, planes, stride)
188 | self.bn1 = norm_layer(planes)
189 | self.relu = nn.ReLU(inplace=True)
190 | self.conv2 = conv3x3(planes, planes)
191 | self.bn2 = norm_layer(planes)
192 | self.downsample = downsample
193 | self.stride = stride
194 |
195 | def forward(self, x):
196 | identity = x
197 |
198 | out = self.conv1(x)
199 | out = self.bn1(out)
200 | out = self.relu(out)
201 |
202 | out = self.conv2(out)
203 | out = self.bn2(out)
204 |
205 | if self.downsample is not None:
206 | identity = self.downsample(x)
207 |
208 | out += identity
209 | out = self.relu(out)
210 |
211 | return out
212 |
213 |
214 | class Bottleneck(nn.Module):
215 | expansion = 4
216 | __constants__ = ['downsample']
217 |
218 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
219 | base_width=64, dilation=1, norm_layer=None):
220 | super(Bottleneck, self).__init__()
221 | if norm_layer is None:
222 | norm_layer = nn.BatchNorm2d
223 | width = int(planes * (base_width / 64.)) * groups
224 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
225 | self.conv1 = conv1x1(inplanes, width)
226 | self.bn1 = norm_layer(width)
227 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
228 | self.bn2 = norm_layer(width)
229 | self.conv3 = conv1x1(width, planes * self.expansion)
230 | self.bn3 = norm_layer(planes * self.expansion)
231 | self.relu = nn.ReLU(inplace=True)
232 | self.downsample = downsample
233 | self.stride = stride
234 |
235 | def forward(self, x):
236 | identity = x
237 |
238 | out = self.conv1(x)
239 | out = self.bn1(out)
240 | out = self.relu(out)
241 |
242 | out = self.conv2(out)
243 | out = self.bn2(out)
244 | out = self.relu(out)
245 |
246 | out = self.conv3(out)
247 | out = self.bn3(out)
248 |
249 | if self.downsample is not None:
250 | identity = self.downsample(x)
251 |
252 | out += identity
253 | out = self.relu(out)
254 |
255 | return out
256 |
257 |
258 | class ResNet(nn.Module):
259 |
260 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
261 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
262 | norm_layer=None):
263 | super(ResNet, self).__init__()
264 | if norm_layer is None:
265 | norm_layer = nn.BatchNorm2d
266 | self._norm_layer = norm_layer
267 |
268 | self.inplanes = 64
269 | self.dilation = 1
270 | if replace_stride_with_dilation is None:
271 | # each element in the tuple indicates if we should replace
272 | # the 2x2 stride with a dilated convolution instead
273 | replace_stride_with_dilation = [False, False, False]
274 | if len(replace_stride_with_dilation) != 3:
275 | raise ValueError("replace_stride_with_dilation should be None "
276 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
277 | self.groups = groups
278 | self.base_width = width_per_group
279 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
280 | bias=False)
281 | self.bn1 = norm_layer(self.inplanes)
282 | self.relu = nn.ReLU(inplace=True)
283 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
284 | self.layer1 = self._make_layer(block, 64, layers[0])
285 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
286 | dilate=replace_stride_with_dilation[0])
287 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
288 | dilate=replace_stride_with_dilation[1])
289 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
290 | dilate=replace_stride_with_dilation[2])
291 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
292 | self.fc = nn.Linear(512 * block.expansion, num_classes)
293 |
294 | for m in self.modules():
295 | if isinstance(m, nn.Conv2d):
296 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
297 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
298 | nn.init.constant_(m.weight, 1)
299 | nn.init.constant_(m.bias, 0)
300 |
301 | # Zero-initialize the last BN in each residual branch,
302 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
303 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
304 | if zero_init_residual:
305 | for m in self.modules():
306 | if isinstance(m, Bottleneck):
307 | nn.init.constant_(m.bn3.weight, 0)
308 | elif isinstance(m, BasicBlock):
309 | nn.init.constant_(m.bn2.weight, 0)
310 |
311 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
312 | norm_layer = self._norm_layer
313 | downsample = None
314 | previous_dilation = self.dilation
315 | if dilate:
316 | self.dilation *= stride
317 | stride = 1
318 | if stride != 1 or self.inplanes != planes * block.expansion:
319 | downsample = nn.Sequential(
320 | conv1x1(self.inplanes, planes * block.expansion, stride),
321 | norm_layer(planes * block.expansion),
322 | )
323 |
324 | layers = []
325 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
326 | self.base_width, previous_dilation, norm_layer))
327 | self.inplanes = planes * block.expansion
328 | for _ in range(1, blocks):
329 | layers.append(block(self.inplanes, planes, groups=self.groups,
330 | base_width=self.base_width, dilation=self.dilation,
331 | norm_layer=norm_layer))
332 |
333 | return nn.Sequential(*layers)
334 |
335 | def _forward_impl(self, x):
336 | # See note [TorchScript super()]
337 | x = self.conv1(x)
338 | x = self.bn1(x)
339 | x = self.relu(x)
340 | x = self.maxpool(x)
341 |
342 | x = self.layer1(x)
343 | x = self.layer2(x)
344 | x = self.layer3(x)
345 | x = self.layer4(x)
346 |
347 | x = self.avgpool(x)
348 | x = torch.flatten(x, 1)
349 | x = self.fc(x)
350 |
351 | return x
352 |
353 | def forward(self, x):
354 | return self._forward_impl(x)
355 |
356 |
357 | def load_param(self, model_path):
358 | param_dict = torch.load(model_path)
359 | for i in param_dict:
360 | if 'fc.weight' in param_dict:
361 | continue
362 | self.state_dict()[i].copy_(param_dict[i])
363 |
364 |
365 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
366 | model = ResNet(block, layers, **kwargs)
367 | if pretrained:
368 | state_dict = load_state_dict_from_url(model_urls[arch],
369 | progress=progress)
370 | model.load_state_dict(state_dict)
371 | return model
372 |
373 |
374 | def resnet50(pretrained=False, progress=True, **kwargs):
375 | """ResNet-50 model from
376 | `"Deep Residual Learning for Image Recognition" `_
377 | Args:
378 | pretrained (bool): If True, returns a model pre-trained on ImageNet
379 | progress (bool): If True, displays a progress bar of the download to stderr
380 | """
381 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
382 | **kwargs)
383 |
--------------------------------------------------------------------------------
/reid/models/baseline.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/baseline.pyc
--------------------------------------------------------------------------------
/reid/models/newresnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import init
6 | from torch.nn.functional import softplus
7 |
8 | from reid.models.baseline import Baseline
9 | from utlis import ChannelCompress, to_edge
10 |
11 | __all__ = ['ft_net']
12 |
13 | ##################################################################################
14 | # Initialization function
15 | ##################################################################################
16 | def weights_init_kaiming(m):
17 | classname = m.__class__.__name__
18 | # print(classname)
19 | if classname.find('Conv') != -1:
20 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
21 | elif classname.find('Linear') != -1:
22 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
23 | init.constant_(m.bias.data, 0.0)
24 | elif classname.find('BatchNorm1d') != -1:
25 | init.normal_(m.weight.data, 1.0, 0.02)
26 | init.constant_(m.bias.data, 0.0)
27 |
28 | def weights_init_classifier(m):
29 | classname = m.__class__.__name__
30 | if classname.find('Linear') != -1:
31 | init.normal_(m.weight.data, std=0.001)
32 | init.constant_(m.bias.data, 0.0)
33 |
34 | ##################################################################################
35 | # framework
36 | ##################################################################################
37 | class ft_net(nn.Module):
38 | @staticmethod
39 | def _init_reduction(reduction):
40 | # conv
41 | nn.init.kaiming_normal_(reduction[0].weight, mode='fan_in')
42 | # nn.init.constant_(reduction[0].bias, 0.)
43 |
44 | # bn
45 | nn.init.normal_(reduction[1].weight, mean=1., std=0.02)
46 | nn.init.constant_(reduction[1].bias, 0.)
47 |
48 | @staticmethod
49 | def _init_fc(fc):
50 | nn.init.kaiming_normal_(fc.weight, mode='fan_out')
51 | # nn.init.normal_(fc.weight, std=0.001)
52 | nn.init.constant_(fc.bias, 0.)
53 |
54 |
55 | def __init__(self, args, num_classes, num_features):
56 | super(ft_net, self).__init__()
57 |
58 | # Load basic config to initialize encoders, decoders and discriminators
59 | self.args = args
60 |
61 | self.IR_backbone = Baseline(num_classes, num_features)
62 | self.IR_Bottleneck = VIB(in_ch=2048, z_dim=self.args.z_dim, num_class= num_classes)
63 | #self.IR_MIE = MIEstimator(size1=2048, size2=int(2 * self.args.z_dim))
64 |
65 | self.RGB_backbone = Baseline(num_classes, num_features)
66 | self.RGB_Bottleneck = VIB(in_ch=2048, z_dim=self.args.z_dim, num_class= num_classes)
67 | #self.RGB_MIE = MIEstimator(size1=2048, size2=int(2 * self.args.z_dim))
68 |
69 | self.shared_backbone = Baseline(num_classes, num_features)
70 | self.shared_Bottleneck = VIB(in_ch=2048, z_dim=self.args.z_dim, num_class= num_classes)
71 | #self.shared_MIE = MIEstimator(size1=2048, size2=int(2 * self.args.z_dim))
72 |
73 | def forward(self, x):
74 | # visible branch
75 | v_observation = self.RGB_backbone(x)
76 | v_representation = self.RGB_Bottleneck(v_observation[1])
77 |
78 | # modal-shared branch
79 | x_grey = to_edge(x)
80 | i_ms_input = torch.cat([x_grey, x_grey, x_grey], dim=1)
81 |
82 | i_ms_observation = self.shared_backbone(i_ms_input)
83 | v_ms_observation = self.shared_backbone(x)
84 |
85 | i_ms_representation = self.shared_Bottleneck(i_ms_observation[1])
86 | v_ms_representation = self.shared_Bottleneck(v_ms_observation[1])
87 |
88 | # infrared branch
89 | i_observation = self.IR_backbone(i_ms_input)
90 | i_representation = self.IR_Bottleneck(i_observation[1])
91 |
92 | return i_observation, i_representation, i_ms_observation, i_ms_representation, \
93 | v_observation, v_representation, v_ms_observation, v_ms_representation
94 |
95 | def optims(self):
96 | conv_params = []
97 |
98 | conv_params += list(self.IR_backbone.parameters())
99 | conv_params += list(self.IR_Bottleneck.bottleneck.parameters())
100 | conv_params += list(self.IR_Bottleneck.classifier.parameters())
101 |
102 | conv_params += list(self.RGB_backbone.parameters())
103 | conv_params += list(self.RGB_Bottleneck.bottleneck.parameters())
104 | conv_params += list(self.RGB_Bottleneck.classifier.parameters())
105 |
106 | conv_params += list(self.shared_backbone.parameters())
107 | conv_params += list(self.shared_Bottleneck.bottleneck.parameters())
108 | conv_params += list(self.shared_Bottleneck.classifier.parameters())
109 |
110 | conv_optim = torch.optim.Adam([p for p in conv_params if p.requires_grad], lr=self.args.lr, weight_decay=5e-4)
111 |
112 | return conv_optim
113 |
114 | ##################################################################################
115 | # Variational Information Bottleneck
116 | ##################################################################################
117 | class VIB(nn.Module):
118 | def __init__(self, in_ch=2048, z_dim=256, num_class=395):
119 | super(VIB, self).__init__()
120 | self.in_ch = in_ch
121 | self.out_ch = z_dim * 2
122 | self.num_class = num_class
123 | self.bottleneck = ChannelCompress(in_ch=self.in_ch, out_ch=self.out_ch)
124 | # classifier of VIB, maybe modified later.
125 | classifier = []
126 | classifier += [nn.Linear(self.out_ch, self.out_ch // 2)]
127 | classifier += [nn.BatchNorm1d(self.out_ch // 2)]
128 | classifier += [nn.LeakyReLU(0.1)]
129 | classifier += [nn.Dropout(0.5)]
130 | classifier += [nn.Linear(self.out_ch // 2, self.num_class)]
131 | classifier = nn.Sequential(*classifier)
132 | self.classifier = classifier
133 | self.classifier.apply(weights_init_classifier)
134 |
135 | def forward(self, v):
136 | z_given_v = self.bottleneck(v)
137 | p_y_given_z = self.classifier(z_given_v)
138 | return p_y_given_z, z_given_v
139 |
140 | ##################################################################################
141 | # Mutual Information Estimator
142 | ##################################################################################
143 | class MIEstimator(nn.Module):
144 | def __init__(self, size1=2048, size2=512):
145 | super(MIEstimator, self).__init__()
146 | self.size1 = size1
147 | self.size2 = size2
148 | self.in_ch = size1 + size2
149 | add_block = []
150 | add_block += [nn.Linear(self.in_ch, 2048)]
151 | add_block += [nn.BatchNorm1d(2048)]
152 | add_block += [nn.ReLU()]
153 | add_block += [nn.Linear(2048, 512)]
154 | add_block += [nn.BatchNorm1d(512)]
155 | add_block += [nn.ReLU()]
156 | add_block += [nn.Linear(512, 1)]
157 |
158 | add_block = nn.Sequential(*add_block)
159 | add_block.apply(weights_init_kaiming)
160 | self.block = add_block
161 |
162 | # Gradient for JSD mutual information estimation and EB-based estimation
163 | def forward(self, x1, x2, x1_shuff):
164 | """
165 | :param x1: observation
166 | :param x2: representation
167 | """
168 | pos = self.block(torch.cat([x1, x2], 1)) # Positive Samples
169 | neg = self.block(torch.cat([x1_shuff, x2], 1))
170 |
171 | return -softplus(-pos).mean() - softplus(neg).mean(), pos.mean() - neg.exp().mean() + 1
--------------------------------------------------------------------------------
/reid/models/newresnet.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/models/newresnet.pyc
--------------------------------------------------------------------------------
/reid/trainers.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import time
3 | import torch
4 |
5 | from utlis import SinkhornDistance
6 | from .utils.meters import AverageMeter
7 | from torch.nn.functional import kl_div
8 |
9 |
10 | class BaseTrainer(object):
11 | def __init__(self, args, model, ce_loss, rank_loss, associate_loss, trainvallabel):
12 | super(BaseTrainer, self).__init__()
13 | self.model = model
14 |
15 | self.args = args
16 |
17 | self.CE_Loss = ce_loss
18 | self.rank_loss = rank_loss
19 |
20 | self.softmax = torch.nn.Softmax(dim=1)
21 | self.KLD = torch.nn.KLDivLoss()
22 | self.W_dist = SinkhornDistance().cuda()
23 |
24 | self.associate_loss = associate_loss
25 |
26 | self.trainvallabel = trainvallabel
27 |
28 | def train(self, epoch, data_loader, conv_optim, print_freq=24):
29 | self.model.train()
30 |
31 | batch_time = AverageMeter()
32 | data_time = AverageMeter()
33 | losses_total = AverageMeter()
34 |
35 | losses_triple = AverageMeter()
36 | losses_celoss = AverageMeter()
37 |
38 | losses_cml = AverageMeter()
39 | losses_vsd = AverageMeter()
40 | losses_vcd = AverageMeter()
41 |
42 | end = time.time()
43 | for i, inputs in enumerate(data_loader):
44 | data_time.update(time.time() - end)
45 |
46 | inputs, sub, label = self._parse_data(inputs)
47 |
48 | # Calc the loss
49 | ce_loss, triplet_Loss, conventional_ML, vsd_loss, vcd_loss = self._forward(inputs, label, sub, epoch)
50 | L = self.args.CE_loss * ce_loss + \
51 | self.args.Triplet_loss * triplet_Loss + \
52 | self.args.CML_loss * conventional_ML + \
53 | self.args.VSD_loss * vsd_loss + \
54 | self.args.VCD_loss * vcd_loss
55 |
56 | conv_optim.zero_grad()
57 | L.backward()
58 | conv_optim.step()
59 |
60 | losses_total.update(L.data.item(), label.size(0))
61 |
62 | losses_celoss.update(ce_loss.item(), label.size(0))
63 | losses_triple.update(triplet_Loss.item(), label.size(0))
64 |
65 | losses_cml.update(conventional_ML.item(), label.size(0))
66 | losses_vcd.update(vcd_loss.item(), label.size(0))
67 | losses_vsd.update(vsd_loss.item(), label.size(0))
68 |
69 | # losses_sharedMI.update(shared_MI.item(), label.size(0))
70 | # losses_specificMI.update(specific_MI.item(), label.size(0))
71 |
72 | batch_time.update(time.time() - end)
73 | end = time.time()
74 |
75 | if (i + 1) % print_freq == 0:
76 | print('Epoch: [{}][{}/{}]\t'
77 | 'Time {:.2f} ({:.2f})\t'
78 | 'Total Loss {:.2f} ({:.2f})\t'
79 | 'IDE Loss {:.2f} ({:.2f})\t'
80 | 'Triple Loss {:.2f} ({:.2f})\t'
81 | 'CML Loss {:.3f} ({:.3f})\t'
82 | 'VSD Loss {:.3f} ({:.3f})\t'
83 | 'VCD Loss {:.3f} ({:.3f})\t'
84 | .format(epoch, i + 1, len(data_loader),
85 | batch_time.val, batch_time.avg,
86 | losses_total.val, losses_total.avg,
87 | losses_celoss.val, losses_celoss.avg,
88 | losses_triple.val, losses_triple.avg,
89 | losses_cml.val, losses_cml.avg,
90 | losses_vsd.val, losses_vsd.avg,
91 | losses_vcd.val, losses_vcd.avg))
92 | return losses_triple.avg, losses_total.avg
93 |
94 | def _parse_data(self, inputs):
95 | raise NotImplementedError
96 |
97 | def _forward(self, inputs, targets):
98 | raise NotImplementedError
99 |
100 |
101 | class Trainer(BaseTrainer):
102 | def _parse_data(self, inputs):
103 | imgs, _, pids, cams = inputs
104 | inputs = imgs.cuda()
105 | pids = pids.cuda()
106 | sub = ((cams == 2).long() + (cams == 5).long()).cuda()
107 | label = torch.cuda.LongTensor(range(pids.size(0)))
108 | for i in range(pids.size(0)):
109 | label[i] = self.trainvallabel[pids[i].item()]
110 | return inputs, sub, label
111 |
112 | def _forward(self, inputs, label, sub, epoch):
113 | i_observation, i_representation, i_ms_observation, i_ms_representation, \
114 | v_observation, v_representation, v_ms_observation, v_ms_representation = self.model(inputs)
115 |
116 | # Classification loss
117 | ce_loss = 0.5 * (self.CE_Loss(i_observation[0], label) + self.CE_Loss(i_representation[0], label)) + \
118 | 0.5 * (self.CE_Loss(v_observation[0], label) + self.CE_Loss(v_representation[0], label)) + \
119 | 0.25 * (self.CE_Loss(i_ms_observation[0], label) + self.CE_Loss(i_ms_representation[0], label)) + \
120 | 0.25 * (self.CE_Loss(v_ms_observation[0], label) + self.CE_Loss(v_ms_representation[0], label))
121 |
122 | # Metric learning, notice rank loss are applied to v and z, respectively.
123 | triplet_Loss = 0.5 * ( self.rank_loss(i_observation[1], label, sub) + self.rank_loss(i_representation[1], label, sub)) + \
124 | 0.5 * (self.rank_loss(v_observation[1], label, sub) + self.rank_loss(v_representation[1], label, sub)) + \
125 | 0.25 * (self.rank_loss(i_ms_observation[1], label, sub) + self.rank_loss(i_ms_representation[1], label, sub)) + \
126 | 0.25 * (self.rank_loss(v_ms_observation[1], label, sub) + self.rank_loss(v_ms_representation[1], label, sub))
127 |
128 | #associate_loss = 0.5 * (self.associate_loss(i_observation[1], label, sub) + self.associate_loss(i_representation[1], label, sub)) + \
129 | # 0.5 * (self.associate_loss(v_observation[1], label, sub) + self.associate_loss(v_representation[1], label, sub)) + \
130 | # 0.25 * (self.associate_loss(i_ms_observation[1], label, sub) + self.associate_loss(i_ms_representation[1], label, sub)) + \
131 | # 0.25 * (self.associate_loss(v_ms_observation[1], label, sub) + self.associate_loss(v_ms_representation[1], label, sub))
132 |
133 | # Conventional mutual learning strategy, conducted only between observations of modal-specific branches.
134 | conventional_ML = self.W_dist(self.softmax(i_observation[0]), self.softmax(v_observation[0]))
135 |
136 | # Variational Self-Distillation, preserving sufficiency
137 | vsd_loss = kl_div(input=self.softmax(i_observation[0].detach() / self.args.temperature),
138 | target=self.softmax(i_representation[0] / self.args.temperature)) + \
139 | kl_div(input=self.softmax(v_observation[0].detach() / self.args.temperature),
140 | target=self.softmax(v_representation[0] / self.args.temperature))
141 |
142 | vcd_loss = 0.5 * kl_div(input=self.softmax(v_ms_observation[0].detach()),
143 | target=self.softmax(i_ms_representation[0])) + \
144 | 0.5 * kl_div(input=self.softmax(i_ms_observation[0].detach()),
145 | target=self.softmax(v_ms_representation[0]))
146 |
147 | # mutual information estimation for modal-specific branches and modal-shared branch
148 | # shuff_order = np.random.permutation(self.args.batch_size)
149 | # specific_MI_I = self.model.module.IR_MIE(x1=i_observation[1], x2=i_representation[1],
150 | # x1_shuff=i_observation[1][shuff_order, :])[0].mean()
151 | # specific_MI_V = self.model.module.IR_MIE(x1=v_observation[1], x2=v_representation[1],
152 | # x1_shuff=v_observation[1][shuff_order, :])[0].mean()
153 | # shared_MI_I = self.model.module.shared_MIE(x1=i_ms_observation[1], x2=i_ms_representation[1],
154 | # x1_shuff=i_ms_observation[1][shuff_order, :])[0].mean()
155 | # shared_MI_V = self.model.module.shared_MIE(x1=v_ms_observation[1], x2=v_ms_representation[1],
156 | # x1_shuff=v_ms_observation[1][shuff_order, :])[0].mean()
157 |
158 | return ce_loss, triplet_Loss, conventional_ML, vsd_loss, vcd_loss#, associate_loss
159 |
--------------------------------------------------------------------------------
/reid/trainers.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/trainers.pyc
--------------------------------------------------------------------------------
/reid/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 |
5 |
6 | def to_numpy(tensor):
7 | if torch.is_tensor(tensor):
8 | return tensor.cpu().numpy()
9 | elif type(tensor).__module__ != 'numpy':
10 | raise ValueError("Cannot convert {} to numpy array"
11 | .format(type(tensor)))
12 | return tensor
13 |
14 |
15 | def to_torch(ndarray):
16 | if type(ndarray).__module__ == 'numpy':
17 | return torch.from_numpy(ndarray)
18 | elif not torch.is_tensor(ndarray):
19 | raise ValueError("Cannot convert {} to torch tensor"
20 | .format(type(ndarray)))
21 | return ndarray
22 |
--------------------------------------------------------------------------------
/reid/utils/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__init__.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/logging.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/logging.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/logging.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/logging.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/meters.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/meters.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/meters.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/meters.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/osutils.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/osutils.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/osutils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/osutils.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/osutils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/osutils.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/serialization.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/serialization.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/serialization.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/serialization.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/utils/__pycache__/serialization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/__pycache__/serialization.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .dataset import Dataset
4 | from .preprocessor import Preprocessor
5 |
6 |
7 |
--------------------------------------------------------------------------------
/reid/utils/data/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__init__.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/dataset.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/dataset.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/preprocessor.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/preprocessor.cpython-35.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/preprocessor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/preprocessor.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/preprocessor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/preprocessor.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/sampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/sampler.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/sampler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/sampler.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/reid/utils/data/__pycache__/transforms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/__pycache__/transforms.cpython-37.pyc
--------------------------------------------------------------------------------
/reid/utils/data/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os.path as osp
3 |
4 | import numpy as np
5 |
6 | from ..serialization import read_json
7 |
8 |
9 | def _pluck(identities, indices, relabel=False):
10 | ret = []
11 | for index, pid in enumerate(indices):
12 | pid_images = identities[pid]
13 | for camid, cam_images in enumerate(pid_images):
14 | for fname in cam_images:
15 | name = osp.splitext(fname)[0]
16 | try:
17 | x, y, _ = map(int, name.split('_'))
18 | assert pid == x and camid == y
19 | except:
20 | _, _, _, _, x = map(str, name.split('_'))
21 | if relabel:
22 | ret.append((fname, index, camid))
23 | else:
24 | ret.append((fname, pid, camid))
25 | return ret
26 |
27 |
28 | class Dataset(object):
29 | def __init__(self, root, split_id=0):
30 | self.root = root
31 | self.split_id = split_id
32 | self.meta = None
33 | self.split = None
34 | self.train, self.val, self.trainval = [], [], []
35 | self.query, self.gallery = [], []
36 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0
37 | self.trainvallabel = {}
38 |
39 | @property
40 | def images_dir(self):
41 | return osp.join(self.root, 'images')
42 |
43 | def load(self, num_val=0.3, verbose=False):
44 | splits = read_json(osp.join(self.root, 'splits.json'))
45 | if self.split_id >= len(splits):
46 | raise ValueError("split_id exceeds total splits {}"
47 | .format(len(splits)))
48 | self.split = splits[self.split_id]
49 |
50 | # Randomly split train / val
51 | trainval_pids = np.asarray(self.split['trainval'])
52 | test_pids = np.asarray(self.split['query'])
53 | # np.random.shuffle(trainval_pids)
54 | num = len(trainval_pids)
55 | if isinstance(num_val, float):
56 | num_val = int(round(num * num_val))
57 | if num_val >= num or num_val < 0:
58 | raise ValueError("num_val exceeds total identities {}"
59 | .format(num))
60 | train_pids = sorted(trainval_pids[:-num_val])
61 | val_pids = sorted(trainval_pids[-num_val:])
62 |
63 | self.meta = read_json(osp.join(self.root, 'meta.json'))
64 | identities = self.meta['identities']
65 | self.train = _pluck(identities, train_pids, relabel=False)
66 | self.val = _pluck(identities, val_pids, relabel=False)
67 | self.trainval = _pluck(identities, trainval_pids, relabel=False)
68 | # print(self.trainval[1],self.trainval[1][1])
69 | countIR = 0
70 | countRGB = 0
71 | for image in self.trainval:
72 | # print(image[2] == 2 or image[2] == 5)
73 | if image[2] == 2 or image[2] == 5:
74 | countIR = countIR + 1
75 | else:
76 | countRGB = countRGB + 1
77 | self.query = _pluck(identities, self.split['query'])
78 | self.gallery = _pluck(identities, self.split['gallery'])
79 | query = 0
80 | gallery = 0
81 | for image in self.query:
82 | if image[2] == 2 or image[2] == 5:
83 | query += 1
84 | else:
85 | gallery += 1
86 |
87 | self.num_train_ids = len(train_pids)
88 | self.num_val_ids = len(val_pids)
89 | self.num_trainval_ids = len(trainval_pids)
90 | # print(sorted(trainval_pids))
91 | for index,i in enumerate(sorted(trainval_pids)):
92 | self.trainvallabel[i] = index
93 | # print (index)
94 |
95 | if verbose:
96 | print(self.__class__.__name__, "dataset loaded")
97 | print(" subset | # ids | # images")
98 | print(" ---------------------+----------+---------")
99 | print(" train | {:8d} | {:8d}"
100 | .format(self.num_train_ids, len(self.train)))
101 | print(" val | {:8d} | {:8d}"
102 | .format(self.num_val_ids, len(self.val)))
103 | print(" trainval | {:8d} | {:8d}"
104 | .format(self.num_trainval_ids, len(self.trainval)))
105 | print(" query | {:8d} | {:8d}"
106 | .format(len(test_pids), len(test_pids) * 4))
107 | print(" gallery | {:8d} | {:8d}"
108 | .format(len(test_pids), gallery))
109 | print(" num of RGB and IR | {:8d} | {:8d}"
110 | .format(countRGB, countIR))
111 |
112 | def _check_integrity(self):
113 | return osp.isdir(osp.join(self.root, 'images')) and \
114 | osp.isfile(osp.join(self.root, 'meta.json')) and \
115 | osp.isfile(osp.join(self.root, 'splits.json'))
116 |
--------------------------------------------------------------------------------
/reid/utils/data/dataset.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/dataset.pyc
--------------------------------------------------------------------------------
/reid/utils/data/preprocessor.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os.path as osp
3 |
4 | from PIL import Image
5 |
6 |
7 | class Preprocessor(object):
8 | def __init__(self, dataset, root=None, transform=None):
9 | super(Preprocessor, self).__init__()
10 | self.dataset = dataset
11 | self.root = root
12 | self.transform = transform
13 |
14 | def __len__(self):
15 | return len(self.dataset)
16 |
17 | def __getitem__(self, indices):
18 | if isinstance(indices, (tuple, list)):
19 | print("indices is tuple list")
20 | return [self._get_single_item(index) for index in indices]
21 | return self._get_single_item(indices)
22 |
23 | def _get_single_item(self, index):
24 | fname, pid, camid = self.dataset[index]
25 | fpath = fname
26 | if self.root is not None:
27 | fpath = osp.join(self.root, fname)
28 | img = Image.open(fpath).convert('RGB')
29 | # img = Image.open(fpath).convert('L')
30 | # print(img)
31 | if self.transform is not None:
32 | img = self.transform(img)
33 | return img, fname, pid, camid
34 |
--------------------------------------------------------------------------------
/reid/utils/data/preprocessor.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/preprocessor.pyc
--------------------------------------------------------------------------------
/reid/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | import torch
6 | from random import shuffle
7 | from torch.utils.data.sampler import (
8 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler,
9 | WeightedRandomSampler)
10 |
11 |
12 | class RandomIdentitySampler(Sampler):
13 | def __init__(self, data_source, num_instances=1):
14 | self.data_source = data_source
15 | self.num_instances = num_instances
16 | self.index_dic = defaultdict(list)
17 | for index, (_, pid, _) in enumerate(data_source):
18 | self.index_dic[pid].append(index)
19 | self.pids = list(self.index_dic.keys())
20 | self.num_samples = len(self.pids)
21 |
22 | def __len__(self):
23 | return self.num_samples * self.num_instances
24 |
25 | def __iter__(self):
26 | indices = torch.randperm(self.num_samples)
27 | ret = []
28 | for i in indices:
29 | pid = self.pids[i]
30 | t = self.index_dic[pid]
31 | if len(t) >= self.num_instances:
32 | t = np.random.choice(t, size=self.num_instances, replace=False)
33 | else:
34 | t = np.random.choice(t, size=self.num_instances, replace=True)
35 | ret.extend(t)
36 | return iter(ret)
37 |
38 | class CamSampler(Sampler):
39 | def __init__(self, data_source, need_cam, num=0):
40 | self.data_source = data_source
41 | self.index_dic = []
42 | self.id_cam = [[] for _ in range(533)]
43 |
44 | if num>0:
45 | for index, (_, pid, cam) in enumerate(data_source):
46 | if cam in need_cam:
47 | self.id_cam[pid].append(index)
48 | for i in range(533):
49 | if len(self.id_cam[i])>num:
50 | self.index_dic.extend(self.id_cam[i][:num])
51 | else:
52 | for index, (_, pid, cam) in enumerate(data_source):
53 | if cam in need_cam:
54 | self.index_dic.append(index)
55 |
56 | def __len__(self):
57 | return len(self.index_dic)
58 |
59 | def __iter__(self):
60 | return iter(self.index_dic)
61 |
62 | class CamRandomIdentitySampler(Sampler):
63 | def __init__(self, data_source, num_instances=2):
64 | self.data_source = data_source
65 | self.num_instances = num_instances
66 | if num_instances % 2 > 0:
67 | raise ValueError("The num_instances should be a even number")
68 | self.index_dic_I = defaultdict(list)
69 | self.index_dic_IR = defaultdict(list)
70 | for index, (name, pid, cam) in enumerate(data_source):
71 | if cam == 2 or cam == 5:
72 | self.index_dic_IR[pid].append(index)
73 | else:
74 | self.index_dic_I[pid].append(index)
75 | self.pids = list(self.index_dic_I.keys())
76 | self.num_samples = len(self.pids)
77 |
78 | def __len__(self):
79 | return self.num_samples * self.num_instances
80 |
81 | def __iter__(self):
82 | indices = torch.randperm(self.num_samples)
83 | ret = []
84 | for i in indices:
85 | pid_I = self.pids[i]
86 | pid_IR = self.pids[i]
87 | t_I = self.index_dic_I[pid_I]
88 | t_IR = self.index_dic_IR[pid_IR]
89 | if len(t_I) >= self.num_instances / 2:
90 | t_I = np.random.choice(t_I, size=int(self.num_instances / 2), replace=False)
91 | else:
92 | t_I = np.random.choice(t_I, size=int(self.num_instances / 2), replace=True)
93 | if len(t_IR) >= self.num_instances / 2:
94 | t_IR = np.random.choice(t_IR, size=int(self.num_instances / 2), replace=False)
95 | else:
96 | t_IR = np.random.choice(t_IR, size=int(self.num_instances / 2), replace=True)
97 | # ret.extend(t_I)
98 | # ret.extend(t_IR)
99 | for j in range(self.num_instances // 2):
100 | ret.append(t_I[j])
101 | ret.append(t_IR[j])
102 | return iter(ret)
103 |
--------------------------------------------------------------------------------
/reid/utils/data/sampler.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/sampler.pyc
--------------------------------------------------------------------------------
/reid/utils/data/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from torchvision.transforms import *
4 | from PIL import Image
5 | import random
6 | import math
7 |
8 |
9 | class RectScale(object):
10 | def __init__(self, height, width, interpolation=Image.BILINEAR):
11 | self.height = height
12 | self.width = width
13 | self.interpolation = interpolation
14 |
15 | def __call__(self, img):
16 | w, h = img.size
17 | if h == self.height and w == self.width:
18 | return img
19 | return img.resize((self.width, self.height), self.interpolation)
20 |
21 |
22 | class RandomSizedRectCrop(object):
23 | def __init__(self, height, width, interpolation=Image.BILINEAR):
24 | self.height = height
25 | self.width = width
26 | self.interpolation = interpolation
27 |
28 | def __call__(self, img):
29 | for attempt in range(10):
30 | area = img.size[0] * img.size[1]
31 | target_area = random.uniform(0.64, 1.0) * area
32 | aspect_ratio = random.uniform(2, 3)
33 |
34 | h = int(round(math.sqrt(target_area * aspect_ratio)))
35 | w = int(round(math.sqrt(target_area / aspect_ratio)))
36 |
37 | if w <= img.size[0] and h <= img.size[1]:
38 | x1 = random.randint(0, img.size[0] - w)
39 | y1 = random.randint(0, img.size[1] - h)
40 |
41 | img = img.crop((x1, y1, x1 + w, y1 + h))
42 | assert(img.size == (w, h))
43 |
44 | return img.resize((self.width, self.height), self.interpolation)
45 |
46 | # Fallback
47 | scale = RectScale(self.height, self.width,
48 | interpolation=self.interpolation)
49 | return scale(img)
50 |
51 | class RandomErasing(object):
52 | """ Randomly selects a rectangle region in an image and erases its pixels.
53 | 'Random Erasing Data Augmentation' by Zhong et al.
54 | See https://arxiv.org/pdf/1708.04896.pdf
55 | Args:
56 | probability: The probability that the Random Erasing operation will be performed.
57 | sl: Minimum proportion of erased area against input image.
58 | sh: Maximum proportion of erased area against input image.
59 | r1: Minimum aspect ratio of erased area.
60 | mean: Erasing value.
61 | """
62 |
63 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
64 | self.probability = probability
65 | self.mean = mean
66 | self.sl = sl
67 | self.sh = sh
68 | self.r1 = r1
69 |
70 | def __call__(self, img):
71 |
72 | if random.uniform(0, 1) >= self.probability:
73 | return img
74 |
75 | for attempt in range(100):
76 | area = img.size()[1] * img.size()[2]
77 |
78 | target_area = random.uniform(self.sl, self.sh) * area
79 | aspect_ratio = random.uniform(self.r1, 1 / self.r1)
80 |
81 | h = int(round(math.sqrt(target_area * aspect_ratio)))
82 | w = int(round(math.sqrt(target_area / aspect_ratio)))
83 |
84 | if w < img.size()[2] and h < img.size()[1]:
85 | x1 = random.randint(0, img.size()[1] - h)
86 | y1 = random.randint(0, img.size()[2] - w)
87 | if img.size()[0] == 3:
88 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
89 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
90 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
91 | else:
92 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
93 | return img
94 |
95 | return img
96 |
--------------------------------------------------------------------------------
/reid/utils/data/transforms.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/data/transforms.pyc
--------------------------------------------------------------------------------
/reid/utils/logging.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import sys
4 |
5 | from .osutils import mkdir_if_missing
6 |
7 |
8 | class Logger(object):
9 | def __init__(self, fpath=None):
10 | self.console = sys.stdout
11 | self.file = None
12 | if fpath is not None:
13 | mkdir_if_missing(os.path.dirname(fpath))
14 | self.file = open(fpath, 'w')
15 |
16 | def __del__(self):
17 | self.close()
18 |
19 | def __enter__(self):
20 | pass
21 |
22 | def __exit__(self, *args):
23 | self.close()
24 |
25 | def write(self, msg):
26 | self.console.write(msg)
27 | if self.file is not None:
28 | self.file.write(msg)
29 |
30 | def flush(self):
31 | self.console.flush()
32 | if self.file is not None:
33 | self.file.flush()
34 | os.fsync(self.file.fileno())
35 |
36 | def close(self):
37 | self.console.close()
38 | if self.file is not None:
39 | self.file.close()
40 |
--------------------------------------------------------------------------------
/reid/utils/logging.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/logging.pyc
--------------------------------------------------------------------------------
/reid/utils/meters.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 |
7 | def __init__(self):
8 | self.val = 0
9 | self.avg = 0
10 | self.sum = 0
11 | self.count = 0
12 |
13 | def reset(self):
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 |
--------------------------------------------------------------------------------
/reid/utils/meters.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/meters.pyc
--------------------------------------------------------------------------------
/reid/utils/osutils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import errno
4 |
5 |
6 | def mkdir_if_missing(dir_path):
7 | try:
8 | os.makedirs(dir_path)
9 | except OSError as e:
10 | if e.errno != errno.EEXIST:
11 | raise
12 |
--------------------------------------------------------------------------------
/reid/utils/osutils.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/osutils.pyc
--------------------------------------------------------------------------------
/reid/utils/serialization.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import json
3 | import os.path as osp
4 | import shutil
5 |
6 | import torch
7 | from torch.nn import Parameter
8 |
9 | from .osutils import mkdir_if_missing
10 |
11 |
12 | def read_json(fpath):
13 | with open(fpath, 'r') as f:
14 | obj = json.load(f)
15 | return obj
16 |
17 |
18 | def write_json(obj, fpath):
19 | mkdir_if_missing(osp.dirname(fpath))
20 | with open(fpath, 'w') as f:
21 | json.dump(obj, f, indent=4, separators=(',', ': '))
22 |
23 |
24 | def save_checkpoint(state, is_best, epoch, dirname, fpath='checkpoint.pth.tar'):
25 | mkdir_if_missing(osp.dirname(dirname))
26 | torch.save(state, osp.join(dirname,fpath))
27 | if is_best:
28 | shutil.copyfile(dirname+'/'+fpath, dirname+'/'+'model_best.pth.tar')
29 |
30 |
31 | def load_checkpoint(fpath):
32 | if osp.isfile(fpath):
33 | checkpoint = torch.load(fpath)
34 | print("=> Loaded checkpoint '{}'".format(fpath))
35 | return checkpoint
36 | else:
37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath))
38 |
39 |
40 | def copy_state_dict(state_dict, model, strip=None):
41 | tgt_state = model.state_dict()
42 | copied_names = set()
43 | for name, param in state_dict.items():
44 | if strip is not None and name.startswith(strip):
45 | name = name[len(strip):]
46 | if name not in tgt_state:
47 | continue
48 | if isinstance(param, Parameter):
49 | param = param.data
50 | if param.size() != tgt_state[name].size():
51 | print('mismatch:', name, param.size(), tgt_state[name].size())
52 | continue
53 | tgt_state[name].copy_(param)
54 | copied_names.add(name)
55 |
56 | missing = set(tgt_state.keys()) - copied_names
57 | if len(missing) > 0:
58 | print("missing keys in state_dict:", missing)
59 |
60 | return model
61 |
--------------------------------------------------------------------------------
/reid/utils/serialization.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FutabaSakuraXD/Farewell-to-Mutual-Information-Variational-Distiilation-for-Cross-Modal-Person-Re-identification/ae7c0187d2ed36e6ed5109c2e0476e7f17bc8ce2/reid/utils/serialization.pyc
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | # encoding: utf-8
3 | import os
4 | c = 0.15
5 | for i in range(0,16):
6 | c = c + 0.15
7 | file_data = ""
8 | f = open("demo_sysu.sh")
9 | lines = f.readlines()
10 | print(c)
11 | with open("demo_sysu.sh", "w") as fw:
12 | for line in lines:
13 | print(line)
14 | if "D_loss" in line:
15 | line = "-D_loss " + str(c) + " \\" + '\n'
16 | if "logs-dir" in line:
17 | line = "--logs-dir ./weight_of_D_Loss_" + str(c) + ' \\' + '\n'
18 | file_data += line
19 | fw.write(file_data)
20 | os.system('sh demo_sysu.sh')
21 |
22 |
--------------------------------------------------------------------------------
/utlis.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | from bisect import bisect_right
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from torch.nn import init
8 | ############################################################################################
9 | # Channel Compress
10 | ############################################################################################
11 |
12 | def weights_init_kaiming(m):
13 | classname = m.__class__.__name__
14 | # print(classname)
15 | if classname.find('Conv') != -1:
16 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
17 | elif classname.find('Linear') != -1:
18 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
19 | init.constant_(m.bias.data, 0.0)
20 | elif classname.find('BatchNorm1d') != -1:
21 | init.normal_(m.weight.data, 1.0, 0.02)
22 | init.constant_(m.bias.data, 0.0)
23 |
24 | def weights_init_classifier(m):
25 | classname = m.__class__.__name__
26 | if classname.find('Linear') != -1:
27 | init.normal_(m.weight.data, std=0.001)
28 | init.constant_(m.bias.data, 0.0)
29 |
30 | class ChannelCompress(nn.Module):
31 | def __init__(self, in_ch=2048, out_ch=256):
32 | """
33 | reduce the amount of channels to prevent final embeddings overwhelming shallow feature maps
34 | out_ch could be 512, 256, 128
35 | """
36 | super(ChannelCompress, self).__init__()
37 | num_bottleneck = 1000
38 | add_block = []
39 | add_block += [nn.Linear(in_ch, num_bottleneck)]
40 | add_block += [nn.BatchNorm1d(num_bottleneck)]
41 | add_block += [nn.ReLU()]
42 |
43 | add_block += [nn.Linear(num_bottleneck, 500)]
44 | add_block += [nn.BatchNorm1d(500)]
45 | add_block += [nn.ReLU()]
46 | add_block += [nn.Linear(500, out_ch)]
47 |
48 | # Extra BN layer, need to be removed
49 | #add_block += [nn.BatchNorm1d(out_ch)]
50 |
51 | add_block = nn.Sequential(*add_block)
52 | add_block.apply(weights_init_kaiming)
53 | self.model = add_block
54 |
55 | def forward(self, x):
56 | x = self.model(x)
57 | return x
58 |
59 | ############################################################################################
60 | # Classification Loss
61 | ############################################################################################
62 | class CrossEntropyLabelSmooth(nn.Module):
63 | """Cross entropy loss with label smoothing regularizer.
64 |
65 | Reference:
66 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
67 | Equation: y = (1 - epsilon) * y + epsilon / K.
68 |
69 | Args:
70 | num_classes (int): number of classes.
71 | epsilon (float): weight.
72 | """
73 | def __init__(self, num_classes, epsilon=0.0, use_gpu=True):
74 | super(CrossEntropyLabelSmooth, self).__init__()
75 | self.num_classes = num_classes
76 | self.epsilon = epsilon
77 | self.use_gpu = use_gpu
78 | self.logsoftmax = nn.LogSoftmax(dim=1)
79 |
80 | def forward(self, inputs, targets):
81 | """
82 | Args:
83 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
84 | targets: ground truth labels with shape (num_classes)
85 | """
86 | log_probs = self.logsoftmax(inputs)
87 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
88 | if self.use_gpu: targets = targets.cuda()
89 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
90 | loss = (- targets * log_probs).mean(0).sum()
91 | return loss
92 |
93 |
94 | ############################################################################################
95 | # gray_scale function
96 | ############################################################################################
97 | def to_edge(x):
98 | x = x.data.cpu()
99 | out = torch.FloatTensor(x.size(0), x.size(2), x.size(3))
100 | for i in range(x.size(0)):
101 | item = x[i,:,:,:]
102 | #print(item.shape)
103 | r, g, b = item[0, :, :], item[1, :, :], item[2, :, :]
104 | xx = 0.2989 * r + 0.5870 * g + 0.1140 * b
105 | #print(xx.shape)
106 | out[i, :, :] = xx
107 | out = out.unsqueeze(1)
108 | return out.cuda()
109 |
110 | ############################################################################################
111 | # Random Erasing
112 | ############################################################################################
113 | class RandomErasing(object):
114 | """ Randomly selects a rectangle region in an image and erases its pixels.
115 | 'Random Erasing Data Augmentation' by Zhong et al.
116 | See https://arxiv.org/pdf/1708.04896.pdf
117 | Args:
118 | probability: The probability that the Random Erasing operation will be performed.
119 | sl: Minimum proportion of erased area against input image.
120 | sh: Maximum proportion of erased area against input image.
121 | r1: Minimum aspect ratio of erased area.
122 | mean: Erasing value.
123 | """
124 |
125 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
126 | self.probability = probability
127 | self.mean = mean
128 | self.sl = sl
129 | self.sh = sh
130 | self.r1 = r1
131 |
132 | def __call__(self, img):
133 |
134 | if random.uniform(0, 1) >= self.probability:
135 | return img
136 |
137 | for attempt in range(100):
138 | area = img.size()[1] * img.size()[2]
139 |
140 | target_area = random.uniform(self.sl, self.sh) * area
141 | aspect_ratio = random.uniform(self.r1, 1 / self.r1)
142 |
143 | h = int(round(math.sqrt(target_area * aspect_ratio)))
144 | w = int(round(math.sqrt(target_area / aspect_ratio)))
145 |
146 | if w < img.size()[2] and h < img.size()[1]:
147 | x1 = random.randint(0, img.size()[1] - h)
148 | y1 = random.randint(0, img.size()[2] - w)
149 | if img.size()[0] == 3:
150 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
151 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
152 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
153 | else:
154 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
155 | return img
156 |
157 | return img
158 |
159 | ############################################################################################
160 | # Warmup scheduler
161 | ############################################################################################
162 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
163 | def __init__(
164 | self,
165 | optimizer,
166 | milestones,
167 | gamma=0.1,
168 | warmup_factor=1.0 / 3,
169 | warmup_iters=500,
170 | warmup_method="linear",
171 | last_epoch=-1,
172 | ):
173 | if not list(milestones) == sorted(milestones):
174 | raise ValueError(
175 | "Milestones should be a list of" " increasing integers. Got {}",
176 | milestones,
177 | )
178 |
179 | if warmup_method not in ("constant", "linear"):
180 | raise ValueError(
181 | "Only 'constant' or 'linear' warmup_method accepted"
182 | "got {}".format(warmup_method)
183 | )
184 | self.milestones = milestones
185 | self.gamma = gamma
186 | self.warmup_factor = warmup_factor
187 | self.warmup_iters = warmup_iters
188 | self.warmup_method = warmup_method
189 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
190 |
191 | def get_lr(self):
192 | warmup_factor = 1
193 | if self.last_epoch < self.warmup_iters:
194 | if self.warmup_method == "constant":
195 | warmup_factor = self.warmup_factor
196 | elif self.warmup_method == "linear":
197 | alpha = self.last_epoch / self.warmup_iters
198 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
199 | return [
200 | base_lr
201 | * warmup_factor
202 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
203 | for base_lr in self.base_lrs
204 | ]
205 | ############################################################################################
206 | # Rank loss
207 | ############################################################################################
208 | class Rank_loss(nn.Module):
209 |
210 | ## Basic idea for cross_modality rank_loss 8
211 |
212 | def __init__(self, margin_1=1.0, margin_2=1.5, alpha_1=2.4, alpha_2=2.2, tval=1.0):
213 | super(Rank_loss, self).__init__()
214 | self.margin_1 = margin_1 # for same modality
215 | self.margin_2 = margin_2 # for different modalities
216 | self.alpha_1 = alpha_1 # for same modality
217 | self.alpha_2 = alpha_2 # for different modalities
218 | self.tval = tval
219 |
220 | def forward(self, x, targets, sub, norm = True):
221 | if norm:
222 | #x = self.normalize(x)
223 | x = torch.nn.functional.normalize(x, dim=1, p=2)
224 | dist_mat = self.euclidean_dist(x, x) # compute the distance
225 | loss = self.rank_loss(dist_mat, targets, sub)
226 | return loss #,dist_mat
227 |
228 | def rank_loss(self, dist, targets, sub):
229 | loss = 0.0
230 | for i in range(dist.size(0)):
231 | is_pos = targets.eq(targets[i])
232 | is_pos[i] = 0
233 | is_neg = targets.ne(targets[i])
234 |
235 | intra_modality = sub.eq(sub[i])
236 | cross_modality = ~ intra_modality
237 |
238 | mask_pos_intra = is_pos* intra_modality
239 | mask_pos_cross = is_pos* cross_modality
240 | mask_neg_intra = is_neg* intra_modality
241 | mask_neg_cross = is_neg* cross_modality
242 |
243 | ap_pos_intra = torch.clamp(torch.add(dist[i][mask_pos_intra], self.margin_1-self.alpha_1),0)
244 | ap_pos_cross = torch.clamp(torch.add(dist[i][mask_pos_cross], self.margin_2-self.alpha_2),0)
245 |
246 | loss_ap = torch.div(torch.sum(ap_pos_intra), ap_pos_intra.size(0)+1e-5)
247 | loss_ap += torch.div(torch.sum(ap_pos_cross), ap_pos_cross.size(0)+1e-5)
248 |
249 | dist_an_intra = dist[i][mask_neg_intra]
250 | dist_an_cross = dist[i][mask_neg_cross]
251 |
252 | an_less_intra = dist_an_intra[torch.lt(dist[i][mask_neg_intra], self.alpha_1)]
253 | an_less_cross = dist_an_cross[torch.lt(dist[i][mask_neg_cross], self.alpha_2)]
254 |
255 | an_weight_intra = torch.exp(self.tval*(-1* an_less_intra +self.alpha_1))
256 | an_weight_intra_sum = torch.sum(an_weight_intra)+1e-5
257 | an_weight_cross = torch.exp(self.tval*(-1* an_less_cross +self.alpha_2))
258 | an_weight_cross_sum = torch.sum(an_weight_cross)+1e-5
259 | an_sum_intra = torch.sum(torch.mul(self.alpha_1-an_less_intra,an_weight_intra))
260 | an_sum_cross = torch.sum(torch.mul(self.alpha_2-an_less_cross,an_weight_cross))
261 |
262 | loss_an =torch.div(an_sum_intra,an_weight_intra_sum ) +torch.div(an_sum_cross, an_weight_cross_sum)
263 | #loss_an = torch.div(an_sum_cross,an_weight_cross_sum )
264 | loss += loss_ap + loss_an
265 | #loss += loss_an
266 | return loss * 1.0/ dist.size(0)
267 |
268 | def normalize(self, x, axis=-1):
269 | x = 1.* x /(torch.norm(x, 2, axis, keepdim = True).expand_as(x)+ 1e-12)
270 | return x
271 |
272 | def euclidean_dist(self, x, y):
273 | m, n =x.size(0), y.size(0)
274 |
275 | xx = torch.pow(x,2).sum(1, keepdim= True).expand(m,n)
276 | yy = torch.pow(y,2).sum(1, keepdim= True).expand(n,m).t()
277 | dist = xx + yy
278 | dist.addmm_(1, -2, x, y.t())
279 | dist = dist.clamp(min =1e-12).sqrt()
280 |
281 | return dist
282 |
283 | ############################################################################################
284 | # Associate Loss
285 | ############################################################################################
286 | class ASS_loss(nn.Module):
287 | def __init__(self, walker_loss=1.0, visit_loss=1.0):
288 | super(ASS_loss, self).__init__()
289 | self.walker_loss = walker_loss
290 | self.visit_loss = visit_loss
291 | self.ce = nn.CrossEntropyLoss()
292 | self.logsoftmax = nn.LogSoftmax(dim=-1)
293 |
294 | def forward(self, feature, targets, sub):
295 | ## normalize
296 | feature = torch.nn.functional.normalize(feature, dim=1, p=2)
297 | loss = 0.0
298 | for i in range(feature.size(0)):
299 | cross_modality = sub.ne(sub[i])
300 |
301 | p_logit_ab, v_loss_ab = self.probablity(feature, cross_modality, targets)
302 | p_logit_ba, v_loss_ba = self.probablity(feature, ~cross_modality, targets)
303 | n1 = targets[cross_modality].size(0)
304 | n2 = targets[~cross_modality].size(0)
305 |
306 | is_pos_ab = targets[cross_modality].expand(n1,n1).eq(targets[cross_modality].expand(n1,n1).t())
307 |
308 | p_target_ab = is_pos_ab.float()/torch.sum(is_pos_ab, dim=1).float().expand_as(is_pos_ab)
309 |
310 | is_pos_ba = targets[~cross_modality].expand(n2,n2).eq(targets[cross_modality].expand(n2,n2).t())
311 | p_target_ba = is_pos_ba.float()/torch.sum(is_pos_ba, dim=1).float().expand_as(is_pos_ba)
312 |
313 | p_logit_ab = self.logsoftmax(p_logit_ab)
314 | p_logit_ba = self.logsoftmax(p_logit_ba)
315 |
316 | loss += (- p_target_ab * p_logit_ab).mean(0).sum()+ (- p_target_ba * p_logit_ba).mean(0).sum()
317 |
318 | loss += 1.0*(v_loss_ab+v_loss_ba)
319 |
320 | return loss/feature.size(0)/4
321 |
322 | def probablity(self, feature, cross_modality, target):
323 | a = feature[cross_modality]
324 | b = feature[~cross_modality]
325 |
326 | match_ab = a.mm(b.t())
327 |
328 | p_ab = F.softmax(match_ab, dim=-1)
329 | p_ba = F.softmax(match_ab, dim=-1)
330 | p_aba = torch.log(1e-8+p_ab.mm(p_ba))
331 |
332 | visit_loss = self.new_visit(p_ab, target, cross_modality)
333 |
334 | return p_aba, visit_loss
335 |
336 | def new_visit(self, p_ab, target, cross_modality):
337 | p_ab = torch.log(1e-8 +p_ab)
338 | visit_probability = p_ab.mean(dim=0).expand_as(p_ab)
339 | n1 = target[cross_modality].size(0)
340 | n2 = target[~cross_modality].size(0)
341 | p_target_ab = target[cross_modality].expand(n1,n1).eq(target[~cross_modality].expand(n2,n2))
342 | p_target_ab = p_target_ab.float()/torch.sum(p_target_ab, dim=1).float().expand_as(p_target_ab)
343 | loss = (- p_target_ab * visit_probability).mean(0).sum()
344 | return loss
345 |
346 | def normalize(self, x, axis=-1):
347 | x = 1.* x /(torch.norm(x, 2, axis, keepdim = True).expand_as(x)+ 1e-12)
348 | return x
349 |
350 | ############################################################################################
351 | # Wasserstein Distance
352 | ############################################################################################
353 | class SinkhornDistance(nn.Module):
354 | def __init__(self, eps=0.01, max_iter=100, reduction='mean'):
355 | super(SinkhornDistance, self).__init__()
356 | self.eps = eps
357 | self.max_iter = max_iter
358 | self.reduction = reduction
359 |
360 | def forward(self, x, y):
361 | # The Sinkhorn algorithm takes as input three variables :
362 | C = self._cost_matrix(x, y) # Wasserstein cost function
363 | C = C.cuda()
364 | n_points = x.shape[-2]
365 | if x.dim() == 2:
366 | batch_size = 1
367 | else:
368 | batch_size = x.shape[0]
369 |
370 | # both marginals are fixed with equal weights
371 | mu = torch.empty(batch_size, n_points, dtype=torch.float,
372 | requires_grad=False).fill_(1.0 / n_points).squeeze()
373 | nu = torch.empty(batch_size, n_points, dtype=torch.float,
374 | requires_grad=False).fill_(1.0 / n_points).squeeze()
375 |
376 | u = torch.zeros_like(mu)
377 | u = u.cuda()
378 | v = torch.zeros_like(nu)
379 | v = v.cuda()
380 | # To check if algorithm terminates because of threshold
381 | # or max iterations reached
382 | actual_nits = 0
383 | # Stopping criterion
384 | thresh = 1e-1
385 |
386 | # Sinkhorn iterations
387 | for i in range(self.max_iter):
388 | u1 = u # useful to check the update
389 | u = self.eps * (torch.log(mu + 1e-8).cuda() - self.lse(self.M(C, u, v))).cuda() + u
390 | v = self.eps * (torch.log(nu + 1e-8).cuda() - self.lse(self.M(C, u, v).transpose(-2, -1))).cuda() + v
391 | err = (u - u1).abs().sum(-1).mean()
392 | err = err.cuda()
393 |
394 | actual_nits += 1
395 | if err.item() < thresh:
396 | break
397 |
398 | U, V = u, v
399 | # Transport plan pi = diag(a)*K*diag(b)
400 | pi = torch.exp(self.M(C, U, V))
401 | # Sinkhorn distance
402 | cost = torch.sum(pi * C, dim=(-2, -1))
403 |
404 | if self.reduction == 'mean':
405 | cost = cost.mean()
406 | elif self.reduction == 'sum':
407 | cost = cost.sum()
408 |
409 | return cost
410 |
411 | def M(self, C, u, v):
412 | "Modified cost for logarithmic updates"
413 | "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
414 | return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps
415 |
416 | @staticmethod
417 | def _cost_matrix(x, y, p=2):
418 | "Returns the matrix of $|x_i-y_j|^p$."
419 | x_col = x.unsqueeze(-2)
420 | y_lin = y.unsqueeze(-3)
421 | C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
422 | return C
423 |
424 | @staticmethod
425 | def lse(A):
426 | "log-sum-exp"
427 | # add 10^-6 to prevent NaN
428 | result = torch.log(torch.exp(A).sum(-1) + 1e-6)
429 | return result
430 |
431 | @staticmethod
432 | def ave(u, u1, tau):
433 | "Barycenter subroutine, used by kinetic acceleration through extrapolation."
434 | return tau * u + (1 - tau) * u1
--------------------------------------------------------------------------------