├── README.md ├── config ├── config_loader.py ├── default_config.yml └── logger.py ├── data ├── MultiEpochsDataLoader.py └── datasets.py ├── main.py ├── models └── models.py ├── train.py ├── transforms ├── __init__.py └── bow.py └── utils ├── meters.py ├── metric.py └── plotter.py /README.md: -------------------------------------------------------------------------------- 1 | # HNH 2 | *********************************************************************************************************** 3 | 4 | This repository is for ["High-order nonlocal Hashing for unsupervised 5 | cross-modal retrieval"](https://sci-hub.se/https://doi.org/10.1007/s11280-020-00859-y) 6 | 7 | 8 | *********************************************************************************************************** 9 | 10 | ## Usage 11 | ### Requirements 12 | - python == 3.11.5 13 | - pytorch == 2.1.0 14 | - torchvision 15 | - CV2 16 | - PIL 17 | - h5py 18 | 19 | ### Datasets 20 | For datasets, we follow [Deep Cross-Modal Hashing's Github (Jiang, CVPR 2017)](https://github.com/jiangqy/DCMH-CVPR2017/tree/master/DCMH_matlab/DCMH_matlab). You can download these datasets from: 21 | - Wikipedia articles, [Link](http://www.svcl.ucsd.edu/projects/crossmodal/) 22 | - MIRFLICKR25K, [[OneDrive](https://pkueducn-my.sharepoint.com/:f:/g/personal/zszhong_pku_edu_cn/EpLD8yNN2lhIpBgQ7Kl8LKABzM68icvJJahchO7pYNPV1g?e=IYoeqn)], [[Baidu Pan](https://pan.baidu.com/s/1o5jSliFjAezBavyBOiJxew), password: 8dub] 23 | - NUS-WIDE (top-10 concept), [[OneDrive](https://pkueducn-my.sharepoint.com/:f:/g/personal/zszhong_pku_edu_cn/EoPpgpDlPR1OqK-ywrrYiN0By6fdnBvY4YoyaBV5i5IvFQ?e=kja8Kj)], [[Baidu Pan](https://pan.baidu.com/s/1GFljcAtWDQFDVhgx6Jv_nQ), password: ml4y] 24 | 25 | 26 | ### Process 27 | 28 | __The following experiment results are the average values, if you demand for better results, please run the experiment a few more times (2~5).__ 29 | 30 | - Clone this repo: `git clone https://github.com/youonly-once/HNH.git`. 31 | - Change the 'dataPath' in `default_config.yml` to where you place the datasets. 32 | - An example to train a model: 33 | ```bash 34 | python main.py 35 | ``` 36 | - Modify the parameter `eval = True` in `default_config.yml` for validation. 37 | - Ablation studies (__optional__): if you want to evaluate other components of our HNH, please refer to our paper and `default_config.yml`. 38 | 39 | *********************************************************************************************************** 40 | 41 | 42 | All rights are reserved by the authors. 43 | *********************************************************************************************************** 44 | -------------------------------------------------------------------------------- /config/config_loader.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/11/7 2 | # @Author : SXS 3 | # @Github : https://github.com/SXS-PRIVATE/HNH 4 | import yaml 5 | from torchvision import transforms 6 | from transforms.bow import * 7 | 8 | 9 | class Config(object): 10 | def __init__(self, config_path): 11 | with open(f'config/{config_path}', 'r') as f: 12 | config = yaml.load(f, Loader=yaml.FullLoader) 13 | self.training = config['training'] # type: dict 14 | # self.dataset_path = config[self.training['method']] # type: dict 15 | if(self.training['dataName'] == 'mirFlickr25k'): 16 | self.dataset_config = config['mirFlickr25k'] 17 | if(self.training['dataName'] == 'nusWide'): 18 | self.dataset_config = config['nusWide'] 19 | if(self.training['dataName'] == 'wiki'): 20 | self.dataset_config = config['wiki'] 21 | self.data_preprocess = config['dataPreprocess'] 22 | self.data_augmentation = config['dataAugmentation'] 23 | #self.dataset_names = [key.lower() for key in self.dataset_path.keys()] 24 | #self.dataset_path = [value for value in self.dataset_path.values()] 25 | self.img_training_transform = self.__get_img_train_transform() 26 | self.txt_training_transform = self.__get_txt_train_transform() 27 | self.img_valid_transform = self.__get_img_valid_transform() 28 | self.txt_valid_transform = self.__get_txt_valid_transform() 29 | 30 | def __get_img_valid_transform(self): 31 | transform_list = [] 32 | data_preprocess = self.data_preprocess['img'] 33 | resize = data_preprocess['resize'] 34 | resize_transform = transforms.Resize(resize) 35 | transform_list.append(resize_transform) 36 | if data_preprocess['toTensor']: 37 | transform_list.append(transforms.ToTensor()) 38 | mean = data_preprocess['normalize']['mean'] 39 | std = data_preprocess['normalize']['std'] 40 | normalize_transform = transforms.Normalize(mean, std) 41 | transform_list.append(normalize_transform) 42 | transform = transforms.Compose(transform_list) 43 | return transform 44 | 45 | # not finish 46 | def __get_txt_valid_transform(self): 47 | return None 48 | 49 | def __get_img_train_transform(self): 50 | transform_list = [] 51 | data_preprocess = self.data_preprocess['img'] 52 | resize = data_preprocess['resize'] 53 | resize_transform = transforms.Resize(resize) 54 | transform_list.append(resize_transform) 55 | data_augmentation = self.data_augmentation['img'] 56 | if self.data_augmentation["enable"] and data_augmentation["enable"]: 57 | original_retention = float(data_augmentation['originalRetention']) 58 | data_augmentation_transform_list = [] 59 | if data_augmentation["randomRotation"]["enable"]: 60 | rotation_list = data_augmentation["randomRotation"]["rotationAngle"] 61 | rotation_transforms = [] 62 | for rotation in rotation_list: 63 | rotation_transforms.append(transforms.RandomRotation(rotation)) 64 | probability = float(data_augmentation["randomRotation"]["probability"]) 65 | random_rotation = transforms.RandomApply(rotation_transforms, probability) 66 | data_augmentation_transform_list.append(random_rotation) 67 | if data_augmentation["RandomHorizontalFlip"]["enable"]: 68 | probability = data_augmentation["RandomHorizontalFlip"]["probability"] 69 | horizontal_flip = transforms.RandomHorizontalFlip(probability) 70 | data_augmentation_transform_list.append(horizontal_flip) 71 | if data_augmentation["RandomVerticalFlip"]["enable"]: 72 | probability = data_augmentation["RandomVerticalFlip"]["probability"] 73 | vertical_flip = transforms.RandomVerticalFlip(probability) 74 | data_augmentation_transform_list.append(vertical_flip) 75 | data_augmentation_transform = transforms.RandomApply(data_augmentation_transform_list, p=1-original_retention) 76 | transform_list.append(data_augmentation_transform) 77 | if data_preprocess['toTensor']: 78 | transform_list.append(transforms.ToTensor()) 79 | mean = data_preprocess['normalize']['mean'] 80 | std = data_preprocess['normalize']['std'] 81 | normalize_transform = transforms.Normalize(mean, std) 82 | transform_list.append(normalize_transform) 83 | transform = transforms.Compose(transform_list) 84 | return transform 85 | 86 | def __get_txt_train_transform(self): 87 | transform_list = [] 88 | data_augmentation = self.data_augmentation['txt'] 89 | if self.data_augmentation["enable"]: 90 | if data_augmentation['RandomErasure']['enable']: 91 | prob = float(data_augmentation['RandomErasure']['probability']) 92 | value = float(data_augmentation['RandomErasure']['defaultValue']) 93 | random_erasure = RandomErasure(prob, value) 94 | transform_list.append(random_erasure) 95 | transform = transforms.Compose(transform_list) 96 | return transform 97 | 98 | def get_dataset_path(self, dataset_name: str): 99 | if dataset_name.lower() not in self.dataset_names: 100 | raise ValueError("there are not dataset name is %s" % dataset_name) 101 | paths = self.dataset_path[self.dataset_names.index(dataset_name.lower())] 102 | return paths['img_dir'] 103 | 104 | def get_img_dir(self): 105 | return self.get_dataset_path(self.training['dataName']) 106 | -------------------------------------------------------------------------------- /config/default_config.yml: -------------------------------------------------------------------------------- 1 | training: 2 | method: HNH #HNH2 3 | dataName: nusWide 4 | batchSize: 32 5 | bit: 16 6 | cuda: True 7 | device: 0 8 | numEpoch: 500 9 | eval: False 10 | numWorkers: 4 11 | evalInterval: 1 12 | modelDir: './checkpoint' 13 | mirFlickr25k: 14 | dataPath: 15 | labelDir: 'E:/MIRFlickr/LALL/mirflickr25k-lall.mat' 16 | txtDir: 'E:/MIRFlickr/YALL/mirflickr25k-yall.mat' 17 | imgDir: 'E:/MIRFlickr/IALL/mirflickr25k-iall.mat' 18 | 19 | beta: 1 20 | lambda: 1 21 | gamma: 0.9 22 | alpha: 40 23 | lrImg: 0.0001 24 | lrTxt: 0.01 25 | evalInterval: 1 26 | momentum: 0.9 27 | weightDecay: 0.0005 28 | eval: False 29 | kX: 2 30 | kY: 2 31 | 32 | nusWide: 33 | dataPath: 34 | labelDir: 'E:/nus-wide-tc10/nus-wide-tc10-lall.mat' 35 | txtDir: 'E:/nus-wide-tc10/nus-wide-tc10-yall.mat' 36 | imgDir: 'E:/nus-wide-tc10/nus-wide-tc10-iall.mat' 37 | 38 | beta: 1 39 | lambda: 1 40 | gamma: 0.6 41 | alpha: 40 42 | lrImg: 0.0001 43 | lrTxt: 0.01 44 | evalInterval: 1 45 | momentum: 0.9 46 | weightDecay: 0.0005 47 | eval: False 48 | kX: 2 49 | kY: 2 50 | 51 | wiki: 52 | dataPath: 53 | dataDir: 'E:/wikipedia_dataset/images' 54 | labelDir: 'E:/wikipedia_dataset/raw_features.mat' 55 | trainLabel: 'E:/wikipedia_dataset/trainset_txt_img_cat.list' 56 | testLabel: 'E:/wikipedia_dataset/testset_txt_img_cat.list' 57 | 58 | beta: 0.3 59 | lambda: 0.01 60 | gamma: 0.8 61 | alpha: 40 62 | lrImg: 0.0001 63 | lrTxt: 0.01 64 | evalInterval: 1 65 | momentum: 0.9 66 | weightDecay: 0.0005 67 | eval: False 68 | kX: 2 69 | kY: 2 70 | dataPreprocess: 71 | img: 72 | resize: [224, 224] 73 | normalize: 74 | mean: [0.485, 0.456, 0.406] 75 | std: [0.229, 0.224, 0.225] 76 | toTensor: True 77 | txt: 78 | normalize: 79 | enable: False 80 | label: 81 | onehot: True 82 | dataAugmentation: 83 | enable: True 84 | img: 85 | enable: True 86 | originalRetention: 0.2 87 | randomRotation: 88 | enable: True 89 | probability: 0.5 90 | rotationAngle: [[90, 90], [180, 180], [270, 270]] 91 | RandomHorizontalFlip: 92 | enable: True 93 | probability: 0.5 94 | RandomVerticalFlip: 95 | enable: True 96 | probability: 0.5 97 | txt: 98 | enable: False 99 | originalRetention: 0.5 100 | RandomErasure: 101 | enable: True 102 | probability: 0.2 103 | defaultValue: 0 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /config/logger.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/11/7 2 | # @Author : SXS 3 | # @Github : https://github.com/SXS-PRIVATE/HNH 4 | import logging 5 | import time 6 | import os.path as osp 7 | def get_logger(): 8 | logger = logging.getLogger('train') 9 | logger.setLevel(logging.INFO) 10 | now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) 11 | log_name = now + '_log.txt' 12 | log_dir = './log' 13 | txt_log = logging.FileHandler(osp.join(log_dir, log_name)) 14 | txt_log.setLevel(logging.INFO) 15 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 16 | txt_log.setFormatter(formatter) 17 | logger.addHandler(txt_log) 18 | 19 | stream_log = logging.StreamHandler() 20 | stream_log.setLevel(logging.INFO) 21 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 22 | stream_log.setFormatter(formatter) 23 | logger.addHandler(stream_log) 24 | return logger 25 | -------------------------------------------------------------------------------- /data/MultiEpochsDataLoader.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/11/7 2 | # @Author : SXS 3 | # @Github : https://github.com/SXS-PRIVATE/HNH 4 | import torch 5 | class MultiEpochsDataLoader(torch.utils.data.DataLoader): 6 | 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._DataLoader__initialized = False 10 | self.batch_sampler = _RepeatSampler(self.batch_sampler) 11 | self._DataLoader__initialized = True 12 | self.iterator = super().__iter__() 13 | 14 | def __len__(self): 15 | return len(self.batch_sampler.sampler) 16 | 17 | def __iter__(self): 18 | for i in range(len(self)): 19 | yield next(self.iterator) 20 | 21 | 22 | class _RepeatSampler(object): 23 | """ Sampler that repeats forever. 24 | Args: 25 | sampler (Sampler) 26 | """ 27 | 28 | def __init__(self, sampler): 29 | self.sampler = sampler 30 | 31 | def __iter__(self): 32 | while True: 33 | yield from iter(self.sampler) 34 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/11/7 2 | # @Author : SXS 3 | # @Github : https://github.com/SXS-PRIVATE/HNH 4 | import threading 5 | 6 | import torch 7 | import cv2 8 | import scipy.io as scio 9 | from PIL import Image 10 | from torch.utils.data.dataset import T_co 11 | 12 | import numpy as np 13 | import scipy.io as scio 14 | from torchvision import transforms 15 | import h5py 16 | from tqdm import tqdm 17 | import time 18 | from torch.utils.data import Dataset 19 | 20 | 21 | class DataSetBase(Dataset): 22 | txt_feat_len = 0 23 | 24 | def __init__(self): 25 | super().__init__() 26 | 27 | def __getitem__(self, index): 28 | pass 29 | 30 | 31 | class WIKI(DataSetBase): 32 | wiki_train_transform = transforms.Compose([ 33 | transforms.Resize(256), 34 | transforms.CenterCrop(224), 35 | transforms.ToTensor(), 36 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 37 | ]) 38 | 39 | wiki_test_transform = transforms.Compose([ 40 | transforms.Resize(256), 41 | transforms.CenterCrop(224), 42 | transforms.ToTensor(), 43 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 44 | ]) 45 | initialized = False 46 | lock = threading.Lock() 47 | train_label = None 48 | train_img_name = None 49 | train_txt = None 50 | test_label = None 51 | test_img_name = None 52 | test_txt = None 53 | @classmethod 54 | def init(cls,data_path_config): 55 | with cls.lock: 56 | if not cls.initialized: 57 | print("Initialization code here") 58 | cls.initialized = True 59 | label_set = scio.loadmat(data_path_config['labelDir']) 60 | cls.test_txt = np.array(label_set['T_te'], dtype=np.float32) 61 | cls.train_txt = np.array(label_set['T_tr'], dtype=np.float32) 62 | 63 | cls.test_label = [] 64 | with open(data_path_config['testLabel'], 'r') as f: 65 | total_lines = len(f.readlines()) 66 | f.seek(0) # Reset the file pointer back to the beginning of the file 67 | for line in tqdm(f.readlines(), total=total_lines, desc="Processing testLabel"): 68 | cls.test_label.extend([int(line.split()[-1]) - 1]) 69 | 70 | cls.test_img_name = [] 71 | with open(data_path_config['testLabel'], 'r') as f: 72 | total_lines = len(f.readlines()) 73 | f.seek(0) # Reset the file pointer back to the beginning of the file 74 | for line in tqdm(f.readlines(), total=total_lines, desc="Processing test_img_name"): 75 | cls.test_img_name.extend([line.split()[1]]) 76 | 77 | cls.train_label = [] 78 | with open(data_path_config['trainLabel'], 'r') as f: 79 | total_lines = len(f.readlines()) 80 | f.seek(0) # Reset the file pointer back to the beginning of the file 81 | for line in tqdm(f.readlines(), total=total_lines, desc="Processing trainLabel"): 82 | cls.train_label.extend([int(line.split()[-1]) - 1]) 83 | 84 | cls.train_img_name = [] 85 | with open(data_path_config['trainLabel'], 'r') as f: 86 | total_lines = len(f.readlines()) 87 | f.seek(0) # Reset the file pointer back to the beginning of the file 88 | for line in tqdm(f.readlines(), total=total_lines, desc="Processing train_img_name"): 89 | cls.train_img_name.extend([line.split()[1]]) 90 | 91 | txt_feat_len = cls.train_txt.shape[1] 92 | DataSetBase.txt_feat_len = txt_feat_len 93 | def __init__(self, transform=None, target_transform=None, train=True, data_path_config=None): 94 | super().__init__() 95 | WIKI.init(data_path_config) 96 | if train: 97 | self.label = self.__class__.train_label 98 | self.img_name = self.__class__.train_img_name 99 | self.txt = self.__class__.train_txt 100 | else: 101 | self.label = self.__class__.test_label 102 | self.img_name = self.__class__.test_img_name 103 | self.txt = self.__class__.test_txt 104 | self.transform = transform 105 | self.target_transform = target_transform 106 | 107 | self.f_name = ['art', 'biology', 'geography', 'history', 'literature', 'media', 'music', 'royalty', 'sport', 108 | 'warfare'] 109 | self.data_dir = data_path_config['dataDir'] 110 | 111 | def __getitem__(self, index): 112 | """ 113 | Args: 114 | index (int): Index 115 | Returns: 116 | tuple: (image, target) where target is index of the target class. 117 | """ 118 | 119 | path = self.data_dir + '/' + self.f_name[self.label[index]] + '/' + self.img_name[index] + '.jpg' 120 | img = cv2.imread(path) 121 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 122 | img = Image.fromarray(img) 123 | target = self.label[index] 124 | txt = self.txt[index] 125 | 126 | if self.transform is not None: 127 | img = self.transform(img) 128 | 129 | if self.target_transform is not None: 130 | target = self.target_transform(target) 131 | 132 | return img, txt, target, index 133 | 134 | def __len__(self): 135 | return len(self.label) 136 | 137 | 138 | class MIRFlickr(DataSetBase): 139 | mir_train_transform = transforms.Compose([ 140 | transforms.RandomHorizontalFlip(), 141 | transforms.Resize(256), 142 | transforms.RandomCrop(224), 143 | transforms.ToTensor(), 144 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 145 | ]) 146 | 147 | mir_test_transform = transforms.Compose([ 148 | transforms.Resize(256), 149 | transforms.CenterCrop(224), 150 | transforms.ToTensor(), 151 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 152 | ]) 153 | initialized = False 154 | lock = threading.Lock() 155 | indexTest = None 156 | indexDatabase = None 157 | indexTrain = None 158 | label_set = None 159 | txt_set = None 160 | @classmethod 161 | def init(cls,data_path_config): 162 | with cls.lock: 163 | if not cls.initialized: 164 | print("Initialization code here") 165 | cls.initialized = True 166 | cls.label_set = scio.loadmat(data_path_config['labelDir']) 167 | cls.label_set = np.array(cls.label_set['LAll'], dtype=np.float32) 168 | cls.txt_set = scio.loadmat(data_path_config['txtDir']) 169 | cls.txt_set = np.array(cls.txt_set['YAll'], dtype=np.float32) 170 | database_index = np.arange(cls.label_set.shape[0]) 171 | first = True 172 | for label in tqdm(range(cls.label_set.shape[1]), desc=f'初始化'): 173 | index = np.where(cls.label_set[:, label] == 1)[0] 174 | 175 | N = index.shape[0] 176 | perm = np.random.permutation(N) 177 | index = index[perm] 178 | if first: 179 | test_index = index[:160] 180 | train_index = index[160:160 + 400] 181 | first = False 182 | else: 183 | ind = np.array([i for i in list(index) if i not in (list(train_index) + list(test_index))]) 184 | test_index = np.concatenate((test_index, ind[:80])) 185 | train_index = np.concatenate((train_index, ind[80:80 + 200])) 186 | 187 | #database_index = np.array([i for i in list(range(self.label_set.shape[0])) if i not in list(test_index)]) 188 | #database_index = train_index 189 | 190 | if train_index.shape[0] < 5000: 191 | pick = np.array([i for i in list(database_index) if i not in list(train_index)]) 192 | N = pick.shape[0] 193 | perm = np.random.permutation(N) 194 | pick = pick[perm] 195 | res = 5000 - train_index.shape[0] 196 | train_index = np.concatenate((train_index, pick[:res])) 197 | 198 | cls.indexTest = test_index 199 | cls.indexDatabase = database_index 200 | cls.indexTrain = train_index 201 | 202 | DataSetBase.txt_feat_len = cls.txt_set.shape[1] 203 | 204 | def __init__(self, transform=None, target_transform=None, train=True, database=False, data_path_config=None): 205 | super().__init__() 206 | MIRFlickr.init(data_path_config) 207 | self.transform = transform 208 | self.target_transform = target_transform 209 | if train: 210 | self.train_labels = self.label_set[self.indexTrain] 211 | self.train_index = self.indexTrain 212 | self.txt = self.txt_set[self.indexTrain] 213 | elif database: 214 | self.train_labels = self.label_set[self.indexDatabase] 215 | self.train_index = self.indexDatabase 216 | self.txt = self.txt_set[self.indexDatabase] 217 | else: 218 | self.train_labels = self.label_set[self.indexTest] 219 | self.train_index = self.indexTest 220 | self.txt = self.txt_set[self.indexTest] 221 | 222 | self.img_dir = data_path_config['imgDir'] 223 | 224 | def __getitem__(self, index): 225 | mir_flickr = h5py.File(self.img_dir, 'r', libver='latest', swmr=True) 226 | img, target = mir_flickr['IAll'][self.train_index[index]], self.train_labels[index] 227 | img = Image.fromarray(np.transpose(img, (2, 1, 0))) 228 | mir_flickr.close() 229 | 230 | txt = self.txt[index] 231 | 232 | if self.transform is not None: 233 | img = self.transform(img) 234 | 235 | if self.target_transform is not None: 236 | target = self.target_transform(target) 237 | 238 | return img, txt, target, index 239 | 240 | def __len__(self): 241 | return len(self.train_labels) 242 | 243 | 244 | class NUSWIDE(DataSetBase): 245 | nus_train_transform = transforms.Compose([ 246 | transforms.RandomHorizontalFlip(), 247 | transforms.Resize(256), 248 | transforms.RandomCrop(224), 249 | transforms.ToTensor(), 250 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 251 | ]) 252 | 253 | nus_test_transform = transforms.Compose([ 254 | transforms.Resize(256), 255 | transforms.CenterCrop(224), 256 | transforms.ToTensor(), 257 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 258 | ]) 259 | initialized = False 260 | lock = threading.Lock() 261 | indexTest = None 262 | indexDatabase = None 263 | indexTrain = None 264 | label_set = None 265 | txt_set = None 266 | @classmethod 267 | def init(cls,data_path_config): 268 | with cls.lock: 269 | if not cls.initialized: 270 | print("Initialization code here") 271 | cls.initialized = True 272 | cls.label_set = scio.loadmat(data_path_config['labelDir']) 273 | cls.label_set = np.array(cls.label_set['LAll'], dtype=np.float32) 274 | txt_file = h5py.File(data_path_config['txtDir']) 275 | cls.txt_set = np.array(txt_file['YAll']).transpose() 276 | txt_file.close() 277 | 278 | first = True 279 | 280 | for label in tqdm(range(cls.label_set.shape[1]), desc=f'初始化'): 281 | index = np.where(cls.label_set[:, label] == 1)[0] 282 | 283 | N = index.shape[0] 284 | perm = np.random.permutation(N) 285 | index = index[perm] 286 | 287 | if first: 288 | test_index = index[:200] 289 | train_index = index[200:700] 290 | first = False 291 | else: 292 | ind = np.array([i for i in list(index) if i not in (list(train_index) + list(test_index))]) 293 | test_index = np.concatenate((test_index, ind[:200])) 294 | train_index = np.concatenate((train_index, ind[200:700])) 295 | 296 | database_index = np.array([i for i in list(range(cls.label_set.shape[0])) if i not in list(test_index)]) 297 | 298 | cls.indexTest = test_index 299 | cls.indexDatabase = database_index 300 | cls.indexTrain = train_index 301 | 302 | txt_feat_len = cls.txt_set.shape[1] 303 | DataSetBase.txt_feat_len = txt_feat_len 304 | def __init__(self, transform=None, target_transform=None, train=True, database=False, data_path_config=None): 305 | super().__init__() 306 | NUSWIDE.init(data_path_config) 307 | if train: 308 | self.train_labels = self.__class__.label_set[self.__class__.indexTrain] 309 | self.train_index = self.__class__.indexTrain 310 | self.txt = self.__class__.txt_set[self.__class__.indexTrain] 311 | elif database: 312 | self.train_labels = self.__class__.label_set[self.__class__.indexDatabase] 313 | self.train_index = self.__class__.indexDatabase 314 | self.txt = self.__class__.txt_set[self.__class__.indexDatabase] 315 | else: 316 | self.train_labels = self.__class__.label_set[self.__class__.indexTest] 317 | self.train_index = self.__class__.indexTest 318 | self.txt = self.__class__.txt_set[self.__class__.indexTest] 319 | 320 | self.transform = transform 321 | self.target_transform = target_transform 322 | 323 | self.img_dir = data_path_config['imgDir'] 324 | 325 | def __getitem__(self, index): 326 | nuswide = h5py.File(self.img_dir, 'r', libver='latest', swmr=True) 327 | img, target = nuswide['IAll'][self.train_index[index]], self.train_labels[index] 328 | img = Image.fromarray(np.transpose(img, (2, 1, 0))) 329 | nuswide.close() 330 | 331 | txt = self.txt[index] 332 | 333 | if self.transform is not None: 334 | img = self.transform(img) 335 | 336 | if self.target_transform is not None: 337 | target = self.target_transform(target) 338 | 339 | return img, txt, target, index 340 | 341 | def __len__(self): 342 | return len(self.train_labels) 343 | 344 | 345 | def get_dataset(dataset_name, data_path): 346 | if dataset_name == "wiki": 347 | train_dataset = WIKI(data_path_config=data_path, train=True, 348 | transform=WIKI.wiki_train_transform) 349 | test_dataset = WIKI(data_path_config=data_path, train=False, 350 | transform=WIKI.wiki_test_transform) 351 | database_dataset = WIKI(data_path_config=data_path, train=True, 352 | transform=WIKI.wiki_test_transform) 353 | 354 | if dataset_name == "mirFlickr25k": 355 | train_dataset = MIRFlickr(train=True, transform=MIRFlickr.mir_train_transform, 356 | data_path_config=data_path) 357 | test_dataset = MIRFlickr(train=False, database=False, 358 | transform=MIRFlickr.mir_test_transform, 359 | data_path_config=data_path) 360 | database_dataset = MIRFlickr(train=False, database=True, 361 | transform=MIRFlickr.mir_test_transform, 362 | data_path_config=data_path) 363 | 364 | if dataset_name == "nusWide": 365 | train_dataset = NUSWIDE(train=True, transform=NUSWIDE.nus_train_transform, 366 | data_path_config=data_path) 367 | test_dataset = NUSWIDE(train=False, database=False, 368 | transform=NUSWIDE.nus_test_transform, 369 | data_path_config=data_path) 370 | database_dataset = NUSWIDE(train=False, database=True, 371 | transform=NUSWIDE.nus_test_transform, 372 | data_path_config=data_path) 373 | return train_dataset, test_dataset, database_dataset 374 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | # @Time : 2023/11/7 3 | # @Author : SXS 4 | # @Github : https://github.com/SXS-PRIVATE/HNH 5 | import train 6 | import torch 7 | 8 | torch.backends.cudnn.benchmark = True 9 | 10 | 11 | if __name__ == '__main__': 12 | train.run() 13 | 14 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/11/7 2 | # @Author : SXS 3 | # @Github : https://github.com/SXS-PRIVATE/HNH 4 | import torch 5 | import math 6 | import torchvision 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision.models import AlexNet_Weights 10 | 11 | 12 | class ImgNet(nn.Module): 13 | def __init__(self, code_len): 14 | super(ImgNet, self).__init__() 15 | self.alexnet = torchvision.models.alexnet(weights=AlexNet_Weights.DEFAULT) 16 | self.alexnet.classifier = nn.Sequential(*list(self.alexnet.classifier.children())[:6]) 17 | self.fc_encode = nn.Linear(4096, code_len) 18 | self.alpha = 1.0 19 | 20 | 21 | def forward(self, x): 22 | x = self.alexnet.features(x) 23 | x = x.view(x.size(0), -1) 24 | feat = self.alexnet.classifier(x) 25 | hid = self.fc_encode(feat) 26 | code = F.tanh(hid) 27 | 28 | return feat, hid, code 29 | 30 | def set_alpha(self, epoch): 31 | self.alpha = math.pow((1.0 * epoch + 1.0), 0.5) 32 | 33 | 34 | class TxtNet(nn.Module): 35 | def __init__(self, code_len, txt_feat_len): 36 | super(TxtNet, self).__init__() 37 | self.fc1 = nn.Linear(txt_feat_len, 8192) 38 | self.fc2 = nn.Linear(8192, 4096) 39 | self.fc3 = nn.Linear(4096, code_len) 40 | self.alpha = 1.0 41 | 42 | def forward(self, x): 43 | x = F.relu(self.fc1(x)) 44 | feat = F.relu(self.fc2(x)) 45 | hid = self.fc3(feat) 46 | code = F.tanh(self.alpha * hid) 47 | return feat, hid, code 48 | 49 | def set_alpha(self, epoch): 50 | self.alpha = math.pow((1.0 * epoch + 1.0), 0.5) 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/11/7 2 | # @Author : SXS 3 | # @Github : https://github.com/SXS-PRIVATE/HNH 4 | import datetime 5 | import os 6 | import sys 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | from data.MultiEpochsDataLoader import MultiEpochsDataLoader 13 | from config.config_loader import Config 14 | from utils.metric import compress_wiki, compress, calculate_top_map 15 | from data import datasets 16 | from models.models import ImgNet, TxtNet 17 | import os.path as osp 18 | from tqdm import tqdm 19 | from utils.meters import AverageMeter 20 | from utils.plotter import get_plotter 21 | import visdom 22 | from config.logger import get_logger 23 | 24 | 25 | class HNH: 26 | def __init__(self, config: Config): 27 | self.load_config(config) 28 | 29 | torch.manual_seed(1) 30 | torch.cuda.manual_seed_all(1) 31 | torch.cuda.set_device(self.device) 32 | 33 | self.train_dataset, self.test_dataset, self.database_dataset = datasets.get_dataset(self.dataset_name, 34 | self.data_path) 35 | 36 | self.loss_store = ["common space loss", 'intra loss', 'inter loss', 'loss'] 37 | self.plotter = get_plotter(self.name) if visdom else None 38 | self.loss_store = self._loss_store_init(self.loss_store) 39 | 40 | self.CodeNet_I = ImgNet(code_len=self.bit) 41 | self.FeatNet_I = ImgNet(code_len=self.bit) 42 | self.CodeNet_T = TxtNet(code_len=self.bit, txt_feat_len=datasets.DataSetBase.txt_feat_len) 43 | 44 | self.set_train_loader() 45 | self.set_optimizer() 46 | 47 | def set_train_loader(self): 48 | # Data Loader (Input Pipeline) 49 | self.train_loader = MultiEpochsDataLoader(dataset=self.train_dataset, 50 | batch_size=self.batch_size, 51 | shuffle=True, 52 | num_workers=self.num_workers, 53 | drop_last=True, 54 | pin_memory=True) 55 | 56 | self.test_loader = MultiEpochsDataLoader(dataset=self.test_dataset, 57 | batch_size=self.batch_size, 58 | shuffle=False, 59 | num_workers=self.num_workers) 60 | 61 | self.database_loader = MultiEpochsDataLoader(dataset=self.database_dataset, 62 | batch_size=self.batch_size, 63 | shuffle=False, 64 | num_workers=self.num_workers) 65 | 66 | def set_optimizer(self): 67 | if self.dataset_name == "wiki": 68 | self.opt_I = torch.optim.SGD([{'params': self.CodeNet_I.fc_encode.parameters(), 'lr': self.lr_img}, 69 | {'params': self.CodeNet_I.alexnet.classifier.parameters(), 70 | 'lr': self.lr_img}], 71 | momentum=self.momentum, weight_decay=self.weight_decay) 72 | 73 | if self.dataset_name == "mirFlickr25k" or self.dataset_name == "nusWide": 74 | self.opt_I = torch.optim.SGD(self.CodeNet_I.parameters(), lr=self.lr_img, momentum=self.momentum, 75 | weight_decay=self.weight_decay) 76 | 77 | self.opt_T = torch.optim.SGD(self.CodeNet_T.parameters(), lr=self.lr_txt, momentum=self.momentum, 78 | weight_decay=self.weight_decay) 79 | 80 | def load_config(self, config: Config): 81 | self.logger = get_logger() 82 | self.method = config.training['method'] 83 | self.dataset_name = config.training['dataName'] 84 | self.name = f'{self.method}_{self.dataset_name}' 85 | self.model_dir = config.training['modelDir'] 86 | self.bit = int(config.training['bit']) 87 | self.batch_size = int(config.training['batchSize']) 88 | self.device = config.training['device'] 89 | self.max_epoch = config.training['numEpoch'] 90 | self.num_workers = config.training['numWorkers'] 91 | self.dataset_config = config.dataset_config 92 | self.data_path = config.dataset_config['dataPath'] 93 | self.lr_img = self.dataset_config['lrImg'] 94 | self.lr_txt = self.dataset_config['lrTxt'] 95 | self.weight_decay = self.dataset_config['weightDecay'] 96 | self.momentum = self.dataset_config['momentum'] 97 | self.eval_interval = self.dataset_config['evalInterval'] 98 | self.eval = self.dataset_config['eval'] 99 | self.gamma = self.dataset_config['gamma'] 100 | self.lambda_ = self.dataset_config['lambda'] 101 | self.beta = self.dataset_config['beta'] 102 | self.alpha = self.dataset_config['alpha'] 103 | self.k_x = self.dataset_config['kX'] 104 | self.k_y = self.dataset_config['kY'] 105 | 106 | os.environ['CUDA_VISIBLE_DEVICES'] = str(self.device) 107 | cuda = bool(config.training['cuda']) 108 | self.img_training_transform = config.img_training_transform 109 | self.img_valid_transform = config.img_valid_transform 110 | self.txt_training_transform = config.txt_training_transform 111 | self.txt_valid_transform = config.txt_valid_transform 112 | t = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 113 | # sys.stdout = Logger(os.path.join('..', 'logs', self.name, self.dataset_name, t + '.txt')) 114 | if cuda: 115 | print("using gpu device: %s" % str(self.device)) 116 | else: 117 | print("using cpu") 118 | print("training transform:") 119 | print("img:", config.img_training_transform) 120 | print("txt:", config.txt_training_transform) 121 | print("valid transform") 122 | print("img:", config.img_valid_transform) 123 | print("txt:", config.txt_valid_transform) 124 | 125 | def train(self, epoch): 126 | self.CodeNet_I.cuda().train() 127 | self.FeatNet_I.cuda().eval() 128 | self.CodeNet_T.cuda().train() 129 | 130 | self.CodeNet_I.set_alpha(epoch) 131 | self.CodeNet_T.set_alpha(epoch) 132 | 133 | for idx, (img, F_T, labels, _) in enumerate(tqdm(self.train_loader)): 134 | img = Variable(img.cuda()) 135 | # LDA topic vectors or the tag occurrence features ? 136 | F_T = torch.FloatTensor(F_T.numpy()).cuda() # batch_size * 1386 137 | labels = Variable(labels.cuda()) 138 | 139 | self.opt_I.zero_grad() 140 | self.opt_T.zero_grad() 141 | # extract F_I 142 | F_I, _, _ = self.FeatNet_I(img) # batch_size * 4096 143 | _, hid_I, code_I = self.CodeNet_I(img) 144 | _, hid_T, code_T = self.CodeNet_T(F_T) 145 | 146 | # ==========calculate S_tilde========= 147 | F_I = F.normalize(F_I) 148 | A_x = torch.matmul(F_I, F_I.t()) 149 | # t_text = torch.squeeze(text, dim=1).squeeze(dim=2) 150 | F_T = F.normalize(F_T) 151 | A_y = torch.matmul(F_T, F_T.t()) 152 | A_tilde_x = self.k_x * (A_x * (A_x.t().mm(A_x))) - 1 153 | A_tilde_y = self.k_y * (A_y * (A_y.t().mm(A_y))) - 1 154 | S_tilde = self.gamma * A_tilde_x + (1 - self.gamma) * A_tilde_y 155 | 156 | # train 157 | 158 | B_x = F.tanh(hid_I).t() 159 | B_y = F.tanh(hid_T).t() 160 | B_x = F.normalize(B_x) 161 | B_y = F.normalize(B_y) 162 | 163 | if (self.method == 'HNH2'): 164 | J3 = self.lambda_ * F.mse_loss(S_tilde, B_x.t() @ B_y) 165 | J1 = self.alpha * F.mse_loss(S_tilde, B_x.t() @ B_x) 166 | J2 = self.beta * F.mse_loss(S_tilde, B_y.t() @ B_y) 167 | else: 168 | # calculate U 169 | Ic = torch.eye(self.bit).cuda() 170 | Ic_1 = torch.eye(self.batch_size).cuda() 171 | # U = (2 * Ic + (beta / alpha) * Bx * Bx^T + (beta / alpha) * By * By^T)^(-1) 172 | b_d_a = (self.beta / self.alpha) 173 | U = torch.inverse(2 * Ic + b_d_a * B_x @ B_x.t() + b_d_a * B_y @ B_y.t()) 174 | # (Bx + By) * (Ic + (beta / alpha) * S_tilde) 175 | temp = (B_x + B_y) @ (Ic_1 + b_d_a * S_tilde) 176 | 177 | # calculate U * temp 178 | U = U @ temp 179 | 180 | # calculate loss 181 | J1 = self.alpha * (F.mse_loss(U, B_x) + F.mse_loss(U, B_y)) 182 | J2 = self.beta * (F.mse_loss(S_tilde, U.t() @ B_x) + F.mse_loss(S_tilde, U.t() @ B_y)) 183 | J3 = self.lambda_ * F.mse_loss(S_tilde, B_x.t() @ B_y) 184 | 185 | loss = J1 + J2 + J3 186 | 187 | loss.backward() 188 | self.opt_I.step() 189 | self.opt_T.step() 190 | self.loss_store['common space loss'].update(J1.item()) 191 | self.loss_store['intra loss'].update(J2.item()) 192 | self.loss_store['inter loss'].update(J3.item()) 193 | self.loss_store['loss'].update(loss.item()) 194 | self.remark_loss(J1, J2, J3, loss) 195 | 196 | self.print_loss(epoch) 197 | self.plot_loss("loss") 198 | self.reset_loss() 199 | # eval the Model 200 | if (epoch + 1) % self.eval_interval == 0: 201 | self.evaluate() 202 | #self.lr_schedule() 203 | self.plotter.next_epoch() 204 | # save the model 205 | if epoch + 1 == self.max_epoch: 206 | self.save_checkpoints(step=epoch + 1) 207 | 208 | def evaluate(self): 209 | self.logger.info('--------------------Evaluation: Calculate top MAP-------------------') 210 | # Change model to 'eval' mode (BN uses moving mean/var). 211 | self.CodeNet_I.eval().cuda() 212 | self.CodeNet_T.eval().cuda() 213 | 214 | if self.dataset_name == "wiki": 215 | re_BI, re_BT, re_L, qu_BI, qu_BT, qu_L = compress_wiki(self.database_loader, self.test_loader, 216 | self.CodeNet_I, self.CodeNet_T, 217 | self.database_dataset, self.test_dataset) 218 | 219 | if self.dataset_name == "mirFlickr25k" or self.dataset_name == "nusWide": 220 | re_BI, re_BT, re_L, qu_BI, qu_BT, qu_L = compress(self.database_loader, self.test_loader, self.CodeNet_I, 221 | self.CodeNet_T, self.database_dataset, self.test_dataset) 222 | 223 | MAP_I2T = calculate_top_map(qu_B=qu_BI, re_B=re_BT, qu_L=qu_L, re_L=re_L, topk=50) 224 | MAP_T2I = calculate_top_map(qu_B=qu_BT, re_B=re_BI, qu_L=qu_L, re_L=re_L, topk=50) 225 | if self.plotter: 226 | self.plotter.plot("mAP", 'i->t', MAP_I2T.item()) 227 | self.plotter.plot("mAP", "t->i", MAP_T2I.item()) 228 | self.logger.info('MAP of Image to Text: %.3f, MAP of Text to Image: %.3f' % (MAP_I2T, MAP_T2I)) 229 | self.logger.info('--------------------------------------------------------------------') 230 | 231 | def save_checkpoints(self, step, file_name='latest.pth'): 232 | ckp_path = osp.join(self.model_dir, file_name) 233 | obj = { 234 | 'ImgNet': self.CodeNet_I.state_dict(), 235 | 'TxtNet': self.CodeNet_T.state_dict(), 236 | 'step': step, 237 | } 238 | torch.save(obj, ckp_path) 239 | self.logger.info('**********Save the trained model successfully.**********') 240 | 241 | def load_checkpoints(self, file_name='latest.pth'): 242 | ckp_path = osp.join(self.model_dir, file_name) 243 | try: 244 | obj = torch.load(ckp_path, map_location=lambda storage, loc: storage.cuda()) 245 | self.logger.info('**************** Load checkpoint %s ****************' % ckp_path) 246 | except IOError: 247 | self.logger.error('********** No checkpoint %s!*********' % ckp_path) 248 | return 249 | self.CodeNet_I.load_state_dict(obj['ImgNet']) 250 | self.CodeNet_T.load_state_dict(obj['TxtNet']) 251 | self.logger.info('********** The loaded model has been trained for %d epochs.*********' % obj['step']) 252 | 253 | @staticmethod 254 | def _loss_store_init(loss_store): 255 | """ 256 | initialize loss store, transform list to dict by (loss name -> loss register) 257 | :param loss_store: the list with name of loss 258 | :return: the dict of loss store 259 | """ 260 | dict_store = {} 261 | for loss_name in loss_store: 262 | dict_store[loss_name] = AverageMeter() 263 | loss_store = dict_store 264 | return loss_store 265 | 266 | def plot_loss(self, title, loss_store=None): 267 | """ 268 | plot loss in loss_store at a figure 269 | :param title: the title of figure name 270 | :param loss_store: the loss store to plot, if none, the default loss store will plot 271 | """ 272 | if loss_store is None: 273 | loss_store = self.loss_store 274 | if self.plotter: 275 | for name, loss in loss_store.items(): 276 | self.plotter.plot(title, name, loss.avg) 277 | 278 | def print_loss(self, epoch, loss_store=None): 279 | loss_str = "epoch: [%3d/%3d], " % (epoch + 1, self.max_epoch) 280 | if loss_store is None: 281 | loss_store = self.loss_store 282 | for name, value in loss_store.items(): 283 | loss_str += name + " {:4.3f}".format(value.avg) + "\t" 284 | print(loss_str) 285 | sys.stdout.flush() 286 | 287 | def reset_loss(self, loss_store=None): 288 | if loss_store is None: 289 | loss_store = self.loss_store 290 | for store in loss_store.values(): 291 | store.reset() 292 | 293 | def remark_loss(self, *args, n=1): 294 | """ 295 | store loss into loss store by order 296 | :param args: loss to store 297 | :return: 298 | """ 299 | for i, loss_name in enumerate(self.loss_store.keys()): 300 | if isinstance(args[i], torch.Tensor): 301 | self.loss_store[loss_name].update(args[i].item(), n) 302 | else: 303 | self.loss_store[loss_name].update(args[i], n) 304 | 305 | 306 | def run(config_path='default_config.yml', **kwargs): 307 | config = Config(config_path) 308 | hnh = HNH(config) 309 | if hnh.eval: 310 | hnh.load_checkpoints() 311 | hnh.eval() 312 | 313 | else: 314 | for epoch in range(hnh.max_epoch): 315 | # train the Model 316 | hnh.train(epoch) 317 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/11/7 2 | # @Author : SXS 3 | # @Github : https://github.com/SXS-PRIVATE/HNH 4 | -------------------------------------------------------------------------------- /transforms/bow.py: -------------------------------------------------------------------------------- 1 | 2 | # @Time : 2023/11/7 3 | # @Author : SXS 4 | # @Github : https://github.com/SXS-PRIVATE/HNH 5 | import torch 6 | import numpy as np 7 | 8 | 9 | class RandomErasure(object): 10 | """ 11 | Randomly erasure input bow vector. 12 | change one to a assign value in a bow vector 13 | """ 14 | def __init__(self, prob=0.2, value=0): 15 | if prob < 0 or prob > 1: 16 | raise ValueError("probability only can be a float in 0 to 1") 17 | self.prob = prob 18 | self.value = value 19 | 20 | def __call__(self, vector: torch.Tensor): 21 | vector_length = vector.shape[-1] 22 | index = np.random.permutation(vector_length) 23 | change_num = int(vector_length * self.prob) 24 | vector[index[: change_num]] = self.value 25 | return vector 26 | 27 | def __repr__(self): 28 | return self.__class__.__name__ + '(prob={0}, value={1})'.format(self.prob, self.value) 29 | 30 | -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | import torchvision.datasets as dsets 6 | from torchvision import transforms 7 | from torch.autograd import Variable 8 | import torchvision 9 | import math 10 | import numpy as np 11 | from tqdm import tqdm 12 | def compress_wiki(train_loader, test_loader, modeli, modelt, train_dataset, test_dataset, classes=10): 13 | re_BI = list([]) 14 | re_BT = list([]) 15 | re_L = list([]) 16 | for _, (data_I, data_T, target, _) in enumerate(train_loader): 17 | var_data_I = Variable(data_I.cuda()) 18 | _,_,code_I = modeli(var_data_I) 19 | code_I = torch.sign(code_I) 20 | re_BI.extend(code_I.cpu().data.numpy()) 21 | re_L.extend(target) 22 | 23 | var_data_T = Variable(torch.FloatTensor(data_T.numpy()).cuda()) 24 | _,_,code_T = modelt(var_data_T) 25 | code_T = torch.sign(code_T) 26 | re_BT.extend(code_T.cpu().data.numpy()) 27 | 28 | qu_BI = list([]) 29 | qu_BT = list([]) 30 | qu_L = list([]) 31 | for _, (data_I, data_T, target, _) in enumerate(test_loader): 32 | var_data_I = Variable(data_I.cuda()) 33 | _,_,code_I = modeli(var_data_I) 34 | code_I = torch.sign(code_I) 35 | qu_BI.extend(code_I.cpu().data.numpy()) 36 | qu_L.extend(target) 37 | 38 | var_data_T = Variable(torch.FloatTensor(data_T.numpy()).cuda()) 39 | _,_,code_T = modelt(var_data_T) 40 | code_T = torch.sign(code_T) 41 | qu_BT.extend(code_T.cpu().data.numpy()) 42 | 43 | re_BI = np.array(re_BI) 44 | re_BT = np.array(re_BT) 45 | re_L = np.eye(classes)[np.array(re_L)] 46 | 47 | qu_BI = np.array(qu_BI) 48 | qu_BT = np.array(qu_BT) 49 | qu_L = np.eye(classes)[np.array(qu_L)] 50 | 51 | return re_BI, re_BT, re_L, qu_BI, qu_BT, qu_L 52 | 53 | 54 | def compress(train_loader, test_loader, model_I, model_T, train_dataset, test_dataset): 55 | re_BI = list([]) 56 | re_BT = list([]) 57 | re_L = list([]) 58 | for _, (data_I, data_T, _, _) in enumerate(tqdm(train_loader,desc='compress for train')): 59 | var_data_I = Variable(data_I.cuda()) 60 | _,_,code_I = model_I(var_data_I) 61 | code_I = torch.sign(code_I) 62 | re_BI.extend(code_I.cpu().data.numpy()) 63 | 64 | var_data_T = Variable(torch.FloatTensor(data_T.numpy()).cuda()) 65 | _,_,code_T = model_T(var_data_T) 66 | code_T = torch.sign(code_T) 67 | re_BT.extend(code_T.cpu().data.numpy()) 68 | 69 | qu_BI = list([]) 70 | qu_BT = list([]) 71 | qu_L = list([]) 72 | for _, (data_I, data_T, _, _) in enumerate(tqdm(test_loader,desc='compress for test')): 73 | var_data_I = Variable(data_I.cuda()) 74 | _,_,code_I = model_I(var_data_I) 75 | code_I = torch.sign(code_I) 76 | qu_BI.extend(code_I.cpu().data.numpy()) 77 | 78 | var_data_T = Variable(torch.FloatTensor(data_T.numpy()).cuda()) 79 | _,_,code_T = model_T(var_data_T) 80 | code_T = torch.sign(code_T) 81 | qu_BT.extend(code_T.cpu().data.numpy()) 82 | 83 | re_BI = np.array(re_BI) 84 | re_BT = np.array(re_BT) 85 | re_L = train_dataset.train_labels 86 | 87 | qu_BI = np.array(qu_BI) 88 | qu_BT = np.array(qu_BT) 89 | qu_L = test_dataset.train_labels 90 | return re_BI, re_BT, re_L, qu_BI, qu_BT, qu_L 91 | 92 | def calculate_hamming(B1, B2): 93 | """ 94 | :param B1: vector [n] 95 | :param B2: vector [r*n] 96 | :return: hamming distance [r] 97 | """ 98 | leng = B2.shape[1] # max inner product value 99 | distH = 0.5 * (leng - np.dot(B1, B2.transpose())) 100 | return distH 101 | 102 | 103 | def calculate_map(qu_B, re_B, qu_L, re_L): 104 | """ 105 | :param qu_B: {-1,+1}^{mxq} query bits 106 | :param re_B: {-1,+1}^{nxq} retrieval bits 107 | :param qu_L: {0,1}^{mxl} query label 108 | :param re_L: {0,1}^{nxl} retrieval label 109 | :return: 110 | """ 111 | num_query = qu_L.shape[0] 112 | map = 0 113 | for iter in range(num_query): 114 | gnd = (np.dot(qu_L[iter, :], re_L.transpose()) > 0).astype(np.float32) 115 | tsum = np.sum(gnd) 116 | if tsum == 0: 117 | continue 118 | hamm = calculate_hamming(qu_B[iter, :], re_B) 119 | ind = np.argsort(hamm) 120 | gnd = gnd[ind] 121 | 122 | count = np.linspace(1, tsum, tsum) # [1,2, tsum] 123 | tindex = np.asarray(np.where(gnd == 1)) + 1.0 124 | map_ = np.mean(count / (tindex)) 125 | map = map + map_ 126 | map = map / num_query 127 | return map 128 | 129 | 130 | def calculate_top_map(qu_B, re_B, qu_L, re_L, topk): 131 | """ 132 | :param qu_B: {-1,+1}^{mxq} query bits 133 | :param re_B: {-1,+1}^{nxq} retrieval bits 134 | :param qu_L: {0,1}^{mxl} query label 135 | :param re_L: {0,1}^{nxl} retrieval label 136 | :param topk: 137 | :return: 138 | """ 139 | num_query = qu_L.shape[0] 140 | topkmap = 0 141 | for iter in tqdm(range(num_query),desc='calculate_top_map'): 142 | gnd = (np.dot(qu_L[iter, :], re_L.transpose()) > 0).astype(np.float32) 143 | hamm = calculate_hamming(qu_B[iter, :], re_B) 144 | ind = np.argsort(hamm) 145 | gnd = gnd[ind] 146 | 147 | tgnd = gnd[0:topk] 148 | tsum = np.sum(tgnd) 149 | if tsum == 0: 150 | continue 151 | count = np.linspace(1, tsum, int(tsum)) 152 | tindex = np.asarray(np.where(tgnd == 1)) + 1.0 153 | topkmap_ = np.mean(count / (tindex)) 154 | topkmap = topkmap + topkmap_ 155 | topkmap = topkmap / num_query 156 | return topkmap 157 | -------------------------------------------------------------------------------- /utils/plotter.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/11/7 2 | # @Author : SXS 3 | # @Github : https://github.com/SXS-PRIVATE/HNH 4 | import numpy as np 5 | from visdom import Visdom 6 | 7 | 8 | class VisdomLinePlotter(object): 9 | """Plots to Visdom""" 10 | def __init__(self, env_name='plotter', port=8097): 11 | self.viz = Visdom(port=port) 12 | self.env = env_name 13 | self.plots = {} 14 | self.epoch = 0 15 | 16 | def plot(self, var_name, split_name, y, x=None): 17 | if x is None: 18 | x = self.epoch 19 | if var_name not in self.plots: 20 | self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict( 21 | legend=[split_name], 22 | title=var_name, 23 | xlabel='Epochs', 24 | ylabel=var_name 25 | )) 26 | else: 27 | self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update='append') 28 | 29 | def next_epoch(self): 30 | self.epoch += 1 31 | 32 | def reset_epoch(self): 33 | self.epoch = 0 34 | 35 | 36 | def get_plotter(env_name: str): 37 | return VisdomLinePlotter(env_name) 38 | 39 | --------------------------------------------------------------------------------