├── kernelPoly.m ├── results.png ├── logmvnpdf.m ├── Inductive_setting.m ├── Transductive_setting.m ├── awa1.m ├── awa2.m ├── sun.m ├── cub.m └── README.md /kernelPoly.m: -------------------------------------------------------------------------------- 1 | function [XX] = kernelPoly(X1,X2,d) 2 | XX = (1+X1*X2').^d; 3 | end 4 | -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vkverma01/Zero-Shot-Learning/HEAD/results.png -------------------------------------------------------------------------------- /logmvnpdf.m: -------------------------------------------------------------------------------- 1 | function [logp] = logmvnpdf(x,mu,Sigma) 2 | % outputs log likelihood array for observations x where x_n ~ N(mu,Sigma) 3 | % x is NxD, mu is 1xD, Sigma is DxD 4 | 5 | [N,D] = size(x); 6 | const = -0.5 * D * log(2*pi); 7 | xc = bsxfun(@minus,x,mu); 8 | term1 = -0.5 * sum((xc / Sigma) .* xc, 2); % N x 1 9 | term2 = const - 0.5 * logdet(Sigma); % scalar 10 | logp = term1 + term2; 11 | 12 | end 13 | 14 | function y = logdet(A) 15 | 16 | U = chol(A); 17 | y = 2*sum(log(diag(U))); 18 | 19 | end 20 | -------------------------------------------------------------------------------- /Inductive_setting.m: -------------------------------------------------------------------------------- 1 | function [Accuracy]=Inductive_setting(test_feat,opt) 2 | 3 | y=[]; 4 | % opt.regulariser=0.1; 5 | for i=1:size(opt.mu_unk,2) 6 | pred=logmvnpdf(test_feat',opt.mu_unk(:,i)',diag(opt.sigma_unk(:,i)+opt.regulariser)); 7 | y=[y,pred]; 8 | end 9 | Nt=size(test_feat,2); 10 | [~,Ind]=max(y,[],2); 11 | index=opt.testClassLabels(Ind); 12 | op=find((index-opt.test_labels)==0); 13 | % [Precision, Recall]=precision_recall(index,opt.test_labels); 14 | Accuracy=(length(op)/Nt)*100; 15 | 16 | end 17 | -------------------------------------------------------------------------------- /Transductive_setting.m: -------------------------------------------------------------------------------- 1 | function [Accuracy]=Transductive_setting(test_feat,opt) 2 | 3 | test_class=size(opt.mu_unk,2); 4 | N_cluster=1; 5 | Nt=size(test_feat,2); 6 | opt.regulariser=.4; 7 | [model.MU,model.S,model.PI,~, ~] = vl_gmm(test_feat,test_class, ... 8 | 'initialization','custom', ... 9 | 'InitMeans',opt.mu_unk, ... 10 | 'InitCovariances',opt.sigma_unk, ... 11 | 'InitPriors',opt.PComponents, 'CovarianceBound', opt.regulariser, 'MaxNumIterations', 1000); 12 | 13 | model.N_cluster=N_cluster; 14 | 15 | y=[]; 16 | for i=1:size(model.MU,2) 17 | pred=logmvnpdf(test_feat',model.MU(:,i)',diag(model.S(:,i))+0.05); 18 | y=[y,pred]; 19 | end 20 | 21 | [~,clusterX]=max(y,[],2); 22 | index=opt.testClassLabels(clusterX); 23 | op=find((index-opt.test_labels)==0); 24 | Accuracy=(length(op)/Nt)*100; 25 | end 26 | -------------------------------------------------------------------------------- /awa1.m: -------------------------------------------------------------------------------- 1 | %clc; 2 | clear; 3 | 4 | % Load dataset 5 | path='dataset/'; 6 | load ([path,'AWA1.mat']); 7 | 8 | train_class=size(trainClassLabels,1); % train class 9 | test_class=size(testClassLabels,1); % test class 10 | % 11 | test_feat=double(test_feat); 12 | classAttributes=classAttributes'; 13 | 14 | attribute_dim=size(classAttributes,2); 15 | [d,Ns]=size(train_feat); 16 | A=classAttributes(trainClassLabels,:)'; 17 | % 18 | %================================================ 19 | K_trtr = kernelPoly(A',A',2); 20 | K_trte = kernelPoly(A',classAttributes(testClassLabels,:),2); 21 | %================================================ 22 | 23 | 24 | mu_cap=zeros(d,train_class); 25 | sigma_s=zeros(d,train_class); 26 | 27 | for i=1:train_class 28 | temp=trainClassLabels(i); 29 | class_feat=train_feat(:,train_labels==temp); 30 | MU=mean(class_feat,2); 31 | S=var(class_feat'); 32 | mu_cap(:,i)=MU; 33 | sigma_s(:,i)=S; 34 | end 35 | 36 | mu_unk=zeros(d,test_class); 37 | sigma_unk=zeros(d,test_class); 38 | 39 | % Hyperparameter lamda1 & lamda2 40 | lamda1=0.1;lamda2=100000000; reg=0.05; 41 | 42 | alpha_mu = (K_trtr+lamda1*eye(train_class))\mu_cap(:,:)'; 43 | mu_unk(:,:)=alpha_mu'*K_trte; 44 | 45 | logsigmaS=log(sigma_s(:,:)+.001); % 0.001 added for stability 46 | alpha = (K_trtr+lamda2*eye(train_class))\logsigmaS'; 47 | sigma_unk(:,:)=exp(alpha'*K_trte); 48 | 49 | PComponents=ones(1,test_class)/test_class; 50 | opt.PComponents=PComponents; 51 | opt.testClassLabels=testClassLabels; 52 | opt.test_labels=test_labels; 53 | opt.regulariser=reg; % for stability 54 | opt.mu_unk=mu_unk; 55 | opt.sigma_unk=sigma_unk; 56 | 57 | % Inductive setting 58 | [Accuracy1]=Inductive_setting(test_feat,opt); 59 | % Transductive Setting 60 | [Accuracy2]=Transductive_setting(test_feat,opt); 61 | result=[Accuracy1,Accuracy2]; 62 | 63 | disp(['Inductive Accuracy = ',num2str(Accuracy1), '% Transductive Accuracy = ', num2str(Accuracy2),'%']) 64 | 65 | -------------------------------------------------------------------------------- /awa2.m: -------------------------------------------------------------------------------- 1 | %clc; 2 | clear; 3 | 4 | % Load dataset 5 | path='dataset/'; 6 | load ([path,'AWA2.mat']); 7 | 8 | train_class=size(trainClassLabels,1); % train class 9 | test_class=size(testClassLabels,1); % test class 10 | % 11 | test_feat=double(test_feat); 12 | classAttributes=classAttributes'; 13 | 14 | attribute_dim=size(classAttributes,2); 15 | [d,Ns]=size(train_feat); 16 | A=classAttributes(trainClassLabels,:)'; 17 | % 18 | %================================================ 19 | K_trtr = kernelPoly(A',A',2); 20 | K_trte = kernelPoly(A',classAttributes(testClassLabels,:),2); 21 | %================================================ 22 | 23 | 24 | mu_cap=zeros(d,train_class); 25 | sigma_s=zeros(d,train_class); 26 | 27 | for i=1:train_class 28 | temp=trainClassLabels(i); 29 | class_feat=train_feat(:,train_labels==temp); 30 | MU=mean(class_feat,2); 31 | S=var(class_feat'); 32 | mu_cap(:,i)=MU; 33 | sigma_s(:,i)=S; 34 | end 35 | 36 | mu_unk=zeros(d,test_class); 37 | sigma_unk=zeros(d,test_class); 38 | 39 | % Hyperparameter lamda1 & lamda2 40 | lamda1=0.01;lamda2=100000000; reg=0.05; 41 | alpha_mu = (K_trtr+lamda1*eye(train_class))\mu_cap(:,:)'; 42 | mu_unk(:,:)=alpha_mu'*K_trte; 43 | 44 | logsigmaS=log(sigma_s(:,:)+.001); % 0.001 added for stability 45 | alpha = (K_trtr+lamda2*eye(train_class))\logsigmaS'; 46 | sigma_unk(:,:)=exp(alpha'*K_trte); 47 | 48 | PComponents=ones(1,test_class)/test_class; 49 | opt.PComponents=PComponents; 50 | opt.testClassLabels=testClassLabels; 51 | opt.test_labels=test_labels; 52 | opt.regulariser=reg; % for stability 53 | opt.mu_unk=mu_unk; 54 | opt.sigma_unk=sigma_unk; 55 | 56 | % Inductive setting 57 | [Accuracy1]=Inductive_setting(test_feat,opt); 58 | 59 | % Transductive Setting 60 | [Accuracy2]=Transductive_setting(test_feat,opt); 61 | 62 | result=[Accuracy1,Accuracy2]; 63 | % 64 | disp(['Inductive Accuracy = ',num2str(Accuracy1), '% Transductive Accuracy = ', num2str(Accuracy2),'%']) 65 | -------------------------------------------------------------------------------- /sun.m: -------------------------------------------------------------------------------- 1 | %clc; 2 | clear; 3 | 4 | % Load dataset 5 | path='dataset/'; 6 | load ([path,'SUN.mat']); 7 | 8 | train_class=size(trainClassLabels,1); % train class 9 | test_class=size(testClassLabels,1); % test class 10 | % 11 | test_feat=double(test_feat); 12 | classAttributes=classAttributes'; 13 | 14 | attribute_dim=size(classAttributes,2); 15 | [d,Ns]=size(train_feat); 16 | A=classAttributes(trainClassLabels,:)'; 17 | % 18 | %================================================ 19 | K_trtr = kernelPoly(A',A',2); 20 | K_trte = kernelPoly(A',classAttributes(testClassLabels,:),2); 21 | %================================================ 22 | 23 | 24 | mu_cap=zeros(d,train_class); 25 | sigma_s=zeros(d,train_class); 26 | 27 | for i=1:train_class 28 | temp=trainClassLabels(i); 29 | class_feat=train_feat(:,train_labels==temp); 30 | MU=mean(class_feat,2); 31 | S=var(class_feat'); 32 | mu_cap(:,i)=MU; 33 | sigma_s(:,i)=S; 34 | end 35 | 36 | mu_unk=zeros(d,test_class); 37 | sigma_unk=zeros(d,test_class); 38 | 39 | % Hyperparameter lamda1 & lamda2 40 | lamda1=0.1;lamda2=100000000; reg=0.05; 41 | 42 | alpha_mu = (K_trtr+lamda1*eye(train_class))\mu_cap(:,:)'; 43 | mu_unk(:,:)=alpha_mu'*K_trte; 44 | 45 | 46 | logsigmaS=log(sigma_s(:,:)+.001); % 0.1 added for stability 47 | alpha = (K_trtr+lamda2*eye(train_class))\logsigmaS'; 48 | sigma_unk(:,:)=exp(alpha'*K_trte); 49 | 50 | PComponents=ones(1,test_class)/test_class; 51 | opt.PComponents=PComponents; 52 | opt.testClassLabels=testClassLabels; 53 | opt.test_labels=test_labels; 54 | opt.regulariser=reg; % for stability 55 | opt.mu_unk=mu_unk; 56 | opt.sigma_unk=sigma_unk; 57 | 58 | % Inductive setting 59 | [Accuracy1]=Inductive_setting(test_feat,opt); 60 | % Transductive Setting 61 | [Accuracy2]=Transductive_setting(test_feat,opt); 62 | result=[Accuracy1,Accuracy2]; 63 | 64 | disp(['Inductive Accuracy = ',num2str(Accuracy1), '% Transductive Accuracy = ', num2str(Accuracy2),'%']) 65 | 66 | -------------------------------------------------------------------------------- /cub.m: -------------------------------------------------------------------------------- 1 | %clc; 2 | clear; 3 | 4 | path='dataset/'; 5 | load ([path,'CUB200.mat']); 6 | % 7 | train_class=size(trainClassLabels,1); % train class 8 | test_class=size(testClassLabels,1); % test class 9 | % 10 | test_feat=double(test_feat); 11 | classAttributes=classAttributes'; % #class * Attribute 12 | 13 | attribute_dim=size(classAttributes,2); 14 | [d,Ns]=size(train_feat); 15 | A=classAttributes(trainClassLabels,:)'; 16 | % 17 | %================================================ 18 | K_trtr = kernelPoly(A',A',2); 19 | K_trte = kernelPoly(A',classAttributes(testClassLabels,:),2); 20 | %================================================ 21 | 22 | N_cluster=1; 23 | mu_cap=zeros(d,train_class); 24 | sigma_s=zeros(d,train_class); 25 | 26 | for i=1:train_class 27 | temp=trainClassLabels(i); 28 | class_feat=train_feat(:,train_labels==temp); 29 | MU=mean(class_feat,2); 30 | S=var(class_feat'); 31 | mu_cap(:,i)=MU; 32 | sigma_s(:,i)=S; 33 | end 34 | 35 | mu_unk=zeros(d,test_class); 36 | sigma_unk=zeros(d,test_class); 37 | 38 | % Hyperparameter lamda1 & lamda2 39 | lamda1=1;lamda2=10000000;reg=0.05; 40 | 41 | alpha_mu = (K_trtr+lamda1*eye(train_class))\mu_cap(:,:)'; 42 | mu_unk(:,:)=alpha_mu'*K_trte; 43 | 44 | 45 | logsigmaS=log(sigma_s(:,:)+.0001); % 0.0001 added for stability 46 | alpha = (K_trtr+lamda2*eye(train_class))\logsigmaS'; 47 | sigma_unk(:,:)=exp(alpha'*K_trte); 48 | 49 | PComponents=ones(1,test_class*N_cluster)/test_class; 50 | opt.PComponents=PComponents; 51 | opt.testClassLabels=testClassLabels; 52 | opt.test_labels=test_labels; 53 | opt.regulariser=reg; % for stability 54 | opt.mu_unk=mu_unk; 55 | opt.sigma_unk=sigma_unk; 56 | 57 | % Inductive setting 58 | [Accuracy1]=Inductive_setting(test_feat,opt); 59 | % Transductive Setting 60 | [Accuracy2]=Transductive_setting(test_feat,opt); 61 | result=[Accuracy1,Accuracy2]; 62 | 63 | disp(['Inductive Accuracy = ',num2str(Accuracy1), '% Transductive Accuracy = ', num2str(Accuracy2),'%']) 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Note: This code is based on the new split proposed by [Y Xian et. al.](https://arxiv.org/pdf/1707.00600.pdf) 2 | 3 | # A Simple Exponential Family Framework for Zero-Shot Learning 4 | 5 | ## Abstract: 6 | Abstract. We present a simple generative framework for learning to predict previously 7 | unseen classes, based on estimating class-attribute-gated class-conditional 8 | distributions. We model each class-conditional distribution as an exponential family 9 | distribution and the parameters of the distribution of each seen/unseen class 10 | are defined as functions of the respective observed class attributes. These functions 11 | can be learned using only the seen class data and can be used to predict 12 | the parameters of the class-conditional distribution of each unseen class. Unlike 13 | most existing methods for zero-shot learning that represent classes as fixed embeddings 14 | in some vector space, our generative model naturally represents each 15 | class as a probability distribution. It is simple to implement and also allows leveraging 16 | additional unlabeled data from unseen classes to improve the estimates of 17 | their class-conditional distributions using transductive/semi-supervised learning. 18 | Moreover, it extends seamlessly to few-shot learning by easily updating these 19 | distributions when provided with a small number of additional labelled examples 20 | from unseen classes. Through a comprehensive set of experiments on several 21 | benchmark data sets, we demonstrate the efficacy of our framework. 22 | 23 | ## Prerequisites 24 | 25 | ``` 26 | Matlab 27 | vlfeat toolbox 28 | ``` 29 | ## Usage 30 | ``` 31 | cub.m 32 | sun.m 33 | awa1.m 34 | awa2.m 35 | ``` 36 | 37 | ## Dataset 38 | * AWA2: [Animals with Attributes 2](https://cvml.ist.ac.at/AwA2/) 39 | * CUB: [Caltech-UCSD Birds-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) 40 | * SUN: [SUN Attribute](https://cs.brown.edu/~gen/sunattributes.html) 41 | 42 | Complete Datasets can be downloaded [here](https://drive.google.com/open?id=1o0uvjk0y3saLzaOT0dn4jMfV4EXthVcy). For more detail about train/test split please refer to our [paper](https://arxiv.org/pdf/1707.08040.pdf) 43 | 44 | ## Result 45 | ![res](https://github.com/vkverma01/Zero-Shot/blob/master/results.png) 46 | Here GFZSL-Trans represents the result in the transductive. 47 | 48 | ## References 49 | If you are using this work please refer the ECML-17 paper: 50 | 51 | ``` 52 | @inproceedings{verma2017simple, 53 | title={A simple exponential family framework for zero-shot learning}, 54 | author={Verma, Vinay Kumar and Rai, Piyush}, 55 | booktitle={Joint European Conference on Machine Learning and Knowledge Discovery in Databases}, 56 | pages={792--808}, 57 | year={2017}, 58 | organization={Springer} 59 | } 60 | ``` 61 | 62 | ## License 63 | 64 | Code are released for non-commercial and research purposes **only**. For commercial purposes, please contact the authors. 65 | --------------------------------------------------------------------------------