├── .gitignore ├── Contents.m ├── README.md ├── bsd.txt ├── models └── forest │ └── .gitignore ├── stChns.m ├── stDemo.m ├── stDetect.m ├── stDetectMex.cpp ├── stEvalBsds.m ├── stGetPatches.m ├── stToEdges.m └── stTrain.m /.gitignore: -------------------------------------------------------------------------------- 1 | *.mat 2 | *.mex* 3 | *.tgz -------------------------------------------------------------------------------- /Contents.m: -------------------------------------------------------------------------------- 1 | % SKETCHTOKENS 2 | % See also 3 | % 4 | % Demo: 5 | % stDemo - Sketch Token demo and usage. 6 | % 7 | % Training code: 8 | % stGetPatches - Sample ground truth edge sketch patches. 9 | % stTrain - Train SketchTokens classifier. 10 | % 11 | % Runtime code: 12 | % stChns - Compute channels for sketch token detection. 13 | % stDetect - Detect sketch tokens in image. 14 | % stEvalBsds - Evaluate sketch token edge detector on BSDS500. 15 | % stToEdges - Convert sketch tokens to edges. 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Sketch Token Toolbox V0.95 2 | ============================== 3 | 4 | This software package provides tools to extract contour-based mid-level features, and to extract contour segmentations from images. 5 | This tool is highly efficient in speed while maintains high accuracy in contour detection. 6 | Also, [1] shows that extracted mid-level features provide additional information for object and pedestrian detections. 7 | 8 | Installation 9 | ------------ 10 |
    11 |
  1. 12 | Download Piotr's Image & Video Matlab Toolbox (http://vision.ucsd.edu/~pdollar/toolbox/doc/)
    13 | SketchTokens/toolbox/ should have channels, classify, filters, images, matlab, etc,. 14 |
  2. 15 | 16 |
  3. 17 | Download Berkeley Segmentation Data Set and Benchmarks 500 (http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html)
    18 | SketchTokens/data/BSR/ should have BSDS500, bench, and documentation. 19 |
  4. 20 | 21 |
  5. 22 | Pre-trained models can be downloaded from: http://people.csail.mit.edu/lim/lzd_cvpr2013/st_data.tgz 23 |
  6. 24 | 25 |
  7. 26 | Look up stDemo.m for how to train and test our code 27 |
  8. 28 |
