├── LICENSE ├── README.md ├── stkernelAwA ├── attrClasses.mat ├── binAttrClasses.mat ├── classes.mat ├── experimentIndices.mat ├── getClassesAndIndices.m ├── getClassesAndIndicesAndFixClassMess.m ├── getClassesAndIndicesRaw.m └── runExperiment.m ├── stkernelSUNAttributes ├── allSignsGeneralscript.m ├── attrClasses.mat ├── binAttrClasses.mat ├── experimentIndices.mat ├── getDataInfo.m ├── infoPrediction.mat ├── mapScores.mat ├── runExperiment.m └── visualizationScript.m ├── stkernelaPY ├── attrClasses.mat ├── binAttrClasses.mat ├── experimentIndices.mat ├── getClassesAndIndices.m ├── getClassesAndIndicesRaw.m ├── getDataInfo.m ├── getMinMaxAttributes.m ├── getSoftAttributes.m ├── mapScores.mat ├── meanStdAttrClasses.mat ├── runExperiment.m └── track.mat └── validationClasses_generalscript.m /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Bernardino Romera-Paredes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Embarrassingly-simple-ZSL 2 | This repository contains the code for the real data experiments appearing in our paper An embarrassingly simple approach to zero-shot learning, presented at ICML 2015. 3 | 4 | Framework used in our paper. 5 | 6 | Before running the experiments you have to download the datasets from http://pub.ist.ac.at/~chl/ABC/. 7 | Then you uncompress each dataset in any directory you want, and then you put that directory in datasetPath. 8 | Move to any of the directories, corresponding to each of the real dataset, and run runExperiment.m 9 | -------------------------------------------------------------------------------- /stkernelAwA/attrClasses.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelAwA/attrClasses.mat -------------------------------------------------------------------------------- /stkernelAwA/binAttrClasses.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelAwA/binAttrClasses.mat -------------------------------------------------------------------------------- /stkernelAwA/classes.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelAwA/classes.mat -------------------------------------------------------------------------------- /stkernelAwA/experimentIndices.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelAwA/experimentIndices.mat -------------------------------------------------------------------------------- /stkernelAwA/getClassesAndIndices.m: -------------------------------------------------------------------------------- 1 | here=pwd; 2 | 3 | cd ../../data/Dinesh-Jayaraman/ 4 | load('AwA_finalFeat.mat'); 5 | attributesMatrix=load('predicate-matrix-binary.txt'); 6 | cd (here); 7 | 8 | y=classes; 9 | testClassesIndices=[25 39 15 6 42 14 18 48 34 24]; 10 | trainClassesIndices=setdiff(1:50, testClassesIndices); 11 | 12 | 13 | trainInstancesIndices=[]; 14 | trainInstancesLabels=[]; 15 | for i=1:length(trainClassesIndices) 16 | class=trainClassesIndices(i); 17 | indices=find(y==class)'; 18 | trainInstancesIndices=[trainInstancesIndices, indices]; 19 | trainInstancesLabels=[trainInstancesLabels, ones(1,length(indices))*i]; 20 | end 21 | 22 | testInstancesIndices=[]; 23 | testInstancesLabels=[]; 24 | for i=1:length(testClassesIndices) 25 | class=testClassesIndices(i); 26 | indices=find (y==class)'; 27 | testInstancesIndices=[testInstancesIndices, indices]; 28 | testInstancesLabels=[testInstancesLabels, ones(1,length(indices))*i]; 29 | end 30 | 31 | % testClassesIndices=1:length(testClassesIndices); 32 | % trainClassesIndices=1:length(trainClassesIndices); 33 | -------------------------------------------------------------------------------- /stkernelAwA/getClassesAndIndicesAndFixClassMess.m: -------------------------------------------------------------------------------- 1 | here=pwd; 2 | 3 | cd ../../data/Dinesh-Jayaraman/ 4 | load('AwA_finalFeat.mat'); 5 | attributesMatrix=load('predicate-matrix-binary.txt'); 6 | cd (here); 7 | load classes 8 | 9 | uniqueClassSource=unique(classSource, 'stable'); 10 | uniqueClassTarget=unique(classTarget, 'stable'); 11 | 12 | nClasses=length(uniqueClassSource); 13 | 14 | y=classes; 15 | testClassesIndices=[25 39 15 6 42 14 18 48 34 24]; 16 | trainClassesIndices=setdiff(1:50, testClassesIndices); 17 | 18 | newTestClassesIndices=[]; 19 | newY=[]; 20 | newAttributesMatrix=[]; 21 | for i=1:nClasses 22 | targetClass=uniqueClassTarget(i); 23 | newAttributesMatrix=[newAttributesMatrix; attributesMatrix(targetClass,:)]; 24 | % if find(targetClass==testClassesIndices) 25 | % newTestClassesIndices=[newTestClassesIndices, i]; 26 | % end 27 | indices=find(classSource==targetClass); 28 | newY=[newY; ones(length(indices),1)*i]; 29 | end 30 | % testClassesIndices=newTestClassesIndices; 31 | y=newY; 32 | attributesMatrix=newAttributesMatrix; 33 | 34 | trainInstancesIndices=[]; 35 | trainInstancesLabels=[]; 36 | for i=1:length(trainClassesIndices) 37 | class=trainClassesIndices(i); 38 | indices=find(y==class)'; 39 | trainInstancesIndices=[trainInstancesIndices, indices]; 40 | trainInstancesLabels=[trainInstancesLabels, ones(1,length(indices))*i]; 41 | end 42 | 43 | testInstancesIndices=[]; 44 | testInstancesLabels=[]; 45 | for i=1:length(testClassesIndices) 46 | class=testClassesIndices(i); 47 | indices=find (y==class)'; 48 | testInstancesIndices=[testInstancesIndices, indices]; 49 | testInstancesLabels=[testInstancesLabels, ones(1,length(indices))*i]; 50 | end 51 | 52 | attrClasses=attributesMatrix; 53 | 54 | % testClassesIndices=1:length(testClassesIndices); 55 | % trainClassesIndices=1:length(trainClassesIndices); 56 | -------------------------------------------------------------------------------- /stkernelAwA/getClassesAndIndicesRaw.m: -------------------------------------------------------------------------------- 1 | here=pwd; 2 | cd /home/bernard/Data/ABC/animals 3 | attrClasses=dlmread('all-perclass.attributes'); 4 | indices=dlmread('split0.mask'); 5 | classes=dlmread('all.classid'); 6 | cd (here); 7 | 8 | trainInstancesIndices=find(indices==1); 9 | trainInstancesLabels=classes(indices==1); 10 | trainClassesIndices=unique(trainInstancesLabels,'stable'); 11 | 12 | auxLabels=zeros(size(trainInstancesLabels)); 13 | for i=1:length(unique(trainInstancesLabels)) 14 | c=find(trainInstancesLabels>0, 1); 15 | class=trainInstancesLabels(c); 16 | auxLabels(trainInstancesLabels==class)=i; 17 | trainInstancesLabels(trainInstancesLabels==class)=0; 18 | end 19 | trainInstancesLabels=auxLabels'; 20 | 21 | testInstancesIndices=find(indices==0); 22 | testInstancesLabels=classes(indices==0); 23 | testClassesIndices=unique(testInstancesLabels,'stable'); 24 | 25 | auxLabels=zeros(size(testInstancesLabels)); 26 | for i=1:length(unique(testInstancesLabels)) 27 | c=find(testInstancesLabels>0, 1); 28 | class=testInstancesLabels(c); 29 | auxLabels(testInstancesLabels==class)=i; 30 | testInstancesLabels(testInstancesLabels==class)=0; 31 | end 32 | testInstancesLabels=auxLabels'; 33 | 34 | attrClasses=unique(attrClasses, 'rows', 'stable'); 35 | uniqueClasses=unique(classes, 'stable'); 36 | newAttrClasses=zeros(size(attrClasses)); 37 | for i=1:length(uniqueClasses) 38 | newAttrClasses(uniqueClasses(i),:)=attrClasses(i,:); 39 | end 40 | attrClasses=newAttrClasses; 41 | -------------------------------------------------------------------------------- /stkernelAwA/runExperiment.m: -------------------------------------------------------------------------------- 1 | 2 | datasetPath='/home/bernard/HD1/ABC/animals'; 3 | addpath('..'); 4 | lambdas=10.^[-3:3]; 5 | gammas=10.^[-3:3]; 6 | tasks=[0]; 7 | 8 | validationClasses_generalscript(datasetPath, lambdas, gammas, tasks); 9 | -------------------------------------------------------------------------------- /stkernelSUNAttributes/allSignsGeneralscript.m: -------------------------------------------------------------------------------- 1 | 2 | function generalscript(path, lambdas, sigmas, tasks) 3 | 4 | addpath(path); 5 | 6 | kernelFID=fopen('all.kernel'); 7 | A=fread(kernelFID, 'float'); 8 | K=reshape(A,[sqrt(length(A)),sqrt(length(A))]); 9 | 10 | perclassattributes = dlmread('all-perclass.attributes'); 11 | S = dlmread('all-perimage.attributes'); 12 | load experimentIndices 13 | 14 | A=unique(perclassattributes, 'rows'); 15 | if length(dir('attrClasses*'))>0 16 | load attrClasses 17 | A=attrClasses; 18 | end 19 | % keyboard 20 | attributesInput=A(trainClassesIndices,:); 21 | attributesOutput=A(testClassesIndices,:); 22 | nAttrs=size(A,2); 23 | % allClass = dlmread('all.classid')+1; 24 | 25 | % keyboard 26 | 27 | nTrainingClasses=length(trainClassesIndices); 28 | nTrainingInstances=length(trainInstancesIndices); 29 | 30 | hist=[]; 31 | S=S(trainInstancesIndices,:); 32 | Stest=attributesOutput; 33 | % Stest(Stest==0)=-1; 34 | % for i=1:size(Stest,1) 35 | % Stest(i,:)=Stest(i,:)-mean(Stest(i,:)); 36 | % Stest(i,:)=Stest(i,:)/norm(Stest(i,:)); 37 | % end 38 | 39 | for t=1:length(tasks) 40 | nNewTasks=tasks(t); 41 | 42 | disp('Getting labels...'); 43 | 44 | % Y=zeros(nTrainingInstances,nTrainingClasses+nNewTasks);%-1;%-ones(nTrainingInstances,nTrainingClasses);%/(nTrainingClasses-1); 45 | % for i=1:nTrainingInstances 46 | % Y(i,trainInstancesLabels(i))=1; 47 | % end 48 | 49 | disp('Precalculating statistics...'); 50 | KTrain=K(trainInstancesIndices, trainInstancesIndices); 51 | KK=KTrain'*KTrain; 52 | 53 | % S=[attributesInput; rand(nNewTasks, nAttrs)]; 54 | % S=[attributesInput; sign(rand(nNewTasks, nAttrs)-0.5)]; 55 | % S(S==0)=-1; 56 | % for i=1:size(S,1) 57 | % S(i,:)=S(i,:)-mean(S(i,:)); 58 | % S(i,:)=S(i,:)/norm(S(i,:)); 59 | % end 60 | 61 | KYS=KTrain*S; 62 | SS=S'*S; 63 | 64 | % keyboard 65 | for s=1:length(sigmas) 66 | sigma=sigmas(s); 67 | KYS_invSS=KYS/(SS+sigma*eye(size(S,2))); 68 | 69 | 70 | disp('Learning...'); 71 | % step=10^-8; 72 | % for t=1:500 73 | % % keyboard 74 | % W = W - step * (XX*W*SS -XYS + lambda*W); 75 | % end 76 | for lambdaIndex=1:length(lambdas) 77 | lambda=lambdas(lambdaIndex); 78 | 79 | Alpha=(KK+lambda*KTrain)\KYS_invSS; 80 | 81 | disp('Predicting...'); 82 | KTest=K(testInstancesIndices,trainInstancesIndices); 83 | 84 | pred=Stest*Alpha'*KTest'; 85 | [scores, classPred]=max(pred',[],2); 86 | 87 | disp('Evaluating...'); 88 | gt=testInstancesLabels; 89 | r=mean(gt==classPred') 90 | % confusionmat(gt, classPred') 91 | hist=[hist; [lambda, sigma, nNewTasks, r]]; 92 | % break; 93 | save('hist4', 'hist'); 94 | end 95 | end 96 | end 97 | end 98 | -------------------------------------------------------------------------------- /stkernelSUNAttributes/attrClasses.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelSUNAttributes/attrClasses.mat -------------------------------------------------------------------------------- /stkernelSUNAttributes/binAttrClasses.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelSUNAttributes/binAttrClasses.mat -------------------------------------------------------------------------------- /stkernelSUNAttributes/experimentIndices.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelSUNAttributes/experimentIndices.mat -------------------------------------------------------------------------------- /stkernelSUNAttributes/getDataInfo.m: -------------------------------------------------------------------------------- 1 | here=pwd; 2 | cd /home/bernard/Data/ABC/sun 3 | attrs=dlmread('all-perimage.attributes'); 4 | classes=dlmread('all.classid'); 5 | cd (here); 6 | 7 | nClasses=length(unique(classes)); 8 | nAttrs=size(attrs, 2); 9 | V=zeros(nClasses,nAttrs); 10 | B=zeros(nClasses,nAttrs); 11 | for i=1:nClasses 12 | classi=classes(i); 13 | indices=find(classes==classi); 14 | B(i,:)=mean(attrs(indices,:)); 15 | V(i,:)=std(attrs(indices,:)); 16 | end 17 | ra=std(B)./mean(V) 18 | ra=ra(~isnan(ra)); 19 | mean(ra) -------------------------------------------------------------------------------- /stkernelSUNAttributes/infoPrediction.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelSUNAttributes/infoPrediction.mat -------------------------------------------------------------------------------- /stkernelSUNAttributes/mapScores.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelSUNAttributes/mapScores.mat -------------------------------------------------------------------------------- /stkernelSUNAttributes/runExperiment.m: -------------------------------------------------------------------------------- 1 | 2 | datasetPath='/home/bernard/HD1/ABC/sun'; 3 | addpath('..'); 4 | lambdas = 10.^[-3:3]; 5 | sigmas=10.^[-3:3]; 6 | tasks=[0]; 7 | 8 | validationClasses_generalscript(datasetPath, lambdas, sigmas, tasks); 9 | -------------------------------------------------------------------------------- /stkernelSUNAttributes/visualizationScript.m: -------------------------------------------------------------------------------- 1 | here=pwd; 2 | load infoPrediction 3 | nClasses=max(gt); 4 | nInstancesPerClass=20; 5 | newSize=[70,70]; 6 | testClasses={'archive', 'art school', 'chemical plant', 'flea market', 'inn','lab classroom', 'lake', 'mineshaft', 'outhouse', 'shoe shop'}; 7 | 8 | 9 | imMatrix=uint8(255+zeros(([newSize*nClasses,3]))); 10 | 11 | cd('~/Dropbox/Data/SUNAttributes'); 12 | listDirs=dir(); 13 | 14 | for i=1:nClasses 15 | cd (listDirs(i+2).name); 16 | classPredAux=classPred((i-1)*nInstancesPerClass+1:i*nInstancesPerClass); 17 | listImages=dir('*.jpg'); 18 | for j=1:nClasses 19 | index=find(classPredAux==j); 20 | if isempty(index) 21 | continue 22 | end 23 | index=index(floor(rand*length(index))+1); 24 | im=imread(listImages(index).name); 25 | im=imresize(im, newSize); 26 | imMatrix(newSize(1)*(i-1)+1:newSize(1)*i,newSize(2)*(j-1)+1:newSize(2)*j,:)=im; 27 | end 28 | cd .. 29 | end 30 | 31 | cd(here); 32 | imshow(imMatrix); 33 | for i=1:10 34 | annotation('textbox',... 35 | [0.01 i/11.5+0.065 0.1 0.01],... 36 | 'String', testClasses{11-i}, ... 37 | 'LineStyle','none'); 38 | end 39 | for i=1:10 40 | xlim=get(gca,'XLim'); 41 | ylim=get(gca,'YLim'); 42 | 43 | name=testClasses{i}; 44 | if length(name)>7 45 | name=name(1:7); 46 | end 47 | ht = text(i/10*xlim(2)-40,0.1*ylim(1)+1.11*ylim(2),name); 48 | set(ht,'Rotation',90) 49 | end 50 | -------------------------------------------------------------------------------- /stkernelaPY/attrClasses.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/attrClasses.mat -------------------------------------------------------------------------------- /stkernelaPY/binAttrClasses.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/binAttrClasses.mat -------------------------------------------------------------------------------- /stkernelaPY/experimentIndices.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/experimentIndices.mat -------------------------------------------------------------------------------- /stkernelaPY/getClassesAndIndices.m: -------------------------------------------------------------------------------- 1 | here=pwd; 2 | cd ../../data/Dinesh-Jayaraman/ 3 | load('Farhadi_PCA105.mat') 4 | cd (here); 5 | 6 | y=classes; 7 | testClassesIndices=%[25 39 15 6 42 14 18 48 34 24]; 8 | trainClassesIndices=%setdiff(1:50, testClassesIndices); 9 | 10 | testIndices=randperm(2644); 11 | trainIndices=[2645:size(features,1)]; 12 | 13 | trainInstancesIndices=[]; 14 | trainInstancesLabels=[]; 15 | for i=1:length(trainClassesIndices) 16 | class=trainClassesIndices(i); 17 | indices=find(y==class)'; 18 | trainInstancesIndices=[trainInstancesIndices, indices]; 19 | trainInstancesLabels=[trainInstancesLabels, ones(1,length(indices))*i]; 20 | end 21 | 22 | testInstancesIndices=[]; 23 | testInstancesLabels=[]; 24 | for i=1:length(testClassesIndices) 25 | class=testClassesIndices(i); 26 | indices=find (y==class)'; 27 | testInstancesIndices=[testInstancesIndices, indices]; 28 | testInstancesLabels=[testInstancesLabels, ones(1,length(indices))*i]; 29 | end 30 | 31 | testClassesIndices=1:length(testClassesIndices); 32 | trainClassesIndices=1:length(trainClassesIndices); 33 | -------------------------------------------------------------------------------- /stkernelaPY/getClassesAndIndicesRaw.m: -------------------------------------------------------------------------------- 1 | here=pwd; 2 | cd /home/bernard/Data/ABC/apascal 3 | attrClasses=dlmread('all-perclass.attributes'); 4 | indices=dlmread('split0.mask'); 5 | classes=dlmread('all.classid'); 6 | cd (here); 7 | 8 | classes=classes+1; 9 | 10 | trainInstancesIndices=find(indices==1); 11 | trainInstancesLabels=classes(indices==1); 12 | trainClassesIndices=unique(trainInstancesLabels,'stable'); 13 | 14 | % keyboard 15 | 16 | auxLabels=zeros(size(trainInstancesLabels)); 17 | for i=1:length(unique(trainInstancesLabels)) 18 | c=find(trainInstancesLabels>0, 1); 19 | class=trainInstancesLabels(c); 20 | auxLabels(trainInstancesLabels==class)=i; 21 | trainInstancesLabels(trainInstancesLabels==class)=0; 22 | end 23 | trainInstancesLabels=auxLabels'; 24 | 25 | testInstancesIndices=find(indices==0); 26 | testInstancesLabels=classes(indices==0); 27 | testClassesIndices=unique(testInstancesLabels,'stable'); 28 | 29 | auxLabels=zeros(size(testInstancesLabels)); 30 | for i=1:length(unique(testInstancesLabels)) 31 | c=find(testInstancesLabels>0, 1); 32 | class=testInstancesLabels(c); 33 | auxLabels(testInstancesLabels==class)=i; 34 | testInstancesLabels(testInstancesLabels==class)=0; 35 | end 36 | testInstancesLabels=auxLabels'; 37 | 38 | newAttrClasses=[]; 39 | for i=1:length(unique(classes)) 40 | newAttrClasses=[newAttrClasses; attrClasses(find(classes==i,1),:)]; 41 | end 42 | attrClasses=newAttrClasses; %unique(attrClasses, 'rows', 'stable'); 43 | -------------------------------------------------------------------------------- /stkernelaPY/getDataInfo.m: -------------------------------------------------------------------------------- 1 | here=pwd; 2 | cd /home/bernard/Data/ABC/apascal 3 | attrs=dlmread('all-perimage.attributes'); 4 | classes=dlmread('all.classid'); 5 | cd (here); 6 | 7 | nClasses=length(unique(classes)); 8 | nAttrs=size(attrs, 2); 9 | V=zeros(nClasses,nAttrs); 10 | B=zeros(nClasses,nAttrs); 11 | for i=1:nClasses 12 | classi=classes(i); 13 | indices=find(classes==classi); 14 | % keyboard 15 | B(i,:)=mean(attrs(indices,:)); 16 | V(i,:)=std(attrs(indices,:)); 17 | end 18 | ra=std(B)./mean(V) 19 | ra=ra(~isnan(ra)); 20 | mean(ra) -------------------------------------------------------------------------------- /stkernelaPY/getMinMaxAttributes.m: -------------------------------------------------------------------------------- 1 | here=pwd; 2 | cd /home/bernard/Data/ABC/apascal 3 | attrs=dlmread('all-perimage.attributes'); 4 | classes=dlmread('all.classid'); 5 | cd (here); 6 | classes=classes+1; 7 | 8 | nClasses=length(unique(classes)); 9 | [nInstances, nAttrs]=size(attrs); 10 | 11 | classesAttrs=zeros(nClasses*3,nAttrs); 12 | classesCells=cell(1,nClasses); 13 | 14 | for i=1:nInstances 15 | class=classes(i); 16 | signature=attrs(i,:); 17 | classesCells{class}=[classesCells{class}; signature]; 18 | end 19 | counter=1; 20 | for class=1:nClasses 21 | classesAttrs(counter,:)=mean(classesCells{class}); 22 | classesAttrs(counter+1,:)=mean(classesCells{class})+0.5*std(classesCells{class}); 23 | classesAttrs(counter+2,:)=mean(classesCells{class})-0.5*std(classesCells{class}); 24 | counter=counter+3; 25 | end 26 | -------------------------------------------------------------------------------- /stkernelaPY/getSoftAttributes.m: -------------------------------------------------------------------------------- 1 | here=pwd; 2 | cd /home/bernard/Data/ABC/apascal 3 | attrs=dlmread('all-perimage.attributes'); 4 | classes=dlmread('all.classid'); 5 | cd (here); 6 | classes=classes+1; 7 | 8 | nClasses=length(unique(classes)); 9 | [nInstances, nAttrs]=size(attrs); 10 | 11 | classesAttrs=zeros(nClasses,nAttrs); 12 | counters=zeros(1,nClasses); 13 | 14 | for i=1:nInstances 15 | class=classes(i); 16 | signature=attrs(i,:); 17 | classesAttrs(class,:)=classesAttrs(class,:)+signature; 18 | counters(class)=counters(class)+1; 19 | end 20 | for i=1:nClasses 21 | classesAttrs(i,:)=classesAttrs(i,:)/counters(i); 22 | end 23 | -------------------------------------------------------------------------------- /stkernelaPY/mapScores.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/mapScores.mat -------------------------------------------------------------------------------- /stkernelaPY/meanStdAttrClasses.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/meanStdAttrClasses.mat -------------------------------------------------------------------------------- /stkernelaPY/runExperiment.m: -------------------------------------------------------------------------------- 1 | 2 | datasetPath='/home/bernard/HD1/ABC/apascal'; 3 | addpath('..'); 4 | 5 | lambdas=10.^[-3:3]; 6 | sigmas=10.^[-3:3]; 7 | tasks=0; 8 | 9 | validationClasses_generalscript(datasetPath, lambdas, sigmas, tasks); -------------------------------------------------------------------------------- /stkernelaPY/track.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/track.mat -------------------------------------------------------------------------------- /validationClasses_generalscript.m: -------------------------------------------------------------------------------- 1 | 2 | function generalscript(path, lambdas, sigmas, tasks) 3 | 4 | addpath(path); 5 | 6 | kernelFID=fopen('all.kernel'); 7 | A=fread(kernelFID, 'float'); 8 | K=reshape(A,[sqrt(length(A)),sqrt(length(A))]); 9 | 10 | perclassattributes = dlmread('all-perclass.attributes'); 11 | load experimentIndices 12 | 13 | A=unique(perclassattributes, 'rows'); 14 | if length(dir('attrClasses*'))>0 15 | load attrClasses 16 | A=attrClasses; 17 | end 18 | attributesInput=A(trainClassesIndices,:); 19 | attributesOutput=A(testClassesIndices,:); 20 | nAttrs=size(A,2); 21 | 22 | nTrainingClasses=length(trainClassesIndices); 23 | nTrainingInstances=length(trainInstancesIndices); 24 | 25 | 26 | hist=[]; 27 | while 1>0 28 | 29 | 30 | newOrder=randperm(nTrainingClasses); 31 | miniTrainingClasses=newOrder(1:floor(nTrainingClasses*0.8)); 32 | miniValClasses=newOrder(floor(nTrainingClasses*0.8)+1:end); 33 | 34 | miniAttributesInput=[]; 35 | miniAttributesOutput=[]; 36 | miniTrainInstancesIndices=[]; 37 | miniValInstancesIndices=[]; 38 | miniTrainInstancesLabels=[]; 39 | miniValInstancesLabels=[]; 40 | miniNTrainingClasses=length(miniTrainingClasses); 41 | for i=1:miniNTrainingClasses 42 | class=newOrder(i); 43 | miniAttributesInput=[miniAttributesInput; attributesInput(class,:)]; 44 | indices=trainInstancesIndices(trainInstancesLabels==class); 45 | miniTrainInstancesIndices=[miniTrainInstancesIndices; indices]; 46 | miniTrainInstancesLabels=[miniTrainInstancesLabels, i*ones(1,length(indices))]; 47 | end 48 | for i=1:length(miniValClasses) 49 | class=newOrder(miniNTrainingClasses+i); 50 | miniAttributesOutput=[miniAttributesOutput; attributesInput(class,:)]; 51 | indices=trainInstancesIndices(trainInstancesLabels==class); 52 | miniValInstancesIndices=[miniValInstancesIndices; indices]; 53 | miniValInstancesLabels=[miniValInstancesLabels, i*ones(1, length(indices))]; 54 | end 55 | 56 | miniNTrainingInstances=length(miniTrainInstancesIndices); 57 | 58 | Stest=attributesOutput; 59 | Sval=miniAttributesOutput; 60 | 61 | record=0; 62 | recordman=[]; 63 | 64 | for t=1:length(tasks) 65 | nNewTasks=tasks(t); 66 | 67 | disp('Getting labels...'); 68 | 69 | Y=zeros(miniNTrainingInstances,miniNTrainingClasses+nNewTasks);%-1;%-ones(nTrainingInstances,nTrainingClasses);%/(nTrainingClasses-1); 70 | for i=1:miniNTrainingInstances 71 | Y(i,miniTrainInstancesLabels(i))=1; 72 | end 73 | 74 | disp('Precalculating statistics...'); 75 | KTrain=K(miniTrainInstancesIndices, miniTrainInstancesIndices); 76 | KK=KTrain'*KTrain; 77 | 78 | S=[miniAttributesInput; rand(nNewTasks, nAttrs)]; 79 | KYS=KTrain*Y*S; 80 | 81 | for s=1:length(sigmas) 82 | sigma=sigmas(s); 83 | 84 | KYS_invSS=KYS/(S'*S+sigma*eye(size(S,2))); 85 | disp('Learning...'); 86 | for lambdaIndex=1:length(lambdas) 87 | lambda=lambdas(lambdaIndex); 88 | 89 | % Alpha=(KK+lambda*KTrain)\KYS_invSS; 90 | Alpha=(KK+lambda*eye(size(KTrain)))\KYS_invSS; 91 | 92 | disp('Predicting...'); 93 | KTest=K(miniValInstancesIndices,miniTrainInstancesIndices); 94 | 95 | pred=Sval*Alpha'*KTest'; 96 | [scores, classPred]=max(pred',[],2); 97 | 98 | disp('Evaluating...'); 99 | gt=miniValInstancesLabels; 100 | r=mean(gt==classPred'); 101 | [r record] 102 | if r>record 103 | record=r; 104 | recordman=[lambda, sigma, nNewTasks]; 105 | disp(['Record! ', num2str(record)]); 106 | end 107 | end 108 | end 109 | end 110 | 111 | lambda=recordman(1); 112 | sigma=recordman(2); 113 | nNewTasks=recordman(3); 114 | 115 | Y=zeros(nTrainingInstances,nTrainingClasses+nNewTasks); 116 | for i=1:nTrainingInstances 117 | Y(i,trainInstancesLabels(i))=1; 118 | end 119 | 120 | disp('Precalculating statistics...'); 121 | KTrain=K(trainInstancesIndices, trainInstancesIndices); 122 | KK=KTrain'*KTrain; 123 | 124 | S=[attributesInput; rand(nNewTasks, nAttrs)]; 125 | KYS=KTrain*Y*S; 126 | KYS_invSS=KYS/(S'*S+sigma*eye(size(S,2))); 127 | disp('Learning...'); 128 | 129 | Alpha=(KK+lambda*KTrain)\KYS_invSS; 130 | % Alpha=(KK+lambda*eye(size(KTrain)))\KYS_invSS; 131 | 132 | disp('Predicting...'); 133 | KTest=K(testInstancesIndices,trainInstancesIndices); 134 | 135 | pred=Stest*Alpha'*KTest'; 136 | [scores, classPred]=max(pred',[],2); 137 | 138 | disp('Evaluating...'); 139 | gt=testInstancesLabels; 140 | r=mean(gt==classPred') 141 | keyboard 142 | hist=[hist, r]; 143 | save('val_hist', 'hist'); 144 | end 145 | 146 | end 147 | --------------------------------------------------------------------------------