├── feat_stas
├── strategies
│ ├── __init__.py
│ ├── core_set.py
│ ├── util.py
│ └── strategy.py
├── models
│ ├── __pycache__
│ │ ├── inception.cpython-37.pyc
│ │ ├── inception.cpython-39.pyc
│ │ ├── net2layer.cpython-37.pyc
│ │ └── net2layer.cpython-39.pyc
│ └── inception.py
├── feat_extraction.py
├── SnP.py
└── dataloader.py
├── images
├── SnP.gif
├── SnP.jpg
└── datasets.jpg
├── trainingset_search_vehicle.py
├── trainingset_search_person.py
└── README.md
/feat_stas/strategies/__init__.py:
--------------------------------------------------------------------------------
1 | from .core_set import CoreSet
2 |
--------------------------------------------------------------------------------
/images/SnP.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yorkeyao/SnP/HEAD/images/SnP.gif
--------------------------------------------------------------------------------
/images/SnP.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yorkeyao/SnP/HEAD/images/SnP.jpg
--------------------------------------------------------------------------------
/images/datasets.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yorkeyao/SnP/HEAD/images/datasets.jpg
--------------------------------------------------------------------------------
/feat_stas/models/__pycache__/inception.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yorkeyao/SnP/HEAD/feat_stas/models/__pycache__/inception.cpython-37.pyc
--------------------------------------------------------------------------------
/feat_stas/models/__pycache__/inception.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yorkeyao/SnP/HEAD/feat_stas/models/__pycache__/inception.cpython-39.pyc
--------------------------------------------------------------------------------
/feat_stas/models/__pycache__/net2layer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yorkeyao/SnP/HEAD/feat_stas/models/__pycache__/net2layer.cpython-37.pyc
--------------------------------------------------------------------------------
/feat_stas/models/__pycache__/net2layer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yorkeyao/SnP/HEAD/feat_stas/models/__pycache__/net2layer.cpython-39.pyc
--------------------------------------------------------------------------------
/trainingset_search_vehicle.py:
--------------------------------------------------------------------------------
1 | import os
2 | from feat_stas import SnP
3 | import argparse
4 | import numpy as np
5 | import logging
6 |
7 | parser = argparse.ArgumentParser(description='outputs')
8 | parser.add_argument('--result_dir', type=str, metavar='PATH', default='sample_data/')
9 | parser.add_argument('--c_num', default=50, type=int, help='number of cluster')
10 | parser.add_argument('--n_num_id', default=2000, type=int, help='number of ids')
11 | parser.add_argument('--target', type=str, default='veri', choices=['veri', 'alice-vehicle'], help='items used to caculate sampling score')
12 | parser.add_argument('--output_data', type=str, metavar='PATH', default='/data/reid_data/alice-vehicle/searched_2000(3000)_directly_6data')
13 | parser.add_argument('--ID_sampling_method', type=str, default='SnP', choices=['greedy', 'random', 'SnP'], help='how to sample')
14 | parser.add_argument('--img_sampling_ratio', default=1.0, type=float, help='image sampling ratio')
15 | parser.add_argument('--img_sampling_method', type=str, default='FPS', choices=['FPS', 'random'], help='how to sample')
16 | parser.add_argument('--no_sample', action='store_true', help='do not perform sample')
17 | parser.add_argument('--cuda', action='store_true', help='whether cuda is enabled')
18 | parser.add_argument('--FD_model', type=str, default='inception', choices=['inception', 'posenet'],
19 | help='model to calculate FD distance')
20 |
21 | opt = parser.parse_args()
22 | result_dir=opt.result_dir
23 | c_num=opt.c_num
24 | n_num_id=opt.n_num_id
25 |
26 | np.random.seed(0)
27 |
28 | # data pool
29 | data_dict = {
30 | 'veri': '/data/reid_data/VeRi/',
31 | 'aic': '/data/reid_data/AIC19-reid/',
32 | 'alice-vehicle': '/data/reid_data/alice-vehicle/',
33 | 'vid': '/data/reid_data/VehicleID_V1.0/',
34 | 'vehiclex': '/data/reid_data/alice-vehicle/vehicleX_random_attributes/',
35 | 'veri-wild': '/data/reid_data/veri-wild/VeRI-Wild/',
36 | 'stanford_cars': '/data/reid_data/stanfordcar/',
37 | 'compcars': '/data/reid_data/compcar/CompCars/',
38 | 'vd1': '/data/reid_data/PKU-VD/PKU-VD/VD1/',
39 | 'vd2': '/data/reid_data/PKU-VD/PKU-VD/VD2/'
40 | }
41 |
42 | databse_id= ['veri', 'aic', 'vid', 'veri-wild', 'vehiclex', 'stanford_cars', 'vd1', 'vd2']
43 |
44 | if opt.target == 'alice-vehicle':
45 | target = data_dict['alice-vehicle'] + 'alice-vehicle_train'
46 | if opt.target == 'veri':
47 | target = data_dict['veri'] + "image_train"
48 | databse_id.remove ('veri')
49 |
50 |
51 | result_dir=opt.result_dir
52 |
53 | if not os.path.isdir(result_dir):
54 | os.makedirs(result_dir)
55 |
56 | if os.path.isdir(opt.output_data):
57 | assert ("output dir has already exist")
58 |
59 | SnP.training_set_search(target, data_dict, databse_id, opt, result_dir, c_num, version = "vehicle")
60 |
61 |
62 |
--------------------------------------------------------------------------------
/trainingset_search_person.py:
--------------------------------------------------------------------------------
1 | import os
2 | from feat_stas import SnP
3 | import argparse
4 | import numpy as np
5 |
6 |
7 | parser = argparse.ArgumentParser(description='outputs')
8 | parser.add_argument('--result_dir', type=str, metavar='PATH', default='sample_data/')
9 | parser.add_argument('--c_num', default=50, type=int, help='number of cluster')
10 | parser.add_argument('--n_num_id', default=301, type=int, help='number of ids')
11 | parser.add_argument('--target', type=str, default='market', choices=['market', 'alice-person'], help='items used to caculate sampling score')
12 | parser.add_argument('--output_data', type=str, metavar='PATH', default='/data/reid_data/alice-person/random_301')
13 | parser.add_argument('--ID_sampling_method', type=str, default='SnP', choices=['greedy', 'random', 'SnP'], help='how to sample')
14 | parser.add_argument('--img_sampling_ratio', default=1.0, type=float, help='image sampling ratio')
15 | parser.add_argument('--img_sampling_method', type=str, default='FPS', choices=['FPS', 'random'], help='how to sample')
16 | parser.add_argument('--no_sample', action='store_true', help='do not perform sample')
17 | parser.add_argument('--cuda', action='store_true', help='whether cuda is enabled')
18 | parser.add_argument('--FD_model', type=str, default='inception', choices=['inception', 'posenet'],
19 | help='model to calculate FD distance')
20 |
21 |
22 | opt = parser.parse_args()
23 | result_dir=opt.result_dir
24 | c_num=opt.c_num
25 | n_num_id=opt.n_num_id
26 | np.random.seed(0)
27 |
28 | # data pool
29 | data_dict = {
30 | 'duke': '/data/reid_data/duke_reid/bounding_box_train/', #duke
31 | 'market': '/data/reid_data/market/bounding_box_train/', #market
32 | 'msmt': '/data/reid_data/MSMT/MSMT_bounding_box_train/', #msmt
33 | 'cuhk': '/data/reid_data/cuhk03_release/', #cuhk
34 | 'alice-person': '/data/reid_data/alice-person/bounding_box_train/', #alice
35 | 'raid': '/data/reid_data/RAiD_Dataset-master/', #raid
36 | 'unreal': '/data/reid_data/unreal/UnrealPerson-data/unreal_v3.1/images/', #unreal
37 | 'personx': '/data/reid_data/personx/bounding_box_train/', #personx
38 | 'randperson': '/data/reid_data/randperson_subset/randperson_subset/', #randperson
39 | 'pku': '/data/reid_data/PKU-Reid/PKUv1a_128x48/', #pku
40 | 'ilids': '/data/reid_data/i-LIDS-VID/', #ilids
41 | 'viper': '/data/reid_data/VIPeR/', # viper
42 | }
43 |
44 |
45 | databse_id = ['duke', 'market', 'msmt', 'cuhk', 'raid', 'unreal', 'personx', 'randperson', 'pku', 'viper']
46 |
47 | # databse_id = ['duke', 'market', 'msmt'] # pool_A
48 | # databse_id = ['duke', 'market', 'msmt', 'unreal'] # pool_B
49 | # databse_id = ['duke', 'market', 'msmt', 'unreal', 'personx', 'randperson'] # pool_C
50 |
51 | if opt.target == 'alice-person':
52 | target = data_dict['alice-person']
53 | if opt.target == 'market':
54 | target = data_dict['market']
55 | databse_id.remove ('market')
56 |
57 | result_dir=opt.result_dir
58 |
59 | if not os.path.isdir(result_dir):
60 | os.makedirs(result_dir)
61 |
62 | if os.path.isdir(opt.output_data):
63 | assert ("output dir has already exist")
64 |
65 | SnP.training_set_search(target, data_dict, databse_id, opt, result_dir, c_num, version = "person")
66 |
67 |
68 |
--------------------------------------------------------------------------------
/feat_stas/strategies/core_set.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pdb
3 | from .strategy import Strategy
4 | from sklearn.neighbors import NearestNeighbors
5 | import pickle
6 | from datetime import datetime
7 | from sklearn.metrics import pairwise_distances
8 | from tqdm import tqdm
9 |
10 | class CoreSet(Strategy):
11 | def __init__(self, X, Y, idxs_lb, net, handler, args, tor=1e-4):
12 | super(CoreSet, self).__init__(X, Y, idxs_lb, net, handler, args)
13 | self.tor = tor
14 |
15 | def furthest_first(self, X, X_set, n):
16 | m = np.shape(X)[0]
17 | if np.shape(X_set)[0] == 0:
18 | min_dist = np.tile(float("inf"), m)
19 | else:
20 | dist_ctr = pairwise_distances(X, X_set)
21 | min_dist = np.amin(dist_ctr, axis=1)
22 |
23 | idxs = []
24 |
25 | for i in tqdm(range(n)):
26 | idx = min_dist.argmax()
27 | idxs.append(idx)
28 | dist_new_ctr = pairwise_distances(X, X[[idx], :])
29 | for j in range(m):
30 | min_dist[j] = min(min_dist[j], dist_new_ctr[j, 0])
31 |
32 | return idxs
33 |
34 | def query(self, n, embedding, idxs_lb):
35 | t_start = datetime.now()
36 | idxs_unlabeled = np.arange(self.n_pool)[~idxs_lb]
37 | lb_flag = self.idxs_lb.copy()
38 | # embedding = self.get_embedding(self.X, self.Y)
39 | # embedding = embedding.numpy()
40 |
41 | chosen = self.furthest_first(embedding[idxs_unlabeled, :], embedding[lb_flag, :], n)
42 |
43 | return idxs_unlabeled[chosen]
44 |
45 |
46 | def query_old(self, n):
47 | lb_flag = self.idxs_lb.copy()
48 | embedding = self.get_embedding(self.X, self.Y)
49 | embedding = embedding.numpy()
50 |
51 | print('calculate distance matrix')
52 | t_start = datetime.now()
53 | dist_mat = np.matmul(embedding, embedding.transpose())
54 | sq = np.array(dist_mat.diagonal()).reshape(len(self.X), 1)
55 | dist_mat *= -2
56 | dist_mat += sq
57 | dist_mat += sq.transpose()
58 | dist_mat = np.sqrt(dist_mat)
59 | print(datetime.now() - t_start)
60 | print('calculate greedy solution')
61 | t_start = datetime.now()
62 | mat = dist_mat[~lb_flag, :][:, lb_flag]
63 |
64 | for i in range(n):
65 | if i % 10 == 0:
66 | print('greedy solution {}/{}'.format(i, n))
67 | mat_min = mat.min(axis=1)
68 | q_idx_ = mat_min.argmax()
69 | q_idx = np.arange(self.n_pool)[~lb_flag][q_idx_]
70 | lb_flag[q_idx] = True
71 | mat = np.delete(mat, q_idx_, 0)
72 | mat = np.append(mat, dist_mat[~lb_flag, q_idx][:, None], axis=1)
73 |
74 | print(datetime.now() - t_start)
75 | opt = mat.min(axis=1).max()
76 |
77 | bound_u = opt
78 | bound_l = opt/2.0
79 | delta = opt
80 |
81 | xx, yy = np.where(dist_mat <= opt)
82 | dd = dist_mat[xx, yy]
83 |
84 | lb_flag_ = self.idxs_lb.copy()
85 | subset = np.where(lb_flag_==True)[0].tolist()
86 |
87 | SEED = 5
88 | sols = None
89 |
90 | if sols is None:
91 | q_idxs = lb_flag
92 | else:
93 | lb_flag_[sols] = True
94 | q_idxs = lb_flag_
95 | print('sum q_idxs = {}'.format(q_idxs.sum()))
96 |
97 | return np.arange(self.n_pool)[(self.idxs_lb ^ q_idxs)]
98 |
--------------------------------------------------------------------------------
/feat_stas/strategies/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import logging
4 |
5 | def logger_config(log_path,logging_name):
6 | logger = logging.getLogger(logging_name)
7 | logger.setLevel(level=logging.DEBUG)
8 | handler = logging.FileHandler(log_path, encoding='UTF-8',mode='w')
9 | handler.setLevel(logging.INFO)
10 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
11 | handler.setFormatter(formatter)
12 | console = logging.StreamHandler()
13 | console.setLevel(logging.DEBUG)
14 | logger.addHandler(handler)
15 | logger.addHandler(console)
16 | return logger
17 |
18 |
19 |
20 | def save_df_as_npy(path, df):
21 | """
22 | Save pandas dataframe (multi-index or non multi-index) as an NPY file
23 | for later retrieval. It gets a list of input dataframe's index levels,
24 | column levels and underlying array data and saves it as an NPY file.
25 |
26 | Parameters
27 | ----------
28 | path : str
29 | Path for saving the dataframe.
30 | df : pandas dataframe
31 | Input dataframe's index, column and underlying array data are gathered
32 | in a nested list and saved as an NPY file.
33 | This is capable of handling multi-index dataframes.
34 |
35 | Returns
36 | -------
37 | out : None
38 |
39 | """
40 |
41 | if df.index.nlevels>1:
42 | lvls = [list(i) for i in df.index.levels]
43 | lbls = [list(i) for i in df.index.labels]
44 | indx = [lvls, lbls]
45 | else:
46 | indx = list(df.index)
47 |
48 | if df.columns.nlevels>1:
49 | lvls = [list(i) for i in df.columns.levels]
50 | lbls = [list(i) for i in df.columns.labels]
51 | cols = [lvls, lbls]
52 | else:
53 | cols = list(df.columns)
54 |
55 | data_flat = df.values.ravel()
56 | df_all = [indx, cols, data_flat]
57 | np.save(path, df_all)
58 |
59 | def load_df_from_npy(path):
60 | """
61 | Load pandas dataframe (multi-index or regular one) from NPY file.
62 |
63 | Parameters
64 | ----------
65 | path : str
66 | Path to the NPY file containing the saved pandas dataframe data.
67 |
68 | Returns
69 | -------
70 | df : Pandas dataframe
71 | Pandas dataframe that's retrieved back saved earlier as an NPY file.
72 |
73 | """
74 |
75 | df_all = np.load(path)
76 | if isinstance(df_all[0][0], list):
77 | indx = pd.MultiIndex(levels=df_all[0][0], labels=df_all[0][1])
78 | else:
79 | indx = df_all[0]
80 |
81 | if isinstance(df_all[1][0], list):
82 | cols = pd.MultiIndex(levels=df_all[1][0], labels=df_all[1][1])
83 | else:
84 | cols = df_all[1]
85 |
86 | df0 = pd.DataFrame(index=indx, columns=cols)
87 | df0[:] = df_all[2].reshape(df0.shape)
88 | return df0
89 |
90 | def max_columns(df0, cols=''):
91 | """
92 | Get dataframe with best configurations
93 |
94 | Parameters
95 | ----------
96 | df0 : pandas dataframe
97 | Input pandas dataframe, which could be a multi-index or a regular one.
98 | cols : list, optional
99 | List of strings that would be used as the column IDs for
100 | output pandas dataframe.
101 |
102 | Returns
103 | -------
104 | df : Pandas dataframe
105 | Pandas dataframe with best configurations for each row of the input
106 | dataframe for maximum value, where configurations refer to the column
107 | IDs of the input dataframe.
108 |
109 | """
110 |
111 | df = df0.reindex_axis(sorted(df0.columns), axis=1)
112 | if df.columns.nlevels==1:
113 | idx = df.values.argmax(-1)
114 | max_vals = df.values[range(len(idx)), idx]
115 | max_df = pd.DataFrame({'':df.columns[idx], 'Out':max_vals})
116 | max_df.index = df.index
117 | else:
118 | input_args = [list(i) for i in df.columns.levels]
119 | input_arg_lens = [len(i) for i in input_args]
120 |
121 | shp = [len(list(i)) for i in df.index.levels] + input_arg_lens
122 | speedups = df.values.reshape(shp)
123 |
124 | idx = speedups.reshape(speedups.shape[:2] + (-1,)).argmax(-1)
125 | argmax_idx = np.dstack((np.unravel_index(idx, input_arg_lens)))
126 | best_args = np.array(input_args)[np.arange(argmax_idx.shape[-1]), argmax_idx]
127 |
128 | N = len(input_arg_lens)
129 | max_df = pd.DataFrame(best_args.reshape(-1,N), index=df.index)
130 | max_vals = speedups.max(axis=tuple(-np.arange(len(input_arg_lens))-1)).ravel()
131 | max_df['Out'] = max_vals
132 | if cols!='':
133 | max_df.columns = cols
134 | return max_df
135 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Search and Pruning (SnP) Framework for Training Set Search
2 |
3 | This repository includes our code for the paper 'Large-scale Training Data Search for Object Re-identification' in CVPR2023.
4 |
5 | Related material: [Paper](https://arxiv.org/pdf/2303.16186), [Video](https://youtu.be/OAZ0Pka2mKE), [Zhihu](https://zhuanlan.zhihu.com/p/641872113)
6 |
7 |
8 |
9 |
14 |
15 | As shown in figure above, we present a search and pruning (SnP) solution to the training data search problem in object re-ID. The source data pool is 1 order of magnitude larger than existing re-ID training sets in terms of the number of images and the number of identities. When the target is AlicePerson, from the source pool, our method (SnP) results in a training set 80\% smaller than the source pool while achieving a similar or even higher re-ID accuracy. The searched training set is also superior to existing individual training sets such as Market-1501, Duke, and MSMT.
16 |
17 | ## Requirements
18 |
19 | - Sklearn
20 | - Scipy 1.2.1
21 | - PyTorch 1.7.0 + torchivision 0.8.1
22 |
23 | ## Re-ID Datasets Preparation
24 |
25 |
26 | 
27 |
28 | Please prepare the following datasets for person re-ID: [DukeMTMC-reID](https://exposing.ai/duke_mtmc/), [Market1503](https://zheng-lab.cecs.anu.edu.au/Project/project_reid.html), [MSMT17](http://www.pkuvmc.com/publications/msmt17.html), [CUHK03](https://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html), [RAiD](https://cs-people.bu.edu/dasabir/raid.php), [PersonX](https://github.com/sxzrt/Instructions-of-the-PersonX-dataset), [UnrealPerson](https://github.com/FlyHighest/UnrealPerson), [RandPerson](https://github.com/VideoObjectSearch/RandPerson), [PKU-Reid](https://github.com/charliememory/PKU-Reid-Dataset), [VIPeR](https://vision.soe.ucsc.edu/node/178), [AlicePerson (target data in VisDA20)](https://github.com/Simon4Yan/VisDA2020).
29 |
30 | You may need to sign up to get access to some of these datasets. Please store these datasets in a file strcuture like this
31 |
32 | ```
33 | ~
34 | └───reid_data
35 | └───duke_reid
36 | │ │ bounding_box_train
37 | │ │ ...
38 | │
39 | └───market
40 | │ │ bounding_box_train
41 | │ │ ...
42 | │
43 | └───MSMT
44 | │ │ MSMT_bounding_box_train
45 | │ │ ...
46 | │
47 | └───cuhk03_release
48 | │ │ cuhk-03.mat
49 | │ │ ...
50 | │
51 | └───alice-person
52 | │ │ bounding_box_train
53 | │ │ ...
54 | │
55 | └───RAiD_Dataset-master
56 | │ │ bounding_box_train
57 | │ │ ...
58 | │
59 | └───unreal
60 | │ │ UnrealPerson-data
61 | │ │ ...
62 | │
63 | └───randperson_subset
64 | │ │ randperson_subset
65 | │ │ ...
66 | │
67 | └───PKU-Reid
68 | │ │ PKUv1a_128x48
69 | │ │ ...
70 | │
71 | └───i-LIDS-VID
72 | │ │ images
73 | │ │ ...
74 | │
75 | └───VIPeR
76 | │ │ images
77 | │ │ ...
78 | ```
79 |
80 | Please prepare the following datasets for vehicle re-ID: [VeRi](https://github.com/JDAI-CV/VeRidataset), [CityFlow-reID](https://www.aicitychallenge.org/), [VehicleID](https://www.pkuml.org/resources/pku-vehicleid.html), [VeRi-wild](https://github.com/PKU-IMRE/VERI-Wild), [VehicleX](https://drive.google.com/file/d/1qySICqFJdgjMVi6CrLwVxEOuvgcQgtF_/view?usp=sharing), [Stanford Cars](http://ai.stanford.edu/~jkrause/cars/car_dataset.html), [PKU-vd1 and PKU-vd2](https://www.pkuml.org/resources/pku-vds.html). The AliceVehicle will be public available by our team shortly.
81 |
82 | Please store these datasets in a file strcuture like this
83 |
84 | ```
85 | ~
86 | └───reid_data
87 | └───VeRi
88 | │ │ bounding_box_train
89 | │ │ ...
90 | │
91 | └───AIC19-reid
92 | │ │ bounding_box_train
93 | │ │ ...
94 | │
95 | └───VehicleID_V1.0
96 | │ │ image
97 | │ │ ...
98 | │
99 | └───vehicleX_random_attributes
100 | │ │ ...
101 | │
102 | └───veri-wild
103 | │ │ VeRI-Wild
104 | │ │ ...
105 | │
106 | └───stanford_cars
107 | │ │ cars_train
108 | │ │ ...
109 | │
110 | └───compcars
111 | │ │ CompCars
112 | │ │ ...
113 | │
114 | └───PKU-VD
115 | │ │ VD1
116 | │ │ VD2
117 | │ │ ...
118 | ```
119 |
120 | ## Running example
121 |
122 |
127 |
128 | The SnP framework are shown in animation above. For running such process, when Market is used as target, we can seach a training set with 2860 IDs using the command below:
129 |
130 | ```python
131 | python trainingset_search_person.py --target 'market' \
132 | --result_dir 'results/sample_data_market/' --n_num_id 2860 \
133 | --ID_sampling_method SnP --img_sampling_method 'FPS' --img_sampling_ratio 0.5 \
134 | --output_data '/data/reid_data/market/SnP_2860IDs_0.5Imgs_0610'
135 | ```
136 |
137 | When VeRi is used as target, the command is:
138 |
139 | ```python
140 | python trainingset_search_vehicle.py --target 'veri' \
141 | --result_dir './results/sample_data_veri/' --n_num_id 3118 \
142 | --ID_sampling_method SnP --img_sampling_method 'FPS' --img_sampling_ratio 0.5 \
143 | --output_data '/data/data/VeRi/SnP_3118IDs_0.5Imgs_0610'
144 | ```
145 |
146 |
147 |
148 | ## Citation
149 |
150 | If you find this code useful, please kindly cite:
151 |
152 | ```
153 | @article{yao2023large,
154 | title={Large-scale Training Data Search for Object Re-identification},
155 | author={Yao, Yue and Lei, Huan and Gedeon, Tom and Zheng, Liang},
156 | journal={arXiv preprint arXiv:2303.16186},
157 | year={2023}
158 | }
159 | ```
160 |
161 | If you have any question, feel free to contact yue.yao@anu.edu.au
162 |
--------------------------------------------------------------------------------
/feat_stas/feat_extraction.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import torchvision
4 |
5 | import numpy as np
6 | import torch
7 | from scipy import linalg
8 | from matplotlib.pyplot import imread, imsave
9 | from torch.nn.functional import adaptive_avg_pool2d, adaptive_max_pool2d
10 | from scipy.special import softmax
11 | from shutil import copyfile
12 | from k_means_constrained import KMeansConstrained
13 | from glob import glob
14 | import os.path as osp
15 | from PIL import Image
16 | import numpy as np
17 | from skimage.transform import resize
18 | from sklearn.cluster import KMeans
19 | import xml.dom.minidom as XD
20 | try:
21 | from tqdm import tqdm
22 | except ImportError:
23 | # If not tqdm is not available, provide a mock version of it
24 | def tqdm(x): return x
25 |
26 |
27 |
28 | def make_square(image, max_dim = 512):
29 | max_dim = max(np.shape(image)[0], np.shape(image)[1])
30 | h, w = image.shape[:2]
31 | top_pad = (max_dim - h) // 2
32 | bottom_pad = max_dim - h - top_pad
33 | left_pad = (max_dim - w) // 2
34 | right_pad = max_dim - w - left_pad
35 | padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]
36 | image = np.pad(image, padding, mode='constant', constant_values=0)
37 | window = (top_pad, left_pad, h + top_pad, w + left_pad)
38 | return image
39 |
40 | def get_activations(opt, files, model, batch_size=50, dims=8192,
41 | cuda=False, verbose=False):
42 | """Calculates the activations of the pool_3 layer for all images.
43 |
44 | Params:
45 | -- files : List of image files paths
46 | -- model : Instance of inception model
47 | -- batch_size : Batch size of images for the model to process at once.
48 | Make sure that the number of samples is a multiple of
49 | the batch size, otherwise some samples are ignored. This
50 | behavior is retained to match the original FID score
51 | implementation.
52 | -- dims : Dimensionality of features returned by Inception
53 | -- cuda : If set to True, use GPU
54 | -- verbose : If set to True and parameter out_step is given, the number
55 | of calculated batches is reported.
56 | Returns:
57 | -- A numpy array of dimension (num images, dims) that contains the
58 | activations of the given tensor when feeding inception with the
59 | query tensor.
60 | """
61 | model.eval()
62 |
63 | # if len(files) % batch_size != 0:
64 | # print(('Warning: number of images is not a multiple of the '
65 | # 'batch size. Some samples are going to be ignored.'))
66 | if batch_size > len(files):
67 | print(('Warning: batch size is bigger than the data size. '
68 | 'Setting batch size to data size'), len(files))
69 | batch_size = len(files)
70 |
71 | n_batches = len(files) // batch_size
72 | n_remainder= len(files) % batch_size
73 |
74 | print('\rnumber of batches is %d' % n_batches),
75 | n_used_imgs = n_batches * batch_size
76 |
77 | pred_arr = np.empty((n_used_imgs+n_remainder, dims))
78 | if n_remainder!=0:
79 | n_batches=n_batches+1
80 | for i in tqdm(range(n_batches)):
81 | if verbose:
82 | print('\rPropagating batch %d/%d' % (i + 1, n_batches),
83 | end='', flush=True)
84 | start = i * batch_size
85 | if n_remainder!=0 and i==n_batches-1:
86 | end = start + n_remainder
87 | else:
88 | end = start + batch_size
89 |
90 | # print (files[start:end])
91 | images = np.array([resize( imread(str(f)).astype(np.float32), (64, 64, 3) ).astype(np.float32)
92 | for f in files[start:end]])
93 |
94 | images = images.transpose((0, 3, 1, 2))
95 | images /= 255
96 |
97 | batch = torch.from_numpy(images).type(torch.FloatTensor)
98 | if cuda:
99 | batch = batch.cuda()
100 |
101 | if opt.FD_model == 'inception':
102 | pred = model(batch)[0]
103 | # If model output is not scalar, apply global spatial average pooling.
104 | # This happens if you choose a dimensionality not equal 2048.
105 | if pred.shape[2] != 1 or pred.shape[3] != 1:
106 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
107 | if opt.FD_model == 'posenet':
108 | pred = model(batch)
109 | # print (np.shape (pred))
110 | pred = adaptive_max_pool2d(pred, output_size=(1, 1))
111 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(end - start, -1)
112 | print('\rPropagating batch %d/%d' % (i + 1, n_batches))
113 |
114 | if verbose:
115 | print(' done')
116 |
117 | return pred_arr
118 |
119 |
120 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
121 | """Numpy implementation of the Frechet Distance.
122 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
123 | and X_2 ~ N(mu_2, C_2) is
124 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
125 |
126 | Stable version by Dougal J. Sutherland.
127 |
128 | Params:
129 | -- mu1 : Numpy array containing the activations of a layer of the
130 | inception net (like returned by the function 'get_predictions')
131 | for generated samples.
132 | -- mu2 : The sample mean over activations, precalculated on an
133 | representative data set.
134 | -- sigma1: The covariance matrix over activations for generated samples.
135 | -- sigma2: The covariance matrix over activations, precalculated on an
136 | representative data set.
137 |
138 | Returns:
139 | -- : The Frechet Distance.
140 | """
141 |
142 | mu1 = np.atleast_1d(mu1)
143 | mu2 = np.atleast_1d(mu2)
144 |
145 | sigma1 = np.atleast_2d(sigma1)
146 | sigma2 = np.atleast_2d(sigma2)
147 |
148 | assert mu1.shape == mu2.shape, \
149 | 'Training and test mean vectors have different lengths'
150 | assert sigma1.shape == sigma2.shape, \
151 | 'Training and test covariances have different dimensions'
152 |
153 | diff = mu1 - mu2
154 |
155 | # Product might be almost singular
156 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
157 | if not np.isfinite(covmean).all():
158 | msg = ('fid calculation produces singular product; '
159 | 'adding %s to diagonal of cov estimates') % eps
160 | print(msg)
161 | offset = np.eye(sigma1.shape[0]) * eps
162 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
163 |
164 | # Numerical error might give slight imaginary component
165 | if np.iscomplexobj(covmean):
166 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
167 | m = np.max(np.abs(covmean.imag))
168 | raise ValueError('Imaginary component {}'.format(m))
169 | covmean = covmean.real
170 |
171 | tr_covmean = np.trace(covmean)
172 |
173 | return (diff.dot(diff) + np.trace(sigma1) +
174 | np.trace(sigma2) - 2 * tr_covmean)
175 |
176 |
177 | def calculate_activation_statistics(opt, files, model, batch_size=50,
178 | dims=8192, cuda=False, verbose=False):
179 | """Calculation of the statistics used by the FID.
180 | Params:
181 | -- files : List of image files paths
182 | -- model : Instance of inception model
183 | -- batch_size : The images numpy array is split into batches with
184 | batch size batch_size. A reasonable batch size
185 | depends on the hardware.
186 | -- dims : Dimensionality of features returned by Inception
187 | -- cuda : If set to True, use GPU
188 | -- verbose : If set to True and parameter out_step is given, the
189 | number of calculated batches is reported.
190 | Returns:
191 | -- mu : The mean over samples of the activations of the pool_3 layer of
192 | the inception model.
193 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
194 | the inception model.
195 | """
196 | act = get_activations(opt, files, model, batch_size, dims, cuda, verbose)
197 | mu = np.mean(act, axis=0)
198 | sigma = np.cov(act, rowvar=False)
199 | #eigen_vals, eigen_vecs= np.linalg.eig(sigma)
200 | #sum_eigen_val=eigen_vals.sum().real
201 | sum_eigen_val = (sigma.diagonal()).sum()
202 | return mu, sigma, sum_eigen_val
203 |
204 |
205 | def _compute_statistics_of_path(opt, path, model, batch_size, dims, cuda):
206 | if path.endswith('.npz'):
207 | f = np.load(path)
208 | m, s = f['mu'][:], f['sigma'][:]
209 | f.close()
210 | else:
211 | path = pathlib.Path(path)
212 | files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
213 | #random.shuffle(files)
214 | #files = files[:2000]
215 | m, s, sum_eigen_val = calculate_activation_statistics(opt, files, model, batch_size,
216 | dims, cuda)
217 | return m, s, sum_eigen_val
218 |
--------------------------------------------------------------------------------
/feat_stas/models/inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 |
6 | try:
7 | from torchvision.models.utils import load_state_dict_from_url
8 | except ImportError:
9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
10 |
11 | # Inception weights ported to Pytorch from
12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14 |
15 |
16 | class InceptionV3(nn.Module):
17 | """Pretrained InceptionV3 network returning feature maps"""
18 |
19 | # Index of default block of inception to return,
20 | # corresponds to output of final average pooling
21 | DEFAULT_BLOCK_INDEX = 3
22 |
23 | # Maps feature dimensionality to their output blocks indices
24 | BLOCK_INDEX_BY_DIM = {
25 | 64: 0, # First max pooling features
26 | 192: 1, # Second max pooling featurs
27 | 768: 2, # Pre-aux classifier features
28 | 2048: 3 # Final average pooling features
29 | }
30 |
31 | def __init__(self,
32 | output_blocks=[DEFAULT_BLOCK_INDEX],
33 | resize_input=True,
34 | normalize_input=True,
35 | requires_grad=False,
36 | use_fid_inception=True):
37 | """Build pretrained InceptionV3
38 |
39 | Parameters
40 | ----------
41 | output_blocks : list of int
42 | Indices of blocks to return features of. Possible values are:
43 | - 0: corresponds to output of first max pooling
44 | - 1: corresponds to output of second max pooling
45 | - 2: corresponds to output which is fed to aux classifier
46 | - 3: corresponds to output of final average pooling
47 | resize_input : bool
48 | If true, bilinearly resizes input to width and height 299 before
49 | feeding input to model. As the network without fully connected
50 | layers is fully convolutional, it should be able to handle inputs
51 | of arbitrary size, so resizing might not be strictly needed
52 | normalize_input : bool
53 | If true, scales the input from range (0, 1) to the range the
54 | pretrained Inception network expects, namely (-1, 1)
55 | requires_grad : bool
56 | If true, parameters of the model require gradients. Possibly useful
57 | for finetuning the network
58 | use_fid_inception : bool
59 | If true, uses the pretrained Inception model used in Tensorflow's
60 | FID implementation. If false, uses the pretrained Inception model
61 | available in torchvision. The FID Inception model has different
62 | weights and a slightly different structure from torchvision's
63 | Inception model. If you want to compute FID scores, you are
64 | strongly advised to set this parameter to true to get comparable
65 | results.
66 | """
67 | super(InceptionV3, self).__init__()
68 |
69 | self.resize_input = resize_input
70 | self.normalize_input = normalize_input
71 | self.output_blocks = sorted(output_blocks)
72 | self.last_needed_block = max(output_blocks)
73 |
74 | assert self.last_needed_block <= 3, \
75 | 'Last possible output block index is 3'
76 |
77 | self.blocks = nn.ModuleList()
78 |
79 | if use_fid_inception:
80 | inception = fid_inception_v3()
81 | else:
82 | inception = models.inception_v3(pretrained=True)
83 |
84 | # Block 0: input to maxpool1
85 | block0 = [
86 | inception.Conv2d_1a_3x3,
87 | inception.Conv2d_2a_3x3,
88 | inception.Conv2d_2b_3x3,
89 | nn.MaxPool2d(kernel_size=3, stride=2)
90 | ]
91 | self.blocks.append(nn.Sequential(*block0))
92 |
93 | # Block 1: maxpool1 to maxpool2
94 | if self.last_needed_block >= 1:
95 | block1 = [
96 | inception.Conv2d_3b_1x1,
97 | inception.Conv2d_4a_3x3,
98 | nn.MaxPool2d(kernel_size=3, stride=2)
99 | ]
100 | self.blocks.append(nn.Sequential(*block1))
101 |
102 | # Block 2: maxpool2 to aux classifier
103 | if self.last_needed_block >= 2:
104 | block2 = [
105 | inception.Mixed_5b,
106 | inception.Mixed_5c,
107 | inception.Mixed_5d,
108 | inception.Mixed_6a,
109 | inception.Mixed_6b,
110 | inception.Mixed_6c,
111 | inception.Mixed_6d,
112 | inception.Mixed_6e,
113 | ]
114 | self.blocks.append(nn.Sequential(*block2))
115 |
116 | # Block 3: aux classifier to final avgpool
117 | if self.last_needed_block >= 3:
118 | block3 = [
119 | inception.Mixed_7a,
120 | inception.Mixed_7b,
121 | inception.Mixed_7c,
122 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
123 | ]
124 | self.blocks.append(nn.Sequential(*block3))
125 |
126 | for param in self.parameters():
127 | param.requires_grad = requires_grad
128 |
129 | def forward(self, inp):
130 | """Get Inception feature maps
131 |
132 | Parameters
133 | ----------
134 | inp : torch.autograd.Variable
135 | Input tensor of shape Bx3xHxW. Values are expected to be in
136 | range (0, 1)
137 |
138 | Returns
139 | -------
140 | List of torch.autograd.Variable, corresponding to the selected output
141 | block, sorted ascending by index
142 | """
143 | outp = []
144 | x = inp
145 |
146 | if self.resize_input:
147 | x = F.interpolate(x,
148 | size=(299, 299),
149 | mode='bilinear',
150 | align_corners=False)
151 |
152 | if self.normalize_input:
153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154 |
155 | for idx, block in enumerate(self.blocks):
156 | x = block(x)
157 | if idx in self.output_blocks:
158 | outp.append(x)
159 |
160 | if idx == self.last_needed_block:
161 | break
162 |
163 | return outp
164 |
165 |
166 | def fid_inception_v3():
167 | """Build pretrained Inception model for FID computation
168 |
169 | The Inception model for FID computation uses a different set of weights
170 | and has a slightly different structure than torchvision's Inception.
171 |
172 | This method first constructs torchvision's Inception and then patches the
173 | necessary parts that are different in the FID Inception model.
174 | """
175 | inception = models.inception_v3(num_classes=1008,
176 | aux_logits=False,
177 | pretrained=False)
178 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
179 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
180 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
181 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
182 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
183 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
184 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
185 | inception.Mixed_7b = FIDInceptionE_1(1280)
186 | inception.Mixed_7c = FIDInceptionE_2(2048)
187 |
188 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
189 | inception.load_state_dict(state_dict)
190 | modify=0
191 | # if modify==1:
192 | # inception=l11;
193 | return inception
194 |
195 |
196 | class FIDInceptionA(models.inception.InceptionA):
197 | """InceptionA block patched for FID computation"""
198 | def __init__(self, in_channels, pool_features):
199 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
200 |
201 | def forward(self, x):
202 | branch1x1 = self.branch1x1(x)
203 |
204 | branch5x5 = self.branch5x5_1(x)
205 | branch5x5 = self.branch5x5_2(branch5x5)
206 |
207 | branch3x3dbl = self.branch3x3dbl_1(x)
208 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
209 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
210 |
211 | # Patch: Tensorflow's average pool does not use the padded zero's in
212 | # its average calculation
213 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
214 | count_include_pad=False)
215 | branch_pool = self.branch_pool(branch_pool)
216 |
217 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
218 | return torch.cat(outputs, 1)
219 |
220 |
221 | class FIDInceptionC(models.inception.InceptionC):
222 | """InceptionC block patched for FID computation"""
223 | def __init__(self, in_channels, channels_7x7):
224 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
225 |
226 | def forward(self, x):
227 | branch1x1 = self.branch1x1(x)
228 |
229 | branch7x7 = self.branch7x7_1(x)
230 | branch7x7 = self.branch7x7_2(branch7x7)
231 | branch7x7 = self.branch7x7_3(branch7x7)
232 |
233 | branch7x7dbl = self.branch7x7dbl_1(x)
234 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
235 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
236 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
237 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
238 |
239 | # Patch: Tensorflow's average pool does not use the padded zero's in
240 | # its average calculation
241 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
242 | count_include_pad=False)
243 | branch_pool = self.branch_pool(branch_pool)
244 |
245 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
246 | return torch.cat(outputs, 1)
247 |
248 |
249 | class FIDInceptionE_1(models.inception.InceptionE):
250 | """First InceptionE block patched for FID computation"""
251 | def __init__(self, in_channels):
252 | super(FIDInceptionE_1, self).__init__(in_channels)
253 |
254 | def forward(self, x):
255 | branch1x1 = self.branch1x1(x)
256 |
257 | branch3x3 = self.branch3x3_1(x)
258 | branch3x3 = [
259 | self.branch3x3_2a(branch3x3),
260 | self.branch3x3_2b(branch3x3),
261 | ]
262 | branch3x3 = torch.cat(branch3x3, 1)
263 |
264 | branch3x3dbl = self.branch3x3dbl_1(x)
265 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
266 | branch3x3dbl = [
267 | self.branch3x3dbl_3a(branch3x3dbl),
268 | self.branch3x3dbl_3b(branch3x3dbl),
269 | ]
270 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
271 |
272 | # Patch: Tensorflow's average pool does not use the padded zero's in
273 | # its average calculation
274 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
275 | count_include_pad=False)
276 | branch_pool = self.branch_pool(branch_pool)
277 |
278 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
279 | return torch.cat(outputs, 1)
280 |
281 |
282 | class FIDInceptionE_2(models.inception.InceptionE):
283 | """Second InceptionE block patched for FID computation"""
284 | def __init__(self, in_channels):
285 | super(FIDInceptionE_2, self).__init__(in_channels)
286 |
287 | def forward(self, x):
288 | branch1x1 = self.branch1x1(x)
289 |
290 | branch3x3 = self.branch3x3_1(x)
291 | branch3x3 = [
292 | self.branch3x3_2a(branch3x3),
293 | self.branch3x3_2b(branch3x3),
294 | ]
295 | branch3x3 = torch.cat(branch3x3, 1)
296 |
297 | branch3x3dbl = self.branch3x3dbl_1(x)
298 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
299 | branch3x3dbl = [
300 | self.branch3x3dbl_3a(branch3x3dbl),
301 | self.branch3x3dbl_3b(branch3x3dbl),
302 | ]
303 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
304 |
305 | # Patch: The FID Inception model uses max pooling instead of average
306 | # pooling. This is likely an error in this specific Inception
307 | # implementation, as other Inception models use average pooling here
308 | # (which matches the description in the paper).
309 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
310 | branch_pool = self.branch_pool(branch_pool)
311 |
312 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
313 | return torch.cat(outputs, 1)
314 |
--------------------------------------------------------------------------------
/feat_stas/strategies/strategy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch import nn
3 | import sys
4 | import torch
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torch.autograd import Variable
8 | from torch.utils.data import DataLoader
9 | from copy import deepcopy
10 | import pdb
11 | # import resnet
12 | from torch.distributions.categorical import Categorical
13 | class Strategy:
14 | def __init__(self, X, Y, idxs_lb, net, handler, args):
15 | self.X = X
16 | self.Y = Y
17 | self.idxs_lb = idxs_lb
18 | self.net = net
19 | self.handler = handler
20 | self.args = args
21 | self.n_pool = len(Y)
22 | use_cuda = torch.cuda.is_available()
23 |
24 | def query(self, n):
25 | pass
26 |
27 | def update(self, idxs_lb):
28 | self.idxs_lb = idxs_lb
29 |
30 | def _train(self, epoch, loader_tr, optimizer):
31 | self.clf.train()
32 | accFinal = 0.
33 | for batch_idx, (x, y, idxs) in enumerate(loader_tr):
34 | x, y = Variable(x.cuda()), Variable(y.cuda())
35 | optimizer.zero_grad()
36 | out, e1 = self.clf(x)
37 | loss = F.cross_entropy(out, y)
38 | accFinal += torch.sum((torch.max(out,1)[1] == y).float()).data.item()
39 | loss.backward()
40 |
41 | # clamp gradients, just in case
42 | for p in filter(lambda p: p.grad is not None, self.clf.parameters()): p.grad.data.clamp_(min=-.1, max=.1)
43 | optimizer.step()
44 |
45 | return accFinal / len(loader_tr.dataset.X), loss.item()
46 |
47 |
48 | def train(self, reset=True, optimizer=0, verbose=True, data=[], net=[]):
49 | def weight_reset(m):
50 | newLayer = deepcopy(m)
51 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
52 | m.reset_parameters()
53 |
54 | n_epoch = self.args['n_epoch']
55 | if reset: self.clf = self.net.apply(weight_reset).cuda()
56 | if type(net) != list: self.clf = net
57 | if type(optimizer) == int: optimizer = optim.Adam(self.clf.parameters(), lr = self.args['lr'], weight_decay=0)
58 |
59 | idxs_train = np.arange(self.n_pool)[self.idxs_lb]
60 | loader_tr = DataLoader(self.handler(self.X[idxs_train], torch.Tensor(self.Y.numpy()[idxs_train]).long(), transform=self.args['transform']), shuffle=True, **self.args['loader_tr_args'])
61 | if len(data) > 0:
62 | loader_tr = DataLoader(self.handler(data[0], torch.Tensor(data[1]).long(), transform=self.args['transform']), shuffle=True, **self.args['loader_tr_args'])
63 |
64 | epoch = 1
65 | accCurrent = 0.
66 | bestAcc = 0.
67 | attempts = 0
68 | while accCurrent < 0.99:
69 | accCurrent, lossCurrent = self._train(epoch, loader_tr, optimizer)
70 | if bestAcc < accCurrent:
71 | bestAcc = accCurrent
72 | attempts = 0
73 | else: attempts += 1
74 | epoch += 1
75 | if verbose: print(str(epoch) + ' ' + str(attempts) + ' training accuracy: ' + str(accCurrent), flush=True)
76 | # reset if not converging
77 | if (epoch % 1000 == 0) and (accCurrent < 0.2) and (self.args['modelType'] != 'linear'):
78 | self.clf = self.net.apply(weight_reset)
79 | optimizer = optim.Adam(self.clf.parameters(), lr = self.args['lr'], weight_decay=0)
80 | if attempts >= 50 and self.args['modelType'] == 'linear': break
81 | #if attempts >= 50 and self.args['modelType'] != 'linear' and len(idxs_train) > 1000:
82 | # self.clf = self.net.apply(weight_reset)
83 | # optimizer = optim.Adam(self.clf.parameters(), lr = self.args['lr'], weight_decay=0)
84 | # attempts = 0
85 |
86 |
87 | def train_val(self, valFrac=0.1, opt='adam', verbose=False):
88 | def weight_reset(m):
89 | newLayer = deepcopy(m)
90 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
91 | newLayer.reset_parameters()
92 | m.reset_parameters()
93 |
94 | if verbose: print(' ',flush=True)
95 | if verbose: print('getting validation minimizing number of epochs', flush=True)
96 | self.clf = self.net.apply(weight_reset).cuda()
97 | if opt == 'adam': optimizer = optim.Adam(self.clf.parameters(), lr=self.args['lr'], weight_decay=0)
98 | if opt == 'sgd': optimizer = optim.SGD(self.clf.parameters(), lr=self.args['lr'], weight_decay=0)
99 |
100 | idxs_train = np.arange(self.n_pool)[self.idxs_lb]
101 | nVal = int(len(idxs_train) * valFrac)
102 | idxs_train = idxs_train[np.random.permutation(len(idxs_train))]
103 | idxs_val = idxs_train[:nVal]
104 | idxs_train = idxs_train[nVal:]
105 |
106 | loader_tr = DataLoader(self.handler(self.X[idxs_train], torch.Tensor(self.Y.numpy()[idxs_train]).long(), transform=self.args['transform']), shuffle=True, **self.args['loader_tr_args'])
107 |
108 | epoch = 1
109 | accCurrent = 0.
110 | bestLoss = np.inf
111 | attempts = 0
112 | ce = nn.CrossEntropyLoss()
113 | valTensor = torch.Tensor(self.Y.numpy()[idxs_val]).long()
114 | attemptThresh = 10
115 | while True:
116 | accCurrent, lossCurrent = self._train(epoch, loader_tr, optimizer)
117 | valPreds = self.predict_prob(self.X[idxs_val], valTensor, exp=False)
118 | loss = ce(valPreds, valTensor).item()
119 | if loss < bestLoss:
120 | bestLoss = loss
121 | attempts = 0
122 | bestEpoch = epoch
123 | else:
124 | attempts += 1
125 | if attempts == attemptThresh: break
126 | if verbose: print(epoch, attempts, loss, bestEpoch, bestLoss, flush=True)
127 | epoch += 1
128 |
129 | return bestEpoch
130 |
131 | def get_dist(self, epochs, nEns=1, opt='adam', verbose=False):
132 |
133 | def weight_reset(m):
134 | newLayer = deepcopy(m)
135 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
136 | newLayer.reset_parameters()
137 | m.reset_parameters()
138 |
139 | if verbose: print(' ',flush=True)
140 | if verbose: print('training to indicated number of epochs', flush=True)
141 |
142 | ce = nn.CrossEntropyLoss()
143 | idxs_train = np.arange(self.n_pool)[self.idxs_lb]
144 | loader_tr = DataLoader(self.handler(self.X[idxs_train], torch.Tensor(self.Y.numpy()[idxs_train]).long(), transform=self.args['transform']), shuffle=True, **self.args['loader_tr_args'])
145 | dataSize = len(idxs_train)
146 | N = np.round((epochs * len(loader_tr)) ** 0.5)
147 | allAvs = []
148 | allWeights = []
149 | for m in range(nEns):
150 |
151 | # initialize new model and optimizer
152 | net = self.net.apply(weight_reset).cuda()
153 | if opt == 'adam': optimizer = optim.Adam(net.parameters(), lr=self.args['lr'], weight_decay=0)
154 | if opt == 'sgd': optimizer = optim.SGD(net.parameters(), lr=self.args['lr'], weight_decay=0)
155 |
156 | nUpdates = k = 0
157 | ek = (k + 1) * N
158 | pVec = torch.cat([torch.zeros_like(p).cpu().flatten() for p in self.clf.parameters()])
159 |
160 | avIterates = []
161 | for epoch in range(epochs + 1):
162 | correct = lossTrain = 0.
163 | net = net.train()
164 | for ind, (x, y, _) in enumerate(loader_tr):
165 | x, y = x.cuda(), y.cuda()
166 | optimizer.zero_grad()
167 | output, _ = net(x)
168 | correct += torch.sum(output.argmax(1) == y).item()
169 | loss = ce(output, y)
170 | loss.backward()
171 | lossTrain += loss.item() * len(y)
172 | optimizer.step()
173 | flat = torch.cat([deepcopy(p.detach().cpu()).flatten() for p in net.parameters()])
174 | pVec = pVec + flat
175 | nUpdates += 1
176 | if nUpdates > ek:
177 | avIterates.append(pVec / N)
178 | pVec = torch.cat([torch.zeros_like(p).cpu().flatten() for p in net.parameters()])
179 | k += 1
180 | ek = (k + 1) * N
181 |
182 | lossTrain /= dataSize
183 | accuracy = correct / dataSize
184 | if verbose: print(m, epoch, nUpdates, accuracy, lossTrain, flush=True)
185 | allAvs.append(avIterates)
186 | allWeights.append(torch.cat([deepcopy(p.detach().cpu()).flatten() for p in net.parameters()]))
187 |
188 | for m in range(nEns):
189 | avIterates = torch.stack(allAvs[m])
190 | if k > 1: avIterates = torch.stack(allAvs[m][1:])
191 | avIterates = avIterates - torch.mean(avIterates, 0)
192 | allAvs[m] = avIterates
193 |
194 | return allWeights, allAvs, optimizer, net
195 |
196 | def getNet(self, params):
197 | i = 0
198 | model = deepcopy(self.clf).cuda()
199 | for p in model.parameters():
200 | L = len(p.flatten())
201 | param = params[i:(i + L)]
202 | p.data = param.view(p.size())
203 | i += L
204 | return model
205 |
206 | def fitBatchnorm(self, model):
207 | idxs_train = np.arange(self.n_pool)[self.idxs_lb]
208 | loader_tr = DataLoader(self.handler(self.X[idxs_train], torch.Tensor(self.Y.numpy()[idxs_train]).long(), transform=self.args['transform']), shuffle=True, **self.args['loader_tr_args'])
209 | model = model.cuda()
210 | for ind, (x, y, _) in enumerate(loader_tr):
211 | x, y = x.cuda(), y.cuda()
212 | output = model(x)
213 | return model
214 |
215 | def sampleNet(self, weights, iterates):
216 | nEns = len(weights)
217 | k = len(iterates[0])
218 | i = np.random.randint(nEns)
219 | z = torch.randn(k, 1)
220 | weightSample = weights[i].view(-1) - torch.mm(iterates[i].t(), z).view(-1) / np.sqrt(k)
221 | sampleNet = self.getNet(weightSample).cuda()
222 | sampleNet = self.fitBatchnorm(sampleNet)
223 | return sampleNet
224 |
225 | def getPosterior(self, weights, iterates, X, Y, nSamps=50):
226 | net = self.fitBatchnorm(self.sampleNet(weights, iterates))
227 | output = self.predict_prob(X, Y, model=net) / nSamps
228 | print(' ', flush=True)
229 | ce = nn.CrossEntropyLoss()
230 | print('sampling models', flush=True)
231 | for i in range(nSamps - 1):
232 | net = self.fitBatchnorm(self.sampleNet(weights, iterates))
233 | output = output + self.predict_prob(X, Y, model=net) / nSamps
234 | print(i+2, torch.sum(torch.argmax(output, 1) == Y).item() / len(Y), flush=True)
235 | return output.numpy()
236 |
237 | def predict(self, X, Y):
238 | if type(X) is np.ndarray:
239 | loader_te = DataLoader(self.handler(X, Y, transform=self.args['transformTest']),
240 | shuffle=False, **self.args['loader_te_args'])
241 | else:
242 | loader_te = DataLoader(self.handler(X.numpy(), Y, transform=self.args['transformTest']),
243 | shuffle=False, **self.args['loader_te_args'])
244 |
245 | self.clf.eval()
246 | P = torch.zeros(len(Y)).long()
247 | with torch.no_grad():
248 | for x, y, idxs in loader_te:
249 | x, y = Variable(x.cuda()), Variable(y.cuda())
250 | out, e1 = self.clf(x)
251 | pred = out.max(1)[1]
252 | P[idxs] = pred.data.cpu()
253 | return P
254 |
255 | def predict_prob(self, X, Y, model=[], exp=True):
256 | if type(model) == list: model = self.clf
257 |
258 | loader_te = DataLoader(self.handler(X, Y, transform=self.args['transformTest']), shuffle=False, **self.args['loader_te_args'])
259 | model = model.eval()
260 | probs = torch.zeros([len(Y), len(np.unique(self.Y))])
261 | with torch.no_grad():
262 | for x, y, idxs in loader_te:
263 | x, y = Variable(x.cuda()), Variable(y.cuda())
264 | out, e1 = model(x)
265 | if exp: out = F.softmax(out, dim=1)
266 | probs[idxs] = out.cpu().data
267 |
268 | return probs
269 |
270 | def predict_prob_dropout(self, X, Y, n_drop):
271 | loader_te = DataLoader(self.handler(X, Y, transform=self.args['transformTest']),
272 | shuffle=False, **self.args['loader_te_args'])
273 |
274 | self.clf.train()
275 | probs = torch.zeros([len(Y), len(np.unique(Y))])
276 | with torch.no_grad():
277 | for i in range(n_drop):
278 | print('n_drop {}/{}'.format(i+1, n_drop))
279 | for x, y, idxs in loader_te:
280 | x, y = Variable(x.cuda()), Variable(y.cuda())
281 | out, e1 = self.clf(x)
282 | prob = F.softmax(out, dim=1)
283 | probs[idxs] += out.cpu().data
284 | probs /= n_drop
285 |
286 | return probs
287 |
288 | def predict_prob_dropout_split(self, X, Y, n_drop):
289 | loader_te = DataLoader(self.handler(X, Y, transform=self.args['transformTest']),
290 | shuffle=False, **self.args['loader_te_args'])
291 |
292 | self.clf.train()
293 | probs = torch.zeros([n_drop, len(Y), len(np.unique(Y))])
294 | with torch.no_grad():
295 | for i in range(n_drop):
296 | print('n_drop {}/{}'.format(i+1, n_drop))
297 | for x, y, idxs in loader_te:
298 | x, y = Variable(x.cuda()), Variable(y.cuda())
299 | out, e1 = self.clf(x)
300 | probs[i][idxs] += F.softmax(out, dim=1).cpu().data
301 | return probs
302 |
303 | def get_embedding(self, X, Y):
304 | loader_te = DataLoader(self.handler(X, Y, transform=self.args['transformTest']),
305 | shuffle=False, **self.args['loader_te_args'])
306 | self.clf.eval()
307 | embedding = torch.zeros([len(Y), self.clf.get_embedding_dim()])
308 | with torch.no_grad():
309 | for x, y, idxs in loader_te:
310 | x, y = Variable(x.cuda()), Variable(y.cuda())
311 | out, e1 = self.clf(x)
312 | embedding[idxs] = e1.data.cpu()
313 |
314 | return embedding
315 |
316 | # gradient embedding for badge (assumes cross-entropy loss)
317 | def get_grad_embedding(self, X, Y, model=[]):
318 | if type(model) == list:
319 | model = self.clf
320 |
321 | embDim = model.get_embedding_dim()
322 | model.eval()
323 | nLab = len(np.unique(Y))
324 | embedding = np.zeros([len(Y), embDim * nLab])
325 | loader_te = DataLoader(self.handler(X, Y, transform=self.args['transformTest']),
326 | shuffle=False, **self.args['loader_te_args'])
327 | with torch.no_grad():
328 | for x, y, idxs in loader_te:
329 | x, y = Variable(x.cuda()), Variable(y.cuda())
330 | cout, out = model(x)
331 | out = out.data.cpu().numpy()
332 | batchProbs = F.softmax(cout, dim=1).data.cpu().numpy()
333 | maxInds = np.argmax(batchProbs,1)
334 | for j in range(len(y)):
335 | for c in range(nLab):
336 | if c == maxInds[j]:
337 | embedding[idxs[j]][embDim * c : embDim * (c+1)] = deepcopy(out[j]) * (1 - batchProbs[j][c])
338 | else:
339 | embedding[idxs[j]][embDim * c : embDim * (c+1)] = deepcopy(out[j]) * (-1 * batchProbs[j][c])
340 | return torch.Tensor(embedding)
341 |
342 | # fisher embedding for bait (assumes cross-entropy loss)
343 | def get_exp_grad_embedding(self, X, Y, probs=[], model=[]):
344 | if type(model) == list:
345 | model = self.clf
346 |
347 | embDim = model.get_embedding_dim()
348 | model.eval()
349 | nLab = len(np.unique(Y))
350 |
351 | embedding = np.zeros([len(Y), nLab, embDim * nLab])
352 | for ind in range(nLab):
353 | loader_te = DataLoader(self.handler(X, Y, transform=self.args['transformTest']),
354 | shuffle=False, **self.args['loader_te_args'])
355 | with torch.no_grad():
356 | for x, y, idxs in loader_te:
357 | x, y = Variable(x.cuda()), Variable(y.cuda())
358 | cout, out = model(x)
359 | out = out.data.cpu().numpy()
360 | batchProbs = F.softmax(cout, dim=1).data.cpu().numpy()
361 | for j in range(len(y)):
362 | for c in range(nLab):
363 | if c == ind:
364 | embedding[idxs[j]][ind][embDim * c : embDim * (c+1)] = deepcopy(out[j]) * (1 - batchProbs[j][c])
365 | else:
366 | embedding[idxs[j]][ind][embDim * c : embDim * (c+1)] = deepcopy(out[j]) * (-1 * batchProbs[j][c])
367 | if len(probs) > 0: embedding[idxs[j]][ind] = embedding[idxs[j]][ind] * np.sqrt(probs[idxs[j]][ind])
368 | else: embedding[idxs[j]][ind] = embedding[idxs[j]][ind] * np.sqrt(batchProbs[j][ind])
369 | return torch.Tensor(embedding)
370 |
371 |
372 |
--------------------------------------------------------------------------------
/feat_stas/SnP.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import torchvision
4 | import numpy as np
5 | import torch
6 | from scipy import linalg
7 | from matplotlib.pyplot import imread, imsave
8 | from torch.nn.functional import adaptive_avg_pool2d, adaptive_max_pool2d
9 | import random
10 | import re
11 | import copy
12 | from scipy.special import softmax
13 | from collections import defaultdict
14 | from shutil import copyfile
15 | import matplotlib.pyplot as plt
16 | from sklearn.decomposition import PCA
17 | from k_means_constrained import KMeansConstrained
18 | from glob import glob
19 | import os.path as osp
20 | import h5py
21 | import scipy.io
22 | from PIL import Image
23 | import collections
24 | import numpy as np
25 | from skimage.transform import resize
26 | from sklearn.cluster import KMeans
27 | import xml.dom.minidom as XD
28 | try:
29 | from tqdm import tqdm
30 | except ImportError:
31 | # If not tqdm is not available, provide a mock version of it
32 | def tqdm(x): return x
33 |
34 | from feat_stas.models.inception import InceptionV3
35 | from feat_stas.strategies import CoreSet
36 | from feat_stas.dataloader import get_id_path_of_data_vehicles
37 | from feat_stas.dataloader import get_id_path_of_data_person
38 | from feat_stas.feat_extraction import get_activations, calculate_frechet_distance, calculate_activation_statistics
39 |
40 |
41 | def training_set_search(tpaths, data_dict, dataset_id, opt, result_dir, c_num, version):
42 | """main function of the SnP framework"""
43 |
44 | if version == 'vehicle':
45 | img_paths, person_ids, dataset_ids, meta_dataset = get_id_path_of_data_vehicles (dataset_id, data_dict)
46 | if version == 'person':
47 | img_paths, person_ids, dataset_ids, meta_dataset = get_id_path_of_data_person (dataset_id, data_dict)
48 |
49 | cuda = opt.cuda
50 | if opt.FD_model == 'inception':
51 | dims = 2048
52 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
53 | model = InceptionV3([block_idx])
54 |
55 | if cuda:
56 | model.cuda()
57 | batch_size=256
58 |
59 | # caculate the feature stas of target set
60 | print('=========== caculate the feature stas of target set===========')
61 | files = []
62 | if version == 'vehicle':
63 | if opt.target == 'alice-vehicle':
64 | for i in range (1, 17):
65 | target_path = pathlib.Path(tpaths + "/cam" + str(i))
66 | files.extend(list(target_path.glob('*.jpg')) + list(target_path.glob('*.png')))
67 | if opt.target == 'veri':
68 | target_path = pathlib.Path(tpaths)
69 | files = list(target_path.glob('*.jpg')) + list(target_path.glob('*.png'))
70 | if version == 'person':
71 | target_path = pathlib.Path(tpaths)
72 | files = list(target_path.glob('*.jpg')) + list(target_path.glob('*.png'))
73 |
74 |
75 | if not os.path.exists(result_dir + '/target_feature.npy'):
76 | target_feature = get_activations(opt, files, model, batch_size, dims, cuda, verbose=False)
77 | np.save(result_dir + '/target_feature.npy', target_feature)
78 | else:
79 | target_feature = np.load(result_dir + '/target_feature.npy')
80 | m1 = np.mean(target_feature, axis=0)
81 | s1 = np.cov(target_feature, rowvar=False)
82 |
83 | # extracter feature for data pool
84 | if not os.path.exists(result_dir + '/feature_infer.npy'):
85 | print('=========== extracting feature of data pool ===========')
86 | model.eval()
87 | feature_infer = get_activations(opt, img_paths, model, batch_size, dims, cuda, verbose=False)
88 | if not os.path.isdir(result_dir):
89 | os.mkdir(result_dir)
90 | np.save(result_dir + '/feature_infer.npy', feature_infer)
91 | else:
92 | feature_infer = np.load(result_dir + '/feature_infer.npy')
93 |
94 | person_ids_array=np.array(person_ids)
95 | mean_feature_per_id=[]
96 | pid_per_id=[]
97 | did_per_id=[]
98 |
99 | # get mean fature of perid and the fid, variance of per_id with the target
100 | if not os.path.exists(result_dir + '/mean_feature_per_id.npy'):
101 | for did in range(1, len(dataset_id) + 1):
102 | ind_of_set = np.argwhere(np.array(dataset_ids) == did).squeeze()
103 | dataset_feature = feature_infer[ind_of_set]
104 | dataset_pid = person_ids_array[ind_of_set]
105 | pid_of_dataset=set(dataset_pid)
106 | for pid in pid_of_dataset:
107 | ind_of_pid = np.argwhere(np.array(dataset_pid) == pid).squeeze()
108 | feature_per_id = dataset_feature[ind_of_pid]
109 | id_ave_feature=feature_per_id.mean(0)
110 | mean_feature_per_id.append(id_ave_feature)
111 | pid_per_id.append(pid)
112 | did_per_id.append(did)
113 | np.save(result_dir+ '/mean_feature_per_id.npy',mean_feature_per_id)
114 | pid_did_fid_var = np.c_[np.array(pid_per_id), np.array(did_per_id)]
115 | np.save(result_dir+ '/pid_did_fid_var.npy', pid_did_fid_var)
116 | else:
117 | mean_feature_per_id = np.load(result_dir + '/mean_feature_per_id.npy')
118 | pid_did_fid_var = np.load(result_dir + '/pid_did_fid_var.npy')
119 |
120 | #remove 0 and -1
121 | ori_pid_per_id = pid_did_fid_var[:, 0]
122 | if version == 'vehicle':
123 | remove_ind=np.r_[np.argwhere(ori_pid_per_id == -1), np.argwhere(ori_pid_per_id == -1)].squeeze()
124 | else:
125 | remove_ind=np.r_[np.argwhere(ori_pid_per_id == -1), np.argwhere(ori_pid_per_id == 0)].squeeze()
126 |
127 | new_pid_did_fid_var = np.delete(pid_did_fid_var, remove_ind, 0)
128 | new_mean_feature_per_id = np.delete(mean_feature_per_id, remove_ind, 0)
129 |
130 |
131 | print('\r=========== clustering the data pool ===========')
132 | pid_per_id = new_pid_did_fid_var[:,0]
133 | did_per_id = new_pid_did_fid_var[:,1]
134 | # clustering ids based on ids' mean feature
135 | if not os.path.exists(result_dir + '/label_cluster_'+str(c_num)+'.npy'):
136 | # estimator = KMeans(n_clusters=c_num)
137 | estimator = KMeansConstrained(n_clusters=c_num, size_min=int(np.shape (new_mean_feature_per_id)[0] / c_num ), size_max=int(np.shape (new_mean_feature_per_id)[0] / c_num))
138 | estimator.fit(new_mean_feature_per_id)
139 | label_pred = estimator.labels_
140 | np.save(result_dir + '/label_cluster_'+str(c_num)+'.npy',label_pred)
141 | else:
142 | label_pred = np.load(result_dir + '/label_cluster_'+str(c_num)+'.npy')
143 |
144 | print('\r=========== caculating the fid between T and C_k ===========')
145 | if not os.path.exists(result_dir + '/cluster_fid_div.npy'):
146 | cluster_feature = []
147 | cluster_fid = []
148 | cluster_div = []
149 | for k in tqdm(range(c_num)):
150 | # initializatn of the first seed cluster 0
151 | initial_pid=pid_per_id[label_pred==k]
152 | initial_did=did_per_id[label_pred==k]
153 |
154 | initial_feature_infer = feature_infer[(dataset_ids == int(initial_did[0])) & (person_ids_array == initial_pid[0])]
155 |
156 | for j in range(1,len(initial_pid)):
157 | current_feature_infer=feature_infer[(dataset_ids == int(initial_did[j])) & (person_ids_array == initial_pid[j])]
158 | initial_feature_infer=np.r_[initial_feature_infer, current_feature_infer]
159 |
160 | cluster_feature.append(initial_feature_infer)
161 |
162 | mu = np.mean(initial_feature_infer, axis=0)
163 | sigma = np.cov(initial_feature_infer, rowvar=False)
164 |
165 | fea_corrcoef = np.corrcoef(initial_feature_infer)
166 | fea_corrcoef = np.ones(np.shape(fea_corrcoef)) - fea_corrcoef
167 | diversity_sum = np.sum(np.sum(fea_corrcoef)) - np.sum(np.diagonal(fea_corrcoef))
168 | current_div = diversity_sum / (np.shape (fea_corrcoef)[0] ** 2 - np.shape (fea_corrcoef)[0])
169 |
170 | # caculating domain gap
171 | current_fid = calculate_frechet_distance(m1, s1, mu, sigma)
172 | cluster_fid.append(current_fid)
173 | cluster_div.append(current_div)
174 | # cluster_mmd.append(current_mmd)
175 | np.save(result_dir + '/cluster_fid_div.npy', np.c_[np.array(cluster_fid), np.array(cluster_div)])
176 | #np.save(result_dir+'/cluster_fid_var.npy', np.c_[np.array(cluster_fid),np.array(cluster_var_gap)])
177 | else:
178 | cluster_fid_var=np.load(result_dir + '/cluster_fid_div.npy')
179 | cluster_fid=cluster_fid_var[:,0]
180 | cluster_div=cluster_fid_var[:,1]
181 |
182 | cluster_fida=np.array(cluster_fid)
183 | score_fid = softmax(-cluster_fida)
184 | sample_rate = score_fid
185 |
186 | c_num_len = []
187 | for kk in range(c_num):
188 | initial_pid = pid_per_id[label_pred == kk]
189 | c_num_len.append(len(initial_pid))
190 |
191 | id_score = []
192 | for jj in range(len(label_pred)):
193 | id_score.append(sample_rate[label_pred[jj]] / c_num_len[label_pred[jj]])
194 |
195 | if opt.ID_sampling_method == 'random':
196 | selected_data_ind = np.sort(np.random.choice(range(len(id_score)), opt.n_num_id, replace=False))
197 | if opt.ID_sampling_method == 'SnP':
198 | lowest_fd = float('inf')
199 | lowest_id_list = []
200 | if not os.path.exists(result_dir + '/domain_seletive_ids.npy'):
201 | cluster_rank = np.argsort(cluster_fida)
202 | current_list = []
203 | cluster_feature_aggressive = []
204 | for k in tqdm(cluster_rank):
205 | id_list = np.where (label_pred==k)[0]
206 | initial_pid=pid_per_id[label_pred==k]
207 | initial_did=did_per_id[label_pred==k]
208 | initial_feature_infer = feature_infer[(dataset_ids == int(initial_did[0])) & (person_ids_array == initial_pid[0])]
209 | for j in range(1,len(initial_pid)):
210 | current_feature_infer=feature_infer[(dataset_ids == int(initial_did[j])) & (person_ids_array == initial_pid[j])]
211 | initial_feature_infer=np.r_[initial_feature_infer, current_feature_infer]
212 |
213 | cluster_feature_aggressive.extend(initial_feature_infer)
214 | cluster_feature_aggressive_fixed = cluster_feature_aggressive
215 | target_feature_fixed = target_feature
216 | if len (cluster_feature_aggressive) > len (target_feature):
217 | cluster_idx = np.random.choice(range(len (cluster_feature_aggressive)), len(target_feature), replace=False)
218 | cluster_feature_aggressive_fixed = np.array([cluster_feature_aggressive[ii] for ii in cluster_idx])
219 | if len (cluster_feature_aggressive) < len (target_feature):
220 | cluster_idx = np.random.choice(range(len(target_feature)), len (cluster_feature_aggressive), replace=False)
221 | target_feature_fixed = target_feature[cluster_idx]
222 | mu = np.mean(cluster_feature_aggressive_fixed, axis=0)
223 | sigma = np.cov(cluster_feature_aggressive_fixed, rowvar=False)
224 | current_fid = calculate_frechet_distance(m1, s1, mu, sigma)
225 | current_list.extend (list (id_list))
226 | print (lowest_fd, current_fid)
227 | if lowest_fd > current_fid:
228 | lowest_fd = current_fid
229 | lowest_id_list = copy.deepcopy(current_list)
230 | np.save(result_dir + '/domain_seletive_ids.npy', lowest_id_list)
231 | else:
232 | lowest_id_list = np.load(result_dir + '/domain_seletive_ids.npy')
233 | print ("searched IDs", len (lowest_id_list))
234 | direct_selected_data_ind = np.array(lowest_id_list)
235 | if opt.n_num_id < len(direct_selected_data_ind):
236 | selected_data_ind = np.sort(np.random.choice(direct_selected_data_ind, opt.n_num_id, replace=False))
237 | else:
238 | selected_data_ind = np.array(lowest_id_list)
239 |
240 | if opt.ID_sampling_method == 'greedy':
241 | selected_data_ind = np.argsort(id_score)[-opt.n_num_id:]
242 |
243 |
244 | sdid = did_per_id[selected_data_ind]
245 | spid = pid_per_id[selected_data_ind]
246 | print (collections.Counter(sdid))
247 | data_dir = result_dir + '/proxy_set'
248 | if not os.path.isdir(data_dir):
249 | os.mkdir(data_dir)
250 | print('\r=========== building training set ===========')
251 | sampled_data=np.c_[sdid,spid]
252 | if not os.path.exists(result_dir + '/' + str(opt.ID_sampling_method) + '_' + str(opt.n_num_id) + '_result_IMG_list_'+str(c_num)+'.npy'):
253 | result_IMG_list, one_img_perID = IDidx2IMGidx(data_dict, dataset_id, sampled_data, opt.output_data, meta_dataset, feature_infer, opt)
254 | np.save(result_dir + '/' + str(opt.ID_sampling_method) + '_' + str(opt.n_num_id) + '_result_IMG_list.npy', result_IMG_list)
255 | np.save(result_dir + '/' + str(opt.ID_sampling_method) + '_' + str(opt.n_num_id) + '_one_img_perID.npy', one_img_perID)
256 | else:
257 | result_IMG_list = np.load(result_dir + '/' + str(opt.ID_sampling_method) + '_' + str(opt.n_num_id) + '_result_IMG_list.npy')
258 | one_img_perID = np.load(result_dir + '/' + str(opt.ID_sampling_method) + '_' + str(opt.n_num_id) + '_one_img_perID.npy')
259 |
260 | result_IMG_list = np.array(result_IMG_list)
261 | one_img_perID = np.array(one_img_perID)
262 |
263 | if opt.img_sampling_ratio < 1:
264 | if opt.img_sampling_method == 'FPS':
265 | direct_selected_data_ind = np.array(result_IMG_list)
266 | idb_is = np.zeros(len(direct_selected_data_ind), dtype=bool)
267 | for ii in one_img_perID:
268 | index_lowest_first_img_per_id = list(result_IMG_list).index(ii)
269 | idb_is[index_lowest_first_img_per_id] = 1
270 | selected_data_feature = feature_infer [direct_selected_data_ind]
271 | strategy = CoreSet (None, np.zeros(len(direct_selected_data_ind), dtype=bool), np.zeros(len(direct_selected_data_ind), dtype=bool), None, None, None)
272 | img_num = int (opt.img_sampling_ratio * len (result_IMG_list))
273 | div_selected_data_ind = strategy.query (img_num - len(one_img_perID), selected_data_feature, idb_is)
274 | selected_img_ind = list(one_img_perID)
275 | selected_img_ind.extend (direct_selected_data_ind[div_selected_data_ind])
276 |
277 | if opt.img_sampling_method == 'random':
278 | lowest_img_list = list(result_IMG_list)
279 | for ii in one_img_perID:
280 | lowest_img_list.remove (ii)
281 | direct_selected_data_ind = np.array(lowest_img_list)
282 | img_num = int (opt.img_sampling_ratio * len (result_IMG_list))
283 | if img_num > len(one_img_perID):
284 | selected_img_ind = list(np.sort(np.random.choice(direct_selected_data_ind, img_num - len(one_img_perID), replace=False)))
285 | else:
286 | selected_img_ind = []
287 | selected_img_ind.extend (one_img_perID)
288 | else:
289 | selected_img_ind = result_IMG_list
290 |
291 | result_feature = dataset_build_img(data_dict, dataset_id, selected_img_ind, data_dir, meta_dataset, feature_infer, opt)
292 |
293 | mu = np.mean(result_feature, axis=0)
294 | sigma = np.cov(result_feature, rowvar=False)
295 | current_fid = calculate_frechet_distance(m1, s1, mu, sigma)
296 |
297 | print('finished with a dataset has FD', current_fid, "to the target")
298 | return sampled_data
299 |
300 |
301 | def IDidx2IMGidx(dict, dataset_id, sampled_data, result_dir, meta_dataset, feature_infer, opt):
302 | pattern = re.compile(r'([-\d]+)_c([-\d]+)')
303 | pid_sampled = sampled_data[:, 1]
304 | did_sampled = sampled_data[:, 0]
305 |
306 | all_pids = {}
307 | image_searched = []
308 | one_img_perID = []
309 | id_count = 0
310 |
311 | for idx, (fname, pid, did, cid) in enumerate(meta_dataset):
312 | exist_judge = False
313 | for j in range (len(pid_sampled)):
314 | if str(pid) == str(pid_sampled[j]) and str(did) == str(did_sampled[j]):
315 | exist_judge = True
316 | if exist_judge:
317 | if pid not in all_pids:
318 | all_pids[pid] = {}
319 | all_pids[pid][did] = id_count
320 | one_img_perID.append (idx)
321 | id_count += 1
322 | if pid in all_pids:
323 | if did not in all_pids[pid]:
324 | all_pids[pid][did] = id_count
325 | one_img_perID.append (idx)
326 | id_count += 1
327 |
328 | image_searched.append (idx)
329 | print ("id count", id_count, "img count", len (image_searched))
330 | return image_searched, one_img_perID
331 |
332 | def dataset_build_img(dict, dataset_id, sampled_data, result_dir, meta_dataset, feature_infer, opt):
333 | all_pids = {}
334 | dstr_path = opt.output_data
335 | if not os.path.isdir(dstr_path):
336 | os.mkdir(dstr_path)
337 | img_count = 0
338 | id_count = 0
339 | feature_searched = []
340 | for idx in sampled_data:
341 | fname, pid, did, cid = meta_dataset[idx]
342 | if pid not in all_pids:
343 | all_pids[pid] = {}
344 | all_pids[pid][did] = id_count
345 | id_count += 1
346 | if pid in all_pids:
347 | if did not in all_pids[pid]:
348 | all_pids[pid][did] = id_count
349 | id_count += 1
350 | feature_searched.append (idx)
351 | pid = all_pids[pid][did]
352 | img_count += 1
353 | # print (pid, cid, did)
354 | new_path = dstr_path + '/' + '{:04}'.format(pid) + "_c" + '{:03}'.format(cid) + "_d" + '{:03}'.format(did) + "_" + str(img_count) + '.jpg'
355 | if opt.no_sample:
356 | continue
357 | copyfile(fname, new_path)
358 |
359 | print ("successfully create a dataset with", img_count, "images, and", id_count, "ids")
360 | feature_new = feature_infer[feature_searched]
361 | return feature_new
362 |
--------------------------------------------------------------------------------
/feat_stas/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torchvision
3 | import numpy as np
4 | from matplotlib.pyplot import imread, imsave
5 | import random
6 | import re
7 | from collections import defaultdict
8 | import matplotlib.pyplot as plt
9 | from sklearn.decomposition import PCA
10 | from glob import glob
11 | import os.path as osp
12 | import h5py
13 | import scipy.io
14 | from PIL import Image
15 | import numpy as np
16 | import xml.dom.minidom as XD
17 |
18 |
19 |
20 |
21 | def makeDir(path):
22 | if not os.path.isdir(path):
23 | os.mkdir(path)
24 |
25 | def process_veri_wild_vehicle(root, vehicle_info):
26 | imgid2vid = {}
27 | imgid2camid = {}
28 | imgid2imgpath = {}
29 | vehicle_info_lines = open(vehicle_info, 'r').readlines()
30 |
31 | for idx, line in enumerate(vehicle_info_lines):
32 | # if idx < 10:
33 | vid = line.strip().split('/')[0]
34 | imgid = line.strip().split(';')[0].split('/')[1]
35 | camid = line.strip().split(';')[1]
36 | img_path = osp.join(root, 'images', imgid + '.jpg')
37 | imgid2vid[imgid] = vid
38 | imgid2camid[imgid] = camid
39 | imgid2imgpath[imgid] = img_path
40 |
41 | assert len(imgid2vid) == len(vehicle_info_lines)
42 | return imgid2vid, imgid2camid, imgid2imgpath
43 |
44 | def get_id_path_of_data_vehicles (dataset_id, data_dict):
45 | """create data loader for vehicle re-ID datsets"""
46 | img_paths = []
47 | dataset_ids = []
48 | person_ids = []
49 | ret = []
50 | pattern = re.compile(r'([-\d]+)_c([-\d]+)')
51 | total_ids = 0
52 |
53 | # for each dataset in candidate list:
54 | for dd_idx, dataset in enumerate(dataset_id):
55 | all_pids = {}
56 | last_img_num = len (img_paths)
57 | if dataset == 'veri':
58 | root = data_dict['veri'] + '/image_train/'
59 | fpaths = sorted(glob(osp.join(root, '*.jpg')))
60 | dataset_id_list = [dd_idx + 1 for n in range(len(fpaths))]
61 | dataset_ids.extend (dataset_id_list)
62 | for fpath in fpaths:
63 | fname = root + osp.basename(fpath)
64 | pid, cam = map(int, pattern.search(fname).groups())
65 | if pid == -1: continue
66 | if pid not in all_pids:
67 | all_pids[pid] = len(all_pids)
68 | pid = all_pids[pid]
69 |
70 | img_paths.append(fname)
71 | person_ids.append(pid)
72 | ret.append((fname, pid, dd_idx + 1, cam))
73 |
74 | if dataset == 'aic':
75 | root = data_dict['aic']
76 | train_path = osp.join(root, 'image_train')
77 | xml_dir = osp.join(root, 'train_label.xml')
78 | reid_info = XD.parse(xml_dir).documentElement.getElementsByTagName('Item')
79 | index_by_fname_dict = defaultdict()
80 | for index in range(len(reid_info)):
81 | fname = reid_info[index].getAttribute('imageName')
82 | index_by_fname_dict[fname] = index
83 |
84 | fpaths = sorted(glob(osp.join(train_path, '*.jpg')))
85 | dataset_id_list = [dd_idx + 1 for n in range(len(fpaths))]
86 | dataset_ids.extend (dataset_id_list)
87 | for fpath in fpaths:
88 | fname = osp.basename(fpath)
89 | pid, cam = map(int, [reid_info[index_by_fname_dict[fname]].getAttribute('vehicleID'),
90 | reid_info[index_by_fname_dict[fname]].getAttribute('cameraID')[1:]])
91 | if pid not in all_pids:
92 | all_pids[pid] = len(all_pids)
93 | pid = all_pids[pid]
94 | cam -= 1
95 | fname = train_path + "/" + fname
96 | img_paths.append(fname)
97 | person_ids.append(pid)
98 | ret.append((fname, pid, dd_idx + 1, cam))
99 |
100 | if dataset == 'vid':
101 | root = data_dict['vid']
102 | label_path = osp.join (root, 'train_test_split')
103 | train_path_label = osp.join(label_path, 'train_list.txt')
104 |
105 | with open(train_path_label, 'r', encoding='utf-8') as f:
106 | lines = f.readlines()
107 | lines = [line.strip().split(' ') for line in lines]
108 | for line in lines:
109 | fname, pid = line
110 | fname = fname + ".jpg"
111 | if pid == -1: continue
112 | fname = root + "image/" + fname
113 | if pid not in all_pids:
114 | all_pids[pid] = len(all_pids)
115 | pid = all_pids[pid]
116 | img_paths.append(fname)
117 | person_ids.append(pid)
118 | ret.append((fname, pid, dd_idx + 1, cam))
119 | dataset_ids.append (dd_idx + 1)
120 |
121 | if dataset == 'vehiclex':
122 | root = data_dict['vehiclex']
123 | fpaths = sorted(glob(osp.join(root, '*.jpg')))
124 | dataset_id_list = [dd_idx + 1 for n in range(len(fpaths))]
125 | dataset_ids.extend (dataset_id_list)
126 | random.shuffle(fpaths)
127 | for fpath in fpaths:
128 | fname = osp.basename(fpath)
129 | pid, cam = map(int, pattern.search(fname).groups())
130 | fname = root + fname
131 | if pid not in all_pids:
132 | all_pids[pid] = len(all_pids)
133 | pid = all_pids[pid]
134 | img_paths.append(fname)
135 | person_ids.append(pid)
136 | ret.append((fname, pid, dd_idx + 1, cam))
137 |
138 | if dataset == 'veri-wild':
139 | root = data_dict['veri-wild']
140 | train_list = osp.join(root, 'train_test_split/train_list_start0.txt')
141 | vehicle_info = osp.join(root, 'train_test_split/vehicle_info.txt')
142 | imgid2vid, imgid2camid, imgid2imgpath = process_veri_wild_vehicle(root, vehicle_info)
143 | vid_container = set()
144 | img_list_lines = open(train_list, 'r').readlines()
145 | for idx, line in enumerate(img_list_lines):
146 | line = line.strip()
147 | vid = line.split('/')[0]
148 | vid_container.add(vid)
149 | vid2label = {vid: label for label, vid in enumerate(vid_container)}
150 |
151 | dataset_id_list = [dd_idx + 1 for n in range(len(img_list_lines))]
152 | dataset_ids.extend (dataset_id_list)
153 |
154 | for idx, line in enumerate(img_list_lines):
155 | line = line.strip()
156 | pid = int(line.split('/')[0])
157 | if pid not in all_pids:
158 | all_pids[pid] = len(all_pids)
159 | pid = all_pids[pid]
160 |
161 | imgid = line.split('/')[1].split('.')[0]
162 | # if relabel: vid = vid2label[vid]
163 | img_paths.append(imgid2imgpath[imgid])
164 | person_ids.append(pid)
165 | # print ((imgid2imgpath[imgid], int(vid), 5, int(imgid2camid[imgid])))
166 | ret.append((imgid2imgpath[imgid], pid, dd_idx + 1, int(imgid2camid[imgid])))
167 |
168 | if dataset == 'stanford_cars':
169 | root = data_dict['stanford_cars']
170 | stanford_dataset = torchvision.datasets.StanfordCars(root=root, download=True)
171 |
172 | dataset_id_list = [dd_idx + 1 for n in range(len(stanford_dataset))]
173 | dataset_ids.extend (dataset_id_list)
174 |
175 | for fname, pid in stanford_dataset._samples:
176 | img_paths.append(fname)
177 | if pid not in all_pids:
178 | all_pids[pid] = len(all_pids)
179 | pid = all_pids[pid]
180 | person_ids.append(int(pid))
181 | ret.append((fname, int(pid), dd_idx + 1, 1))
182 |
183 | if dataset == 'vd1':
184 | root = data_dict['vd1']
185 | label_path = osp.join (root, 'train_test')
186 | train_path_label = osp.join(label_path, 'trainlist.txt')
187 |
188 | with open(train_path_label, 'r', encoding='utf-8') as f:
189 | lines = f.readlines()
190 | lines = [line.strip().split(' ') for line in lines]
191 | for line in lines:
192 | fname, pid, _, _ = line
193 | fname = fname + ".jpg"
194 | if pid == -1: continue
195 | fname = root + "image/" + fname
196 | if pid not in all_pids:
197 | all_pids[pid] = len(all_pids)
198 | pid = all_pids[pid]
199 | img_paths.append(fname)
200 | person_ids.append(pid)
201 | ret.append((fname, pid, dd_idx + 1, cam))
202 | dataset_ids.append (dd_idx + 1)
203 |
204 | if dataset == 'vd2':
205 | root = data_dict['vd2']
206 | label_path = osp.join (root, 'train_test')
207 | train_path_label = osp.join(label_path, 'trainlist.txt')
208 |
209 | with open(train_path_label, 'r', encoding='utf-8') as f:
210 | lines = f.readlines()
211 | lines = [line.strip().split(' ') for line in lines]
212 | for line in lines:
213 | fname, pid, _, _ = line
214 | fname = fname + ".jpg"
215 | if pid == -1: continue
216 | fname = root + "image/" + fname
217 | if pid not in all_pids:
218 | all_pids[pid] = len(all_pids)
219 | pid = all_pids[pid]
220 | img_paths.append(fname)
221 | person_ids.append(pid)
222 | ret.append((fname, pid, dd_idx + 1, cam))
223 | dataset_ids.append (dd_idx + 1)
224 |
225 |
226 | total_ids += len (all_pids)
227 | print("ID", dd_idx + 1, "dataset", dataset, "loaded")
228 | print(" subset | # ids | # images")
229 | print(" ---------------------------")
230 | print(" train | {:5d} | {:8d}"
231 | .format(len(all_pids), len(img_paths) - last_img_num))
232 |
233 | print ("whole dataset contains", total_ids, "ids", len (img_paths), "images")
234 | return img_paths, person_ids, np.array(dataset_ids), ret
235 |
236 | def get_id_path_of_data_person (dataset_id, data_dict):
237 | """create data loader for person re-ID datsets"""
238 | img_paths = []
239 | dataset_ids = []
240 | person_ids = []
241 | ret = []
242 | total_ids = 0
243 |
244 | for idx, dataset in enumerate(dataset_id):
245 | all_pids = {}
246 | last_img_num = len (img_paths)
247 | if dataset in ['duke', 'market', 'msmt', 'unreal', 'personx', 'randperson', 'pku']:
248 | root = data_dict[dataset]
249 | fpaths = sorted(glob(osp.join(root, '*.jpg')) + glob(osp.join(root, '*.png')))
250 |
251 | dataset_id_list = [idx + 1 for n in range(len(fpaths))]
252 | dataset_ids.extend (dataset_id_list)
253 |
254 | pattern = re.compile(r'([-\d]+)_c([-\d]+)')
255 | if dataset == 'randperson':
256 | pattern = re.compile(r'([-\d]+)_s([-\d]+)_c([-\d]+)')
257 | if dataset == 'pku':
258 | pattern = re.compile(r'([-\d]+)_([-\d]+)_([-\d]+)')
259 |
260 |
261 | for fpath in fpaths:
262 | fname = root + osp.basename(fpath)
263 | if fname.endswith('.png'):
264 | Image.open(fname).save(fname.split('.')[0] + '.jpg')
265 | fname = fname.split('.')[0] + '.jpg'
266 | pid, cam = 0, 0
267 | if dataset == 'randperson':
268 | pid, sid, cam = map(int, pattern.search(fname).groups())
269 | elif dataset == 'pku':
270 | pid, sid, cnt_num = map(int, pattern.search(fname).groups())
271 | else:
272 | pid, cam = map(int, pattern.search(fname).groups())
273 | if pid == -1: continue
274 | if pid not in all_pids:
275 | all_pids[pid] = len(all_pids)
276 | pid = all_pids[pid]
277 | img_paths.append(fname)
278 | person_ids.append(pid)
279 | ret.append((fname, pid, idx + 1, cam))
280 |
281 | if dataset == 'raid':
282 | root = data_dict[dataset]
283 | raid = h5py.File(os.path.join(root, 'RAiD_4Cams.mat'))
284 | images = raid['dataset']['images']
285 | camID = raid['dataset']['cam']
286 | labels = raid['dataset']['personID']
287 |
288 | dataset_id_list = [idx + 1 for n in range(len(images))]
289 | dataset_ids.extend (dataset_id_list)
290 | images_dst_path = os.path.join(root, "train_all")
291 |
292 | for idx_img in range (len(images)):
293 | np_image = images[idx_img].T
294 | img = Image.fromarray(np_image)
295 | pid = labels[idx_img][0]
296 | if pid not in all_pids:
297 | all_pids[pid] = len(all_pids)
298 | pid = all_pids[pid]
299 | cid = camID[idx_img][0]
300 | if not os.path.isdir(images_dst_path):
301 | os.mkdir(images_dst_path)
302 | fname = images_dst_path + '/' + id_label + '_' + 'c' + \
303 | str(cid) + '_' + str(idx_img).zfill(5) + '.jpg'
304 |
305 | if not os.path.exists(fname):
306 | img.save(os.path.join(img_dst_path, fname))
307 |
308 | img_paths.append(fname)
309 | person_ids.append(pid)
310 | ret.append((fname, pid, idx + 1, cid))
311 |
312 | if dataset == 'cuhk':
313 | '''
314 | download "cuhk03_new_protocol_config_detected.mat" from "https://github.com/zhunzhong07/person-re-ranking/tree/master/evaluation/data/CUHK03"
315 | '''
316 | root = data_dict[dataset]
317 | cuhk03 = h5py.File(os.path.join(root, 'cuhk-03.mat'))
318 | config = scipy.io.loadmat(os.path.join(
319 | root, 'cuhk03_new_protocol_config_detected.mat'))
320 | train_idx = config['train_idx'].flatten()
321 | # gallery_idx = config['gallery_idx'].flatten()
322 | # query_idx = config['query_idx'].flatten()
323 | labels = config['labels'].flatten()
324 | filelist = config['filelist'].flatten()
325 | cam_id = config['camId'].flatten()
326 |
327 | imgs = cuhk03['detected'][0]
328 | cam_imgs = []
329 | for i in range(len(imgs)):
330 | cam_imgs.append(cuhk03[imgs[i]][:].T)
331 |
332 | images_dst_path = os.path.join(root, "train_all")
333 | makeDir(images_dst_path)
334 |
335 | dataset_id_list = [idx + 1 for n in range(len(train_idx))]
336 | dataset_ids.extend (dataset_id_list)
337 |
338 | for i in train_idx:
339 | i -= 1 # Start from 0
340 | file_name = filelist[i][0]
341 | cam_pair_id = int(file_name[0])
342 | cam_label = int(file_name[2: 5])
343 | cam_image_idx = int(file_name[8: 10])
344 |
345 | np_image = cuhk03[cam_imgs[cam_pair_id - 1]
346 | [cam_label - 1][cam_image_idx - 1]][:].T
347 |
348 | unified_cam_id = (cam_pair_id - 1) * 2 + cam_id[i]
349 | img = Image.fromarray(np_image)
350 |
351 | pid = labels[i]
352 | if pid not in all_pids:
353 | all_pids[pid] = len(all_pids)
354 | pid = all_pids[pid]
355 |
356 | id_label = str(labels[i]).zfill(4)
357 | img_dst_path = os.path.join(images_dst_path, id_label)
358 |
359 | # If the dir not exists yet, save this first image to val set
360 | if not os.path.isdir(img_dst_path):
361 | os.mkdir(img_dst_path)
362 |
363 | fname = root + id_label + '_' + 'c' + \
364 | str(unified_cam_id) + '_' + str(cam_image_idx).zfill(2) + '.jpg'
365 | if not os.path.exists(fname):
366 | img.save(os.path.join(img_dst_path, fname))
367 |
368 | img_paths.append(fname)
369 | person_ids.append(pid)
370 | ret.append((fname, pid, idx + 1, unified_cam_id))
371 |
372 | if dataset == 'viper':
373 | root = data_dict[dataset]
374 | images_dir = osp.join(root, 'images/')
375 | makeDir(images_dir)
376 | cameras = [sorted(glob(osp.join(root, 'cam_a', '*.bmp'))),
377 | sorted(glob(osp.join(root, 'cam_b', '*.bmp')))]
378 | assert len(cameras[0]) == len(cameras[1])
379 |
380 | for pid, (cam1, cam2) in enumerate(zip(*cameras)):
381 |
382 | if pid not in all_pids:
383 | all_pids[pid] = len(all_pids)
384 | pid = all_pids[pid]
385 | # view-0
386 | fname = images_dir + '{:08d}_c{:02d}_{:04d}.jpg'.format(pid, 0, 0)
387 | imsave(osp.join(images_dir, fname), imread(cam1))
388 |
389 | img_paths.append(fname)
390 | person_ids.append(pid)
391 | ret.append((fname, pid, idx + 1, 1))
392 |
393 | # view-1
394 | fname = images_dir + '{:08d}_c{:02d}_{:04d}.jpg'.format(pid, 1, 0)
395 | imsave(osp.join(images_dir, fname), imread(cam2))
396 |
397 | img_paths.append(fname)
398 | person_ids.append(pid)
399 | ret.append((fname, pid, idx + 1, 1))
400 |
401 | dataset_id_list = [idx + 1 for n in range(len(img_paths) - last_img_num)]
402 | dataset_ids.extend (dataset_id_list)
403 |
404 |
405 | total_ids += len (all_pids)
406 | print("dataset", dataset, "loaded")
407 | print(" subset | # ids | # images")
408 | print(" ---------------------------")
409 | print(" train | {:5d} | {:8d}"
410 | .format(len(all_pids), len(img_paths) - last_img_num))
411 |
412 | print ("whole dataset contains", total_ids, "ids", len (img_paths), "images")
413 | return img_paths, person_ids, np.array(dataset_ids), ret
--------------------------------------------------------------------------------