29 | 30 | References 31 | ---------- 32 | Please cite the following paper if you end up using the code:
33 | [1] Joseph J. Lim, C. Lawrence Zitnick, and Piotr Dollar. "Sketch Tokens: A Learned Mid-level Representation for Contour and and Object Detection," CVPR2013. 34 | 35 | License 36 | ------- 37 | Copyright 2013 Joseph Lim [lim@csail.mit.edu] 38 | 39 | Please email me if you find bugs, or have suggestions or questions! 40 | 41 | Licensed under the Simplified BSD License [see bsd.txt]
42 | 43 | Note: There is a patent pending on the ideas presented in this work so this code should only be used for academic purposes.
44 | -------------------------------------------------------------------------------- /bsd.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012, Joseph Lim 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | 24 | The views and conclusions contained in the software and documentation are those 25 | of the authors and should not be interpreted as representing official policies, 26 | either expressed or implied, of the FreeBSD Project. 27 | -------------------------------------------------------------------------------- /models/forest/.gitignore: -------------------------------------------------------------------------------- 1 | *.mat 2 | -------------------------------------------------------------------------------- /stChns.m: -------------------------------------------------------------------------------- 1 | function chns = stChns( I, opts ) 2 | % Compute channels for sketch token detection. 3 | % 4 | % USAGE 5 | % chns = stChns( I, opts ) 6 | % 7 | % INPUTS 8 | % I - [h x w x 3] color input image 9 | % opts - sketch token model options 10 | % 11 | % OUTPUTS 12 | % chns - [h x w x nChannel] output channels 13 | % 14 | % EXAMPLE 15 | % 16 | % See also stDetect, gradientMag, gradientHist 17 | % 18 | % Sketch Token Toolbox V0.95 19 | % Copyright 2013 Joseph Lim [lim@csail.mit.edu] 20 | % Please email me if you find bugs, or have suggestions or questions! 21 | % Licensed under the Simplified BSD License [see bsd.txt] 22 | 23 | % extract gradient magnitude and histogram channels 24 | if isfield(opts, 'inputColorChannel') && strcmp(opts.inputColorChannel, 'luv') 25 | I = rgbConvert(I,'orig'); 26 | else 27 | I = rgbConvert(I,'luv'); 28 | end 29 | chns=cell(1,1000); 30 | k=1; 31 | chns{k}=I; 32 | for i = 1:length( opts.sigmas ) 33 | if ( opts.sigmas(i)==0 ), 34 | I1=I; 35 | else 36 | f = fspecial('gaussian',opts.radius,opts.sigmas(i)); 37 | I1 = imfilter(I,f); 38 | end 39 | if( opts.nOrients(i)>0 ) 40 | [M,O] = gradientMag( I1, 0, opts.normRad, opts.normConst ); 41 | H = gradientHist( M, O, 1, opts.nOrients(i), 0 ); 42 | k=k+1; 43 | chns{k}=M; 44 | k=k+1; 45 | chns{k}=H; 46 | else 47 | M = gradientMag( I1, 0, opts.normRad, opts.normConst ); 48 | k=k+1; 49 | chns{k}=M; 50 | end 51 | end 52 | chns=cat(3,chns{1:k}); 53 | chns=convTri(chns,opts.chnsSmooth); 54 | end 55 | -------------------------------------------------------------------------------- /stDemo.m: -------------------------------------------------------------------------------- 1 | %% Sketch Token demo and usage. 2 | %% 3 | %% Please cite the following paper if you end up using the code: 4 | %% Joseph J. Lim, C. Lawrence Zitnick, and Piotr Dollar. "Sketch Tokens: A 5 | %% Learned Mid-level Representation for Contour and and Object Detection," 6 | %% CVPR2013. 7 | %% 8 | %% Note: There is a patent pending on the ideas presented in this work so 9 | %% this code should only be used for academic purposes. 10 | %% 11 | %% Sketch Token Toolbox V0.95 12 | %% Copyright 2013 Joseph Lim [lim@csail.mit.edu] 13 | %% Please email me if you find bugs, or have suggestions or questions! 14 | %% Licensed under the Simplified BSD License [see bsd.txt] 15 | 16 | 17 | 18 | %% setup (follow instructions, only need to do once) 19 | cd(fileparts(mfilename('fullpath'))) 20 | if( 0 ) 21 | % (1) Download Berkeley Segmentation Data Set and Benchmarks 500 22 | % http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/ 23 | % BSR/ should contain BSDS500, bench, and documentation 24 | addpath('BSR/bench/benchmarks'); 25 | % (2) Download and compile Piotr's toolbox (directions on website) 26 | % http://vision.ucsd.edu/~pdollar/toolbox/doc/index.html 27 | % (3) Compile code (removed OPTIMFLAGS to compile single core variant) 28 | mex stDetectMex.cpp 'OPTIMFLAGS="$OPTIMFLAGS' '/openmp"' 29 | % (4) Add current directory to path 30 | addpath(pwd); 31 | end 32 | 33 | %% train or load model (see stTrain.m) 34 | if( 0 ) 35 | if( 1 ) 36 | % can be trained in ~20m and requires ~3GB ram 37 | % BSDS500 performance: ODS=0.721 OIS=0.739 AP=0.768 38 | opts=struct('nPos',100,'nNeg',80,'modelFnm','modelSmall','nTrees',20); 39 | else 40 | % takes ~2.5 hours to train and requires ~27GB ram 41 | % BSDS500 performance: ODS=0.727 OIS=0.746 AP=0.780 42 | opts=struct('nPos',1000,'nNeg',800,'modelFnm','modelFull'); 43 | end 44 | tic, model=stTrain(opts); toc 45 | else 46 | % Pre-trained models can be downloaded from: 47 | % http://people.csail.mit.edu/lim/lzd_cvpr2013/st_data.tgz 48 | load('models/forest/modelSmall.mat'); 49 | end 50 | 51 | %% evaluate sketch token model on BSDS500 (see stEvalBsds.m) 52 | if (0), 53 | [ODS,OIS,AP]=stEvalBsds( model ); 54 | end 55 | 56 | %% detect sketch tokens and extract edges (see stDetect.m and stToEdges.m) 57 | I = imread('peppers.png'); 58 | tic; st = stDetect( I, model ); toc; 59 | tic, E = stToEdges( st, 1 ); toc 60 | 61 | %% visualize edge detection results 62 | figure(1); 63 | im(I); 64 | figure(2); 65 | im(E); 66 | colormap jet; 67 | -------------------------------------------------------------------------------- /stDetect.m: -------------------------------------------------------------------------------- 1 | function S = stDetect( I, model, stride, rescale_back ) 2 | % Detect sketch tokens in image. 3 | % 4 | % USAGE 5 | % S = stDetect( I, model, [stride] ) 6 | % 7 | % INPUTS 8 | % I - [h x w x 3] color input image 9 | % model - sketch token model trained with stTrain 10 | % stride - [2] stride at which to compute sketch tokens 11 | % rescale_back - [true] rescale after running stride 12 | % 13 | % OUTPUTS 14 | % S - [h x w x (nTokens+1)] sketch token probability maps 15 | % 16 | % EXAMPLE 17 | % 18 | % See also stTrain, stChns 19 | % 20 | % Sketch Token Toolbox V0.95 21 | % Copyright 2013 Joseph Lim [lim@csail.mit.edu] 22 | % Please email me if you find bugs, or have suggestions or questions! 23 | % Licensed under the Simplified BSD License [see bsd.txt] 24 | 25 | if nargin<3 26 | stride=2; 27 | end 28 | if nargin<4 29 | rescale_back = true; 30 | end 31 | 32 | % compute features 33 | sizeOrig=size(I); 34 | opts=model.opts; 35 | %opts.inputColorChannel = 'luv'; 36 | %opts.inputColorChannel = 'rgb'; 37 | I = imPad(I,opts.radius,'symmetric'); 38 | chns = stChns( I, opts ); 39 | [cids1,cids2] = computeCids(size(chns),opts); 40 | if opts.nCells 41 | chnsSs = convBox(chns,opts.cellRad); 42 | else 43 | chnsSs = []; 44 | end 45 | 46 | % run forest on image 47 | S = stDetectMex( chns, chnsSs, model.thrs, model.fids, model.child, ... 48 | model.distr, cids1, cids2, stride, opts.radius, opts.nChnFtrs ); 49 | 50 | % finalize sketch token probability maps 51 | S = permute(S,[2 3 1]) * (1/opts.nTrees); 52 | if ~rescale_back 53 | %keyboard; 54 | else 55 | S = imResample( S, stride ); 56 | cr=size(S); cr=cr(1:2)-sizeOrig(1:2); 57 | if any(cr) 58 | S=S(1:end-cr(1),1:end-cr(2),:); 59 | end 60 | end 61 | 62 | end 63 | 64 | function [cids1,cids2] = computeCids( siz, opts ) 65 | % construct cids lookup for standard features 66 | radius=opts.radius; 67 | s=opts.patchSiz; 68 | nChns=opts.nChns; 69 | 70 | ht=siz(1); 71 | wd=siz(2); 72 | assert(siz(3)==nChns); 73 | 74 | nChnFtrs=s*s*nChns; 75 | fids=uint32(0:nChnFtrs-1); 76 | rs=mod(fids,s); 77 | fids=(fids-rs)/s; 78 | cs=mod(fids,s); 79 | ch=(fids-cs)/s; 80 | cids = rs + cs*ht + ch*ht*wd; 81 | 82 | % construct cids1/cids2 lookup for self-similarity features 83 | n=opts.nCells; 84 | m=opts.cellStep; 85 | nCellTotal=(n*n)*(n*n-1)/2; 86 | 87 | assert((n==0) || (mod(n,2)==1)); 88 | n1=(n-1)/2; 89 | nSimFtrs=nCellTotal*nChns; 90 | fids=uint32(0:nSimFtrs-1); 91 | ind=mod(fids,nCellTotal); 92 | ch=(fids-ind)/nCellTotal; 93 | 94 | ind1 = []; ind2 = []; 95 | k=0; 96 | for i=1:n*n-1, 97 | k1=n*n-i; 98 | ind1(k+1:k+k1)=(0:k1-1); 99 | k=k+k1; 100 | end 101 | k=0; 102 | for i=1:n*n-1, 103 | k1=n*n-i; 104 | ind2(k+1:k+k1)=(0:k1-1)+i; 105 | k=k+k1; 106 | end 107 | 108 | ind1=ind1(ind+1); 109 | rs1=mod(ind1,n); 110 | cs1=(ind1-rs1)/n; 111 | ind2=ind2(ind+1); 112 | rs2=mod(ind2,n); 113 | cs2=(ind2-rs2)/n; 114 | 115 | rs1=uint32((rs1-n1)*m+radius); 116 | cs1=uint32((cs1-n1)*m+radius); 117 | rs2=uint32((rs2-n1)*m+radius); 118 | cs2=uint32((cs2-n1)*m+radius); 119 | 120 | cids1 = rs1 + cs1*ht + ch*ht*wd; 121 | cids2 = rs2 + cs2*ht + ch*ht*wd; 122 | 123 | % combine cids for standard and self-similarity features 124 | cids1=[cids cids1]; 125 | cids2=[zeros(1,nChnFtrs) cids2]; 126 | end 127 | -------------------------------------------------------------------------------- /stDetectMex.cpp: -------------------------------------------------------------------------------- 1 | /******************************************************************************* 2 | * Sketch Token Toolbox V0.95 3 | * Copyright 2013 Joseph Lim [lim@csail.mit.edu] 4 | * Please email me if you find bugs, or have suggestions or questions! 5 | * Licensed under the Simplified BSD License [see bsd.txt] 6 | *******************************************************************************/ 7 | #include "mex.h" 8 | #include 9 | #include 10 | 11 | typedef unsigned int uint32; 12 | 13 | void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] ) 14 | { 15 | // get inputs 16 | float *chns = (float*) mxGetData(prhs[0]); 17 | float *chnsSs = (float*) mxGetData(prhs[1]); 18 | float *thrs = (float*) mxGetData(prhs[2]); 19 | uint32 *fids = (uint32*) mxGetData(prhs[3]); 20 | uint32 *child = (uint32*) mxGetData(prhs[4]); 21 | float *distr = (float*) mxGetData(prhs[5]); 22 | uint32 *cids1 = (uint32*) mxGetData(prhs[6]); 23 | uint32 *cids2 = (uint32*) mxGetData(prhs[7]); 24 | const int stride = (int) mxGetScalar(prhs[8]); 25 | const int rad = (int) mxGetScalar(prhs[9]); 26 | const int nChnFtrs = (int) mxGetScalar(prhs[10]); 27 | 28 | // get dimensions and constants 29 | const mwSize *chnsSize = mxGetDimensions(prhs[0]); 30 | const int height = (int) chnsSize[0]; 31 | const int width = (int) chnsSize[1]; 32 | const mwSize *distrSize = mxGetDimensions(prhs[5]); 33 | const int nTokens = (int) distrSize[0]; 34 | const int nTreeNodes = (int) distrSize[1]; 35 | const int nTrees = (int) distrSize[2]; 36 | const int heightOut = (int) ceil((height-rad*2.0)/stride); 37 | const int widthOut = (int) ceil((width-rad*2.0)/stride); 38 | 39 | // create output 40 | const int outDims[3]={nTokens,heightOut,widthOut}; 41 | float *S = (float*) mxCalloc(nTokens*heightOut*widthOut,sizeof(float)); 42 | plhs[0] = mxCreateNumericMatrix(0,0,mxSINGLE_CLASS,mxREAL); 43 | mxSetData(plhs[0],S); 44 | mxSetDimensions(plhs[0],outDims,3); 45 | 46 | // apply forest to each patch 47 | #pragma omp parallel for 48 | for( int c=0; c 88 | 89 | I = imread([imgDir id '.jpg']); 90 | st = max(min(1,stDetect(I,model)),0); 91 | E = stToEdges(st,1); 92 | 93 | imwrite(uint8(E*255),[resDir id '.png']); 94 | end 95 | 96 | % perform evaluation on each image (Linux only, slow) 97 | if(ispc), 98 | error('Evaluation code runs on Linux ONLY.'); 99 | end 100 | do=false(1,n); 101 | jobs=cell(1,n); 102 | for i=1:n, 103 | do(i)=~exist([evalDir ids{i} '_ev1.txt'],'file'); 104 | end 105 | for i=1:n, 106 | id=ids{i}; 107 | jobs{i}={[resDir id '.png'],... 108 | [gtDir id '.mat'],[evalDir id '_ev1.txt'],p.nThresh}; 109 | end 110 | if(~exist(evalDir,'dir')), 111 | mkdir(evalDir); 112 | end 113 | fevalDistr('evaluation_bdry_image',jobs(do),p.pDistr{:}); 114 | 115 | % collect results and display 116 | [ODS,~,~,~,OIS,~,~,AP]=collect_eval_bdry(evalDir); 117 | fprintf('ODS=%.3f OIS=%.3f AP=%.3f\n',ODS,OIS,AP); 118 | if( p.show ), 119 | plot_eval(evalDir,'r'); 120 | end 121 | if( p.cleanup ), 122 | delete([evalDir '/*_ev1.txt']); 123 | delete([evalDir '/eval_bdry_img.txt']); 124 | delete([evalDir '/eval_bdry_thr.txt']); 125 | delete([resDir '/*.png']); rmdir(resDir); 126 | end 127 | 128 | end 129 | -------------------------------------------------------------------------------- /stGetPatches.m: -------------------------------------------------------------------------------- 1 | function clusters = stGetPatches( radius, nPatches, bsdsDir ) 2 | % Sample ground truth edge sketch patches. 3 | % 4 | % Calling stGetPatches() is *optional* as the code package comes with 5 | % pre-computed clusters.mat. stGetPatches() only needs to be called 6 | % if you want to generate alternate clusters to the ones provided. 7 | % 8 | % stGetPatches() is the first step in generating sketch tokens classes: it 9 | % is used to sample patches of ground truth edge maps from the training set 10 | % of the Berkeley Segmentation Dataset. After sampling, the extracted 11 | % patches must be clustered to generate the sketch token classes. 12 | % Clustering code is *NOT* provided. It should be easy to implement, see 13 | % the paper for more details. After clustering, the additional fields 14 | % "clusterId" and "clusters" should be initialized in the clusters struct. 15 | % "clusterId" should indicate cluster membership of each extracted patch 16 | % (integer between 1 and K) and "cluters" should be the mean of the patches 17 | % belonging to the given cluster. After these two fields are added to the 18 | % clusters struct, the sketch token model is ready to be trained via 19 | % stTrain() (see parameter "clusterFnm" in stTrain.m). 20 | % 21 | % USAGE 22 | % clusters = stGetPatches( [radius], [nPatches], [bsdsDir] ) 23 | % 24 | % INPUTS 25 | % radius - [15] radius of sketch token patches 26 | % nPatches - [inf] maximum number of patches to sample 27 | % bsdsDir - ['BSR/BSDS500/data/'] location of BSDS dataset 28 | % 29 | % OUTPUTS 30 | % clusters - extracted ground truth info w the following fields 31 | % .x - [Nx1] x-coordinate each sampled patch 32 | % .y - [Nx1] y-coordinate each sampled patch 33 | % .gtId - [Nx1] integer ground turth labeler of each sampled patch 34 | % .imId - [Nx1] integer image id of each sampled patch 35 | % .patches - [PxPxN] binary images of sampled patches, P=2*radius+1 36 | % .clusterId - [Nx1] cluster membership in [1,K] (NOT COMPUTED) 37 | % .clusters - [PxPxK] cluster images (NOT COMPUTED) 38 | % 39 | % EXAMPLE 40 | % 41 | % See also stTrain 42 | % 43 | % Sketch Token Toolbox V0.95 44 | % Copyright 2013 Joseph Lim [lim@csail.mit.edu] 45 | % Please email me if you find bugs, or have suggestions or questions! 46 | % Licensed under the Simplified BSD License [see bsd.txt] 47 | 48 | if( nargin<1 ), 49 | radius=15; 50 | end 51 | if( nargin<2 ), 52 | nPatches=inf; 53 | end 54 | if( nargin<3 ), 55 | bsdsDir='BSR/BSDS500/data/'; 56 | end 57 | 58 | % location of ground truth 59 | trnImgDir = [bsdsDir '/images/train/']; 60 | trnGtDir = [bsdsDir '/groundTruth/train/']; 61 | imgIds=dir([trnImgDir '*.jpg']); 62 | imgIds={imgIds.name}; 63 | nImgs=length(imgIds); 64 | for i=1:nImgs, 65 | imgIds{i}=imgIds{i}(1:end-4); 66 | end 67 | 68 | % loop over ground truth and collect samples 69 | clusters=struct('x',[],'y',[],'gtId',[],'imId',[],'patches',[],... 70 | 'clusterId',[],'clusters',[]); 71 | clusters.patches = false(radius*2+1,radius*2+1,9000*5*nImgs); 72 | tid = ticStatus('data collection'); 73 | cnt=0; 74 | for i = 1:nImgs 75 | gt=load([trnGtDir imgIds{i} '.mat']); 76 | gt=gt.groundTruth; 77 | for j=1:length(gt) 78 | if(isempty(gt{j}.Boundaries)), 79 | continue; 80 | end 81 | M0 = gt{j}.Boundaries; 82 | M=M0; 83 | M([1:radius end-radius+1:end],:)=0; 84 | M(:,[1:radius end-radius+1:end])=0; 85 | [y,x]=find(M); 86 | cnt1=length(y); 87 | clusters.y = [clusters.y; int32(y)]; 88 | clusters.x = [clusters.x; int32(x)]; 89 | clusters.gtId = [clusters.gtId; ones(cnt1,1,'int32')*j]; 90 | clusters.imId = [clusters.imId; ones(cnt1,1,'int32')*i]; 91 | for k=1:cnt1, 92 | clusters.patches(:,:,cnt+k) = M0(y(k)-radius:y(k)+radius,x(k)-radius:x(k)+radius); 93 | end 94 | cnt = cnt + cnt1; 95 | end 96 | tocStatus(tid, i/nImgs); 97 | end 98 | clusters.patches = clusters.patches(:,:,1:cnt); 99 | 100 | % optionally sample patches 101 | if( nPatches0; 91 | Dy(F)=-Dy(F); 92 | O=mod(atan2(Dy,Dx),pi); 93 | end 94 | -------------------------------------------------------------------------------- /stTrain.m: -------------------------------------------------------------------------------- 1 | function model = stTrain( varargin ) 2 | % Train SketchTokens classifier. 3 | % 4 | % See stDemo for a full demo that include both traininga and application. 5 | % 6 | % Pre-trained models can be downloaded from: 7 | % http://people.csail.mit.edu/lim/lzd_cvpr2013/st_data.tgz 8 | % 9 | % Please cite the following paper if you end up using the code: 10 | % Joseph J. Lim, C. Lawrence Zitnick, and Piotr Dollar. "Sketch Tokens: A 11 | % Learned Mid-level Representation for Contour and and Object Detection," 12 | % CVPR2013. 13 | % 14 | % Note: There is a patent pending on the ideas presented in this work so 15 | % this code should only be used for academic purposes. 16 | % 17 | % USAGE 18 | % model = stTrain( opts ) 19 | % 20 | % INPUTS 21 | % opts - parameters (struct or name/value pairs) 22 | % (1) parameters for model and data: 23 | % .nClusters - [150] number of clusters to train with 24 | % .nTrees - [25] number of trees in forest to train 25 | % .radius - [17] radius of sketch token patches 26 | % .nPos - [1000] number of positive patches per cluster 27 | % .nNeg - [800] number of negative patches per image 28 | % .negDist - [2] distance from closest contour defining a negative 29 | % .minCount - [4] minimum number of training examples per node 30 | % (2) parameters for features: 31 | % .nCells - [5] number of self similarity cells 32 | % .normRad - [5] normalization radius (see gradientMag) 33 | % .normConst - [.01] normalization constant (see gradientMag) 34 | % .nOrients - [4 4 0] number of orientations for each channel set 35 | % .sigmas - [0 1.5 5] gaussian blur for each channel set 36 | % .chnsSmooth - [2] radius for channel smoothing (using convTri) 37 | % .fracFtrs - [1] fraction of features to use to train each tree 38 | % (3) other parameters: 39 | % .seed - [1] seed for random stream (for reproducibility) 40 | % .modelDir - ['models/'] target directory for storing models 41 | % .modelFnm - ['model'] model filename 42 | % .clusterFnm - ['clusters.mat'] file containing cluster info 43 | % .bsdsDir - ['BSR/BSDS500/data/'] location of BSDS dataset 44 | % 45 | % OUTPUTS 46 | % model - trained sketch token detector w the following fields 47 | % .trees - learned forest model struct array (see forestTrain) 48 | % .opts - input parameters and constants 49 | % .clusters - actual cluster centers used to learn tokens 50 | % 51 | % EXAMPLE 52 | % 53 | % See also stGetPatches, stDetect, forestTrain, chnsCompute, gradientMag 54 | % 55 | % Sketch Token Toolbox V0.95 56 | % Copyright 2013 Joseph Lim [lim@csail.mit.edu] 57 | % Please email me if you find bugs, or have suggestions or questions! 58 | % Licensed under the Simplified BSD License [see bsd.txt] 59 | 60 | % get default parameters 61 | dfs={'nClusters',150, 'nTrees',25, 'radius',17, 'nPos',1000, 'nNeg',800,... 62 | 'negDist',2, 'minCount',4, 'nCells',5, 'normRad',5, 'normConst',0.01, ... 63 | 'nOrients',[4 4 0], 'sigmas',[0 1.5 5], 'chnsSmooth',2, 'fracFtrs',1, ... 64 | 'seed',1, 'modelDir','models/', 'modelFnm','model', ... 65 | 'clusterFnm','clusters.mat', 'bsdsDir','BSR/BSDS500/data/'}; 66 | opts = getPrmDflt(varargin,dfs,1); 67 | 68 | % if forest exists load it and return 69 | cd(fileparts(mfilename('fullpath'))); 70 | forestDir = [opts.modelDir '/forest/']; 71 | forestFn = [forestDir opts.modelFnm]; 72 | if exist([forestFn '.mat'], 'file') 73 | load([forestFn '.mat']); 74 | return; 75 | end 76 | 77 | % compute constants and store in opts 78 | nTrees=opts.nTrees; 79 | nCells=opts.nCells; 80 | 81 | patchSiz=opts.radius*2+1; 82 | opts.patchSiz=patchSiz; 83 | 84 | nChns = size(stChns(ones(2,2,3),opts),3); 85 | opts.nChns=nChns; 86 | 87 | opts.nChnFtrs = patchSiz*patchSiz*nChns; 88 | opts.nSimFtrs = (nCells*nCells)*(nCells*nCells-1)/2*nChns; 89 | opts.nTotFtrs = opts.nChnFtrs + opts.nSimFtrs; 90 | opts.cellRad = round(patchSiz/nCells/2); 91 | tmp=opts.cellRad*2+1; 92 | opts.cellStep = tmp-ceil((nCells*tmp-patchSiz)/(nCells-1)); disp(opts); 93 | assert( (nCells == 0) || (mod(nCells,2)==1 && (nCells-1)*opts.cellStep+tmp <= patchSiz )); 94 | 95 | % generate stream for reproducibility of model 96 | stream=RandStream('mrg32k3a','Seed',opts.seed); 97 | 98 | % train nTrees random trees (can be trained with parfor if enough memory) 99 | for i=1:nTrees 100 | stTrainTree( opts, stream, i ); 101 | end 102 | 103 | % accumulate trees and merge into final model 104 | treeFn = [opts.modelDir '/tree/' opts.modelFnm '_tree']; 105 | for i=1:nTrees 106 | t=load([treeFn int2str2(i,3) '.mat'],'tree'); 107 | t=t.tree; 108 | if (i==1) 109 | trees=t(ones(1,nTrees)); 110 | else 111 | trees(i)=t; 112 | end 113 | end 114 | nNodes=0; 115 | for i=1:nTrees 116 | nNodes=max(nNodes,size(trees(i).fids,1)); 117 | end 118 | model.thrs=zeros(nNodes,nTrees,'single'); 119 | Z=zeros(nNodes,nTrees,'uint32'); 120 | model.fids=Z; 121 | model.child=Z; 122 | model.count=Z; 123 | model.depth=Z; 124 | model.distr=zeros(nNodes,size(trees(1).distr,2),nTrees,'single'); 125 | for i=1:nTrees, tree=trees(i); nNodes1=size(tree.fids,1); 126 | model.fids(1:nNodes1,i) = tree.fids; 127 | model.thrs(1:nNodes1,i) = tree.thrs; 128 | model.child(1:nNodes1,i) = tree.child; 129 | model.distr(1:nNodes1,:,i) = tree.distr; 130 | model.count(1:nNodes1,i) = tree.count; 131 | model.depth(1:nNodes1,i) = tree.depth; 132 | end 133 | model.distr = permute(model.distr, [2 1 3]); 134 | 135 | clusters=load(opts.clusterFnm); 136 | clusters=clusters.clusters; 137 | 138 | model.opts = opts; 139 | model.clusters=clusters.clusters; 140 | if ~exist(forestDir,'dir') 141 | mkdir(forestDir); 142 | end 143 | save([forestFn '.mat'], 'model', '-v7.3'); 144 | 145 | end 146 | 147 | function stTrainTree( opts, stream, treeInd ) 148 | % Train a single tree in forest model. 149 | 150 | % location of ground truth 151 | trnImgDir = [opts.bsdsDir '/images/train/']; 152 | trnGtDir = [opts.bsdsDir '/groundTruth/train/']; 153 | imgIds=dir([trnImgDir '*.jpg']); 154 | imgIds={imgIds.name}; 155 | nImgs=length(imgIds); 156 | for i=1:nImgs, 157 | imgIds{i}=imgIds{i}(1:end-4); 158 | end 159 | 160 | % extract commonly used options 161 | radius=opts.radius; 162 | patchSiz=opts.patchSiz; 163 | nChns=opts.nChns; 164 | nTotFtrs=opts.nTotFtrs; 165 | nClusters=opts.nClusters; 166 | nPos=opts.nPos; 167 | nNeg=opts.nNeg; 168 | 169 | % finalize setup 170 | treeDir = [opts.modelDir '/tree/']; 171 | treeFn = [treeDir opts.modelFnm '_tree']; 172 | if exist([treeFn int2str2(treeInd,3) '.mat'],'file') 173 | return; 174 | end 175 | fprintf('\n-------------------------------------------\n'); 176 | fprintf('Training tree %d of %d\n',treeInd,opts.nTrees); 177 | tStart=clock; 178 | 179 | % set global stream to stream with given substream (will undo at end) 180 | streamOrig = RandStream.getGlobalStream(); 181 | set(stream,'Substream',treeInd); 182 | RandStream.setGlobalStream( stream ); 183 | 184 | % sample nPos positive patch locations per cluster 185 | clstr=load(opts.clusterFnm); 186 | clstr=clstr.clusters; 187 | for i = 1:nClusters 188 | if i==1 189 | centers=[]; 190 | end 191 | ids = find(clstr.clusterId == i); 192 | ids = ids(randperm(length(ids),min(nPos,length(ids)))); 193 | centers = [centers; [clstr.x(ids),clstr.y(ids),clstr.imId(ids),... 194 | clstr.clusterId(ids),clstr.gtId(ids)]]; %#ok 195 | end 196 | 197 | % collect positive and negative patches and compute features 198 | fids=sort(randperm(nTotFtrs,round(nTotFtrs*opts.fracFtrs))); 199 | k = size(centers,1)+nNeg*nImgs; 200 | ftrs = zeros(k,length(fids),'single'); 201 | labels = zeros(k,1); k = 0; 202 | tid = ticStatus('Collecting data',1,1); 203 | for i = 1:nImgs 204 | % get image and compute channels 205 | gt=load([trnGtDir imgIds{i} '.mat']); 206 | gt=gt.groundTruth; 207 | 208 | I = imread([trnImgDir imgIds{i} '.jpg']); 209 | I = imPad(I,radius,'symmetric'); 210 | chns = stChns(I,opts); 211 | 212 | % sample positive patch locations 213 | centers1=centers(centers(:,3)==i,:); 214 | lbls1=centers1(:,4); 215 | xy1=single(centers1(:,[1 2])); 216 | 217 | % sample negative patch locations 218 | M=false(size(I,1)-2*radius,size(I,2)-2*radius); 219 | nGt=length(gt); 220 | for j=1:nGt 221 | M1=gt{j}.Boundaries; 222 | if ~isempty(M1) 223 | M=M | M1; 224 | end 225 | end 226 | M(bwdist(M)0) = fids(tree.fids(tree.child>0)+1)-1; 266 | tree=pruneTree(tree,opts.minCount); %#ok 267 | if ~exist(treeDir,'dir') 268 | mkdir(treeDir); 269 | end 270 | save([treeFn int2str2(treeInd,3) '.mat'],'tree'); 271 | e=etime(clock,tStart); 272 | fprintf('Training of tree %d complete (time=%.1fs).\n',treeInd,e); 273 | RandStream.setGlobalStream( streamOrig ); 274 | 275 | end 276 | 277 | function tree = pruneTree( tree, minCount ) 278 | % Prune all nodes whose count is less than minCount. 279 | 280 | % mark all internal nodes if either child has count<=minCount 281 | mark = [0; tree.count<=minCount]; 282 | mark = mark(tree.child+1) | mark(tree.child+2); 283 | 284 | % list of nodes to be discarded / kept 285 | disc=tree.child(mark); 286 | disc=[disc; disc+1]; 287 | n=length(tree.fids); 288 | keep=1:n; 289 | keep(disc)=[]; 290 | 291 | % prune tree 292 | tree.fids=tree.fids(keep); 293 | tree.thrs=tree.thrs(keep); 294 | tree.child=tree.child(keep); 295 | tree.distr=tree.distr(keep,:); 296 | tree.count=tree.count(keep); 297 | tree.depth=tree.depth(keep); 298 | assert(all(tree.count>minCount)) 299 | 300 | % re-index children 301 | route=zeros(1,n); 302 | route(keep)=1:length(keep); 303 | tree.child(tree.child>0) = route(tree.child(tree.child>0)); 304 | end 305 | 306 | function ftrs = stComputeSimFtrs( chns, opts ) 307 | % Compute self-similarity features. 308 | n=opts.nCells; 309 | if(n==0), 310 | ftrs=[]; 311 | return; 312 | end 313 | nSimFtrs=opts.nSimFtrs; 314 | nChns=opts.nChns; 315 | m=size(chns,4); 316 | 317 | inds = ((1:n)-(n+1)/2)*opts.cellStep+opts.radius+1; 318 | chns=reshape(chns,opts.patchSiz,opts.patchSiz,nChns*m); 319 | chns=convBox(chns,opts.cellRad); 320 | chns=reshape(chns(inds,inds,:,:),n*n,nChns,m); 321 | ftrs=zeros(nSimFtrs/nChns,nChns,m,'single'); 322 | k=0; 323 | for i=1:n*n-1 324 | k1=n*n-i; 325 | ftrs(k+1:k+k1,:,:)=chns(1:end-i,:,:)-chns(i+1:end,:,:); 326 | k=k+k1; 327 | end 328 | ftrs = reshape(ftrs,nSimFtrs,m)'; 329 | % % For m=1, the above should be identical to the following: 330 | % [cids1,cids2]=computeCids(size(chns),opts); % see stDetect.m 331 | % chns=convBox(chns,opts.cellRad); k=opts.nChnFtrs; 332 | % cids1=cids1(k+1:end)-k+1; cids2=cids2(k+1:end)-k+1; 333 | % ftrs=chns(cids1)-chns(cids2); 334 | end 335 | --------------------------------------------------------------------------------