The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── CONTRIBUTING.md
├── DataLoader.lua
├── DataSampler.lua
├── DeepMask.lua
├── InferDeepMask.lua
├── InferSharpMask.lua
├── LICENSE
├── PATENTS
├── README.md
├── SharpMask.lua
├── SpatialSymmetricPadding.lua
├── TrainerDeepMask.lua
├── TrainerSharpMask.lua
├── computeProposals.lua
├── data
    ├── teaser.png
    └── testImage.jpg
├── evalPerImage.lua
├── evalPerPatch.lua
├── modelUtils.lua
├── train.lua
└── trainMeters.lua


/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | pretrained/
3 | data/
4 | exps/
5 | 


--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
 1 | # Contributing to deepmask
 2 | We want to make contributing to this project as easy and transparent as
 3 | possible.
 4 | 
 5 | 
 6 | ## Pull Requests
 7 | We actively welcome your pull requests.
 8 | 
 9 | 1. Fork the repo and create your branch from `master`.
10 | 2. If you haven't already, complete the Contributor License Agreement ("CLA").
11 | 
12 | ## Contributor License Agreement ("CLA")
13 | In order to accept your pull request, we need you to submit a CLA. You only need
14 | to do this once to work on any of Facebook's open source projects.
15 | 
16 | Complete your CLA here: <https://code.facebook.com/cla>
17 | 
18 | ## Issues
19 | We use GitHub issues to track public bugs. Please ensure your description is
20 | clear and has sufficient instructions to be able to reproduce the issue.
21 | 
22 | ## Coding Style  
23 | * 2 spaces for indentation rather than tabs
24 | * 80 character line length
25 | 
26 | ## License
27 | By contributing to deepmask, you agree that your contributions will be licensed
28 | under its [BSD license](https://github.com/facebookresearch/deepmask/blob/master/LICENSE).
29 | 


--------------------------------------------------------------------------------
/DataLoader.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | Multi-threaded data loader
  8 | ------------------------------------------------------------------------------]]
  9 | 
 10 | local Threads = require 'threads'
 11 | Threads.serialization('threads.sharedserialize')
 12 | 
 13 | local M = {}
 14 | local DataLoader = torch.class('DataLoader', M)
 15 | 
 16 | --------------------------------------------------------------------------------
 17 | -- function: create train/val data loaders
 18 | function DataLoader.create(config)
 19 |   local loaders = {}
 20 |   for i, split in ipairs{'train', 'val'} do
 21 |     loaders[i] = M.DataLoader(config, split)
 22 |   end
 23 | 
 24 |   return table.unpack(loaders)
 25 | end
 26 | 
 27 | --------------------------------------------------------------------------------
 28 | -- function: init
 29 | function DataLoader:__init(config, split)
 30 |   local function main(idx)
 31 |     torch.setdefaulttensortype('torch.FloatTensor')
 32 |     local seed = config.seed + idx
 33 |     torch.manualSeed(seed)
 34 | 
 35 |     paths.dofile('DataSampler.lua')
 36 |     _G.ds = DataSampler(config, split)
 37 |     return _G.ds:size()
 38 |   end
 39 | 
 40 |   local threads, sizes = Threads(config.nthreads, main)
 41 |   self.threads = threads
 42 |   self.__size = sizes[1][1]
 43 |   self.batch = config.batch
 44 |   self.hfreq = config.hfreq
 45 | end
 46 | 
 47 | --------------------------------------------------------------------------------
 48 | -- function: return size of dataset
 49 | function DataLoader:size()
 50 |   return math.ceil(self.__size / self.batch)
 51 | end
 52 | 
 53 | --------------------------------------------------------------------------------
 54 | -- function: run
 55 | function DataLoader:run()
 56 |   local threads = self.threads
 57 |   local size, batch = self.__size, self.batch
 58 | 
 59 |   local idx, sample = 1, nil
 60 |   local function enqueue()
 61 |     while idx <= size and threads:acceptsjob() do
 62 |       local bsz = math.min(batch, size - idx + 1)
 63 |       threads:addjob(
 64 |         function(bsz, hfreq)
 65 |           local inputs, labels
 66 |           local head -- head sampling
 67 |           if torch.uniform() > hfreq then head = 1 else head = 2 end
 68 | 
 69 |           for i = 1, bsz do
 70 |             local input, label = _G.ds:get(head)
 71 |             if not inputs then
 72 |               local iSz = input:size():totable()
 73 |               local mSz = label:size():totable()
 74 |               inputs = torch.FloatTensor(bsz, table.unpack(iSz))
 75 |               labels = torch.FloatTensor(bsz, table.unpack(mSz))
 76 |             end
 77 |             inputs[i]:copy(input)
 78 |             labels[i]:copy(label)
 79 |           end
 80 |           collectgarbage()
 81 | 
 82 |           return {inputs = inputs, labels = labels, head = head}
 83 |         end,
 84 |         function(_sample_) sample = _sample_ end,
 85 |         bsz, self.hfreq
 86 |       )
 87 |       idx = idx + batch
 88 |     end
 89 |   end
 90 | 
 91 |   local n = 0
 92 |   local function loop()
 93 |     enqueue()
 94 |     if not threads:hasjob() then return nil end
 95 |     threads:dojob()
 96 |     if threads:haserror() then threads:synchronize() end
 97 |     enqueue()
 98 |     n = n + 1
 99 |     return n, sample
100 |   end
101 | 
102 |   return loop
103 | end
104 | 
105 | return M.DataLoader
106 | 


--------------------------------------------------------------------------------
/DataSampler.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | Dataset sampler for for training/evaluation of DeepMask and SharpMask
  8 | ------------------------------------------------------------------------------]]
  9 | 
 10 | require 'torch'
 11 | require 'image'
 12 | local tds = require 'tds'
 13 | local coco = require 'coco'
 14 | 
 15 | local DataSampler = torch.class('DataSampler')
 16 | 
 17 | --------------------------------------------------------------------------------
 18 | -- function: init
 19 | function DataSampler:__init(config,split)
 20 |   assert(split == 'train' or split == 'val')
 21 | 
 22 |   -- coco api
 23 |   local annFile = string.format('%s/annotations/instances_%s2014.json',
 24 |   config.datadir,split)
 25 |   self.coco = coco.CocoApi(annFile)
 26 | 
 27 |   -- mask api
 28 |   self.maskApi = coco.MaskApi
 29 | 
 30 |   -- mean/std computed from random subset of ImageNet training images
 31 |   self.mean, self.std = {0.485, 0.456, 0.406}, {0.229, 0.224, 0.225}
 32 | 
 33 |   -- class members
 34 |   self.datadir = config.datadir
 35 |   self.split = split
 36 | 
 37 |   self.iSz = config.iSz
 38 |   self.objSz = math.ceil(config.iSz*128/224)
 39 |   self.wSz = config.iSz + 32
 40 |   self.gSz = config.gSz
 41 |   self.scale = config.scale
 42 |   self.shift = config.shift
 43 | 
 44 |   self.imgIds = self.coco:getImgIds()
 45 |   self.annIds = self.coco:getAnnIds()
 46 |   self.catIds = self.coco:getCatIds()
 47 |   self.nImages = self.imgIds:size(1)
 48 | 
 49 |   if split == 'train' then self.__size  = config.maxload*config.batch
 50 |   elseif split == 'val' then self.__size = config.testmaxload*config.batch end
 51 | 
 52 |   if config.hfreq > 0 then
 53 |     self.scales = {} -- scale range for score sampling
 54 |     for scale = -3,2,.25 do table.insert(self.scales,scale) end
 55 |     self:createBBstruct(self.objSz,config.scale)
 56 |   end
 57 | 
 58 |   collectgarbage()
 59 | end
 60 | local function log2(x) return math.log(x)/math.log(2) end
 61 | 
 62 | --------------------------------------------------------------------------------
 63 | -- function: create BB struct of objects for score sampling
 64 | -- each key k contain the scale and bb information of all annotations of
 65 | -- image k
 66 | function DataSampler:createBBstruct(objSz,scale)
 67 |   local bbStruct = tds.Vec()
 68 | 
 69 |   for i = 1, self.nImages do
 70 |     local annIds = self.coco:getAnnIds({imgId=self.imgIds[i]})
 71 |     local bbs = {scales = {}}
 72 |     if annIds:dim() ~= 0 then
 73 |       for i = 1,annIds:size(1) do
 74 |         local annId = annIds[i]
 75 |         local ann = self.coco:loadAnns(annId)[1]
 76 |         local bbGt = ann.bbox
 77 |         local x0,y0,w,h = bbGt[1],bbGt[2],bbGt[3],bbGt[4]
 78 |         local xc,yc, maxDim = x0+w/2,y0+h/2, math.max(w,h)
 79 | 
 80 |         for s = -32,32,1 do
 81 |           if maxDim > objSz*2^((s-1)*scale) and
 82 |             maxDim <= objSz*2^((s+1)*(scale)) then
 83 |             local ss = -s*scale
 84 |             local xcS,ycS = xc*2^ss,yc*2^ss
 85 |             if not bbs[ss] then
 86 |               bbs[ss] = {}; table.insert(bbs.scales,ss)
 87 |             end
 88 |             table.insert(bbs[ss],{xcS,ycS,category_id=ann.category})
 89 |             break
 90 |           end
 91 |         end
 92 |       end
 93 |     end
 94 |     bbStruct:insert(tds.Hash(bbs))
 95 |   end
 96 |   collectgarbage()
 97 |   self.bbStruct = bbStruct
 98 | end
 99 | 
