├── .gitignore ├── README.md ├── checkpoints └── .gitkeep ├── data ├── __init__.py ├── cifar10.py ├── data_loader.py ├── imagenet.py ├── nus_wide.py └── transform.py ├── dpsh.py ├── logs └── .gitkeep ├── models ├── alexnet.py ├── dpsh_loss.py ├── model_loader.py └── vgg16.py ├── requirements.txt ├── run.py └── utils └── evaluate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Checkpoints 2 | *.pt 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Feature Learning based Deep Supervised Hashing with Pairwise Labels 2 | 3 | ## REQUIREMENTS 4 | 1. pytorch 5 | 2. loguru 6 | 7 | `pip install -r requirements.txt` 8 | 9 | ## DATASETS 10 | 1. [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) 11 | 2. [NUS-WIDE](https://pan.baidu.com/s/1f9mKXE2T8XpIq8p7y8Fa6Q) Password: uhr3 12 | 3. [Imagenet100](https://pan.baidu.com/s/1Vihhd2hJ4q0FOiltPA-8_Q) Password: ynwf 13 | 14 | ## USAGE 15 | ``` 16 | usage: run.py [-h] [--dataset DATASET] [--root ROOT] [--num-query NUM_QUERY] 17 | [--arch ARCH] [--num-train NUM_TRAIN] 18 | [--code-length CODE_LENGTH] [--topk TOPK] [--gpu GPU] [--lr LR] 19 | [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] 20 | [--num-workers NUM_WORKERS] 21 | [--evaluate-interval EVALUATE_INTERVAL] [--eta ETA] 22 | 23 | DPSH_PyTorch 24 | 25 | optional arguments: 26 | -h, --help show this help message and exit 27 | --dataset DATASET Dataset name. 28 | --root ROOT Path of dataset 29 | --num-query NUM_QUERY 30 | Number of query data points.(default: 1000) 31 | --arch ARCH CNN model name.(default: alexnet) 32 | --num-train NUM_TRAIN 33 | Number of training data points.(default: 5000) 34 | --code-length CODE_LENGTH 35 | Binary hash code length.(default: 12,24,32,48) 36 | --topk TOPK Calculate map of top k.(default: all) 37 | --gpu GPU Using gpu.(default: False) 38 | --lr LR learning rate(default: 1e-5) 39 | --batch-size BATCH_SIZE 40 | batch size(default: 128) 41 | --max-iter MAX_ITER Number of iterations.(default: 150) 42 | --num-workers NUM_WORKERS 43 | Number of loading data threads.(default: 6) 44 | --evaluate-interval EVALUATE_INTERVAL 45 | Evaluation interval(default: 10) 46 | --eta ETA Hyper-parameter.(default: 0.1) 47 | ``` 48 | 49 | ## EXPERIMENTS 50 | CNN model: Alexnet. Compute mean average precision(MAP). 51 | 52 | cifar10: 1000 query images, 5000 training images. 53 | 54 | nus-wide-tc21: 21 classes, 2100 query images, 10500 training images. 55 | 56 | imagenet100: 100 classes, 5000 query images, 10000 training images. 57 | 58 | bits | 12 | 16 | 24 | 32 | 48 | 64 | 128 59 | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: 60 | cifar10@ALL | 0.6676 | 0.7131 | 0.7118 | 0.7362 | 0.7487 | 0.7542 | 0.7565 61 | nus-wide-tc21@5000 | 0.8091 | 0.8188 | 0.8346 | 0.8403 | 0.8450 | 0.8503 |0.8588 62 | imagenet100@1000 | 0.1985 | 0.2497 | 0.3654 | 0.4147 | 0.4612 | 0.4950 | 0.5687 63 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tree-Shu-Zhao/DPSH_PyTorch/bb6ac2e8b165276f736040d44596a271d4b5beef/checkpoints/.gitkeep -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tree-Shu-Zhao/DPSH_PyTorch/bb6ac2e8b165276f736040d44596a271d4b5beef/data/__init__.py -------------------------------------------------------------------------------- /data/cifar10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | import os 5 | import sys 6 | import pickle 7 | 8 | from torch.utils.data.dataloader import DataLoader 9 | from torch.utils.data.dataset import Dataset 10 | 11 | from data.transform import train_transform, query_transform, Onehot, encode_onehot 12 | 13 | 14 | def load_data(root, num_query, num_train, batch_size, num_workers): 15 | """ 16 | Load cifar10 dataset. 17 | 18 | Args 19 | root(str): Path of dataset. 20 | num_query(int): Number of query data points. 21 | num_train(int): Number of training data points. 22 | batch_size(int): Batch size. 23 | num_workers(int): Number of loading data threads. 24 | 25 | Returns 26 | train_dataloader, query_dataloader, retrieval_dataloader(torch.evaluate.data.DataLoader): Data loader. 27 | """ 28 | CIFAR10.init(root, num_query, num_train) 29 | train_dataset = CIFAR10('train', transform=train_transform(), target_transform=Onehot(10)) 30 | query_dataset = CIFAR10('query', transform=query_transform(), target_transform=Onehot(10)) 31 | retrieval_dataset = CIFAR10('database', transform=query_transform(), target_transform=Onehot(10)) 32 | 33 | train_dataloader = DataLoader( 34 | train_dataset, 35 | shuffle=True, 36 | batch_size=batch_size, 37 | pin_memory=True, 38 | num_workers=num_workers, 39 | ) 40 | query_dataloader = DataLoader( 41 | query_dataset, 42 | batch_size=batch_size, 43 | pin_memory=True, 44 | num_workers=num_workers, 45 | ) 46 | retrieval_dataloader = DataLoader( 47 | retrieval_dataset, 48 | batch_size=batch_size, 49 | pin_memory=True, 50 | num_workers=num_workers, 51 | ) 52 | 53 | return train_dataloader, query_dataloader, retrieval_dataloader 54 | 55 | 56 | class CIFAR10(Dataset): 57 | """ 58 | Cifar10 dataset. 59 | """ 60 | @staticmethod 61 | def init(root, num_query, num_train): 62 | data_list = ['data_batch_1', 63 | 'data_batch_2', 64 | 'data_batch_3', 65 | 'data_batch_4', 66 | 'data_batch_5', 67 | 'test_batch', 68 | ] 69 | base_folder = 'cifar-10-batches-py' 70 | 71 | data = [] 72 | targets = [] 73 | 74 | for file_name in data_list: 75 | file_path = os.path.join(root, base_folder, file_name) 76 | with open(file_path, 'rb') as f: 77 | if sys.version_info[0] == 2: 78 | entry = pickle.load(f) 79 | else: 80 | entry = pickle.load(f, encoding='latin1') 81 | data.append(entry['data']) 82 | if 'labels' in entry: 83 | targets.extend(entry['labels']) 84 | else: 85 | targets.extend(entry['fine_labels']) 86 | 87 | data = np.vstack(data).reshape(-1, 3, 32, 32) 88 | data = data.transpose((0, 2, 3, 1)) # convert to HWC 89 | targets = np.array(targets) 90 | 91 | # Sort by class 92 | sort_index = targets.argsort() 93 | data = data[sort_index, :] 94 | targets = targets[sort_index] 95 | 96 | # (num_query / number of class) query images per class 97 | # (num_train / number of class) train images per class 98 | query_per_class = num_query // 10 99 | train_per_class = num_train // 10 100 | 101 | # Permutate index (range 0 - 6000 per class) 102 | perm_index = np.random.permutation(data.shape[0] // 10) 103 | query_index = perm_index[:query_per_class] 104 | train_index = perm_index[query_per_class: query_per_class + train_per_class] 105 | 106 | query_index = np.tile(query_index, 10) 107 | train_index = np.tile(train_index, 10) 108 | inc_index = np.array([i * (data.shape[0] // 10) for i in range(10)]) 109 | query_index = query_index + inc_index.repeat(query_per_class) 110 | train_index = train_index + inc_index.repeat(train_per_class) 111 | list_query_index = [i for i in query_index] 112 | retrieval_index = np.array(list(set(range(data.shape[0])) - set(list_query_index)), dtype=np.int) 113 | 114 | # Split data, targets 115 | CIFAR10.QUERY_IMG = data[query_index, :] 116 | CIFAR10.QUERY_TARGET = targets[query_index] 117 | CIFAR10.TRAIN_IMG = data[train_index, :] 118 | CIFAR10.TRAIN_TARGET = targets[train_index] 119 | CIFAR10.RETRIEVAL_IMG = data[retrieval_index, :] 120 | CIFAR10.RETRIEVAL_TARGET = targets[retrieval_index] 121 | 122 | def __init__(self, mode='train', 123 | transform=None, target_transform=None, 124 | ): 125 | self.transform = transform 126 | self.target_transform = target_transform 127 | 128 | if mode == 'train': 129 | self.data = CIFAR10.TRAIN_IMG 130 | self.targets = CIFAR10.TRAIN_TARGET 131 | elif mode == 'query': 132 | self.data = CIFAR10.QUERY_IMG 133 | self.targets = CIFAR10.QUERY_TARGET 134 | else: 135 | self.data = CIFAR10.RETRIEVAL_IMG 136 | self.targets = CIFAR10.RETRIEVAL_TARGET 137 | 138 | self.onehot_targets = encode_onehot(self.targets, 10) 139 | 140 | def __getitem__(self, index): 141 | """ 142 | Args: 143 | index (int): Index 144 | 145 | Returns: 146 | tuple: (image, target, index) where target is index of the target class. 147 | """ 148 | img, target = self.data[index], self.targets[index] 149 | 150 | # doing this so that it is consistent with all other datasets 151 | # to return a PIL Image 152 | img = Image.fromarray(img) 153 | 154 | if self.transform is not None: 155 | img = self.transform(img) 156 | 157 | if self.target_transform is not None: 158 | target = self.target_transform(target) 159 | 160 | return img, target, index 161 | 162 | def __len__(self): 163 | return len(self.data) 164 | 165 | def get_onehot_targets(self): 166 | """ 167 | Return one-hot encoding targets. 168 | """ 169 | return torch.from_numpy(self.onehot_targets).float() 170 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | import data.cifar10 as cifar10 2 | import data.nus_wide as nuswide 3 | import data.imagenet as imagenet 4 | 5 | from PIL import ImageFile 6 | ImageFile.LOAD_TRUNCATED_IMAGES = True 7 | 8 | 9 | def load_data(dataset, root, num_query, num_train, batch_size, num_workers): 10 | """ 11 | Load dataset. 12 | 13 | Args 14 | dataset(str): Dataset name. 15 | root(str): Path of dataset. 16 | num_query(int): Number of query data points. 17 | num_train(int): Number of training data points. 18 | num_workers(int): Number of loading data threads. 19 | 20 | Returns 21 | train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.DataLoader): Data loader. 22 | """ 23 | if dataset == 'cifar-10': 24 | train_dataloader, query_dataloader, retrieval_dataloader = cifar10.load_data(root, 25 | num_query, 26 | num_train, 27 | batch_size, 28 | num_workers, 29 | ) 30 | elif dataset == 'nus-wide-tc21': 31 | train_dataloader, query_dataloader, retrieval_dataloader = nuswide.load_data(21, 32 | root, 33 | num_query, 34 | num_train, 35 | batch_size, 36 | num_workers, 37 | ) 38 | elif dataset == 'imagenet': 39 | train_dataloader, query_dataloader, retrieval_dataloader = imagenet.load_data(root, 40 | batch_size, 41 | num_workers, 42 | ) 43 | else: 44 | raise ValueError("Invalid dataset name!") 45 | 46 | return train_dataloader, query_dataloader, retrieval_dataloader 47 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | 4 | import os 5 | 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data.dataset import Dataset 8 | from PIL import Image 9 | from data.transform import encode_onehot, Onehot 10 | 11 | 12 | def load_data(root, batch_size, workers): 13 | """ 14 | Load imagenet dataset 15 | 16 | Args: 17 | root (str): Path of imagenet dataset. 18 | batch_size (int): Number of samples in one batch. 19 | workers (int): Number of data loading threads. 20 | 21 | Returns: 22 | train_dataloader (torch.utils.data.dataloader.DataLoader): Training dataset loader. 23 | query_dataloader (torch.utils.data.dataloader.DataLoader): Query dataset loader. 24 | retrieval_dataloader (torch.utils.data.dataloader.DataLoader): Validation dataset loader. 25 | """ 26 | # Data transform 27 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 28 | std=[0.229, 0.224, 0.225]) 29 | train_transform = transforms.Compose([ 30 | transforms.RandomResizedCrop(224), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | normalize, 34 | ]) 35 | test_transform = transforms.Compose([ 36 | transforms.Resize(256), 37 | transforms.CenterCrop(224), 38 | transforms.ToTensor(), 39 | normalize, 40 | ]) 41 | 42 | # Construct data loader 43 | train_dir = os.path.join(root, 'train') 44 | query_dir = os.path.join(root, 'query') 45 | retrieval_dir = os.path.join(root, 'database') 46 | 47 | train_dataset = ImagenetDataset( 48 | train_dir, 49 | transform=train_transform, 50 | target_transform=Onehot(100), 51 | ) 52 | 53 | train_dataloader = DataLoader( 54 | train_dataset, 55 | batch_size=batch_size, 56 | shuffle=True, 57 | num_workers=workers, 58 | pin_memory=True, 59 | ) 60 | 61 | query_dataset = ImagenetDataset( 62 | query_dir, 63 | transform=test_transform, 64 | target_transform=Onehot(100), 65 | ) 66 | 67 | query_dataloader = DataLoader( 68 | query_dataset, 69 | batch_size=batch_size, 70 | shuffle=False, 71 | num_workers=workers, 72 | pin_memory=True, 73 | ) 74 | 75 | retrieval_dataset = ImagenetDataset( 76 | retrieval_dir, 77 | transform=test_transform, 78 | target_transform=Onehot(100), 79 | ) 80 | 81 | retrieval_dataloader = DataLoader( 82 | retrieval_dataset, 83 | batch_size=batch_size, 84 | shuffle=False, 85 | num_workers=workers, 86 | pin_memory=True, 87 | ) 88 | 89 | return train_dataloader, query_dataloader, retrieval_dataloader 90 | 91 | 92 | class ImagenetDataset(Dataset): 93 | classes = None 94 | class_to_idx = None 95 | 96 | def __init__(self, root, transform=None, target_transform=None): 97 | self.root = root 98 | self.transform = transform 99 | self.target_transform = target_transform 100 | self.imgs = [] 101 | self.targets = [] 102 | 103 | # Assume file alphabet order is the class order 104 | if ImagenetDataset.class_to_idx is None: 105 | ImagenetDataset.classes, ImagenetDataset.class_to_idx = self._find_classes(root) 106 | 107 | for i, cl in enumerate(ImagenetDataset.classes): 108 | cur_class = os.path.join(self.root, cl) 109 | files = os.listdir(cur_class) 110 | files = [os.path.join(cur_class, i) for i in files] 111 | self.imgs.extend(files) 112 | self.targets.extend([ImagenetDataset.class_to_idx[cl] for i in range(len(files))]) 113 | self.targets = torch.tensor(self.targets) 114 | self.onehot_targets = torch.from_numpy(encode_onehot(self.targets, 100)).float() 115 | 116 | def __len__(self): 117 | return len(self.imgs) 118 | 119 | def __getitem__(self, item): 120 | img, target = self.imgs[item], self.targets[item] 121 | 122 | img = Image.open(img).convert('RGB') 123 | 124 | if self.transform is not None: 125 | img = self.transform(img) 126 | 127 | if self.target_transform is not None: 128 | target = self.target_transform(target) 129 | 130 | return img, target, item 131 | 132 | def _find_classes(self, dir): 133 | """ 134 | Finds the class folders in a dataset. 135 | 136 | Args: 137 | dir (string): Root directory path. 138 | 139 | Returns: 140 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 141 | 142 | Ensures: 143 | No class is a subdirectory of another. 144 | """ 145 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 146 | classes.sort() 147 | class_to_idx = {classes[i]: i for i in range(len(classes))} 148 | return classes, class_to_idx 149 | 150 | def get_onehot_targets(self): 151 | """ 152 | Return one-hot encoding targets. 153 | """ 154 | return self.onehot_targets 155 | 156 | -------------------------------------------------------------------------------- /data/nus_wide.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image, ImageFile 6 | from torch.utils.data.dataset import Dataset 7 | from torch.utils.data.dataloader import DataLoader 8 | 9 | from data.transform import train_transform, query_transform 10 | 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | 14 | def load_data(tc, root, num_query, num_train, batch_size, num_workers, 15 | ): 16 | """ 17 | Loading nus-wide dataset. 18 | 19 | Args: 20 | tc(int): Top class. 21 | root(str): Path of image files. 22 | num_query(int): Number of query data. 23 | num_train(int): Number of training data. 24 | batch_size(int): Batch size. 25 | num_workers(int): Number of loading data threads. 26 | 27 | Returns 28 | query_dataloader, train_dataloader, retrieval_dataloader(torch.evaluate.data.DataLoader): Data loader. 29 | """ 30 | if tc == 21: 31 | train_dataset = NusWideDatasetTC21( 32 | root, 33 | 'database_img.txt', 34 | 'database_label_onehot.txt', 35 | transform=train_transform(), 36 | train=True, 37 | num_train=num_train, 38 | ) 39 | 40 | query_dataset = NusWideDatasetTC21( 41 | root, 42 | 'test_img.txt', 43 | 'test_label_onehot.txt', 44 | transform=query_transform(), 45 | ) 46 | 47 | retrieval_dataset = NusWideDatasetTC21( 48 | root, 49 | 'database_img.txt', 50 | 'database_label_onehot.txt', 51 | transform=query_transform(), 52 | ) 53 | elif tc == 10: 54 | NusWideDatasetTc10.init(root, num_query, num_train) 55 | train_dataset = NusWideDatasetTc10(root, 'train', train_transform()) 56 | query_dataset = NusWideDatasetTc10(root, 'query', query_transform()) 57 | retrieval_dataset = NusWideDatasetTc10(root, 'retrieval', query_transform()) 58 | 59 | train_dataloader = DataLoader( 60 | train_dataset, 61 | batch_size=batch_size, 62 | shuffle=True, 63 | pin_memory=True, 64 | num_workers=num_workers, 65 | ) 66 | query_dataloader = DataLoader( 67 | query_dataset, 68 | batch_size=batch_size, 69 | pin_memory=True, 70 | num_workers=num_workers, 71 | ) 72 | retrieval_dataloader = DataLoader( 73 | retrieval_dataset, 74 | batch_size=batch_size, 75 | pin_memory=True, 76 | num_workers=num_workers, 77 | ) 78 | 79 | return train_dataloader, query_dataloader, retrieval_dataloader 80 | 81 | 82 | class NusWideDatasetTc10(Dataset): 83 | """ 84 | Nus-wide dataset, 10 classes. 85 | 86 | Args 87 | root(str): Path of dataset. 88 | mode(str, 'train', 'query', 'retrieval'): Mode of dataset. 89 | transform(callable, optional): Transform images. 90 | """ 91 | def __init__(self, root, mode, transform=None): 92 | self.root = root 93 | self.transform = transform 94 | 95 | if mode == 'train': 96 | self.data = NusWideDatasetTc10.TRAIN_DATA 97 | self.targets = NusWideDatasetTc10.TRAIN_TARGETS 98 | elif mode == 'query': 99 | self.data = NusWideDatasetTc10.QUERY_DATA 100 | self.targets = NusWideDatasetTc10.QUERY_TARGETS 101 | elif mode == 'retrieval': 102 | self.data = NusWideDatasetTc10.RETRIEVAL_DATA 103 | self.targets = NusWideDatasetTc10.RETRIEVAL_TARGETS 104 | else: 105 | raise ValueError(r'Invalid arguments: mode, can\'t load dataset!') 106 | 107 | def __getitem__(self, index): 108 | img = Image.open(os.path.join(self.root, self.data[index])).convert('RGB') 109 | if self.transform is not None: 110 | img = self.transform(img) 111 | return img, self.targets[index], index 112 | 113 | def __len__(self): 114 | return self.data.shape[0] 115 | 116 | def get_targets(self): 117 | return torch.from_numpy(self.targets).float() 118 | 119 | @staticmethod 120 | def init(root, num_query, num_train): 121 | """ 122 | Initialize dataset. 123 | 124 | Args 125 | root(str): Path of image files. 126 | num_query(int): Number of query data. 127 | num_train(int): Number of training data. 128 | """ 129 | # Load dataset 130 | img_txt_path = os.path.join(root, 'img_tc10.txt') 131 | targets_txt_path = os.path.join(root, 'targets_onehot_tc10.txt') 132 | 133 | # Read files 134 | with open(img_txt_path, 'r') as f: 135 | data = np.array([i.strip() for i in f]) 136 | targets = np.loadtxt(targets_txt_path, dtype=np.int64) 137 | 138 | # Split dataset 139 | perm_index = np.random.permutation(data.shape[0]) 140 | query_index = perm_index[:num_query] 141 | train_index = perm_index[num_query: num_query + num_train] 142 | retrieval_index = perm_index[num_query:] 143 | 144 | NusWideDatasetTc10.QUERY_DATA = data[query_index] 145 | NusWideDatasetTc10.QUERY_TARGETS = targets[query_index, :] 146 | 147 | NusWideDatasetTc10.TRAIN_DATA = data[train_index] 148 | NusWideDatasetTc10.TRAIN_TARGETS = targets[train_index, :] 149 | 150 | NusWideDatasetTc10.RETRIEVAL_DATA = data[retrieval_index] 151 | NusWideDatasetTc10.RETRIEVAL_TARGETS = targets[retrieval_index, :] 152 | 153 | 154 | class NusWideDatasetTC21(Dataset): 155 | """ 156 | Nus-wide dataset, 21 classes. 157 | 158 | Args 159 | root(str): Path of image files. 160 | img_txt(str): Path of txt file containing image file name. 161 | label_txt(str): Path of txt file containing image label. 162 | transform(callable, optional): Transform images. 163 | train(bool, optional): Return training dataset. 164 | num_train(int, optional): Number of training data. 165 | """ 166 | def __init__(self, root, img_txt, label_txt, transform=None, train=None, num_train=None): 167 | self.root = root 168 | self.transform = transform 169 | 170 | img_txt_path = os.path.join(root, img_txt) 171 | label_txt_path = os.path.join(root, label_txt) 172 | 173 | # Read files 174 | with open(img_txt_path, 'r') as f: 175 | self.data = np.array([i.strip() for i in f]) 176 | self.targets = np.loadtxt(label_txt_path, dtype=np.float32) 177 | 178 | # Sample training dataset 179 | if train is True: 180 | perm_index = np.random.permutation(len(self.data))[:num_train] 181 | self.data = self.data[perm_index] 182 | self.targets = self.targets[perm_index] 183 | 184 | def __getitem__(self, index): 185 | img = Image.open(os.path.join(self.root, self.data[index])).convert('RGB') 186 | if self.transform is not None: 187 | img = self.transform(img) 188 | 189 | return img, self.targets[index], index 190 | 191 | def __len__(self): 192 | return len(self.data) 193 | 194 | def get_onehot_targets(self): 195 | return torch.from_numpy(self.targets).float() 196 | -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | 5 | 6 | def encode_onehot(labels, num_classes=10): 7 | """ 8 | one-hot labels 9 | 10 | Args: 11 | labels (numpy.ndarray): labels. 12 | num_classes (int): Number of classes. 13 | 14 | Returns: 15 | onehot_labels (numpy.ndarray): one-hot labels. 16 | """ 17 | onehot_labels = np.zeros((len(labels), num_classes)) 18 | 19 | for i in range(len(labels)): 20 | onehot_labels[i, labels[i]] = 1 21 | 22 | return onehot_labels 23 | 24 | 25 | class Onehot(object): 26 | def __init__(self, num_classes): 27 | self.num_classes = num_classes 28 | 29 | def __call__(self, sample): 30 | target_onehot = torch.zeros(self.num_classes) 31 | target_onehot[sample] = 1 32 | 33 | return target_onehot 34 | 35 | 36 | def train_transform(): 37 | """ 38 | Training images transform. 39 | 40 | Args 41 | None 42 | 43 | Returns 44 | transform(torchvision.transforms): transform 45 | """ 46 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 47 | std=[0.229, 0.224, 0.225]) 48 | return transforms.Compose([ 49 | transforms.RandomResizedCrop(224), 50 | transforms.RandomHorizontalFlip(), 51 | transforms.ToTensor(), 52 | normalize, 53 | ]) 54 | 55 | 56 | def query_transform(): 57 | """ 58 | Query images transform. 59 | 60 | Args 61 | None 62 | 63 | Returns 64 | transform(torchvision.transforms): transform 65 | """ 66 | # Data transform 67 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 68 | std=[0.229, 0.224, 0.225]) 69 | return transforms.Compose([ 70 | transforms.Resize(256), 71 | transforms.CenterCrop(224), 72 | transforms.ToTensor(), 73 | normalize, 74 | ]) 75 | -------------------------------------------------------------------------------- /dpsh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import time 4 | 5 | from torch.optim.lr_scheduler import CosineAnnealingLR 6 | from models.model_loader import load_model 7 | from loguru import logger 8 | from models.dpsh_loss import DPSHLoss 9 | from utils.evaluate import mean_average_precision 10 | 11 | 12 | def train( 13 | train_dataloader, 14 | query_dataloader, 15 | retrieval_dataloader, 16 | arch, 17 | code_length, 18 | device, 19 | eta, 20 | lr, 21 | max_iter, 22 | topk, 23 | evaluate_interval, 24 | ): 25 | """ 26 | Training model. 27 | 28 | Args 29 | train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. 30 | arch(str): CNN model name. 31 | code_length(int): Hash code length. 32 | device(torch.device): GPU or CPU. 33 | eta(float): Hyper-parameter. 34 | lr(float): Learning rate. 35 | max_iter(int): Number of iterations. 36 | topk(int): Calculate map of top k. 37 | evaluate_interval(int): Evaluation interval. 38 | 39 | Returns 40 | checkpoint(dict): Checkpoint. 41 | """ 42 | # Create model, optimizer, criterion, scheduler 43 | model = load_model(arch, code_length).to(device) 44 | criterion = DPSHLoss(eta) 45 | optimizer = optim.RMSprop( 46 | model.parameters(), 47 | lr=lr, 48 | weight_decay=1e-5, 49 | ) 50 | scheduler = CosineAnnealingLR(optimizer, max_iter, 1e-7) 51 | 52 | # Initialization 53 | N = len(train_dataloader.dataset) 54 | U = torch.zeros(N, code_length).to(device) 55 | train_targets = train_dataloader.dataset.get_onehot_targets().to(device) 56 | 57 | # Training 58 | best_map = 0.0 59 | iter_time = time.time() 60 | for it in range(max_iter): 61 | model.train() 62 | running_loss = 0. 63 | for data, targets, index in train_dataloader: 64 | data, targets = data.to(device), targets.to(device) 65 | optimizer.zero_grad() 66 | 67 | S = (targets @ train_targets.t() > 0).float() 68 | U_cnn = model(data) 69 | U[index, :] = U_cnn.data 70 | loss = criterion(U_cnn, U, S) 71 | 72 | loss.backward() 73 | optimizer.step() 74 | running_loss += loss.item() 75 | scheduler.step() 76 | 77 | # Evaluate 78 | if it % evaluate_interval == evaluate_interval-1: 79 | iter_time = time.time() - iter_time 80 | 81 | # Generate hash code and one-hot targets 82 | query_code = generate_code(model, query_dataloader, code_length, device) 83 | query_targets = query_dataloader.dataset.get_onehot_targets() 84 | retrieval_code = generate_code(model, retrieval_dataloader, code_length, device) 85 | retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets() 86 | 87 | # Compute map 88 | mAP = mean_average_precision( 89 | query_code.to(device), 90 | retrieval_code.to(device), 91 | query_targets.to(device), 92 | retrieval_targets.to(device), 93 | device, 94 | topk, 95 | ) 96 | 97 | # Save checkpoint 98 | if best_map < mAP: 99 | best_map = mAP 100 | checkpoint = { 101 | 'qB': query_code, 102 | 'qL': query_targets, 103 | 'rB': retrieval_code, 104 | 'rL': retrieval_targets, 105 | 'model': model.state_dict(), 106 | 'map': best_map, 107 | } 108 | logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format( 109 | it+1, 110 | max_iter, 111 | running_loss, 112 | mAP, 113 | iter_time, 114 | )) 115 | iter_time = time.time() 116 | 117 | return checkpoint 118 | 119 | 120 | def generate_code(model, dataloader, code_length, device): 121 | """ 122 | Generate hash code 123 | 124 | Args 125 | dataloader(torch.utils.data.dataloader.DataLoader): Data loader. 126 | code_length(int): Hash code length. 127 | device(torch.device): Using gpu or cpu. 128 | 129 | Returns 130 | code(torch.Tensor, n*code_length): Hash code. 131 | """ 132 | model.eval() 133 | with torch.no_grad(): 134 | N = len(dataloader.dataset) 135 | code = torch.zeros([N, code_length]) 136 | for data, _, index in dataloader: 137 | data = data.to(device) 138 | hash_code = model(data) 139 | code[index, :] = hash_code.sign().cpu() 140 | 141 | model.train() 142 | return code 143 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tree-Shu-Zhao/DPSH_PyTorch/bb6ac2e8b165276f736040d44596a271d4b5beef/logs/.gitkeep -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torch.hub import load_state_dict_from_url 4 | 5 | 6 | def load_model(code_length): 7 | """ 8 | Load CNN model. 9 | 10 | Args 11 | code_length (int): Hashing code length. 12 | 13 | Returns 14 | model (torch.nn.Module): CNN model. 15 | """ 16 | model = AlexNet(code_length) 17 | state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth') 18 | model.load_state_dict(state_dict, strict=False) 19 | 20 | return model 21 | 22 | 23 | class AlexNet(nn.Module): 24 | 25 | def __init__(self, code_length): 26 | super(AlexNet, self).__init__() 27 | self.features = nn.Sequential( 28 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 32 | nn.ReLU(inplace=True), 33 | nn.MaxPool2d(kernel_size=3, stride=2), 34 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 39 | nn.ReLU(inplace=True), 40 | nn.MaxPool2d(kernel_size=3, stride=2), 41 | ) 42 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 43 | self.classifier = nn.Sequential( 44 | nn.Dropout(), 45 | nn.Linear(256 * 6 * 6, 4096), 46 | nn.ReLU(inplace=True), 47 | nn.Dropout(), 48 | nn.Linear(4096, 4096), 49 | nn.ReLU(inplace=True), 50 | nn.Linear(4096, 1000), 51 | ) 52 | 53 | self.classifier = self.classifier[:-1] 54 | self.hash_layer = nn.Linear(4096, code_length) 55 | 56 | def forward(self, x): 57 | x = self.features(x) 58 | x = self.avgpool(x) 59 | x = x.view(x.size(0), 256 * 6 * 6) 60 | x = self.classifier(x) 61 | x = self.hash_layer(x) 62 | return x 63 | -------------------------------------------------------------------------------- /models/dpsh_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class DPSHLoss(nn.Module): 6 | def __init__(self, eta): 7 | super(DPSHLoss, self).__init__() 8 | self.eta = eta 9 | 10 | def forward(self, U_cnn, U, S): 11 | theta = U_cnn @ U.t() / 2 12 | 13 | # Prevent overflow 14 | theta = torch.clamp(theta, min=-100, max=50) 15 | 16 | pair_loss = (torch.log(1 + torch.exp(theta)) - S * theta).mean() 17 | regular_term = (U_cnn - U_cnn.sign()).pow(2).mean() 18 | loss = pair_loss + self.eta * regular_term 19 | 20 | return loss 21 | -------------------------------------------------------------------------------- /models/model_loader.py: -------------------------------------------------------------------------------- 1 | import models.alexnet as alexnet 2 | import models.vgg16 as vgg16 3 | 4 | 5 | def load_model(arch, code_length): 6 | """ 7 | Load CNN model. 8 | 9 | Args 10 | arch(str): CNN model name. 11 | code_length(int): Hash code length. 12 | 13 | Returns 14 | model(torch.nn.Module): CNN model. 15 | """ 16 | if arch == 'alexnet': 17 | model = alexnet.load_model(code_length) 18 | elif arch == 'vgg16': 19 | model = vgg16.load_model(code_length) 20 | else: 21 | raise ValueError('Invalid cnn model name!') 22 | 23 | return model 24 | 25 | -------------------------------------------------------------------------------- /models/vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.hub import load_state_dict_from_url 5 | 6 | 7 | def load_model(code_length): 8 | """ 9 | Load vgg16 model. 10 | 11 | Args 12 | code_length (int): Hash code length. 13 | p (int): Eta layer. 14 | 15 | Returns 16 | model (torch.nn.Module): VGG16 mofel. 17 | """ 18 | model = VGG(make_layers(cfgs['D'], batch_norm=False), code_length) 19 | model.load_state_dict( 20 | load_state_dict_from_url('https://download.pytorch.org/models/vgg16-397923af.pth'), 21 | strict=False, 22 | ) 23 | 24 | return model 25 | 26 | 27 | class VGG(nn.Module): 28 | 29 | def __init__(self, features, code_length): 30 | super(VGG, self).__init__() 31 | self.features = features 32 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 33 | self.classifier = nn.Sequential( 34 | nn.Linear(512 * 7 * 7, 4096), 35 | nn.ReLU(True), 36 | nn.Dropout(), 37 | nn.Linear(4096, 4096), 38 | nn.ReLU(True), 39 | nn.Dropout(), 40 | nn.Linear(4096, 1000), 41 | ) 42 | self.classifier = self.classifier[:-1] 43 | 44 | self.hash_layer = nn.Linear(4096, code_length) 45 | 46 | def forward(self, x): 47 | x = self.features(x) 48 | x = self.avgpool(x) 49 | x = torch.flatten(x, 1) 50 | x = self.classifier(x) 51 | x = self.hash_layer(x) 52 | return x 53 | 54 | 55 | def make_layers(cfg, batch_norm=False): 56 | layers = [] 57 | in_channels = 3 58 | for v in cfg: 59 | if v == 'M': 60 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 61 | else: 62 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 63 | if batch_norm: 64 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 65 | else: 66 | layers += [conv2d, nn.ReLU(inplace=True)] 67 | in_channels = v 68 | return nn.Sequential(*layers) 69 | 70 | 71 | cfgs = { 72 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 73 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 74 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 75 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 76 | } 77 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | loguru 4 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import dpsh 2 | import os 3 | 4 | import argparse 5 | import torch 6 | from loguru import logger 7 | from data.data_loader import load_data 8 | 9 | 10 | def run(): 11 | args = load_config() 12 | logger.add(os.path.join('logs', '{}_model_{}_code_{}_query_{}_train_{}_topk_{}_eta_{}.log'.format( 13 | args.dataset, 14 | args.arch, 15 | ','.join([str(c) for c in args.code_length]), 16 | args.num_query, 17 | args.num_train, 18 | args.topk, 19 | args.eta, 20 | )), rotation='500 MB', level='INFO') 21 | logger.info(args) 22 | torch.backends.cudnn.benchmark = True 23 | 24 | # Load dataset 25 | train_dataloader, query_dataloader, retrieval_dataloader = load_data( 26 | args.dataset, 27 | args.root, 28 | args.num_query, 29 | args.num_train, 30 | args.batch_size, 31 | args.num_workers, 32 | ) 33 | 34 | # Training 35 | for code_length in args.code_length: 36 | logger.info('[code_length:{}]'.format(code_length)) 37 | checkpoint = dpsh.train( 38 | train_dataloader, 39 | query_dataloader, 40 | retrieval_dataloader, 41 | args.arch, 42 | code_length, 43 | args.device, 44 | args.eta, 45 | args.lr, 46 | args.max_iter, 47 | args.topk, 48 | args.evaluate_interval, 49 | ) 50 | torch.save(checkpoint, os.path.join('checkpoints', '{}_model_{}_code_{}_query_{}_train_{}_topk_{}_eta_{}_map_{:.4f}.pt'.format(args.dataset, args.arch, code_length, args.num_query, args.num_train, args.topk, args.eta, checkpoint['map']))) 51 | logger.info('[code_length:{}][map:{:.4f}]'.format(code_length, checkpoint['map'])) 52 | 53 | 54 | def load_config(): 55 | """ 56 | Load configuration. 57 | 58 | Args 59 | None 60 | 61 | Returns 62 | args(argparse.ArgumentParser): Configuration. 63 | """ 64 | parser = argparse.ArgumentParser(description='DPSH_PyTorch') 65 | parser.add_argument('--dataset', 66 | help='Dataset name.') 67 | parser.add_argument('--root', 68 | help='Path of dataset') 69 | parser.add_argument('--num-query', default=1000, type=int, 70 | help='Number of query data points.(default: 1000)') 71 | parser.add_argument('--arch', default='alexnet', type=str, 72 | help='CNN model name.(default: alexnet)') 73 | parser.add_argument('--num-train', default=5000, type=int, 74 | help='Number of training data points.(default: 5000)') 75 | parser.add_argument('--code-length', default='12,24,32,48', type=str, 76 | help='Binary hash code length.(default: 12,24,32,48)') 77 | parser.add_argument('--topk', default=-1, type=int, 78 | help='Calculate map of top k.(default: all)') 79 | parser.add_argument('--gpu', default=None, type=int, 80 | help='Using gpu.(default: False)') 81 | parser.add_argument('--lr', default=1e-5, type=float, 82 | help='learning rate(default: 1e-5)') 83 | parser.add_argument('--batch-size', default=128, type=int, 84 | help='batch size(default: 128)') 85 | parser.add_argument('--max-iter', default=150, type=int, 86 | help='Number of iterations.(default: 150)') 87 | parser.add_argument('--num-workers', default=6, type=int, 88 | help='Number of loading data threads.(default: 6)') 89 | parser.add_argument('--evaluate-interval', default=10, type=int, 90 | help='Evaluation interval(default: 10)') 91 | parser.add_argument('--eta', default=0.1, type=float, 92 | help='Hyper-parameter.(default: 0.1)') 93 | 94 | args = parser.parse_args() 95 | 96 | # GPU 97 | if args.gpu is None: 98 | args.device = torch.device("cpu") 99 | else: 100 | args.device = torch.device("cuda:%d" % args.gpu) 101 | torch.cuda.set_device(args.gpu) 102 | 103 | # Hash code length 104 | args.code_length = list(map(int, args.code_length.split(','))) 105 | 106 | return args 107 | 108 | 109 | if __name__ == "__main__": 110 | run() 111 | -------------------------------------------------------------------------------- /utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mean_average_precision(query_code, 5 | retrieval_code, 6 | query_targets, 7 | retrieval_targets, 8 | device, 9 | topk=None, 10 | ): 11 | """ 12 | Calculate mean average precision(map). 13 | 14 | Args: 15 | query_code (torch.Tensor): Query data hash code. 16 | retrieval_code (torch.Tensor): Database data hash code. 17 | query_targets (torch.Tensor): Query data targets, one-hot 18 | retrieval_targets (torch.Tensor): Database data targets, one-host 19 | device (torch.device): Using CPU or GPU. 20 | topk (int): Calculate top k data map. 21 | 22 | Returns: 23 | meanAP (float): Mean Average Precision. 24 | """ 25 | num_query = query_targets.shape[0] 26 | mean_AP = 0.0 27 | 28 | for i in range(num_query): 29 | # Retrieve images from database 30 | retrieval = (query_targets[i, :] @ retrieval_targets.t() > 0).float() 31 | 32 | # Calculate hamming distance 33 | hamming_dist = 0.5 * (retrieval_code.shape[1] - query_code[i, :] @ retrieval_code.t()) 34 | 35 | # Arrange position according to hamming distance 36 | retrieval = retrieval[torch.argsort(hamming_dist)][:topk] 37 | 38 | # Retrieval count 39 | retrieval_cnt = retrieval.sum().int().item() 40 | 41 | # Can not retrieve images 42 | if retrieval_cnt == 0: 43 | continue 44 | 45 | # Generate score for every position 46 | score = torch.linspace(1, retrieval_cnt, retrieval_cnt).to(device) 47 | 48 | # Acquire index 49 | index = (torch.nonzero(retrieval == 1).squeeze() + 1.0).float() 50 | 51 | mean_AP += (score / index).mean() 52 | 53 | mean_AP = mean_AP / num_query 54 | return mean_AP 55 | --------------------------------------------------------------------------------