├── .gitignore ├── Overall.png ├── modules ├── training_utils.py ├── AdversarialDomainAdaptation │ ├── generator.py │ ├── classifier.py │ ├── discriminator.py │ └── cycleGAN.py └── baseZSL │ ├── baseZSL.py │ ├── base_fc_layer_c.py │ ├── base_fc_layer_mean.py │ └── base_model.py ├── config └── awa_config.yml ├── registry.py ├── train_baseZSL.py ├── README.md ├── trainADA.py └── dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | code 2 | -------------------------------------------------------------------------------- /Overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vkkhare/ZSL-ADA/HEAD/Overall.png -------------------------------------------------------------------------------- /modules/training_utils.py: -------------------------------------------------------------------------------- 1 | import registry 2 | import torch 3 | 4 | 5 | @registry.register("loss","mse")(torch.nn.MSELoss) 6 | @registry.register("loss","l1")(torch.nn.L1Loss) 7 | @registry.register("loss","nll")(torch.nn.NLLLoss) 8 | 9 | @registry.register("optimizer","rms")(torch.optim.RMSprop) -------------------------------------------------------------------------------- /config/awa_config.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | file: "../../dataset/AWA2.mat" 3 | dset_name: "AWA2" 4 | additional_attributes: "" 5 | write_location: "../FinalWeights/" 6 | 7 | model: 8 | base_zsl_net: 9 | - FC_layer_m: "feed85m" 10 | - FC_layer_c: "feed85c" 11 | ADA: 12 | classifier: "linear" 13 | discriminator: "disc_v1" 14 | generator: "gen_v1" 15 | -------------------------------------------------------------------------------- /modules/AdversarialDomainAdaptation/generator.py: -------------------------------------------------------------------------------- 1 | import registry 2 | import torch 3 | 4 | @registry.register("generator","V1") 5 | class Generator(nn.Module): 6 | 7 | def __init__(self): 8 | super(Generator, self).__init__() 9 | 10 | self.generator = nn.Sequential( 11 | nn.Linear(2048, 1200), 12 | nn.Dropout(0.5), 13 | nn.BatchNorm1d(1200), 14 | nn.LeakyReLU(0.2), 15 | nn.Linear(1200, 1200), 16 | nn.Dropout(0.5), 17 | nn.BatchNorm1d(1200), 18 | nn.LeakyReLU(0.2), 19 | nn.Linear(1200, 2048) 20 | ) 21 | def forward(self, x): 22 | return self.generator(x) 23 | 24 | -------------------------------------------------------------------------------- /modules/baseZSL/baseZSL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import registry 3 | 4 | @registry.register('base_zsl','zsl_cls') 5 | class BaseZSLNet(nn.Module): 6 | 7 | def __init__(self, layer_m,layer_c): 8 | super(BaseZSLNet, self).__init__() 9 | 10 | self.FC_layer_m = registry.construct("layer_m",layer_m) 11 | self.FC_layer_c = registry.construct("layer_c",layer_c) 12 | 13 | def forward(self, x): 14 | mean = self.FC_layer_m(x) 15 | cov = 0.5 + torch.sigmoid(self.FC_layer_c(x)) 16 | return mean,cov 17 | 18 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 19 | torch.save(state, filename) 20 | if is_best: 21 | shutil.copyfile(filename, 'model_best.pth.tar') -------------------------------------------------------------------------------- /modules/AdversarialDomainAdaptation/classifier.py: -------------------------------------------------------------------------------- 1 | class Classifier(nn.Module): 2 | 3 | def __init__(self,num_classes): 4 | super(Classifier, self).__init__() 5 | 6 | self.classifier = nn.Sequential(nn.Linear(2048, num_classes), 7 | # nn.Dropout(0.5), 8 | # nn.BatchNorm1d(800), 9 | # nn.LeakyReLU(0.2), 10 | # # nn.Linear(1600,800), 11 | # # nn.BatchNorm1d(800), 12 | # # nn.LeakyReLU(0.2), 13 | # nn.Linear(800,num_classes), 14 | nn.LogSoftmax(1)) 15 | 16 | def forward(self, x): 17 | return self.classifier(x) 18 | -------------------------------------------------------------------------------- /modules/AdversarialDomainAdaptation/discriminator.py: -------------------------------------------------------------------------------- 1 | import registry 2 | import torch 3 | 4 | @registry.register("disc_cls","lnr_cls") 5 | class Discriminator_Class(nn.Module): 6 | 7 | def __init__(self,num_classes): 8 | super(Discriminator_Class, self).__init__() 9 | 10 | self.discriminator = nn.Sequential( 11 | nn.Linear(2048, 1600), 12 | nn.Dropout(0.5), 13 | nn.BatchNorm1d(1600), 14 | nn.LeakyReLU(0.2), 15 | nn.Linear(1600, 1) 16 | ) 17 | 18 | self.classify = nn.Sequential(nn.Linear(2048, num_classes), 19 | nn.LogSoftmax(1)) 20 | 21 | def forward(self, x): 22 | valid = self.discriminator(x) 23 | label = self.classify(x) 24 | return valid, label 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /registry.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import collections.abc 3 | import inspect 4 | import sys 5 | 6 | 7 | LOOKUP_DICT = collections.defaultdict(dict) 8 | 9 | 10 | def load(kind, name): 11 | registry = LOOKUP_DICT[kind] 12 | 13 | def decorator(obj): 14 | if name in registry: 15 | raise LookupError('{} already present'.format(name, kind)) 16 | registry[name] = obj 17 | return obj 18 | 19 | return decorator 20 | 21 | 22 | def lookup(kind, name): 23 | if isinstance(name, collections.abc.Mapping): 24 | name = name['name'] 25 | 26 | if kind not in LOOKUP_DICT: 27 | raise KeyError('Nothing registered under "{}"'.format(kind)) 28 | return LOOKUP_DICT[kind][name] 29 | 30 | 31 | def construct(kind, config, unused_keys=(), **kwargs): 32 | return instantiate( 33 | lookup(kind, config), 34 | config, 35 | unused_keys + ('name',), 36 | **kwargs) 37 | 38 | 39 | def instantiate(callable, config, unused_keys=(), **kwargs): 40 | merged = {**config, **kwargs} 41 | signature = inspect.signature(callable) 42 | for name, param in signature.parameters.items(): 43 | if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.VAR_POSITIONAL): 44 | raise ValueError('Unsupported kind for param {}: {}'.format(name, param.kind)) 45 | 46 | if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): 47 | return callable(**merged) 48 | 49 | missing = {} 50 | for key in list(merged.keys()): 51 | if key not in signature.parameters: 52 | if key not in unused_keys: 53 | missing[key] = merged[key] 54 | merged.pop(key) 55 | if missing: 56 | print('WARNING {}: superfluous {}'.format(callable, missing), file=sys.stderr) 57 | return callable(**merged) 58 | -------------------------------------------------------------------------------- /train_baseZSL.py: -------------------------------------------------------------------------------- 1 | from dataset import MeanCovDataset 2 | import os 3 | import h5py 4 | import hdf5storage 5 | import copy 6 | import numpy as np 7 | import torch 8 | import torch.utils.data as data 9 | from torch.autograd import Variable 10 | import torchvision.transforms as transforms 11 | import torchvision.models as models_res 12 | import cv2 13 | import math 14 | import itertools 15 | import datetime 16 | import time 17 | import random 18 | 19 | config = sys.argv[1] 20 | 21 | #config = "running_configs/job_config.yml" 22 | with open(config, 'r') as ymlfile: 23 | cfg = yaml.load(ymlfile,Loader=yaml.FullLoader) 24 | ymlfile.close() 25 | 26 | BatchSize= cfg['batch_size'] 27 | num_epochs = cfg['num_epochs'] 28 | device = torch.device('cuda:1') 29 | kwargs = {'num_workers': 4, 'pin_memory': True} 30 | learning_rate= cfg['lr'] 31 | 32 | transform = transforms.Compose([transforms.ToTensor()]) 33 | train_dataset = MeanCovDataset(cfg['dataset'],False) 34 | test_dataset = MeanCovDataset(cfg['dataset'],True) 35 | 36 | train_loader = torch.utils.data.DataLoader(train_dataset, 37 | batch_size=BatchSize, shuffle=True, **kwargs) 38 | test_loader = torch.utils.data.DataLoader(test_dataset, 39 | batch_size=BatchSize, shuffle=True, **kwargs) 40 | 41 | # Neural Network for learning relation between attribute vectors and parameters 42 | zslNet = registry.construct('base_zsl',cfg['base_zsl']) 43 | global_loss = [] 44 | partial_loss = [] 45 | 46 | for epoch in range(1, num_epochs + 1): 47 | acc_class= {} 48 | count_class= {} 49 | zslNet.train_zsl(train_loader, optimizer, epoch,partial_loss) 50 | zslNet.test(test_loader, epoch,global_loss, acc_class, count_class ) 51 | 52 | if epoch % 100 ==0 or global_loss[-1] > 0.70: 53 | save_checkpoint({ 54 | 'epoch': epoch + 1, 55 | 'state_dict': zslNet.model.state_dict(), 56 | 'optimizer' : zslNet.optimizer.state_dict(), 57 | }, False,'AwaTest/awanet_Awa_'+str(epoch)+ '_' + str(global_loss[-1])+'_.pth.tar') 58 | 59 | print('\n Max Res: Average loss: {:.8f},\n'.format(max(global_loss))) 60 | -------------------------------------------------------------------------------- /modules/baseZSL/base_fc_layer_c.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import registry 3 | 4 | @registry.register("layer_c","awa2") 5 | class AwaLayerCov(nn.Module): 6 | def __init__(self): 7 | super(AwaLayerCov, self).__init__() 8 | self.fc_layer = nn.Sequential(nn.Linear(85, 1600 ), 9 | nn.BatchNorm1d(1600), 10 | nn.LeakyReLU(0.95), 11 | nn.Linear(1600, 2048), 12 | nn.Dropout(0.1) 13 | ) 14 | 15 | def forward(self, x): 16 | return self.fc_layer(x) 17 | 18 | @registry.register("layer_c","sun") 19 | class SunLayerCov(nn.Module): 20 | def __init__(self): 21 | super(AwaLayerCov, self).__init__() 22 | self.fc_layer = nn.Sequential( 23 | nn.Linear(102, 1800), 24 | nn.BatchNorm1d(1800), 25 | nn.ReLU(), 26 | nn.Linear(1800,2048), 27 | nn.Dropout(0.05) 28 | ) 29 | 30 | def forward(self, x): 31 | return self.fc_layer(x) 32 | 33 | @registry.register("layer_c","cub") 34 | class CubLayerCov(nn.Module): 35 | def __init__(self): 36 | super(AwaLayerCov, self).__init__() 37 | self.fc_layer = nn.Sequential( 38 | nn.Linear(312, 1200), 39 | nn.BatchNorm1d(1200), 40 | nn.ReLU(), 41 | nn.Linear(1200, 1800), 42 | nn.Dropout(0.1), 43 | nn.BatchNorm1d(1800), 44 | nn.ReLU(), 45 | nn.Linear(1800, 2048), 46 | nn.Dropout(0.1), 47 | ) 48 | 49 | def forward(self, x): 50 | return self.fc_layer(x) 51 | 52 | @registry.register("layer_c","cub_large") 53 | class CubLargeLayerCov(nn.Module): 54 | def __init__(self): 55 | super(AwaLayerCov, self).__init__() 56 | self.fc_layer = nn.Sequential( 57 | nn.Linear(1024, 1500), 58 | nn.BatchNorm1d(1500), 59 | nn.ReLU(), 60 | nn.Linear(1500, 1800), 61 | nn.Dropout(0.3), 62 | nn.BatchNorm1d(1800), 63 | nn.ReLU(), 64 | nn.Linear(1800, 2048), 65 | nn.Dropout(0.1), 66 | ) 67 | 68 | def forward(self, x): 69 | return self.fc_layer(x) -------------------------------------------------------------------------------- /modules/baseZSL/base_fc_layer_mean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import registry 3 | 4 | @registry.register("layer_m","awa2") 5 | class AwaLayerMean(nn.Module): 6 | def __init__(self): 7 | super(AwaLayerMean, self).__init__() 8 | self.fc_layer = nn.Sequential(nn.Linear(85, 1600 ), 9 | nn.BatchNorm1d(1600), 10 | nn.LeakyReLU(0.95), 11 | nn.Linear(1600, 2048), 12 | nn.Dropout(0.1)) 13 | 14 | def forward(self, x): 15 | return self.fc_layer(x) 16 | 17 | @registry.register("layer_m","sun") 18 | class SunLayerMean(nn.Module): 19 | def __init__(self): 20 | super(AwaLayerMean, self).__init__() 21 | self.fc_layer = nn.Sequential( 22 | nn.Linear(102, 1800), 23 | nn.BatchNorm1d(1800), 24 | nn.ReLU(), 25 | nn.Linear(1800, 2048), 26 | nn.Dropout(0.05) 27 | ) 28 | 29 | def forward(self, x): 30 | return self.fc_layer(x) 31 | 32 | @registry.register("layer_m","cub") 33 | class CubLayerMean(nn.Module): 34 | def __init__(self): 35 | super(AwaLayerMean, self).__init__() 36 | self.fc_layer = nn.Sequential( 37 | nn.Linear(312, 1200 ), 38 | nn.BatchNorm1d(1200), 39 | nn.ReLU(), 40 | nn.Linear(1200, 1800), 41 | nn.Dropout(0.1), 42 | nn.BatchNorm1d(1800), 43 | nn.ReLU(), 44 | nn.Linear(1800, 2048), 45 | nn.Dropout(0.1), 46 | ) 47 | 48 | def forward(self, x): 49 | return self.fc_layer(x) 50 | 51 | @registry.register("layer_m","cub_large") 52 | class CubLargeLayerMean(nn.Module): 53 | def __init__(self): 54 | super(AwaLayerMean, self).__init__() 55 | self.fc_layer = nn.Sequential( 56 | nn.Linear(1024, 1500 ), 57 | nn.BatchNorm1d(1500), 58 | nn.LeakyReLU(0.2), 59 | nn.Linear(1500, 1800), 60 | nn.Dropout(0.3), 61 | nn.BatchNorm1d(1800), 62 | nn.LeakyReLU(0.2), 63 | nn.Linear(1800, 2048), 64 | nn.Dropout(0.1), 65 | ) 66 | 67 | def forward(self, x): 68 | return self.fc_layer(x) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ZSL-ADA 2 | Code accompanying the paper [A Generative Framework for Zero Shot Learning with Adversarial Domain Adaptation](https://arxiv.org/abs/1906.03038) published in [WACV 2020](http://wacv20.wacv.net/), with the following author list: [Varun Khare*](https://vkkhare.github.io/), [Divyat Mahajan*](https://divyat09.github.io/), [Homanga Bharadhwaj](https://homangab.github.io/), [Vinay K. Verma](), and [Piyush Rai](https://homangab.github.io/) 3 | 4 | # Instructions 5 | 6 | Modify the file `config/awa_config.yml` if needed and run `python trainADA.py config/awa_config.yml` to start training. 7 | 8 | # Brief note about the paper 9 | 10 | This paper present a domain adaptation based generative framework for zero shot learning. We address the problem of domain shift between the seen and unseen class distribution in Zero-Shot Learning (ZSL) and seek to minimize it by developing a generative model and training it via adversarial domain adaptation. Our approach is based on end-to-end learning of the class distributions of seen classes and unseen classes. To enable the model to learn the class distributions of unseen classes, we parameterize these class distributions in terms of the class attribute information (which is available for both seen and unseen classes). This provides a very simple way to learn the class distribution of any unseen class, given only its class attribute information, and no labeled training data. Training this model with adversarial domain adaptation provides robustness against the distribution mismatch between the data from seen and unseen classes. It also engenders a novel way for training neural net based classifiers to overcome the hubness problem in Zero-Shot learning. Through a comprehensive set of experiments, we show that our model yields superior accuracies as compared to various state-of-the-art zero shot learning models, on a variety of benchmark datasets. 11 | 12 |   13 |   14 |
15 | If you have questions/comments about the code or the paper, please contact [Varun Khare](http://home.iitk.ac.in/~varun/), [Divyat Mahajan](https://divyat09.github.io/), or [Homanga Bharadhwaj](https://homangab.github.io/) 16 | 17 | # If you find this repo useful, please consider citing our paper 18 | 19 | ```bibtex 20 | @inproceedings{khare2020generative, 21 | title={A Generative Framework for Zero Shot Learning with Adversarial Domain Adaptation}, 22 | author={Khare, Varun and Mahajan, Divyat and Bharadhwaj, Homanga and Verma, Vinay Kumar and Rai, Piyush}, 23 | booktitle={The IEEE Winter Conference on Applications of Computer Vision}, 24 | pages={3101--3110}, 25 | year={2020} 26 | } 27 | ``` 28 | -------------------------------------------------------------------------------- /trainADA.py: -------------------------------------------------------------------------------- 1 | from dataset import MeanCovDataset 2 | import os 3 | import h5py 4 | import hdf5storage 5 | import copy 6 | import numpy as np 7 | import torch 8 | import torch.utils.data as data 9 | from torch.autograd import Variable 10 | import torchvision.transforms as transforms 11 | import torchvision.models as models_res 12 | import cv2 13 | import math 14 | import itertools 15 | import datetime 16 | import time 17 | import random 18 | 19 | config = sys.argv[1] 20 | 21 | #config = "running_configs/job_config.yml" 22 | with open(config, 'r') as ymlfile: 23 | cfg = yaml.load(ymlfile,Loader=yaml.FullLoader) 24 | ymlfile.close() 25 | 26 | BatchSize= cfg['batch_size'] 27 | num_epochs = cfg['num_epochs'] 28 | device = torch.device('cuda:1') 29 | kwargs = {'num_workers': 4, 'pin_memory': True} 30 | learning_rate= cfg['lr'] 31 | 32 | transform = transforms.Compose([transforms.ToTensor()]) 33 | train_dataset = MeanCovDataset(cfg['dataset'],False) 34 | test_dataset = MeanCovDataset(cfg['dataset'],True) 35 | 36 | train_loader = torch.utils.data.DataLoader(train_dataset, 37 | batch_size=BatchSize, shuffle=True, **kwargs) 38 | test_loader = torch.utils.data.DataLoader(test_dataset, 39 | batch_size=BatchSize, shuffle=True, **kwargs) 40 | 41 | # Neural Network for learning relation between attribute vectors and parameters 42 | zslNet = registry.construct('ada',cfg['ada'],device) 43 | awa = BaseZSL(device) 44 | awa.load(cfg['baseZSL_checkpoint']) 45 | global_loss = [] 46 | partial_loss = [] 47 | 48 | for epoch in range(num_epochs): 49 | for i, batch in enumerate(test_loader): 50 | D_A.train() 51 | real_A = batch['feature'].cuda() 52 | 53 | #### Generation from awaNet 54 | C_all = torch.Tensor(dt.AttributeData[dt.testClassLabels-1])[:,0,:].cuda() 55 | dummy_labels = torch.arange(0, len(dt.testClassLabels)).cuda() 56 | 57 | iterates = int(real_A.shape[0] // len(dt.testClassLabels)) 58 | left = real_A.shape[0] % len(dt.testClassLabels) 59 | sample_input = C_all.repeat(iterates,1) 60 | labels_B = dummy_labels.repeat(iterates) 61 | 62 | if left != 0: 63 | sample_input = torch.cat((sample_input,C_all[:left,:]),0) 64 | labels_B = torch.cat((labels_B,dummy_labels[:left]),0) 65 | means,covs = awa.model(sample_input) 66 | 67 | ## reparametrisation trick 68 | noise = torch.randn_like(real_A) 69 | real_B = means.detach()+ noise * covs.detach() 70 | 71 | # prepare real and fake label 72 | valid = make_variable(torch.ones(real_A.size(0),1).type(torch.FloatTensor)) 73 | fake = make_variable(torch.zeros(real_B.size(0),1).type(torch.FloatTensor)) 74 | labels_A_soft,mask = predict_labels(real_A,test_loader,epoch) 75 | 76 | loss_G,loss_cycle,loss_identity,loss_GAN = zslNet.trainGenerator(real_A,real_B,labels_A_soft,labels_B) 77 | for iterate_disc in range(5): 78 | loss_D = zslNet.trainDiscriminators(real_A,real_B,labels_A_soft,labels_B) 79 | 80 | # -------------- 81 | # Log Progress 82 | # -------------- 83 | 84 | # Determine approximate time left 85 | batches_done = epoch * len(train_loader) + i 86 | batches_left = num_epochs * len(train_loader) - batches_done 87 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 88 | prev_time = time.time() 89 | 90 | # Print log 91 | print("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s" % 92 | (epoch, num_epochs, 93 | i, len(test_loader), 94 | loss_D.item(), loss_G.item(), 95 | loss_GAN.item(), loss_cycle.item(), 96 | loss_identity.item(), time_left)) 97 | 98 | zslNet.test_cycle(real_B,labels_B) 99 | zslNet.test_cycle(awa.model,real_B,labels_B) 100 | 101 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import hdf5storage 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | from torch.autograd import Variable 11 | import torchvision.transforms as transforms 12 | import torch.optim as optim 13 | import pandas as pd 14 | import torchvision.models as models_res 15 | import matplotlib.pyplot as plt 16 | import cv2 17 | import math 18 | import itertools 19 | import datetime 20 | import time 21 | import pandas as pd 22 | import random 23 | import matplotlib.pyplot as plt 24 | import seaborn as sns 25 | 26 | seed = 100 27 | torch.manual_seed(seed) 28 | if torch.cuda.is_available(): 29 | torch.cuda.manual_seed_all(seed) 30 | 31 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 32 | torch.save(state, filename) 33 | if is_best: 34 | shutil.copyfile(filename, 'model_best.pth.tar') 35 | 36 | def load_checkpoint(file,model,optimizer,best_prec1=None): 37 | if os.path.isfile(file): 38 | print("=> loading checkpoint '{}'".format(file)) 39 | checkpoint = torch.load(file) 40 | start_epoch = checkpoint['epoch'] 41 | # best_prec1 = checkpoint['best_prec1'] 42 | model.load_state_dict(checkpoint['state_dict']) 43 | optimizer.load_state_dict(checkpoint['optimizer']) 44 | print("=> loaded checkpoint '{}' (epoch {})" 45 | .format(file, checkpoint['epoch'])) 46 | return start_epoch 47 | else: 48 | print("=> no checkpoint found at '{}'".format(file)) 49 | return 0 50 | 51 | 52 | 53 | def make_variable(tensor,volatile=False): 54 | """Convert Tensor to Variable.""" 55 | if torch.cuda.is_available(): 56 | tensor = tensor.cuda() 57 | return Variable(tensor, volatile=volatile) 58 | 59 | class MeanCovDataset(data.Dataset): 60 | def __init__(self, mat_file,test_dataset=False,transform=None, generalized = False): 61 | mat = hdf5storage.loadmat(mat_file) 62 | self.test_bool = test_dataset 63 | # Loading training data: From mat format to dictionary 64 | # ( Train_Classes_Size, 1) 65 | self.trainClassLabels= mat['trainClassLabels'].astype(int) 66 | # ( Test_Classes_Size, 1) 67 | self.testClassLabels= mat['testClassLabels'].astype(int) 68 | # 40 for Awa 69 | self.train_class_dim= len( self.trainClassLabels ) 70 | # 10 for Awa 71 | self.test_class_dim= len( self.testClassLabels ) 72 | # Feat is (D*N) 73 | self.TestData= np.array( mat['test_feat'], dtype='float32' ).T 74 | 75 | if case==4: 76 | self.AttributeData= np.float32(np.load('FinalWeights/cub_attributes_reed.npy')) 77 | else: 78 | self.AttributeData= np.array( mat['classAttributes'], dtype='float32' ).T # C*D 79 | self.TrainData= np.array( mat['train_feat'], dtype='float32' ).T # N*D shape, thats why tranpose 80 | self.TrainLabels= np.array( mat['train_labels'] ) #N*1 81 | self.TestLabels= np.array( mat['test_labels'] ) 82 | self.AttributeDim= np.array( mat['classAttributes']).shape[0] 83 | [self.FeatureDim, self.TrainSize]= np.array( mat['train_feat'] ).shape 84 | self.transform = transform 85 | 86 | if generalized: 87 | indices = np.random.choice(self.TrainSize, int(self.TrainSize/5), replace=False) 88 | 89 | self.TestData = np.concatenate((self.TestData,self.TrainData[indices]),axis=0) 90 | self.TestLabels = np.concatenate((self.TestLabels,self.TrainLabels[indices]),axis=0) 91 | self.TrainData = np.delete(self.TrainData, indices, 0) 92 | self.TrainLabels = np.delete(self.TrainLabels, indices, 0) 93 | self.TrainSize = len(self.TrainData) 94 | self.trainClassLabels= np.unique(self.TrainLabels) 95 | self.testClassLabels= np.unique(self.TrainLabels) 96 | self.train_class_dim= len( self.trainClassLabels ) 97 | self.test_class_dim= len( self.testClassLabels ) 98 | 99 | def __len__(self): 100 | if (not self.test_bool): 101 | return self.TrainSize 102 | else: 103 | return len(self.TestData) 104 | 105 | def __getitem__(self, idx): 106 | if( not self.test_bool): 107 | x_n = self.TrainData[idx,:] 108 | class_label = int( self.TrainLabels[idx] ) 109 | label_index = None 110 | else: 111 | x_n = self.TestData[idx,:] 112 | class_label = int( self.TestLabels[idx] ) 113 | label_index = np.argwhere(test_dataset.testClassLabels == class_label)[0][0] 114 | 115 | class_attribute = self.AttributeData[class_label-1,:] 116 | sample = {'feature': x_n, 'class_label': class_label,'attribute': class_attribute,'label_index':label_index} 117 | if self.transform: 118 | sample['feature'] = self.transform(sample['feature']) 119 | sample['class_label'] = self.transform(sample['class_label']) 120 | sample['attribute'] = self.transform(sample['attribute']) 121 | return sample 122 | -------------------------------------------------------------------------------- /modules/baseZSL/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import registry 3 | 4 | class BaseZSL: 5 | def __init__(self, base_model,optimizer,decay,device): 6 | self.model = registry.construct('base_model',base_model).to(device) 7 | self.optimizer = registry.construct('optimizer',optimizer,[ 8 | {'params': filter(lambda p: p.requires_grad, self.model.FC_layer_m.parameters()),'weight_decay': decay['fc_m']}, 9 | {'params': filter(lambda p: p.requires_grad, self.model.FC_layer_c.parameters()),'weight_decay': decay['fc_c']} 10 | ],lr=learning_rate) 11 | self.device = device 12 | 13 | def load(self,file): 14 | if os.path.isfile(file): 15 | print("=> loading checkpoint '{}'".format(file)) 16 | checkpoint = torch.load(file) 17 | self.model.load_state_dict(checkpoint['state_dict']) 18 | print("=> loaded checkpoint '{}'" 19 | .format(file)) 20 | else: 21 | print("=> no checkpoint found at '{}'".format(file)) 22 | return 0 23 | 24 | def train_zsl(self, train_loader, epoch,losslist): 25 | self.model.train() 26 | b_idx = 0 27 | for x in train_loader: 28 | b_idx+=1 29 | x_feat = x['feature'].to(self.device) 30 | label = x['class_label'].type(torch.LongTensor).to(self.device) 31 | attribute = x['attribute'].to(self.device) 32 | self.optimizer.zero_grad() 33 | means,covs = self.model(attribute) 34 | loss_eval = torch.sum((x_feat-means)*covs*(x_feat-means)) - torch.sum(torch.log(covs))/2 35 | loss_eval.backward() 36 | self.optimizer.step() 37 | 38 | if b_idx%6 == 0: 39 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, b_idx * x_feat.shape[0], len(train_loader.dataset), 100. * b_idx / len(train_loader), loss_eval.item())) 40 | 41 | losslist.append(loss_eval.item()) 42 | 43 | def predict_labels(self,x_feat,dataset): 44 | self.model.eval() 45 | with torch.no_grad(): 46 | C_all = torch.Tensor(dataset.AttributeData[dataset.testClassLabels-1])[:,0,:].cuda() 47 | means,Covs = self.model(C_all) 48 | PredMat= torch.zeros( (x_feat.shape[0],len(dataset.testClassLabels)), dtype=torch.float32 ).cuda() 49 | class_probab = torch.ones(len(dataset.testClassLabels)).cuda() / len(dataset.testClassLabels) 50 | for Iter in range(len(dataset.testClassLabels)): 51 | Mean= means[Iter,:] 52 | CovarianceI= Covs[Iter,:] 53 | 54 | logDet= torch.sum(torch.log(CovarianceI)) 55 | logExp= -1 * torch.sum((x_feat-Mean)*CovarianceI*(x_feat-Mean),dim=1) 56 | Likelihood= logExp + logDet/2 57 | PredMat[:,Iter]= Likelihood # Likelihood computed for whole batch for Iter class 58 | 59 | P_c_x = F.softmax(PredMat,1) 60 | mask = torch.max(P_c_x,1)[0] > 0.5 61 | labels = torch.argmax(P_c_x,dim=1) 62 | return labels.detach(),mask.unsqueeze(1) 63 | 64 | def test(self, test_loader,epoch,losslist, acc_class, count_class): 65 | self.model.eval() 66 | test_loss = 0 67 | test_loss_2 = 0 68 | correct = 0 69 | with torch.no_grad(): 70 | dt = test_loader.dataset 71 | C_all = torch.Tensor(dt.AttributeData[dt.testClassLabels-1])[:,0,:].to(device) 72 | means,Covs = self.model(C_all) 73 | 74 | for x in test_loader: 75 | x_feat, TrueLabel = x['feature'].to(device), x['class_label'].to(device) 76 | PredMat= torch.zeros( (x_feat.shape[0],len(dt.testClassLabels)), dtype=torch.float32 ).to(device) 77 | for Iter in range(len(dt.testClassLabels)): 78 | Mean= means[Iter,:] 79 | CovarianceI= Covs[Iter,:] 80 | 81 | logDet= torch.sum(torch.log(CovarianceI)) 82 | logExp= -1 * torch.sum((x_feat-Mean)*CovarianceI*(x_feat-Mean),dim=1) 83 | Likelihood= logExp + logDet/2 84 | PredMat[:,Iter]= Likelihood # Likelihood computed for whole batch for Iter class 85 | 86 | PredLabel= torch.argmax(PredMat,dim=1) # index of max along each row 87 | PredLabel= PredLabel.cpu().numpy() 88 | TrueLabel= TrueLabel.cpu().numpy() 89 | BatchAcc= ( dt.testClassLabels[PredLabel].reshape(x_feat.shape[0]) - TrueLabel == 0 ) + 0 90 | 91 | for i in range( 0, BatchAcc.shape[0] ): 92 | if TrueLabel[i] not in acc_class.keys(): 93 | acc_class[ TrueLabel[i] ]= BatchAcc[i] 94 | count_class[ TrueLabel[i] ]= 1 95 | else: 96 | acc_class[ TrueLabel[i] ]+= BatchAcc[i] 97 | count_class[ TrueLabel[i] ]+=1 98 | 99 | test_loss_2 += np.sum( BatchAcc ) 100 | #test_loss += torch.sum( dt.testClassLabels[PredLabel].reshape(x_feat.shape[0]) - TrueLabel != 0 ).item() # sum up batch loss 101 | 102 | print(dt.testClassLabels[PredLabel].reshape(x_feat.shape[0])) 103 | print("means \n",means,"means>0",means[means>1].shape,"covs \n",Covs) 104 | 105 | for key in acc_class.keys(): 106 | test_loss+= acc_class[key] / count_class[key] 107 | 108 | print(acc_class, '\n') 109 | test_loss /= len( dt.testClassLabels ) 110 | test_loss_2 /= len(test_loader.dataset.TestData) 111 | print('\n Test set: Average loss: {:.8f},\n'.format( 112 | test_loss)) 113 | 114 | losslist.append(test_loss) 115 | 116 | -------------------------------------------------------------------------------- /modules/AdversarialDomainAdaptation/cycleGAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import registry 3 | 4 | @registry.register('ada','cycleGAN') 5 | class CycleGAN: 6 | def __init__(self,generator,discriminator,criterions,optimizers,numTestClass): 7 | super(CycleGAN, self).__init__() 8 | self.criterion_dist = registry.construct("loss",criterions["c_dist"]) 9 | self.criterion_GAN = registry.construct("loss",criterions["c_GAN"]) 10 | self.criterion_cycle = registry.construct("loss",criterions["c_cyc"]) 11 | self.criterion_identity = registry.construct("loss",criterions["c_id"]) 12 | self.criterion_task_loss = registry.construct("loss",criterions["c_task"]) 13 | self.G_AB = registry.construct("generator",generator) 14 | self.G_BA = registry.construct("generator",generator) 15 | self.D_A = registry.construct("disc_cls",discriminator,numTestClass) 16 | self.D_B = registry.construct("disc_cls",discriminator,numTestClass) 17 | self.optimizer_G = registry.construct("optimizer",optimizers["generators"], 18 | itertools.chain(G_AB.parameters(), G_BA.parameters()), 19 | learning_rate) 20 | self.optimizer_D_A = registry.construct("optimizer",optimizers["disc_A"], 21 | D_A.parameters(),learning_rate) 22 | self.optimizer_D_B = registry.construct("optimizer",optimizers["disc_B"], 23 | D_B.parameters(),learning_rate) 24 | 25 | def toCuda(self): 26 | self.G_AB.cuda() 27 | self.G_BA.cuda() 28 | self.D_A.cuda() 29 | self.D_B.cuda() 30 | self.D_B_side.cuda() 31 | self.criterion_GAN.cuda() 32 | self.criterion_cycle.cuda() 33 | self.criterion_identity.cuda() 34 | self.criterion_task_loss.cuda() 35 | self.criterion_dist.cuda() 36 | 37 | def load(self,gab_file,gba_file,da_file,db_file): 38 | self.G_AB.load_state_dict(torch.load(gab_file)) 39 | self.G_BA.load_state_dict(torch.load(gba_file)) 40 | self.D_A.load_state_dict(torch.load(da_file)) 41 | self.D_B.load_state_dict(torch.load(db_file)) 42 | 43 | def predict(self): 44 | self.D_A.eval() 45 | _, predictions = self.D_A(x) 46 | labels = torch.argmax(predictions,dim=1) # index of max along each row 47 | mask = torch.max(predictions,1)[0]>0.1 48 | return labels.detach(),mask.unsqueeze(1) 49 | 50 | def test_cycle(real_B,labels,test_loader): 51 | self.D_A.eval() 52 | test_loss = 0 53 | correct = 0 54 | acc_class = {} 55 | count_class = {} 56 | dt = test_loader.dataset 57 | samples = self.G_BA(real_B).detach() 58 | with torch.no_grad(): 59 | for x in test_loader: 60 | x_feat, TrueLabel = x['feature'].cuda(), x['class_label'].cuda() 61 | _, predictions = self.D_A(x_feat) 62 | PredLabel= torch.argmax(predictions,dim=1).cpu() # index of max along each row 63 | TrueLabel= TrueLabel.cpu().numpy() 64 | test_loss += np.sum( dt.testClassLabels[PredLabel].reshape(x_feat.shape[0]) - TrueLabel == 0 ) # sum up batch loss 65 | BatchAcc= ( dt.testClassLabels[PredLabel].reshape(x_feat.shape[0]) - TrueLabel == 0 ) + 0 66 | 67 | for i in range( 0, BatchAcc.shape[0] ): 68 | if TrueLabel[i] not in acc_class.keys(): 69 | acc_class[ TrueLabel[i] ]= BatchAcc[i] 70 | count_class[ TrueLabel[i] ]= 1 71 | else: 72 | acc_class[ TrueLabel[i] ]+= BatchAcc[i] 73 | count_class[ TrueLabel[i] ]+=1 74 | 75 | for key in acc_class.keys(): 76 | test_loss+= acc_class[key] / count_class[key] 77 | _,predictions = D_A(samples) 78 | PredLabel= torch.argmax(predictions,dim=1) # index of max along each row 79 | sample_acc = torch.sum( PredLabel - labels == 0 ).item() # sum up batch loss 80 | 81 | # unseen_acc /= len(dt.Tes) 82 | test_loss /= len(test_loader.dataset.TestData) 83 | sample_acc /= labels.shape[0] 84 | 85 | 86 | print('\n Test set discriminator prediction: Average acc: {:.8f},\n'.format( 87 | test_loss)) 88 | g_loss.append(test_loss) 89 | print('\n Sample set acc set: Average acc: {:.8f},\n'.format( 90 | sample_acc)) 91 | 92 | def test_cycle(awa,real_B,labels,test_loader) 93 | awa.eval() 94 | self.G_BA.eval() 95 | test_loss = 0 96 | correct = 0 97 | acc_class = {} 98 | count_class = {} 99 | dt = test_loader.dataset 100 | samples = self.G_BA(real_B).detach() 101 | with torch.no_grad(): 102 | C_all = torch.Tensor(dt.AttributeData[dt.testClassLabels-1])[:,0,:].cuda() 103 | means,Covs = awa(C_all) 104 | up_means = torch.zeros(means.shape[0],means.shape[1]).cuda() 105 | for i in range(means.shape[0]): 106 | noise = torch.randn(100,means.shape[1]).cuda() 107 | noise_Vec = noise * Covs[i] + means[i] 108 | map_cluster = self.G_BA(noise_Vec) 109 | up_means[i,:] = torch.sum(map_cluster,0)/100 110 | for x in test_loader: 111 | x_feat, TrueLabel = x['feature'].cuda(), x['class_label'].cuda() 112 | PredMat= torch.zeros( (x_feat.shape[0],len(dt.testClassLabels)), dtype=torch.float32 ).cuda() 113 | for Iter in range(len(dt.testClassLabels)): 114 | Mean= up_means[Iter,:] 115 | CovarianceI= Covs[Iter,:] 116 | 117 | logDet= torch.sum(torch.log(CovarianceI)) 118 | logExp= -1 * torch.sum((x_feat-Mean)*CovarianceI*(x_feat-Mean),dim=1) 119 | Likelihood= logExp + logDet/2 120 | PredMat[:,Iter]= Likelihood # Likelihood computed for whole batch for Iter class 121 | 122 | PredLabel= torch.argmax(PredMat,dim=1) # index of max along each row 123 | 124 | TrueLabel= TrueLabel.cpu().numpy() 125 | BatchAcc= ( dt.testClassLabels[PredLabel.cpu()].reshape(x_feat.shape[0]) - TrueLabel == 0 ) + 0 126 | 127 | for i in range( 0, BatchAcc.shape[0] ): 128 | if TrueLabel[i] not in acc_class.keys(): 129 | acc_class[ TrueLabel[i] ]= BatchAcc[i] 130 | count_class[ TrueLabel[i] ]= 1 131 | else: 132 | acc_class[ TrueLabel[i] ]+= BatchAcc[i] 133 | count_class[ TrueLabel[i] ]+=1 134 | 135 | for key in acc_class.keys(): 136 | test_loss+= acc_class[key] / count_class[key] 137 | test_loss /= np.unique(TrueLabel).shape[0] 138 | print('\nGen Mapped Test set means knn: Average Acc: {:.8f},\n'.format( 139 | test_loss)) 140 | self.G_BA.train() 141 | p_loss.append(test_loss) 142 | 143 | 144 | 145 | def trainGenerator(real_A,real_B,labels_A_soft,labels_B): 146 | 147 | self.optimizer_G.zero_grad() 148 | 149 | # Identity loss 150 | loss_id_A = self.criterion_identity(self.G_BA(real_A), real_A) 151 | loss_id_B = self.criterion_identity(self.G_AB(real_B), real_B) 152 | loss_identity = (loss_id_A + loss_id_B) / 2 153 | 154 | # GAN loss 155 | fake_B = self.G_AB(real_A) 156 | fake_B_append = torch.cat((fake_B, 157 | labels_A_soft.unsqueeze(1) 158 | .type(torch.cuda.FloatTensor)), 159 | 1) 160 | 161 | validity_B, pred_labels_B = self.D_B(fake_B) 162 | loss_GAN_AB = -torch.mean(validity_B) 163 | fake_A = self.G_BA(real_B) 164 | fake_A_append = torch.cat((fake_B, 165 | labels_B.unsqueeze(1) 166 | .type(torch.cuda.FloatTensor)), 167 | 1) 168 | 169 | validity_A, pred_labels_A = self.D_A(fake_A) 170 | loss_GAN_BA = -torch.mean(validity_A) 171 | loss_GAN = (loss_GAN_AB + loss_GAN_BA ) / 2 172 | 173 | # #criterion task loss 174 | loss_task_B = self.criterion_task_loss(pred_labels_B * validity_B, 175 | labels_B) 176 | loss_task_A = self.criterion_task_loss(pred_labels_A * validity_A, 177 | labels_A_soft) 178 | task_loss = (loss_task_A + loss_task_B)/2 179 | 180 | # Cycle loss 181 | recov_A = self.G_BA(fake_B) 182 | loss_cycle_A = self.criterion_cycle(recov_A, real_A) 183 | recov_B = self.G_AB(fake_A) 184 | loss_cycle_B = self.criterion_cycle(recov_B, real_B ) 185 | 186 | loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 187 | 188 | # Total loss 189 | loss_G = loss_GAN + \ 190 | self.lambda_cyc * loss_cycle + \ 191 | self.lambda_id * loss_identity + self.lmda *1e-6* task_loss 192 | loss_G.backward() 193 | self.optimizer_G.step() 194 | return loss_G,loss_cycle,loss_identity,loss_GAN 195 | 196 | def trainDiscriminators(real_A,real_B,labels_A_soft,labels_B): 197 | self.optimizer_D_A.zero_grad() 198 | 199 | # ----------------------- 200 | # Train Discriminator A 201 | # ----------------------- 202 | 203 | # Real loss 204 | validity_A, pred_labels_A = self.D_A(real_A) 205 | loss_real = -torch.mean(validity_A) + \ 206 | lmda * 0.2 * \ 207 | self.criterion_task_loss(pred_labels_A,labels_A_soft) 208 | # Fake loss (on batch of previously generated samples) 209 | fake_A_ = fake_A 210 | validity_fake_A, pred_labels_fake_A = self.D_A(fake_A_.detach()) 211 | if epoch > 500: 212 | loss_fake = torch.mean(validity_fake_A) + \ 213 | lmda * 0.8 * \ 214 | self.criterion_task_loss(pred_labels_fake_A,labels_B) 215 | else: 216 | loss_fake = torch.mean(validity_fake_A) 217 | # Total loss 218 | loss_D_A = (loss_real + loss_fake) / 2 219 | loss_D_A.backward() 220 | self.optimizer_D_A.step() 221 | for p in D_A.parameters(): 222 | p.data.clamp_(-0.01,0.01) 223 | 224 | # ----------------------- 225 | # Train Discriminator B 226 | # ----------------------- 227 | 228 | self.optimizer_D_B.zero_grad() 229 | 230 | # Real loss 231 | validity_B, pred_labels_B = self.D_B(real_B) 232 | loss_real = -torch.mean(validity_B) + \ 233 | lmda * 0.2 * \ 234 | self.criterion_task_loss(pred_labels_B,labels_B) 235 | fake_B_ = fake_B 236 | validity_fake_B, pred_labels_fake_B = self.D_B(fake_B_.detach()) 237 | if epoch >500: 238 | loss_fake = torch.mean(validity_fake_B) + \ 239 | lmda * 0.8 * \ 240 | self.criterion_task_loss(pred_labels_fake_B,labels_A_soft) 241 | else: 242 | loss_fake = torch.mean(validity_fake_B) 243 | # Total loss 244 | loss_D_B = (loss_real + loss_fake) / 2 245 | loss_D_B.backward() 246 | self.optimizer_D_B.step() 247 | for p in D_B.parameters(): 248 | p.data.clamp_(-0.01,0.01) 249 | loss_D = (loss_D_A + loss_D_B) / 2 250 | return loss_D 251 | 252 | 253 | --------------------------------------------------------------------------------