100 | --------------------------------------------------------------------------------
101 | -- function: get size of epoch
102 | function DataSampler:size()
103 |   return self.__size
104 | end
105 | 
106 | --------------------------------------------------------------------------------
107 | -- function: get a sample
108 | function DataSampler:get(headSampling)
109 |   local input,label
110 |   if headSampling == 1 then -- sample masks
111 |     input, label = self:maskSampling()
112 |   else -- sample score
113 |     input,label = self:scoreSampling()
114 |   end
115 | 
116 |   if torch.uniform() > .5 then
117 |     input = image.hflip(input)
118 |     if headSampling == 1 then label = image.hflip(label) end
119 |   end
120 | 
121 |   -- normalize input
122 |   for i=1,3 do input:narrow(1,i,1):add(-self.mean[i]):div(self.std[i]) end
123 | 
124 |   return input,label
125 | end
126 | 
127 | --------------------------------------------------------------------------------
128 | -- function: mask sampling
129 | function DataSampler:maskSampling()
130 |   local iSz,wSz,gSz = self.iSz,self.wSz,self.gSz
131 | 
132 |   local cat,ann = torch.random(80)
133 |   while not ann or ann.iscrowd == 1 or ann.area < 100 or ann.bbox[3] < 5
134 |     or ann.bbox[4] < 5 do
135 |       local catId = self.catIds[cat]
136 |       local annIds = self.coco:getAnnIds({catId=catId})
137 |       local annid = annIds[torch.random(annIds:size(1))]
138 |       ann = self.coco:loadAnns(annid)[1]
139 |   end
140 |   local bbox = self:jitterBox(ann.bbox)
141 |   local imgName = self.coco:loadImgs(ann.image_id)[1].file_name
142 | 
143 |   -- input
144 |   local pathImg = string.format('%s/%s2014/%s',self.datadir,self.split,imgName)
145 |   local inp = image.load(pathImg,3)
146 |   local h, w = inp:size(2), inp:size(3)
147 |   inp = self:cropTensor(inp, bbox, 0.5)
148 |   inp = image.scale(inp, wSz, wSz)
149 | 
150 |   -- label
151 |   local iSzR = iSz*(bbox[3]/wSz)
152 |   local xc, yc = bbox[1]+bbox[3]/2, bbox[2]+bbox[4]/2
153 |   local bboxInpSz = {xc-iSzR/2,yc-iSzR/2,iSzR,iSzR}
154 |   local lbl = self:cropMask(ann, bboxInpSz, h, w, gSz)
155 |   lbl:mul(2):add(-1)
156 | 
157 |   return inp, lbl
158 | end
159 | 
160 | --------------------------------------------------------------------------------
161 | -- function: score head sampler
162 | local imgPad = torch.Tensor()
163 | function DataSampler:scoreSampling(cat,imgId)
164 |   local idx,bb
165 |   repeat
166 |     idx = torch.random(1,self.nImages)
167 |     bb = self.bbStruct[idx]
168 |   until #bb.scales ~= 0
169 | 
170 |   local imgId = self.imgIds[idx]
171 |   local imgName = self.coco:loadImgs(imgId)[1].file_name
172 |   local pathImg = string.format('%s/%s2014/%s',self.datadir,self.split,imgName)
173 |   local img = image.load(pathImg,3)
174 |   local h,w = img:size(2),img:size(3)
175 | 
176 |   -- sample central pixel of BB to be used
177 |   local x,y,scale
178 |   local lbl = torch.Tensor(1)
179 |   if torch.uniform() > .5 then
180 |     x,y,scale = self:posSamplingBB(bb)
181 |     lbl:fill(1)
182 |   else
183 |     x,y,scale = self:negSamplingBB(bb,w,h)
184 |     lbl:fill(-1)
185 |   end
186 | 
187 |   local s = 2^-scale
188 |   x,y  = math.min(math.max(x*s,1),w), math.min(math.max(y*s,1),h)
189 |   local isz = math.max(self.wSz*s,10)
190 |   local bw =isz/2
191 | 
192 |   --pad/crop/rescale
193 |   imgPad:resize(3,h+2*bw,w+2*bw):fill(.5)
194 |   imgPad:narrow(2,bw+1,h):narrow(3,bw+1,w):copy(img)
195 |   local inp = imgPad:narrow(2,y,isz):narrow(3,x,isz)
196 |   inp = image.scale(inp,self.wSz,self.wSz)
197 | 
198 |   return inp,lbl
199 | end
200 | 
201 | --------------------------------------------------------------------------------
202 | -- function: crop bbox b from inp tensor
203 | function DataSampler:cropTensor(inp, b, pad)
204 |   pad = pad or 0
205 |   b[1], b[2] = torch.round(b[1])+1, torch.round(b[2])+1 -- 0 to 1 index
206 |   b[3], b[4] = torch.round(b[3]), torch.round(b[4])
207 | 
208 |   local out, h, w, ind
209 |   if #inp:size() == 3 then
210 |     ind, out = 2, torch.Tensor(inp:size(1), b[3], b[4]):fill(pad)
211 |   elseif #inp:size() == 2 then
212 |     ind, out = 1, torch.Tensor(b[3], b[4]):fill(pad)
213 |   end
214 |   h, w = inp:size(ind), inp:size(ind+1)
215 | 
216 |   local xo1,yo1,xo2,yo2 = b[1],b[2],b[3]+b[1]-1,b[4]+b[2]-1
217 |   local xc1,yc1,xc2,yc2 = 1,1,b[3],b[4]
218 | 
219 |   -- compute box on binary mask inp and cropped mask out
220 |   if b[1] < 1 then xo1=1; xc1=1+(1-b[1]) end
221 |   if b[2] < 1 then yo1=1; yc1=1+(1-b[2]) end
222 |   if b[1]+b[3]-1 > w then xo2=w; xc2=xc2-(b[1]+b[3]-1-w) end
223 |   if b[2]+b[4]-1 > h then yo2=h; yc2=yc2-(b[2]+b[4]-1-h) end
224 |   local xo, yo, wo, ho = xo1, yo1, xo2-xo1+1, yo2-yo1+1
225 |   local xc, yc, wc, hc = xc1, yc1, xc2-xc1+1, yc2-yc1+1
226 |   if yc+hc-1 > out:size(ind)   then hc = out:size(ind  )-yc+1 end
227 |   if xc+wc-1 > out:size(ind+1) then wc = out:size(ind+1)-xc+1 end
228 |   if yo+ho-1 > inp:size(ind)   then ho = inp:size(ind  )-yo+1 end
229 |   if xo+wo-1 > inp:size(ind+1) then wo = inp:size(ind+1)-xo+1 end
230 |   out:narrow(ind,yc,hc); out:narrow(ind+1,xc,wc)
231 |   inp:narrow(ind,yo,ho); inp:narrow(ind+1,xo,wo)
232 |   out:narrow(ind,yc,hc):narrow(ind+1,xc,wc):copy(
233 |   inp:narrow(ind,yo,ho):narrow(ind+1,xo,wo))
234 | 
235 |   return out
236 | end
237 | 
238 | --------------------------------------------------------------------------------
239 | -- function: crop bbox from mask
240 | function DataSampler:cropMask(ann, bbox, h, w, sz)
241 |   local mask = torch.FloatTensor(sz,sz)
242 |   local seg = ann.segmentation
243 | 
244 |   local scale = sz / bbox[3]
245 |   local polS = {}
246 |   for m, segm in pairs(seg) do
247 |     polS[m] = torch.DoubleTensor():resizeAs(segm):copy(segm); polS[m]:mul(scale)
248 |   end
249 |   local bboxS = {}
250 |   for m = 1,#bbox do bboxS[m] = bbox[m]*scale end
251 | 
252 |   local Rs = self.maskApi.frPoly(polS, h*scale, w*scale)
253 |   local mo = self.maskApi.decode(Rs)
254 |   local mc = self:cropTensor(mo, bboxS)
255 |   mask:copy(image.scale(mc,sz,sz):gt(0.5))
256 | 
257 |   return mask
258 | end
259 | 
260 | --------------------------------------------------------------------------------
261 | -- function: jitter bbox
262 | function DataSampler:jitterBox(box)
263 |   local x, y, w, h = box[1], box[2], box[3], box[4]
264 |   local xc, yc = x+w/2, y+h/2
265 |   local maxDim = math.max(w,h)
266 |   local scale = log2(maxDim/self.objSz)
267 |   local s = scale + torch.uniform(-self.scale,self.scale)
268 |   xc = xc + torch.uniform(-self.shift,self.shift)*2^s
269 |   yc = yc + torch.uniform(-self.shift,self.shift)*2^s
270 |   w, h = self.wSz*2^s, self.wSz*2^s
271 |   return {xc-w/2, yc-h/2,w,h}
272 | end
273 | 
274 | --------------------------------------------------------------------------------
275 | --function: posSampling: do positive sampling
276 | function DataSampler:posSamplingBB(bb)
277 |   local r = math.random(1,#bb.scales)
278 |   local scale = bb.scales[r]
279 |   r=torch.random(1,#bb[scale])
280 |   local x,y = bb[scale][r][1], bb[scale][r][2]
281 |   return x,y,scale
282 | end
283 | 
284 | --------------------------------------------------------------------------------
285 | --function: negSampling: do negative sampling
286 | function DataSampler:negSamplingBB(bb,w0,h0)
287 |   local x,y,scale
288 |   local negSample,c = false,0
289 |   while not negSample and c < 100 do
290 |     local r = math.random(1,#self.scales)
291 |     scale = self.scales[r]
292 |     x,y = math.random(1,w0*2^scale),math.random(1,h0*2^scale)
293 |     negSample = true
294 |     for s = -10,10 do
295 |       local ss = scale+s*self.scale
296 |       if bb[ss] then
297 |         for _,c in pairs(bb[ss]) do
298 |           local dist = math.sqrt(math.pow(x-c[1],2)+math.pow(y-c[2],2))
299 |           if dist < 3*self.shift then
300 |             negSample = false
301 |             break
302 |           end
303 |         end
304 |       end
305 |       if negSample == false then break end
306 |     end
307 |     c=c+1
308 |   end
309 |    return x,y,scale
310 | end
311 | 
312 | return DataSampler
313 | 


--------------------------------------------------------------------------------
/DeepMask.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | When initialized, it creates/load the common trunk, the maskBranch and the
  8 | scoreBranch.
  9 | DeepMask class members:
 10 |   - trunk: the common trunk (modified pre-trained resnet50)
 11 |   - maskBranch: the mask head architecture
 12 |   - scoreBranch: the score head architecture
 13 | ------------------------------------------------------------------------------]]
 14 | 
 15 | require 'nn'
 16 | require 'nnx'
 17 | require 'cunn'
 18 | require 'cudnn'
 19 | paths.dofile('SpatialSymmetricPadding.lua')
 20 | local utils = paths.dofile('modelUtils.lua')
 21 | 
 22 | local DeepMask,_ = torch.class('nn.DeepMask','nn.Container')
 23 | 
 24 | --------------------------------------------------------------------------------
 25 | -- function: constructor
 26 | function DeepMask:__init(config)
 27 |   -- create common trunk
 28 |   self:createTrunk(config)
 29 | 
 30 |   -- create mask head
 31 |   self:createMaskBranch(config)
 32 | 
 33 |   -- create score head
 34 |   self:createScoreBranch(config)
 35 | 
 36 |   -- number of parameters
 37 |   local npt,nps,npm = 0,0,0
 38 |   local p1,p2,p3  = self.trunk:parameters(),
 39 |     self.maskBranch:parameters(),self.scoreBranch:parameters()
 40 |   for k,v in pairs(p1) do npt = npt+v:nElement() end
 41 |   for k,v in pairs(p2) do npm = npm+v:nElement() end
 42 |   for k,v in pairs(p3) do nps = nps+v:nElement() end
 43 |   print(string.format('| number of paramaters trunk: %d', npt))
 44 |   print(string.format('| number of paramaters mask branch: %d', npm))
 45 |   print(string.format('| number of paramaters score branch: %d', nps))
 46 |   print(string.format('| number of paramaters total: %d', npt+nps+npm))
 47 | end
 48 | 
 49 | --------------------------------------------------------------------------------
 50 | -- function: create common trunk
 51 | function DeepMask:createTrunk(config)
 52 |   -- size of feature maps at end of trunk
 53 |   self.fSz = config.iSz/16
 54 | 
 55 |   -- load trunk
 56 |   local trunk = torch.load('pretrained/resnet-50.t7')
 57 | 
 58 |   -- remove BN
 59 |   utils.BNtoFixed(trunk, true)
 60 | 
 61 |   -- remove fully connected layers
 62 |   trunk:remove();trunk:remove();trunk:remove();trunk:remove()
 63 | 
 64 |   -- crop central pad
 65 |   trunk:add(nn.SpatialZeroPadding(-1,-1,-1,-1))
 66 | 
 67 |   -- add common extra layers
 68 |   trunk:add(cudnn.SpatialConvolution(1024,128,1,1,1,1))
 69 |   trunk:add(cudnn.ReLU())
 70 |   trunk:add(nn.View(config.batch,128*self.fSz*self.fSz))
 71 |   trunk:add(nn.Linear(128*self.fSz*self.fSz,512))
 72 | 
 73 |   -- from scratch? reset the parameters
 74 |   if config.scratch then
 75 |     for k,m in pairs(trunk.modules) do if m.weight then m:reset() end end
 76 |   end
 77 | 
 78 |   -- symmetricPadding
 79 |   utils.updatePadding(trunk, nn.SpatialSymmetricPadding)
 80 | 
 81 |   self.trunk = trunk:cuda()
 82 |   return trunk
 83 | end
 84 | 
 85 | --------------------------------------------------------------------------------
 86 | -- function: create mask branch
 87 | function DeepMask:createMaskBranch(config)
 88 |   local maskBranch = nn.Sequential()
 89 | 
 90 |   -- maskBranch
 91 |   maskBranch:add(nn.Linear(512,config.oSz*config.oSz))
 92 |   self.maskBranch = nn.Sequential():add(maskBranch:cuda())
 93 | 
 94 |   -- upsampling layer
 95 |   if config.gSz > config.oSz then
 96 |     local upSample = nn.Sequential()
 97 |     upSample:add(nn.Copy('torch.CudaTensor','torch.FloatTensor'))
 98 |     upSample:add(nn.View(config.batch,config.oSz,config.oSz))
 99 |     upSample:add(nn.SpatialReSamplingEx{owidth=config.gSz,oheight=config.gSz,
100 |     mode='bilinear'})
101 |     upSample:add(nn.View(config.batch,config.gSz*config.gSz))
102 |     upSample:add(nn.Copy('torch.FloatTensor','torch.CudaTensor'))
103 |     self.maskBranch:add(upSample)
104 |   end
105 | 
106 |   return self.maskBranch
107 | end
108 | 
109 | --------------------------------------------------------------------------------
110 | -- function: create score branch
111 | function DeepMask:createScoreBranch(config)
112 |   local scoreBranch = nn.Sequential()
113 |   scoreBranch:add(nn.Dropout(.5))
114 |   scoreBranch:add(nn.Linear(512,1024))
115 |   scoreBranch:add(nn.Threshold(0, 1e-6))
116 | 
117 |   scoreBranch:add(nn.Dropout(.5))
118 |   scoreBranch:add(nn.Linear(1024,1))
119 | 
120 |   self.scoreBranch = scoreBranch:cuda()
121 |   return self.scoreBranch
122 | end
123 | 
124 | --------------------------------------------------------------------------------
125 | -- function: training
126 | function DeepMask:training()
127 |   self.trunk:training(); self.maskBranch:training(); self.scoreBranch:training()
128 | end
129 | 
130 | --------------------------------------------------------------------------------
131 | -- function: evaluate
132 | function DeepMask:evaluate()
133 |   self.trunk:evaluate(); self.maskBranch:evaluate(); self.scoreBranch:evaluate()
134 | end
135 | 
136 | --------------------------------------------------------------------------------
137 | -- function: to cuda
138 | function DeepMask:cuda()
139 |   self.trunk:cuda(); self.scoreBranch:cuda(); self.maskBranch:cuda()
140 | end
141 | 
142 | --------------------------------------------------------------------------------
143 | -- function: to float
144 | function DeepMask:float()
145 |   self.trunk:float(); self.scoreBranch:float(); self.maskBranch:float()
146 | end
147 | 
148 | --------------------------------------------------------------------------------
149 | -- function: inference (used for full scene inference)
150 | function DeepMask:inference()
151 |   self.trunk:evaluate()
152 |   self.maskBranch:evaluate()
153 |   self.scoreBranch:evaluate()
154 | 
155 |   utils.linear2convTrunk(self.trunk,self.fSz)
156 |   utils.linear2convHead(self.scoreBranch)
157 |   utils.linear2convHead(self.maskBranch.modules[1])
158 |   self.maskBranch = self.maskBranch.modules[1]
159 | 
160 |   self:cuda()
161 | end
162 | 
163 | --------------------------------------------------------------------------------
164 | -- function: clone
165 | function DeepMask:clone(...)
166 |   local f = torch.MemoryFile("rw"):binary()
167 |   f:writeObject(self); f:seek(1)
168 |   local clone = f:readObject(); f:close()
169 | 
170 |   if select('#',...) > 0 then
171 |     clone.trunk:share(self.trunk,...)
172 |     clone.maskBranch:share(self.maskBranch,...)
173 |     clone.scoreBranch:share(self.scoreBranch,...)
174 |   end
175 | 
176 |   return clone
177 | end
178 | 
179 | return nn.DeepMask
180 | 


--------------------------------------------------------------------------------
/InferDeepMask.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | Inference module for DeepMask
  8 | ------------------------------------------------------------------------------]]
  9 | 
 10 | require 'image'
 11 | local argcheck = require 'argcheck'
 12 | 
 13 | local Infer = torch.class('Infer')
 14 | 
 15 | --------------------------------------------------------------------------------
 16 | -- function: unfold the mask output into a matrix of masks
 17 | local function unfoldMasksMatrix(masks)
 18 |   local umasks = {}
 19 |   local oSz = math.sqrt(masks[1]:size(1))
 20 |   for _,mask in pairs(masks) do
 21 |     local umask = mask:reshape(oSz,oSz,mask:size(2),mask:size(3))
 22 |     umask=umask:transpose(1,3):transpose(2,3):transpose(2,4):transpose(3,4)
 23 |     table.insert(umasks,umask)
 24 |   end
 25 |   return umasks
 26 | end
 27 | 
 28 | --------------------------------------------------------------------------------
 29 | -- function: init
 30 | Infer.__init = argcheck{
 31 |   noordered = true,
 32 |   {name="self", type="Infer"},
 33 |   {name="np", type="number",default=500},
 34 |   {name="scales", type="table"},
 35 |   {name="meanstd", type="table"},
 36 |   {name="model", type="nn.Container"},
 37 |   {name="iSz", type="number", default=160},
 38 |   {name="dm", type="boolean", default=true},
 39 |   {name="timer", type="boolean", default=false},
 40 |   call =
 41 |     function(self, np, scales, meanstd, model, iSz, dm, timer)
 42 |       --model
 43 |       self.trunk = model.trunk
 44 |       self.mHead = model.maskBranch
 45 |       self.sHead = model.scoreBranch
 46 | 
 47 |       -- number of proposals
 48 |       self.np = np
 49 | 
 50 |       --mean/std
 51 |       self.mean, self.std = meanstd.mean, meanstd.std
 52 | 
 53 |       -- input size and border width
 54 |       self.iSz, self.bw = iSz, iSz/2
 55 | 
 56 |       -- timer
 57 |       if timer then self.timer = torch.Tensor(6):zero() end
 58 | 
 59 |       -- create  scale pyramid
 60 |       self.scales = scales
 61 |       self.pyramid = nn.ConcatTable()
 62 |       for i = 1,#scales do
 63 |         self.pyramid:add(nn.SpatialReSamplingEx{rwidth=scales[i],
 64 |           rheight=scales[i], mode='bilinear'})
 65 |       end
 66 | 
 67 |       -- allocate topScores and topMasks
 68 |       self.topScores = torch.Tensor()
 69 |       self.topMasks = torch.ByteTensor()
 70 |     end
 71 | }
 72 | 
 73 | --------------------------------------------------------------------------------
 74 | -- function: forward
 75 | local inpPad = torch.CudaTensor()
 76 | function Infer:forward(input)
 77 |   if input:type() == 'torch.CudaTensor' then input = input:float() end
 78 | 
 79 |   -- forward pyramid
 80 |   if self.timer then sys.tic() end
 81 |   local inpPyramid = self.pyramid:forward(input)
 82 |   if self.timer then self.timer:narrow(1,1,1):add(sys.toc()) end
 83 | 
 84 |   -- forward all scales through network
 85 |   local outPyramidMask,outPyramidScore = {},{}
 86 |   for i,_ in pairs(inpPyramid) do
 87 |     local inp = inpPyramid[i]:cuda()
 88 |     local h,w = inp:size(2),inp:size(3)
 89 | 
 90 |     -- padding/normalize
 91 |     if self.timer then sys.tic() end
 92 |     inpPad:resize(1,3,h+2*self.bw,w+2*self.bw):fill(.5)
 93 |     inpPad:narrow(1,1,1):narrow(3,self.bw+1,h):narrow(4,self.bw+1,w):copy(inp)
 94 |     for i=1,3 do inpPad[1][i]:add(-self.mean[i]):div(self.std[i]) end
 95 |     cutorch.synchronize()
 96 |     if self.timer then self.timer:narrow(1,2,1):add(sys.toc()) end
 97 | 
 98 |     -- forward trunk
 99 |     if self.timer then sys.tic() end
