├── .gitignore
├── Contents.m
├── README.md
├── bsd.txt
├── models
└── forest
│ └── .gitignore
├── stChns.m
├── stDemo.m
├── stDetect.m
├── stDetectMex.cpp
├── stEvalBsds.m
├── stGetPatches.m
├── stToEdges.m
└── stTrain.m
/.gitignore:
--------------------------------------------------------------------------------
1 | *.mat
2 | *.mex*
3 | *.tgz
--------------------------------------------------------------------------------
/Contents.m:
--------------------------------------------------------------------------------
1 | % SKETCHTOKENS
2 | % See also
3 | %
4 | % Demo:
5 | % stDemo - Sketch Token demo and usage.
6 | %
7 | % Training code:
8 | % stGetPatches - Sample ground truth edge sketch patches.
9 | % stTrain - Train SketchTokens classifier.
10 | %
11 | % Runtime code:
12 | % stChns - Compute channels for sketch token detection.
13 | % stDetect - Detect sketch tokens in image.
14 | % stEvalBsds - Evaluate sketch token edge detector on BSDS500.
15 | % stToEdges - Convert sketch tokens to edges.
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Sketch Token Toolbox V0.95
2 | ==============================
3 |
4 | This software package provides tools to extract contour-based mid-level features, and to extract contour segmentations from images.
5 | This tool is highly efficient in speed while maintains high accuracy in contour detection.
6 | Also, [1] shows that extracted mid-level features provide additional information for object and pedestrian detections.
7 |
8 | Installation
9 | ------------
10 |
11 | -
12 | Download Piotr's Image & Video Matlab Toolbox (http://vision.ucsd.edu/~pdollar/toolbox/doc/)
13 | SketchTokens/toolbox/ should have channels, classify, filters, images, matlab, etc,.
14 |
15 |
16 | -
17 | Download Berkeley Segmentation Data Set and Benchmarks 500 (http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html)
18 | SketchTokens/data/BSR/ should have BSDS500, bench, and documentation.
19 |
20 |
21 | -
22 | Pre-trained models can be downloaded from: http://people.csail.mit.edu/lim/lzd_cvpr2013/st_data.tgz
23 |
24 |
25 | -
26 | Look up stDemo.m for how to train and test our code
27 |
28 |
29 |
30 | References
31 | ----------
32 | Please cite the following paper if you end up using the code:
33 | [1] Joseph J. Lim, C. Lawrence Zitnick, and Piotr Dollar. "Sketch Tokens: A Learned Mid-level Representation for Contour and and Object Detection," CVPR2013.
34 |
35 | License
36 | -------
37 | Copyright 2013 Joseph Lim [lim@csail.mit.edu]
38 |
39 | Please email me if you find bugs, or have suggestions or questions!
40 |
41 | Licensed under the Simplified BSD License [see bsd.txt]
42 |
43 | Note: There is a patent pending on the ideas presented in this work so this code should only be used for academic purposes.
44 |
--------------------------------------------------------------------------------
/bsd.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2012, Joseph Lim
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 | 2. Redistributions in binary form must reproduce the above copyright notice,
10 | this list of conditions and the following disclaimer in the documentation
11 | and/or other materials provided with the distribution.
12 |
13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 |
24 | The views and conclusions contained in the software and documentation are those
25 | of the authors and should not be interpreted as representing official policies,
26 | either expressed or implied, of the FreeBSD Project.
27 |
--------------------------------------------------------------------------------
/models/forest/.gitignore:
--------------------------------------------------------------------------------
1 | *.mat
2 |
--------------------------------------------------------------------------------
/stChns.m:
--------------------------------------------------------------------------------
1 | function chns = stChns( I, opts )
2 | % Compute channels for sketch token detection.
3 | %
4 | % USAGE
5 | % chns = stChns( I, opts )
6 | %
7 | % INPUTS
8 | % I - [h x w x 3] color input image
9 | % opts - sketch token model options
10 | %
11 | % OUTPUTS
12 | % chns - [h x w x nChannel] output channels
13 | %
14 | % EXAMPLE
15 | %
16 | % See also stDetect, gradientMag, gradientHist
17 | %
18 | % Sketch Token Toolbox V0.95
19 | % Copyright 2013 Joseph Lim [lim@csail.mit.edu]
20 | % Please email me if you find bugs, or have suggestions or questions!
21 | % Licensed under the Simplified BSD License [see bsd.txt]
22 |
23 | % extract gradient magnitude and histogram channels
24 | if isfield(opts, 'inputColorChannel') && strcmp(opts.inputColorChannel, 'luv')
25 | I = rgbConvert(I,'orig');
26 | else
27 | I = rgbConvert(I,'luv');
28 | end
29 | chns=cell(1,1000);
30 | k=1;
31 | chns{k}=I;
32 | for i = 1:length( opts.sigmas )
33 | if ( opts.sigmas(i)==0 ),
34 | I1=I;
35 | else
36 | f = fspecial('gaussian',opts.radius,opts.sigmas(i));
37 | I1 = imfilter(I,f);
38 | end
39 | if( opts.nOrients(i)>0 )
40 | [M,O] = gradientMag( I1, 0, opts.normRad, opts.normConst );
41 | H = gradientHist( M, O, 1, opts.nOrients(i), 0 );
42 | k=k+1;
43 | chns{k}=M;
44 | k=k+1;
45 | chns{k}=H;
46 | else
47 | M = gradientMag( I1, 0, opts.normRad, opts.normConst );
48 | k=k+1;
49 | chns{k}=M;
50 | end
51 | end
52 | chns=cat(3,chns{1:k});
53 | chns=convTri(chns,opts.chnsSmooth);
54 | end
55 |
--------------------------------------------------------------------------------
/stDemo.m:
--------------------------------------------------------------------------------
1 | %% Sketch Token demo and usage.
2 | %%
3 | %% Please cite the following paper if you end up using the code:
4 | %% Joseph J. Lim, C. Lawrence Zitnick, and Piotr Dollar. "Sketch Tokens: A
5 | %% Learned Mid-level Representation for Contour and and Object Detection,"
6 | %% CVPR2013.
7 | %%
8 | %% Note: There is a patent pending on the ideas presented in this work so
9 | %% this code should only be used for academic purposes.
10 | %%
11 | %% Sketch Token Toolbox V0.95
12 | %% Copyright 2013 Joseph Lim [lim@csail.mit.edu]
13 | %% Please email me if you find bugs, or have suggestions or questions!
14 | %% Licensed under the Simplified BSD License [see bsd.txt]
15 |
16 |
17 |
18 | %% setup (follow instructions, only need to do once)
19 | cd(fileparts(mfilename('fullpath')))
20 | if( 0 )
21 | % (1) Download Berkeley Segmentation Data Set and Benchmarks 500
22 | % http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/
23 | % BSR/ should contain BSDS500, bench, and documentation
24 | addpath('BSR/bench/benchmarks');
25 | % (2) Download and compile Piotr's toolbox (directions on website)
26 | % http://vision.ucsd.edu/~pdollar/toolbox/doc/index.html
27 | % (3) Compile code (removed OPTIMFLAGS to compile single core variant)
28 | mex stDetectMex.cpp 'OPTIMFLAGS="$OPTIMFLAGS' '/openmp"'
29 | % (4) Add current directory to path
30 | addpath(pwd);
31 | end
32 |
33 | %% train or load model (see stTrain.m)
34 | if( 0 )
35 | if( 1 )
36 | % can be trained in ~20m and requires ~3GB ram
37 | % BSDS500 performance: ODS=0.721 OIS=0.739 AP=0.768
38 | opts=struct('nPos',100,'nNeg',80,'modelFnm','modelSmall','nTrees',20);
39 | else
40 | % takes ~2.5 hours to train and requires ~27GB ram
41 | % BSDS500 performance: ODS=0.727 OIS=0.746 AP=0.780
42 | opts=struct('nPos',1000,'nNeg',800,'modelFnm','modelFull');
43 | end
44 | tic, model=stTrain(opts); toc
45 | else
46 | % Pre-trained models can be downloaded from:
47 | % http://people.csail.mit.edu/lim/lzd_cvpr2013/st_data.tgz
48 | load('models/forest/modelSmall.mat');
49 | end
50 |
51 | %% evaluate sketch token model on BSDS500 (see stEvalBsds.m)
52 | if (0),
53 | [ODS,OIS,AP]=stEvalBsds( model );
54 | end
55 |
56 | %% detect sketch tokens and extract edges (see stDetect.m and stToEdges.m)
57 | I = imread('peppers.png');
58 | tic; st = stDetect( I, model ); toc;
59 | tic, E = stToEdges( st, 1 ); toc
60 |
61 | %% visualize edge detection results
62 | figure(1);
63 | im(I);
64 | figure(2);
65 | im(E);
66 | colormap jet;
67 |
--------------------------------------------------------------------------------
/stDetect.m:
--------------------------------------------------------------------------------
1 | function S = stDetect( I, model, stride, rescale_back )
2 | % Detect sketch tokens in image.
3 | %
4 | % USAGE
5 | % S = stDetect( I, model, [stride] )
6 | %
7 | % INPUTS
8 | % I - [h x w x 3] color input image
9 | % model - sketch token model trained with stTrain
10 | % stride - [2] stride at which to compute sketch tokens
11 | % rescale_back - [true] rescale after running stride
12 | %
13 | % OUTPUTS
14 | % S - [h x w x (nTokens+1)] sketch token probability maps
15 | %
16 | % EXAMPLE
17 | %
18 | % See also stTrain, stChns
19 | %
20 | % Sketch Token Toolbox V0.95
21 | % Copyright 2013 Joseph Lim [lim@csail.mit.edu]
22 | % Please email me if you find bugs, or have suggestions or questions!
23 | % Licensed under the Simplified BSD License [see bsd.txt]
24 |
25 | if nargin<3
26 | stride=2;
27 | end
28 | if nargin<4
29 | rescale_back = true;
30 | end
31 |
32 | % compute features
33 | sizeOrig=size(I);
34 | opts=model.opts;
35 | %opts.inputColorChannel = 'luv';
36 | %opts.inputColorChannel = 'rgb';
37 | I = imPad(I,opts.radius,'symmetric');
38 | chns = stChns( I, opts );
39 | [cids1,cids2] = computeCids(size(chns),opts);
40 | if opts.nCells
41 | chnsSs = convBox(chns,opts.cellRad);
42 | else
43 | chnsSs = [];
44 | end
45 |
46 | % run forest on image
47 | S = stDetectMex( chns, chnsSs, model.thrs, model.fids, model.child, ...
48 | model.distr, cids1, cids2, stride, opts.radius, opts.nChnFtrs );
49 |
50 | % finalize sketch token probability maps
51 | S = permute(S,[2 3 1]) * (1/opts.nTrees);
52 | if ~rescale_back
53 | %keyboard;
54 | else
55 | S = imResample( S, stride );
56 | cr=size(S); cr=cr(1:2)-sizeOrig(1:2);
57 | if any(cr)
58 | S=S(1:end-cr(1),1:end-cr(2),:);
59 | end
60 | end
61 |
62 | end
63 |
64 | function [cids1,cids2] = computeCids( siz, opts )
65 | % construct cids lookup for standard features
66 | radius=opts.radius;
67 | s=opts.patchSiz;
68 | nChns=opts.nChns;
69 |
70 | ht=siz(1);
71 | wd=siz(2);
72 | assert(siz(3)==nChns);
73 |
74 | nChnFtrs=s*s*nChns;
75 | fids=uint32(0:nChnFtrs-1);
76 | rs=mod(fids,s);
77 | fids=(fids-rs)/s;
78 | cs=mod(fids,s);
79 | ch=(fids-cs)/s;
80 | cids = rs + cs*ht + ch*ht*wd;
81 |
82 | % construct cids1/cids2 lookup for self-similarity features
83 | n=opts.nCells;
84 | m=opts.cellStep;
85 | nCellTotal=(n*n)*(n*n-1)/2;
86 |
87 | assert((n==0) || (mod(n,2)==1));
88 | n1=(n-1)/2;
89 | nSimFtrs=nCellTotal*nChns;
90 | fids=uint32(0:nSimFtrs-1);
91 | ind=mod(fids,nCellTotal);
92 | ch=(fids-ind)/nCellTotal;
93 |
94 | ind1 = []; ind2 = [];
95 | k=0;
96 | for i=1:n*n-1,
97 | k1=n*n-i;
98 | ind1(k+1:k+k1)=(0:k1-1);
99 | k=k+k1;
100 | end
101 | k=0;
102 | for i=1:n*n-1,
103 | k1=n*n-i;
104 | ind2(k+1:k+k1)=(0:k1-1)+i;
105 | k=k+k1;
106 | end
107 |
108 | ind1=ind1(ind+1);
109 | rs1=mod(ind1,n);
110 | cs1=(ind1-rs1)/n;
111 | ind2=ind2(ind+1);
112 | rs2=mod(ind2,n);
113 | cs2=(ind2-rs2)/n;
114 |
115 | rs1=uint32((rs1-n1)*m+radius);
116 | cs1=uint32((cs1-n1)*m+radius);
117 | rs2=uint32((rs2-n1)*m+radius);
118 | cs2=uint32((cs2-n1)*m+radius);
119 |
120 | cids1 = rs1 + cs1*ht + ch*ht*wd;
121 | cids2 = rs2 + cs2*ht + ch*ht*wd;
122 |
123 | % combine cids for standard and self-similarity features
124 | cids1=[cids cids1];
125 | cids2=[zeros(1,nChnFtrs) cids2];
126 | end
127 |
--------------------------------------------------------------------------------
/stDetectMex.cpp:
--------------------------------------------------------------------------------
1 | /*******************************************************************************
2 | * Sketch Token Toolbox V0.95
3 | * Copyright 2013 Joseph Lim [lim@csail.mit.edu]
4 | * Please email me if you find bugs, or have suggestions or questions!
5 | * Licensed under the Simplified BSD License [see bsd.txt]
6 | *******************************************************************************/
7 | #include "mex.h"
8 | #include
9 | #include
10 |
11 | typedef unsigned int uint32;
12 |
13 | void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] )
14 | {
15 | // get inputs
16 | float *chns = (float*) mxGetData(prhs[0]);
17 | float *chnsSs = (float*) mxGetData(prhs[1]);
18 | float *thrs = (float*) mxGetData(prhs[2]);
19 | uint32 *fids = (uint32*) mxGetData(prhs[3]);
20 | uint32 *child = (uint32*) mxGetData(prhs[4]);
21 | float *distr = (float*) mxGetData(prhs[5]);
22 | uint32 *cids1 = (uint32*) mxGetData(prhs[6]);
23 | uint32 *cids2 = (uint32*) mxGetData(prhs[7]);
24 | const int stride = (int) mxGetScalar(prhs[8]);
25 | const int rad = (int) mxGetScalar(prhs[9]);
26 | const int nChnFtrs = (int) mxGetScalar(prhs[10]);
27 |
28 | // get dimensions and constants
29 | const mwSize *chnsSize = mxGetDimensions(prhs[0]);
30 | const int height = (int) chnsSize[0];
31 | const int width = (int) chnsSize[1];
32 | const mwSize *distrSize = mxGetDimensions(prhs[5]);
33 | const int nTokens = (int) distrSize[0];
34 | const int nTreeNodes = (int) distrSize[1];
35 | const int nTrees = (int) distrSize[2];
36 | const int heightOut = (int) ceil((height-rad*2.0)/stride);
37 | const int widthOut = (int) ceil((width-rad*2.0)/stride);
38 |
39 | // create output
40 | const int outDims[3]={nTokens,heightOut,widthOut};
41 | float *S = (float*) mxCalloc(nTokens*heightOut*widthOut,sizeof(float));
42 | plhs[0] = mxCreateNumericMatrix(0,0,mxSINGLE_CLASS,mxREAL);
43 | mxSetData(plhs[0],S);
44 | mxSetDimensions(plhs[0],outDims,3);
45 |
46 | // apply forest to each patch
47 | #pragma omp parallel for
48 | for( int c=0; c
88 |
89 | I = imread([imgDir id '.jpg']);
90 | st = max(min(1,stDetect(I,model)),0);
91 | E = stToEdges(st,1);
92 |
93 | imwrite(uint8(E*255),[resDir id '.png']);
94 | end
95 |
96 | % perform evaluation on each image (Linux only, slow)
97 | if(ispc),
98 | error('Evaluation code runs on Linux ONLY.');
99 | end
100 | do=false(1,n);
101 | jobs=cell(1,n);
102 | for i=1:n,
103 | do(i)=~exist([evalDir ids{i} '_ev1.txt'],'file');
104 | end
105 | for i=1:n,
106 | id=ids{i};
107 | jobs{i}={[resDir id '.png'],...
108 | [gtDir id '.mat'],[evalDir id '_ev1.txt'],p.nThresh};
109 | end
110 | if(~exist(evalDir,'dir')),
111 | mkdir(evalDir);
112 | end
113 | fevalDistr('evaluation_bdry_image',jobs(do),p.pDistr{:});
114 |
115 | % collect results and display
116 | [ODS,~,~,~,OIS,~,~,AP]=collect_eval_bdry(evalDir);
117 | fprintf('ODS=%.3f OIS=%.3f AP=%.3f\n',ODS,OIS,AP);
118 | if( p.show ),
119 | plot_eval(evalDir,'r');
120 | end
121 | if( p.cleanup ),
122 | delete([evalDir '/*_ev1.txt']);
123 | delete([evalDir '/eval_bdry_img.txt']);
124 | delete([evalDir '/eval_bdry_thr.txt']);
125 | delete([resDir '/*.png']); rmdir(resDir);
126 | end
127 |
128 | end
129 |
--------------------------------------------------------------------------------
/stGetPatches.m:
--------------------------------------------------------------------------------
1 | function clusters = stGetPatches( radius, nPatches, bsdsDir )
2 | % Sample ground truth edge sketch patches.
3 | %
4 | % Calling stGetPatches() is *optional* as the code package comes with
5 | % pre-computed clusters.mat. stGetPatches() only needs to be called
6 | % if you want to generate alternate clusters to the ones provided.
7 | %
8 | % stGetPatches() is the first step in generating sketch tokens classes: it
9 | % is used to sample patches of ground truth edge maps from the training set
10 | % of the Berkeley Segmentation Dataset. After sampling, the extracted
11 | % patches must be clustered to generate the sketch token classes.
12 | % Clustering code is *NOT* provided. It should be easy to implement, see
13 | % the paper for more details. After clustering, the additional fields
14 | % "clusterId" and "clusters" should be initialized in the clusters struct.
15 | % "clusterId" should indicate cluster membership of each extracted patch
16 | % (integer between 1 and K) and "cluters" should be the mean of the patches
17 | % belonging to the given cluster. After these two fields are added to the
18 | % clusters struct, the sketch token model is ready to be trained via
19 | % stTrain() (see parameter "clusterFnm" in stTrain.m).
20 | %
21 | % USAGE
22 | % clusters = stGetPatches( [radius], [nPatches], [bsdsDir] )
23 | %
24 | % INPUTS
25 | % radius - [15] radius of sketch token patches
26 | % nPatches - [inf] maximum number of patches to sample
27 | % bsdsDir - ['BSR/BSDS500/data/'] location of BSDS dataset
28 | %
29 | % OUTPUTS
30 | % clusters - extracted ground truth info w the following fields
31 | % .x - [Nx1] x-coordinate each sampled patch
32 | % .y - [Nx1] y-coordinate each sampled patch
33 | % .gtId - [Nx1] integer ground turth labeler of each sampled patch
34 | % .imId - [Nx1] integer image id of each sampled patch
35 | % .patches - [PxPxN] binary images of sampled patches, P=2*radius+1
36 | % .clusterId - [Nx1] cluster membership in [1,K] (NOT COMPUTED)
37 | % .clusters - [PxPxK] cluster images (NOT COMPUTED)
38 | %
39 | % EXAMPLE
40 | %
41 | % See also stTrain
42 | %
43 | % Sketch Token Toolbox V0.95
44 | % Copyright 2013 Joseph Lim [lim@csail.mit.edu]
45 | % Please email me if you find bugs, or have suggestions or questions!
46 | % Licensed under the Simplified BSD License [see bsd.txt]
47 |
48 | if( nargin<1 ),
49 | radius=15;
50 | end
51 | if( nargin<2 ),
52 | nPatches=inf;
53 | end
54 | if( nargin<3 ),
55 | bsdsDir='BSR/BSDS500/data/';
56 | end
57 |
58 | % location of ground truth
59 | trnImgDir = [bsdsDir '/images/train/'];
60 | trnGtDir = [bsdsDir '/groundTruth/train/'];
61 | imgIds=dir([trnImgDir '*.jpg']);
62 | imgIds={imgIds.name};
63 | nImgs=length(imgIds);
64 | for i=1:nImgs,
65 | imgIds{i}=imgIds{i}(1:end-4);
66 | end
67 |
68 | % loop over ground truth and collect samples
69 | clusters=struct('x',[],'y',[],'gtId',[],'imId',[],'patches',[],...
70 | 'clusterId',[],'clusters',[]);
71 | clusters.patches = false(radius*2+1,radius*2+1,9000*5*nImgs);
72 | tid = ticStatus('data collection');
73 | cnt=0;
74 | for i = 1:nImgs
75 | gt=load([trnGtDir imgIds{i} '.mat']);
76 | gt=gt.groundTruth;
77 | for j=1:length(gt)
78 | if(isempty(gt{j}.Boundaries)),
79 | continue;
80 | end
81 | M0 = gt{j}.Boundaries;
82 | M=M0;
83 | M([1:radius end-radius+1:end],:)=0;
84 | M(:,[1:radius end-radius+1:end])=0;
85 | [y,x]=find(M);
86 | cnt1=length(y);
87 | clusters.y = [clusters.y; int32(y)];
88 | clusters.x = [clusters.x; int32(x)];
89 | clusters.gtId = [clusters.gtId; ones(cnt1,1,'int32')*j];
90 | clusters.imId = [clusters.imId; ones(cnt1,1,'int32')*i];
91 | for k=1:cnt1,
92 | clusters.patches(:,:,cnt+k) = M0(y(k)-radius:y(k)+radius,x(k)-radius:x(k)+radius);
93 | end
94 | cnt = cnt + cnt1;
95 | end
96 | tocStatus(tid, i/nImgs);
97 | end
98 | clusters.patches = clusters.patches(:,:,1:cnt);
99 |
100 | % optionally sample patches
101 | if( nPatches0;
91 | Dy(F)=-Dy(F);
92 | O=mod(atan2(Dy,Dx),pi);
93 | end
94 |
--------------------------------------------------------------------------------
/stTrain.m:
--------------------------------------------------------------------------------
1 | function model = stTrain( varargin )
2 | % Train SketchTokens classifier.
3 | %
4 | % See stDemo for a full demo that include both traininga and application.
5 | %
6 | % Pre-trained models can be downloaded from:
7 | % http://people.csail.mit.edu/lim/lzd_cvpr2013/st_data.tgz
8 | %
9 | % Please cite the following paper if you end up using the code:
10 | % Joseph J. Lim, C. Lawrence Zitnick, and Piotr Dollar. "Sketch Tokens: A
11 | % Learned Mid-level Representation for Contour and and Object Detection,"
12 | % CVPR2013.
13 | %
14 | % Note: There is a patent pending on the ideas presented in this work so
15 | % this code should only be used for academic purposes.
16 | %
17 | % USAGE
18 | % model = stTrain( opts )
19 | %
20 | % INPUTS
21 | % opts - parameters (struct or name/value pairs)
22 | % (1) parameters for model and data:
23 | % .nClusters - [150] number of clusters to train with
24 | % .nTrees - [25] number of trees in forest to train
25 | % .radius - [17] radius of sketch token patches
26 | % .nPos - [1000] number of positive patches per cluster
27 | % .nNeg - [800] number of negative patches per image
28 | % .negDist - [2] distance from closest contour defining a negative
29 | % .minCount - [4] minimum number of training examples per node
30 | % (2) parameters for features:
31 | % .nCells - [5] number of self similarity cells
32 | % .normRad - [5] normalization radius (see gradientMag)
33 | % .normConst - [.01] normalization constant (see gradientMag)
34 | % .nOrients - [4 4 0] number of orientations for each channel set
35 | % .sigmas - [0 1.5 5] gaussian blur for each channel set
36 | % .chnsSmooth - [2] radius for channel smoothing (using convTri)
37 | % .fracFtrs - [1] fraction of features to use to train each tree
38 | % (3) other parameters:
39 | % .seed - [1] seed for random stream (for reproducibility)
40 | % .modelDir - ['models/'] target directory for storing models
41 | % .modelFnm - ['model'] model filename
42 | % .clusterFnm - ['clusters.mat'] file containing cluster info
43 | % .bsdsDir - ['BSR/BSDS500/data/'] location of BSDS dataset
44 | %
45 | % OUTPUTS
46 | % model - trained sketch token detector w the following fields
47 | % .trees - learned forest model struct array (see forestTrain)
48 | % .opts - input parameters and constants
49 | % .clusters - actual cluster centers used to learn tokens
50 | %
51 | % EXAMPLE
52 | %
53 | % See also stGetPatches, stDetect, forestTrain, chnsCompute, gradientMag
54 | %
55 | % Sketch Token Toolbox V0.95
56 | % Copyright 2013 Joseph Lim [lim@csail.mit.edu]
57 | % Please email me if you find bugs, or have suggestions or questions!
58 | % Licensed under the Simplified BSD License [see bsd.txt]
59 |
60 | % get default parameters
61 | dfs={'nClusters',150, 'nTrees',25, 'radius',17, 'nPos',1000, 'nNeg',800,...
62 | 'negDist',2, 'minCount',4, 'nCells',5, 'normRad',5, 'normConst',0.01, ...
63 | 'nOrients',[4 4 0], 'sigmas',[0 1.5 5], 'chnsSmooth',2, 'fracFtrs',1, ...
64 | 'seed',1, 'modelDir','models/', 'modelFnm','model', ...
65 | 'clusterFnm','clusters.mat', 'bsdsDir','BSR/BSDS500/data/'};
66 | opts = getPrmDflt(varargin,dfs,1);
67 |
68 | % if forest exists load it and return
69 | cd(fileparts(mfilename('fullpath')));
70 | forestDir = [opts.modelDir '/forest/'];
71 | forestFn = [forestDir opts.modelFnm];
72 | if exist([forestFn '.mat'], 'file')
73 | load([forestFn '.mat']);
74 | return;
75 | end
76 |
77 | % compute constants and store in opts
78 | nTrees=opts.nTrees;
79 | nCells=opts.nCells;
80 |
81 | patchSiz=opts.radius*2+1;
82 | opts.patchSiz=patchSiz;
83 |
84 | nChns = size(stChns(ones(2,2,3),opts),3);
85 | opts.nChns=nChns;
86 |
87 | opts.nChnFtrs = patchSiz*patchSiz*nChns;
88 | opts.nSimFtrs = (nCells*nCells)*(nCells*nCells-1)/2*nChns;
89 | opts.nTotFtrs = opts.nChnFtrs + opts.nSimFtrs;
90 | opts.cellRad = round(patchSiz/nCells/2);
91 | tmp=opts.cellRad*2+1;
92 | opts.cellStep = tmp-ceil((nCells*tmp-patchSiz)/(nCells-1)); disp(opts);
93 | assert( (nCells == 0) || (mod(nCells,2)==1 && (nCells-1)*opts.cellStep+tmp <= patchSiz ));
94 |
95 | % generate stream for reproducibility of model
96 | stream=RandStream('mrg32k3a','Seed',opts.seed);
97 |
98 | % train nTrees random trees (can be trained with parfor if enough memory)
99 | for i=1:nTrees
100 | stTrainTree( opts, stream, i );
101 | end
102 |
103 | % accumulate trees and merge into final model
104 | treeFn = [opts.modelDir '/tree/' opts.modelFnm '_tree'];
105 | for i=1:nTrees
106 | t=load([treeFn int2str2(i,3) '.mat'],'tree');
107 | t=t.tree;
108 | if (i==1)
109 | trees=t(ones(1,nTrees));
110 | else
111 | trees(i)=t;
112 | end
113 | end
114 | nNodes=0;
115 | for i=1:nTrees
116 | nNodes=max(nNodes,size(trees(i).fids,1));
117 | end
118 | model.thrs=zeros(nNodes,nTrees,'single');
119 | Z=zeros(nNodes,nTrees,'uint32');
120 | model.fids=Z;
121 | model.child=Z;
122 | model.count=Z;
123 | model.depth=Z;
124 | model.distr=zeros(nNodes,size(trees(1).distr,2),nTrees,'single');
125 | for i=1:nTrees, tree=trees(i); nNodes1=size(tree.fids,1);
126 | model.fids(1:nNodes1,i) = tree.fids;
127 | model.thrs(1:nNodes1,i) = tree.thrs;
128 | model.child(1:nNodes1,i) = tree.child;
129 | model.distr(1:nNodes1,:,i) = tree.distr;
130 | model.count(1:nNodes1,i) = tree.count;
131 | model.depth(1:nNodes1,i) = tree.depth;
132 | end
133 | model.distr = permute(model.distr, [2 1 3]);
134 |
135 | clusters=load(opts.clusterFnm);
136 | clusters=clusters.clusters;
137 |
138 | model.opts = opts;
139 | model.clusters=clusters.clusters;
140 | if ~exist(forestDir,'dir')
141 | mkdir(forestDir);
142 | end
143 | save([forestFn '.mat'], 'model', '-v7.3');
144 |
145 | end
146 |
147 | function stTrainTree( opts, stream, treeInd )
148 | % Train a single tree in forest model.
149 |
150 | % location of ground truth
151 | trnImgDir = [opts.bsdsDir '/images/train/'];
152 | trnGtDir = [opts.bsdsDir '/groundTruth/train/'];
153 | imgIds=dir([trnImgDir '*.jpg']);
154 | imgIds={imgIds.name};
155 | nImgs=length(imgIds);
156 | for i=1:nImgs,
157 | imgIds{i}=imgIds{i}(1:end-4);
158 | end
159 |
160 | % extract commonly used options
161 | radius=opts.radius;
162 | patchSiz=opts.patchSiz;
163 | nChns=opts.nChns;
164 | nTotFtrs=opts.nTotFtrs;
165 | nClusters=opts.nClusters;
166 | nPos=opts.nPos;
167 | nNeg=opts.nNeg;
168 |
169 | % finalize setup
170 | treeDir = [opts.modelDir '/tree/'];
171 | treeFn = [treeDir opts.modelFnm '_tree'];
172 | if exist([treeFn int2str2(treeInd,3) '.mat'],'file')
173 | return;
174 | end
175 | fprintf('\n-------------------------------------------\n');
176 | fprintf('Training tree %d of %d\n',treeInd,opts.nTrees);
177 | tStart=clock;
178 |
179 | % set global stream to stream with given substream (will undo at end)
180 | streamOrig = RandStream.getGlobalStream();
181 | set(stream,'Substream',treeInd);
182 | RandStream.setGlobalStream( stream );
183 |
184 | % sample nPos positive patch locations per cluster
185 | clstr=load(opts.clusterFnm);
186 | clstr=clstr.clusters;
187 | for i = 1:nClusters
188 | if i==1
189 | centers=[];
190 | end
191 | ids = find(clstr.clusterId == i);
192 | ids = ids(randperm(length(ids),min(nPos,length(ids))));
193 | centers = [centers; [clstr.x(ids),clstr.y(ids),clstr.imId(ids),...
194 | clstr.clusterId(ids),clstr.gtId(ids)]]; %#ok
195 | end
196 |
197 | % collect positive and negative patches and compute features
198 | fids=sort(randperm(nTotFtrs,round(nTotFtrs*opts.fracFtrs)));
199 | k = size(centers,1)+nNeg*nImgs;
200 | ftrs = zeros(k,length(fids),'single');
201 | labels = zeros(k,1); k = 0;
202 | tid = ticStatus('Collecting data',1,1);
203 | for i = 1:nImgs
204 | % get image and compute channels
205 | gt=load([trnGtDir imgIds{i} '.mat']);
206 | gt=gt.groundTruth;
207 |
208 | I = imread([trnImgDir imgIds{i} '.jpg']);
209 | I = imPad(I,radius,'symmetric');
210 | chns = stChns(I,opts);
211 |
212 | % sample positive patch locations
213 | centers1=centers(centers(:,3)==i,:);
214 | lbls1=centers1(:,4);
215 | xy1=single(centers1(:,[1 2]));
216 |
217 | % sample negative patch locations
218 | M=false(size(I,1)-2*radius,size(I,2)-2*radius);
219 | nGt=length(gt);
220 | for j=1:nGt
221 | M1=gt{j}.Boundaries;
222 | if ~isempty(M1)
223 | M=M | M1;
224 | end
225 | end
226 | M(bwdist(M)0) = fids(tree.fids(tree.child>0)+1)-1;
266 | tree=pruneTree(tree,opts.minCount); %#ok
267 | if ~exist(treeDir,'dir')
268 | mkdir(treeDir);
269 | end
270 | save([treeFn int2str2(treeInd,3) '.mat'],'tree');
271 | e=etime(clock,tStart);
272 | fprintf('Training of tree %d complete (time=%.1fs).\n',treeInd,e);
273 | RandStream.setGlobalStream( streamOrig );
274 |
275 | end
276 |
277 | function tree = pruneTree( tree, minCount )
278 | % Prune all nodes whose count is less than minCount.
279 |
280 | % mark all internal nodes if either child has count<=minCount
281 | mark = [0; tree.count<=minCount];
282 | mark = mark(tree.child+1) | mark(tree.child+2);
283 |
284 | % list of nodes to be discarded / kept
285 | disc=tree.child(mark);
286 | disc=[disc; disc+1];
287 | n=length(tree.fids);
288 | keep=1:n;
289 | keep(disc)=[];
290 |
291 | % prune tree
292 | tree.fids=tree.fids(keep);
293 | tree.thrs=tree.thrs(keep);
294 | tree.child=tree.child(keep);
295 | tree.distr=tree.distr(keep,:);
296 | tree.count=tree.count(keep);
297 | tree.depth=tree.depth(keep);
298 | assert(all(tree.count>minCount))
299 |
300 | % re-index children
301 | route=zeros(1,n);
302 | route(keep)=1:length(keep);
303 | tree.child(tree.child>0) = route(tree.child(tree.child>0));
304 | end
305 |
306 | function ftrs = stComputeSimFtrs( chns, opts )
307 | % Compute self-similarity features.
308 | n=opts.nCells;
309 | if(n==0),
310 | ftrs=[];
311 | return;
312 | end
313 | nSimFtrs=opts.nSimFtrs;
314 | nChns=opts.nChns;
315 | m=size(chns,4);
316 |
317 | inds = ((1:n)-(n+1)/2)*opts.cellStep+opts.radius+1;
318 | chns=reshape(chns,opts.patchSiz,opts.patchSiz,nChns*m);
319 | chns=convBox(chns,opts.cellRad);
320 | chns=reshape(chns(inds,inds,:,:),n*n,nChns,m);
321 | ftrs=zeros(nSimFtrs/nChns,nChns,m,'single');
322 | k=0;
323 | for i=1:n*n-1
324 | k1=n*n-i;
325 | ftrs(k+1:k+k1,:,:)=chns(1:end-i,:,:)-chns(i+1:end,:,:);
326 | k=k+k1;
327 | end
328 | ftrs = reshape(ftrs,nSimFtrs,m)';
329 | % % For m=1, the above should be identical to the following:
330 | % [cids1,cids2]=computeCids(size(chns),opts); % see stDetect.m
331 | % chns=convBox(chns,opts.cellRad); k=opts.nChnFtrs;
332 | % cids1=cids1(k+1:end)-k+1; cids2=cids2(k+1:end)-k+1;
333 | % ftrs=chns(cids1)-chns(cids2);
334 | end
335 |
--------------------------------------------------------------------------------