├── LICENSE
├── README.md
├── stkernelAwA
├── attrClasses.mat
├── binAttrClasses.mat
├── classes.mat
├── experimentIndices.mat
├── getClassesAndIndices.m
├── getClassesAndIndicesAndFixClassMess.m
├── getClassesAndIndicesRaw.m
└── runExperiment.m
├── stkernelSUNAttributes
├── allSignsGeneralscript.m
├── attrClasses.mat
├── binAttrClasses.mat
├── experimentIndices.mat
├── getDataInfo.m
├── infoPrediction.mat
├── mapScores.mat
├── runExperiment.m
└── visualizationScript.m
├── stkernelaPY
├── attrClasses.mat
├── binAttrClasses.mat
├── experimentIndices.mat
├── getClassesAndIndices.m
├── getClassesAndIndicesRaw.m
├── getDataInfo.m
├── getMinMaxAttributes.m
├── getSoftAttributes.m
├── mapScores.mat
├── meanStdAttrClasses.mat
├── runExperiment.m
└── track.mat
└── validationClasses_generalscript.m
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 Bernardino Romera-Paredes
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Embarrassingly-simple-ZSL
2 | This repository contains the code for the real data experiments appearing in our paper An embarrassingly simple approach to zero-shot learning, presented at ICML 2015.
3 |
4 |
5 |
6 | Before running the experiments you have to download the datasets from http://pub.ist.ac.at/~chl/ABC/.
7 | Then you uncompress each dataset in any directory you want, and then you put that directory in datasetPath.
8 | Move to any of the directories, corresponding to each of the real dataset, and run runExperiment.m
9 |
--------------------------------------------------------------------------------
/stkernelAwA/attrClasses.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelAwA/attrClasses.mat
--------------------------------------------------------------------------------
/stkernelAwA/binAttrClasses.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelAwA/binAttrClasses.mat
--------------------------------------------------------------------------------
/stkernelAwA/classes.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelAwA/classes.mat
--------------------------------------------------------------------------------
/stkernelAwA/experimentIndices.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelAwA/experimentIndices.mat
--------------------------------------------------------------------------------
/stkernelAwA/getClassesAndIndices.m:
--------------------------------------------------------------------------------
1 | here=pwd;
2 |
3 | cd ../../data/Dinesh-Jayaraman/
4 | load('AwA_finalFeat.mat');
5 | attributesMatrix=load('predicate-matrix-binary.txt');
6 | cd (here);
7 |
8 | y=classes;
9 | testClassesIndices=[25 39 15 6 42 14 18 48 34 24];
10 | trainClassesIndices=setdiff(1:50, testClassesIndices);
11 |
12 |
13 | trainInstancesIndices=[];
14 | trainInstancesLabels=[];
15 | for i=1:length(trainClassesIndices)
16 | class=trainClassesIndices(i);
17 | indices=find(y==class)';
18 | trainInstancesIndices=[trainInstancesIndices, indices];
19 | trainInstancesLabels=[trainInstancesLabels, ones(1,length(indices))*i];
20 | end
21 |
22 | testInstancesIndices=[];
23 | testInstancesLabels=[];
24 | for i=1:length(testClassesIndices)
25 | class=testClassesIndices(i);
26 | indices=find (y==class)';
27 | testInstancesIndices=[testInstancesIndices, indices];
28 | testInstancesLabels=[testInstancesLabels, ones(1,length(indices))*i];
29 | end
30 |
31 | % testClassesIndices=1:length(testClassesIndices);
32 | % trainClassesIndices=1:length(trainClassesIndices);
33 |
--------------------------------------------------------------------------------
/stkernelAwA/getClassesAndIndicesAndFixClassMess.m:
--------------------------------------------------------------------------------
1 | here=pwd;
2 |
3 | cd ../../data/Dinesh-Jayaraman/
4 | load('AwA_finalFeat.mat');
5 | attributesMatrix=load('predicate-matrix-binary.txt');
6 | cd (here);
7 | load classes
8 |
9 | uniqueClassSource=unique(classSource, 'stable');
10 | uniqueClassTarget=unique(classTarget, 'stable');
11 |
12 | nClasses=length(uniqueClassSource);
13 |
14 | y=classes;
15 | testClassesIndices=[25 39 15 6 42 14 18 48 34 24];
16 | trainClassesIndices=setdiff(1:50, testClassesIndices);
17 |
18 | newTestClassesIndices=[];
19 | newY=[];
20 | newAttributesMatrix=[];
21 | for i=1:nClasses
22 | targetClass=uniqueClassTarget(i);
23 | newAttributesMatrix=[newAttributesMatrix; attributesMatrix(targetClass,:)];
24 | % if find(targetClass==testClassesIndices)
25 | % newTestClassesIndices=[newTestClassesIndices, i];
26 | % end
27 | indices=find(classSource==targetClass);
28 | newY=[newY; ones(length(indices),1)*i];
29 | end
30 | % testClassesIndices=newTestClassesIndices;
31 | y=newY;
32 | attributesMatrix=newAttributesMatrix;
33 |
34 | trainInstancesIndices=[];
35 | trainInstancesLabels=[];
36 | for i=1:length(trainClassesIndices)
37 | class=trainClassesIndices(i);
38 | indices=find(y==class)';
39 | trainInstancesIndices=[trainInstancesIndices, indices];
40 | trainInstancesLabels=[trainInstancesLabels, ones(1,length(indices))*i];
41 | end
42 |
43 | testInstancesIndices=[];
44 | testInstancesLabels=[];
45 | for i=1:length(testClassesIndices)
46 | class=testClassesIndices(i);
47 | indices=find (y==class)';
48 | testInstancesIndices=[testInstancesIndices, indices];
49 | testInstancesLabels=[testInstancesLabels, ones(1,length(indices))*i];
50 | end
51 |
52 | attrClasses=attributesMatrix;
53 |
54 | % testClassesIndices=1:length(testClassesIndices);
55 | % trainClassesIndices=1:length(trainClassesIndices);
56 |
--------------------------------------------------------------------------------
/stkernelAwA/getClassesAndIndicesRaw.m:
--------------------------------------------------------------------------------
1 | here=pwd;
2 | cd /home/bernard/Data/ABC/animals
3 | attrClasses=dlmread('all-perclass.attributes');
4 | indices=dlmread('split0.mask');
5 | classes=dlmread('all.classid');
6 | cd (here);
7 |
8 | trainInstancesIndices=find(indices==1);
9 | trainInstancesLabels=classes(indices==1);
10 | trainClassesIndices=unique(trainInstancesLabels,'stable');
11 |
12 | auxLabels=zeros(size(trainInstancesLabels));
13 | for i=1:length(unique(trainInstancesLabels))
14 | c=find(trainInstancesLabels>0, 1);
15 | class=trainInstancesLabels(c);
16 | auxLabels(trainInstancesLabels==class)=i;
17 | trainInstancesLabels(trainInstancesLabels==class)=0;
18 | end
19 | trainInstancesLabels=auxLabels';
20 |
21 | testInstancesIndices=find(indices==0);
22 | testInstancesLabels=classes(indices==0);
23 | testClassesIndices=unique(testInstancesLabels,'stable');
24 |
25 | auxLabels=zeros(size(testInstancesLabels));
26 | for i=1:length(unique(testInstancesLabels))
27 | c=find(testInstancesLabels>0, 1);
28 | class=testInstancesLabels(c);
29 | auxLabels(testInstancesLabels==class)=i;
30 | testInstancesLabels(testInstancesLabels==class)=0;
31 | end
32 | testInstancesLabels=auxLabels';
33 |
34 | attrClasses=unique(attrClasses, 'rows', 'stable');
35 | uniqueClasses=unique(classes, 'stable');
36 | newAttrClasses=zeros(size(attrClasses));
37 | for i=1:length(uniqueClasses)
38 | newAttrClasses(uniqueClasses(i),:)=attrClasses(i,:);
39 | end
40 | attrClasses=newAttrClasses;
41 |
--------------------------------------------------------------------------------
/stkernelAwA/runExperiment.m:
--------------------------------------------------------------------------------
1 |
2 | datasetPath='/home/bernard/HD1/ABC/animals';
3 | addpath('..');
4 | lambdas=10.^[-3:3];
5 | gammas=10.^[-3:3];
6 | tasks=[0];
7 |
8 | validationClasses_generalscript(datasetPath, lambdas, gammas, tasks);
9 |
--------------------------------------------------------------------------------
/stkernelSUNAttributes/allSignsGeneralscript.m:
--------------------------------------------------------------------------------
1 |
2 | function generalscript(path, lambdas, sigmas, tasks)
3 |
4 | addpath(path);
5 |
6 | kernelFID=fopen('all.kernel');
7 | A=fread(kernelFID, 'float');
8 | K=reshape(A,[sqrt(length(A)),sqrt(length(A))]);
9 |
10 | perclassattributes = dlmread('all-perclass.attributes');
11 | S = dlmread('all-perimage.attributes');
12 | load experimentIndices
13 |
14 | A=unique(perclassattributes, 'rows');
15 | if length(dir('attrClasses*'))>0
16 | load attrClasses
17 | A=attrClasses;
18 | end
19 | % keyboard
20 | attributesInput=A(trainClassesIndices,:);
21 | attributesOutput=A(testClassesIndices,:);
22 | nAttrs=size(A,2);
23 | % allClass = dlmread('all.classid')+1;
24 |
25 | % keyboard
26 |
27 | nTrainingClasses=length(trainClassesIndices);
28 | nTrainingInstances=length(trainInstancesIndices);
29 |
30 | hist=[];
31 | S=S(trainInstancesIndices,:);
32 | Stest=attributesOutput;
33 | % Stest(Stest==0)=-1;
34 | % for i=1:size(Stest,1)
35 | % Stest(i,:)=Stest(i,:)-mean(Stest(i,:));
36 | % Stest(i,:)=Stest(i,:)/norm(Stest(i,:));
37 | % end
38 |
39 | for t=1:length(tasks)
40 | nNewTasks=tasks(t);
41 |
42 | disp('Getting labels...');
43 |
44 | % Y=zeros(nTrainingInstances,nTrainingClasses+nNewTasks);%-1;%-ones(nTrainingInstances,nTrainingClasses);%/(nTrainingClasses-1);
45 | % for i=1:nTrainingInstances
46 | % Y(i,trainInstancesLabels(i))=1;
47 | % end
48 |
49 | disp('Precalculating statistics...');
50 | KTrain=K(trainInstancesIndices, trainInstancesIndices);
51 | KK=KTrain'*KTrain;
52 |
53 | % S=[attributesInput; rand(nNewTasks, nAttrs)];
54 | % S=[attributesInput; sign(rand(nNewTasks, nAttrs)-0.5)];
55 | % S(S==0)=-1;
56 | % for i=1:size(S,1)
57 | % S(i,:)=S(i,:)-mean(S(i,:));
58 | % S(i,:)=S(i,:)/norm(S(i,:));
59 | % end
60 |
61 | KYS=KTrain*S;
62 | SS=S'*S;
63 |
64 | % keyboard
65 | for s=1:length(sigmas)
66 | sigma=sigmas(s);
67 | KYS_invSS=KYS/(SS+sigma*eye(size(S,2)));
68 |
69 |
70 | disp('Learning...');
71 | % step=10^-8;
72 | % for t=1:500
73 | % % keyboard
74 | % W = W - step * (XX*W*SS -XYS + lambda*W);
75 | % end
76 | for lambdaIndex=1:length(lambdas)
77 | lambda=lambdas(lambdaIndex);
78 |
79 | Alpha=(KK+lambda*KTrain)\KYS_invSS;
80 |
81 | disp('Predicting...');
82 | KTest=K(testInstancesIndices,trainInstancesIndices);
83 |
84 | pred=Stest*Alpha'*KTest';
85 | [scores, classPred]=max(pred',[],2);
86 |
87 | disp('Evaluating...');
88 | gt=testInstancesLabels;
89 | r=mean(gt==classPred')
90 | % confusionmat(gt, classPred')
91 | hist=[hist; [lambda, sigma, nNewTasks, r]];
92 | % break;
93 | save('hist4', 'hist');
94 | end
95 | end
96 | end
97 | end
98 |
--------------------------------------------------------------------------------
/stkernelSUNAttributes/attrClasses.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelSUNAttributes/attrClasses.mat
--------------------------------------------------------------------------------
/stkernelSUNAttributes/binAttrClasses.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelSUNAttributes/binAttrClasses.mat
--------------------------------------------------------------------------------
/stkernelSUNAttributes/experimentIndices.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelSUNAttributes/experimentIndices.mat
--------------------------------------------------------------------------------
/stkernelSUNAttributes/getDataInfo.m:
--------------------------------------------------------------------------------
1 | here=pwd;
2 | cd /home/bernard/Data/ABC/sun
3 | attrs=dlmread('all-perimage.attributes');
4 | classes=dlmread('all.classid');
5 | cd (here);
6 |
7 | nClasses=length(unique(classes));
8 | nAttrs=size(attrs, 2);
9 | V=zeros(nClasses,nAttrs);
10 | B=zeros(nClasses,nAttrs);
11 | for i=1:nClasses
12 | classi=classes(i);
13 | indices=find(classes==classi);
14 | B(i,:)=mean(attrs(indices,:));
15 | V(i,:)=std(attrs(indices,:));
16 | end
17 | ra=std(B)./mean(V)
18 | ra=ra(~isnan(ra));
19 | mean(ra)
--------------------------------------------------------------------------------
/stkernelSUNAttributes/infoPrediction.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelSUNAttributes/infoPrediction.mat
--------------------------------------------------------------------------------
/stkernelSUNAttributes/mapScores.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelSUNAttributes/mapScores.mat
--------------------------------------------------------------------------------
/stkernelSUNAttributes/runExperiment.m:
--------------------------------------------------------------------------------
1 |
2 | datasetPath='/home/bernard/HD1/ABC/sun';
3 | addpath('..');
4 | lambdas = 10.^[-3:3];
5 | sigmas=10.^[-3:3];
6 | tasks=[0];
7 |
8 | validationClasses_generalscript(datasetPath, lambdas, sigmas, tasks);
9 |
--------------------------------------------------------------------------------
/stkernelSUNAttributes/visualizationScript.m:
--------------------------------------------------------------------------------
1 | here=pwd;
2 | load infoPrediction
3 | nClasses=max(gt);
4 | nInstancesPerClass=20;
5 | newSize=[70,70];
6 | testClasses={'archive', 'art school', 'chemical plant', 'flea market', 'inn','lab classroom', 'lake', 'mineshaft', 'outhouse', 'shoe shop'};
7 |
8 |
9 | imMatrix=uint8(255+zeros(([newSize*nClasses,3])));
10 |
11 | cd('~/Dropbox/Data/SUNAttributes');
12 | listDirs=dir();
13 |
14 | for i=1:nClasses
15 | cd (listDirs(i+2).name);
16 | classPredAux=classPred((i-1)*nInstancesPerClass+1:i*nInstancesPerClass);
17 | listImages=dir('*.jpg');
18 | for j=1:nClasses
19 | index=find(classPredAux==j);
20 | if isempty(index)
21 | continue
22 | end
23 | index=index(floor(rand*length(index))+1);
24 | im=imread(listImages(index).name);
25 | im=imresize(im, newSize);
26 | imMatrix(newSize(1)*(i-1)+1:newSize(1)*i,newSize(2)*(j-1)+1:newSize(2)*j,:)=im;
27 | end
28 | cd ..
29 | end
30 |
31 | cd(here);
32 | imshow(imMatrix);
33 | for i=1:10
34 | annotation('textbox',...
35 | [0.01 i/11.5+0.065 0.1 0.01],...
36 | 'String', testClasses{11-i}, ...
37 | 'LineStyle','none');
38 | end
39 | for i=1:10
40 | xlim=get(gca,'XLim');
41 | ylim=get(gca,'YLim');
42 |
43 | name=testClasses{i};
44 | if length(name)>7
45 | name=name(1:7);
46 | end
47 | ht = text(i/10*xlim(2)-40,0.1*ylim(1)+1.11*ylim(2),name);
48 | set(ht,'Rotation',90)
49 | end
50 |
--------------------------------------------------------------------------------
/stkernelaPY/attrClasses.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/attrClasses.mat
--------------------------------------------------------------------------------
/stkernelaPY/binAttrClasses.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/binAttrClasses.mat
--------------------------------------------------------------------------------
/stkernelaPY/experimentIndices.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/experimentIndices.mat
--------------------------------------------------------------------------------
/stkernelaPY/getClassesAndIndices.m:
--------------------------------------------------------------------------------
1 | here=pwd;
2 | cd ../../data/Dinesh-Jayaraman/
3 | load('Farhadi_PCA105.mat')
4 | cd (here);
5 |
6 | y=classes;
7 | testClassesIndices=%[25 39 15 6 42 14 18 48 34 24];
8 | trainClassesIndices=%setdiff(1:50, testClassesIndices);
9 |
10 | testIndices=randperm(2644);
11 | trainIndices=[2645:size(features,1)];
12 |
13 | trainInstancesIndices=[];
14 | trainInstancesLabels=[];
15 | for i=1:length(trainClassesIndices)
16 | class=trainClassesIndices(i);
17 | indices=find(y==class)';
18 | trainInstancesIndices=[trainInstancesIndices, indices];
19 | trainInstancesLabels=[trainInstancesLabels, ones(1,length(indices))*i];
20 | end
21 |
22 | testInstancesIndices=[];
23 | testInstancesLabels=[];
24 | for i=1:length(testClassesIndices)
25 | class=testClassesIndices(i);
26 | indices=find (y==class)';
27 | testInstancesIndices=[testInstancesIndices, indices];
28 | testInstancesLabels=[testInstancesLabels, ones(1,length(indices))*i];
29 | end
30 |
31 | testClassesIndices=1:length(testClassesIndices);
32 | trainClassesIndices=1:length(trainClassesIndices);
33 |
--------------------------------------------------------------------------------
/stkernelaPY/getClassesAndIndicesRaw.m:
--------------------------------------------------------------------------------
1 | here=pwd;
2 | cd /home/bernard/Data/ABC/apascal
3 | attrClasses=dlmread('all-perclass.attributes');
4 | indices=dlmread('split0.mask');
5 | classes=dlmread('all.classid');
6 | cd (here);
7 |
8 | classes=classes+1;
9 |
10 | trainInstancesIndices=find(indices==1);
11 | trainInstancesLabels=classes(indices==1);
12 | trainClassesIndices=unique(trainInstancesLabels,'stable');
13 |
14 | % keyboard
15 |
16 | auxLabels=zeros(size(trainInstancesLabels));
17 | for i=1:length(unique(trainInstancesLabels))
18 | c=find(trainInstancesLabels>0, 1);
19 | class=trainInstancesLabels(c);
20 | auxLabels(trainInstancesLabels==class)=i;
21 | trainInstancesLabels(trainInstancesLabels==class)=0;
22 | end
23 | trainInstancesLabels=auxLabels';
24 |
25 | testInstancesIndices=find(indices==0);
26 | testInstancesLabels=classes(indices==0);
27 | testClassesIndices=unique(testInstancesLabels,'stable');
28 |
29 | auxLabels=zeros(size(testInstancesLabels));
30 | for i=1:length(unique(testInstancesLabels))
31 | c=find(testInstancesLabels>0, 1);
32 | class=testInstancesLabels(c);
33 | auxLabels(testInstancesLabels==class)=i;
34 | testInstancesLabels(testInstancesLabels==class)=0;
35 | end
36 | testInstancesLabels=auxLabels';
37 |
38 | newAttrClasses=[];
39 | for i=1:length(unique(classes))
40 | newAttrClasses=[newAttrClasses; attrClasses(find(classes==i,1),:)];
41 | end
42 | attrClasses=newAttrClasses; %unique(attrClasses, 'rows', 'stable');
43 |
--------------------------------------------------------------------------------
/stkernelaPY/getDataInfo.m:
--------------------------------------------------------------------------------
1 | here=pwd;
2 | cd /home/bernard/Data/ABC/apascal
3 | attrs=dlmread('all-perimage.attributes');
4 | classes=dlmread('all.classid');
5 | cd (here);
6 |
7 | nClasses=length(unique(classes));
8 | nAttrs=size(attrs, 2);
9 | V=zeros(nClasses,nAttrs);
10 | B=zeros(nClasses,nAttrs);
11 | for i=1:nClasses
12 | classi=classes(i);
13 | indices=find(classes==classi);
14 | % keyboard
15 | B(i,:)=mean(attrs(indices,:));
16 | V(i,:)=std(attrs(indices,:));
17 | end
18 | ra=std(B)./mean(V)
19 | ra=ra(~isnan(ra));
20 | mean(ra)
--------------------------------------------------------------------------------
/stkernelaPY/getMinMaxAttributes.m:
--------------------------------------------------------------------------------
1 | here=pwd;
2 | cd /home/bernard/Data/ABC/apascal
3 | attrs=dlmread('all-perimage.attributes');
4 | classes=dlmread('all.classid');
5 | cd (here);
6 | classes=classes+1;
7 |
8 | nClasses=length(unique(classes));
9 | [nInstances, nAttrs]=size(attrs);
10 |
11 | classesAttrs=zeros(nClasses*3,nAttrs);
12 | classesCells=cell(1,nClasses);
13 |
14 | for i=1:nInstances
15 | class=classes(i);
16 | signature=attrs(i,:);
17 | classesCells{class}=[classesCells{class}; signature];
18 | end
19 | counter=1;
20 | for class=1:nClasses
21 | classesAttrs(counter,:)=mean(classesCells{class});
22 | classesAttrs(counter+1,:)=mean(classesCells{class})+0.5*std(classesCells{class});
23 | classesAttrs(counter+2,:)=mean(classesCells{class})-0.5*std(classesCells{class});
24 | counter=counter+3;
25 | end
26 |
--------------------------------------------------------------------------------
/stkernelaPY/getSoftAttributes.m:
--------------------------------------------------------------------------------
1 | here=pwd;
2 | cd /home/bernard/Data/ABC/apascal
3 | attrs=dlmread('all-perimage.attributes');
4 | classes=dlmread('all.classid');
5 | cd (here);
6 | classes=classes+1;
7 |
8 | nClasses=length(unique(classes));
9 | [nInstances, nAttrs]=size(attrs);
10 |
11 | classesAttrs=zeros(nClasses,nAttrs);
12 | counters=zeros(1,nClasses);
13 |
14 | for i=1:nInstances
15 | class=classes(i);
16 | signature=attrs(i,:);
17 | classesAttrs(class,:)=classesAttrs(class,:)+signature;
18 | counters(class)=counters(class)+1;
19 | end
20 | for i=1:nClasses
21 | classesAttrs(i,:)=classesAttrs(i,:)/counters(i);
22 | end
23 |
--------------------------------------------------------------------------------
/stkernelaPY/mapScores.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/mapScores.mat
--------------------------------------------------------------------------------
/stkernelaPY/meanStdAttrClasses.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/meanStdAttrClasses.mat
--------------------------------------------------------------------------------
/stkernelaPY/runExperiment.m:
--------------------------------------------------------------------------------
1 |
2 | datasetPath='/home/bernard/HD1/ABC/apascal';
3 | addpath('..');
4 |
5 | lambdas=10.^[-3:3];
6 | sigmas=10.^[-3:3];
7 | tasks=0;
8 |
9 | validationClasses_generalscript(datasetPath, lambdas, sigmas, tasks);
--------------------------------------------------------------------------------
/stkernelaPY/track.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bernard24/Embarrassingly-simple-ZSL/700e92b1f7aebaf5a262c061803a789b082cca97/stkernelaPY/track.mat
--------------------------------------------------------------------------------
/validationClasses_generalscript.m:
--------------------------------------------------------------------------------
1 |
2 | function generalscript(path, lambdas, sigmas, tasks)
3 |
4 | addpath(path);
5 |
6 | kernelFID=fopen('all.kernel');
7 | A=fread(kernelFID, 'float');
8 | K=reshape(A,[sqrt(length(A)),sqrt(length(A))]);
9 |
10 | perclassattributes = dlmread('all-perclass.attributes');
11 | load experimentIndices
12 |
13 | A=unique(perclassattributes, 'rows');
14 | if length(dir('attrClasses*'))>0
15 | load attrClasses
16 | A=attrClasses;
17 | end
18 | attributesInput=A(trainClassesIndices,:);
19 | attributesOutput=A(testClassesIndices,:);
20 | nAttrs=size(A,2);
21 |
22 | nTrainingClasses=length(trainClassesIndices);
23 | nTrainingInstances=length(trainInstancesIndices);
24 |
25 |
26 | hist=[];
27 | while 1>0
28 |
29 |
30 | newOrder=randperm(nTrainingClasses);
31 | miniTrainingClasses=newOrder(1:floor(nTrainingClasses*0.8));
32 | miniValClasses=newOrder(floor(nTrainingClasses*0.8)+1:end);
33 |
34 | miniAttributesInput=[];
35 | miniAttributesOutput=[];
36 | miniTrainInstancesIndices=[];
37 | miniValInstancesIndices=[];
38 | miniTrainInstancesLabels=[];
39 | miniValInstancesLabels=[];
40 | miniNTrainingClasses=length(miniTrainingClasses);
41 | for i=1:miniNTrainingClasses
42 | class=newOrder(i);
43 | miniAttributesInput=[miniAttributesInput; attributesInput(class,:)];
44 | indices=trainInstancesIndices(trainInstancesLabels==class);
45 | miniTrainInstancesIndices=[miniTrainInstancesIndices; indices];
46 | miniTrainInstancesLabels=[miniTrainInstancesLabels, i*ones(1,length(indices))];
47 | end
48 | for i=1:length(miniValClasses)
49 | class=newOrder(miniNTrainingClasses+i);
50 | miniAttributesOutput=[miniAttributesOutput; attributesInput(class,:)];
51 | indices=trainInstancesIndices(trainInstancesLabels==class);
52 | miniValInstancesIndices=[miniValInstancesIndices; indices];
53 | miniValInstancesLabels=[miniValInstancesLabels, i*ones(1, length(indices))];
54 | end
55 |
56 | miniNTrainingInstances=length(miniTrainInstancesIndices);
57 |
58 | Stest=attributesOutput;
59 | Sval=miniAttributesOutput;
60 |
61 | record=0;
62 | recordman=[];
63 |
64 | for t=1:length(tasks)
65 | nNewTasks=tasks(t);
66 |
67 | disp('Getting labels...');
68 |
69 | Y=zeros(miniNTrainingInstances,miniNTrainingClasses+nNewTasks);%-1;%-ones(nTrainingInstances,nTrainingClasses);%/(nTrainingClasses-1);
70 | for i=1:miniNTrainingInstances
71 | Y(i,miniTrainInstancesLabels(i))=1;
72 | end
73 |
74 | disp('Precalculating statistics...');
75 | KTrain=K(miniTrainInstancesIndices, miniTrainInstancesIndices);
76 | KK=KTrain'*KTrain;
77 |
78 | S=[miniAttributesInput; rand(nNewTasks, nAttrs)];
79 | KYS=KTrain*Y*S;
80 |
81 | for s=1:length(sigmas)
82 | sigma=sigmas(s);
83 |
84 | KYS_invSS=KYS/(S'*S+sigma*eye(size(S,2)));
85 | disp('Learning...');
86 | for lambdaIndex=1:length(lambdas)
87 | lambda=lambdas(lambdaIndex);
88 |
89 | % Alpha=(KK+lambda*KTrain)\KYS_invSS;
90 | Alpha=(KK+lambda*eye(size(KTrain)))\KYS_invSS;
91 |
92 | disp('Predicting...');
93 | KTest=K(miniValInstancesIndices,miniTrainInstancesIndices);
94 |
95 | pred=Sval*Alpha'*KTest';
96 | [scores, classPred]=max(pred',[],2);
97 |
98 | disp('Evaluating...');
99 | gt=miniValInstancesLabels;
100 | r=mean(gt==classPred');
101 | [r record]
102 | if r>record
103 | record=r;
104 | recordman=[lambda, sigma, nNewTasks];
105 | disp(['Record! ', num2str(record)]);
106 | end
107 | end
108 | end
109 | end
110 |
111 | lambda=recordman(1);
112 | sigma=recordman(2);
113 | nNewTasks=recordman(3);
114 |
115 | Y=zeros(nTrainingInstances,nTrainingClasses+nNewTasks);
116 | for i=1:nTrainingInstances
117 | Y(i,trainInstancesLabels(i))=1;
118 | end
119 |
120 | disp('Precalculating statistics...');
121 | KTrain=K(trainInstancesIndices, trainInstancesIndices);
122 | KK=KTrain'*KTrain;
123 |
124 | S=[attributesInput; rand(nNewTasks, nAttrs)];
125 | KYS=KTrain*Y*S;
126 | KYS_invSS=KYS/(S'*S+sigma*eye(size(S,2)));
127 | disp('Learning...');
128 |
129 | Alpha=(KK+lambda*KTrain)\KYS_invSS;
130 | % Alpha=(KK+lambda*eye(size(KTrain)))\KYS_invSS;
131 |
132 | disp('Predicting...');
133 | KTest=K(testInstancesIndices,trainInstancesIndices);
134 |
135 | pred=Stest*Alpha'*KTest';
136 | [scores, classPred]=max(pred',[],2);
137 |
138 | disp('Evaluating...');
139 | gt=testInstancesLabels;
140 | r=mean(gt==classPred')
141 | keyboard
142 | hist=[hist, r];
143 | save('val_hist', 'hist');
144 | end
145 |
146 | end
147 |
--------------------------------------------------------------------------------