100 |     local outTrunk = self.trunk:forward(inpPad):squeeze()
101 |     cutorch.synchronize()
102 |     if self.timer then self.timer:narrow(1,3,1):add(sys.toc()) end
103 | 
104 |     -- forward score branch
105 |     if self.timer then sys.tic() end
106 |     local outScore = self.sHead:forward(outTrunk)
107 |     cutorch.synchronize()
108 |     if self.timer then self.timer:narrow(1,4,1):add(sys.toc()) end
109 |     table.insert(outPyramidScore,outScore:clone():squeeze())
110 | 
111 |     -- forward mask branch
112 |     if self.timer then sys.tic() end
113 |     local outMask = self.mHead:forward(outTrunk)
114 |     cutorch.synchronize()
115 |     if self.timer then self.timer:narrow(1,5,1):add(sys.toc()) end
116 |     table.insert(outPyramidMask,outMask:float():squeeze())
117 |   end
118 | 
119 |   self.mask = unfoldMasksMatrix(outPyramidMask)
120 |   self.score = outPyramidScore
121 | 
122 |   if self.timer then self.timer:narrow(1,6,1):add(1) end
123 | end
124 | 
125 | --------------------------------------------------------------------------------
126 | -- function: get top scores
127 | -- return a tensor k x 4, where k is the number of top scores.
128 | -- each line contains: the score value, the scaleNb and position(of M(:))
129 | local sortedScores = torch.Tensor()
130 | local sortedIds = torch.Tensor()
131 | local pos = torch.Tensor()
132 | function Infer:getTopScores()
133 |   local topScores = self.topScores
134 | 
135 |   -- sort scores/ids for each scale
136 |   local nScales=#self.scales
137 |   local rowN=self.score[nScales]:size(1)*self.score[nScales]:size(2)
138 |   sortedScores:resize(rowN,nScales):zero()
139 |   sortedIds:resize(rowN,nScales):zero()
140 |   for s = 1,nScales do
141 |     self.score[s]:mul(-1):exp():add(1):pow(-1) -- scores2prob
142 | 
143 |     local sc = self.score[s]
144 |     local h,w = sc:size(1),sc:size(2)
145 | 
146 |     local sc=sc:view(h*w)
147 |     local sS,sIds=torch.sort(sc,true)
148 |     local sz = sS:size(1)
149 |     sortedScores:narrow(2,s,1):narrow(1,1,sz):copy(sS)
150 |     sortedIds:narrow(2,s,1):narrow(1,1,sz):copy(sIds)
151 |   end
152 | 
153 |   -- get top scores
154 |   local np = self.np
155 |   pos:resize(nScales):fill(1)
156 |   topScores:resize(np,4):fill(1)
157 |   np=math.min(np,rowN)
158 | 
159 |   for i = 1,np do
160 |     local scale,score = 0,0
161 |     for k = 1,nScales do
162 |       if sortedScores[pos[k]][k] > score then
163 |         score = sortedScores[pos[k]][k]
164 |         scale = k
165 |       end
166 |     end
167 |     local temp=sortedIds[pos[scale]][scale]
168 |     local x=math.floor(temp/self.score[scale]:size(2))
169 |     local y=temp%self.score[scale]:size(2)+1
170 |     x,y=math.max(1,x),math.max(1,y)
171 | 
172 |     pos[scale]=pos[scale]+1
173 |     topScores:narrow(1,i,1):copy(torch.Tensor({score,scale,x,y}))
174 |   end
175 | 
176 |   return topScores
177 | end
178 | 
179 | --------------------------------------------------------------------------------
180 | -- function: get top masks.
181 | local imgMask = torch.ByteTensor()
182 | function Infer:getTopMasks(thr,h,w)
183 |   local topMasks = self.topMasks
184 | 
185 |   thr = math.log(thr/(1-thr)) -- 1/(1+e^-s) > th => s > log(1-th)
186 | 
187 |   local masks,topScores,np = self.mask,self.topScores,self.np
188 |   topMasks:resize(np,h,w):zero()
189 |   imgMask:resize(h,w)
190 |   local imgMaskPtr = imgMask:data()
191 | 
192 |   for i = 1,np do
193 |     imgMask:zero()
194 |     local scale,x,y=topScores[i][2], topScores[i][3], topScores[i][4]
195 |     local s=self.scales[scale]
196 |     local sz = math.floor(self.iSz/s)
197 |     local mask = masks[scale]
198 |     x,y = math.min(x,mask:size(1)),math.min(y,mask:size(2))
199 |     mask = mask[x][y]:float()
200 |     local mask = image.scale(mask,sz,sz,'bilinear')
201 |     local mask_ptr = mask:data()
202 | 
203 |     local t = 16/s
204 |     local delta = self.iSz/2/s
205 |     for im =0, sz-1 do
206 |       local ii = math.floor((x-1)*t-delta+im)
207 |       for jm = 0,sz- 1 do
208 |         local jj=math.floor((y-1)*t-delta+jm)
209 |         if  mask_ptr[sz*im + jm] > thr and
210 |           ii >= 0 and ii <= h-1 and jj >= 0 and jj <= w-1 then
211 |           imgMaskPtr[jj+ w*ii]=1
212 |         end
213 |       end
214 |     end
215 |     topMasks:narrow(1,i,1):copy(imgMask)
216 |   end
217 | 
218 |   return topMasks
219 | end
220 | 
221 | --------------------------------------------------------------------------------
222 | -- function: get top proposals
223 | function Infer:getTopProps(thr,h,w)
224 |   self:getTopScores()
225 |   self:getTopMasks(thr,h,w)
226 |   return self.topMasks, self.topScores
227 | end
228 | 
229 | --------------------------------------------------------------------------------
230 | -- function: display timer
231 | function Infer:printTiming()
232 |   local t = self.timer
233 |   t:div(t[t:size(1)])
234 | 
235 |   print('| time pyramid:',t[1])
236 |   print('| time pre-process:',t[2])
237 |   print('| time trunk:',t[3])
238 |   print('| time score branch:',t[4])
239 |   print('| time mask branch:',t[5])
240 |   print('| time total:',t:narrow(1,1,t:size(1)-1):sum())
241 | end
242 | 
243 | return Infer
244 | 


--------------------------------------------------------------------------------
/InferSharpMask.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | Inference module for SharpMask
  8 | ------------------------------------------------------------------------------]]
  9 | 
 10 | require 'image'
 11 | local argcheck = require 'argcheck'
 12 | 
 13 | local Infer = torch.class('Infer')
 14 | 
 15 | --------------------------------------------------------------------------------
 16 | -- function: init
 17 | Infer.__init = argcheck{
 18 |   noordered = true,
 19 |   {name="self", type="Infer"},
 20 |   {name="np", type="number",default=500},
 21 |   {name="scales", type="table"},
 22 |   {name="meanstd", type="table"},
 23 |   {name="model", type="nn.Container"},
 24 |   {name="iSz", type="number", default=160},
 25 |   {name="dm", type="boolean", default=false},
 26 |   {name="timer", type="boolean", default=false},
 27 |   call =
 28 |     function(self, np, scales, meanstd, model, iSz, dm, timer)
 29 |       --model
 30 |       self.trunk = model.trunk
 31 |       self.mBranch = model.maskBranchDM
 32 |       self.sBranch = model.scoreBranch
 33 |       self.refs = model.refs
 34 |       self.neths = model.neths
 35 |       self.skpos = model.skpos
 36 |       self.fSz = model.fSz
 37 |       self.dm = dm -- flag to use deepmask instead of sharpmask
 38 | 
 39 |       -- number of proposals
 40 |       self.np = np
 41 | 
 42 |       --mean/std
 43 |       self.mean, self.std = meanstd.mean, meanstd.std
 44 | 
 45 |       -- input size and border width
 46 |       self.iSz, self.bw = iSz, iSz/2
 47 | 
 48 |       -- timer
 49 |       if timer then self.timer = torch.Tensor(8):zero() end
 50 | 
 51 |       -- create  scale pyramid
 52 |       self.scales = scales
 53 |       self.pyramid = nn.ConcatTable()
 54 |       for i = 1,#scales do
 55 |         self.pyramid:add(nn.SpatialReSamplingEx{rwidth=scales[i],
 56 |           rheight=scales[i], mode='bilinear'})
 57 |       end
 58 | 
 59 |       -- allocate topScores, topMasks and topPatches
 60 |       self.topScores, self.topMasks = torch.Tensor(), torch.ByteTensor()
 61 |       local topPatches
 62 |       if self.dm then
 63 |         topPatches = torch.CudaTensor(self.np,512):zero()
 64 |       else
 65 |         topPatches = {}
 66 |         topPatches[1] = torch.CudaTensor(self.np,512):zero()
 67 |         for j = 1, #model.refs do
 68 |           local sz = model.fSz*2^(j-1)
 69 |           topPatches[j+1] = torch.CudaTensor(self.np,model.ks/2^(j),sz,sz)
 70 |         end
 71 |       end
 72 |       self.topPatches = topPatches
 73 |     end
 74 | }
 75 | 
 76 | --------------------------------------------------------------------------------
 77 | -- function: forward
 78 | local inpPad = torch.CudaTensor()
 79 | function Infer:forward(input,id)
 80 |   if input:type() == 'torch.CudaTensor' then input = input:float() end
 81 | 
 82 |   -- forward pyramid
 83 |   if self.timer then sys.tic() end
 84 |   local inpPyramid = self.pyramid:forward(input)
 85 |   if self.timer then self.timer:narrow(1,1,1):add(sys.toc()) end
 86 | 
 87 |   -- forward all scales through network
 88 |   local outPyramidTrunk,outPyramidScore,outPyramidSkip = {},{},{}
 89 |   for i,_ in pairs(inpPyramid) do
 90 |     local inp = inpPyramid[i]:cuda()
 91 |     local h,w = inp:size(2),inp:size(3)
 92 | 
 93 |     -- padding/normalize
 94 |     if self.timer then sys.tic() end
 95 |     inpPad:resize(1,3,h+2*self.bw,w+2*self.bw):fill(.5)
 96 |     inpPad:narrow(1,1,1):narrow(3,self.bw+1,h):narrow(4,self.bw+1,w):copy(inp)
 97 |     for i=1,3 do inpPad[1][i]:add(-self.mean[i]):div(self.std[i]) end
 98 |     cutorch.synchronize()
 99 |     if self.timer then self.timer:narrow(1,2,1):add(sys.toc()) end
100 | 
101 |     -- forward trunk
102 |     if self.timer then sys.tic() end
103 |     local outTrunk = self.trunk:forward(inpPad)
104 |     cutorch.synchronize()
105 |     if self.timer then self.timer:narrow(1,3,1):add(sys.toc()) end
106 |     table.insert(outPyramidTrunk,outTrunk:clone():squeeze())
107 | 
108 |     -- forward score branch
109 |     if self.timer then sys.tic() end
110 |     local outScore = self.sBranch:forward(outTrunk)
111 |     cutorch.synchronize()
112 |     if self.timer then self.timer:narrow(1,4,1):add(sys.toc()) end
113 |     table.insert(outPyramidScore,outScore:clone():squeeze())
114 | 
115 |     -- forward horizontal nets
116 |     if not self.dm then
117 |       local hOuts = {}
118 |       for k,neth in pairs(self.neths) do
119 |         if self.timer then sys.tic() end
120 |         neth:forward(self.trunk.modules[self.skpos[k]].output)
121 |         cutorch.synchronize()
122 |         if self.timer then self.timer:narrow(1,5,1):add(sys.toc()) end
123 |         hOuts[k] = neth.output:clone()
124 |       end
125 |       outPyramidSkip[i] = hOuts
126 |     end
127 |   end
128 | 
129 |   -- get top scores
130 |   self:getTopScores(outPyramidScore)
131 | 
132 |   -- get top patches and top masks, depending on mode
133 |   local topMasks0
134 |   if self.dm then
135 |     if self.timer then sys.tic() end
136 |     self:getTopPatchesDM(outPyramidTrunk)
137 |     if self.timer then self.timer:narrow(1,6,1):add(sys.toc()) end
138 | 
139 |     if self.timer then sys.tic() end
140 |     topMasks0 = self.mBranch:forward(self.topPatches)
141 |     local osz = math.sqrt(topMasks0:size(2))
142 |     topMasks0 = topMasks0:view(self.np,osz,osz)
143 |     if self.timer then self.timer:narrow(1,7,1):add(sys.toc()) end
144 |   else
145 |     if self.timer then sys.tic() end
146 |     self:getTopPatches(outPyramidTrunk,outPyramidSkip)
147 |     if self.timer then self.timer:narrow(1,6,1):add(sys.toc()) end
148 | 
149 |     if self.timer then sys.tic() end
150 |     topMasks0 = self:forwardRefinement(self.topPatches)
151 |     if self.timer then self.timer:narrow(1,7,1):add(sys.toc()) end
152 |   end
153 |   self.topMasks0 = topMasks0:float():squeeze()
154 | 
155 |   collectgarbage()
156 | 
157 |   if self.timer then self.timer:narrow(1,8,1):add(1) end
158 | end
159 | 
160 | --------------------------------------------------------------------------------
161 | -- function: forward refinement inference
162 | -- input is a table containing the output of bottom-up and the output of all
163 | -- horizontal layers
164 | function Infer:forwardRefinement(input)
165 |   local currentOutput = self.refs[0]:forward(input[1])
166 |   for i = 1,#self.refs do
167 |     currentOutput = self.refs[i]:forward({input[i+1],currentOutput})
168 |   end
169 |   cutorch.synchronize()
170 |   self.output = currentOutput
171 |   return self.output
172 | end
173 | 
174 | --------------------------------------------------------------------------------
175 | -- function: get top patches
176 | function Infer:getTopPatchesDM(outPyramidTrunk)
177 |   local topscores = self.topScores
178 |   local ts_ptr = topscores:data()
179 |   for i = 1, topscores:size(1) do
180 |     local pos = (i-1)*4
181 |     local s,x,y = ts_ptr[pos+1], ts_ptr[pos+2], ts_ptr[pos+3]
182 |     local patch = outPyramidTrunk[s]:narrow(2,x,1):narrow(3,y,1)
183 |     self.topPatches:narrow(1,i,1):copy(patch)
184 |   end
185 | end
186 | 
187 | --------------------------------------------------------------------------------
188 | -- function: get top patches
189 | local t
190 | function Infer:getTopPatches(outPyramidTrunk,outPyramidSkip)
191 |   local topscores = self.topScores
192 |   local ts_ptr = topscores:data()
193 | 
194 |   if not t then t={}; for j = 1, #self.skpos do t[j]=2^(j-1) end end
195 | 
196 |   for i = 1, #self.topPatches do self.topPatches[i]:zero() end
197 |   for i = 1, self.np do
198 |     local pos = (i-1)*4
199 |     local s,x,y = ts_ptr[pos+1], ts_ptr[pos+2], ts_ptr[pos+3]
200 | 
201 |     -- get patches from output outPyramidTrunk
202 |     local patch = outPyramidTrunk[s]:narrow(2,x,1):narrow(3,y,1)
203 |     self.topPatches[1]:narrow(1,i,1):copy(patch)
204 | 
205 |     for j = 1, #self.skpos do
206 |       local isz =(self.fSz)*t[j]
207 |       local xx,yy = (x-1)*t[j]+1 , (y-1)*t[j]+1
208 |       local o = outPyramidSkip[s][j]
209 |       local dx=math.min(isz,o:size(3)-xx+1)
210 |       local dy=math.min(isz,o:size(4)-yy+1)
211 |       local patch = o:narrow(3,xx,dx):narrow(4,yy,dy)
212 |       self.topPatches[j+1]:narrow(1,i,1):narrow(3,1,dx):narrow(4,1,dy)
213 |       :copy(patch)
214 |     end
215 |   end
216 |   cutorch.synchronize()
217 |   collectgarbage()
218 | end
219 | 
220 | --------------------------------------------------------------------------------
221 | -- function: get top scores
222 | -- return a tensor k x 4, where k is the number of top scores.
223 | -- each line contains: the score value, the scaleNb and position(of M(:))
224 | local sortedScores = torch.Tensor()
225 | local sortedIds = torch.Tensor()
226 | local pos = torch.Tensor()
227 | function Infer:getTopScores(outPyramidScore)
228 |   local topScores = self.topScores
229 | 
230 |   self.score = outPyramidScore
231 |   local np = self.np
232 |   -- sort scores/ids for each scale
233 |   local nScales=#self.score
234 |   local rowN=self.score[nScales]:size(1)*self.score[nScales]:size(2)
235 |   sortedScores:resize(rowN,nScales):zero()
236 |   sortedIds:resize(rowN,nScales):zero()
237 |   for s = 1,nScales do
238 |     self.score[s]:mul(-1):exp():add(1):pow(-1) -- scores2prob
239 | 
240 |     local sc = self.score[s]
241 |     local h,w = sc:size(1),sc:size(2)
242 | 
243 |     local sc=sc:view(h*w)
244 |     local sS,sIds=torch.sort(sc,true)
245 |     local sz = sS:size(1)
246 |     sortedScores:narrow(2,s,1):narrow(1,1,sz):copy(sS)
247 |     sortedIds:narrow(2,s,1):narrow(1,1,sz):copy(sIds)
248 |   end
249 | 
250 |   -- get top scores
251 |   pos:resize(nScales):fill(1)
252 |   topScores:resize(np,4):fill(1)
253 |   np=math.min(np,rowN)
254 | 
255 |   for i = 1,np do
256 |     local scale,score = 0,0
257 |     for k = 1,nScales do
258 |       if sortedScores[pos[k]][k] > score then
259 |         score = sortedScores[pos[k]][k]
260 |         scale = k
261 |       end
262 |     end
263 |     local temp=sortedIds[pos[scale]][scale]
264 |     local x=math.floor(temp/self.score[scale]:size(2))
265 |     local y=temp%self.score[scale]:size(2)+1
266 |     x,y=math.max(1,x),math.max(1,y)
267 | 
268 |     pos[scale]=pos[scale]+1
269 |     topScores:narrow(1,i,1):copy(torch.Tensor({score,scale,x,y}))
270 |   end
271 | 
272 |   return topScores
273 | end
274 | 
275 | --------------------------------------------------------------------------------
276 | -- function: get top masks.
277 | local topMasks = torch.ByteTensor()
278 | local imgMask = torch.ByteTensor()
279 | function Infer:getTopMasks(thr,h,w)
280 |   thr = math.log(thr/(1-thr)) -- 1/(1+e^-s) > th => s > log(1-th)
281 | 
282 |   local topMasks0,topScores,np = self.topMasks0,self.topScores,self.np
283 |   topMasks:resize(np,h,w):zero()
284 |   imgMask:resize(h,w)
285 |   local imgMaskPtr = imgMask:data()
286 | 
287 |   for i = 1,np do
288 |     imgMask:zero()
289 |     local scale,x,y = topScores[i][2],topScores[i][3],topScores[i][4]
290 |     local s = self.scales[scale]
291 |     local sz = math.floor(self.iSz/s)
292 |     local mask = topMasks0[i]
293 |     local x,y = math.min(x,mask:size(1)),math.min(y,mask:size(2))
294 |     local mask = image.scale(mask,sz,sz,'bilinear')
295 |     local maskPtr = mask:data()
296 | 
297 |     local t,delta = 16/s, self.iSz/2/s
298 |     for im =0, sz-1 do
299 |       local ii = math.floor((x-1)*t-delta+im)
300 |       for jm = 0,sz- 1 do
301 |         local jj=math.floor((y-1)*t-delta+jm)
302 |         if  maskPtr[sz*im + jm] > thr and
303 |         ii >= 0 and ii <= h-1 and jj >= 0 and jj <= w-1 then
304 |           imgMaskPtr[jj+ w*ii]=1
305 |         end
306 |       end
307 |     end
308 | 
309 |     topMasks:narrow(1,i,1):copy(imgMask)
310 |   end
311 | 
312 |   self.topMasks = topMasks
313 |   return topMasks
314 | end
315 | 
316 | --------------------------------------------------------------------------------
317 | -- function: get top proposals
318 | function Infer:getTopProps(thr,h,w)
319 |   self:getTopMasks(thr,h,w)
320 |   return self.topMasks, self.topScores
321 | end
322 | 
323 | --------------------------------------------------------------------------------
324 | -- function: display timer
325 | function Infer:printTiming()
326 |   local t = self.timer
327 |   t:div(t[t:size(1)])
328 | 
329 |   print('\n| timing:')
330 |   print('| time pyramid:',t[1])
331 |   print('| time pre-process:',t[2])
332 |   print('| time trunk:',t[3])
333 |   print('| time score branch:',t[4])
334 |   print('| time skip connections:',t[5])
335 |   print('| time topPatches:',t[6])
336 |   print('| time refinement:',t[7])
337 |   print('| time total:',t:narrow(1,1,t:size(1)-1):sum())
338 | end
339 | 
340 | return Infer
341 | 


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | BSD License
 2 | 
 3 | For deepmask software
 4 | 
 5 | Copyright (c) 2016, Facebook, Inc. All rights reserved.
 6 | 
 7 | Redistribution and use in source and binary forms, with or without modification,
 8 | are permitted provided that the following conditions are met:
 9 | 
