├── README.md ├── adsh.py ├── checkpoints └── .gitignore ├── data ├── __init__.py ├── cifar10.py ├── data_loader.py ├── flickr25k.py ├── imagenet.py ├── nus_wide.py └── transform.py ├── logs └── .gitignore ├── models ├── __init__.py ├── adsh_loss.py └── alexnet.py ├── run.py └── utils ├── __init__.py └── evaluate.py /README.md: -------------------------------------------------------------------------------- 1 | # Asymmetric Deep Supervised Hashing 2 | 3 | ## REQUIREMENTS 4 | 1. pytorch >= 1.0 5 | 2. loguru 6 | 7 | ## DATASETS 8 | 1. [CIFAR-10](http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz) 9 | 2. [NUS-WIDE](https://pan.baidu.com/s/1f9mKXE2T8XpIq8p7y8Fa6Q) Password: uhr3 10 | 11 | ## USAGE 12 | ``` 13 | usage: run.py [-h] [--dataset DATASET] [--root ROOT] [--batch-size BATCH_SIZE] 14 | [--lr LR] [--code-length CODE_LENGTH] [--max-iter MAX_ITER] 15 | [--max-epoch MAX_EPOCH] [--num-query NUM_QUERY] 16 | [--num-train NUM_TRAIN] [--num-workers NUM_WORKERS] 17 | [--topk TOPK] [--gpu GPU] [--gamma GAMMA] 18 | 19 | ADSH_PyTorch 20 | 21 | optional arguments: 22 | -h, --help show this help message and exit 23 | --dataset DATASET Dataset name. 24 | --root ROOT Path of dataset 25 | --batch-size BATCH_SIZE 26 | Batch size.(default: 64) 27 | --lr LR Learning rate.(default: 1e-4) 28 | --code-length CODE_LENGTH 29 | Binary hash code length.(default: 12) 30 | --max-iter MAX_ITER Number of iterations.(default: 50) 31 | --max-epoch MAX_EPOCH 32 | Number of epochs.(default: 3) 33 | --num-query NUM_QUERY 34 | Number of query data points.(default: 1000) 35 | --num-train NUM_TRAIN 36 | Number of training data points.(default: 2000) 37 | --num-workers NUM_WORKERS 38 | Number of loading data threads.(default: 0) 39 | --topk TOPK Calculate map of top k.(default: all) 40 | --gpu GPU Using gpu.(default: False) 41 | --gamma GAMMA Hyper-parameter.(default: 200) 42 | ``` 43 | 44 | ## EXPERIMENTS 45 | cifar10: 1000 query images, 2000 sampling images. 46 | 47 | nus-wide: Top 21 classes, 2100 query images, 2000 sampling images. 48 | 49 | model: Alexnet 50 | 51 | | | 12 bits | 24 bits | 32 bits | 48 bits 52 | :-: | :-: | :-: | :-: | :-: 53 | cifar-10 MAP@ALL | 0.9075 | 0.9047 | 0.9116 | 0.9045 54 | nus-wide MAP@5000 | 0.8698 | 0.9022 | 0.9079 | 0.9133 55 | -------------------------------------------------------------------------------- /adsh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import os 4 | import time 5 | import models.alexnet as alexnet 6 | import utils.evaluate as evaluate 7 | 8 | from loguru import logger 9 | from models.adsh_loss import ADSH_Loss 10 | from data.data_loader import sample_dataloader 11 | 12 | 13 | def train( 14 | query_dataloader, 15 | retrieval_dataloader, 16 | code_length, 17 | device, 18 | lr, 19 | max_iter, 20 | max_epoch, 21 | num_samples, 22 | batch_size, 23 | root, 24 | dataset, 25 | gamma, 26 | topk, 27 | ): 28 | """ 29 | Training model. 30 | 31 | Args 32 | query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. 33 | code_length(int): Hashing code length. 34 | device(torch.device): GPU or CPU. 35 | lr(float): Learning rate. 36 | max_iter(int): Number of iterations. 37 | max_epoch(int): Number of epochs. 38 | num_train(int): Number of sampling training data points. 39 | batch_size(int): Batch size. 40 | root(str): Path of dataset. 41 | dataset(str): Dataset name. 42 | gamma(float): Hyper-parameters. 43 | topk(int): Topk k map. 44 | 45 | Returns 46 | mAP(float): Mean Average Precision. 47 | """ 48 | # Initialization 49 | model = alexnet.load_model(code_length).to(device) 50 | optimizer = optim.Adam( 51 | model.parameters(), 52 | lr=lr, 53 | weight_decay=1e-5, 54 | ) 55 | criterion = ADSH_Loss(code_length, gamma) 56 | 57 | num_retrieval = len(retrieval_dataloader.dataset) 58 | U = torch.zeros(num_samples, code_length).to(device) 59 | B = torch.randn(num_retrieval, code_length).to(device) 60 | retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(device) 61 | 62 | start = time.time() 63 | for it in range(max_iter): 64 | iter_start = time.time() 65 | # Sample training data for cnn learning 66 | train_dataloader, sample_index = sample_dataloader(retrieval_dataloader, num_samples, batch_size, root, dataset) 67 | 68 | # Create Similarity matrix 69 | train_targets = train_dataloader.dataset.get_onehot_targets().to(device) 70 | S = (train_targets @ retrieval_targets.t() > 0).float() 71 | S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) 72 | 73 | # Soft similarity matrix, benefit to converge 74 | r = S.sum() / (1 - S).sum() 75 | S = S * (1 + r) - r 76 | 77 | # Training CNN model 78 | for epoch in range(max_epoch): 79 | for batch, (data, targets, index) in enumerate(train_dataloader): 80 | data, targets, index = data.to(device), targets.to(device), index.to(device) 81 | optimizer.zero_grad() 82 | 83 | F = model(data) 84 | U[index, :] = F.data 85 | cnn_loss = criterion(F, B, S[index, :], sample_index[index]) 86 | 87 | cnn_loss.backward() 88 | optimizer.step() 89 | 90 | # Update B 91 | expand_U = torch.zeros(B.shape).to(device) 92 | expand_U[sample_index, :] = U 93 | B = solve_dcc(B, U, expand_U, S, code_length, gamma) 94 | 95 | # Total loss 96 | iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma) 97 | logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, max_iter, iter_loss, time.time()-iter_start)) 98 | logger.info('[Training time:{:.2f}]'.format(time.time()-start)) 99 | 100 | # Evaluate 101 | query_code = generate_code(model, query_dataloader, code_length, device) 102 | mAP = evaluate.mean_average_precision( 103 | query_code.to(device), 104 | B, 105 | query_dataloader.dataset.get_onehot_targets().to(device), 106 | retrieval_targets, 107 | device, 108 | topk, 109 | ) 110 | 111 | # Save checkpoints 112 | torch.save(query_code.cpu(), os.path.join('checkpoints', 'query_code.t')) 113 | torch.save(B.cpu(), os.path.join('checkpoints', 'database_code.t')) 114 | torch.save(query_dataloader.dataset.get_onehot_targets, os.path.join('checkpoints', 'query_targets.t')) 115 | torch.save(retrieval_targets.cpu(), os.path.join('checkpoints', 'database_targets.t')) 116 | torch.save(model.cpu(), os.path.join('checkpoints', 'model.t')) 117 | 118 | return mAP 119 | 120 | 121 | def solve_dcc(B, U, expand_U, S, code_length, gamma): 122 | """ 123 | Solve DCC problem. 124 | """ 125 | Q = (code_length * S).t() @ U + gamma * expand_U 126 | 127 | for bit in range(code_length): 128 | q = Q[:, bit] 129 | u = U[:, bit] 130 | B_prime = torch.cat((B[:, :bit], B[:, bit+1:]), dim=1) 131 | U_prime = torch.cat((U[:, :bit], U[:, bit+1:]), dim=1) 132 | 133 | B[:, bit] = (q.t() - B_prime @ U_prime.t() @ u.t()).sign() 134 | 135 | return B 136 | 137 | 138 | def calc_loss(U, B, S, code_length, omega, gamma): 139 | """ 140 | Calculate loss. 141 | """ 142 | hash_loss = ((code_length * S - U @ B.t()) ** 2).sum() 143 | quantization_loss = ((U - B[omega, :]) ** 2).sum() 144 | loss = (hash_loss + gamma * quantization_loss) / (U.shape[0] * B.shape[0]) 145 | 146 | return loss.item() 147 | 148 | 149 | def generate_code(model, dataloader, code_length, device): 150 | """ 151 | Generate hash code 152 | 153 | Args 154 | dataloader(torch.utils.data.DataLoader): Data loader. 155 | code_length(int): Hash code length. 156 | device(torch.device): Using gpu or cpu. 157 | 158 | Returns 159 | code(torch.Tensor): Hash code. 160 | """ 161 | model.eval() 162 | with torch.no_grad(): 163 | N = len(dataloader.dataset) 164 | code = torch.zeros([N, code_length]) 165 | for data, _, index in dataloader: 166 | data = data.to(device) 167 | hash_code = model(data) 168 | code[index, :] = hash_code.sign().cpu() 169 | 170 | model.train() 171 | return code 172 | -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tree-Shu-Zhao/ADSH_PyTorch/50df03f1380d68b88b9c6965d9fb93ce0b459328/checkpoints/.gitignore -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tree-Shu-Zhao/ADSH_PyTorch/50df03f1380d68b88b9c6965d9fb93ce0b459328/data/__init__.py -------------------------------------------------------------------------------- /data/cifar10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | import os 5 | import sys 6 | import pickle 7 | 8 | from torch.utils.data.dataloader import DataLoader 9 | from torch.utils.data.dataset import Dataset 10 | 11 | from data.transform import train_transform, query_transform, Onehot, encode_onehot 12 | 13 | 14 | def load_data(root, num_query, num_train, batch_size, num_workers): 15 | """ 16 | Load cifar10 dataset. 17 | 18 | Args 19 | root(str): Path of dataset. 20 | num_query(int): Number of query data points. 21 | num_train(int): Number of training data points. 22 | batch_size(int): Batch size. 23 | num_workers(int): Number of loading data threads. 24 | 25 | Returns 26 | query_dataloader, train_dataloader, retrieval_dataloader(torch.evaluate.data.DataLoader): Data loader. 27 | """ 28 | CIFAR10.init(root, num_query, num_train) 29 | query_dataset = CIFAR10('query', transform=query_transform(), target_transform=Onehot()) 30 | train_dataset = CIFAR10('train', transform=train_transform(), target_transform=None) 31 | retrieval_dataset = CIFAR10('database', transform=query_transform(), target_transform=Onehot()) 32 | 33 | query_dataloader = DataLoader( 34 | query_dataset, 35 | batch_size=batch_size, 36 | pin_memory=True, 37 | num_workers=num_workers, 38 | ) 39 | train_dataloader = DataLoader( 40 | train_dataset, 41 | shuffle=True, 42 | batch_size=batch_size, 43 | pin_memory=True, 44 | num_workers=num_workers, 45 | ) 46 | retrieval_dataloader = DataLoader( 47 | retrieval_dataset, 48 | batch_size=batch_size, 49 | pin_memory=True, 50 | num_workers=num_workers, 51 | ) 52 | 53 | return query_dataloader, train_dataloader, retrieval_dataloader 54 | 55 | 56 | class CIFAR10(Dataset): 57 | """ 58 | Cifar10 dataset. 59 | """ 60 | @staticmethod 61 | def init(root, num_query, num_train): 62 | data_list = ['data_batch_1', 63 | 'data_batch_2', 64 | 'data_batch_3', 65 | 'data_batch_4', 66 | 'data_batch_5', 67 | 'test_batch', 68 | ] 69 | base_folder = 'cifar-10-batches-py' 70 | 71 | data = [] 72 | targets = [] 73 | 74 | for file_name in data_list: 75 | file_path = os.path.join(root, base_folder, file_name) 76 | with open(file_path, 'rb') as f: 77 | if sys.version_info[0] == 2: 78 | entry = pickle.load(f) 79 | else: 80 | entry = pickle.load(f, encoding='latin1') 81 | data.append(entry['data']) 82 | if 'labels' in entry: 83 | targets.extend(entry['labels']) 84 | else: 85 | targets.extend(entry['fine_labels']) 86 | 87 | data = np.vstack(data).reshape(-1, 3, 32, 32) 88 | data = data.transpose((0, 2, 3, 1)) # convert to HWC 89 | targets = np.array(targets) 90 | 91 | # Sort by class 92 | sort_index = targets.argsort() 93 | data = data[sort_index, :] 94 | targets = targets[sort_index] 95 | 96 | # (num_query / number of class) query images per class 97 | # (num_train / number of class) train images per class 98 | query_per_class = num_query // 10 99 | train_per_class = num_train // 10 100 | 101 | # Permutate index (range 0 - 6000 per class) 102 | perm_index = np.random.permutation(data.shape[0] // 10) 103 | query_index = perm_index[:query_per_class] 104 | train_index = perm_index[query_per_class: query_per_class + train_per_class] 105 | 106 | query_index = np.tile(query_index, 10) 107 | train_index = np.tile(train_index, 10) 108 | inc_index = np.array([i * (data.shape[0] // 10) for i in range(10)]) 109 | query_index = query_index + inc_index.repeat(query_per_class) 110 | train_index = train_index + inc_index.repeat(train_per_class) 111 | list_query_index = [i for i in query_index] 112 | retrieval_index = np.array(list(set(range(data.shape[0])) - set(list_query_index)), dtype=np.int) 113 | 114 | # Split data, targets 115 | CIFAR10.QUERY_IMG = data[query_index, :] 116 | CIFAR10.QUERY_TARGET = targets[query_index] 117 | CIFAR10.TRAIN_IMG = data[train_index, :] 118 | CIFAR10.TRAIN_TARGET = targets[train_index] 119 | CIFAR10.RETRIEVAL_IMG = data[retrieval_index, :] 120 | CIFAR10.RETRIEVAL_TARGET = targets[retrieval_index] 121 | 122 | def __init__(self, mode='train', 123 | transform=None, target_transform=None, 124 | ): 125 | self.transform = transform 126 | self.target_transform = target_transform 127 | 128 | if mode == 'train': 129 | self.data = CIFAR10.TRAIN_IMG 130 | self.targets = CIFAR10.TRAIN_TARGET 131 | elif mode == 'query': 132 | self.data = CIFAR10.QUERY_IMG 133 | self.targets = CIFAR10.QUERY_TARGET 134 | else: 135 | self.data = CIFAR10.RETRIEVAL_IMG 136 | self.targets = CIFAR10.RETRIEVAL_TARGET 137 | 138 | self.onehot_targets = encode_onehot(self.targets, 10) 139 | 140 | def __getitem__(self, index): 141 | """ 142 | Args: 143 | index (int): Index 144 | 145 | Returns: 146 | tuple: (image, target, index) where target is index of the target class. 147 | """ 148 | img, target = self.data[index], self.targets[index] 149 | 150 | # doing this so that it is consistent with all other datasets 151 | # to return a PIL Image 152 | img = Image.fromarray(img) 153 | 154 | if self.transform is not None: 155 | img = self.transform(img) 156 | 157 | if self.target_transform is not None: 158 | target = self.target_transform(target) 159 | 160 | return img, target, index 161 | 162 | def __len__(self): 163 | return len(self.data) 164 | 165 | def get_onehot_targets(self): 166 | """ 167 | Return one-hot encoding targets. 168 | """ 169 | return torch.from_numpy(self.onehot_targets).float() 170 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | from torch.utils.data.dataloader import DataLoader 5 | from torch.utils.data.dataset import Dataset 6 | from PIL import Image, ImageFile 7 | 8 | import data.cifar10 as cifar10 9 | import data.nus_wide as nuswide 10 | import data.flickr25k as flickr25k 11 | import data.imagenet as imagenet 12 | 13 | from data.transform import train_transform, encode_onehot 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def load_data(dataset, root, num_query, num_train, batch_size, num_workers): 19 | """ 20 | Load dataset. 21 | 22 | Args 23 | dataset(str): Dataset name. 24 | root(str): Path of dataset. 25 | num_query(int): Number of query data points. 26 | num_train(int): Number of training data points. 27 | num_workers(int): Number of loading data threads. 28 | 29 | Returns 30 | query_dataloader, train_dataloader, retrieval_dataloader(torch.utils.data.DataLoader): Data loader. 31 | """ 32 | if dataset == 'cifar-10': 33 | query_dataloader, train_dataloader, retrieval_dataloader = cifar10.load_data(root, 34 | num_query, 35 | num_train, 36 | batch_size, 37 | num_workers, 38 | ) 39 | elif dataset == 'nus-wide-tc10': 40 | query_dataloader, train_dataloader, retrieval_dataloader = nuswide.load_data(10, 41 | root, 42 | num_query, 43 | num_train, 44 | batch_size, 45 | num_workers, 46 | ) 47 | elif dataset == 'nus-wide-tc21': 48 | query_dataloader, train_dataloader, retrieval_dataloader = nuswide.load_data(21, 49 | root, 50 | num_query, 51 | num_train, 52 | batch_size, 53 | num_workers 54 | ) 55 | elif dataset == 'flickr25k': 56 | query_dataloader, train_dataloader, retrieval_dataloader = flickr25k.load_data(root, 57 | num_query, 58 | num_train, 59 | batch_size, 60 | num_workers, 61 | ) 62 | elif dataset == 'imagenet': 63 | query_dataloader, train_dataloader, retrieval_dataloader = imagenet.load_data(root, 64 | batch_size, 65 | num_workers, 66 | ) 67 | else: 68 | raise ValueError("Invalid dataset name!") 69 | 70 | return query_dataloader, train_dataloader, retrieval_dataloader 71 | 72 | 73 | def sample_dataloader(dataloader, num_samples, batch_size, root, dataset): 74 | """ 75 | Sample data from dataloder. 76 | 77 | Args 78 | dataloader (torch.utils.data.DataLoader): Dataloader. 79 | num_samples (int): Number of samples. 80 | batch_size (int): Batch size. 81 | root (str): Path of dataset. 82 | sample_index (int): Sample index. 83 | dataset(str): Dataset name. 84 | 85 | Returns 86 | sample_dataloader (torch.utils.data.DataLoader): Sample dataloader. 87 | """ 88 | data = dataloader.dataset.data 89 | targets = dataloader.dataset.targets 90 | 91 | sample_index = torch.randperm(data.shape[0])[:num_samples] 92 | data = data[sample_index] 93 | targets = targets[sample_index] 94 | sample = wrap_data(data, targets, batch_size, root, dataset) 95 | 96 | return sample, sample_index 97 | 98 | 99 | def wrap_data(data, targets, batch_size, root, dataset): 100 | """ 101 | Wrap data into dataloader. 102 | 103 | Args 104 | data (np.ndarray): Data. 105 | targets (np.ndarray): Targets. 106 | batch_size (int): Batch size. 107 | root (str): Path of dataset. 108 | dataset(str): Dataset name. 109 | 110 | Returns 111 | dataloader (torch.utils.data.dataloader): Data loader. 112 | """ 113 | class MyDataset(Dataset): 114 | def __init__(self, data, targets, root, dataset): 115 | self.data = data 116 | self.targets = targets 117 | self.root = root 118 | self.transform = train_transform() 119 | self.dataset = dataset 120 | if dataset == 'cifar-10': 121 | self.onehot_targets = encode_onehot(self.targets, 10) 122 | else: 123 | self.onehot_targets = self.targets 124 | 125 | def __getitem__(self, index): 126 | if self.dataset == 'cifar-10': 127 | img = Image.fromarray(self.data[index]) 128 | if self.transform is not None: 129 | img = self.transform(img) 130 | else: 131 | img = Image.open(os.path.join(self.root, self.data[index])).convert('RGB') 132 | img = self.transform(img) 133 | return img, self.targets[index], index 134 | 135 | def __len__(self): 136 | return self.data.shape[0] 137 | 138 | def get_onehot_targets(self): 139 | """ 140 | Return one-hot encoding targets. 141 | """ 142 | return torch.from_numpy(self.onehot_targets).float() 143 | 144 | dataset = MyDataset(data, targets, root, dataset) 145 | dataloader = DataLoader( 146 | dataset, 147 | batch_size=batch_size, 148 | shuffle=True, 149 | ) 150 | 151 | return dataloader 152 | -------------------------------------------------------------------------------- /data/flickr25k.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | from PIL import Image, ImageFile 6 | from torch.utils.data.dataset import Dataset 7 | from torch.utils.data.dataloader import DataLoader 8 | 9 | from data.transform import train_transform, query_transform 10 | 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | 14 | def load_data(root, num_query, num_train, batch_size, num_workers): 15 | """ 16 | Loading nus-wide dataset. 17 | 18 | Args: 19 | root(str): Path of image files. 20 | num_query(int): Number of query data. 21 | num_train(int): Number of training data. 22 | batch_size(int): Batch size. 23 | num_workers(int): Number of loading data threads. 24 | 25 | Returns 26 | query_dataloader, train_dataloader, retrieval_dataloader (torch.evaluate.data.DataLoader): Data loader. 27 | """ 28 | 29 | Flickr25k.init(root, num_query, num_train) 30 | query_dataset = Flickr25k(root, 'query', query_transform()) 31 | train_dataset = Flickr25k(root, 'train', train_transform()) 32 | retrieval_dataset = Flickr25k(root, 'retrieval', query_transform()) 33 | 34 | query_dataloader = DataLoader( 35 | query_dataset, 36 | batch_size=batch_size, 37 | pin_memory=True, 38 | num_workers=num_workers, 39 | ) 40 | train_dataloader = DataLoader( 41 | train_dataset, 42 | batch_size=batch_size, 43 | shuffle=True, 44 | pin_memory=True, 45 | num_workers=num_workers, 46 | ) 47 | retrieval_dataloader = DataLoader( 48 | retrieval_dataset, 49 | batch_size=batch_size, 50 | pin_memory=True, 51 | num_workers=num_workers, 52 | ) 53 | 54 | return query_dataloader, train_dataloader, retrieval_dataloader 55 | 56 | 57 | class Flickr25k(Dataset): 58 | """ 59 | Flicker 25k dataset. 60 | 61 | Args 62 | root(str): Path of dataset. 63 | mode(str, 'train', 'query', 'retrieval'): Mode of dataset. 64 | transform(callable, optional): Transform images. 65 | """ 66 | def __init__(self, root, mode, transform=None): 67 | self.root = root 68 | self.transform = transform 69 | 70 | if mode == 'train': 71 | self.data = Flickr25k.TRAIN_DATA 72 | self.targets = Flickr25k.TRAIN_TARGETS 73 | elif mode == 'query': 74 | self.data = Flickr25k.QUERY_DATA 75 | self.targets = Flickr25k.QUERY_TARGETS 76 | elif mode == 'retrieval': 77 | self.data = Flickr25k.RETRIEVAL_DATA 78 | self.targets = Flickr25k.RETRIEVAL_TARGETS 79 | else: 80 | raise ValueError(r'Invalid arguments: mode, can\'t load dataset!') 81 | 82 | def __getitem__(self, index): 83 | img = Image.open(os.path.join(self.root, 'images', self.data[index])).convert('RGB') 84 | if self.transform is not None: 85 | img = self.transform(img) 86 | return img, self.targets[index], index 87 | 88 | def __len__(self): 89 | return self.data.shape[0] 90 | 91 | def get_onehot_targets(self): 92 | return torch.from_numpy(self.targets).float() 93 | 94 | @staticmethod 95 | def init(root, num_query, num_train): 96 | """ 97 | Initialize dataset 98 | 99 | Args 100 | root(str): Path of image files. 101 | num_query(int): Number of query data. 102 | num_train(int): Number of training data. 103 | """ 104 | # Load dataset 105 | img_txt_path = os.path.join(root, 'img.txt') 106 | targets_txt_path = os.path.join(root, 'targets.txt') 107 | 108 | # Read files 109 | with open(img_txt_path, 'r') as f: 110 | data = np.array([i.strip() for i in f]) 111 | targets = np.loadtxt(targets_txt_path, dtype=np.int64) 112 | 113 | # Split dataset 114 | perm_index = np.random.permutation(data.shape[0]) 115 | query_index = perm_index[:num_query] 116 | train_index = perm_index[num_query: num_query + num_train] 117 | retrieval_index = perm_index[num_query:] 118 | 119 | Flickr25k.QUERY_DATA = data[query_index] 120 | Flickr25k.QUERY_TARGETS = targets[query_index, :] 121 | 122 | Flickr25k.TRAIN_DATA = data[train_index] 123 | Flickr25k.TRAIN_TARGETS = targets[train_index, :] 124 | 125 | Flickr25k.RETRIEVAL_DATA = data[retrieval_index] 126 | Flickr25k.RETRIEVAL_TARGETS = targets[retrieval_index, :] 127 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | import os 4 | 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data.dataset import Dataset 7 | from PIL import Image 8 | 9 | 10 | def load_data(root, batch_size, workers): 11 | """ 12 | Load imagenet dataset 13 | 14 | Args: 15 | root (str): Path of imagenet dataset. 16 | batch_size (int): Number of samples in one batch. 17 | workers (int): Number of data loading threads. 18 | 19 | Returns: 20 | train_loader (torch.utils.data.DataLoader): Training dataset loader. 21 | query_loader (torch.utils.data.DataLoader): Query dataset loader. 22 | val_loader (torch.utils.data.DataLoader): Validation dataset loader. 23 | """ 24 | # Data transform 25 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 26 | std=[0.229, 0.224, 0.225]) 27 | train_transform = transforms.Compose([ 28 | transforms.RandomResizedCrop(224), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor(), 31 | normalize, 32 | ]) 33 | query_val_init_transform = transforms.Compose([ 34 | transforms.Resize(256), 35 | transforms.CenterCrop(224), 36 | transforms.ToTensor(), 37 | normalize, 38 | ]) 39 | 40 | # Construct data loader 41 | traindir = os.path.join(root, 'train') 42 | valdir = os.path.join(root, 'val') 43 | 44 | train_dataset = ImagenetDataset( 45 | traindir, 46 | transform=train_transform, 47 | ) 48 | 49 | train_loader = DataLoader( 50 | train_dataset, 51 | batch_size=batch_size, 52 | shuffle=True, 53 | num_workers=workers, 54 | pin_memory=True, 55 | ) 56 | 57 | query_dataset = ImagenetDataset( 58 | valdir, 59 | transform=query_val_init_transform, 60 | num_samples=5000, 61 | ) 62 | 63 | query_loader = DataLoader( 64 | query_dataset, 65 | batch_size=batch_size, 66 | num_workers=workers, 67 | pin_memory=True, 68 | ) 69 | 70 | val_dataset = ImagenetDataset( 71 | valdir, 72 | transform=query_val_init_transform, 73 | ) 74 | 75 | val_loader = DataLoader( 76 | val_dataset, 77 | batch_size=batch_size, 78 | shuffle=False, 79 | num_workers=workers, 80 | pin_memory=True, 81 | ) 82 | 83 | return query_loader, train_loader, val_loader 84 | 85 | 86 | class ImagenetDataset(Dataset): 87 | classes = None 88 | class_to_idx = None 89 | 90 | def __init__(self, root, transform=None, num_samples=None): 91 | self.root = root 92 | self.transform = transform 93 | self.imgs = [] 94 | self.targets = [] 95 | 96 | # Assume file alphabet order is the class order 97 | if ImagenetDataset.class_to_idx is None: 98 | ImagenetDataset.classes, ImagenetDataset.class_to_idx = self._find_classes(root) 99 | 100 | for i, cl in enumerate(ImagenetDataset.classes): 101 | cur_class = os.path.join(self.root, cl) 102 | files = os.listdir(cur_class) 103 | if num_samples is not None: 104 | num_per_class = num_samples // len(ImagenetDataset.classes) 105 | sample_files = files[: num_per_class] 106 | sample_files = [os.path.join(cur_class, file) for file in sample_files] 107 | 108 | self.imgs.extend(sample_files) 109 | self.targets.extend([ImagenetDataset.class_to_idx[cl] for i in range(num_per_class)]) 110 | else: 111 | files = [os.path.join(cur_class, i) for i in files] 112 | self.imgs.extend(files) 113 | self.targets.extend([ImagenetDataset.class_to_idx[cl] for i in range(len(files))]) 114 | 115 | def __len__(self): 116 | return len(self.imgs) 117 | 118 | def __getitem__(self, item): 119 | img, target = self.imgs[item], self.targets[item] 120 | 121 | img = Image.open(img).convert('RGB') 122 | 123 | if self.transform is not None: 124 | img = self.transform(img) 125 | return img, target, item 126 | 127 | def _find_classes(self, dir): 128 | """ 129 | Finds the class folders in a dataset. 130 | 131 | Args: 132 | dir (string): Root directory path. 133 | 134 | Returns: 135 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 136 | 137 | Ensures: 138 | No class is a subdirectory of another. 139 | """ 140 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 141 | classes.sort() 142 | class_to_idx = {classes[i]: i for i in range(len(classes))} 143 | return classes, class_to_idx 144 | -------------------------------------------------------------------------------- /data/nus_wide.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image, ImageFile 6 | from torch.utils.data.dataset import Dataset 7 | from torch.utils.data.dataloader import DataLoader 8 | 9 | from data.transform import train_transform, query_transform 10 | 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | 14 | def load_data(tc, root, num_query, num_train, batch_size, num_workers, 15 | ): 16 | """ 17 | Loading nus-wide dataset. 18 | 19 | Args: 20 | tc(int): Top class. 21 | root(str): Path of image files. 22 | num_query(int): Number of query data. 23 | num_train(int): Number of training data. 24 | batch_size(int): Batch size. 25 | num_workers(int): Number of loading data threads. 26 | 27 | Returns 28 | query_dataloader, train_dataloader, retrieval_dataloader(torch.evaluate.data.DataLoader): Data loader. 29 | """ 30 | if tc == 21: 31 | query_dataset = NusWideDatasetTC21( 32 | root, 33 | 'test_img.txt', 34 | 'test_label_onehot.txt', 35 | transform=query_transform(), 36 | ) 37 | 38 | train_dataset = NusWideDatasetTC21( 39 | root, 40 | 'database_img.txt', 41 | 'database_label_onehot.txt', 42 | transform=train_transform(), 43 | train=True, 44 | num_train=num_train, 45 | ) 46 | 47 | retrieval_dataset = NusWideDatasetTC21( 48 | root, 49 | 'database_img.txt', 50 | 'database_label_onehot.txt', 51 | transform=query_transform(), 52 | ) 53 | elif tc == 10: 54 | NusWideDatasetTc10.init(root, num_query, num_train) 55 | query_dataset = NusWideDatasetTc10(root, 'query', query_transform()) 56 | train_dataset = NusWideDatasetTc10(root, 'train', train_transform()) 57 | retrieval_dataset = NusWideDatasetTc10(root, 'retrieval', query_transform()) 58 | 59 | query_dataloader = DataLoader( 60 | query_dataset, 61 | batch_size=batch_size, 62 | pin_memory=True, 63 | num_workers=num_workers, 64 | ) 65 | train_dataloader = DataLoader( 66 | train_dataset, 67 | batch_size=batch_size, 68 | shuffle=True, 69 | pin_memory=True, 70 | num_workers=num_workers, 71 | ) 72 | retrieval_dataloader = DataLoader( 73 | retrieval_dataset, 74 | batch_size=batch_size, 75 | pin_memory=True, 76 | num_workers=num_workers, 77 | ) 78 | 79 | return query_dataloader, train_dataloader, retrieval_dataloader 80 | 81 | 82 | class NusWideDatasetTc10(Dataset): 83 | """ 84 | Nus-wide dataset, 10 classes. 85 | 86 | Args 87 | root(str): Path of dataset. 88 | mode(str, 'train', 'query', 'retrieval'): Mode of dataset. 89 | transform(callable, optional): Transform images. 90 | """ 91 | def __init__(self, root, mode, transform=None): 92 | self.root = root 93 | self.transform = transform 94 | 95 | if mode == 'train': 96 | self.data = NusWideDatasetTc10.TRAIN_DATA 97 | self.targets = NusWideDatasetTc10.TRAIN_TARGETS 98 | elif mode == 'query': 99 | self.data = NusWideDatasetTc10.QUERY_DATA 100 | self.targets = NusWideDatasetTc10.QUERY_TARGETS 101 | elif mode == 'retrieval': 102 | self.data = NusWideDatasetTc10.RETRIEVAL_DATA 103 | self.targets = NusWideDatasetTc10.RETRIEVAL_TARGETS 104 | else: 105 | raise ValueError(r'Invalid arguments: mode, can\'t load dataset!') 106 | 107 | def __getitem__(self, index): 108 | img = Image.open(os.path.join(self.root, self.data[index])).convert('RGB') 109 | if self.transform is not None: 110 | img = self.transform(img) 111 | return img, self.targets[index], index 112 | 113 | def __len__(self): 114 | return self.data.shape[0] 115 | 116 | def get_targets(self): 117 | return torch.from_numpy(self.targets).float() 118 | 119 | @staticmethod 120 | def init(root, num_query, num_train): 121 | """ 122 | Initialize dataset. 123 | 124 | Args 125 | root(str): Path of image files. 126 | num_query(int): Number of query data. 127 | num_train(int): Number of training data. 128 | """ 129 | # Load dataset 130 | img_txt_path = os.path.join(root, 'img_tc10.txt') 131 | targets_txt_path = os.path.join(root, 'targets_onehot_tc10.txt') 132 | 133 | # Read files 134 | with open(img_txt_path, 'r') as f: 135 | data = np.array([i.strip() for i in f]) 136 | targets = np.loadtxt(targets_txt_path, dtype=np.int64) 137 | 138 | # Split dataset 139 | perm_index = np.random.permutation(data.shape[0]) 140 | query_index = perm_index[:num_query] 141 | train_index = perm_index[num_query: num_query + num_train] 142 | retrieval_index = perm_index[num_query:] 143 | 144 | NusWideDatasetTc10.QUERY_DATA = data[query_index] 145 | NusWideDatasetTc10.QUERY_TARGETS = targets[query_index, :] 146 | 147 | NusWideDatasetTc10.TRAIN_DATA = data[train_index] 148 | NusWideDatasetTc10.TRAIN_TARGETS = targets[train_index, :] 149 | 150 | NusWideDatasetTc10.RETRIEVAL_DATA = data[retrieval_index] 151 | NusWideDatasetTc10.RETRIEVAL_TARGETS = targets[retrieval_index, :] 152 | 153 | 154 | class NusWideDatasetTC21(Dataset): 155 | """ 156 | Nus-wide dataset, 21 classes. 157 | 158 | Args 159 | root(str): Path of image files. 160 | img_txt(str): Path of txt file containing image file name. 161 | label_txt(str): Path of txt file containing image label. 162 | transform(callable, optional): Transform images. 163 | train(bool, optional): Return training dataset. 164 | num_train(int, optional): Number of training data. 165 | """ 166 | def __init__(self, root, img_txt, label_txt, transform=None, train=None, num_train=None): 167 | self.root = root 168 | self.transform = transform 169 | 170 | img_txt_path = os.path.join(root, img_txt) 171 | label_txt_path = os.path.join(root, label_txt) 172 | 173 | # Read files 174 | with open(img_txt_path, 'r') as f: 175 | self.data = np.array([i.strip() for i in f]) 176 | self.targets = np.loadtxt(label_txt_path, dtype=np.float32) 177 | 178 | # Sample training dataset 179 | if train is True: 180 | perm_index = np.random.permutation(len(self.data))[:num_train] 181 | self.data = self.data[perm_index] 182 | self.targets = self.targets[perm_index] 183 | 184 | def __getitem__(self, index): 185 | img = Image.open(os.path.join(self.root, self.data[index])).convert('RGB') 186 | if self.transform is not None: 187 | img = self.transform(img) 188 | 189 | return img, self.targets[index], index 190 | 191 | def __len__(self): 192 | return len(self.data) 193 | 194 | def get_onehot_targets(self): 195 | return torch.from_numpy(self.targets).float() 196 | -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | 5 | 6 | def encode_onehot(labels, num_classes=10): 7 | """ 8 | one-hot labels 9 | 10 | Args: 11 | labels (numpy.ndarray): labels. 12 | num_classes (int): Number of classes. 13 | 14 | Returns: 15 | onehot_labels (numpy.ndarray): one-hot labels. 16 | """ 17 | onehot_labels = np.zeros((len(labels), num_classes)) 18 | 19 | for i in range(len(labels)): 20 | onehot_labels[i, labels[i]] = 1 21 | 22 | return onehot_labels 23 | 24 | 25 | class Onehot(object): 26 | def __call__(self, sample, num_classes=10): 27 | target_onehot = torch.zeros(num_classes) 28 | target_onehot[sample] = 1 29 | 30 | return target_onehot 31 | 32 | 33 | def train_transform(): 34 | """ 35 | Training images transform. 36 | 37 | Args 38 | None 39 | 40 | Returns 41 | transform(torchvision.transforms): transform 42 | """ 43 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 44 | std=[0.229, 0.224, 0.225]) 45 | return transforms.Compose([ 46 | transforms.RandomResizedCrop(224), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | normalize, 50 | ]) 51 | 52 | 53 | def query_transform(): 54 | """ 55 | Query images transform. 56 | 57 | Args 58 | None 59 | 60 | Returns 61 | transform(torchvision.transforms): transform 62 | """ 63 | # Data transform 64 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 65 | std=[0.229, 0.224, 0.225]) 66 | return transforms.Compose([ 67 | transforms.Resize(256), 68 | transforms.CenterCrop(224), 69 | transforms.ToTensor(), 70 | normalize, 71 | ]) 72 | -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tree-Shu-Zhao/ADSH_PyTorch/50df03f1380d68b88b9c6965d9fb93ce0b459328/logs/.gitignore -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tree-Shu-Zhao/ADSH_PyTorch/50df03f1380d68b88b9c6965d9fb93ce0b459328/models/__init__.py -------------------------------------------------------------------------------- /models/adsh_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ADSH_Loss(nn.Module): 5 | """ 6 | Loss function of ADSH 7 | 8 | Args: 9 | code_length(int): Hashing code length. 10 | gamma(float): Hyper-parameter. 11 | """ 12 | def __init__(self, code_length, gamma): 13 | super(ADSH_Loss, self).__init__() 14 | self.code_length = code_length 15 | self.gamma = gamma 16 | 17 | def forward(self, F, B, S, omega): 18 | hash_loss = ((self.code_length * S - F @ B.t()) ** 2).sum() 19 | quantization_loss = ((F - B[omega, :]) ** 2).sum() 20 | loss = (hash_loss + self.gamma * quantization_loss) / (F.shape[0] * B.shape[0]) 21 | 22 | return loss 23 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torch.hub import load_state_dict_from_url 4 | 5 | 6 | def load_model(code_length): 7 | """ 8 | Load CNN model. 9 | 10 | Args 11 | code_length (int): Hashing code length. 12 | 13 | Returns 14 | model (torch.nn.Module): CNN model. 15 | """ 16 | model = AlexNet(code_length) 17 | state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth') 18 | model.load_state_dict(state_dict, strict=False) 19 | 20 | return model 21 | 22 | 23 | class AlexNet(nn.Module): 24 | 25 | def __init__(self, code_length): 26 | super(AlexNet, self).__init__() 27 | self.features = nn.Sequential( 28 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 32 | nn.ReLU(inplace=True), 33 | nn.MaxPool2d(kernel_size=3, stride=2), 34 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 39 | nn.ReLU(inplace=True), 40 | nn.MaxPool2d(kernel_size=3, stride=2), 41 | ) 42 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 43 | self.classifier = nn.Sequential( 44 | nn.Dropout(), 45 | nn.Linear(256 * 6 * 6, 4096), 46 | nn.ReLU(inplace=True), 47 | nn.Dropout(), 48 | nn.Linear(4096, 4096), 49 | nn.ReLU(inplace=True), 50 | nn.Linear(4096 ,1000), 51 | ) 52 | 53 | self.classifier = self.classifier[:-1] 54 | self.hash_layer = nn.Sequential( 55 | nn.Linear(4096, code_length), 56 | nn.Tanh(), 57 | ) 58 | 59 | def forward(self, x): 60 | x = self.features(x) 61 | x = self.avgpool(x) 62 | x = x.view(x.size(0), 256 * 6 * 6) 63 | x = self.classifier(x) 64 | x = self.hash_layer(x) 65 | return x 66 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import adsh 4 | 5 | from loguru import logger 6 | from data.data_loader import load_data 7 | 8 | 9 | def run(): 10 | args = load_config() 11 | logger.add('logs/{time}.log', rotation='500 MB', level='INFO') 12 | logger.info(args) 13 | 14 | torch.backends.cudnn.benchmark = True 15 | 16 | # Load dataset 17 | query_dataloader, _, retrieval_dataloader = load_data( 18 | args.dataset, 19 | args.root, 20 | args.num_query, 21 | args.num_samples, 22 | args.batch_size, 23 | args.num_workers, 24 | ) 25 | 26 | for code_length in args.code_length: 27 | mAP = adsh.train( 28 | query_dataloader, 29 | retrieval_dataloader, 30 | code_length, 31 | args.device, 32 | args.lr, 33 | args.max_iter, 34 | args.max_epoch, 35 | args.num_samples, 36 | args.batch_size, 37 | args.root, 38 | args.dataset, 39 | args.gamma, 40 | args.topk, 41 | ) 42 | logger.info('[code_length:{}][map:{:.4f}]'.format(code_length, mAP)) 43 | 44 | 45 | def load_config(): 46 | """ 47 | Load configuration. 48 | 49 | Args 50 | None 51 | 52 | Returns 53 | args(argparse.ArgumentParser): Configuration. 54 | """ 55 | parser = argparse.ArgumentParser(description='ADSH_PyTorch') 56 | parser.add_argument('--dataset', 57 | help='Dataset name.') 58 | parser.add_argument('--root', 59 | help='Path of dataset') 60 | parser.add_argument('--batch-size', default=64, type=int, 61 | help='Batch size.(default: 64)') 62 | parser.add_argument('--lr', default=1e-4, type=float, 63 | help='Learning rate.(default: 1e-4)') 64 | parser.add_argument('--code-length', default='12,24,32,48', type=str, 65 | help='Binary hash code length.(default: 12,24,32,48)') 66 | parser.add_argument('--max-iter', default=50, type=int, 67 | help='Number of iterations.(default: 50)') 68 | parser.add_argument('--max-epoch', default=3, type=int, 69 | help='Number of epochs.(default: 3)') 70 | parser.add_argument('--num-query', default=1000, type=int, 71 | help='Number of query data points.(default: 1000)') 72 | parser.add_argument('--num-samples', default=2000, type=int, 73 | help='Number of sampling data points.(default: 2000)') 74 | parser.add_argument('--num-workers', default=0, type=int, 75 | help='Number of loading data threads.(default: 0)') 76 | parser.add_argument('--topk', default=-1, type=int, 77 | help='Calculate map of top k.(default: all)') 78 | parser.add_argument('--gpu', default=None, type=int, 79 | help='Using gpu.(default: False)') 80 | parser.add_argument('--gamma', default=200, type=float, 81 | help='Hyper-parameter.(default: 200)') 82 | 83 | args = parser.parse_args() 84 | 85 | # GPU 86 | if args.gpu is None: 87 | args.device = torch.device("cpu") 88 | else: 89 | args.device = torch.device("cuda:%d" % args.gpu) 90 | 91 | # Hash code length 92 | args.code_length = list(map(int, args.code_length.split(','))) 93 | 94 | return args 95 | 96 | 97 | if __name__ == '__main__': 98 | run() 99 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tree-Shu-Zhao/ADSH_PyTorch/50df03f1380d68b88b9c6965d9fb93ce0b459328/utils/__init__.py -------------------------------------------------------------------------------- /utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mean_average_precision(query_code, 5 | database_code, 6 | query_labels, 7 | database_labels, 8 | device, 9 | topk=None, 10 | ): 11 | """ 12 | Calculate mean average precision(map). 13 | 14 | Args: 15 | query_code (torch.Tensor): Query data hash code. 16 | database_code (torch.Tensor): Database data hash code. 17 | query_labels (torch.Tensor): Query data targets, one-hot 18 | database_labels (torch.Tensor): Database data targets, one-host 19 | device (torch.device): Using CPU or GPU. 20 | topk (int): Calculate top k data map. 21 | 22 | Returns: 23 | meanAP (float): Mean Average Precision. 24 | """ 25 | num_query = query_labels.shape[0] 26 | mean_AP = 0.0 27 | 28 | for i in range(num_query): 29 | # Retrieve images from database 30 | retrieval = (query_labels[i, :] @ database_labels.t() > 0).float() 31 | 32 | # Calculate hamming distance 33 | hamming_dist = 0.5 * (database_code.shape[1] - query_code[i, :] @ database_code.t()) 34 | 35 | # Arrange position according to hamming distance 36 | retrieval = retrieval[torch.argsort(hamming_dist)][:topk] 37 | 38 | # Retrieval count 39 | retrieval_cnt = retrieval.sum().int().item() 40 | 41 | # Can not retrieve images 42 | if retrieval_cnt == 0: 43 | continue 44 | 45 | # Generate score for every position 46 | score = torch.linspace(1, retrieval_cnt, retrieval_cnt).to(device) 47 | 48 | # Acquire index 49 | index = (torch.nonzero(retrieval == 1).squeeze() + 1.0).float() 50 | 51 | mean_AP += (score / index).mean() 52 | 53 | mean_AP = mean_AP / num_query 54 | return mean_AP 55 | --------------------------------------------------------------------------------