├── BatchProviderBase.lua ├── BatchProviderROI.lua ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── DataSetJSON.lua ├── ImageDetect.lua ├── LICENSE ├── Makefile ├── PATENTS ├── README.md ├── Tester_FRCNN.lua ├── config.lua ├── data.lua ├── demo.lua ├── donkey.lua ├── engines ├── Optim.lua └── fboptimengine.lua ├── fbcoco.lua ├── loaders ├── concatloader.lua ├── dataloader.lua ├── loader.lua └── narrowloader.lua ├── models ├── alexnet.lua ├── inceptionv3.lua ├── model_utils.lua ├── multipathnet.lua ├── nin.lua ├── resnet.lua └── vgg.lua ├── modules ├── BBoxNorm.lua ├── BBoxRegressionCriterion.lua ├── ContextRegion.lua ├── ConvertFrom.lua ├── Foveal.lua ├── ImageTransformer.lua ├── ModeSwitch.lua ├── ModelParallelTable.lua ├── NoBackprop.lua ├── SelectBoxes.lua ├── SequentialSplitBatch.lua └── test.lua ├── nms.c ├── run_test.lua ├── scripts ├── ec2-install.sh ├── eval_coco.sh ├── eval_fastrcnn_voc2007.sh ├── train_coco.sh ├── train_fastrcnn_voc2007.sh └── train_multipathnet_coco.sh ├── test.lua ├── testCoco ├── coco.lua └── init.lua ├── test_runner.lua ├── train.lua └── utils.lua /BatchProviderBase.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local BatchProviderBase = torch.class('fbcoco.BatchProviderBase') 10 | 11 | function BatchProviderBase:getImages(img_ids, do_flip) 12 | local num_images = img_ids:size(1) 13 | 14 | local imgs = {} 15 | local im_sizes = {} 16 | local im_scales = {} 17 | 18 | for i=1,num_images do 19 | local im = self.dataset:getImage(img_ids[i]) 20 | im = self.image_transformer(im) 21 | local flip = do_flip[i] == 1 22 | if flip then im = image.hflip(im) end 23 | local im_size = im[1]:size() 24 | local im_size_min = math.min(im_size[1],im_size[2]) 25 | local im_size_max = math.max(im_size[1],im_size[2]) 26 | local im_scale = self.scale/im_size_min 27 | local aspect_jitter = 1 + (torch.uniform(-1,1)-0.5)*self.aspect_jitter 28 | local scale_jitter = 1 + (torch.uniform(-1,1)-0.5)*self.scale_jitter 29 | local im_scale = im_scale * scale_jitter 30 | im_scale = {im_scale * math.sqrt(aspect_jitter), im_scale / math.sqrt(aspect_jitter)} 31 | local im_s = {im_size[1]*im_scale[1],im_size[2]*im_scale[1]} 32 | for dim = 1,2 do 33 | if im_s[dim] > self.max_size then 34 | local rat = im_s[dim] / self.max_size 35 | im_s = {im_s[1] / rat, im_s[2] / rat} 36 | im_scale = {im_scale[1] / rat, im_scale[2] / rat} 37 | end 38 | end 39 | table.insert(imgs,image.scale(im,im_s[2],im_s[1])) 40 | table.insert(im_sizes,im_s) 41 | table.insert(im_scales,im_scale) 42 | end 43 | -- create single tensor with all images, padding with zero for different sizes 44 | im_sizes = torch.IntTensor(im_sizes) 45 | local max_shape = im_sizes:max(1)[1] 46 | local images = torch.FloatTensor(num_images,3,max_shape[1],max_shape[2]):zero() 47 | for i,v in ipairs(imgs) do 48 | images[{i, {}, {1,v:size(2)}, {1,v:size(3)}}]:copy(v) 49 | end 50 | return images, im_scales, im_sizes 51 | end 52 | 53 | 54 | function BatchProviderBase.takeSubset(rec, t, i, is_bg) 55 | local idx = torch.type(t) == 'table' and torch.LongTensor(t) or t:long() 56 | local n = idx:numel() 57 | if n == 0 then return end 58 | if idx:dim() == 2 then idx = idx:select(2,1) end 59 | local window = { 60 | indexes = torch.IntTensor(n), 61 | rois = torch.FloatTensor(n,4), 62 | labels = torch.IntTensor(n):fill(1), 63 | gtboxes = torch.FloatTensor(n,4):zero(), 64 | size = function() return n end, 65 | } 66 | window.indexes:fill(i) 67 | window.rois:copy(rec.boxes:index(1,idx)) 68 | if not is_bg then 69 | window.labels:add(rec.label:index(1,idx)) 70 | local corresp = rec.correspondance:index(1,idx) 71 | window.gtboxes:copy(rec.boxes:index(1, corresp)) 72 | end 73 | return window 74 | end 75 | 76 | 77 | function BatchProviderBase.selectBBoxesOne(bboxes, num_max, im_scale, im_size, flip) 78 | local rois = {} 79 | local labels = {} 80 | local gtboxes = {} 81 | 82 | local n = bboxes:size() 83 | local im_scale = torch.FloatTensor(im_scale):repeatTensor(2) 84 | 85 | local function preprocess_bbox(dd, flip) 86 | dd = dd:clone():add(-1):cmul(im_scale):add(1) 87 | if flip then 88 | local tt = dd[1] 89 | dd[1] = im_size[2]-dd[3] +1 90 | dd[3] = im_size[2]-tt +1 91 | end 92 | return dd:view(1,4) 93 | end 94 | 95 | for i=1,math.min(num_max, n) do 96 | local position = torch.random(n) 97 | table.insert(rois, preprocess_bbox(bboxes.rois[position],flip)) 98 | table.insert(gtboxes, preprocess_bbox(bboxes.gtboxes[position], flip)) 99 | table.insert(labels, bboxes.labels[position]) 100 | end 101 | 102 | return { 103 | gtboxes = torch.FloatTensor():cat(gtboxes,1), 104 | rois = torch.FloatTensor():cat(rois,1), 105 | labels = torch.IntTensor(labels), 106 | } 107 | end 108 | 109 | -------------------------------------------------------------------------------- /BatchProviderROI.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local BatchProviderROI, parent = torch.class('fbcoco.BatchProviderROI', 'fbcoco.BatchProviderBase') 10 | local utils = paths.dofile'utils.lua' 11 | local tablex = require'pl.tablex' 12 | 13 | function BatchProviderROI:__init(dataset, imgs_per_batch, scale, max_size, transformer, fg_threshold, bg_threshold) 14 | assert(transformer,'must provide transformer!') 15 | 16 | self.dataset = dataset 17 | 18 | self.batch_size = 128 19 | self.fg_fraction = 0.25 20 | 21 | self.fg_threshold = fg_threshold 22 | self.bg_threshold = bg_threshold 23 | 24 | self.imgs_per_batch = imgs_per_batch or 2 25 | self.scale = scale or 600 26 | self.max_size = max_size or 1000 27 | self.image_transformer = transformer 28 | 29 | self.scale_jitter = scale_jitter or 0 -- uniformly jitter the scale by this frac 30 | self.aspect_jitter = aspect_jitter or 0 -- uniformly jitter the scale by this frac 31 | self.crop_likelihood = crop_likelihood or 0 -- likelihood of doing a random crop (in each dimension, independently) 32 | self.crop_attempts = 10 -- number of attempts to try to find a valid crop 33 | self.crop_min_frac = 0.7 -- a crop must preserve at least this fraction of the iamge 34 | end 35 | 36 | -- Prepare foreground / background rois for one image 37 | -- there is a check if self.bboxes has a table prepared for this image already 38 | -- because we prepare the rois during training to save time on loading 39 | function BatchProviderROI:setupOne(i) 40 | local rec = self.dataset:attachProposals(i) 41 | 42 | local bf = rec.overlap:ge(self.fg_threshold):nonzero() 43 | local bg = rec.overlap:ge(self.bg_threshold[1]):cmul( 44 | rec.overlap:lt(self.bg_threshold[2])):nonzero() 45 | return { 46 | [0] = self.takeSubset(rec, bg, i, true), 47 | [1] = self.takeSubset(rec, bf, i, false) 48 | } 49 | end 50 | 51 | -- Calculate rois and supporting data for the first 1000 images 52 | -- to compute mean/var for bbox regresion 53 | function BatchProviderROI:setupData() 54 | local regression_values = {} 55 | local subset_size = 1000 56 | for i = 1,1000 do 57 | local v = self:setupOne(i)[1] 58 | if v then 59 | table.insert(regression_values, utils.convertTo(v.rois, v.gtboxes)) 60 | end 61 | end 62 | regression_values = torch.FloatTensor():cat(regression_values,1) 63 | 64 | self.bbox_regr = { 65 | mean = regression_values:mean(1), 66 | std = regression_values:std(1) 67 | } 68 | return self.bbox_regr 69 | end 70 | 71 | -- sample until find a valid combination of bg/fg boxes 72 | function BatchProviderROI:permuteIdx() 73 | local boxes, img_idx = {}, {} 74 | for i=1,self.imgs_per_batch do 75 | local curr_idx 76 | local bboxes = {} 77 | while not bboxes[0] or not bboxes[1] do 78 | curr_idx = torch.random(self.dataset:size()) 79 | tablex.update(bboxes, self:setupOne(curr_idx)) 80 | end 81 | table.insert(boxes, bboxes) 82 | table.insert(img_idx, curr_idx) 83 | end 84 | local do_flip = torch.FloatTensor(self.imgs_per_batch):random(0,1) 85 | return torch.IntTensor(img_idx), boxes, do_flip 86 | end 87 | 88 | function BatchProviderROI:selectBBoxes(boxes, im_scales, im_sizes, do_flip) 89 | local rois = {} 90 | local labels = {} 91 | local gtboxes = {} 92 | for im,v in ipairs(boxes) do 93 | local flip = do_flip[im] == 1 94 | 95 | local bg = self.selectBBoxesOne(v[0],self.bg_num_each,im_scales[im],im_sizes[im],flip) 96 | local fg = self.selectBBoxesOne(v[1],self.fg_num_each,im_scales[im],im_sizes[im],flip) 97 | 98 | local imrois = torch.FloatTensor():cat(bg.rois, fg.rois, 1) 99 | imrois = torch.FloatTensor(imrois:size(1),1):fill(im):cat(imrois, 2) 100 | local imgtboxes = torch.FloatTensor():cat(bg.gtboxes, fg.gtboxes, 1) 101 | local imlabels = torch.IntTensor():cat(bg.labels, fg.labels, 1) 102 | 103 | table.insert(rois, imrois) 104 | table.insert(gtboxes, imgtboxes) 105 | table.insert(labels, imlabels) 106 | end 107 | gtboxes = torch.FloatTensor():cat(gtboxes,1) 108 | rois = torch.FloatTensor():cat(rois,1) 109 | labels = torch.IntTensor():cat(labels,1) 110 | return rois, labels, gtboxes 111 | end 112 | 113 | 114 | function BatchProviderROI:sample() 115 | collectgarbage() 116 | self.fg_num_each = self.fg_fraction * self.batch_size 117 | self.bg_num_each = self.batch_size - self.fg_num_each 118 | 119 | local img_idx, boxes, do_flip = self:permuteIdx() 120 | local images, im_scales, im_sizes = self:getImages(img_idx, do_flip) 121 | local rois, labels, gtboxes = self:selectBBoxes(boxes, im_scales, im_sizes, do_flip) 122 | 123 | local bboxregr_vals = torch.FloatTensor(rois:size(1), 4*(self.dataset:getNumClasses() + 1)):zero() 124 | 125 | for i,label in ipairs(labels:totable()) do 126 | if label > 1 then 127 | local out = bboxregr_vals[i]:narrow(1,(label-1)*4 + 1,4) 128 | utils.convertTo(out, rois[i]:narrow(1,2,4), gtboxes[i]) 129 | out:add(-1,self.bbox_regr.mean):cdiv(self.bbox_regr.std) 130 | end 131 | end 132 | 133 | local batches = {images, rois} 134 | local targets = {labels, {labels, bboxregr_vals}, g_donkey_idx} 135 | 136 | return batches, targets 137 | end 138 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://code.fb.com/codeofconduct) so that you can understand what actions will and will not be tolerated. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to multipathnet 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `master`. 10 | 2. If you haven't already, complete the Contributor License Agreement ("CLA"). 11 | 12 | ## Contributor License Agreement ("CLA") 13 | In order to accept your pull request, we need you to submit a CLA. You only need 14 | to do this once to work on any of Facebook's open source projects. 15 | 16 | Complete your CLA here: 17 | 18 | ## Issues 19 | We use GitHub issues to track public bugs. Please ensure your description is 20 | clear and has sufficient instructions to be able to reproduce the issue. 21 | 22 | ## Coding Style 23 | * 3 spaces for indentation rather than tabs 24 | * 80 character line length 25 | 26 | ## License 27 | By contributing to multipathnet, you agree that your contributions will be licensed 28 | under its BSD license. 29 | -------------------------------------------------------------------------------- /DataSetJSON.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local DataLoader = require 'loaders.dataloader' 10 | local ConcatLoader = require 'loaders.concatloader' 11 | local NarrowLoader = require 'loaders.narrowloader' 12 | 13 | local utils = paths.dofile'utils.lua' 14 | local stringx = require('pl.stringx') 15 | 16 | local DataSetCOCO = {} 17 | 18 | function DataSetCOCO:create(name, roidbfile, nsamples, offset) 19 | local dataset 20 | if name == 'coco_trainval2014' then 21 | local train = DataLoader('coco_train2014') 22 | local val = DataLoader('coco_val2014') 23 | 24 | dataset = ConcatLoader{train, NarrowLoader(val, 5001, val:nImages() - 5000)} 25 | elseif name == 'coco_val5k2014' then 26 | local val = DataLoader('coco_val2014') 27 | dataset = NarrowLoader(val, 1, 5000) 28 | elseif name == 'coco_val35k2014' then 29 | local val = DataLoader('coco_val2014') 30 | dataset = NarrowLoader(val, 5001, val:nImages() - 5000) 31 | elseif name == 'pascal_trainval2007,2012' then 32 | dataset = ConcatLoader{ 33 | DataLoader('pascal_train2007'), 34 | DataLoader('pascal_val2007'), 35 | DataLoader('pascal_train2012'), 36 | DataLoader('pascal_val2012'), 37 | } 38 | elseif name == 'pascal_trainval2007' then 39 | dataset = ConcatLoader{ 40 | DataLoader('pascal_train2007'), 41 | DataLoader('pascal_val2007'), 42 | } 43 | else 44 | dataset = DataLoader(name) 45 | end 46 | 47 | if offset and offset ~= -1 then 48 | local size = dataset:nImages() 49 | nsamples = math.min(size - offset + 1, nsamples) 50 | dataset = NarrowLoader(dataset, offset, nsamples) 51 | end 52 | 53 | self.dataset_name = name 54 | dataset.do_normalize = false 55 | self.dataset = dataset 56 | self.classes = {} 57 | if dataset.categories then -- coco_test2014 does not have categories 58 | for i,v in ipairs(dataset.categories) do self.classes[i] = v.name end 59 | end 60 | self.roidbfile = roidbfile 61 | self.min_area = 0 62 | self.min_proposal_area = 0 63 | self.nsamples = nsamples 64 | 65 | self.sample_n_per_box = 0 66 | self.sample_sigma = 1 67 | self.allow_missing_proposals=true 68 | return self 69 | end 70 | 71 | function DataSetCOCO:allowMissingProposals(allow_missing_proposals) 72 | self.allow_missing_proposals = allow_missing_proposals 73 | return self 74 | end 75 | 76 | function DataSetCOCO:size() 77 | if self.nsamples and self.nsamples >=0 then 78 | return self.nsamples 79 | end 80 | return self.dataset:nImages() 81 | end 82 | 83 | function DataSetCOCO:getImage(i) 84 | return self.dataset:loadImage(i) 85 | end 86 | 87 | function DataSetCOCO:getNumClasses() 88 | return self.dataset:nCategories() 89 | end 90 | 91 | function DataSetCOCO:setMinArea(area) 92 | assert(torch.type(area) == 'number') 93 | self.min_area = area 94 | end 95 | 96 | function DataSetCOCO:setMinProposalArea(area) 97 | assert(torch.type(area) == 'number') 98 | self.min_proposal_area = area 99 | end 100 | 101 | function DataSetCOCO:getAnnotation(i) 102 | local object = {} 103 | for j,a in ipairs(self.dataset:getAnnotationsForImage(i)) do 104 | if a.area > self.min_area then 105 | assert(a.difficult) 106 | local bbox = a.bbox:clone():float() 107 | bbox:narrow(1,3,2):add(bbox:narrow(1,1,2)):add(1) 108 | table.insert(object, {bbox = bbox, class_id = a.category, difficult = a.difficult, iscrowd = a.iscrowd}) 109 | end 110 | end 111 | return object 112 | end 113 | 114 | local function TableConcat(t1,t2) 115 | if not t1 or t1:nElement() == 0 then 116 | return t2:float() 117 | end 118 | if not t2 or t2:nElement() == 0 then 119 | return t1:float() 120 | end 121 | return torch.cat(t1:float(), t2:float(), 1) 122 | end 123 | 124 | function DataSetCOCO:loadAndMergeProposals(roidbfile) 125 | local dt 126 | if type(roidbfile) == 'table' then 127 | dt = {boxes={}, scores={}, images={}} 128 | local img2idx = {} 129 | for i = 1, #roidbfile do 130 | assert(roidbfile[i] and paths.filep(roidbfile[i]),'proposals file ('..roidbfile[i]..') not found') 131 | local dt2 = torch.load(roidbfile[i]) 132 | for k,v in pairs(dt2.images) do 133 | if not img2idx[v] then 134 | table.insert(dt.images, v) 135 | img2idx[v] = #dt.images 136 | end 137 | local idx = img2idx[v] 138 | dt.boxes[idx] = TableConcat(dt.boxes[idx], dt2.boxes[k]) 139 | if dt2.scores then 140 | dt.scores[idx] = TableConcat(dt.scores[idx], dt2.scores[k]) 141 | else 142 | -- lets just score unscored proposals as 0 143 | dt.scores[idx] = TableConcat(dt.scores[idx], 144 | torch.FloatTensor(dt2.boxes[k]:size(1)):zero()) 145 | end 146 | end 147 | end 148 | elseif type(roidbfile) == 'string' then 149 | assert(roidbfile and paths.filep(roidbfile),'proposals file ('..roidbfile..') not found') 150 | dt = torch.load(roidbfile) 151 | else 152 | error("???") 153 | end 154 | return dt 155 | end 156 | 157 | local permute_tensor = torch.LongTensor{2,1,4,3} 158 | 159 | local function filterScore(boxes, scores, best_number) 160 | if not scores then 161 | return boxes 162 | end 163 | if boxes:size(1) > best_number then -- select boxes with best scores 164 | local _,idx = scores:sort(true) 165 | idx = idx:narrow(1,1,best_number) 166 | -- print('scores', scores:size()) 167 | -- print('idx', idx:size()) 168 | boxes = boxes:index(1,idx) 169 | scores = scores:index(1,idx) 170 | end 171 | return boxes, scores 172 | end 173 | 174 | local function filterArea(boxes, scores, area) 175 | if area == 0 then 176 | return boxes, scores 177 | else 178 | assert(boxes:nDimension() == 2) 179 | local wh = boxes:narrow(2,3,2):clone():add(-1, boxes:narrow(2,1,2)) 180 | local s = wh:select(2,1):cmul(wh:select(2,2)) 181 | local idx = s:gt(area):nonzero() 182 | idx = idx:view(idx:nElement()) 183 | local new_boxes = boxes:index(1, idx) 184 | local new_scores = scores and scores:index(1, idx) 185 | -- print("filterArea: reduced proposals from " .. boxes:size(1) .. " to " .. new_boxes:size(1)) 186 | return new_boxes, new_scores 187 | end 188 | end 189 | 190 | function DataSetCOCO:loadROIDB(best_number) 191 | if self.roidb then 192 | return 193 | end 194 | local roidbfile = self.roidbfile 195 | 196 | print("Loading proposals at ", roidbfile) 197 | local dt = self:loadAndMergeProposals(roidbfile) 198 | print("Done loading proposals") 199 | 200 | assert(#dt.boxes == #dt.images) 201 | print('# proposal images', #dt.boxes) 202 | print('# dataset images', self.dataset:nImages()) 203 | -- assert(#dt.boxes >= self.dataset:nImages(), 'proposals have fewer boxes than dataset ' .. #dt.boxes .. ' ' .. self.dataset:nImages()) 204 | if dt.scores then 205 | assert(#dt.boxes == #dt.scores) 206 | assert(best_number and torch.type(best_number) == 'number','best_number has to be a valid number, e.g. 500 or 5000') 207 | end 208 | 209 | self.roidb = {} 210 | self.scoredb = {} 211 | 212 | print('# images', #dt.images) 213 | print('nImages', self.dataset:nImages()) 214 | local im2box = {} 215 | for i = 1,#dt.images do 216 | im2box[dt.images[i] ] = i 217 | end 218 | 219 | for i=1,self.dataset:nImages() do 220 | local file_name = self.dataset:getImage(i).file_name 221 | if not self.allow_missing_proposals then 222 | assert(im2box[file_name], file_name .. " is not in proposals") 223 | elseif not im2box[file_name] then 224 | print("WARNING: " .. i .. " " .. file_name .. " is not in proposals") 225 | end 226 | if im2box[file_name] then --assert(im2box[file_name], file_name .. " is not in proposals") 227 | local boxes = dt.boxes[im2box[file_name] ]:float() 228 | 229 | local scores = dt.scores and dt.scores[im2box[file_name] ]:float() 230 | scores = scores and scores:reshape(scores:nElement()) 231 | boxes, scores = filterArea(boxes, scores, self.min_proposal_area) 232 | boxes, scores = filterScore(boxes, scores, best_number) 233 | 234 | boxes = boxes:size(2) ~= 4 and torch.FloatTensor(0,4) or boxes:index(2,permute_tensor) 235 | self.roidb[i] = boxes 236 | self.scoredb[i] = scores 237 | end 238 | end 239 | end 240 | 241 | function DataSetCOCO:getROIBoxes(i) 242 | if not self.roidb then self:loadROIDB() end 243 | assert(self.roidb[i], "No proposals for image " .. self.dataset:getImage(i).file_name) 244 | return self.roidb[i] 245 | end 246 | 247 | function DataSetCOCO:getROIScores(i) 248 | if not self.roidb then self:loadROIDB() end 249 | return self.scoredb[i] 250 | end 251 | 252 | 253 | function DataSetCOCO:getGTBoxes(i) 254 | local anno = self:getAnnotation(i) 255 | local valid_objects = {} 256 | local gt_boxes = torch.FloatTensor() 257 | local gt_classes = {} 258 | 259 | for idx,obj in ipairs(anno) do 260 | if not obj.difficult or obj.difficult == 0 and not obj.iscrowd then 261 | table.insert(valid_objects,idx) 262 | end 263 | end 264 | 265 | gt_boxes:resize(#valid_objects,4) 266 | for idx0,idx in ipairs(valid_objects) do 267 | gt_boxes[idx0]:copy(anno[idx].bbox) 268 | table.insert(gt_classes, anno[idx].class_id) 269 | end 270 | return gt_boxes,gt_classes,valid_objects,anno 271 | end 272 | 273 | 274 | local function sampleAroundGTBoxes(boxes, n_per_box, sigma) 275 | local samples = torch.repeatTensor(boxes, n_per_box, 1) 276 | return samples:add(torch.FloatTensor(#samples):normal(0,sigma)) 277 | end 278 | 279 | 280 | function DataSetCOCO:attachProposals(i) 281 | if not self.roidb then self:loadROIDB() end 282 | 283 | local boxes = self:getROIBoxes(i) 284 | -- handle 285 | local gt_boxes,gt_classes,valid_objects,anno = self:getGTBoxes(i) 286 | if self.sample_n_per_box > 0 and gt_boxes:numel() > 0 then 287 | local sampled = sampleAroundGTBoxes(gt_boxes, self.sample_n_per_box, self.sample_sigma) 288 | boxes = boxes:cat(sampled, 1) 289 | end 290 | 291 | local all_boxes 292 | if anno then 293 | if #valid_objects > 0 and boxes:dim() > 0 then 294 | all_boxes = torch.cat(gt_boxes,boxes,1) 295 | elseif boxes:dim() == 0 then 296 | all_boxes = gt_boxes 297 | else 298 | all_boxes = boxes 299 | end 300 | else 301 | gt_boxes = torch.FloatTensor(0,4) 302 | all_boxes = boxes 303 | end 304 | 305 | local num_boxes = boxes:dim() > 0 and boxes:size(1) or 0 306 | local num_gt_boxes = #gt_classes 307 | 308 | local rec = {} 309 | if num_gt_boxes > 0 and num_boxes > 0 then 310 | rec.gt = torch.cat(torch.ByteTensor(num_gt_boxes):fill(1), 311 | torch.ByteTensor(num_boxes):fill(0) ) 312 | elseif num_boxes > 0 then 313 | rec.gt = torch.ByteTensor(num_boxes):fill(0) 314 | elseif num_gt_boxes > 0 then 315 | rec.gt = torch.ByteTensor(num_gt_boxes):fill(1) 316 | else 317 | rec.gt = torch.ByteTensor(0) 318 | end 319 | 320 | rec.overlap_class = torch.FloatTensor(num_boxes+num_gt_boxes,self:getNumClasses()):fill(0) 321 | rec.overlap = torch.FloatTensor(num_boxes+num_gt_boxes,num_gt_boxes):fill(0) 322 | for idx=1,num_gt_boxes do 323 | local o = utils.boxoverlap(all_boxes,gt_boxes[idx]) 324 | local tmp = rec.overlap_class[{{},gt_classes[idx]}] -- pointer copy 325 | tmp[tmp:lt(o)] = o[tmp:lt(o)] 326 | rec.overlap[{{},idx}] = utils.boxoverlap(all_boxes,gt_boxes[idx]) 327 | end 328 | -- get max class overlap 329 | --rec.overlap,rec.label = rec.overlap:max(2) 330 | --rec.overlap = torch.squeeze(rec.overlap,2) 331 | --rec.label = torch.squeeze(rec.label,2) 332 | --rec.label[rec.overlap:eq(0)] = 0 333 | 334 | if num_gt_boxes > 0 then 335 | rec.overlap,rec.correspondance = rec.overlap:max(2) 336 | rec.overlap = torch.squeeze(rec.overlap,2) 337 | rec.correspondance = torch.squeeze(rec.correspondance,2) 338 | rec.correspondance[rec.overlap:eq(0)] = 0 339 | else 340 | rec.overlap = torch.FloatTensor(num_boxes+num_gt_boxes):fill(0) 341 | rec.correspondance = torch.LongTensor(num_boxes+num_gt_boxes):fill(0) 342 | end 343 | rec.label = torch.IntTensor(num_boxes+num_gt_boxes):fill(0) 344 | 345 | do -- handle crowds 346 | -- find crowd boxes 347 | local crowds = {} 348 | for i,v in ipairs(anno) do 349 | if v.iscrowd then table.insert(crowds, v.bbox)end 350 | end 351 | if #crowds > 0 then 352 | -- compute intersections of all objects with each crowd 353 | local inters = torch.FloatTensor(#crowds, all_boxes:size(1)) 354 | for i,v in ipairs(crowds) do 355 | inters[i] = utils.intersection(all_boxes, v) 356 | end 357 | local maxinters = inters:max(1) 358 | local mask = maxinters:gt(0.7):select(1,1) 359 | -- don't want to exclude ground truth boxes 360 | mask:narrow(1,1,num_gt_boxes):fill(0) 361 | rec.overlap:maskedFill(mask, -1) 362 | end 363 | end 364 | 365 | for idx=1,(num_boxes+num_gt_boxes) do 366 | local corr = rec.correspondance[idx] 367 | if corr > 0 then 368 | local obj = anno[valid_objects[corr] ] 369 | rec.label[idx] = obj.class_id 370 | end 371 | end 372 | 373 | rec.boxes = all_boxes 374 | if num_gt_boxes > 0 and num_boxes > 0 then 375 | rec.class = torch.cat(torch.CharTensor(gt_classes), 376 | torch.CharTensor(num_boxes):fill(0)) 377 | elseif num_boxes > 0 then 378 | rec.class = torch.CharTensor(num_boxes):fill(0) 379 | elseif num_gt_boxes > 0 then 380 | rec.class = torch.CharTensor(gt_classes) 381 | else 382 | rec.class = torch.CharTensor(0) 383 | end 384 | 385 | function rec:size() 386 | return (num_boxes+num_gt_boxes) 387 | end 388 | 389 | return rec 390 | end 391 | 392 | return DataSetCOCO 393 | -------------------------------------------------------------------------------- /ImageDetect.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local utils = paths.dofile'utils.lua' 10 | local ImageDetect = torch.class('fbcoco.ImageDetect') 11 | 12 | function ImageDetect:__init(model, transformer, scale, max_size) 13 | assert(model, 'must provide model!') 14 | assert(transformer, 'must provide transformer!') 15 | self.model = model 16 | self.image_transformer = transformer 17 | self.scale = scale or {600} 18 | self.max_size = max_size or 1000 19 | self.sm = nn.SoftMax():cuda() 20 | end 21 | 22 | local function getImages(self,images,im) 23 | local num_scales = #self.scale 24 | 25 | local imgs = {} 26 | local im_sizes = {} 27 | local im_scales = {} 28 | 29 | im = self.image_transformer:forward(im) 30 | 31 | local im_size = im[1]:size() 32 | local im_size_min = math.min(im_size[1],im_size[2]) 33 | local im_size_max = math.max(im_size[1],im_size[2]) 34 | for i=1,num_scales do 35 | local im_scale = self.scale[i]/im_size_min 36 | if torch.round(im_scale*im_size_max) > self.max_size then 37 | im_scale = self.max_size/im_size_max 38 | end 39 | local im_s = {im_size[1]*im_scale,im_size[2]*im_scale} 40 | table.insert(imgs,image.scale(im,im_s[2],im_s[1])) 41 | table.insert(im_sizes,im_s) 42 | table.insert(im_scales,im_scale) 43 | end 44 | -- create single tensor with all images, padding with zero for different sizes 45 | im_sizes = torch.IntTensor(im_sizes) 46 | local max_shape = im_sizes:max(1)[1] 47 | images:resize(num_scales,3,max_shape[1],max_shape[2]):zero() 48 | for i=1,num_scales do 49 | images[i][{{},{1,imgs[i]:size(2)},{1,imgs[i]:size(3)}}]:copy(imgs[i]) 50 | end 51 | return im_scales 52 | end 53 | 54 | local function project_im_rois(im_rois,scales) 55 | local levels 56 | local rois = torch.FloatTensor() 57 | if #scales > 1 then 58 | local scales = torch.FloatTensor(scales) 59 | local widths = im_rois[{{},3}] - im_rois[{{},1}] + 1 60 | local heights = im_rois[{{},4}] - im_rois[{{}, 2}] + 1 61 | 62 | local areas = widths * heights 63 | local scaled_areas = areas:view(-1,1) * torch.pow(scales:view(1,-1),2) 64 | local diff_areas = torch.abs(scaled_areas - 224 * 224) 65 | levels = select(2, diff_areas:min(2)) 66 | else 67 | levels = torch.FloatTensor() 68 | rois:resize(im_rois:size(1),5) 69 | rois[{{},1}]:fill(1) 70 | rois[{{},{2,5}}]:copy(im_rois):add(-1):mul(scales[1]):add(1) 71 | end 72 | return rois 73 | end 74 | 75 | local function recursiveSplit(x, bs, dim) 76 | if type(x) == 'table' then 77 | local res = {} 78 | for k,v in pairs(x) do 79 | local tmp = v:split(bs,dim) 80 | for i=1,#tmp do 81 | if not res[i] then res[i] = {} end 82 | res[i][k] = tmp[i] 83 | end 84 | end 85 | return res 86 | else 87 | return x:split(bs, dim) 88 | end 89 | end 90 | 91 | function ImageDetect:memoryEfficientForward(model, input, bs, recompute_features) 92 | local images = input[1] 93 | local rois = input[2] 94 | local recompute_features = recompute_features == nil and true or recompute_features 95 | assert(model.output[1]:numel() > 0) 96 | 97 | local rest = nn.Sequential() 98 | for i=2,#model.modules do rest:add(model:get(i)) end 99 | local final = model:get(#model.modules) 100 | 101 | -- assuming the net has bbox regression part 102 | self.output = self.output or {torch.CudaTensor(), torch.CudaTensor()} 103 | local num_classes = self.model.output[1]:size(2) 104 | self.output[1]:resize(rois:size(1), num_classes) 105 | self.output[2]:resize(rois:size(1), num_classes * 4) 106 | 107 | if recompute_features then 108 | model:get(1):forward{images,rois} 109 | else 110 | model:get(1).output[2] = rois 111 | end 112 | 113 | local features = model:get(1).output 114 | assert(features[2]:size(1) == rois:size(1)) 115 | 116 | local roi_split = features[2]:split(bs,1) 117 | local output1_split = self.output[1]:split(bs,1) 118 | local output2_split = self.output[2]:split(bs,1) 119 | 120 | for i,v in ipairs(roi_split) do 121 | local out = rest:forward({features[1], v}) 122 | output1_split[i]:copy(out[1]) 123 | output2_split[i]:copy(out[2]) 124 | end 125 | 126 | local function test() 127 | local output_full = model:forward({images,rois}) 128 | 129 | local output_split = self.output 130 | assert((output_full[1] - output_split[1]):abs():max() == 0) 131 | assert((output_full[2] - output_split[2]):abs():max() == 0) 132 | end 133 | --test() 134 | return self.output 135 | end 136 | 137 | function ImageDetect:computeRawOutputs(im, boxes, min_images, recompute_features) 138 | self.model:evaluate() 139 | 140 | local inputs = {torch.FloatTensor(),torch.FloatTensor()} 141 | local im_scales = getImages(self,inputs[1],im) 142 | inputs[2] = project_im_rois(boxes,im_scales) 143 | if min_images then 144 | assert(inputs[1]:size(1) == 1) 145 | inputs[1] = inputs[1]:expand(min_images, inputs[1]:size(2), inputs[1]:size(3), inputs[1]:size(4)) 146 | end 147 | 148 | self.inputs_cuda = self.inputs_cuda or {torch.CudaTensor(),torch.CudaTensor()} 149 | self.inputs_cuda[1]:resize(inputs[1]:size()):copy(inputs[1]) 150 | self.inputs_cuda[2]:resize(inputs[2]:size()):copy(inputs[2]) 151 | 152 | return self.model:forward(self.inputs_cuda) 153 | end 154 | 155 | -- supposes boxes is in [x1,y1,x2,y2] format 156 | function ImageDetect:detect(im, boxes, min_images, recompute_features) 157 | self.model:evaluate() 158 | 159 | local inputs = {torch.FloatTensor(),torch.FloatTensor()} 160 | local im_scales = getImages(self,inputs[1],im) 161 | inputs[2] = project_im_rois(boxes,im_scales) 162 | if min_images then 163 | assert(inputs[1]:size(1) == 1) 164 | inputs[1] = inputs[1]:expand(min_images, inputs[1]:size(2), inputs[1]:size(3), inputs[1]:size(4)) 165 | end 166 | 167 | self.inputs_cuda = self.inputs_cuda or {torch.CudaTensor(),torch.CudaTensor()} 168 | self.inputs_cuda[1]:resize(inputs[1]:size()):copy(inputs[1]) 169 | self.inputs_cuda[2]:resize(inputs[2]:size()):copy(inputs[2]) 170 | 171 | local output0 172 | if opt and opt.disable_memory_efficient_forward then 173 | print('memory efficient forward disabled') 174 | output0 = self.model:forward(self.inputs_cuda) 175 | else 176 | output0 = self:memoryEfficientForward(self.model, self.inputs_cuda, 500, recompute_features) 177 | end 178 | 179 | local class_values, bbox_values 180 | if torch.type(output0) == 'table' then 181 | class_values= output0[1] 182 | bbox_values = output0[2]:float() 183 | for i,v in ipairs(bbox_values:split(4,2)) do 184 | utils.convertFrom(v,boxes,v) 185 | end 186 | else 187 | class_values = output0 188 | end 189 | if not self.model.noSoftMax then 190 | class_values = self.sm:forward(class_values) 191 | end 192 | return class_values:float(), bbox_values 193 | end 194 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For multipathnet software 4 | 5 | Copyright (c) 2016, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | TORCH_INCLUDE = $(shell dirname `which th`)/../include 2 | 3 | libnms.so: 4 | gcc -I$(TORCH_INCLUDE) nms.c -fPIC -std=c99 -shared -O3 -o $@ 5 | clean: 6 | rm libnms.so 7 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the multipathnet software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | MultiPath Network training code 2 | ========== 3 | 4 | The code provides functionality to train Fast R-CNN and MultiPath Networks in [Torch-7](http://torch.ch).
5 | Corresponding paper: **A MultiPath Network for Object Detection** http://arxiv.org/abs/1604.02135 6 | 7 | ![sheep](https://cloud.githubusercontent.com/assets/4953728/17826153/442d027a-666e-11e6-9a1e-2fac95a2d3ba.jpg) 8 | 9 | If you use MultiPathNet in your research, please cite the relevant papers: 10 | 11 | ``` 12 | @INPROCEEDINGS{Zagoruyko2016Multipath, 13 | author = {S. Zagoruyko and A. Lerer and T.-Y. Lin and P. O. Pinheiro and S. Gross and S. Chintala and P. Doll{\'{a}}r}, 14 | title = {A MultiPath Network for Object Detection}, 15 | booktitle = {BMVC} 16 | year = {2016} 17 | } 18 | ``` 19 | 20 | ## Requirements 21 | 22 | * Linux 23 | * NVIDIA GPU with compute capability 3.5+ 24 | 25 | ## Installation 26 | 27 | The code depends on Torch-7, fb.python and several other easy-to-install torch packages.
28 | To install Torch, follow http://torch.ch/docs/getting-started.html
29 | Then install additional packages: 30 | 31 | ```bash 32 | luarocks install inn 33 | luarocks install torchnet 34 | luarocks install fbpython 35 | luarocks install class 36 | ``` 37 | 38 | Evaluation relies on COCO API calls via python interface, because lua interface doesn't support it. 39 | Lua API is used to load annotation files in \*json to COCO API data structures. This doesn't work for proposal 40 | files as they're too big, so we provide converted proposals for sharpmask and selective search in torch format. 41 | 42 | First, clone https://github.com/pdollar/coco: 43 | 44 | ``` 45 | git clone https://github.com/pdollar/coco 46 | ``` 47 | 48 | Then install LuaAPI: 49 | 50 | ``` 51 | cd coco 52 | luarocks make LuaAPI/rocks/coco-scm-1.rockspec 53 | ``` 54 | 55 | And PythonAPI: 56 | 57 | ``` 58 | cd coco/PythonAPI 59 | make 60 | ``` 61 | 62 | You might need to install Cython for this: 63 | 64 | ``` 65 | sudo apt-get install python-pip 66 | sudo pip install Cython 67 | ``` 68 | 69 | You will have to add the path to PythonAPI to `PYTHONPATH`. Note that this won't work with anaconda as it ships 70 | with it's own libraries which conflict with torch. 71 | 72 | ### EC2 installation script 73 | 74 | Thanks to @DeegC there is [scripts/ec2-install.sh](scripts/ec2-install.sh) script for quick EC2 setup. 75 | 76 | ## Data preparation 77 | 78 | The root folder should have a folder `data` with the following subfolders: 79 | 80 | ``` 81 | models/ 82 | annotations/ 83 | proposals/ 84 | ``` 85 | 86 | `models` folder should contain AlexNet and VGG pretrained imagenet files downloaded from [here](#training). ResNets can resident in other places specified by `resnet_path` env variable. 87 | 88 | `annotations` should contain \*json files downloaded from http://mscoco.org/external. There are \*json annotation files for 89 | PASCAL VOC, MSCOCO, ImageNet and other datasets. 90 | 91 | `proposals` should contain \*t7 files downloaded from here 92 | We provide selective search VOC 2007 and VOC 2012 proposals converted from https://github.com/rbgirshick/fast-rcnn and SharpMask proposals for COCO 2015 converted from https://github.com/facebookresearch/deepmask, which can be used to compute proposals for new images as well. 93 | 94 | Here is an example structure: 95 | 96 | ``` 97 | data 98 | |-- annotations 99 | | |-- instances_train2014.json 100 | | |-- instances_val2014.json 101 | | |-- pascal_test2007.json 102 | | |-- pascal_train2007.json 103 | | |-- pascal_train2012.json 104 | | |-- pascal_val2007.json 105 | | `-- pascal_val2012.json 106 | |-- models 107 | | |-- caffenet_fast_rcnn_iter_40000.t7 108 | | |-- imagenet_pretrained_alexnet.t7 109 | | |-- imagenet_pretrained_vgg.t7 110 | | `-- vgg16_fast_rcnn_iter_40000.t7 111 | `-- proposals 112 | |-- VOC2007 113 | | `-- selective_search 114 | | |-- test.t7 115 | | |-- train.t7 116 | | |-- trainval.t7 117 | | `-- val.t7 118 | `-- coco 119 | `-- sharpmask 120 | |-- train.t7 121 | `-- val.t7 122 | ``` 123 | 124 | Download selective_search proposals for VOC2007: 125 | 126 | ```bash 127 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/VOC2007/selective_search/train.t7 128 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/VOC2007/selective_search/val.t7 129 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/VOC2007/selective_search/trainval.t7 130 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/VOC2007/selective_search/test.t7 131 | ``` 132 | 133 | Download sharpmask proposals for COCO: 134 | 135 | ```bash 136 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/coco/sharpmask/train.t7 137 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/coco/sharpmask/val.t7 138 | ``` 139 | 140 | As for the images themselves, provide paths to VOCDevkit and COCO in [config.lua](config.lua) 141 | 142 | ## Running DeepMask with MultiPathNet on provided image 143 | 144 | We provide an example of how to extract DeepMask or SharpMask proposals from an image and run recognition MultiPathNet 145 | to classify them, then do non-maximum suppression and draw the found objects. 146 | 147 | 1. Clone DeepMask project into the root directory: 148 | 149 | ```bash 150 | git clone https://github.com/facebookresearch/deepmask 151 | ``` 152 | 153 | 2. Download DeepMask or SharpMask network: 154 | 155 | ```bash 156 | cd data/models 157 | # download SharpMask based on ResNet-50 158 | wget https://dl.fbaipublicfiles.com/deepmask/models/sharpmask/model.t7 -O sharpmask.t7 159 | ``` 160 | 161 | 3. Download recognition network: 162 | 163 | ```bash 164 | cd data/models 165 | # download ResNet-18-based model trained on COCO with integral loss 166 | wget https://dl.fbaipublicfiles.com/multipathnet/models/resnet18_integral_coco.t7 167 | ``` 168 | 169 | 4. Make sure you have COCO validation .json files in `data/annotations/instances_val2014.json` 170 | 171 | 5. Pick some image and run the script: 172 | 173 | ```bash 174 | th demo.lua -img ./deepmask/data/testImage.jpg 175 | ``` 176 | 177 | And you should see this image: 178 | 179 | ![iterm2 4jpuod lua_khbaaq](https://cloud.githubusercontent.com/assets/4953728/17951035/69d6cb2e-6a5f-11e6-83b8-767c2ae0ae64.png) 180 | 181 | See file [demo.lua](demo.lua) for details. 182 | 183 | ## Training 184 | 185 | The repository supports training Fast-RCNN and MultiPath networks with data and model multi-GPU paralellism. 186 | Supported base models are the following: 187 | 188 | * AlexNet trained in [caffe](https://github.com/bvlc/caffe) by Ross Girshick, [imagenet_pretrained_alexnet.t7](https://dl.fbaipublicfiles.com/multipathnet/models/imagenet_pretrained_alexnet.t7) 189 | * VGG trained in [caffe](https://github.com/bvlc/caffe) by Ross Girshick, [imagenet_pretrained_vgg.t7](https://dl.fbaipublicfiles.com/multipathnet/models/imagenet_pretrained_vgg.t7) 190 | * ResNets trained in torch with [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) by Sam Gross 191 | * inception-v3 trained in [tensorflow](https://github.com/tensorflow/tensorflow) by Google 192 | * Network-In-Network trained in torch with [imagenet-multiGPU.torch](https://github.com/soumith/imagenet-multiGPU.torch) by Sergey Zagoruyko 193 | 194 | ### PASCAL VOC 195 | 196 | To train Fast-RCNN on VOC2007 trainval with VGG base model and selective search proposals do: 197 | 198 | ```bash 199 | test_nsamples=1000 model=vgg ./scripts/train_fastrcnn_voc2007.sh 200 | ``` 201 | 202 | The resulting mAP is slightly (~2 mAP) higher than original Fast-RCNN number. We should mention that the code is not exactly the same 203 | as we improved ROIPooling by fixing a few bugs, see https://github.com/szagoruyko/imagine-nn/pull/17 204 | 205 | ### COCO 206 | 207 | To train MultiPathNet with VGG-16 base model on 4 GPUs run: 208 | 209 | ```bash 210 | train_nGPU=4 test_nGPU=1 ./scripts/train_multipathnet_coco.sh 211 | ``` 212 | 213 | Here is a graph visualization of the network (click to enlarge): 214 | 215 | 216 | multipathnet 217 | 218 | 219 | To train ResNet-18 on COCO do: 220 | 221 | ```bash 222 | train_nGPU=4 test_nGPU=1 model=resnet resnet_path=./data/models/resnet/resnet-18.t7 ./scripts/train_coco.sh 223 | ``` 224 | 225 | ## Evaluation 226 | 227 | ### PASCAL VOC 228 | 229 | We provide original models from Fast-RCNN paper converted to torch format here: 230 | * [caffenet_fast_rcnn_iter_40000.t7](https://dl.fbaipublicfiles.com/multipathnet/models/caffenet_fast_rcnn_iter_40000.t7) 231 | * [vgg16_fast_rcnn_iter_40000.t7](https://dl.fbaipublicfiles.com/multipathnet/models/vgg16_fast_rcnn_iter_40000.t7) 232 | 233 | To evaluate these models run: 234 | 235 | ```bash 236 | model=data/models/caffenet_fast_rcnn_iter_40000.t7 ./scripts/eval_fastrcnn_voc2007.sh 237 | model=data/models/vgg_fast_rcnn_iter_40000.t7 ./scripts/eval_fastrcnn_voc2007.sh 238 | ``` 239 | 240 | ### COCO 241 | 242 | Evaluate fast ResNet-18-based network trained with integral loss on COCO val5k split ([resnet18_integral_coco.t7](https://dl.fbaipublicfiles.com/multipathnet/models/resnet18_integral_coco.t7) 89MB): 243 | 244 | ```bash 245 | test_nGPU=4 test_nsamples=5000 ./scripts/eval_coco.sh 246 | ``` 247 | 248 | It achieves 24.4 mAP using 400 SharpMask proposals per image: 249 | 250 | ``` 251 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.244 252 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.402 253 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.268 254 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.078 255 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.266 256 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.394 257 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.249 258 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.368 259 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.377 260 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.135 261 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.444 262 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.561 263 | ``` 264 | -------------------------------------------------------------------------------- /Tester_FRCNN.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local utils = paths.dofile('utils.lua') 10 | local tds = require 'tds' 11 | local testCoco = require 'testCoco.init' 12 | require 'sys' 13 | 14 | local Tester = torch.class('fbcoco.Tester_FRCNN') 15 | 16 | local Threads = require 'threads' 17 | Threads.serialization('threads.sharedserialize') 18 | 19 | function Tester:__init(module, transformer, dataset, scale, max_size, opt) 20 | self.dataset = dataset 21 | self.module = module 22 | self.transformer = transformer 23 | if module and transformer then 24 | self.detec = fbcoco.ImageDetect(self.module, self.transformer, scale, max_size) 25 | end 26 | self.num_iter = opt.test_num_iterative_loc or 1 27 | 28 | self.nms_thresh = opt.test_nms_threshold or 0.3 29 | self.bbox_vote_thresh = opt.test_bbox_voting_nms_threshold or 0.5 30 | 31 | self.threads = Threads(10, 32 | function() 33 | require 'torch' 34 | end) 35 | 36 | if module then 37 | module:apply(function(m) 38 | if torch.type(m) == 'nn.DataParallelTable' then 39 | self.data_parallel_n = #m.gpuAssignments 40 | end 41 | end) 42 | print('data_parallel_n', self.data_parallel_n) 43 | 44 | -- to determine num of output classes 45 | local input = {torch.CudaTensor(self.data_parallel_n or 2, 3, 224, 224), 46 | torch.Tensor{1, 1, 1, 100, 100}:view(1, 5):expand(2, 5):cuda()} 47 | module:forward(input) 48 | 49 | self.num_classes = module.output[1]:size(2) - 1 50 | self.thresh = torch.ones(self.num_classes):mul(-1.5) 51 | end 52 | end 53 | 54 | function Tester:testOne(i) 55 | local dataset = self.dataset 56 | local thresh = self.thresh 57 | 58 | local img_boxes = tds.hash() 59 | local timer = torch.Timer() 60 | local timer2 = torch.Timer() 61 | local timer3 = torch.Timer() 62 | 63 | timer:reset() 64 | local boxes = dataset:getROIBoxes(i):float() 65 | -- print('#boxes', boxes:size()) 66 | local im = dataset:getImage(i) 67 | timer3:reset() 68 | 69 | local all_output = {} 70 | local all_bbox_pred = {} 71 | 72 | local output, bbox_pred = self.detec:detect(im, boxes, self.data_parallel_n, true) 73 | local num_classes = output:size(2) - 1 74 | 75 | -- clamp predictions within image 76 | local bbox_pred_tmp = bbox_pred:view(-1, 2) 77 | bbox_pred_tmp:select(2,1):clamp(1, im:size(3)) 78 | bbox_pred_tmp:select(2,2):clamp(1, im:size(2)) 79 | 80 | table.insert(all_output, output) 81 | table.insert(all_bbox_pred, bbox_pred) 82 | for i = 2, self.num_iter do 83 | -- have to copy to cuda because of torch/cutorch LongTensor differences 84 | self.boxselect = self.boxselect or nn.SelectBoxes():cuda() 85 | local new_boxes = self.boxselect:forward{output:cuda(), bbox_pred:cuda()}:float() 86 | output, bbox_pred = self.detec:detect(im, new_boxes, self.data_parallel_n, false) 87 | table.insert(all_output, output) 88 | table.insert(all_bbox_pred, bbox_pred) 89 | end 90 | 91 | if opt.test_use_rbox_scores then 92 | assert(#all_output > 1) 93 | -- we use the scores from iter n+1 for the boxes at iter n 94 | -- this means we lose one iteration worth of boxes 95 | table.remove(all_output, 1) 96 | table.remove(all_bbox_pred) 97 | end 98 | 99 | output = utils.joinTable(all_output, 1) 100 | bbox_pred = utils.joinTable(all_bbox_pred, 1) 101 | 102 | local tt2 = timer3:time().real 103 | 104 | timer2:reset() 105 | local nms_timer = torch.Timer() 106 | for j = 1, num_classes do 107 | local scores = output:select(2, j+1) 108 | local idx = torch.range(1, scores:numel()):long() 109 | local idx2 = scores:gt(thresh[j]) 110 | idx = idx[idx2] 111 | local scored_boxes = torch.FloatTensor(idx:numel(), 5) 112 | if scored_boxes:numel() > 0 then 113 | local bx = scored_boxes:narrow(2, 1, 4) 114 | bx:copy(bbox_pred:narrow(2, j*4+1, 4):index(1, idx)) 115 | scored_boxes:select(2, 5):copy(scores[idx2]) 116 | end 117 | img_boxes[j] = utils.nms(scored_boxes, self.nms_thresh) 118 | if opt.test_bbox_voting then 119 | local rescaled_scored_boxes = scored_boxes:clone() 120 | local scores = rescaled_scored_boxes:select(2,5) 121 | scores:pow(opt.test_bbox_voting_score_pow or 1) 122 | 123 | img_boxes[j] = utils.bbox_vote(img_boxes[j], rescaled_scored_boxes, self.test_bbox_voting_nms_threshold) 124 | end 125 | end 126 | self.threads:synchronize() 127 | local nms_time = nms_timer:time().real 128 | 129 | if i % 1 == 0 then 130 | print(('test: (%s) %5d/%-5d dev: %d, forward time: %.3f, ' 131 | .. 'select time: %.3fs, nms time: %.3fs, ' 132 | .. 'total time: %.3fs'):format(dataset.dataset_name, 133 | i, dataset:size(), 134 | cutorch.getDevice(), 135 | tt2, timer2:time().real, 136 | nms_time, timer:time().real)); 137 | end 138 | return img_boxes, {output, bbox_pred} 139 | end 140 | 141 | function Tester:test() 142 | self.module:evaluate() 143 | self.dataset:loadROIDB() 144 | 145 | local aboxes_t = tds.hash() 146 | 147 | local raw_output = tds.hash() 148 | local raw_bbox_pred = tds.hash() 149 | 150 | for i = 1, self.dataset:size() do 151 | local img_boxes, raw_boxes = self:testOne(i) 152 | aboxes_t[i] = img_boxes 153 | if opt.test_save_raw and opt.test_save_raw ~= '' then 154 | raw_output[i] = raw_boxes[1]:float() 155 | raw_bbox_pred[i] = raw_boxes[2]:float() 156 | end 157 | end 158 | 159 | if opt.test_save_raw and opt.test_save_raw ~= '' then 160 | torch.save(opt.test_save_raw, {raw_output, raw_bbox_pred}) 161 | end 162 | 163 | aboxes_t = self:keepTopKPerImage(aboxes_t, 100) -- coco only accepts 100/image 164 | local aboxes = self:transposeBoxes(aboxes_t) 165 | aboxes_t = nil 166 | 167 | return self:computeAP(aboxes) 168 | end 169 | 170 | function Tester:keepTopKPerImage(aboxes_t, k) 171 | for j = 1,self.dataset:size() do 172 | aboxes_t[j] = utils.keep_top_k(aboxes_t[j], k) 173 | end 174 | return aboxes_t 175 | end 176 | 177 | function Tester:transposeBoxes(aboxes_t) 178 | -- print("Running topk. max= ", self.max_per_set) 179 | local aboxes = tds.hash() 180 | for j = 1, self.num_classes do 181 | aboxes[j] = tds.hash() 182 | for i = 1, self.dataset:size() do 183 | aboxes[j][i] = aboxes_t[i][j] 184 | end 185 | end 186 | return aboxes 187 | end 188 | 189 | function Tester:computeAP(aboxes) 190 | return testCoco.evaluate(self.dataset.dataset_name, aboxes) 191 | end 192 | 193 | -------------------------------------------------------------------------------- /config.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | -- put your paths to VOC and COCO containing subfolders with images here 10 | local VOCdevkit = '/home/zagoruys/raid/datasets/VOCdevkit' 11 | local coco_dir = '/home/zagoruys/raid/datasets/mscoco' 12 | 13 | return { 14 | pascal_train2007 = paths.concat(VOCdevkit, 'VOC2007/JPEGImages'), 15 | pascal_val2007 = paths.concat(VOCdevkit, 'VOC2007/JPEGImages'), 16 | pascal_test2007 = paths.concat(VOCdevkit, 'VOC2007/JPEGImages'), 17 | pascal_train2012 = paths.concat(VOCdevkit, 'VOC2012/JPEGImages'), 18 | pascal_val2012 = paths.concat(VOCdevkit, 'VOC2012/JPEGImages'), 19 | pascal_test2012 = paths.concat(VOCdevkit, 'VOC2012/JPEGImages'), 20 | coco_train2014 = paths.concat(coco_dir, 'train2014'), 21 | coco_val2014 = paths.concat(coco_dir, 'val2014'), 22 | } 23 | -------------------------------------------------------------------------------- /data.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local tnt = require 'torchnet' 10 | require 'donkey' 11 | 12 | -- create an instance of DataSetJSON to make roidb and scoredb 13 | -- that are passed to threads 14 | local roidb, scoredb 15 | do 16 | local ds = loadDataSet(opt) 17 | ds:loadROIDB(opt.best_proposals_number) 18 | roidb, scoredb = ds.roidb, ds.scoredb 19 | end 20 | 21 | local loader = createTrainLoader(opt, roidb, scoredb, 1) 22 | local bbox_regr = loader:setupData() 23 | g_mean_std = bbox_regr 24 | 25 | local opt = tnt.utils.table.clone(opt) 26 | 27 | local function getIterator() 28 | return tnt.ParallelDatasetIterator{ 29 | nthread = opt.nDonkeys, 30 | init = function(idx) 31 | require 'torchnet' 32 | require 'donkey' 33 | torch.manualSeed(opt.manualSeed + idx) 34 | g_donkey_idx = idx 35 | end, 36 | closure = function() 37 | local loaders = {} 38 | for i=1,(opt.integral and opt.nDonkeys or 1) do 39 | loaders[i] = createTrainLoader(opt, roidb, scoredb, i) 40 | end 41 | 42 | for i,v in ipairs(loaders) do 43 | v.bbox_regr = bbox_regr 44 | end 45 | 46 | return tnt.ListDataset{ 47 | list = torch.range(1,opt.epochSize):long(), 48 | load = function(idx) 49 | return {loaders[torch.random(#loaders)]:sample()} 50 | end, 51 | } 52 | end, 53 | } 54 | end 55 | 56 | return getIterator 57 | -------------------------------------------------------------------------------- /demo.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | -- DeepMask + MultiPathNet demo 10 | 11 | require 'deepmask.SharpMask' 12 | require 'deepmask.SpatialSymmetricPadding' 13 | require 'deepmask.InferSharpMask' 14 | require 'inn' 15 | require 'fbcoco' 16 | require 'image' 17 | local model_utils = require 'models.model_utils' 18 | local utils = require 'utils' 19 | local coco = require 'coco' 20 | 21 | local cmd = torch.CmdLine() 22 | cmd:option('-np', 5,'number of proposals to save in test') 23 | cmd:option('-si', -2.5, 'initial scale') 24 | cmd:option('-sf', .5, 'final scale') 25 | cmd:option('-ss', .5, 'scale step') 26 | cmd:option('-dm', false, 'use DeepMask version of SharpMask') 27 | cmd:option('-img','./deepmask/data/testImage.jpg' ,'path/to/test/image') 28 | cmd:option('-thr', 0.5, 'multipathnet score threshold [0,1]') 29 | cmd:option('-maxsize', 600, 'resize image dimension') 30 | cmd:option('-sharpmask_path', 'data/models/sharpmask.t7', 'path to sharpmask') 31 | cmd:option('-multipath_path', 'data/models/resnet18_integral_coco.t7', 'path to multipathnet') 32 | 33 | local config = cmd:parse(arg) 34 | 35 | local sharpmask = torch.load(config.sharpmask_path).model 36 | sharpmask:inference(config.np) 37 | 38 | local multipathnet = torch.load(config.multipath_path) 39 | multipathnet:evaluate() 40 | multipathnet:cuda() 41 | model_utils.testModel(multipathnet) 42 | 43 | local detector = fbcoco.ImageDetect(multipathnet, model_utils.ImagenetTransformer()) 44 | 45 | ------------------- Run DeepMask -------------------- 46 | 47 | local meanstd = {mean = { 0.485, 0.456, 0.406 }, std = { 0.229, 0.224, 0.225 }} 48 | 49 | local scales = {} 50 | for i = config.si,config.sf,config.ss do table.insert(scales,2^i) end 51 | print(scales) 52 | 53 | local infer = Infer{ 54 | np = config.np, 55 | scales = scales, 56 | meanstd = meanstd, 57 | model = sharpmask, 58 | dm = config.dm, 59 | } 60 | 61 | local img = image.load(config.img) 62 | img = image.scale(img, config.maxsize) 63 | local h,w = img:size(2),img:size(3) 64 | 65 | infer:forward(img) 66 | 67 | local masks,_ = infer:getTopProps(.2,h,w) 68 | 69 | local Rs = coco.MaskApi.encode(masks) 70 | local bboxes = coco.MaskApi.toBbox(Rs) 71 | bboxes:narrow(2,3,2):add(bboxes:narrow(2,1,2)) -- convert from x,y,w,h to x1,y1,x2,y2 72 | 73 | ------------------- Run MultiPathNet -------------------- 74 | 75 | local detections = detector:detect(img:float(), bboxes:float()) 76 | local prob, maxes = detections:max(2) 77 | 78 | -- remove background detections 79 | local idx = maxes:squeeze():gt(1):cmul(prob:gt(config.thr)):nonzero():select(2,1) 80 | bboxes = bboxes:index(1, idx) 81 | maxes = maxes:index(1, idx) 82 | prob = prob:index(1, idx) 83 | 84 | local scored_boxes = torch.cat(bboxes:float(), prob:float(), 2) 85 | local final_idx = utils.nms_dense(scored_boxes, 0.3) 86 | 87 | ------------------- Draw detections -------------------- 88 | 89 | -- remove suppressed masks 90 | masks = masks:index(1, idx):index(1, final_idx) 91 | 92 | local dataset = paths.dofile'./DataSetJSON.lua':create'coco_val2014' 93 | 94 | local res = img:clone() 95 | coco.MaskApi.drawMasks(res, masks, 10) 96 | for i,v in ipairs(final_idx:totable()) do 97 | local class = maxes[v][1]-1 98 | local x1,y1,x2,y2 = table.unpack(bboxes[v]:totable()) 99 | y2 = math.min(y2, res:size(2)) - 10 100 | local name = dataset.dataset.categories[class] 101 | print(prob[v][1], class, name) 102 | image.drawText(res, name, x1, y2, {bg={255,255,255}, inplace=true}) 103 | end 104 | image.save(string.format('./res.jpg',config.model),res) 105 | 106 | print('| done') 107 | -------------------------------------------------------------------------------- /donkey.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local tnt = require 'torchnet' 10 | local utils = paths.dofile'utils.lua' 11 | require 'fbcoco' 12 | 13 | function loadDataSet(opt) 14 | local dataset_name = opt.dataset..'_'..opt.train_set..opt.year 15 | local folder_name = opt.dataset == 'pascal' and ('VOC'..opt.year) or 'coco' 16 | local proposals_path = utils.makeProposalPath(opt.proposal_dir, folder_name, opt.proposals, opt.train_set, opt.imagenet_classes ~= '') 17 | 18 | local ds = paths.dofile'DataSetJSON.lua':create(dataset_name, proposals_path, opt.train_nsamples) 19 | if opt.imagenet_classes ~= '' then 20 | ds:allowMissingProposals(true) -- workaround 21 | end 22 | 23 | ds.sample_n_per_box = opt.sample_n_per_box 24 | ds.sample_sigma = opt.sample_n_per_box 25 | 26 | ds:setMinProposalArea(opt.train_min_proposal_size) 27 | -- ds:loadROIDB(opt.best_proposals_number) 28 | ds:setMinArea(opt.train_min_gtroi_size) 29 | return ds 30 | end 31 | 32 | function createTrainLoader(opt, roidb, scoredb, loader_idx) 33 | local ds = loadDataSet(opt) 34 | ds.roidb, ds.scoredb = roidb, scoredb 35 | local transformer = torch.load(opt.transformer) 36 | 37 | local fg_threshold, bg_threshold 38 | if opt.integral then 39 | local threshold = opt.bg_threshold_max + (loader_idx - 1) / 20 40 | bg_threshold = {opt.bg_threshold_min, threshold} 41 | fg_threshold = threshold 42 | else 43 | bg_threshold = {opt.bg_threshold_min, opt.bg_threshold_max} 44 | fg_threshold = opt.fg_threshold 45 | end 46 | 47 | local bp = fbcoco.BatchProviderROI(ds, opt.images_per_batch, opt.scale, opt.max_size, transformer, fg_threshold, bg_threshold) 48 | 49 | bp.batch_size = opt.batchSize 50 | bp.class_specific = opt.train_class_specific 51 | 52 | return bp 53 | end 54 | 55 | 56 | -------------------------------------------------------------------------------- /engines/Optim.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2004-present Facebook. All Rights Reserved. 2 | 3 | local pl = require('pl.import_into')() 4 | 5 | -- from fblualib/fb/util/data.lua , copied here because fblualib is not rockspec ready yet. 6 | -- deepcopy routine that assumes the presence of a 'clone' method in user 7 | -- data should be used to deeply copy. This matches the behavior of Torch 8 | -- tensors. 9 | local function deepcopy(x) 10 | local typename = type(x) 11 | if typename == "userdata" then 12 | return x:clone() 13 | end 14 | if typename == "table" then 15 | local retval = { } 16 | for k,v in pairs(x) do 17 | retval[deepcopy(k)] = deepcopy(v) 18 | end 19 | return retval 20 | end 21 | return x 22 | end 23 | 24 | local Optim, parent = torch.class('nn.Optim') 25 | 26 | 27 | -- Returns weight parameters and bias parameters and associated grad parameters 28 | -- for this module. Annotates the return values with flag marking parameter set 29 | -- as bias parameters set 30 | function Optim.weight_bias_parameters(module) 31 | local weight_params, bias_params 32 | if module.weight then 33 | weight_params = {module.weight, module.gradWeight} 34 | weight_params.is_bias = false 35 | end 36 | if module.bias then 37 | bias_params = {module.bias, module.gradBias} 38 | bias_params.is_bias = true 39 | end 40 | return {weight_params, bias_params} 41 | end 42 | 43 | -- The regular `optim` package relies on `getParameters`, which is a 44 | -- beastly abomination before all. This `optim` package uses separate 45 | -- optim state for each submodule of a `nn.Module`. 46 | function Optim:__init(model, optState, checkpoint_data) 47 | assert(model) 48 | assert(checkpoint_data or optState) 49 | assert(not (checkpoint_data and optState)) 50 | 51 | self.model = model 52 | self.modulesToOptState = {} 53 | -- Keep this around so we update it in setParameters 54 | self.originalOptState = optState 55 | 56 | -- Each module has some set of parameters and grad parameters. Since 57 | -- they may be allocated discontinuously, we need separate optState for 58 | -- each parameter tensor. self.modulesToOptState maps each module to 59 | -- a lua table of optState clones. 60 | if not checkpoint_data then 61 | self.model:apply(function(module) 62 | self.modulesToOptState[module] = { } 63 | local params = self.weight_bias_parameters(module) 64 | for i, _ in ipairs(params) do 65 | self.modulesToOptState[module][i] = deepcopy(optState) 66 | if params[i] and params[i].is_bias then 67 | -- never regularize biases 68 | self.modulesToOptState[module][i].weightDecay = 0.0 69 | end 70 | end 71 | assert(module) 72 | assert(self.modulesToOptState[module]) 73 | end) 74 | else 75 | local state = checkpoint_data.optim_state 76 | local modules = {} 77 | self.model:apply(function(m) table.insert(modules, m) end) 78 | assert(pl.tablex.compare_no_order(modules, pl.tablex.keys(state))) 79 | self.modulesToOptState = state 80 | end 81 | end 82 | 83 | function Optim:save() 84 | return { 85 | optim_state = self.modulesToOptState 86 | } 87 | end 88 | 89 | local function _type_all(obj, t) 90 | for k, v in pairs(obj) do 91 | if type(v) == 'table' then 92 | _type_all(v, t) 93 | else 94 | local tn = torch.typename(v) 95 | if tn and tn:find('torch%..+Tensor') then 96 | obj[k] = v:type(t) 97 | end 98 | end 99 | end 100 | end 101 | 102 | function Optim:type(t) 103 | self.model:apply(function(module) 104 | local state= self.modulesToOptState[module] 105 | assert(state) 106 | _type_all(state, t) 107 | end) 108 | end 109 | 110 | local function get_device_for_module(mod) 111 | local dev_id = nil 112 | for name, val in pairs(mod) do 113 | if torch.typename(val) == 'torch.CudaTensor' then 114 | local this_dev = val:getDevice() 115 | if this_dev ~= 0 then 116 | -- _make sure the tensors are allocated consistently 117 | assert(dev_id == nil or dev_id == this_dev) 118 | dev_id = this_dev 119 | end 120 | end 121 | end 122 | return dev_id -- _may still be zero if none are allocated. 123 | end 124 | 125 | local function on_device_for_module(mod, f) 126 | local this_dev = get_device_for_module(mod) 127 | if this_dev ~= nil then 128 | return cutorch.withDevice(this_dev, f) 129 | end 130 | return f() 131 | end 132 | 133 | function Optim:optimize(optimMethod, inputs, targets, criterion) 134 | assert(optimMethod) 135 | assert(inputs) 136 | assert(targets) 137 | assert(criterion) 138 | assert(self.modulesToOptState) 139 | 140 | self.model:zeroGradParameters() 141 | local output = self.model:forward(inputs) 142 | 143 | local err = criterion:forward(output, targets) 144 | 145 | local df_do = criterion:backward(output, targets) 146 | self.model:backward(inputs, df_do) 147 | 148 | self:updateParameters(optimMethod, err) 149 | 150 | return err, output 151 | end 152 | 153 | function Optim:updateParameters(optimMethod, err) 154 | assert(self.modulesToOptState) 155 | 156 | -- We'll set these in the loop that iterates over each module. Get them 157 | -- out here to be captured. 158 | local curGrad 159 | local curParam 160 | local function fEvalMod(x) 161 | return err, curGrad 162 | end 163 | 164 | for curMod, opt in pairs(self.modulesToOptState) do 165 | on_device_for_module(curMod, function() 166 | local curModParams = self.weight_bias_parameters(curMod) 167 | if curModParams then 168 | for i, tensor in ipairs(curModParams) do 169 | if curModParams[i] then 170 | -- expect param, gradParam pair 171 | curParam, curGrad = table.unpack(curModParams[i]) 172 | assert(curParam and curGrad) 173 | optimMethod(fEvalMod, curParam, opt[i]) 174 | end 175 | end 176 | end 177 | end) 178 | end 179 | end 180 | 181 | function Optim:setParameters(newParams) 182 | assert(newParams) 183 | assert(type(newParams) == 'table') 184 | local function splice(dest, src) 185 | for k,v in pairs(src) do 186 | dest[k] = v 187 | end 188 | end 189 | 190 | splice(self.originalOptState, newParams) 191 | for _,optStates in pairs(self.modulesToOptState) do 192 | for i,optState in pairs(optStates) do 193 | assert(type(optState) == 'table') 194 | splice(optState, newParams) 195 | end 196 | end 197 | end 198 | -------------------------------------------------------------------------------- /engines/fboptimengine.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2016-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'nn' 11 | require 'engines.Optim' 12 | 13 | local tnt = require 'torchnet' 14 | local argcheck = require 'argcheck' 15 | 16 | local FBOptimEngine, SGDEngine = torch.class('tnt.FBOptimEngine', 'tnt.SGDEngine', tnt) 17 | 18 | FBOptimEngine.__init = argcheck{ 19 | {name="self", type="tnt.FBOptimEngine"}, 20 | call = 21 | function(self) 22 | SGDEngine.__init(self) 23 | end 24 | } 25 | 26 | FBOptimEngine.train = argcheck{ 27 | {name="self", type="tnt.FBOptimEngine"}, 28 | {name="network", type="nn.Module"}, 29 | {name="criterion", type="nn.Criterion"}, 30 | {name="iterator", type="tnt.DatasetIterator"}, 31 | {name="maxepoch", type="number", default=1000}, 32 | {name="optimMethod", type="function"}, 33 | {name="config", type="table", opt=true}, 34 | call = 35 | function(self, network, criterion, iterator, maxepoch, optimMethod, config) 36 | local state = { 37 | network = network, 38 | criterion = criterion, 39 | iterator = iterator, 40 | maxepoch = maxepoch, 41 | optimMethod = optimMethod, 42 | optimizer = nn.Optim(network, config), 43 | config = config, 44 | sample = {}, 45 | epoch = 0, -- epoch done so far 46 | t = 0, -- samples seen so far 47 | training = true 48 | } 49 | 50 | self.hooks("onStart", state) 51 | while state.epoch < state.maxepoch do 52 | state.network:training() 53 | 54 | self.hooks("onStartEpoch", state) 55 | for sample in state.iterator() do 56 | state.sample = sample 57 | self.hooks("onSample", state) 58 | 59 | state.network:forward(sample.input) 60 | self.hooks("onForward", state) 61 | state.criterion:forward(state.network.output, sample.target) 62 | self.hooks("onForwardCriterion", state) 63 | 64 | state.network:zeroGradParameters() 65 | if state.criterion.zeroGradParameters then 66 | state.criterion:zeroGradParameters() 67 | end 68 | 69 | state.criterion:backward(state.network.output, sample.target) 70 | self.hooks("onBackwardCriterion", state) 71 | state.network:backward(sample.input, state.criterion.gradInput) 72 | self.hooks("onBackward", state) 73 | 74 | state.optimizer:updateParameters(state.optimMethod, criterion.output) 75 | state.t = state.t + 1 76 | self.hooks("onUpdate", state) 77 | end 78 | state.epoch = state.epoch + 1 79 | self.hooks("onEndEpoch", state) 80 | end 81 | self.hooks("onEnd", state) 82 | end 83 | } 84 | -------------------------------------------------------------------------------- /fbcoco.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | require 'nn' 10 | 11 | fbcoco = {} 12 | 13 | require 'testCoco.init' 14 | require 'BatchProviderBase' 15 | require 'BatchProviderROI' 16 | require 'Tester_FRCNN' 17 | require 'ImageDetect' 18 | 19 | require 'modules.ImageTransformer' 20 | require 'modules.ContextRegion' 21 | require 'modules.Foveal' 22 | require 'modules.SelectBoxes' 23 | require 'modules.ConvertFrom' 24 | require 'modules.BBoxRegressionCriterion' 25 | require 'modules.NoBackprop' 26 | require 'modules.BBoxNorm' 27 | require 'modules.ModeSwitch' 28 | require 'modules.SequentialSplitBatch' 29 | 30 | require 'modules.ModelParallelTable' 31 | 32 | return fbcoco 33 | -------------------------------------------------------------------------------- /loaders/concatloader.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | -- Combines multiple coco.DataLoaders 10 | 11 | local class = require 'class' 12 | 13 | local ConcatLoader = class('ConcatLoader') 14 | 15 | function ConcatLoader:__init(loaders) 16 | self.__loaders = loaders 17 | -- Offsets for images and annotations 18 | self.__imageOffset = {} 19 | self.categories = loaders[1].categories 20 | local i = 0, 0 21 | for _,loader in ipairs(loaders) do 22 | self.__imageOffset[loader] = i 23 | i = i + loader:nImages() 24 | end 25 | end 26 | 27 | function ConcatLoader:__getLoader(idx, sizeFn) 28 | local offset = idx 29 | for _,l in ipairs(self.__loaders) do 30 | local sz = l[sizeFn](l) 31 | 32 | if offset <= sz then 33 | return l, offset 34 | end 35 | offset = offset - sz 36 | end 37 | error('Invalid index: ' .. idx) 38 | end 39 | 40 | -- Remove indices into the data loader 41 | function ConcatLoader.__removeOffsets(res) 42 | if torch.type(res) == 'table' then 43 | if #res > 0 then 44 | for i,r in ipairs(res) do 45 | res[i] = ConcatLoader.__removeOffsets(r) 46 | end 47 | return res 48 | end 49 | res.image = nil 50 | res.idx = nil 51 | res.annotations = nil 52 | end 53 | return res 54 | end 55 | 56 | function ConcatLoader:nCategories() 57 | return #self.categories 58 | end 59 | local function delegate(fn, sizeFn) 60 | return function(self, idx) 61 | local loader, i = self:__getLoader(idx, sizeFn) 62 | local res = loader[fn](loader, i) 63 | 64 | return ConcatLoader.__removeOffsets(res) 65 | end 66 | end 67 | 68 | local function delegateSum(fn) 69 | return function(self) 70 | local res = 0 71 | for _,loader in ipairs(self.__loaders) do 72 | res = res + loader[fn](loader) 73 | end 74 | return res 75 | end 76 | end 77 | 78 | ConcatLoader.getImage = delegate('getImage', 'nImages') 79 | ConcatLoader.loadImage = delegate('loadImage', 'nImages') 80 | ConcatLoader.getAnnotation = delegate('getAnnotation', 'nAnnotations') 81 | ConcatLoader.getAnnotationsForImage = delegate('getAnnotationsForImage', 'nImages') 82 | ConcatLoader.nImages = delegateSum('nImages') 83 | ConcatLoader.nAnnotations = delegateSum('nAnnotations') 84 | 85 | return ConcatLoader 86 | -------------------------------------------------------------------------------- /loaders/dataloader.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local loader = require 'loaders.loader' 10 | 11 | local datasets = {} 12 | 13 | local path_config = require 'config' 14 | 15 | local dataset_path = './data/annotations' 16 | 17 | -- Add COCO datasets 18 | for _,name in ipairs{'train', 'val', 'test'} do 19 | local file = dataset_path .. '/instances_' .. name .. '2014.json' 20 | datasets['coco_' .. name .. '2014'] = file 21 | datasets[name] = file 22 | end 23 | for _,name in ipairs{'test2014', 'test2015-dev', 'test2015-full'} do 24 | local file = dataset_path .. '/instances_' .. name .. '.json' 25 | datasets['coco_' .. name] = file 26 | datasets[name] = file 27 | end 28 | 29 | -- Add Pascal datasets 30 | for _,name in ipairs{'train2007', 'train2012', 'val2007', 'val2012', 'test2007'} do 31 | local file = dataset_path .. '/pascal_' .. name .. '.json' 32 | datasets['pascal_' .. name] = file 33 | end 34 | 35 | -- Add ImageNet detection datasets 36 | for _,name in ipairs{'train2014','val2013'} do 37 | local file = dataset_path .. '/imagenet_' .. name .. '.json' 38 | datasets['imagenet_' .. name] = file 39 | end 40 | 41 | -- e.g. coco.DataLoader('train') or coco.DataLoader('pascal_train2007') 42 | local function DataLoader(dset) 43 | if torch.typename(dset) == 'dataLoader' then return dset end 44 | local file = datasets[dset] 45 | 46 | if not file then 47 | error('invalid dataset: ' .. tostring(dset)) 48 | end 49 | assert(path_config[dset], 'image dir not set in config.lua') 50 | 51 | return loader():load(file, path_config[dset]) 52 | end 53 | 54 | return DataLoader 55 | -------------------------------------------------------------------------------- /loaders/loader.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | -- Loader for MSCOCO annotations 10 | 11 | local ffi = require 'ffi' 12 | local image = require 'image' 13 | local coco = require 'coco' 14 | local class = require 'class' 15 | 16 | local dataset = class('loaders.dataLoader') 17 | 18 | function createAnnotationList(m) 19 | local annotation_table = {} 20 | local annotations = {} 21 | local map = {} 22 | 23 | -- Find annotations for each image 24 | for i=1,m.images.id:numel() do 25 | map[i] = {} 26 | end 27 | for i=1,m.annotations.image_id:size(1) do 28 | local imageIdx = m.annotations.image_idx[i] 29 | table.insert(map[imageIdx], i) 30 | end 31 | 32 | for i,anns in ipairs(map) do 33 | table.insert(annotation_table, {#annotations + 1, #anns}) 34 | for _,v in ipairs(anns) do 35 | table.insert(annotations, v) 36 | end 37 | end 38 | 39 | return { 40 | table = torch.IntTensor(annotation_table), 41 | annotations = torch.IntTensor(annotations), 42 | } 43 | end 44 | 45 | function dataset:load(path, image_dir) 46 | local cocoApi = coco.CocoApi(path) 47 | for k,v in pairs(cocoApi) do 48 | self[k] = v 49 | end 50 | self.images = self.data.images 51 | self.categories = {} 52 | for i,v in ipairs(cocoApi:getCatIds():totable()) do 53 | self.categories[i] = cocoApi:loadCats(v)[1].name 54 | end 55 | self.image_dir = image_dir 56 | self.annotationList = createAnnotationList(self.data) 57 | return self 58 | end 59 | 60 | -- Gets image properties as a table: 61 | -- width, height, file_name, annotations 62 | function dataset:getImage(idx) 63 | return { 64 | width = self.images.width[idx], 65 | height = self.images.height[idx], 66 | id = self.images.id[idx], 67 | file_name = ffi.string(self.images.file_name[idx]), 68 | annotations = self:getImageAnnotations(idx), 69 | image_dir = self.images.image_dir and self.images.image_dir[idx], 70 | idx = idx, 71 | } 72 | end 73 | 74 | -- Gets the image data as a ByteTensor 75 | function dataset:loadImage(idx) 76 | local metadata = self:getImage(idx) 77 | local dir = self.image_dir or self.image_dir[metadata.image_dir] 78 | local path = paths.concat(dir, metadata.file_name) 79 | return image.load(path, 3, 'double') 80 | end 81 | 82 | -- Gets annotations: 83 | -- bbox, polygons/rle, category, area, image 84 | function dataset:getAnnotation(idx) 85 | local a = self.data.annotations 86 | assert(idx <= a.id:numel(), 'no annotation for '..idx) 87 | local iscrowd = a.iscrowd[idx] == 1 88 | 89 | local annotation = { 90 | bbox = a.bbox[idx], 91 | image = a.image_idx[idx], 92 | area = a.area[idx], 93 | category = a.category_idx[idx], 94 | iscrowd = iscrowd, 95 | idx = idx, 96 | difficult = a.ignore and a.ignore[idx] or 0, 97 | } 98 | 99 | return annotation 100 | end 101 | 102 | -- Category names 103 | function dataset:categoryNames() 104 | local names = {} 105 | for _,cat in ipairs(self.categories) do 106 | table.insert(names, cat.name) 107 | end 108 | return names 109 | end 110 | 111 | -- Total number of categories (i.e. 80) 112 | function dataset:nCategories() 113 | return #self.categories 114 | end 115 | 116 | -- Total number of categories (i.e. 80) 117 | function dataset:nAnnotations() 118 | return self.annotations.image:size(1) 119 | end 120 | 121 | -- Total number of categories (i.e. 80) 122 | function dataset:nImages() 123 | return self.images.id:size(1) 124 | end 125 | 126 | -- Random annotation for a given category 127 | function dataset:randomAnnotation(category) 128 | local list = self.classListSample[category] 129 | local annotationIdx = list[math.ceil(torch.uniform() * list:nElement())] 130 | 131 | return self:getAnnotation(annotationIdx) 132 | end 133 | 134 | -- Indices of all labeled annotations for a given image 135 | function dataset:getImageAnnotations(idx) 136 | local offset, len = table.unpack(self.annotationList.table[idx]:totable()) 137 | if len == 0 then return {} end 138 | return self.annotationList.annotations:narrow(1, offset, len):totable() 139 | end 140 | 141 | -- All annotations for a given image 142 | function dataset:getAnnotationsForImage(idx) 143 | local tbl = {} 144 | for _,annIdx in ipairs(self:getImageAnnotations(idx)) do 145 | table.insert(tbl, self:getAnnotation(annIdx)) 146 | end 147 | return tbl 148 | end 149 | 150 | return dataset 151 | -------------------------------------------------------------------------------- /loaders/narrowloader.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | -- View of subset of coco.DataLoader 10 | 11 | local class = require 'class' 12 | 13 | local NarrowLoader = class('NarrowLoader') 14 | local ConcatLoader = require 'loaders.concatloader' 15 | 16 | function NarrowLoader:__init(loader, start, len) 17 | assert(start > 0 and start <= loader:nImages(), 'invalid start: ' .. start) 18 | assert(len > 0 and start + len - 1 <= loader:nImages(), 'invalid len: ' .. len) 19 | self.__loader = loader 20 | self.__start = start 21 | self.__len = len 22 | self.categories = loader.categories 23 | end 24 | 25 | local function delegate(name) 26 | return function(self, idx) 27 | assert(idx >= 1 and idx <= self.__len, 'invalid index: ' .. idx) 28 | local res = self.__loader[name](self.__loader, idx + self.__start - 1) 29 | 30 | return ConcatLoader.__removeOffsets(res) 31 | end 32 | end 33 | 34 | NarrowLoader.getImage = delegate('getImage') 35 | NarrowLoader.loadImage = delegate('loadImage') 36 | NarrowLoader.getAnnotationsForImage = delegate('getAnnotationsForImage') 37 | 38 | function NarrowLoader:getAnnotation(idx) 39 | local res = self.__loader:getAnnotation(idx) 40 | return ConcatLoader.__removeOffsets(res) 41 | end 42 | 43 | function NarrowLoader:nImages() 44 | return self.__len 45 | end 46 | 47 | function NarrowLoader:nCategories() 48 | return #self.categories 49 | end 50 | 51 | function NarrowLoader:nAnnotations() 52 | return self.__loader:nAnnotations() 53 | end 54 | 55 | return NarrowLoader 56 | -------------------------------------------------------------------------------- /models/alexnet.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | require 'inn' 10 | require 'cudnn' 11 | require 'fbcoco' 12 | local utils = paths.dofile'model_utils.lua' 13 | 14 | local data = torch.load'data/models/imagenet_pretrained_alexnet.t7' 15 | local features = utils.safe_unpack(data.features) 16 | local top = utils.safe_unpack(data.top) 17 | 18 | local model = nn.Sequential() 19 | :add(nn.ParallelTable() 20 | :add(utils.makeDataParallel(features)) 21 | :add(nn.Identity()) 22 | ) 23 | :add(inn.ROIPooling(6,6,1/16)) 24 | :add(nn.View(-1):setNumInputDims(3)) 25 | :add(top) 26 | :add(utils.classAndBBoxLinear(4096)) 27 | 28 | model:cuda() 29 | 30 | utils.testModel(model) 31 | 32 | return {model, utils.RossTransformer()} 33 | -------------------------------------------------------------------------------- /models/inceptionv3.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | -- fine-tuning of pretrained inception-v3 trained by google and converted to 10 | -- torch using https://github.com/Moodstocks/inception-v3.torch 11 | 12 | require 'cudnn' 13 | require 'cunn' 14 | require 'inn' 15 | require 'fbcoco' 16 | inn.utils = require 'inn.utils' 17 | local utils = paths.dofile'model_utils.lua' 18 | 19 | local net = torch.load'./data/models/inceptionv3.t7' 20 | 21 | local input = torch.randn(1,3,299,299):cuda() 22 | local output1 = net:forward(input):clone() 23 | utils.BNtoFixed(net, true) 24 | local output2 = net:forward(input):clone() 25 | assert((output1 - output2):abs():max() < 1e-5) 26 | 27 | local features = nn.Sequential() 28 | local classifier = nn.Sequential() 29 | 30 | for i=1,25 do features:add(net:get(i)) end 31 | for i=26,30 do classifier:add(net:get(i)) end 32 | 33 | utils.testSurgery(input, utils.disableFeatureBackprop, features, 16) 34 | utils.testSurgery(input, inn.utils.foldBatchNorm, features:findModules'nn.NoBackprop'[1]) 35 | 36 | local model = nn.Sequential() 37 | :add(nn.ParallelTable() 38 | :add(utils.makeDataParallel(features)) 39 | :add(nn.Identity()) 40 | ) 41 | :add(inn.ROIPooling(17,17):setSpatialScale(17/299)) 42 | :add(utils.makeDataParallel(classifier)) 43 | :add(utils.classAndBBoxLinear(2048)) 44 | 45 | model:cuda() 46 | model.input_size = 299 -- for utils.testModel 47 | 48 | utils.testModel(model) 49 | 50 | return {model, fbcoco.ImageTransformer({1,1,1},nil,2)} 51 | -------------------------------------------------------------------------------- /models/model_utils.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local generateGraph = require 'optnet.graphgen' 10 | -- local iterm = require 'iterm' 11 | -- require 'iterm.dot' 12 | 13 | local utils = {} 14 | 15 | function utils.makeDataParallel(module, nGPU) 16 | nGPU = nGPU or ((opt and opt.train_nGPU) or 1) 17 | if nGPU > 1 then 18 | local dpt = nn.DataParallelTable(1) -- true? 19 | local cur_dev = cutorch.getDevice() 20 | for i = 1, nGPU do 21 | cutorch.setDevice(i) 22 | dpt:add(module:clone():cuda(), i) 23 | end 24 | cutorch.setDevice(cur_dev) 25 | return dpt 26 | else 27 | return nn.Sequential():add(module) 28 | end 29 | end 30 | 31 | function utils.makeDPParallelTable(module, nGPU) 32 | if nGPU > 1 then 33 | local dpt = nn.DPParallelTable() 34 | local cur_dev = cutorch.getDevice() 35 | for i = 1, nGPU do 36 | cutorch.setDevice(i) 37 | dpt:add(module:clone():cuda(), i) 38 | end 39 | cutorch.setDevice(cur_dev) 40 | return dpt 41 | else 42 | return nn.ParallelTable():add(module) 43 | end 44 | end 45 | 46 | -- returns a new Linear layer with less output neurons 47 | function utils.compress(layer, n) 48 | local W = layer.weight 49 | local U,S,V = torch.svd(W:t():float()) 50 | local new = nn.Linear(W:size(2), n):cuda() 51 | new.weight:t():copy(U:narrow(2,1,n) * torch.diag(S:narrow(1,1,n)) * V:narrow(1,1,n):narrow(2,1,n)) 52 | new.bias:zero() 53 | return new 54 | end 55 | 56 | -- returns a Sequential of 2 Linear layers, one biasless with U*diag(S) and one 57 | -- with V and original bias. L is the number of components to keep. 58 | function utils.SVDlinear(layer, L) 59 | local W = layer.weight:double() 60 | local b = layer.bias:double() 61 | 62 | local K, N = W:size(1), W:size(2) 63 | 64 | local U, S, V = torch.svd(W:t(), 'A') 65 | 66 | local US = U:narrow(2,1,L) * torch.diag(S:narrow(1,1,L)) 67 | local Vt = V:narrow(2,1,L) 68 | 69 | local L1 = nn.LinearNB(N, L) 70 | L1.weight:copy(US:t()) 71 | 72 | local L2 = nn.Linear(L, K) 73 | L2.weight:copy(Vt) 74 | L2.bias:copy(b) 75 | 76 | return nn.Sequential():add(L1):add(L2) 77 | end 78 | 79 | 80 | function utils.testSurgery(input, f, net, ...) 81 | local output1 = net:forward(input):clone() 82 | f(net,...) 83 | local output2 = net:forward(input):clone() 84 | print((output1 - output2):abs():max()) 85 | assert((output1 - output2):abs():max() < 1e-5) 86 | end 87 | 88 | 89 | function utils.removeDropouts(net) 90 | net:replace(function(x) 91 | return torch.typename(x):find'nn.Dropout' and nn.Identity() or x 92 | end) 93 | end 94 | 95 | 96 | function utils.disableFeatureBackprop(features, maxLayer) 97 | local noBackpropModules = nn.Sequential() 98 | for i = 1,maxLayer do 99 | noBackpropModules:add(features.modules[1]) 100 | features:remove(1) 101 | end 102 | features:insert(nn.NoBackprop(noBackpropModules):cuda(), 1) 103 | end 104 | 105 | function utils.classAndBBoxLinear(N, N2) 106 | local class_linear = nn.Linear(N,opt and opt.num_classes or 21):cuda() 107 | class_linear.weight:normal(0,0.01) 108 | class_linear.bias:zero() 109 | 110 | local bbox_linear = nn.Linear(N2 or N,(opt and opt.num_classes or 21) * 4):cuda() 111 | bbox_linear.weight:normal(0,0.001) 112 | bbox_linear.bias:zero() 113 | 114 | if N2 then 115 | return nn.ParallelTable():add(class_linear):add(bbox_linear):cuda() 116 | else 117 | return nn.ConcatTable():add(class_linear):add(bbox_linear):cuda() 118 | end 119 | end 120 | 121 | function utils.testModel(model) 122 | input_size = model.input_size or 224 123 | print(model) 124 | model:training() 125 | local batchSz = opt and opt.images_per_batch or 2 126 | local boxes = torch.Tensor(batchSz, 5) 127 | for i = 1, batchSz do 128 | boxes[i]:copy(torch.Tensor({i,1,1,100,100})) 129 | end 130 | local input = {torch.CudaTensor(batchSz,3,input_size,input_size),boxes:cuda()} 131 | local output = model:forward(input) 132 | -- iterm.dot(generateGraph(model, input), opt and opt.save_folder..'/graph.pdf' or paths.tmpname()..'.pdf') 133 | print{output} 134 | print{model:backward(input,output)} 135 | end 136 | 137 | -- used in AlexNet and VGG models trained by Ross 138 | function utils.RossTransformer() 139 | return fbcoco.ImageTransformer({102.9801,115.9465,122.7717}, nil, 255, {3,2,1}) 140 | end 141 | 142 | -- used in ResNet and facebook inceptions 143 | function utils.ImagenetTransformer() 144 | return fbcoco.ImageTransformer( 145 | { -- mean 146 | 0.48462227599918, 147 | 0.45624044862054, 148 | 0.40588363755159, 149 | }, 150 | { -- std 151 | 0.22889466674951, 152 | 0.22446679341259, 153 | 0.22495548344775, 154 | }) 155 | end 156 | 157 | function utils.normalizeBBoxRegr(model, meanstd) 158 | if #model:findModules('nn.BBoxNorm') == 0 then 159 | -- normalize the bbox regression 160 | local regression_layer = model:get(#model.modules):get(2) 161 | if torch.type(regression_layer) == 'nn.Sequential' then 162 | regression_layer = regression_layer:get(#regression_layer.modules) 163 | end 164 | assert(torch.type(regression_layer) == 'nn.Linear') 165 | 166 | local mean_hat = torch.repeatTensor(meanstd.mean,1,opt.num_classes):cuda() 167 | local sigma_hat = torch.repeatTensor(meanstd.std,1,opt.num_classes):cuda() 168 | 169 | regression_layer.weight:cdiv(sigma_hat:t():expandAs(regression_layer.weight)) 170 | regression_layer.bias:add(-mean_hat):cdiv(sigma_hat) 171 | 172 | utils.addBBoxNorm(model, meanstd) 173 | end 174 | end 175 | 176 | function utils.addBBoxNorm(model, meanstd) 177 | if #model:findModules('nn.BBoxNorm') == 0 then 178 | model:add(nn.ParallelTable() 179 | :add(nn.Identity()) 180 | :add(nn.BBoxNorm(meanstd.mean, meanstd.std)):cuda()) 181 | end 182 | end 183 | 184 | function utils.vggSetPhase2(model) 185 | assert(model.phase == 1) 186 | local dpt = model.modules[1].modules[1] 187 | for i = 1, #dpt.modules do 188 | assert(torch.type(dpt.modules[i]) == 'nn.NoBackprop') 189 | dpt.modules[i] = dpt.modules[i].modules[1] 190 | utils.disableFeatureBackprop(dpt.modules[i], 10) 191 | end 192 | model.phase = phase 193 | print("Switched model to phase 2") 194 | print(model) 195 | end 196 | 197 | function utils.vggSetPhase2_outer(model) 198 | assert(model.phase == 1) 199 | model.modules[1].modules[1] = model.modules[1].modules[1].modules[1] 200 | local dpt = model.modules[1].modules[1] 201 | for i = 1, #dpt.modules do 202 | utils.disableFeatureBackprop(dpt.modules[i], 10) 203 | end 204 | model.phase = phase 205 | print("Switched model to phase 2") 206 | print(model) 207 | end 208 | 209 | function utils.conv345Combine(isNormalized, useConv3, useConv4, initCopyConv5) 210 | local totalFeat = 0 211 | 212 | local function make1PoolingLayer(idx, nFeat, spatialScale, normFactor) 213 | local pool1 = nn.Sequential() 214 | :add(nn.ParallelTable():add(nn.SelectTable(idx)):add(nn.Identity())) 215 | :add(inn.ROIPooling(7,7,spatialScale)) 216 | if isNormalized then 217 | pool1:add(nn.View(-1, nFeat*7*7)) 218 | :add(nn.Normalize(2)) 219 | :add(nn.Contiguous()) 220 | :add(nn.View(-1, nFeat, 7, 7)) 221 | else 222 | pool1:add(nn.MulConstant(normFactor)) 223 | end 224 | totalFeat = totalFeat + nFeat 225 | return pool1 226 | end 227 | 228 | local pooling_layer = nn.ConcatTable() 229 | pooling_layer:add(make1PoolingLayer(1, 512, 1/16, 1)) -- conv5 230 | if useConv4 then 231 | pooling_layer:add(make1PoolingLayer(2, 512, 1/8, 1/30)) -- conv4 232 | end 233 | if useConv3 then 234 | pooling_layer:add(make1PoolingLayer(3, 256, 1/4, 1/200)) -- conv3 235 | end 236 | local pooling_join = nn.Sequential() 237 | :add(pooling_layer) 238 | :add(nn.JoinTable(2)) 239 | if isNormalized then 240 | pooling_join:add(nn.MulConstant(1000)) 241 | end 242 | local conv_mix = cudnn.SpatialConvolution(totalFeat, 512, 1, 1, 1, 1) 243 | if initCopyConv5 then 244 | conv_mix.weight:zero() 245 | conv_mix.weight:narrow(2, 1, 512):copy(torch.eye(512)) -- initialize to just copy conv5 246 | end 247 | pooling_join:add(conv_mix) 248 | pooling_join:add(nn.View(-1):setNumInputDims(3)) 249 | 250 | return pooling_join 251 | end 252 | 253 | -- workaround for bytecode incompat functions 254 | function utils.safe_unpack(self) 255 | if self.unpack and self.model then 256 | return self:unpack() 257 | else 258 | local model = self.model 259 | for k,v in ipairs(model:listModules()) do 260 | if v.weight and not v.gradWeight then 261 | v.gradWeight = v.weight:clone() 262 | v.gradBias = v.bias:clone() 263 | end 264 | end 265 | return model 266 | end 267 | end 268 | 269 | function utils.load(path) 270 | local data = torch.load(path) 271 | return data.unpack and data:unpack() or data 272 | end 273 | 274 | -- takes a model, removes last classification layer and inserts integral loss 275 | function utils.integral(model) 276 | local top_cat = model:get(#model.modules) 277 | model:remove(#model.modules) 278 | 279 | assert(torch.type(top_cat) == 'nn.ConcatTable' or 280 | torch.type(top_cat) == 'nn.ParallelTable') 281 | local is_parallel = torch.type(top_cat) == 'nn.ParallelTable' 282 | 283 | local new_cl = nn.ConcatTable() 284 | for i=1,opt.nDonkeys do 285 | new_cl:add(top_cat:get(1):clone()) 286 | end 287 | local new_top = is_parallel and nn.ParallelTable() or nn.ConcatTable() 288 | new_top:add(new_cl):add(top_cat:get(2)) 289 | model:add(new_top) 290 | 291 | integral_selector = nn.SelectTable(1) 292 | local train_branch = nn.ParallelTable() 293 | :add(integral_selector) 294 | :add(nn.Identity()) 295 | 296 | local softmaxes = nn.ParallelTable() 297 | for i=1,opt.nDonkeys do 298 | softmaxes:add( 299 | nn.Sequential() 300 | :add(nn.SoftMax()) 301 | :add(nn.View(1,-1,opt.num_classes)) 302 | ) 303 | end 304 | 305 | local eval_branch = nn.Sequential() 306 | :add(nn.ParallelTable() 307 | :add(nn.Sequential() 308 | :add(softmaxes) 309 | :add(nn.JoinTable(1)) 310 | :add(nn.Mean(1)) 311 | ) 312 | :add(nn.Identity()) 313 | ) 314 | model.noSoftMax = true 315 | model:add(nn.ModeSwitch(train_branch, eval_branch)) 316 | return {integral_selector} 317 | end 318 | 319 | return utils 320 | -------------------------------------------------------------------------------- /models/multipathnet.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | require 'xlua' 10 | require 'inn' 11 | require 'cudnn' 12 | require 'fbcoco' 13 | local utils = paths.dofile'model_utils.lua' 14 | 15 | local model_opt = xlua.envparams{ 16 | model_conv345_norm = true, 17 | model_het = true, 18 | model_foveal_exclude = -1, 19 | } 20 | 21 | print("model_opt") 22 | print(model_opt) 23 | 24 | local N = 4 25 | 26 | local data = torch.load'data/models/imagenet_pretrained_vgg.t7' 27 | local features = utils.safe_unpack(data.features) 28 | local classifier = utils.safe_unpack(data.top) 29 | 30 | local model = nn.Sequential() 31 | 32 | for k,v in ipairs(classifier:findModules'nn.Dropout') do v.inplace = true end 33 | 34 | local skip_features = nn.Sequential() 35 | for i = 1, 16 do 36 | skip_features:add(features:get(i)) 37 | end 38 | local conv4 = nn.Sequential() 39 | for i = 17, 23 do 40 | conv4:add(features:get(i)) 41 | end 42 | 43 | local conv5 = nn.Sequential() 44 | for i = 24, 30 do 45 | conv5:add(features:get(i)) 46 | end 47 | 48 | skip_features:add(nn.ConcatTable() 49 | :add(conv4) 50 | :add(nn.Identity())) 51 | 52 | skip_features:add(nn.ParallelTable() 53 | :add(nn.ConcatTable() 54 | :add(conv5) 55 | :add(nn.Identity())) 56 | :add(nn.Identity())) 57 | 58 | skip_features:add(nn.FlattenTable()) 59 | 60 | model:add(nn.ParallelTable() 61 | :add(nn.NoBackprop(utils.makeDataParallel(skip_features))) 62 | :add(nn.Identity())) 63 | 64 | model:add(nn.ParallelTable() 65 | :add(nn.Identity()) 66 | :add(nn.Sequential() 67 | :add(nn.Foveal()) 68 | :add(nn.View(-1,N,5)) 69 | :add(nn.Transpose({1,2})))) 70 | 71 | -- local towers = nn.ConcatTable() 72 | 73 | local nGPU = opt and opt.train_nGPU or 4 74 | local regions = nn.ModelParallelTable(2) 75 | local oldDev = cutorch.getDevice() 76 | local dev = 1 77 | local Nreg = 0 78 | for i=1,N do 79 | -- local dev = i % nGPU 80 | -- dev = (dev==0) and nGPU or dev 81 | if i ~= model_opt.model_foveal_exclude then 82 | cutorch.setDevice(dev) 83 | print('dev', i, dev) 84 | 85 | local region_instance = nn.Sequential() 86 | region_instance:add(nn.ParallelTable():add(nn.Identity()):add(nn.Select(1,i))) 87 | region_instance:add(utils.conv345Combine( 88 | model_opt.model_conv345_norm, i == 1, i <= 3, not model_opt.model_conv345_norm)) 89 | region_instance:add(classifier:clone()) 90 | region_instance:float():cuda() 91 | regions:add(region_instance, dev) 92 | 93 | dev = dev + 1 94 | dev = (dev > nGPU) and 1 or dev 95 | Nreg = Nreg +1 96 | end 97 | end 98 | 99 | if model_opt.model_het then 100 | -- ooh, doing something weird here to avoid OOM 101 | cutorch.setDevice(nGPU) 102 | local region_instance = nn.Sequential() 103 | region_instance:add(nn.ParallelTable():add(nn.Identity()):add(nn.Select(1,2))) 104 | region_instance:add(utils.conv345Combine( 105 | model_opt.model_conv345_norm, true, true, not model_opt.model_conv345_norm)) 106 | 107 | region_instance:add(classifier:clone()) 108 | region_instance:float():cuda() 109 | 110 | regions:add(region_instance, nGPU) 111 | end 112 | cutorch.setDevice(oldDev) 113 | model:add(regions) 114 | 115 | if model_opt.model_het then 116 | model:add(nn.ConcatTable():add(nn.Narrow(2, 1, Nreg*4096)):add(nn.Narrow(2, Nreg*4096+1, 4096))) 117 | model:add(utils.classAndBBoxLinear(Nreg*4096, 4096)) 118 | else 119 | model:add(utils.classAndBBoxLinear(Nreg*4096)) 120 | end 121 | model:cuda() 122 | 123 | model.phase = 1 124 | model.setPhase2 = utils.vggSetPhase2_outer 125 | 126 | utils.testModel(model) 127 | 128 | return {model, utils.RossTransformer()} 129 | -------------------------------------------------------------------------------- /models/nin.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | -- Fast Network-In-Network model from https://gist.github.com/szagoruyko/0f5b4c5e2d2b18472854 10 | 11 | require 'inn' 12 | require 'cudnn' 13 | require 'fbcoco' 14 | inn.utils = require 'inn.utils' 15 | local utils = paths.dofile'model_utils.lua' 16 | 17 | local net = utils.load'./data/models/model_bn_final.t7' 18 | net:cuda():evaluate() 19 | cudnn.convert(net, cudnn) 20 | 21 | local input = torch.randn(1,3,224,224):cuda() 22 | 23 | utils.testSurgery(input, utils.BNtoFixed, net, true) 24 | 25 | local features = nn.Sequential() 26 | local classifier = nn.Sequential() 27 | 28 | for i=1,29 do features:add(net:get(i)) end 29 | for i=31,40 do classifier:add(net:get(i)) end 30 | classifier:add(nn.View(-1):setNumInputDims(3)) 31 | 32 | utils.testSurgery(input, utils.disableFeatureBackprop, features, 10) 33 | utils.testSurgery(input, inn.utils.foldBatchNorm, features:findModules'nn.NoBackprop'[1]) 34 | utils.testSurgery(input, utils.BNtoFixed, features, true) 35 | utils.testSurgery(input, utils.BNtoFixed, net, true) 36 | 37 | local model = nn.Sequential() 38 | :add(nn.ParallelTable() 39 | :add(utils.makeDataParallel(features)) 40 | :add(nn.Identity()) 41 | ) 42 | :add(inn.ROIPooling(7,7,1/16)) 43 | :add(classifier) 44 | :add(utils.classAndBBoxLinear(1024)) 45 | 46 | model:cuda() 47 | 48 | utils.testModel(model) 49 | 50 | return {model, utils.ImagenetTransformer()} 51 | -------------------------------------------------------------------------------- /models/resnet.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | require 'cudnn' 10 | require 'cunn' 11 | require 'fbcoco' 12 | require 'xlua' 13 | require 'inn' 14 | inn.utils = require 'inn.utils' 15 | local utils = paths.dofile'model_utils.lua' 16 | 17 | local model_opt = { 18 | resnet_path = './data/models/resnet/resnet-18.t7' 19 | } 20 | model_opt = xlua.envparams(model_opt) 21 | print(model_opt) 22 | if opt then for k,v in pairs(model_opt) do opt[k] = v end end 23 | 24 | local function loadResNet(model_path) 25 | local net = torch.load(model_path) 26 | net:cuda():evaluate() 27 | 28 | local features = nn.Sequential() 29 | for i=1,7 do features:add(net:get(i)) end 30 | 31 | local input = torch.randn(1,3,224,224):cuda() 32 | 33 | utils.testSurgery(input, utils.disableFeatureBackprop, features, 5) 34 | utils.testSurgery(input, inn.utils.foldBatchNorm, features:findModules'nn.NoBackprop'[1]) 35 | utils.testSurgery(input, inn.utils.BNtoFixed, features, true) 36 | utils.testSurgery(input, inn.utils.BNtoFixed, net, true) 37 | 38 | local classifier = nn.Sequential() 39 | for i=8,10 do classifier:add(net:get(i)) end 40 | 41 | local output_dim = classifier.output:size(2) 42 | 43 | local model = nn.Sequential() 44 | :add(nn.ParallelTable() 45 | :add(utils.makeDataParallel(features)) 46 | :add(nn.Identity()) 47 | ) 48 | :add(inn.ROIPooling(14,14,1/16)) 49 | :add(utils.makeDataParallel(classifier)) 50 | :add(utils.classAndBBoxLinear(output_dim)) 51 | 52 | model:cuda() 53 | 54 | utils.testModel(model) 55 | 56 | return {model, utils.ImagenetTransformer()} 57 | end 58 | 59 | return loadResNet(model_opt.resnet_path) 60 | -------------------------------------------------------------------------------- /models/vgg.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | require 'inn' 10 | require 'cudnn' 11 | require 'fbcoco' 12 | local utils = paths.dofile'model_utils.lua' 13 | 14 | local data = torch.load'data/models/imagenet_pretrained_vgg.t7' 15 | local features = utils.safe_unpack(data.features) 16 | local top = utils.safe_unpack(data.top) 17 | 18 | -- kill first 4 conv layers; the convolutions are at {1,3,6,8,11,13} 19 | utils.disableFeatureBackprop(features, 10) 20 | 21 | for k,v in ipairs(top:findModules'nn.Dropout') do v.inplace = true end 22 | 23 | local model = nn.Sequential() 24 | :add(nn.ParallelTable() 25 | :add(utils.makeDataParallel(features)) 26 | :add(nn.Identity()) 27 | ) 28 | :add(inn.ROIPooling(7,7,1/16)) 29 | :add(nn.View(-1):setNumInputDims(3)) 30 | :add(top) 31 | :add(utils.classAndBBoxLinear(4096)) 32 | 33 | model:cuda() 34 | 35 | utils.testModel(model) 36 | 37 | return {model, utils.RossTransformer()} 38 | -------------------------------------------------------------------------------- /modules/BBoxNorm.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local BBoxNorm, parent = torch.class('nn.BBoxNorm','nn.Module') 10 | 11 | function BBoxNorm:__init(mean, std) 12 | assert(mean and std) 13 | parent.__init(self) 14 | self.mean = mean 15 | self.std = std 16 | end 17 | 18 | function BBoxNorm:updateOutput(input) 19 | assert(input:dim() == 2 and input:size(2) % 4 == 0) 20 | self.output:set(input) 21 | if not self.train then 22 | if not input:isContiguous() then 23 | self._output = self._output or input.new() 24 | self._output:resizeAs(input):copy(input) 25 | self.output = self._output 26 | end 27 | 28 | local output = self.output:view(-1, 4) 29 | output:cmul(self.std:expandAs(output)):add(self.mean:expandAs(output)) 30 | end 31 | return self.output 32 | end 33 | 34 | function BBoxNorm:updateGradInput(input, gradOutput) 35 | assert(self.train, 'cannot updateGradInput in evaluate mode') 36 | self.gradInput = gradOutput 37 | return self.gradInput 38 | end 39 | 40 | function BBoxNorm:clearState() 41 | nn.utils.clear(self, '_output') 42 | return parent.clearState(self) 43 | end 44 | -------------------------------------------------------------------------------- /modules/BBoxRegressionCriterion.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local BBoxRegressionCriterion, parent = torch.class('nn.BBoxRegressionCriterion', 'nn.SmoothL1Criterion') 10 | 11 | function BBoxRegressionCriterion:updateOutput(inputs, targets) 12 | local target_classes = targets[1] -- B 13 | local target_boxes = targets[2] -- Bx84 14 | -- inputs : Bx84 15 | 16 | self.sizeAverage = false 17 | 18 | target_classes = torch.type(target_classes) == 'torch.CudaTensor' and target_classes or target_classes:long() 19 | 20 | local B = inputs:size(1) 21 | local N = target_boxes:size(2)/4 22 | 23 | self._buffer1 = self._buffer1 or inputs.new() 24 | self._buffer2 = self._buffer2 or inputs.new() 25 | self._buffer1:resize(B,N):zero() 26 | self._buffer1:scatter(2,target_classes:view(B,1),1) 27 | self._buffer2:resizeAs(inputs):copy(self._buffer1:view(B,N,1):expand(B,N,4)) 28 | self._buffer2:narrow(2,1,4):zero() 29 | self._buffer2:cmul(inputs) 30 | 31 | parent.updateOutput(self, self._buffer2, target_boxes) 32 | self.output = self.output / B 33 | return self.output 34 | end 35 | 36 | function BBoxRegressionCriterion:updateGradInput(inputs, targets) 37 | local B = inputs:size(1) 38 | parent.updateGradInput(self, self._buffer2, targets[2]) 39 | return self.gradInput:div(B) 40 | end 41 | -------------------------------------------------------------------------------- /modules/ContextRegion.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local Context, parent = torch.class('nn.ContextRegion','nn.Module') 10 | 11 | -- Takes (Bx5) input in format {id,x1,x2,y1,y2} 12 | -- and increases or decreases bounding boxes by 'scale' parameter 13 | 14 | function Context:__init(scale) 15 | parent.__init(self) 16 | local a = (1 + scale) / 2 17 | local b = (1 - scale) / 2 18 | self.tr = torch.Tensor{ 19 | {a, 0, b, 0}, 20 | {0, a, 0, b}, 21 | {b, 0, a, 0}, 22 | {0, b, 0, a}, 23 | } 24 | end 25 | 26 | function Context:updateOutput(input) 27 | assert(input:nDimension() == 2) 28 | assert(input:size(2) == 5) 29 | self.output:resizeAs(input):copy(input) 30 | self.output:narrow(2,2,4):mm(input:narrow(2,2,4), self.tr) 31 | return self.output 32 | end 33 | 34 | function Context:updateGradInput(input, gradOutput) 35 | self.gradInput:resizeAs(input):zero() 36 | return self.gradInput 37 | end 38 | 39 | function Context:__tostring__() 40 | return torch.type(self)..'('..(self.tr[1][1]*2 - 1)..')' 41 | end 42 | -------------------------------------------------------------------------------- /modules/ConvertFrom.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | require 'nn' 10 | 11 | -- same as utils.ConvertFrom 12 | -- reparametrization of source bbox -> target bbox relashionship 13 | 14 | local module, parent = torch.class('nn.ConvertFrom','nn.Module') 15 | 16 | function module:__init() 17 | parent.__init(self) 18 | self.gradInput1 = torch.Tensor() 19 | self.gradInput2 = torch.Tensor() 20 | end 21 | 22 | function module:updateOutput(input) 23 | local roi_boxes = input[1] 24 | local y = input[2] 25 | 26 | local bbox = roi_boxes:narrow(2,2,4) 27 | 28 | self.output:resizeAs(roi_boxes):copy(roi_boxes) 29 | local out = self.output:narrow(2,2,4) 30 | 31 | assert(bbox:size(2) == y:size(2)) 32 | assert(bbox:size(2) == out:size(2)) 33 | assert(bbox:size(1) == y:size(1)) 34 | assert(bbox:size(1) == out:size(1)) 35 | 36 | local xc = (bbox[{{},1}] + bbox[{{},3}]) * 0.5 37 | local yc = (bbox[{{},2}] + bbox[{{},4}]) * 0.5 38 | local w = bbox[{{},3}] - bbox[{{},1}] 39 | local h = bbox[{{},4}] - bbox[{{},2}] 40 | 41 | local xtc = torch.addcmul(xc, y[{{},1}], w) 42 | local ytc = torch.addcmul(yc, y[{{},2}], h) 43 | local wt = torch.exp(y[{{},3}]):cmul(w) 44 | local ht = torch.exp(y[{{},4}]):cmul(h) 45 | 46 | out[{{},1}] = xtc - wt * 0.5 47 | out[{{},2}] = ytc - ht * 0.5 48 | out[{{},3}] = xtc + wt * 0.5 49 | out[{{},4}] = ytc + ht * 0.5 50 | 51 | return self.output 52 | end 53 | 54 | function module:updateGradInput(input, gradOutput) 55 | self.gradInput1:resizeAs(input[1]) 56 | self.gradInput2:resizeAs(input[2]) 57 | self.gradInput = {self.gradInput1, self.gradInput2} 58 | return self.gradInput 59 | end 60 | -------------------------------------------------------------------------------- /modules/Foveal.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local Foveal, parent = torch.class('nn.Foveal','nn.Module') 10 | 11 | function Foveal:__init() 12 | parent.__init(self) 13 | end 14 | 15 | function Foveal:updateOutput(input) 16 | assert(input:nDimension() == 2) 17 | assert(input:size(2) == 5) 18 | local N = 4 19 | self.output:resize(input:size(1) * N, input:size(2)) 20 | 21 | local cinput = input:float() 22 | local coutput = self.output:float() 23 | 24 | local output_split = coutput:split(N) 25 | 26 | local function createRegion(id,x,y,w,h) 27 | return torch.FloatTensor{id,x,y,x+w,y+h} 28 | end 29 | 30 | for i=1,input:size(1) do 31 | local box = cinput[i] 32 | local id,x,y,x2,y2 = table.unpack(box:totable()) 33 | local w = x2 - x 34 | local h = y2 - y 35 | local base = output_split[i] 36 | base[1]:copy(box) 37 | base[2]:copy(createRegion(id,x-w*.25,y-h*.25,w*1.5,h*1.5)) 38 | base[3]:copy(createRegion(id,x-w*0.5,y-h*0.5,w*2.0,h*2.0)) 39 | base[4]:copy(createRegion(id,x-w*1.5,y-h*1.5,w*4.0,h*4.0)) 40 | end 41 | 42 | self.output:copy(coutput) 43 | return self.output 44 | end 45 | -------------------------------------------------------------------------------- /modules/ImageTransformer.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local ImageTransformer, parent = torch.class('fbcoco.ImageTransformer', 'nn.Module') 10 | 11 | function ImageTransformer:__init(mean,std,scale,swap) 12 | parent.__init(self) 13 | self.mean = mean 14 | self.std = std 15 | self.scale = scale or 1 16 | self.swap = swap 17 | end 18 | 19 | function ImageTransformer:updateOutput(I) 20 | assert(I:nDimension() == 3) 21 | I = self.swap and I:index(1,torch.LongTensor(self.swap)) or I:clone() 22 | if self.scale ~= 1 then 23 | I:mul(self.scale) 24 | end 25 | for i=1,3 do 26 | I[i]:add(-self.mean[i]) 27 | if self.std then 28 | I[i]:div(self.std[i]) 29 | end 30 | end 31 | self.output = I 32 | return I 33 | end 34 | -------------------------------------------------------------------------------- /modules/ModeSwitch.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local ModeSwitch, parent = torch.class('nn.ModeSwitch','nn.Container') 10 | 11 | function ModeSwitch:__init(train_module, test_module) 12 | self.train = true 13 | self.modules = {train_module, test_module} 14 | end 15 | 16 | function ModeSwitch:updateOutput(input) 17 | local active = self.train and self.modules[1] or self.modules[2] 18 | self.output = active:updateOutput(input) 19 | return self.output 20 | end 21 | 22 | function ModeSwitch:updateGradInput(input, gradOutput) 23 | if self.train then 24 | self.gradInput = self.modules[1]:updateGradInput(input, gradOutput) 25 | else 26 | error'backprop not defined in evaluate mode' 27 | end 28 | return self.gradInput 29 | end 30 | 31 | function ModeSwitch:accGradParameters(input, gradOutput) 32 | end 33 | 34 | function ModeSwitch:__tostring__() 35 | return nn.ParallelTable.__tostring__(self) 36 | end 37 | -------------------------------------------------------------------------------- /modules/ModelParallelTable.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local gpuLocalCopyBuffers = {} 10 | local baseModuleIndex = 1 -- A constant 11 | 12 | -- ***************************************************************************** 13 | -- Helper Functions 14 | -- ***************************************************************************** 15 | -- queryGPUDeviceId - Function to query a tensor or table for the 16 | -- GPUID. For tables we will search the table for CudaTensors, query their 17 | -- device and make sure the deviceIds of ALL CudaTensors are on the same GPU. 18 | local function queryGPUDeviceId(object) 19 | if torch.type(object) == 'torch.CudaTensor' then 20 | return object:getDevice() 21 | end 22 | 23 | local deviceId 24 | 25 | -- Try finding a parameter 26 | local stack = {} -- explicit stack to recurse on tables 27 | for key, param in pairs(object) do 28 | if key ~= 'modules' then 29 | stack[#stack+1] = param -- Push onto the stack 30 | end 31 | end 32 | while #stack > 0 do 33 | local param = stack[#stack]; stack[#stack] = nil -- Pop the stack 34 | if (torch.type(param) == 'table') then 35 | for i = 1, #param do stack[#stack+1] = param[i] end -- Push onto stack 36 | elseif (torch.type(param) == 'torch.CudaTensor') then 37 | if (torch.numel(param) > 0) then 38 | -- Empty tensors are always on GPU "0" 39 | local curId = param:getDevice() 40 | if deviceId == nil then 41 | deviceId = curId 42 | else 43 | assert(deviceId == curId, 44 | 'Found CudaTensor instances from different devices') 45 | end 46 | end 47 | end 48 | end 49 | 50 | return deviceId 51 | end 52 | 53 | -- Get an available GPU buffer for asyncGPUCopy. It is used when the GPU tensor 54 | -- is not contiguous. 55 | local function getBuffer() 56 | local device = cutorch.getDevice() 57 | if not gpuLocalCopyBuffers[device] then 58 | gpuLocalCopyBuffers[device] = torch.CudaTensor() 59 | end 60 | return gpuLocalCopyBuffers[device] 61 | end 62 | 63 | -- setDeviceSafe - Avoid redundant calls to setDevice 64 | local function setDevice(gpuid) 65 | if (cutorch.getDevice() ~= gpuid) then 66 | cutorch.setDevice(gpuid) 67 | end 68 | end 69 | 70 | local function equalSize(sizeTable1, sizeTable2) 71 | if (#sizeTable1 ~= #sizeTable2) then 72 | return false 73 | end 74 | for i = 1, #sizeTable1 do 75 | if sizeTable1[i] ~= sizeTable2[i] then return false end 76 | end 77 | return true 78 | end 79 | 80 | local function equalSize(sizeTable1, sizeTable2) 81 | if (#sizeTable1 ~= #sizeTable2) then 82 | return false 83 | end 84 | for i = 1, #sizeTable1 do 85 | if sizeTable1[i] ~= sizeTable2[i] then return false end 86 | end 87 | return true 88 | end 89 | 90 | -- deepTensorsCopy - perform an elementwise copy of the tensors in the nested 91 | -- table. We assume that the tables are properly initialized (ie same size and 92 | -- structure), although we will assert it. 93 | local function deepTensorsCopy(dst, src) 94 | if (torch.type(src) == 'table') then 95 | assert(torch.type(dst) == 'table' and #src == #dst) 96 | for i = 1, #src do deepTensorsCopy(dst[i], src[i]) end 97 | elseif torch.type(src):find('torch%..+Tensor') then 98 | assert(torch.type(dst):find('torch%..+Tensor')) 99 | assert(dst:isSameSizeAs(src)) 100 | dst:copy(src) 101 | else 102 | error('input must be a nested table of tensors!') 103 | end 104 | end 105 | 106 | -- deepTensorsAdd - perform an elementwise add of the tensors in the nested 107 | -- table. We assume that the tables are properly initialized (ie same size and 108 | -- structure), although we will assert it. 109 | -- 110 | -- Note: this is necessary because add() will malloc new memory on the cuda 111 | -- driver side every time we want to get new memory! Therefore, we actually 112 | -- need to copy src to the dst gpu 113 | local function deepTensorsAdd(dst, src) 114 | if (torch.type(src) == 'table') then 115 | assert(torch.type(dst) == 'table' and #src == #dst) 116 | for i = 1, #src do deepTensorsAdd(dst[i], src[i]) end 117 | elseif torch.type(src):find('torch%..+Tensor') then 118 | assert(torch.type(dst):find('torch%..+Tensor')) 119 | assert(dst:isSameSizeAs(src)) 120 | 121 | local dstGpuid = dst:getDevice() 122 | local srcGpuid = src:getDevice() 123 | local curGpuid = cutorch:getDevice() 124 | setDevice(dstGpuid) 125 | 126 | -- Copy src over to a buffer on the dst GPU 127 | local srcBufferOnDstGpu = src 128 | if (dstGpuid ~= srcGpuid) then 129 | srcBufferOnDstGpu = getBuffer() 130 | srcBufferOnDstGpu:resizeAs(src) 131 | assert(src:isContiguous()) 132 | srcBufferOnDstGpu:copy(src) 133 | end 134 | 135 | -- Perform the actual add 136 | dst:add(srcBufferOnDstGpu) 137 | if (dstGpuid ~= srcGpuid) then 138 | -- Ensures we get to keep the buffer for the duration of the add 139 | cutorch.synchronize() 140 | end 141 | 142 | setDevice(curGpuid) -- Put the GPU id back to what it was 143 | else 144 | error('input must be a nested table of tensors!') 145 | end 146 | end 147 | 148 | -- ***************************************************************************** 149 | -- ModelParallelTable 150 | -- ***************************************************************************** 151 | local ModelParallelTable, parent = torch.class('nn.ModelParallelTable', 152 | 'nn.Container') 153 | 154 | function ModelParallelTable:__init(dimension, noGradInput) 155 | parent.__init(self) 156 | if not dimension then 157 | error "must specify a dimension!" 158 | end 159 | 160 | self.dimension = dimension 161 | self.modules = {} 162 | self.gpuAssignments = {} -- Which gpuid each module sits on 163 | self.inputGpu = {} -- inputs for each gpu 164 | self.gradOutputGpu = {} -- gradOutputs for each gpu 165 | self.outputGpu = {} -- outputs for each gpu 166 | self.gradInputGpu = {} -- gradInput for each gpu 167 | self.gradInputAddBuffer = {} 168 | self.noGradInput = noGradInput or false 169 | end 170 | 171 | -- NOTE: The input should be on the FIRST added GPU device, and similarly the 172 | -- output will be on the FIRST GPU device. 173 | function ModelParallelTable:add(module, gpuid) 174 | local parameters = module:parameters() 175 | for _, param in ipairs(parameters) do 176 | assert(param:getDevice() == gpuid, param:getDevice() .. "~=" .. gpuid) 177 | end 178 | assert(gpuid <= cutorch.getDeviceCount() and gpuid >= 1) 179 | assert(#self.modules == #self.gpuAssignments) 180 | 181 | self.modules[#self.modules + 1] = module 182 | self.gpuAssignments[#self.gpuAssignments + 1] = gpuid 183 | 184 | return self 185 | end 186 | 187 | function ModelParallelTable:__tostring() 188 | return 'ModelParallelTable: ' .. #self.modules .. ' x ' .. tostring(self.modules[1]) 189 | end 190 | 191 | function ModelParallelTable:get(index) 192 | return self.modules[index] 193 | end 194 | 195 | function ModelParallelTable:updateOutput(input) 196 | local baseGpuid = self.gpuAssignments[baseModuleIndex] 197 | -- cutorch.withDevice(baseGpuid, function() print('input', input[1]:mean()) end) 198 | assert(queryGPUDeviceId(input) == baseGpuid, 'Input is not on gpu ' .. 199 | baseGpuid) 200 | 201 | local prevGpuid = cutorch.getDevice() 202 | 203 | -- distribute the input to GPUs 204 | for i = 1, #self.modules do 205 | local gpuid = self.gpuAssignments[i] 206 | -- Copy the tensors in the input nested table to the GPU with gpuid 207 | self.inputGpu[i] = self:_copyTensorRecursive( 208 | input, self.inputGpu[i], 209 | baseGpuid, gpuid 210 | ) 211 | -- cutorch.withDevice(gpuid, function() print('inputGpu', i, self.inputGpu[gpuid][1]:mean()) end) 212 | end 213 | 214 | cutorch.synchronize() 215 | 216 | -- update output for each module asynchronously 217 | for i, module in ipairs(self.modules) do 218 | local gpuid = self.gpuAssignments[i] 219 | setDevice(gpuid) 220 | self.outputGpu[i] = module:updateOutput(self.inputGpu[i]) 221 | -- cutorch.withDevice(gpuid, function() print('outputGpu', i, self.outputGpu[gpuid]:mean()) end) 222 | end 223 | 224 | cutorch.synchronize() 225 | 226 | -- concatenate the outputs to the base GPU 227 | for i = 1, #self.modules do 228 | local gpuid = self.gpuAssignments[i] 229 | -- Merge the tensors in the input nested table to the GPU with gpuid 230 | self.output = self:_concatTensorRecursive( 231 | self.outputGpu[i], self.output, 232 | gpuid, i, baseGpuid, baseModuleIndex, 233 | #self.modules 234 | ) 235 | end 236 | cutorch.synchronize() 237 | 238 | setDevice(prevGpuid) 239 | 240 | -- cutorch.withDevice(baseGpuid, function() print('output', self.output:mean()) end) 241 | return self.output 242 | end 243 | 244 | function ModelParallelTable:updateGradInput(input, gradOutput) 245 | -- We assume that updateOutput has already been called (therefore inputGpu 246 | -- has been populated) 247 | local baseGpuid = self.gpuAssignments[baseModuleIndex] 248 | -- cutorch.withDevice(baseGpuid, function() print('gradOutput', gradOutput:mean()) end) 249 | assert(queryGPUDeviceId(gradOutput) == baseGpuid, 250 | 'gradOutput is not on gpu ' .. baseGpuid) 251 | 252 | local prevGpuid = cutorch.getDevice() 253 | 254 | -- distribute the gradOutput to GPUs 255 | for i = 1, #self.modules do 256 | local gpuid = self.gpuAssignments[i] 257 | -- Split the tensors in the input nested table to the GPU with gpuid 258 | -- _distributeTensorRecursive(src,dst,srcGpuid,srcInd,dstGpuid,dstInd) 259 | self.gradOutputGpu[i] = self:_distributeTensorRecursive(gradOutput, 260 | self.gradOutputGpu[i], baseGpuid, baseGpuIndex, gpuid, i, #self.modules) 261 | end 262 | 263 | cutorch.synchronize() 264 | 265 | -- update gradInput for each module asynchronously 266 | for i, module in ipairs(self.modules) do 267 | local gpuid = self.gpuAssignments[i] 268 | setDevice(gpuid) 269 | self.gradInputGpu[i] = module:updateGradInput(self.inputGpu[i], 270 | self.gradOutputGpu[i]) 271 | end 272 | 273 | -- concatenate the outputs to the base GPU 274 | for i = 1, #self.modules do 275 | local gpuid = self.gpuAssignments[i] 276 | -- Merge the tensors in the input nested table to the GPU with gpuid 277 | self.gradInputAddBuffer[i] = self:_copyTensorRecursive(self.gradInputGpu[i], 278 | self.gradInputAddBuffer[i], gpuid, baseGpuid) 279 | end 280 | 281 | cutorch.synchronize() 282 | setDevice(baseGpuid) 283 | self.gradInput = self:_zeroTensorRecursive(self.gradInputGpu[baseGpuid], self.gradInput) 284 | for i = 1, #self.modules do 285 | self:_accumulateTensorRecursive(self.gradInputAddBuffer[i], self.gradInput) 286 | end 287 | 288 | setDevice(prevGpuid) 289 | 290 | return self.gradInput 291 | end 292 | 293 | function ModelParallelTable:accGradParameters(input, gradOutput, scale) 294 | -- We assume updateGradInput has already been called (so gradOutput has 295 | -- already been populated) 296 | local prevGpuid = cutorch.getDevice() 297 | local baseGpuid = self.gpuAssignments[baseModuleIndex] 298 | 299 | scale = scale or 1 300 | -- Calculate the gradWeight + gradBias on each sub-module 301 | for i, module in ipairs(self.modules) do 302 | local gpuid = self.gpuAssignments[i] 303 | setDevice(gpuid) 304 | module:accGradParameters(self.inputGpu[i], self.gradOutputGpu[i], 305 | scale) 306 | end 307 | 308 | cutorch.synchronize() -- We have to wait until accGradParameters has finished 309 | 310 | setDevice(prevGpuid) 311 | end 312 | 313 | function ModelParallelTable:accUpdateGradParameters(input, gradOutput, lr) 314 | error("accUpdateGradParameters not supported for ModelParallelTable.") 315 | end 316 | 317 | function ModelParallelTable:zeroGradParameters() 318 | local prevGpuid = cutorch.getDevice() 319 | for i, module in ipairs(self.modules) do 320 | setDevice(self.gpuAssignments[i]) 321 | module:zeroGradParameters() 322 | end 323 | setDevice(prevGpuid) 324 | end 325 | 326 | function ModelParallelTable:updateParameters(learningRate) 327 | error("updateParameters not supported for ModelParallelTable.") 328 | end 329 | 330 | function ModelParallelTable:share(mlp,...) 331 | error("Share not supported for ModelParallelTable.") 332 | end 333 | 334 | function ModelParallelTable:clone() 335 | error("clone not supported for ModelParallelTable.") 336 | end 337 | 338 | function ModelParallelTable:reset(stdv) 339 | local prevGpuid = cutorch.getDevice() 340 | for i, module in ipairs(self.modules) do 341 | setDevice(self.gpuAssignments[i]) 342 | module:reset(stdv) 343 | end 344 | setDevice(prevGpuid) 345 | end 346 | 347 | function ModelParallelTable:name() 348 | return 'ModelParallelTable' 349 | end 350 | 351 | function ModelParallelTable:type(typeStr) 352 | if typeStr == "torch.CudaTensor" then 353 | for i, m in ipairs(self.modules) do 354 | m:type(typeStr) 355 | end 356 | else 357 | error("ModelParallelTable only supports CudaTensor, not " .. typeStr) 358 | end 359 | end 360 | 361 | function ModelParallelTable:_getSliceRange(tensor, id, total) 362 | local outerDim = tensor:size(self.dimension) 363 | assert(outerDim % total == 0) -- FIXME get rid of this restriction 364 | local eltsPerMod = outerDim / total 365 | local rangeStart = (id - 1) * eltsPerMod + 1 366 | local rangeEnd = id * eltsPerMod 367 | 368 | return tensor:narrow(self.dimension, rangeStart, rangeEnd-rangeStart+1) 369 | end 370 | 371 | function ModelParallelTable:_copyTensorRecursive(src, dst, srcGpuid, dstGpuid) 372 | if (torch.type(src) == 'table') then 373 | if torch.type(dst) ~= 'table' or #src ~= #dst then 374 | dst = {} 375 | end 376 | 377 | -- Recurse on the table 378 | for i = 1, #src do 379 | dst[i] = self:_copyTensorRecursive(src[i], dst[i], srcGpuid, dstGpuid) 380 | end 381 | 382 | elseif torch.type(src):find('torch%..+Tensor') then 383 | if (dst == nil or torch.type(dst) ~= 'torch.CudaTensor') then 384 | -- Allocate only on startup or when input table structure changes. 385 | -- Otherwise we will just resize the tensor below. 386 | setDevice(dstGpuid) 387 | dst = torch.CudaTensor() 388 | end 389 | 390 | -- Split the tensor 391 | assert(torch.typename(src) == 'torch.CudaTensor') 392 | 393 | if not dst:isSameSizeAs(src) then 394 | setDevice(dstGpuid) 395 | dst:resizeAs(src) 396 | end 397 | 398 | dst:copy(src) 399 | else 400 | error('input must be a nested table of tensors!') 401 | end 402 | 403 | return dst 404 | end 405 | 406 | -- _distributeTensorRecursive - if the src is a tensor then the function slices 407 | -- it long self.dimension and copies each portion into each child module. 408 | -- Otherwise it does a recursive call on tables. 409 | function ModelParallelTable:_distributeTensorRecursive(src, dst, 410 | srcGpuid, srcIndex, dstGpuid, dstIndex, nModules) 411 | if (torch.type(src) == 'table') then 412 | if torch.type(dst) ~= 'table' or #src ~= #dst then 413 | dst = {} 414 | end 415 | 416 | -- Recurse on the table 417 | for i = 1, #src do 418 | dst[i] = self:_distributeTensorRecursive(src[i], dst[i], srcGpuid, 419 | srcIndex, dstGpuid, dstIndex, nModules) 420 | end 421 | 422 | elseif torch.type(src):find('torch%..+Tensor') then 423 | if (dst == nil or torch.type(dst) ~= 'torch.CudaTensor') then 424 | -- Allocate only on startup or when input table structure changes. 425 | -- Otherwise we will just resize the tensor below. 426 | setDevice(dstGpuid) 427 | dst = torch.CudaTensor() 428 | end 429 | 430 | -- Split the tensor 431 | assert(torch.typename(src) == 'torch.CudaTensor') 432 | local slice = self:_getSliceRange(src, dstIndex, nModules) 433 | 434 | if not dst:isSameSizeAs(slice) then 435 | setDevice(dstGpuid) 436 | dst:resizeAs(slice) 437 | end 438 | 439 | dst:copy(slice) 440 | else 441 | error('input must be a nested table of tensors!') 442 | end 443 | 444 | return dst 445 | end 446 | 447 | -- _concatTensorRecursive - if the src is a tensor then the function copies it 448 | -- into the dst slice along self.dimension. 449 | -- Otherwise it does a recursive call on tables. 450 | function ModelParallelTable:_concatTensorRecursive(src, dst, srcGpuid, 451 | srcIndex, dstGpuid, dstIndex, nModules) 452 | if (torch.type(src) == 'table') then 453 | if torch.type(dst) ~= 'table' or #src ~= #dst then 454 | dst = {} 455 | end 456 | 457 | -- Recurse on the table 458 | for i = 1, #src do 459 | dst[i] = self:_concatTensorRecursive(src[i], dst[i], srcGpuid, 460 | srcIndex, dstGpuid, dstIndex, nModules) 461 | end 462 | 463 | elseif torch.type(src):find('torch%..+Tensor') then 464 | if (dst == nil or torch.type(dst) ~= 'torch.CudaTensor') then 465 | -- Allocate only on startup or when input table structure changes. 466 | -- Otherwise we will just resize the tensor below. 467 | setDevice(dstGpuid) 468 | dst = torch.CudaTensor() 469 | end 470 | 471 | if (torch.numel(src) > 0) then 472 | -- Some modules return empty gradInputs if they don't actually return 473 | -- anything. 474 | local dstSize = src:size():totable() 475 | dstSize[self.dimension] = dstSize[self.dimension] * nModules 476 | if not (equalSize(dst:size():totable(), dstSize)) then 477 | assert(srcIndex == 1) 478 | setDevice(dstGpuid) 479 | dst:resize(unpack(dstSize)) 480 | end 481 | 482 | -- Split the tensor 483 | assert(torch.typename(src) == 'torch.CudaTensor') 484 | local slice = self:_getSliceRange(dst, srcIndex, nModules) 485 | slice:copy(src) 486 | end 487 | else 488 | error('input must be a nested table of tensors!') 489 | end 490 | 491 | return dst 492 | end 493 | 494 | function ModelParallelTable:_zeroTensorRecursive(src, dst) 495 | if (torch.type(src) == 'table') then 496 | if torch.type(dst) ~= 'table' or #src ~= #dst then 497 | dst = {} 498 | end 499 | 500 | -- Recurse on the table 501 | for i = 1, #src do 502 | dst[i] = self:_zeroTensorRecursive(src[i], dst[i]) 503 | end 504 | 505 | elseif torch.type(src):find('torch%..+Tensor') then 506 | if (dst == nil or torch.type(dst) ~= 'torch.CudaTensor') then 507 | dst = torch.CudaTensor() 508 | end 509 | 510 | -- Split the tensor 511 | assert(torch.typename(src) == 'torch.CudaTensor') 512 | 513 | if not dst:isSameSizeAs(src) then 514 | dst:resizeAs(src) 515 | end 516 | dst:zero() 517 | else 518 | error('input must be a nested table of tensors!') 519 | end 520 | return dst 521 | end 522 | 523 | function ModelParallelTable:_accumulateTensorRecursive(src, dst) 524 | if (torch.type(src) == 'table') then 525 | -- Recurse on the table 526 | for i = 1, #src do 527 | dst[i] = self:_accumulateTensorRecursive(src[i], dst[i]) 528 | end 529 | elseif torch.type(src):find('torch%..+Tensor') then 530 | dst:add(src) 531 | else 532 | error('input must be a nested table of tensors!') 533 | end 534 | return dst 535 | end 536 | 537 | 538 | -- Backward compatibility purposes 539 | ModelParallelTable.__version = 2 540 | 541 | -- ModelParallelTable.deserializeNGPUs controls how many GPUs to deserialize 542 | -- upon, otherwise will deserialize to as many GPUs as serialized and error 543 | -- out if it doesn't have enough available 544 | function ModelParallelTable:__read(file, version) 545 | -- backwards compatibility 546 | -- TEMPORARY HACK: remove before checking into OSS 547 | if version < 2 then 548 | local var = file:readObject() 549 | for k, v in pairs(var) do 550 | self[k] = v 551 | end 552 | -- hope we didn't run out of memory :) 553 | local gpu = cutorch.getDevice() 554 | for i = 1, #self.gpuAssignments do 555 | -- move each branch to the correct gpu 556 | cutorch.setDevice(self.gpuAssignments[i]) 557 | self.modules[i]:float():cuda() 558 | end 559 | cutorch.setDevice(gpu) 560 | return 561 | end 562 | 563 | self.gpuAssignments = file:readObject() 564 | 565 | if ModelParallelTable.deserializeNGPUs then 566 | if ModelParallelTable.deserializeNGPUs > cutorch.getDeviceCount() then 567 | error('Deserialization requested on too many GPUs: ' .. 568 | ModelParallelTable.deserializeNGPUs .. ' vs ' .. 569 | cutorch.getDeviceCount() .. ' available') 570 | end 571 | -- round-robin branches 572 | for i = 1, #self.gpuAssignments do 573 | self.gpuAssignments[i] = i % ModelParallelTable.deserializeNGPUs 574 | self.gpuAssignments[i] = (self.gpuAssignments[i]==0) and 575 | ModelParallelTable.deserializeNGPUs or self.gpuAssignments[i] 576 | end 577 | end 578 | 579 | -- If ModelParallelTable.deserializeNGPUs, deserialization overrides 580 | -- gpu assignments anyway. If not, we need as many GPUs as the max, 581 | -- there may be holes. 582 | local nGPUs = math.max(unpack(self.gpuAssignments)) 583 | if nGPUs > cutorch.getDeviceCount() then 584 | error('Model was serialized on ' .. 585 | math.max(unpack(self.gpuAssignments)) .. 586 | ' nGPUs, but you are running on ' .. cutorch.getDeviceCount() .. 587 | ' please set ModelParallelTable.deserializeNGPUs to ignore ' .. 588 | ' serialized tower-GPU assignments') 589 | end 590 | 591 | local gpu = cutorch.getDevice() 592 | self.modules = {} 593 | -- deserialize each of the branches on the correct gpu 594 | for i = 1, #self.gpuAssignments do 595 | cutorch.setDevice(self.gpuAssignments[i]) 596 | self.modules[i] = file:readObject() 597 | end 598 | 599 | -- finally deserialize everything else 600 | cutorch.setDevice(gpu) 601 | local var = file:readObject() 602 | for k, v in pairs(var) do 603 | self[k] = v 604 | end 605 | end 606 | 607 | function ModelParallelTable:__write(file) 608 | file:writeObject(self.gpuAssignments) 609 | 610 | -- Write all the branches 611 | local modules = self.modules 612 | local gpuAssignments = self.gpuAssignments 613 | self.modules = nil 614 | self.gpuAssignments = nil 615 | for _, m in ipairs(modules) do 616 | file:writeObject(m) 617 | end 618 | 619 | -- Write everything else as a table 620 | local t = {} 621 | for k, v in pairs(self) do 622 | t[k] = v 623 | end 624 | file:writeObject(t) 625 | 626 | self.gpuAssignments = gpuAssignments 627 | self.modules = modules 628 | end 629 | 630 | function ModelParallelTable:clearState() 631 | self.inputGpu = {} 632 | self.gradOutputGpu = {} 633 | self.outputGpu = {} 634 | self.gradInputGpu = {} 635 | 636 | return parent.clearState(self) 637 | end 638 | -------------------------------------------------------------------------------- /modules/NoBackprop.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local NoBackprop,parent = torch.class('nn.NoBackprop','nn.Container') 10 | 11 | -- was lazy to finish CPU side 12 | function NoBackprop:__init(inner) 13 | parent.__init(self) 14 | assert(inner) 15 | self.modules = {inner} 16 | end 17 | 18 | function NoBackprop:updateOutput(input) 19 | self.output = self.modules[1]:updateOutput(input) 20 | return self.output 21 | end 22 | 23 | function NoBackprop:updateGradInput(input, gradOutput) 24 | self.gradInput:resizeAs(input):zero() 25 | return self.gradInput 26 | end 27 | 28 | function NoBackprop:__tostring() 29 | return 'NoBackprop: ' .. tostring(self.modules[1]) 30 | end 31 | 32 | -- ugh, stupid temporary backwards-compatibility hack 33 | NoBackprop.__version = 2 34 | function NoBackprop:__read(file, version) 35 | -- do the normal read 36 | local var = file:readObject() 37 | for k, v in pairs(var) do 38 | self[k] = v 39 | end 40 | -- fixup module 41 | if version < 2 then 42 | self.modules = {self.inner} 43 | self.inner = nil 44 | end 45 | end 46 | -------------------------------------------------------------------------------- /modules/SelectBoxes.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local SelectBoxes,parent = torch.class('nn.SelectBoxes','nn.Module') 10 | 11 | -- Input: 12 | -- * SoftMax output (eg. 128 x 21) 13 | -- * Bbox regresson output (eg. 128 x 84) 14 | -- Output: 15 | -- * boxes corresponding to the best classes (eg. 128 x 4) 16 | -- Optionally renormalizes by multiplying by std and adding mean 17 | -- if sigma_hat and sigma_mean are present in the self. 18 | 19 | -- was lazy to finish CPU side 20 | function SelectBoxes:__init() 21 | parent.__init(self) 22 | self.gradInput_classes = torch.Tensor() 23 | self.gradInput_boxes = torch.Tensor() 24 | end 25 | 26 | function SelectBoxes:updateOutput(input) 27 | local classes = input[1] 28 | local ys = input[2] 29 | 30 | local B = classes:size(1) 31 | self.maxvals = self.maxvals or classes.new() 32 | self.maxids = self.maxids or classes.new() 33 | self.ids = self.ids or classes.new() 34 | 35 | local maxvals = self.maxvals:resize(B,1) 36 | local maxids = self.maxids:resize(B,1) 37 | local ids = self.ids:resize(B,4) 38 | 39 | torch.max(maxvals, maxids, classes, 2) 40 | 41 | maxids:add(-1):mul(4) 42 | for i=1,4 do ids:select(2,i):fill(i) end 43 | ids:add(maxids:expand(B,4)) 44 | self.output:resize(B,4):gather(ys, 2, ids) 45 | 46 | if not self.std then 47 | --print'dry run, using 0-1 mean-sigma in nn.SelectBoxes' 48 | else 49 | -- renormalize output 50 | local mu = self.mean:expandAs(self.output) 51 | local sigma = self.std:expandAs(self.output) 52 | self.output:cmul(sigma):add(mu) 53 | end 54 | 55 | return self.output 56 | end 57 | 58 | function SelectBoxes:updateGradInput(input,gradOutput) 59 | self.gradInput_classes:resizeAs(input[1]):zero() 60 | self.gradInput_boxes:resizeAs(input[2]):zero() 61 | self.gradInput = {self.gradInput_classes, self.gradInput2} 62 | return self.gradInput 63 | end 64 | 65 | function test() 66 | module = nn.SelectBoxes():cuda() 67 | local classes = torch.Tensor{ 68 | {0,1,0}, 69 | {1,0,0} 70 | }:cuda() 71 | local ys = torch.rand(2,3*4):cuda() 72 | 73 | local output = module:forward{classes, ys}:cuda() 74 | print(ys, output) 75 | end 76 | 77 | --test() 78 | -------------------------------------------------------------------------------- /modules/SequentialSplitBatch.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local Sequential, parent = torch.class('nn.SequentialSplitBatch', 'nn.Sequential') 10 | 11 | function Sequential:__init(size) 12 | parent.__init(self) 13 | self.batch_size = size 14 | end 15 | 16 | function Sequential:updateOutput(input) 17 | if torch.type(input) ~= 'table' then 18 | assert(input:dim() == 2 or input:dim() == 4) 19 | local batch_size = input:size(1) 20 | if batch_size <= self.batch_size then 21 | return parent.updateOutput(self, input) 22 | else 23 | -- propagate small batch to determine output size 24 | local output_size = parent.updateOutput(self,input:narrow(1,1,1)):size() 25 | output_size[1] = batch_size 26 | self.output_ = self.output_ or input.new() 27 | self.parent_output = self.parent_output or input[1].new() 28 | self.output:set(self.parent_output) 29 | self.output_:resize(output_size) 30 | local input_split = input:split(self.batch_size,1) 31 | local output_split = self.output_:split(self.batch_size,1) 32 | for i,v in ipairs(input_split) do 33 | output_split[i]:copy(parent.updateOutput(self,v)) 34 | end 35 | self.output:set(self.output_) 36 | return self.output 37 | end 38 | elseif torch.type(input) == 'table' then 39 | -- only 1-nested tables supported 40 | -- only tensor output so far 41 | local input_sizes = {} 42 | for i,v in ipairs(input) do 43 | assert(v:dim() == 2 or v:dim() == 4) 44 | input_sizes[i] = v:size(1) 45 | end 46 | --assert(torch.Tensor(input_sizes):std() == 0, 'different sizes on input') 47 | local batch_size = input_sizes[2] 48 | 49 | if batch_size <= self.batch_size then 50 | return parent.updateOutput(self, input) 51 | else 52 | -- propagate small batch 53 | local subinput = {} 54 | for i,v in ipairs(input) do subinput[i] = v:narrow(1,1,1) end 55 | local output_size = parent.updateOutput(self, subinput):size() 56 | output_size[1] = batch_size 57 | self.output_ = self.output_ or input[1].new() 58 | self.parent_output = self.parent_output or input[1].new() 59 | self.output:set(self.parent_output) 60 | self.output_:resize(output_size) 61 | local output_split = self.output_:split(self.batch_size,1) 62 | local per_input_splits = {} 63 | for i,v in ipairs(input) do per_input_splits[i] = v:split(self.batch_size,1) end 64 | 65 | assert(self.output_:storage() ~= self.output:storage()) 66 | for k,u in ipairs(output_split) do 67 | local subinput = {input[1], per_input_splits[2][k]} 68 | u:copy(parent.updateOutput(self,subinput)) 69 | end 70 | self.output:set(self.output_) 71 | return self.output 72 | end 73 | end 74 | end 75 | 76 | 77 | -- for updateGradInput, accGradParameters, etc do not do anything. 78 | -------------------------------------------------------------------------------- /modules/test.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | require 'fbcoco' 10 | require 'inn' 11 | 12 | local mytester = torch.Tester() 13 | 14 | local precision = 1e-3 15 | 16 | local nntest = torch.TestSuite() 17 | 18 | local function criterionJacobianTest1D(cri, input, target) 19 | local eps = 1e-6 20 | local _ = cri:forward(input, target) 21 | local dfdx = cri:backward(input, target) 22 | -- for each input perturbation, do central difference 23 | local centraldiff_dfdx = torch.Tensor():resizeAs(dfdx) 24 | local input_s = input:storage() 25 | local centraldiff_dfdx_s = centraldiff_dfdx:storage() 26 | for i=1,input:nElement() do 27 | -- f(xi + h) 28 | input_s[i] = input_s[i] + eps 29 | local fx1 = cri:forward(input, target) 30 | -- f(xi - h) 31 | input_s[i] = input_s[i] - 2*eps 32 | local fx2 = cri:forward(input, target) 33 | -- f'(xi) = (f(xi + h) - f(xi - h)) / 2h 34 | local cdfx = (fx1 - fx2) / (2*eps) 35 | -- store f' in appropriate place 36 | centraldiff_dfdx_s[i] = cdfx 37 | -- reset input[i] 38 | input_s[i] = input_s[i] + eps 39 | end 40 | 41 | -- compare centraldiff_dfdx with :backward() 42 | local err = (centraldiff_dfdx - dfdx):abs():max() 43 | mytester:assertlt(err, precision, 'error in difference between central difference and :backward') 44 | end 45 | 46 | 47 | function nntest.BBoxRegressionCriterion() 48 | local bs = torch.random(16,32) 49 | local input = torch.randn(bs, 84) 50 | local bbox_targets = torch.randn(bs, 84):zero() 51 | local bbox_labels = torch.Tensor(bs):random(2,21) 52 | for i=1,bs do 53 | bbox_targets[i]:narrow(1,(bbox_labels[i]-1)*4 + 1, 4) 54 | end 55 | local target = {bbox_labels, bbox_targets} 56 | local cri = nn.BBoxRegressionCriterion() 57 | criterionJacobianTest1D(cri, input, target) 58 | end 59 | 60 | function nntest.SequentialSplitBatch_ROIPooling() 61 | local input = { 62 | torch.randn(1,512,38,50):cuda(), 63 | torch.randn(40,5):cuda():mul(50), 64 | } 65 | input[2]:select(2,1):fill(1) 66 | 67 | local module = nn.SequentialSplitBatch(25) 68 | :add(inn.ROIPooling(7,7,1/16)) 69 | :add(nn.View(-1):setNumInputDims(3)) 70 | :add(nn.Linear(7*7*512,9)) 71 | :cuda() 72 | 73 | local output_mod = module:forward(input):clone() 74 | output_mod = module:forward(input):clone() 75 | local output_ref = module:replace(function(x) 76 | if torch.typename(x) == 'nn.SequentialSplitBatch' then 77 | torch.setmetatable(x, 'nn.Sequential') 78 | end 79 | return x 80 | end):forward(input):clone() 81 | 82 | mytester:asserteq((output_mod - output_ref):abs():max(), 0, 'SequentialSplitBatch err') 83 | end 84 | 85 | function nntest.SequentialSplitBatch_Tensor() 86 | local input = torch.randn(40,512):cuda() 87 | local module = nn.SequentialSplitBatch(25):add(nn.Linear(512,9)):cuda() 88 | 89 | local output_mod = module:forward(input):clone() 90 | output_mod = module:forward(input):clone() 91 | local output_ref = module:replace(function(x) 92 | if torch.typename(x) == 'nn.SequentialSplitBatch' then 93 | torch.setmetatable(x, 'nn.Sequential') 94 | end 95 | return x 96 | end):forward(input):clone() 97 | 98 | mytester:asserteq((output_mod - output_ref):abs():max(), 0, 'SequentialSplitBatch err') 99 | end 100 | 101 | mytester:add(nntest) 102 | mytester:run() 103 | -------------------------------------------------------------------------------- /nms.c: -------------------------------------------------------------------------------- 1 | /*------------------------------------------------------------------------------ 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------*/ 8 | 9 | #include 10 | 11 | #define MIN(a,b) (((a)<(b))?(a):(b)) 12 | #define MAX(a,b) (((a)>(b))?(a):(b)) 13 | 14 | float overlap(const float *a, const float *b) 15 | { 16 | float a_x1 = a[0]; 17 | float a_y1 = a[1]; 18 | float a_x2 = a[2]; 19 | float a_y2 = a[3]; 20 | 21 | float b_x1 = b[0]; 22 | float b_y1 = b[1]; 23 | float b_x2 = b[2]; 24 | float b_y2 = b[3]; 25 | 26 | float x1 = MAX(a_x1, b_x1); 27 | float y1 = MAX(a_y1, b_y1); 28 | float x2 = MIN(a_x2, b_x2); 29 | float y2 = MIN(a_y2, b_y2); 30 | 31 | float w = x2 - x1 + 1; 32 | float h = y2 - y1 + 1; 33 | 34 | float intersection = w * h; 35 | 36 | float aarea = (a_x2 - a_x1 + 1) * (a_y2 - a_y1 + 1); 37 | float barea = (b_x2 - b_x1 + 1) * (b_y2 - b_y1 + 1); 38 | 39 | float iou = intersection / (aarea + barea - intersection); 40 | return (w <= 0 || h <= 0) ? 0 : iou; 41 | } 42 | 43 | void boxoverlap(THFloatTensor *result, THFloatTensor *a, THFloatTensor *b) 44 | { 45 | int N = a->size[0]; 46 | THFloatTensor_resize1d(result, N); 47 | float *a_data = THFloatTensor_data(a); 48 | float *b_data = THFloatTensor_data(b); 49 | float *r_data = THFloatTensor_data(result); 50 | 51 | for(int i=0;isize[0]; 62 | float **boxes = calloc(N, sizeof(float*)); 63 | float *scored_boxes_data = THFloatTensor_data(scored_boxes); 64 | for(int i=0; i bestS) { 78 | bestS = boxes[i][4]; 79 | best = i; 80 | } 81 | } 82 | float *b = boxes[best]; 83 | float *tmp = boxes[0]; 84 | boxes[0] = boxes[best]; 85 | boxes[best] = tmp; 86 | boxes++; 87 | numNMS++; 88 | 89 | // Remove all bounding boxes where the percent area of overlap is greater than overlap 90 | int numGood = 0; 91 | for(int i = 0; i < num-1; i++) { 92 | float inter_over_union = overlap(b, boxes[i]); 93 | if(inter_over_union <= threshold) { 94 | tmp = boxes[numGood]; 95 | boxes[numGood++] = boxes[i]; 96 | boxes[i] = tmp; 97 | } 98 | } 99 | num = numGood; 100 | } 101 | 102 | THFloatTensor_resize2d(keep, numNMS, 5); 103 | float *keep_data = THFloatTensor_data(keep); 104 | for(int i=0; isize[0]; 118 | int N_boxes = scored_boxes->size[0]; 119 | THAssert(nms_boxes->size[1] == 5); 120 | THAssert(scored_boxes->size[1] == 5); 121 | // THFloatTensor* overlaps = THFloatTensor_newWithSize1d(N_boxes); 122 | 123 | float *nms_data = THFloatTensor_data(nms_boxes); 124 | float *scored_data = THFloatTensor_data(scored_boxes); 125 | float *res_data = THFloatTensor_data(res); 126 | 127 | for(int i=0; i threshold) { 131 | for(int field = 0; field<4; field++) { 132 | res_data[5*i+field] += scored_data[5*j+field] * scored_data[5*j+4]; 133 | } 134 | res_data[5*i+4] += scored_data[5*j+4]; 135 | } 136 | } 137 | for(int field = 0; field<4; field++) { 138 | res_data[5*i+field] /= res_data[5*i+4]; 139 | } 140 | res_data[5*i+4] = nms_data[5*i+4]; 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /run_test.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | require 'torch' 10 | local json = require 'cjson' 11 | local test_runner = paths.dofile('test_runner.lua') 12 | local utils = paths.dofile'utils.lua' 13 | local tds = require 'tds' 14 | 15 | opt = { 16 | dataset = 'pascal', 17 | year = '2007', 18 | proposals = 'deepmask', 19 | proposal_dir = './data/proposals', 20 | transformer = 'RossTransformer', 21 | scale = 600, 22 | max_size = 1000, 23 | test_nGPU = 4, 24 | test_set = 'test', 25 | test_nsamples = -1, -- all samples 26 | test_data_offset = -1, -- ignore the first "offset" samples 27 | test_model = './data/models/caffenet_fast_rcnn_iter_40000.t7', 28 | test_best_proposals_number = 500, 29 | test_load_aboxes = '', 30 | test_save_res_prefix = '', 31 | test_save_res = '', 32 | test_save_raw = '', 33 | test_num_iterative_loc = 1, 34 | disable_memory_efficient_forward = false, 35 | test_add_nosoftmax = false, -- for backwards compatibility with szagoruyko's experiments ONLY 36 | test_use_rbox_scores = false, 37 | test_bbox_voting = false, 38 | test_bbox_voting_score_pow = 1, 39 | test_augment = false, 40 | test_just_save_boxes = false, 41 | test_min_proposal_size = 2, 42 | test_nms_threshold = 0.3, 43 | test_bbox_voting_nms_threshold = 0.5, 44 | } 45 | opt = xlua.envparams(opt) 46 | print(opt) 47 | 48 | local dataset_name = opt.dataset..'_'..opt.test_set..opt.year 49 | local folder_name = opt.dataset == 'pascal' and ('VOC'..opt.year) or 'coco' 50 | local proposals_path = utils.makeProposalPath(opt.proposal_dir, folder_name, opt.proposals, opt.test_set) 51 | 52 | print('dataset:',dataset_name) 53 | print('proposals_path:',proposals_path) 54 | 55 | test_runner:setup(opt.test_nGPU, dataset_name, proposals_path) 56 | 57 | local aboxes 58 | 59 | if opt.test_load_aboxes == '' then 60 | aboxes = test_runner:computeBBoxes() 61 | else 62 | aboxes = torch.load(opt.test_load_aboxes) 63 | end 64 | 65 | local dir = opt.test_save_res 66 | if opt.test_data_offset ~= -1 then 67 | dir = opt.test_data_offset 68 | dir = opt.test_save_res_prefix .. dir 69 | end 70 | 71 | 72 | if dir ~= '' then 73 | print("Saving boxes to " .. dir) 74 | paths.mkdir(dir) 75 | torch.save(('%s/boxes.t7'):format(dir), aboxes) 76 | end 77 | 78 | if not opt.test_just_save_boxes then 79 | local res = test_runner:evaluateBoxes(aboxes) 80 | 81 | if dir ~= '' then 82 | torch.save(dir..'/results.t7', res) 83 | end 84 | end 85 | -------------------------------------------------------------------------------- /scripts/ec2-install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This sets up a new AWS EC2 server to run Facebook's multipathnet. 4 | # 5 | # Requires: 6 | # - Amazon EC2 instance with GPU (g2.2xlarge or g2.8xlarge) running AWS Linux. 7 | # - 30 GB of disk space. 8 | 9 | sudo pip install numpy 10 | sudo yum -y install boost-devel 11 | 12 | sudo yum -y install git 13 | sudo yum -y install automake 14 | 15 | wget https://github.com/google/glog/archive/v0.3.3.zip 16 | unzip v0.3.3.zip 17 | cd glog-0.3.3/ 18 | ./configure 19 | make 20 | sudo make install 21 | 22 | # Install torch 23 | cd ~ 24 | git clone https://github.com/torch/distro.git ~/torch --recursive 25 | cd ~/torch; bash install-deps; 26 | ./install.sh -b 27 | . ~/torch/install/bin/torch-activate 28 | 29 | luarocks install inn 30 | luarocks install torchnet 31 | luarocks install fbpython 32 | luarocks install class 33 | luarocks install optnet 34 | 35 | # Install COCO 36 | cd ~ 37 | git clone https://github.com/pdollar/coco.git 38 | cd coco 39 | luarocks make LuaAPI/rocks/coco-scm-1.rockspec 40 | 41 | cd PythonAPI 42 | make 43 | export PYTHONPATH=$PYTHONPATH:~/coco/PythonAPI 44 | 45 | # Install nVidia CUDA 46 | cd ~ 47 | wget http://developer.download.nvidia.com/compute/redist/cudnn/v5.1/cudnn-7.5-linux-x64-v5.1.tgz 48 | tar xvf cudnn-7.5-linux-x64-v5.1.tgz 49 | cd cuda/lib64 50 | export LD_LIBRARY_PATH=`pwd`:$LD_LIBRARY_PATH 51 | 52 | # Install Multipathnet 53 | cd ~ 54 | git clone https://github.com/facebookresearch/multipathnet.git 55 | 56 | cd /tmp 57 | wget http://mscoco.org/static/annotations/PASCAL_VOC.zip 58 | wget http://mscoco.org/static/annotations/ILSVRC2014.zip 59 | wget http://msvocds.blob.core.windows.net/annotations-1-0-3/instances_train-val2014.zip 60 | 61 | export MPROOT=~/multipathnet 62 | mkdir -p $MPROOT/data/annotations 63 | cd $MPROOT/data/annotations 64 | unzip -j /tmp/PASCAL_VOC.zip 65 | unzip -j /tmp/ILSVRC2014.zip 66 | unzip -j /tmp/instances_train-val2014.zip 67 | 68 | mkdir -p $MPROOT/data/proposals/VOC2007/selective_search 69 | cd $MPROOT/data/proposals/VOC2007/selective_search 70 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/VOC2007/selective_search/train.t7 71 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/VOC2007/selective_search/val.t7 72 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/VOC2007/selective_search/trainval.t7 73 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/VOC2007/selective_search/test.t7 74 | 75 | mkdir -p $MPROOT/data/proposals/coco/sharpmask 76 | cd $MPROOT/data/proposals/coco/sharpmask 77 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/coco/sharpmask/train.t7 78 | wget https://dl.fbaipublicfiles.com/multipathnet/proposals/coco/sharpmask/val.t7 79 | 80 | mkdir -p $MPROOT/data/models 81 | cd $MPROOT/data/models 82 | wget https://dl.fbaipublicfiles.com/multipathnet/models/imagenet_pretrained_alexnet.t7 83 | wget https://dl.fbaipublicfiles.com/multipathnet/models/imagenet_pretrained_vgg.t7 84 | wget https://dl.fbaipublicfiles.com/multipathnet/models/vgg16_fast_rcnn_iter_40000.t7 85 | wget https://dl.fbaipublicfiles.com/multipathnet/models/caffenet_fast_rcnn_iter_40000.t7 86 | 87 | if [ ! -f ~/multipathnet/config.lua.backup ]; then 88 | cp ~/multipathnet/config.lua ~/multipathnet/config.lua.backup 89 | fi 90 | 91 | echo " 92 | -- put your paths to VOC and COCO containing subfolders with images here 93 | local VOCdevkit = '$MPROOT/data/proposals' 94 | local coco_dir = '$MPROOT/data/proposals/coco' 95 | 96 | return { 97 | pascal_train2007 = paths.concat(VOCdevkit, 'VOC2007/selective_search'), 98 | pascal_val2007 = paths.concat(VOCdevkit, 'VOC2007/selective_search'), 99 | pascal_test2007 = paths.concat(VOCdevkit, 'VOC2007/selective_search'), 100 | pascal_train2012 = paths.concat(VOCdevkit, 'VOC2007/selective_search'), 101 | pascal_val2012 = paths.concat(VOCdevkit, 'VOC2007/selective_search'), 102 | pascal_test2012 = paths.concat(VOCdevkit, 'VOC2007/selective_search'), 103 | coco_train2014 = paths.concat(coco_dir, 'sharpmask'), 104 | coco_val2014 = paths.concat(coco_dir, 'sharpmask'), 105 | }" > ~/multipathnet/config.lua 106 | 107 | cd $MPROOT 108 | git clone https://github.com/facebookresearch/deepmask.git 109 | 110 | cd $MPROOT/data/models 111 | # download SharpMask based on ResNet-50 112 | wget https://dl.fbaipublicfiles.com/deepmask/models/sharpmask/model.t7 -O sharpmask.t7 113 | wget https://dl.fbaipublicfiles.com/multipathnet/models/resnet18_integral_coco.t7 114 | 115 | echo 116 | echo 'Add the following to your .bashrc: 117 | export PYTHONPATH=~/coco/PythonAPI 118 | export LD_LIBRARY_PATH=~/cuda/lib64:$LD_LIBRARY_PATH' 119 | -------------------------------------------------------------------------------- /scripts/eval_coco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export dataset=coco 4 | export test_set=val 5 | export year=2014 6 | export scale=800 7 | 8 | export transformer=ImagenetTransformer 9 | export test_model=./data/models/resnet18_integral_coco.t7 10 | export proposals=sharpmask 11 | 12 | export test_nsamples=5000 13 | export test_best_proposals_number=400 14 | export max_size=1000 15 | 16 | th run_test.lua 17 | -------------------------------------------------------------------------------- /scripts/eval_fastrcnn_voc2007.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export proposals=selective_search 4 | export test_best_proposals_number=2000 5 | 6 | th run_test.lua 7 | 8 | # model=../data/models/vgg16_fast_rcnn_iter_40000.t7 9 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.345 10 | # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.668 11 | # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.320 12 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.062 13 | # Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.202 14 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.406 15 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.350 16 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.462 17 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.468 18 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.260 19 | # Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.344 20 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.518 21 | 22 | 23 | # model=data/models/caffenet_fast_rcnn_iter_40000.t7 24 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.264 25 | # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.559 26 | # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.217 27 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.034 28 | # Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.132 29 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.318 30 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.304 31 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.400 32 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.408 33 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.140 34 | # Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.285 35 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.456 36 | -------------------------------------------------------------------------------- /scripts/train_coco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export year=2014 4 | export train_set=trainval 5 | export test_set=val 6 | export dataset=coco 7 | 8 | export nDonkeys=6 9 | export integral=true 10 | export images_per_batch=4 11 | export batchSize=64 12 | export scale=800 13 | export weightDecay=0 14 | export test_best_proposals_number=400 15 | export test_nsamples=1000 16 | 17 | export proposals=sharpmask 18 | export nEpochs=3200 19 | export step=2800 20 | export save_folder="logs/coco_${model}_${proposals}_$RANDOM$RANDOM" 21 | 22 | mkdir -p $save_folder 23 | 24 | th train.lua | tee $save_folder/log.txt 25 | 26 | -------------------------------------------------------------------------------- /scripts/train_fastrcnn_voc2007.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export proposals=selective_search 4 | export test_best_proposals_number=2000 5 | export best_proposals_number=2000 6 | 7 | export save_folder=logs/fastrcnn_voc2007_${RANDOM}${RANDOM} 8 | mkdir -p $save_folder 9 | 10 | th train.lua | tee $save_folder/log.txt 11 | 12 | # model=vgg 13 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.358 14 | # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.688 15 | # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.331 16 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.102 17 | # Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.211 18 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.415 19 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.359 20 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.470 21 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.474 22 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.237 23 | # Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.352 24 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.525 25 | 26 | # model=alexnet 27 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.260 28 | # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.558 29 | # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.212 30 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.050 31 | # Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.128 32 | # Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.309 33 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.301 34 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.396 35 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.402 36 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.169 37 | # Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.278 38 | # Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.449 39 | -------------------------------------------------------------------------------- /scripts/train_multipathnet_coco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export year=2014 4 | export train_set=trainval 5 | export test_set=val 6 | export dataset=coco 7 | 8 | export nDonkeys=6 9 | export integral=true 10 | export images_per_batch=4 11 | export batchSize=64 12 | export scale=800 13 | export weightDecay=0 14 | export test_best_proposals_number=400 15 | export test_nsamples=1000 16 | 17 | export model=multipathnet 18 | export proposals=sharpmask 19 | export nEpochs=3200 20 | export step=2800 21 | export save_folder="logs/coco_${model}_${proposals}_$RANDOM$RANDOM" 22 | 23 | mkdir -p $save_folder 24 | 25 | th train.lua | tee $save_folder/log.txt 26 | 27 | -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local inn = require 'inn' 10 | require 'fbcoco' 11 | 12 | local utils = paths.dofile'utils.lua' 13 | 14 | local mytester = torch.Tester() 15 | local utiltest = torch.TestSuite() 16 | 17 | function utiltest.bboxregression_parametrization() 18 | local A = torch.rand(2) * 100 19 | local B = torch.rand(2) * 100 20 | local bbox = torch.Tensor{A[1], A[2], A[1] + torch.random(40), A[2] + torch.random(40)} 21 | local tbox = torch.Tensor{B[1], B[2], B[1] + torch.random(40), B[2] + torch.random(40)} 22 | local out = torch.zeros(4) 23 | 24 | -- test 1-dim 25 | utils.convertTo(out, bbox, tbox) 26 | local out1 = torch.zeros(4) 27 | utils.convertFrom(out1, bbox, out) 28 | mytester:assertlt((out1 - tbox):abs():max(), 1e-8) 29 | 30 | -- test 2-dim 31 | local out2 = torch.zeros(1,4) 32 | utils.convertTo(out2, bbox:view(1,4), tbox:view(1,4)) 33 | mytester:assertlt((out2:squeeze() - out):abs():max(), 1e-8) 34 | 35 | local out3 = torch.zeros(1,4) 36 | utils.convertFrom(out3, bbox:view(1,4), out:view(1,4)) 37 | mytester:assertlt((out3 - out1):abs():max(), 1e-8) 38 | end 39 | 40 | function utiltest.boxoverlap() 41 | local a = torch.Tensor{ 42 | {0,0,100,100}, 43 | {0,50,100,150}, 44 | {50,0,150,100}, 45 | {50,50,150,150}, 46 | {100,100,200,200} 47 | } 48 | local b = {50,50,150,150} 49 | 50 | local gt = torch.FloatTensor{1/7, 1/3, 1/3, 1, 1/7} 51 | mytester:assertlt((utils.boxoverlap(a,b) - gt):max(),5e-3) 52 | end 53 | 54 | function utiltest.attachProposals() 55 | local dataset_name = 'pascal_test2007' 56 | local proposals_path = 'data/proposals/VOC2007/selective_search/test.t7' 57 | 58 | local ds = dofile'DataSetJSON.lua':create(dataset_name, proposals_path) 59 | ds:loadROIDB(500) 60 | 61 | mytester:assertgt(ds:size(), 0) 62 | 63 | -- go over some annotations and check that they are in the right format 64 | for i=1,32 do 65 | local id = torch.random(ds:size()) 66 | 67 | -- load an image and check that it has 3 channels 68 | local im = ds:getImage(id) 69 | mytester:asserteq(im:nDimension(), 3) 70 | 71 | -- annotation check 72 | local anno = ds:getAnnotation(1) 73 | local obj = anno[1] 74 | -- check that annotation has 'difficult' field 75 | mytester:assert(obj.difficult ~= nil) 76 | mytester:assertgt(obj.class_id, 0) 77 | -- check that the bbox is x1,y1,x2,y2 78 | mytester:assertgt(obj.bbox[3], obj.bbox[1]) 79 | mytester:assertgt(obj.bbox[4], obj.bbox[2]) 80 | 81 | -- check that proposals are in x1,y1,x2,y2 too 82 | local proposals = ds:getROIBoxes(id) 83 | mytester:assertgt(proposals:select(2,3):gt(proposals:select(2,1)):float():mean(), 0.9) 84 | mytester:assertgt(proposals:select(2,4):gt(proposals:select(2,2)):float():mean(), 0.9) 85 | end 86 | end 87 | 88 | function utiltest.merge_table() 89 | local t1, t2, t3 = {x = 1}, {y = 2}, {z = 3} 90 | local t = utils.merge_table{t1,t2,t3} 91 | assert(t.x == t1.x and t.y == t2.y and t.z == t3.z) 92 | end 93 | 94 | local precision = 1e-3 95 | 96 | local nntest = torch.TestSuite() 97 | 98 | local function criterionJacobianTest1D(cri, input, target) 99 | local eps = 1e-6 100 | local _ = cri:forward(input, target) 101 | local dfdx = cri:backward(input, target) 102 | -- for each input perturbation, do central difference 103 | local centraldiff_dfdx = torch.Tensor():resizeAs(dfdx) 104 | local input_s = input:storage() 105 | local centraldiff_dfdx_s = centraldiff_dfdx:storage() 106 | for i=1,input:nElement() do 107 | -- f(xi + h) 108 | input_s[i] = input_s[i] + eps 109 | local fx1 = cri:forward(input, target) 110 | -- f(xi - h) 111 | input_s[i] = input_s[i] - 2*eps 112 | local fx2 = cri:forward(input, target) 113 | -- f'(xi) = (f(xi + h) - f(xi - h)) / 2h 114 | local cdfx = (fx1 - fx2) / (2*eps) 115 | -- store f' in appropriate place 116 | centraldiff_dfdx_s[i] = cdfx 117 | -- reset input[i] 118 | input_s[i] = input_s[i] + eps 119 | end 120 | 121 | -- compare centraldiff_dfdx with :backward() 122 | local err = (centraldiff_dfdx - dfdx):abs():max() 123 | mytester:assertlt(err, precision, 'error in difference between central difference and :backward') 124 | end 125 | 126 | 127 | function nntest.BBoxRegressionCriterion() 128 | local bs = torch.random(16,32) 129 | local input = torch.randn(bs, 84) 130 | local bbox_targets = torch.randn(bs, 84):zero() 131 | local bbox_labels = torch.Tensor(bs):random(2,21) 132 | for i=1,bs do 133 | bbox_targets[i]:narrow(1,(bbox_labels[i]-1)*4 + 1, 4) 134 | end 135 | local target = {bbox_labels, bbox_targets} 136 | local cri = nn.BBoxRegressionCriterion() 137 | criterionJacobianTest1D(cri, input, target) 138 | end 139 | 140 | function nntest.SequentialSplitBatch_ROIPooling() 141 | local input = { 142 | torch.randn(1,512,38,50):cuda(), 143 | torch.randn(40,5):cuda():mul(50), 144 | } 145 | input[2]:select(2,1):fill(1) 146 | 147 | local module = nn.SequentialSplitBatch(25) 148 | :add(inn.ROIPooling(7,7,1/16)) 149 | :add(nn.View(-1):setNumInputDims(3)) 150 | :add(nn.Linear(7*7*512,9)) 151 | :cuda() 152 | 153 | local output_mod = module:forward(input):clone() 154 | output_mod = module:forward(input):clone() 155 | local output_ref = module:replace(function(x) 156 | if torch.typename(x) == 'nn.SequentialSplitBatch' then 157 | torch.setmetatable(x, 'nn.Sequential') 158 | end 159 | return x 160 | end):forward(input):clone() 161 | 162 | mytester:asserteq((output_mod - output_ref):abs():max(), 0, 'SequentialSplitBatch err') 163 | end 164 | 165 | function nntest.SequentialSplitBatch_Tensor() 166 | local input = torch.randn(40,512):cuda() 167 | local module = nn.SequentialSplitBatch(25):add(nn.Linear(512,9)):cuda() 168 | 169 | local output_mod = module:forward(input):clone() 170 | output_mod = module:forward(input):clone() 171 | local output_ref = module:replace(function(x) 172 | if torch.typename(x) == 'nn.SequentialSplitBatch' then 173 | torch.setmetatable(x, 'nn.Sequential') 174 | end 175 | return x 176 | end):forward(input):clone() 177 | 178 | mytester:asserteq((output_mod - output_ref):abs():max(), 0, 'SequentialSplitBatch err') 179 | end 180 | 181 | mytester:add(utiltest) 182 | mytester:add(nntest) 183 | mytester:run() 184 | -------------------------------------------------------------------------------- /testCoco/coco.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | local class = require 'class' 10 | local py = require 'fb.python' 11 | 12 | local Coco = class('coco') 13 | 14 | function Coco:__init(annFile) 15 | py.exec('import sys') 16 | py.exec('from pycocotools.coco import COCO') 17 | py.exec('from pycocotools.cocoeval import COCOeval') 18 | py.exec([=[ 19 | global cocoGt 20 | cocoGt = COCO(annFile) 21 | ]=], {annFile=annFile}) 22 | end 23 | 24 | function Coco:evaluate(res) 25 | py.exec([=[ 26 | global stats 27 | cocoDt = cocoGt.loadRes(res) 28 | imgIds=sorted(cocoDt.imgToAnns.keys()) 29 | imgIds=imgIds[0:len(imgIds)] 30 | cocoEval = COCOeval(cocoGt,cocoDt) 31 | cocoEval.params.imgIds = imgIds 32 | cocoEval.evaluate() 33 | cocoEval.accumulate() 34 | cocoEval.summarize() 35 | stats = cocoEval.stats 36 | ]=], {res=res}) 37 | return py.eval('stats') 38 | end 39 | 40 | return Coco 41 | -------------------------------------------------------------------------------- /testCoco/init.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | -- script to run the python coco tester on the saved results file from run_test.lua 10 | -- 11 | 12 | local testCoco = {} 13 | 14 | local Coco = require 'testCoco.coco' 15 | local loader = require 'loaders.dataloader' 16 | require 'xlua' 17 | 18 | local function getAboxes(res, class) 19 | if type(res) == 'string' then -- res_folder 20 | return torch.load(('%s/%.2d.t7'):format(res, class)) 21 | elseif type(res) == 'table' or type(res) == 'cdata' then -- table or tds.hash 22 | return res[class] 23 | else 24 | error("Unknown res object: type " .. type(res)) 25 | end 26 | end 27 | 28 | local annotations_path = 'data/annotations/' 29 | 30 | function testCoco.evaluate(dataset_name, res) 31 | local annFile 32 | if dataset_name == 'coco_val2014' then 33 | annFile = 'instances_val2014.json' 34 | elseif dataset_name == 'pascal_test2007' then 35 | annFile = 'pascal_test2007.json' 36 | end 37 | annFile = paths.concat(annotations_path, annFile) 38 | 39 | local dataset = loader(dataset_name) 40 | 41 | print("Loading COCO image ids...") 42 | local image_ids = {} 43 | for i = 1, dataset:nImages() do 44 | if i % 10000 == 0 then print(" "..i..'/'..dataset:nImages()) end 45 | image_ids[i] = dataset:getImage(i).id 46 | end 47 | print('#image_ids',#image_ids) 48 | 49 | local nClasses = dataset:nCategories() 50 | 51 | print("Loading files to calculate sizes...") 52 | local nboxes = 0 53 | for class = 1, nClasses do 54 | local aboxes = getAboxes(res, class) 55 | 56 | for _,u in pairs(aboxes) do 57 | if u:nDimension() > 0 then 58 | nboxes = nboxes + u:size(1) 59 | end 60 | end 61 | -- xlua.progress(class, nClasses) 62 | end 63 | print("Total boxes: " .. nboxes) 64 | 65 | local boxt = torch.FloatTensor(nboxes, 7) 66 | 67 | print("Loading files to create giant tensor...") 68 | local offset = 1 69 | for class = 1, nClasses do 70 | local aboxes = getAboxes(res, class) 71 | for img,t in pairs(aboxes) do 72 | if t:nDimension() > 0 then 73 | local sub = boxt:narrow(1,offset,t:size(1)) 74 | sub:select(2, 1):fill(image_ids[img]) -- image ID 75 | sub:select(2, 2):copy(t:select(2, 1) - 1) -- x1 0-indexed 76 | sub:select(2, 3):copy(t:select(2, 2) - 1) -- y1 0-indexed 77 | sub:select(2, 4):copy(t:select(2, 3) - t:select(2, 1)) -- w 78 | sub:select(2, 5):copy(t:select(2, 4) - t:select(2, 2)) -- h 79 | sub:select(2, 6):copy(t:select(2, 5)) -- score 80 | sub:select(2, 7):fill(dataset.data.categories.id[class]) -- class 81 | offset = offset + t:size(1) 82 | end 83 | end 84 | -- xlua.progress(class, nClasses) 85 | end 86 | 87 | local coco = Coco(annFile) 88 | return coco:evaluate(boxt) 89 | end 90 | 91 | return testCoco 92 | -------------------------------------------------------------------------------- /test_runner.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | -- uses the 'donkey' pattern 10 | -- constructs threads for running the model on multiple GPUs 11 | 12 | local module = {} 13 | 14 | local Threads = require 'threads' 15 | Threads.serialization('threads.sharedserialize') 16 | local tds = require 'tds' 17 | 18 | local function _setup(dataset_name, proposals_path) 19 | require 'cutorch' 20 | require 'fbcoco' 21 | require 'inn' 22 | require 'cudnn' 23 | require 'nngraph' 24 | local utils = paths.dofile 'utils.lua' 25 | local model_utils = paths.dofile 'models/model_utils.lua' 26 | 27 | nn.DataParallelTable.deserializeNGPUs = cutorch.getDeviceCount() 28 | nn.ModelParallelTable.deserializeNGPUs = cutorch.getDeviceCount() 29 | 30 | local transformer = model_utils[opt.transformer]() 31 | local model = model_utils.load(opt.test_model):cuda() 32 | if opt.test_nGPU > 1 then 33 | utils.removeDataParallel(model) 34 | end 35 | utils.removeDataParallel(model) -- TODO: see why it complains 36 | 37 | model:evaluate() 38 | if opt.test_add_nosoftmax then 39 | print("Setting noSoftMax=true") 40 | model.noSoftMax = true 41 | end 42 | -- patch to use inplace dropout everywhere 43 | for k,v in ipairs(model:findModules'nn.Dropout') do v.inplace = true end 44 | ds = paths.dofile'DataSetJSON.lua':create(dataset_name, proposals_path, opt.test_nsamples, opt.test_data_offset) 45 | ds:setMinProposalArea(opt.test_min_proposal_size) 46 | ds:loadROIDB(opt.test_best_proposals_number) 47 | tester = fbcoco.Tester_FRCNN(model, transformer, ds, {opt.scale}, opt.max_size, opt) 48 | end 49 | 50 | function module:setup(nThreads, dataset_name, proposals_path) 51 | self.nThreads = nThreads 52 | if self.nThreads > 1 then 53 | _setup(dataset_name, proposals_path) 54 | local _opt = opt 55 | self.threads = Threads(self.nThreads, 56 | function() 57 | require 'cutorch' 58 | end, 59 | function(idx) 60 | thread_idx = idx 61 | opt = _opt 62 | local dev = idx % cutorch.getDeviceCount() 63 | dev = (dev==0) and cutorch.getDeviceCount() or dev 64 | cutorch.setDevice(dev) 65 | _setup(dataset_name, proposals_path) 66 | end) 67 | else 68 | self.threads = { 69 | addjob = function(self, f1, f2) 70 | if f2 then 71 | return f2(f1()) 72 | else 73 | f1() 74 | end 75 | end, 76 | synchronize = function() end, 77 | } 78 | require 'cutorch' 79 | _setup(dataset_name, proposals_path) 80 | end 81 | return self 82 | end 83 | 84 | -- go over all images in the dataset and the proposals and extract the 85 | -- classes and bbox predictions 86 | function module:computeBBoxes() 87 | local aboxes_t = {} 88 | local raw_output = tds.hash() 89 | local raw_bbox_pred = tds.hash() 90 | local timer = torch.Timer() 91 | for i=1, ds:size() do 92 | self.threads:addjob( 93 | function() 94 | return tester:testOne(i) 95 | end, 96 | function(res, raw_res) 97 | aboxes_t[i] = res 98 | if opt.test_save_raw ~= '' then 99 | raw_output[i] = raw_res[1]:float() 100 | raw_bbox_pred[i] = raw_res[2]:float() 101 | end 102 | end 103 | ) 104 | end 105 | self.threads:synchronize() 106 | print("Finished with images in " .. timer:time().real .. " s") 107 | 108 | if opt.test_save_raw ~= '' then 109 | torch.save(opt.test_save_raw, {raw_output, raw_bbox_pred}) 110 | print('Saved raw bboxes at: ' , opt.test_save_raw) 111 | end 112 | 113 | for i = 1,self.nThreads do 114 | self.threads:addjob( 115 | function() collectgarbage(); collectgarbage(); end) 116 | end 117 | self.threads:synchronize() 118 | self.threads = nil 119 | collectgarbage(); collectgarbage(); 120 | print("Thread garbage collected") 121 | aboxes_t = tester:keepTopKPerImage(aboxes_t, 100) -- coco only accepts 100/image 122 | local aboxes = tester:transposeBoxes(aboxes_t) 123 | aboxes_t = nil 124 | collectgarbage(); collectgarbage(); 125 | return aboxes 126 | end 127 | 128 | -- validation only function 129 | function module:evaluateBoxes(aboxes) 130 | return tester:computeAP(aboxes) 131 | end 132 | 133 | function module:test() 134 | local aboxes = self:computeBBoxes() 135 | return self:evaluateBoxes(aboxes) 136 | end 137 | 138 | return module 139 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | require 'torch' 10 | require 'nn' 11 | require 'optim' 12 | require 'xlua' 13 | 14 | local tnt = require 'torchnet' 15 | require 'engines.fboptimengine' 16 | 17 | require 'fbcoco' 18 | 19 | local json = require 'cjson' 20 | local utils = paths.dofile 'utils.lua' 21 | local model_utils = paths.dofile 'models/model_utils.lua' 22 | 23 | opt = { 24 | epoch = 1, 25 | dataset = 'pascal', 26 | train_set = 'trainval', 27 | test_set = 'test', 28 | model = 'alexnet', 29 | year = '2007', 30 | proposal_dir = 'data/proposals/', 31 | proposals = 'deepmask', 32 | images_per_batch = 2, 33 | scale = 600, 34 | max_size = 1000, 35 | learningRate = 1e-3, 36 | dampening = 0, 37 | weightDecay = 0.0005, 38 | momentum = 0.9, 39 | learningRateDecay = 0, 40 | nEpochs = 400, 41 | epochSize = 100, 42 | nDonkeys = 4, 43 | batchSize = 128, 44 | manualSeed = 555, 45 | step = 300, 46 | best_proposals_number = 1000, 47 | snapshot = 100, 48 | criterion = 'ce', 49 | decay = 0.1, 50 | bbox_regression = 1, 51 | retrain = 'no', 52 | train_min_gtroi_size = 0, 53 | train_remove_dropouts = false, 54 | retrain_mean_std = '', 55 | train_nGPU = 1, 56 | test_nGPU = 1, 57 | train_nsamples = -1, -- all samples 58 | test_nsamples = -1, -- all samples 59 | test_best_proposals_number = 500, 60 | disable_memory_efficient_forward=false, 61 | checkpoint=false, 62 | resume='', 63 | extra_proposals_file = '', 64 | method='sgd', 65 | sample_n_per_box = 0, 66 | sample_sigma = 1, 67 | train_min_proposal_size = 0, 68 | integral=false, 69 | imagenet_classes='', 70 | test_num_per_image=100, 71 | save_folder='', 72 | 73 | phase2_epoch=-1, 74 | phase2_learningRate=-1, 75 | phase2_step=-1, 76 | phase2_decay=-1, 77 | 78 | fg_threshold = -1, -- if -1, then set to bg_threshold_max 79 | bg_threshold_min = 0.1, 80 | bg_threshold_max = 0.5, 81 | } 82 | opt = xlua.envparams(opt) 83 | 84 | if opt.fg_threshold < 0 then 85 | opt.fg_threshold = opt.bg_threshold_max 86 | end 87 | if opt.manualSeed == -1 then --random 88 | opt.manualSeed = torch.random(10000) 89 | end 90 | print(opt) 91 | model_opt = {} 92 | 93 | require 'cutorch' 94 | math.randomseed(opt.manualSeed) 95 | cutorch.manualSeedAll(opt.manualSeed) 96 | torch.manualSeed(opt.manualSeed) 97 | 98 | --------------------------------------------------------------------------------------- 99 | -- model 100 | --------------------------------------------------------------------------------------- 101 | assert(opt.images_per_batch % opt.train_nGPU == 0, "images_per_batch must be a multiple of train_nGPU") 102 | opt.num_classes = opt.dataset == 'pascal' and 21 or 81 103 | 104 | local model_data = paths.dofile('models/'..opt.model..'.lua') 105 | local model, transformer, info = table.unpack(model_data) 106 | 107 | if opt.train_remove_dropouts then 108 | model_utils.removeDropouts(model) 109 | end 110 | 111 | -- serialize transformer for donkeys and to be loaded for testing 112 | opt.transformer = paths.concat(opt.save_folder, 'transformer.t7') 113 | torch.save(opt.transformer, transformer) 114 | 115 | if opt.retrain ~= 'no' then 116 | print('Loading a retrain model:'..opt.retrain) 117 | model = torch.load(opt.retrain) 118 | transformer = torch.load(opt.transformer) 119 | end 120 | 121 | local getIterator = require 'data' 122 | local iterator = getIterator() 123 | 124 | local integral_switches 125 | if opt.integral then 126 | if opt.retrain == 'no' then 127 | integral_switches = model_utils.integral(model) 128 | else 129 | local switch = model:findModules'nn.ModeSwitch'[1] 130 | integral_switches = switch:get(1):findModules'nn.SelectTable' 131 | end 132 | end 133 | 134 | model:cuda() 135 | 136 | if not opt.bbox_mask_1d then 137 | model_utils.addBBoxNorm(model, g_mean_std) 138 | end 139 | 140 | model_utils.testModel(model) 141 | 142 | -- set up testing 143 | local test_year = (opt.year == '2007,2012') and '2007' or opt.year 144 | local dataset_name = opt.dataset..'_'..opt.test_set..test_year 145 | local test_folder_name = opt.dataset == 'pascal' and ('VOC'..test_year) or 'coco' 146 | local test_proposals_path = utils.makeProposalPath(opt.proposal_dir, test_folder_name, opt.proposals, opt.test_set) 147 | 148 | -------------------------------------------------------------------------- 149 | -- training 150 | -------------------------------------------------------------------------- 151 | 152 | local samples = {} 153 | 154 | local function createCriterion() 155 | criterion = nn.ParallelCriterion() 156 | :add(nn.CrossEntropyCriterion(), 1) 157 | :add(nn.BBoxRegressionCriterion(), opt.bbox_regression) 158 | return criterion:cuda() 159 | end 160 | 161 | local dataTimer = tnt.TimeMeter() 162 | local timer, batchTimer = tnt.TimeMeter({ unit = true }), tnt.TimeMeter() 163 | local trainLoss = tnt.AverageValueMeter() 164 | local primary_loss = tnt.AverageValueMeter() 165 | local bboxregr_loss = tnt.AverageValueMeter() 166 | 167 | 168 | local engine = tnt.FBOptimEngine() 169 | 170 | 171 | local function json_log(t) print('json_stats: '..json.encode(t)) end 172 | 173 | ----------------------------------------------------------------------------- 174 | 175 | local function log(state, extra) 176 | local info = { 177 | epoch = state.epoch + 1, 178 | learningRate = state.learningRate, 179 | decay = state.decay, 180 | train_time = timer.timer:time().real, 181 | train_loss = trainLoss:value(), 182 | primary_loss = primary_loss:value(), 183 | bboxregr_loss = bboxregr_loss:value(), 184 | } 185 | json_log(utils.merge_table{opt, model_opt, extra, info}) 186 | end 187 | 188 | local function save(model, state, epoch) 189 | opt.test_model = 'model_'..epoch..'.t7' 190 | opt.test_state = 'optimState_'..epoch..'.t7' 191 | local model_path = paths.concat(opt.save_folder, opt.test_model) 192 | local state_path = paths.concat(opt.save_folder, opt.test_state) 193 | 194 | print("Saving model to "..model_path) 195 | torch.save(model_path, utils.checkpoint(model)) 196 | print("Saving state to "..state_path) 197 | torch.save(state_path, state) 198 | end 199 | 200 | local function validate(model) 201 | if opt.test_nGPU > 1 then 202 | print("test_nGPU > 1, running tester in separate threads") 203 | local test_runner = paths.dofile'test_runner.lua' 204 | test_runner:setup(opt.test_nGPU, dataset_name, test_proposals_path) 205 | local res = test_runner:test() 206 | test_runner = nil 207 | tester = nil -- global var 208 | return res 209 | else 210 | print("test_nGPU == 1, running tester in main thread") 211 | model:evaluate() 212 | local ds = paths.dofile'DataSetJSON.lua':create(dataset_name, test_proposals_path, opt.test_nsamples) 213 | ds:loadROIDB(opt.test_best_proposals_number) 214 | local tester = fbcoco.Tester_FRCNN(model,transformer,ds,{opt.scale}, opt.max_size, opt) 215 | local res = tester:test() 216 | model:training() 217 | return res 218 | end 219 | end 220 | 221 | engine.hooks.onStart = function(state) 222 | state.learningRate = opt.learningRate 223 | state.decay = opt.decay 224 | state.step = opt.step 225 | utils.cleanupOptim(state) 226 | if opt.checkpoint then 227 | local filename = checkpoint.resume(state) 228 | if filename then 229 | print("WARNING: restarted from checkpoint:", filename) 230 | elseif opt.resume ~= '' then 231 | print("resuming from checkpoint:", opt.resume) 232 | checkpoint.apply(state, opt.resume) 233 | end 234 | end 235 | end 236 | 237 | engine.hooks.onStartEpoch = function(state) 238 | local epoch = state.epoch + 1 239 | if epoch == opt.phase2_epoch then 240 | print("switching to phase 2") 241 | if state.network.setPhase2 then 242 | state.network:setPhase2() 243 | end 244 | if opt.phase2_learningRate >= 0 then 245 | print("setting learning rate to " .. opt.phase2_learningRate) 246 | state.learningRate = opt.phase2_learningRate 247 | 248 | local optimizer = state.optimizer 249 | for k,v in pairs(optimizer.modulesToOptState) do if v[1] then 250 | for i,u in ipairs(v) do 251 | if u.dfdx then 252 | local curdev = cutorch.getDevice() 253 | cutorch.setDevice(u.dfdx:getDevice()) 254 | u.dfdx:zero() 255 | cutorch.setDevice(curdev) 256 | u.learningRate = state.learningRate 257 | end 258 | end 259 | end end 260 | end 261 | if opt.phase2_step >= 0 then 262 | print("setting step to " .. opt.phase2_step) 263 | state.step = opt.phase2_step 264 | end 265 | if opt.phase2_decay >= 0 then 266 | print("setting decay to " .. opt.phase2_decay) 267 | state.decay = opt.phase2_decay 268 | end 269 | end 270 | 271 | if opt.checkpoint and epoch % opt.snapshot == 0 then 272 | checkpoint.checkpoint(state, opt) 273 | end 274 | print("Training epoch " .. epoch .. "/" .. opt.nEpochs) 275 | trainLoss:reset() 276 | primary_loss:reset() 277 | bboxregr_loss:reset() 278 | timer:reset() 279 | state.n = 0 280 | end 281 | 282 | engine.hooks.onSample = function(state) 283 | cutorch.synchronize(); collectgarbage(); 284 | dataTimer:stop() 285 | 286 | utils.recursiveCast(samples, state.sample, 'torch.CudaTensor') 287 | 288 | if opt.integral then 289 | assert(samples[2][3]) 290 | for i,v in ipairs(integral_switches) do 291 | v.index = samples[2][3] 292 | v.gradInput = {} 293 | end 294 | end 295 | 296 | state.sample.input = samples[1] 297 | state.sample.target = samples[2] 298 | end 299 | 300 | engine.hooks.onUpdate = function(state) 301 | cutorch.synchronize(); collectgarbage(); 302 | state.n = state.n + 1 303 | 304 | local err = state.criterion.output 305 | trainLoss:add(err) 306 | primary_loss:add(state.criterion.criterions[1].output) 307 | bboxregr_loss:add(state.criterion.criterions[2].output) 308 | 309 | timer:incUnit() 310 | 311 | print(('Epoch: [%d][%d/%d]\tTime %.3f (%.3f) DataTime %.3f Err %.4f'):format( 312 | state.epoch + 1, state.n, opt.epochSize, batchTimer:value(), timer:value(), dataTimer:value(), err)) 313 | 314 | dataTimer:reset() 315 | dataTimer:resume() 316 | batchTimer:reset() 317 | end 318 | 319 | engine.hooks.onEndEpoch = function(state) 320 | local epoch = state.epoch + 1 321 | if epoch % state.step == 0 then 322 | print('Dropping learning rate') 323 | state.learningRate = state.learningRate * state.decay 324 | local optimizer = state.optimizer 325 | for k,v in pairs(optimizer.modulesToOptState) do if v[1] then 326 | for i,u in ipairs(v) do 327 | if u.dfdx then 328 | local curdev = cutorch.getDevice() 329 | cutorch.setDevice(u.dfdx:getDevice()) 330 | u.dfdx:mul(state.decay) 331 | cutorch.setDevice(curdev) 332 | u.learningRate = u.learningRate * state.decay 333 | end 334 | end 335 | end end 336 | end 337 | log(state, {finished = 0, voc_metric = 0, coco_metric = 0}) 338 | if epoch % opt.snapshot == 0 then 339 | save(state.network, state.optimizer, epoch) 340 | local res = validate(state.network) 341 | log(state, { 342 | voc_metric = res[2], 343 | coco_metric = res[1], 344 | }) 345 | end 346 | end 347 | 348 | engine.hooks.onEnd = function(state) 349 | print("Done training. Running final validation") 350 | 351 | save(state.network, state.optimizer, 'final') 352 | 353 | opt.test_nsamples = 4952 354 | 355 | local res = validate(state.network) 356 | log(state, { 357 | voc_metric = res[2], 358 | coco_metric = res[1], 359 | }) 360 | end 361 | 362 | 363 | engine:train{ 364 | network = model, 365 | criterion = createCriterion(), 366 | config = opt, 367 | maxepoch = opt.nEpochs, 368 | optimMethod = optim[opt.method], 369 | iterator = iterator, 370 | } 371 | -------------------------------------------------------------------------------- /utils.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | ------------------------------------------------------------------------------]] 8 | 9 | stringx = require('pl.stringx') -- must be global or threads will barf :( 10 | 11 | local tnt = require 'torchnet' 12 | 13 | local utils = {} 14 | 15 | local ffi = require 'ffi' 16 | ffi.cdef[[ 17 | void bbox_vote(THFloatTensor *res, THFloatTensor *nms_boxes, THFloatTensor *scored_boxes, float threshold); 18 | void NMS(THFloatTensor *keep, THFloatTensor *scored_boxes, float overlap); 19 | ]] 20 | 21 | local ok, C = pcall(ffi.load, './libnms.so') 22 | if not ok then 23 | os.execute'make' 24 | ok, C = pcall(ffi.load, './libnms.so') 25 | assert(ok, 'run make and check what is wrong') 26 | end 27 | 28 | 29 | function utils.nms(boxes, overlap) 30 | local keep = torch.FloatTensor() 31 | C.NMS(keep:cdata(), boxes:cdata(), overlap) 32 | return keep 33 | end 34 | 35 | function utils.bbox_vote(nms_boxes, scored_boxes, overlap) 36 | local res = torch.FloatTensor() 37 | C.bbox_vote(res:cdata(), nms_boxes:cdata(), scored_boxes:cdata(), overlap) 38 | return res 39 | end 40 | 41 | 42 | -------------------------------------------------------------------------------- 43 | -- utility functions for the evaluation part 44 | -------------------------------------------------------------------------------- 45 | 46 | function utils.joinTable(input,dim) 47 | local size = torch.LongStorage() 48 | local is_ok = false 49 | for i=1,#input do 50 | local currentOutput = input[i] 51 | if currentOutput:numel() > 0 then 52 | if not is_ok then 53 | size:resize(currentOutput:dim()):copy(currentOutput:size()) 54 | is_ok = true 55 | else 56 | size[dim] = size[dim] + currentOutput:size(dim) 57 | end 58 | end 59 | end 60 | local output = input[1].new():resize(size) 61 | local offset = 1 62 | for i=1,#input do 63 | local currentOutput = input[i] 64 | if currentOutput:numel() > 0 then 65 | output:narrow(dim, offset, 66 | currentOutput:size(dim)):copy(currentOutput) 67 | offset = offset + currentOutput:size(dim) 68 | end 69 | end 70 | return output 71 | end 72 | 73 | -------------------------------------------------------------------------------- 74 | 75 | function utils.keep_top_k(boxes,top_k) 76 | local X = utils.joinTable(boxes,1) 77 | if X:numel() == 0 then 78 | return boxes, 0 79 | end 80 | local scores = X[{{},-1}]:sort(1,true) 81 | local thresh = scores[math.min(scores:numel(),top_k)] 82 | for i=1,#boxes do 83 | local bbox = boxes[i] 84 | if bbox:numel() > 0 then 85 | local idx = torch.range(1,bbox:size(1)):long() 86 | local keep = bbox[{{},-1}]:ge(thresh) 87 | idx = idx[keep] 88 | if idx:numel() > 0 then 89 | boxes[i] = bbox:index(1,idx) 90 | else 91 | boxes[i]:resize() 92 | end 93 | end 94 | end 95 | return boxes, thresh 96 | end 97 | 98 | -------------------------------------------------------------------------------- 99 | -- evaluation 100 | -------------------------------------------------------------------------------- 101 | 102 | -------------------------------------------------------------------------------- 103 | 104 | function utils.boxoverlap(a,b) 105 | local b = b.xmin and {b.xmin,b.ymin,b.xmax,b.ymax} or b 106 | local x1 = a:select(2,1):clone() 107 | x1[x1:lt(b[1])] = b[1] 108 | local y1 = a:select(2,2):clone() 109 | y1[y1:lt(b[2])] = b[2] 110 | local x2 = a:select(2,3):clone() 111 | x2[x2:gt(b[3])] = b[3] 112 | local y2 = a:select(2,4):clone() 113 | y2[y2:gt(b[4])] = b[4] 114 | 115 | local w = x2-x1+1; 116 | local h = y2-y1+1; 117 | local inter = torch.cmul(w,h):float() 118 | local aarea = torch.cmul((a:select(2,3)-a:select(2,1)+1) , 119 | (a:select(2,4)-a:select(2,2)+1)):float() 120 | local barea = (b[3]-b[1]+1) * (b[4]-b[2]+1); 121 | 122 | -- intersection over union overlap 123 | local o = torch.cdiv(inter , (aarea+barea-inter)) 124 | -- set invalid entries to 0 overlap 125 | o[w:lt(0)] = 0 126 | o[h:lt(0)] = 0 127 | return o 128 | end 129 | 130 | 131 | function utils.intersection(a,b) 132 | local b = b.xmin and {b.xmin,b.ymin,b.xmax,b.ymax} or b 133 | local x1 = a:select(2,1):clone() 134 | x1[x1:lt(b[1])] = b[1] 135 | local y1 = a:select(2,2):clone() 136 | y1[y1:lt(b[2])] = b[2] 137 | local x2 = a:select(2,3):clone() 138 | x2[x2:gt(b[3])] = b[3] 139 | local y2 = a:select(2,4):clone() 140 | y2[y2:gt(b[4])] = b[4] 141 | 142 | local w = x2-x1+1; 143 | local h = y2-y1+1; 144 | local inter = torch.cmul(w,h):float() 145 | local aarea = torch.cmul((a:select(2,3)-a:select(2,1)+1) , 146 | (a:select(2,4)-a:select(2,2)+1)):float() 147 | return torch.cdiv(inter, aarea) 148 | end 149 | -------------------------------------------------------------------------------- 150 | 151 | function utils.flipBoxes(boxes, image_width) 152 | local flipped = boxes:clone() 153 | flipped:select(2,1):copy( - boxes:select(2,3) + image_width + 1 ) 154 | flipped:select(2,3):copy( - boxes:select(2,1) + image_width + 1 ) 155 | return flipped 156 | end 157 | 158 | -------------------------------------------------------------------------------- 159 | 160 | function utils.merge_table(elements) 161 | local t = {} 162 | for i,u in ipairs(elements) do 163 | for k,v in pairs(u) do 164 | t[k] = v 165 | end 166 | end 167 | return t 168 | end 169 | 170 | -- bbox, tbox: [x1,y1,x2,y2] 171 | local function convertTo(out, bbox, tbox) 172 | if torch.type(out) == 'table' or out:nDimension() == 1 then 173 | local xc = (bbox[1] + bbox[3]) * 0.5 174 | local yc = (bbox[2] + bbox[4]) * 0.5 175 | local w = bbox[3] - bbox[1] 176 | local h = bbox[4] - bbox[2] 177 | local xtc = (tbox[1] + tbox[3]) * 0.5 178 | local ytc = (tbox[2] + tbox[4]) * 0.5 179 | local wt = tbox[3] - tbox[1] 180 | local ht = tbox[4] - tbox[2] 181 | out[1] = (xtc - xc) / w 182 | out[2] = (ytc - yc) / h 183 | out[3] = math.log(wt / w) 184 | out[4] = math.log(ht / h) 185 | else 186 | local xc = (bbox[{{},1}] + bbox[{{},3}]) * 0.5 187 | local yc = (bbox[{{},2}] + bbox[{{},4}]) * 0.5 188 | local w = bbox[{{},3}] - bbox[{{},1}] 189 | local h = bbox[{{},4}] - bbox[{{},2}] 190 | local xtc = (tbox[{{},1}] + tbox[{{},3}]) * 0.5 191 | local ytc = (tbox[{{},2}] + tbox[{{},4}]) * 0.5 192 | local wt = tbox[{{},3}] - tbox[{{},1}] 193 | local ht = tbox[{{},4}] - tbox[{{},2}] 194 | out[{{},1}] = (xtc - xc):cdiv(w) 195 | out[{{},2}] = (ytc - yc):cdiv(h) 196 | out[{{},3}] = wt:cdiv(w):log() 197 | out[{{},4}] = ht:cdiv(h):log() 198 | end 199 | end 200 | 201 | function utils.convertTo(...) 202 | local arg = {...} 203 | if #arg == 3 then 204 | convertTo(...) 205 | else 206 | local x = arg[1]:clone() 207 | convertTo(x, arg[1], arg[2]) 208 | return x 209 | end 210 | end 211 | 212 | function utils.convertFrom(out, bbox, y) 213 | if torch.type(out) == 'table' or out:nDimension() == 1 then 214 | local xc = (bbox[1] + bbox[3]) * 0.5 215 | local yc = (bbox[2] + bbox[4]) * 0.5 216 | local w = bbox[3] - bbox[1] 217 | local h = bbox[4] - bbox[2] 218 | 219 | local xtc = xc + y[1] * w 220 | local ytc = yc + y[2] * h 221 | local wt = w * math.exp(y[3]) 222 | local ht = h * math.exp(y[4]) 223 | 224 | out[1] = xtc - wt/2 225 | out[2] = ytc - ht/2 226 | out[3] = xtc + wt/2 227 | out[4] = ytc + ht/2 228 | else 229 | assert(bbox:size(2) == y:size(2)) 230 | assert(bbox:size(2) == out:size(2)) 231 | assert(bbox:size(1) == y:size(1)) 232 | assert(bbox:size(1) == out:size(1)) 233 | local xc = (bbox[{{},1}] + bbox[{{},3}]) * 0.5 234 | local yc = (bbox[{{},2}] + bbox[{{},4}]) * 0.5 235 | local w = bbox[{{},3}] - bbox[{{},1}] 236 | local h = bbox[{{},4}] - bbox[{{},2}] 237 | 238 | local xtc = torch.addcmul(xc, y[{{},1}], w) 239 | local ytc = torch.addcmul(yc, y[{{},2}], h) 240 | local wt = torch.exp(y[{{},3}]):cmul(w) 241 | local ht = torch.exp(y[{{},4}]):cmul(h) 242 | 243 | out[{{},1}] = xtc - wt * 0.5 244 | out[{{},2}] = ytc - ht * 0.5 245 | out[{{},3}] = xtc + wt * 0.5 246 | out[{{},4}] = ytc + ht * 0.5 247 | end 248 | end 249 | 250 | -- WARNING: DO NOT USE 251 | -- this function is WIP, it doesn't seem to work yet 252 | function utils.setDataParallelN(model, nGPU) 253 | assert(nGPU) 254 | assert(nGPU >= 1 and nGPU <= cutorch.getDeviceCount()) 255 | for _,m in ipairs(model:listModules()) do 256 | if torch.type(m) == 'nn.DataParallelTable' then 257 | if #m.modules ~= nGPU then 258 | assert(#m.modules >= 1) 259 | local inner = m.modules[1] 260 | inner:float() 261 | m:__init(m.dimension, m.noGradInput) -- reinitialize 262 | for i = 1, nGPU do 263 | cutorch.withDevice(i, function() 264 | m:add(inner:clone():cuda(), i) 265 | end) 266 | end 267 | end 268 | end 269 | end 270 | collectgarbage(); collectgarbage(); 271 | end 272 | 273 | function utils.removeDataParallel(model) 274 | for _,m in ipairs(model:listModules()) do 275 | if m.modules then 276 | for j,inner in ipairs(m.modules) do 277 | if torch.type(inner) == 'nn.DataParallelTable' then 278 | assert(#inner.modules >= 1) 279 | m.modules[j] = inner.modules[1]:float():cuda() -- maybe move to the right GPU 280 | end 281 | end 282 | end 283 | end 284 | -- model:float():cuda() -- maybe move to the right GPU 285 | end 286 | 287 | -- Deletes entries in modulesToOptState for modules that don't have parameters 288 | -- in the network. This includes modules in DataParallelTable that aren't on 289 | -- the primary GPU. 290 | function utils.cleanupOptim(state) 291 | local params, gradParams = state.network:parameters() 292 | local map = {} 293 | for _,param in ipairs(params) do 294 | map[param] = true 295 | end 296 | 297 | local optimizer = state.optimizer 298 | for module, _ in pairs(optimizer.modulesToOptState) do 299 | if not map[module.weight] and not map[module.bias] then 300 | optimizer.modulesToOptState[module] = nil 301 | end 302 | end 303 | end 304 | 305 | function utils.makeProposalPath(proposal_dir, dataset, proposals, set, imagenet) 306 | local res = {} 307 | if set == 'val5k' then set = 'val' end 308 | if set == 'val35k' then set = 'val' end 309 | proposals = stringx.split(proposals, ',') 310 | for i = 1, #proposals do 311 | if dataset=='coco' and set=='trainval' then 312 | table.insert(res, paths.concat(proposal_dir, dataset, proposals[i], 'train.t7')) 313 | table.insert(res, paths.concat(proposal_dir, dataset, proposals[i], 'val.t7')) 314 | elseif dataset=='VOC2007,2012' then 315 | table.insert(res, paths.concat(proposal_dir, 'VOC2007', proposals[i], set .. '.t7')) 316 | table.insert(res, paths.concat(proposal_dir, 'VOC2012', proposals[i], set .. '.t7')) 317 | else 318 | table.insert(res, paths.concat(proposal_dir, dataset, proposals[i], set .. '.t7')) 319 | end 320 | end 321 | 322 | if opt and opt.extra_proposals_file ~= '' then 323 | table.insert(res, opt.extra_proposals_file) 324 | end 325 | 326 | if imagenet then 327 | -- deepmask, cuz that's all we got 328 | table.insert(res, paths.concat(proposal_dir, 'imagenet', 'deepmask', 'train.t7')) 329 | end 330 | 331 | 332 | return res 333 | end 334 | 335 | function utils.saveResults(aboxes, dataset, res_file) 336 | 337 | nClasses = #aboxes 338 | nImages = #aboxes[1] 339 | 340 | local size = 0 341 | for class, rc in pairs(aboxes) do 342 | for i, data in pairs(rc) do 343 | if data:nElement() > 0 then 344 | size = size + data:size(1) 345 | end 346 | end 347 | end 348 | 349 | local out = {} 350 | out.dataset = dataset 351 | out.images = torch.range(1,nImages):float() 352 | local det = {} 353 | out.detections = det 354 | det.boxes = torch.FloatTensor(size, 4) 355 | det.scores = torch.FloatTensor(size) 356 | det.categories = torch.FloatTensor(size) 357 | det.images = torch.FloatTensor(size) 358 | local off = 1 359 | for class = 1, #aboxes do 360 | for i = 1, #aboxes[class] do 361 | local data = aboxes[class][i] 362 | if data:nElement() > 0 then 363 | det.boxes:narrow(1, off, data:size(1)):copy(data:narrow(2,1,4)) 364 | det.scores:narrow(1, off, data:size(1)):copy(data:select(2,5)) 365 | det.categories:narrow(1, off, data:size(1)):fill(class) 366 | det.images:narrow(1, off, data:size(1)):fill(i) 367 | off = off + data:size(1) 368 | end 369 | end 370 | end 371 | torch.save(res_file, out) 372 | end 373 | 374 | -- modified nn.utils 375 | -- accepts different types and numbers 376 | function utils.recursiveCopy(t1,t2) 377 | if torch.type(t2) == 'table' then 378 | t1 = (torch.type(t1) == 'table') and t1 or {t1} 379 | for key,_ in pairs(t2) do 380 | t1[key], t2[key] = utils.recursiveCopy(t1[key], t2[key]) 381 | end 382 | elseif torch.isTensor(t2) then 383 | t1 = torch.isTensor(t1) and t1 or t2.new() 384 | t1:resize(t2:size()):copy(t2) 385 | elseif torch.type(t2) == 'number' then 386 | t1 = t2 387 | else 388 | error("expecting nested tensors or tables. Got ".. 389 | torch.type(t1).." and "..torch.type(t2).." instead") 390 | end 391 | return t1, t2 392 | end 393 | 394 | function utils.recursiveCast(dst, src, type) 395 | if #dst == 0 then 396 | tnt.utils.table.copy(dst, nn.utils.recursiveType(src, type)) 397 | end 398 | utils.recursiveCopy(dst, src) 399 | end 400 | 401 | -- another version of nms that returns indexes instead of new boxes 402 | function utils.nms_dense(boxes, overlap) 403 | local n_boxes = boxes:size(1) 404 | 405 | if n_boxes == 0 then 406 | return torch.LongTensor() 407 | end 408 | 409 | -- sort scores in descending order 410 | assert(boxes:size(2) == 5) 411 | local vals, I = torch.sort(boxes:select(2,5), 1, true) 412 | 413 | -- sort the boxes 414 | local boxes_s = boxes:index(1, I):t():contiguous() 415 | 416 | local suppressed = torch.ByteTensor():resize(boxes_s:size(2)):zero() 417 | 418 | local x1 = boxes_s[1] 419 | local y1 = boxes_s[2] 420 | local x2 = boxes_s[3] 421 | local y2 = boxes_s[4] 422 | local s = boxes_s[5] 423 | 424 | local area = torch.cmul((x2-x1+1), (y2-y1+1)) 425 | 426 | local pick = torch.LongTensor(s:size(1)):zero() 427 | 428 | -- these clones are just for setting the size 429 | local xx1 = x1:clone() 430 | local yy1 = x1:clone() 431 | local xx2 = x1:clone() 432 | local yy2 = x1:clone() 433 | local w = x1:clone() 434 | local h = x1:clone() 435 | 436 | local pickIdx = 1 437 | for c = 1, n_boxes do 438 | if suppressed[c] == 0 then 439 | pick[pickIdx] = I[c] 440 | pickIdx = pickIdx + 1 441 | 442 | xx1:copy(x1):clamp(x1[c], math.huge) 443 | yy1:copy(y1):clamp(y1[c], math.huge) 444 | xx2:copy(x2):clamp(0, x2[c]) 445 | yy2:copy(y2):clamp(0, y2[c]) 446 | 447 | w:add(xx2, -1, xx1):add(1):clamp(0, math.huge) 448 | h:add(yy2, -1, yy1):add(1):clamp(0, math.huge) 449 | local inter = w 450 | inter:cmul(h) 451 | local union = xx1 452 | union:add(area, -1, inter):add(area[c]) 453 | local ol = h 454 | torch.cdiv(ol, inter, union) 455 | 456 | suppressed:add(ol:gt(overlap)):clamp(0,1) 457 | end 458 | end 459 | 460 | pick = pick[{{1,pickIdx-1}}] 461 | return pick 462 | end 463 | 464 | local function deepCopy(tbl) 465 | -- creates a copy of a network with new modules and the same tensors 466 | local copy = {} 467 | for k,v in pairs(tbl) do 468 | -- will skip all DPTs. it also causes stack overflow, idk why 469 | if torch.typename(v) == 'nn.DataParallelTable' then 470 | v = v:get(1) 471 | end 472 | if type(v) == 'table' then 473 | copy[k] = deepCopy(v) 474 | else 475 | copy[k] = v 476 | end 477 | end 478 | if torch.typename(tbl) then 479 | torch.setmetatable(copy, torch.typename(tbl)) 480 | end 481 | return copy 482 | end 483 | 484 | utils.deepCopy = deepCopy 485 | 486 | function utils.checkpoint(net) 487 | return deepCopy(net):float():clearState() 488 | end 489 | 490 | 491 | return utils 492 | --------------------------------------------------------------------------------