10 |  * Redistributions of source code must retain the above copyright notice, this
11 |    list of conditions and the following disclaimer.
12 | 
13 |  * Redistributions in binary form must reproduce the above copyright notice,
14 |    this list of conditions and the following disclaimer in the documentation
15 |    and/or other materials provided with the distribution.
16 | 
17 |  * Neither the name Facebook nor the names of its contributors may be used to
18 |    endorse or promote products derived from this software without specific
19 |    prior written permission.
20 | 
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 | 


--------------------------------------------------------------------------------
/PATENTS:
--------------------------------------------------------------------------------
 1 | Additional Grant of Patent Rights Version 2
 2 | 
 3 | "Software" means the deepmask software distributed by Facebook, Inc.
 4 | 
 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software
 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable
 7 | (subject to the termination provision below) license under any Necessary
 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise
 9 | transfer the Software. For avoidance of doubt, no license is granted under
10 | Facebook’s rights in any patent claims that are infringed by (i) modifications
11 | to the Software made by you or any third party or (ii) the Software in
12 | combination with any software or other technology.
13 | 
14 | The license granted hereunder will terminate, automatically and without notice,
15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate
16 | directly or indirectly, or take a direct financial interest in, any Patent
17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate
18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or
19 | in part from any software, technology, product or service of Facebook or any of
20 | its subsidiaries or corporate affiliates, or (iii) against any party relating
21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its
22 | subsidiaries or corporate affiliates files a lawsuit alleging patent
23 | infringement against you in the first instance, and you respond by filing a
24 | patent infringement counterclaim in that lawsuit against that party that is
25 | unrelated to the Software, the license granted hereunder will not terminate
26 | under section (i) of this paragraph due to such counterclaim.
27 | 
28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is
29 | necessarily infringed by the Software standing alone.
30 | 
31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
32 | or contributory infringement or inducement to infringe any patent, including a
33 | cross-claim or counterclaim.
34 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # Introduction
  2 | This repository contains a [Torch](http://torch.ch) implementation for both the [DeepMask](http://arxiv.org/abs/1506.06204) and [SharpMask](http://arxiv.org/abs/1603.08695) object proposal algorithms.
  3 | 
  4 | ![teaser](https://raw.githubusercontent.com/facebookresearch/deepmask/master/data/teaser.png)
  5 | 
  6 | [DeepMask](http://arxiv.org/abs/1506.06204) is trained with two objectives: given an image patch, one branch of the model outputs a class-agnostic segmentation mask, while the other branch outputs how likely the patch is to contain an object. At test time, DeepMask is applied densely to an image and generates a set of object masks, each with a corresponding objectness score. These masks densely cover the objects in an image and can be used as a first step for object detection and other tasks in computer vision.
  7 | 
  8 | [SharpMask](http://arxiv.org/abs/1603.08695) is an extension of DeepMask which generates higher-fidelity masks using an additional top-down refinement step. The idea is to first generate a coarse mask encoding in a feedforward pass, then refine this mask encoding in a top-down pass using features at successively lower layers. This result in masks that better adhere to object boundaries.
  9 | 
 10 | If you use DeepMask/SharpMask in your research, please cite the relevant papers:
 11 | ```
 12 | @inproceedings{DeepMask,
 13 |    title = {Learning to Segment Object Candidates},
 14 |    author = {Pedro O. Pinheiro and Ronan Collobert and Piotr Dollár},
 15 |    booktitle = {NIPS},
 16 |    year = {2015}
 17 | }
 18 | ```
 19 | ```
 20 | @inproceedings{SharpMask,
 21 |    title = {Learning to Refine Object Segments},
 22 |    author = {Pedro O. Pinheiro and Tsung-Yi Lin and Ronan Collobert and Piotr Dollár},
 23 |    booktitle = {ECCV},
 24 |    year = {2016}
 25 | }
 26 | ```
 27 | Note: the version of DeepMask implemented here is the updated version reported in the SharpMask paper. DeepMask takes on average .5s per COCO image, SharpMask runs at .8s. Runtime roughly doubles for the "zoom" versions of the models.
 28 | 
 29 | # Requirements and Dependencies
 30 | * MAC OS X or Linux
 31 | * NVIDIA GPU with compute capability 3.5+
 32 | * [Torch](http://torch.ch) with packages: [COCO API](https://github.com/pdollar/coco), [image](https://github.com/torch/image), [tds](https://github.com/torch/tds), [cjson](https://github.com/clementfarabet/lua---json), [nnx](https://github.com/clementfarabet/lua---nnx), [optim](https://github.com/torch/optim), [inn](https://github.com/szagoruyko/imagine-nn), [cutorch](https://github.com/torch/cutorch), [cunn](https://github.com/torch/cunn), [cudnn](https://github.com/soumith/cudnn.torch)
 33 | 
 34 | # Quick Start
 35 | To run pretrained DeepMask/SharpMask models to generate object proposals, follow these steps:
 36 | 
 37 | 1. Clone this repository into $DEEPMASK:
 38 | 
 39 |    ```bash
 40 |    DEEPMASK=/desired/absolute/path/to/deepmask/ # set absolute path as desired
 41 |    git clone git@github.com:facebookresearch/deepmask.git $DEEPMASK
 42 |    ```
 43 | 
 44 | 2. Download pre-trained DeepMask and SharpMask models:
 45 | 
 46 |    ```bash
 47 |    mkdir -p $DEEPMASK/pretrained/deepmask; cd $DEEPMASK/pretrained/deepmask
 48 |    wget https://dl.fbaipublicfiles.com/deepmask/models/deepmask/model.t7
 49 |    mkdir -p $DEEPMASK/pretrained/sharpmask; cd $DEEPMASK/pretrained/sharpmask
 50 |    wget https://dl.fbaipublicfiles.com/deepmask/models/sharpmask/model.t7
 51 |    ```
 52 | 
 53 | 3. Run `computeProposals.lua` with a given model and optional target image (specified via the `-img` option):
 54 | 
 55 |    ```bash
 56 |    # apply to a default sample image (data/testImage.jpg)
 57 |    cd $DEEPMASK
 58 |    th computeProposals.lua $DEEPMASK/pretrained/deepmask # run DeepMask
 59 |    th computeProposals.lua $DEEPMASK/pretrained/sharpmask # run SharpMask
 60 |    th computeProposals.lua $DEEPMASK/pretrained/sharpmask -img /path/to/image.jpg
 61 |    ```
 62 | 
 63 | 
 64 | # Training Your Own Model
 65 | To train your own DeepMask/SharpMask models, follow these steps:
 66 | 
 67 | ## Preparation
 68 | 1. If you have not done so already, clone this repository into $DEEPMASK:
 69 | 
 70 |    ```bash
 71 |    DEEPMASK=/desired/absolute/path/to/deepmask/ # set absolute path as desired
 72 |    git clone git@github.com:facebookresearch/deepmask.git $DEEPMASK
 73 |    ```
 74 | 
 75 | 2. Download the Torch [ResNet-50](https://dl.fbaipublicfiles.com/deepmask/models/resnet-50.t7) model pretrained on ImageNet:
 76 | 
 77 |    ```bash
 78 |    mkdir -p $DEEPMASK/pretrained; cd $DEEPMASK/pretrained
 79 |    wget https://dl.fbaipublicfiles.com/deepmask/models/resnet-50.t7
 80 |    ```
 81 | 
 82 | 3. Download and extract the [COCO](http://mscoco.org/) images and annotations:
 83 | 
 84 |    ```bash
 85 |    mkdir -p $DEEPMASK/data; cd $DEEPMASK/data
 86 |    wget http://msvocds.blob.core.windows.net/annotations-1-0-3/instances_train-val2014.zip
 87 |    wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip
 88 |    wget http://msvocds.blob.core.windows.net/coco2014/val2014.zip
 89 |    ```
 90 | 
 91 | ## Training
 92 | To train, launch the `train.lua` script. It contains several options, to list them, simply use the `--help` flag.
 93 | 
 94 | 1. To train DeepMask:
 95 | 
 96 |    ```bash
 97 |    th train.lua
 98 |    ```
 99 | 
100 | 2. To train SharpMask (requires pre-trained DeepMask model):
101 | 
102 |    ```bash
103 |    th train.lua -dm /path/to/trained/deepmask/
104 |    ```
105 | 
106 | ## Evaluation
107 | There are two ways to evaluate a model on the COCO dataset.
108 | 
109 | 1. `evalPerPatch.lua` evaluates only the mask generation step. The per-patch evaluation only uses image patches that contain roughly centered objects. Its usage is as follows:
110 | 
111 |    ```bash
112 |    th evalPerPatch.lua /path/to/trained/deepmask-or-sharpmask/
113 |    ```
114 | 
115 | 2. `evalPerImage.lua` evaluates the full model on COCO images, as reported in the papers. By default, it evaluates performance on the first 5K COCO validation images (run `th evalPerImage.lua --help` to see the options):
116 | 
117 |    ```bash
118 |    th evalPerImage.lua /path/to/trained/deepmask-or-sharpmask/
119 |    ```
120 | 
121 | 
122 | # Precomputed Proposals
123 | 
124 | You can download pre-computed proposals (1000 per image) on the COCO and PASCAL VOC datasets, for both segmentation and bounding box proposals. We use the COCO JSON [format](http://mscoco.org/dataset/#format) for the proposals. The proposals are divided into chunks of 500 images each (that is, each JSON contains 1000 proposals per image for 500 images). All proposals correspond to the "zoom" setting in the paper (DeepMaskZoom and SharpMaskZoom) which tend to be most effective for object detection.
125 | 
126 | ## DeepMask
127 | * COCO Boxes: [[train](https://dl.fbaipublicfiles.com/deepmask/boxes/deepmask-coco-train-bbox.tar.gz) | [val](https://dl.fbaipublicfiles.com/deepmask/boxes/deepmask-coco-val-bbox.tar.gz) | [test-dev](https://dl.fbaipublicfiles.com/deepmask/boxes/deepmask-coco-test-dev-bbox.tar.gz) | [test-full](https://dl.fbaipublicfiles.com/deepmask/boxes/deepmask-coco-test-full-bbox.tar.gz)]
128 | * COCO Segments: [[train](https://dl.fbaipublicfiles.com/deepmask/segms/deepmask-coco-train.tar.gz) | [val](https://dl.fbaipublicfiles.com/deepmask/segms/deepmask-coco-val.tar.gz) | [test-dev](https://dl.fbaipublicfiles.com/deepmask/segms/deepmask-coco-test-dev.tar.gz) | [test-full](https://dl.fbaipublicfiles.com/deepmask/segms/deepmask-coco-test-full.tar.gz)]
129 | * PASCAL Boxes: [[train+val+test-2007](https://dl.fbaipublicfiles.com/deepmask/boxes/deepmask-pascal07-bbox.tar.gz) | [train+val+test-2012](https://dl.fbaipublicfiles.com/deepmask/boxes/deepmask-pascal12-bbox.tar.gz)]
130 | * PASCAL Segments: [[train+val+test-2007](https://dl.fbaipublicfiles.com/deepmask/segms/deepmask-pascal07.tar.gz) | [train+val+test-2012](https://dl.fbaipublicfiles.com/deepmask/segms/deepmask-pascal12.tar.gz)]
131 | 
132 | ## SharpMask
133 | * COCO Boxes: [[train](https://dl.fbaipublicfiles.com/deepmask/boxes/sharpmask-coco-train-bbox.tar.gz) | [val](https://dl.fbaipublicfiles.com/deepmask/boxes/sharpmask-coco-val-bbox.tar.gz) | [test-dev](https://dl.fbaipublicfiles.com/deepmask/boxes/sharpmask-coco-test-dev-bbox.tar.gz) | [test-full](https://dl.fbaipublicfiles.com/deepmask/boxes/sharpmask-coco-test-full-bbox.tar.gz)]
134 | * COCO Segments: [[train](https://dl.fbaipublicfiles.com/deepmask/segms/sharpmask-coco-train.tar.gz) | [val](https://dl.fbaipublicfiles.com/deepmask/segms/sharpmask-coco-val.tar.gz) | [test-dev](https://dl.fbaipublicfiles.com/deepmask/segms/sharpmask-coco-test-dev.tar.gz) | [test-full](https://dl.fbaipublicfiles.com/deepmask/segms/sharpmask-coco-test-full.tar.gz)]
135 | * PASCAL Boxes: [[train+val+test-2007](https://dl.fbaipublicfiles.com/deepmask/boxes/sharpmask-pascal07-bbox.tar.gz) | [train+val+test-2012](https://dl.fbaipublicfiles.com/deepmask/boxes/sharpmask-pascal12-bbox.tar.gz)]
136 | * PASCAL Segments: [[train+val+test-2007](https://dl.fbaipublicfiles.com/deepmask/segms/sharpmask-pascal07.tar.gz) | [train+val+test-2012](https://dl.fbaipublicfiles.com/deepmask/segms/sharpmask-pascal12.tar.gz)]
137 | 


--------------------------------------------------------------------------------
/SharpMask.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | When initialized, it loads a pre-trained DeepMask and create the refinement
  8 | modules.
  9 | SharpMask class members:
 10 |   - self.trunk: common trunk (from trained DeepMask model)
 11 |   - self.scoreBranch: score head architecture (from trained DeepMask model)
 12 |   - self.maskBranchDM: mask head architecture (from trained DeepMask model)
 13 |   - self.refs: ensemble of refinement modules for top-down path
 14 | ------------------------------------------------------------------------------]]
 15 | 
 16 | require 'nn'
 17 | require 'nnx'
 18 | require 'cunn'
 19 | require 'cudnn'
 20 | local utils = paths.dofile('modelUtils.lua')
 21 | 
 22 | local SharpMask, _ = torch.class('nn.SharpMask','nn.Container')
 23 | 
 24 | --------------------------------------------------------------------------------
 25 | -- function: init
 26 | function SharpMask:__init(config)
 27 |   self.km, self.ks = config.km, config.ks
 28 |   assert(self.km >= 16 and self.km%16==0 and self.ks >= 16 and self.ks%16==0)
 29 | 
 30 |   self.skpos = {8,6,5,3} -- positions to forward horizontal nets
 31 |   self.inps = {}
 32 | 
 33 |   -- create bottom-up flow (from deepmask)
 34 |   local m = torch.load(config.dm..'/model.t7')
 35 |   local deepmask = m.model
 36 |   self.trunk = deepmask.trunk
 37 |   self.scoreBranch = deepmask.scoreBranch
 38 |   self.maskBranchDM = deepmask.maskBranch
 39 |   self.fSz = deepmask.fSz
 40 | 
 41 |   -- create refinement modules
 42 |   self:createTopDownRefinement(config)
 43 | 
 44 |   -- number of parameters
 45 |   local nh,nv = 0,0
 46 |   for k,v in pairs(self.neths) do
 47 |     for kk,vv in pairs(v:parameters()) do nh = nh+vv:nElement() end
 48 |   end
 49 |   for k,v in pairs(self.netvs) do
 50 |     for kk,vv in pairs(v:parameters()) do nv = nv+vv:nElement() end
 51 |   end
 52 |   print(string.format('| number of paramaters net h: %d', nh))
 53 |   print(string.format('| number of paramaters net v: %d', nv))
 54 |   print(string.format('| number of paramaters total: %d', nh+nv))
 55 |   self:cuda()
 56 | end
 57 | 
 58 | --------------------------------------------------------------------------------
 59 | -- function: create vertical nets
 60 | function SharpMask:createVertical(config)
 61 |   local netvs = {}
 62 | 
 63 |   local n0 = nn.Sequential()
 64 |   n0:add(nn.Linear(512,self.fSz*self.fSz*self.km))
 65 |   n0:add(nn.View(config.batch,self.km,self.fSz,self.fSz))
 66 |   netvs[0]=n0:cuda()
 67 | 
 68 |   for i = 1, #self.skpos do
 69 |     local netv = nn.Sequential()
 70 |     local nInps = self.km/2^(i-1)
 71 | 
 72 |     netv:add(nn.SpatialSymmetricPadding(1,1,1,1))
 73 |     netv:add(cudnn.SpatialConvolution(nInps,nInps,3,3,1,1))
 74 |     netv:add(cudnn.ReLU())
 75 | 
 76 |     netv:add(nn.SpatialSymmetricPadding(1,1,1,1))
 77 |     netv:add(cudnn.SpatialConvolution(nInps,nInps/2,3,3,1,1))
 78 | 
 79 |     table.insert(netvs,netv:cuda())
 80 |   end
 81 | 
 82 |   self.netvs = netvs
 83 |   return netvs
 84 | end
 85 | 
 86 | --------------------------------------------------------------------------------
 87 | -- function: create horizontal nets
 88 | function SharpMask:createHorizontal(config)
 89 |   local neths = {}
 90 |   local nhu1,nhu2,crop
 91 |   for i =1,#self.skpos do
 92 |     local h = nn.Sequential()
 93 |     local nInps = self.ks/2^(i-1)
 94 | 
 95 |     if i == 1 then nhu1,nhu2,crop=1024,64,0
 96 |     elseif i == 2 then nhu1,nhu2,crop = 512,64,-2
 97 |     elseif i == 3 then nhu1,nhu2,crop = 256,64,-4
 98 |     elseif i == 4 then nhu1,nhu2,crop = 64,32,-8
 99 |     end
100 |     if crop ~= 0 then h:add(nn.SpatialZeroPadding(crop,crop,crop,crop)) end
101 | 
102 |     h:add(nn.SpatialSymmetricPadding(1,1,1,1))
103 |     h:add(cudnn.SpatialConvolution(nhu1,nhu2,3,3,1,1))
104 |     h:add(cudnn.ReLU())
105 | 
106 |     h:add(nn.SpatialSymmetricPadding(1,1,1,1))
107 |     h:add(cudnn.SpatialConvolution(nhu2,nInps,3,3,1,1))
108 |     h:add(cudnn.ReLU())
109 | 
110 |     h:add(nn.SpatialSymmetricPadding(1,1,1,1))
111 |     h:add(cudnn.SpatialConvolution(nInps,nInps/2,3,3,1,1))
112 | 
113 |     table.insert(neths,h:cuda())
114 |   end
115 | 
116 |   self.neths = neths
117 |   return neths
118 | end
119 | 
120 | --------------------------------------------------------------------------------
121 | -- function: create refinement modules
122 | function SharpMask:refinement(neth,netv)
123 |    local ref = nn.Sequential()
124 |    local par = nn.ParallelTable():add(neth):add(netv)
125 |    ref:add(par)
126 |    ref:add(nn.CAddTable(2))
127 |    ref:add(cudnn.ReLU())
128 |    ref:add(nn.SpatialUpSamplingNearest(2))
129 | 
130 |    return ref:cuda()
131 | end
132 | 
133 | function SharpMask:createTopDownRefinement(config)
134 |   -- create horizontal nets
135 |   self:createHorizontal(config)
136 | 
137 |   -- create vertical nets
138 |   self:createVertical(config)
139 | 
140 |   local refs = {}
141 |   refs[0] = self.netvs[0]
142 |   for i = 1, #self.skpos do
143 |     table.insert(refs,self:refinement(self.neths[i],self.netvs[i]))
144 |   end
145 | 
146 |   local finalref = refs[#refs]
147 |   finalref:add(nn.SpatialSymmetricPadding(1,1,1,1))
148 |   finalref:add(cudnn.SpatialConvolution((self.km)/2^(#refs),1,3,3,1,1))
149 |   finalref:add(nn.View(config.batch,config.gSz*config.gSz))
150 | 
151 |   self.refs = refs
152 |   return refs
153 | end
154 | 
155 | --------------------------------------------------------------------------------
156 | -- function: forward
157 | function SharpMask:forward(input)
158 |   -- forward bottom-up
159 |   local currentOutput = self.trunk:forward(input)
160 | 
161 |   -- forward refinement modules
162 |   currentOutput = self.refs[0]:forward(currentOutput)
163 |   for k = 1,#self.refs do
164 |     local F = self.trunk.modules[self.skpos[k]].output
165 |     self.inps[k] = {F,currentOutput}
166 |     currentOutput = self.refs[k]:forward(self.inps[k])
167 |   end
168 |   self.output = currentOutput
169 |   return self.output
170 | end
171 | 
172 | --------------------------------------------------------------------------------
173 | -- function: backward
174 | function SharpMask:backward(input,gradOutput)
175 |   local currentGrad = gradOutput
176 |   for i = #self.refs,1,-1 do
177 |     currentGrad =self.refs[i]:backward(self.inps[i],currentGrad)
178 |     currentGrad = currentGrad[2]
179 |   end
180 |   currentGrad =self.refs[0]:backward(self.trunk.output,currentGrad)
181 | 
182 |   self.gradInput = currentGrad
183 |   return currentGrad
184 | end
185 | 
186 | --------------------------------------------------------------------------------
187 | -- function: zeroGradParameters
188 | function SharpMask:zeroGradParameters()
189 |   for k,v in pairs(self.refs) do self.refs[k]:zeroGradParameters() end
190 | end
191 | 
192 | --------------------------------------------------------------------------------
193 | -- function: updateParameters
194 | function SharpMask:updateParameters(lr)
195 |   for k,n in pairs(self.refs) do self.refs[k]:updateParameters(lr) end
196 | end
197 | 
198 | --------------------------------------------------------------------------------
199 | -- function: training
200 | function SharpMask:training()
201 |   self.trunk:training();self.scoreBranch:training();self.maskBranchDM:training()
202 |   for k,n in pairs(self.refs) do self.refs[k]:training() end
203 | end
204 | 
205 | --------------------------------------------------------------------------------
206 | -- function: evaluate
207 | function SharpMask:evaluate()
208 |   self.trunk:evaluate();self.scoreBranch:evaluate();self.maskBranchDM:evaluate()
209 |   for k,n in pairs(self.refs) do self.refs[k]:evaluate() end
210 | end
211 | 
212 | --------------------------------------------------------------------------------
213 | -- function: to cuda
214 | function SharpMask:cuda()
215 |   self.trunk:cuda();self.scoreBranch:cuda();self.maskBranchDM:cuda()
216 |   for k,n in pairs(self.refs) do self.refs[k]:cuda() end
217 | end
218 | 
219 | --------------------------------------------------------------------------------
220 | -- function: to float
221 | function SharpMask:float()
222 |   self.trunk:float();self.scoreBranch:float();self.maskBranchDM:float()
223 |   for k,n in pairs(self.refs) do self.refs[k]:float() end
224 | end
225 | 
226 | --------------------------------------------------------------------------------
227 | -- function: set number of proposals for inference
228 | function SharpMask:setnpinference(np)
229 |   local vsz = self.refs[0].modules[2].size
230 |   self.refs[0].modules[2]:resetSize(np,vsz[2],vsz[3],vsz[4])
231 | end
232 | 
233 | --------------------------------------------------------------------------------
234 | -- function: inference (used for full scene inference)
235 | function SharpMask:inference(np)
236 |   self:evaluate()
237 | 
238 |   -- remove last view
239 |   self.refs[#self.refs]:remove()
240 | 
241 |   -- remove ZeroPaddings
242 |   self.trunk.modules[8]=nn.Identity():cuda()
243 |   for k = 1, #self.refs do
244 |     local m = self.refs[k].modules[1].modules[1].modules[1]
245 |     if torch.typename(m):find('SpatialZeroPadding') then
246 |       self.refs[k].modules[1].modules[1].modules[1]=nn.Identity():cuda()
247 |     end
248 |   end
249 | 
250 |   -- remove horizontal links, as they are applied convolutionally
251 |   for k = 1, #self.refs do
252 |     self.refs[k].modules[1].modules[1]=nn.Identity():cuda()
253 |   end
254 | 
255 |   -- modify number of batch to np (number of proposals)
256 |   self:setnpinference(np)
257 | 
258 |   -- transform trunk and score branch to conv
259 |   utils.linear2convTrunk(self.trunk,self.fSz)
260 |   utils.linear2convHead(self.scoreBranch)
261 |   self.maskBranchDM = self.maskBranchDM.modules[1]
262 | 
263 |   self:cuda()
264 | end
265 | 
266 | --------------------------------------------------------------------------------
267 | -- function: clone
268 | function SharpMask:clone(...)
269 |   local f = torch.MemoryFile("rw"):binary()
270 |   f:writeObject(self); f:seek(1)
271 |   local clone = f:readObject(); f:close()
272 | 
273 |   if select('#',...) > 0 then
274 |     clone.trunk:share(self.trunk,...)
275 |     clone.maskBranchDM:share(self.maskBranchDM,...)
276 |     clone.scoreBranch:share(self.scoreBranch,...)
277 |     for k,n in pairs(self.netvs) do clone.netvs[k]:share(self.netvs[k],...)end
278 |     for k,n in pairs(self.neths) do clone.neths[k]:share(self.neths[k],...) end
279 |     for k,n in pairs(self.refs)  do clone.refs[k]:share(self.refs[k],...) end
280 |   end
281 | 
282 |   return clone
283 | end
284 | 
285 | return nn.SharpMask
286 | 


--------------------------------------------------------------------------------
/SpatialSymmetricPadding.lua:
--------------------------------------------------------------------------------
 1 | --[[----------------------------------------------------------------------------
 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
 3 | This source code is licensed under the BSD-style license found in the
 4 | LICENSE file in the root directory of this source tree. An additional grant
 5 | of patent rights can be found in the PATENTS file in the same directory.
 6 | 
 7 | SpatialSymmetricPadding module
 8 | 
 9 | The forward(A) pads input tensor A with mirror reflections of itself
10 | It is the same function as Matlab padarray(A, padsize, 'symmetric' )
11 | The padding is of the form: cba[abcd...]
12 | While nn.SpatialReflectionPadding does: dcb[abcd...]
13 | And nn.SpatialReplicationPadding does: aaa[abcd...]
14 | (where [abcd...] is a tensor)
15 | The updateGradInput(input, gradOutput) is inherited from nn.SpatialZeroPadding,
16 | where the padded region is treated as constant and
17 | the gradients are accumulated in the backward pass
18 | ------------------------------------------------------------------------------]]
19 | 
20 | local SpatialSymmetricPadding, parent =
21 |   torch.class('nn.SpatialSymmetricPadding', 'nn.SpatialZeroPadding')
22 | 
23 | function SpatialSymmetricPadding:__init(pad_l, pad_r, pad_t, pad_b)
24 |    parent.__init(self, pad_l, pad_r, pad_t, pad_b)
25 | end
26 | 
27 | function SpatialSymmetricPadding:updateOutput(input)
28 |   assert(input:dim()==4, "only Dimension=4 implemented")
29 |   -- sizes
30 |   local h = input:size(3) + self.pad_t + self.pad_b
31 |   local w = input:size(4) + self.pad_l + self.pad_r
32 |   if w < 1 or h < 1 then error('input is too small') end
33 |   self.output:resize(input:size(1), input:size(2), h, w)
34 |   self.output:zero()
35 |   -- crop input if necessary
36 |   local c_input = input
37 |   if self.pad_t < 0 then
38 |     c_input = c_input:narrow(3, 1 - self.pad_t, c_input:size(3) + self.pad_t)
39 |   end
40 |   if self.pad_b < 0 then
41 |     c_input = c_input:narrow(3, 1, c_input:size(3) + self.pad_b)
42 |   end
43 |   if self.pad_l < 0 then
44 |     c_input = c_input:narrow(4, 1 - self.pad_l, c_input:size(4) + self.pad_l)
45 |   end
46 |   if self.pad_r < 0 then
47 |     c_input = c_input:narrow(4, 1, c_input:size(4) + self.pad_r)
48 |   end
49 |   -- crop outout if necessary
50 |   local c_output = self.output
51 |   if self.pad_t > 0 then
52 |     c_output = c_output:narrow(3, 1 + self.pad_t, c_output:size(3) - self.pad_t)
53 |   end
54 |   if self.pad_b > 0 then
55 |     c_output = c_output:narrow(3, 1, c_output:size(3) - self.pad_b)
56 |   end
57 |   if self.pad_l > 0 then
58 |     c_output = c_output:narrow(4, 1 + self.pad_l, c_output:size(4) - self.pad_l)
59 |   end
60 |   if self.pad_r > 0 then
61 |     c_output = c_output:narrow(4, 1, c_output:size(4) - self.pad_r)
62 |   end
63 |   -- copy input to output
64 |   c_output:copy(c_input)
65 |   -- symmetric padding that fills in values on the padded region
66 |   if w<2*self.pad_l or w<2*self.pad_r or h<2*self.pad_t or h<2*self.pad_b then
67 |     error('input is too small')
68 |   end
69 |   for i=1,self.pad_t do
70 |     self.output:narrow(3,self.pad_t-i+1,1):copy(
71 |     self.output:narrow(3,i+self.pad_t,1))
72 |   end
73 |   for i=1,self.pad_b do
74 |     self.output:narrow(3,self.output:size(3)-self.pad_b+i,1):copy(
75 |     self.output:narrow(3,self.output:size(3)-self.pad_b-i+1,1))
76 |   end
77 |   for i=1,self.pad_l do
78 |     self.output:narrow(4,self.pad_l-i+1,1):copy(
79 |     self.output:narrow(4,i+self.pad_l,1))
80 |   end
81 |   for i=1,self.pad_r do
82 |     self.output:narrow(4,self.output:size(4)-self.pad_r+i,1):copy(
83 |     self.output:narrow(4,self.output:size(4)-self.pad_r-i+1,1))
84 |   end
85 |   return self.output
86 | end
87 | 


--------------------------------------------------------------------------------
/TrainerDeepMask.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | Training and testing loop for DeepMask
  8 | ------------------------------------------------------------------------------]]
  9 | 
 10 | local optim = require 'optim'
 11 | paths.dofile('trainMeters.lua')
 12 | 
 13 | local Trainer = torch.class('Trainer')
 14 | 
 15 | --------------------------------------------------------------------------------
 16 | -- function: init
 17 | function Trainer:__init(model, criterion, config)
 18 |   -- training params
 19 |   self.config = config
 20 |   self.model = model
 21 |   self.maskNet = nn.Sequential():add(model.trunk):add(model.maskBranch)
 22 |   self.scoreNet = nn.Sequential():add(model.trunk):add(model.scoreBranch)
 23 |   self.criterion = criterion
 24 |   self.lr = config.lr
 25 |   self.optimState ={}
 26 |   for k,v in pairs({'trunk','mask','score'}) do
 27 |     self.optimState[v] = {
 28 |       learningRate = config.lr,
 29 |       learningRateDecay = 0,
 30 |       momentum = config.momentum,
 31 |       dampening = 0,
 32 |       weightDecay = config.wd,
 33 |     }
 34 |   end
 35 | 
 36 |   -- params and gradparams
 37 |   self.pt,self.gt = model.trunk:getParameters()
 38 |   self.pm,self.gm = model.maskBranch:getParameters()
 39 |   self.ps,self.gs = model.scoreBranch:getParameters()
 40 | 
 41 |   -- allocate cuda tensors
 42 |   self.inputs, self.labels = torch.CudaTensor(), torch.CudaTensor()
 43 | 
 44 |   -- meters
 45 |   self.lossmeter  = LossMeter()
 46 |   self.maskmeter  = IouMeter(0.5,config.testmaxload*config.batch)
 47 |   self.scoremeter = BinaryMeter()
 48 | 
 49 |   -- log
 50 |   self.modelsv = {model=model:clone('weight', 'bias'),config=config}
 51 |   self.rundir = config.rundir
 52 |   self.log = torch.DiskFile(self.rundir .. '/log', 'rw'); self.log:seekEnd()
 53 | end
 54 | 
 55 | --------------------------------------------------------------------------------
 56 | -- function: train
 57 | function Trainer:train(epoch, dataloader)
 58 |   self.model:training()
 59 |   self:updateScheduler(epoch)
 60 |   self.lossmeter:reset()
 61 | 
 62 |   local timer = torch.Timer()
 63 | 
 64 |   local fevaltrunk = function() return self.model.trunk.output, self.gt end
 65 |   local fevalmask  = function() return self.criterion.output,   self.gm end
 66 |   local fevalscore = function() return self.criterion.output,   self.gs end
 67 | 
 68 |   for n, sample in dataloader:run() do
 69 |     -- copy samples to the GPU
 70 |     self:copySamples(sample)
 71 | 
 72 |     -- forward/backward
 73 |     local model, params, feval, optimState
 74 |     if sample.head == 1 then
 75 |       model, params = self.maskNet, self.pm
 76 |       feval,optimState = fevalmask, self.optimState.mask
 77 |     else
 78 |       model, params = self.scoreNet, self.ps
 79 |       feval,optimState = fevalscore, self.optimState.score
 80 |     end
 81 | 
 82 |     local outputs = model:forward(self.inputs)
 83 |     local lossbatch = self.criterion:forward(outputs, self.labels)
 84 | 
 85 |     model:zeroGradParameters()
 86 |     local gradOutputs = self.criterion:backward(outputs, self.labels)
 87 |     if sample.head == 1 then gradOutputs:mul(self.inputs:size(1)) end
 88 |     model:backward(self.inputs, gradOutputs)
 89 | 
 90 |     -- optimize
 91 |     optim.sgd(fevaltrunk, self.pt, self.optimState.trunk)
 92 |     optim.sgd(feval, params, optimState)
 93 | 
 94 |     -- update loss
 95 |     self.lossmeter:add(lossbatch)
 96 |   end
 97 | 
 98 |   -- write log
 99 |   local logepoch =
100 |     string.format('[train] | epoch %05d | s/batch %04.2f | loss: %07.5f ',
101 |       epoch, timer:time().real/dataloader:size(),self.lossmeter:value())
102 |   print(logepoch)
103 |   self.log:writeString(string.format('%s\n',logepoch))
104 |   self.log:synchronize()
105 | 
106 |   --save model
107 |   torch.save(string.format('%s/model.t7', self.rundir),self.modelsv)
108 |   if epoch%50 == 0 then
109 |     torch.save(string.format('%s/model_%d.t7', self.rundir, epoch),
110 |       self.modelsv)
111 |   end
112 | 
113 |   collectgarbage()
114 | end
115 | 
116 | --------------------------------------------------------------------------------
117 | -- function: test
118 | local maxacc = 0
119 | function Trainer:test(epoch, dataloader)
120 |   self.model:evaluate()
121 |   self.maskmeter:reset()
122 |   self.scoremeter:reset()
123 | 
124 |   for n, sample in dataloader:run() do
125 |     -- copy input and target to the GPU
126 |     self:copySamples(sample)
127 | 
128 |     if sample.head == 1 then
129 |       local outputs = self.maskNet:forward(self.inputs)
130 |       self.maskmeter:add(outputs:view(self.labels:size()),self.labels)
131 |     else
132 |       local outputs = self.scoreNet:forward(self.inputs)
133 |       self.scoremeter:add(outputs, self.labels)
134 |     end
135 |     cutorch.synchronize()
136 | 
137 |   end
138 |   self.model:training()
139 | 
140 |   -- check if bestmodel so far
141 |   local z,bestmodel = self.maskmeter:value('0.7')
142 |   if z > maxacc then
143 |     torch.save(string.format('%s/bestmodel.t7', self.rundir),self.modelsv)
144 |     maxacc = z
145 |     bestmodel = true
146 |   end
147 | 
148 |   -- write log
149 |   local logepoch =
150 |     string.format('[test]  | epoch %05d '..
151 |       '| IoU: mean %06.2f median %06.2f suc@.5 %06.2f suc@.7 %06.2f '..
152 |       '| acc %06.2f | bestmodel %s',
153 |       epoch,
154 |       self.maskmeter:value('mean'),self.maskmeter:value('median'),
155 |       self.maskmeter:value('0.5'), self.maskmeter:value('0.7'),
156 |       self.scoremeter:value(), bestmodel and '*' or 'x')
157 |   print(logepoch)
158 |   self.log:writeString(string.format('%s\n',logepoch))
159 |   self.log:synchronize()
160 | 
161 |   collectgarbage()
162 | end
163 | 
164 | --------------------------------------------------------------------------------
165 | -- function: copy inputs/labels to CUDA tensor
166 | function Trainer:copySamples(sample)
167 |   self.inputs:resize(sample.inputs:size()):copy(sample.inputs)
168 |   self.labels:resize(sample.labels:size()):copy(sample.labels)
169 | end
170 | 
171 | --------------------------------------------------------------------------------
172 | -- function: update training schedule according to epoch
173 | function Trainer:updateScheduler(epoch)
174 |   if self.lr == 0 then
175 |     local regimes = {
176 |       {   1,  50, 1e-3, 5e-4},
177 |       {  51, 120, 5e-4, 5e-4},
178 |       { 121, 1e8, 1e-4, 5e-4}
179 |     }
180 | 
181 |     for _, row in ipairs(regimes) do
182 |       if epoch >= row[1] and epoch <= row[2] then
183 |         for k,v in pairs(self.optimState) do
184 |           v.learningRate=row[3]; v.weightDecay=row[4]
185 |         end
186 |       end
187 |     end
188 |   end
189 | end
190 | 
191 | return Trainer
192 | 


--------------------------------------------------------------------------------
/TrainerSharpMask.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | Training and testing loop for SharpMask
  8 | ------------------------------------------------------------------------------]]
  9 | 
 10 | paths.dofile('trainMeters.lua')
 11 | 
 12 | local Trainer = torch.class('Trainer')
 13 | 
 14 | --------------------------------------------------------------------------------
 15 | -- function: init
 16 | function Trainer:__init(model, criterion, config)
 17 |   -- training params
 18 |   self.model = model
 19 |   self.criterion = criterion
 20 |   self.lr = config.lr
 21 | 
 22 |   -- allocate cuda tensors
 23 |   self.inputs, self.labels = torch.CudaTensor(), torch.CudaTensor()
 24 | 
 25 |   -- meters
 26 |   self.lossmeter  = LossMeter()
 27 |   self.maskmeter  = IouMeter(0.5,config.testmaxload*config.batch)
 28 | 
 29 |   -- log
 30 |   self.modelsv = {model=model:clone('weight', 'bias'),config=config}
 31 |   self.rundir = config.rundir
 32 |   self.log = torch.DiskFile(self.rundir .. '/log', 'rw'); self.log:seekEnd()
 33 | end
 34 | 
 35 | --------------------------------------------------------------------------------
 36 | -- function: train
 37 | function Trainer:train(epoch, dataloader)
 38 |   self.model:training()
 39 |   self:updateScheduler(epoch)
 40 |   self.lossmeter:reset()
 41 | 
 42 |   local timer = torch.Timer()
 43 | 
 44 |   for n, sample in dataloader:run() do
 45 |     -- copy samples to the GPU
 46 |     self:copySamples(sample)
 47 | 
 48 |     -- forward/backward
 49 |     local outputs = self.model:forward(self.inputs)
 50 |     local lossbatch = self.criterion:forward(outputs, self.labels)
 51 | 
 52 |     local gradOutputs = self.criterion:backward(outputs, self.labels)
 53 |     gradOutputs:mul(self.inputs:size(1))
 54 |     self.model:zeroGradParameters()
 55 |     self.model:backward(self.inputs, gradOutputs)
 56 |     self.model:updateParameters(self.lr)
 57 | 
 58 |     -- update loss
 59 |     self.lossmeter:add(lossbatch)
 60 |   end
 61 | 
 62 |   -- write log
 63 |   local logepoch =
 64 |     string.format('[train] | epoch %05d | s/batch %04.2f | loss: %07.5f ',
 65 |       epoch, timer:time().real/dataloader:size(),self.lossmeter:value())
 66 |   print(logepoch)
 67 |   self.log:writeString(string.format('%s\n',logepoch))
 68 |   self.log:synchronize()
 69 | 
 70 |   --save model
 71 |   torch.save(string.format('%s/model.t7', self.rundir),self.modelsv)
 72 |   if epoch%50 == 0 then
 73 |     torch.save(string.format('%s/model_%d.t7', self.rundir, epoch),
 74 |       self.modelsv)
 75 |   end
 76 |   collectgarbage()
 77 | end
 78 | 
 79 | --------------------------------------------------------------------------------
 80 | -- function: test
 81 | local maxacc = 0
 82 | function Trainer:test(epoch, dataloader)
 83 |   self.model:evaluate()
 84 |   self.maskmeter:reset()
 85 | 
 86 |   for n, sample in dataloader:run() do
 87 |     -- copy input and target to the GPU
 88 |     self:copySamples(sample)
 89 | 
 90 |     -- infer mask in batch
 91 |     local outputs = self.model:forward(self.inputs):float()
 92 |     cutorch.synchronize()
 93 | 
 94 |     self.maskmeter:add(outputs:view(sample.labels:size()),sample.labels)
 95 | 
 96 |   end
 97 |   self.model:training()
 98 | 
 99 |   -- check if bestmodel so far
100 |   local z,bestmodel = self.maskmeter:value('0.7')
101 |   if z > maxacc then
102 |     torch.save(string.format('%s/bestmodel.t7', self.rundir),self.modelsv)
103 |     maxacc = z
104 |     bestmodel = true
105 |   end
106 | 
107 |   -- write log
108 |   local logepoch =
109 |     string.format('[test]  | epoch %05d '..
110 |       '| IoU: mean %06.2f median %06.2f suc@.5 %06.2f suc@.7 %06.2f '..
111 |       '| bestmodel %s',
112 |       epoch,
113 |       self.maskmeter:value('mean'),self.maskmeter:value('median'),
114 |       self.maskmeter:value('0.5'), self.maskmeter:value('0.7'),
115 |       bestmodel and '*' or 'x')
116 |   print(logepoch)
117 |   self.log:writeString(string.format('%s\n',logepoch))
118 |   self.log:synchronize()
119 | 
120 |   collectgarbage()
121 | end
122 | 
123 | --------------------------------------------------------------------------------
124 | -- function: copy inputs/labels to CUDA tensor
125 | function Trainer:copySamples(sample)
126 |   self.inputs:resize(sample.inputs:size()):copy(sample.inputs)
127 |   self.labels:resize(sample.labels:size()):copy(sample.labels)
128 | end
129 | 
130 | --------------------------------------------------------------------------------
131 | -- function: update training schedule according to epoch
132 | function Trainer:updateScheduler(epoch)
133 |   if self.lr == 0 then
134 |     local regimes = {
135 |       {  1,     50,  1e-3},
136 |       { 51,     80,  5e-4},
137 |       { 81,    1e8,  1e-4}
138 |     }
139 | 
140 |     for _, row in ipairs(regimes) do
141 |       if epoch >= row[1] and epoch <= row[2] then
142 |         self.lr = row[3]
143 |       end
144 |     end
145 |   end
146 | end
147 | 
148 | return Trainer
149 | 


--------------------------------------------------------------------------------
/computeProposals.lua:
--------------------------------------------------------------------------------
 1 | --[[----------------------------------------------------------------------------
 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
 3 | This source code is licensed under the BSD-style license found in the
 4 | LICENSE file in the root directory of this source tree. An additional grant
 5 | of patent rights can be found in the PATENTS file in the same directory.
 6 | 
 7 | Run full scene inference in sample image
 8 | ------------------------------------------------------------------------------]]
 9 | 
10 | require 'torch'
11 | require 'cutorch'
12 | require 'image'
13 | 
14 | --------------------------------------------------------------------------------
15 | -- parse arguments
16 | local cmd = torch.CmdLine()
17 | cmd:text()
18 | cmd:text('evaluate deepmask/sharpmask')
19 | cmd:text()
20 | cmd:argument('-model', 'path to model to load')
21 | cmd:text('Options:')
22 | cmd:option('-img','data/testImage.jpg' ,'path/to/test/image')
23 | cmd:option('-gpu', 1, 'gpu device')
24 | cmd:option('-np', 5,'number of proposals to save in test')
25 | cmd:option('-si', -2.5, 'initial scale')
26 | cmd:option('-sf', .5, 'final scale')
27 | cmd:option('-ss', .5, 'scale step')
28 | cmd:option('-dm', false, 'use DeepMask version of SharpMask')
29 | 
30 | local config = cmd:parse(arg)
31 | 
32 | --------------------------------------------------------------------------------
33 | -- various initializations
34 | torch.setdefaulttensortype('torch.FloatTensor')
35 | cutorch.setDevice(config.gpu)
36 | 
37 | local coco = require 'coco'
38 | local maskApi = coco.MaskApi
39 | 
40 | local meanstd = {mean = { 0.485, 0.456, 0.406 }, std = { 0.229, 0.224, 0.225 }}
41 | 
42 | --------------------------------------------------------------------------------
43 | -- load moodel
44 | paths.dofile('DeepMask.lua')
45 | paths.dofile('SharpMask.lua')
46 | 
47 | print('| loading model file... ' .. config.model)
48 | local m = torch.load(config.model..'/model.t7')
49 | local model = m.model
50 | model:inference(config.np)
51 | model:cuda()
52 | 
53 | --------------------------------------------------------------------------------
54 | -- create inference module
55 | local scales = {}
56 | for i = config.si,config.sf,config.ss do table.insert(scales,2^i) end
57 | 
58 | if torch.type(model)=='nn.DeepMask' then
59 |   paths.dofile('InferDeepMask.lua')
60 | elseif torch.type(model)=='nn.SharpMask' then
61 |   paths.dofile('InferSharpMask.lua')
62 | end
63 | 
64 | local infer = Infer{
65 |   np = config.np,
66 |   scales = scales,
67 |   meanstd = meanstd,
68 |   model = model,
69 |   dm = config.dm,
70 | }
71 | 
72 | --------------------------------------------------------------------------------
73 | -- do it
74 | print('| start')
75 | 
76 | -- load image
77 | local img = image.load(config.img)
78 | local h,w = img:size(2),img:size(3)
79 | 
80 | -- forward all scales
81 | infer:forward(img)
82 | 
83 | -- get top propsals
84 | local masks,_ = infer:getTopProps(.2,h,w)
85 | 
86 | -- save result
87 | local res = img:clone()
88 | maskApi.drawMasks(res, masks, 10)
89 | image.save(string.format('./res.jpg',config.model),res)
90 | 
91 | print('| done')
92 | collectgarbage()
93 | 


--------------------------------------------------------------------------------
/data/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/deepmask/280a716b7a6253698f118843cb960bebd96b7ce4/data/teaser.png


--------------------------------------------------------------------------------
/data/testImage.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/deepmask/280a716b7a6253698f118843cb960bebd96b7ce4/data/testImage.jpg


--------------------------------------------------------------------------------
/evalPerImage.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | Full scene evaluation of DeepMask/SharpMask
  8 | ------------------------------------------------------------------------------]]
  9 | 
 10 | require 'torch'
 11 | require 'cutorch'
 12 | require 'image'
 13 | 
 14 | local cjson = require 'cjson'
 15 | local tds = require 'tds'
 16 | local coco = require 'coco'
 17 | 
 18 | paths.dofile('DeepMask.lua')
 19 | paths.dofile('SharpMask.lua')
 20 | 
 21 | --------------------------------------------------------------------------------
 22 | -- parse arguments
 23 | local cmd = torch.CmdLine()
 24 | cmd:text()
 25 | cmd:text('full scene evaluation of DeepMask/SharpMask')
 26 | cmd:text()
 27 | cmd:argument('-model', 'model to load')
 28 | cmd:text('Options:')
 29 | cmd:option('-datadir', 'data/', 'data directory')
 30 | cmd:option('-seed', 1, 'manually set RNG seed')
 31 | cmd:option('-gpu', 1, 'gpu device')
 32 | cmd:option('-split', 'val', 'dataset split to be used (train/val)')
 33 | cmd:option('-np', 500,'number of proposals')
 34 | cmd:option('-thr', .2, 'mask binary threshold')
 35 | cmd:option('-save', false, 'save top proposals')
 36 | cmd:option('-startAt', 1, 'start image id')
 37 | cmd:option('-endAt', 5000, 'end image id')
 38 | cmd:option('-smin', -2.5, 'min scale')
 39 | cmd:option('-smax', .5, 'max scale')
 40 | cmd:option('-sstep', .5, 'scale step')
 41 | cmd:option('-timer', false, 'breakdown timer')
 42 | cmd:option('-dm', false, 'use DeepMask version of SharpMask')
 43 | 
 44 | local config = cmd:parse(arg)
 45 | 
 46 | --------------------------------------------------------------------------------
 47 | -- various initializations
 48 | torch.setdefaulttensortype('torch.FloatTensor')
 49 | cutorch.setDevice(config.gpu)
 50 | torch.manualSeed(config.seed)
 51 | math.randomseed(config.seed)
 52 | local maskApi = coco.MaskApi
 53 | local meanstd = {mean={ 0.485, 0.456, 0.406 }, std={ 0.229, 0.224, 0.225 }}
 54 | 
 55 | --------------------------------------------------------------------------------
 56 | -- load model and config
 57 | print('| loading model file... ' .. config.model)
 58 | local m = torch.load(config.model..'/model.t7')
 59 | local c = m.config
 60 | for k,v in pairs(c) do if config[k] == nil then config[k] = v end end
 61 | local epoch = 0
 62 | if paths.filep(config.model..'/log') then
 63 |   for line in io.lines(config.model..'/log') do
 64 |     if string.find(line,'train') then epoch = epoch + 1 end
 65 |   end
 66 |   print(string.format('| number of examples seen until now: %d (%d epochs)',
 67 |     epoch*config.maxload*config.batch,epoch))
 68 | end
 69 | 
 70 | local model = m.model
 71 | model:inference(config.np)
 72 | model:cuda()
 73 | 
 74 | --------------------------------------------------------------------------------
 75 | -- directory to save results
 76 | local savedir = string.format('%s/epoch=%d/',config.model,epoch)
 77 | print(string.format('| saving results results in %s',savedir))
 78 | os.execute(string.format('mkdir -p %s',savedir))
 79 | os.execute(string.format('mkdir -p %s/t7',savedir))
 80 | os.execute(string.format('mkdir -p %s/jsons',savedir))
 81 | if config.save then os.execute(string.format('mkdir -p %s/res',savedir)) end
 82 | 
 83 | --------------------------------------------------------------------------------
 84 | -- create inference module
 85 | local scales = {}
 86 | for i = config.smin,config.smax,config.sstep do table.insert(scales,2^i) end
 87 | 
 88 | if torch.type(model)=='nn.DeepMask' then
 89 |   paths.dofile('InferDeepMask.lua')
 90 | elseif torch.type(model)=='nn.SharpMask' then
 91 |   paths.dofile('InferSharpMask.lua')
 92 | end
 93 | 
 94 | local infer = Infer{
 95 |   np = config.np,
 96 |   scales = scales,
 97 |   meanstd = meanstd,
 98 |   model = model,
 99 |   iSz = config.iSz,
100 |   dm = config.dm,
101 |   timer = config.timer,
102 | }
103 | 
104 | --------------------------------------------------------------------------------
105 | -- get list of eval images
106 | local annFile = string.format('%s/annotations/instances_%s2014.json',
107 |   config.datadir,config.split)
108 | local coco = coco.CocoApi(annFile)
109 | local imgIds = coco:getImgIds()
110 | imgIds,_ = imgIds:sort()
111 | 
112 | --------------------------------------------------------------------------------
113 | -- function: encode proposals
114 | local function encodeProps(props,np,img,k,masks,scores)
115 |   local t = (k-1)*np
116 |   local enc = maskApi.encode(masks)
117 | 
118 |   for i = 1, np do
119 |     local elem = tds.Hash()
120 |     elem.segmentation = tds.Hash(enc[i])
121 |     elem.image_id=img.id
122 |     elem.category_id=1
123 |     elem.score=scores[i][1]
124 | 
125 |     props[t+i] = elem
126 |   end
127 | end
128 | 
129 | --------------------------------------------------------------------------------
130 | -- function: convert props to json and save
131 | local function saveProps(props,savedir,s,e)
132 |   --t7
133 |   local pathsvt7 = string.format('%s/t7/props-%d-%d.t7', savedir,s,e)
134 |   torch.save(pathsvt7,props)
135 |   --json
136 |   local pathsvjson = string.format('%s/jsons/props-%d-%d.json', savedir,s,e)
137 |   local propsjson = {}
138 |   for _,prop in pairs(props) do -- hash2table
139 |     local elem = {}
140 |     elem.category_id = prop.category_id
141 |     elem.image_id = prop.image_id
142 |     elem.score = prop.score
143 |     elem.segmentation={
144 |       size={prop.segmentation.size[1],prop.segmentation.size[2]},
145 |       counts = prop.segmentation.counts or prop.segmentation.count
146 |     }
147 |     table.insert(propsjson,elem)
148 |   end
149 |   local jsonText = cjson.encode(propsjson)
150 |   local f = io.open(pathsvjson,'w'); f:write(jsonText); f:close()
151 |   collectgarbage()
152 | end
153 | 
154 | --------------------------------------------------------------------------------
155 | -- function: read image
156 | local function readImg(datadir,split,fileName)
157 |   local pathImg = string.format('%s/%s2014/%s',datadir,split,fileName)
158 |   local inp = image.load(pathImg,3)
159 |   return inp
160 | end
161 | 
162 | --------------------------------------------------------------------------------
163 | -- run
164 | print('| start eval')
165 | local props, svcount = tds.Hash(), config.startAt
166 | for k = config.startAt,config.endAt do
167 |   xlua.progress(k,config.endAt)
168 | 
169 |   -- load image
170 |   local img = coco:loadImgs(imgIds[k])[1]
171 |   local input = readImg(config.datadir,config.split,img.file_name)
172 |   local h,w = img.height,img.width
173 | 
174 |   -- forward all scales
175 |   infer:forward(input)
176 | 
177 |   -- get top proposals
178 |   local masks,scores = infer:getTopProps(config.thr,h,w)
179 | 
180 |   -- encode proposals
181 |   encodeProps(props,config.np,img,k,masks,scores)
182 | 
183 |   -- save top masks?
184 |   if config.save then
185 |     local res = input:clone()
186 |     maskApi.drawMasks(res, masks, 10)
187 |     image.save(string.format('%s/res/%d.jpg',savedir,k),res)
188 |   end
189 | 
190 |   -- save proposals
191 |   if k%500 == 0 then
192 |     saveProps(props,savedir,svcount,k); props = tds.Hash(); collectgarbage()
193 |     svcount = svcount + 500
194 |   end
195 | 
196 |   collectgarbage()
197 | end
198 | 
199 | if config.timer then infer:printTiming() end
200 | collectgarbage()
201 | print('| finish')
202 | 


--------------------------------------------------------------------------------
/evalPerPatch.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | Per patch evaluation of DeepMask/SharpMask
  8 | ------------------------------------------------------------------------------]]
  9 | 
 10 | require 'torch'
 11 | require 'cutorch'
 12 | 
 13 | paths.dofile('DeepMask.lua')
 14 | paths.dofile('SharpMask.lua')
 15 | 
 16 | --------------------------------------------------------------------------------
 17 | -- parse arguments
 18 | local cmd = torch.CmdLine()
 19 | cmd:text()
 20 | cmd:text('per patch evaluation of DeepMask/SharpMask')
 21 | cmd:text()
 22 | cmd:argument('-model', 'model to load')
 23 | cmd:text('Options:')
 24 | cmd:option('-seed', 1, 'Manually set RNG seed')
 25 | cmd:option('-gpu', 1, 'gpu device')
 26 | cmd:option('-testmaxload', 200, 'max number of testing batches')
 27 | cmd:option('-save', false, 'save output')
 28 | 
 29 | local config = cmd:parse(arg)
 30 | 
 31 | --------------------------------------------------------------------------------
 32 | -- various initializations
 33 | torch.setdefaulttensortype('torch.FloatTensor')
 34 | cutorch.setDevice(config.gpu)
 35 | torch.manualSeed(config.seed)
 36 | math.randomseed(config.seed)
 37 | 
 38 | local inputs = torch.CudaTensor()
 39 | 
 40 | --------------------------------------------------------------------------------
 41 | -- loading model and config
 42 | print('| loading model file...' .. config.model)
 43 | local m = torch.load(config.model..'/model.t7')
 44 | local c = m.config
 45 | for k,v in pairs(c) do if config[k] == nil then config[k] = v end end
 46 | local epoch = 0
 47 | if paths.filep(config.model..'/log') then
 48 |   for line in io.lines(config.model..'/log') do
 49 |     if string.find(line,'train') then epoch = epoch + 1 end
 50 |   end
 51 |   print(string.format('| number of examples seen until now: %d (%d epochs)',
 52 |     epoch*config.maxload*config.batch,epoch))
 53 | end
 54 | config.hfreq = 0 -- only evaluate masks
 55 | 
 56 | local model = m.model
 57 | if torch.type(model)=='nn.DeepMask' then
 58 |   model=nn.Sequential():add(model.trunk):add(model.maskBranch)
 59 | end
 60 | model:evaluate()
 61 | 
 62 | --------------------------------------------------------------------------------
 63 | -- directory to save results
 64 | local savedir
 65 | if config.save then
 66 |   require 'image'
 67 |   savedir = string.format('%s/epoch=%d/res-patch/',config.model,epoch)
 68 |   os.execute(string.format('mkdir -p %s',savedir))
 69 | end
 70 | 
 71 | --------------------------------------------------------------------------------
 72 | -- initialize data provider and mask meter
 73 | local DataLoader = paths.dofile('DataLoader.lua')
 74 | local _, valLoader = DataLoader.create(config)
 75 | 
 76 | paths.dofile('trainMeters.lua')
 77 | local maskmeter = IouMeter(0.5,config.testmaxload*config.batch)
 78 | 
 79 | --------------------------------------------------------------------------------
 80 | -- function display output
 81 | local function saveRes(input,target,output,savedir,n)
 82 |   local batch,h,w = target:size(1),config.gSz,config.gSz
 83 | 
 84 |   local input,target,output = input:float(),target:float(),output:float()
 85 |   input = input:narrow(3,16,config.iSz):narrow(4,16,config.iSz)
 86 |   output:mul(-1):exp():add(1):pow(-1) -- transform outs in probability
 87 |   output = output:view(batch,h,w)
 88 | 
 89 |   local imgRGB = torch.Tensor(batch,3,h,w):zero()
 90 |   local outJet = torch.Tensor(batch,3,h,w):zero()
 91 | 
 92 |   for b = 1, batch do
 93 |     imgRGB:narrow(1,b,1):copy(image.scale(input[b],w,h))
 94 |     local oj = torch.floor(output[b]*100):add(1):double()
 95 |     oj = image.scale(oj,w,h); oj = image.y2jet(oj)
 96 |     outJet:narrow(1,b,1):copy(oj)
 97 |     local mask = image.scale(target[b],w,h):ge(0):double()
 98 |     local me = image.erode(mask,torch.DoubleTensor(3,3):fill(1))
 99 |     local md = image.dilate(mask,torch.DoubleTensor(3,3):fill(1))
100 |     local maskf = md - me
101 |     maskf = maskf:eq(1)
102 |     imgRGB:narrow(1,b,1):add(-imgRGB:min()):mul(1/imgRGB:max())
103 |     imgRGB[b][1][maskf]=1; imgRGB[b][2][maskf]=0; imgRGB[b][3][maskf]=0
104 |   end
105 | 
106 |   -- concatenate
107 |   local res = torch.Tensor(3,h*batch,w*2):zero()
108 |   for b = 1, batch do
109 |     res:narrow(2,(b-1)*h+1,h):narrow(3,1,w):copy(imgRGB[b])
110 |     res:narrow(2,(b-1)*h+1,h):narrow(3,w+1,w):copy(outJet[b])
111 |   end
112 | 
113 |   image.save(string.format('%s/%d.jpg',savedir,n),res)
114 | end
115 | 
116 | --------------------------------------------------------------------------------
117 | -- start evaluation
118 | print('| start per batch evaluation')
119 | maskmeter:reset()
120 | sys.tic()
121 | for n, sample in valLoader:run() do
122 |   xlua.progress(n,config.testmaxload)
123 | 
124 |   -- copy input and target to the GPU
125 |   inputs:resize(sample.inputs:size()):copy(sample.inputs)
126 | 
127 |   -- infer mask in batch
128 |   local output = model:forward(inputs):float()
129 |   cutorch.synchronize()
130 |   output = output:view(sample.labels:size())
131 | 
132 |   -- compute IoU
133 |   maskmeter:add(output,sample.labels)
134 | 
135 |   -- save?
136 |   if config.save then
137 |     saveRes(sample.inputs, sample.labels, output, savedir, n)
138 |   end
139 | 
140 |   collectgarbage()
141 | end
142 | cutorch.synchronize()
143 | print('| finish')
144 | 
145 | --------------------------------------------------------------------------------
146 | -- log
147 | print('----------------------------------------------')
148 | local log = string.format('| model: %s\n',config.model)
149 | log = log..string.format('| # epochs: %s\n',epoch)
150 | log = log..string.format(
151 |   '| # samples: %d\n'..
152 |   '| samples/s %7d '..
153 |   '| mean %06.2f median %06.2f '..
154 |   'iou@.5 %06.2f  iou@.7 %06.2f ',
155 |   maskmeter.n,config.batch*config.testmaxload/sys.toc(),
156 |   maskmeter:value('mean'),maskmeter:value('median'),
157 |   maskmeter:value('0.5'), maskmeter:value('0.7')
158 |   )
159 | print(log)
160 | 


--------------------------------------------------------------------------------
/modelUtils.lua:
--------------------------------------------------------------------------------
 1 | --[[----------------------------------------------------------------------------
 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
 3 | This source code is licensed under the BSD-style license found in the
 4 | LICENSE file in the root directory of this source tree. An additional grant
 5 | of patent rights can be found in the PATENTS file in the same directory.
 6 | 
 7 | Utility functions for models
 8 | ------------------------------------------------------------------------------]]
 9 | 
10 | local utils = {}
11 | 
12 | --------------------------------------------------------------------------------
13 | -- all BN modules in ResNet to be transformed into SpatialConstDiagonal
14 | -- (with inn routines)
15 | local inn = require 'inn'
16 | local innutils = require 'inn.utils'
17 | if not nn.SpatialConstDiagonal then
18 |   torch.class('nn.SpatialConstDiagonal', 'inn.ConstAffine')
19 | end
20 | utils.BNtoFixed = innutils.BNtoFixed
21 | 
22 | --------------------------------------------------------------------------------
23 | -- function: linear2convTrunk
24 | function utils.linear2convTrunk(net,fSz)
25 |   return net:replace(function(x)
26 |     if torch.typename(x):find('Linear') then
27 |       local nInp,nOut = x.weight:size(2)/(fSz*fSz),x.weight:size(1)
28 |       local w = torch.reshape(x.weight,nOut,nInp,fSz,fSz)
29 |       local y = cudnn.SpatialConvolution(nInp,nOut,fSz,fSz,1,1)
30 |       y.weight:copy(w); y.gradWeight:copy(w); y.bias:copy(x.bias)
31 |       return y
32 |     elseif torch.typename(x):find('Threshold') then
33 |       return cudnn.ReLU()
34 |     elseif torch.typename(x):find('View') or
35 |        torch.typename(x):find('SpatialZeroPadding') then
36 |       return nn.Identity()
37 |     else
38 |       return x
39 |     end
40 |   end
41 |   )
42 | end
43 | 
44 | --------------------------------------------------------------------------------
45 | -- function: linear2convHeads
46 | function utils.linear2convHead(net)
47 |   return net:replace(function(x)
48 |     if torch.typename(x):find('Linear') then
49 |       local nInp,nOut = x.weight:size(2),x.weight:size(1)
50 |       local w = torch.reshape(x.weight,nOut,nInp,1,1)
51 |       local y = cudnn.SpatialConvolution(nInp,nOut,1,1,1,1)
52 |       y.weight:copy(w); y.gradWeight:copy(w); y.bias:copy(x.bias)
53 |       return y
54 |     elseif torch.typename(x):find('Threshold') then
55 |       return cudnn.ReLU()
56 |     elseif not torch.typename(x):find('View') and
57 |       not torch.typename(x):find('Copy') then
58 |       return x
59 |     end
60 |   end
61 |   )
62 | end
63 | 
64 | --------------------------------------------------------------------------------
65 | -- function: replace 0-padding of 3x3 conv into mirror-padding
66 | function utils.updatePadding(net, nn_padding)
67 |   if torch.typename(net) == "nn.Sequential" or
68 |     torch.typename(net) == "nn.ConcatTable" then
69 |     for i = #net.modules,1,-1 do
70 |       local out = utils.updatePadding(net:get(i), nn_padding)
71 |       if out ~= -1 then
72 |         local pw, ph = out[1], out[2]
73 |         net.modules[i] = nn.Sequential():add(nn_padding(pw,pw,ph,ph))
74 |           :add(net.modules[i]):cuda()
75 |       end
76 |     end
77 |   else
78 |     if torch.typename(net) == "nn.SpatialConvolution" or
79 |       torch.typename(net) == "cudnn.SpatialConvolution" then
80 |       if (net.kW == 3 and net.kH == 3) or (net.kW==7 and net.kH==7) then
81 |         local pw, ph = net.padW, net.padH
82 |         net.padW, net.padH = 0, 0
83 |         return {pw,ph}
84 |       end
85 |     end
86 |   end
87 |   return -1
88 | end
89 | 
90 | return utils
91 | 


--------------------------------------------------------------------------------
/train.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | Train DeepMask or SharpMask
  8 | ------------------------------------------------------------------------------]]
  9 | 
 10 | require 'torch'
 11 | require 'cutorch'
 12 | 
 13 | --------------------------------------------------------------------------------
 14 | -- parse arguments
 15 | local cmd = torch.CmdLine()
 16 | cmd:text()
 17 | cmd:text('train DeepMask or SharpMask')
 18 | cmd:text()
 19 | cmd:text('Options:')
 20 | cmd:option('-rundir', 'exps/', 'experiments directory')
 21 | cmd:option('-datadir', 'data/', 'data directory')
 22 | cmd:option('-seed', 1, 'manually set RNG seed')
 23 | cmd:option('-gpu', 1, 'gpu device')
 24 | cmd:option('-nthreads', 2, 'number of threads for DataSampler')
 25 | cmd:option('-reload', '', 'reload a network from given directory')
 26 | cmd:text()
 27 | cmd:text('Training Options:')
 28 | cmd:option('-batch', 32, 'training batch size')
 29 | cmd:option('-lr', 0, 'learning rate (0 uses default lr schedule)')
 30 | cmd:option('-momentum', 0.9, 'momentum')
 31 | cmd:option('-wd', 5e-4, 'weight decay')
 32 | cmd:option('-maxload', 4000, 'max number of training batches per epoch')
 33 | cmd:option('-testmaxload', 500, 'max number of testing batches')
 34 | cmd:option('-maxepoch', 300, 'max number of training epochs')
 35 | cmd:option('-iSz', 160, 'input size')
 36 | cmd:option('-oSz', 56, 'output size')
 37 | cmd:option('-gSz', 112, 'ground truth size')
 38 | cmd:option('-shift', 16, 'shift jitter allowed')
 39 | cmd:option('-scale', .25, 'scale jitter allowed')
 40 | cmd:option('-hfreq', 0.5, 'mask/score head sampling frequency')
 41 | cmd:option('-scratch', false, 'train DeepMask with randomly initialize weights')
 42 | cmd:text()
 43 | cmd:text('SharpMask Options:')
 44 | cmd:option('-dm', '', 'path to trained deepmask (if dm, then train SharpMask)')
 45 | cmd:option('-km', 32, 'km')
 46 | cmd:option('-ks', 32, 'ks')
 47 | 
 48 | local config = cmd:parse(arg)
 49 | 
 50 | --------------------------------------------------------------------------------
 51 | -- various initializations
 52 | torch.setdefaulttensortype('torch.FloatTensor')
 53 | cutorch.setDevice(config.gpu)
 54 | torch.manualSeed(config.seed)
 55 | math.randomseed(config.seed)
 56 | 
 57 | local trainSm -- flag to train SharpMask (true) or DeepMask (false)
 58 | if #config.dm > 0 then
 59 |   trainSm = true
 60 |   config.hfreq = 0 -- train only mask head
 61 |   config.gSz = config.iSz -- in sharpmask, ground-truth has same dim as input
 62 | end
 63 | 
 64 | paths.dofile('DeepMask.lua')
 65 | if trainSm then paths.dofile('SharpMask.lua') end
 66 | 
 67 | --------------------------------------------------------------------------------
 68 | -- reload?
 69 | local epoch, model
 70 | if #config.reload > 0 then
 71 |   epoch = 0
 72 |   if paths.filep(config.reload..'/log') then
 73 |     for line in io.lines(config.reload..'/log') do
 74 |       if string.find(line,'train') then epoch = epoch + 1 end
 75 |     end
 76 |   end
 77 |   print(string.format('| reloading experiment %s', config.reload))
 78 |   local m = torch.load(string.format('%s/model.t7', config.reload))
 79 |   model, config = m.model, m.config
 80 | end
 81 | 
 82 | --------------------------------------------------------------------------------
 83 | -- directory to save log and model
 84 | local pathsv = trainSm and 'sharpmask/exp' or 'deepmask/exp'
 85 | config.rundir = cmd:string(
 86 |   paths.concat(config.reload=='' and config.rundir or config.reload, pathsv),
 87 |   config,{rundir=true, gpu=true, reload=true, datadir=true, dm=true} --ignore
 88 | )
 89 | 
 90 | print(string.format('| running in directory %s', config.rundir))
 91 | os.execute(string.format('mkdir -p %s',config.rundir))
 92 | 
 93 | --------------------------------------------------------------------------------
 94 | -- network and criterion
 95 | model = model or (trainSm and nn.SharpMask(config) or nn.DeepMask(config))
 96 | local criterion = nn.SoftMarginCriterion():cuda()
 97 | 
 98 | --------------------------------------------------------------------------------
 99 | -- initialize data loader
100 | local DataLoader = paths.dofile('DataLoader.lua')
101 | local trainLoader, valLoader = DataLoader.create(config)
102 | 
103 | --------------------------------------------------------------------------------
104 | -- initialize Trainer (handles training/testing loop)
105 | if trainSm then
106 |   paths.dofile('TrainerSharpMask.lua')
107 | else
108 |   paths.dofile('TrainerDeepMask.lua')
109 | end
110 | local trainer = Trainer(model, criterion, config)
111 | 
112 | --------------------------------------------------------------------------------
113 | -- do it
114 | epoch = epoch or 1
115 | print('| start training')
116 | for i = 1, config.maxepoch do
117 |   trainer:train(epoch,trainLoader)
118 |   if i%2 == 0 then trainer:test(epoch,valLoader) end
119 |   epoch = epoch + 1
120 | end
121 | 


--------------------------------------------------------------------------------
/trainMeters.lua:
--------------------------------------------------------------------------------
  1 | --[[----------------------------------------------------------------------------
  2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
  3 | This source code is licensed under the BSD-style license found in the
  4 | LICENSE file in the root directory of this source tree. An additional grant
  5 | of patent rights can be found in the PATENTS file in the same directory.
  6 | 
  7 | Contains the tree metrics used during training/evaluation:
  8 |   - lossmeter: measure the average loss.
  9 |   - binarymeter: measure error of predicted objectness score and ground truth
 10 |                 objectness annotation.
 11 |   - ioumeter: measure iou between infered and ground truth masks.
 12 | ------------------------------------------------------------------------------]]
 13 | 
 14 | --------------------------------------------------------------------------------
 15 | -- loss meter
 16 | do
 17 |   local LossMeter = torch.class('LossMeter')
 18 |   -- init
 19 |   function LossMeter:__init()
 20 |     self:reset()
 21 |   end
 22 | 
 23 |   -- function: reset
 24 |   function LossMeter:reset()
 25 |     self.sum = 0; self.n = 0
 26 |   end
 27 | 
 28 |   -- function: add
 29 |   function LossMeter:add(value,n)
 30 |     n = n or 1
 31 |     self.sum = self.sum + value
 32 |     self.n = self.n + n
 33 |   end
 34 | 
 35 |   -- function: value
 36 |   function LossMeter:value()
 37 |     return self.sum / self.n
 38 |   end
 39 | end
 40 | 
 41 | --------------------------------------------------------------------------------
 42 | -- binary meter
 43 | do
 44 |   local BinaryMeter = torch.class('BinaryMeter')
 45 |   -- init
 46 |   function BinaryMeter:__init()
 47 |     self:reset()
 48 |   end
 49 |   -- function: reset
 50 |   function BinaryMeter:reset()
 51 |     self.acc = 0; self.n = 0
 52 |   end
 53 | 
 54 |   -- function: add
 55 |   function BinaryMeter:add(output, target)
 56 |     target, output = target:squeeze(), output:squeeze()
 57 |     assert(output:nElement() == target:nElement(),
 58 |       'target and output do not match')
 59 | 
 60 |     local acc = torch.cmul(output,target)
 61 |     self.acc = self.acc + acc:ge(0):sum()
 62 |     self.n = self.n + output:size(1)
 63 |   end
 64 | 
 65 |   -- function: value
 66 |   function BinaryMeter:value()
 67 |     local res = self.acc/self.n
 68 |     return res*100
 69 |   end
 70 | end
 71 | 
 72 | --------------------------------------------------------------------------------
 73 | -- iou meter
 74 | do
 75 |   local IouMeter = torch.class('IouMeter')
 76 |   -- init
 77 |   function IouMeter:__init(thr,sz)
 78 |     self.sz = sz
 79 |     self.iou = torch.Tensor(sz)
 80 |     self.thr = math.log(thr/(1-thr))
 81 |     self:reset()
 82 |   end
 83 | 
 84 |   -- function: reset
 85 |   function IouMeter:reset()
 86 |     self.iou:zero(); self.n = 0
 87 |   end
 88 | 
 89 |   -- function: add
 90 |   function IouMeter:add(output, target)
 91 |     target, output = target:squeeze():float(), output:squeeze():float()
 92 |     assert(output:nElement() == target:nElement(),
 93 |       'target and output do not match')
 94 | 
 95 |     local batch,h,w = output:size(1),output:size(2),output:size(3)
 96 |     local nOuts = h*w
 97 |     local iouptr = self.iou:data()
 98 | 
 99 |     local int,uni
100 |     local pred = output:ge(self.thr)
101 |     local pPtr,tPtr = pred:data(), target:data()
102 |     for b = 0,batch-1 do
103 |       int,uni = 0,0
104 |       for i = 0,nOuts-1 do
105 |         local id = b*nOuts+i
106 |         if pPtr[id] == 1 and tPtr[id] == 1 then int = int + 1 end
107 |         if pPtr[id] == 1 or tPtr[id] == 1 then uni = uni + 1 end
108 |       end
109 |       if uni > 0 then iouptr[self.n+b] = int/uni end
110 |     end
111 |     self.n = self.n + batch
112 |   end
113 | 
114 |   -- function: value
115 |   function IouMeter:value(s)
116 |     if s then
117 |       local res
118 |       local nb = math.max(self.iou:ne(0):sum(),1)
119 |       local iou = self.iou:narrow(1,1,nb)
120 |       if s == 'mean' then
121 |         res = iou:mean()
122 |       elseif s == 'median' then
123 |         res = iou:median():squeeze()
124 |       elseif tonumber(s) then
125 |         local iouSort, _ = iou:sort()
126 |         res = iouSort:ge(tonumber(s)):sum()/nb
127 |       elseif s == 'hist' then
128 |         res = torch.histc(iou,20)/nb
129 |       end
130 | 
131 |       return res*100
132 |     else
133 |       local value = {}
134 |       for _,s in ipairs(self.stats) do
135 |         value[s] = self:value(s)
136 |       end
137 |       return value
138 |     end
139 |   end
140 | end
141 | 


--------------------------------------------------------------------------------