├── LICENSE
├── README.md
├── fastap_github.jpg
├── matlab
├── +imdb
│ ├── inshop.m
│ ├── products.m
│ └── vid.m
├── +models
│ ├── googlenet.m
│ ├── resnet18.m
│ └── resnet50.m
├── +solver
│ ├── adadelta.m
│ ├── adagrad.m
│ ├── adam.m
│ └── rmsprop.m
├── FastAP.m
├── README.md
├── evaluate_model.m
├── get_model.m
├── get_opts.m
├── run_demo.m
├── startup.m
├── train_hard.m
├── train_hard_vid.m
├── train_rand.m
└── util
│ ├── catstruct.m
│ ├── clearMex.m
│ ├── logInfo.m
│ ├── prepareGPUs.m
│ └── record_diary.m
└── pytorch
├── FastAP_loss.py
└── README.md
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018-2019 Fatih Cakir, Kun He, Xide Xia, Brian Kulis, Stan Sclaroff
4 |
5 | If used for academic purposes, please cite the following paper:
6 |
7 | "Deep Metric Learning to Rank"
8 | Fatih Cakir(*), Kun He(*), Xide Xia, Brian Kulis, and Stan Sclaroff
9 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019
10 |
11 | Permission is hereby granted, free of charge, to any person obtaining a copy
12 | of this software and associated documentation files (the "Software"), to deal
13 | in the Software without restriction, including without limitation the rights
14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | copies of the Software, and to permit persons to whom the Software is
16 | furnished to do so, subject to the following conditions:
17 |
18 | The above copyright notice and this permission notice shall be included in all
19 | copies or substantial portions of the Software.
20 |
21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27 | SOFTWARE.
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FastAP: Deep Metric Learning to Rank
2 | This repository contains implementation of the following paper:
3 |
4 | [Deep Metric Learning to Rank](http://openaccess.thecvf.com/content_CVPR_2019/html/Cakir_Deep_Metric_Learning_to_Rank_CVPR_2019_paper.html)
5 | [Fatih Cakir](http://cs-people.bu.edu/fcakir/)\*, [Kun He](http://cs-people.bu.edu/hekun/)\*, [Xide Xia](https://xidexia.github.io), [Brian Kulis](http://people.bu.edu/bkulis/), and [Stan Sclaroff](http://www.cs.bu.edu/~sclaroff/) (*equal contribution)
6 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019
7 |
8 | 
9 |
10 | ## Other Implementations
11 | [FastAPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#fastaploss) from [pytorch-metric-learning](https://github.com/KevinMusgrave/pytorch-metric-learning)
12 |
13 | ## Usage
14 | * **Matlab**: see `matlab/README.md`
15 | * **PyTorch**: see `pytorch/README.md`
16 |
17 | ## Datasets
18 | * Stanford Online Products
19 | * Can be downloaded [here](http://cvgl.stanford.edu/projects/lifted_struct/)
20 | * In-Shop Clothes Retrieval
21 | * Can be downloaded [here](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html)
22 | * PKU VehicleID
23 | * Please request the dataset from the authors [here](https://pkuml.org/resources/pku-vehicleid.html)
24 |
25 | ## Reproducibility
26 | * We provide trained [MatConvNet](https://www.vlfeat.org/matconvnet/quick/) models and experimental logs for the results in the paper. These models were used to achieve the results in the tables.
27 | * The logs also include parameters settings that enable one to re-train a model if desired. It also includes evaluation results with model checkpoints at certain epochs.
28 | * Table 1: Stanford Online Products
29 | * FastAP, ResNet-18, M=256, Dim=512: [[model @ epoch 20](https://drive.google.com/file/d/1sPCG34rV4Bqf0aWF7GrFIDUK7DGjcaB5/view?usp=sharing), [log](https://drive.google.com/open?id=14m3fHgeZu8MIAePFHXe141R60KRwH1d8)]
30 | * FastAP, ResNet-50, M=96, Dim=128: [[model @ epoch 30](https://drive.google.com/open?id=1yGUVTskdERdLeF85GP-lLRnwS0KhpvvL), [log](https://drive.google.com/open?id=1A0G1aUBS7URotInbT7eBbys4xfARvCCe)]
31 | * FastAP, ResNet-50, M=96, Dim=512: [[model @ epoch 28](https://drive.google.com/file/d/14yEyAYhGzNygBBn8r2RcZut_Ye14mJoK/view?usp=sharing), [log](https://drive.google.com/open?id=19mpLn1OqA2nqpMZtvZ3GvFk_VppOOPTc)]
32 | * FastAP, ResNet-50, M=256, Dim=512: [[model @ epoch 12](https://drive.google.com/open?id=1WfV1ArXHG4oksHGE8DZRDwsxupoO60sD), [log](https://drive.google.com/open?id=1shvC5qB8O0l6vH1qa2SG1oxX8_jM1gdi)]
33 | * Table 2: In-Shop Clothes
34 | * FastAP, ResNet-18, M=256, Dim=512: [[model @ epoch 50](https://drive.google.com/open?id=1ZZ-Fpx9uPkRL-QXL-8-RcROjQVOLcbr5), [log](https://drive.google.com/file/d/1osxoHsMy11v-kvUNTuRG3luhsxMMn78B/view?usp=sharing)]
35 | * FastAP, ResNet-50, M=96, Dim=512: [[model @ epoch 40](https://drive.google.com/open?id=1PyiHog7fJp_InvqdAO0dzyJDRMNvAXxm), [log](https://drive.google.com/open?id=14IPgDfkbKo9PnrgMFFRDSIBRW1xwRYs5)]
36 | * FastAP, ResNet-50, M=256, Dim=512: [[model @ epoch 35](https://drive.google.com/open?id=1T5IynM63YqnGslnMGppJsmtJdHIWamJv), [log](https://drive.google.com/open?id=1oud9i87FTJE7Ei636bjxqgBXpahysRKK)]
37 | * Table 3: PKU VehicleID
38 | * FastAP, ResNet-18, M=256, Dim=512: [[model @ epoch 50](https://drive.google.com/open?id=1KsUF2SzkhvBOkHzbrXKj7H5KtN6Z3hRJ), [log](https://drive.google.com/open?id=155Ce-FmI6dmMgJnXdESVHx08unU3jWX2)]
39 | * FastAP, ResNet-50, M=96, Dim=512: [[model @ epoch 40](https://drive.google.com/open?id=1AblJelRHStBfWwmZeoRM8iEpNRNdOobn), [log](https://drive.google.com/open?id=1twswLE-j9kLxUsk5Ku7vWqBp0Sml65sG)]
40 | * FastAP, ResNet-50, M=256, dim=512: [[model @ epoch 30](https://drive.google.com/open?id=1MAimhKEyEfq2LDYnUaDburFH2YsUhrpA), [log](https://drive.google.com/open?id=1CtNk-wxSZToO703OvfK8ndFQzFnpOvVS)]
41 |
42 | (M=mini-batch size)
43 | * PyTorch code is a direct port from our MATLAB implementation. We haven't tried reproducing the paper results with our PyTorch code. **For reproducibility use the MATLAB version**.
44 | * Note that the mini-batch sampling strategy must also be used alongside the FastAP loss for good results.
45 |
46 | ## Contact
47 | For questions and comments, feel free to contact: kunhe@fb.com or fcakirs@gmail.com
48 |
49 | ## License
50 | MIT
51 |
--------------------------------------------------------------------------------
/fastap_github.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunhe/FastAP-metric-learning/ca85b2ab3f22460795d4306c8fa01c655d17027e/fastap_github.jpg
--------------------------------------------------------------------------------
/matlab/+imdb/inshop.m:
--------------------------------------------------------------------------------
1 | function imdb = inshop(prefix, opts)
2 | savefn = fullfile(opts.dataDir, [prefix '.mat']);
3 | if exist(savefn, 'file')
4 | imdb = load(savefn);
5 | return;
6 | end
7 | inshopdir = fullfile(opts.dataDir, 'InShopClothes');
8 |
9 | % meta-classes
10 | metaFile = fullfile(inshopdir, 'meta_classes.txt');
11 | metaClasses = textread(metaFile, '%s');
12 |
13 | % parse image names/labels
14 | imgsFile = fullfile(inshopdir, 'Anno/list_bbox_inshop.txt');
15 | [imgs, cType, pType, ~, ~, ~, ~] = textread(imgsFile, '%s %d %d %d %d %d %d', ...
16 | 'headerlines', 2);
17 | tmp = cellfun(@(x) strsplit(x, '/'), imgs, 'uni', false);
18 | cls = cellfun(@(x) sscanf(x{4}, 'id_%d'), tmp);
19 | meta = cellfun(@(x) [x{2} '/' x{3}], tmp, 'uni', false);
20 | [~, meta] = ismember(meta, metaClasses);
21 |
22 | % train/test
23 | splitFile = fullfile(inshopdir, 'list_eval_partition.txt');
24 | [img_sp, ~, sp] = textread(splitFile, '%s %s %s', 'headerlines', 2);
25 | [~, loc] = ismember(imgs, img_sp);
26 | [~, set] = ismember(sp(loc), {'train', 'query', 'gallery'});
27 |
28 | % images are already 256x256, nothing to do
29 |
30 | imdb.images.data = fullfile(inshopdir, imgs);
31 | imdb.images.labels = single([cls meta cType pType]) ;
32 | imdb.images.set = uint8(set) ;
33 | imdb.meta.sets = {'train', 'query', 'gallery'} ;
34 | save(savefn, '-struct', 'imdb');
35 | end
36 |
--------------------------------------------------------------------------------
/matlab/+imdb/products.m:
--------------------------------------------------------------------------------
1 | function imdb = products(prefix, opts)
2 | savefn = fullfile(opts.dataDir, [prefix '.mat']);
3 | if exist(savefn, 'file')
4 | imdb = load(savefn);
5 | return;
6 | end
7 | stanforddir = fullfile(opts.dataDir, 'Stanford_Online_Products');
8 |
9 | % train/test
10 | [id, y1, y2, x] = textread([stanforddir '/Ebay_train.txt'], '%d %d %d %s', ...
11 | 'headerlines', 1);
12 | fn = x;
13 | labels = [y1 y2];
14 | set = ones(numel(id), 1);
15 | logInfo('train: %d', numel(x));
16 |
17 | [id, y1, y2, x] = textread([stanforddir '/Ebay_test.txt'], '%d %d %d %s', ...
18 | 'headerlines', 1);
19 | fn = [fn; x];
20 | labels = [labels; y1 y2];
21 | set = [set; 3*ones(numel(id), 1)];
22 | logInfo(' test: %d', numel(x));
23 |
24 | % resize to 256x256
25 | a = tic; tic;
26 | stanford256 = fullfile(stanforddir, 'images_256x256');
27 | if ~exist(stanford256, 'dir'), mkdir(stanford256); end
28 | fn256 = strrep(fn, '/', '_');
29 | fn256 = cellfun(@(x) fullfile(stanford256, x), fn256, 'uniform', false);
30 | for i = 1:numel(fn)
31 | if ~exist(fn256{i}, 'file')
32 | im = imread([stanforddir '/' fn{i}]);
33 | im = imresize(im, [256 256]);
34 | imwrite(im, fn256{i});
35 | end
36 | if toc > 30
37 | logInfo('%d/%d: %s', i, numel(fn), fn256{i}); tic;
38 | end
39 | end
40 | toc(a)
41 |
42 | imdb.images.data = fn256;
43 | imdb.images.labels = single(labels) ;
44 | imdb.images.set = uint8(set) ;
45 | imdb.meta.sets = {'train', 'val', 'test'} ;
46 | save(savefn, '-struct', 'imdb');
47 | end
48 |
--------------------------------------------------------------------------------
/matlab/+imdb/vid.m:
--------------------------------------------------------------------------------
1 | function imdb = vid(prefix, opts)
2 | savefn = fullfile(opts.dataDir, [prefix '.mat']);
3 | if exist(savefn, 'file')
4 | imdb = load(savefn);
5 | return;
6 | end
7 | viddir = fullfile(opts.dataDir, 'VehicleID_V1.0');
8 |
9 | % imgs, vid
10 | img2vid = fullfile(viddir, 'attribute', 'img2vid.txt');
11 | [imgs, vid] = textread(img2vid, '%d %d');
12 | n = numel(imgs);
13 |
14 | % model info
15 | mid = zeros(n, 1);
16 | vid2model = fullfile(viddir, 'attribute', 'model_attr.txt');
17 | [vid_m, mid_m] = textread(vid2model, '%d %d');
18 | for i = 1:length(vid_m)
19 | mid(vid == vid_m(i)) = mid_m(i) + 1;
20 | end
21 |
22 | % color info
23 | cid = zeros(n, 1);
24 | vid2color = fullfile(viddir, 'attribute', 'color_attr.txt');
25 | [vid_c, cid_c] = textread(vid2color, '%d %d');
26 | for i = 1:length(vid_c)
27 | mid(vid == vid_c(i)) = cid_c(i) + 1;
28 | end
29 |
30 | % train/test
31 | sets = zeros(n, 4); % [train test_small test_medium test_large]
32 |
33 | trainFile = fullfile(viddir, 'train_test_split', 'train_list.txt');
34 | [~, vid_train] = textread(trainFile, '%d %d');
35 | sets(ismember(vid, vid_train), 1) = 1;
36 | logInfo('train: %d imgs', sum(sets(:, 1)));
37 |
38 | testSmallFile = fullfile(viddir, 'train_test_split', 'test_list_800.txt');
39 | [~, vid_s] = textread(testSmallFile, '%d %d');
40 | sets(ismember(vid, vid_s), 2) = 1;
41 | logInfo('test-small: %d imgs', sum(sets(:, 2)));
42 |
43 | testMediumFile = fullfile(viddir, 'train_test_split', 'test_list_1600.txt');
44 | [~, vid_m] = textread(testMediumFile, '%d %d');
45 | sets(ismember(vid, vid_m), 3) = 1;
46 | logInfo('test-medium: %d imgs', sum(sets(:, 3)));
47 |
48 | testLargeFile = fullfile(viddir, 'train_test_split', 'test_list_2400.txt');
49 | [~, vid_l] = textread(testLargeFile, '%d %d');
50 | sets(ismember(vid, vid_l), 4) = 1;
51 | logInfo('test-large: %d imgs', sum(sets(:, 4)));
52 |
53 | % resize to 256x256
54 | a = tic; tic;
55 | vid256 = fullfile(viddir, 'image_256x256');
56 | if ~exist(vid256, 'dir'), mkdir(vid256); end
57 | fn256 = cell(n, 1);
58 | for i = 1:n
59 | fn256{i} = sprintf('%s/%07d.jpg', vid256, imgs(i));
60 | if ~exist(fn256{i}, 'file')
61 | im = imread(sprintf('%s/image/%07d.jpg', viddir, imgs(i)));
62 | im = imresize(im, [256 256]);
63 | imwrite(im, fn256{i});
64 | end
65 | if toc > 20
66 | logInfo(fn256{i}); tic;
67 | end
68 | end
69 | toc(a)
70 |
71 | % imdb
72 | imdb.images.data = fn256;
73 | imdb.images.labels = single([vid mid cid]) ;
74 | imdb.images.set = uint8(sets) ;
75 | imdb.meta.sets = {'train', 'test_small', 'test_medium', 'test_large'} ;
76 | save(savefn, '-struct', 'imdb');
77 | end
78 |
--------------------------------------------------------------------------------
/matlab/+models/googlenet.m:
--------------------------------------------------------------------------------
1 | function [net, opts, in_name, in_dim] = googlenet(opts)
2 | opts.imageSize = 224;
3 | opts.maxGpuImgs = 320; % # of images that a 12GB GPU can hold
4 |
5 | % finetune GoogLeNet
6 | net = load(fullfile(opts.localDir, 'models', 'imagenet-googlenet-dag.mat'));
7 | net = dagnn.DagNN.loadobj(net) ;
8 |
9 | % remove softmax, fc
10 | net.removeLayer('softmax');
11 | net.removeLayer('cls3_fc');
12 |
13 | net.removeLayer('cls2_pool');
14 | net.removeLayer('cls2_reduction');
15 | net.removeLayer('relu_cls2_reduction');
16 | net.removeLayer('cls2_fc1');
17 | net.removeLayer('relu_cls2_fc1');
18 | net.removeLayer('cls2_fc2');
19 |
20 | net.removeLayer('cls1_pool');
21 | net.removeLayer('cls1_reduction');
22 | net.removeLayer('relu_cls1_reduction');
23 | net.removeLayer('cls1_fc1');
24 | net.removeLayer('relu_cls1_fc1');
25 | net.removeLayer('cls1_fc2');
26 |
27 | % freeze pretrained layers
28 | if opts.lastLayer
29 | for i = 1:numel(net.params)
30 | net.params(i).learningRate = 0;
31 | net.params(i).weightDecay = 0;
32 | end
33 | elseif isfield(opts, 'ft') && opts.ft >= 1
34 | assert(opts.ft>=1 && opts.ft<=9);
35 | names = {};
36 | for l = opts.ft : 9
37 | names{end+1} = sprintf('icp%d', l);
38 | end
39 | for i = 1:numel(net.params)
40 | pname = net.params(i).name;
41 | notfound = cellfun(@(x) isempty(strfind(pname, x)), names);
42 | if all(notfound)
43 | net.params(i).learningRate = 0;
44 | net.params(i).weightDecay = 0;
45 | end
46 | end
47 | end
48 |
49 | in_name = 'cls3_pool';
50 | in_dim = 1024;
51 | end
52 |
--------------------------------------------------------------------------------
/matlab/+models/resnet18.m:
--------------------------------------------------------------------------------
1 | function [net, opts, in_name, in_dim] = resnet18(opts)
2 | opts.imageSize = 224;
3 | opts.maxGpuImgs = 256; % # of images that a 12GB GPU can hold
4 |
5 | net = load(fullfile(opts.localDir, 'models', 'resnet18-pt-mcn.mat'));
6 | net = dagnn.DagNN.loadobj(net) ;
7 |
8 | % remove softmax, fc
9 | net.removeLayer('classifier_0');
10 |
11 | % freeze pretrained layers
12 | if opts.lastLayer
13 | for i = 1:numel(net.params)
14 | net.params(i).learningRate = 0;
15 | net.params(i).weightDecay = 0;
16 | end
17 | end
18 |
19 | in_name = 'classifier_flatten';
20 | in_dim = 512;
21 | end
22 |
--------------------------------------------------------------------------------
/matlab/+models/resnet50.m:
--------------------------------------------------------------------------------
1 | function [net, opts, in_name, in_dim] = resnet50(opts)
2 | opts.imageSize = 224;
3 | opts.maxGpuImgs = 90; % # of images that a 12GB GPU can hold
4 |
5 | net = load(fullfile(opts.localDir, 'models', 'imagenet-resnet-50-dag.mat'));
6 | net = dagnn.DagNN.loadobj(net) ;
7 |
8 | % remove softmax, fc
9 | net.removeLayer('prob');
10 | net.removeLayer('fc1000');
11 |
12 | % freeze pretrained layers
13 | if opts.lastLayer
14 | for i = 1:numel(net.params)
15 | net.params(i).learningRate = 0;
16 | net.params(i).weightDecay = 0;
17 | end
18 | end
19 |
20 | in_name = 'pool5';
21 | in_dim = 2048;
22 | end
23 |
--------------------------------------------------------------------------------
/matlab/+solver/adadelta.m:
--------------------------------------------------------------------------------
1 | function [w, state] = adadelta(w, state, grad, opts, ~)
2 | %ADADELTA
3 | % Example AdaDelta solver, for use with CNN_TRAIN and CNN_TRAIN_DAG.
4 | %
5 | % AdaDelta sets its own learning rate, so any learning rate set in the
6 | % options of CNN_TRAIN and CNN_TRAIN_DAG will be ignored.
7 | %
8 | % If called without any input argument, returns the default options
9 | % structure.
10 | %
11 | % Solver options: (opts.train.solverOpts)
12 | %
13 | % `epsilon`:: 1e-6
14 | % Small additive constant to regularize variance estimate.
15 | %
16 | % `rho`:: 0.9
17 | % Moving average window for variance update, between 0 and 1 (larger
18 | % values result in slower/more stable updating).
19 |
20 | % Copyright (C) 2016 Joao F. Henriques.
21 | % All rights reserved.
22 | %
23 | % This file is part of the VLFeat library and is made available under
24 | % the terms of the BSD license (see the COPYING file).
25 |
26 | if nargin == 0 % Return the default solver options
27 | w = struct('epsilon', 1e-6, 'rho', 0.9) ;
28 | return ;
29 | end
30 |
31 | if isequal(state, 0) % First iteration, initialize state struct
32 | state = struct('g_sqr', 0, 'delta_sqr', 0) ;
33 | end
34 |
35 | rho = opts.rho ;
36 |
37 | state.g_sqr = state.g_sqr * rho + grad.^2 * (1 - rho) ;
38 | new_delta = -sqrt((state.delta_sqr + opts.epsilon) ./ ...
39 | (state.g_sqr + opts.epsilon)) .* grad ;
40 | state.delta_sqr = state.delta_sqr * rho + new_delta.^2 * (1 - rho) ;
41 |
42 | w = w + new_delta ;
43 |
--------------------------------------------------------------------------------
/matlab/+solver/adagrad.m:
--------------------------------------------------------------------------------
1 | function [w, g_sqr] = adagrad(w, g_sqr, grad, opts, lr)
2 | %ADAGRAD
3 | % Example AdaGrad solver, for use with CNN_TRAIN and CNN_TRAIN_DAG.
4 | %
5 | % Set the initial learning rate for AdaGrad in the options for
6 | % CNN_TRAIN and CNN_TRAIN_DAG. Note that a learning rate that works for
7 | % SGD may be inappropriate for AdaGrad; the default is 0.001.
8 | %
9 | % If called without any input argument, returns the default options
10 | % structure.
11 | %
12 | % Solver options: (opts.train.solverOpts)
13 | %
14 | % `epsilon`:: 1e-10
15 | % Small additive constant to regularize variance estimate.
16 | %
17 | % `rho`:: 1
18 | % Moving average window for variance update, between 0 and 1 (larger
19 | % values result in slower/more stable updating). This is similar to
20 | % RHO in AdaDelta and RMSProp. Standard AdaGrad is obtained with a RHO
21 | % value of 1 (use total average instead of a moving average).
22 | %
23 | % A possibly undesirable effect of standard AdaGrad is that the update
24 | % will monotonically decrease to 0, until training eventually stops. This
25 | % is because the AdaGrad update is inversely proportional to the total
26 | % variance of the gradients seen so far.
27 | % With RHO smaller than 1, a moving average is used instead. This
28 | % prevents the final update from monotonically decreasing to 0.
29 |
30 | % Copyright (C) 2016 Joao F. Henriques.
31 | % All rights reserved.
32 | %
33 | % This file is part of the VLFeat library and is made available under
34 | % the terms of the BSD license (see the COPYING file).
35 |
36 | if nargin == 0 % Return the default solver options
37 | w = struct('epsilon', 1e-10, 'rho', 1) ;
38 | return ;
39 | end
40 |
41 | g_sqr = g_sqr * opts.rho + grad.^2 ;
42 |
43 | w = w - lr * grad ./ (sqrt(g_sqr) + opts.epsilon) ;
44 |
--------------------------------------------------------------------------------
/matlab/+solver/adam.m:
--------------------------------------------------------------------------------
1 | function [w, state] = adam(w, state, grad, opts, lr)
2 | %ADAM
3 | % Adam solver for use with CNN_TRAIN and CNN_TRAIN_DAG
4 | %
5 | % See [Kingma et. al., 2014](http://arxiv.org/abs/1412.6980)
6 | % | ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
7 | %
8 | % If called without any input argument, returns the default options
9 | % structure. Otherwise provide all input arguments.
10 | %
11 | % W is the vector/matrix/tensor of parameters. It can be single/double
12 | % precision and can be a `gpuArray`.
13 | %
14 | % STATE is as defined below and so are supported OPTS.
15 | %
16 | % GRAD is the gradient of the objective w.r.t W
17 | %
18 | % LR is the learning rate, referred to as \alpha by Algorithm 1 in
19 | % [Kingma et. al., 2014].
20 | %
21 | % Solver options: (opts.train.solverOpts)
22 | %
23 | % `beta1`:: 0.9
24 | % Decay for 1st moment vector. See algorithm 1 in [Kingma et.al. 2014]
25 | %
26 | % `beta2`:: 0.999
27 | % Decay for 2nd moment vector
28 | %
29 | % `eps`:: 1e-8
30 | % Additive offset when dividing by state.v
31 | %
32 | % The state is initialized as 0 (number) to start with. The first call to
33 | % this function will initialize it with the default state consisting of
34 | %
35 | % `m`:: 0
36 | % First moment vector
37 | %
38 | % `v`:: 0
39 | % Second moment vector
40 | %
41 | % `t`:: 0
42 | % Global iteration number across epochs
43 | %
44 | % This implementation borrowed from torch optim.adam
45 |
46 | % Copyright (C) 2016 Aravindh Mahendran.
47 | % All rights reserved.
48 | %
49 | % This file is part of the VLFeat library and is made available under
50 | % the terms of the BSD license (see the COPYING file).
51 |
52 | if nargin == 0 % Returns the default solver options
53 | w = struct('beta1', 0.9, 'beta2', 0.999, 'eps', 1e-8) ;
54 | return ;
55 | end
56 |
57 | if isequal(state, 0) % start off with state = 0 so as to get default state
58 | state = struct('m', 0, 'v', 0, 't', 0);
59 | end
60 |
61 | % update first moment vector `m`
62 | state.m = opts.beta1 * state.m + (1 - opts.beta1) * grad ;
63 |
64 | % update second moment vector `v`
65 | state.v = opts.beta2 * state.v + (1 - opts.beta2) * grad.^2 ;
66 |
67 | % update the time step
68 | state.t = state.t + 1 ;
69 |
70 | % This implicitly corrects for biased estimates of first and second moment
71 | % vectors
72 | lr_t = lr * (((1 - opts.beta2^state.t)^0.5) / (1 - opts.beta1^state.t)) ;
73 |
74 | % Update `w`
75 | w = w - lr_t * state.m ./ (state.v.^0.5 + opts.eps) ;
76 |
--------------------------------------------------------------------------------
/matlab/+solver/rmsprop.m:
--------------------------------------------------------------------------------
1 | function [w, g_sqr] = rmsprop(w, g_sqr, grad, opts, lr)
2 | %RMSPROP
3 | % Example RMSProp solver, for use with CNN_TRAIN and CNN_TRAIN_DAG.
4 | %
5 | % Set the initial learning rate for RMSProp in the options for
6 | % CNN_TRAIN and CNN_TRAIN_DAG. Note that a learning rate that works for
7 | % SGD may be inappropriate for RMSProp; the default is 0.001.
8 | %
9 | % If called without any input argument, returns the default options
10 | % structure.
11 | %
12 | % Solver options: (opts.train.solverOpts)
13 | %
14 | % `epsilon`:: 1e-8
15 | % Small additive constant to regularize variance estimate.
16 | %
17 | % `rho`:: 0.99
18 | % Moving average window for variance update, between 0 and 1 (larger
19 | % values result in slower/more stable updating).
20 |
21 | % Copyright (C) 2016 Joao F. Henriques.
22 | % All rights reserved.
23 | %
24 | % This file is part of the VLFeat library and is made available under
25 | % the terms of the BSD license (see the COPYING file).
26 |
27 | if nargin == 0 % Return the default solver options
28 | w = struct('epsilon', 1e-8, 'rho', 0.99) ;
29 | return ;
30 | end
31 |
32 | g_sqr = g_sqr * opts.rho + grad.^2 * (1 - opts.rho) ;
33 |
34 | w = w - lr * grad ./ (sqrt(g_sqr) + opts.epsilon) ;
35 |
--------------------------------------------------------------------------------
/matlab/FastAP.m:
--------------------------------------------------------------------------------
1 | classdef FastAP < dagnn.Loss
2 | properties
3 | opt
4 | end
5 |
6 | properties (Transient)
7 | Z
8 | Delta
9 | dist2
10 | I_pos
11 | I_neg
12 | h_pos
13 | h_neg
14 | H_pos
15 | N_pos
16 | h
17 | H
18 | end
19 |
20 | methods
21 |
22 | function obj = FastAP(varargin)
23 | obj.load(varargin{:});
24 | end
25 |
26 |
27 | function outputs = forward(obj, inputs, params)
28 | % forward pass
29 | X = squeeze(inputs{1}); % features (L2 normalized)
30 | Y = inputs{2};
31 | N = size(X, 2);
32 | assert(size(Y, 1) == N);
33 |
34 | opts = obj.opt;
35 | onGPU = numel(opts.gpus) > 0;
36 |
37 | % binary affinity
38 | Affinity = 2 * bsxfun(@eq, Y(:,1), Y(:,1)') - 1;
39 | Affinity(logical(eye(N))) = 0;
40 | I_pos = (Affinity > 0);
41 | I_neg = (Affinity < 0);
42 | N_pos = sum(I_pos, 2);
43 |
44 | % (squared) pairwise distance matrix
45 | dist2 = max(0, 2 - 2 * X' * X);
46 |
47 | % histogram binning
48 | Delta = 4 / opts.nbins;
49 | Z = linspace(0, 4, opts.nbins+1);
50 | L = length(Z);
51 | h_pos = zeros(N, L);
52 | h_neg = zeros(N, L);
53 | if onGPU
54 | h_pos = gpuArray(h_pos);
55 | h_neg = gpuArray(h_neg);
56 | end
57 |
58 | for l = 1:L
59 | pulse = obj.softBinning(dist2, Z(l), Delta);
60 | h_pos(:, l) = sum(pulse .* I_pos, 2);
61 | h_neg(:, l) = sum(pulse .* N_neg, 2);
62 | end
63 | H_pos = cumsum(h_pos, 2);
64 | h = h_pos + h_neg;
65 | H = cumsum(h, 2);
66 |
67 | % compute FastAP
68 | FastAP = h_pos .* H_pos ./ H;
69 | FastAP(isnan(FastAP)|isinf(FastAP)) = 0;
70 | FastAP = sum(FastAP, 2) ./ N_pos;
71 | FastAP = FastAP(~isnan(FastAP));
72 |
73 | obj.numAveraged = N;
74 | obj.average = gather(mean(FastAP));
75 |
76 | % output
77 | outputs{1} = sum(FastAP);
78 | obj.Z = Z;
79 | obj.Delta = Delta;
80 | obj.dist2 = dist2;
81 | obj.I_pos = I_pos;
82 | obj.I_neg = I_neg;
83 | obj.h_pos = h_pos;
84 | obj.h_neg = h_neg;
85 | obj.H_pos = H_pos;
86 | obj.N_pos = N_pos;
87 | obj.h = h;
88 | obj.H = H;
89 | end
90 |
91 |
92 | function [dInputs, dParams] = backward(obj, inputs, params, dOutputs)
93 | % backward pass
94 | X = squeeze(inputs{1});
95 | opts = obj.opt;
96 | onGPU = numel(opts.gpus) > 0;
97 |
98 | L = numel(obj.Z);
99 | h_pos = obj.h_pos;
100 | h_neg = obj.h_neg;
101 | H_pos = obj.H_pos;
102 | N_pos = obj.N_pos;
103 | H_neg = H - H_pos;
104 | h = obj.h;
105 | H = obj.H;
106 | H2 = H .^ 2;
107 |
108 | % 1. d(FastAP)/d(h+)
109 | tmp1 = h_pos .* H_neg ./ H2;
110 | tmp1(isnan(tmp1)) = 0;
111 |
112 | d_AP_h_pos = (H_pos .* H + h_pos .* H_neg) ./ H2;
113 | d_AP_h_pos = d_AP_h_pos + tmp1 * triu(ones(L), 1)';
114 | d_AP_h_pos = bsxfun(@rdivide, d_AP_h_pos, N_pos);
115 | d_AP_h_pos(isnan(d_AP_h_pos)|isinf(d_AP_h_pos)) = 0;
116 |
117 | % 2. d(FastAP)/d(h-)
118 | tmp2 = -h_pos .* H_pos ./ H2;
119 | tmp2(isnan(tmp2)) = 0;
120 |
121 | d_AP_h_neg = tmp2 * triu(ones(L))';
122 | d_AP_h_neg = bsxfun(@rdivide, d_AP_h_neg, N_pos);
123 | d_AP_h_neg(isnan(d_AP_h_neg)|isinf(d_AP_h_neg)) = 0;
124 |
125 | % 3. d(FastAP)/d(x)
126 | d_AP_x = 0;
127 | for l = 1:L
128 | % NxN matrix of delta_hat(i, j, l) for fixed l
129 | dpulse = obj.dSoftBinning(obj.dist2, obj.Z(l), obj.Delta);
130 | dpulse(isnan(dpulse)|isinf(dpulse)) = 0;
131 | ddp = dpulse .* obj.I_pos;
132 | ddn = dpulse .* obj.I_neg;
133 |
134 | alpha_p = diag(d_AP_h_pos(:, l)); % NxN
135 | alpha_n = diag(d_AP_h_neg(:, l));
136 | Ap = ddp * alpha_p + alpha_p * ddp;
137 | An = ddn * alpha_n + alpha_n * ddn;
138 |
139 | % accumulate gradient % (BxN) (NxN) -> (BxN)
140 | d_AP_x = d_AP_x - X * (Ap + An);
141 | end
142 |
143 | % output
144 | dInputs{1} = zeros(size(inputs{1}), 'single');
145 | if onGPU, dInputs{1} = gpuArray(dInputs{1}); end
146 | dInputs{1}(1, 1, :, :) = -single(d_AP_x);
147 | dInputs{2} = [];
148 | dParams = {};
149 | end
150 |
151 |
152 | function [grad, objval] = computeGrad(obj, X, Y)
153 | % helper function to compute gradient matrix
154 | inputs = {X, Y};
155 | output = obj.forward(inputs, []);
156 | objval = obj.average;
157 | [dInputs, ~] = obj.backward(inputs, [], {'objective', 1});
158 | grad = dInputs{1};
159 | end
160 |
161 |
162 | function y = softBinning(D, mid, delta)
163 | % D: input matrix of distance values
164 | % mid: scalar, the center of some histogram bin
165 | % delta: scalar, histogram bin width
166 | %
167 | % For histogram bin mid, compute the contribution y
168 | % from every element in D.
169 | y = 1 - abs(D - mid) / delta;
170 | y = max(0, y);
171 | end
172 |
173 |
174 | function y = dSoftBinning(D, mid, delta);
175 | % vectorized version
176 | % mid: scalar bin center
177 | % D: can be a matrix
178 | ind1 = (D > mid-delta) & (D <= mid);
179 | ind2 = (D > mid) & (D <= mid+delta);
180 | y = (ind1 - ind2) / delta;
181 | end
182 |
183 | end
184 | end
185 |
--------------------------------------------------------------------------------
/matlab/README.md:
--------------------------------------------------------------------------------
1 | ## Matlab implementation of "Deep Metric Learning to Rank"
2 |
3 | ### Requirements
4 | * Matlab R2017b or newer
5 | * This is to use the built-in [mink](https://www.mathworks.com/help/matlab/ref/mink.html) function.
6 | * Alternatively, for earlier Matlab versions, you can use [this implementation](https://www.mathworks.com/matlabcentral/fileexchange/23576-min-max-selection) of mink.
7 | * [MatConvNet](http://www.vlfeat.org/matconvnet/) v1.0-beta25 (with [`vl_contrib`](http://www.vlfeat.org/matconvnet/mfiles/vl_contrib/))
8 | * [mcnExtraLayers](https://github.com/albanie/mcnExtraLayers) via `vl_contrib setup mcnExtraLayers`
9 | * [autonn](https://github.com/vlfeat/autonn) via `vl_contrib setup autonn`
10 |
11 | ### Preparation
12 | * Install/symlink MatConvNet at `./matconvnet` under this directory
13 | * Create or symlink a directory `./cachedir` under the this directory
14 | * Create a subdirectory `./cachedir/data`, and create symlinks to the datasets
15 | * Stanford Online Products at `./cachedir/data/Stanford_Online_Products`
16 | * In-Shop Clothes Retrieval at `./cachedir/data/InShopClothes`
17 | * PKU VehicleID at `./cachedir/data/VehicleID_V1.0`
18 | * Create `./cachedir/models` and download pretrained models
19 | * [GoogLeNet](http://www.vlfeat.org/matconvnet/models/imagenet-googlenet-dag.mat): download to `./cachedir/models/imagenet-googlenet-dag.mat`
20 | * [ResNet-18](http://www.robots.ox.ac.uk/~albanie/models/pytorch-imports/resnet18-pt-mcn.mat): download to `./cachedir/models/resnet18-pt-mcn.mat`
21 | * [ResNet-50](http://www.vlfeat.org/matconvnet/models/imagenet-resnet-50-dag.mat): download to `./cachedir/models/imagenet-resnet-50-dag.mat`
22 |
23 | ### Usage
24 | We provide a unified interface `run_demo.m` to run all experiments conducted in the paper. The general syntax is
25 | ```
26 | run_demo([dataset], [key-value pairs])
27 | ```
28 | where
29 | * `dataset` is one of `'products', 'inshop', 'vid'`
30 | * Various parameters are specified as key-value pairs. The full list can be found by inspecting `get_opts.m`. Some notable ones are:
31 | * `'gpus'` (int) 1-based GPU index. Current implementation only supports 1 GPU.
32 | * `'arch'` (string) network architecture. Available: `'googlenet', 'resnet18', 'resnet50'`
33 | * `'dim'` (int) embedding dimesionality, default 512
34 | * `'nbins'` (int) number of distance quantizations, default 10
35 | * `'solver'` (string) SGD optimizer. Default: `'adam'`. Others: `'sgd', 'adadelta', 'adagrad', 'rmsprop'`
36 |
37 | ### Notes
38 | * Our experiments are mainly run on a Titan X Pascal GPU with 12GB memory. If your GPU has a different amount of memory, you may want to change the default value of `opts.maxGpuImgs` (in `+models/*.m`) accordingly for each model architecture.
39 |
--------------------------------------------------------------------------------
/matlab/evaluate_model.m:
--------------------------------------------------------------------------------
1 | function evaluate_model(net, imdb, batchFunc, opts, varargin)
2 |
3 | if strcmp(opts.dataset, 'inshop')
4 | % separate query set and database
5 | query_id = find(imdb.images.set == 2);
6 | gallery_id = find(imdb.images.set == 3);
7 | Yquery = imdb.images.labels(query_id, :);
8 | Ygallery = imdb.images.labels(gallery_id, :);
9 | Fquery = cnn_encode(net, imdb, batchFunc, query_id, opts, varargin{:})';
10 | Fgallery = cnn_encode(net, imdb, batchFunc, gallery_id, opts, varargin{:})';
11 | whos Fquery Fgallery
12 |
13 | logInfo('[%s] Recall ...', opts.dataset);
14 | evaluate_recall(Fquery, Yquery, Fgallery, Ygallery, opts, false);
15 |
16 | elseif strcmp(opts.dataset, 'vid')
17 | % small, medium, large
18 | test_sets = {'Small', 'Medium', 'Large'};
19 | for s = 1:3
20 | test_id = find(imdb.images.set(:, s+1) == 1);
21 | Ytest = imdb.images.labels(test_id, :);
22 | Ftest = cnn_encode(net, imdb, batchFunc, test_id, opts, varargin{:})';
23 | logInfo('[%s] Recall on [%s] ...', dataset, test_sets{s});
24 | evaluate_recall(Ftest, Ytest, Ftest, Ytest, opts);
25 | end
26 |
27 | else
28 | % query set == database
29 | test_id = find(imdb.images.set == 3);
30 | Ytest = imdb.images.labels(test_id, :);
31 | Ftest = cnn_encode(net, imdb, batchFunc, test_id, opts, varargin{:})';
32 | whos Ftest
33 |
34 | logInfo('[%s] Recall & NMI ...', opts.dataset);
35 | evaluate_recall(Ftest, Ytest, Ftest, Ytest, opts);
36 | end
37 |
38 | end
39 |
40 | % ------------------------------------------------------------------------------
41 | % ------------------------------------------------------------------------------
42 |
43 | function evaluate_recall(Xq, Yq, Xdb, Ydb, opts, removeDiag)
44 | % features: Nxd matrix
45 | if ~exist('removeDiag', 'var'), removeDiag = true; end
46 |
47 | [Nq, d] = size(Xq);
48 | assert(Nq == size(Yq, 1));
49 | assert(size(Yq, 2) == size(Ydb, 2));
50 |
51 | % pairwise distances
52 | tic;
53 | distmatrix = 2 - 2 * Xq * Xdb'; % NxN
54 | if removeDiag
55 | distmatrix(logical(eye(Nq))) = Inf;
56 | end
57 |
58 | % recall rate
59 | Ks = sort(opts.Ks, 'ascend');
60 | recallrate = zeros(Nq, numel(Ks));
61 |
62 | for i = 1:Nq
63 | % get top max(K) results, in increasing distance
64 | [val, ind] = mink(distmatrix(i, :), Ks(end));
65 |
66 | % recall rates for all K's
67 | for j = 1:numel(Ks)
68 | ind_k = ind(1 : Ks(j));
69 | recallrate(i, j) = any(Ydb(ind_k, 1) == Yq(i, 1));
70 | end
71 | end
72 | toc;
73 |
74 | for j = 1:numel(Ks)
75 | fprintf('K: %4d, Recall: %.3f\n', opts.Ks(j), mean(recallrate(:, j)));
76 | end
77 |
78 | end
79 |
80 | % ------------------------------------------------------------------------------
81 | % ------------------------------------------------------------------------------
82 |
83 | function H = cnn_encode(net, imdb, batchFunc, ids, opts, varName)
84 | if ~exist('varName', 'var')
85 | varName = 'feats_l2';
86 | end
87 |
88 | batch_size = 2 * opts.maxGpuImgs;
89 | onGPU = numel(opts.gpus) > 0;
90 |
91 | logInfo('Testing [%s] on %d -> %s', opts.arch, length(ids), varName);
92 |
93 | net.mode = 'test';
94 | ind = net.getVarIndex(varName);
95 | net.vars(ind).precious = 1;
96 | if onGPU, net.move('gpu'); end
97 |
98 | assert(strcmp(varName, 'feats_l2'));
99 | H = zeros(opts.dim, length(ids), 'single');
100 |
101 | tic;
102 | for t = 1:batch_size:length(ids)
103 | ed = min(t+batch_size-1, length(ids));
104 | inputs = batchFunc(imdb, ids(t:ed));
105 | net.eval(inputs);
106 |
107 | ret = net.vars(ind).value;
108 | ret = squeeze(gather(ret));
109 | H(:, t:ed) = ret;
110 | end
111 |
112 | if onGPU && isa(net, 'dagnn.DagNN')
113 | net.reset();
114 | net.move('cpu');
115 | end
116 | toc;
117 |
118 | end
119 |
--------------------------------------------------------------------------------
/matlab/get_model.m:
--------------------------------------------------------------------------------
1 | function [net, opts] = get_model(opts, addFC)
2 | if nargin < 2, addFC = true; end
3 |
4 | t0 = tic;
5 | modelFunc = str2func(sprintf('models.%s', opts.arch));
6 | [net, opts, in_name, in_dim] = modelFunc(opts);
7 | logInfo('%s in %.2fs', opts.arch, toc(t0));
8 |
9 | % + FC layer
10 | if addFC
11 | convobj = dagnn.Conv('size', [1 1 in_dim opts.dim], ...
12 | 'pad', 0, 'stride', 1, 'hasBias', true);
13 | params = convobj.initParams();
14 | net.addLayer('fc', convobj, {in_name}, {'logits'}, {'fc_w', 'fc_b'});
15 | p1 = net.getParamIndex('fc_w');
16 | p2 = net.getParamIndex('fc_b');
17 | net.params(p1).value = params{1};
18 | net.params(p2).value = params{2};
19 | net.params(p1).learningRate = opts.lrmult;
20 | net.params(p2).learningRate = opts.lrmult;
21 |
22 | in_name = 'logits';
23 | in_dim = opts.dim;
24 | end
25 |
26 | % + l2 normalization layer
27 | net.addLayer('L2norm', dagnn.LRN('param', [2*in_dim, 0, 1, 0.5]), ...
28 | {in_name}, {'feats_l2'});
29 |
30 | % + loss layer
31 | lossobj = str2func(opts.obj);
32 | net.addLayer('loss', lossobj('opt', opts), {'feats_l2', 'labels'}, {'objective'});
33 |
34 | % print
35 | if 0
36 | net.print({'data', [opts.imageSize opts.imageSize 3 opts.batchSize]}, ...
37 | 'MaxNumColumns', 4, 'Layers', [], 'Parameters', []);
38 | end
39 |
40 | end
41 |
--------------------------------------------------------------------------------
/matlab/get_opts.m:
--------------------------------------------------------------------------------
1 | function opts = get_opts(opts, dataset, varargin)
2 |
3 | if nargin == 1
4 | opts = process_opts(opts);
5 | return;
6 | end
7 |
8 | ip = inputParser;
9 |
10 | % model params
11 | ip.addParameter('obj' , 'FastAP');
12 | ip.addParameter('arch' , 'resnet18'); % CNN model
13 | ip.addParameter('dim' , 512); % embedding vector size
14 | ip.addParameter('nbins' , 10); % quantization granularity
15 |
16 | % SGD
17 | ip.addParameter('solver' , 'adam');
18 | ip.addParameter('batchSize' , 256);
19 | ip.addParameter('lr' , 1e-5); % base learning rate
20 | ip.addParameter('lrmult' , 10); % learning rate multiplier for last layer
21 | ip.addParameter('lrdecay' , 0.1); % [SGD only] LR decay factor
22 | ip.addParameter('lrstep' , 10); % [SGD only] decay LR after this many epochs
23 | ip.addParameter('wdecay' , 0); % weight decay
24 |
25 | % train
26 | ip.addParameter('testInterval', 3); % test interval
27 | ip.addParameter('epoch' , 30); % num. epochs
28 | ip.addParameter('gpus' , 1); % which GPU to use
29 | ip.addParameter('continue' , true); % resume from existing model
30 | ip.addParameter('debug' , false);
31 | ip.addParameter('plot' , false);
32 |
33 | % misc
34 | ip.addParameter('ablation', 0); % 0: none, 1: batchSize, 2: nbins, 3: randMBS
35 | ip.addParameter('randseed', 42);
36 | ip.addParameter('prefix' , []);
37 |
38 | % parse input
39 | ip.KeepUnmatched = true;
40 | ip.parse(varargin{:});
41 | opts = catstruct(ip.Results, opts); % combine w/ existing opts
42 | opts.dataset = dataset;
43 |
44 | end
45 |
46 | % =======================================================================
47 | % =======================================================================
48 |
49 | function opts = process_opts(opts)
50 | % post-parse processing
51 |
52 | opts.expID = sprintf('%s-%s-%d-%s', opts.dataset, opts.obj, opts.dim, opts.arch);
53 |
54 | switch (opts.ablation)
55 | case 1
56 | opts.prefix = 'ablationM';
57 | case 2
58 | opts.prefix = 'ablationHist';
59 | case 3
60 | opts.prefix = 'randMBS';
61 | otherwise
62 | opts.prefix = '';
63 | end
64 |
65 | % for batch sizes > GPU mem limit: use chunks
66 | switch (opts.arch)
67 | case 'resnet50'
68 | opts.maxGpuImgs = 90;
69 | case 'resnet18'
70 | opts.maxGpuImgs = 256;
71 | case 'googlenet'
72 | opts.maxGpuImgs = 320;
73 | end
74 |
75 | % --------------------------------------------
76 | % identifier string for the current experiment
77 | ID = sprintf('%dbins-batch%d-%s%.emult%g', opts.nbins, opts.batchSize, ...
78 | opts.solver, opts.lr, opts.lrmult);
79 |
80 | if strcmp(opts.solver, 'sgd')
81 | assert(opts.lrdecay > 0 && opts.lrdecay < 1);
82 | assert(opts.lrstep > 0);
83 | ID = sprintf('%sD%gE%d', ID, opts.lrdecay, opts.lrstep);
84 | end
85 | ID = sprintf('%s-WD%.e', ID, opts.wdecay);
86 |
87 | if isempty(opts.prefix)
88 | % prefix: timestamp
89 | [~, T] = unix(['git log -1 --format=%ci|cut -d " " -f1,2|cut -d "-" -f2,3' ...
90 | '|tr " " "."|tr -d ":-"']);
91 | opts.prefix = strrep(T, newline, '');
92 | end
93 | opts.identifier = [opts.prefix '-' ID];
94 |
95 | % --------------------------------------------
96 | % mkdirs
97 | opts.localDir = fullfile(pwd, 'cachedir'); % use symlink on linux
98 | if ~exist(opts.localDir, 'file')
99 | error('Please mkdir/symlink cachedir!');
100 | end
101 | opts.dataDir = fullfile(opts.localDir, 'data');
102 | opts.imdbPath = fullfile(opts.dataDir, ['imdb_' opts.dataset]);
103 |
104 | opts.expDir = fullfile(opts.localDir, opts.expID, opts.identifier);
105 | if ~exist(opts.expDir, 'dir')
106 | logInfo(['creating opts.expDir: ' opts.expDir]);
107 | mkdir(opts.expDir);
108 | end
109 |
110 | % --------------------------------------------
111 | % rng
112 | rng(opts.randseed);
113 |
114 | end
115 |
--------------------------------------------------------------------------------
/matlab/run_demo.m:
--------------------------------------------------------------------------------
1 | function run_demo(dataset, varargin)
2 |
3 | % ----------------------------------------
4 | % init
5 | % ----------------------------------------
6 | ip = inputParser;
7 |
8 | if strcmp(dataset, 'products')
9 | ip.addParameter('Ks', [1 10 100 1000]);
10 | hardTrainFn = @train_hard;
11 | BPMP = 5;
12 |
13 | elseif strcmp(dataset, 'inshop')
14 | ip.addParameter('Ks', [1 10 20 30 40 50]);
15 | hardTrainFn = @train_hard;
16 | BPMP = 2;
17 |
18 | elseif strcmp(dataset, 'vid')
19 | ip.addParameter('Ks', [1 5]);
20 | hardTrainFn = @train_hard_vid;
21 | BPMP = 0;
22 |
23 | else
24 | error('dataset not yet supported');
25 | end
26 |
27 | ip.KeepUnmatched = true;
28 | ip.parse(varargin{:});
29 | opts = ip.Results;
30 | opts = get_opts(opts, dataset, varargin{:});
31 |
32 | % post-parsing
33 | cleanupObj = onCleanup(@() cleanup(opts.gpus));
34 |
35 | opts = get_opts(opts); % carry out all post-processing on opts
36 | record_diary(opts);
37 | disp(opts);
38 |
39 | % ----------------------------------------
40 | % model & data
41 | % ----------------------------------------
42 | [net, opts] = get_model(opts);
43 |
44 | global imdb
45 | imdb = get_imdb(imdb, opts);
46 | disp(imdb.images)
47 |
48 | % use imagenet-pretrained model
49 | imgSize = opts.imageSize;
50 | meanImage = single(net.meta.normalization.averageImage);
51 | batchFunc = @(I, B) batch_imagenet(I, B, imgSize, meanImage);
52 |
53 | % ----------------------------------------
54 | % train
55 | % ----------------------------------------
56 | % figure out solver & learning rate
57 | if strcmp(opts.solver, 'sgd')
58 | solverFunc = [];
59 | if opts.lrdecay>0 & opts.lrdecay<1
60 | cur_lr = opts.lr;
61 | lrvec = [];
62 | while length(lrvec) < opts.epoch
63 | lrvec = [lrvec, ones(1, opts.lrstep)*cur_lr];
64 | cur_lr = cur_lr * opts.lrdecay;
65 | end
66 | elseif opts.lrdecay > 1
67 | % linear decay, lrdecay specifies the # epoch -> 0
68 | assert(mod(opts.lrdecay, 1) == 0);
69 | lrvec = linspace(opts.lr, 0, opts.lrdecay+1);
70 | opts.epoch = min(opts.epoch, opts.lrdecay);
71 | end
72 | else
73 | solverFunc = str2func(['solver.' opts.solver]);
74 | lrvec = opts.lr;
75 | end
76 |
77 | if opts.ablation == 3
78 | % random minibatch sampling
79 | [net, info] = train_rand(net, imdb, batchFunc , ...
80 | 'saveInterval' , opts.testInterval , ...
81 | 'plotStatistics' , opts.plot , ...
82 | 'randomSeed' , opts.randseed , ...
83 | 'gpus' , opts.gpus , ...
84 | 'continue' , opts.continue , ...
85 | 'expDir' , opts.expDir , ...
86 | 'batchSize' , opts.batchSize , ...
87 | 'weightDecay' , opts.wdecay , ...
88 | 'numEpochs' , opts.epoch , ...
89 | 'learningRate' , lrvec , ...
90 | 'train' , imdb.images.PV , ...
91 | 'val' , NaN , ...
92 | 'solver' , solverFunc , ...
93 | 'postEpochFn' , @postepoch);
94 | else
95 | % hard minibatch sampling
96 | [net, info] = hardTrainFn(net, imdb, batchFunc , ...
97 | 'saveInterval' , opts.testInterval , ...
98 | 'plotStatistics' , opts.plot , ...
99 | 'randomSeed' , opts.randseed , ...
100 | 'gpus' , opts.gpus , ...
101 | 'continue' , opts.continue , ...
102 | 'expDir' , opts.expDir , ...
103 | 'batchSize' , opts.batchSize , ...
104 | 'weightDecay' , opts.wdecay , ...
105 | 'numEpochs' , opts.epoch , ...
106 | 'learningRate' , lrvec , ...
107 | 'val' , NaN , ...
108 | 'solver' , solverFunc , ...
109 | 'postEpochFn' , @postepoch , ...
110 | 'batchesPerMetaPair' , BPMP);
111 | end
112 |
113 | % ----------------------------------------
114 | % done
115 | % ----------------------------------------
116 | net.reset();
117 | net.move('cpu');
118 | diary('off');
119 |
120 | end
121 |
122 | % ====================================================================
123 | % get IMDB
124 | % ====================================================================
125 |
126 | function imdb = get_imdb(imdb, opts)
127 | imdbName = sprintf('%s_%d', opts.dataset, opts.imageSize);
128 | logInfo('IMDB: %s', imdbName);
129 |
130 | if ~isempty(imdb) && strcmp(imdb.name, imdbName)
131 | return;
132 | end
133 |
134 | % load/compute
135 | t0 = tic;
136 | imdbFunc = str2func(['imdb.' opts.dataset]);
137 | imdb = imdbFunc(['imdb_' imdbName], opts) ;
138 | imdb.name = imdbName;
139 |
140 | if opts.ablation == 3
141 | % random MBS: shuffle training instances
142 | imdb.images.PV = shuffle(imdb, opts);
143 | else
144 | % hard MBS: group by meta-class
145 | if ~isfield(imdb, 'metaclass') || isempty(imdb.metaclass)
146 | logInfo('Analyzing meta-classes...'); tic;
147 |
148 | Ycls = imdb.images.labels(imdb.images.set==1, 1);
149 | Ymeta = imdb.images.labels(imdb.images.set==1, 2);
150 | Umeta = unique(Ymeta);
151 | imdb.metaclass = cell(1, numel(Umeta));
152 |
153 | for i = 1:numel(Umeta)
154 | % group instances in this metaclass into cells
155 | Im = find(Ymeta == Umeta(i)); % images in this metaclass
156 | Ic = unique(Ycls(Im)); % instances in this metaclass
157 | G = arrayfun(@(x) find(Ycls==x & Ymeta==Umeta(i)), Ic, 'uniform', 0);
158 |
159 | % store in imdb
160 | imdb.metaclass{i} = [];
161 | imdb.metaclass{i}.id = Umeta(i);
162 | imdb.metaclass{i}.num = numel(Im);
163 | imdb.metaclass{i}.groups = G;
164 | end
165 | toc;
166 | end
167 | end
168 |
169 | logInfo('%s in %.2f sec', imdbName, toc(t0));
170 | end
171 |
172 | % ====================================================================
173 | % postprocessing after each epoch
174 | % ====================================================================
175 |
176 | function [lr, params] = postepoch(net, params, state)
177 | lr = [];
178 | if isa(net, 'composite')
179 | net = net{1};
180 | end
181 | opts = net.layers(end).block.opt;
182 | epoch = params.epoch;
183 | logInfo(opts.expID);
184 | logInfo(opts.identifier);
185 | logInfo(char(datetime));
186 |
187 | if ~isempty(opts.gpus)
188 | [~, name] = unix('hostname');
189 | logInfo('GPU #%d on %s', opts.gpus, name);
190 | end
191 |
192 | % evaluate
193 | if (epoch==1) | ~mod(epoch, opts.testInterval);
194 | evaluate_model(net, params.imdb, params.getBatch, opts);
195 | end
196 |
197 | % reshuffle
198 | if isfield(opts, 'ablation') && opts.ablation == 3
199 | params.train = shuffle(params.imdb, opts);
200 | end
201 | diary off, diary on
202 | end
203 |
204 |
205 | function PV = shuffle(imdb, opts)
206 | % group by class label, then shuffle
207 | Itrain = find(imdb.images.set == 1);
208 | Ytrain = imdb.images.labels(Itrain, 1);
209 | C = arrayfun(@(y) find(Ytrain==y), unique(Ytrain), 'uniform', false);
210 | PV = [];
211 |
212 | % go through classes in random order
213 | for j = randperm(numel(C))
214 | % randomize the instances
215 | ind = C{j};
216 | ind = ind(randperm(length(ind)));
217 | PV = [PV; ind];
218 | end
219 |
220 | assert(length(PV) == length(Itrain));
221 | PV = Itrain(PV);
222 | end
223 |
224 | % ====================================================================
225 | % misc
226 | % ====================================================================
227 |
228 | function cleanup(gpuid)
229 | diary('off');
230 | end
231 |
--------------------------------------------------------------------------------
/matlab/startup.m:
--------------------------------------------------------------------------------
1 | % util
2 | addpath('util');
3 |
4 | % mink()
5 | if ~exist('mink', 'builtin')
6 | warning('Please use Matlab R2017b or newer, or follow the README to download the MinMaxSelection implementation.');
7 | end
8 |
9 | % matconvnet
10 | logInfo('setting up [MatConvNet]');
11 | run ./matconvnet/matlab/vl_setupnn
12 |
13 | % mcn extra layers
14 | logInfo('setting up [mcnExtraLayers]');
15 | vl_contrib setup mcnExtraLayers
16 |
17 | % autonn
18 | logInfo('setting up [autonn]');
19 | vl_contrib setup autonn
20 |
21 | logInfo('Done!');
22 |
--------------------------------------------------------------------------------
/matlab/train_hard.m:
--------------------------------------------------------------------------------
1 | function [net,stats] = train_hard(net, imdb, getBatch, varargin)
2 | %CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper
3 | % CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with
4 | % the DagNN wrapper instead of the SimpleNN wrapper.
5 |
6 | % Copyright (C) 2014-16 Andrea Vedaldi.
7 | % All rights reserved.
8 | %
9 | % This file is part of the VLFeat library and is made available under
10 | % the terms of the BSD license (see the COPYING file).
11 | addpath(fullfile(vl_rootnn, 'examples'));
12 |
13 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%%
14 | opts.saveInterval = 2;
15 | opts.batchesPerMetaPair = 5;
16 | opts.maxGpuImgs = Inf;
17 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%%
18 |
19 | opts.expDir = fullfile('data','exp') ;
20 | opts.continue = true ;
21 | opts.batchSize = 256 ;
22 | opts.numSubBatches = 1 ;
23 | opts.train = [] ;
24 | opts.val = [] ;
25 | opts.gpus = [] ;
26 | opts.prefetch = false ;
27 | opts.epochSize = inf;
28 | opts.numEpochs = 300 ;
29 | opts.learningRate = 0.001 ;
30 | opts.weightDecay = 0.0005 ;
31 |
32 | opts.solver = [] ; % Empty array means use the default SGD solver
33 | [opts, varargin] = vl_argparse(opts, varargin) ;
34 | if ~isempty(opts.solver)
35 | assert(isa(opts.solver, 'function_handle') && nargout(opts.solver) == 2,...
36 | 'Invalid solver; expected a function handle with two outputs.') ;
37 | % Call without input arguments, to get default options
38 | opts.solverOpts = opts.solver() ;
39 | end
40 |
41 | opts.momentum = 0.9 ;
42 | opts.saveSolverState = true ;
43 | opts.nesterovUpdate = false ;
44 | opts.randomSeed = 0 ;
45 | opts.profile = false ;
46 | opts.parameterServer.method = 'mmap' ;
47 | opts.parameterServer.prefix = 'mcn' ;
48 |
49 | opts.derOutputs = {'objective', 1} ;
50 | opts.extractStatsFn = @extractStats ;
51 | opts.plotStatistics = true;
52 | opts.postEpochFn = [] ; % postEpochFn(net,params,state) called after each epoch; can return a new learning rate, 0 to stop, [] for no change
53 | opts = vl_argparse(opts, varargin) ;
54 |
55 | if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end
56 | if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end
57 | if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end
58 | if isscalar(opts.train) && isnumeric(opts.train) && isnan(opts.train)
59 | opts.train = [] ;
60 | end
61 | if isscalar(opts.val) && isnumeric(opts.val) && isnan(opts.val)
62 | opts.val = [] ;
63 | end
64 |
65 | % -------------------------------------------------------------------------
66 | % Initialization
67 | % -------------------------------------------------------------------------
68 |
69 | evaluateMode = isempty(opts.train) ;
70 | if ~evaluateMode
71 | if isempty(opts.derOutputs)
72 | error('DEROUTPUTS must be specified when training.\n') ;
73 | end
74 | end
75 |
76 | % -------------------------------------------------------------------------
77 | % Train and validate
78 | % -------------------------------------------------------------------------
79 |
80 | modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep));
81 | modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ;
82 |
83 | start = opts.continue * findLastCheckpoint(opts.expDir) ;
84 | if start >= 1
85 | fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ;
86 | % [KH] make sure to use the input opts, saved maybe outdated
87 | opt = net.layers(end).block.opt;
88 | [net, state, stats] = loadState(modelPath(start)) ;
89 | net.layers(end).block.opt = opt;
90 | else
91 | state = [] ;
92 | end
93 |
94 | for epoch=start+1:opts.numEpochs
95 |
96 | % Set the random seed based on the epoch and opts.randomSeed.
97 | % This is important for reproducibility, including when training
98 | % is restarted from a checkpoint.
99 |
100 | rng(epoch + opts.randomSeed) ;
101 | prepareGPUs(opts, epoch == start+1) ;
102 |
103 | % Train for one epoch.
104 | params = opts ;
105 | params.epoch = epoch ;
106 | params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ;
107 | params.train = opts.train; %(randperm(numel(opts.train))) ; % shuffle
108 | params.train = params.train; %(1:min(opts.epochSize, numel(opts.train)));
109 | params.val = opts.val; %(randperm(numel(opts.val))) ;
110 | params.imdb = imdb ;
111 | params.getBatch = getBatch ;
112 |
113 | if numel(opts.gpus) <= 1
114 | [net, state] = processEpoch(net, state, params, 'train') ;
115 | [net, state] = processEpoch(net, state, params, 'val') ;
116 | if ~evaluateMode && ~mod(epoch, opts.saveInterval)
117 | saveState(modelPath(epoch), net, state) ;
118 | end
119 | lastStats = state.stats ;
120 | else
121 | spmd
122 | [net, state] = processEpoch(net, state, params, 'train') ;
123 | [net, state] = processEpoch(net, state, params, 'val') ;
124 | if labindex == 1 && ~evaluateMode && ~mod(epoch, opts.saveInterval)
125 | saveState(modelPath(epoch), net, state) ;
126 | end
127 | lastStats = state.stats ;
128 | end
129 | lastStats = accumulateStats(lastStats) ;
130 | end
131 |
132 | stats.train(epoch) = lastStats.train ;
133 | stats.val(epoch) = lastStats.val ;
134 | clear lastStats ;
135 | if ~mod(epoch, opts.saveInterval)
136 | saveStats(modelPath(epoch), stats) ;
137 | end
138 |
139 | if opts.plotStatistics
140 | switchFigure(1) ; clf ;
141 | plots = setdiff(...
142 | cat(2,...
143 | fieldnames(stats.train)', ...
144 | fieldnames(stats.val)'), {'num', 'time'}) ;
145 | for p = plots
146 | p = char(p) ;
147 | values = zeros(0, epoch) ;
148 | leg = {} ;
149 | for f = {'train', 'val'}
150 | f = char(f) ;
151 | if isfield(stats.(f), p)
152 | tmp = [stats.(f).(p)] ;
153 | values(end+1,:) = tmp(1,:)' ;
154 | leg{end+1} = f ;
155 | end
156 | end
157 | subplot(1,numel(plots),find(strcmp(p,plots))) ;
158 | plot(1:epoch, values','o-') ;
159 | xlabel('epoch') ;
160 | title(p) ;
161 | legend(leg{:}) ;
162 | grid on ;
163 | end
164 | drawnow ;
165 | print(1, modelFigPath, '-dpdf') ;
166 | end
167 |
168 | if ~isempty(opts.postEpochFn)
169 | if nargout(opts.postEpochFn) == 0
170 | opts.postEpochFn(net, params, state) ;
171 | else
172 | [lr, params] = opts.postEpochFn(net, params, state) ;
173 | if ~isempty(lr), opts.learningRate = lr; end
174 | if opts.learningRate == 0, break; end
175 | opts.train = params.train;
176 | end
177 | end
178 | end
179 | logInfo('Complete: %d epochs.', opts.numEpochs);
180 | if ~isempty(opts.postEpochFn)
181 | params = opts;
182 | params.imdb = imdb;
183 | params.epoch = params.numEpochs;
184 | params.getBatch = getBatch;
185 | opts.postEpochFn(net, params, state) ;
186 | end
187 |
188 | % With multiple GPUs, return one copy
189 | if isa(net, 'Composite'), net = net{1} ; end
190 |
191 |
192 | % -------------------------------------------------------------------------
193 | function [net, state] = processEpoch(net, state, params, mode)
194 | % -------------------------------------------------------------------------
195 | % Note that net is not strictly needed as an output argument as net
196 | % is a handle class. However, this fixes some aliasing issue in the
197 | % spmd caller.
198 |
199 | % initialize with momentum 0
200 | if isempty(state) || isempty(state.solverState)
201 | state.solverState = cell(1, numel(net.params)) ;
202 | state.solverState(:) = {0} ;
203 | end
204 |
205 | % move CNN to GPU as needed
206 | numGpus = numel(params.gpus) ;
207 | if numGpus >= 1
208 | net.move('gpu') ;
209 | for i = 1:numel(state.solverState)
210 | s = state.solverState{i} ;
211 | if isnumeric(s)
212 | state.solverState{i} = gpuArray(s) ;
213 | elseif isstruct(s)
214 | state.solverState{i} = structfun(@gpuArray, s, 'UniformOutput', false) ;
215 | end
216 | end
217 | end
218 | if numGpus > 1
219 | parserv = ParameterServer(params.parameterServer) ;
220 | net.setParameterServer(parserv) ;
221 | else
222 | parserv = [] ;
223 | end
224 |
225 | % profile
226 | if params.profile
227 | if numGpus <= 1
228 | profile clear ;
229 | profile on ;
230 | else
231 | mpiprofile reset ;
232 | mpiprofile on ;
233 | end
234 | end
235 |
236 | num = 0 ;
237 | epoch = params.epoch ;
238 | subset = params.(mode) ;
239 | adjustTime = 0 ;
240 |
241 | stats.num = 0 ; % return something even if subset = []
242 | stats.time = 0 ;
243 |
244 | % -----------------------------------------
245 | % -----------------------------------------
246 | N = numel(subset);
247 | if N > 0
248 | logInfo('Epoch %d/%d: %s LR=%d', epoch, params.numEpochs, ...
249 | func2str(params.solver), params.learningRate);
250 |
251 | % [KH] preprocess: group by meta-class
252 | metaC = params.imdb.metaclass;
253 | Nmeta = numel(metaC);
254 | metaC_idx = cell(1, Nmeta);
255 | for i = 1:Nmeta
256 | G = metaC{i}.groups;
257 | G = G(randperm(numel(G)));
258 | metaC_idx{i} = cat(1, G{:});
259 | end
260 | batchSize = params.batchSize;
261 | assert(mod(batchSize, 2) == 0);
262 | metaPairs = [];
263 | for m = 1:Nmeta
264 | %metaPairs = [metaPairs; m*ones(Nmeta-m+1, 1), (m:Nmeta)'];
265 | metaPairs = [metaPairs; m*ones(Nmeta-m, 1), (m+1:Nmeta)'];
266 | end
267 | numPairs = size(metaPairs, 1);
268 | metaPairs = metaPairs(randperm(numPairs), :); % opt
269 | numBatches = numPairs * params.batchesPerMetaPair;
270 | logInfo('%d meta-classes, %d batches per meta-class pair', Nmeta, ...
271 | params.batchesPerMetaPair);
272 |
273 | netopts = net.layers(end).block.opt;
274 | if strcmp(mode, 'train')
275 | if params.batchSize > netopts.maxGpuImgs
276 | mode = 'train_staged';
277 | else
278 | net.mode = 'normal' ;
279 | net.conserveMemory = true;
280 | end
281 | else
282 | net.mode = 'test' ;
283 | end
284 |
285 | start = tic; a = tic;
286 | objs = [];
287 | cur = zeros(1, Nmeta);
288 | t = 0;
289 | for p = 1:numPairs
290 | % get 2 meta-classes
291 | m1 = metaPairs(p, 1);
292 | m2 = metaPairs(p, 2);
293 | for k = 1:params.batchesPerMetaPair
294 | if m1 == m2
295 | I = cur(m1) + (1:batchSize);
296 | I = mod(I-1, metaC{m1}.num) + 1;
297 | cur(m1) = I(end);
298 | batch = metaC_idx{m1}(I);
299 | else
300 | % get equal # of imgs from both meta-class
301 | I1 = cur(m1) + (1:batchSize/2);
302 | I2 = cur(m2) + (1:batchSize/2);
303 | I1 = mod(I1-1, metaC{m1}.num) + 1;
304 | I2 = mod(I2-1, metaC{m2}.num) + 1;
305 | cur(m1) = I1(end);
306 | cur(m2) = I2(end);
307 | batch = [metaC_idx{m1}(I1); metaC_idx{m2}(I2)];
308 | end
309 | t = t + batchSize;
310 |
311 | % train 1 batch
312 | if strcmp(mode, 'train')
313 | % regular train mode (batchSize within GPU limit)
314 | inputs = params.getBatch(params.imdb, batch) ;
315 | net.eval(inputs, params.derOutputs) ;
316 |
317 | if ~isempty(parserv), parserv.sync() ; end
318 | state = accumulateGradients(net, state, params, batchSize, parserv) ;
319 |
320 | elseif strcmp(mode, 'train_staged')
321 | % staged backprop mode
322 | % stage 1: get embedding matrix, in chunks
323 | net.mode = 'test';
324 | ind = net.getVarIndex('feats_l2');
325 | net.vars(ind).precious = 1;
326 |
327 | labelMat = []; % label matrix
328 | embeddingMat = zeros(1, 1, netopts.dim, batchSize); % embedding matrix
329 | if numel(params.gpus) > 0
330 | embeddingMat = gpuArray(embeddingMat);
331 | end
332 |
333 | subBatchSize = netopts.maxGpuImgs;
334 | nSub = ceil(batchSize / subBatchSize);
335 | inputsCache = cell(1, nSub);
336 | for s = 1:nSub
337 | sub = (s-1)*subBatchSize+1 : min(batchSize, s*subBatchSize);
338 | subInputs = params.getBatch(params.imdb, batch(sub));
339 | inputsCache{s} = subInputs;
340 |
341 | net.eval(subInputs);
342 | embeddingMat(:, :, :, sub) = gather(net.vars(ind).value);
343 | labelMat = [labelMat; subInputs{4}];
344 | end
345 |
346 | % stage 2: compute gradient matrix
347 | [gradientMat, obj_staged] = net.layers(end).block.computeGrad(...
348 | embeddingMat, labelMat);
349 |
350 | % stage 3: accumulate gradients
351 | net.mode = 'normal';
352 | for s = 1:nSub
353 | sub = (s-1)*subBatchSize+1 : min(batchSize, s*subBatchSize);
354 | featsGrad = gradientMat(:, :, :, sub);
355 | net.eval(inputsCache{s}, {'feats_l2', featsGrad});
356 |
357 | if ~isempty(parserv), parserv.sync() ; end
358 | state = accumulateGradients(net, state, params, numel(sub), parserv);
359 | end
360 | else
361 | % eval mode
362 | net.eval(inputs) ;
363 | end
364 |
365 | % Get statistics.
366 | time = toc(start) + adjustTime ;
367 | batchTime = time - stats.time ;
368 | stats.num = num ;
369 | stats.time = time ;
370 | stats = params.extractStatsFn(stats,net) ;
371 | currentSpeed = batchSize / batchTime ;
372 | averageSpeed = t / time ;
373 | if strcmp(mode, 'train_staged')
374 | objs = [objs, obj_staged];
375 | else
376 | objs = [objs, stats.objective];
377 | end
378 | if toc(a) > 30
379 | %if strcmp(mode, 'train_staged'), fprintf('\n'); end
380 | fprintf('%s-%s ep%02d: %3d/%3d (%d) %.1fHz', ...
381 | net.layers(end).block.opt.dataset, mode, epoch, ...
382 | fix(t/batchSize), numBatches, batchSize, averageSpeed) ;
383 | if strcmp(mode, 'train_staged')
384 | fprintf(' obj: %.3f (%.3f)\n', obj_staged, mean(objs(~isnan(objs))));
385 | else
386 | fprintf(' obj: %.3f (%.3f)\n', stats.objective, mean(objs(~isnan(objs))));
387 | end
388 | a = tic;
389 | end
390 | end % for k
391 | end % for p
392 | if strcmp(mode, 'train_staged')
393 | mode = 'train';
394 | end
395 | logInfo('Epoch %d, Avg %s obj = %g', epoch, mode, mean(objs(~isnan(objs))));
396 | end % if N>0
397 |
398 | % Save back to state.
399 | state.stats.(mode) = stats ;
400 | if params.profile
401 | if numGpus <= 1
402 | state.prof.(mode) = profile('info') ;
403 | profile off ;
404 | else
405 | state.prof.(mode) = mpiprofile('info');
406 | mpiprofile off ;
407 | end
408 | end
409 | if ~params.saveSolverState
410 | state.solverState = [] ;
411 | else
412 | for i = 1:numel(state.solverState)
413 | s = state.solverState{i} ;
414 | if isnumeric(s)
415 | state.solverState{i} = gather(s) ;
416 | elseif isstruct(s)
417 | state.solverState{i} = structfun(@gather, s, 'UniformOutput', false) ;
418 | end
419 | end
420 | end
421 |
422 | net.reset() ;
423 | net.move('cpu') ;
424 |
425 | % -------------------------------------------------------------------------
426 | function state = accumulateGradients(net, state, params, batchSize, parserv)
427 | % -------------------------------------------------------------------------
428 | numGpus = numel(params.gpus) ;
429 | otherGpus = setdiff(1:numGpus, labindex) ;
430 |
431 | for p=1:numel(net.params)
432 |
433 | if ~isempty(parserv)
434 | parDer = parserv.pullWithIndex(p) ;
435 | else
436 | parDer = net.params(p).der ;
437 | end
438 |
439 | switch net.params(p).trainMethod
440 | case 'average' % mainly for batch normalization
441 | thisLR = net.params(p).learningRate ;
442 | net.params(p).value = vl_taccum(...
443 | 1 - thisLR, net.params(p).value, ...
444 | (thisLR/batchSize/net.params(p).fanout), parDer) ;
445 |
446 | case 'gradient'
447 | thisDecay = params.weightDecay * net.params(p).weightDecay ;
448 | thisLR = params.learningRate * net.params(p).learningRate ;
449 |
450 | if thisLR>0 || thisDecay>0
451 | % Normalize gradient and incorporate weight decay.
452 | parDer = vl_taccum(1/batchSize, parDer, ...
453 | thisDecay, net.params(p).value) ;
454 |
455 | if isempty(params.solver)
456 | % Default solver is the optimised SGD.
457 | % Update momentum.
458 | state.solverState{p} = vl_taccum(...
459 | params.momentum, state.solverState{p}, ...
460 | -1, parDer) ;
461 |
462 | % Nesterov update (aka one step ahead).
463 | if params.nesterovUpdate
464 | delta = params.momentum * state.solverState{p} - parDer ;
465 | else
466 | delta = state.solverState{p} ;
467 | end
468 |
469 | % Update parameters.
470 | net.params(p).value = vl_taccum(...
471 | 1, net.params(p).value, thisLR, delta) ;
472 |
473 | else
474 | % call solver function to update weights
475 | [net.params(p).value, state.solverState{p}] = ...
476 | params.solver(net.params(p).value, state.solverState{p}, ...
477 | parDer, params.solverOpts, thisLR) ;
478 | end
479 | end
480 | otherwise
481 | error('Unknown training method ''%s'' for parameter ''%s''.', ...
482 | net.params(p).trainMethod, ...
483 | net.params(p).name) ;
484 | end
485 | end
486 |
487 | % -------------------------------------------------------------------------
488 | function stats = accumulateStats(stats_)
489 | % -------------------------------------------------------------------------
490 | for s = {'train', 'val'}
491 | s = char(s) ;
492 | total = 0 ;
493 |
494 | % initialize stats stucture with same fields and same order as
495 | % stats_{1}
496 | stats__ = stats_{1} ;
497 | names = fieldnames(stats__.(s))' ;
498 | values = zeros(1, numel(names)) ;
499 | fields = cat(1, names, num2cell(values)) ;
500 | stats.(s) = struct(fields{:}) ;
501 |
502 | for g = 1:numel(stats_)
503 | stats__ = stats_{g} ;
504 | num__ = stats__.(s).num ;
505 | total = total + num__ ;
506 |
507 | for f = setdiff(fieldnames(stats__.(s))', 'num')
508 | f = char(f) ;
509 | stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ;
510 |
511 | if g == numel(stats_)
512 | stats.(s).(f) = stats.(s).(f) / total ;
513 | end
514 | end
515 | end
516 | stats.(s).num = total ;
517 | end
518 |
519 | % -------------------------------------------------------------------------
520 | function stats = extractStats(stats, net)
521 | % -------------------------------------------------------------------------
522 | sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ;
523 | for i = 1:numel(sel)
524 | if net.layers(sel(i)).block.ignoreAverage, continue; end;
525 | stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ;
526 | end
527 |
528 | % -------------------------------------------------------------------------
529 | function saveState(fileName, net_, state)
530 | % -------------------------------------------------------------------------
531 | net = net_.saveobj() ;
532 | save(fileName, 'net', 'state') ;
533 | logInfo('saved: %s', fileName);
534 |
535 | % -------------------------------------------------------------------------
536 | function saveStats(fileName, stats)
537 | % -------------------------------------------------------------------------
538 | if exist(fileName)
539 | try, save(fileName, 'stats', '-append') ; end
540 | else
541 | try, save(fileName, 'stats') ; end
542 | end
543 |
544 | % -------------------------------------------------------------------------
545 | function [net, state, stats] = loadState(fileName)
546 | % -------------------------------------------------------------------------
547 | load(fileName, 'net', 'state', 'stats') ;
548 | net = dagnn.DagNN.loadobj(net) ;
549 | if isempty(whos('stats'))
550 | warning('Epoch ''%s'' was only partially saved. Delete this file and try again.', ...
551 | fileName) ;
552 | end
553 |
554 | % -------------------------------------------------------------------------
555 | function epoch = findLastCheckpoint(modelDir)
556 | % -------------------------------------------------------------------------
557 | list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
558 | tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
559 | epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ;
560 | epoch = max([epoch 0]) ;
561 |
562 | % -------------------------------------------------------------------------
563 | function switchFigure(n)
564 | % -------------------------------------------------------------------------
565 | if get(0,'CurrentFigure') ~= n
566 | try
567 | set(0,'CurrentFigure',n) ;
568 | catch
569 | figure(n) ;
570 | end
571 | end
572 |
573 | % -------------------------------------------------------------------------
574 | function clearMex()
575 | % -------------------------------------------------------------------------
576 | clear vl_tmove vl_imreadjpeg ;
577 |
578 | % -------------------------------------------------------------------------
579 | function prepareGPUs(opts, cold)
580 | % -------------------------------------------------------------------------
581 | numGpus = numel(opts.gpus) ;
582 | if numGpus > 1
583 | % check parallel pool integrity as it could have timed out
584 | pool = gcp('nocreate') ;
585 | if ~isempty(pool) && pool.NumWorkers ~= numGpus
586 | delete(pool) ;
587 | end
588 | pool = gcp('nocreate') ;
589 | if isempty(pool)
590 | parpool('local', numGpus) ;
591 | cold = true ;
592 | end
593 |
594 | end
595 | if numGpus >= 1 && cold
596 | fprintf('%s: resetting GPU\n', mfilename)
597 | clearMex() ;
598 | if numGpus == 1
599 | gpuDevice(opts.gpus)
600 | else
601 | spmd
602 | clearMex() ;
603 | gpuDevice(opts.gpus(labindex))
604 | end
605 | end
606 | end
607 |
--------------------------------------------------------------------------------
/matlab/train_hard_vid.m:
--------------------------------------------------------------------------------
1 | function [net,stats] = train_hard_vid(net, imdb, getBatch, varargin)
2 | %CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper
3 | % CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with
4 | % the DagNN wrapper instead of the SimpleNN wrapper.
5 |
6 | % Copyright (C) 2014-16 Andrea Vedaldi.
7 | % All rights reserved.
8 | %
9 | % This file is part of the VLFeat library and is made available under
10 | % the terms of the BSD license (see the COPYING file).
11 | addpath(fullfile(vl_rootnn, 'examples'));
12 |
13 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%%
14 | opts.saveInterval = 2;
15 | opts.batchesPerMetaPair = 0;
16 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%%
17 |
18 | opts.expDir = fullfile('data','exp') ;
19 | opts.continue = true ;
20 | opts.batchSize = 256 ;
21 | opts.numSubBatches = 1 ;
22 | opts.train = [] ;
23 | opts.val = [] ;
24 | opts.gpus = [] ;
25 | opts.prefetch = false ;
26 | opts.epochSize = inf;
27 | opts.numEpochs = 300 ;
28 | opts.learningRate = 0.001 ;
29 | opts.weightDecay = 0.0005 ;
30 |
31 | opts.solver = [] ; % Empty array means use the default SGD solver
32 | [opts, varargin] = vl_argparse(opts, varargin) ;
33 | if ~isempty(opts.solver)
34 | assert(isa(opts.solver, 'function_handle') && nargout(opts.solver) == 2,...
35 | 'Invalid solver; expected a function handle with two outputs.') ;
36 | % Call without input arguments, to get default options
37 | opts.solverOpts = opts.solver() ;
38 | end
39 |
40 | opts.momentum = 0.9 ;
41 | opts.saveSolverState = true ;
42 | opts.nesterovUpdate = false ;
43 | opts.randomSeed = 0 ;
44 | opts.profile = false ;
45 | opts.parameterServer.method = 'mmap' ;
46 | opts.parameterServer.prefix = 'mcn' ;
47 |
48 | opts.derOutputs = {'objective', 1} ;
49 | opts.extractStatsFn = @extractStats ;
50 | opts.plotStatistics = true;
51 | opts.postEpochFn = [] ; % postEpochFn(net,params,state) called after each epoch; can return a new learning rate, 0 to stop, [] for no change
52 | opts = vl_argparse(opts, varargin) ;
53 |
54 | if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end
55 | if isempty(opts.train), opts.train = find(imdb.images.set(:,1)==1) ; end
56 | if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end
57 | if isscalar(opts.train) && isnumeric(opts.train) && isnan(opts.train)
58 | opts.train = [] ;
59 | end
60 | if isscalar(opts.val) && isnumeric(opts.val) && isnan(opts.val)
61 | opts.val = [] ;
62 | end
63 |
64 | % -------------------------------------------------------------------------
65 | % Initialization
66 | % -------------------------------------------------------------------------
67 |
68 | evaluateMode = isempty(opts.train) ;
69 | if ~evaluateMode
70 | if isempty(opts.derOutputs)
71 | error('DEROUTPUTS must be specified when training.\n') ;
72 | end
73 | end
74 |
75 | % -------------------------------------------------------------------------
76 | % Train and validate
77 | % -------------------------------------------------------------------------
78 |
79 | modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep));
80 | modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ;
81 |
82 | start = opts.continue * findLastCheckpoint(opts.expDir) ;
83 | if start >= 1
84 | fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ;
85 | % [KH] make sure to use the input opts, saved maybe outdated
86 | opt = net.layers(end).block.opt;
87 | [net, state, stats] = loadState(modelPath(start)) ;
88 | net.layers(end).block.opt = opt;
89 | else
90 | state = [] ;
91 | end
92 |
93 | for epoch=start+1:opts.numEpochs
94 |
95 | % Set the random seed based on the epoch and opts.randomSeed.
96 | % This is important for reproducibility, including when training
97 | % is restarted from a checkpoint.
98 |
99 | rng(epoch + opts.randomSeed) ;
100 | prepareGPUs(opts, epoch == start+1) ;
101 |
102 | % Train for one epoch.
103 | params = opts ;
104 | params.epoch = epoch ;
105 | params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ;
106 | params.train = opts.train; %(randperm(numel(opts.train))) ; % shuffle
107 | params.train = params.train; %(1:min(opts.epochSize, numel(opts.train)));
108 | params.val = opts.val; %(randperm(numel(opts.val))) ;
109 | params.imdb = imdb ;
110 | params.getBatch = getBatch ;
111 |
112 | if numel(opts.gpus) <= 1
113 | [net, state] = processEpoch(net, state, params, 'train') ;
114 | [net, state] = processEpoch(net, state, params, 'val') ;
115 | if ~evaluateMode && ~mod(epoch, opts.saveInterval)
116 | saveState(modelPath(epoch), net, state) ;
117 | end
118 | lastStats = state.stats ;
119 | else
120 | spmd
121 | [net, state] = processEpoch(net, state, params, 'train') ;
122 | [net, state] = processEpoch(net, state, params, 'val') ;
123 | if labindex == 1 && ~evaluateMode && ~mod(epoch, opts.saveInterval)
124 | saveState(modelPath(epoch), net, state) ;
125 | end
126 | lastStats = state.stats ;
127 | end
128 | lastStats = accumulateStats(lastStats) ;
129 | end
130 |
131 | stats.train(epoch) = lastStats.train ;
132 | stats.val(epoch) = lastStats.val ;
133 | clear lastStats ;
134 | if ~mod(epoch, opts.saveInterval)
135 | saveStats(modelPath(epoch), stats) ;
136 | end
137 |
138 | if opts.plotStatistics
139 | switchFigure(1) ; clf ;
140 | plots = setdiff(...
141 | cat(2,...
142 | fieldnames(stats.train)', ...
143 | fieldnames(stats.val)'), {'num', 'time'}) ;
144 | for p = plots
145 | p = char(p) ;
146 | values = zeros(0, epoch) ;
147 | leg = {} ;
148 | for f = {'train', 'val'}
149 | f = char(f) ;
150 | if isfield(stats.(f), p)
151 | tmp = [stats.(f).(p)] ;
152 | values(end+1,:) = tmp(1,:)' ;
153 | leg{end+1} = f ;
154 | end
155 | end
156 | subplot(1,numel(plots),find(strcmp(p,plots))) ;
157 | plot(1:epoch, values','o-') ;
158 | xlabel('epoch') ;
159 | title(p) ;
160 | legend(leg{:}) ;
161 | grid on ;
162 | end
163 | drawnow ;
164 | print(1, modelFigPath, '-dpdf') ;
165 | end
166 |
167 | if ~isempty(opts.postEpochFn)
168 | if nargout(opts.postEpochFn) == 0
169 | opts.postEpochFn(net, params, state) ;
170 | else
171 | [lr, params] = opts.postEpochFn(net, params, state) ;
172 | if ~isempty(lr), opts.learningRate = lr; end
173 | if opts.learningRate == 0, break; end
174 | opts.train = params.train;
175 | end
176 | end
177 | end
178 |
179 | % With multiple GPUs, return one copy
180 | if isa(net, 'Composite'), net = net{1} ; end
181 |
182 | % -------------------------------------------------------------------------
183 | function [net, state] = processEpoch(net, state, params, mode)
184 | % -------------------------------------------------------------------------
185 | % Note that net is not strictly needed as an output argument as net
186 | % is a handle class. However, this fixes some aliasing issue in the
187 | % spmd caller.
188 |
189 | % initialize with momentum 0
190 | if isempty(state) || isempty(state.solverState)
191 | state.solverState = cell(1, numel(net.params)) ;
192 | state.solverState(:) = {0} ;
193 | end
194 |
195 | % move CNN to GPU as needed
196 | numGpus = numel(params.gpus) ;
197 | if numGpus >= 1
198 | net.move('gpu') ;
199 | for i = 1:numel(state.solverState)
200 | s = state.solverState{i} ;
201 | if isnumeric(s)
202 | state.solverState{i} = gpuArray(s) ;
203 | elseif isstruct(s)
204 | state.solverState{i} = structfun(@gpuArray, s, 'UniformOutput', false) ;
205 | end
206 | end
207 | end
208 | if numGpus > 1
209 | parserv = ParameterServer(params.parameterServer) ;
210 | net.setParameterServer(parserv) ;
211 | else
212 | parserv = [] ;
213 | end
214 |
215 | % profile
216 | if params.profile
217 | if numGpus <= 1
218 | profile clear ;
219 | profile on ;
220 | else
221 | mpiprofile reset ;
222 | mpiprofile on ;
223 | end
224 | end
225 |
226 | num = 0 ;
227 | epoch = params.epoch ;
228 | subset = params.(mode) ;
229 | adjustTime = 0 ;
230 |
231 | stats.num = 0 ; % return something even if subset = []
232 | stats.time = 0 ;
233 |
234 | % -----------------------------------------
235 | % -----------------------------------------
236 | if numel(subset) > 0
237 | logInfo('Epoch %d/%d: %s LR=%d', epoch, params.numEpochs, ...
238 | func2str(params.solver), params.learningRate);
239 |
240 | % [KH] preprocess: group by meta-class
241 | metaC = params.imdb.metaclass;
242 | assert(metaC{1}.id == 0); % NOTE id==0 means missing metaclass label
243 | Nmeta = numel(metaC);
244 | metaC_idx = cell(1, Nmeta);
245 |
246 | batchSize = params.batchSize;
247 | assert(mod(batchSize, 2) == 0);
248 | numBatches = zeros(Nmeta, 1);
249 | for i = 1:Nmeta
250 | % randomize the groups (individual classes) within this metaclass
251 | G = metaC{i}.groups;
252 | G = G(randperm(numel(G)));
253 | metaC_idx{i} = cat(1, G{:});
254 | if i > 1
255 | numBatches(i) = ceil(2*metaC{i}.num/batchSize);
256 | end
257 | end
258 | logInfo('%d meta-classes, %d batches', Nmeta-1, sum(numBatches));
259 |
260 | netopts = net.layers(end).block.opt;
261 | if strcmp(mode, 'train')
262 | if params.batchSize > netopts.maxGpuImgs
263 | mode = 'train_staged';
264 | else
265 | net.mode = 'normal' ;
266 | net.conserveMemory = true;
267 | end
268 | else
269 | net.mode = 'test' ;
270 | end
271 |
272 | start = tic; a = tic;
273 | objs = [];
274 | cur = zeros(1, Nmeta);
275 | t = 0;
276 | for m = 2:Nmeta, for p = 1:numBatches(m)
277 | % get half from metaclass #m, half from metaclass #1 (id:0)
278 | I1 = cur(1) + (1:batchSize/2);
279 | I2 = cur(m) + (1:batchSize/2);
280 | I1 = mod(I1-1, metaC{1}.num) + 1;
281 | I2 = mod(I2-1, metaC{m}.num) + 1;
282 | cur(1) = I1(end);
283 | cur(m) = I2(end);
284 | batch = [metaC_idx{1}(I1); metaC_idx{m}(I2)];
285 | t = t + batchSize;
286 |
287 | % assemble batch
288 | inputs = params.getBatch(params.imdb, batch) ;
289 | if strcmp(mode, 'train')
290 | net.eval(inputs, params.derOutputs) ;
291 |
292 | if ~isempty(parserv), parserv.sync() ; end
293 | state = accumulateGradients(net, state, params, batchSize, parserv) ;
294 |
295 | elseif strcmp(mode, 'train_staged')
296 | % staged backprop mode
297 | % stage 1: get embedding matrix, in chunks
298 | net.mode = 'test';
299 | ind = net.getVarIndex('feats_l2');
300 | net.vars(ind).precious = 1;
301 |
302 | labelMat = []; % label matrix
303 | embeddingMat = zeros(1, 1, netopts.dim, batchSize); % embedding matrix
304 | if numel(params.gpus) > 0
305 | embeddingMat = gpuArray(embeddingMat);
306 | end
307 |
308 | subBatchSize = netopts.maxGpuImgs;
309 | nSub = ceil(batchSize / subBatchSize);
310 | inputsCache = cell(1, nSub);
311 | for s = 1:nSub
312 | sub = (s-1)*subBatchSize+1 : min(batchSize, s*subBatchSize);
313 | subInputs = params.getBatch(params.imdb, batch(sub));
314 | inputsCache{s} = subInputs;
315 |
316 | net.eval(subInputs);
317 | embeddingMat(:, :, :, sub) = gather(net.vars(ind).value);
318 | labelMat = [labelMat; subInputs{4}];
319 | end
320 |
321 | % stage 2: compute gradient matrix
322 | [gradientMat, obj_staged] = net.layers(end).block.computeGrad(...
323 | embeddingMat, labelMat);
324 |
325 | % stage 3: accumulate gradients
326 | net.mode = 'normal';
327 | for s = 1:nSub
328 | sub = (s-1)*subBatchSize+1 : min(batchSize, s*subBatchSize);
329 | featsGrad = gradientMat(:, :, :, sub);
330 | net.eval(inputsCache{s}, {'feats_l2', featsGrad});
331 |
332 | if ~isempty(parserv), parserv.sync() ; end
333 | state = accumulateGradients(net, state, params, numel(sub), parserv);
334 | end
335 | else
336 | % eval mode
337 | net.eval(inputs) ;
338 | end
339 |
340 | % Get statistics.
341 | time = toc(start) + adjustTime ;
342 | batchTime = time - stats.time ;
343 | stats.num = num ;
344 | stats.time = time ;
345 | stats = params.extractStatsFn(stats,net) ;
346 | averageSpeed = t / time ;
347 | objs = [objs, stats.objective];
348 | if toc(a) > 30
349 | fprintf('%s-%s ep%02d: %3d/%3d (%d) %.1fHz', ...
350 | net.layers(end).block.opt.dataset, mode, epoch, ...
351 | fix(t/batchSize), sum(numBatches), batchSize, averageSpeed) ;
352 | fprintf(' obj: %.3f (%.3f)\n', stats.objective, mean(objs(~isnan(objs))));
353 | a = tic;
354 | end
355 | end, end % for m, for p
356 | if strcmp(mode, 'train_staged')
357 | mode = 'train';
358 | end
359 | logInfo('Epoch %d, Avg %s obj = %g', epoch, mode, mean(objs(~isnan(objs))));
360 | end % if N>0
361 |
362 | % Save back to state.
363 | state.stats.(mode) = stats ;
364 | if params.profile
365 | if numGpus <= 1
366 | state.prof.(mode) = profile('info') ;
367 | profile off ;
368 | else
369 | state.prof.(mode) = mpiprofile('info');
370 | mpiprofile off ;
371 | end
372 | end
373 | if ~params.saveSolverState
374 | state.solverState = [] ;
375 | else
376 | for i = 1:numel(state.solverState)
377 | s = state.solverState{i} ;
378 | if isnumeric(s)
379 | state.solverState{i} = gather(s) ;
380 | elseif isstruct(s)
381 | state.solverState{i} = structfun(@gather, s, 'UniformOutput', false) ;
382 | end
383 | end
384 | end
385 |
386 | net.reset() ;
387 | net.move('cpu') ;
388 |
389 | % -------------------------------------------------------------------------
390 | function state = accumulateGradients(net, state, params, batchSize, parserv)
391 | % -------------------------------------------------------------------------
392 | numGpus = numel(params.gpus) ;
393 | otherGpus = setdiff(1:numGpus, labindex) ;
394 |
395 | for p=1:numel(net.params)
396 |
397 | if ~isempty(parserv)
398 | parDer = parserv.pullWithIndex(p) ;
399 | else
400 | parDer = net.params(p).der ;
401 | end
402 |
403 | switch net.params(p).trainMethod
404 | case 'average' % mainly for batch normalization
405 | thisLR = net.params(p).learningRate ;
406 | net.params(p).value = vl_taccum(...
407 | 1 - thisLR, net.params(p).value, ...
408 | (thisLR/batchSize/net.params(p).fanout), parDer) ;
409 |
410 | case 'gradient'
411 | thisDecay = params.weightDecay * net.params(p).weightDecay ;
412 | thisLR = params.learningRate * net.params(p).learningRate ;
413 |
414 | if thisLR>0 || thisDecay>0
415 | % Normalize gradient and incorporate weight decay.
416 | try
417 | parDer = vl_taccum(1/batchSize, parDer, ...
418 | thisDecay, net.params(p).value) ;
419 | catch, net.params(p), keyboard, end
420 |
421 | if isempty(params.solver)
422 | % Default solver is the optimised SGD.
423 | % Update momentum.
424 | state.solverState{p} = vl_taccum(...
425 | params.momentum, state.solverState{p}, ...
426 | -1, parDer) ;
427 |
428 | % Nesterov update (aka one step ahead).
429 | if params.nesterovUpdate
430 | delta = params.momentum * state.solverState{p} - parDer ;
431 | else
432 | delta = state.solverState{p} ;
433 | end
434 |
435 | % Update parameters.
436 | net.params(p).value = vl_taccum(...
437 | 1, net.params(p).value, thisLR, delta) ;
438 |
439 | else
440 | % call solver function to update weights
441 | [net.params(p).value, state.solverState{p}] = ...
442 | params.solver(net.params(p).value, state.solverState{p}, ...
443 | parDer, params.solverOpts, thisLR) ;
444 | end
445 | end
446 | otherwise
447 | error('Unknown training method ''%s'' for parameter ''%s''.', ...
448 | net.params(p).trainMethod, ...
449 | net.params(p).name) ;
450 | end
451 | end
452 |
453 | % -------------------------------------------------------------------------
454 | function stats = accumulateStats(stats_)
455 | % -------------------------------------------------------------------------
456 | for s = {'train', 'val'}
457 | s = char(s) ;
458 | total = 0 ;
459 |
460 | % initialize stats stucture with same fields and same order as
461 | % stats_{1}
462 | stats__ = stats_{1} ;
463 | names = fieldnames(stats__.(s))' ;
464 | values = zeros(1, numel(names)) ;
465 | fields = cat(1, names, num2cell(values)) ;
466 | stats.(s) = struct(fields{:}) ;
467 |
468 | for g = 1:numel(stats_)
469 | stats__ = stats_{g} ;
470 | num__ = stats__.(s).num ;
471 | total = total + num__ ;
472 |
473 | for f = setdiff(fieldnames(stats__.(s))', 'num')
474 | f = char(f) ;
475 | stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ;
476 |
477 | if g == numel(stats_)
478 | stats.(s).(f) = stats.(s).(f) / total ;
479 | end
480 | end
481 | end
482 | stats.(s).num = total ;
483 | end
484 |
485 | % -------------------------------------------------------------------------
486 | function stats = extractStats(stats, net)
487 | % -------------------------------------------------------------------------
488 | sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ;
489 | for i = 1:numel(sel)
490 | if net.layers(sel(i)).block.ignoreAverage, continue; end;
491 | stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ;
492 | end
493 |
494 | % -------------------------------------------------------------------------
495 | function saveState(fileName, net_, state)
496 | % -------------------------------------------------------------------------
497 | net = net_.saveobj() ;
498 | save(fileName, 'net', 'state') ;
499 | logInfo('saved: %s', fileName);
500 |
501 | % -------------------------------------------------------------------------
502 | function saveStats(fileName, stats)
503 | % -------------------------------------------------------------------------
504 | if exist(fileName)
505 | try, save(fileName, 'stats', '-append') ; end
506 | else
507 | try, save(fileName, 'stats') ; end
508 | end
509 |
510 | % -------------------------------------------------------------------------
511 | function [net, state, stats] = loadState(fileName)
512 | % -------------------------------------------------------------------------
513 | load(fileName, 'net', 'state', 'stats') ;
514 | net = dagnn.DagNN.loadobj(net) ;
515 | if isempty(whos('stats'))
516 | warning('Epoch ''%s'' was only partially saved. Delete this file and try again.', ...
517 | fileName) ;
518 | end
519 |
520 | % -------------------------------------------------------------------------
521 | function epoch = findLastCheckpoint(modelDir)
522 | % -------------------------------------------------------------------------
523 | list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
524 | tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
525 | epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ;
526 | epoch = max([epoch 0]) ;
527 |
528 | % -------------------------------------------------------------------------
529 | function switchFigure(n)
530 | % -------------------------------------------------------------------------
531 | if get(0,'CurrentFigure') ~= n
532 | try
533 | set(0,'CurrentFigure',n) ;
534 | catch
535 | figure(n) ;
536 | end
537 | end
538 |
539 | % -------------------------------------------------------------------------
540 | function clearMex()
541 | % -------------------------------------------------------------------------
542 | clear vl_tmove vl_imreadjpeg ;
543 |
544 | % -------------------------------------------------------------------------
545 | function prepareGPUs(opts, cold)
546 | % -------------------------------------------------------------------------
547 | numGpus = numel(opts.gpus) ;
548 | if numGpus > 1
549 | % check parallel pool integrity as it could have timed out
550 | pool = gcp('nocreate') ;
551 | if ~isempty(pool) && pool.NumWorkers ~= numGpus
552 | delete(pool) ;
553 | end
554 | pool = gcp('nocreate') ;
555 | if isempty(pool)
556 | parpool('local', numGpus) ;
557 | cold = true ;
558 | end
559 |
560 | end
561 | if numGpus >= 1 && cold
562 | fprintf('%s: resetting GPU\n', mfilename)
563 | clearMex() ;
564 | if numGpus == 1
565 | gpuDevice(opts.gpus)
566 | else
567 | spmd
568 | clearMex() ;
569 | gpuDevice(opts.gpus(labindex))
570 | end
571 | end
572 | end
573 |
--------------------------------------------------------------------------------
/matlab/train_rand.m:
--------------------------------------------------------------------------------
1 | function [net,stats] = train_rand(net, imdb, getBatch, varargin)
2 | %CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper
3 | % CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with
4 | % the DagNN wrapper instead of the SimpleNN wrapper.
5 |
6 | % Copyright (C) 2014-16 Andrea Vedaldi.
7 | % All rights reserved.
8 | %
9 | % This file is part of the VLFeat library and is made available under
10 | % the terms of the BSD license (see the COPYING file).
11 | addpath(fullfile(vl_rootnn, 'examples'));
12 |
13 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%%
14 | opts.saveInterval = 2;
15 | %%%%%%%%%%%%% new fields %%%%%%%%%%%%%
16 |
17 | opts.expDir = fullfile('data','exp') ;
18 | opts.continue = true ;
19 | opts.batchSize = 256 ;
20 | opts.numSubBatches = 1 ;
21 | opts.train = [] ;
22 | opts.val = [] ;
23 | opts.gpus = [] ;
24 | opts.prefetch = false ;
25 | opts.epochSize = inf;
26 | opts.numEpochs = 300 ;
27 | opts.learningRate = 0.001 ;
28 | opts.weightDecay = 0.0005 ;
29 |
30 | opts.solver = [] ; % Empty array means use the default SGD solver
31 | [opts, varargin] = vl_argparse(opts, varargin) ;
32 | if ~isempty(opts.solver)
33 | assert(isa(opts.solver, 'function_handle') && nargout(opts.solver) == 2,...
34 | 'Invalid solver; expected a function handle with two outputs.') ;
35 | % Call without input arguments, to get default options
36 | opts.solverOpts = opts.solver() ;
37 | end
38 |
39 | opts.momentum = 0.9 ;
40 | opts.saveSolverState = true ;
41 | opts.nesterovUpdate = false ;
42 | opts.randomSeed = 0 ;
43 | opts.profile = false ;
44 | opts.parameterServer.method = 'mmap' ;
45 | opts.parameterServer.prefix = 'mcn' ;
46 |
47 | opts.derOutputs = {'objective', 1} ;
48 | opts.extractStatsFn = @extractStats ;
49 | opts.plotStatistics = true;
50 | opts.postEpochFn = [] ; % postEpochFn(net,params,state) called after each epoch; can return a new learning rate, 0 to stop, [] for no change
51 | opts = vl_argparse(opts, varargin) ;
52 |
53 | if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end
54 | if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end
55 | if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end
56 | if isscalar(opts.train) && isnumeric(opts.train) && isnan(opts.train)
57 | opts.train = [] ;
58 | end
59 | if isscalar(opts.val) && isnumeric(opts.val) && isnan(opts.val)
60 | opts.val = [] ;
61 | end
62 |
63 | % -------------------------------------------------------------------------
64 | % Initialization
65 | % -------------------------------------------------------------------------
66 |
67 | evaluateMode = isempty(opts.train) ;
68 | if ~evaluateMode
69 | if isempty(opts.derOutputs)
70 | error('DEROUTPUTS must be specified when training.\n') ;
71 | end
72 | end
73 |
74 | % -------------------------------------------------------------------------
75 | % Train and validate
76 | % -------------------------------------------------------------------------
77 |
78 | modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep));
79 | modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ;
80 |
81 | start = opts.continue * findLastCheckpoint(opts.expDir) ;
82 | if start >= 1
83 | fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ;
84 | % [KH] make sure to use the input opts, saved maybe outdated
85 | opt = net.layers(end).block.opt;
86 | [net, state, stats] = loadState(modelPath(start)) ;
87 | net.layers(end).block.opt = opt;
88 | else
89 | state = [] ;
90 | end
91 |
92 | for epoch=start+1:opts.numEpochs
93 |
94 | % Set the random seed based on the epoch and opts.randomSeed.
95 | % This is important for reproducibility, including when training
96 | % is restarted from a checkpoint.
97 |
98 | rng(epoch + opts.randomSeed) ;
99 | prepareGPUs(opts, epoch == start+1) ;
100 |
101 | % Train for one epoch.
102 | params = opts ;
103 | params.epoch = epoch ;
104 | params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ;
105 | params.train = opts.train; %(randperm(numel(opts.train))) ; % shuffle
106 | params.train = params.train; %(1:min(opts.epochSize, numel(opts.train)));
107 | params.val = opts.val; %(randperm(numel(opts.val))) ;
108 | params.imdb = imdb ;
109 | params.getBatch = getBatch ;
110 |
111 | if numel(opts.gpus) <= 1
112 | [net, state] = processEpoch(net, state, params, 'train') ;
113 | [net, state] = processEpoch(net, state, params, 'val') ;
114 | if ~evaluateMode && ~mod(epoch, opts.saveInterval)
115 | saveState(modelPath(epoch), net, state) ;
116 | end
117 | lastStats = state.stats ;
118 | else
119 | spmd
120 | [net, state] = processEpoch(net, state, params, 'train') ;
121 | [net, state] = processEpoch(net, state, params, 'val') ;
122 | if labindex == 1 && ~evaluateMode && ~mod(epoch, opts.saveInterval)
123 | saveState(modelPath(epoch), net, state) ;
124 | end
125 | lastStats = state.stats ;
126 | end
127 | lastStats = accumulateStats(lastStats) ;
128 | end
129 |
130 | stats.train(epoch) = lastStats.train ;
131 | stats.val(epoch) = lastStats.val ;
132 | clear lastStats ;
133 | saveStats(modelPath(epoch), stats) ;
134 |
135 | if opts.plotStatistics
136 | switchFigure(1) ; clf ;
137 | plots = setdiff(...
138 | cat(2,...
139 | fieldnames(stats.train)', ...
140 | fieldnames(stats.val)'), {'num', 'time'}) ;
141 | for p = plots
142 | p = char(p) ;
143 | values = zeros(0, epoch) ;
144 | leg = {} ;
145 | for f = {'train', 'val'}
146 | f = char(f) ;
147 | if isfield(stats.(f), p)
148 | tmp = [stats.(f).(p)] ;
149 | values(end+1,:) = tmp(1,:)' ;
150 | leg{end+1} = f ;
151 | end
152 | end
153 | subplot(1,numel(plots),find(strcmp(p,plots))) ;
154 | plot(1:epoch, values','o-') ;
155 | xlabel('epoch') ;
156 | title(p) ;
157 | legend(leg{:}) ;
158 | grid on ;
159 | end
160 | drawnow ;
161 | print(1, modelFigPath, '-dpdf') ;
162 | end
163 |
164 | if ~isempty(opts.postEpochFn)
165 | if nargout(opts.postEpochFn) == 0
166 | opts.postEpochFn(net, params, state) ;
167 | else
168 | [lr, params] = opts.postEpochFn(net, params, state) ;
169 | if ~isempty(lr), opts.learningRate = lr; end
170 | if opts.learningRate == 0, break; end
171 | opts.train = params.train;
172 | end
173 | end
174 | end
175 |
176 | % With multiple GPUs, return one copy
177 | if isa(net, 'Composite'), net = net{1} ; end
178 |
179 | % -------------------------------------------------------------------------
180 | function [net, state] = processEpoch(net, state, params, mode)
181 | % -------------------------------------------------------------------------
182 | % Note that net is not strictly needed as an output argument as net
183 | % is a handle class. However, this fixes some aliasing issue in the
184 | % spmd caller.
185 |
186 | % initialize with momentum 0
187 | if isempty(state) || isempty(state.solverState)
188 | state.solverState = cell(1, numel(net.params)) ;
189 | state.solverState(:) = {0} ;
190 | end
191 |
192 | % move CNN to GPU as needed
193 | numGpus = numel(params.gpus) ;
194 | if numGpus >= 1
195 | net.move('gpu') ;
196 | for i = 1:numel(state.solverState)
197 | s = state.solverState{i} ;
198 | if isnumeric(s)
199 | state.solverState{i} = gpuArray(s) ;
200 | elseif isstruct(s)
201 | state.solverState{i} = structfun(@gpuArray, s, 'UniformOutput', false) ;
202 | end
203 | end
204 | end
205 | if numGpus > 1
206 | parserv = ParameterServer(params.parameterServer) ;
207 | net.setParameterServer(parserv) ;
208 | else
209 | parserv = [] ;
210 | end
211 |
212 | % profile
213 | if params.profile
214 | if numGpus <= 1
215 | profile clear ;
216 | profile on ;
217 | else
218 | mpiprofile reset ;
219 | mpiprofile on ;
220 | end
221 | end
222 |
223 | num = 0 ;
224 | epoch = params.epoch ;
225 | subset = params.(mode) ;
226 | adjustTime = 0 ;
227 |
228 | stats.num = 0 ; % return something even if subset = []
229 | stats.time = 0 ;
230 |
231 | if numel(subset) > 0
232 | logInfo('Epoch %d/%d: %s LR=%d', epoch, params.numEpochs, ...
233 | func2str(params.solver), params.learningRate);
234 | start = tic; a = tic;
235 | objs = [];
236 | for t=1:params.batchSize:numel(subset)
237 | batchSize = min(params.batchSize, numel(subset) - t + 1) ;
238 |
239 | for s=1:params.numSubBatches
240 | % get this image batch and prefetch the next
241 | batchStart = t + (labindex-1) + (s-1) * numlabs ;
242 | batchEnd = min(t+params.batchSize-1, numel(subset)) ;
243 | batch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
244 | num = num + numel(batch) ;
245 | if numel(batch) == 0, continue ; end
246 |
247 | inputs = params.getBatch(params.imdb, batch) ;
248 |
249 | if params.prefetch
250 | if s == params.numSubBatches
251 | batchStart = t + (labindex-1) + params.batchSize ;
252 | batchEnd = min(t+2*params.batchSize-1, numel(subset)) ;
253 | else
254 | batchStart = batchStart + numlabs ;
255 | end
256 | nextBatch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
257 | params.getBatch(params.imdb, nextBatch) ;
258 | end
259 |
260 | if strcmp(mode, 'train')
261 | net.mode = 'normal' ;
262 | net.conserveMemory = true;
263 | net.accumulateParamDers = (s ~= 1) ;
264 | net.eval(inputs, params.derOutputs, 'holdOn', s < params.numSubBatches) ;
265 | else
266 | net.mode = 'test' ;
267 | net.eval(inputs) ;
268 | end
269 | end
270 |
271 | % Accumulate gradient.
272 | if strcmp(mode, 'train')
273 | if ~isempty(parserv), parserv.sync() ; end
274 | state = accumulateGradients(net, state, params, batchSize, parserv) ;
275 | end
276 |
277 | % Get statistics.
278 | time = toc(start) + adjustTime ;
279 | batchTime = time - stats.time ;
280 | stats.num = num ;
281 | stats.time = time ;
282 | stats = params.extractStatsFn(stats,net) ;
283 | currentSpeed = batchSize / batchTime ;
284 | averageSpeed = (t + batchSize - 1) / time ;
285 | if t == 3*params.batchSize + 1
286 | % compensate for the first three iterations, which are outliers
287 | adjustTime = 4*batchTime - time ;
288 | stats.time = time + adjustTime ;
289 | end
290 |
291 | if toc(a) > 30
292 | fprintf('%s-%s ep%02d: %3d/%3d (%d):', ...
293 | net.layers(end).block.opt.dataset, mode, epoch, ...
294 | fix((t-1)/params.batchSize)+1, ceil(numel(subset)/params.batchSize), ...
295 | params.batchSize) ;
296 | %fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ;
297 | fprintf(' %.1fHz', averageSpeed) ;
298 | for f = setdiff(fieldnames(stats)', {'num', 'time'})
299 | f = char(f) ;
300 | fprintf(' %s: %.3f', f, stats.(f)) ;
301 | end
302 | fprintf('\n') ;
303 | a = tic;
304 | end
305 | objs = [objs, stats.objective];
306 | end
307 | logInfo('Epoch %d, Avg %s obj = %g', epoch, mode, mean(objs(~isnan(objs))));
308 | end % if N>0
309 |
310 | % Save back to state.
311 | state.stats.(mode) = stats ;
312 | if params.profile
313 | if numGpus <= 1
314 | state.prof.(mode) = profile('info') ;
315 | profile off ;
316 | else
317 | state.prof.(mode) = mpiprofile('info');
318 | mpiprofile off ;
319 | end
320 | end
321 | if ~params.saveSolverState
322 | state.solverState = [] ;
323 | else
324 | for i = 1:numel(state.solverState)
325 | s = state.solverState{i} ;
326 | if isnumeric(s)
327 | state.solverState{i} = gather(s) ;
328 | elseif isstruct(s)
329 | state.solverState{i} = structfun(@gather, s, 'UniformOutput', false) ;
330 | end
331 | end
332 | end
333 |
334 | net.reset() ;
335 | net.move('cpu') ;
336 |
337 | % -------------------------------------------------------------------------
338 | function state = accumulateGradients(net, state, params, batchSize, parserv)
339 | % -------------------------------------------------------------------------
340 | numGpus = numel(params.gpus) ;
341 | otherGpus = setdiff(1:numGpus, labindex) ;
342 |
343 | for p=1:numel(net.params)
344 |
345 | if ~isempty(parserv)
346 | parDer = parserv.pullWithIndex(p) ;
347 | else
348 | parDer = net.params(p).der ;
349 | end
350 |
351 | switch net.params(p).trainMethod
352 | case 'average' % mainly for batch normalization
353 | thisLR = net.params(p).learningRate ;
354 | net.params(p).value = vl_taccum(...
355 | 1 - thisLR, net.params(p).value, ...
356 | (thisLR/batchSize/net.params(p).fanout), parDer) ;
357 |
358 | case 'gradient'
359 | thisDecay = params.weightDecay * net.params(p).weightDecay ;
360 | thisLR = params.learningRate * net.params(p).learningRate ;
361 |
362 | if thisLR>0 || thisDecay>0
363 | % Normalize gradient and incorporate weight decay.
364 | try
365 | parDer = vl_taccum(1/batchSize, parDer, ...
366 | thisDecay, net.params(p).value) ;
367 | catch, net.params(p), keyboard, end
368 |
369 | if isempty(params.solver)
370 | % Default solver is the optimised SGD.
371 | % Update momentum.
372 | state.solverState{p} = vl_taccum(...
373 | params.momentum, state.solverState{p}, ...
374 | -1, parDer) ;
375 |
376 | % Nesterov update (aka one step ahead).
377 | if params.nesterovUpdate
378 | delta = params.momentum * state.solverState{p} - parDer ;
379 | else
380 | delta = state.solverState{p} ;
381 | end
382 |
383 | % Update parameters.
384 | net.params(p).value = vl_taccum(...
385 | 1, net.params(p).value, thisLR, delta) ;
386 |
387 | else
388 | % call solver function to update weights
389 | [net.params(p).value, state.solverState{p}] = ...
390 | params.solver(net.params(p).value, state.solverState{p}, ...
391 | parDer, params.solverOpts, thisLR) ;
392 | end
393 | end
394 | otherwise
395 | error('Unknown training method ''%s'' for parameter ''%s''.', ...
396 | net.params(p).trainMethod, ...
397 | net.params(p).name) ;
398 | end
399 | end
400 |
401 | % -------------------------------------------------------------------------
402 | function stats = accumulateStats(stats_)
403 | % -------------------------------------------------------------------------
404 | for s = {'train', 'val'}
405 | s = char(s) ;
406 | total = 0 ;
407 |
408 | % initialize stats stucture with same fields and same order as
409 | % stats_{1}
410 | stats__ = stats_{1} ;
411 | names = fieldnames(stats__.(s))' ;
412 | values = zeros(1, numel(names)) ;
413 | fields = cat(1, names, num2cell(values)) ;
414 | stats.(s) = struct(fields{:}) ;
415 |
416 | for g = 1:numel(stats_)
417 | stats__ = stats_{g} ;
418 | num__ = stats__.(s).num ;
419 | total = total + num__ ;
420 |
421 | for f = setdiff(fieldnames(stats__.(s))', 'num')
422 | f = char(f) ;
423 | stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ;
424 |
425 | if g == numel(stats_)
426 | stats.(s).(f) = stats.(s).(f) / total ;
427 | end
428 | end
429 | end
430 | stats.(s).num = total ;
431 | end
432 |
433 | % -------------------------------------------------------------------------
434 | function stats = extractStats(stats, net)
435 | % -------------------------------------------------------------------------
436 | sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ;
437 | for i = 1:numel(sel)
438 | if net.layers(sel(i)).block.ignoreAverage, continue; end;
439 | stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ;
440 | end
441 |
442 | % -------------------------------------------------------------------------
443 | function saveState(fileName, net_, state)
444 | % -------------------------------------------------------------------------
445 | net = net_.saveobj() ;
446 | save(fileName, 'net', 'state') ;
447 | logInfo('saved: %s', fileName);
448 |
449 | % -------------------------------------------------------------------------
450 | function saveStats(fileName, stats)
451 | % -------------------------------------------------------------------------
452 | if exist(fileName)
453 | try, save(fileName, 'stats', '-append') ; end
454 | else
455 | try, save(fileName, 'stats') ; end
456 | end
457 |
458 | % -------------------------------------------------------------------------
459 | function [net, state, stats] = loadState(fileName)
460 | % -------------------------------------------------------------------------
461 | load(fileName, 'net', 'state', 'stats') ;
462 | net = dagnn.DagNN.loadobj(net) ;
463 | if isempty(whos('stats'))
464 | warning('Epoch ''%s'' was only partially saved. Delete this file and try again.', ...
465 | fileName) ;
466 | end
467 |
468 | % -------------------------------------------------------------------------
469 | function epoch = findLastCheckpoint(modelDir)
470 | % -------------------------------------------------------------------------
471 | list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
472 | tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
473 | epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ;
474 | epoch = max([epoch 0]) ;
475 |
476 | % -------------------------------------------------------------------------
477 | function switchFigure(n)
478 | % -------------------------------------------------------------------------
479 | if get(0,'CurrentFigure') ~= n
480 | try
481 | set(0,'CurrentFigure',n) ;
482 | catch
483 | figure(n) ;
484 | end
485 | end
486 |
487 | % -------------------------------------------------------------------------
488 | function clearMex()
489 | % -------------------------------------------------------------------------
490 | clear vl_tmove vl_imreadjpeg ;
491 |
492 | % -------------------------------------------------------------------------
493 | function prepareGPUs(opts, cold)
494 | % -------------------------------------------------------------------------
495 | numGpus = numel(opts.gpus) ;
496 | if numGpus > 1
497 | % check parallel pool integrity as it could have timed out
498 | pool = gcp('nocreate') ;
499 | if ~isempty(pool) && pool.NumWorkers ~= numGpus
500 | delete(pool) ;
501 | end
502 | pool = gcp('nocreate') ;
503 | if isempty(pool)
504 | parpool('local', numGpus) ;
505 | cold = true ;
506 | end
507 |
508 | end
509 | if numGpus >= 1 && cold
510 | fprintf('%s: resetting GPU\n', mfilename)
511 | clearMex() ;
512 | if numGpus == 1
513 | gpuDevice(opts.gpus)
514 | else
515 | spmd
516 | clearMex() ;
517 | gpuDevice(opts.gpus(labindex))
518 | end
519 | end
520 | end
521 |
--------------------------------------------------------------------------------
/matlab/util/catstruct.m:
--------------------------------------------------------------------------------
1 | function A = catstruct(varargin)
2 | % CATSTRUCT Concatenate or merge structures with different fieldnames
3 | % X = CATSTRUCT(S1,S2,S3,...) merges the structures S1, S2, S3 ...
4 | % into one new structure X. X contains all fields present in the various
5 | % structures. An example:
6 | %
7 | % A.name = 'Me' ;
8 | % B.income = 99999 ;
9 | % X = catstruct(A,B)
10 | % % -> X.name = 'Me' ;
11 | % % X.income = 99999 ;
12 | %
13 | % If a fieldname is not unique among structures (i.e., a fieldname is
14 | % present in more than one structure), only the value from the last
15 | % structure with this field is used. In this case, the fields are
16 | % alphabetically sorted. A warning is issued as well. An axample:
17 | %
18 | % S1.name = 'Me' ;
19 | % S2.age = 20 ; S3.age = 30 ; S4.age = 40 ;
20 | % S5.honest = false ;
21 | % Y = catstruct(S1,S2,S3,S4,S5) % use value from S4
22 | %
23 | % The inputs can be array of structures. All structures should have the
24 | % same size. An example:
25 | %
26 | % C(1).bb = 1 ; C(2).bb = 2 ;
27 | % D(1).aa = 3 ; D(2).aa = 4 ;
28 | % CD = catstruct(C,D) % CD is a 1x2 structure array with fields bb and aa
29 | %
30 | % The last input can be the string 'sorted'. In this case,
31 | % CATSTRUCT(S1,S2, ..., 'sorted') will sort the fieldnames alphabetically.
32 | % To sort the fieldnames of a structure A, you could use
33 | % CATSTRUCT(A,'sorted') but I recommend ORDERFIELDS for doing that.
34 | %
35 | % When there is nothing to concatenate, the result will be an empty
36 | % struct (0x0 struct array with no fields).
37 | %
38 | % NOTE: To concatenate similar arrays of structs, you can use simple
39 | % concatenation:
40 | % A = dir('*.mat') ; B = dir('*.m') ; C = [A ; B] ;
41 |
42 | % NOTE: This function relies on unique. Matlab changed the behavior of
43 | % its set functions since 2013a, so this might cause some backward
44 | % compatibility issues when dulpicated fieldnames are found.
45 | %
46 | % See also CAT, STRUCT, FIELDNAMES, STRUCT2CELL, ORDERFIELDS
47 |
48 | % version 4.1 (feb 2015), tested in R2014a
49 | % (c) Jos van der Geest
50 | % email: jos@jasen.nl
51 |
52 | % History
53 | % Created in 2005
54 | % Revisions
55 | % 2.0 (sep 2007) removed bug when dealing with fields containing cell
56 | % arrays (Thanks to Rene Willemink)
57 | % 2.1 (sep 2008) added warning and error identifiers
58 | % 2.2 (oct 2008) fixed error when dealing with empty structs (thanks to
59 | % Lars Barring)
60 | % 3.0 (mar 2013) fixed problem when the inputs were array of structures
61 | % (thanks to Tor Inge Birkenes).
62 | % Rephrased the help section as well.
63 | % 4.0 (dec 2013) fixed problem with unique due to version differences in
64 | % ML. Unique(...,'last') is no longer the deafult.
65 | % (thanks to Isabel P)
66 | % 4.1 (feb 2015) fixed warning with narginchk
67 |
68 | narginchk(1,Inf) ;
69 | N = nargin ;
70 |
71 | if ~isstruct(varargin{end}),
72 | if isequal(varargin{end},'sorted'),
73 | narginchk(2,Inf) ;
74 | sorted = 1 ;
75 | N = N-1 ;
76 | else
77 | error('catstruct:InvalidArgument','Last argument should be a structure, or the string "sorted".') ;
78 | end
79 | else
80 | sorted = 0 ;
81 | end
82 |
83 | sz0 = [] ; % used to check that all inputs have the same size
84 |
85 | % used to check for a few trivial cases
86 | NonEmptyInputs = false(N,1) ;
87 | NonEmptyInputsN = 0 ;
88 |
89 | % used to collect the fieldnames and the inputs
90 | FN = cell(N,1) ;
91 | VAL = cell(N,1) ;
92 |
93 | % parse the inputs
94 | for ii=1:N,
95 | X = varargin{ii} ;
96 | if ~isstruct(X),
97 | error('catstruct:InvalidArgument',['Argument #' num2str(ii) ' is not a structure.']) ;
98 | end
99 |
100 | if ~isempty(X),
101 | % empty structs are ignored
102 | if ii > 1 && ~isempty(sz0)
103 | if ~isequal(size(X), sz0)
104 | error('catstruct:UnequalSizes','All structures should have the same size.') ;
105 | end
106 | else
107 | sz0 = size(X) ;
108 | end
109 | NonEmptyInputsN = NonEmptyInputsN + 1 ;
110 | NonEmptyInputs(ii) = true ;
111 | FN{ii} = fieldnames(X) ;
112 | VAL{ii} = struct2cell(X) ;
113 | end
114 | end
115 |
116 | if NonEmptyInputsN == 0
117 | % all structures were empty
118 | A = struct([]) ;
119 | elseif NonEmptyInputsN == 1,
120 | % there was only one non-empty structure
121 | A = varargin{NonEmptyInputs} ;
122 | if sorted,
123 | A = orderfields(A) ;
124 | end
125 | else
126 | % there is actually something to concatenate
127 | FN = cat(1,FN{:}) ;
128 | VAL = cat(1,VAL{:}) ;
129 | FN = squeeze(FN) ;
130 | VAL = squeeze(VAL) ;
131 |
132 |
133 | [UFN,ind] = unique(FN, 'last') ;
134 | % If this line errors, due to your matlab version not having UNIQUE
135 | % accept the 'last' input, use the following line instead
136 | % [UFN,ind] = unique(FN) ; % earlier ML versions, like 6.5
137 |
138 | if numel(UFN) ~= numel(FN),
139 | warning('catstruct:DuplicatesFound','Fieldnames are not unique between structures.') ;
140 | sorted = 1 ;
141 | end
142 |
143 | if sorted,
144 | VAL = VAL(ind,:) ;
145 | FN = FN(ind,:) ;
146 | end
147 |
148 | A = cell2struct(VAL, FN);
149 | A = reshape(A, sz0) ; % reshape into original format
150 | end
151 |
152 |
153 |
154 |
--------------------------------------------------------------------------------
/matlab/util/clearMex.m:
--------------------------------------------------------------------------------
1 | % -------------------------------------------------------------------------
2 | function clearMex()
3 | % -------------------------------------------------------------------------
4 | disp('Clearing mex files') ;
5 | clear mex ;
6 | clear vl_tmove vl_imreadjpeg ;
7 |
8 |
--------------------------------------------------------------------------------
/matlab/util/logInfo.m:
--------------------------------------------------------------------------------
1 | function logInfo(str, varargin)
2 | % get caller function ID and display log msg
3 | if nargin > 1
4 | cmd = 'sprintf(str';
5 | for i = 1:length(varargin)
6 | cmd = sprintf('%s, varargin{%d}', cmd, i);
7 | end
8 | str = eval([cmd, ');']);
9 | end
10 | [st, i] = dbstack();
11 | caller = st(2).name;
12 | fprintf('%s: %s\n', caller, str);
13 | end
14 |
--------------------------------------------------------------------------------
/matlab/util/prepareGPUs.m:
--------------------------------------------------------------------------------
1 | % -------------------------------------------------------------------------
2 | function prepareGPUs(params, cold)
3 | % -------------------------------------------------------------------------
4 | numGpus = numel(params.gpus) ;
5 | if numGpus > 1
6 | % check parallel pool integrity as it could have timed out
7 | pool = gcp('nocreate') ;
8 | if ~isempty(pool) && pool.NumWorkers ~= numGpus
9 | delete(pool) ;
10 | end
11 | pool = gcp('nocreate') ;
12 | if isempty(pool)
13 | parpool('local', numGpus) ;
14 | cold = true ;
15 | end
16 | end
17 | if numGpus >= 1 && cold
18 | fprintf('%s: resetting GPU\n', mfilename) ;
19 | %clearMex() ;
20 | if numGpus == 1
21 | disp(gpuDevice(params.gpus)) ;
22 | else
23 | spmd
24 | %clearMex() ;
25 | disp(gpuDevice(params.gpus(labindex))) ;
26 | end
27 | end
28 | end
29 |
--------------------------------------------------------------------------------
/matlab/util/record_diary.m:
--------------------------------------------------------------------------------
1 | function record_diary(opts)
2 | diary_path = @(i) sprintf('%s/diary_%03d.txt', opts.expDir, i);
3 | ind = 1;
4 | while exist(diary_path(ind), 'file')
5 | ind = ind + 1;
6 | end
7 | diary(diary_path(ind));
8 | diary('on');
9 | end
10 |
--------------------------------------------------------------------------------
/pytorch/FastAP_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable, Function
3 |
4 | def softBinning(D, mid, Delta):
5 | y = 1 - torch.abs(D-mid)/Delta
6 | return torch.max(torch.Tensor([0]).cuda(), y)
7 |
8 | def dSoftBinning(D, mid, Delta):
9 | side1 = (D > (mid - Delta)).type(torch.float)
10 | side2 = (D <= mid).type(torch.float)
11 | ind1 = (side1 * side2) #.type(torch.uint8)
12 |
13 | side1 = (D > mid).type(torch.float)
14 | side2 = (D <= (mid + Delta)).type(torch.float)
15 | ind2 = (side1 * side2) #.type(torch.uint8)
16 |
17 | return (ind1 - ind2)/Delta
18 |
19 |
20 | class FastAP(torch.autograd.Function):
21 | """
22 | FastAP - autograd function definition
23 |
24 | This class implements the FastAP loss from the following paper:
25 | "Deep Metric Learning to Rank",
26 | F. Cakir, K. He, X. Xia, B. Kulis, S. Sclaroff. CVPR 2019
27 |
28 | NOTE:
29 | Given a input batch, FastAP does not sample triplets from it as it's not
30 | a triplet-based method. Therefore, FastAP does not take a Sampler as input.
31 | Rather, we specify how the input batch is selected.
32 | """
33 |
34 | @staticmethod
35 | def forward(ctx, input, target, num_bins):
36 | """
37 | Args:
38 | input: torch.Tensor(N x embed_dim), embedding matrix
39 | target: torch.Tensor(N x 1), class labels
40 | num_bins: int, number of bins in distance histogram
41 | """
42 | N = target.size()[0]
43 | assert input.size()[0] == N, "Batch size donesn't match!"
44 |
45 | # 1. get affinity matrix
46 | Y = target.unsqueeze(1)
47 | Aff = 2 * (Y == Y.t()).type(torch.float) - 1
48 | Aff.masked_fill_(torch.eye(N, N).byte(), 0) # set diagonal to 0
49 |
50 | I_pos = (Aff > 0).type(torch.float).cuda()
51 | I_neg = (Aff < 0).type(torch.float).cuda()
52 | N_pos = torch.sum(I_pos, 1)
53 |
54 | # 2. compute distances from embeddings
55 | # squared Euclidean distance with range [0,4]
56 | dist2 = 2 - 2 * torch.mm(input, input.t())
57 |
58 | # 3. estimate discrete histograms
59 | Delta = torch.tensor(4. / num_bins).cuda()
60 | Z = torch.linspace(0., 4., steps=num_bins+1).cuda()
61 | L = Z.size()[0]
62 | h_pos = torch.zeros((N, L)).cuda()
63 | h_neg = torch.zeros((N, L)).cuda()
64 | for l in range(L):
65 | pulse = softBinning(dist2, Z[l], Delta)
66 | h_pos[:,l] = torch.sum(pulse * I_pos, 1)
67 | h_neg[:,l] = torch.sum(pulse * I_neg, 1)
68 |
69 | H_pos = torch.cumsum(h_pos, 1)
70 | h = h_pos + h_neg
71 | H = torch.cumsum(h, 1)
72 |
73 | # 4. compate FastAP
74 | FastAP = h_pos * H_pos / H
75 | FastAP[torch.isnan(FastAP) | torch.isinf(FastAP)] = 0
76 | FastAP = torch.sum(FastAP,1)/N_pos
77 | FastAP = FastAP[ ~torch.isnan(FastAP) ]
78 | loss = 1 - torch.mean(FastAP)
79 | if torch.rand(1) > 0.99:
80 | print("loss value (1-mean(FastAP)): ", loss.item())
81 |
82 | # 6. save for backward
83 | ctx.save_for_backward(input, target)
84 | ctx.Z = Z
85 | ctx.Delta = Delta
86 | ctx.dist2 = dist2
87 | ctx.I_pos = I_pos
88 | ctx.I_neg = I_neg
89 | ctx.h_pos = h_pos
90 | ctx.h_neg = h_neg
91 | ctx.H_pos = H_pos
92 | ctx.N_pos = N_pos
93 | ctx.h = h
94 | ctx.H = H
95 | ctx.L = torch.tensor(L)
96 |
97 | return loss
98 |
99 |
100 | @staticmethod
101 | def backward(ctx, grad_output):
102 | input, target = ctx.saved_tensors
103 |
104 | Z = Variable(ctx.Z , requires_grad = False)
105 | Delta = Variable(ctx.Delta , requires_grad = False)
106 | dist2 = Variable(ctx.dist2 , requires_grad = False)
107 | I_pos = Variable(ctx.I_pos , requires_grad = False)
108 | I_neg = Variable(ctx.I_neg , requires_grad = False)
109 | h = Variable(ctx.h , requires_grad = False)
110 | H = Variable(ctx.H , requires_grad = False)
111 | h_pos = Variable(ctx.h_pos , requires_grad = False)
112 | h_neg = Variable(ctx.h_neg , requires_grad = False)
113 | H_pos = Variable(ctx.H_pos , requires_grad = False)
114 | N_pos = Variable(ctx.N_pos , requires_grad = False)
115 |
116 | L = Z.size()[0]
117 | H2 = torch.pow(H,2)
118 | H_neg = H - H_pos
119 |
120 | # 1. d(FastAP)/d(h+)
121 | LTM1 = torch.tril(torch.ones(L,L), -1) # lower traingular matrix
122 | tmp1 = h_pos * H_neg / H2
123 | tmp1[torch.isnan(tmp1)] = 0
124 |
125 | d_AP_h_pos = (H_pos * H + h_pos * H_neg) / H2
126 | d_AP_h_pos = d_AP_h_pos + torch.mm(tmp1, LTM1.cuda())
127 | d_AP_h_pos = d_AP_h_pos / N_pos.repeat(L,1).t()
128 | d_AP_h_pos[torch.isnan(d_AP_h_pos) | torch.isinf(d_AP_h_pos)] = 0
129 |
130 |
131 | # 2. d(FastAP)/d(h-)
132 | LTM0 = torch.tril(torch.ones(L,L), 0) # lower triangular matrix
133 | tmp2 = -h_pos * H_pos / H2
134 | tmp2[torch.isnan(tmp2)] = 0
135 |
136 | d_AP_h_neg = torch.mm(tmp2, LTM0.cuda())
137 | d_AP_h_neg = d_AP_h_neg / N_pos.repeat(L,1).t()
138 | d_AP_h_neg[torch.isnan(d_AP_h_neg) | torch.isinf(d_AP_h_neg)] = 0
139 |
140 |
141 | # 3. d(FastAP)/d(embedding)
142 | d_AP_x = 0
143 | for l in range(L):
144 | dpulse = dSoftBinning(dist2, Z[l], Delta)
145 | dpulse[torch.isnan(dpulse) | torch.isinf(dpulse)] = 0
146 | ddp = dpulse * I_pos
147 | ddn = dpulse * I_neg
148 |
149 | alpha_p = torch.diag(d_AP_h_pos[:,l]) # N*N
150 | alpha_n = torch.diag(d_AP_h_neg[:,l])
151 | Ap = torch.mm(ddp, alpha_p) + torch.mm(alpha_p, ddp)
152 | An = torch.mm(ddn, alpha_n) + torch.mm(alpha_n, ddn)
153 |
154 | # accumulate gradient
155 | d_AP_x = d_AP_x - torch.mm(input.t(), (Ap+An))
156 |
157 | grad_input = -d_AP_x
158 | return grad_input.t(), None, None
159 |
160 |
161 | class FastAPLoss(torch.nn.Module):
162 | """
163 | FastAP - loss layer definition
164 |
165 | This class implements the FastAP loss from the following paper:
166 | "Deep Metric Learning to Rank",
167 | F. Cakir, K. He, X. Xia, B. Kulis, S. Sclaroff. CVPR 2019
168 | """
169 | def __init__(self, num_bins=10):
170 | super(FastAPLoss, self).__init__()
171 | self.num_bins = num_bins
172 |
173 | def forward(self, batch, labels):
174 | return FastAP.apply(batch, labels, self.num_bins)
175 |
--------------------------------------------------------------------------------
/pytorch/README.md:
--------------------------------------------------------------------------------
1 | ## PyTorch implementation of "Deep Metric Learning to Rank"
2 |
3 | The PyTorch version of FastAP is implemented within the framework of
4 | [Deep-Metric-Learning-Baselines](https://github.com/kunhe/Deep-Metric-Learning-Baselines).
5 | In this repository we provide a standalone implementation of the loss layer.
6 |
7 | #### NOTE
8 | - To completely reproduce results reported in the paper, the FastAP loss needs to be used in conjunction with the proposed minibatch sampling method, which is implemented in the linked repo.
9 | - It is currently a direct port of the Matlab implementation. To investigate: better use of automatic differentiation.
10 |
11 | **TODO**
12 | - [x] implement FastAP's minibatch sampling method
13 | - [ ] reproduce results in the paper
14 |
--------------------------------------------------------------------------------