├── example.png ├── vlfeat.url ├── .gitignore ├── matconvnet.url ├── nets └── nets.url ├── +utls ├── readfile.m └── provision.m ├── README.md ├── setup.m ├── COPYING ├── example.m ├── DeStride.m └── DDet.m /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lenck/ddet/HEAD/example.png -------------------------------------------------------------------------------- /vlfeat.url: -------------------------------------------------------------------------------- 1 | http://www.vlfeat.org/download/vlfeat-0.9.20-bin.tar.gz 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | publish_nets.sh 2 | nets/* 3 | *~ 4 | matconvnet 5 | vlfeat 6 | -------------------------------------------------------------------------------- /matconvnet.url: -------------------------------------------------------------------------------- 1 | https://github.com/vlfeat/matconvnet/archive/master.zip 2 | -------------------------------------------------------------------------------- /nets/nets.url: -------------------------------------------------------------------------------- 1 | http://www.robots.ox.ac.uk/~karel/blobs/ddet-nets.tar.gz 2 | -------------------------------------------------------------------------------- /+utls/readfile.m: -------------------------------------------------------------------------------- 1 | function data = readfile(path) 2 | assert(exist(path, 'file') == 2, 'File %s does not exist.', path); 3 | 4 | fd = fopen(path, 'r'); 5 | data = textscan(fd, '%s', 'delimiter', '\n'); 6 | fclose(fd); 7 | 8 | data = data{1}; 9 | data = cellfun(@strtrim, data, 'UniformOutput', false); 10 | isNonAscii = cellfun(@(s) all(s < 128), data); 11 | assert(all(isNonAscii), 'File %s contains non ascii characters.', path); -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Covariant Feature Detectors 2 | *Karel Lenc and Andrea Vedaldi* 3 | 4 | A source code and network models for a translation covariant detector presented in 5 | ["Learning Covariant Feature Detectors"](https://arxiv.org/abs/1605.01224). 6 | Written in MATLAB using the MatConvNet library. 7 | 8 | This code depends on [VLFeat](http://www.vlfeat.org/) and 9 | [MatConvNet](http://www.vlfeat.org/matconvnet/) library. The script `setup.m` 10 | attempts to download and install those if the libraries are not in the MATLAB 11 | path. The script `setup.m` also downloads the model files. 12 | 13 | Currently contains the following models: 14 | * `./nets/detnet_s1.mat` Densely evaluation model DetNet-S 15 | * `./nets/detnet_s2.mat` DetNet-S evaluated with stride 2 16 | * `./nets/detnet_s4.mat` DetNet-S evalyated with stride 4 17 | 18 | An example how to run the detector is shown in `example.m` which produces the 19 | following figure. 20 | 21 | ![Example detection](./example.png) 22 | -------------------------------------------------------------------------------- /setup.m: -------------------------------------------------------------------------------- 1 | function setup() 2 | % SETUP Setup the environment 3 | 4 | % Copyright (C) 2016 Karel Lenc. 5 | % All rights reserved. 6 | % 7 | % Tishis file is part of the VLFeat library and is made available under 8 | % the terms of the BSD license (see the COPYING file). 9 | 10 | % Setup VLFeat, if not in path 11 | if ~exist('vl_covdet', 'file') 12 | utls.provision('vlfeat.url', 'vlfeat'); 13 | run(fullfile(getlatest('vlfeat', 'vlfeat'), 'toolbox', 'vl_setup.m')); 14 | end 15 | 16 | % Setup MatConvNet, if not in path 17 | if ~exist('vl_nnconv', 'file') 18 | utls.provision('matconvnet.url', 'matconvnet'); 19 | run(fullfile(getlatest('matconvnet', 'matconvnet'), 'matlab', 'vl_setupnn.m')); 20 | 21 | if isempty(strfind(which('vl_nnconv'), mexext)) 22 | fprintf('MatConvNet not compiled. Attempting to run `vl_compilenn` (CPU ONLY!).\n'); 23 | fprintf('To compile with a GPU support, see `help vl_compilenn`.'); 24 | vl_compilenn('EnableImreadJpeg', false); 25 | end 26 | end 27 | 28 | utls.provision(fullfile('nets', 'nets.url'), 'nets'); 29 | end 30 | 31 | function out = getlatest(path, name) 32 | sel_dir = dir(fullfile(path, [name '*'])); 33 | sel_dir = sel_dir([sel_dir.isdir]); 34 | sel_dir = sort({sel_dir.name}); 35 | out = fullfile(path, sel_dir{end}); 36 | end -------------------------------------------------------------------------------- /COPYING: -------------------------------------------------------------------------------- 1 | Copyright (C) 2016 Karel Lenc 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright 10 | notice, this list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the 12 | distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 18 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 19 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 20 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /example.m: -------------------------------------------------------------------------------- 1 | %% Example use of the DDet 2 | setup(); 3 | 4 | %% Detect the features 5 | 6 | net_name = 'detnet_s2.mat'; 7 | net = dagnn.DagNN.loadobj(load(fullfile('nets', net_name))); 8 | 9 | % Uncomment the following lines to compute on a GPU 10 | % (works only if MatConvNet compiled with GPU support) 11 | % gpuDevice(1); net.move('gpu'); 12 | 13 | detector = DDet(net, 'thr', 4); 14 | 15 | im = vl_impattern('box'); 16 | [frames, ~, info] = detector.detect(im); 17 | 18 | %% Plot the results 19 | 20 | figure(1); clf; 21 | subplot(2,2,1); 22 | imshow(repmat(im, 1, 1, 3)); 23 | title('Original image'); 24 | 25 | subplot(2,2,2); 26 | imshow(repmat(im, 1, 1, 3)); 27 | hold on; scatter(frames(1, :), frames(2, :), ... 28 | info.peakScores, info.peakScores, 'filled'); 29 | colormap jet; 30 | title('Detecions'); 31 | text(0, size(im, 1)+10, 'Area represents feature strength.'); 32 | 33 | subplot(2,2,3); 34 | accum = info.im_accum ./ max(info.im_accum(:)); alpha = 0.15; 35 | imshow(cat(3, sqrt(accum)*(1-alpha) + im*alpha, im*alpha, im*alpha)); 36 | axis image; 37 | title('Accummulated locations'); 38 | text(0, size(im, 1)+10, 'SQRT for better visibility'); 39 | 40 | subplot(2,2,4); 41 | imshow(repmat(im, 1, 1, 3)); hold on; 42 | quiver(info.vfield(:,:,1), info.vfield(:,:,2), 0); 43 | title('Vector field of the regressed locations'); 44 | 45 | if ~exist('example.png', 'file') 46 | vl_printsize(1); 47 | print('-dpng', '-r100', 'example.png'); 48 | end -------------------------------------------------------------------------------- /+utls/provision.m: -------------------------------------------------------------------------------- 1 | function downloaded = provision( url_file, tgt_dir) 2 | % PROVISION Provision a binary file from an archive 3 | % PROVISION(URL_FILE, TGT_DIR) Downloads and unpacks the archive from 4 | % URL_FILE to TGT_DIR folder, if not already done. 5 | % 6 | % Uses an empty file: 7 | % TGT_DIR/.URL_FILE_NAME.done 8 | % as an indicator that the folder had been already provisioned. 9 | 10 | % Copyright (C) 2016 Karel Lenc. 11 | % All rights reserved. 12 | % 13 | % Tishis file is part of the VLFeat library and is made available under 14 | % the terms of the BSD license (see the COPYING file). 15 | downloaded = false; 16 | if ~exist(url_file, 'file') 17 | error('Unable to find the URL file %s.', url_file); 18 | end; 19 | [~, url_file_nm] = fileparts(url_file); 20 | done_file = fullfile(tgt_dir, ['.', url_file_nm, '.done']); 21 | if exist(done_file, 'file'), return; end; 22 | [~,~] = mkdir(tgt_dir); 23 | url = utls.readfile(url_file); 24 | for ui = 1:numel(url) 25 | p_unpack(url{ui}, tgt_dir); 26 | end 27 | downloaded = true; 28 | f = fopen(done_file, 'w'); fclose(f); 29 | end 30 | 31 | function p_unpack(url, tgt_dir) 32 | [~, wget_p] = system('which wget'); 33 | if exist(strtrim(wget_p), 'file'); 34 | [~, fname, ext] = fileparts(url); 35 | tar_file = fullfile(tgt_dir, [fname, ext]); 36 | if ~exist(tar_file, 'file') 37 | fprintf(isdeployed+1, 'Downloading %s -> %s.\n', url, tar_file); 38 | ret = system(sprintf('wget %s -O %s', url, tar_file)); 39 | if ret ~= 0 40 | fprintf(isdeployed+1, 'wget failed.'); 41 | delete(tar_file); 42 | end; 43 | end 44 | if exist(tar_file, 'file') 45 | fprintf(isdeployed+1, 'Unpacking %s -> %s. This may take a while...\n', ... 46 | tar_file, tgt_dir); 47 | m_unpack(tar_file, tgt_dir); 48 | else 49 | m_unpack(url, tgt_dir); 50 | end 51 | else 52 | m_unpack(url, tgt_dir); 53 | end 54 | end 55 | 56 | function m_unpack(url, tgt_dir) 57 | fprintf(isdeployed+1, ... 58 | 'Downloading %s -> %s using MATLAB, this may take a while...\n',... 59 | url, tgt_dir); 60 | [~, ~, ext] = fileparts(url); 61 | switch ext 62 | case '.gz' 63 | untar(url, tgt_dir); 64 | case '.zip' 65 | unzip(url, tgt_dir); 66 | end 67 | end -------------------------------------------------------------------------------- /DeStride.m: -------------------------------------------------------------------------------- 1 | classdef DeStride < dagnn.Filter 2 | % DESTRIDE A baisc implementation of a-trous algorithm 3 | % DESTRIDE simply allows to evaluate a network with an operation with 4 | % stride >1 densely. It achieves it by sampling the data with a different 5 | % offset, computing the operation, and joining the data back. In this way 6 | % it achieves to have the same size of the input and output. 7 | 8 | % Copyright (C) 2016 Karel Lenc. 9 | % All rights reserved. 10 | % 11 | % Tishis file is part of the VLFeat library and is made available under 12 | % the terms of the BSD license (see the COPYING file). 13 | 14 | properties 15 | destride; 16 | layer; 17 | end 18 | 19 | properties (Hidden, Transient) 20 | offsets; 21 | end 22 | 23 | methods 24 | function outputs = forward(obj, inputs, params) 25 | out = []; 26 | assert(numel(inputs) == 1); 27 | strd = obj.destride; 28 | for oi = 1:size(obj.offsets, 2) 29 | o = obj.offsets(:, oi); 30 | nin = {inputs{1}(o(2):strd(2):end, o(1):strd(1):end, :, :)}; 31 | out_r = obj.layer.forward(nin, params); 32 | out_r = out_r{1}; 33 | if isempty(out) 34 | out = zeros(size(out_r, 1)*obj.destride(2), ... 35 | size(out_r, 2)*obj.destride(1), size(out_r, 3), ... 36 | size(out_r, 4), 'like', out_r); 37 | end 38 | % Pick the indexes based on the output size -> zero padding if 39 | % needed 40 | oidx_y = ((1:size(out_r, 1)) - 1)*strd(2) + o(2); 41 | oidx_x = ((1:size(out_r, 2)) - 1)*strd(1) + o(1); 42 | out(oidx_y, oidx_x, :, :) = out_r; 43 | end 44 | outputs = {out}; 45 | end 46 | 47 | function kernelSize = getKernelSize(obj) 48 | kernelSize = obj.layer.getKernelSize() ; 49 | end 50 | 51 | function outputSizes = getOutputSizes(obj, inputSizes) 52 | strd = obj.destride; 53 | inputSizes{1}(1) = numel(1:strd(2):inputSizes{1}(1)); 54 | inputSizes{1}(2) = numel(1:strd(1):inputSizes{1}(2)); 55 | outputSizes = obj.layer.getOutputSizes(inputSizes); 56 | outputSizes{1}(1:2) = outputSizes{1}(1:2) .* strd; 57 | assert(numel(outputSizes) == 1); 58 | end 59 | 60 | function set.destride(obj, destride) 61 | obj.destride = destride; 62 | [ofy, ofx] = ndgrid(1:destride(1), 1:destride(2)); 63 | obj.offsets = [ofy(:)'; ofx(:)']; 64 | end 65 | 66 | function rfs = getReceptiveFields(obj) 67 | ks = obj.getKernelSize() ; 68 | y1 = 1 - obj.pad(1) ; 69 | y2 = 1 - obj.pad(1) + ks(1)*obj.destride(1) - 1 ; 70 | x1 = 1 - obj.pad(3) ; 71 | x2 = 1 - obj.pad(3) + ks(2)*obj.destride(2) - 1 ; 72 | h = y2 - y1 + 1 ; 73 | w = x2 - x1 + 1 ; 74 | rfs.size = [h, w] ; 75 | rfs.stride = obj.stride ; 76 | rfs.offset = [y1+y2, x1+x2]/2 ; 77 | end 78 | end 79 | end 80 | -------------------------------------------------------------------------------- /DDet.m: -------------------------------------------------------------------------------- 1 | classdef DDet < handle 2 | %DDET Implementation of the Convariant feature detector. 3 | % DDet implements the local feature detection using the network trained on 4 | % regressing relative translation between two patches. Accummulates the 5 | % relative transformations using biliniear voting. 6 | 7 | % Copyright (C) 2016 Karel Lenc. 8 | % All rights reserved. 9 | % 10 | % Tishis file is part of the VLFeat library and is made available under 11 | % the terms of the BSD license (see the COPYING file). 12 | 13 | properties (SetAccess=public, GetAccess=public) 14 | Opts = struct('thr',3, 'defscale', 3); 15 | Net 16 | Args; 17 | end 18 | 19 | methods 20 | function obj = DDet(net, varargin) 21 | %DDET Construct DDet object 22 | % obj = DDET(NET) construc a DDET using the network NET. Network 23 | % must be a `dagnn.DagNN` object and must have an input `x0a` and 24 | % output `feata`. 25 | % 26 | % Accepts the following options: 27 | % 28 | % `thr` :: 3 29 | % Detection threshold. Number of regressed locations which must have 30 | % voted to the particular location in order to be considered as a valid 31 | % detection. 32 | assert(isa(net, 'dagnn.DagNN'), 'Invalid network'); 33 | assert(ismember('x0a', net.getInputs()), 'Input x0a not found.'); 34 | assert(ismember('x0a', net.getInputs()), 'Input x0a not found.'); 35 | assert(ismember('feata', net.getOutputs()), 'Output feata not found.'); 36 | obj.Net = net; 37 | [obj.Opts, obj.Args] = vl_argparse(obj.Opts, varargin); 38 | end 39 | 40 | function [frames, desc, info] = detect(obj, im) 41 | %DETECT Detect local features in image im. 42 | % FMS = obj.detect(im) Detect features in image im. 43 | % [FMS, ~, INFO] obj.detect(im) Additionally returns an info 44 | % structure with fields: 45 | % 46 | % `im_accum` - The accummulated votes. 47 | % `vfield` - Regressed vector field. 48 | % `peakScores` - Score for each detected feature. 49 | 50 | desc = []; 51 | [pts_im, geom, offs_] = obj.eval(obj.Net, im, obj.Args{:}); 52 | 53 | % Accummulate the votes using bilinear function 54 | conf = double(pts_im(3, :)); 55 | im_sz = [size(im, 1), size(im, 2)]; 56 | im_accum = double(zeros(im_sz)); 57 | pts_im = double(pts_im(1:2, :)); 58 | pts_x = floor(pts_im(1,:)); pts_y = floor(pts_im(2,:)); 59 | x_a = pts_im(1,:) - pts_x; x_b = 1 - x_a; 60 | y_a = pts_im(2,:) - pts_y; y_b = 1 - y_a; 61 | im_accum = vl_binsum(im_accum, conf .* x_b .* y_b, sub2ind(im_sz, pts_y, pts_x)); 62 | im_accum = vl_binsum(im_accum, conf .* x_a .* y_b, sub2ind(im_sz, pts_y, pts_x+1)); 63 | im_accum = vl_binsum(im_accum, conf .* x_a .* y_a, sub2ind(im_sz, pts_y+1, pts_x+1)); 64 | im_accum = vl_binsum(im_accum, conf .* x_b .* y_a, sub2ind(im_sz, pts_y+1, pts_x)); 65 | 66 | pts = imregionalmax(im_accum); 67 | [pts_y, pts_x] = find(pts); 68 | pts_value = im_accum(pts); 69 | pts = [pts_x(:) pts_y(:)]'; 70 | 71 | frames_sel = pts_value > obj.Opts.thr; 72 | frames = pts(:, frames_sel); 73 | 74 | info = struct('im_accum', im_accum); 75 | info.vfield = zeros(size(im, 1), size(im, 2), size(offs_, 3)); 76 | info.vfield(floor(geom.y_offs), floor(geom.x_offs), :) = offs_ .* geom.hsz(1); 77 | info.peakScores = pts_value(frames_sel)'; 78 | 79 | frames = [frames; obj.Opts.defscale * ones(1, size(frames, 2))]; 80 | end 81 | end 82 | 83 | methods (Static) 84 | function [ pts, geom, offs ] = eval( evnet, im, varargin ) 85 | opts.featureLayer = 'x0a'; 86 | opts.outLayer = 'feata'; 87 | opts.gpu = []; 88 | opts = vl_argparse(opts, varargin); 89 | 90 | assert(isfield(evnet.meta, 'data_mean'), 'Missing net.meta.data_mean'); 91 | evnet.vars(evnet.getVarIndex(opts.featureLayer)).precious = true; 92 | 93 | % Image preprocessing 94 | if size(im, 3) == 3, im = rgb2gray(im); end 95 | if isa(im, 'uint8'), im = im2single(im); end 96 | 97 | m = evnet.meta.data_mean; 98 | im_n = bsxfun(@minus, im, m); 99 | 100 | geom.hsz = (evnet.meta.inputSize(1:2)./2)'; 101 | rfs = evnet.getVarReceptiveFields(opts.featureLayer); 102 | outsz = evnet.getVarSizes({opts.featureLayer, size(im)}); 103 | outlayerIdx = evnet.getVarIndex(opts.outLayer); 104 | 105 | %! In this case the x_offs is the centre of the receptive field 106 | geom.x_offs = ((1:outsz{outlayerIdx}(2)) - 1) * ... 107 | rfs(outlayerIdx).stride(1) + geom.hsz(2) + 0.5; 108 | geom.y_offs = ((1:outsz{outlayerIdx}(1)) - 1) * ... 109 | rfs(outlayerIdx).stride(2) + geom.hsz(1) + 0.5; 110 | 111 | [ys, xs] = ndgrid(geom.y_offs, geom.x_offs); 112 | geom.tl_anchors = [xs(:)'; ys(:)']; 113 | 114 | if strcmp(evnet.device, 'gpu'), im_n = gpuArray(im_n); end; 115 | 116 | evnet.eval({opts.featureLayer, single(im_n)}); 117 | pts = gather(evnet.vars(evnet.getVarIndex(opts.outLayer)).value); 118 | pts_size = size(pts); 119 | pts = reshape(pts, [], 2)'; 120 | offs = reshape(pts', pts_size); 121 | 122 | pts = [pts; ones(1, size(pts, 2))]; 123 | pts(1:2,:) = bsxfun(@times, pts(1:2,:), geom.hsz); 124 | pts(1:2,:) = pts(1:2,:) + geom.tl_anchors; 125 | % Remove points reaching out of image 126 | pts(:, pts(1, :) < 1 | pts(2, :) < 1 | ... 127 | pts(1,:) >= size(im, 2) | pts(2,:) >= size(im, 1)) = []; 128 | end 129 | 130 | end 131 | end 132 | --------------------------------------------------------------------------------