├── Discussion.md ├── HashGAN.py ├── README.md ├── imgs ├── D_feat_CIFAR10.png ├── D_feat_MNIST.png ├── generated_CIFAR10.png └── generated_MNIST.png ├── models.py ├── models_improvedGAN.py ├── trained ├── D_cifar10.ckpt ├── D_mnist.ckpt ├── G_cifar10.ckpt └── G_mnist.ckpt └── utils.py /Discussion.md: -------------------------------------------------------------------------------- 1 | ## Discussion 2 | 3 | ### Network Structure 4 | 5 | 논문에서는 사용한 네트워크 구조에 대해서 명확하게 설명하고 있지 않기 때문에, 논문에서 언급한 Improved GAN 구조와 DCGAN 구조에 대해 모두 실험을 진행했다. 논문에서는 Improved GAN 구조와 유사하게 사용했다고 언급하고있으나, DCGAN 구조로 학습했을 때 학습이 더 안정적이고 나은 성능을 보였기 때문에 DCGAN 구조를 채택했다. 사용했던 Improved GAN 구조는 models_improvedGAN.py 에 포함되어 있다. 6 | 7 | 8 | ### Performance 9 | 10 | 해당 코드로 실험했을 때, precision@1000 기준으로 MNIST에서 0.36-0.44, CIFAR10에서 0.22-0.30 의 성능을 보였다. 이는 논문에서 제시하는 결과보다 많이 떨어지는 수치이다. 이는 Encoder network가 잘 학습되지 못했기 때문이라고 의심된다. 학습이 진행되었을 때 Generator가 생성한 이미지의 질도 충분히 좋은 편이며, Discriminator의 마지막 hidden layer의 feature도 좋은 특성을 가진다. 생성 이미지와 hidden layer feature의 그림을 함께 첨부하였다. (imgs/generated_{dataset}.png, imgs/D_feat_{dataset}.png) 11 | 따라서 Generator와 Discriminator는 충분히 학습되었다고 볼 수 있다. 12 | 13 | Encoder 성능을 올리기 위해 1)Loss weight 조절, 2)Minimum entropy loss, Uniform frequency loss form 변경 을 시도해 보았으나, 둘 다 성능이 크게 향상되지는 않았다. 다만, 논문에서는 weight 조절을 하지 않았다고 하였으나 weight 조절에 따른 성능 변화는 무시할 수 없는 수준이었다(~5%). 이러한 문제를 해결하기 위해 저자에게 연락을 시도하였으나 답변을 받지는 못했다. 14 | 15 | ### Training 16 | 17 | 논문에서는 학습 epoch 수, batch size 등은 제시하고 있지 않으나, 학습 epoch이 너무 많아지면 Generator가 mode collapse에 빠지는 문제가 발생하였다. Batch size는 큰 영향을 주지 않았고, epoch 수는 50이 가장 적절했다. -------------------------------------------------------------------------------- /HashGAN.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | import argparse 4 | import numpy as np 5 | from tqdm import trange 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | from torchvision import datasets, transforms 11 | import torchvision.utils as vutils 12 | from torch.autograd import grad, Variable 13 | import torch.backends.cudnn as cudnn 14 | 15 | from models import Generator, Discriminator 16 | from utils import get_trmat, set_input, get_prec_topn, bit_entropy 17 | 18 | 19 | class HashGAN(object): 20 | def __init__(self, z_dim, b_dim, device, save_dir, dataset='mnist', G_dict=None, D_dict=None): 21 | """ Defines the HashGAN network. 22 | 23 | Args: 24 | z_dim: The length of a continuous part of input random variable. 25 | b_dim: The length of a binary part of input random variable. 26 | device: The device number to use. It should be an integer value. 27 | save_dir: The path to a directory where the generated images and trained models will be saved. 28 | dataset: The name of dataset to use. 'cifar10' and 'mnist' are possible. 29 | G_dict: The path of generator model to load. If this is None, generator will be randomly initialized. 30 | D_dict: The path of discriminator(encoder) model to load. If this is None, discriminator(encoder) will be randomly initialized. 31 | 32 | Returns: 33 | None. 34 | """ 35 | self.device = device 36 | self.z_dim = z_dim 37 | self.b_dim = b_dim 38 | self.input_dim = z_dim + b_dim 39 | self.dataset = dataset 40 | 41 | self.save_dir = Path('results')/dataset/save_dir 42 | self.model_dir = self.save_dir/'models' 43 | self.img_dir = self.save_dir/'images' 44 | 45 | Path(self.model_dir).mkdir(parents=True, exist_ok=True) 46 | Path(self.img_dir).mkdir(parents=True, exist_ok=True) 47 | 48 | if self.dataset == 'mnist': 49 | img_channel = 1 50 | else: 51 | img_channel = 3 52 | 53 | self.G = Generator(self.input_dim, img_channel).to(self.device) 54 | self.D = Discriminator(self.b_dim, img_channel).to(self.device) 55 | 56 | if G_dict: 57 | G_weight = torch.load(G_dict) 58 | G_dict = self.G.state_dict() 59 | G_weight = {k: v for k, v in G_weight.items() if k in G_dict} 60 | G_dict.update(G_weight) 61 | self.G.load_state_dict(G_dict) 62 | if D_dict: 63 | D_weight = torch.load(D_dict) 64 | D_dict = self.D.state_dict() 65 | D_weight = {k: v for k, v in D_weight.items() if k in D_dict} 66 | D_dict.update(D_weight) 67 | self.D.load_state_dict(D_dict) 68 | 69 | def loss_D(self, x_real, hashing=True, zb_input=None): 70 | """ Computes the dicriminator and encoder losses and stores them as attributes of HashGAN class. 71 | 72 | Args: 73 | x_real: The real image to compute loss. 74 | hashing: Whether to compute hashing loss. 75 | zb_input: The input random noise to generate a fake image. If this is None, an input random noise will be set randomly. 76 | 77 | Returns: 78 | None. 79 | """ 80 | bs = x_real.size(0) 81 | if zb_input is None: 82 | zb_input = set_input(bs, self.z_dim, self.b_dim, self.device) 83 | b_input = zb_input[:, -self.b_dim:] 84 | x_fake = self.G(zb_input).detach() 85 | fake_logits, fake_codes, fake_feats = self.D(x_fake) 86 | real_logits, real_codes, real_feats = self.D(x_real) 87 | 88 | self.adv_loss_real = nn.BCELoss()(real_logits, torch.ones(bs, 1).to(self.device)) 89 | self.adv_loss_fake = nn.BCELoss()(fake_logits, torch.zeros(bs, 1).to(self.device)) 90 | self.adv_loss = self.adv_loss_real + self.adv_loss_fake 91 | 92 | if hashing: 93 | aff = get_trmat(bs, 10, 0.1).to(self.device) 94 | affgrid = F.affine_grid(aff, x_real.size()) 95 | x_real_aff = F.grid_sample(x_real, affgrid, padding_mode='reflection') 96 | real_aff_codes = self.D(x_real_aff)[1] 97 | mean_codes = real_codes.mean(dim=0) 98 | 99 | self.min_entropy_loss = bit_entropy(real_codes, 'mean') 100 | self.uniform_freq_loss = - bit_entropy(mean_codes, 'mean') 101 | self.consistent_loss = (real_codes-real_aff_codes).pow(2).mean() 102 | self.independent_loss = ((self.D.encode.weight.T @ self.D.encode.weight) 103 | - torch.eye(self.D.encode.weight.size(1)).to(self.device)).pow(2).mean() 104 | 105 | self.hash_loss = 0.01*self.min_entropy_loss + self.uniform_freq_loss + self.consistent_loss + self.independent_loss 106 | 107 | self.col_l2_loss = (fake_codes - b_input).pow(2).sum(dim=1).mean() 108 | else: 109 | self.min_entropy_loss = torch.Tensor([0.0]).to(self.device) 110 | self.uniform_freq_loss = torch.Tensor([0.0]).to(self.device) 111 | self.consistent_loss = torch.Tensor([0.0]).to(self.device) 112 | self.independent_loss = torch.Tensor([0.0]).to(self.device) 113 | self.hash_loss = torch.Tensor([0.0]).to(self.device) 114 | self.col_l2_loss = torch.Tensor([0.0]).to(self.device) 115 | 116 | 117 | def loss_G(self, x_real, zb_input=None): 118 | """ Computes the feature matching loss and stores it as an attribute of HashGAN class. 119 | 120 | Args: 121 | x_real: The real image to compute loss. 122 | zb_input: The input random noise to generate a fake image. If this is None, an input random noise will be set randomly. 123 | 124 | Returns: 125 | None. 126 | """ 127 | bs = x_real.size(0) 128 | if zb_input is None: 129 | zb_input = set_input(bs, self.z_dim, self.b_dim, self.device) 130 | x_fake = self.G(zb_input) 131 | 132 | real_feats = self.D(x_real)[2] 133 | fake_feats = self.D(x_fake)[2] 134 | 135 | self.feat_match_loss = (real_feats.mean(dim=0)-fake_feats.mean(dim=0)).pow(2).mean() 136 | 137 | 138 | def step_opt(self, loss, opt, retain_graph=False): 139 | """ Computes gradient and step optimizer. 140 | 141 | Args: 142 | loss: The loss to compute gradient. 143 | opt: The optimizer to use. 144 | retain_graph: Whether to retain the graph. 145 | 146 | Returns: 147 | None. 148 | """ 149 | opt.zero_grad() 150 | loss.backward(retain_graph=retain_graph) 151 | opt.step() 152 | 153 | def generate_code_label(self, dataloader): 154 | """ Generates codes and labels of all datapoints in given dataloader. 155 | 156 | Args: 157 | dataloader: The dataloader to get codes and labels, instance of torch.utils.data.DataLoader. 158 | 159 | Returns: 160 | Generated binary(-1 or 1) codes and corresponding one-hot labels. 161 | """ 162 | with torch.no_grad(): 163 | bs = dataloader.batch_size 164 | num_data = len(dataloader.dataset) 165 | codes_all = torch.zeros([num_data, self.b_dim]) 166 | labels_all = torch.zeros([num_data, 10]) 167 | 168 | train_it = iter(dataloader) 169 | t_train = trange(0, len(dataloader), initial=0, total=len(dataloader)) 170 | 171 | for step in t_train: 172 | index = torch.arange(step*bs, (step+1)*bs) 173 | try: 174 | img, label = next(train_it) 175 | except StopIteration: 176 | continue 177 | img = img.to(self.device) 178 | code = self.D(img)[1] 179 | codes_all[index, :] = (code-0.5).sign().cpu() 180 | labels_all[index, label] = 1 181 | return codes_all, labels_all 182 | 183 | def eval(self, query_loader, database_loader): 184 | """ Evaluates the hashgan network with given query set and database set. 185 | 186 | Args: 187 | query_loader: Dataloader of query set. Instance of torch.utils.data.DataLoader. 188 | database_loader: Dataloader of database set. Instance of torch.utils.data.DataLoader. 189 | 190 | Returns: 191 | Computed precision@1000 of given query set and database set. 192 | """ 193 | self.G.eval() 194 | self.D.eval() 195 | query_code, query_label = self.generate_code_label(query_loader) 196 | database_code, database_label = self.generate_code_label(database_loader) 197 | mAP1000 = get_prec_topn(query_code, database_code, query_label, database_label, topn=1000) 198 | self.G.train() 199 | self.D.train() 200 | return mAP1000 201 | 202 | 203 | def train(self, data_loader, init_lr, final_lr, num_epoch, log_step, query_loader=None, database_loader=None): 204 | """ Train the network. 205 | 206 | Args: 207 | data_lodaer: Dataloader of training set. Instance of torch.utils.data.DataLoader. 208 | init_lr: The initial learning rate. 209 | final_lr: The final learning rate. 210 | num_epoch: Maximum epoch to train. 211 | log_step: Interval to print the losses. 212 | query_loader: Dataloader of query set to evaluate network during training. If this is None, the network will be not evaluated during training. 213 | database_loader: Dataloader of database set to evaluate network during training. If this is None, the network will be not evaluated during training. 214 | 215 | Returns: 216 | None. 217 | """ 218 | cudnn.benchmark = True 219 | bs = data_loader.batch_size 220 | 221 | self.num_epoch = num_epoch 222 | self.G.train() 223 | self.D.train() 224 | 225 | self.G_opt = torch.optim.Adam(self.G.parameters(), lr=init_lr, betas=(0.5, 0.999)) 226 | self.D_opt = torch.optim.Adam(self.D.parameters(), lr=init_lr, betas=(0.5, 0.999)) 227 | 228 | z_example = set_input(bs, self.z_dim, self.b_dim, self.device) 229 | 230 | best_mAP1000 = 0.0 231 | 232 | for epoch in range(num_epoch): 233 | self.epoch = epoch 234 | train_it = iter(data_loader) 235 | t_train = trange(0, len(data_loader), initial=0, 236 | total=len(data_loader)) 237 | 238 | for step in t_train: 239 | try: 240 | dp = next(train_it) 241 | except StopIteration: 242 | continue 243 | 244 | x_real = dp[0] 245 | x_real = x_real.to(self.device) 246 | self.loss_G(x_real) 247 | self.step_opt(self.feat_match_loss, self.G_opt) 248 | if epoch < num_epoch/10: 249 | self.loss_D(x_real, hashing=False) 250 | self.step_opt(self.adv_loss, self.D_opt) 251 | 252 | else: 253 | self.loss_D(x_real) 254 | D_loss = self.adv_loss + self.hash_loss + 0.1*self.col_l2_loss 255 | self.step_opt(D_loss, self.D_opt) 256 | 257 | if not (step) % log_step: 258 | t_train.set_description(f'Epoch[{epoch}],' 259 | + f'G:[{self.feat_match_loss.item():.3f}],' 260 | + f'adv:[{self.adv_loss.item():.3f}],' 261 | + f'me:[{self.min_entropy_loss.item():.3f}],' 262 | + f'uf:[{self.uniform_freq_loss.item():.3f}],' 263 | + f'cons:[{self.consistent_loss.item():.3f}],' 264 | + f'W:[{self.independent_loss.item():.3f}],' 265 | + f'col:[{self.col_l2_loss.item():.3f}]' 266 | ) 267 | 268 | if epoch == 0: 269 | vutils.save_image(x_real, self.img_dir/'real.png', normalize=True, nrow=int(np.sqrt(bs))) 270 | 271 | x_fake_example = self.G(z_example) 272 | vutils.save_image(x_fake_example.view(x_real.size()), 273 | self.img_dir/f'fake_epoch_{epoch}.png', 274 | normalize=True, 275 | nrow=int(np.sqrt(bs))) 276 | 277 | if query_loader and database_loader and not epoch%10: 278 | mAP1000 = self.eval(query_loader, database_loader) 279 | if mAP1000 >= best_mAP1000: 280 | torch.save(self.D.state_dict(), self.model_dir/f'D_best.ckpt') 281 | torch.save(self.G.state_dict(), self.model_dir/f'G_best.ckpt') 282 | best_mAP1000 = mAP1000 283 | print(f'mAP_1000: {mAP1000:.4f}, best_mAP1000: {best_mAP1000:.4f}') 284 | 285 | if not (epoch+1) % (num_epoch//5): 286 | self.G_opt.param_groups[0]['lr'] += (final_lr-init_lr)/(4) 287 | self.D_opt.param_groups[0]['lr'] += (final_lr-init_lr)/(4) 288 | print(f'==== lr has changed to {self.G_opt.param_groups[0]["lr"]} ====') 289 | 290 | torch.save(self.D.state_dict(), self.model_dir/f'D_final.ckpt') 291 | torch.save(self.G.state_dict(), self.model_dir/f'G_final.ckpt') 292 | 293 | 294 | 295 | if __name__ == '__main__': 296 | parser = argparse.ArgumentParser() 297 | parser.add_argument('--train', action='store_true') 298 | parser.add_argument('--eval', action='store_true') 299 | parser.add_argument('--batch_size', type=int, 300 | default=100, help='The size of a batch') 301 | parser.add_argument('--z_dim', type=int, default=128, 302 | help='The size of a latent vector') 303 | parser.add_argument('--b_dim', type=int, default=16, 304 | help='The length of hash code') 305 | parser.add_argument('--init_lr', type=float, default=9e-4, 306 | help='The initial learning rate') 307 | parser.add_argument('--final_lr', type=float, 308 | default=3e-4, help='The final learning rate') 309 | parser.add_argument('--epochs', type=int, default=100, 310 | help='The number of epochs to run') 311 | parser.add_argument('--gpu', type=int, default=0, 312 | help='The ID of GPU to use') 313 | parser.add_argument('--dataset', type=str, 314 | default='mnist', help='The dataset to use') 315 | parser.add_argument('--G_dict', type=str, 316 | default=None, help='The path to generator model to load') 317 | parser.add_argument('--D_dict', type=str, 318 | default=None, help='The path to discriminator model to load') 319 | parser.add_argument('--save_dir', type=str, 320 | default='temp', help='The name of the save directory') 321 | parser.add_argument('--log_step', type=int, 322 | default=10, help='Log step period') 323 | parser.add_argument('--data_dir', type=str, default='./data', 324 | help='The images should be in here') 325 | 326 | args = parser.parse_args() 327 | device = torch.device(f"cuda:{args.gpu}") 328 | 329 | net = HashGAN(args.z_dim, args.b_dim, device, 330 | args.save_dir, args.dataset, args.G_dict, args.D_dict) 331 | 332 | 333 | if args.dataset == 'mnist': 334 | img_transforms = transforms.Compose([transforms.Resize((32, 32)), 335 | transforms.ToTensor(), 336 | transforms.Normalize((0.5,), (0.5,)), 337 | ]) 338 | dataset = datasets.MNIST( 339 | args.data_dir, train=True, download=True, transform=img_transforms) 340 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 341 | batch_size=args.batch_size, 342 | shuffle=True, 343 | drop_last=True, 344 | pin_memory=True) 345 | query_set = datasets.MNIST( 346 | args.data_dir, train=False, download=True, transform=img_transforms) 347 | query_loader = torch.utils.data.DataLoader(dataset=query_set, 348 | batch_size=args.batch_size, 349 | shuffle=False, 350 | drop_last=True, 351 | pin_memory=True) 352 | database_set = datasets.MNIST( 353 | args.data_dir, train=True, download=True, transform=img_transforms) 354 | database_loader = torch.utils.data.DataLoader(dataset=query_set, 355 | batch_size=args.batch_size, 356 | shuffle=False, 357 | drop_last=True, 358 | pin_memory=True) 359 | 360 | elif args.dataset == 'cifar10': 361 | img_transforms = transforms.Compose([transforms.ToTensor(), 362 | transforms.Normalize( 363 | (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 364 | ]) 365 | dataset = datasets.CIFAR10( 366 | args.data_dir, train=True, download=True, transform=img_transforms) 367 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 368 | batch_size=args.batch_size, 369 | shuffle=True, 370 | drop_last=True, 371 | pin_memory=True) 372 | query_set = datasets.CIFAR10( 373 | args.data_dir, train=False, download=True, transform=img_transforms) 374 | query_loader = torch.utils.data.DataLoader(dataset=query_set, 375 | batch_size=args.batch_size, 376 | shuffle=False, 377 | drop_last=True, 378 | pin_memory=True) 379 | database_set = datasets.CIFAR10( 380 | args.data_dir, train=True, download=True, transform=img_transforms) 381 | database_loader = torch.utils.data.DataLoader(dataset=query_set, 382 | batch_size=args.batch_size, 383 | shuffle=False, 384 | drop_last=True, 385 | pin_memory=True) 386 | 387 | else: 388 | print('The given dataset is invalid!') 389 | sys.exit() 390 | 391 | if args.train: 392 | net.train(data_loader, args.init_lr, args.final_lr, 393 | args.epochs, args.log_step) 394 | elif args.eval: 395 | net.eval(query_loader, database_loader) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HashGAN-Pytorch 2 | 3 | ## Introduction 4 | 5 | This is a pytorch implementation of [***Unsupervised Deep Generative Adversarial Hashing Network, CVPR'18***](http://openaccess.thecvf.com/content_cvpr_2018/papers/Dizaji_Unsupervised_Deep_Generative_CVPR_2018_paper.pdf) for CIFAR10 and MNIST dataset. 6 | 7 | * * * 8 | 9 | ## Prerequisites 10 | 11 | * **Linux** 12 | 13 | This code was written to be run on Linux. 14 | * **Python > 3.6** 15 | 16 | Using conda is recommended: [https://docs.anaconda.com/anaconda/install/linux/](https://docs.anaconda.com/anaconda/install/linux/) 17 | * **pytorch** 18 | 19 | To install: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) 20 | * **numpy** 21 | 22 | Run this in terminal: `conda install numpy` 23 | 24 | * **tqdm** 25 | 26 | Run this in terminal: `conda install tqdm` 27 | 28 | * * * 29 | 30 | ## Installation 31 | 32 | Run this in terminal: 33 | 34 | `git clone https://github.com/8uos/HashGAN-Pytorch` 35 | 36 | * * * 37 | 38 | ## Usage 39 | ### Train 40 | The simplest way is just run this in terminal: 41 | 42 | `python HashGAN.py --train` 43 | 44 | There are some possible additional arguments: 45 | 46 | `--batch_size`: The size of minibatch. Default value is 100. 47 | 48 | `--z_dim` : The length of a continuous part of input random variable. Default value is 128. 49 | 50 | `--b_dim` : The length of a binary part of input random variable. Default value is 16. 51 | 52 | `--init_lr` : Initial learning rate. Default value is 9e-04. 53 | 54 | `--final_lr` : Final learning rate. Default value is 3e-04. 55 | 56 | `--epochs` : The number of epochs to train. Default value is 100. 57 | 58 | `--gpu` : The id of gpu to use. Default value is 0. 59 | 60 | `--dataset` : The name of dataset to use. 'cifar10' and 'mnist' are possible, and default value is 'mnist'. 61 | 62 | `--G_dict` : The path to generator model to load. If this is None, generator will be randomly initialized. Default value is None. 63 | 64 | `--D_dict` : The path to discriminator model to load. If this is None, discriminator will be randomly initialized. Default value is None. 65 | 66 | `--save_dir` : The name of the save directory. Everything will be saved in `results/save_dir`. Default value is 'temp'. 67 | 68 | `--log_step` : Interval to print the losses. 69 | 70 | `--data_dir` : The location of the dataset. If the dataset does not exists in data_dir, the dataset will be downloaded. Default value is './data'. 71 | 72 | ### Evaluate 73 | 74 | ``` 75 | python HashGAN.py --eval \ 76 | --G_dict=Path/to/generator/dict/to/evaluate \ 77 | --D_dict=Path/to/discriminator/dict/to/evaluate 78 | ``` 79 | 80 | The possible additional arguments are identical to the ones above. 81 | 82 | * * * 83 | 84 | ## Structure 85 | ### HashGAN.py 86 | * HashGAN ***(class)*** 87 | 88 | * __init__ ***(method)*** 89 | 90 | Defines the HashGAN network. 91 | 92 | * loss_D ***(method)*** 93 | 94 | Computes the dicriminator and encoder losses and stores them as attributes of HashGAN class. 95 | 96 | * loss_G ***(method)*** 97 | 98 | Computes the feature matching loss and stores it as an attribute of HashGAN class. 99 | 100 | * step_opt ***(method)*** 101 | 102 | Computes gradient and step optimizer. 103 | 104 | * generate_code_label ***(method)*** 105 | 106 | Generates codes using current encoder and labels of all datapoints in given dataloader. 107 | 108 | * eval ***(method)*** 109 | 110 | Evaluates the hashgan network with given query set and database set. 111 | 112 | * train ***(method)*** 113 | 114 | Train the network. 115 | 116 | * Define the net, train, evaluate 117 | 118 | ### utils.py 119 | * get_trmat ***(function)*** 120 | 121 | Builds transform matrix to compute consistent bit loss. 122 | 123 | * set_input ***(function)*** 124 | 125 | Makes input random variable consisting of a continuous part and a binary part of input random variable. 126 | 127 | * get_prec_topn ***(function)*** 128 | 129 | Computes precision@topn with given query and database codes. 130 | 131 | * bit_entropy ***(function)*** 132 | 133 | Computes the entropy of each bit of the given code. 134 | 135 | ### models.py 136 | * Generator ***(class)*** 137 | 138 | The definition of the model of generator network. 139 | 140 | * Discriminator ***(class)*** 141 | 142 | The definition of the model of discriminator and encoder network. 143 | 144 | 145 | ## References 146 | * K.dizaji, F.Zheng, N.Nourabadi, Y.Yang, C.Deng, H.Huang. “Unsupervised Deep Generative Adversarial Hashing Network”, CVPR, 2018. 147 | 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /imgs/D_feat_CIFAR10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/8uos/HashGAN-Pytorch/453448ddadcc2e4e33a52288311480eb5272205a/imgs/D_feat_CIFAR10.png -------------------------------------------------------------------------------- /imgs/D_feat_MNIST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/8uos/HashGAN-Pytorch/453448ddadcc2e4e33a52288311480eb5272205a/imgs/D_feat_MNIST.png -------------------------------------------------------------------------------- /imgs/generated_CIFAR10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/8uos/HashGAN-Pytorch/453448ddadcc2e4e33a52288311480eb5272205a/imgs/generated_CIFAR10.png -------------------------------------------------------------------------------- /imgs/generated_MNIST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/8uos/HashGAN-Pytorch/453448ddadcc2e4e33a52288311480eb5272205a/imgs/generated_MNIST.png -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.utils.weight_norm as weight_norm 4 | 5 | 6 | class Generator(nn.Module): 7 | def __init__(self, z_dim=128, img_channel=3, nf=64): 8 | """ Defines the generator model. 9 | 10 | Args: 11 | z_dim: The length of input random noise. 12 | img_channel: The number of channels of output image. 3 for CIFAR10, 1 for MNIST. 13 | nf: The unit number of filters. Default is 64. 14 | 15 | Returns: 16 | None. 17 | """ 18 | super(Generator, self).__init__() 19 | self.nf = nf 20 | self.z_dim = z_dim 21 | 22 | self.latent = nn.Sequential(nn.Linear(self.z_dim, 8*nf*4*4), 23 | nn.BatchNorm1d(8*nf*4*4), 24 | nn.ReLU() 25 | ) 26 | 27 | self.network = nn.Sequential(nn.ConvTranspose2d(nf*8, nf*4, 4, 2, 1, bias=False), 28 | nn.BatchNorm2d(nf*4), 29 | nn.ReLU(inplace=True), 30 | 31 | nn.ConvTranspose2d(nf*4, nf*2, 4, 2, 1, bias=False), 32 | nn.BatchNorm2d(nf*2), 33 | nn.ReLU(inplace=True), 34 | 35 | nn.ConvTranspose2d(nf*2, nf, 4, 2, 1, bias=False), 36 | nn.BatchNorm2d(nf), 37 | nn.ReLU(inplace=True), 38 | 39 | nn.ConvTranspose2d(nf, img_channel, kernel_size=1, stride=1, padding=0, bias=False), 40 | nn.Tanh() 41 | ) 42 | self._initialize_weights() 43 | 44 | def forward(self, z): 45 | out = self.latent(z) 46 | out = out.reshape(out.size(0), 8*self.nf, 4, 4) 47 | out = self.network(out) 48 | return out 49 | 50 | 51 | def _initialize_weights(self): 52 | for m in self.modules(): 53 | if isinstance(m, nn.ConvTranspose2d): 54 | nn.init.kaiming_normal_( 55 | m.weight, mode='fan_out', nonlinearity='relu') 56 | if m.bias is not None: 57 | nn.init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.BatchNorm2d): 59 | nn.init.constant_(m.weight, 1) 60 | nn.init.constant_(m.bias, 0) 61 | elif isinstance(m, nn.Linear): 62 | nn.init.normal_(m.weight, 0, 0.01) 63 | nn.init.constant_(m.bias, 0) 64 | 65 | 66 | class Discriminator(nn.Module): 67 | def __init__(self, len_code, img_channel=3, nf=64): 68 | """ Defines the discriminator and encoder model. 69 | 70 | Args: 71 | len_code: The length of output hash code. 72 | img_channel: The number of channels of input image. 3 for CIFAR10, 1 for MNIST. 73 | nf: The unit number of filters. Default is 64. 74 | 75 | Returns: 76 | None. 77 | """ 78 | super(Discriminator, self).__init__() 79 | self.len_code = len_code 80 | self.nf = nf 81 | 82 | self.network = nn.Sequential(nn.Conv2d(img_channel, nf, 4, 2, 1), 83 | nn.LeakyReLU(2e-1), 84 | 85 | nn.Conv2d(nf, nf*2, 4, 2, 1), 86 | nn.BatchNorm2d(nf*2), 87 | nn.LeakyReLU(2e-1), 88 | 89 | nn.Conv2d(nf*2, nf*4, 4, 2, 1), 90 | nn.BatchNorm2d(nf*4), 91 | nn.LeakyReLU(2e-1), 92 | 93 | nn.Conv2d(nf*4, nf*8, 4, 2, 1), 94 | nn.BatchNorm2d(nf*8), 95 | nn.LeakyReLU(2e-1)) 96 | 97 | self.discriminate = nn.Linear(8*nf*2*2, 1) 98 | self.encode = nn.Linear(8*nf*2*2, len_code) 99 | 100 | self._initialize_weights() 101 | 102 | def forward(self, x): 103 | feat = self.network(x) 104 | feat = feat.view(feat.size(0), -1) 105 | disc = self.discriminate(feat) 106 | disc = nn.Sigmoid()(disc) 107 | code = self.encode(feat) 108 | code = nn.Sigmoid()(code) 109 | return disc, code, feat 110 | 111 | def _initialize_weights(self): 112 | for m in self.modules(): 113 | if isinstance(m, nn.Conv2d): 114 | nn.init.kaiming_normal_( 115 | m.weight, mode='fan_out', nonlinearity='relu') 116 | if m.bias is not None: 117 | nn.init.constant_(m.bias, 0) 118 | elif isinstance(m, nn.Linear): 119 | nn.init.normal_(m.weight, 0, 0.01) 120 | nn.init.constant_(m.bias, 0) -------------------------------------------------------------------------------- /models_improvedGAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.utils.weight_norm as weight_norm 4 | 5 | 6 | class G_cifar10(nn.Module): 7 | """ Defines the generator model. 8 | 9 | Args: 10 | z_dim: The length of input random noise. 11 | nf: The unit number of filters. Default is 64. 12 | 13 | Returns: 14 | None. 15 | """ 16 | def __init__(self, z_dim=128, nf=64): 17 | super(G_cifar10, self).__init__() 18 | self.nf = nf 19 | self.z_dim = z_dim 20 | 21 | self.latent = nn.Sequential(nn.Linear(self.z_dim, 8*nf*4*4), 22 | nn.BatchNorm1d(8*nf*4*4), 23 | nn.ReLU() 24 | ) 25 | self.network = nn.Sequential(nn.ConvTranspose2d(8*nf, 4*nf, 4, 2, 1, bias=False), 26 | nn.BatchNorm2d(4*nf), 27 | nn.ReLU(inplace=True), 28 | nn.ConvTranspose2d(4*nf, 2*nf, 4, 2, 1, bias=False), 29 | nn.BatchNorm2d(2*nf), 30 | nn.ReLU(inplace=True), 31 | weight_norm(nn.ConvTranspose2d(2*nf, 3, 4, 2, 1, bias=False)), 32 | nn.Tanh() 33 | ) 34 | self._initialize_weights() 35 | 36 | def forward(self, z): 37 | out = self.latent(z) 38 | out = out.reshape(out.size(0), 8*self.nf, 4, 4) 39 | out = self.network(out) 40 | return out 41 | 42 | 43 | def _initialize_weights(self): 44 | for m in self.modules(): 45 | if isinstance(m, nn.ConvTranspose2d): 46 | nn.init.kaiming_normal_( 47 | m.weight, mode='fan_out', nonlinearity='relu') 48 | if m.bias is not None: 49 | nn.init.constant_(m.bias, 0) 50 | elif isinstance(m, nn.BatchNorm2d): 51 | nn.init.constant_(m.weight, 1) 52 | nn.init.constant_(m.bias, 0) 53 | elif isinstance(m, nn.Linear): 54 | nn.init.normal_(m.weight, 0, 0.01) 55 | nn.init.constant_(m.bias, 0) 56 | 57 | 58 | class DE_cifar10(nn.Module): 59 | """ Defines the discriminator and encoder model. 60 | 61 | Args: 62 | len_code: The length of output hash code. 63 | nf: The unit number of filters. Default is 64. 64 | 65 | Returns: 66 | None. 67 | """ 68 | def __init__(self, len_code, nf=48): 69 | super(DE_cifar10, self).__init__() 70 | self.len_code = len_code 71 | self.nf = nf 72 | 73 | self.network = nn.Sequential(nn.Dropout(0.2), 74 | weight_norm(nn.Conv2d(3, 2*nf, 3, 1, 1)), 75 | nn.LeakyReLU(2e-1), 76 | weight_norm( 77 | nn.Conv2d(2*nf, 2*nf, 3, 1, 1)), 78 | nn.LeakyReLU(2e-1), 79 | weight_norm( 80 | nn.Conv2d(2*nf, 2*nf, 3, 2, 1)), 81 | nn.LeakyReLU(2e-1), 82 | nn.Dropout(0.5), 83 | weight_norm( 84 | nn.Conv2d(2*nf, 4*nf, 3, 1, 1)), 85 | nn.LeakyReLU(2e-1), 86 | weight_norm( 87 | nn.Conv2d(4*nf, 4*nf, 3, 1, 1)), 88 | nn.LeakyReLU(2e-1), 89 | weight_norm( 90 | nn.Conv2d(4*nf, 4*nf, 3, 2, 1)), 91 | nn.LeakyReLU(2e-1), 92 | nn.Dropout(0.5), 93 | weight_norm( 94 | nn.Conv2d(4*nf, 4*nf, 3, 1, 0)), 95 | nn.LeakyReLU(2e-1), 96 | weight_norm( 97 | nn.Conv2d(4*nf, 4*nf, 1, 1, 0)), 98 | nn.LeakyReLU(2e-1), 99 | weight_norm( 100 | nn.Conv2d(4*nf, 4*nf, 1, 1, 0)), 101 | nn.LeakyReLU(2e-1), 102 | nn.AdaptiveAvgPool2d((1, 1)), 103 | ) 104 | 105 | self.discriminate = weight_norm(nn.Linear(4*nf, 1)) 106 | self.encode = nn.Linear(4*nf, len_code) 107 | 108 | self._initialize_weights() 109 | 110 | def forward(self, x): 111 | feat = self.network(x) 112 | feat = feat.view(feat.size(0), -1) 113 | disc = self.discriminate(feat) 114 | disc = nn.Sigmoid()(disc) 115 | code = self.encode(feat) 116 | code = nn.Sigmoid()(code) 117 | return disc, code, feat 118 | 119 | def _initialize_weights(self): 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | nn.init.kaiming_normal_( 123 | m.weight, mode='fan_out', nonlinearity='relu') 124 | if m.bias is not None: 125 | nn.init.constant_(m.bias, 0) 126 | elif isinstance(m, nn.Linear): 127 | nn.init.normal_(m.weight, 0, 0.01) 128 | nn.init.constant_(m.bias, 0) 129 | 130 | 131 | class G_mnist(nn.Module): 132 | """ Defines the generator model. 133 | 134 | Args: 135 | z_dim: The length of input random noise. 136 | nf: The unit number of filters. Default is 64. 137 | 138 | Returns: 139 | None. 140 | """ 141 | def __init__(self, z_dim=128): 142 | super(G_mnist, self).__init__() 143 | self.z_dim = z_dim 144 | 145 | self.network = nn.Sequential(nn.Linear(self.z_dim, 500), 146 | nn.BatchNorm1d(500), 147 | nn.Softplus(), 148 | nn.Linear(500, 500), 149 | nn.BatchNorm1d(500), 150 | nn.Softplus(), 151 | weight_norm(nn.Linear(500, 28**2)), 152 | nn.Tanh(), 153 | ) 154 | self._initialize_weights() 155 | 156 | def forward(self, z): 157 | out = self.network(z) 158 | return out 159 | 160 | def _initialize_weights(self): 161 | for m in self.modules(): 162 | if isinstance(m, nn.BatchNorm1d): 163 | nn.init.constant_(m.weight, 1) 164 | nn.init.constant_(m.bias, 0) 165 | elif isinstance(m, nn.Linear): 166 | nn.init.normal_(m.weight, 0, 0.01) 167 | nn.init.constant_(m.bias, 0) 168 | 169 | 170 | class DE_mnist(nn.Module): 171 | """ Defines the discriminator and encoder model. 172 | 173 | Args: 174 | len_code: The length of output hash code. 175 | nf: The unit number of filters. Default is 64. 176 | 177 | Returns: 178 | None. 179 | """ 180 | def __init__(self, len_code): 181 | super(DE_mnist, self).__init__() 182 | self.len_code = len_code 183 | self.network = torch.nn.Sequential(weight_norm(nn.Linear(28**2, 1000)), 184 | nn.ReLU(), 185 | AddNoise(0.15), 186 | weight_norm(nn.Linear(1000, 500)), 187 | nn.ReLU(), 188 | AddNoise(0.15), 189 | weight_norm(nn.Linear(500, 250)), 190 | nn.ReLU(), 191 | AddNoise(0.15), 192 | weight_norm(nn.Linear(250, 250)), 193 | nn.ReLU(), 194 | AddNoise(0.15), 195 | weight_norm(nn.Linear(250, 250)), 196 | ) 197 | 198 | self.discriminate = weight_norm(nn.Linear(250, 1)) 199 | self.encode = nn.Linear(250, len_code) 200 | self._initialize_weights() 201 | 202 | def forward(self, x): 203 | x = x.view(x.size(0), -1) 204 | feat = self.network(x) 205 | disc = self.discriminate(feat) 206 | disc = nn.Sigmoid()(disc) 207 | code = self.encode(feat) 208 | code = nn.Sigmoid()(code) 209 | return disc, code, feat 210 | 211 | def _initialize_weights(self): 212 | for m in self.modules(): 213 | if isinstance(m, nn.BatchNorm1d): 214 | nn.init.constant_(m.weight, 1) 215 | nn.init.constant_(m.bias, 0) 216 | 217 | class AddNoise(nn.Module): 218 | def __init__(self, std=0.05): 219 | super().__init__() 220 | self.noise = torch.Tensor([0]) 221 | self.std = std 222 | 223 | def forward(self, x): 224 | device = torch.device(f'cuda:{x.get_device()}') 225 | x = x.cpu() 226 | if self.training and self.std != 0: 227 | scale = self.std * x 228 | noise = self.noise.repeat(*x.size()).normal_() * scale 229 | x = x + noise 230 | x = x.to(device) 231 | return x -------------------------------------------------------------------------------- /trained/D_cifar10.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/8uos/HashGAN-Pytorch/453448ddadcc2e4e33a52288311480eb5272205a/trained/D_cifar10.ckpt -------------------------------------------------------------------------------- /trained/D_mnist.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/8uos/HashGAN-Pytorch/453448ddadcc2e4e33a52288311480eb5272205a/trained/D_mnist.ckpt -------------------------------------------------------------------------------- /trained/G_cifar10.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/8uos/HashGAN-Pytorch/453448ddadcc2e4e33a52288311480eb5272205a/trained/G_cifar10.ckpt -------------------------------------------------------------------------------- /trained/G_mnist.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/8uos/HashGAN-Pytorch/453448ddadcc2e4e33a52288311480eb5272205a/trained/G_mnist.ckpt -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from tqdm import trange 5 | 6 | def get_trmat(bs, rot, trs): 7 | """ Builds transform matrix to compute consistent bit loss. 8 | 9 | Args: 10 | bs: Size of minibatch. 11 | rot: Maximum rotation degree. 12 | trs: Maximum translation ratio. 13 | 14 | Returns: 15 | 3*3 transform matrix. 16 | 17 | """ 18 | rot = np.random.randint(-rot, rot) 19 | trs = np.random.uniform(-trs, trs) 20 | hf = int(np.random.choice([-1, 1])) 21 | 22 | pi = torch.tensor(np.pi) 23 | cosR = torch.cos(rot * pi / 180.0) 24 | sinR = torch.sin(rot * pi / 180.0) 25 | 26 | rotmat = torch.zeros(bs, 3, 3) 27 | trsmat = torch.zeros(bs, 3, 3) 28 | hfmat = torch.zeros(bs, 3, 3) 29 | scmat = torch.zeros(bs, 3, 3) 30 | 31 | rotmat[:, 0, 0] = cosR 32 | rotmat[:, 0, 1] = -sinR 33 | rotmat[:, 1, 0] = sinR 34 | rotmat[:, 1, 1] = cosR 35 | rotmat[:, 2, 2] = 1.0 36 | 37 | trsmat[:, 0, 0] = 1.0 38 | trsmat[:, 0, 2] = trs 39 | trsmat[:, 1, 1] = 1.0 40 | trsmat[:, 1, 2] = trs 41 | trsmat[:, 2, 2] = 1.0 42 | 43 | hfmat[:, 0, 0] = hf 44 | hfmat[:, 1, 1] = 1.0 45 | hfmat[:, 2, 2] = 1.0 46 | 47 | mats = [trsmat, rotmat, hfmat] 48 | theta = mats[0] 49 | for matidx in range(1, len(mats)): 50 | theta = torch.matmul(theta, mats[matidx]) 51 | theta = theta[:, :2, :] 52 | return theta 53 | 54 | def set_input(bs, z_dim, b_dim, device): 55 | """ Makes input random variable consisting of a continuous part and a binary part of input random variable. 56 | 57 | Args: 58 | bs: Size of minibatch. 59 | z_dim: The length of a continuous part of input random variable. 60 | b_dim: The length of a binary part of input random variable. 61 | 62 | Returns: 63 | The input random variable, concatenation of z_dim continuous random noise and b_dim binary random noise. 64 | """ 65 | z_input = torch.FloatTensor(bs, z_dim).uniform_(0, 1).to(device) 66 | b_input = torch.FloatTensor(bs, b_dim).uniform_(-1, 1).to(device) 67 | b_input = (b_input.sign()+1)/2 68 | zb_input = torch.cat([z_input, b_input], dim=1) 69 | return zb_input 70 | 71 | def get_prec_topn(query_code, database_code, query_labels, database_labels, topn=1000): 72 | """ Computes precision@topn with given query and database codes. 73 | 74 | Args: 75 | query_code: The binary codes of query images. Every element is -1 or 1. 76 | database_code: The binary codes of database images. Every element is -1 or 1. 77 | query_labels: One-hot labels of query images. 78 | database_labels: One-hot labels of database images. 79 | topn: The number of retrieved images. 80 | 81 | Returns: 82 | Precision@topn of given query and database codes. 83 | """ 84 | num_query = query_labels.shape[0] 85 | num_database = database_labels.shape[0] 86 | 87 | mean_topn = 0.0 88 | 89 | for i in trange(num_query): 90 | S = (query_labels[i, :] @ database_labels.t() > 0).float() 91 | relevant_num = S.sum().item() 92 | if not relevant_num: 93 | continue 94 | 95 | hamming_dist = 0.5 * \ 96 | (database_code.shape[1] - query_code[i, :] @ database_code.t()) 97 | S = S[torch.argsort(hamming_dist)] 98 | prec_topn = S[:topn].sum().item() / topn 99 | mean_topn += prec_topn 100 | 101 | mean_topn = mean_topn / num_query 102 | 103 | return mean_topn 104 | 105 | def bit_entropy(codes, reduction='mean'): 106 | """ Computes the entropy of each bit of the given code. 107 | 108 | Args: 109 | codes: The codes to compute entropy. 110 | reduction: The way to reduce the dimension of output. 'mean' and 'sum' are possible; the default is 'mean'. 111 | 112 | Returns: 113 | The entropy of bits of given code. 114 | """ 115 | eps = 1e-40 116 | entropy = -(codes*codes.clamp(eps).log() + (1-codes)*(1-codes).clamp(eps).log()) 117 | 118 | if reduction == 'sum': 119 | entropy_loss = entropy.sum() 120 | elif reduction == 'mean': 121 | entropy_loss = entropy.mean() 122 | else: 123 | print('specify the reduction') 124 | 125 | return entropy_loss 126 | --------------------------------------------------------------------------------