├── .gitignore ├── README.md ├── data ├── README.md ├── coco │ └── 80_coco │ │ ├── database.txt │ │ ├── test.txt │ │ └── train.txt └── nuswide │ └── nuswide_81 │ ├── database.txt │ ├── test.txt │ └── train.txt ├── preprocess ├── README.md └── __init__.py └── src ├── __init__.py ├── common ├── __init__.py ├── fake_demo.py ├── logger.py ├── lr_scheduler.py └── mmhh_config.py ├── dataloader ├── __init__.py ├── image_list.py └── image_preprocess.py ├── evaluate ├── __init__.py ├── measure_utils.py └── optimize_metric_demo.py ├── mmhh.py ├── mmhh_loss.py ├── mmhh_network.py ├── semi_batch.py ├── test_mmhh.py └── train_mmhh.py /.gitignore: -------------------------------------------------------------------------------- 1 | log 2 | *.pyc 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MMHH 2 | The Pytorch implementation of Maximum-Margin Hamming Hashing. 3 | 4 | ## Requirements 5 | 6 | The code requires some common packages: 7 | 8 | ```shell 9 | # python>=3.6 10 | # Anaconda: it is not necessary but recommended since it contains a lot of common packages. 11 | conda create -n py36 python=3.6 12 | source activate py36 13 | 14 | # pytorch = 0.4.1 15 | conda install pytorch=0.4.1 cuda92 -c pytorch 16 | 17 | # Maybe you need tensorboardX for detailed analysis 18 | pip install tensorboardX 19 | ``` 20 | 21 | ## Data Preparation 22 | 23 | We recommend you to follow [HashNet](https://github.com/thuml/HashNet/) to prepare the dataset images. 24 | 25 | Our paper also conducts experiments on **noise data** and **Unseen Classes Retrieval Protocol**. The preprocessing has been elaborated in the paper and is easy to follow. We will add the preprocessing scripts soon. 26 | 27 | ## Example Usage 28 | 29 | To train the model, it is an example: 30 | 31 | ```shell 32 | python train_mmhh.py --gpu_id=0 --s_dataset="coco_80" --hash_bit=48 --annotation="MMHH-train" --loss_lambda=0.001 --num_iters=1000 --image_network="AlexNetFc" --batch_size=48 --radius 2 --distance_type "MMHH" --similar_weight "1" --lr 0.0001 --decay_step 200 --gamma 10.0 --opt-test True 33 | ``` 34 | 35 | The model will be examined in the end of training. If you want to test a model individually, run the following example: 36 | 37 | ```shell 38 | python test_mmhh.py --gpu_id=0 --dataset="coco_80" --model_path "../snapshot/hash/MMHH-train_coco_80_coco_80_iter_01000" --batch_size=48 --radius 2 --opt-test --annotation="MMHH-test" --test_sample_ratio 1.0 39 | ``` 40 | 41 | The basic metric functions refer to [DeepHash](https://github.com/thulab/DeepHash) (MAP@H<=2) and [HashNet](https://github.com/thuml/HashNet/) (MAP@TopK). We optimize them carefully and speed up by ×2 to ×10. 42 | 43 | Due to the `numpy` randomness, the optimized version may be slightly different from the original ones, but we believe it doesn't matter after lots of tests. 44 | 45 | ## Acknowledgments 46 | 47 | Our code mainly refers to the following repositories, we want to thanks for their invaluable help sincerely: 48 | 49 | * [HashNet](https://github.com/thuml/HashNet/) : the dataset, data processing, the network backbones, etc.. 50 | * [DeepHash](https://github.com/thulab/DeepHash): the DCH implementation and the training parameters. 51 | * [Snca.pytorch](https://github.com/microsoft/snca.pytorch): the augmented memory. 52 | 53 | ## Citations 54 | 55 | If you find the codes are helpful to your work, please kindly cite our paper: 56 | 57 | ``` 58 | @inproceedings{DBLP:conf/iccv/Kang0L0Y19, 59 | author = {Rong Kang and 60 | Yue Cao and 61 | Mingsheng Long and 62 | Jianmin Wang and 63 | Philip S. Yu}, 64 | title = {Maximum-Margin Hamming Hashing}, 65 | booktitle = {2019 {IEEE/CVF} International Conference on Computer Vision, {ICCV} 66 | 2019, Seoul, Korea (South), October 27 - November 2, 2019}, 67 | pages = {8251--8260}, 68 | publisher = {{IEEE}}, 69 | year = {2019}, 70 | } 71 | ``` 72 | 73 | If you encounter any issues, please feel free to send an email to [kangr15@mails.tsinghua.edu.cn](mailto:kangr15@mails.tsinghua.edu.cn). We will do our best to address your concerns. 74 | 75 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Dataset 2 | Thanks to the pytorch part of [HashNet](https://github.com/thuml/HashNet). 3 | 4 | As for now, we provide image datasets(2d). 5 | 6 | Series(1d) and Voxel(3d) datasets will also be involved in the future. 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /preprocess/README.md: -------------------------------------------------------------------------------- 1 | # Preprocess 2 | Contains some scripts for preprocessing data. 3 | 4 | e.g: 5 | 6 | * Clip dataset into seen/unseen datasets. 7 | * Add noise into original data with a specified ratio. 8 | * So on. 9 | 10 | TO BE UPDATED. 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MMHH/65514f0c0f03c50e23df49bbd41c78b73e6d950b/preprocess/__init__.py -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MMHH/65514f0c0f03c50e23df49bbd41c78b73e6d950b/src/__init__.py -------------------------------------------------------------------------------- /src/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MMHH/65514f0c0f03c50e23df49bbd41c78b73e6d950b/src/common/__init__.py -------------------------------------------------------------------------------- /src/common/fake_demo.py: -------------------------------------------------------------------------------- 1 | from common.mmhh_config import data_config 2 | 3 | 4 | def get_fake_train_list(s_dataset, t_dataset): 5 | return get_fake_list(s_dataset), get_fake_list(t_dataset) 6 | 7 | 8 | def get_fake_test_list(config, t_dataset): 9 | config["data"]["database"]["list_path"] = get_fake_list(t_dataset) 10 | config["data"]["test"]["list_path"] = get_fake_list(t_dataset) 11 | config["R"] = 10 12 | 13 | 14 | def get_fake_list(dataset): 15 | if dataset in ['ElectricDevices', 'Crop', 'InsectWingbeat']: 16 | return data_config["ElectricDevices"]["train"] 17 | fake_dir = '../data/fake' 18 | # fake_dir = 'data/fake' 19 | if dataset in ['shapenet_9', 'shapenet_13', 'shape_pro_13', 'shape_pro_9', 'modelnet_10', 'modelnet_40', 20 | 'modelnet_sm_11']: 21 | return fake_dir + '/voxel_list.txt' 22 | else: 23 | return fake_dir + '/image_list.txt' 24 | -------------------------------------------------------------------------------- /src/common/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | 5 | 6 | def get_log(log_dir, annotation): 7 | if not os.path.exists(log_dir): 8 | os.system("mkdir -p " + log_dir) 9 | logger = logging.getLogger(annotation) 10 | logger.setLevel(logging.INFO) 11 | dd = datetime.datetime.now() 12 | fh = logging.FileHandler(log_dir + "out_project_%s.log.%s" % (annotation, dd.isoformat())) 13 | fh.setLevel(logging.INFO) 14 | ch = logging.StreamHandler() 15 | ch.setLevel(logging.INFO) 16 | # log format 17 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 18 | ch.setFormatter(formatter) 19 | fh.setFormatter(formatter) 20 | logger.addHandler(ch) 21 | logger.addHandler(fh) 22 | # log part end 23 | return logger 24 | 25 | 26 | def get_log_one_path(path): 27 | logger = logging.getLogger(path) 28 | logger.setLevel(logging.INFO) 29 | fh = logging.FileHandler(path) 30 | fh.setLevel(logging.INFO) 31 | ch = logging.StreamHandler() 32 | ch.setLevel(logging.INFO) 33 | # log format 34 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 35 | ch.setFormatter(formatter) 36 | fh.setFormatter(formatter) 37 | logger.addHandler(ch) 38 | logger.addHandler(fh) 39 | # log part end 40 | return logger 41 | -------------------------------------------------------------------------------- /src/common/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | class INVScheduler(object): 2 | def __init__(self, gamma, decay_rate, init_lr=0.001): 3 | self.gamma = gamma 4 | self.decay_rate = decay_rate 5 | self.init_lr = init_lr 6 | 7 | def next_optimizer(self, group_ratios, optimizer, iter_num): 8 | lr = self.init_lr * (1 + self.gamma * iter_num) ** (-self.decay_rate) 9 | for param_group, group_ratio in zip(optimizer.param_groups, group_ratios): 10 | param_group['lr'] = lr * group_ratio 11 | return optimizer 12 | 13 | 14 | class StepScheduler(object): 15 | def __init__(self, gamma=0.5, step=2000, init_lr=0.0003): 16 | self.step = step 17 | self.gamma = gamma 18 | self.init_lr = init_lr 19 | self.last_step = -1 20 | 21 | def next_optimizer(self, group_ratios, optimizer, iter_num, logger=None): 22 | lr = self.init_lr * (self.gamma ** (iter_num // self.step)) 23 | if logger and lr != self.last_step: 24 | self.last_step = lr 25 | for param_group, group_ratio in zip(optimizer.param_groups, group_ratios): 26 | param_group['lr'] = lr * group_ratio 27 | return optimizer 28 | -------------------------------------------------------------------------------- /src/common/mmhh_config.py: -------------------------------------------------------------------------------- 1 | from dataloader.image_list import load_images 2 | 3 | from enum import Enum 4 | 5 | global_debugging = False 6 | 7 | 8 | class ImageLossType(Enum): 9 | DHN = 1 10 | HashNet = 2 11 | 12 | 13 | class TestType(Enum): 14 | Test = 1 15 | NoFirst = 2 16 | NoTest = 3 17 | 18 | 19 | class DistanceType(Enum): 20 | Hamming = 1 21 | tSNE = 2 22 | Cauchy = 3 23 | Margin1 = 4 24 | Margin2 = 5 25 | Metric = 6 26 | MMHH = 7 27 | 28 | 29 | class BatchType(Enum): 30 | PairBatch = 1 31 | SemiMem = 2 32 | BatchInitMem = 3 33 | BatchSelectMem = 4 34 | 35 | 36 | class SemiInitType(Enum): 37 | RANDOM = 1 38 | MODEL = 2 39 | 40 | 41 | class Mission(Enum): 42 | Cross_Modal_Transfer = 1 43 | Hashing = 2 44 | Cross_Domain_Transfer = 3 45 | 46 | 47 | class LrScheduleType(Enum): 48 | Stair_Step = 1 49 | DANN_INV = 2 50 | HashNet_Step = 3 51 | ShapeNet_Step = 4 52 | TimeSeries_Step = 5 53 | 54 | 55 | class MarginParams(object): 56 | def __init__(self, sim_in, dis_in, sim_out=1.0, dis_out=1.0, margin=2.0): 57 | self.sim_in = sim_in 58 | self.dis_in = dis_in 59 | self.sim_out = sim_out 60 | self.dis_out = dis_out 61 | self.margin = margin 62 | 63 | @staticmethod 64 | def from_string_params(param_str: str): 65 | param_strs = param_str.split(":") 66 | if len(param_strs) == 2: 67 | return MarginParams(float(param_strs[0]), float(param_strs[1])) 68 | if len(param_strs) == 3: 69 | return MarginParams(float(param_strs[0]), float(param_strs[1]), float(param_strs[2]), 70 | float(param_strs[3])) 71 | if len(param_strs) == 4: 72 | return MarginParams(float(param_strs[0]), float(param_strs[1]), float(param_strs[2]), 73 | float(param_strs[3])) 74 | if len(param_strs) == 5: 75 | return MarginParams(float(param_strs[0]), float(param_strs[1]), float(param_strs[2]), 76 | float(param_strs[3]), float(param_strs[4])) 77 | 78 | 79 | data_config = { 80 | "coco_80": { 81 | "train": "../data/coco/80_coco/train.txt", 82 | "database": "../data/coco/80_coco/database.txt", 83 | "test": "../data/coco/80_coco/test.txt", 84 | "R": 5000, 85 | "loader": load_images, "class_num": 1}, 86 | "nuswide_21": { 87 | "train": "../data/nuswide/nuswide_21/train.txt", 88 | "database": "../data/nuswide/nuswide_21/database.txt", 89 | "test": "../data/nuswide/nuswide_21/test.txt", 90 | "R": 5000, 91 | "loader": load_images, "class_num": 5}, 92 | "nuswide_81": { 93 | "train": "../data/nuswide/nuswide_81/train.txt", 94 | "database": "../data/nuswide/nuswide_81/database.txt", 95 | "test": "../data/nuswide/nuswide_81/test.txt", 96 | "R": 5000, 97 | "loader": load_images, "class_num": 5}, 98 | } 99 | 100 | tensorboard_interval = 50 101 | -------------------------------------------------------------------------------- /src/dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MMHH/65514f0c0f03c50e23df49bbd41c78b73e6d950b/src/dataloader/__init__.py -------------------------------------------------------------------------------- /src/dataloader/image_list.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function, division 2 | from dataloader import image_preprocess as prep 3 | from torchvision import transforms 4 | import torch 5 | import numpy as np 6 | from PIL import Image 7 | import torch.utils.data as data 8 | 9 | 10 | def make_dataset(image_list, labels): 11 | if labels: 12 | len_ = len(image_list) 13 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 14 | else: 15 | if len(image_list[0].split()) > 2: 16 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]], dtype=np.uint8)) for val in image_list] 17 | else: 18 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 19 | return images 20 | 21 | 22 | def pil_loader(path): 23 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 24 | with open(path, 'rb') as f: 25 | with Image.open(f) as img: 26 | return img.convert('RGB') 27 | 28 | 29 | def accimage_loader(path): 30 | import accimage 31 | try: 32 | return accimage.Image(path) 33 | except IOError: 34 | # Potentially a decoding problem, fall back to PIL.Image 35 | return pil_loader(path) 36 | 37 | 38 | def default_loader(path): 39 | # from torchvision import get_image_backend 40 | # if get_image_backend() == 'accimage': 41 | # return accimage_loader(path) 42 | # else: 43 | return pil_loader(path) 44 | 45 | 46 | class ImageListWithIndex(object): 47 | """A generic data loader where the images are arranged in this way: :: 48 | root/dog/xxx.png 49 | root/dog/xxy.png 50 | root/dog/xxz.png 51 | root/cat/123.png 52 | root/cat/nsdf3.png 53 | root/cat/asd932_.png 54 | Args: 55 | root (string): Root directory path. 56 | transform (callable, optional): A function/transform that takes in an PIL image 57 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 58 | target_transform (callable, optional): A function/transform that takes in the 59 | target and transforms it. 60 | loader (callable, optional): A function to load an image given its path. 61 | Attributes: 62 | classes (list): List of the class names. 63 | class_to_idx (dict): Dict with items (class_name, class_index). 64 | imgs (list): List of (image path, class_index) tuples 65 | """ 66 | 67 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, 68 | loader=default_loader, need_make_dataset=True): 69 | """ 70 | 71 | :param image_list: 72 | :param labels: 73 | :param transform: 74 | :param target_transform: 75 | :param loader: 76 | """ 77 | if need_make_dataset: 78 | imgs = make_dataset(image_list, labels) 79 | else: 80 | imgs = image_list 81 | if len(imgs) == 0: 82 | raise RuntimeError("Found 0 images in subfolders of: " + image_list) 83 | 84 | self.imgs = imgs 85 | self.transform = transform 86 | self.target_transform = target_transform 87 | self.loader = loader 88 | 89 | def __getitem__(self, index): 90 | """ 91 | Args: 92 | index (int): Index 93 | Returns: 94 | tuple: (image, target) where target is class_index of the target class. 95 | """ 96 | path, target = self.imgs[index] 97 | img = self.loader(path) 98 | if self.transform is not None: 99 | img = self.transform(img) 100 | if self.target_transform is not None: 101 | target = self.target_transform(target) 102 | 103 | return img, target, index 104 | 105 | def __len__(self): 106 | return len(self.imgs) 107 | 108 | 109 | def load_images(images_file_path, batch_size, resize_size=256, is_train=True, crop_size=224, test_sample_ratio=1.0): 110 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 111 | if not is_train: 112 | start_center = (resize_size - crop_size - 1) / 2 113 | transformer = transforms.Compose([ 114 | prep.ResizeImage(resize_size), 115 | prep.PlaceCrop(crop_size, start_center, start_center), 116 | transforms.ToTensor(), 117 | normalize]) 118 | image_lines = open(images_file_path).readlines() 119 | if test_sample_ratio < 1.0: 120 | sample_line = int(len(image_lines) * test_sample_ratio) 121 | print("sample ratio: %.3f, ori: %.3f, sample: %.3f" % (test_sample_ratio, len(image_lines), sample_line)) 122 | image_lines = np.random.choice(image_lines, int(len(image_lines) * test_sample_ratio), replace=False) 123 | else: 124 | print("no sample ori: %.3f" % (len(image_lines))) 125 | images = ImageListWithIndex(image_lines, transform=transformer) 126 | images_loader = torch.utils.data.DataLoader(images, batch_size=batch_size, shuffle=False, num_workers=4) 127 | else: 128 | transformer = transforms.Compose([prep.ResizeImage(resize_size), 129 | transforms.RandomResizedCrop(crop_size), 130 | transforms.RandomHorizontalFlip(), 131 | transforms.ToTensor(), 132 | normalize]) 133 | image_lines = open(images_file_path).readlines() 134 | if test_sample_ratio < 1.0: 135 | assert "we shouldn't sample train set!" 136 | # sample_line = int(len(image_lines) * test_sample_ratio) 137 | # image_lines = np.random.choice(image_lines, sample_line, replace=False) 138 | images = ImageListWithIndex(image_lines, transform=transformer) 139 | images_loader = torch.utils.data.DataLoader(images, batch_size=batch_size, shuffle=True, num_workers=4) 140 | return images_loader 141 | 142 | 143 | def inverted_imgs(imgs): 144 | total_class_num = len(imgs[0][1]) 145 | inverts = [[] for _ in range(total_class_num)] 146 | for item in imgs: 147 | for i, c in enumerate(item[1]): 148 | if c == 1: 149 | inverts[i].append(item) 150 | return inverts 151 | 152 | 153 | def load_balance_images(images_file_path, batch_size, resize_size=256, is_train=True, crop_size=224): 154 | if not is_train: 155 | return load_images(images_file_path, batch_size, resize_size=resize_size, is_train=is_train, crop_size=crop_size) 156 | else: 157 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 158 | transformer = transforms.Compose([prep.ResizeImage(resize_size), 159 | transforms.RandomResizedCrop(crop_size), 160 | transforms.RandomHorizontalFlip(), 161 | transforms.ToTensor(), 162 | normalize]) 163 | # images = ImageBalanceList(open(images_file_path).readlines(), transform=transformer) 164 | # images = ImageList(open(images_file_path).readlines(), transform=transformer) 165 | # images_loader = torch.utils.data.DataLoader(images, batch_size=batch_size, shuffle=True, num_workers=4, 166 | # drop_last=True) 167 | 168 | imgs = make_dataset(open(images_file_path).readlines(), None) 169 | inverts = inverted_imgs(imgs) 170 | image_loaders = [] 171 | half_batch = batch_size // 2 172 | for invert in inverts: 173 | # print("class num: %d" % len(invert)) 174 | invert_len = len(invert) 175 | i_class_images = ImageListWithIndex(invert, transform=transformer, need_make_dataset=False) 176 | loader = torch.utils.data.DataLoader(i_class_images, batch_size=min(invert_len, half_batch), shuffle=True, num_workers=4, 177 | drop_last=True) 178 | image_loaders.append(loader) 179 | return image_loaders 180 | 181 | -------------------------------------------------------------------------------- /src/dataloader/image_preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import transforms 3 | import os 4 | from PIL import Image, ImageOps 5 | import numbers 6 | import torch 7 | 8 | class ResizeImage(): 9 | def __init__(self, size): 10 | if isinstance(size, int): 11 | self.size = (int(size), int(size)) 12 | else: 13 | self.size = size 14 | def __call__(self, img): 15 | th, tw = self.size 16 | return img.resize((th, tw)) 17 | 18 | 19 | class PlaceCrop(object): 20 | """Crops the given PIL.Image at the particular index. 21 | Args: 22 | size (sequence or int): Desired output size of the crop. If size is an 23 | int instead of sequence like (w, h), a square crop (size, size) is 24 | made. 25 | """ 26 | 27 | def __init__(self, size, start_x, start_y): 28 | if isinstance(size, int): 29 | self.size = (int(size), int(size)) 30 | else: 31 | self.size = size 32 | self.start_x = start_x 33 | self.start_y = start_y 34 | 35 | def __call__(self, img): 36 | """ 37 | Args: 38 | img (PIL.Image): Image to be cropped. 39 | Returns: 40 | PIL.Image: Cropped image. 41 | """ 42 | th, tw = self.size 43 | return img.crop((self.start_x, self.start_y, self.start_x + tw, self.start_y + th)) 44 | 45 | 46 | class ForceFlip(object): 47 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 48 | 49 | def __call__(self, img): 50 | """ 51 | Args: 52 | img (PIL.Image): Image to be flipped. 53 | Returns: 54 | PIL.Image: Randomly flipped image. 55 | """ 56 | return img.transpose(Image.FLIP_LEFT_RIGHT) 57 | 58 | def image_train(resize_size=256, crop_size=224): 59 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 60 | std=[0.229, 0.224, 0.225]) 61 | return transforms.Compose([ 62 | ResizeImage(resize_size), 63 | transforms.RandomResizedCrop(crop_size), 64 | transforms.RandomHorizontalFlip(), 65 | transforms.ToTensor(), 66 | normalize 67 | ]) 68 | 69 | def image_test(resize_size=256, crop_size=224): 70 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 71 | std=[0.229, 0.224, 0.225]) 72 | #ten crops for image when validation, input the data_transforms dictionary 73 | start_first = 0 74 | start_center = (resize_size - crop_size - 1) / 2 75 | start_last = resize_size - crop_size - 1 76 | 77 | return transforms.Compose([ 78 | ResizeImage(resize_size), 79 | PlaceCrop(crop_size, start_center, start_center), 80 | transforms.ToTensor(), 81 | normalize 82 | ]) 83 | 84 | 85 | def image_train_cifar(resize_size=256, crop_size=224): 86 | normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 87 | std=[0.2023, 0.1994, 0.2010]) 88 | return transforms.Compose([ 89 | transforms.RandomCrop(32, padding=4), 90 | transforms.RandomHorizontalFlip(), 91 | transforms.ToTensor(), 92 | normalize 93 | ]) 94 | 95 | def image_test_cifar(resize_size=256, crop_size=224): 96 | normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 97 | std=[0.2023, 0.1994, 0.2010]) 98 | return transforms.Compose([ 99 | transforms.ToTensor(), 100 | normalize 101 | ]) 102 | -------------------------------------------------------------------------------- /src/evaluate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/MMHH/65514f0c0f03c50e23df49bbd41c78b73e6d950b/src/evaluate/__init__.py -------------------------------------------------------------------------------- /src/evaluate/measure_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | 5 | 6 | class ds: 7 | def __init__(self): 8 | self.output = [] 9 | self.label = [] 10 | 11 | 12 | def mean_average_precision_normal(database_output, database_labels, query_output, query_labels, R, verbose=0): 13 | query_num = query_output.shape[0] 14 | 15 | sim = np.dot(database_output, query_output.T) 16 | start_time = time.time() 17 | ids = np.argsort(-sim, axis=0) 18 | end_time = time.time() 19 | print("total query: {:d}, sorting time: {:.3f}".format(query_num, end_time - start_time)) 20 | APx = [] 21 | precX = [] 22 | recX = [] 23 | label_matchs = calc_label_match_matrix(database_labels, query_labels) 24 | for i in range(query_num): 25 | if i % 100 == 0: 26 | tmp_time = time.time() 27 | print("query map {:d}, time: {:.3f}".format(i, tmp_time - end_time)) 28 | end_time = tmp_time 29 | label = query_labels[i, :] 30 | label[label == 0] = -1 31 | idx = ids[:, i] 32 | imatch = np.sum(database_labels[idx[0:R], :] == label, axis=1) > 0 33 | relevant_num = np.sum(imatch) 34 | 35 | all_sim_num = label_matchs.all_sims[i] 36 | recX.append(float(relevant_num) / all_sim_num) 37 | precX.append(float(relevant_num) / R) 38 | 39 | Lx = np.cumsum(imatch) 40 | 41 | Px = Lx.astype(float) / np.arange(1, R + 1, 1) 42 | if relevant_num != 0: 43 | APx.append(np.sum(Px * imatch) / relevant_num) 44 | if verbose > 1: 45 | print(relevant_num, relevant_num, APx[-1]) 46 | if verbose > 0: 47 | print("MAP: %f" % np.mean(np.array(APx))) 48 | print("total time: {:.3f}".format(time.time() - start_time)) 49 | return np.mean(np.array(precX), 0), np.mean(np.array(recX), 0), np.mean(np.array(APx), 0) 50 | 51 | 52 | def mean_average_precision_normal_optimized_label(database_output, database_labels, query_output, query_labels, 53 | R, verbose=0, label_matchs=None): 54 | """ 55 | Optimizing the primary function by calculating the label-similarity matrix in advance. 56 | :param database: 57 | :param query: 58 | :param R: top R 59 | :param verbose: 60 | :param label_matchs: In this optimization, we suppose the test and database lists are fixed, so we only 61 | calculate the test-db label matching relation once and store it in a matrix with space complexity O(db_size * test_size). 62 | :return: 63 | """ 64 | 65 | query_labels[query_labels < 0] = 0 66 | database_labels[database_labels < 0] = 0 67 | 68 | label_matrix_time = -1 69 | if label_matchs is None: 70 | tmp_time = time.time() 71 | label_matchs = calc_label_match_matrix(database_labels, query_labels) 72 | label_matrix_time = time.time() - tmp_time 73 | print("calc label matrix: time: {:.3f}".format(label_matrix_time)) 74 | 75 | query_num = query_output.shape[0] 76 | 77 | sim = np.dot(database_output, query_output.T) 78 | start_time = time.time() 79 | ids = np.argsort(-sim, axis=0) 80 | end_time = time.time() 81 | sort_time = end_time - start_time 82 | print("total query: {:d}, sorting time: {:.3f}".format(query_num, sort_time)) 83 | APx = [] 84 | precX = [] 85 | recX = [] 86 | 87 | for i in range(query_num): 88 | if i % 100 == 0: 89 | tmp_time = time.time() 90 | print("query map {:d}, time: {:.3f}".format(i, tmp_time - end_time)) 91 | end_time = tmp_time 92 | idx = ids[:, i] 93 | imatch = label_matchs.label_match_matrix[i, idx[0:R]] 94 | relevant_num = np.sum(imatch) 95 | 96 | all_sim_num = label_matchs.all_sims[i] 97 | recX.append(float(relevant_num) / all_sim_num) 98 | precX.append(float(relevant_num) / R) 99 | 100 | 101 | Lx = np.cumsum(imatch) 102 | Px = Lx.astype(float) / np.arange(1, R + 1, 1) 103 | if relevant_num > 0: 104 | APx.append(np.sum(Px * imatch) / relevant_num) 105 | if verbose > 1: 106 | print(relevant_num, relevant_num, APx[-1]) 107 | if verbose > 0: 108 | print("MAP: %f" % np.mean(np.array(APx))) 109 | print("total query: {:d}, sorting time: {:.3f}".format(query_num, sort_time)) 110 | print("total time(no label matrix): {:.3f}".format(time.time() - start_time)) 111 | if label_matrix_time > 0: 112 | print("calc label matrix: time: {:.3f}".format(label_matrix_time)) 113 | return np.mean(np.array(precX), 0), np.mean(np.array(recX), 0), np.mean(np.array(APx), 0) 114 | 115 | 116 | def partition_arg_topK(matrix, K, axis=0): 117 | """ 118 | perform topK based on np.argpartition 119 | :param matrix: to be sorted 120 | :param K: select and sort the top K items 121 | :param axis: 0 or 1. dimension to be sorted. 122 | :return: 123 | """ 124 | a_part = np.argpartition(matrix, K, axis=axis) 125 | if axis == 0: 126 | row_index = np.arange(matrix.shape[1 - axis]) 127 | a_sec_argsort_K = np.argsort(matrix[a_part[0:K, :], row_index], axis=axis) 128 | return a_part[0:K, :][a_sec_argsort_K, row_index] 129 | else: 130 | column_index = np.arange(matrix.shape[1 - axis])[:, None] 131 | a_sec_argsort_K = np.argsort(matrix[column_index, a_part[:, 0:K]], axis=axis) 132 | return a_part[:, 0:K][column_index, a_sec_argsort_K] 133 | 134 | 135 | def mean_average_precision_normal_optimized_topK(database_output, database_labels, query_output, query_labels, R, 136 | verbose=0, label_matchs=None): 137 | """ 138 | Optimizing the primary function by calculating the label-similarity matrix in advance. 139 | Furthermore, optimize the topK. 140 | :param database: 141 | :param query: 142 | :param R: top R 143 | :param verbose: 144 | :param label_matchs: In this optimization, we suppose the test and database lists are fixed, so we only 145 | calculate the test-db label matching relation once and store it in a matrix with space complexity O(db_size * test_size). 146 | :return: 147 | """ 148 | 149 | query_labels[query_labels < 0] = 0 150 | database_labels[database_labels < 0] = 0 151 | 152 | label_matrix_time = -1 153 | if label_matchs is None: 154 | tmp_time = time.time() 155 | label_matchs = calc_label_match_matrix(database_labels, query_labels) 156 | label_matrix_time = time.time() - tmp_time 157 | print("calc label matrix: time: {:.3f}".format(label_matrix_time)) 158 | 159 | query_num = query_output.shape[0] 160 | 161 | sim = -np.dot(query_output, database_output.T) 162 | start_time = time.time() 163 | # ids = np.argsort(-sim, axis=0) 164 | topk_ids = partition_arg_topK(sim, R, axis=1) 165 | end_time = time.time() 166 | sort_time = end_time - start_time 167 | print("total query: {:d}, sorting time: {:.3f}".format(query_num, sort_time)) 168 | # APx = [] 169 | # precX = [] 170 | # recX = [] 171 | 172 | column_index = np.arange(query_num)[:, None] 173 | imatchs = label_matchs.label_match_matrix[column_index, topk_ids] 174 | relevant_nums = np.sum(imatchs, axis=1) 175 | 176 | # all_sim_num = label_matchs.all_sims[i] 177 | recX = relevant_nums.astype(float) / label_matchs.all_sims 178 | precX = relevant_nums.astype(float) / R 179 | 180 | Lxs = np.cumsum(imatchs, axis=1) 181 | Pxs = Lxs.astype(float) / np.arange(1, R + 1, 1) 182 | APxs = np.sum(Pxs * imatchs, axis=1)[relevant_nums > 0] / relevant_nums[relevant_nums > 0] 183 | meanAPxs = np.sum(APxs) / query_num 184 | if verbose > 0: 185 | print("MAP: %f" % meanAPxs) 186 | print("total query: {:d}, sorting time: {:.3f}".format(query_num, sort_time)) 187 | print("total time(no label matrix): {:.3f}".format(time.time() - start_time)) 188 | if label_matrix_time > 0: 189 | print("calc label matrix: time: {:.3f}".format(label_matrix_time)) 190 | return np.mean(np.array(precX), 0), np.mean(np.array(recX), 0), meanAPxs 191 | 192 | 193 | def get_precision_recall_by_Hamming_Radius_All(database_output, database_labels, query_output, query_labels): 194 | signed_query_output = np.sign(query_output) 195 | signed_database_output = np.sign(database_output) 196 | 197 | bit_n = signed_query_output.shape[1] 198 | 199 | ips = np.dot(signed_query_output, signed_database_output.T) 200 | ips = (bit_n - ips) / 2 201 | precX = np.zeros((ips.shape[0], bit_n + 1)) 202 | recX = np.zeros((ips.shape[0], bit_n + 1)) 203 | mAPX = np.zeros((ips.shape[0], bit_n + 1)) 204 | 205 | start_time = time.time() 206 | ids = np.argsort(ips, 1) 207 | end_time = time.time() 208 | print("total query: {:d}, sorting {:.3f}".format(ips.shape[0], end_time - start_time)) 209 | for i in range(ips.shape[0]): 210 | if i % 100 == 0: 211 | tmp_time = time.time() 212 | print("query map {:d}, {:.3f}".format(i, tmp_time - end_time)) 213 | end_time = tmp_time 214 | label = query_labels[i, :] 215 | label[label == 0] = -1 216 | 217 | idx = ids[i, :] 218 | imatch = np.sum(database_labels[idx[:], :] == label, 1) > 0 219 | all_sim_num = int(np.sum(imatch)) 220 | 221 | counts = np.bincount(ips[i, :].astype(np.int64)) 222 | 223 | for r in range(bit_n + 1): 224 | if r >= len(counts): 225 | precX[i, r] = precX[i, r - 1] 226 | recX[i, r] = recX[i, r - 1] 227 | mAPX[i, r] = mAPX[i, r - 1] 228 | continue 229 | 230 | all_num = int(np.sum(counts[0:r + 1])) 231 | 232 | if all_num != 0: 233 | match_num = np.sum(imatch[0:all_num]) 234 | precX[i, r] = np.float(match_num) / all_num 235 | recX[i, r] = np.float(match_num) / all_sim_num 236 | 237 | rel = match_num 238 | Lx = np.cumsum(imatch[0:all_num]) 239 | Px = Lx.astype(float) / np.arange(1, all_num + 1, 1) 240 | if rel != 0: 241 | mAPX[i, r] = np.sum(Px * imatch[0:all_num]) / rel 242 | print("total time: {:.3f}".format(time.time() - start_time)) 243 | return np.mean(np.array(precX), 0), np.mean(np.array(recX), 0), np.mean(np.array(mAPX), 0) 244 | 245 | 246 | def get_precision_recall_by_Hamming_Radius(database_output, database_labels, query_output, query_labels, radius=2): 247 | signed_query_output = np.sign(query_output) 248 | signed_database_output = np.sign(database_output) 249 | 250 | bit_n = signed_query_output.shape[1] 251 | 252 | ips = np.dot(signed_query_output, signed_database_output.T) 253 | ips = (bit_n - ips) / 2 254 | 255 | start_time = time.time() 256 | ids = np.argsort(ips, 1) 257 | end_time = time.time() 258 | sort_time = end_time - start_time 259 | print("total query: {:d}, sorting time: {:.3f}".format(ips.shape[0], sort_time)) 260 | 261 | precX = [] 262 | recX = [] 263 | mAPX = [] 264 | matchX = [] 265 | allX = [] 266 | # query_labels = query.label 267 | # database_labels = database.label 268 | 269 | zero_count = 0 270 | for i in range(ips.shape[0]): 271 | if i % 100 == 0: 272 | tmp_time = time.time() 273 | print("query map {:d}, time: {:.3f}".format(i, tmp_time - end_time)) 274 | end_time = tmp_time 275 | label = query_labels[i, :] 276 | label[label == 0] = -1 277 | idx = np.reshape(np.argwhere(ips[i, :] <= radius), (-1)) 278 | all_num = len(idx) 279 | if all_num != 0: 280 | imatch = np.sum(database_labels[idx[:], :] == label, 1) > 0 281 | match_num = np.sum(imatch) 282 | precX.append(np.float(match_num) / all_num) 283 | matchX.append(match_num) 284 | allX.append(all_num) 285 | all_sim_num = np.sum( 286 | np.sum(database_labels[:, :] == label, 1) > 0) 287 | recX.append(np.float(match_num) / all_sim_num) 288 | if radius < 10: 289 | ips_trad = np.dot( 290 | query_output[i, :], database_output[ids[i, 0:all_num], :].T) 291 | ids_trad = np.argsort(-ips_trad, axis=0) 292 | db_labels = database_labels[ids[i, 0:all_num], :] 293 | 294 | rel = match_num 295 | imatch = np.sum(db_labels[ids_trad, :] == label, 1) > 0 296 | Lx = np.cumsum(imatch) 297 | Px = Lx.astype(float) / np.arange(1, all_num + 1, 1) 298 | if rel != 0: 299 | mAPX.append(np.sum(Px * imatch) / rel) 300 | else: 301 | mAPX.append(np.float(match_num) / all_num) 302 | # print('%d\tret_num:%d\tmatch_num:%d\tall_sim_num:%d\tAP:%f' % (i, all_num, np.int(match_num), all_sim_num, np.float(mAPX[-1]))) 303 | 304 | else: 305 | print('zero: %d, no return' % zero_count) 306 | zero_count += 1 307 | precX.append(np.float(0.0)) 308 | recX.append(np.float(0.0)) 309 | mAPX.append(np.float(0.0)) 310 | matchX.append(0.0) 311 | allX.append(0.0) 312 | print("total query: {:d}, sorting time: {:.3f}".format(ips.shape[0], sort_time)) 313 | print("total time: {:.3f}".format(time.time() - start_time)) 314 | return np.mean(np.array(precX)), np.mean(np.array(recX)), np.mean(np.array(mAPX)) 315 | 316 | 317 | class LabelMatchs(object): 318 | def __init__(self, label_match_matrix): 319 | self.label_match_matrix = label_match_matrix 320 | self.all_sims = np.sum(label_match_matrix, axis=1) 321 | 322 | 323 | def calc_label_match_matrix(database_labels, query_labels): 324 | """ 325 | 326 | :param database_labels: 327 | :param query_labels: 328 | :return: T * N matrix: N for database size and T for query size 329 | 330 | Notes 331 | ----- 332 | There is more than one definition of sign in common use for complex 333 | numbers. The definition used here is equivalent to :math:`x/\sqrt{x*x}` 334 | which is different from a common alternative, :math:`x/|x|`. 335 | 336 | Examples 337 | -------- 338 | query_labels = np.array([[0,1,0], [1,1,0]]) 339 | array([[0, 1, 0], 340 | [1, 1, 0]]) 341 | database_labels = np.array([[1,0,0], [1,1,0], [1,0,1], [0,0,1]]) 342 | array([[1, 0, 0], 343 | [1, 1, 0], 344 | [1, 0, 1], 345 | [0, 0, 1]]) 346 | ret = np.dot(query_labels, database_labels.T) > 0 347 | array([[False, True, False, False], 348 | [ True, True, True, False]]) 349 | """ 350 | return LabelMatchs(np.dot(query_labels, database_labels.T) > 0) 351 | 352 | 353 | def get_precision_recall_by_Hamming_Radius_optimized(database_output, database_labels, query_output, query_labels, 354 | radius=2, label_matchs=None, coarse_sign=True, fine_sign=False): 355 | """ 356 | 357 | :param database: 358 | :param query: 359 | :param radius: 360 | :param label_match_matrix: In this optimization, we suppose the test and database lists are fixed, so we only 361 | calculate the test-db label matching relation once and store it in a matrix with space complexity O(db_size * test_size). 362 | :return: 363 | """ 364 | # query_output = query.output 365 | # database_output = database.output 366 | # query_labels = query.label 367 | # database_labels = database.label 368 | # prevent impact from other measure function 369 | query_labels[query_labels < 0] = 0 370 | database_labels[database_labels < 0] = 0 371 | bit_n = query_output.shape[1] # i.e. K 372 | 373 | coarse_query_output = np.sign(query_output) 374 | coarse_database_output = np.sign(database_output) 375 | 376 | fine_query_output = coarse_query_output if fine_sign else query_output 377 | fine_database_output = coarse_database_output if fine_sign else database_output 378 | 379 | label_matrix_time = -1 380 | if label_matchs is None: 381 | tmp_time = time.time() 382 | label_matchs = calc_label_match_matrix(database_labels, query_labels) 383 | label_matrix_time = time.time() - tmp_time 384 | print("calc label matrix: time: {:.3f}".format(label_matrix_time)) 385 | 386 | start_time = time.time() 387 | 388 | ips = np.dot(coarse_query_output, coarse_database_output.T) 389 | ips = (bit_n - ips) / 2 390 | ids = np.argsort(ips, 1) 391 | end_time = time.time() 392 | sort_time = end_time - start_time 393 | print("total query: {:d}, sorting time: {:.3f}".format(ips.shape[0], sort_time)) 394 | all_nums = np.sum(ips <= radius, axis=1) 395 | precX = [] 396 | recX = [] 397 | mAPX = [] 398 | matchX = [] 399 | allX = [] 400 | 401 | for i in range(ips.shape[0]): 402 | if i % 100 == 0: 403 | tmp_time = time.time() 404 | print("query map {:d}, time: {:.3f}".format(i, tmp_time - end_time)) 405 | end_time = tmp_time 406 | all_num = all_nums[i] 407 | 408 | if all_num != 0: 409 | idx = ids[i, 0:all_num] 410 | if fine_sign: 411 | imatch = label_matchs.label_match_matrix[i, idx[:]] 412 | else: 413 | ips_continue = np.dot(fine_query_output[i, :], fine_database_output[idx, :].T) 414 | subset_idx = np.argsort(-ips_continue, axis=0) 415 | idx_continue = idx[subset_idx] 416 | imatch = label_matchs.label_match_matrix[i, idx_continue] 417 | 418 | match_num = int(np.sum(imatch)) 419 | matchX.append(match_num) 420 | allX.append(all_num) 421 | precX.append(np.float(match_num) / all_num) 422 | all_sim_num = label_matchs.all_sims[i] 423 | recX.append(np.float(match_num) / all_sim_num) 424 | 425 | Lx = np.cumsum(imatch) 426 | Px = Lx.astype(float) / np.arange(1, all_num + 1, 1) 427 | if match_num != 0: 428 | mAPX.append(np.sum(Px * imatch) / match_num) 429 | else: 430 | mAPX.append(0) 431 | 432 | print("total query: {:d}, sorting time: {:.3f}".format(ips.shape[0], sort_time)) 433 | print("total time(no label matrix): {:.3f}".format(time.time() - start_time)) 434 | if label_matrix_time > 0: 435 | print("calc label matrix: time: {:.3f}".format(label_matrix_time)) 436 | meanPrecX = 0 if len(precX) == 0 else np.mean(np.array(precX)) 437 | meanRecX = 0 if len(recX) == 0 else np.mean(np.array(recX)) 438 | meanMAPX = 0 if len(mAPX) == 0 else np.mean(np.array(mAPX)) 439 | return meanPrecX, meanRecX, meanMAPX, label_matchs 440 | -------------------------------------------------------------------------------- /src/evaluate/optimize_metric_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | For both Hamming Space Retrieval (MAP@H<=2) and Ranking Retrieval (MAP@TopK), 3 | we carefully optimize the measurement functions. 4 | In usual, we speedup them by O(logK)/O(logN). 5 | 6 | In this scripts, we provided some examples to demonstrate their efficiency. 7 | """ 8 | import argparse 9 | import pickle 10 | 11 | import numpy as np 12 | import os 13 | import sys 14 | import time 15 | 16 | sys.path.append('..') 17 | sys.path.append(os.path.abspath('../valid')) 18 | 19 | from evaluate.measure_utils \ 20 | import get_precision_recall_by_Hamming_Radius, \ 21 | get_precision_recall_by_Hamming_Radius_optimized, mean_average_precision_normal, \ 22 | mean_average_precision_normal_optimized_label, mean_average_precision_normal_optimized_topK 23 | 24 | 25 | if __name__ == "__main__": 26 | print("sync sync" + "*" * 20) 27 | print(os.getcwd()) 28 | parser = argparse.ArgumentParser(description='check-result') 29 | parser.add_argument('--dir_name', type=str, default='../output/coco_48bit', 30 | help="dir name") 31 | parser.add_argument('--format', type=str, default="default", help="default, ADSH") 32 | parser.add_argument('--test_radius', type=str, default="True", help="test_radius") 33 | parser.add_argument('--R', type=int, default=0, help="recall@R, 0 for no testing it") 34 | parser.add_argument('--verbose', type=int, default='1', help="verbose level") 35 | parser.add_argument('--sample_ratio', type=float, default=1, help="verbose level") 36 | print(os.getcwd()) 37 | args = parser.parse_args() 38 | dir_name = args.dir_name 39 | R = args.R 40 | test_radius = args.test_radius == 'True' 41 | 42 | ## ========= special for local PyCharm test start ========= 43 | # dir_name = '../../test_output/models_coco_48_mmhh_seen' 44 | # test_radius = True 45 | # R = 5000 46 | ## ========= special for local PyCharm test end ========= 47 | 48 | print('valid file: %s' % dir_name) 49 | verbose = args.verbose 50 | if args.format == 'default': 51 | query_output = np.load(dir_name + '/test_code.npy') 52 | query_labels = np.load(dir_name + '/test_labels.npy') 53 | database_output = np.load(dir_name + '/database_code.npy') 54 | database_labels = np.load(dir_name + '/database_labels.npy') 55 | elif args.format == 'ADSH': 56 | query_output, query_labels, database_output, database_labels = load_ADSH(dir_name) 57 | else: 58 | raise NotImplementedError 59 | if args.sample_ratio < 1: 60 | query_len = int(query_output.shape[0] * args.sample_ratio) 61 | base_len = int(database_output.shape[0] * args.sample_ratio) 62 | # query_len = min(3, query_labels.shape[0]) 63 | # base_len = min(5, database_labels.shape[0]) 64 | # query_len = min(20, query_labels.shape[0]) 65 | # base_len = min(100, database_labels.shape[0]) 66 | query_output = query_output[:query_len] 67 | query_labels = query_labels[:query_len] 68 | database_output = database_output[:base_len] 69 | database_labels = database_labels[:base_len] 70 | print("sample %.2f%% query and base, query len: %d, base len: %d" % 71 | (args.sample_ratio * 100, query_len, base_len)) 72 | # database_output = database_output[:50, :20] 73 | # database_labels = database_labels[:50, :20] 74 | # query_output = query_output[:30, :20] 75 | # query_labels = query_labels[:30, :20] 76 | 77 | output_dim = query_output.shape[1] 78 | if test_radius: 79 | line_prec = [] 80 | line_rec = [] 81 | line_mmap = [] 82 | line_time = [] 83 | # prec, rec, mmap = get_precision_recall_by_Hamming_Radius_All(img_database, img_query) 84 | # for i in range(output_dim + 1): 85 | # print('Results ham dist [%d], prec:%s, rec:%s, mAP:%s' % (i, prec[i], rec[i], mmap[i])) 86 | 87 | print("test target radius") 88 | start = time.time() 89 | prec, rec, mmap = get_precision_recall_by_Hamming_Radius( 90 | database_output, database_labels, query_output, query_labels, 2) 91 | end = time.time() 92 | line_prec.append(prec) 93 | line_rec.append(rec) 94 | line_mmap.append(mmap) 95 | line_time.append(end - start) 96 | 97 | print("test radius refine") 98 | start = time.time() 99 | prec, rec, mmap, label_matchs = get_precision_recall_by_Hamming_Radius_optimized( 100 | database_output, database_labels, query_output, query_labels, 2) 101 | end = time.time() 102 | line_prec.append(prec) 103 | line_rec.append(rec) 104 | line_mmap.append(mmap) 105 | line_time.append(end - start) 106 | 107 | time.sleep(0.1) 108 | print("rate\titem\tstd \trefine") 109 | print("rate %.2f\t" % args.sample_ratio + "prec\t" + "\t".join(["%.4f" % l for l in line_prec])) 110 | print("rate %.2f\t" % args.sample_ratio + "recall\t" + "\t".join(["%.4f" % l for l in line_rec])) 111 | print("rate %.2f\t" % args.sample_ratio + "mmap\t" + "\t".join(["%.4f" % l for l in line_mmap])) 112 | print("rate %.2f\t" % args.sample_ratio + "time\t" + "\t".join(["%.4f" % l for l in line_time])) 113 | 114 | print("\n") 115 | if R > 0: 116 | R = int(args.sample_ratio * R) 117 | line_prec = [] 118 | line_rec = [] 119 | line_mmap = [] 120 | line_time = [] 121 | 122 | print("test linear mAP") 123 | start = time.time() 124 | prec_norm, rec_norm, mmap_norm = mean_average_precision_normal( 125 | database_output, database_labels, query_output, query_labels, R) 126 | end = time.time() 127 | 128 | line_prec.append(prec_norm) 129 | line_rec.append(rec_norm) 130 | line_mmap.append(mmap_norm) 131 | line_time.append(end - start) 132 | 133 | print("test linear mAP refine label") 134 | start = time.time() 135 | prec_norm, rec_norm, mmap_norm = mean_average_precision_normal_optimized_label( 136 | database_output, database_labels, query_output, query_labels, R) 137 | end = time.time() 138 | print(end - start) 139 | line_prec.append(prec_norm) 140 | line_rec.append(rec_norm) 141 | line_mmap.append(mmap_norm) 142 | line_time.append(end - start) 143 | 144 | print("test linear mAP huge refine topK") 145 | start = time.time() 146 | prec_norm, rec_norm, mmap_norm = mean_average_precision_normal_optimized_topK( 147 | database_output, database_labels, query_output, query_labels, R) 148 | end = time.time() 149 | line_prec.append(prec_norm) 150 | line_rec.append(rec_norm) 151 | line_mmap.append(mmap_norm) 152 | line_time.append(end - start) 153 | 154 | time.sleep(0.1) 155 | 156 | print("rate item\tstd \tfaster\tfastest") 157 | print("rate %.2f\t" % args.sample_ratio + "prec\t" + "\t".join(["%.4f" % l for l in line_prec])) 158 | print("rate %.2f\t" % args.sample_ratio + "recall\t" + "\t".join(["%.4f" % l for l in line_rec])) 159 | print("rate %.2f\t" % args.sample_ratio + "mmap\t" + "\t".join(["%.4f" % l for l in line_mmap])) 160 | print("rate %.2f\t" % args.sample_ratio + "time\t" + "\t".join(["%.4f" % l for l in line_time])) 161 | -------------------------------------------------------------------------------- /src/mmhh.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from common.mmhh_config import ImageLossType, DistanceType, Mission, global_debugging, BatchType 3 | from mmhh_loss import * 4 | import torch 5 | 6 | 7 | class MMHH(object): 8 | def __init__(self, feature_net, hash_bit, 9 | trade_off=1.0, loss_lambda=0.1, use_gpu=True, distance_type=DistanceType.Hamming, similar_weight=1., 10 | image_loss_type=ImageLossType.HashNet, mission=Mission.Hashing, radius=2, gamma=1.0, 11 | sigmoid_param=1.0): 12 | self.debug = global_debugging 13 | self.gamma = gamma 14 | self.hash_bit = hash_bit 15 | self.trade_off = trade_off 16 | self.loss_lambda = loss_lambda 17 | self.use_gpu = use_gpu 18 | self.is_train = False 19 | self.distance_type = distance_type 20 | self.similar_weight = similar_weight 21 | self.image_loss_type = image_loss_type 22 | self.mission = mission 23 | self.radius = radius 24 | self.sigmoid_param = sigmoid_param 25 | self.iter_num = 0 26 | if mission == Mission.Hashing: 27 | self.feature_network = feature_net 28 | 29 | if self.use_gpu: 30 | self.feature_network = self.feature_network.cuda() 31 | 32 | def get_semi_hash_loss(self, logger, semi_batch, inputs_batch, labels_batch, iter_num, 33 | batch_type, batch_params, margin_params=None): 34 | self.iter_num = iter_num 35 | hash_batch = self.feature_network(inputs_batch) 36 | 37 | sigmoid_param = self.sigmoid_param 38 | if batch_type == BatchType.PairBatch: 39 | hash2, labels2 = hash_batch, labels_batch 40 | elif batch_type == BatchType.BatchInitMem: 41 | if iter_num < int(batch_params): 42 | hash2, labels2 = hash_batch, labels_batch 43 | else: 44 | hash2, labels2 = semi_batch.aug_memory, semi_batch.labels 45 | else: 46 | raise NotImplementedError("Wrong BatchType: " + str(batch_type)) 47 | 48 | if self.distance_type == DistanceType.MMHH: 49 | hash_loss = mmhh_loss(hash_batch, hash2, labels_batch, labels2, margin_params, 50 | gamma=self.gamma, similar_weight=self.similar_weight) 51 | elif self.distance_type is DistanceType.Hamming: 52 | hash_loss = pairwise_loss(hash_batch, hash2, labels_batch, labels2, 53 | sigmoid_param=sigmoid_param, similar_weight=self.similar_weight) 54 | else: 55 | raise NotImplementedError 56 | # quantization loss 57 | if self.loss_lambda > 0: 58 | q_loss = quantization_loss(hash_batch) 59 | else: 60 | q_loss = 0 61 | if iter_num % 1 == 0: 62 | logger.info("Iter %05d hash loss %.5f, quan loss %.5f" % (iter_num, hash_loss.item(), q_loss.item())) 63 | return hash_loss + self.loss_lambda * q_loss, hash_batch.data 64 | 65 | def predict(self, inputs): 66 | return self.feature_network(inputs) 67 | 68 | def get_parameter_list(self): 69 | if self.mission in [Mission.Cross_Modal_Transfer, Mission.Cross_Domain_Transfer]: 70 | print("transfer, parameter involves d_net") 71 | return self.feature_network.get_parameter_list() 72 | elif self.mission == Mission.Hashing: 73 | return self.feature_network.get_parameter_list() 74 | else: 75 | raise NotImplementedError 76 | 77 | def set_train(self, mode): 78 | self.feature_network.train(mode) 79 | self.is_train = mode 80 | -------------------------------------------------------------------------------- /src/mmhh_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | from common.mmhh_config import MarginParams 5 | 6 | 7 | def pairwise_loss(outputs1, outputs2, label1, label2, sigmoid_param=1.0, l_threshold=15.0, similar_weight=1.0): 8 | """ 9 | Refer to https://github.com/thuml/HashNet 10 | :param outputs1: 11 | :param outputs2: 12 | :param label1: 13 | :param label2: 14 | :param sigmoid_param: 15 | :param l_threshold: 16 | :param similar_weight: 17 | :return: 18 | """ 19 | if similar_weight == "auto": 20 | similar_weight = calc_similar_rate(label1, label2) 21 | assert similar_weight >= 0 22 | 23 | similarity = Variable(torch.mm(label1.data.float(), label2.data.float().t()) > 0).float() 24 | dot_product = sigmoid_param * torch.mm(outputs1, outputs2.t()) 25 | exp_product = torch.exp(dot_product) 26 | mask_dot = dot_product.data > l_threshold 27 | mask_exp = dot_product.data <= l_threshold 28 | mask_positive = similarity.data > 0 29 | mask_negative = similarity.data <= 0 30 | mask_dp = mask_dot & mask_positive 31 | mask_dn = mask_dot & mask_negative 32 | mask_ep = mask_exp & mask_positive 33 | mask_en = mask_exp & mask_negative 34 | 35 | dot_loss = dot_product * (1 - similarity) 36 | exp_loss = (torch.log(1 + exp_product) - similarity * dot_product) 37 | loss = (torch.sum(torch.masked_select(exp_loss, Variable(mask_ep))) + torch.sum( 38 | torch.masked_select(dot_loss, Variable(mask_dp)))) * similar_weight + torch.sum( 39 | torch.masked_select(exp_loss, Variable(mask_en))) + torch.sum(torch.masked_select(dot_loss, Variable(mask_dn))) 40 | 41 | return loss / (torch.sum(mask_positive.float()) * similar_weight + torch.sum(mask_negative.float())) 42 | 43 | 44 | def quantization_loss(outputs): 45 | return torch.sum(torch.log(torch.cosh(torch.abs(outputs) - 1))) / outputs.size(0) 46 | 47 | 48 | def mmhh_loss(outputs1, outputs2, label1, label2, margin_params: MarginParams = None, gamma=1.0, 49 | similar_weight=1.0): 50 | if similar_weight == "auto": 51 | _, _, similar_weight = calc_similar_rate_triplet(label1, label2) 52 | # calculate similarity 53 | similarity = Variable(torch.mm(label1.data.float(), label2.data.float().t()) > 0).float() 54 | weight_similarity = torch.abs(similarity - 1.0) + similarity * similar_weight 55 | 56 | dist_ham = calc_ham_dist(outputs1, outputs2) 57 | if margin_params: 58 | weight_similarity = margin_reweight(similarity, weight_similarity, dist_ham, 59 | margin_params.sim_in, margin_params.dis_in, 60 | margin_params.sim_out, margin_params.dis_out, margin_params.margin) 61 | alpha = 0.5 62 | probs = alpha * gamma / (dist_ham + gamma) 63 | 64 | loss_matrix = weight_similarity * (similarity * torch.log(1.0 / probs) + 65 | (1 - similarity) * torch.log(1.0 / (1.0 - probs))) 66 | return loss_matrix.mean() 67 | 68 | 69 | def calc_similar_rate(label1, label2): 70 | similar_sum = torch.sum(torch.mm(label1.data.float(), label2.data.float().t()) > 0) 71 | dis_sum = label1.shape[0] * label2.shape[0] - similar_sum 72 | return float(dis_sum) / float(similar_sum) 73 | 74 | 75 | def calc_similar_rate_triplet(label1, label2): 76 | similar_sum = torch.sum(torch.mm(label1.data.float(), label2.data.float().t()) > 0) 77 | dis_sum = label1.shape[0] * label2.shape[0] - similar_sum 78 | return similar_sum, dis_sum, float(dis_sum) / float(similar_sum) 79 | 80 | 81 | def calc_ham_dist(outputs1, outputs2): 82 | ip = torch.mm(outputs1, outputs2.t()) 83 | mod = torch.mm((outputs1 ** 2).sum(dim=1).reshape(-1, 1), (outputs2 ** 2).sum(dim=1).reshape(1, -1)) 84 | cos = ip / mod.sqrt() 85 | hash_bit = outputs1.shape[1] 86 | dist_ham = hash_bit / 2.0 * (1.0 - cos) 87 | return dist_ham 88 | 89 | 90 | def margin_reweight(similarity, weight_similarity, dist_hum, in_sim_weight, in_dis_weight, out_sim_weight, 91 | out_dis_weight, margin): 92 | mask_in = dist_hum.data <= margin 93 | mask_out = dist_hum.data > margin 94 | mask_sim = similarity.data > 0 95 | mask_dis = similarity.data <= 0 96 | mask_in_sim = mask_in & mask_sim 97 | mask_in_dis = mask_in & mask_dis 98 | mask_out_sim = mask_out & mask_sim 99 | mask_out_dis = mask_out & mask_dis 100 | weight_similarity[mask_in_sim] *= in_sim_weight 101 | weight_similarity[mask_in_dis] *= in_dis_weight 102 | weight_similarity[mask_out_sim] *= out_sim_weight 103 | weight_similarity[mask_out_dis] *= out_dis_weight 104 | return weight_similarity 105 | -------------------------------------------------------------------------------- /src/mmhh_network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | 4 | import torch 5 | 6 | import torch.nn as nn 7 | from torchvision import models 8 | import math 9 | import torch.nn.functional as F 10 | 11 | """ 12 | Refer to https://github.com/thuml/HashNet 13 | """ 14 | 15 | 16 | class AlexNetFc(nn.Module): 17 | def __init__(self, logger, hash_bit, increase_scale=True): 18 | """ 19 | :param hash_bit: output hash bit 20 | :param increase_scale: if the scale increase gradually. True for HashNet, False for DHN 21 | """ 22 | super(AlexNetFc, self).__init__() 23 | model_alexnet = models.alexnet(pretrained=True) 24 | self.features = model_alexnet.features 25 | self.classifier = nn.Sequential() 26 | for i in range(6): 27 | self.classifier.add_module("classifier" + str(i), model_alexnet.classifier[i]) 28 | self.feature_layers = nn.Sequential(self.features, self.classifier) 29 | 30 | hash_layer = nn.Linear(model_alexnet.classifier[6].in_features, hash_bit) 31 | hash_layer.weight.data.normal_(0, 0.01) 32 | hash_layer.bias.data.fill_(0.0) 33 | self.hash_layer = hash_layer 34 | self.__in_features = hash_bit 35 | self.activation = nn.Tanh() 36 | 37 | # HashNet part 38 | self.iter_num = 0 39 | self.step_size = 200 40 | self.gamma = 0.005 41 | self.power = 0.5 42 | self.init_scale = 1.0 43 | self.scale = self.init_scale 44 | self.increase_scale = increase_scale 45 | logger.info("increase_scale is %s" % increase_scale) 46 | 47 | def forward(self, x): 48 | if self.training: 49 | self.iter_num += 1 50 | x = self.features(x) 51 | x = x.view(x.size(0), 256 * 6 * 6) 52 | x = self.classifier(x) 53 | y = self.hash_layer(x) 54 | if self.increase_scale and self.iter_num % self.step_size == 0: 55 | self.scale = self.init_scale * (math.pow((1. + self.gamma * self.iter_num), self.power)) 56 | y = self.activation(self.scale * y) 57 | return y 58 | 59 | def get_parameter_list(self): 60 | return [{"params": self.feature_layers.parameters(), "lr": 1}, 61 | {"params": self.hash_layer.parameters(), "lr": 10}] 62 | 63 | def output_num(self): 64 | return self.__in_features 65 | 66 | 67 | resnet_dict = {"ResNet18": models.resnet18, "ResNet34": models.resnet34, "ResNet50": models.resnet50, 68 | "ResNet101": models.resnet101, "ResNet152": models.resnet152} 69 | 70 | 71 | class ResNetFc(nn.Module): 72 | def __init__(self, logger, name, hash_bit, increase_scale=True): 73 | super(ResNetFc, self).__init__() 74 | model_resnet = resnet_dict[name](pretrained=True) 75 | conv1 = model_resnet.conv1 76 | bn1 = model_resnet.bn1 77 | relu = model_resnet.relu 78 | maxpool = model_resnet.maxpool 79 | layer1 = model_resnet.layer1 80 | layer2 = model_resnet.layer2 81 | layer3 = model_resnet.layer3 82 | layer4 = model_resnet.layer4 83 | avgpool = model_resnet.avgpool 84 | self.feature_layers = nn.Sequential(conv1, bn1, relu, maxpool, 85 | layer1, layer2, layer3, layer4, avgpool) 86 | hash_layer = nn.Linear(model_resnet.fc.in_features, hash_bit) 87 | hash_layer.weight.data.normal_(0, 0.01) 88 | hash_layer.bias.data.fill_(0.0) 89 | self.hash_layer = hash_layer 90 | self.__in_features = hash_bit 91 | self.activation = nn.Tanh() 92 | 93 | # HashNet part 94 | self.iter_num = 0 95 | self.step_size = 200 96 | self.gamma = 0.005 97 | self.power = 0.5 98 | self.init_scale = 1.0 99 | self.scale = self.init_scale 100 | self.increase_scale = increase_scale 101 | logger.info("increase_scale is %s" % increase_scale) 102 | 103 | def forward(self, x): 104 | if self.training: 105 | self.iter_num += 1 106 | x = self.feature_layers(x) 107 | x = x.view(x.size(0), -1) 108 | y = self.hash_layer(x) 109 | if self.increase_scale and self.iter_num % self.step_size == 0: 110 | self.scale = self.init_scale * (math.pow((1. + self.gamma * self.iter_num), self.power)) 111 | y = self.activation(self.scale * y) 112 | return y 113 | 114 | def get_parameter_list(self): 115 | return [{"params": self.feature_layers.parameters(), "lr": 1}, 116 | {"params": self.hash_layer.parameters(), "lr": 10}] 117 | 118 | def output_num(self): 119 | return self.__in_features 120 | 121 | 122 | vgg_dict = {"VGG11": models.vgg11, "VGG13": models.vgg13, "VGG16": models.vgg16, "VGG19": models.vgg19, 123 | "VGG11BN": models.vgg11_bn, "VGG13BN": models.vgg13_bn, "VGG16BN": models.vgg16_bn, 124 | "VGG19BN": models.vgg19_bn} 125 | 126 | 127 | class VGGFc(nn.Module): 128 | def __init__(self, name, hash_bit, increase_scale=True): 129 | super(VGGFc, self).__init__() 130 | model_vgg = vgg_dict[name](pretrained=True) 131 | self.features = model_vgg.features 132 | self.classifier = nn.Sequential() 133 | for i in range(6): 134 | self.classifier.add_module("classifier" + str(i), model_vgg.classifier[i]) 135 | self.feature_layers = nn.Sequential(self.features, self.classifier) 136 | 137 | self.hash_layer = nn.Linear(model_vgg.classifier[6].in_features, hash_bit) 138 | self.hash_layer.weight.data.normal_(0, 0.01) 139 | self.hash_layer.bias.data.fill_(0.0) 140 | self.iter_num = 0 141 | self.__in_features = hash_bit 142 | self.step_size = 200 143 | self.gamma = 0.005 144 | self.power = 0.5 145 | self.init_scale = 1.0 146 | self.activation = nn.Tanh() 147 | self.scale = self.init_scale 148 | self.increase_scale = increase_scale 149 | 150 | def forward(self, x): 151 | if self.training: 152 | self.iter_num += 1 153 | x = self.features(x) 154 | x = x.view(x.size(0), 25088) 155 | x = self.classifier(x) 156 | y = self.hash_layer(x) 157 | if self.increase_scale and self.iter_num % self.step_size == 0: 158 | self.scale = self.init_scale * (math.pow((1. + self.gamma * self.iter_num), self.power)) 159 | y = self.activation(self.scale * y) 160 | return y 161 | 162 | def output_num(self): 163 | return self.__in_features 164 | 165 | 166 | def weights_init(m): 167 | classname = m.__class__.__name__ 168 | if classname.find('Conv3d') != -1: 169 | std = np.float( 170 | np.sqrt(6. / (m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * (m.in_channels + m.out_channels)))) 171 | m.weight.data.normal_(0.0, std) 172 | elif classname.find('Linear') != -1: 173 | std = np.float(np.sqrt(1. / (m.in_features * m.out_features))) 174 | m.weight.data.normal_(0.0, std) 175 | elif classname.find('BatchNorm') != -1: 176 | m.weight.data.normal_(1.0, 0.02) 177 | m.bias.data.fill_(0) 178 | 179 | 180 | class VoxNet(nn.Module): 181 | 182 | def __init__(self, logger, hash_bit, input_shape=(32, 32, 32), n_channels=1): 183 | super(VoxNet, self).__init__() 184 | self.body = torch.nn.Sequential(OrderedDict([ 185 | ('conv1', torch.nn.Conv3d(in_channels=n_channels, 186 | out_channels=32, kernel_size=5, stride=2)), 187 | ('lkrelu1', torch.nn.LeakyReLU()), 188 | ('drop1', torch.nn.Dropout(p=0.2)), 189 | ('conv2', torch.nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3)), 190 | ('lkrelu2', torch.nn.LeakyReLU()), 191 | ('pool2', torch.nn.MaxPool3d(2)), 192 | ('drop2', torch.nn.Dropout(p=0.3)) 193 | ])) 194 | 195 | # Trick to accept different input shapes 196 | x = self.body(torch.autograd.Variable( 197 | torch.rand((1, 1) + input_shape))) 198 | first_fc_in_features = 1 199 | for n in x.size()[1:]: 200 | first_fc_in_features *= n 201 | 202 | self.head = torch.nn.Sequential(OrderedDict([ 203 | ('fc1', torch.nn.Linear(first_fc_in_features, 128)), 204 | ('relu1', torch.nn.ReLU()), 205 | ('drop3', torch.nn.Dropout(p=0.4)), 206 | ])) 207 | # self.fc = torch.nn.Linear(128, num_classes) 208 | self.fc = nn.Linear(128, hash_bit) 209 | # self.fc.weight.data.normal_(0, np.float(np.sqrt(1. / (self.fc.in_features * self.fc.out_features)))) 210 | self.apply(weights_init) 211 | self.fc.bias.data.fill_(0.0) 212 | 213 | # def feature_parameters(self): 214 | # return [self.body.parameters(), self.head.parameters()] 215 | def get_parameter_list(self): 216 | voxnet_lr = 0.5 217 | return [{"params": self.body.parameters(), "lr": voxnet_lr}, 218 | {"params": self.head.parameters(), "lr": voxnet_lr}, 219 | {"params": self.fc.parameters(), "lr": voxnet_lr}] 220 | 221 | def forward(self, x): 222 | x = self.body(x) 223 | x = x.view(x.size(0), -1) 224 | x = self.head(x) 225 | y = self.fc(x) 226 | output = nn.Tanh()(y) 227 | return output 228 | 229 | 230 | class Inception3dLayer(nn.Module): 231 | def __init__(self, param_dict): 232 | super(Inception3dLayer, self).__init__() 233 | self.branch = [] 234 | for i, dictionary in enumerate(param_dict): 235 | temp_layer_list = [] 236 | for j, layer in enumerate(dictionary['layers']): 237 | temp_layer_list.append(layer(**dictionary['layer_params'][j])) 238 | if not (dictionary['activation'][j] is None): 239 | temp_layer_list.append(dictionary['activation'][j]()) 240 | if dictionary["bnorm"][j]: 241 | if "out_channels" in dictionary["layer_params"][j]: 242 | temp_layer_list.append(nn.BatchNorm3d(dictionary["layer_params"][j]["out_channels"])) 243 | temp_channels = dictionary["layer_params"][j]["out_channels"] 244 | else: 245 | temp_layer_list.append(nn.BatchNorm3d(temp_channels)) 246 | self.branch.append(nn.Sequential(*temp_layer_list)) 247 | for i in range(len(self.branch)): 248 | exec("self.branch" + str(i) + "=self.branch[i]") 249 | 250 | def forward(self, x): 251 | output = [] 252 | for branch in self.branch: 253 | output.append(branch(x)) 254 | output = torch.cat(output, dim=1) 255 | return output 256 | 257 | 258 | class VRNBasicBlock1(nn.Module): 259 | def __init__(self, in_channels, drop_rate): 260 | super(VRNBasicBlock1, self).__init__() 261 | inception_dict = [{"layers": [nn.Conv3d, nn.Conv3d, nn.Conv3d], \ 262 | "layer_params": [{"in_channels": in_channels, "out_channels": in_channels / 4, \ 263 | "kernel_size": 1, "stride": 1, "padding": 0}, \ 264 | {"in_channels": in_channels / 4, "out_channels": in_channels / 4, \ 265 | "kernel_size": 3, "stride": 1, "padding": 1}, \ 266 | {"in_channels": in_channels / 4, "out_channels": in_channels / 2, \ 267 | "kernel_size": 1, "stride": 1, "padding": 0}], \ 268 | "activation": [nn.ELU, nn.ELU, None], \ 269 | "bnorm": [True, True, False]}, \ 270 | {"layers": [nn.Conv3d, nn.Conv3d], \ 271 | "layer_params": [{"in_channels": in_channels, "out_channels": in_channels / 4, \ 272 | "kernel_size": 3, "stride": 1, "padding": 1}, \ 273 | {"in_channels": in_channels / 4, "out_channels": in_channels / 2, \ 274 | "kernel_size": 3, "stride": 1, "padding": 1}], \ 275 | "activation": [nn.ELU, None], \ 276 | "bnorm": [True, False]} 277 | ] 278 | self.bn = nn.BatchNorm3d(in_channels) 279 | self.activation = nn.ELU() 280 | self.inception = Inception3dLayer(inception_dict) 281 | self.drop = nn.Dropout(drop_rate) 282 | 283 | def forward(self, x): 284 | x1 = self.drop(self.inception(self.activation(self.bn(x)))) 285 | x = x + x1 286 | return x 287 | 288 | 289 | class VRNBasicBlock2(nn.Module): 290 | def __init__(self, in_channels): 291 | super(VRNBasicBlock2, self).__init__() 292 | inception_dict = [ 293 | {"layers": [nn.Conv3d], \ 294 | "layer_params": [{"in_channels": in_channels, "out_channels": in_channels / 2, \ 295 | "kernel_size": 3, "stride": 2, "padding": 1}], \ 296 | "activation": [None], \ 297 | "bnorm": [True]}, \ 298 | {"layers": [nn.Conv3d], \ 299 | "layer_params": [{"in_channels": in_channels, "out_channels": in_channels / 2, \ 300 | "kernel_size": 1, "stride": 2, "padding": 0}], \ 301 | "activation": [None], \ 302 | "bnorm": [True]}, \ 303 | {"layers": [nn.Conv3d, nn.MaxPool3d], \ 304 | "layer_params": [{"in_channels": in_channels, "out_channels": in_channels / 2, \ 305 | "kernel_size": 3, "stride": 1, "padding": 1}, \ 306 | {"kernel_size": 3, "stride": 2, "padding": 1}], \ 307 | "activation": [None, None], \ 308 | "bnorm": [False, True]}, \ 309 | {"layers": [nn.Conv3d, nn.AvgPool3d], \ 310 | "layer_params": [{"in_channels": in_channels, "out_channels": in_channels / 2, \ 311 | "kernel_size": 3, "stride": 1, "padding": 1}, \ 312 | {"kernel_size": 3, "stride": 2, "padding": 1, "count_include_pad": True}], \ 313 | "activation": [None, None], \ 314 | "bnorm": [False, True]} 315 | ] 316 | 317 | self.bn = nn.BatchNorm3d(in_channels) 318 | self.activation = nn.ELU() 319 | self.inception = Inception3dLayer(inception_dict) 320 | 321 | def forward(self, x): 322 | x = self.inception(self.activation(self.bn(x))) 323 | return x 324 | 325 | 326 | class VRN(nn.Module): 327 | def __init__(self, num_classes, input_shape=(32, 32, 32), n_channels=1): 328 | super(VRN, self).__init__() 329 | self.conv0 = nn.Conv3d(n_channels, 32, (3, 3, 3), (1, 1, 1), (1, 1, 1), bias=False) 330 | 331 | block1 = VRNBasicBlock1(32, 0.05) 332 | block2 = VRNBasicBlock1(32, 0.1) 333 | block3 = VRNBasicBlock1(32, 0.2) 334 | block4 = VRNBasicBlock2(32) 335 | self.block_chain1 = nn.Sequential(block1, block2, block3, block4) 336 | 337 | block5 = VRNBasicBlock1(64, 0.3) 338 | block6 = VRNBasicBlock1(64, 0.4) 339 | block7 = VRNBasicBlock1(64, 0.5) 340 | block8 = VRNBasicBlock2(64) 341 | self.block_chain2 = nn.Sequential(block5, block6, block7, block8) 342 | 343 | block9 = VRNBasicBlock1(128, 0.5) 344 | block10 = VRNBasicBlock1(128, 0.55) 345 | block11 = VRNBasicBlock1(128, 0.6) 346 | block12 = VRNBasicBlock2(128) 347 | self.block_chain3 = nn.Sequential(block9, block10, block11, block12) 348 | 349 | block13 = VRNBasicBlock1(256, 0.65) 350 | block14 = VRNBasicBlock1(256, 0.7) 351 | block15 = VRNBasicBlock1(256, 0.75) 352 | block16 = VRNBasicBlock2(256) 353 | self.block_chain4 = nn.Sequential(block13, block14, block15, block16) 354 | 355 | self.conv17 = nn.Conv3d(512, 512, 3, stride=1, padding=1) 356 | self.bn17 = nn.BatchNorm3d(512) 357 | self.drop17 = nn.Dropout(0.5) 358 | self.activation17 = nn.ELU() 359 | 360 | self.bn18 = nn.BatchNorm1d(512) 361 | self.linear19 = nn.Linear(512, 512) 362 | self.bn19 = nn.BatchNorm1d(512) 363 | self.activation19 = nn.ELU() 364 | self.fc = nn.Linear(512, num_classes) 365 | self.apply(weights_init) 366 | 367 | def feature_patameters(self): 368 | return [self.conv0.parameters(), self.block_chain1.parameters(), \ 369 | self.block_chain2.parameters(), self.block_chain3.parameters(), \ 370 | self.block_chain4.parameters(), self.bn17.parameters(), \ 371 | self.conv17.parameters(), self.activation17.parameters(), \ 372 | self.bn18.parameters(), self.linear19.parameters(), \ 373 | self.bn19.paramaters(), self.activation19.parameters()] 374 | 375 | def forward(self, x): 376 | x = self.conv0(x) 377 | 378 | x = self.block_chain1(x) 379 | x = self.block_chain2(x) 380 | x = self.block_chain3(x) 381 | x = self.block_chain4(x) 382 | 383 | x17 = self.drop17(self.bn17(self.conv17(x))) 384 | x = x + x17 385 | x = self.activation17(x) 386 | 387 | x = F.avg_pool3d(x, (x.size(2), x.size(3), x.size(4))) 388 | x = x.view(x.size(0), -1) 389 | x = self.bn18(x) 390 | x = self.activation19(self.bn19(self.linear19(x))) 391 | y = self.fc(x) 392 | return y 393 | -------------------------------------------------------------------------------- /src/semi_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from common.mmhh_config import BatchType, SemiInitType 4 | 5 | 6 | class SemiBatch(object): 7 | def __init__(self, total_num, D_out, total_iter_num, init_momentum=0.5, semi_init_type=SemiInitType.MODEL, 8 | model=None, loader=None, use_gpu=False): 9 | # init hash codes 10 | self.total_iter_num = total_iter_num 11 | self.init_momentum = init_momentum 12 | self.momentum = 0 13 | if semi_init_type == SemiInitType.RANDOM: 14 | self.aug_memory = torch.randn(total_num, D_out, requires_grad=False) 15 | elif semi_init_type == SemiInitType.MODEL: 16 | self.aug_memory = torch.zeros((total_num, D_out), dtype=torch.float, requires_grad=False) 17 | else: 18 | raise NotImplementedError 19 | # init labels 20 | self.labels = None 21 | if use_gpu: 22 | self.aug_memory = self.aug_memory.cuda() 23 | total_num = len(loader.dataset) 24 | calc_num = 0 25 | for inputs_batch, labels_batch, indices_batch in loader: 26 | if use_gpu: 27 | inputs_batch, labels_batch = inputs_batch.cuda(), labels_batch.cuda() 28 | if self.labels is None: 29 | self.labels = torch.zeros((total_num, labels_batch.shape[1]), dtype=torch.uint8, requires_grad=False) 30 | if use_gpu: 31 | self.labels = self.labels.cuda() 32 | self.labels[indices_batch] = labels_batch.data 33 | if semi_init_type == "model": 34 | x_out = model.predict(inputs_batch) 35 | self.aug_memory[indices_batch] = x_out.data 36 | calc_num += len(indices_batch) 37 | if calc_num != total_num: 38 | raise Exception("predict num %d != total %d" % (calc_num, total_num)) 39 | 40 | def update_memory(self, x_out: torch.Tensor, bidxs, norm_memory_batch=False): 41 | # update the non-parametric data 42 | m_x = self.aug_memory[bidxs] 43 | weighted_m_x = m_x * self.momentum + x_out.data * (1. - self.momentum) 44 | if norm_memory_batch: 45 | w_norm = (weighted_m_x ** 2).sum(axis=1, keepdim=True).pow(0.5) 46 | weighted_m_x = weighted_m_x.div(w_norm) 47 | self.aug_memory[bidxs] = weighted_m_x 48 | 49 | def update_momentum(self, iter_num, batch_type, batch_params) -> None: 50 | if batch_type == BatchType.BatchInitMem and iter_num < int(batch_params): 51 | self.momentum = 0 52 | else: 53 | self.momentum = self.init_momentum + (1 - self.init_momentum) * ( 54 | float(iter_num) / self.total_iter_num) ** 0.5 55 | -------------------------------------------------------------------------------- /src/test_mmhh.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import argparse 4 | import os 5 | import os.path as osp 6 | 7 | import numpy as np 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from common.fake_demo import get_fake_test_list 12 | from common.logger import get_log 13 | from common.mmhh_config import data_config 14 | from evaluate.measure_utils import \ 15 | get_precision_recall_by_Hamming_Radius_optimized, get_precision_recall_by_Hamming_Radius, \ 16 | mean_average_precision_normal, mean_average_precision_normal_optimized_topK 17 | 18 | 19 | def save_and_test(summary_writer, m_logger, model_instance, snap_path, annotation, s_dataset, t_dataset, iter_num, 20 | batch_size, use_gpu, fake_cpu_demo=False, radius=0, opt_test=True, test_sample_ratio=1.): 21 | """ 22 | evaluate the performance during training 23 | :param model_instance: 24 | :param snap_path: 25 | :param annotation: 26 | :param s_dataset: 27 | :param t_dataset: 28 | :param iter_num: 29 | :param batch_size: 30 | :param use_gpu: 31 | :return: 32 | """ 33 | torch.save(model_instance, 34 | snap_path + "{:s}_{:s}_{:s}_iter_{:05d}".format(annotation, s_dataset, t_dataset, iter_num)) 35 | m_logger.info(snap_path + "{:s}_{:s}_{:s}_iter_{:05d}".format(annotation, s_dataset, t_dataset, iter_num)) 36 | eval_st = time.perf_counter() 37 | m_config = get_test_config(t_dataset, batch_size) 38 | if fake_cpu_demo: 39 | print("fake test list") 40 | get_fake_test_list(m_config, t_dataset) 41 | 42 | model_instance.set_train(False) 43 | con_mAP, con_time = evaluate(m_logger, m_config, model_instance, use_gpu, 44 | '../test_output/{:s}_iter_{:05d}'.format(annotation, iter_num), 45 | radius, opt_test, test_sample_ratio) 46 | summary_writer.add_scalar('Test/conMAP', con_mAP, iter_num) 47 | summary_writer.add_scalar('Test/con_time', con_mAP, con_time) 48 | 49 | model_instance.set_train(True) 50 | eval_end = time.perf_counter() 51 | m_logger.info("iter_num %d, map: %.3f, use time: %.3f" % 52 | (iter_num, con_mAP, (eval_end - eval_st))) 53 | return con_mAP 54 | 55 | 56 | def save_code_and_label(params, path): 57 | database_code = params['database_code'] 58 | validation_code = params['test_code'] 59 | database_labels = params['database_labels'] 60 | validation_labels = params['test_labels'] 61 | np.save(path + "/database_code.npy", database_code) 62 | np.save(path + "/database_labels.npy", database_labels) 63 | np.save(path + "/test_code.npy", validation_code) 64 | np.save(path + "/test_labels.npy", validation_labels) 65 | 66 | 67 | def code_predict(loader, model, name, use_gpu=True): 68 | start_test = True 69 | 70 | iter_val = iter(loader[name]) 71 | print("name: %s; length: %d" % (name, len(loader[name]))) 72 | display_interval = 100 73 | for i in range(len(loader[name])): 74 | if i % display_interval == 0: 75 | print("iter: %d" % i) 76 | data = iter_val.next() 77 | inputs = data[0] 78 | labels = data[1] 79 | if use_gpu: 80 | inputs = Variable(inputs.cuda()) 81 | else: 82 | inputs = Variable(inputs) 83 | outputs = model.predict(inputs) 84 | if start_test: 85 | all_output = outputs.data.cpu().float() 86 | all_label = labels.float() 87 | start_test = False 88 | else: 89 | all_output = torch.cat((all_output, outputs.data.cpu().float()), 0) 90 | all_label = torch.cat((all_label, labels.float()), 0) 91 | return all_output, all_label 92 | 93 | 94 | def predict(config, model_instance, use_gpu, test_sample_ratio=1.0): 95 | dset_loaders = {} 96 | data_config = config["data"] 97 | 98 | print("loading base list") 99 | dset_loaders["database"] = config["loader"](data_config["database"]["list_path"], 100 | batch_size=data_config["database"]["batch_size"], resize_size=256, 101 | is_train=False, test_sample_ratio=test_sample_ratio) 102 | print("loading test list") 103 | dset_loaders["test"] = config["loader"](data_config["test"]["list_path"], 104 | batch_size=data_config["test"]["batch_size"], resize_size=256, 105 | is_train=False, test_sample_ratio=test_sample_ratio) 106 | print("start database predict") 107 | database_codes, database_labels = code_predict(dset_loaders, model_instance, "database", 108 | use_gpu=use_gpu) 109 | print("start test predict") 110 | test_codes, test_labels = code_predict(dset_loaders, model_instance, "test", 111 | use_gpu=use_gpu) 112 | print("done predict") 113 | 114 | return {"database_code": database_codes.numpy(), "database_labels": database_labels.numpy(), 115 | "test_code": test_codes.numpy(), "test_labels": test_labels.numpy()} 116 | 117 | 118 | def get_test_config(dataset, batch_size): 119 | test_config = { 120 | "prep": {"resize_size": 256, "crop_size": 224}, 121 | "dataset": dataset, 122 | "batch_size": batch_size, 123 | "data": { 124 | "database": { 125 | "list_path": data_config[dataset]['database'], 126 | "batch_size": batch_size}, 127 | "test": { 128 | "list_path": data_config[dataset]['test'], 129 | "batch_size": batch_size} 130 | }, 131 | "R": data_config[dataset]['R'], 132 | "loader": data_config[dataset]["loader"]} 133 | return test_config 134 | 135 | 136 | def evaluate(logger, config, model_instance, use_gpu, output_path=None, radius=0, opt_test=True, 137 | test_sample_ratio=1.0): 138 | print('R=%d' % int(config["R"])) 139 | # prepare data 140 | code_and_label = predict(config, model_instance, use_gpu, test_sample_ratio=test_sample_ratio) 141 | query_output = code_and_label["test_code"] 142 | query_labels = code_and_label["test_labels"] 143 | database_output = code_and_label["database_code"] 144 | database_labels = code_and_label["database_labels"] 145 | # sign_query_output = np.sign(query_output) 146 | # sign_database_output = np.sign(database_output) 147 | 148 | if output_path is not None: 149 | print("saving to %s" % output_path) 150 | if not osp.exists(output_path): 151 | os.system("mkdir -p " + output_path) 152 | save_code_and_label(code_and_label, output_path) 153 | st = time.time() 154 | con_mAP = evaluate_measurement(logger, database_output, database_labels, query_output, query_labels, 155 | config["R"], radius, opt_test) 156 | con_time = time.time() - st 157 | return con_mAP, con_time 158 | 159 | 160 | def evaluate_measurement(logger, database_output, database_labels, query_output, query_labels, topR, radius, opt_test): 161 | if radius > 0: 162 | print("evaluate hamming %d" % radius) 163 | if opt_test: 164 | logger.info("radius {:d}, optimized evaluation".format(radius)) 165 | 166 | precs, recs, mAP, _ = get_precision_recall_by_Hamming_Radius_optimized( 167 | database_output, database_labels, query_output, query_labels, radius=radius) 168 | else: 169 | logger.info("radius {:d}, primary evaluation".format(radius)) 170 | precs, recs, mAP = get_precision_recall_by_Hamming_Radius( 171 | database_output, database_labels, query_output, query_labels, radius) 172 | else: 173 | print("no lookup, hamming %d" % radius) 174 | if opt_test: 175 | logger.info("linear scan, optimized evaluation") 176 | precs, recs, mAP = mean_average_precision_normal_optimized_topK( 177 | database_output, database_labels, query_output, query_labels, topR) 178 | else: 179 | logger.info("linear scan, primary evaluation") 180 | precs, recs, mAP = mean_average_precision_normal( 181 | database_output, database_labels, query_output, query_labels, topR) 182 | return mAP 183 | 184 | 185 | if __name__ == "__main__": 186 | parser = argparse.ArgumentParser(description='Transfer Learning') 187 | parser.add_argument('--gpu_id', type=str, default='0', help="device id to run") 188 | parser.add_argument('--dataset', type=str, default='shapenet_13', help="dataset name") 189 | parser.add_argument('--batch_size', type=int, default=16, help="batch size") 190 | parser.add_argument('--output_path', type=str, help="path to save the code and labels") 191 | parser.add_argument('--model_path', type=str, help="model path") 192 | parser.add_argument('--radius', type=int, default=2, help="radius") 193 | parser.add_argument('--opt-test', action='store_true', help='setting this will use the optimized evaluation') 194 | parser.add_argument('--log_dir', type=str, default='../log/', help="log dir") 195 | parser.add_argument('--annotation', type=str, default='empty', help="annotation for distinguishing") 196 | parser.add_argument('--test_sample_ratio', type=float, default=1.0, help="sample ratio to test") 197 | 198 | args = parser.parse_args() 199 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 200 | use_gpu = torch.cuda.is_available() 201 | output_path = args.output_path 202 | model_path = args.model_path 203 | dataset = args.dataset 204 | batch_size = args.batch_size 205 | radius = args.radius 206 | opt_test = args.opt_test 207 | log_dir = args.log_dir 208 | annotation = args.annotation 209 | test_sample_ratio = args.test_sample_ratio 210 | 211 | config = get_test_config(dataset, batch_size) 212 | model_instance = torch.load(args.model_path) 213 | model_instance.set_train(False) 214 | print("calc mean_average_precision") 215 | logger = get_log("../", annotation) 216 | con_mAP, use_time = evaluate(logger, config, model_instance, use_gpu, output_path, radius, 217 | opt_test=opt_test, test_sample_ratio=test_sample_ratio) 218 | logger.info("mAP: %.3f, use time: %.3f" % (con_mAP, use_time)) 219 | print("saving done") 220 | -------------------------------------------------------------------------------- /src/train_mmhh.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import time 6 | from pprint import pprint 7 | 8 | import torch.multiprocessing 9 | import torch.optim as optim 10 | 11 | import mmhh_network 12 | from mmhh import MMHH 13 | from semi_batch import SemiBatch 14 | from common.fake_demo import get_fake_train_list 15 | from common.lr_scheduler import INVScheduler, StepScheduler 16 | from tensorboardX import SummaryWriter 17 | from common.logger import get_log 18 | from common.mmhh_config import data_config, ImageLossType, DistanceType, Mission, TestType, BatchType, SemiInitType, \ 19 | MarginParams, LrScheduleType 20 | from test_mmhh import save_and_test 21 | 22 | torch.multiprocessing.set_sharing_strategy('file_system') 23 | 24 | 25 | def get_optimizer(m_logger, parameter_list, lr_schedule_type, init_lr=1.0, decay_step=10000, weight_decay=0.1): 26 | # type = "DANN_INV" 27 | # type = "HashNet_Step" 28 | if lr_schedule_type == LrScheduleType.HashNet_Step: 29 | m_logger.info("HashNet_Step") 30 | optimizer = optim.SGD(parameter_list, lr=1.0, momentum=0.9, weight_decay=0.0005, nesterov=True) 31 | lr_scheduler = StepScheduler(gamma=0.5, step=2000, init_lr=0.0003) 32 | elif lr_schedule_type == LrScheduleType.Stair_Step: 33 | m_logger.info("Stair_Step") 34 | optimizer = optim.SGD(parameter_list, lr=0.5, momentum=0.9, weight_decay=weight_decay, nesterov=True) 35 | lr_scheduler = StepScheduler(gamma=0.5, step=decay_step, init_lr=init_lr) 36 | elif lr_schedule_type == LrScheduleType.DANN_INV: 37 | m_logger.info("DANN_INV") 38 | optimizer = optim.SGD(parameter_list, lr=0.3, momentum=0.9, weight_decay=0.0005, nesterov=True) 39 | lr_scheduler = INVScheduler(gamma=0.0003, decay_rate=0.75, init_lr=0.0003) 40 | else: 41 | raise NotImplementedError 42 | group_ratios = [param_group["lr"] for param_group in optimizer.param_groups] 43 | return lr_scheduler, optimizer, group_ratios 44 | 45 | 46 | def get_feature_net(m_logger, dataset, hash_bit, image_loss_type=ImageLossType.HashNet, 47 | network_name='ResNetFc'): 48 | if dataset in ["shapenet_13", "shapenet_9", 'modelnet_10', 'modelnet_40', 'modelnet_sm_11']: 49 | raise NotImplementedError 50 | elif dataset in ["ElectricDevices", "Crop", 'InsectWingbeat']: 51 | raise NotImplementedError 52 | else: 53 | if network_name == 'ResNetFc': 54 | m_logger.info("feature net: ResNetFc") 55 | _feature_net = mmhh_network.ResNetFc(m_logger, 'ResNet50', hash_bit, 56 | increase_scale=(image_loss_type == ImageLossType.HashNet)) 57 | elif network_name == 'AlexNetFc': 58 | m_logger.info("feature net: AlexNetFc") 59 | _feature_net = mmhh_network.AlexNetFc(m_logger, hash_bit, 60 | increase_scale=(image_loss_type == ImageLossType.HashNet)) 61 | else: 62 | raise NotImplementedError 63 | return _feature_net 64 | 65 | 66 | def get_next_iter_with_index(m_logger, data_loader, batch_size, data_loader_iter=None, balance_sampling=False): 67 | if data_loader_iter is None: 68 | data_loader_iter = iter(data_loader) 69 | try: 70 | inputs, labels, indices = data_loader_iter.next() 71 | except StopIteration: 72 | m_logger.info('stop iter, re-init') 73 | data_loader_iter, inputs, labels, indices = get_next_iter_with_index(m_logger, data_loader, batch_size, 74 | balance_sampling=balance_sampling) 75 | return data_loader_iter, inputs, labels, indices 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser(description='MMHH') 80 | parser.add_argument('--gpu_id', type=str, default='0', help="device id to run") 81 | parser.add_argument('--s_dataset', type=str, default='imagenet_13', help="dataloader dataset name") 82 | parser.add_argument('--t_dataset', type=str, default='', help="target dataset name") 83 | parser.add_argument('--batch_size', type=int, default=48, help="batch size") 84 | parser.add_argument('--snap_path', type=str, default='../snapshot/hash/', help="save path prefix") 85 | parser.add_argument('--hash_bit', type=int, default=48, help="output hash bit") 86 | parser.add_argument('--lr', type=float, default=0.0003, help="learning rate") 87 | parser.add_argument('--annotation', type=str, default='empty', help="annotation for distinguishing") 88 | parser.add_argument('--loss_lambda', type=float, default='0.1', help="loss_lambda") 89 | parser.add_argument('--num_iters', type=int, default=1000, help="number of iterations") 90 | parser.add_argument('--pre_model_path', type=str, default='', help="continue training based on previous model") 91 | # network 92 | parser.add_argument('--image_network', type=str, default='AlexNetFc', help="ResNetFc or AlexNetFc") 93 | # different schema 94 | parser.add_argument('--image_loss_type', type=str, default='DHN', help="HashNet or DHN") 95 | parser.add_argument('--distance_type', type=str, default='Hamming', 96 | help="Hamming, tSNE or Cauchy, Metric, Margin1, Margin2") 97 | parser.add_argument('--mission', type=str, default='Hashing', 98 | help="Cross_Modal_Transfer, Hashing, Cross_Domain_Transfer") 99 | parser.add_argument('--log_dir', type=str, default='../log/', help="log dir") 100 | parser.add_argument('--lr_schedule_type', type=str, default='Stair_Step', 101 | help="DANN_INV, HashNet_Step, ShapeNet_Step, Stair_Step, Timeseries_Step") 102 | 103 | parser.add_argument('--gamma', type=float, default=1.0, help="gamma") 104 | parser.add_argument('--sigmoid_param', type=float, default=1.0, help="sigmoid function") 105 | parser.add_argument('--radius', type=int, default=0, help="radius") 106 | 107 | parser.add_argument('--decay_step', type=int, default=200, help="decay_step") 108 | parser.add_argument('--weight_decay', type=float, default=0.1, help="weight_decay") 109 | parser.add_argument('--similar_weight_type', type=str, default="config", 110 | help="config: config-preset. auto: auto calc per batch, other number: manually") 111 | parser.add_argument('--norm_memory_batch', action="store_true", help="class_num") 112 | parser.add_argument('--batch_type', type=str, help="PairBatch, SemiMem, BatchInitMem, BatchSelectMem") 113 | parser.add_argument('--batch_params', type=str, default="300", help="params for batch type, optional") 114 | parser.add_argument('--semi_init_type', type=str, default="MODEL", help="MODEL, RANDOM") 115 | parser.add_argument('--semi_init_momentum', type=float, default=0.5, help="for debugs") 116 | parser.add_argument('--margin_params', type=str, default="0.5:0.8:1.0:1.0:2.0", 117 | help="margin parameters: sim_in:dis_in:sim_out:dis_out:margin_radius") 118 | 119 | parser.add_argument('--fake_cpu_demo', type=str, default='False', help="use toy dataset to valid the process") 120 | parser.add_argument('--test_sample_ratio', type=float, default=0.1, help="class_num") 121 | parser.add_argument('--snapshot-interval', type=int, default=1000, help="the interval of snapshot") 122 | parser.add_argument('--opt-test', type=str, default='False', help='setting this will use the optimized evaluation') 123 | args = parser.parse_args() 124 | pprint(vars(args)) 125 | 126 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 127 | use_gpu = torch.cuda.is_available() 128 | fake_cpu_demo = not use_gpu 129 | 130 | s_dataset = args.s_dataset 131 | t_dataset = s_dataset if args.t_dataset == '' else args.t_dataset 132 | batch_size = args.batch_size 133 | snap_path = args.snap_path 134 | hash_bit = args.hash_bit 135 | arg_lr = args.lr 136 | annotation = args.annotation 137 | loss_lambda = args.loss_lambda 138 | num_iterations = args.num_iters 139 | pre_model_path = args.pre_model_path 140 | fake_cpu_demo = args.fake_cpu_demo == 'True' 141 | radius = args.radius 142 | gamma = args.gamma 143 | sigmoid_param = args.sigmoid_param 144 | opt_test = args.opt_test == 'True' 145 | snapshot_interval = args.snapshot_interval 146 | test_sample_ratio = args.test_sample_ratio 147 | 148 | # network 149 | log_dir = args.log_dir 150 | image_network = args.image_network 151 | logger = get_log(log_dir, annotation) 152 | summary_writer = SummaryWriter('../runs/' + annotation) 153 | 154 | # Some Enum parameters 155 | if args.similar_weight_type == "auto": 156 | similar_weight = "auto" 157 | elif args.similar_weight_type == "config": 158 | similar_weight = data_config[s_dataset]["class_num"] 159 | else: 160 | similar_weight = float(args.similar_weight_type) 161 | batch_type = args.batch_type 162 | if batch_type: 163 | batch_type = BatchType[args.batch_type] 164 | else: 165 | batch_type = BatchType.SemiMem if similar_weight > 100 else BatchType.PairBatch 166 | lr_schedule_type = LrScheduleType[args.lr_schedule_type] 167 | batch_params = args.batch_params 168 | mission = Mission[args.mission] 169 | semi_init_type = SemiInitType[args.semi_init_type] 170 | image_loss_type = ImageLossType[args.image_loss_type] 171 | distance_type = DistanceType[args.distance_type] 172 | margin_params = MarginParams.from_string_params(args.margin_params) 173 | 174 | if not osp.exists(snap_path): 175 | os.system("mkdir -p " + snap_path) 176 | 177 | # Init model 178 | feature_net = get_feature_net(logger, s_dataset, hash_bit, image_loss_type=image_loss_type, 179 | network_name=image_network) 180 | if image_loss_type == ImageLossType.HashNet: 181 | loss_lambda = 0 182 | if pre_model_path != '': 183 | logger.info('load previous model: %s' % pre_model_path) 184 | model_instance = torch.load(pre_model_path) 185 | else: 186 | model_instance = MMHH(feature_net, hash_bit, trade_off=1.0, 187 | use_gpu=use_gpu, loss_lambda=loss_lambda, similar_weight=similar_weight, 188 | image_loss_type=image_loss_type, mission=mission, radius=radius, 189 | distance_type=distance_type, gamma=gamma, sigmoid_param=sigmoid_param) 190 | 191 | # Prepare data 192 | if fake_cpu_demo: 193 | source_train_list, _ = get_fake_train_list(s_dataset, t_dataset) 194 | else: 195 | source_train_list = data_config[s_dataset]["train"] 196 | train_loader = data_config[s_dataset]["loader"](source_train_list, batch_size=batch_size, 197 | resize_size=256, is_train=True, crop_size=224) 198 | # Set optimizer 199 | parameter_list = model_instance.get_parameter_list() 200 | lr_scheduler, optimizer, group_ratios = get_optimizer(logger, parameter_list, lr_schedule_type=lr_schedule_type, 201 | init_lr=arg_lr, decay_step=args.decay_step, 202 | weight_decay=args.weight_decay) 203 | semi_batch = None 204 | if batch_type in [BatchType.SemiMem, BatchType.BatchInitMem]: 205 | semi_batch = SemiBatch(len(train_loader.dataset), hash_bit, num_iterations, 206 | init_momentum=args.semi_init_momentum, semi_init_type=semi_init_type, 207 | model=model_instance, loader=train_loader, use_gpu=use_gpu) 208 | iter_batch = None 209 | all_st = time.perf_counter() 210 | 211 | logger.info("start train...") 212 | for iter_num in range(num_iterations): 213 | model_instance.set_train(True) 214 | 215 | iter_batch, inputs_batch, labels_batch, indices_batch = get_next_iter_with_index(logger, train_loader, 216 | batch_size, iter_batch) 217 | if model_instance.use_gpu: 218 | inputs_batch, labels_batch = inputs_batch.cuda(), labels_batch.cuda() 219 | 220 | optimizer = lr_scheduler.next_optimizer(group_ratios, optimizer, iter_num, logger) 221 | optimizer.zero_grad() 222 | 223 | total_loss, hash_batch = model_instance.get_semi_hash_loss(logger, semi_batch, inputs_batch, labels_batch, 224 | iter_num, batch_type, batch_params, margin_params) 225 | total_loss.backward() 226 | optimizer.step() 227 | 228 | if batch_type in [BatchType.SemiMem, BatchType.BatchInitMem]: 229 | semi_batch.update_momentum(iter_num, batch_type, batch_params) 230 | semi_batch.update_memory(hash_batch, indices_batch, norm_memory_batch=args.norm_memory_batch) 231 | 232 | con_mAP = save_and_test(summary_writer, logger, model_instance, snap_path, annotation, s_dataset, t_dataset, 233 | num_iterations, batch_size, use_gpu, radius=radius, opt_test=opt_test, 234 | test_sample_ratio=test_sample_ratio) 235 | 236 | summary_writer.close() 237 | all_end = time.perf_counter() 238 | logger.info("finish train.") 239 | logger.info("All training is finished, total time: %.3f" % (all_end - all_st)) 240 | --------------------------------------------------------------------------------