├── LICENSE ├── README.md ├── fastap_github.jpg ├── matlab ├── +imdb │ ├── inshop.m │ ├── products.m │ └── vid.m ├── +models │ ├── googlenet.m │ ├── resnet18.m │ └── resnet50.m ├── +solver │ ├── adadelta.m │ ├── adagrad.m │ ├── adam.m │ └── rmsprop.m ├── FastAP.m ├── README.md ├── evaluate_model.m ├── get_model.m ├── get_opts.m ├── run_demo.m ├── startup.m ├── train_hard.m ├── train_hard_vid.m ├── train_rand.m └── util │ ├── catstruct.m │ ├── clearMex.m │ ├── logInfo.m │ ├── prepareGPUs.m │ └── record_diary.m └── pytorch ├── FastAP_loss.py └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018-2019 Fatih Cakir, Kun He, Xide Xia, Brian Kulis, Stan Sclaroff 4 | 5 | If used for academic purposes, please cite the following paper: 6 | 7 | "Deep Metric Learning to Rank" 8 | Fatih Cakir(*), Kun He(*), Xide Xia, Brian Kulis, and Stan Sclaroff 9 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in all 19 | copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 27 | SOFTWARE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastAP: Deep Metric Learning to Rank 2 | This repository contains implementation of the following paper: 3 | 4 | [Deep Metric Learning to Rank](http://openaccess.thecvf.com/content_CVPR_2019/html/Cakir_Deep_Metric_Learning_to_Rank_CVPR_2019_paper.html)
5 | [Fatih Cakir](http://cs-people.bu.edu/fcakir/)\*, [Kun He](http://cs-people.bu.edu/hekun/)\*, [Xide Xia](https://xidexia.github.io), [Brian Kulis](http://people.bu.edu/bkulis/), and [Stan Sclaroff](http://www.cs.bu.edu/~sclaroff/) (*equal contribution)
6 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019 7 | 8 | ![](fastap_github.jpg) 9 | 10 | ## Other Implementations 11 | [FastAPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#fastaploss) from [pytorch-metric-learning](https://github.com/KevinMusgrave/pytorch-metric-learning) 12 | 13 | ## Usage 14 | * **Matlab**: see `matlab/README.md` 15 | * **PyTorch**: see `pytorch/README.md` 16 | 17 | ## Datasets 18 | * Stanford Online Products 19 | * Can be downloaded [here](http://cvgl.stanford.edu/projects/lifted_struct/) 20 | * In-Shop Clothes Retrieval 21 | * Can be downloaded [here](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html) 22 | * PKU VehicleID 23 | * Please request the dataset from the authors [here](https://pkuml.org/resources/pku-vehicleid.html) 24 | 25 | ## Reproducibility 26 | * We provide trained [MatConvNet](https://www.vlfeat.org/matconvnet/quick/) models and experimental logs for the results in the paper. These models were used to achieve the results in the tables. 27 | * The logs also include parameters settings that enable one to re-train a model if desired. It also includes evaluation results with model checkpoints at certain epochs. 28 | * Table 1: Stanford Online Products 29 | * FastAP, ResNet-18, M=256, Dim=512: [[model @ epoch 20](https://drive.google.com/file/d/1sPCG34rV4Bqf0aWF7GrFIDUK7DGjcaB5/view?usp=sharing), [log](https://drive.google.com/open?id=14m3fHgeZu8MIAePFHXe141R60KRwH1d8)] 30 | * FastAP, ResNet-50, M=96, Dim=128: [[model @ epoch 30](https://drive.google.com/open?id=1yGUVTskdERdLeF85GP-lLRnwS0KhpvvL), [log](https://drive.google.com/open?id=1A0G1aUBS7URotInbT7eBbys4xfARvCCe)] 31 | * FastAP, ResNet-50, M=96, Dim=512: [[model @ epoch 28](https://drive.google.com/file/d/14yEyAYhGzNygBBn8r2RcZut_Ye14mJoK/view?usp=sharing), [log](https://drive.google.com/open?id=19mpLn1OqA2nqpMZtvZ3GvFk_VppOOPTc)] 32 | * FastAP, ResNet-50, M=256, Dim=512: [[model @ epoch 12](https://drive.google.com/open?id=1WfV1ArXHG4oksHGE8DZRDwsxupoO60sD), [log](https://drive.google.com/open?id=1shvC5qB8O0l6vH1qa2SG1oxX8_jM1gdi)] 33 | * Table 2: In-Shop Clothes 34 | * FastAP, ResNet-18, M=256, Dim=512: [[model @ epoch 50](https://drive.google.com/open?id=1ZZ-Fpx9uPkRL-QXL-8-RcROjQVOLcbr5), [log](https://drive.google.com/file/d/1osxoHsMy11v-kvUNTuRG3luhsxMMn78B/view?usp=sharing)] 35 | * FastAP, ResNet-50, M=96, Dim=512: [[model @ epoch 40](https://drive.google.com/open?id=1PyiHog7fJp_InvqdAO0dzyJDRMNvAXxm), [log](https://drive.google.com/open?id=14IPgDfkbKo9PnrgMFFRDSIBRW1xwRYs5)] 36 | * FastAP, ResNet-50, M=256, Dim=512: [[model @ epoch 35](https://drive.google.com/open?id=1T5IynM63YqnGslnMGppJsmtJdHIWamJv), [log](https://drive.google.com/open?id=1oud9i87FTJE7Ei636bjxqgBXpahysRKK)] 37 | * Table 3: PKU VehicleID 38 | * FastAP, ResNet-18, M=256, Dim=512: [[model @ epoch 50](https://drive.google.com/open?id=1KsUF2SzkhvBOkHzbrXKj7H5KtN6Z3hRJ), [log](https://drive.google.com/open?id=155Ce-FmI6dmMgJnXdESVHx08unU3jWX2)] 39 | * FastAP, ResNet-50, M=96, Dim=512: [[model @ epoch 40](https://drive.google.com/open?id=1AblJelRHStBfWwmZeoRM8iEpNRNdOobn), [log](https://drive.google.com/open?id=1twswLE-j9kLxUsk5Ku7vWqBp0Sml65sG)] 40 | * FastAP, ResNet-50, M=256, dim=512: [[model @ epoch 30](https://drive.google.com/open?id=1MAimhKEyEfq2LDYnUaDburFH2YsUhrpA), [log](https://drive.google.com/open?id=1CtNk-wxSZToO703OvfK8ndFQzFnpOvVS)] 41 | 42 | (M=mini-batch size) 43 | * PyTorch code is a direct port from our MATLAB implementation. We haven't tried reproducing the paper results with our PyTorch code. **For reproducibility use the MATLAB version**. 44 | * Note that the mini-batch sampling strategy must also be used alongside the FastAP loss for good results. 45 | 46 | ## Contact 47 | For questions and comments, feel free to contact: kunhe@fb.com or fcakirs@gmail.com 48 | 49 | ## License 50 | MIT 51 | -------------------------------------------------------------------------------- /fastap_github.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunhe/FastAP-metric-learning/ca85b2ab3f22460795d4306c8fa01c655d17027e/fastap_github.jpg -------------------------------------------------------------------------------- /matlab/+imdb/inshop.m: -------------------------------------------------------------------------------- 1 | function imdb = inshop(prefix, opts) 2 | savefn = fullfile(opts.dataDir, [prefix '.mat']); 3 | if exist(savefn, 'file') 4 | imdb = load(savefn); 5 | return; 6 | end 7 | inshopdir = fullfile(opts.dataDir, 'InShopClothes'); 8 | 9 | % meta-classes 10 | metaFile = fullfile(inshopdir, 'meta_classes.txt'); 11 | metaClasses = textread(metaFile, '%s'); 12 | 13 | % parse image names/labels 14 | imgsFile = fullfile(inshopdir, 'Anno/list_bbox_inshop.txt'); 15 | [imgs, cType, pType, ~, ~, ~, ~] = textread(imgsFile, '%s %d %d %d %d %d %d', ... 16 | 'headerlines', 2); 17 | tmp = cellfun(@(x) strsplit(x, '/'), imgs, 'uni', false); 18 | cls = cellfun(@(x) sscanf(x{4}, 'id_%d'), tmp); 19 | meta = cellfun(@(x) [x{2} '/' x{3}], tmp, 'uni', false); 20 | [~, meta] = ismember(meta, metaClasses); 21 | 22 | % train/test 23 | splitFile = fullfile(inshopdir, 'list_eval_partition.txt'); 24 | [img_sp, ~, sp] = textread(splitFile, '%s %s %s', 'headerlines', 2); 25 | [~, loc] = ismember(imgs, img_sp); 26 | [~, set] = ismember(sp(loc), {'train', 'query', 'gallery'}); 27 | 28 | % images are already 256x256, nothing to do 29 | 30 | imdb.images.data = fullfile(inshopdir, imgs); 31 | imdb.images.labels = single([cls meta cType pType]) ; 32 | imdb.images.set = uint8(set) ; 33 | imdb.meta.sets = {'train', 'query', 'gallery'} ; 34 | save(savefn, '-struct', 'imdb'); 35 | end 36 | -------------------------------------------------------------------------------- /matlab/+imdb/products.m: -------------------------------------------------------------------------------- 1 | function imdb = products(prefix, opts) 2 | savefn = fullfile(opts.dataDir, [prefix '.mat']); 3 | if exist(savefn, 'file') 4 | imdb = load(savefn); 5 | return; 6 | end 7 | stanforddir = fullfile(opts.dataDir, 'Stanford_Online_Products'); 8 | 9 | % train/test 10 | [id, y1, y2, x] = textread([stanforddir '/Ebay_train.txt'], '%d %d %d %s', ... 11 | 'headerlines', 1); 12 | fn = x; 13 | labels = [y1 y2]; 14 | set = ones(numel(id), 1); 15 | logInfo('train: %d', numel(x)); 16 | 17 | [id, y1, y2, x] = textread([stanforddir '/Ebay_test.txt'], '%d %d %d %s', ... 18 | 'headerlines', 1); 19 | fn = [fn; x]; 20 | labels = [labels; y1 y2]; 21 | set = [set; 3*ones(numel(id), 1)]; 22 | logInfo(' test: %d', numel(x)); 23 | 24 | % resize to 256x256 25 | a = tic; tic; 26 | stanford256 = fullfile(stanforddir, 'images_256x256'); 27 | if ~exist(stanford256, 'dir'), mkdir(stanford256); end 28 | fn256 = strrep(fn, '/', '_'); 29 | fn256 = cellfun(@(x) fullfile(stanford256, x), fn256, 'uniform', false); 30 | for i = 1:numel(fn) 31 | if ~exist(fn256{i}, 'file') 32 | im = imread([stanforddir '/' fn{i}]); 33 | im = imresize(im, [256 256]); 34 | imwrite(im, fn256{i}); 35 | end 36 | if toc > 30 37 | logInfo('%d/%d: %s', i, numel(fn), fn256{i}); tic; 38 | end 39 | end 40 | toc(a) 41 | 42 | imdb.images.data = fn256; 43 | imdb.images.labels = single(labels) ; 44 | imdb.images.set = uint8(set) ; 45 | imdb.meta.sets = {'train', 'val', 'test'} ; 46 | save(savefn, '-struct', 'imdb'); 47 | end 48 | -------------------------------------------------------------------------------- /matlab/+imdb/vid.m: -------------------------------------------------------------------------------- 1 | function imdb = vid(prefix, opts) 2 | savefn = fullfile(opts.dataDir, [prefix '.mat']); 3 | if exist(savefn, 'file') 4 | imdb = load(savefn); 5 | return; 6 | end 7 | viddir = fullfile(opts.dataDir, 'VehicleID_V1.0'); 8 | 9 | % imgs, vid 10 | img2vid = fullfile(viddir, 'attribute', 'img2vid.txt'); 11 | [imgs, vid] = textread(img2vid, '%d %d'); 12 | n = numel(imgs); 13 | 14 | % model info 15 | mid = zeros(n, 1); 16 | vid2model = fullfile(viddir, 'attribute', 'model_attr.txt'); 17 | [vid_m, mid_m] = textread(vid2model, '%d %d'); 18 | for i = 1:length(vid_m) 19 | mid(vid == vid_m(i)) = mid_m(i) + 1; 20 | end 21 | 22 | % color info 23 | cid = zeros(n, 1); 24 | vid2color = fullfile(viddir, 'attribute', 'color_attr.txt'); 25 | [vid_c, cid_c] = textread(vid2color, '%d %d'); 26 | for i = 1:length(vid_c) 27 | mid(vid == vid_c(i)) = cid_c(i) + 1; 28 | end 29 | 30 | % train/test 31 | sets = zeros(n, 4); % [train test_small test_medium test_large] 32 | 33 | trainFile = fullfile(viddir, 'train_test_split', 'train_list.txt'); 34 | [~, vid_train] = textread(trainFile, '%d %d'); 35 | sets(ismember(vid, vid_train), 1) = 1; 36 | logInfo('train: %d imgs', sum(sets(:, 1))); 37 | 38 | testSmallFile = fullfile(viddir, 'train_test_split', 'test_list_800.txt'); 39 | [~, vid_s] = textread(testSmallFile, '%d %d'); 40 | sets(ismember(vid, vid_s), 2) = 1; 41 | logInfo('test-small: %d imgs', sum(sets(:, 2))); 42 | 43 | testMediumFile = fullfile(viddir, 'train_test_split', 'test_list_1600.txt'); 44 | [~, vid_m] = textread(testMediumFile, '%d %d'); 45 | sets(ismember(vid, vid_m), 3) = 1; 46 | logInfo('test-medium: %d imgs', sum(sets(:, 3))); 47 | 48 | testLargeFile = fullfile(viddir, 'train_test_split', 'test_list_2400.txt'); 49 | [~, vid_l] = textread(testLargeFile, '%d %d'); 50 | sets(ismember(vid, vid_l), 4) = 1; 51 | logInfo('test-large: %d imgs', sum(sets(:, 4))); 52 | 53 | % resize to 256x256 54 | a = tic; tic; 55 | vid256 = fullfile(viddir, 'image_256x256'); 56 | if ~exist(vid256, 'dir'), mkdir(vid256); end 57 | fn256 = cell(n, 1); 58 | for i = 1:n 59 | fn256{i} = sprintf('%s/%07d.jpg', vid256, imgs(i)); 60 | if ~exist(fn256{i}, 'file') 61 | im = imread(sprintf('%s/image/%07d.jpg', viddir, imgs(i))); 62 | im = imresize(im, [256 256]); 63 | imwrite(im, fn256{i}); 64 | end 65 | if toc > 20 66 | logInfo(fn256{i}); tic; 67 | end 68 | end 69 | toc(a) 70 | 71 | % imdb 72 | imdb.images.data = fn256; 73 | imdb.images.labels = single([vid mid cid]) ; 74 | imdb.images.set = uint8(sets) ; 75 | imdb.meta.sets = {'train', 'test_small', 'test_medium', 'test_large'} ; 76 | save(savefn, '-struct', 'imdb'); 77 | end 78 | -------------------------------------------------------------------------------- /matlab/+models/googlenet.m: -------------------------------------------------------------------------------- 1 | function [net, opts, in_name, in_dim] = googlenet(opts) 2 | opts.imageSize = 224; 3 | opts.maxGpuImgs = 320; % # of images that a 12GB GPU can hold 4 | 5 | % finetune GoogLeNet 6 | net = load(fullfile(opts.localDir, 'models', 'imagenet-googlenet-dag.mat')); 7 | net = dagnn.DagNN.loadobj(net) ; 8 | 9 | % remove softmax, fc 10 | net.removeLayer('softmax'); 11 | net.removeLayer('cls3_fc'); 12 | 13 | net.removeLayer('cls2_pool'); 14 | net.removeLayer('cls2_reduction'); 15 | net.removeLayer('relu_cls2_reduction'); 16 | net.removeLayer('cls2_fc1'); 17 | net.removeLayer('relu_cls2_fc1'); 18 | net.removeLayer('cls2_fc2'); 19 | 20 | net.removeLayer('cls1_pool'); 21 | net.removeLayer('cls1_reduction'); 22 | net.removeLayer('relu_cls1_reduction'); 23 | net.removeLayer('cls1_fc1'); 24 | net.removeLayer('relu_cls1_fc1'); 25 | net.removeLayer('cls1_fc2'); 26 | 27 | % freeze pretrained layers 28 | if opts.lastLayer 29 | for i = 1:numel(net.params) 30 | net.params(i).learningRate = 0; 31 | net.params(i).weightDecay = 0; 32 | end 33 | elseif isfield(opts, 'ft') && opts.ft >= 1 34 | assert(opts.ft>=1 && opts.ft<=9); 35 | names = {}; 36 | for l = opts.ft : 9 37 | names{end+1} = sprintf('icp%d', l); 38 | end 39 | for i = 1:numel(net.params) 40 | pname = net.params(i).name; 41 | notfound = cellfun(@(x) isempty(strfind(pname, x)), names); 42 | if all(notfound) 43 | net.params(i).learningRate = 0; 44 | net.params(i).weightDecay = 0; 45 | end 46 | end 47 | end 48 | 49 | in_name = 'cls3_pool'; 50 | in_dim = 1024; 51 | end 52 | -------------------------------------------------------------------------------- /matlab/+models/resnet18.m: -------------------------------------------------------------------------------- 1 | function [net, opts, in_name, in_dim] = resnet18(opts) 2 | opts.imageSize = 224; 3 | opts.maxGpuImgs = 256; % # of images that a 12GB GPU can hold 4 | 5 | net = load(fullfile(opts.localDir, 'models', 'resnet18-pt-mcn.mat')); 6 | net = dagnn.DagNN.loadobj(net) ; 7 | 8 | % remove softmax, fc 9 | net.removeLayer('classifier_0'); 10 | 11 | % freeze pretrained layers 12 | if opts.lastLayer 13 | for i = 1:numel(net.params) 14 | net.params(i).learningRate = 0; 15 | net.params(i).weightDecay = 0; 16 | end 17 | end 18 | 19 | in_name = 'classifier_flatten'; 20 | in_dim = 512; 21 | end 22 | -------------------------------------------------------------------------------- /matlab/+models/resnet50.m: -------------------------------------------------------------------------------- 1 | function [net, opts, in_name, in_dim] = resnet50(opts) 2 | opts.imageSize = 224; 3 | opts.maxGpuImgs = 90; % # of images that a 12GB GPU can hold 4 | 5 | net = load(fullfile(opts.localDir, 'models', 'imagenet-resnet-50-dag.mat')); 6 | net = dagnn.DagNN.loadobj(net) ; 7 | 8 | % remove softmax, fc 9 | net.removeLayer('prob'); 10 | net.removeLayer('fc1000'); 11 | 12 | % freeze pretrained layers 13 | if opts.lastLayer 14 | for i = 1:numel(net.params) 15 | net.params(i).learningRate = 0; 16 | net.params(i).weightDecay = 0; 17 | end 18 | end 19 | 20 | in_name = 'pool5'; 21 | in_dim = 2048; 22 | end 23 | -------------------------------------------------------------------------------- /matlab/+solver/adadelta.m: -------------------------------------------------------------------------------- 1 | function [w, state] = adadelta(w, state, grad, opts, ~) 2 | %ADADELTA 3 | % Example AdaDelta solver, for use with CNN_TRAIN and CNN_TRAIN_DAG. 4 | % 5 | % AdaDelta sets its own learning rate, so any learning rate set in the 6 | % options of CNN_TRAIN and CNN_TRAIN_DAG will be ignored. 7 | % 8 | % If called without any input argument, returns the default options 9 | % structure. 10 | % 11 | % Solver options: (opts.train.solverOpts) 12 | % 13 | % `epsilon`:: 1e-6 14 | % Small additive constant to regularize variance estimate. 15 | % 16 | % `rho`:: 0.9 17 | % Moving average window for variance update, between 0 and 1 (larger 18 | % values result in slower/more stable updating). 19 | 20 | % Copyright (C) 2016 Joao F. Henriques. 21 | % All rights reserved. 22 | % 23 | % This file is part of the VLFeat library and is made available under 24 | % the terms of the BSD license (see the COPYING file). 25 | 26 | if nargin == 0 % Return the default solver options 27 | w = struct('epsilon', 1e-6, 'rho', 0.9) ; 28 | return ; 29 | end 30 | 31 | if isequal(state, 0) % First iteration, initialize state struct 32 | state = struct('g_sqr', 0, 'delta_sqr', 0) ; 33 | end 34 | 35 | rho = opts.rho ; 36 | 37 | state.g_sqr = state.g_sqr * rho + grad.^2 * (1 - rho) ; 38 | new_delta = -sqrt((state.delta_sqr + opts.epsilon) ./ ... 39 | (state.g_sqr + opts.epsilon)) .* grad ; 40 | state.delta_sqr = state.delta_sqr * rho + new_delta.^2 * (1 - rho) ; 41 | 42 | w = w + new_delta ; 43 | -------------------------------------------------------------------------------- /matlab/+solver/adagrad.m: -------------------------------------------------------------------------------- 1 | function [w, g_sqr] = adagrad(w, g_sqr, grad, opts, lr) 2 | %ADAGRAD 3 | % Example AdaGrad solver, for use with CNN_TRAIN and CNN_TRAIN_DAG. 4 | % 5 | % Set the initial learning rate for AdaGrad in the options for 6 | % CNN_TRAIN and CNN_TRAIN_DAG. Note that a learning rate that works for 7 | % SGD may be inappropriate for AdaGrad; the default is 0.001. 8 | % 9 | % If called without any input argument, returns the default options 10 | % structure. 11 | % 12 | % Solver options: (opts.train.solverOpts) 13 | % 14 | % `epsilon`:: 1e-10 15 | % Small additive constant to regularize variance estimate. 16 | % 17 | % `rho`:: 1 18 | % Moving average window for variance update, between 0 and 1 (larger 19 | % values result in slower/more stable updating). This is similar to 20 | % RHO in AdaDelta and RMSProp. Standard AdaGrad is obtained with a RHO 21 | % value of 1 (use total average instead of a moving average). 22 | % 23 | % A possibly undesirable effect of standard AdaGrad is that the update 24 | % will monotonically decrease to 0, until training eventually stops. This 25 | % is because the AdaGrad update is inversely proportional to the total 26 | % variance of the gradients seen so far. 27 | % With RHO smaller than 1, a moving average is used instead. This 28 | % prevents the final update from monotonically decreasing to 0. 29 | 30 | % Copyright (C) 2016 Joao F. Henriques. 31 | % All rights reserved. 32 | % 33 | % This file is part of the VLFeat library and is made available under 34 | % the terms of the BSD license (see the COPYING file). 35 | 36 | if nargin == 0 % Return the default solver options 37 | w = struct('epsilon', 1e-10, 'rho', 1) ; 38 | return ; 39 | end 40 | 41 | g_sqr = g_sqr * opts.rho + grad.^2 ; 42 | 43 | w = w - lr * grad ./ (sqrt(g_sqr) + opts.epsilon) ; 44 | -------------------------------------------------------------------------------- /matlab/+solver/adam.m: -------------------------------------------------------------------------------- 1 | function [w, state] = adam(w, state, grad, opts, lr) 2 | %ADAM 3 | % Adam solver for use with CNN_TRAIN and CNN_TRAIN_DAG 4 | % 5 | % See [Kingma et. al., 2014](http://arxiv.org/abs/1412.6980) 6 | % | ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). 7 | % 8 | % If called without any input argument, returns the default options 9 | % structure. Otherwise provide all input arguments. 10 | % 11 | % W is the vector/matrix/tensor of parameters. It can be single/double 12 | % precision and can be a `gpuArray`. 13 | % 14 | % STATE is as defined below and so are supported OPTS. 15 | % 16 | % GRAD is the gradient of the objective w.r.t W 17 | % 18 | % LR is the learning rate, referred to as \alpha by Algorithm 1 in 19 | % [Kingma et. al., 2014]. 20 | % 21 | % Solver options: (opts.train.solverOpts) 22 | % 23 | % `beta1`:: 0.9 24 | % Decay for 1st moment vector. See algorithm 1 in [Kingma et.al. 2014] 25 | % 26 | % `beta2`:: 0.999 27 | % Decay for 2nd moment vector 28 | % 29 | % `eps`:: 1e-8 30 | % Additive offset when dividing by state.v 31 | % 32 | % The state is initialized as 0 (number) to start with. The first call to 33 | % this function will initialize it with the default state consisting of 34 | % 35 | % `m`:: 0 36 | % First moment vector 37 | % 38 | % `v`:: 0 39 | % Second moment vector 40 | % 41 | % `t`:: 0 42 | % Global iteration number across epochs 43 | % 44 | % This implementation borrowed from torch optim.adam 45 | 46 | % Copyright (C) 2016 Aravindh Mahendran. 47 | % All rights reserved. 48 | % 49 | % This file is part of the VLFeat library and is made available under 50 | % the terms of the BSD license (see the COPYING file). 51 | 52 | if nargin == 0 % Returns the default solver options 53 | w = struct('beta1', 0.9, 'beta2', 0.999, 'eps', 1e-8) ; 54 | return ; 55 | end 56 | 57 | if isequal(state, 0) % start off with state = 0 so as to get default state 58 | state = struct('m', 0, 'v', 0, 't', 0); 59 | end 60 | 61 | % update first moment vector `m` 62 | state.m = opts.beta1 * state.m + (1 - opts.beta1) * grad ; 63 | 64 | % update second moment vector `v` 65 | state.v = opts.beta2 * state.v + (1 - opts.beta2) * grad.^2 ; 66 | 67 | % update the time step 68 | state.t = state.t + 1 ; 69 | 70 | % This implicitly corrects for biased estimates of first and second moment 71 | % vectors 72 | lr_t = lr * (((1 - opts.beta2^state.t)^0.5) / (1 - opts.beta1^state.t)) ; 73 | 74 | % Update `w` 75 | w = w - lr_t * state.m ./ (state.v.^0.5 + opts.eps) ; 76 | -------------------------------------------------------------------------------- /matlab/+solver/rmsprop.m: -------------------------------------------------------------------------------- 1 | function [w, g_sqr] = rmsprop(w, g_sqr, grad, opts, lr) 2 | %RMSPROP 3 | % Example RMSProp solver, for use with CNN_TRAIN and CNN_TRAIN_DAG. 4 | % 5 | % Set the initial learning rate for RMSProp in the options for 6 | % CNN_TRAIN and CNN_TRAIN_DAG. Note that a learning rate that works for 7 | % SGD may be inappropriate for RMSProp; the default is 0.001. 8 | % 9 | % If called without any input argument, returns the default options 10 | % structure. 11 | % 12 | % Solver options: (opts.train.solverOpts) 13 | % 14 | % `epsilon`:: 1e-8 15 | % Small additive constant to regularize variance estimate. 16 | % 17 | % `rho`:: 0.99 18 | % Moving average window for variance update, between 0 and 1 (larger 19 | % values result in slower/more stable updating). 20 | 21 | % Copyright (C) 2016 Joao F. Henriques. 22 | % All rights reserved. 23 | % 24 | % This file is part of the VLFeat library and is made available under 25 | % the terms of the BSD license (see the COPYING file). 26 | 27 | if nargin == 0 % Return the default solver options 28 | w = struct('epsilon', 1e-8, 'rho', 0.99) ; 29 | return ; 30 | end 31 | 32 | g_sqr = g_sqr * opts.rho + grad.^2 * (1 - opts.rho) ; 33 | 34 | w = w - lr * grad ./ (sqrt(g_sqr) + opts.epsilon) ; 35 | -------------------------------------------------------------------------------- /matlab/FastAP.m: -------------------------------------------------------------------------------- 1 | classdef FastAP < dagnn.Loss 2 | properties 3 | opt 4 | end 5 | 6 | properties (Transient) 7 | Z 8 | Delta 9 | dist2 10 | I_pos 11 | I_neg 12 | h_pos 13 | h_neg 14 | H_pos 15 | N_pos 16 | h 17 | H 18 | end 19 | 20 | methods 21 | 22 | function obj = FastAP(varargin) 23 | obj.load(varargin{:}); 24 | end 25 | 26 | 27 | function outputs = forward(obj, inputs, params) 28 | % forward pass 29 | X = squeeze(inputs{1}); % features (L2 normalized) 30 | Y = inputs{2}; 31 | N = size(X, 2); 32 | assert(size(Y, 1) == N); 33 | 34 | opts = obj.opt; 35 | onGPU = numel(opts.gpus) > 0; 36 | 37 | % binary affinity 38 | Affinity = 2 * bsxfun(@eq, Y(:,1), Y(:,1)') - 1; 39 | Affinity(logical(eye(N))) = 0; 40 | I_pos = (Affinity > 0); 41 | I_neg = (Affinity < 0); 42 | N_pos = sum(I_pos, 2); 43 | 44 | % (squared) pairwise distance matrix 45 | dist2 = max(0, 2 - 2 * X' * X); 46 | 47 | % histogram binning 48 | Delta = 4 / opts.nbins; 49 | Z = linspace(0, 4, opts.nbins+1); 50 | L = length(Z); 51 | h_pos = zeros(N, L); 52 | h_neg = zeros(N, L); 53 | if onGPU 54 | h_pos = gpuArray(h_pos); 55 | h_neg = gpuArray(h_neg); 56 | end 57 | 58 | for l = 1:L 59 | pulse = obj.softBinning(dist2, Z(l), Delta); 60 | h_pos(:, l) = sum(pulse .* I_pos, 2); 61 | h_neg(:, l) = sum(pulse .* N_neg, 2); 62 | end 63 | H_pos = cumsum(h_pos, 2); 64 | h = h_pos + h_neg; 65 | H = cumsum(h, 2); 66 | 67 | % compute FastAP 68 | FastAP = h_pos .* H_pos ./ H; 69 | FastAP(isnan(FastAP)|isinf(FastAP)) = 0; 70 | FastAP = sum(FastAP, 2) ./ N_pos; 71 | FastAP = FastAP(~isnan(FastAP)); 72 | 73 | obj.numAveraged = N; 74 | obj.average = gather(mean(FastAP)); 75 | 76 | % output 77 | outputs{1} = sum(FastAP); 78 | obj.Z = Z; 79 | obj.Delta = Delta; 80 | obj.dist2 = dist2; 81 | obj.I_pos = I_pos; 82 | obj.I_neg = I_neg; 83 | obj.h_pos = h_pos; 84 | obj.h_neg = h_neg; 85 | obj.H_pos = H_pos; 86 | obj.N_pos = N_pos; 87 | obj.h = h; 88 | obj.H = H; 89 | end 90 | 91 | 92 | function [dInputs, dParams] = backward(obj, inputs, params, dOutputs) 93 | % backward pass 94 | X = squeeze(inputs{1}); 95 | opts = obj.opt; 96 | onGPU = numel(opts.gpus) > 0; 97 | 98 | L = numel(obj.Z); 99 | h_pos = obj.h_pos; 100 | h_neg = obj.h_neg; 101 | H_pos = obj.H_pos; 102 | N_pos = obj.N_pos; 103 | H_neg = H - H_pos; 104 | h = obj.h; 105 | H = obj.H; 106 | H2 = H .^ 2; 107 | 108 | % 1. d(FastAP)/d(h+) 109 | tmp1 = h_pos .* H_neg ./ H2; 110 | tmp1(isnan(tmp1)) = 0; 111 | 112 | d_AP_h_pos = (H_pos .* H + h_pos .* H_neg) ./ H2; 113 | d_AP_h_pos = d_AP_h_pos + tmp1 * triu(ones(L), 1)'; 114 | d_AP_h_pos = bsxfun(@rdivide, d_AP_h_pos, N_pos); 115 | d_AP_h_pos(isnan(d_AP_h_pos)|isinf(d_AP_h_pos)) = 0; 116 | 117 | % 2. d(FastAP)/d(h-) 118 | tmp2 = -h_pos .* H_pos ./ H2; 119 | tmp2(isnan(tmp2)) = 0; 120 | 121 | d_AP_h_neg = tmp2 * triu(ones(L))'; 122 | d_AP_h_neg = bsxfun(@rdivide, d_AP_h_neg, N_pos); 123 | d_AP_h_neg(isnan(d_AP_h_neg)|isinf(d_AP_h_neg)) = 0; 124 | 125 | % 3. d(FastAP)/d(x) 126 | d_AP_x = 0; 127 | for l = 1:L 128 | % NxN matrix of delta_hat(i, j, l) for fixed l 129 | dpulse = obj.dSoftBinning(obj.dist2, obj.Z(l), obj.Delta); 130 | dpulse(isnan(dpulse)|isinf(dpulse)) = 0; 131 | ddp = dpulse .* obj.I_pos; 132 | ddn = dpulse .* obj.I_neg; 133 | 134 | alpha_p = diag(d_AP_h_pos(:, l)); % NxN 135 | alpha_n = diag(d_AP_h_neg(:, l)); 136 | Ap = ddp * alpha_p + alpha_p * ddp; 137 | An = ddn * alpha_n + alpha_n * ddn; 138 | 139 | % accumulate gradient % (BxN) (NxN) -> (BxN) 140 | d_AP_x = d_AP_x - X * (Ap + An); 141 | end 142 | 143 | % output 144 | dInputs{1} = zeros(size(inputs{1}), 'single'); 145 | if onGPU, dInputs{1} = gpuArray(dInputs{1}); end 146 | dInputs{1}(1, 1, :, :) = -single(d_AP_x); 147 | dInputs{2} = []; 148 | dParams = {}; 149 | end 150 | 151 | 152 | function [grad, objval] = computeGrad(obj, X, Y) 153 | % helper function to compute gradient matrix 154 | inputs = {X, Y}; 155 | output = obj.forward(inputs, []); 156 | objval = obj.average; 157 | [dInputs, ~] = obj.backward(inputs, [], {'objective', 1}); 158 | grad = dInputs{1}; 159 | end 160 | 161 | 162 | function y = softBinning(D, mid, delta) 163 | % D: input matrix of distance values 164 | % mid: scalar, the center of some histogram bin 165 | % delta: scalar, histogram bin width 166 | % 167 | % For histogram bin mid, compute the contribution y 168 | % from every element in D. 169 | y = 1 - abs(D - mid) / delta; 170 | y = max(0, y); 171 | end 172 | 173 | 174 | function y = dSoftBinning(D, mid, delta); 175 | % vectorized version 176 | % mid: scalar bin center 177 | % D: can be a matrix 178 | ind1 = (D > mid-delta) & (D <= mid); 179 | ind2 = (D > mid) & (D <= mid+delta); 180 | y = (ind1 - ind2) / delta; 181 | end 182 | 183 | end 184 | end 185 | -------------------------------------------------------------------------------- /matlab/README.md: -------------------------------------------------------------------------------- 1 | ## Matlab implementation of "Deep Metric Learning to Rank" 2 | 3 | ### Requirements 4 | * Matlab R2017b or newer 5 | * This is to use the built-in [mink](https://www.mathworks.com/help/matlab/ref/mink.html) function. 6 | * Alternatively, for earlier Matlab versions, you can use [this implementation](https://www.mathworks.com/matlabcentral/fileexchange/23576-min-max-selection) of mink. 7 | * [MatConvNet](http://www.vlfeat.org/matconvnet/) v1.0-beta25 (with [`vl_contrib`](http://www.vlfeat.org/matconvnet/mfiles/vl_contrib/)) 8 | * [mcnExtraLayers](https://github.com/albanie/mcnExtraLayers) via `vl_contrib setup mcnExtraLayers` 9 | * [autonn](https://github.com/vlfeat/autonn) via `vl_contrib setup autonn` 10 | 11 | ### Preparation 12 | * Install/symlink MatConvNet at `./matconvnet` under this directory 13 | * Create or symlink a directory `./cachedir` under the this directory 14 | * Create a subdirectory `./cachedir/data`, and create symlinks to the datasets 15 | * Stanford Online Products at `./cachedir/data/Stanford_Online_Products` 16 | * In-Shop Clothes Retrieval at `./cachedir/data/InShopClothes` 17 | * PKU VehicleID at `./cachedir/data/VehicleID_V1.0` 18 | * Create `./cachedir/models` and download pretrained models 19 | * [GoogLeNet](http://www.vlfeat.org/matconvnet/models/imagenet-googlenet-dag.mat): download to `./cachedir/models/imagenet-googlenet-dag.mat` 20 | * [ResNet-18](http://www.robots.ox.ac.uk/~albanie/models/pytorch-imports/resnet18-pt-mcn.mat): download to `./cachedir/models/resnet18-pt-mcn.mat` 21 | * [ResNet-50](http://www.vlfeat.org/matconvnet/models/imagenet-resnet-50-dag.mat): download to `./cachedir/models/imagenet-resnet-50-dag.mat` 22 | 23 | ### Usage 24 | We provide a unified interface `run_demo.m` to run all experiments conducted in the paper. The general syntax is 25 | ``` 26 | run_demo([dataset], [key-value pairs]) 27 | ``` 28 | where 29 | * `dataset` is one of `'products', 'inshop', 'vid'` 30 | * Various parameters are specified as key-value pairs. The full list can be found by inspecting `get_opts.m`. Some notable ones are: 31 | * `'gpus'` (int) 1-based GPU index. Current implementation only supports 1 GPU. 32 | * `'arch'` (string) network architecture. Available: `'googlenet', 'resnet18', 'resnet50'` 33 | * `'dim'` (int) embedding dimesionality, default 512 34 | * `'nbins'` (int) number of distance quantizations, default 10 35 | * `'solver'` (string) SGD optimizer. Default: `'adam'`. Others: `'sgd', 'adadelta', 'adagrad', 'rmsprop'` 36 | 37 | ### Notes 38 | * Our experiments are mainly run on a Titan X Pascal GPU with 12GB memory. If your GPU has a different amount of memory, you may want to change the default value of `opts.maxGpuImgs` (in `+models/*.m`) accordingly for each model architecture. 39 | -------------------------------------------------------------------------------- /matlab/evaluate_model.m: -------------------------------------------------------------------------------- 1 | function evaluate_model(net, imdb, batchFunc, opts, varargin) 2 | 3 | if strcmp(opts.dataset, 'inshop') 4 | % separate query set and database 5 | query_id = find(imdb.images.set == 2); 6 | gallery_id = find(imdb.images.set == 3); 7 | Yquery = imdb.images.labels(query_id, :); 8 | Ygallery = imdb.images.labels(gallery_id, :); 9 | Fquery = cnn_encode(net, imdb, batchFunc, query_id, opts, varargin{:})'; 10 | Fgallery = cnn_encode(net, imdb, batchFunc, gallery_id, opts, varargin{:})'; 11 | whos Fquery Fgallery 12 | 13 | logInfo('[%s] Recall ...', opts.dataset); 14 | evaluate_recall(Fquery, Yquery, Fgallery, Ygallery, opts, false); 15 | 16 | elseif strcmp(opts.dataset, 'vid') 17 | % small, medium, large 18 | test_sets = {'Small', 'Medium', 'Large'}; 19 | for s = 1:3 20 | test_id = find(imdb.images.set(:, s+1) == 1); 21 | Ytest = imdb.images.labels(test_id, :); 22 | Ftest = cnn_encode(net, imdb, batchFunc, test_id, opts, varargin{:})'; 23 | logInfo('[%s] Recall on [%s] ...', dataset, test_sets{s}); 24 | evaluate_recall(Ftest, Ytest, Ftest, Ytest, opts); 25 | end 26 | 27 | else 28 | % query set == database 29 | test_id = find(imdb.images.set == 3); 30 | Ytest = imdb.images.labels(test_id, :); 31 | Ftest = cnn_encode(net, imdb, batchFunc, test_id, opts, varargin{:})'; 32 | whos Ftest 33 | 34 | logInfo('[%s] Recall & NMI ...', opts.dataset); 35 | evaluate_recall(Ftest, Ytest, Ftest, Ytest, opts); 36 | end 37 | 38 | end 39 | 40 | % ------------------------------------------------------------------------------ 41 | % ------------------------------------------------------------------------------ 42 | 43 | function evaluate_recall(Xq, Yq, Xdb, Ydb, opts, removeDiag) 44 | % features: Nxd matrix 45 | if ~exist('removeDiag', 'var'), removeDiag = true; end 46 | 47 | [Nq, d] = size(Xq); 48 | assert(Nq == size(Yq, 1)); 49 | assert(size(Yq, 2) == size(Ydb, 2)); 50 | 51 | % pairwise distances 52 | tic; 53 | distmatrix = 2 - 2 * Xq * Xdb'; % NxN 54 | if removeDiag 55 | distmatrix(logical(eye(Nq))) = Inf; 56 | end 57 | 58 | % recall rate 59 | Ks = sort(opts.Ks, 'ascend'); 60 | recallrate = zeros(Nq, numel(Ks)); 61 | 62 | for i = 1:Nq 63 | % get top max(K) results, in increasing distance 64 | [val, ind] = mink(distmatrix(i, :), Ks(end)); 65 | 66 | % recall rates for all K's 67 | for j = 1:numel(Ks) 68 | ind_k = ind(1 : Ks(j)); 69 | recallrate(i, j) = any(Ydb(ind_k, 1) == Yq(i, 1)); 70 | end 71 | end 72 | toc; 73 | 74 | for j = 1:numel(Ks) 75 | fprintf('K: %4d, Recall: %.3f\n', opts.Ks(j), mean(recallrate(:, j))); 76 | end 77 | 78 | end 79 | 80 | % ------------------------------------------------------------------------------ 81 | % ------------------------------------------------------------------------------ 82 | 83 | function H = cnn_encode(net, imdb, batchFunc, ids, opts, varName) 84 | if ~exist('varName', 'var') 85 | varName = 'feats_l2'; 86 | end 87 | 88 | batch_size = 2 * opts.maxGpuImgs; 89 | onGPU = numel(opts.gpus) > 0; 90 | 91 | logInfo('Testing [%s] on %d -> %s', opts.arch, length(ids), varName); 92 | 93 | net.mode = 'test'; 94 | ind = net.getVarIndex(varName); 95 | net.vars(ind).precious = 1; 96 | if onGPU, net.move('gpu'); end 97 | 98 | assert(strcmp(varName, 'feats_l2')); 99 | H = zeros(opts.dim, length(ids), 'single'); 100 | 101 | tic; 102 | for t = 1:batch_size:length(ids) 103 | ed = min(t+batch_size-1, length(ids)); 104 | inputs = batchFunc(imdb, ids(t:ed)); 105 | net.eval(inputs); 106 | 107 | ret = net.vars(ind).value; 108 | ret = squeeze(gather(ret)); 109 | H(:, t:ed) = ret; 110 | end 111 | 112 | if onGPU && isa(net, 'dagnn.DagNN') 113 | net.reset(); 114 | net.move('cpu'); 115 | end 116 | toc; 117 | 118 | end 119 | -------------------------------------------------------------------------------- /matlab/get_model.m: -------------------------------------------------------------------------------- 1 | function [net, opts] = get_model(opts, addFC) 2 | if nargin < 2, addFC = true; end 3 | 4 | t0 = tic; 5 | modelFunc = str2func(sprintf('models.%s', opts.arch)); 6 | [net, opts, in_name, in_dim] = modelFunc(opts); 7 | logInfo('%s in %.2fs', opts.arch, toc(t0)); 8 | 9 | % + FC layer 10 | if addFC 11 | convobj = dagnn.Conv('size', [1 1 in_dim opts.dim], ... 12 | 'pad', 0, 'stride', 1, 'hasBias', true); 13 | params = convobj.initParams(); 14 | net.addLayer('fc', convobj, {in_name}, {'logits'}, {'fc_w', 'fc_b'}); 15 | p1 = net.getParamIndex('fc_w'); 16 | p2 = net.getParamIndex('fc_b'); 17 | net.params(p1).value = params{1}; 18 | net.params(p2).value = params{2}; 19 | net.params(p1).learningRate = opts.lrmult; 20 | net.params(p2).learningRate = opts.lrmult; 21 | 22 | in_name = 'logits'; 23 | in_dim = opts.dim; 24 | end 25 | 26 | % + l2 normalization layer 27 | net.addLayer('L2norm', dagnn.LRN('param', [2*in_dim, 0, 1, 0.5]), ... 28 | {in_name}, {'feats_l2'}); 29 | 30 | % + loss layer 31 | lossobj = str2func(opts.obj); 32 | net.addLayer('loss', lossobj('opt', opts), {'feats_l2', 'labels'}, {'objective'}); 33 | 34 | % print 35 | if 0 36 | net.print({'data', [opts.imageSize opts.imageSize 3 opts.batchSize]}, ... 37 | 'MaxNumColumns', 4, 'Layers', [], 'Parameters', []); 38 | end 39 | 40 | end 41 | -------------------------------------------------------------------------------- /matlab/get_opts.m: -------------------------------------------------------------------------------- 1 | function opts = get_opts(opts, dataset, varargin) 2 | 3 | if nargin == 1 4 | opts = process_opts(opts); 5 | return; 6 | end 7 | 8 | ip = inputParser; 9 | 10 | % model params 11 | ip.addParameter('obj' , 'FastAP'); 12 | ip.addParameter('arch' , 'resnet18'); % CNN model 13 | ip.addParameter('dim' , 512); % embedding vector size 14 | ip.addParameter('nbins' , 10); % quantization granularity 15 | 16 | % SGD 17 | ip.addParameter('solver' , 'adam'); 18 | ip.addParameter('batchSize' , 256); 19 | ip.addParameter('lr' , 1e-5); % base learning rate 20 | ip.addParameter('lrmult' , 10); % learning rate multiplier for last layer 21 | ip.addParameter('lrdecay' , 0.1); % [SGD only] LR decay factor 22 | ip.addParameter('lrstep' , 10); % [SGD only] decay LR after this many epochs 23 | ip.addParameter('wdecay' , 0); % weight decay 24 | 25 | % train 26 | ip.addParameter('testInterval', 3); % test interval 27 | ip.addParameter('epoch' , 30); % num. epochs 28 | ip.addParameter('gpus' , 1); % which GPU to use 29 | ip.addParameter('continue' , true); % resume from existing model 30 | ip.addParameter('debug' , false); 31 | ip.addParameter('plot' , false); 32 | 33 | % misc 34 | ip.addParameter('ablation', 0); % 0: none, 1: batchSize, 2: nbins, 3: randMBS 35 | ip.addParameter('randseed', 42); 36 | ip.addParameter('prefix' , []); 37 | 38 | % parse input 39 | ip.KeepUnmatched = true; 40 | ip.parse(varargin{:}); 41 | opts = catstruct(ip.Results, opts); % combine w/ existing opts 42 | opts.dataset = dataset; 43 | 44 | end 45 | 46 | % ======================================================================= 47 | % ======================================================================= 48 | 49 | function opts = process_opts(opts) 50 | % post-parse processing 51 | 52 | opts.expID = sprintf('%s-%s-%d-%s', opts.dataset, opts.obj, opts.dim, opts.arch); 53 | 54 | switch (opts.ablation) 55 | case 1 56 | opts.prefix = 'ablationM'; 57 | case 2 58 | opts.prefix = 'ablationHist'; 59 | case 3 60 | opts.prefix = 'randMBS'; 61 | otherwise 62 | opts.prefix = ''; 63 | end 64 | 65 | % for batch sizes > GPU mem limit: use chunks 66 | switch (opts.arch) 67 | case 'resnet50' 68 | opts.maxGpuImgs = 90; 69 | case 'resnet18' 70 | opts.maxGpuImgs = 256; 71 | case 'googlenet' 72 | opts.maxGpuImgs = 320; 73 | end 74 | 75 | % -------------------------------------------- 76 | % identifier string for the current experiment 77 | ID = sprintf('%dbins-batch%d-%s%.emult%g', opts.nbins, opts.batchSize, ... 78 | opts.solver, opts.lr, opts.lrmult); 79 | 80 | if strcmp(opts.solver, 'sgd') 81 | assert(opts.lrdecay > 0 && opts.lrdecay < 1); 82 | assert(opts.lrstep > 0); 83 | ID = sprintf('%sD%gE%d', ID, opts.lrdecay, opts.lrstep); 84 | end 85 | ID = sprintf('%s-WD%.e', ID, opts.wdecay); 86 | 87 | if isempty(opts.prefix) 88 | % prefix: timestamp 89 | [~, T] = unix(['git log -1 --format=%ci|cut -d " " -f1,2|cut -d "-" -f2,3' ... 90 | '|tr " " "."|tr -d ":-"']); 91 | opts.prefix = strrep(T, newline, ''); 92 | end 93 | opts.identifier = [opts.prefix '-' ID]; 94 | 95 | % -------------------------------------------- 96 | % mkdirs 97 | opts.localDir = fullfile(pwd, 'cachedir'); % use symlink on linux 98 | if ~exist(opts.localDir, 'file') 99 | error('Please mkdir/symlink cachedir!'); 100 | end 101 | opts.dataDir = fullfile(opts.localDir, 'data'); 102 | opts.imdbPath = fullfile(opts.dataDir, ['imdb_' opts.dataset]); 103 | 104 | opts.expDir = fullfile(opts.localDir, opts.expID, opts.identifier); 105 | if ~exist(opts.expDir, 'dir') 106 | logInfo(['creating opts.expDir: ' opts.expDir]); 107 | mkdir(opts.expDir); 108 | end 109 | 110 | % -------------------------------------------- 111 | % rng 112 | rng(opts.randseed); 113 | 114 | end 115 | -------------------------------------------------------------------------------- /matlab/run_demo.m: -------------------------------------------------------------------------------- 1 | function run_demo(dataset, varargin) 2 | 3 | % ---------------------------------------- 4 | % init 5 | % ---------------------------------------- 6 | ip = inputParser; 7 | 8 | if strcmp(dataset, 'products') 9 | ip.addParameter('Ks', [1 10 100 1000]); 10 | hardTrainFn = @train_hard; 11 | BPMP = 5; 12 | 13 | elseif strcmp(dataset, 'inshop') 14 | ip.addParameter('Ks', [1 10 20 30 40 50]); 15 | hardTrainFn = @train_hard; 16 | BPMP = 2; 17 | 18 | elseif strcmp(dataset, 'vid') 19 | ip.addParameter('Ks', [1 5]); 20 | hardTrainFn = @train_hard_vid; 21 | BPMP = 0; 22 | 23 | else 24 | error('dataset not yet supported'); 25 | end 26 | 27 | ip.KeepUnmatched = true; 28 | ip.parse(varargin{:}); 29 | opts = ip.Results; 30 | opts = get_opts(opts, dataset, varargin{:}); 31 | 32 | % post-parsing 33 | cleanupObj = onCleanup(@() cleanup(opts.gpus)); 34 | 35 | opts = get_opts(opts); % carry out all post-processing on opts 36 | record_diary(opts); 37 | disp(opts); 38 | 39 | % ---------------------------------------- 40 | % model & data 41 | % ---------------------------------------- 42 | [net, opts] = get_model(opts); 43 | 44 | global imdb 45 | imdb = get_imdb(imdb, opts); 46 | disp(imdb.images) 47 | 48 | % use imagenet-pretrained model 49 | imgSize = opts.imageSize; 50 | meanImage = single(net.meta.normalization.averageImage); 51 | batchFunc = @(I, B) batch_imagenet(I, B, imgSize, meanImage); 52 | 53 | % ---------------------------------------- 54 | % train 55 | % ---------------------------------------- 56 | % figure out solver & learning rate 57 | if strcmp(opts.solver, 'sgd') 58 | solverFunc = []; 59 | if opts.lrdecay>0 & opts.lrdecay<1 60 | cur_lr = opts.lr; 61 | lrvec = []; 62 | while length(lrvec) < opts.epoch 63 | lrvec = [lrvec, ones(1, opts.lrstep)*cur_lr]; 64 | cur_lr = cur_lr * opts.lrdecay; 65 | end 66 | elseif opts.lrdecay > 1 67 | % linear decay, lrdecay specifies the # epoch -> 0 68 | assert(mod(opts.lrdecay, 1) == 0); 69 | lrvec = linspace(opts.lr, 0, opts.lrdecay+1); 70 | opts.epoch = min(opts.epoch, opts.lrdecay); 71 | end 72 | else 73 | solverFunc = str2func(['solver.' opts.solver]); 74 | lrvec = opts.lr; 75 | end 76 | 77 | if opts.ablation == 3 78 | % random minibatch sampling 79 | [net, info] = train_rand(net, imdb, batchFunc , ... 80 | 'saveInterval' , opts.testInterval , ... 81 | 'plotStatistics' , opts.plot , ... 82 | 'randomSeed' , opts.randseed , ... 83 | 'gpus' , opts.gpus , ... 84 | 'continue' , opts.continue , ... 85 | 'expDir' , opts.expDir , ... 86 | 'batchSize' , opts.batchSize , ... 87 | 'weightDecay' , opts.wdecay , ... 88 | 'numEpochs' , opts.epoch , ... 89 | 'learningRate' , lrvec , ... 90 | 'train' , imdb.images.PV , ... 91 | 'val' , NaN , ... 92 | 'solver' , solverFunc , ... 93 | 'postEpochFn' , @postepoch); 94 | else 95 | % hard minibatch sampling 96 | [net, info] = hardTrainFn(net, imdb, batchFunc , ... 97 | 'saveInterval' , opts.testInterval , ... 98 | 'plotStatistics' , opts.plot , ... 99 | 'randomSeed' , opts.randseed , ... 100 | 'gpus' , opts.gpus , ... 101 | 'continue' , opts.continue , ... 102 | 'expDir' , opts.expDir , ... 103 | 'batchSize' , opts.batchSize , ... 104 | 'weightDecay' , opts.wdecay , ... 105 | 'numEpochs' , opts.epoch , ... 106 | 'learningRate' , lrvec , ... 107 | 'val' , NaN , ... 108 | 'solver' , solverFunc , ... 109 | 'postEpochFn' , @postepoch , ... 110 | 'batchesPerMetaPair' , BPMP); 111 | end 112 | 113 | % ---------------------------------------- 114 | % done 115 | % ---------------------------------------- 116 | net.reset(); 117 | net.move('cpu'); 118 | diary('off'); 119 | 120 | end 121 | 122 | % ==================================================================== 123 | % get IMDB 124 | % ==================================================================== 125 | 126 | function imdb = get_imdb(imdb, opts) 127 | imdbName = sprintf('%s_%d', opts.dataset, opts.imageSize); 128 | logInfo('IMDB: %s', imdbName); 129 | 130 | if ~isempty(imdb) && strcmp(imdb.name, imdbName) 131 | return; 132 | end 133 | 134 | % load/compute 135 | t0 = tic; 136 | imdbFunc = str2func(['imdb.' opts.dataset]); 137 | imdb = imdbFunc(['imdb_' imdbName], opts) ; 138 | imdb.name = imdbName; 139 | 140 | if opts.ablation == 3 141 | % random MBS: shuffle training instances 142 | imdb.images.PV = shuffle(imdb, opts); 143 | else 144 | % hard MBS: group by meta-class 145 | if ~isfield(imdb, 'metaclass') || isempty(imdb.metaclass) 146 | logInfo('Analyzing meta-classes...'); tic; 147 | 148 | Ycls = imdb.images.labels(imdb.images.set==1, 1); 149 | Ymeta = imdb.images.labels(imdb.images.set==1, 2); 150 | Umeta = unique(Ymeta); 151 | imdb.metaclass = cell(1, numel(Umeta)); 152 | 153 | for i = 1:numel(Umeta) 154 | % group instances in this metaclass into cells 155 | Im = find(Ymeta == Umeta(i)); % images in this metaclass 156 | Ic = unique(Ycls(Im)); % instances in this metaclass 157 | G = arrayfun(@(x) find(Ycls==x & Ymeta==Umeta(i)), Ic, 'uniform', 0); 158 | 159 | % store in imdb 160 | imdb.metaclass{i} = []; 161 | imdb.metaclass{i}.id = Umeta(i); 162 | imdb.metaclass{i}.num = numel(Im); 163 | imdb.metaclass{i}.groups = G; 164 | end 165 | toc; 166 | end 167 | end 168 | 169 | logInfo('%s in %.2f sec', imdbName, toc(t0)); 170 | end 171 | 172 | % ==================================================================== 173 | % postprocessing after each epoch 174 | % ==================================================================== 175 | 176 | function [lr, params] = postepoch(net, params, state) 177 | lr = []; 178 | if isa(net, 'composite') 179 | net = net{1}; 180 | end 181 | opts = net.layers(end).block.opt; 182 | epoch = params.epoch; 183 | logInfo(opts.expID); 184 | logInfo(opts.identifier); 185 | logInfo(char(datetime)); 186 | 187 | if ~isempty(opts.gpus) 188 | [~, name] = unix('hostname'); 189 | logInfo('GPU #%d on %s', opts.gpus, name); 190 | end 191 | 192 | % evaluate 193 | if (epoch==1) | ~mod(epoch, opts.testInterval); 194 | evaluate_model(net, params.imdb, params.getBatch, opts); 195 | end 196 | 197 | % reshuffle 198 | if isfield(opts, 'ablation') && opts.ablation == 3 199 | params.train = shuffle(params.imdb, opts); 200 | end 201 | diary off, diary on 202 | end 203 | 204 | 205 | function PV = shuffle(imdb, opts) 206 | % group by class label, then shuffle 207 | Itrain = find(imdb.images.set == 1); 208 | Ytrain = imdb.images.labels(Itrain, 1); 209 | C = arrayfun(@(y) find(Ytrain==y), unique(Ytrain), 'uniform', false); 210 | PV = []; 211 | 212 | % go through classes in random order 213 | for j = randperm(numel(C)) 214 | % randomize the instances 215 | ind = C{j}; 216 | ind = ind(randperm(length(ind))); 217 | PV = [PV; ind]; 218 | end 219 | 220 | assert(length(PV) == length(Itrain)); 221 | PV = Itrain(PV); 222 | end 223 | 224 | % ==================================================================== 225 | % misc 226 | % ==================================================================== 227 | 228 | function cleanup(gpuid) 229 | diary('off'); 230 | end 231 | -------------------------------------------------------------------------------- /matlab/startup.m: -------------------------------------------------------------------------------- 1 | % util 2 | addpath('util'); 3 | 4 | % mink() 5 | if ~exist('mink', 'builtin') 6 | warning('Please use Matlab R2017b or newer, or follow the README to download the MinMaxSelection implementation.'); 7 | end 8 | 9 | % matconvnet 10 | logInfo('setting up [MatConvNet]'); 11 | run ./matconvnet/matlab/vl_setupnn 12 | 13 | % mcn extra layers 14 | logInfo('setting up [mcnExtraLayers]'); 15 | vl_contrib setup mcnExtraLayers 16 | 17 | % autonn 18 | logInfo('setting up [autonn]'); 19 | vl_contrib setup autonn 20 | 21 | logInfo('Done!'); 22 | -------------------------------------------------------------------------------- /matlab/train_hard.m: -------------------------------------------------------------------------------- 1 | function [net,stats] = train_hard(net, imdb, getBatch, varargin) 2 | %CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper 3 | % CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with 4 | % the DagNN wrapper instead of the SimpleNN wrapper. 5 | 6 | % Copyright (C) 2014-16 Andrea Vedaldi. 7 | % All rights reserved. 8 | % 9 | % This file is part of the VLFeat library and is made available under 10 | % the terms of the BSD license (see the COPYING file). 11 | addpath(fullfile(vl_rootnn, 'examples')); 12 | 13 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%% 14 | opts.saveInterval = 2; 15 | opts.batchesPerMetaPair = 5; 16 | opts.maxGpuImgs = Inf; 17 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%% 18 | 19 | opts.expDir = fullfile('data','exp') ; 20 | opts.continue = true ; 21 | opts.batchSize = 256 ; 22 | opts.numSubBatches = 1 ; 23 | opts.train = [] ; 24 | opts.val = [] ; 25 | opts.gpus = [] ; 26 | opts.prefetch = false ; 27 | opts.epochSize = inf; 28 | opts.numEpochs = 300 ; 29 | opts.learningRate = 0.001 ; 30 | opts.weightDecay = 0.0005 ; 31 | 32 | opts.solver = [] ; % Empty array means use the default SGD solver 33 | [opts, varargin] = vl_argparse(opts, varargin) ; 34 | if ~isempty(opts.solver) 35 | assert(isa(opts.solver, 'function_handle') && nargout(opts.solver) == 2,... 36 | 'Invalid solver; expected a function handle with two outputs.') ; 37 | % Call without input arguments, to get default options 38 | opts.solverOpts = opts.solver() ; 39 | end 40 | 41 | opts.momentum = 0.9 ; 42 | opts.saveSolverState = true ; 43 | opts.nesterovUpdate = false ; 44 | opts.randomSeed = 0 ; 45 | opts.profile = false ; 46 | opts.parameterServer.method = 'mmap' ; 47 | opts.parameterServer.prefix = 'mcn' ; 48 | 49 | opts.derOutputs = {'objective', 1} ; 50 | opts.extractStatsFn = @extractStats ; 51 | opts.plotStatistics = true; 52 | opts.postEpochFn = [] ; % postEpochFn(net,params,state) called after each epoch; can return a new learning rate, 0 to stop, [] for no change 53 | opts = vl_argparse(opts, varargin) ; 54 | 55 | if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end 56 | if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end 57 | if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end 58 | if isscalar(opts.train) && isnumeric(opts.train) && isnan(opts.train) 59 | opts.train = [] ; 60 | end 61 | if isscalar(opts.val) && isnumeric(opts.val) && isnan(opts.val) 62 | opts.val = [] ; 63 | end 64 | 65 | % ------------------------------------------------------------------------- 66 | % Initialization 67 | % ------------------------------------------------------------------------- 68 | 69 | evaluateMode = isempty(opts.train) ; 70 | if ~evaluateMode 71 | if isempty(opts.derOutputs) 72 | error('DEROUTPUTS must be specified when training.\n') ; 73 | end 74 | end 75 | 76 | % ------------------------------------------------------------------------- 77 | % Train and validate 78 | % ------------------------------------------------------------------------- 79 | 80 | modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep)); 81 | modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ; 82 | 83 | start = opts.continue * findLastCheckpoint(opts.expDir) ; 84 | if start >= 1 85 | fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ; 86 | % [KH] make sure to use the input opts, saved maybe outdated 87 | opt = net.layers(end).block.opt; 88 | [net, state, stats] = loadState(modelPath(start)) ; 89 | net.layers(end).block.opt = opt; 90 | else 91 | state = [] ; 92 | end 93 | 94 | for epoch=start+1:opts.numEpochs 95 | 96 | % Set the random seed based on the epoch and opts.randomSeed. 97 | % This is important for reproducibility, including when training 98 | % is restarted from a checkpoint. 99 | 100 | rng(epoch + opts.randomSeed) ; 101 | prepareGPUs(opts, epoch == start+1) ; 102 | 103 | % Train for one epoch. 104 | params = opts ; 105 | params.epoch = epoch ; 106 | params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ; 107 | params.train = opts.train; %(randperm(numel(opts.train))) ; % shuffle 108 | params.train = params.train; %(1:min(opts.epochSize, numel(opts.train))); 109 | params.val = opts.val; %(randperm(numel(opts.val))) ; 110 | params.imdb = imdb ; 111 | params.getBatch = getBatch ; 112 | 113 | if numel(opts.gpus) <= 1 114 | [net, state] = processEpoch(net, state, params, 'train') ; 115 | [net, state] = processEpoch(net, state, params, 'val') ; 116 | if ~evaluateMode && ~mod(epoch, opts.saveInterval) 117 | saveState(modelPath(epoch), net, state) ; 118 | end 119 | lastStats = state.stats ; 120 | else 121 | spmd 122 | [net, state] = processEpoch(net, state, params, 'train') ; 123 | [net, state] = processEpoch(net, state, params, 'val') ; 124 | if labindex == 1 && ~evaluateMode && ~mod(epoch, opts.saveInterval) 125 | saveState(modelPath(epoch), net, state) ; 126 | end 127 | lastStats = state.stats ; 128 | end 129 | lastStats = accumulateStats(lastStats) ; 130 | end 131 | 132 | stats.train(epoch) = lastStats.train ; 133 | stats.val(epoch) = lastStats.val ; 134 | clear lastStats ; 135 | if ~mod(epoch, opts.saveInterval) 136 | saveStats(modelPath(epoch), stats) ; 137 | end 138 | 139 | if opts.plotStatistics 140 | switchFigure(1) ; clf ; 141 | plots = setdiff(... 142 | cat(2,... 143 | fieldnames(stats.train)', ... 144 | fieldnames(stats.val)'), {'num', 'time'}) ; 145 | for p = plots 146 | p = char(p) ; 147 | values = zeros(0, epoch) ; 148 | leg = {} ; 149 | for f = {'train', 'val'} 150 | f = char(f) ; 151 | if isfield(stats.(f), p) 152 | tmp = [stats.(f).(p)] ; 153 | values(end+1,:) = tmp(1,:)' ; 154 | leg{end+1} = f ; 155 | end 156 | end 157 | subplot(1,numel(plots),find(strcmp(p,plots))) ; 158 | plot(1:epoch, values','o-') ; 159 | xlabel('epoch') ; 160 | title(p) ; 161 | legend(leg{:}) ; 162 | grid on ; 163 | end 164 | drawnow ; 165 | print(1, modelFigPath, '-dpdf') ; 166 | end 167 | 168 | if ~isempty(opts.postEpochFn) 169 | if nargout(opts.postEpochFn) == 0 170 | opts.postEpochFn(net, params, state) ; 171 | else 172 | [lr, params] = opts.postEpochFn(net, params, state) ; 173 | if ~isempty(lr), opts.learningRate = lr; end 174 | if opts.learningRate == 0, break; end 175 | opts.train = params.train; 176 | end 177 | end 178 | end 179 | logInfo('Complete: %d epochs.', opts.numEpochs); 180 | if ~isempty(opts.postEpochFn) 181 | params = opts; 182 | params.imdb = imdb; 183 | params.epoch = params.numEpochs; 184 | params.getBatch = getBatch; 185 | opts.postEpochFn(net, params, state) ; 186 | end 187 | 188 | % With multiple GPUs, return one copy 189 | if isa(net, 'Composite'), net = net{1} ; end 190 | 191 | 192 | % ------------------------------------------------------------------------- 193 | function [net, state] = processEpoch(net, state, params, mode) 194 | % ------------------------------------------------------------------------- 195 | % Note that net is not strictly needed as an output argument as net 196 | % is a handle class. However, this fixes some aliasing issue in the 197 | % spmd caller. 198 | 199 | % initialize with momentum 0 200 | if isempty(state) || isempty(state.solverState) 201 | state.solverState = cell(1, numel(net.params)) ; 202 | state.solverState(:) = {0} ; 203 | end 204 | 205 | % move CNN to GPU as needed 206 | numGpus = numel(params.gpus) ; 207 | if numGpus >= 1 208 | net.move('gpu') ; 209 | for i = 1:numel(state.solverState) 210 | s = state.solverState{i} ; 211 | if isnumeric(s) 212 | state.solverState{i} = gpuArray(s) ; 213 | elseif isstruct(s) 214 | state.solverState{i} = structfun(@gpuArray, s, 'UniformOutput', false) ; 215 | end 216 | end 217 | end 218 | if numGpus > 1 219 | parserv = ParameterServer(params.parameterServer) ; 220 | net.setParameterServer(parserv) ; 221 | else 222 | parserv = [] ; 223 | end 224 | 225 | % profile 226 | if params.profile 227 | if numGpus <= 1 228 | profile clear ; 229 | profile on ; 230 | else 231 | mpiprofile reset ; 232 | mpiprofile on ; 233 | end 234 | end 235 | 236 | num = 0 ; 237 | epoch = params.epoch ; 238 | subset = params.(mode) ; 239 | adjustTime = 0 ; 240 | 241 | stats.num = 0 ; % return something even if subset = [] 242 | stats.time = 0 ; 243 | 244 | % ----------------------------------------- 245 | % ----------------------------------------- 246 | N = numel(subset); 247 | if N > 0 248 | logInfo('Epoch %d/%d: %s LR=%d', epoch, params.numEpochs, ... 249 | func2str(params.solver), params.learningRate); 250 | 251 | % [KH] preprocess: group by meta-class 252 | metaC = params.imdb.metaclass; 253 | Nmeta = numel(metaC); 254 | metaC_idx = cell(1, Nmeta); 255 | for i = 1:Nmeta 256 | G = metaC{i}.groups; 257 | G = G(randperm(numel(G))); 258 | metaC_idx{i} = cat(1, G{:}); 259 | end 260 | batchSize = params.batchSize; 261 | assert(mod(batchSize, 2) == 0); 262 | metaPairs = []; 263 | for m = 1:Nmeta 264 | %metaPairs = [metaPairs; m*ones(Nmeta-m+1, 1), (m:Nmeta)']; 265 | metaPairs = [metaPairs; m*ones(Nmeta-m, 1), (m+1:Nmeta)']; 266 | end 267 | numPairs = size(metaPairs, 1); 268 | metaPairs = metaPairs(randperm(numPairs), :); % opt 269 | numBatches = numPairs * params.batchesPerMetaPair; 270 | logInfo('%d meta-classes, %d batches per meta-class pair', Nmeta, ... 271 | params.batchesPerMetaPair); 272 | 273 | netopts = net.layers(end).block.opt; 274 | if strcmp(mode, 'train') 275 | if params.batchSize > netopts.maxGpuImgs 276 | mode = 'train_staged'; 277 | else 278 | net.mode = 'normal' ; 279 | net.conserveMemory = true; 280 | end 281 | else 282 | net.mode = 'test' ; 283 | end 284 | 285 | start = tic; a = tic; 286 | objs = []; 287 | cur = zeros(1, Nmeta); 288 | t = 0; 289 | for p = 1:numPairs 290 | % get 2 meta-classes 291 | m1 = metaPairs(p, 1); 292 | m2 = metaPairs(p, 2); 293 | for k = 1:params.batchesPerMetaPair 294 | if m1 == m2 295 | I = cur(m1) + (1:batchSize); 296 | I = mod(I-1, metaC{m1}.num) + 1; 297 | cur(m1) = I(end); 298 | batch = metaC_idx{m1}(I); 299 | else 300 | % get equal # of imgs from both meta-class 301 | I1 = cur(m1) + (1:batchSize/2); 302 | I2 = cur(m2) + (1:batchSize/2); 303 | I1 = mod(I1-1, metaC{m1}.num) + 1; 304 | I2 = mod(I2-1, metaC{m2}.num) + 1; 305 | cur(m1) = I1(end); 306 | cur(m2) = I2(end); 307 | batch = [metaC_idx{m1}(I1); metaC_idx{m2}(I2)]; 308 | end 309 | t = t + batchSize; 310 | 311 | % train 1 batch 312 | if strcmp(mode, 'train') 313 | % regular train mode (batchSize within GPU limit) 314 | inputs = params.getBatch(params.imdb, batch) ; 315 | net.eval(inputs, params.derOutputs) ; 316 | 317 | if ~isempty(parserv), parserv.sync() ; end 318 | state = accumulateGradients(net, state, params, batchSize, parserv) ; 319 | 320 | elseif strcmp(mode, 'train_staged') 321 | % staged backprop mode 322 | % stage 1: get embedding matrix, in chunks 323 | net.mode = 'test'; 324 | ind = net.getVarIndex('feats_l2'); 325 | net.vars(ind).precious = 1; 326 | 327 | labelMat = []; % label matrix 328 | embeddingMat = zeros(1, 1, netopts.dim, batchSize); % embedding matrix 329 | if numel(params.gpus) > 0 330 | embeddingMat = gpuArray(embeddingMat); 331 | end 332 | 333 | subBatchSize = netopts.maxGpuImgs; 334 | nSub = ceil(batchSize / subBatchSize); 335 | inputsCache = cell(1, nSub); 336 | for s = 1:nSub 337 | sub = (s-1)*subBatchSize+1 : min(batchSize, s*subBatchSize); 338 | subInputs = params.getBatch(params.imdb, batch(sub)); 339 | inputsCache{s} = subInputs; 340 | 341 | net.eval(subInputs); 342 | embeddingMat(:, :, :, sub) = gather(net.vars(ind).value); 343 | labelMat = [labelMat; subInputs{4}]; 344 | end 345 | 346 | % stage 2: compute gradient matrix 347 | [gradientMat, obj_staged] = net.layers(end).block.computeGrad(... 348 | embeddingMat, labelMat); 349 | 350 | % stage 3: accumulate gradients 351 | net.mode = 'normal'; 352 | for s = 1:nSub 353 | sub = (s-1)*subBatchSize+1 : min(batchSize, s*subBatchSize); 354 | featsGrad = gradientMat(:, :, :, sub); 355 | net.eval(inputsCache{s}, {'feats_l2', featsGrad}); 356 | 357 | if ~isempty(parserv), parserv.sync() ; end 358 | state = accumulateGradients(net, state, params, numel(sub), parserv); 359 | end 360 | else 361 | % eval mode 362 | net.eval(inputs) ; 363 | end 364 | 365 | % Get statistics. 366 | time = toc(start) + adjustTime ; 367 | batchTime = time - stats.time ; 368 | stats.num = num ; 369 | stats.time = time ; 370 | stats = params.extractStatsFn(stats,net) ; 371 | currentSpeed = batchSize / batchTime ; 372 | averageSpeed = t / time ; 373 | if strcmp(mode, 'train_staged') 374 | objs = [objs, obj_staged]; 375 | else 376 | objs = [objs, stats.objective]; 377 | end 378 | if toc(a) > 30 379 | %if strcmp(mode, 'train_staged'), fprintf('\n'); end 380 | fprintf('%s-%s ep%02d: %3d/%3d (%d) %.1fHz', ... 381 | net.layers(end).block.opt.dataset, mode, epoch, ... 382 | fix(t/batchSize), numBatches, batchSize, averageSpeed) ; 383 | if strcmp(mode, 'train_staged') 384 | fprintf(' obj: %.3f (%.3f)\n', obj_staged, mean(objs(~isnan(objs)))); 385 | else 386 | fprintf(' obj: %.3f (%.3f)\n', stats.objective, mean(objs(~isnan(objs)))); 387 | end 388 | a = tic; 389 | end 390 | end % for k 391 | end % for p 392 | if strcmp(mode, 'train_staged') 393 | mode = 'train'; 394 | end 395 | logInfo('Epoch %d, Avg %s obj = %g', epoch, mode, mean(objs(~isnan(objs)))); 396 | end % if N>0 397 | 398 | % Save back to state. 399 | state.stats.(mode) = stats ; 400 | if params.profile 401 | if numGpus <= 1 402 | state.prof.(mode) = profile('info') ; 403 | profile off ; 404 | else 405 | state.prof.(mode) = mpiprofile('info'); 406 | mpiprofile off ; 407 | end 408 | end 409 | if ~params.saveSolverState 410 | state.solverState = [] ; 411 | else 412 | for i = 1:numel(state.solverState) 413 | s = state.solverState{i} ; 414 | if isnumeric(s) 415 | state.solverState{i} = gather(s) ; 416 | elseif isstruct(s) 417 | state.solverState{i} = structfun(@gather, s, 'UniformOutput', false) ; 418 | end 419 | end 420 | end 421 | 422 | net.reset() ; 423 | net.move('cpu') ; 424 | 425 | % ------------------------------------------------------------------------- 426 | function state = accumulateGradients(net, state, params, batchSize, parserv) 427 | % ------------------------------------------------------------------------- 428 | numGpus = numel(params.gpus) ; 429 | otherGpus = setdiff(1:numGpus, labindex) ; 430 | 431 | for p=1:numel(net.params) 432 | 433 | if ~isempty(parserv) 434 | parDer = parserv.pullWithIndex(p) ; 435 | else 436 | parDer = net.params(p).der ; 437 | end 438 | 439 | switch net.params(p).trainMethod 440 | case 'average' % mainly for batch normalization 441 | thisLR = net.params(p).learningRate ; 442 | net.params(p).value = vl_taccum(... 443 | 1 - thisLR, net.params(p).value, ... 444 | (thisLR/batchSize/net.params(p).fanout), parDer) ; 445 | 446 | case 'gradient' 447 | thisDecay = params.weightDecay * net.params(p).weightDecay ; 448 | thisLR = params.learningRate * net.params(p).learningRate ; 449 | 450 | if thisLR>0 || thisDecay>0 451 | % Normalize gradient and incorporate weight decay. 452 | parDer = vl_taccum(1/batchSize, parDer, ... 453 | thisDecay, net.params(p).value) ; 454 | 455 | if isempty(params.solver) 456 | % Default solver is the optimised SGD. 457 | % Update momentum. 458 | state.solverState{p} = vl_taccum(... 459 | params.momentum, state.solverState{p}, ... 460 | -1, parDer) ; 461 | 462 | % Nesterov update (aka one step ahead). 463 | if params.nesterovUpdate 464 | delta = params.momentum * state.solverState{p} - parDer ; 465 | else 466 | delta = state.solverState{p} ; 467 | end 468 | 469 | % Update parameters. 470 | net.params(p).value = vl_taccum(... 471 | 1, net.params(p).value, thisLR, delta) ; 472 | 473 | else 474 | % call solver function to update weights 475 | [net.params(p).value, state.solverState{p}] = ... 476 | params.solver(net.params(p).value, state.solverState{p}, ... 477 | parDer, params.solverOpts, thisLR) ; 478 | end 479 | end 480 | otherwise 481 | error('Unknown training method ''%s'' for parameter ''%s''.', ... 482 | net.params(p).trainMethod, ... 483 | net.params(p).name) ; 484 | end 485 | end 486 | 487 | % ------------------------------------------------------------------------- 488 | function stats = accumulateStats(stats_) 489 | % ------------------------------------------------------------------------- 490 | for s = {'train', 'val'} 491 | s = char(s) ; 492 | total = 0 ; 493 | 494 | % initialize stats stucture with same fields and same order as 495 | % stats_{1} 496 | stats__ = stats_{1} ; 497 | names = fieldnames(stats__.(s))' ; 498 | values = zeros(1, numel(names)) ; 499 | fields = cat(1, names, num2cell(values)) ; 500 | stats.(s) = struct(fields{:}) ; 501 | 502 | for g = 1:numel(stats_) 503 | stats__ = stats_{g} ; 504 | num__ = stats__.(s).num ; 505 | total = total + num__ ; 506 | 507 | for f = setdiff(fieldnames(stats__.(s))', 'num') 508 | f = char(f) ; 509 | stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ; 510 | 511 | if g == numel(stats_) 512 | stats.(s).(f) = stats.(s).(f) / total ; 513 | end 514 | end 515 | end 516 | stats.(s).num = total ; 517 | end 518 | 519 | % ------------------------------------------------------------------------- 520 | function stats = extractStats(stats, net) 521 | % ------------------------------------------------------------------------- 522 | sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ; 523 | for i = 1:numel(sel) 524 | if net.layers(sel(i)).block.ignoreAverage, continue; end; 525 | stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ; 526 | end 527 | 528 | % ------------------------------------------------------------------------- 529 | function saveState(fileName, net_, state) 530 | % ------------------------------------------------------------------------- 531 | net = net_.saveobj() ; 532 | save(fileName, 'net', 'state') ; 533 | logInfo('saved: %s', fileName); 534 | 535 | % ------------------------------------------------------------------------- 536 | function saveStats(fileName, stats) 537 | % ------------------------------------------------------------------------- 538 | if exist(fileName) 539 | try, save(fileName, 'stats', '-append') ; end 540 | else 541 | try, save(fileName, 'stats') ; end 542 | end 543 | 544 | % ------------------------------------------------------------------------- 545 | function [net, state, stats] = loadState(fileName) 546 | % ------------------------------------------------------------------------- 547 | load(fileName, 'net', 'state', 'stats') ; 548 | net = dagnn.DagNN.loadobj(net) ; 549 | if isempty(whos('stats')) 550 | warning('Epoch ''%s'' was only partially saved. Delete this file and try again.', ... 551 | fileName) ; 552 | end 553 | 554 | % ------------------------------------------------------------------------- 555 | function epoch = findLastCheckpoint(modelDir) 556 | % ------------------------------------------------------------------------- 557 | list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ; 558 | tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ; 559 | epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ; 560 | epoch = max([epoch 0]) ; 561 | 562 | % ------------------------------------------------------------------------- 563 | function switchFigure(n) 564 | % ------------------------------------------------------------------------- 565 | if get(0,'CurrentFigure') ~= n 566 | try 567 | set(0,'CurrentFigure',n) ; 568 | catch 569 | figure(n) ; 570 | end 571 | end 572 | 573 | % ------------------------------------------------------------------------- 574 | function clearMex() 575 | % ------------------------------------------------------------------------- 576 | clear vl_tmove vl_imreadjpeg ; 577 | 578 | % ------------------------------------------------------------------------- 579 | function prepareGPUs(opts, cold) 580 | % ------------------------------------------------------------------------- 581 | numGpus = numel(opts.gpus) ; 582 | if numGpus > 1 583 | % check parallel pool integrity as it could have timed out 584 | pool = gcp('nocreate') ; 585 | if ~isempty(pool) && pool.NumWorkers ~= numGpus 586 | delete(pool) ; 587 | end 588 | pool = gcp('nocreate') ; 589 | if isempty(pool) 590 | parpool('local', numGpus) ; 591 | cold = true ; 592 | end 593 | 594 | end 595 | if numGpus >= 1 && cold 596 | fprintf('%s: resetting GPU\n', mfilename) 597 | clearMex() ; 598 | if numGpus == 1 599 | gpuDevice(opts.gpus) 600 | else 601 | spmd 602 | clearMex() ; 603 | gpuDevice(opts.gpus(labindex)) 604 | end 605 | end 606 | end 607 | -------------------------------------------------------------------------------- /matlab/train_hard_vid.m: -------------------------------------------------------------------------------- 1 | function [net,stats] = train_hard_vid(net, imdb, getBatch, varargin) 2 | %CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper 3 | % CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with 4 | % the DagNN wrapper instead of the SimpleNN wrapper. 5 | 6 | % Copyright (C) 2014-16 Andrea Vedaldi. 7 | % All rights reserved. 8 | % 9 | % This file is part of the VLFeat library and is made available under 10 | % the terms of the BSD license (see the COPYING file). 11 | addpath(fullfile(vl_rootnn, 'examples')); 12 | 13 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%% 14 | opts.saveInterval = 2; 15 | opts.batchesPerMetaPair = 0; 16 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%% 17 | 18 | opts.expDir = fullfile('data','exp') ; 19 | opts.continue = true ; 20 | opts.batchSize = 256 ; 21 | opts.numSubBatches = 1 ; 22 | opts.train = [] ; 23 | opts.val = [] ; 24 | opts.gpus = [] ; 25 | opts.prefetch = false ; 26 | opts.epochSize = inf; 27 | opts.numEpochs = 300 ; 28 | opts.learningRate = 0.001 ; 29 | opts.weightDecay = 0.0005 ; 30 | 31 | opts.solver = [] ; % Empty array means use the default SGD solver 32 | [opts, varargin] = vl_argparse(opts, varargin) ; 33 | if ~isempty(opts.solver) 34 | assert(isa(opts.solver, 'function_handle') && nargout(opts.solver) == 2,... 35 | 'Invalid solver; expected a function handle with two outputs.') ; 36 | % Call without input arguments, to get default options 37 | opts.solverOpts = opts.solver() ; 38 | end 39 | 40 | opts.momentum = 0.9 ; 41 | opts.saveSolverState = true ; 42 | opts.nesterovUpdate = false ; 43 | opts.randomSeed = 0 ; 44 | opts.profile = false ; 45 | opts.parameterServer.method = 'mmap' ; 46 | opts.parameterServer.prefix = 'mcn' ; 47 | 48 | opts.derOutputs = {'objective', 1} ; 49 | opts.extractStatsFn = @extractStats ; 50 | opts.plotStatistics = true; 51 | opts.postEpochFn = [] ; % postEpochFn(net,params,state) called after each epoch; can return a new learning rate, 0 to stop, [] for no change 52 | opts = vl_argparse(opts, varargin) ; 53 | 54 | if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end 55 | if isempty(opts.train), opts.train = find(imdb.images.set(:,1)==1) ; end 56 | if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end 57 | if isscalar(opts.train) && isnumeric(opts.train) && isnan(opts.train) 58 | opts.train = [] ; 59 | end 60 | if isscalar(opts.val) && isnumeric(opts.val) && isnan(opts.val) 61 | opts.val = [] ; 62 | end 63 | 64 | % ------------------------------------------------------------------------- 65 | % Initialization 66 | % ------------------------------------------------------------------------- 67 | 68 | evaluateMode = isempty(opts.train) ; 69 | if ~evaluateMode 70 | if isempty(opts.derOutputs) 71 | error('DEROUTPUTS must be specified when training.\n') ; 72 | end 73 | end 74 | 75 | % ------------------------------------------------------------------------- 76 | % Train and validate 77 | % ------------------------------------------------------------------------- 78 | 79 | modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep)); 80 | modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ; 81 | 82 | start = opts.continue * findLastCheckpoint(opts.expDir) ; 83 | if start >= 1 84 | fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ; 85 | % [KH] make sure to use the input opts, saved maybe outdated 86 | opt = net.layers(end).block.opt; 87 | [net, state, stats] = loadState(modelPath(start)) ; 88 | net.layers(end).block.opt = opt; 89 | else 90 | state = [] ; 91 | end 92 | 93 | for epoch=start+1:opts.numEpochs 94 | 95 | % Set the random seed based on the epoch and opts.randomSeed. 96 | % This is important for reproducibility, including when training 97 | % is restarted from a checkpoint. 98 | 99 | rng(epoch + opts.randomSeed) ; 100 | prepareGPUs(opts, epoch == start+1) ; 101 | 102 | % Train for one epoch. 103 | params = opts ; 104 | params.epoch = epoch ; 105 | params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ; 106 | params.train = opts.train; %(randperm(numel(opts.train))) ; % shuffle 107 | params.train = params.train; %(1:min(opts.epochSize, numel(opts.train))); 108 | params.val = opts.val; %(randperm(numel(opts.val))) ; 109 | params.imdb = imdb ; 110 | params.getBatch = getBatch ; 111 | 112 | if numel(opts.gpus) <= 1 113 | [net, state] = processEpoch(net, state, params, 'train') ; 114 | [net, state] = processEpoch(net, state, params, 'val') ; 115 | if ~evaluateMode && ~mod(epoch, opts.saveInterval) 116 | saveState(modelPath(epoch), net, state) ; 117 | end 118 | lastStats = state.stats ; 119 | else 120 | spmd 121 | [net, state] = processEpoch(net, state, params, 'train') ; 122 | [net, state] = processEpoch(net, state, params, 'val') ; 123 | if labindex == 1 && ~evaluateMode && ~mod(epoch, opts.saveInterval) 124 | saveState(modelPath(epoch), net, state) ; 125 | end 126 | lastStats = state.stats ; 127 | end 128 | lastStats = accumulateStats(lastStats) ; 129 | end 130 | 131 | stats.train(epoch) = lastStats.train ; 132 | stats.val(epoch) = lastStats.val ; 133 | clear lastStats ; 134 | if ~mod(epoch, opts.saveInterval) 135 | saveStats(modelPath(epoch), stats) ; 136 | end 137 | 138 | if opts.plotStatistics 139 | switchFigure(1) ; clf ; 140 | plots = setdiff(... 141 | cat(2,... 142 | fieldnames(stats.train)', ... 143 | fieldnames(stats.val)'), {'num', 'time'}) ; 144 | for p = plots 145 | p = char(p) ; 146 | values = zeros(0, epoch) ; 147 | leg = {} ; 148 | for f = {'train', 'val'} 149 | f = char(f) ; 150 | if isfield(stats.(f), p) 151 | tmp = [stats.(f).(p)] ; 152 | values(end+1,:) = tmp(1,:)' ; 153 | leg{end+1} = f ; 154 | end 155 | end 156 | subplot(1,numel(plots),find(strcmp(p,plots))) ; 157 | plot(1:epoch, values','o-') ; 158 | xlabel('epoch') ; 159 | title(p) ; 160 | legend(leg{:}) ; 161 | grid on ; 162 | end 163 | drawnow ; 164 | print(1, modelFigPath, '-dpdf') ; 165 | end 166 | 167 | if ~isempty(opts.postEpochFn) 168 | if nargout(opts.postEpochFn) == 0 169 | opts.postEpochFn(net, params, state) ; 170 | else 171 | [lr, params] = opts.postEpochFn(net, params, state) ; 172 | if ~isempty(lr), opts.learningRate = lr; end 173 | if opts.learningRate == 0, break; end 174 | opts.train = params.train; 175 | end 176 | end 177 | end 178 | 179 | % With multiple GPUs, return one copy 180 | if isa(net, 'Composite'), net = net{1} ; end 181 | 182 | % ------------------------------------------------------------------------- 183 | function [net, state] = processEpoch(net, state, params, mode) 184 | % ------------------------------------------------------------------------- 185 | % Note that net is not strictly needed as an output argument as net 186 | % is a handle class. However, this fixes some aliasing issue in the 187 | % spmd caller. 188 | 189 | % initialize with momentum 0 190 | if isempty(state) || isempty(state.solverState) 191 | state.solverState = cell(1, numel(net.params)) ; 192 | state.solverState(:) = {0} ; 193 | end 194 | 195 | % move CNN to GPU as needed 196 | numGpus = numel(params.gpus) ; 197 | if numGpus >= 1 198 | net.move('gpu') ; 199 | for i = 1:numel(state.solverState) 200 | s = state.solverState{i} ; 201 | if isnumeric(s) 202 | state.solverState{i} = gpuArray(s) ; 203 | elseif isstruct(s) 204 | state.solverState{i} = structfun(@gpuArray, s, 'UniformOutput', false) ; 205 | end 206 | end 207 | end 208 | if numGpus > 1 209 | parserv = ParameterServer(params.parameterServer) ; 210 | net.setParameterServer(parserv) ; 211 | else 212 | parserv = [] ; 213 | end 214 | 215 | % profile 216 | if params.profile 217 | if numGpus <= 1 218 | profile clear ; 219 | profile on ; 220 | else 221 | mpiprofile reset ; 222 | mpiprofile on ; 223 | end 224 | end 225 | 226 | num = 0 ; 227 | epoch = params.epoch ; 228 | subset = params.(mode) ; 229 | adjustTime = 0 ; 230 | 231 | stats.num = 0 ; % return something even if subset = [] 232 | stats.time = 0 ; 233 | 234 | % ----------------------------------------- 235 | % ----------------------------------------- 236 | if numel(subset) > 0 237 | logInfo('Epoch %d/%d: %s LR=%d', epoch, params.numEpochs, ... 238 | func2str(params.solver), params.learningRate); 239 | 240 | % [KH] preprocess: group by meta-class 241 | metaC = params.imdb.metaclass; 242 | assert(metaC{1}.id == 0); % NOTE id==0 means missing metaclass label 243 | Nmeta = numel(metaC); 244 | metaC_idx = cell(1, Nmeta); 245 | 246 | batchSize = params.batchSize; 247 | assert(mod(batchSize, 2) == 0); 248 | numBatches = zeros(Nmeta, 1); 249 | for i = 1:Nmeta 250 | % randomize the groups (individual classes) within this metaclass 251 | G = metaC{i}.groups; 252 | G = G(randperm(numel(G))); 253 | metaC_idx{i} = cat(1, G{:}); 254 | if i > 1 255 | numBatches(i) = ceil(2*metaC{i}.num/batchSize); 256 | end 257 | end 258 | logInfo('%d meta-classes, %d batches', Nmeta-1, sum(numBatches)); 259 | 260 | netopts = net.layers(end).block.opt; 261 | if strcmp(mode, 'train') 262 | if params.batchSize > netopts.maxGpuImgs 263 | mode = 'train_staged'; 264 | else 265 | net.mode = 'normal' ; 266 | net.conserveMemory = true; 267 | end 268 | else 269 | net.mode = 'test' ; 270 | end 271 | 272 | start = tic; a = tic; 273 | objs = []; 274 | cur = zeros(1, Nmeta); 275 | t = 0; 276 | for m = 2:Nmeta, for p = 1:numBatches(m) 277 | % get half from metaclass #m, half from metaclass #1 (id:0) 278 | I1 = cur(1) + (1:batchSize/2); 279 | I2 = cur(m) + (1:batchSize/2); 280 | I1 = mod(I1-1, metaC{1}.num) + 1; 281 | I2 = mod(I2-1, metaC{m}.num) + 1; 282 | cur(1) = I1(end); 283 | cur(m) = I2(end); 284 | batch = [metaC_idx{1}(I1); metaC_idx{m}(I2)]; 285 | t = t + batchSize; 286 | 287 | % assemble batch 288 | inputs = params.getBatch(params.imdb, batch) ; 289 | if strcmp(mode, 'train') 290 | net.eval(inputs, params.derOutputs) ; 291 | 292 | if ~isempty(parserv), parserv.sync() ; end 293 | state = accumulateGradients(net, state, params, batchSize, parserv) ; 294 | 295 | elseif strcmp(mode, 'train_staged') 296 | % staged backprop mode 297 | % stage 1: get embedding matrix, in chunks 298 | net.mode = 'test'; 299 | ind = net.getVarIndex('feats_l2'); 300 | net.vars(ind).precious = 1; 301 | 302 | labelMat = []; % label matrix 303 | embeddingMat = zeros(1, 1, netopts.dim, batchSize); % embedding matrix 304 | if numel(params.gpus) > 0 305 | embeddingMat = gpuArray(embeddingMat); 306 | end 307 | 308 | subBatchSize = netopts.maxGpuImgs; 309 | nSub = ceil(batchSize / subBatchSize); 310 | inputsCache = cell(1, nSub); 311 | for s = 1:nSub 312 | sub = (s-1)*subBatchSize+1 : min(batchSize, s*subBatchSize); 313 | subInputs = params.getBatch(params.imdb, batch(sub)); 314 | inputsCache{s} = subInputs; 315 | 316 | net.eval(subInputs); 317 | embeddingMat(:, :, :, sub) = gather(net.vars(ind).value); 318 | labelMat = [labelMat; subInputs{4}]; 319 | end 320 | 321 | % stage 2: compute gradient matrix 322 | [gradientMat, obj_staged] = net.layers(end).block.computeGrad(... 323 | embeddingMat, labelMat); 324 | 325 | % stage 3: accumulate gradients 326 | net.mode = 'normal'; 327 | for s = 1:nSub 328 | sub = (s-1)*subBatchSize+1 : min(batchSize, s*subBatchSize); 329 | featsGrad = gradientMat(:, :, :, sub); 330 | net.eval(inputsCache{s}, {'feats_l2', featsGrad}); 331 | 332 | if ~isempty(parserv), parserv.sync() ; end 333 | state = accumulateGradients(net, state, params, numel(sub), parserv); 334 | end 335 | else 336 | % eval mode 337 | net.eval(inputs) ; 338 | end 339 | 340 | % Get statistics. 341 | time = toc(start) + adjustTime ; 342 | batchTime = time - stats.time ; 343 | stats.num = num ; 344 | stats.time = time ; 345 | stats = params.extractStatsFn(stats,net) ; 346 | averageSpeed = t / time ; 347 | objs = [objs, stats.objective]; 348 | if toc(a) > 30 349 | fprintf('%s-%s ep%02d: %3d/%3d (%d) %.1fHz', ... 350 | net.layers(end).block.opt.dataset, mode, epoch, ... 351 | fix(t/batchSize), sum(numBatches), batchSize, averageSpeed) ; 352 | fprintf(' obj: %.3f (%.3f)\n', stats.objective, mean(objs(~isnan(objs)))); 353 | a = tic; 354 | end 355 | end, end % for m, for p 356 | if strcmp(mode, 'train_staged') 357 | mode = 'train'; 358 | end 359 | logInfo('Epoch %d, Avg %s obj = %g', epoch, mode, mean(objs(~isnan(objs)))); 360 | end % if N>0 361 | 362 | % Save back to state. 363 | state.stats.(mode) = stats ; 364 | if params.profile 365 | if numGpus <= 1 366 | state.prof.(mode) = profile('info') ; 367 | profile off ; 368 | else 369 | state.prof.(mode) = mpiprofile('info'); 370 | mpiprofile off ; 371 | end 372 | end 373 | if ~params.saveSolverState 374 | state.solverState = [] ; 375 | else 376 | for i = 1:numel(state.solverState) 377 | s = state.solverState{i} ; 378 | if isnumeric(s) 379 | state.solverState{i} = gather(s) ; 380 | elseif isstruct(s) 381 | state.solverState{i} = structfun(@gather, s, 'UniformOutput', false) ; 382 | end 383 | end 384 | end 385 | 386 | net.reset() ; 387 | net.move('cpu') ; 388 | 389 | % ------------------------------------------------------------------------- 390 | function state = accumulateGradients(net, state, params, batchSize, parserv) 391 | % ------------------------------------------------------------------------- 392 | numGpus = numel(params.gpus) ; 393 | otherGpus = setdiff(1:numGpus, labindex) ; 394 | 395 | for p=1:numel(net.params) 396 | 397 | if ~isempty(parserv) 398 | parDer = parserv.pullWithIndex(p) ; 399 | else 400 | parDer = net.params(p).der ; 401 | end 402 | 403 | switch net.params(p).trainMethod 404 | case 'average' % mainly for batch normalization 405 | thisLR = net.params(p).learningRate ; 406 | net.params(p).value = vl_taccum(... 407 | 1 - thisLR, net.params(p).value, ... 408 | (thisLR/batchSize/net.params(p).fanout), parDer) ; 409 | 410 | case 'gradient' 411 | thisDecay = params.weightDecay * net.params(p).weightDecay ; 412 | thisLR = params.learningRate * net.params(p).learningRate ; 413 | 414 | if thisLR>0 || thisDecay>0 415 | % Normalize gradient and incorporate weight decay. 416 | try 417 | parDer = vl_taccum(1/batchSize, parDer, ... 418 | thisDecay, net.params(p).value) ; 419 | catch, net.params(p), keyboard, end 420 | 421 | if isempty(params.solver) 422 | % Default solver is the optimised SGD. 423 | % Update momentum. 424 | state.solverState{p} = vl_taccum(... 425 | params.momentum, state.solverState{p}, ... 426 | -1, parDer) ; 427 | 428 | % Nesterov update (aka one step ahead). 429 | if params.nesterovUpdate 430 | delta = params.momentum * state.solverState{p} - parDer ; 431 | else 432 | delta = state.solverState{p} ; 433 | end 434 | 435 | % Update parameters. 436 | net.params(p).value = vl_taccum(... 437 | 1, net.params(p).value, thisLR, delta) ; 438 | 439 | else 440 | % call solver function to update weights 441 | [net.params(p).value, state.solverState{p}] = ... 442 | params.solver(net.params(p).value, state.solverState{p}, ... 443 | parDer, params.solverOpts, thisLR) ; 444 | end 445 | end 446 | otherwise 447 | error('Unknown training method ''%s'' for parameter ''%s''.', ... 448 | net.params(p).trainMethod, ... 449 | net.params(p).name) ; 450 | end 451 | end 452 | 453 | % ------------------------------------------------------------------------- 454 | function stats = accumulateStats(stats_) 455 | % ------------------------------------------------------------------------- 456 | for s = {'train', 'val'} 457 | s = char(s) ; 458 | total = 0 ; 459 | 460 | % initialize stats stucture with same fields and same order as 461 | % stats_{1} 462 | stats__ = stats_{1} ; 463 | names = fieldnames(stats__.(s))' ; 464 | values = zeros(1, numel(names)) ; 465 | fields = cat(1, names, num2cell(values)) ; 466 | stats.(s) = struct(fields{:}) ; 467 | 468 | for g = 1:numel(stats_) 469 | stats__ = stats_{g} ; 470 | num__ = stats__.(s).num ; 471 | total = total + num__ ; 472 | 473 | for f = setdiff(fieldnames(stats__.(s))', 'num') 474 | f = char(f) ; 475 | stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ; 476 | 477 | if g == numel(stats_) 478 | stats.(s).(f) = stats.(s).(f) / total ; 479 | end 480 | end 481 | end 482 | stats.(s).num = total ; 483 | end 484 | 485 | % ------------------------------------------------------------------------- 486 | function stats = extractStats(stats, net) 487 | % ------------------------------------------------------------------------- 488 | sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ; 489 | for i = 1:numel(sel) 490 | if net.layers(sel(i)).block.ignoreAverage, continue; end; 491 | stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ; 492 | end 493 | 494 | % ------------------------------------------------------------------------- 495 | function saveState(fileName, net_, state) 496 | % ------------------------------------------------------------------------- 497 | net = net_.saveobj() ; 498 | save(fileName, 'net', 'state') ; 499 | logInfo('saved: %s', fileName); 500 | 501 | % ------------------------------------------------------------------------- 502 | function saveStats(fileName, stats) 503 | % ------------------------------------------------------------------------- 504 | if exist(fileName) 505 | try, save(fileName, 'stats', '-append') ; end 506 | else 507 | try, save(fileName, 'stats') ; end 508 | end 509 | 510 | % ------------------------------------------------------------------------- 511 | function [net, state, stats] = loadState(fileName) 512 | % ------------------------------------------------------------------------- 513 | load(fileName, 'net', 'state', 'stats') ; 514 | net = dagnn.DagNN.loadobj(net) ; 515 | if isempty(whos('stats')) 516 | warning('Epoch ''%s'' was only partially saved. Delete this file and try again.', ... 517 | fileName) ; 518 | end 519 | 520 | % ------------------------------------------------------------------------- 521 | function epoch = findLastCheckpoint(modelDir) 522 | % ------------------------------------------------------------------------- 523 | list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ; 524 | tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ; 525 | epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ; 526 | epoch = max([epoch 0]) ; 527 | 528 | % ------------------------------------------------------------------------- 529 | function switchFigure(n) 530 | % ------------------------------------------------------------------------- 531 | if get(0,'CurrentFigure') ~= n 532 | try 533 | set(0,'CurrentFigure',n) ; 534 | catch 535 | figure(n) ; 536 | end 537 | end 538 | 539 | % ------------------------------------------------------------------------- 540 | function clearMex() 541 | % ------------------------------------------------------------------------- 542 | clear vl_tmove vl_imreadjpeg ; 543 | 544 | % ------------------------------------------------------------------------- 545 | function prepareGPUs(opts, cold) 546 | % ------------------------------------------------------------------------- 547 | numGpus = numel(opts.gpus) ; 548 | if numGpus > 1 549 | % check parallel pool integrity as it could have timed out 550 | pool = gcp('nocreate') ; 551 | if ~isempty(pool) && pool.NumWorkers ~= numGpus 552 | delete(pool) ; 553 | end 554 | pool = gcp('nocreate') ; 555 | if isempty(pool) 556 | parpool('local', numGpus) ; 557 | cold = true ; 558 | end 559 | 560 | end 561 | if numGpus >= 1 && cold 562 | fprintf('%s: resetting GPU\n', mfilename) 563 | clearMex() ; 564 | if numGpus == 1 565 | gpuDevice(opts.gpus) 566 | else 567 | spmd 568 | clearMex() ; 569 | gpuDevice(opts.gpus(labindex)) 570 | end 571 | end 572 | end 573 | -------------------------------------------------------------------------------- /matlab/train_rand.m: -------------------------------------------------------------------------------- 1 | function [net,stats] = train_rand(net, imdb, getBatch, varargin) 2 | %CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper 3 | % CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with 4 | % the DagNN wrapper instead of the SimpleNN wrapper. 5 | 6 | % Copyright (C) 2014-16 Andrea Vedaldi. 7 | % All rights reserved. 8 | % 9 | % This file is part of the VLFeat library and is made available under 10 | % the terms of the BSD license (see the COPYING file). 11 | addpath(fullfile(vl_rootnn, 'examples')); 12 | 13 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%% 14 | opts.saveInterval = 2; 15 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%% 16 | 17 | opts.expDir = fullfile('data','exp') ; 18 | opts.continue = true ; 19 | opts.batchSize = 256 ; 20 | opts.numSubBatches = 1 ; 21 | opts.train = [] ; 22 | opts.val = [] ; 23 | opts.gpus = [] ; 24 | opts.prefetch = false ; 25 | opts.epochSize = inf; 26 | opts.numEpochs = 300 ; 27 | opts.learningRate = 0.001 ; 28 | opts.weightDecay = 0.0005 ; 29 | 30 | opts.solver = [] ; % Empty array means use the default SGD solver 31 | [opts, varargin] = vl_argparse(opts, varargin) ; 32 | if ~isempty(opts.solver) 33 | assert(isa(opts.solver, 'function_handle') && nargout(opts.solver) == 2,... 34 | 'Invalid solver; expected a function handle with two outputs.') ; 35 | % Call without input arguments, to get default options 36 | opts.solverOpts = opts.solver() ; 37 | end 38 | 39 | opts.momentum = 0.9 ; 40 | opts.saveSolverState = true ; 41 | opts.nesterovUpdate = false ; 42 | opts.randomSeed = 0 ; 43 | opts.profile = false ; 44 | opts.parameterServer.method = 'mmap' ; 45 | opts.parameterServer.prefix = 'mcn' ; 46 | 47 | opts.derOutputs = {'objective', 1} ; 48 | opts.extractStatsFn = @extractStats ; 49 | opts.plotStatistics = true; 50 | opts.postEpochFn = [] ; % postEpochFn(net,params,state) called after each epoch; can return a new learning rate, 0 to stop, [] for no change 51 | opts = vl_argparse(opts, varargin) ; 52 | 53 | if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end 54 | if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end 55 | if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end 56 | if isscalar(opts.train) && isnumeric(opts.train) && isnan(opts.train) 57 | opts.train = [] ; 58 | end 59 | if isscalar(opts.val) && isnumeric(opts.val) && isnan(opts.val) 60 | opts.val = [] ; 61 | end 62 | 63 | % ------------------------------------------------------------------------- 64 | % Initialization 65 | % ------------------------------------------------------------------------- 66 | 67 | evaluateMode = isempty(opts.train) ; 68 | if ~evaluateMode 69 | if isempty(opts.derOutputs) 70 | error('DEROUTPUTS must be specified when training.\n') ; 71 | end 72 | end 73 | 74 | % ------------------------------------------------------------------------- 75 | % Train and validate 76 | % ------------------------------------------------------------------------- 77 | 78 | modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep)); 79 | modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ; 80 | 81 | start = opts.continue * findLastCheckpoint(opts.expDir) ; 82 | if start >= 1 83 | fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ; 84 | % [KH] make sure to use the input opts, saved maybe outdated 85 | opt = net.layers(end).block.opt; 86 | [net, state, stats] = loadState(modelPath(start)) ; 87 | net.layers(end).block.opt = opt; 88 | else 89 | state = [] ; 90 | end 91 | 92 | for epoch=start+1:opts.numEpochs 93 | 94 | % Set the random seed based on the epoch and opts.randomSeed. 95 | % This is important for reproducibility, including when training 96 | % is restarted from a checkpoint. 97 | 98 | rng(epoch + opts.randomSeed) ; 99 | prepareGPUs(opts, epoch == start+1) ; 100 | 101 | % Train for one epoch. 102 | params = opts ; 103 | params.epoch = epoch ; 104 | params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ; 105 | params.train = opts.train; %(randperm(numel(opts.train))) ; % shuffle 106 | params.train = params.train; %(1:min(opts.epochSize, numel(opts.train))); 107 | params.val = opts.val; %(randperm(numel(opts.val))) ; 108 | params.imdb = imdb ; 109 | params.getBatch = getBatch ; 110 | 111 | if numel(opts.gpus) <= 1 112 | [net, state] = processEpoch(net, state, params, 'train') ; 113 | [net, state] = processEpoch(net, state, params, 'val') ; 114 | if ~evaluateMode && ~mod(epoch, opts.saveInterval) 115 | saveState(modelPath(epoch), net, state) ; 116 | end 117 | lastStats = state.stats ; 118 | else 119 | spmd 120 | [net, state] = processEpoch(net, state, params, 'train') ; 121 | [net, state] = processEpoch(net, state, params, 'val') ; 122 | if labindex == 1 && ~evaluateMode && ~mod(epoch, opts.saveInterval) 123 | saveState(modelPath(epoch), net, state) ; 124 | end 125 | lastStats = state.stats ; 126 | end 127 | lastStats = accumulateStats(lastStats) ; 128 | end 129 | 130 | stats.train(epoch) = lastStats.train ; 131 | stats.val(epoch) = lastStats.val ; 132 | clear lastStats ; 133 | saveStats(modelPath(epoch), stats) ; 134 | 135 | if opts.plotStatistics 136 | switchFigure(1) ; clf ; 137 | plots = setdiff(... 138 | cat(2,... 139 | fieldnames(stats.train)', ... 140 | fieldnames(stats.val)'), {'num', 'time'}) ; 141 | for p = plots 142 | p = char(p) ; 143 | values = zeros(0, epoch) ; 144 | leg = {} ; 145 | for f = {'train', 'val'} 146 | f = char(f) ; 147 | if isfield(stats.(f), p) 148 | tmp = [stats.(f).(p)] ; 149 | values(end+1,:) = tmp(1,:)' ; 150 | leg{end+1} = f ; 151 | end 152 | end 153 | subplot(1,numel(plots),find(strcmp(p,plots))) ; 154 | plot(1:epoch, values','o-') ; 155 | xlabel('epoch') ; 156 | title(p) ; 157 | legend(leg{:}) ; 158 | grid on ; 159 | end 160 | drawnow ; 161 | print(1, modelFigPath, '-dpdf') ; 162 | end 163 | 164 | if ~isempty(opts.postEpochFn) 165 | if nargout(opts.postEpochFn) == 0 166 | opts.postEpochFn(net, params, state) ; 167 | else 168 | [lr, params] = opts.postEpochFn(net, params, state) ; 169 | if ~isempty(lr), opts.learningRate = lr; end 170 | if opts.learningRate == 0, break; end 171 | opts.train = params.train; 172 | end 173 | end 174 | end 175 | 176 | % With multiple GPUs, return one copy 177 | if isa(net, 'Composite'), net = net{1} ; end 178 | 179 | % ------------------------------------------------------------------------- 180 | function [net, state] = processEpoch(net, state, params, mode) 181 | % ------------------------------------------------------------------------- 182 | % Note that net is not strictly needed as an output argument as net 183 | % is a handle class. However, this fixes some aliasing issue in the 184 | % spmd caller. 185 | 186 | % initialize with momentum 0 187 | if isempty(state) || isempty(state.solverState) 188 | state.solverState = cell(1, numel(net.params)) ; 189 | state.solverState(:) = {0} ; 190 | end 191 | 192 | % move CNN to GPU as needed 193 | numGpus = numel(params.gpus) ; 194 | if numGpus >= 1 195 | net.move('gpu') ; 196 | for i = 1:numel(state.solverState) 197 | s = state.solverState{i} ; 198 | if isnumeric(s) 199 | state.solverState{i} = gpuArray(s) ; 200 | elseif isstruct(s) 201 | state.solverState{i} = structfun(@gpuArray, s, 'UniformOutput', false) ; 202 | end 203 | end 204 | end 205 | if numGpus > 1 206 | parserv = ParameterServer(params.parameterServer) ; 207 | net.setParameterServer(parserv) ; 208 | else 209 | parserv = [] ; 210 | end 211 | 212 | % profile 213 | if params.profile 214 | if numGpus <= 1 215 | profile clear ; 216 | profile on ; 217 | else 218 | mpiprofile reset ; 219 | mpiprofile on ; 220 | end 221 | end 222 | 223 | num = 0 ; 224 | epoch = params.epoch ; 225 | subset = params.(mode) ; 226 | adjustTime = 0 ; 227 | 228 | stats.num = 0 ; % return something even if subset = [] 229 | stats.time = 0 ; 230 | 231 | if numel(subset) > 0 232 | logInfo('Epoch %d/%d: %s LR=%d', epoch, params.numEpochs, ... 233 | func2str(params.solver), params.learningRate); 234 | start = tic; a = tic; 235 | objs = []; 236 | for t=1:params.batchSize:numel(subset) 237 | batchSize = min(params.batchSize, numel(subset) - t + 1) ; 238 | 239 | for s=1:params.numSubBatches 240 | % get this image batch and prefetch the next 241 | batchStart = t + (labindex-1) + (s-1) * numlabs ; 242 | batchEnd = min(t+params.batchSize-1, numel(subset)) ; 243 | batch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ; 244 | num = num + numel(batch) ; 245 | if numel(batch) == 0, continue ; end 246 | 247 | inputs = params.getBatch(params.imdb, batch) ; 248 | 249 | if params.prefetch 250 | if s == params.numSubBatches 251 | batchStart = t + (labindex-1) + params.batchSize ; 252 | batchEnd = min(t+2*params.batchSize-1, numel(subset)) ; 253 | else 254 | batchStart = batchStart + numlabs ; 255 | end 256 | nextBatch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ; 257 | params.getBatch(params.imdb, nextBatch) ; 258 | end 259 | 260 | if strcmp(mode, 'train') 261 | net.mode = 'normal' ; 262 | net.conserveMemory = true; 263 | net.accumulateParamDers = (s ~= 1) ; 264 | net.eval(inputs, params.derOutputs, 'holdOn', s < params.numSubBatches) ; 265 | else 266 | net.mode = 'test' ; 267 | net.eval(inputs) ; 268 | end 269 | end 270 | 271 | % Accumulate gradient. 272 | if strcmp(mode, 'train') 273 | if ~isempty(parserv), parserv.sync() ; end 274 | state = accumulateGradients(net, state, params, batchSize, parserv) ; 275 | end 276 | 277 | % Get statistics. 278 | time = toc(start) + adjustTime ; 279 | batchTime = time - stats.time ; 280 | stats.num = num ; 281 | stats.time = time ; 282 | stats = params.extractStatsFn(stats,net) ; 283 | currentSpeed = batchSize / batchTime ; 284 | averageSpeed = (t + batchSize - 1) / time ; 285 | if t == 3*params.batchSize + 1 286 | % compensate for the first three iterations, which are outliers 287 | adjustTime = 4*batchTime - time ; 288 | stats.time = time + adjustTime ; 289 | end 290 | 291 | if toc(a) > 30 292 | fprintf('%s-%s ep%02d: %3d/%3d (%d):', ... 293 | net.layers(end).block.opt.dataset, mode, epoch, ... 294 | fix((t-1)/params.batchSize)+1, ceil(numel(subset)/params.batchSize), ... 295 | params.batchSize) ; 296 | %fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ; 297 | fprintf(' %.1fHz', averageSpeed) ; 298 | for f = setdiff(fieldnames(stats)', {'num', 'time'}) 299 | f = char(f) ; 300 | fprintf(' %s: %.3f', f, stats.(f)) ; 301 | end 302 | fprintf('\n') ; 303 | a = tic; 304 | end 305 | objs = [objs, stats.objective]; 306 | end 307 | logInfo('Epoch %d, Avg %s obj = %g', epoch, mode, mean(objs(~isnan(objs)))); 308 | end % if N>0 309 | 310 | % Save back to state. 311 | state.stats.(mode) = stats ; 312 | if params.profile 313 | if numGpus <= 1 314 | state.prof.(mode) = profile('info') ; 315 | profile off ; 316 | else 317 | state.prof.(mode) = mpiprofile('info'); 318 | mpiprofile off ; 319 | end 320 | end 321 | if ~params.saveSolverState 322 | state.solverState = [] ; 323 | else 324 | for i = 1:numel(state.solverState) 325 | s = state.solverState{i} ; 326 | if isnumeric(s) 327 | state.solverState{i} = gather(s) ; 328 | elseif isstruct(s) 329 | state.solverState{i} = structfun(@gather, s, 'UniformOutput', false) ; 330 | end 331 | end 332 | end 333 | 334 | net.reset() ; 335 | net.move('cpu') ; 336 | 337 | % ------------------------------------------------------------------------- 338 | function state = accumulateGradients(net, state, params, batchSize, parserv) 339 | % ------------------------------------------------------------------------- 340 | numGpus = numel(params.gpus) ; 341 | otherGpus = setdiff(1:numGpus, labindex) ; 342 | 343 | for p=1:numel(net.params) 344 | 345 | if ~isempty(parserv) 346 | parDer = parserv.pullWithIndex(p) ; 347 | else 348 | parDer = net.params(p).der ; 349 | end 350 | 351 | switch net.params(p).trainMethod 352 | case 'average' % mainly for batch normalization 353 | thisLR = net.params(p).learningRate ; 354 | net.params(p).value = vl_taccum(... 355 | 1 - thisLR, net.params(p).value, ... 356 | (thisLR/batchSize/net.params(p).fanout), parDer) ; 357 | 358 | case 'gradient' 359 | thisDecay = params.weightDecay * net.params(p).weightDecay ; 360 | thisLR = params.learningRate * net.params(p).learningRate ; 361 | 362 | if thisLR>0 || thisDecay>0 363 | % Normalize gradient and incorporate weight decay. 364 | try 365 | parDer = vl_taccum(1/batchSize, parDer, ... 366 | thisDecay, net.params(p).value) ; 367 | catch, net.params(p), keyboard, end 368 | 369 | if isempty(params.solver) 370 | % Default solver is the optimised SGD. 371 | % Update momentum. 372 | state.solverState{p} = vl_taccum(... 373 | params.momentum, state.solverState{p}, ... 374 | -1, parDer) ; 375 | 376 | % Nesterov update (aka one step ahead). 377 | if params.nesterovUpdate 378 | delta = params.momentum * state.solverState{p} - parDer ; 379 | else 380 | delta = state.solverState{p} ; 381 | end 382 | 383 | % Update parameters. 384 | net.params(p).value = vl_taccum(... 385 | 1, net.params(p).value, thisLR, delta) ; 386 | 387 | else 388 | % call solver function to update weights 389 | [net.params(p).value, state.solverState{p}] = ... 390 | params.solver(net.params(p).value, state.solverState{p}, ... 391 | parDer, params.solverOpts, thisLR) ; 392 | end 393 | end 394 | otherwise 395 | error('Unknown training method ''%s'' for parameter ''%s''.', ... 396 | net.params(p).trainMethod, ... 397 | net.params(p).name) ; 398 | end 399 | end 400 | 401 | % ------------------------------------------------------------------------- 402 | function stats = accumulateStats(stats_) 403 | % ------------------------------------------------------------------------- 404 | for s = {'train', 'val'} 405 | s = char(s) ; 406 | total = 0 ; 407 | 408 | % initialize stats stucture with same fields and same order as 409 | % stats_{1} 410 | stats__ = stats_{1} ; 411 | names = fieldnames(stats__.(s))' ; 412 | values = zeros(1, numel(names)) ; 413 | fields = cat(1, names, num2cell(values)) ; 414 | stats.(s) = struct(fields{:}) ; 415 | 416 | for g = 1:numel(stats_) 417 | stats__ = stats_{g} ; 418 | num__ = stats__.(s).num ; 419 | total = total + num__ ; 420 | 421 | for f = setdiff(fieldnames(stats__.(s))', 'num') 422 | f = char(f) ; 423 | stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ; 424 | 425 | if g == numel(stats_) 426 | stats.(s).(f) = stats.(s).(f) / total ; 427 | end 428 | end 429 | end 430 | stats.(s).num = total ; 431 | end 432 | 433 | % ------------------------------------------------------------------------- 434 | function stats = extractStats(stats, net) 435 | % ------------------------------------------------------------------------- 436 | sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ; 437 | for i = 1:numel(sel) 438 | if net.layers(sel(i)).block.ignoreAverage, continue; end; 439 | stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ; 440 | end 441 | 442 | % ------------------------------------------------------------------------- 443 | function saveState(fileName, net_, state) 444 | % ------------------------------------------------------------------------- 445 | net = net_.saveobj() ; 446 | save(fileName, 'net', 'state') ; 447 | logInfo('saved: %s', fileName); 448 | 449 | % ------------------------------------------------------------------------- 450 | function saveStats(fileName, stats) 451 | % ------------------------------------------------------------------------- 452 | if exist(fileName) 453 | try, save(fileName, 'stats', '-append') ; end 454 | else 455 | try, save(fileName, 'stats') ; end 456 | end 457 | 458 | % ------------------------------------------------------------------------- 459 | function [net, state, stats] = loadState(fileName) 460 | % ------------------------------------------------------------------------- 461 | load(fileName, 'net', 'state', 'stats') ; 462 | net = dagnn.DagNN.loadobj(net) ; 463 | if isempty(whos('stats')) 464 | warning('Epoch ''%s'' was only partially saved. Delete this file and try again.', ... 465 | fileName) ; 466 | end 467 | 468 | % ------------------------------------------------------------------------- 469 | function epoch = findLastCheckpoint(modelDir) 470 | % ------------------------------------------------------------------------- 471 | list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ; 472 | tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ; 473 | epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ; 474 | epoch = max([epoch 0]) ; 475 | 476 | % ------------------------------------------------------------------------- 477 | function switchFigure(n) 478 | % ------------------------------------------------------------------------- 479 | if get(0,'CurrentFigure') ~= n 480 | try 481 | set(0,'CurrentFigure',n) ; 482 | catch 483 | figure(n) ; 484 | end 485 | end 486 | 487 | % ------------------------------------------------------------------------- 488 | function clearMex() 489 | % ------------------------------------------------------------------------- 490 | clear vl_tmove vl_imreadjpeg ; 491 | 492 | % ------------------------------------------------------------------------- 493 | function prepareGPUs(opts, cold) 494 | % ------------------------------------------------------------------------- 495 | numGpus = numel(opts.gpus) ; 496 | if numGpus > 1 497 | % check parallel pool integrity as it could have timed out 498 | pool = gcp('nocreate') ; 499 | if ~isempty(pool) && pool.NumWorkers ~= numGpus 500 | delete(pool) ; 501 | end 502 | pool = gcp('nocreate') ; 503 | if isempty(pool) 504 | parpool('local', numGpus) ; 505 | cold = true ; 506 | end 507 | 508 | end 509 | if numGpus >= 1 && cold 510 | fprintf('%s: resetting GPU\n', mfilename) 511 | clearMex() ; 512 | if numGpus == 1 513 | gpuDevice(opts.gpus) 514 | else 515 | spmd 516 | clearMex() ; 517 | gpuDevice(opts.gpus(labindex)) 518 | end 519 | end 520 | end 521 | -------------------------------------------------------------------------------- /matlab/util/catstruct.m: -------------------------------------------------------------------------------- 1 | function A = catstruct(varargin) 2 | % CATSTRUCT Concatenate or merge structures with different fieldnames 3 | % X = CATSTRUCT(S1,S2,S3,...) merges the structures S1, S2, S3 ... 4 | % into one new structure X. X contains all fields present in the various 5 | % structures. An example: 6 | % 7 | % A.name = 'Me' ; 8 | % B.income = 99999 ; 9 | % X = catstruct(A,B) 10 | % % -> X.name = 'Me' ; 11 | % % X.income = 99999 ; 12 | % 13 | % If a fieldname is not unique among structures (i.e., a fieldname is 14 | % present in more than one structure), only the value from the last 15 | % structure with this field is used. In this case, the fields are 16 | % alphabetically sorted. A warning is issued as well. An axample: 17 | % 18 | % S1.name = 'Me' ; 19 | % S2.age = 20 ; S3.age = 30 ; S4.age = 40 ; 20 | % S5.honest = false ; 21 | % Y = catstruct(S1,S2,S3,S4,S5) % use value from S4 22 | % 23 | % The inputs can be array of structures. All structures should have the 24 | % same size. An example: 25 | % 26 | % C(1).bb = 1 ; C(2).bb = 2 ; 27 | % D(1).aa = 3 ; D(2).aa = 4 ; 28 | % CD = catstruct(C,D) % CD is a 1x2 structure array with fields bb and aa 29 | % 30 | % The last input can be the string 'sorted'. In this case, 31 | % CATSTRUCT(S1,S2, ..., 'sorted') will sort the fieldnames alphabetically. 32 | % To sort the fieldnames of a structure A, you could use 33 | % CATSTRUCT(A,'sorted') but I recommend ORDERFIELDS for doing that. 34 | % 35 | % When there is nothing to concatenate, the result will be an empty 36 | % struct (0x0 struct array with no fields). 37 | % 38 | % NOTE: To concatenate similar arrays of structs, you can use simple 39 | % concatenation: 40 | % A = dir('*.mat') ; B = dir('*.m') ; C = [A ; B] ; 41 | 42 | % NOTE: This function relies on unique. Matlab changed the behavior of 43 | % its set functions since 2013a, so this might cause some backward 44 | % compatibility issues when dulpicated fieldnames are found. 45 | % 46 | % See also CAT, STRUCT, FIELDNAMES, STRUCT2CELL, ORDERFIELDS 47 | 48 | % version 4.1 (feb 2015), tested in R2014a 49 | % (c) Jos van der Geest 50 | % email: jos@jasen.nl 51 | 52 | % History 53 | % Created in 2005 54 | % Revisions 55 | % 2.0 (sep 2007) removed bug when dealing with fields containing cell 56 | % arrays (Thanks to Rene Willemink) 57 | % 2.1 (sep 2008) added warning and error identifiers 58 | % 2.2 (oct 2008) fixed error when dealing with empty structs (thanks to 59 | % Lars Barring) 60 | % 3.0 (mar 2013) fixed problem when the inputs were array of structures 61 | % (thanks to Tor Inge Birkenes). 62 | % Rephrased the help section as well. 63 | % 4.0 (dec 2013) fixed problem with unique due to version differences in 64 | % ML. Unique(...,'last') is no longer the deafult. 65 | % (thanks to Isabel P) 66 | % 4.1 (feb 2015) fixed warning with narginchk 67 | 68 | narginchk(1,Inf) ; 69 | N = nargin ; 70 | 71 | if ~isstruct(varargin{end}), 72 | if isequal(varargin{end},'sorted'), 73 | narginchk(2,Inf) ; 74 | sorted = 1 ; 75 | N = N-1 ; 76 | else 77 | error('catstruct:InvalidArgument','Last argument should be a structure, or the string "sorted".') ; 78 | end 79 | else 80 | sorted = 0 ; 81 | end 82 | 83 | sz0 = [] ; % used to check that all inputs have the same size 84 | 85 | % used to check for a few trivial cases 86 | NonEmptyInputs = false(N,1) ; 87 | NonEmptyInputsN = 0 ; 88 | 89 | % used to collect the fieldnames and the inputs 90 | FN = cell(N,1) ; 91 | VAL = cell(N,1) ; 92 | 93 | % parse the inputs 94 | for ii=1:N, 95 | X = varargin{ii} ; 96 | if ~isstruct(X), 97 | error('catstruct:InvalidArgument',['Argument #' num2str(ii) ' is not a structure.']) ; 98 | end 99 | 100 | if ~isempty(X), 101 | % empty structs are ignored 102 | if ii > 1 && ~isempty(sz0) 103 | if ~isequal(size(X), sz0) 104 | error('catstruct:UnequalSizes','All structures should have the same size.') ; 105 | end 106 | else 107 | sz0 = size(X) ; 108 | end 109 | NonEmptyInputsN = NonEmptyInputsN + 1 ; 110 | NonEmptyInputs(ii) = true ; 111 | FN{ii} = fieldnames(X) ; 112 | VAL{ii} = struct2cell(X) ; 113 | end 114 | end 115 | 116 | if NonEmptyInputsN == 0 117 | % all structures were empty 118 | A = struct([]) ; 119 | elseif NonEmptyInputsN == 1, 120 | % there was only one non-empty structure 121 | A = varargin{NonEmptyInputs} ; 122 | if sorted, 123 | A = orderfields(A) ; 124 | end 125 | else 126 | % there is actually something to concatenate 127 | FN = cat(1,FN{:}) ; 128 | VAL = cat(1,VAL{:}) ; 129 | FN = squeeze(FN) ; 130 | VAL = squeeze(VAL) ; 131 | 132 | 133 | [UFN,ind] = unique(FN, 'last') ; 134 | % If this line errors, due to your matlab version not having UNIQUE 135 | % accept the 'last' input, use the following line instead 136 | % [UFN,ind] = unique(FN) ; % earlier ML versions, like 6.5 137 | 138 | if numel(UFN) ~= numel(FN), 139 | warning('catstruct:DuplicatesFound','Fieldnames are not unique between structures.') ; 140 | sorted = 1 ; 141 | end 142 | 143 | if sorted, 144 | VAL = VAL(ind,:) ; 145 | FN = FN(ind,:) ; 146 | end 147 | 148 | A = cell2struct(VAL, FN); 149 | A = reshape(A, sz0) ; % reshape into original format 150 | end 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /matlab/util/clearMex.m: -------------------------------------------------------------------------------- 1 | % ------------------------------------------------------------------------- 2 | function clearMex() 3 | % ------------------------------------------------------------------------- 4 | disp('Clearing mex files') ; 5 | clear mex ; 6 | clear vl_tmove vl_imreadjpeg ; 7 | 8 | -------------------------------------------------------------------------------- /matlab/util/logInfo.m: -------------------------------------------------------------------------------- 1 | function logInfo(str, varargin) 2 | % get caller function ID and display log msg 3 | if nargin > 1 4 | cmd = 'sprintf(str'; 5 | for i = 1:length(varargin) 6 | cmd = sprintf('%s, varargin{%d}', cmd, i); 7 | end 8 | str = eval([cmd, ');']); 9 | end 10 | [st, i] = dbstack(); 11 | caller = st(2).name; 12 | fprintf('%s: %s\n', caller, str); 13 | end 14 | -------------------------------------------------------------------------------- /matlab/util/prepareGPUs.m: -------------------------------------------------------------------------------- 1 | % ------------------------------------------------------------------------- 2 | function prepareGPUs(params, cold) 3 | % ------------------------------------------------------------------------- 4 | numGpus = numel(params.gpus) ; 5 | if numGpus > 1 6 | % check parallel pool integrity as it could have timed out 7 | pool = gcp('nocreate') ; 8 | if ~isempty(pool) && pool.NumWorkers ~= numGpus 9 | delete(pool) ; 10 | end 11 | pool = gcp('nocreate') ; 12 | if isempty(pool) 13 | parpool('local', numGpus) ; 14 | cold = true ; 15 | end 16 | end 17 | if numGpus >= 1 && cold 18 | fprintf('%s: resetting GPU\n', mfilename) ; 19 | %clearMex() ; 20 | if numGpus == 1 21 | disp(gpuDevice(params.gpus)) ; 22 | else 23 | spmd 24 | %clearMex() ; 25 | disp(gpuDevice(params.gpus(labindex))) ; 26 | end 27 | end 28 | end 29 | -------------------------------------------------------------------------------- /matlab/util/record_diary.m: -------------------------------------------------------------------------------- 1 | function record_diary(opts) 2 | diary_path = @(i) sprintf('%s/diary_%03d.txt', opts.expDir, i); 3 | ind = 1; 4 | while exist(diary_path(ind), 'file') 5 | ind = ind + 1; 6 | end 7 | diary(diary_path(ind)); 8 | diary('on'); 9 | end 10 | -------------------------------------------------------------------------------- /pytorch/FastAP_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable, Function 3 | 4 | def softBinning(D, mid, Delta): 5 | y = 1 - torch.abs(D-mid)/Delta 6 | return torch.max(torch.Tensor([0]).cuda(), y) 7 | 8 | def dSoftBinning(D, mid, Delta): 9 | side1 = (D > (mid - Delta)).type(torch.float) 10 | side2 = (D <= mid).type(torch.float) 11 | ind1 = (side1 * side2) #.type(torch.uint8) 12 | 13 | side1 = (D > mid).type(torch.float) 14 | side2 = (D <= (mid + Delta)).type(torch.float) 15 | ind2 = (side1 * side2) #.type(torch.uint8) 16 | 17 | return (ind1 - ind2)/Delta 18 | 19 | 20 | class FastAP(torch.autograd.Function): 21 | """ 22 | FastAP - autograd function definition 23 | 24 | This class implements the FastAP loss from the following paper: 25 | "Deep Metric Learning to Rank", 26 | F. Cakir, K. He, X. Xia, B. Kulis, S. Sclaroff. CVPR 2019 27 | 28 | NOTE: 29 | Given a input batch, FastAP does not sample triplets from it as it's not 30 | a triplet-based method. Therefore, FastAP does not take a Sampler as input. 31 | Rather, we specify how the input batch is selected. 32 | """ 33 | 34 | @staticmethod 35 | def forward(ctx, input, target, num_bins): 36 | """ 37 | Args: 38 | input: torch.Tensor(N x embed_dim), embedding matrix 39 | target: torch.Tensor(N x 1), class labels 40 | num_bins: int, number of bins in distance histogram 41 | """ 42 | N = target.size()[0] 43 | assert input.size()[0] == N, "Batch size donesn't match!" 44 | 45 | # 1. get affinity matrix 46 | Y = target.unsqueeze(1) 47 | Aff = 2 * (Y == Y.t()).type(torch.float) - 1 48 | Aff.masked_fill_(torch.eye(N, N).byte(), 0) # set diagonal to 0 49 | 50 | I_pos = (Aff > 0).type(torch.float).cuda() 51 | I_neg = (Aff < 0).type(torch.float).cuda() 52 | N_pos = torch.sum(I_pos, 1) 53 | 54 | # 2. compute distances from embeddings 55 | # squared Euclidean distance with range [0,4] 56 | dist2 = 2 - 2 * torch.mm(input, input.t()) 57 | 58 | # 3. estimate discrete histograms 59 | Delta = torch.tensor(4. / num_bins).cuda() 60 | Z = torch.linspace(0., 4., steps=num_bins+1).cuda() 61 | L = Z.size()[0] 62 | h_pos = torch.zeros((N, L)).cuda() 63 | h_neg = torch.zeros((N, L)).cuda() 64 | for l in range(L): 65 | pulse = softBinning(dist2, Z[l], Delta) 66 | h_pos[:,l] = torch.sum(pulse * I_pos, 1) 67 | h_neg[:,l] = torch.sum(pulse * I_neg, 1) 68 | 69 | H_pos = torch.cumsum(h_pos, 1) 70 | h = h_pos + h_neg 71 | H = torch.cumsum(h, 1) 72 | 73 | # 4. compate FastAP 74 | FastAP = h_pos * H_pos / H 75 | FastAP[torch.isnan(FastAP) | torch.isinf(FastAP)] = 0 76 | FastAP = torch.sum(FastAP,1)/N_pos 77 | FastAP = FastAP[ ~torch.isnan(FastAP) ] 78 | loss = 1 - torch.mean(FastAP) 79 | if torch.rand(1) > 0.99: 80 | print("loss value (1-mean(FastAP)): ", loss.item()) 81 | 82 | # 6. save for backward 83 | ctx.save_for_backward(input, target) 84 | ctx.Z = Z 85 | ctx.Delta = Delta 86 | ctx.dist2 = dist2 87 | ctx.I_pos = I_pos 88 | ctx.I_neg = I_neg 89 | ctx.h_pos = h_pos 90 | ctx.h_neg = h_neg 91 | ctx.H_pos = H_pos 92 | ctx.N_pos = N_pos 93 | ctx.h = h 94 | ctx.H = H 95 | ctx.L = torch.tensor(L) 96 | 97 | return loss 98 | 99 | 100 | @staticmethod 101 | def backward(ctx, grad_output): 102 | input, target = ctx.saved_tensors 103 | 104 | Z = Variable(ctx.Z , requires_grad = False) 105 | Delta = Variable(ctx.Delta , requires_grad = False) 106 | dist2 = Variable(ctx.dist2 , requires_grad = False) 107 | I_pos = Variable(ctx.I_pos , requires_grad = False) 108 | I_neg = Variable(ctx.I_neg , requires_grad = False) 109 | h = Variable(ctx.h , requires_grad = False) 110 | H = Variable(ctx.H , requires_grad = False) 111 | h_pos = Variable(ctx.h_pos , requires_grad = False) 112 | h_neg = Variable(ctx.h_neg , requires_grad = False) 113 | H_pos = Variable(ctx.H_pos , requires_grad = False) 114 | N_pos = Variable(ctx.N_pos , requires_grad = False) 115 | 116 | L = Z.size()[0] 117 | H2 = torch.pow(H,2) 118 | H_neg = H - H_pos 119 | 120 | # 1. d(FastAP)/d(h+) 121 | LTM1 = torch.tril(torch.ones(L,L), -1) # lower traingular matrix 122 | tmp1 = h_pos * H_neg / H2 123 | tmp1[torch.isnan(tmp1)] = 0 124 | 125 | d_AP_h_pos = (H_pos * H + h_pos * H_neg) / H2 126 | d_AP_h_pos = d_AP_h_pos + torch.mm(tmp1, LTM1.cuda()) 127 | d_AP_h_pos = d_AP_h_pos / N_pos.repeat(L,1).t() 128 | d_AP_h_pos[torch.isnan(d_AP_h_pos) | torch.isinf(d_AP_h_pos)] = 0 129 | 130 | 131 | # 2. d(FastAP)/d(h-) 132 | LTM0 = torch.tril(torch.ones(L,L), 0) # lower triangular matrix 133 | tmp2 = -h_pos * H_pos / H2 134 | tmp2[torch.isnan(tmp2)] = 0 135 | 136 | d_AP_h_neg = torch.mm(tmp2, LTM0.cuda()) 137 | d_AP_h_neg = d_AP_h_neg / N_pos.repeat(L,1).t() 138 | d_AP_h_neg[torch.isnan(d_AP_h_neg) | torch.isinf(d_AP_h_neg)] = 0 139 | 140 | 141 | # 3. d(FastAP)/d(embedding) 142 | d_AP_x = 0 143 | for l in range(L): 144 | dpulse = dSoftBinning(dist2, Z[l], Delta) 145 | dpulse[torch.isnan(dpulse) | torch.isinf(dpulse)] = 0 146 | ddp = dpulse * I_pos 147 | ddn = dpulse * I_neg 148 | 149 | alpha_p = torch.diag(d_AP_h_pos[:,l]) # N*N 150 | alpha_n = torch.diag(d_AP_h_neg[:,l]) 151 | Ap = torch.mm(ddp, alpha_p) + torch.mm(alpha_p, ddp) 152 | An = torch.mm(ddn, alpha_n) + torch.mm(alpha_n, ddn) 153 | 154 | # accumulate gradient 155 | d_AP_x = d_AP_x - torch.mm(input.t(), (Ap+An)) 156 | 157 | grad_input = -d_AP_x 158 | return grad_input.t(), None, None 159 | 160 | 161 | class FastAPLoss(torch.nn.Module): 162 | """ 163 | FastAP - loss layer definition 164 | 165 | This class implements the FastAP loss from the following paper: 166 | "Deep Metric Learning to Rank", 167 | F. Cakir, K. He, X. Xia, B. Kulis, S. Sclaroff. CVPR 2019 168 | """ 169 | def __init__(self, num_bins=10): 170 | super(FastAPLoss, self).__init__() 171 | self.num_bins = num_bins 172 | 173 | def forward(self, batch, labels): 174 | return FastAP.apply(batch, labels, self.num_bins) 175 | -------------------------------------------------------------------------------- /pytorch/README.md: -------------------------------------------------------------------------------- 1 | ## PyTorch implementation of "Deep Metric Learning to Rank" 2 | 3 | The PyTorch version of FastAP is implemented within the framework of 4 | [Deep-Metric-Learning-Baselines](https://github.com/kunhe/Deep-Metric-Learning-Baselines). 5 | In this repository we provide a standalone implementation of the loss layer. 6 | 7 | #### NOTE 8 | - To completely reproduce results reported in the paper, the FastAP loss needs to be used in conjunction with the proposed minibatch sampling method, which is implemented in the linked repo. 9 | - It is currently a direct port of the Matlab implementation. To investigate: better use of automatic differentiation. 10 | 11 | **TODO** 12 | - [x] implement FastAP's minibatch sampling method 13 | - [ ] reproduce results in the paper 14 | --------------------------------------------------------------------------------