├── .gitignore ├── README.md ├── checkpoints └── .gitkeep ├── data └── dataloader.py ├── logs └── .gitkeep ├── requirements.txt ├── run.py ├── sdh.py └── utils └── evaluate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Checkpoints 2 | *.pt 3 | 4 | # Log 5 | *.log 6 | 7 | # Script 8 | *.sh 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Supervised Discrete Hashing 2 | 3 | ## REQUIREMENTS 4 | `pip install -r requirements.txt` 5 | 6 | 1. pytorch >= 1.0 7 | 2. loguru 8 | 9 | ## DATASETS 10 | 1. [cifar10-gist.mat](https://pan.baidu.com/s/1qE9KiAOTNs5ORn_WoDDwUg) password: umb6 11 | 2. [cifar-10_alexnet.t](https://pan.baidu.com/s/1ciJIYGCfS3m0marQvatNjQ) password: f1b7 12 | 3. [nus-wide-tc21_alexnet.t](https://pan.baidu.com/s/1YglFwoxB-3j7xTEyAc8ykw) password: vfeu 13 | 4. [imagenet-tc100_alexnet.t](https://pan.baidu.com/s/1ayv4wdtCOzEDsJy01SjRew) password: 6w5i 14 | 15 | ## USAGE 16 | ``` 17 | usage: run.py [-h] [--dataset DATASET] [--root ROOT] 18 | [--code-length CODE_LENGTH] [--max-iter MAX_ITER] 19 | [--num-anchor NUM_ANCHOR] [--num-train NUM_TRAIN] 20 | [--num-query NUM_QUERY] [--topk TOPK] [--gpu GPU] [--seed SEED] 21 | [--evaluate-interval EVALUATE_INTERVAL] [--lamda LAMDA] 22 | [--nu NU] [--sigma SIGMA] 23 | 24 | SDH_PyTorch 25 | 26 | optional arguments: 27 | -h, --help show this help message and exit 28 | --dataset DATASET Dataset name. 29 | --root ROOT Path of dataset 30 | --code-length CODE_LENGTH 31 | Binary hash code length.(default: 32 | 12,16,24,32,48,64,128) 33 | --max-iter MAX_ITER Number of iterations.(default: 5) 34 | --num-anchor NUM_ANCHOR 35 | Number of anchor.(default: 1000) 36 | --topk TOPK Calculate map of top k.(default: all) 37 | --gpu GPU Using gpu.(default: False) 38 | --seed SEED Random seed.(default: 3367) 39 | --evaluate-interval EVALUATE_INTERVAL 40 | Evaluation interval.(default: 1) 41 | --lamda LAMDA Hyper-parameter.(default: 1) 42 | --nu NU Hyper-parameter.(default: 1e-5) 43 | --sigma SIGMA Hyper-parameter. 2e-3 for cifar-10-gist, 5e-4 for 44 | others. 45 | ``` 46 | 47 | ## EXPERIMENTS 48 | 49 | cifar-10-gist: GIST features, 1000 query images, 5000 training images, sigma=2e-3, map@ALL. 50 | 51 | cifar-10-alexnet. Alexnet features, 1000 query images, 5000 training images, sigma=5e-4, map@ALL. 52 | 53 | nus-wide-tc21-alexnet. Alexnet features, top 21 classes, 2100 query images, 10500 training images, sigma=5e-4, map@5000. 54 | 55 | imagenet-tc100-alexnet: Alexnet features, top 100 classes, 5000 query images, 10000 training images, sigma=5e-4, map@1000. 56 | 57 | bits | 12 | 16 | 24 | 32 | 48 | 64 | 128 58 | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: 59 | cifar-10-gist@ALL | 0.3964 | 0.4335 | 0.4357 | 0.4611 | 0.4729 | 0.4826 | 0.4973 60 | cifar-10-alexnet@ALL | 0.4966 | 0.4837 | 0.5209 | 0.5373 | 0.5411 | 0.5629 | 0.5750 61 | nus-wide-tc21-alexnet@5000 | 0.7504 | 0.7684 | 0.7745 | 0.7932 | 0.7912 | 0.8035 | 0.8162 62 | imagenet-tc100-alexnet@1000 | 0.3529 | 0.4166 | 0.4790 | 0.5096 | 0.5429 | 0.5586 | 0.5974 63 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tree-Shu-Zhao/SDH_PyTorch/a82cc317f73e4819fa606e086207a807a733d40f/checkpoints/.gitkeep -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.io as sio 4 | 5 | 6 | def load_data(dataset, root): 7 | """ 8 | Load dataset. 9 | 10 | Args 11 | dataset(str): Dataset name. 12 | root(str): Path of dataset. 13 | """ 14 | if dataset == 'cifar10-gist': 15 | return load_data_gist(root) 16 | elif dataset == 'cifar-10' or dataset == 'nus-wide-tc21' or dataset == 'imagenet-tc100': 17 | return _load_data(root) 18 | else: 19 | raise ValueError('Invalid dataset name!') 20 | 21 | 22 | def _load_data(root): 23 | """ 24 | Load alexnet fc7 features. 25 | 26 | Args 27 | root(str): Path of dataset. 28 | 29 | Returns 30 | train_data(torch.Tensor, 5000*4096): Training data. 31 | train_targets(torch.Tensor, 5000*10): One-hot training targets. 32 | query_data(torch.Tensor, 1000*4096): Query data. 33 | query_targets(torch.Tensor, 1000*10): One-hot query targets. 34 | retrieval_data(torch.Tensor, 59000*4096): Retrieval data. 35 | retrieval_targets(torch.Tensor, 59000*10): One-hot retrieval targets. 36 | """ 37 | data = torch.load(root) 38 | train_data = data['train_features'] 39 | train_targets = data['train_targets'] 40 | query_data = data['query_features'] 41 | query_targets = data['query_targets'] 42 | retrieval_data = data['retrieval_features'] 43 | retrieval_targets = data['retrieval_targets'] 44 | 45 | # Normalization 46 | mean = retrieval_data.mean() 47 | std = retrieval_data.std() 48 | train_data = (train_data - mean) / std 49 | query_data = (query_data - mean) / std 50 | retrieval_data = (retrieval_data - mean) / std 51 | 52 | return train_data, train_targets, query_data, query_targets, retrieval_data, retrieval_targets 53 | 54 | 55 | def load_data_gist(root): 56 | """ 57 | Load cifar10-gist dataset. 58 | 59 | Args 60 | root(str): Path of dataset. 61 | 62 | Returns 63 | train_data(torch.Tensor, num_train*512): Training data. 64 | train_targets(torch.Tensor, num_train*10): One-hot training targets. 65 | query_data(torch.Tensor, num_query*512): Query data. 66 | query_targets(torch.Tensor, num_query*10): One-hot query targets. 67 | retrieval_data(torch.Tensor, num_train*512): Retrieval data. 68 | retrieval_targets(torch.Tensor, num_train*10): One-hot retrieval targets. 69 | """ 70 | # Load data 71 | mat_data = sio.loadmat(root) 72 | query_data = mat_data['testdata'] 73 | query_targets = mat_data['testgnd'].astype(np.int) 74 | retrieval_data = mat_data['traindata'] 75 | retrieval_targets = mat_data['traingnd'].astype(np.int) 76 | 77 | # One-hot 78 | query_targets = encode_onehot(query_targets) 79 | retrieval_targets = encode_onehot(retrieval_targets) 80 | 81 | # Normalization 82 | data = np.concatenate((query_data, retrieval_data), axis=0) 83 | data = (data - data.mean()) / data.std() 84 | query_data = data[:query_data.shape[0], :] 85 | retrieval_data = data[query_data.shape[0]:, :] 86 | 87 | # Sample training data 88 | num_train = 5000 89 | train_index = np.random.permutation(len(retrieval_data))[:num_train] 90 | train_data = retrieval_data[train_index, :] 91 | train_targets = retrieval_targets[train_index, :] 92 | 93 | train_data = torch.from_numpy(train_data).float() 94 | train_targets = torch.from_numpy(train_targets).float() 95 | query_data = torch.from_numpy(query_data).float() 96 | query_targets = torch.from_numpy(query_targets).float() 97 | retrieval_data = torch.from_numpy(retrieval_data).float() 98 | retrieval_targets = torch.from_numpy(retrieval_targets).float() 99 | 100 | 101 | return train_data, train_targets, query_data, query_targets, train_data, train_targets 102 | 103 | 104 | def encode_onehot(labels, num_classes=10): 105 | """ 106 | One-hot labels. 107 | 108 | Args: 109 | labels (numpy.ndarray): labels. 110 | num_classes (int): Number of classes. 111 | 112 | Returns: 113 | onehot_labels (numpy.ndarray): one-hot labels. 114 | """ 115 | onehot_labels = np.zeros((len(labels), num_classes)) 116 | 117 | for i in range(len(labels)): 118 | onehot_labels[i, labels[i]] = 1 119 | 120 | return onehot_labels 121 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tree-Shu-Zhao/SDH_PyTorch/a82cc317f73e4819fa606e086207a807a733d40f/logs/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | loguru -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import random 4 | import numpy as np 5 | import sdh 6 | 7 | from loguru import logger 8 | from data.dataloader import load_data 9 | 10 | 11 | def run(): 12 | # Load configuration 13 | args = load_config() 14 | logger.add('logs/{}_code_{}_anchor_{}_lamda_{}_nu_{}_sigma_{}_topk_{}.log'.format( 15 | args.dataset, 16 | '_'.join([str(code_length) for code_length in args.code_length]), 17 | args.num_anchor, 18 | args.lamda, 19 | args.nu, 20 | args.sigma, 21 | args.topk, 22 | ), 23 | rotation='500 MB', 24 | level='INFO', 25 | ) 26 | logger.info(args) 27 | 28 | # Set seed 29 | random.seed(args.seed) 30 | torch.manual_seed(args.seed) 31 | torch.cuda.manual_seed(args.seed) 32 | np.random.seed(args.seed) 33 | 34 | # Load data 35 | train_data, train_targets, query_data, query_targets, retrieval_data, retrieval_targets = load_data(args.dataset, args.root) 36 | 37 | # Training 38 | for code_length in args.code_length: 39 | checkpoint = sdh.train( 40 | train_data, 41 | train_targets, 42 | query_data, 43 | query_targets, 44 | retrieval_data, 45 | retrieval_targets, 46 | code_length, 47 | args.num_anchor, 48 | args.max_iter, 49 | args.lamda, 50 | args.nu, 51 | args.sigma, 52 | args.device, 53 | args.topk, 54 | ) 55 | logger.info('[code length:{}][map:{:.4f}]'.format(code_length, checkpoint['map'])) 56 | 57 | # Save checkpoint 58 | torch.save(checkpoint, 'checkpoints/{}_code_{}_anchor_{}_lamda_{}_nu_{}_sigma_{}_topk_{}_map_{:.4f}.pt'.format( 59 | args.dataset, 60 | code_length, 61 | args.num_anchor, 62 | args.lamda, 63 | args.nu, 64 | args.sigma, 65 | args.topk, 66 | checkpoint['map'], 67 | )) 68 | 69 | 70 | def load_config(): 71 | """ 72 | Load configuration. 73 | 74 | Args 75 | None 76 | 77 | Returns 78 | args(argparse.ArgumentParser): Configuration. 79 | """ 80 | parser = argparse.ArgumentParser(description='SDH_PyTorch') 81 | parser.add_argument('--dataset', 82 | help='Dataset name.') 83 | parser.add_argument('--root', 84 | help='Path of dataset') 85 | parser.add_argument('--code-length', default='12,16,24,32,48,64,128', type=str, 86 | help='Binary hash code length.(default: 12,16,24,32,48,64,128)') 87 | parser.add_argument('--max-iter', default=3, type=int, 88 | help='Number of iterations.(default: 3)') 89 | parser.add_argument('--num-anchor', default=1000, type=int, 90 | help='Number of anchor.(default: 1000)') 91 | parser.add_argument('--topk', default=-1, type=int, 92 | help='Calculate map of top k.(default: all)') 93 | parser.add_argument('--gpu', default=None, type=int, 94 | help='Using gpu.(default: False)') 95 | parser.add_argument('--seed', default=3367, type=int, 96 | help='Random seed.(default: 3367)') 97 | parser.add_argument('--lamda', default=1, type=float, 98 | help='Hyper-parameter.(default: 1)') 99 | parser.add_argument('--nu', default=1e-5, type=float, 100 | help='Hyper-parameter.(default: 1e-5)') 101 | parser.add_argument('--sigma', default=5e-4, type=float, 102 | help='Hyper-parameter. 2e-3 for cifar-10-gist, 5e-4 for others.') 103 | 104 | args = parser.parse_args() 105 | 106 | # GPU 107 | if args.gpu is None: 108 | args.device = torch.device("cpu") 109 | else: 110 | args.device = torch.device("cuda:%d" % args.gpu) 111 | 112 | # Hash code length 113 | args.code_length = list(map(int, args.code_length.split(','))) 114 | 115 | return args 116 | 117 | 118 | if __name__ == '__main__': 119 | run() 120 | -------------------------------------------------------------------------------- /sdh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from sklearn.metrics.pairwise import rbf_kernel 4 | from loguru import logger 5 | from utils.evaluate import mean_average_precision, pr_curve 6 | 7 | 8 | def train( 9 | train_data, 10 | train_targets, 11 | query_data, 12 | query_targets, 13 | retrieval_data, 14 | retrieval_targets, 15 | code_length, 16 | num_anchor, 17 | max_iter, 18 | lamda, 19 | nu, 20 | sigma, 21 | device, 22 | topk, 23 | ): 24 | """ 25 | Training model. 26 | 27 | Args 28 | train_data(torch.Tensor): Training data. 29 | train_targets(torch.Tensor): Training targets. 30 | query_data(torch.Tensor): Query data. 31 | query_targets(torch.Tensor): Query targets. 32 | retrieval_data(torch.Tensor): Retrieval data. 33 | retrieval_targets(torch.Tensor): Retrieval targets. 34 | code_length(int): Hash code length. 35 | num_anchor(int): Number of anchors. 36 | max_iter(int): Number of iterations. 37 | lamda, nu, sigma(float): Hyper-parameters. 38 | device(torch.device): GPU or CPU. 39 | topk(int): Compute mAP using top k retrieval result. 40 | 41 | Returns 42 | checkpoint(dict): Checkpoint. 43 | """ 44 | # Initialization 45 | n = train_data.shape[0] 46 | L = code_length 47 | m = num_anchor 48 | t = max_iter 49 | X = train_data.t() 50 | Y = train_targets.t() 51 | B = torch.randn(L, n).sign() 52 | 53 | # Permute data 54 | perm_index = torch.randperm(n) 55 | X = X[:, perm_index] 56 | Y = Y[:, perm_index] 57 | 58 | # Randomly select num_anchor samples from the training data 59 | anchor = X[:, :m] 60 | 61 | # Map training data via RBF kernel 62 | phi_x = torch.from_numpy(rbf_kernel(X.numpy().T, anchor.numpy().T, sigma)).t() 63 | 64 | # Training 65 | B = B.to(device) 66 | Y = Y.to(device) 67 | phi_x = phi_x.to(device) 68 | for it in range(t): 69 | # G-Step 70 | W = torch.pinverse(B @ B.t() + lamda * torch.eye(code_length, device=device)) @ B @ Y.t() 71 | 72 | # F-Step 73 | P = torch.pinverse(phi_x @ phi_x.t()) @ phi_x @ B.t() 74 | F_X = P.t() @ phi_x 75 | 76 | # B-Step 77 | B = solve_dcc(B, W, Y, F_X, nu) 78 | 79 | # Evaluate 80 | query_code = generate_code(query_data.t(), anchor, P, sigma) 81 | retrieval_code = generate_code(retrieval_data.t(), anchor, P, sigma) 82 | 83 | # Compute map 84 | mAP = mean_average_precision( 85 | query_code.t().to(device), 86 | retrieval_code.t().to(device), 87 | query_targets.to(device), 88 | retrieval_targets.to(device), 89 | device, 90 | topk, 91 | ) 92 | 93 | # PR curve 94 | Precision, R = pr_curve( 95 | query_code.t().to(device), 96 | retrieval_code.t().to(device), 97 | query_targets.to(device), 98 | retrieval_targets.to(device), 99 | device, 100 | ) 101 | 102 | # Save checkpoint 103 | checkpoint = { 104 | 'tB': B, 105 | 'tL': train_targets, 106 | 'qB': query_code, 107 | 'qL': query_targets, 108 | 'rB': retrieval_code, 109 | 'rL': retrieval_targets, 110 | 'anchor': anchor, 111 | 'projection': P, 112 | 'P': Precision, 113 | 'R': R, 114 | 'map': mAP, 115 | } 116 | 117 | return checkpoint 118 | 119 | 120 | def solve_dcc(B, W, Y, F_X, nu): 121 | """Solve DCC(Discrete Cyclic Coordinate Descent) problem 122 | """ 123 | for i in range(B.shape[0]): 124 | Q = W @ Y + nu * F_X 125 | 126 | q = Q[i, :] 127 | v = W[i, :] 128 | W_prime = torch.cat((W[:i, :], W[i+1:, :])) 129 | B_prime = torch.cat((B[:i, :], B[i+1:, :])) 130 | 131 | B[i, :] = (q - B_prime.t() @ W_prime @ v).sign() 132 | 133 | return B 134 | 135 | 136 | def generate_code(data, anchor, P, sigma): 137 | """ 138 | Generate hash code from data using projection matrix. 139 | 140 | Args 141 | data(torch.Tensor): Data. 142 | anchor(torch.Tensor): Anchor points. 143 | P(torch.Tensor): Projection matrix. 144 | sigma(float): RBF kernel width. 145 | 146 | Returns 147 | code(torch.Tensor): Hash code. 148 | """ 149 | phi_x = torch.from_numpy(rbf_kernel(data.cpu().numpy().T, anchor.cpu().numpy().T, sigma)).t().to(P.device) 150 | return (P.t() @ phi_x).sign() 151 | 152 | -------------------------------------------------------------------------------- /utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mean_average_precision(query_code, 5 | retrieval_code, 6 | query_targets, 7 | retrieval_targets, 8 | device, 9 | topk=None, 10 | ): 11 | """ 12 | Calculate mean average precision(map). 13 | 14 | Args: 15 | query_code (torch.Tensor): Query data hash code. 16 | retrieval_code (torch.Tensor): Retrieval data hash code. 17 | query_targets (torch.Tensor): Query data targets, one-hot 18 | retrieval_targets (torch.Tensor): retrieval data targets, one-host 19 | device (torch.device): Using CPU or GPU. 20 | topk (int): Calculate top k data map. 21 | 22 | Returns: 23 | meanAP (float): Mean Average Precision. 24 | """ 25 | num_query = query_targets.shape[0] 26 | mean_AP = 0.0 27 | 28 | for i in range(num_query): 29 | # Retrieve images from database 30 | retrieval = (query_targets[i, :] @ retrieval_targets.t() > 0).float() 31 | 32 | # Calculate hamming distance 33 | hamming_dist = 0.5 * (retrieval_code.shape[1] - query_code[i, :] @ retrieval_code.t()) 34 | 35 | # Arrange position according to hamming distance 36 | retrieval = retrieval[torch.argsort(hamming_dist)][:topk] 37 | 38 | # Retrieval count 39 | retrieval_cnt = retrieval.sum().int().item() 40 | 41 | # Can not retrieve images 42 | if retrieval_cnt == 0: 43 | continue 44 | 45 | # Generate score for every position 46 | score = torch.linspace(1, retrieval_cnt, retrieval_cnt).to(device) 47 | 48 | # Acquire index 49 | index = (torch.nonzero(retrieval == 1).squeeze() + 1.0).float() 50 | 51 | mean_AP += (score / index).mean() 52 | 53 | mean_AP = mean_AP / num_query 54 | return mean_AP.item() 55 | 56 | 57 | def pr_curve(query_code, retrieval_code, query_targets, retrieval_targets, device): 58 | """ 59 | P-R curve. 60 | 61 | Args 62 | query_code(torch.Tensor): Query hash code. 63 | retrieval_code(torch.Tensor): Retrieval hash code. 64 | query_targets(torch.Tensor): Query targets. 65 | retrieval_targets(torch.Tensor): Retrieval targets. 66 | device (torch.device): Using CPU or GPU. 67 | 68 | Returns 69 | P(torch.Tensor): Precision. 70 | R(torch.Tensor): Recall. 71 | """ 72 | num_query = query_code.shape[0] 73 | num_bit = query_code.shape[1] 74 | P = torch.zeros(num_query, num_bit + 1).to(device) 75 | R = torch.zeros(num_query, num_bit + 1).to(device) 76 | for i in range(num_query): 77 | gnd = (query_targets[i].unsqueeze(0).mm(retrieval_targets.t()) > 0).float().squeeze() 78 | tsum = torch.sum(gnd) 79 | if tsum == 0: 80 | continue 81 | hamm = 0.5 * (retrieval_code.shape[1] - query_code[i, :] @ retrieval_code.t()) 82 | tmp = (hamm <= torch.arange(0, num_bit + 1).reshape(-1, 1).float().to(device)).float() 83 | total = tmp.sum(dim=-1) 84 | total = total + (total == 0).float() * 0.1 85 | t = gnd * tmp 86 | count = t.sum(dim=-1) 87 | p = count / total 88 | r = count / tsum 89 | P[i] = p 90 | R[i] = r 91 | mask = (P > 0).float().sum(dim=0) 92 | mask = mask + (mask == 0).float() * 0.1 93 | P = P.sum(dim=0) / mask 94 | R = R.sum(dim=0) / mask 95 | 96 | return P, R 97 | --------------------------------------------------------------------------------