├── .gitignore ├── CNAME ├── CONTRIBUTING.md ├── DataLoader.lua ├── DataSampler.lua ├── DeepMask.lua ├── InferDeepMask.lua ├── InferSharpMask.lua ├── LICENSE ├── PATENTS ├── README.md ├── SharpMask.lua ├── SpatialSymmetricPadding.lua ├── TrainerDeepMask.lua ├── TrainerSharpMask.lua ├── computeProposals.lua ├── data ├── teaser.png └── testImage.jpg ├── demo ├── control.js ├── cropperjs │ ├── cropper.css │ ├── cropper.js │ ├── cropper.min.css │ └── cropper.min.js ├── images │ ├── bot.jpg │ ├── crop-button.png │ ├── edit-logo.gif │ ├── img1.jpg │ ├── img2.jpg │ ├── img3.jpg │ ├── img4.jpg │ ├── img5.jpg │ ├── in.jpg │ ├── logo.jpg │ ├── mid.jpg │ ├── out.jpg │ └── out.png └── style.css ├── evalPerImage.lua ├── evalPerPatch.lua ├── index.html ├── modelUtils.lua ├── train.lua └── trainMeters.lua /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | pretrained/ 3 | data/ 4 | exps/ 5 | -------------------------------------------------------------------------------- /CNAME: -------------------------------------------------------------------------------- 1 | www.objectcropbot.com 2 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to deepmask 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `master`. 10 | 2. If you haven't already, complete the Contributor License Agreement ("CLA"). 11 | 12 | ## Contributor License Agreement ("CLA") 13 | In order to accept your pull request, we need you to submit a CLA. You only need 14 | to do this once to work on any of Facebook's open source projects. 15 | 16 | Complete your CLA here: 17 | 18 | ## Issues 19 | We use GitHub issues to track public bugs. Please ensure your description is 20 | clear and has sufficient instructions to be able to reproduce the issue. 21 | 22 | ## Coding Style 23 | * 2 spaces for indentation rather than tabs 24 | * 80 character line length 25 | 26 | ## License 27 | By contributing to deepmask, you agree that your contributions will be licensed 28 | under its [BSD license](https://github.com/facebookresearch/deepmask/blob/master/LICENSE). 29 | -------------------------------------------------------------------------------- /DataLoader.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Multi-threaded data loader 8 | ------------------------------------------------------------------------------]] 9 | 10 | local Threads = require 'threads' 11 | Threads.serialization('threads.sharedserialize') 12 | 13 | local M = {} 14 | local DataLoader = torch.class('DataLoader', M) 15 | 16 | -------------------------------------------------------------------------------- 17 | -- function: create train/val data loaders 18 | function DataLoader.create(config) 19 | local loaders = {} 20 | for i, split in ipairs{'train', 'val'} do 21 | loaders[i] = M.DataLoader(config, split) 22 | end 23 | 24 | return table.unpack(loaders) 25 | end 26 | 27 | -------------------------------------------------------------------------------- 28 | -- function: init 29 | function DataLoader:__init(config, split) 30 | local function main(idx) 31 | torch.setdefaulttensortype('torch.FloatTensor') 32 | local seed = config.seed + idx 33 | torch.manualSeed(seed) 34 | 35 | paths.dofile('DataSampler.lua') 36 | _G.ds = DataSampler(config, split) 37 | return _G.ds:size() 38 | end 39 | 40 | local threads, sizes = Threads(config.nthreads, main) 41 | self.threads = threads 42 | self.__size = sizes[1][1] 43 | self.batch = config.batch 44 | self.hfreq = config.hfreq 45 | end 46 | 47 | -------------------------------------------------------------------------------- 48 | -- function: return size of dataset 49 | function DataLoader:size() 50 | return math.ceil(self.__size / self.batch) 51 | end 52 | 53 | -------------------------------------------------------------------------------- 54 | -- function: run 55 | function DataLoader:run() 56 | local threads = self.threads 57 | local size, batch = self.__size, self.batch 58 | 59 | local idx, sample = 1, nil 60 | local function enqueue() 61 | while idx <= size and threads:acceptsjob() do 62 | local bsz = math.min(batch, size - idx + 1) 63 | threads:addjob( 64 | function(bsz, hfreq) 65 | local inputs, labels 66 | local head -- head sampling 67 | if torch.uniform() > hfreq then head = 1 else head = 2 end 68 | 69 | for i = 1, bsz do 70 | local input, label = _G.ds:get(head) 71 | if not inputs then 72 | local iSz = input:size():totable() 73 | local mSz = label:size():totable() 74 | inputs = torch.FloatTensor(bsz, table.unpack(iSz)) 75 | labels = torch.FloatTensor(bsz, table.unpack(mSz)) 76 | end 77 | inputs[i]:copy(input) 78 | labels[i]:copy(label) 79 | end 80 | collectgarbage() 81 | 82 | return {inputs = inputs, labels = labels, head = head} 83 | end, 84 | function(_sample_) sample = _sample_ end, 85 | bsz, self.hfreq 86 | ) 87 | idx = idx + batch 88 | end 89 | end 90 | 91 | local n = 0 92 | local function loop() 93 | enqueue() 94 | if not threads:hasjob() then return nil end 95 | threads:dojob() 96 | if threads:haserror() then threads:synchronize() end 97 | enqueue() 98 | n = n + 1 99 | return n, sample 100 | end 101 | 102 | return loop 103 | end 104 | 105 | return M.DataLoader 106 | -------------------------------------------------------------------------------- /DataSampler.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Dataset sampler for for training/evaluation of DeepMask and SharpMask 8 | ------------------------------------------------------------------------------]] 9 | 10 | require 'torch' 11 | require 'image' 12 | local tds = require 'tds' 13 | local coco = require 'coco' 14 | 15 | local DataSampler = torch.class('DataSampler') 16 | 17 | -------------------------------------------------------------------------------- 18 | -- function: init 19 | function DataSampler:__init(config,split) 20 | assert(split == 'train' or split == 'val') 21 | 22 | -- coco api 23 | local annFile = string.format('%s/annotations/instances_%s2014.json', 24 | config.datadir,split) 25 | self.coco = coco.CocoApi(annFile) 26 | 27 | -- mask api 28 | self.maskApi = coco.MaskApi 29 | 30 | -- mean/std computed from random subset of ImageNet training images 31 | self.mean, self.std = {0.485, 0.456, 0.406}, {0.229, 0.224, 0.225} 32 | 33 | -- class members 34 | self.datadir = config.datadir 35 | self.split = split 36 | 37 | self.iSz = config.iSz 38 | self.objSz = math.ceil(config.iSz*128/224) 39 | self.wSz = config.iSz + 32 40 | self.gSz = config.gSz 41 | self.scale = config.scale 42 | self.shift = config.shift 43 | 44 | self.imgIds = self.coco:getImgIds() 45 | self.annIds = self.coco:getAnnIds() 46 | self.catIds = self.coco:getCatIds() 47 | self.nImages = self.imgIds:size(1) 48 | 49 | if split == 'train' then self.__size = config.maxload*config.batch 50 | elseif split == 'val' then self.__size = config.testmaxload*config.batch end 51 | 52 | if config.hfreq > 0 then 53 | self.scales = {} -- scale range for score sampling 54 | for scale = -3,2,.25 do table.insert(self.scales,scale) end 55 | self:createBBstruct(self.objSz,config.scale) 56 | end 57 | 58 | collectgarbage() 59 | end 60 | local function log2(x) return math.log(x)/math.log(2) end 61 | 62 | -------------------------------------------------------------------------------- 63 | -- function: create BB struct of objects for score sampling 64 | -- each key k contain the scale and bb information of all annotations of 65 | -- image k 66 | function DataSampler:createBBstruct(objSz,scale) 67 | local bbStruct = tds.Vec() 68 | 69 | for i = 1, self.nImages do 70 | local annIds = self.coco:getAnnIds({imgId=self.imgIds[i]}) 71 | local bbs = {scales = {}} 72 | if annIds:dim() ~= 0 then 73 | for i = 1,annIds:size(1) do 74 | local annId = annIds[i] 75 | local ann = self.coco:loadAnns(annId)[1] 76 | local bbGt = ann.bbox 77 | local x0,y0,w,h = bbGt[1],bbGt[2],bbGt[3],bbGt[4] 78 | local xc,yc, maxDim = x0+w/2,y0+h/2, math.max(w,h) 79 | 80 | for s = -32,32,1 do 81 | if maxDim > objSz*2^((s-1)*scale) and 82 | maxDim <= objSz*2^((s+1)*(scale)) then 83 | local ss = -s*scale 84 | local xcS,ycS = xc*2^ss,yc*2^ss 85 | if not bbs[ss] then 86 | bbs[ss] = {}; table.insert(bbs.scales,ss) 87 | end 88 | table.insert(bbs[ss],{xcS,ycS,category_id=ann.category}) 89 | break 90 | end 91 | end 92 | end 93 | end 94 | bbStruct:insert(tds.Hash(bbs)) 95 | end 96 | collectgarbage() 97 | self.bbStruct = bbStruct 98 | end 99 | 100 | -------------------------------------------------------------------------------- 101 | -- function: get size of epoch 102 | function DataSampler:size() 103 | return self.__size 104 | end 105 | 106 | -------------------------------------------------------------------------------- 107 | -- function: get a sample 108 | function DataSampler:get(headSampling) 109 | local input,label 110 | if headSampling == 1 then -- sample masks 111 | input, label = self:maskSampling() 112 | else -- sample score 113 | input,label = self:scoreSampling() 114 | end 115 | 116 | if torch.uniform() > .5 then 117 | input = image.hflip(input) 118 | if headSampling == 1 then label = image.hflip(label) end 119 | end 120 | 121 | -- normalize input 122 | for i=1,3 do input:narrow(1,i,1):add(-self.mean[i]):div(self.std[i]) end 123 | 124 | return input,label 125 | end 126 | 127 | -------------------------------------------------------------------------------- 128 | -- function: mask sampling 129 | function DataSampler:maskSampling() 130 | local iSz,wSz,gSz = self.iSz,self.wSz,self.gSz 131 | 132 | local cat,ann = torch.random(80) 133 | while not ann or ann.iscrowd == 1 or ann.area < 100 or ann.bbox[3] < 5 134 | or ann.bbox[4] < 5 do 135 | local catId = self.catIds[cat] 136 | local annIds = self.coco:getAnnIds({catId=catId}) 137 | local annid = annIds[torch.random(annIds:size(1))] 138 | ann = self.coco:loadAnns(annid)[1] 139 | end 140 | local bbox = self:jitterBox(ann.bbox) 141 | local imgName = self.coco:loadImgs(ann.image_id)[1].file_name 142 | 143 | -- input 144 | local pathImg = string.format('%s/%s2014/%s',self.datadir,self.split,imgName) 145 | local inp = image.load(pathImg,3) 146 | local h, w = inp:size(2), inp:size(3) 147 | inp = self:cropTensor(inp, bbox, 0.5) 148 | inp = image.scale(inp, wSz, wSz) 149 | 150 | -- label 151 | local iSzR = iSz*(bbox[3]/wSz) 152 | local xc, yc = bbox[1]+bbox[3]/2, bbox[2]+bbox[4]/2 153 | local bboxInpSz = {xc-iSzR/2,yc-iSzR/2,iSzR,iSzR} 154 | local lbl = self:cropMask(ann, bboxInpSz, h, w, gSz) 155 | lbl:mul(2):add(-1) 156 | 157 | return inp, lbl 158 | end 159 | 160 | -------------------------------------------------------------------------------- 161 | -- function: score head sampler 162 | local imgPad = torch.Tensor() 163 | function DataSampler:scoreSampling(cat,imgId) 164 | local idx,bb 165 | repeat 166 | idx = torch.random(1,self.nImages) 167 | bb = self.bbStruct[idx] 168 | until #bb.scales ~= 0 169 | 170 | local imgId = self.imgIds[idx] 171 | local imgName = self.coco:loadImgs(imgId)[1].file_name 172 | local pathImg = string.format('%s/%s2014/%s',self.datadir,self.split,imgName) 173 | local img = image.load(pathImg,3) 174 | local h,w = img:size(2),img:size(3) 175 | 176 | -- sample central pixel of BB to be used 177 | local x,y,scale 178 | local lbl = torch.Tensor(1) 179 | if torch.uniform() > .5 then 180 | x,y,scale = self:posSamplingBB(bb) 181 | lbl:fill(1) 182 | else 183 | x,y,scale = self:negSamplingBB(bb,w,h) 184 | lbl:fill(-1) 185 | end 186 | 187 | local s = 2^-scale 188 | x,y = math.min(math.max(x*s,1),w), math.min(math.max(y*s,1),h) 189 | local isz = math.max(self.wSz*s,10) 190 | local bw =isz/2 191 | 192 | --pad/crop/rescale 193 | imgPad:resize(3,h+2*bw,w+2*bw):fill(.5) 194 | imgPad:narrow(2,bw+1,h):narrow(3,bw+1,w):copy(img) 195 | local inp = imgPad:narrow(2,y,isz):narrow(3,x,isz) 196 | inp = image.scale(inp,self.wSz,self.wSz) 197 | 198 | return inp,lbl 199 | end 200 | 201 | -------------------------------------------------------------------------------- 202 | -- function: crop bbox b from inp tensor 203 | function DataSampler:cropTensor(inp, b, pad) 204 | pad = pad or 0 205 | b[1], b[2] = torch.round(b[1])+1, torch.round(b[2])+1 -- 0 to 1 index 206 | b[3], b[4] = torch.round(b[3]), torch.round(b[4]) 207 | 208 | local out, h, w, ind 209 | if #inp:size() == 3 then 210 | ind, out = 2, torch.Tensor(inp:size(1), b[3], b[4]):fill(pad) 211 | elseif #inp:size() == 2 then 212 | ind, out = 1, torch.Tensor(b[3], b[4]):fill(pad) 213 | end 214 | h, w = inp:size(ind), inp:size(ind+1) 215 | 216 | local xo1,yo1,xo2,yo2 = b[1],b[2],b[3]+b[1]-1,b[4]+b[2]-1 217 | local xc1,yc1,xc2,yc2 = 1,1,b[3],b[4] 218 | 219 | -- compute box on binary mask inp and cropped mask out 220 | if b[1] < 1 then xo1=1; xc1=1+(1-b[1]) end 221 | if b[2] < 1 then yo1=1; yc1=1+(1-b[2]) end 222 | if b[1]+b[3]-1 > w then xo2=w; xc2=xc2-(b[1]+b[3]-1-w) end 223 | if b[2]+b[4]-1 > h then yo2=h; yc2=yc2-(b[2]+b[4]-1-h) end 224 | local xo, yo, wo, ho = xo1, yo1, xo2-xo1+1, yo2-yo1+1 225 | local xc, yc, wc, hc = xc1, yc1, xc2-xc1+1, yc2-yc1+1 226 | if yc+hc-1 > out:size(ind) then hc = out:size(ind )-yc+1 end 227 | if xc+wc-1 > out:size(ind+1) then wc = out:size(ind+1)-xc+1 end 228 | if yo+ho-1 > inp:size(ind) then ho = inp:size(ind )-yo+1 end 229 | if xo+wo-1 > inp:size(ind+1) then wo = inp:size(ind+1)-xo+1 end 230 | out:narrow(ind,yc,hc); out:narrow(ind+1,xc,wc) 231 | inp:narrow(ind,yo,ho); inp:narrow(ind+1,xo,wo) 232 | out:narrow(ind,yc,hc):narrow(ind+1,xc,wc):copy( 233 | inp:narrow(ind,yo,ho):narrow(ind+1,xo,wo)) 234 | 235 | return out 236 | end 237 | 238 | -------------------------------------------------------------------------------- 239 | -- function: crop bbox from mask 240 | function DataSampler:cropMask(ann, bbox, h, w, sz) 241 | local mask = torch.FloatTensor(sz,sz) 242 | local seg = ann.segmentation 243 | 244 | local scale = sz / bbox[3] 245 | local polS = {} 246 | for m, segm in pairs(seg) do 247 | polS[m] = torch.DoubleTensor():resizeAs(segm):copy(segm); polS[m]:mul(scale) 248 | end 249 | local bboxS = {} 250 | for m = 1,#bbox do bboxS[m] = bbox[m]*scale end 251 | 252 | local Rs = self.maskApi.frPoly(polS, h*scale, w*scale) 253 | local mo = self.maskApi.decode(Rs) 254 | local mc = self:cropTensor(mo, bboxS) 255 | mask:copy(image.scale(mc,sz,sz):gt(0.5)) 256 | 257 | return mask 258 | end 259 | 260 | -------------------------------------------------------------------------------- 261 | -- function: jitter bbox 262 | function DataSampler:jitterBox(box) 263 | local x, y, w, h = box[1], box[2], box[3], box[4] 264 | local xc, yc = x+w/2, y+h/2 265 | local maxDim = math.max(w,h) 266 | local scale = log2(maxDim/self.objSz) 267 | local s = scale + torch.uniform(-self.scale,self.scale) 268 | xc = xc + torch.uniform(-self.shift,self.shift)*2^s 269 | yc = yc + torch.uniform(-self.shift,self.shift)*2^s 270 | w, h = self.wSz*2^s, self.wSz*2^s 271 | return {xc-w/2, yc-h/2,w,h} 272 | end 273 | 274 | -------------------------------------------------------------------------------- 275 | --function: posSampling: do positive sampling 276 | function DataSampler:posSamplingBB(bb) 277 | local r = math.random(1,#bb.scales) 278 | local scale = bb.scales[r] 279 | r=torch.random(1,#bb[scale]) 280 | local x,y = bb[scale][r][1], bb[scale][r][2] 281 | return x,y,scale 282 | end 283 | 284 | -------------------------------------------------------------------------------- 285 | --function: negSampling: do negative sampling 286 | function DataSampler:negSamplingBB(bb,w0,h0) 287 | local x,y,scale 288 | local negSample,c = false,0 289 | while not negSample and c < 100 do 290 | local r = math.random(1,#self.scales) 291 | scale = self.scales[r] 292 | x,y = math.random(1,w0*2^scale),math.random(1,h0*2^scale) 293 | negSample = true 294 | for s = -10,10 do 295 | local ss = scale+s*self.scale 296 | if bb[ss] then 297 | for _,c in pairs(bb[ss]) do 298 | local dist = math.sqrt(math.pow(x-c[1],2)+math.pow(y-c[2],2)) 299 | if dist < 3*self.shift then 300 | negSample = false 301 | break 302 | end 303 | end 304 | end 305 | if negSample == false then break end 306 | end 307 | c=c+1 308 | end 309 | return x,y,scale 310 | end 311 | 312 | return DataSampler 313 | -------------------------------------------------------------------------------- /DeepMask.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | When initialized, it creates/load the common trunk, the maskBranch and the 8 | scoreBranch. 9 | DeepMask class members: 10 | - trunk: the common trunk (modified pre-trained resnet50) 11 | - maskBranch: the mask head architecture 12 | - scoreBranch: the score head architecture 13 | ------------------------------------------------------------------------------]] 14 | 15 | require 'nn' 16 | require 'nnx' 17 | require 'cunn' 18 | require 'cudnn' 19 | paths.dofile('SpatialSymmetricPadding.lua') 20 | local utils = paths.dofile('modelUtils.lua') 21 | 22 | local DeepMask,_ = torch.class('nn.DeepMask','nn.Container') 23 | 24 | -------------------------------------------------------------------------------- 25 | -- function: constructor 26 | function DeepMask:__init(config) 27 | -- create common trunk 28 | self:createTrunk(config) 29 | 30 | -- create mask head 31 | self:createMaskBranch(config) 32 | 33 | -- create score head 34 | self:createScoreBranch(config) 35 | 36 | -- number of parameters 37 | local npt,nps,npm = 0,0,0 38 | local p1,p2,p3 = self.trunk:parameters(), 39 | self.maskBranch:parameters(),self.scoreBranch:parameters() 40 | for k,v in pairs(p1) do npt = npt+v:nElement() end 41 | for k,v in pairs(p2) do npm = npm+v:nElement() end 42 | for k,v in pairs(p3) do nps = nps+v:nElement() end 43 | print(string.format('| number of paramaters trunk: %d', npt)) 44 | print(string.format('| number of paramaters mask branch: %d', npm)) 45 | print(string.format('| number of paramaters score branch: %d', nps)) 46 | print(string.format('| number of paramaters total: %d', npt+nps+npm)) 47 | end 48 | 49 | -------------------------------------------------------------------------------- 50 | -- function: create common trunk 51 | function DeepMask:createTrunk(config) 52 | -- size of feature maps at end of trunk 53 | self.fSz = config.iSz/16 54 | 55 | -- load trunk 56 | local trunk = torch.load('pretrained/resnet-18.t7') 57 | 58 | -- remove BN 59 | utils.BNtoFixed(trunk, true) 60 | 61 | -- remove fully connected layers 62 | trunk:remove();trunk:remove();trunk:remove();trunk:remove() 63 | 64 | -- crop central pad 65 | trunk:add(nn.SpatialZeroPadding(-1,-1,-1,-1)) 66 | 67 | -- add common extra layers 68 | trunk:add(cudnn.SpatialConvolution(256,128,1,1,1,1)) 69 | trunk:add(cudnn.ReLU()) 70 | trunk:add(nn.View(config.batch,128*self.fSz*self.fSz)) 71 | trunk:add(nn.Linear(128*self.fSz*self.fSz,512)) 72 | 73 | -- from scratch? reset the parameters 74 | if config.scratch then 75 | for k,m in pairs(trunk.modules) do if m.weight then m:reset() end end 76 | end 77 | 78 | -- symmetricPadding 79 | utils.updatePadding(trunk, nn.SpatialSymmetricPadding) 80 | 81 | self.trunk = trunk:cuda() 82 | return trunk 83 | end 84 | 85 | -------------------------------------------------------------------------------- 86 | -- function: create mask branch 87 | function DeepMask:createMaskBranch(config) 88 | local maskBranch = nn.Sequential() 89 | 90 | -- maskBranch 91 | maskBranch:add(nn.Linear(512,config.oSz*config.oSz)) 92 | self.maskBranch = nn.Sequential():add(maskBranch:cuda()) 93 | 94 | -- upsampling layer 95 | if config.gSz > config.oSz then 96 | local upSample = nn.Sequential() 97 | upSample:add(nn.Copy('torch.CudaTensor','torch.FloatTensor')) 98 | upSample:add(nn.View(config.batch,config.oSz,config.oSz)) 99 | upSample:add(nn.SpatialReSamplingEx{owidth=config.gSz,oheight=config.gSz, 100 | mode='bilinear'}) 101 | upSample:add(nn.View(config.batch,config.gSz*config.gSz)) 102 | upSample:add(nn.Copy('torch.FloatTensor','torch.CudaTensor')) 103 | self.maskBranch:add(upSample) 104 | end 105 | 106 | return self.maskBranch 107 | end 108 | 109 | -------------------------------------------------------------------------------- 110 | -- function: create score branch 111 | function DeepMask:createScoreBranch(config) 112 | local scoreBranch = nn.Sequential() 113 | scoreBranch:add(nn.Dropout(.5)) 114 | scoreBranch:add(nn.Linear(512,1024)) 115 | scoreBranch:add(nn.Threshold(0, 1e-6)) 116 | 117 | scoreBranch:add(nn.Dropout(.5)) 118 | scoreBranch:add(nn.Linear(1024,1)) 119 | 120 | self.scoreBranch = scoreBranch:cuda() 121 | return self.scoreBranch 122 | end 123 | 124 | -------------------------------------------------------------------------------- 125 | -- function: training 126 | function DeepMask:training() 127 | self.trunk:training(); self.maskBranch:training(); self.scoreBranch:training() 128 | end 129 | 130 | -------------------------------------------------------------------------------- 131 | -- function: evaluate 132 | function DeepMask:evaluate() 133 | self.trunk:evaluate(); self.maskBranch:evaluate(); self.scoreBranch:evaluate() 134 | end 135 | 136 | -------------------------------------------------------------------------------- 137 | -- function: to cuda 138 | function DeepMask:cuda() 139 | self.trunk:cuda(); self.scoreBranch:cuda(); self.maskBranch:cuda() 140 | end 141 | 142 | -------------------------------------------------------------------------------- 143 | -- function: to float 144 | function DeepMask:float() 145 | self.trunk:float(); self.scoreBranch:float(); self.maskBranch:float() 146 | end 147 | 148 | -------------------------------------------------------------------------------- 149 | -- function: inference (used for full scene inference) 150 | function DeepMask:inference() 151 | self.trunk:evaluate() 152 | self.maskBranch:evaluate() 153 | self.scoreBranch:evaluate() 154 | 155 | utils.linear2convTrunk(self.trunk,self.fSz) 156 | utils.linear2convHead(self.scoreBranch) 157 | utils.linear2convHead(self.maskBranch.modules[1]) 158 | self.maskBranch = self.maskBranch.modules[1] 159 | 160 | self:cuda() 161 | end 162 | 163 | -------------------------------------------------------------------------------- 164 | -- function: clone 165 | function DeepMask:clone(...) 166 | local f = torch.MemoryFile("rw"):binary() 167 | f:writeObject(self); f:seek(1) 168 | local clone = f:readObject(); f:close() 169 | 170 | if select('#',...) > 0 then 171 | clone.trunk:share(self.trunk,...) 172 | clone.maskBranch:share(self.maskBranch,...) 173 | clone.scoreBranch:share(self.scoreBranch,...) 174 | end 175 | 176 | return clone 177 | end 178 | 179 | return nn.DeepMask 180 | -------------------------------------------------------------------------------- /InferDeepMask.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Inference module for DeepMask 8 | ------------------------------------------------------------------------------]] 9 | 10 | require 'image' 11 | local argcheck = require 'argcheck' 12 | 13 | local Infer = torch.class('Infer') 14 | 15 | -------------------------------------------------------------------------------- 16 | -- function: unfold the mask output into a matrix of masks 17 | local function unfoldMasksMatrix(masks) 18 | local umasks = {} 19 | local oSz = math.sqrt(masks[1]:size(1)) 20 | for _,mask in pairs(masks) do 21 | local umask = mask:reshape(oSz,oSz,mask:size(2),mask:size(3)) 22 | umask=umask:transpose(1,3):transpose(2,3):transpose(2,4):transpose(3,4) 23 | table.insert(umasks,umask) 24 | end 25 | return umasks 26 | end 27 | 28 | -------------------------------------------------------------------------------- 29 | -- function: init 30 | Infer.__init = argcheck{ 31 | noordered = true, 32 | {name="self", type="Infer"}, 33 | {name="np", type="number",default=500}, 34 | {name="scales", type="table"}, 35 | {name="meanstd", type="table"}, 36 | {name="model", type="nn.Container"}, 37 | {name="iSz", type="number", default=160}, 38 | {name="dm", type="boolean", default=true}, 39 | {name="timer", type="boolean", default=false}, 40 | call = 41 | function(self, np, scales, meanstd, model, iSz, dm, timer) 42 | --model 43 | self.trunk = model.trunk 44 | self.mHead = model.maskBranch 45 | self.sHead = model.scoreBranch 46 | 47 | -- number of proposals 48 | self.np = np 49 | 50 | --mean/std 51 | self.mean, self.std = meanstd.mean, meanstd.std 52 | 53 | -- input size and border width 54 | self.iSz, self.bw = iSz, iSz/2 55 | 56 | -- timer 57 | if timer then self.timer = torch.Tensor(6):zero() end 58 | 59 | -- create scale pyramid 60 | self.scales = scales 61 | self.pyramid = nn.ConcatTable() 62 | for i = 1,#scales do 63 | self.pyramid:add(nn.SpatialReSamplingEx{rwidth=scales[i], 64 | rheight=scales[i], mode='bilinear'}) 65 | end 66 | 67 | -- allocate topScores and topMasks 68 | self.topScores = torch.Tensor() 69 | self.topMasks = torch.ByteTensor() 70 | end 71 | } 72 | 73 | -------------------------------------------------------------------------------- 74 | -- function: forward 75 | local inpPad = torch.CudaTensor() 76 | function Infer:forward(input) 77 | if input:type() == 'torch.CudaTensor' then input = input:float() end 78 | 79 | -- forward pyramid 80 | if self.timer then sys.tic() end 81 | local inpPyramid = self.pyramid:forward(input) 82 | if self.timer then self.timer:narrow(1,1,1):add(sys.toc()) end 83 | 84 | -- forward all scales through network 85 | local outPyramidMask,outPyramidScore = {},{} 86 | for i,_ in pairs(inpPyramid) do 87 | local inp = inpPyramid[i]:cuda() 88 | local h,w = inp:size(2),inp:size(3) 89 | 90 | -- padding/normalize 91 | if self.timer then sys.tic() end 92 | inpPad:resize(1,3,h+2*self.bw,w+2*self.bw):fill(.5) 93 | inpPad:narrow(1,1,1):narrow(3,self.bw+1,h):narrow(4,self.bw+1,w):copy(inp) 94 | for i=1,3 do inpPad[1][i]:add(-self.mean[i]):div(self.std[i]) end 95 | cutorch.synchronize() 96 | if self.timer then self.timer:narrow(1,2,1):add(sys.toc()) end 97 | 98 | -- forward trunk 99 | if self.timer then sys.tic() end 100 | local outTrunk = self.trunk:forward(inpPad):squeeze() 101 | cutorch.synchronize() 102 | if self.timer then self.timer:narrow(1,3,1):add(sys.toc()) end 103 | 104 | -- forward score branch 105 | if self.timer then sys.tic() end 106 | local outScore = self.sHead:forward(outTrunk) 107 | cutorch.synchronize() 108 | if self.timer then self.timer:narrow(1,4,1):add(sys.toc()) end 109 | table.insert(outPyramidScore,outScore:clone():squeeze()) 110 | 111 | -- forward mask branch 112 | if self.timer then sys.tic() end 113 | local outMask = self.mHead:forward(outTrunk) 114 | cutorch.synchronize() 115 | if self.timer then self.timer:narrow(1,5,1):add(sys.toc()) end 116 | table.insert(outPyramidMask,outMask:float():squeeze()) 117 | end 118 | 119 | self.mask = unfoldMasksMatrix(outPyramidMask) 120 | self.score = outPyramidScore 121 | 122 | if self.timer then self.timer:narrow(1,6,1):add(1) end 123 | end 124 | 125 | -------------------------------------------------------------------------------- 126 | -- function: get top scores 127 | -- return a tensor k x 4, where k is the number of top scores. 128 | -- each line contains: the score value, the scaleNb and position(of M(:)) 129 | local sortedScores = torch.Tensor() 130 | local sortedIds = torch.Tensor() 131 | local pos = torch.Tensor() 132 | function Infer:getTopScores() 133 | local topScores = self.topScores 134 | 135 | -- sort scores/ids for each scale 136 | local nScales=#self.scales 137 | local rowN=self.score[nScales]:size(1)*self.score[nScales]:size(2) 138 | sortedScores:resize(rowN,nScales):zero() 139 | sortedIds:resize(rowN,nScales):zero() 140 | for s = 1,nScales do 141 | self.score[s]:mul(-1):exp():add(1):pow(-1) -- scores2prob 142 | 143 | local sc = self.score[s] 144 | local h,w = sc:size(1),sc:size(2) 145 | 146 | local sc=sc:view(h*w) 147 | local sS,sIds=torch.sort(sc,true) 148 | local sz = sS:size(1) 149 | sortedScores:narrow(2,s,1):narrow(1,1,sz):copy(sS) 150 | sortedIds:narrow(2,s,1):narrow(1,1,sz):copy(sIds) 151 | end 152 | 153 | -- get top scores 154 | local np = self.np 155 | pos:resize(nScales):fill(1) 156 | topScores:resize(np,4):fill(1) 157 | np=math.min(np,rowN) 158 | 159 | for i = 1,np do 160 | local scale,score = 0,0 161 | for k = 1,nScales do 162 | if sortedScores[pos[k]][k] > score then 163 | score = sortedScores[pos[k]][k] 164 | scale = k 165 | end 166 | end 167 | local temp=sortedIds[pos[scale]][scale] 168 | local x=math.floor(temp/self.score[scale]:size(2)) 169 | local y=temp%self.score[scale]:size(2)+1 170 | x,y=math.max(1,x),math.max(1,y) 171 | 172 | pos[scale]=pos[scale]+1 173 | topScores:narrow(1,i,1):copy(torch.Tensor({score,scale,x,y})) 174 | end 175 | 176 | return topScores 177 | end 178 | 179 | -------------------------------------------------------------------------------- 180 | -- function: get top masks. 181 | local imgMask = torch.ByteTensor() 182 | function Infer:getTopMasks(thr,h,w) 183 | local topMasks = self.topMasks 184 | 185 | thr = math.log(thr/(1-thr)) -- 1/(1+e^-s) > th => s > log(1-th) 186 | 187 | local masks,topScores,np = self.mask,self.topScores,self.np 188 | topMasks:resize(np,h,w):zero() 189 | imgMask:resize(h,w) 190 | local imgMaskPtr = imgMask:data() 191 | 192 | for i = 1,np do 193 | imgMask:zero() 194 | local scale,x,y=topScores[i][2], topScores[i][3], topScores[i][4] 195 | local s=self.scales[scale] 196 | local sz = math.floor(self.iSz/s) 197 | local mask = masks[scale] 198 | x,y = math.min(x,mask:size(1)),math.min(y,mask:size(2)) 199 | mask = mask[x][y]:float() 200 | local mask = image.scale(mask,sz,sz,'bilinear') 201 | local mask_ptr = mask:data() 202 | 203 | local t = 16/s 204 | local delta = self.iSz/2/s 205 | for im =0, sz-1 do 206 | local ii = math.floor((x-1)*t-delta+im) 207 | for jm = 0,sz- 1 do 208 | local jj=math.floor((y-1)*t-delta+jm) 209 | if mask_ptr[sz*im + jm] > thr and 210 | ii >= 0 and ii <= h-1 and jj >= 0 and jj <= w-1 then 211 | imgMaskPtr[jj+ w*ii]=1 212 | end 213 | end 214 | end 215 | topMasks:narrow(1,i,1):copy(imgMask) 216 | end 217 | 218 | return topMasks 219 | end 220 | 221 | -------------------------------------------------------------------------------- 222 | -- function: get top proposals 223 | function Infer:getTopProps(thr,h,w) 224 | self:getTopScores() 225 | self:getTopMasks(thr,h,w) 226 | return self.topMasks, self.topScores 227 | end 228 | 229 | -------------------------------------------------------------------------------- 230 | -- function: display timer 231 | function Infer:printTiming() 232 | local t = self.timer 233 | t:div(t[t:size(1)]) 234 | 235 | print('| time pyramid:',t[1]) 236 | print('| time pre-process:',t[2]) 237 | print('| time trunk:',t[3]) 238 | print('| time score branch:',t[4]) 239 | print('| time mask branch:',t[5]) 240 | print('| time total:',t:narrow(1,1,t:size(1)-1):sum()) 241 | end 242 | 243 | return Infer 244 | -------------------------------------------------------------------------------- /InferSharpMask.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Inference module for SharpMask 8 | ------------------------------------------------------------------------------]] 9 | 10 | require 'image' 11 | local argcheck = require 'argcheck' 12 | 13 | local Infer = torch.class('Infer') 14 | 15 | -------------------------------------------------------------------------------- 16 | -- function: init 17 | Infer.__init = argcheck{ 18 | noordered = true, 19 | {name="self", type="Infer"}, 20 | {name="np", type="number",default=500}, 21 | {name="scales", type="table"}, 22 | {name="meanstd", type="table"}, 23 | {name="model", type="nn.Container"}, 24 | {name="iSz", type="number", default=160}, 25 | {name="dm", type="boolean", default=false}, 26 | {name="timer", type="boolean", default=false}, 27 | call = 28 | function(self, np, scales, meanstd, model, iSz, dm, timer) 29 | --model 30 | self.trunk = model.trunk 31 | self.mBranch = model.maskBranchDM 32 | self.sBranch = model.scoreBranch 33 | self.refs = model.refs 34 | self.neths = model.neths 35 | self.skpos = model.skpos 36 | self.fSz = model.fSz 37 | self.dm = dm -- flag to use deepmask instead of sharpmask 38 | 39 | -- number of proposals 40 | self.np = np 41 | 42 | --mean/std 43 | self.mean, self.std = meanstd.mean, meanstd.std 44 | 45 | -- input size and border width 46 | self.iSz, self.bw = iSz, iSz/2 47 | 48 | -- timer 49 | if timer then self.timer = torch.Tensor(8):zero() end 50 | 51 | -- create scale pyramid 52 | self.scales = scales 53 | self.pyramid = nn.ConcatTable() 54 | for i = 1,#scales do 55 | self.pyramid:add(nn.SpatialReSamplingEx{rwidth=scales[i], 56 | rheight=scales[i], mode='bilinear'}) 57 | end 58 | 59 | -- allocate topScores, topMasks and topPatches 60 | self.topScores, self.topMasks = torch.Tensor(), torch.ByteTensor() 61 | local topPatches 62 | if self.dm then 63 | topPatches = torch.CudaTensor(self.np,512):zero() 64 | else 65 | topPatches = {} 66 | topPatches[1] = torch.CudaTensor(self.np,512):zero() 67 | for j = 1, #model.refs do 68 | local sz = model.fSz*2^(j-1) 69 | topPatches[j+1] = torch.CudaTensor(self.np,model.ks/2^(j),sz,sz) 70 | end 71 | end 72 | self.topPatches = topPatches 73 | end 74 | } 75 | 76 | -------------------------------------------------------------------------------- 77 | -- function: forward 78 | local inpPad = torch.CudaTensor() 79 | function Infer:forward(input,id) 80 | if input:type() == 'torch.CudaTensor' then input = input:float() end 81 | 82 | -- forward pyramid 83 | if self.timer then sys.tic() end 84 | local inpPyramid = self.pyramid:forward(input) 85 | if self.timer then self.timer:narrow(1,1,1):add(sys.toc()) end 86 | 87 | -- forward all scales through network 88 | local outPyramidTrunk,outPyramidScore,outPyramidSkip = {},{},{} 89 | for i,_ in pairs(inpPyramid) do 90 | local inp = inpPyramid[i]:cuda() 91 | local h,w = inp:size(2),inp:size(3) 92 | 93 | -- padding/normalize 94 | if self.timer then sys.tic() end 95 | inpPad:resize(1,3,h+2*self.bw,w+2*self.bw):fill(.5) 96 | inpPad:narrow(1,1,1):narrow(3,self.bw+1,h):narrow(4,self.bw+1,w):copy(inp) 97 | for i=1,3 do inpPad[1][i]:add(-self.mean[i]):div(self.std[i]) end 98 | cutorch.synchronize() 99 | if self.timer then self.timer:narrow(1,2,1):add(sys.toc()) end 100 | 101 | -- forward trunk 102 | if self.timer then sys.tic() end 103 | local outTrunk = self.trunk:forward(inpPad) 104 | cutorch.synchronize() 105 | if self.timer then self.timer:narrow(1,3,1):add(sys.toc()) end 106 | table.insert(outPyramidTrunk,outTrunk:clone():squeeze()) 107 | 108 | -- forward score branch 109 | if self.timer then sys.tic() end 110 | local outScore = self.sBranch:forward(outTrunk) 111 | cutorch.synchronize() 112 | if self.timer then self.timer:narrow(1,4,1):add(sys.toc()) end 113 | table.insert(outPyramidScore,outScore:clone():squeeze()) 114 | 115 | -- forward horizontal nets 116 | if not self.dm then 117 | local hOuts = {} 118 | for k,neth in pairs(self.neths) do 119 | if self.timer then sys.tic() end 120 | neth:forward(self.trunk.modules[self.skpos[k]].output) 121 | cutorch.synchronize() 122 | if self.timer then self.timer:narrow(1,5,1):add(sys.toc()) end 123 | hOuts[k] = neth.output:clone() 124 | end 125 | outPyramidSkip[i] = hOuts 126 | end 127 | end 128 | 129 | -- get top scores 130 | self:getTopScores(outPyramidScore) 131 | 132 | -- get top patches and top masks, depending on mode 133 | local topMasks0 134 | if self.dm then 135 | if self.timer then sys.tic() end 136 | self:getTopPatchesDM(outPyramidTrunk) 137 | if self.timer then self.timer:narrow(1,6,1):add(sys.toc()) end 138 | 139 | if self.timer then sys.tic() end 140 | topMasks0 = self.mBranch:forward(self.topPatches) 141 | local osz = math.sqrt(topMasks0:size(2)) 142 | topMasks0 = topMasks0:view(self.np,osz,osz) 143 | if self.timer then self.timer:narrow(1,7,1):add(sys.toc()) end 144 | else 145 | if self.timer then sys.tic() end 146 | self:getTopPatches(outPyramidTrunk,outPyramidSkip) 147 | if self.timer then self.timer:narrow(1,6,1):add(sys.toc()) end 148 | 149 | if self.timer then sys.tic() end 150 | topMasks0 = self:forwardRefinement(self.topPatches) 151 | if self.timer then self.timer:narrow(1,7,1):add(sys.toc()) end 152 | end 153 | self.topMasks0 = topMasks0:float():squeeze() 154 | 155 | collectgarbage() 156 | 157 | if self.timer then self.timer:narrow(1,8,1):add(1) end 158 | end 159 | 160 | -------------------------------------------------------------------------------- 161 | -- function: forward refinement inference 162 | -- input is a table containing the output of bottom-up and the output of all 163 | -- horizontal layers 164 | function Infer:forwardRefinement(input) 165 | local currentOutput = self.refs[0]:forward(input[1]) 166 | for i = 1,#self.refs do 167 | currentOutput = self.refs[i]:forward({input[i+1],currentOutput}) 168 | end 169 | cutorch.synchronize() 170 | self.output = currentOutput 171 | return self.output 172 | end 173 | 174 | -------------------------------------------------------------------------------- 175 | -- function: get top patches 176 | function Infer:getTopPatchesDM(outPyramidTrunk) 177 | local topscores = self.topScores 178 | local ts_ptr = topscores:data() 179 | for i = 1, topscores:size(1) do 180 | local pos = (i-1)*4 181 | local s,x,y = ts_ptr[pos+1], ts_ptr[pos+2], ts_ptr[pos+3] 182 | local patch = outPyramidTrunk[s]:narrow(2,x,1):narrow(3,y,1) 183 | self.topPatches:narrow(1,i,1):copy(patch) 184 | end 185 | end 186 | 187 | -------------------------------------------------------------------------------- 188 | -- function: get top patches 189 | local t 190 | function Infer:getTopPatches(outPyramidTrunk,outPyramidSkip) 191 | local topscores = self.topScores 192 | local ts_ptr = topscores:data() 193 | 194 | if not t then t={}; for j = 1, #self.skpos do t[j]=2^(j-1) end end 195 | 196 | for i = 1, #self.topPatches do self.topPatches[i]:zero() end 197 | for i = 1, self.np do 198 | local pos = (i-1)*4 199 | local s,x,y = ts_ptr[pos+1], ts_ptr[pos+2], ts_ptr[pos+3] 200 | 201 | -- get patches from output outPyramidTrunk 202 | local patch = outPyramidTrunk[s]:narrow(2,x,1):narrow(3,y,1) 203 | self.topPatches[1]:narrow(1,i,1):copy(patch) 204 | 205 | for j = 1, #self.skpos do 206 | local isz =(self.fSz)*t[j] 207 | local xx,yy = (x-1)*t[j]+1 , (y-1)*t[j]+1 208 | local o = outPyramidSkip[s][j] 209 | local dx=math.min(isz,o:size(3)-xx+1) 210 | local dy=math.min(isz,o:size(4)-yy+1) 211 | local patch = o:narrow(3,xx,dx):narrow(4,yy,dy) 212 | self.topPatches[j+1]:narrow(1,i,1):narrow(3,1,dx):narrow(4,1,dy) 213 | :copy(patch) 214 | end 215 | end 216 | cutorch.synchronize() 217 | collectgarbage() 218 | end 219 | 220 | -------------------------------------------------------------------------------- 221 | -- function: get top scores 222 | -- return a tensor k x 4, where k is the number of top scores. 223 | -- each line contains: the score value, the scaleNb and position(of M(:)) 224 | local sortedScores = torch.Tensor() 225 | local sortedIds = torch.Tensor() 226 | local pos = torch.Tensor() 227 | function Infer:getTopScores(outPyramidScore) 228 | local topScores = self.topScores 229 | 230 | self.score = outPyramidScore 231 | local np = self.np 232 | -- sort scores/ids for each scale 233 | local nScales=#self.score 234 | local rowN=self.score[nScales]:size(1)*self.score[nScales]:size(2) 235 | sortedScores:resize(rowN,nScales):zero() 236 | sortedIds:resize(rowN,nScales):zero() 237 | for s = 1,nScales do 238 | self.score[s]:mul(-1):exp():add(1):pow(-1) -- scores2prob 239 | 240 | local sc = self.score[s] 241 | local h,w = sc:size(1),sc:size(2) 242 | 243 | local sc=sc:view(h*w) 244 | local sS,sIds=torch.sort(sc,true) 245 | local sz = sS:size(1) 246 | sortedScores:narrow(2,s,1):narrow(1,1,sz):copy(sS) 247 | sortedIds:narrow(2,s,1):narrow(1,1,sz):copy(sIds) 248 | end 249 | 250 | -- get top scores 251 | pos:resize(nScales):fill(1) 252 | topScores:resize(np,4):fill(1) 253 | np=math.min(np,rowN) 254 | 255 | for i = 1,np do 256 | local scale,score = 0,0 257 | for k = 1,nScales do 258 | if sortedScores[pos[k]][k] > score then 259 | score = sortedScores[pos[k]][k] 260 | scale = k 261 | end 262 | end 263 | local temp=sortedIds[pos[scale]][scale] 264 | local x=math.floor(temp/self.score[scale]:size(2)) 265 | local y=temp%self.score[scale]:size(2)+1 266 | x,y=math.max(1,x),math.max(1,y) 267 | 268 | pos[scale]=pos[scale]+1 269 | topScores:narrow(1,i,1):copy(torch.Tensor({score,scale,x,y})) 270 | end 271 | 272 | return topScores 273 | end 274 | 275 | -------------------------------------------------------------------------------- 276 | -- function: get top masks. 277 | local topMasks = torch.ByteTensor() 278 | local imgMask = torch.ByteTensor() 279 | function Infer:getTopMasks(thr,h,w) 280 | thr = math.log(thr/(1-thr)) -- 1/(1+e^-s) > th => s > log(1-th) 281 | 282 | local topMasks0,topScores,np = self.topMasks0,self.topScores,self.np 283 | topMasks:resize(np,h,w):zero() 284 | imgMask:resize(h,w) 285 | local imgMaskPtr = imgMask:data() 286 | 287 | for i = 1,np do 288 | imgMask:zero() 289 | local scale,x,y = topScores[i][2],topScores[i][3],topScores[i][4] 290 | local s = self.scales[scale] 291 | local sz = math.floor(self.iSz/s) 292 | local mask = topMasks0[i] 293 | local x,y = math.min(x,mask:size(1)),math.min(y,mask:size(2)) 294 | local mask = image.scale(mask,sz,sz,'bilinear') 295 | local maskPtr = mask:data() 296 | 297 | local t,delta = 16/s, self.iSz/2/s 298 | for im =0, sz-1 do 299 | local ii = math.floor((x-1)*t-delta+im) 300 | for jm = 0,sz- 1 do 301 | local jj=math.floor((y-1)*t-delta+jm) 302 | if maskPtr[sz*im + jm] > thr and 303 | ii >= 0 and ii <= h-1 and jj >= 0 and jj <= w-1 then 304 | imgMaskPtr[jj+ w*ii]=1 305 | end 306 | end 307 | end 308 | 309 | topMasks:narrow(1,i,1):copy(imgMask) 310 | end 311 | 312 | self.topMasks = topMasks 313 | return topMasks 314 | end 315 | 316 | -------------------------------------------------------------------------------- 317 | -- function: get top proposals 318 | function Infer:getTopProps(thr,h,w) 319 | self:getTopMasks(thr,h,w) 320 | return self.topMasks, self.topScores 321 | end 322 | 323 | -------------------------------------------------------------------------------- 324 | -- function: display timer 325 | function Infer:printTiming() 326 | local t = self.timer 327 | t:div(t[t:size(1)]) 328 | 329 | print('\n| timing:') 330 | print('| time pyramid:',t[1]) 331 | print('| time pre-process:',t[2]) 332 | print('| time trunk:',t[3]) 333 | print('| time score branch:',t[4]) 334 | print('| time skip connections:',t[5]) 335 | print('| time topPatches:',t[6]) 336 | print('| time refinement:',t[7]) 337 | print('| time total:',t:narrow(1,1,t:size(1)-1):sum()) 338 | end 339 | 340 | return Infer 341 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For deepmask software 4 | 5 | Copyright (c) 2016, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the deepmask software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | See http://www.andreykurenkov.com/projects/hacks/objectcropbot/ 2 | -------------------------------------------------------------------------------- /SharpMask.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | When initialized, it loads a pre-trained DeepMask and create the refinement 8 | modules. 9 | SharpMask class members: 10 | - self.trunk: common trunk (from trained DeepMask model) 11 | - self.scoreBranch: score head architecture (from trained DeepMask model) 12 | - self.maskBranchDM: mask head architecture (from trained DeepMask model) 13 | - self.refs: ensemble of refinement modules for top-down path 14 | ------------------------------------------------------------------------------]] 15 | 16 | require 'nn' 17 | require 'nnx' 18 | require 'cunn' 19 | require 'cudnn' 20 | local utils = paths.dofile('modelUtils.lua') 21 | 22 | local SharpMask, _ = torch.class('nn.SharpMask','nn.Container') 23 | 24 | -------------------------------------------------------------------------------- 25 | -- function: init 26 | function SharpMask:__init(config) 27 | self.km, self.ks = config.km, config.ks 28 | assert(self.km >= 16 and self.km%16==0 and self.ks >= 16 and self.ks%16==0) 29 | 30 | self.skpos = {8,6,5,3} -- positions to forward horizontal nets 31 | self.inps = {} 32 | 33 | -- create bottom-up flow (from deepmask) 34 | local m = torch.load(config.dm..'/model.t7') 35 | local deepmask = m.model 36 | self.trunk = deepmask.trunk 37 | self.scoreBranch = deepmask.scoreBranch 38 | self.maskBranchDM = deepmask.maskBranch 39 | self.fSz = deepmask.fSz 40 | 41 | -- create refinement modules 42 | self:createTopDownRefinement(config) 43 | 44 | -- number of parameters 45 | local nh,nv = 0,0 46 | for k,v in pairs(self.neths) do 47 | for kk,vv in pairs(v:parameters()) do nh = nh+vv:nElement() end 48 | end 49 | for k,v in pairs(self.netvs) do 50 | for kk,vv in pairs(v:parameters()) do nv = nv+vv:nElement() end 51 | end 52 | print(string.format('| number of paramaters net h: %d', nh)) 53 | print(string.format('| number of paramaters net v: %d', nv)) 54 | print(string.format('| number of paramaters total: %d', nh+nv)) 55 | self:cuda() 56 | end 57 | 58 | -------------------------------------------------------------------------------- 59 | -- function: create vertical nets 60 | function SharpMask:createVertical(config) 61 | local netvs = {} 62 | 63 | local n0 = nn.Sequential() 64 | n0:add(nn.Linear(512,self.fSz*self.fSz*self.km)) 65 | n0:add(nn.View(config.batch,self.km,self.fSz,self.fSz)) 66 | netvs[0]=n0:cuda() 67 | 68 | for i = 1, #self.skpos do 69 | local netv = nn.Sequential() 70 | local nInps = self.km/2^(i-1) 71 | 72 | netv:add(nn.SpatialSymmetricPadding(1,1,1,1)) 73 | netv:add(cudnn.SpatialConvolution(nInps,nInps,3,3,1,1)) 74 | netv:add(cudnn.ReLU()) 75 | 76 | netv:add(nn.SpatialSymmetricPadding(1,1,1,1)) 77 | netv:add(cudnn.SpatialConvolution(nInps,nInps/2,3,3,1,1)) 78 | 79 | table.insert(netvs,netv:cuda()) 80 | end 81 | 82 | self.netvs = netvs 83 | return netvs 84 | end 85 | 86 | -------------------------------------------------------------------------------- 87 | -- function: create horizontal nets 88 | function SharpMask:createHorizontal(config) 89 | local neths = {} 90 | local nhu1,nhu2,crop 91 | for i =1,#self.skpos do 92 | local h = nn.Sequential() 93 | local nInps = self.ks/2^(i-1) 94 | 95 | if i == 1 then nhu1,nhu2,crop=1024,64,0 96 | elseif i == 2 then nhu1,nhu2,crop = 512,64,-2 97 | elseif i == 3 then nhu1,nhu2,crop = 256,64,-4 98 | elseif i == 4 then nhu1,nhu2,crop = 64,32,-8 99 | end 100 | if crop ~= 0 then h:add(nn.SpatialZeroPadding(crop,crop,crop,crop)) end 101 | 102 | h:add(nn.SpatialSymmetricPadding(1,1,1,1)) 103 | h:add(cudnn.SpatialConvolution(nhu1,nhu2,3,3,1,1)) 104 | h:add(cudnn.ReLU()) 105 | 106 | h:add(nn.SpatialSymmetricPadding(1,1,1,1)) 107 | h:add(cudnn.SpatialConvolution(nhu2,nInps,3,3,1,1)) 108 | h:add(cudnn.ReLU()) 109 | 110 | h:add(nn.SpatialSymmetricPadding(1,1,1,1)) 111 | h:add(cudnn.SpatialConvolution(nInps,nInps/2,3,3,1,1)) 112 | 113 | table.insert(neths,h:cuda()) 114 | end 115 | 116 | self.neths = neths 117 | return neths 118 | end 119 | 120 | -------------------------------------------------------------------------------- 121 | -- function: create refinement modules 122 | function SharpMask:refinement(neth,netv) 123 | local ref = nn.Sequential() 124 | local par = nn.ParallelTable():add(neth):add(netv) 125 | ref:add(par) 126 | ref:add(nn.CAddTable(2)) 127 | ref:add(cudnn.ReLU()) 128 | ref:add(nn.SpatialUpSamplingNearest(2)) 129 | 130 | return ref:cuda() 131 | end 132 | 133 | function SharpMask:createTopDownRefinement(config) 134 | -- create horizontal nets 135 | self:createHorizontal(config) 136 | 137 | -- create vertical nets 138 | self:createVertical(config) 139 | 140 | local refs = {} 141 | refs[0] = self.netvs[0] 142 | for i = 1, #self.skpos do 143 | table.insert(refs,self:refinement(self.neths[i],self.netvs[i])) 144 | end 145 | 146 | local finalref = refs[#refs] 147 | finalref:add(nn.SpatialSymmetricPadding(1,1,1,1)) 148 | finalref:add(cudnn.SpatialConvolution((self.km)/2^(#refs),1,3,3,1,1)) 149 | finalref:add(nn.View(config.batch,config.gSz*config.gSz)) 150 | 151 | self.refs = refs 152 | return refs 153 | end 154 | 155 | -------------------------------------------------------------------------------- 156 | -- function: forward 157 | function SharpMask:forward(input) 158 | -- forward bottom-up 159 | local currentOutput = self.trunk:forward(input) 160 | 161 | -- forward refinement modules 162 | currentOutput = self.refs[0]:forward(currentOutput) 163 | for k = 1,#self.refs do 164 | local F = self.trunk.modules[self.skpos[k]].output 165 | self.inps[k] = {F,currentOutput} 166 | currentOutput = self.refs[k]:forward(self.inps[k]) 167 | end 168 | self.output = currentOutput 169 | return self.output 170 | end 171 | 172 | -------------------------------------------------------------------------------- 173 | -- function: backward 174 | function SharpMask:backward(input,gradOutput) 175 | local currentGrad = gradOutput 176 | for i = #self.refs,1,-1 do 177 | currentGrad =self.refs[i]:backward(self.inps[i],currentGrad) 178 | currentGrad = currentGrad[2] 179 | end 180 | currentGrad =self.refs[0]:backward(self.trunk.output,currentGrad) 181 | 182 | self.gradInput = currentGrad 183 | return currentGrad 184 | end 185 | 186 | -------------------------------------------------------------------------------- 187 | -- function: zeroGradParameters 188 | function SharpMask:zeroGradParameters() 189 | for k,v in pairs(self.refs) do self.refs[k]:zeroGradParameters() end 190 | end 191 | 192 | -------------------------------------------------------------------------------- 193 | -- function: updateParameters 194 | function SharpMask:updateParameters(lr) 195 | for k,n in pairs(self.refs) do self.refs[k]:updateParameters(lr) end 196 | end 197 | 198 | -------------------------------------------------------------------------------- 199 | -- function: training 200 | function SharpMask:training() 201 | self.trunk:training();self.scoreBranch:training();self.maskBranchDM:training() 202 | for k,n in pairs(self.refs) do self.refs[k]:training() end 203 | end 204 | 205 | -------------------------------------------------------------------------------- 206 | -- function: evaluate 207 | function SharpMask:evaluate() 208 | self.trunk:evaluate();self.scoreBranch:evaluate();self.maskBranchDM:evaluate() 209 | for k,n in pairs(self.refs) do self.refs[k]:evaluate() end 210 | end 211 | 212 | -------------------------------------------------------------------------------- 213 | -- function: to cuda 214 | function SharpMask:cuda() 215 | self.trunk:cuda();self.scoreBranch:cuda();self.maskBranchDM:cuda() 216 | for k,n in pairs(self.refs) do self.refs[k]:cuda() end 217 | end 218 | 219 | -------------------------------------------------------------------------------- 220 | -- function: to float 221 | function SharpMask:float() 222 | self.trunk:float();self.scoreBranch:float();self.maskBranchDM:float() 223 | for k,n in pairs(self.refs) do self.refs[k]:float() end 224 | end 225 | 226 | -------------------------------------------------------------------------------- 227 | -- function: set number of proposals for inference 228 | function SharpMask:setnpinference(np) 229 | local vsz = self.refs[0].modules[2].size 230 | self.refs[0].modules[2]:resetSize(np,vsz[2],vsz[3],vsz[4]) 231 | end 232 | 233 | -------------------------------------------------------------------------------- 234 | -- function: inference (used for full scene inference) 235 | function SharpMask:inference(np) 236 | self:evaluate() 237 | 238 | -- remove last view 239 | self.refs[#self.refs]:remove() 240 | 241 | -- remove ZeroPaddings 242 | self.trunk.modules[8]=nn.Identity():cuda() 243 | for k = 1, #self.refs do 244 | local m = self.refs[k].modules[1].modules[1].modules[1] 245 | if torch.typename(m):find('SpatialZeroPadding') then 246 | self.refs[k].modules[1].modules[1].modules[1]=nn.Identity():cuda() 247 | end 248 | end 249 | 250 | -- remove horizontal links, as they are applied convolutionally 251 | for k = 1, #self.refs do 252 | self.refs[k].modules[1].modules[1]=nn.Identity():cuda() 253 | end 254 | 255 | -- modify number of batch to np (number of proposals) 256 | self:setnpinference(np) 257 | 258 | -- transform trunk and score branch to conv 259 | utils.linear2convTrunk(self.trunk,self.fSz) 260 | utils.linear2convHead(self.scoreBranch) 261 | self.maskBranchDM = self.maskBranchDM.modules[1] 262 | 263 | self:cuda() 264 | end 265 | 266 | -------------------------------------------------------------------------------- 267 | -- function: clone 268 | function SharpMask:clone(...) 269 | local f = torch.MemoryFile("rw"):binary() 270 | f:writeObject(self); f:seek(1) 271 | local clone = f:readObject(); f:close() 272 | 273 | if select('#',...) > 0 then 274 | clone.trunk:share(self.trunk,...) 275 | clone.maskBranchDM:share(self.maskBranchDM,...) 276 | clone.scoreBranch:share(self.scoreBranch,...) 277 | for k,n in pairs(self.netvs) do clone.netvs[k]:share(self.netvs[k],...)end 278 | for k,n in pairs(self.neths) do clone.neths[k]:share(self.neths[k],...) end 279 | for k,n in pairs(self.refs) do clone.refs[k]:share(self.refs[k],...) end 280 | end 281 | 282 | return clone 283 | end 284 | 285 | return nn.SharpMask 286 | -------------------------------------------------------------------------------- /SpatialSymmetricPadding.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | SpatialSymmetricPadding module 8 | 9 | The forward(A) pads input tensor A with mirror reflections of itself 10 | It is the same function as Matlab padarray(A, padsize, 'symmetric' ) 11 | The padding is of the form: cba[abcd...] 12 | While nn.SpatialReflectionPadding does: dcb[abcd...] 13 | And nn.SpatialReplicationPadding does: aaa[abcd...] 14 | (where [abcd...] is a tensor) 15 | The updateGradInput(input, gradOutput) is inherited from nn.SpatialZeroPadding, 16 | where the padded region is treated as constant and 17 | the gradients are accumulated in the backward pass 18 | ------------------------------------------------------------------------------]] 19 | 20 | local SpatialSymmetricPadding, parent = 21 | torch.class('nn.SpatialSymmetricPadding', 'nn.SpatialZeroPadding') 22 | 23 | function SpatialSymmetricPadding:__init(pad_l, pad_r, pad_t, pad_b) 24 | parent.__init(self, pad_l, pad_r, pad_t, pad_b) 25 | end 26 | 27 | function SpatialSymmetricPadding:updateOutput(input) 28 | assert(input:dim()==4, "only Dimension=4 implemented") 29 | -- sizes 30 | local h = input:size(3) + self.pad_t + self.pad_b 31 | local w = input:size(4) + self.pad_l + self.pad_r 32 | if w < 1 or h < 1 then error('input is too small') end 33 | self.output:resize(input:size(1), input:size(2), h, w) 34 | self.output:zero() 35 | -- crop input if necessary 36 | local c_input = input 37 | if self.pad_t < 0 then 38 | c_input = c_input:narrow(3, 1 - self.pad_t, c_input:size(3) + self.pad_t) 39 | end 40 | if self.pad_b < 0 then 41 | c_input = c_input:narrow(3, 1, c_input:size(3) + self.pad_b) 42 | end 43 | if self.pad_l < 0 then 44 | c_input = c_input:narrow(4, 1 - self.pad_l, c_input:size(4) + self.pad_l) 45 | end 46 | if self.pad_r < 0 then 47 | c_input = c_input:narrow(4, 1, c_input:size(4) + self.pad_r) 48 | end 49 | -- crop outout if necessary 50 | local c_output = self.output 51 | if self.pad_t > 0 then 52 | c_output = c_output:narrow(3, 1 + self.pad_t, c_output:size(3) - self.pad_t) 53 | end 54 | if self.pad_b > 0 then 55 | c_output = c_output:narrow(3, 1, c_output:size(3) - self.pad_b) 56 | end 57 | if self.pad_l > 0 then 58 | c_output = c_output:narrow(4, 1 + self.pad_l, c_output:size(4) - self.pad_l) 59 | end 60 | if self.pad_r > 0 then 61 | c_output = c_output:narrow(4, 1, c_output:size(4) - self.pad_r) 62 | end 63 | -- copy input to output 64 | c_output:copy(c_input) 65 | -- symmetric padding that fills in values on the padded region 66 | if w<2*self.pad_l or w<2*self.pad_r or h<2*self.pad_t or h<2*self.pad_b then 67 | error('input is too small') 68 | end 69 | for i=1,self.pad_t do 70 | self.output:narrow(3,self.pad_t-i+1,1):copy( 71 | self.output:narrow(3,i+self.pad_t,1)) 72 | end 73 | for i=1,self.pad_b do 74 | self.output:narrow(3,self.output:size(3)-self.pad_b+i,1):copy( 75 | self.output:narrow(3,self.output:size(3)-self.pad_b-i+1,1)) 76 | end 77 | for i=1,self.pad_l do 78 | self.output:narrow(4,self.pad_l-i+1,1):copy( 79 | self.output:narrow(4,i+self.pad_l,1)) 80 | end 81 | for i=1,self.pad_r do 82 | self.output:narrow(4,self.output:size(4)-self.pad_r+i,1):copy( 83 | self.output:narrow(4,self.output:size(4)-self.pad_r-i+1,1)) 84 | end 85 | return self.output 86 | end 87 | -------------------------------------------------------------------------------- /TrainerDeepMask.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Training and testing loop for DeepMask 8 | ------------------------------------------------------------------------------]] 9 | 10 | local optim = require 'optim' 11 | paths.dofile('trainMeters.lua') 12 | 13 | local Trainer = torch.class('Trainer') 14 | 15 | -------------------------------------------------------------------------------- 16 | -- function: init 17 | function Trainer:__init(model, criterion, config) 18 | -- training params 19 | self.config = config 20 | self.model = model 21 | self.maskNet = nn.Sequential():add(model.trunk):add(model.maskBranch) 22 | self.scoreNet = nn.Sequential():add(model.trunk):add(model.scoreBranch) 23 | self.criterion = criterion 24 | self.lr = config.lr 25 | self.optimState ={} 26 | for k,v in pairs({'trunk','mask','score'}) do 27 | self.optimState[v] = { 28 | learningRate = config.lr, 29 | learningRateDecay = 0, 30 | momentum = config.momentum, 31 | dampening = 0, 32 | weightDecay = config.wd, 33 | } 34 | end 35 | 36 | -- params and gradparams 37 | self.pt,self.gt = model.trunk:getParameters() 38 | self.pm,self.gm = model.maskBranch:getParameters() 39 | self.ps,self.gs = model.scoreBranch:getParameters() 40 | 41 | -- allocate cuda tensors 42 | self.inputs, self.labels = torch.CudaTensor(), torch.CudaTensor() 43 | 44 | -- meters 45 | self.lossmeter = LossMeter() 46 | self.maskmeter = IouMeter(0.5,config.testmaxload*config.batch) 47 | self.scoremeter = BinaryMeter() 48 | 49 | -- log 50 | self.modelsv = {model=model:clone('weight', 'bias'),config=config} 51 | self.rundir = config.rundir 52 | self.log = torch.DiskFile(self.rundir .. '/log', 'rw'); self.log:seekEnd() 53 | end 54 | 55 | -------------------------------------------------------------------------------- 56 | -- function: train 57 | function Trainer:train(epoch, dataloader) 58 | self.model:training() 59 | self:updateScheduler(epoch) 60 | self.lossmeter:reset() 61 | 62 | local timer = torch.Timer() 63 | 64 | local fevaltrunk = function() return self.model.trunk.output, self.gt end 65 | local fevalmask = function() return self.criterion.output, self.gm end 66 | local fevalscore = function() return self.criterion.output, self.gs end 67 | 68 | for n, sample in dataloader:run() do 69 | -- copy samples to the GPU 70 | self:copySamples(sample) 71 | 72 | -- forward/backward 73 | local model, params, feval, optimState 74 | if sample.head == 1 then 75 | model, params = self.maskNet, self.pm 76 | feval,optimState = fevalmask, self.optimState.mask 77 | else 78 | model, params = self.scoreNet, self.ps 79 | feval,optimState = fevalscore, self.optimState.score 80 | end 81 | 82 | local outputs = model:forward(self.inputs) 83 | local lossbatch = self.criterion:forward(outputs, self.labels) 84 | 85 | model:zeroGradParameters() 86 | local gradOutputs = self.criterion:backward(outputs, self.labels) 87 | if sample.head == 1 then gradOutputs:mul(self.inputs:size(1)) end 88 | model:backward(self.inputs, gradOutputs) 89 | 90 | -- optimize 91 | optim.sgd(fevaltrunk, self.pt, self.optimState.trunk) 92 | optim.sgd(feval, params, optimState) 93 | 94 | -- update loss 95 | self.lossmeter:add(lossbatch) 96 | end 97 | 98 | -- write log 99 | local logepoch = 100 | string.format('[train] | epoch %05d | s/batch %04.2f | loss: %07.5f ', 101 | epoch, timer:time().real/dataloader:size(),self.lossmeter:value()) 102 | print(logepoch) 103 | self.log:writeString(string.format('%s\n',logepoch)) 104 | self.log:synchronize() 105 | 106 | --save model 107 | torch.save(string.format('%s/model.t7', self.rundir),self.modelsv) 108 | if epoch%50 == 0 then 109 | torch.save(string.format('%s/model_%d.t7', self.rundir, epoch), 110 | self.modelsv) 111 | end 112 | 113 | collectgarbage() 114 | end 115 | 116 | -------------------------------------------------------------------------------- 117 | -- function: test 118 | local maxacc = 0 119 | function Trainer:test(epoch, dataloader) 120 | self.model:evaluate() 121 | self.maskmeter:reset() 122 | self.scoremeter:reset() 123 | 124 | for n, sample in dataloader:run() do 125 | -- copy input and target to the GPU 126 | self:copySamples(sample) 127 | 128 | if sample.head == 1 then 129 | local outputs = self.maskNet:forward(self.inputs) 130 | self.maskmeter:add(outputs:view(self.labels:size()),self.labels) 131 | else 132 | local outputs = self.scoreNet:forward(self.inputs) 133 | self.scoremeter:add(outputs, self.labels) 134 | end 135 | cutorch.synchronize() 136 | 137 | end 138 | self.model:training() 139 | 140 | -- check if bestmodel so far 141 | local z,bestmodel = self.maskmeter:value('0.7') 142 | if z > maxacc then 143 | torch.save(string.format('%s/bestmodel.t7', self.rundir),self.modelsv) 144 | maxacc = z 145 | bestmodel = true 146 | end 147 | 148 | -- write log 149 | local logepoch = 150 | string.format('[test] | epoch %05d '.. 151 | '| IoU: mean %06.2f median %06.2f suc@.5 %06.2f suc@.7 %06.2f '.. 152 | '| acc %06.2f | bestmodel %s', 153 | epoch, 154 | self.maskmeter:value('mean'),self.maskmeter:value('median'), 155 | self.maskmeter:value('0.5'), self.maskmeter:value('0.7'), 156 | self.scoremeter:value(), bestmodel and '*' or 'x') 157 | print(logepoch) 158 | self.log:writeString(string.format('%s\n',logepoch)) 159 | self.log:synchronize() 160 | 161 | collectgarbage() 162 | end 163 | 164 | -------------------------------------------------------------------------------- 165 | -- function: copy inputs/labels to CUDA tensor 166 | function Trainer:copySamples(sample) 167 | self.inputs:resize(sample.inputs:size()):copy(sample.inputs) 168 | self.labels:resize(sample.labels:size()):copy(sample.labels) 169 | end 170 | 171 | -------------------------------------------------------------------------------- 172 | -- function: update training schedule according to epoch 173 | function Trainer:updateScheduler(epoch) 174 | if self.lr == 0 then 175 | local regimes = { 176 | { 1, 50, 1e-3, 5e-4}, 177 | { 51, 120, 5e-4, 5e-4}, 178 | { 121, 1e8, 1e-4, 5e-4} 179 | } 180 | 181 | for _, row in ipairs(regimes) do 182 | if epoch >= row[1] and epoch <= row[2] then 183 | for k,v in pairs(self.optimState) do 184 | v.learningRate=row[3]; v.weightDecay=row[4] 185 | end 186 | end 187 | end 188 | end 189 | end 190 | 191 | return Trainer 192 | -------------------------------------------------------------------------------- /TrainerSharpMask.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Training and testing loop for SharpMask 8 | ------------------------------------------------------------------------------]] 9 | 10 | paths.dofile('trainMeters.lua') 11 | 12 | local Trainer = torch.class('Trainer') 13 | 14 | -------------------------------------------------------------------------------- 15 | -- function: init 16 | function Trainer:__init(model, criterion, config) 17 | -- training params 18 | self.model = model 19 | self.criterion = criterion 20 | self.lr = config.lr 21 | 22 | -- allocate cuda tensors 23 | self.inputs, self.labels = torch.CudaTensor(), torch.CudaTensor() 24 | 25 | -- meters 26 | self.lossmeter = LossMeter() 27 | self.maskmeter = IouMeter(0.5,config.testmaxload*config.batch) 28 | 29 | -- log 30 | self.modelsv = {model=model:clone('weight', 'bias'),config=config} 31 | self.rundir = config.rundir 32 | self.log = torch.DiskFile(self.rundir .. '/log', 'rw'); self.log:seekEnd() 33 | end 34 | 35 | -------------------------------------------------------------------------------- 36 | -- function: train 37 | function Trainer:train(epoch, dataloader) 38 | self.model:training() 39 | self:updateScheduler(epoch) 40 | self.lossmeter:reset() 41 | 42 | local timer = torch.Timer() 43 | 44 | for n, sample in dataloader:run() do 45 | -- copy samples to the GPU 46 | self:copySamples(sample) 47 | 48 | -- forward/backward 49 | local outputs = self.model:forward(self.inputs) 50 | local lossbatch = self.criterion:forward(outputs, self.labels) 51 | 52 | local gradOutputs = self.criterion:backward(outputs, self.labels) 53 | gradOutputs:mul(self.inputs:size(1)) 54 | self.model:zeroGradParameters() 55 | self.model:backward(self.inputs, gradOutputs) 56 | self.model:updateParameters(self.lr) 57 | 58 | -- update loss 59 | self.lossmeter:add(lossbatch) 60 | end 61 | 62 | -- write log 63 | local logepoch = 64 | string.format('[train] | epoch %05d | s/batch %04.2f | loss: %07.5f ', 65 | epoch, timer:time().real/dataloader:size(),self.lossmeter:value()) 66 | print(logepoch) 67 | self.log:writeString(string.format('%s\n',logepoch)) 68 | self.log:synchronize() 69 | 70 | --save model 71 | torch.save(string.format('%s/model.t7', self.rundir),self.modelsv) 72 | if epoch%50 == 0 then 73 | torch.save(string.format('%s/model_%d.t7', self.rundir, epoch), 74 | self.modelsv) 75 | end 76 | collectgarbage() 77 | end 78 | 79 | -------------------------------------------------------------------------------- 80 | -- function: test 81 | local maxacc = 0 82 | function Trainer:test(epoch, dataloader) 83 | self.model:evaluate() 84 | self.maskmeter:reset() 85 | 86 | for n, sample in dataloader:run() do 87 | -- copy input and target to the GPU 88 | self:copySamples(sample) 89 | 90 | -- infer mask in batch 91 | local outputs = self.model:forward(self.inputs):float() 92 | cutorch.synchronize() 93 | 94 | self.maskmeter:add(outputs:view(sample.labels:size()),sample.labels) 95 | 96 | end 97 | self.model:training() 98 | 99 | -- check if bestmodel so far 100 | local z,bestmodel = self.maskmeter:value('0.7') 101 | if z > maxacc then 102 | torch.save(string.format('%s/bestmodel.t7', self.rundir),self.modelsv) 103 | maxacc = z 104 | bestmodel = true 105 | end 106 | 107 | -- write log 108 | local logepoch = 109 | string.format('[test] | epoch %05d '.. 110 | '| IoU: mean %06.2f median %06.2f suc@.5 %06.2f suc@.7 %06.2f '.. 111 | '| bestmodel %s', 112 | epoch, 113 | self.maskmeter:value('mean'),self.maskmeter:value('median'), 114 | self.maskmeter:value('0.5'), self.maskmeter:value('0.7'), 115 | bestmodel and '*' or 'x') 116 | print(logepoch) 117 | self.log:writeString(string.format('%s\n',logepoch)) 118 | self.log:synchronize() 119 | 120 | collectgarbage() 121 | end 122 | 123 | -------------------------------------------------------------------------------- 124 | -- function: copy inputs/labels to CUDA tensor 125 | function Trainer:copySamples(sample) 126 | self.inputs:resize(sample.inputs:size()):copy(sample.inputs) 127 | self.labels:resize(sample.labels:size()):copy(sample.labels) 128 | end 129 | 130 | -------------------------------------------------------------------------------- 131 | -- function: update training schedule according to epoch 132 | function Trainer:updateScheduler(epoch) 133 | if self.lr == 0 then 134 | local regimes = { 135 | { 1, 50, 1e-3}, 136 | { 51, 80, 5e-4}, 137 | { 81, 1e8, 1e-4} 138 | } 139 | 140 | for _, row in ipairs(regimes) do 141 | if epoch >= row[1] and epoch <= row[2] then 142 | self.lr = row[3] 143 | end 144 | end 145 | end 146 | end 147 | 148 | return Trainer 149 | -------------------------------------------------------------------------------- /computeProposals.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Run full scene inference in sample image 8 | ------------------------------------------------------------------------------]] 9 | 10 | require 'torch' 11 | require 'cutorch' 12 | require 'image' 13 | 14 | -------------------------------------------------------------------------------- 15 | -- parse arguments 16 | local cmd = torch.CmdLine() 17 | cmd:text() 18 | cmd:text('evaluate deepmask/sharpmask') 19 | cmd:text() 20 | cmd:argument('-model', 'path to model to load') 21 | cmd:text('Options:') 22 | cmd:option('-img','data/testImage.jpg' ,'path/to/test/image') 23 | cmd:option('-gpu', 1, 'gpu device') 24 | cmd:option('-np', 5,'number of proposals to save in test') 25 | cmd:option('-si', -2.5, 'initial scale') 26 | cmd:option('-sf', .5, 'final scale') 27 | cmd:option('-ss', .5, 'scale step') 28 | cmd:option('-dm', false, 'use DeepMask version of SharpMask') 29 | 30 | local config = cmd:parse(arg) 31 | 32 | -------------------------------------------------------------------------------- 33 | -- various initializations 34 | torch.setdefaulttensortype('torch.FloatTensor') 35 | cutorch.setDevice(config.gpu) 36 | 37 | local coco = require 'coco' 38 | local maskApi = coco.MaskApi 39 | 40 | local meanstd = {mean = { 0.485, 0.456, 0.406 }, std = { 0.229, 0.224, 0.225 }} 41 | 42 | -------------------------------------------------------------------------------- 43 | -- load moodel 44 | paths.dofile('DeepMask.lua') 45 | paths.dofile('SharpMask.lua') 46 | 47 | print('| loading model file... ' .. config.model) 48 | local m = torch.load(config.model..'/model.t7') 49 | local model = m.model 50 | model:inference(config.np) 51 | model:cuda() 52 | 53 | -------------------------------------------------------------------------------- 54 | -- create inference module 55 | local scales = {} 56 | for i = config.si,config.sf,config.ss do table.insert(scales,2^i) end 57 | 58 | if torch.type(model)=='nn.DeepMask' then 59 | paths.dofile('InferDeepMask.lua') 60 | elseif torch.type(model)=='nn.SharpMask' then 61 | paths.dofile('InferSharpMask.lua') 62 | end 63 | 64 | local infer = Infer{ 65 | np = config.np, 66 | scales = scales, 67 | meanstd = meanstd, 68 | model = model, 69 | dm = config.dm, 70 | } 71 | 72 | -------------------------------------------------------------------------------- 73 | -- do it 74 | print('| start') 75 | 76 | -- load image 77 | local img = image.load(config.img) 78 | local h,w = img:size(2),img:size(3) 79 | 80 | -- forward all scales 81 | infer:forward(img) 82 | 83 | -- get top propsals 84 | local masks,_ = infer:getTopProps(.2,h,w) 85 | 86 | -- save result 87 | local res = img:clone() 88 | maskApi.drawMasks(res, masks, 10) 89 | image.save(string.format('./res.jpg',config.model),res) 90 | 91 | print('| done') 92 | collectgarbage() 93 | -------------------------------------------------------------------------------- /data/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/data/teaser.png -------------------------------------------------------------------------------- /data/testImage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/data/testImage.jpg -------------------------------------------------------------------------------- /demo/control.js: -------------------------------------------------------------------------------- 1 | var image; 2 | var cropper; 3 | var online = true; 4 | window.addEventListener('DOMContentLoaded', function () { 5 | var button = document.getElementById('crop-button'); 6 | $.ajax({ 7 | url: "http://ec2-54-219-178-149.us-west-1.compute.amazonaws.com:5000", 8 | success: function (response) { 9 | button.disabled = false; 10 | }, 11 | error: function (xhr, ajaxOptions, thrownError) { 12 | online = false; 13 | button.disabled = true; 14 | window.alert("Cropping functionality offline :(. Give me free AWS credits?"); 15 | console.log(":("); 16 | }, 17 | timeout: 5000 18 | }); 19 | var first = true; 20 | var canvas = document.getElementById('canvas'); 21 | var ctx = canvas.getContext("2d"); 22 | var outImage = document.getElementById('out_img'); 23 | canvas.width = 226; 24 | canvas.height = 218; 25 | ctx.drawImage(outImage, 0, 0); 26 | 27 | image = document.getElementById('cropImg'); 28 | cropper = new Cropper(image, { 29 | autoCropArea:0.5, 30 | ready: function () { 31 | if(first){ 32 | cropper.setCropBoxData({'left': cropper.getContainerData().width*0.24, 33 | 'top': cropper.getContainerData().width*0.17, 34 | 'width': cropper.getContainerData().width*0.38, 35 | 'height': cropper.getContainerData().height*0.46 }); 36 | first = false; 37 | } 38 | } 39 | }); 40 | 41 | 42 | $(".thumbnail").click(function(event){ 43 | var img = document.getElementById(event.target.id); 44 | image.src = img.src; 45 | cropper.replace(img.src); 46 | }); 47 | 48 | var dc_url = 'http://ec2-54-219-178-149.us-west-1.compute.amazonaws.com:5000'; 49 | var loader = document.getElementById('loader'); 50 | $("#crop-button").click(function(event){ 51 | if(!button.disabled){ 52 | button.disabled = true; 53 | var encoded = cropper.getCroppedCanvas().toDataURL("image/jpeg"); 54 | encoded = encoded.substring(encoded.indexOf(',')+1); 55 | canvas.style.display = "none"; 56 | loader.style.display = "block"; 57 | msg = {'img_64':encoded}; 58 | $.ajax({ 59 | type: "POST", 60 | url: dc_url, 61 | data: JSON.stringify(msg), 62 | success: function(data) { 63 | var image = new Image(); 64 | loading = false; 65 | image.onload = function() { 66 | canvas.width = image.width; 67 | canvas.height = image.height; 68 | ctx.drawImage(image, 0, 0); 69 | canvas.style.display = ""; 70 | loader.style.display = "none"; 71 | button.disabled = false; 72 | }; 73 | image.src = "data:image/jpg;base64,"+data; 74 | }, 75 | error: function() { 76 | window.alert("Something went wrong :( Try again?"); 77 | canvas.style.display = ""; 78 | loader.style.display = "none"; 79 | button.disabled = false; 80 | }, 81 | dataType: "json", 82 | contentType: "application/json; charset=utf-8", 83 | timeout: 15000 84 | }); 85 | } 86 | }); 87 | }); 88 | 89 | function loadImg() { 90 | var file = document.querySelector('input[type=file]').files[0]; 91 | var reader = new FileReader(); 92 | 93 | reader.addEventListener("load", function () { 94 | image.src = reader.result; 95 | cropper.replace(image.src); 96 | }, false); 97 | if (file) { 98 | reader.readAsDataURL(file); 99 | } 100 | } 101 | 102 | 103 | -------------------------------------------------------------------------------- /demo/cropperjs/cropper.css: -------------------------------------------------------------------------------- 1 | /*! 2 | * Cropper.js v1.0.0-beta.1 3 | * https://github.com/fengyuanchen/cropperjs 4 | * 5 | * Copyright (c) 2017 Fengyuan Chen 6 | * Released under the MIT license 7 | * 8 | * Date: 2017-01-21T12:28:26.786Z 9 | */ 10 | 11 | .cropper-container { 12 | font-size: 0; 13 | line-height: 0; 14 | 15 | position: relative; 16 | 17 | -webkit-user-select: none; 18 | 19 | -moz-user-select: none; 20 | 21 | -ms-user-select: none; 22 | 23 | user-select: none; 24 | 25 | direction: ltr; 26 | -ms-touch-action: none; 27 | touch-action: none 28 | } 29 | 30 | .cropper-container img { 31 | /* Avoid margin top issue (Occur only when margin-top <= -height) */ 32 | display: block; 33 | min-width: 0 !important; 34 | max-width: none !important; 35 | min-height: 0 !important; 36 | max-height: none !important; 37 | width: 100%; 38 | height: 100%; 39 | image-orientation: 0deg 40 | } 41 | 42 | .cropper-wrap-box, 43 | .cropper-canvas, 44 | .cropper-drag-box, 45 | .cropper-crop-box, 46 | .cropper-modal { 47 | position: absolute; 48 | top: 0; 49 | right: 0; 50 | bottom: 0; 51 | left: 0; 52 | } 53 | 54 | .cropper-wrap-box { 55 | overflow: hidden; 56 | } 57 | 58 | .cropper-drag-box { 59 | opacity: 0; 60 | background-color: #fff; 61 | } 62 | 63 | .cropper-modal { 64 | opacity: .5; 65 | background-color: #000; 66 | } 67 | 68 | .cropper-view-box { 69 | display: block; 70 | overflow: hidden; 71 | 72 | width: 100%; 73 | height: 100%; 74 | 75 | outline: 1px solid #39f; 76 | outline-color: rgba(51, 153, 255, 0.75); 77 | } 78 | 79 | .cropper-dashed { 80 | position: absolute; 81 | 82 | display: block; 83 | 84 | opacity: .5; 85 | border: 0 dashed #eee 86 | } 87 | 88 | .cropper-dashed.dashed-h { 89 | top: 33.33333%; 90 | left: 0; 91 | width: 100%; 92 | height: 33.33333%; 93 | border-top-width: 1px; 94 | border-bottom-width: 1px 95 | } 96 | 97 | .cropper-dashed.dashed-v { 98 | top: 0; 99 | left: 33.33333%; 100 | width: 33.33333%; 101 | height: 100%; 102 | border-right-width: 1px; 103 | border-left-width: 1px 104 | } 105 | 106 | .cropper-center { 107 | position: absolute; 108 | top: 50%; 109 | left: 50%; 110 | 111 | display: block; 112 | 113 | width: 0; 114 | height: 0; 115 | 116 | opacity: .75 117 | } 118 | 119 | .cropper-center:before, 120 | .cropper-center:after { 121 | position: absolute; 122 | display: block; 123 | content: ' '; 124 | background-color: #eee 125 | } 126 | 127 | .cropper-center:before { 128 | top: 0; 129 | left: -3px; 130 | width: 7px; 131 | height: 1px 132 | } 133 | 134 | .cropper-center:after { 135 | top: -3px; 136 | left: 0; 137 | width: 1px; 138 | height: 7px 139 | } 140 | 141 | .cropper-face, 142 | .cropper-line, 143 | .cropper-point { 144 | position: absolute; 145 | 146 | display: block; 147 | 148 | width: 100%; 149 | height: 100%; 150 | 151 | opacity: .1; 152 | } 153 | 154 | .cropper-face { 155 | top: 0; 156 | left: 0; 157 | 158 | background-color: #fff; 159 | } 160 | 161 | .cropper-line { 162 | background-color: #39f 163 | } 164 | 165 | .cropper-line.line-e { 166 | top: 0; 167 | right: -3px; 168 | width: 5px; 169 | cursor: e-resize 170 | } 171 | 172 | .cropper-line.line-n { 173 | top: -3px; 174 | left: 0; 175 | height: 5px; 176 | cursor: n-resize 177 | } 178 | 179 | .cropper-line.line-w { 180 | top: 0; 181 | left: -3px; 182 | width: 5px; 183 | cursor: w-resize 184 | } 185 | 186 | .cropper-line.line-s { 187 | bottom: -3px; 188 | left: 0; 189 | height: 5px; 190 | cursor: s-resize 191 | } 192 | 193 | .cropper-point { 194 | width: 5px; 195 | height: 5px; 196 | 197 | opacity: .75; 198 | background-color: #39f 199 | } 200 | 201 | .cropper-point.point-e { 202 | top: 50%; 203 | right: -3px; 204 | margin-top: -3px; 205 | cursor: e-resize 206 | } 207 | 208 | .cropper-point.point-n { 209 | top: -3px; 210 | left: 50%; 211 | margin-left: -3px; 212 | cursor: n-resize 213 | } 214 | 215 | .cropper-point.point-w { 216 | top: 50%; 217 | left: -3px; 218 | margin-top: -3px; 219 | cursor: w-resize 220 | } 221 | 222 | .cropper-point.point-s { 223 | bottom: -3px; 224 | left: 50%; 225 | margin-left: -3px; 226 | cursor: s-resize 227 | } 228 | 229 | .cropper-point.point-ne { 230 | top: -3px; 231 | right: -3px; 232 | cursor: ne-resize 233 | } 234 | 235 | .cropper-point.point-nw { 236 | top: -3px; 237 | left: -3px; 238 | cursor: nw-resize 239 | } 240 | 241 | .cropper-point.point-sw { 242 | bottom: -3px; 243 | left: -3px; 244 | cursor: sw-resize 245 | } 246 | 247 | .cropper-point.point-se { 248 | right: -3px; 249 | bottom: -3px; 250 | width: 20px; 251 | height: 20px; 252 | cursor: se-resize; 253 | opacity: 1 254 | } 255 | 256 | @media (min-width: 768px) { 257 | 258 | .cropper-point.point-se { 259 | width: 15px; 260 | height: 15px 261 | } 262 | } 263 | 264 | @media (min-width: 992px) { 265 | 266 | .cropper-point.point-se { 267 | width: 10px; 268 | height: 10px 269 | } 270 | } 271 | 272 | @media (min-width: 1200px) { 273 | 274 | .cropper-point.point-se { 275 | width: 5px; 276 | height: 5px; 277 | opacity: .75 278 | } 279 | } 280 | 281 | .cropper-point.point-se:before { 282 | position: absolute; 283 | right: -50%; 284 | bottom: -50%; 285 | display: block; 286 | width: 200%; 287 | height: 200%; 288 | content: ' '; 289 | opacity: 0; 290 | background-color: #39f 291 | } 292 | 293 | .cropper-invisible { 294 | opacity: 0; 295 | } 296 | 297 | .cropper-bg { 298 | background-image: url(''); 299 | } 300 | 301 | .cropper-hide { 302 | position: absolute; 303 | 304 | display: block; 305 | 306 | width: 0; 307 | height: 0; 308 | } 309 | 310 | .cropper-hidden { 311 | display: none !important; 312 | } 313 | 314 | .cropper-move { 315 | cursor: move; 316 | } 317 | 318 | .cropper-crop { 319 | cursor: crosshair; 320 | } 321 | 322 | .cropper-disabled .cropper-drag-box, 323 | .cropper-disabled .cropper-face, 324 | .cropper-disabled .cropper-line, 325 | .cropper-disabled .cropper-point { 326 | cursor: not-allowed; 327 | } 328 | 329 | -------------------------------------------------------------------------------- /demo/cropperjs/cropper.min.css: -------------------------------------------------------------------------------- 1 | /*! 2 | * Cropper.js v1.0.0-beta.1 3 | * https://github.com/fengyuanchen/cropperjs 4 | * 5 | * Copyright (c) 2017 Fengyuan Chen 6 | * Released under the MIT license 7 | * 8 | * Date: 2017-01-21T12:28:26.786Z 9 | */ 10 | 11 | .cropper-container{font-size:0;line-height:0;position:relative;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none;direction:ltr;-ms-touch-action:none;touch-action:none}.cropper-container img{display:block;min-width:0!important;max-width:none!important;min-height:0!important;max-height:none!important;width:100%;height:100%;image-orientation:0deg}.cropper-canvas,.cropper-crop-box,.cropper-drag-box,.cropper-modal,.cropper-wrap-box{position:absolute;top:0;right:0;bottom:0;left:0}.cropper-wrap-box{overflow:hidden}.cropper-drag-box{opacity:0;background-color:#fff}.cropper-modal{opacity:.5;background-color:#000}.cropper-view-box{display:block;overflow:hidden;width:100%;height:100%;outline:1px solid #39f;outline-color:rgba(51,153,255,.75)}.cropper-dashed{position:absolute;display:block;opacity:.5;border:0 dashed #eee}.cropper-dashed.dashed-h{top:33.33333%;left:0;width:100%;height:33.33333%;border-top-width:1px;border-bottom-width:1px}.cropper-dashed.dashed-v{top:0;left:33.33333%;width:33.33333%;height:100%;border-right-width:1px;border-left-width:1px}.cropper-center{position:absolute;top:50%;left:50%;display:block;width:0;height:0;opacity:.75}.cropper-center:after,.cropper-center:before{position:absolute;display:block;content:" ";background-color:#eee}.cropper-center:before{top:0;left:-3px;width:7px;height:1px}.cropper-center:after{top:-3px;left:0;width:1px;height:7px}.cropper-face,.cropper-line,.cropper-point{position:absolute;display:block;width:100%;height:100%;opacity:.1}.cropper-face{top:0;left:0;background-color:#fff}.cropper-line{background-color:#39f}.cropper-line.line-e{top:0;right:-3px;width:5px;cursor:e-resize}.cropper-line.line-n{top:-3px;left:0;height:5px;cursor:n-resize}.cropper-line.line-w{top:0;left:-3px;width:5px;cursor:w-resize}.cropper-line.line-s{bottom:-3px;left:0;height:5px;cursor:s-resize}.cropper-point{width:5px;height:5px;opacity:.75;background-color:#39f}.cropper-point.point-e{top:50%;right:-3px;margin-top:-3px;cursor:e-resize}.cropper-point.point-n{top:-3px;left:50%;margin-left:-3px;cursor:n-resize}.cropper-point.point-w{top:50%;left:-3px;margin-top:-3px;cursor:w-resize}.cropper-point.point-s{bottom:-3px;left:50%;margin-left:-3px;cursor:s-resize}.cropper-point.point-ne{top:-3px;right:-3px;cursor:ne-resize}.cropper-point.point-nw{top:-3px;left:-3px;cursor:nw-resize}.cropper-point.point-sw{bottom:-3px;left:-3px;cursor:sw-resize}.cropper-point.point-se{right:-3px;bottom:-3px;width:20px;height:20px;cursor:se-resize;opacity:1}@media (min-width:768px){.cropper-point.point-se{width:15px;height:15px}}@media (min-width:992px){.cropper-point.point-se{width:10px;height:10px}}@media (min-width:1200px){.cropper-point.point-se{width:5px;height:5px;opacity:.75}}.cropper-point.point-se:before{position:absolute;right:-50%;bottom:-50%;display:block;width:200%;height:200%;content:" ";opacity:0;background-color:#39f}.cropper-invisible{opacity:0}.cropper-bg{background-image:url("")}.cropper-hide{position:absolute;display:block;width:0;height:0}.cropper-hidden{display:none!important}.cropper-move{cursor:move}.cropper-crop{cursor:crosshair}.cropper-disabled .cropper-drag-box,.cropper-disabled .cropper-face,.cropper-disabled .cropper-line,.cropper-disabled .cropper-point{cursor:not-allowed} -------------------------------------------------------------------------------- /demo/cropperjs/cropper.min.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * Cropper.js v1.0.0-beta.1 3 | * https://github.com/fengyuanchen/cropperjs 4 | * 5 | * Copyright (c) 2017 Fengyuan Chen 6 | * Released under the MIT license 7 | * 8 | * Date: 2017-01-21T12:28:26.786Z 9 | */ 10 | 11 | !function(t,e){"object"==typeof exports&&"undefined"!=typeof module?module.exports=e():"function"==typeof define&&define.amd?define(e):t.Cropper=e()}(this,function(){"use strict";function t(t){return rt.call(t).slice(8,-1).toLowerCase()}function e(t){return"number"==typeof t&&!isNaN(t)}function a(t){return"undefined"==typeof t}function i(t){return"object"===("undefined"==typeof t?"undefined":Z(t))&&null!==t}function o(t){if(!i(t))return!1;try{var e=t.constructor,a=e.prototype;return e&&a&&ht.call(a,"isPrototypeOf")}catch(t){return!1}}function n(e){return"function"===t(e)}function r(e){return Array.isArray?Array.isArray(e):"array"===t(e)}function h(t){return"string"==typeof t&&(t=t.trim?t.trim():t.replace(et,"$1")),t}function c(t,a){if(t&&n(a)){var o=void 0;if(r(t)||e(t.length)){var h=t.length;for(o=0;o1&&(e.shift(),e.forEach(function(t){i(t)&&Object.keys(t).forEach(function(e){o&&i(n[e])?s(!0,n[e],t[e]):n[e]=t[e]})})),n}function d(t,e){for(var a=arguments.length,i=Array(a>2?a-2:0),o=2;o-1}function u(t,a){if(e(t.length))return void c(t,function(t){u(t,a)});if(t.classList)return void t.classList.add(a);var i=h(t.className);i?i.indexOf(a)<0&&(t.className=i+" "+a):t.className=a}function m(t,a){return e(t.length)?void c(t,function(t){m(t,a)}):t.classList?void t.classList.remove(a):void(t.className.indexOf(a)>=0&&(t.className=t.className.replace(a,"")))}function f(t,a,i){return e(t.length)?void c(t,function(t){f(t,a,i)}):void(i?u(t,a):m(t,a))}function g(t){return t.replace(J,"$1-$2").toLowerCase()}function v(t,e){return i(t[e])?t[e]:t.dataset?t.dataset[e]:t.getAttribute("data-"+g(e))}function w(t,e,a){i(a)?t[e]=a:t.dataset?t.dataset[e]=a:t.setAttribute("data-"+g(e),a)}function b(t,e){if(i(t[e]))delete t[e];else if(t.dataset)try{delete t.dataset[e]}catch(a){t.dataset[e]=null}else t.removeAttribute("data-"+g(e))}function x(t,e,a){var i=h(e).split(_);return i.length>1?void c(i,function(e){x(t,e,a)}):void(t.removeEventListener?t.removeEventListener(e,a,!1):t.detachEvent&&t.detachEvent("on"+e,a))}function y(t,e,a,i){var o=h(e).split(_),n=a;return o.length>1?void c(o,function(e){y(t,e,a)}):(i&&(a=function(){for(var i=arguments.length,o=Array(i),r=0;r90?180-a:a)*Math.PI/180,o=Math.sin(i),n=Math.cos(i),r=t.width,h=t.height,c=t.aspectRatio,s=void 0,d=void 0;return e?(s=r/(n+o/c),d=s/c):(s=r*n+h*o,d=r*o+h*n),{width:s,height:d}}function z(t,a){var i=T("canvas"),o=i.getContext("2d"),n=0,r=0,h=a.naturalWidth,c=a.naturalHeight,s=a.rotate,d=a.scaleX,l=a.scaleY,p=e(d)&&e(l)&&(1!==d||1!==l),u=e(s)&&0!==s,m=u||p,f=h*Math.abs(d||1),g=c*Math.abs(l||1),v=void 0,w=void 0,b=void 0;return p&&(v=f/2,w=g/2),u&&(b=O({width:f,height:g,degree:s}),f=b.width,g=b.height,v=f/2,w=g/2),i.width=f,i.height=g,m&&(n=-h/2,r=-c/2,o.save(),o.translate(v,w)),u&&o.rotate(s*Math.PI/180),p&&o.scale(d,l),o.drawImage(t,Math.floor(n),Math.floor(r),Math.floor(h),Math.floor(c)),m&&o.restore(),i}function A(t,e,a){var i="",o=e;for(a+=e;o=8&&(d=n+r)))),d)for(a=e.getUint16(d,h),p=0;p
',Z="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol&&t!==Symbol.prototype?"symbol":typeof t},K=function(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")},V=function(){function t(t,e){for(var a=0;aa.width?3===e?c=a.height*h:d=a.width/h:3===e?d=a.width/h:c=a.height*h;var l={naturalWidth:n,naturalHeight:r,aspectRatio:h,width:c,height:d};l.oldLeft=l.left=(a.width-c)/2,l.oldTop=l.top=(a.height-d)/2,t.canvasData=l,t.limited=1===e||2===e,t.limitCanvas(!0,!0),t.initialImageData=s({},i),t.initialCanvasData=s({},l)},limitCanvas:function(t,e){var a=this,i=a.options,o=i.viewMode,n=a.containerData,r=a.canvasData,h=r.aspectRatio,c=a.cropBoxData,s=a.cropped&&c;if(t){var d=Number(i.minCanvasWidth)||0,l=Number(i.minCanvasHeight)||0;o>1?(d=Math.max(d,n.width),l=Math.max(l,n.height),3===o&&(l*h>d?d=l*h:l=d/h)):o>0&&(d?d=Math.max(d,s?c.width:0):l?l=Math.max(l,s?c.height:0):s&&(d=c.width,l=c.height,l*h>d?d=l*h:l=d/h)),d&&l?l*h>d?l=d/h:d=l*h:d?l=d/h:l&&(d=l*h),r.minWidth=d,r.minHeight=l,r.maxWidth=1/0,r.maxHeight=1/0}if(e)if(o){var p=n.width-r.width,u=n.height-r.height;r.minLeft=Math.min(0,p),r.minTop=Math.min(0,u),r.maxLeft=Math.max(0,p),r.maxTop=Math.max(0,u),s&&a.limited&&(r.minLeft=Math.min(c.left,c.left+(c.width-r.width)),r.minTop=Math.min(c.top,c.top+(c.height-r.height)),r.maxLeft=c.left,r.maxTop=c.top,2===o&&(r.width>=n.width&&(r.minLeft=Math.min(0,p),r.maxLeft=Math.max(0,p)),r.height>=n.height&&(r.minTop=Math.min(0,u),r.maxTop=Math.max(0,u))))}else r.minLeft=-r.width,r.minTop=-r.height,r.maxLeft=n.width,r.maxTop=n.height},renderCanvas:function(t){var e=this,a=e.canvasData,i=e.imageData,o=i.rotate,n=void 0,r=void 0;e.rotated&&(e.rotated=!1,r=O({width:i.width,height:i.height,degree:o}),n=r.width/r.height,n!==a.aspectRatio&&(a.left-=(r.width-a.width)/2,a.top-=(r.height-a.height)/2,a.width=r.width,a.height=r.height,a.aspectRatio=n,a.naturalWidth=i.naturalWidth,a.naturalHeight=i.naturalHeight,o%180&&(r=O({width:i.naturalWidth,height:i.naturalHeight,degree:o}),a.naturalWidth=r.width,a.naturalHeight=r.height),e.limitCanvas(!0,!1))),(a.width>a.maxWidth||a.widtha.maxHeight||a.heighto.width?n.height=n.width/a:n.width=n.height*a),t.cropBoxData=n,t.limitCropBox(!0,!0),n.width=Math.min(Math.max(n.width,n.minWidth),n.maxWidth),n.height=Math.min(Math.max(n.height,n.minHeight),n.maxHeight),n.width=Math.max(n.minWidth,n.width*i),n.height=Math.max(n.minHeight,n.height*i),n.oldLeft=n.left=o.left+(o.width-n.width)/2,n.oldTop=n.top=o.top+(o.height-n.height)/2,t.initialCropBoxData=s({},n)},limitCropBox:function(t,e){var a=this,i=a.options,o=i.aspectRatio,n=a.containerData,r=a.canvasData,h=a.cropBoxData,c=a.limited;if(t){var s=Number(i.minCropBoxWidth)||0,d=Number(i.minCropBoxHeight)||0,l=Math.min(n.width,c?r.width:n.width),p=Math.min(n.height,c?r.height:n.height);s=Math.min(s,n.width),d=Math.min(d,n.height),o&&(s&&d?d*o>s?d=s/o:s=d*o:s?d=s/o:d&&(s=d*o),p*o>l?p=l/o:l=p*o),h.minWidth=Math.min(s,l),h.minHeight=Math.min(d,p),h.maxWidth=l,h.maxHeight=p}e&&(c?(h.minLeft=Math.max(0,r.left),h.minTop=Math.max(0,r.top),h.maxLeft=Math.min(n.width,r.left+r.width)-h.width,h.maxTop=Math.min(n.height,r.top+r.height)-h.height):(h.minLeft=0,h.minTop=0,h.maxLeft=n.width-h.width,h.maxTop=n.height-h.height))},renderCropBox:function(){var t=this,e=t.options,a=t.containerData,i=t.cropBoxData;(i.width>i.maxWidth||i.widthi.maxHeight||i.heightc&&(f=c/n,u=o*f,m=c),l(t,{width:u,height:m}),l(B(t,"img")[0],s({width:r*f,height:h*f},N(s({translateX:-d*f,translateY:-p*f},e))))}))}},pt="undefined"!=typeof window?window.PointerEvent:null,ut=pt?"pointerdown":"touchstart mousedown",mt=pt?"pointermove":"touchmove mousemove",ft=pt?" pointerup pointercancel":"touchend touchcancel mouseup",gt="wheel mousewheel DOMMouseScroll",vt="dblclick",wt="resize",bt="cropstart",xt="cropmove",yt="cropend",Mt="crop",Ct="zoom",Dt={bind:function(){var t=this,e=t.options,a=t.element,i=t.cropper;n(e.cropstart)&&y(a,bt,e.cropstart),n(e.cropmove)&&y(a,xt,e.cropmove),n(e.cropend)&&y(a,yt,e.cropend),n(e.crop)&&y(a,Mt,e.crop),n(e.zoom)&&y(a,Ct,e.zoom),y(i,ut,t.onCropStart=d(t.cropStart,t)),e.zoomable&&e.zoomOnWheel&&y(i,gt,t.onWheel=d(t.wheel,t)),e.toggleDragModeOnDblclick&&y(i,vt,t.onDblclick=d(t.dblclick,t)),y(document,mt,t.onCropMove=d(t.cropMove,t)),y(document,ft,t.onCropEnd=d(t.cropEnd,t)),e.responsive&&y(window,wt,t.onResize=d(t.resize,t))},unbind:function(){var t=this,e=t.options,a=t.element,i=t.cropper;n(e.cropstart)&&x(a,bt,e.cropstart),n(e.cropmove)&&x(a,xt,e.cropmove),n(e.cropend)&&x(a,yt,e.cropend),n(e.crop)&&x(a,Mt,e.crop),n(e.zoom)&&x(a,Ct,e.zoom),x(i,ut,t.onCropStart),e.zoomable&&e.zoomOnWheel&&x(i,gt,t.onWheel),e.toggleDragModeOnDblclick&&x(i,vt,t.onDblclick),x(document,mt,t.onCropMove),x(document,ft,t.onCropEnd),e.responsive&&x(window,wt,t.onResize)}},Bt=/^(e|w|s|n|se|sw|ne|nw|all|crop|move|zoom)$/,kt={resize:function(){var t=this,e=t.options.restore,a=t.container,i=t.containerData;if(!t.disabled&&i){var o=a.offsetWidth/i.width,n=void 0,r=void 0;1===o&&a.offsetHeight===i.height||(e&&(n=t.getCanvasData(),r=t.getCropBoxData()),t.render(),e&&(t.setCanvasData(c(n,function(t,e){n[e]=t*o})),t.setCropBoxData(c(r,function(t,e){r[e]=t*o}))))}},dblclick:function(){var t=this;t.disabled||t.setDragMode(p(t.dragBox,"cropper-crop")?"move":"crop")},wheel:function(t){var e=this,a=C(t),i=Number(e.options.wheelZoomRatio)||.1,o=1;e.disabled||(a.preventDefault(),e.wheeling||(e.wheeling=!0,setTimeout(function(){e.wheeling=!1},50),a.deltaY?o=a.deltaY>0?1:-1:a.wheelDelta?o=-a.wheelDelta/120:a.detail&&(o=a.detail>0?1:-1),e.zoom(-o*i,a)))},cropStart:function(t){var e=this;if(!e.disabled){var a=e.options,i=e.pointers,o=C(t),n=void 0;if(o.changedTouches?c(o.changedTouches,function(t){i[t.identifier]=U(t)}):i[o.pointerId||0]=U(o),n=Object.keys(i).length>1&&a.zoomable&&a.zoomOnTouch?"zoom":v(o.target,"action"),Bt.test(n)){if(M(e.element,"cropstart",{originalEvent:o,action:n})===!1)return;o.preventDefault(),e.action=n,e.cropping=!1,"crop"===n&&(e.cropping=!0,u(e.dragBox,"cropper-modal"))}}},cropMove:function(t){var e=this,a=e.action;if(!e.disabled&&a){var i=e.pointers,o=C(t);o.preventDefault(),M(e.element,"cropmove",{originalEvent:o,action:a})!==!1&&(o.changedTouches?c(o.changedTouches,function(t){s(i[t.identifier],U(t,!0))}):s(i[o.pointerId||0],U(o,!0)),e.change(o))}},cropEnd:function(t){var e=this,a=e.action;if(!e.disabled&&a){var i=e.pointers,o=C(t);o.preventDefault(),o.changedTouches?c(o.changedTouches,function(t){delete i[t.identifier]}):delete i[o.pointerId||0],Object.keys(i).length||(e.action=""),e.cropping&&(e.cropping=!1,f(e.dragBox,"cropper-modal",e.cropped&&this.options.modal)),M(e.element,"cropend",{originalEvent:o,action:a})}}},Tt="e",Lt="w",Wt="s",Xt="n",Yt="se",Et="sw",Ht="ne",Nt="nw",Ot={change:function(t){var e=this,a=e.options,i=e.containerData,o=e.canvasData,n=e.cropBoxData,r=a.aspectRatio,h=e.action,s=n.width,d=n.height,l=n.left,p=n.top,u=l+s,f=p+d,g=0,v=0,w=i.width,b=i.height,x=!0,y=void 0;!r&&t.shiftKey&&(r=s&&d?s/d:1),e.limited&&(g=n.minLeft,v=n.minTop,w=g+Math.min(i.width,o.width,o.left+o.width),b=v+Math.min(i.height,o.height,o.top+o.height));var M=e.pointers,C=M[Object.keys(M)[0]],B={x:C.endX-C.startX,y:C.endY-C.startY};switch(r&&(B.X=B.y*r,B.Y=B.x/r),h){case"all":l+=B.x,p+=B.y;break;case Tt:if(B.x>=0&&(u>=w||r&&(p<=v||f>=b))){x=!1;break}s+=B.x,r&&(d=s/r,p-=B.Y/2),s<0&&(h=Lt,s=0);break;case Xt:if(B.y<=0&&(p<=v||r&&(l<=g||u>=w))){x=!1;break}d-=B.y,p+=B.y,r&&(s=d*r,l+=B.X/2),d<0&&(h=Wt,d=0);break;case Lt:if(B.x<=0&&(l<=g||r&&(p<=v||f>=b))){x=!1;break}s-=B.x,l+=B.x,r&&(d=s/r,p+=B.Y/2),s<0&&(h=Tt,s=0);break;case Wt:if(B.y>=0&&(f>=b||r&&(l<=g||u>=w))){x=!1;break}d+=B.y,r&&(s=d*r,l-=B.X/2),d<0&&(h=Xt,d=0);break;case Ht:if(r){if(B.y<=0&&(p<=v||u>=w)){x=!1;break}d-=B.y,p+=B.y,s=d*r}else B.x>=0?uv&&(d-=B.y,p+=B.y):(d-=B.y,p+=B.y);s<0&&d<0?(h=Et,d=0,s=0):s<0?(h=Nt,s=0):d<0&&(h=Yt,d=0);break;case Nt:if(r){if(B.y<=0&&(p<=v||l<=g)){x=!1;break}d-=B.y,p+=B.y,s=d*r,l+=B.X}else B.x<=0?l>g?(s-=B.x,l+=B.x):B.y<=0&&p<=v&&(x=!1):(s-=B.x,l+=B.x),B.y<=0?p>v&&(d-=B.y,p+=B.y):(d-=B.y,p+=B.y);s<0&&d<0?(h=Yt,d=0,s=0):s<0?(h=Ht,s=0):d<0&&(h=Et,d=0);break;case Et:if(r){if(B.x<=0&&(l<=g||f>=b)){x=!1;break}s-=B.x,l+=B.x,d=s/r}else B.x<=0?l>g?(s-=B.x,l+=B.x):B.y>=0&&f>=b&&(x=!1):(s-=B.x,l+=B.x),B.y>=0?f=0&&(u>=w||f>=b)){x=!1;break}s+=B.x,d=s/r}else B.x>=0?u=0&&f>=b&&(x=!1):s+=B.x,B.y>=0?f0?h=B.y>0?Yt:Ht:B.x<0&&(l-=s,h=B.y>0?Et:Nt),B.y<0&&(p-=d),e.cropped||(m(e.cropBox,"cropper-hidden"),e.cropped=!0,e.limited&&e.limitCropBox(!0,!0))}x&&(n.width=s,n.height=d,n.left=l,n.top=p,e.action=h,e.renderCropBox()),c(M,function(t){t.startX=t.endX,t.startY=t.endY})}},zt={crop:function(){var t=this;return t.ready&&!t.disabled&&(t.cropped||(t.cropped=!0,t.limitCropBox(!0,!0),t.options.modal&&u(t.dragBox,"cropper-modal"),m(t.cropBox,"cropper-hidden")),t.setCropBoxData(t.initialCropBoxData)),t},reset:function(){var t=this;return t.ready&&!t.disabled&&(t.imageData=s({},t.initialImageData),t.canvasData=s({},t.initialCanvasData),t.cropBoxData=s({},t.initialCropBoxData),t.renderCanvas(),t.cropped&&t.renderCropBox()),t},clear:function(){var t=this;return t.cropped&&!t.disabled&&(s(t.cropBoxData,{left:0,top:0,width:0,height:0}),t.cropped=!1,t.renderCropBox(),t.limitCanvas(),t.renderCanvas(),m(t.dragBox,"cropper-modal"),u(t.cropBox,"cropper-hidden")),t},replace:function(t,e){var a=this;return!a.disabled&&t&&(a.isImg&&(a.element.src=t),e?(a.url=t,a.image.src=t,a.ready&&(a.image2.src=t,c(a.previews,function(e){B(e,"img")[0].src=t}))):(a.isImg&&(a.replaced=!0),a.options.data=null,a.load(t))),a},enable:function(){var t=this;return t.ready&&(t.disabled=!1,m(t.cropper,"cropper-disabled")),t},disable:function(){var t=this;return t.ready&&(t.disabled=!0,u(t.cropper,"cropper-disabled")),t},destroy:function(){var t=this,e=t.element,a=t.image;return t.loaded?(t.isImg&&t.replaced&&(e.src=t.originalUrl),t.unbuild(),m(e,"cropper-hidden")):t.isImg?x(e,"load",t.start):a&&W(a),b(e,"cropper"),t},move:function(t,e){var i=this,o=i.canvasData;return i.moveTo(a(t)?t:o.left+Number(t),a(e)?e:o.top+Number(e))},moveTo:function(t,i){var o=this,n=o.canvasData,r=!1;return a(i)&&(i=t),t=Number(t),i=Number(i),o.ready&&!o.disabled&&o.options.movable&&(e(t)&&(n.left=t,r=!0),e(i)&&(n.top=i,r=!0),r&&o.renderCanvas(!0)),o},zoom:function(t,e){var a=this,i=a.canvasData;return t=Number(t),t=t<0?1/(1-t):1+t,a.zoomTo(i.width*t/i.naturalWidth,e)},zoomTo:function(t,e){var a=this,i=a.options,o=a.canvasData,n=o.width,r=o.height,h=o.naturalWidth,c=o.naturalHeight;if(t=Number(t),t>=0&&a.ready&&!a.disabled&&i.zoomable){var s=h*t,d=c*t;if(M(a.element,"zoom",{originalEvent:e,oldRatio:n/h,ratio:s/h})===!1)return a;if(e){var l=a.pointers,p=D(a.cropper),u=l&&Object.keys(l).length?P(l):{pageX:e.pageX,pageY:e.pageY};o.left-=(s-n)*((u.pageX-p.left-o.left)/n),o.top-=(d-r)*((u.pageY-p.top-o.top)/r)}else o.left-=(s-n)/2,o.top-=(d-r)/2;o.width=s,o.height=d,a.renderCanvas(!0)}return a},rotate:function(t){var e=this;return e.rotateTo((e.imageData.rotate||0)+Number(t))},rotateTo:function(t){var a=this;return t=Number(t),e(t)&&a.ready&&!a.disabled&&a.options.rotatable&&(a.imageData.rotate=t%360,a.rotated=!0,a.renderCanvas(!0)),a},scale:function(t,i){var o=this,n=o.imageData,r=!1;return a(i)&&(i=t),t=Number(t),i=Number(i),o.ready&&!o.disabled&&o.options.scalable&&(e(t)&&(n.scaleX=t,r=!0),e(i)&&(n.scaleY=i,r=!0),r&&o.renderImage(!0)),o},scaleX:function(t){var a=this,i=a.imageData.scaleY;return a.scale(t,e(i)?i:1)},scaleY:function(t){var a=this,i=a.imageData.scaleX;return a.scale(e(i)?i:1,t)},getData:function(t){var e=this,a=e.options,i=e.imageData,o=e.canvasData,n=e.cropBoxData,r=void 0,h=void 0;return e.ready&&e.cropped?(h={x:n.left-o.left,y:n.top-o.top,width:n.width,height:n.height},r=i.width/i.naturalWidth,c(h,function(e,a){e/=r,h[a]=t?Math.round(e):e})):h={x:0,y:0,width:0,height:0},a.rotatable&&(h.rotate=i.rotate||0),a.scalable&&(h.scaleX=i.scaleX||1,h.scaleY=i.scaleY||1),h},setData:function(t){var a=this,i=a.options,r=a.imageData,h=a.canvasData,c={},s=void 0,d=void 0,l=void 0;return n(t)&&(t=t.call(a.element)),a.ready&&!a.disabled&&o(t)&&(i.rotatable&&e(t.rotate)&&t.rotate!==r.rotate&&(r.rotate=t.rotate,a.rotated=s=!0),i.scalable&&(e(t.scaleX)&&t.scaleX!==r.scaleX&&(r.scaleX=t.scaleX,d=!0),e(t.scaleY)&&t.scaleY!==r.scaleY&&(r.scaleY=t.scaleY,d=!0)),s?a.renderCanvas():d&&a.renderImage(),l=r.width/r.naturalWidth,e(t.x)&&(c.left=t.x*l+h.left),e(t.y)&&(c.top=t.y*l+h.top),e(t.width)&&(c.width=t.width*l),e(t.height)&&(c.height=t.height*l),a.setCropBoxData(c)),a},getContainerData:function(){var t=this;return t.ready?t.containerData:{}},getImageData:function(){var t=this;return t.loaded?t.imageData:{}},getCanvasData:function(){var t=this,e=t.canvasData,a={};return t.ready&&c(["left","top","width","height","naturalWidth","naturalHeight"],function(t){a[t]=e[t]}),a},setCanvasData:function(t){var a=this,i=a.canvasData,r=i.aspectRatio;return n(t)&&(t=t.call(a.element)),a.ready&&!a.disabled&&o(t)&&(e(t.left)&&(i.left=t.left),e(t.top)&&(i.top=t.top),e(t.width)?(i.width=t.width,i.height=t.width/r):e(t.height)&&(i.height=t.height,i.width=t.height*r),a.renderCanvas(!0)),a},getCropBoxData:function(){var t=this,e=t.cropBoxData,a=void 0;return t.ready&&t.cropped&&(a={left:e.left,top:e.top,width:e.width,height:e.height}),a||{}},setCropBoxData:function(t){var a=this,i=a.cropBoxData,r=a.options.aspectRatio,h=void 0,c=void 0;return n(t)&&(t=t.call(a.element)),a.ready&&a.cropped&&!a.disabled&&o(t)&&(e(t.left)&&(i.left=t.left),e(t.top)&&(i.top=t.top),e(t.width)&&t.width!==i.width&&(h=!0,i.width=t.width),e(t.height)&&t.height!==i.height&&(c=!0,i.height=t.height),r&&(h?i.height=i.width/r:c&&(i.width=i.height*r)),a.renderCropBox()),a},getCroppedCanvas:function(t){var e=this;if(!e.ready||!window.HTMLCanvasElement)return null;if(!e.cropped)return z(e.image,e.imageData);o(t)||(t={});var a=e.getData(),i=a.width,n=a.height,r=i/n,h=void 0,c=void 0,s=void 0;o(t)&&(h=t.width,c=t.height,h?(c=h/r,s=h/i):c&&(h=c*r,s=c/n));var d=Math.floor(h||i),l=Math.floor(c||n),p=T("canvas"),u=p.getContext("2d");p.width=d,p.height=l,t.fillColor&&(u.fillStyle=t.fillColor,u.fillRect(0,0,d,l));var m=function(){var t=z(e.image,e.imageData),o=t.width,r=t.height,h=e.canvasData,c=[t],d=a.x+h.naturalWidth*(Math.abs(a.scaleX||1)-1)/2,l=a.y+h.naturalHeight*(Math.abs(a.scaleY||1)-1)/2,p=void 0,u=void 0,m=void 0,f=void 0,g=void 0,v=void 0;return d<=-i||d>o?d=p=m=g=0:d<=0?(m=-d,d=0,p=g=Math.min(o,i+d)):d<=o&&(m=0,p=g=Math.min(i,o-d)),p<=0||l<=-n||l>r?l=u=f=v=0:l<=0?(f=-l,l=0,u=v=Math.min(r,n+l)):l<=r&&(f=0,u=v=Math.min(n,r-l)),c.push(Math.floor(d),Math.floor(l),Math.floor(p),Math.floor(u)),s&&(m*=s,f*=s,g*=s,v*=s),g>0&&v>0&&c.push(Math.floor(m),Math.floor(f),Math.floor(g),Math.floor(v)),c}();return u.drawImage.apply(u,F(m)),p},setAspectRatio:function(t){var e=this,i=e.options;return e.disabled||a(t)||(i.aspectRatio=Math.max(0,t)||NaN,e.ready&&(e.initCropBox(),e.cropped&&e.renderCropBox())),e},setDragMode:function(t){var e=this,a=e.options,i=e.dragBox,o=e.face,n=void 0,r=void 0;return e.loaded&&!e.disabled&&(n="crop"===t,r=a.movable&&"move"===t,t=n||r?t:"none",w(i,"action",t),f(i,"cropper-crop",n),f(i,"cropper-move",r),a.cropBoxMovable||(w(o,"action",t),f(o,"cropper-crop",n),f(o,"cropper-move",r))),e}},At="cropper",Rt=At+"-hidden",It="error",St="load",Ut="ready",jt="crop",Pt=/^data:/,qt=/^data:image\/jpeg.*;base64,/,$t=void 0,Zt=function(){function t(e,a){K(this,t);var i=this;i.element=e,i.options=s({},q,o(a)&&a),i.loaded=!1,i.ready=!1,i.complete=!1,i.rotated=!1,i.cropped=!1,i.disabled=!1,i.replaced=!1,i.limited=!1,i.wheeling=!1,i.isImg=!1,i.originalUrl="",i.canvasData=null,i.cropBoxData=null,i.previews=null,i.pointers={},i.init()}return V(t,[{key:"init",value:function(){var t=this,e=t.element,a=e.tagName.toLowerCase(),i=void 0;if(!v(e,At)){if(w(e,At,t),"img"===a){if(t.isImg=!0,t.originalUrl=i=e.getAttribute("src"),!i)return;i=e.src}else"canvas"===a&&window.HTMLCanvasElement&&(i=e.toDataURL());t.load(i)}}},{key:"load",value:function(t){var e=this,a=e.options,i=e.element;if(t){if(e.url=t,e.imageData={},!a.checkOrientation||!window.ArrayBuffer)return void e.clone();if(Pt.test(t))return void(qt?e.read(I(t)):e.clone());var o=new XMLHttpRequest;o.onerror=o.onabort=function(){e.clone()},o.onload=function(){e.read(o.response)},a.checkCrossOrigin&&Y(t)&&i.crossOrigin&&(t=E(t)),o.open("get",t),o.responseType="arraybuffer",o.withCredentials="use-credentials"===i.crossOrigin,o.send()}}},{key:"read",value:function(t){var e=this,a=e.options,i=R(t),o=e.imageData,n=0,r=1,h=1;if(i>1)switch(e.url=S(t),i){case 2:r=-1;break;case 3:n=-180;break;case 4:h=-1;break;case 5:n=90,h=-1;break;case 6:n=90;break;case 7:n=90,r=-1;break;case 8:n=-90}a.rotatable&&(o.rotate=n),a.scalable&&(o.scaleX=r,o.scaleY=h),e.clone()}},{key:"clone",value:function(){var t=this,e=t.element,a=t.url,i=void 0,o=void 0,n=void 0,r=void 0;t.options.checkCrossOrigin&&Y(a)&&(i=e.crossOrigin,i?o=a:(i="anonymous",o=E(a))),t.crossOrigin=i,t.crossOriginUrl=o;var h=T("img");i&&(h.crossOrigin=i),h.src=o||a,t.image=h,t.onStart=n=d(t.start,t),t.onStop=r=d(t.stop,t),t.isImg?e.complete?t.start():y(e,St,n):(y(h,St,n),y(h,It,r),u(h,"cropper-hide"),e.parentNode.insertBefore(h,e.nextSibling))}},{key:"start",value:function(t){var e=this,a=e.isImg?e.element:e.image;t&&(x(a,St,e.onStart),x(a,It,e.onStop)),H(a,function(t,a){s(e.imageData,{naturalWidth:t,naturalHeight:a,aspectRatio:t/a}),e.loaded=!0,e.build()})}},{key:"stop",value:function(){var t=this,e=t.image;x(e,St,t.onStart),x(e,It,t.onStop),W(e),t.image=null}},{key:"build",value:function(){var t=this,e=t.options,a=t.element,i=t.image,o=void 0,r=void 0,h=void 0,c=void 0,s=void 0,d=void 0;if(t.loaded){t.ready&&t.unbuild();var l=T("div");l.innerHTML=$,t.container=o=a.parentNode,t.cropper=r=k(l,"cropper-container")[0],t.canvas=h=k(r,"cropper-canvas")[0],t.dragBox=c=k(r,"cropper-drag-box")[0],t.cropBox=s=k(r,"cropper-crop-box")[0],t.viewBox=k(r,"cropper-view-box")[0],t.face=d=k(s,"cropper-face")[0],L(h,i),u(a,Rt),o.insertBefore(r,a.nextSibling),t.isImg||m(i,"cropper-hide"),t.initPreview(),t.bind(),e.aspectRatio=Math.max(0,e.aspectRatio)||NaN,e.viewMode=Math.max(0,Math.min(3,Math.round(e.viewMode)))||0,t.cropped=e.autoCrop,e.autoCrop?e.modal&&u(c,"cropper-modal"):u(s,Rt),e.guides||u(k(s,"cropper-dashed"),Rt),e.center||u(k(s,"cropper-center"),Rt),e.background&&u(r,"cropper-bg"),e.highlight||u(d,"cropper-invisible"), 12 | e.cropBoxMovable&&(u(d,"cropper-move"),w(d,"action","all")),e.cropBoxResizable||(u(k(s,"cropper-line"),Rt),u(k(s,"cropper-point"),Rt)),t.setDragMode(e.dragMode),t.render(),t.ready=!0,t.setData(e.data),t.completing=setTimeout(function(){n(e.ready)&&y(a,Ut,e.ready,!0),M(a,Ut),M(a,jt,t.getData()),t.complete=!0},0)}}},{key:"unbuild",value:function(){var t=this;t.ready&&(t.complete||clearTimeout(t.completing),t.ready=!1,t.complete=!1,t.initialImageData=null,t.initialCanvasData=null,t.initialCropBoxData=null,t.containerData=null,t.canvasData=null,t.cropBoxData=null,t.unbind(),t.resetPreview(),t.previews=null,t.viewBox=null,t.cropBox=null,t.dragBox=null,t.canvas=null,t.container=null,W(t.cropper),t.cropper=null)}}],[{key:"noConflict",value:function(){return window.Cropper=$t,t}},{key:"setDefaults",value:function(t){s(q,o(t)&&t)}}]),t}();return s(Zt.prototype,st),s(Zt.prototype,lt),s(Zt.prototype,Dt),s(Zt.prototype,kt),s(Zt.prototype,Ot),s(Zt.prototype,zt),"undefined"!=typeof window&&($t=window.Cropper,window.Cropper=Zt),Zt}); -------------------------------------------------------------------------------- /demo/images/bot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/bot.jpg -------------------------------------------------------------------------------- /demo/images/crop-button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/crop-button.png -------------------------------------------------------------------------------- /demo/images/edit-logo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/edit-logo.gif -------------------------------------------------------------------------------- /demo/images/img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/img1.jpg -------------------------------------------------------------------------------- /demo/images/img2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/img2.jpg -------------------------------------------------------------------------------- /demo/images/img3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/img3.jpg -------------------------------------------------------------------------------- /demo/images/img4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/img4.jpg -------------------------------------------------------------------------------- /demo/images/img5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/img5.jpg -------------------------------------------------------------------------------- /demo/images/in.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/in.jpg -------------------------------------------------------------------------------- /demo/images/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/logo.jpg -------------------------------------------------------------------------------- /demo/images/mid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/mid.jpg -------------------------------------------------------------------------------- /demo/images/out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/out.jpg -------------------------------------------------------------------------------- /demo/images/out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreykurenkov/ObjectCropBot/17fc31d1e6b852a47d86a46a7eabf6f94c86ded1/demo/images/out.png -------------------------------------------------------------------------------- /demo/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: Arial, "Helvetica Neue", Helvetica, sans-serif; 3 | } 4 | 5 | .centered { 6 | text-align: center; 7 | } 8 | 9 | .container { 10 | max-width: 640px; 11 | margin: 0px auto; 12 | } 13 | 14 | #controls { 15 | width: 100%; 16 | height: 400px; 17 | } 18 | 19 | #input-container { 20 | width: 36%; 21 | height: 100%; 22 | float:left; 23 | border: 2px solid #008CBA; 24 | } 25 | 26 | .button { 27 | background-color: #4CAF50; /* Green */ 28 | border: none; 29 | color: white; 30 | text-align: center; 31 | vertical-align: middle; 32 | text-decoration: none; 33 | height: 20%; 34 | width: 20%; 35 | margin-left: 2%; 36 | margin-right: 2%; 37 | margin-top: 150px; 38 | font-size: 20px; 39 | -webkit-transition-duration: 0.4s; /* Safari */ 40 | transition-duration: 0.4s; 41 | cursor: pointer; 42 | float:left; 43 | } 44 | 45 | #out-container { 46 | width: 39%; 47 | float:left; 48 | margin: 0px auto; 49 | height: 100%; 50 | border: 2px solid #008CBA; 51 | } 52 | 53 | .loader { 54 | border: 16px solid #f3f3f3; /* Light grey */ 55 | border-top: 16px solid #3498db; /* Blue */ 56 | border-radius: 50%; 57 | width: 120px; 58 | height: 120px; 59 | animation: spin 2s linear infinite; 60 | display: none; 61 | padding: 0; 62 | margin-left: auto; 63 | margin-right: auto; 64 | margin-top: 130px; 65 | } 66 | 67 | @keyframes spin { 68 | 0% { transform: rotate(0deg); } 69 | 100% { transform: rotate(360deg); } 70 | } 71 | 72 | 73 | img { 74 | max-width: 100%; 75 | } 76 | 77 | #canvas { 78 | padding: 0; 79 | margin: auto; 80 | border: 1px solid #511; 81 | } 82 | 83 | .thumbnails img { 84 | height: 80px; 85 | border: 4px solid #555; 86 | padding: 1px; 87 | margin: 0 10px 10px 0; 88 | } 89 | 90 | .thumbnails img:hover { 91 | border: 4px solid #00ccff; 92 | cursor:pointer; 93 | } 94 | 95 | .examples img { 96 | height: 100px; 97 | margin: 0 10px 10px 0; 98 | } 99 | 100 | .preview img { 101 | border: 4px solid #444; 102 | padding: 1px; 103 | width: 800px; 104 | } 105 | 106 | .button { 107 | background-color: white; 108 | color: black; 109 | border: 2px solid #008CBA; 110 | } 111 | 112 | .button:hover { 113 | background-color: #008CBA; 114 | color: white; 115 | } 116 | 117 | .button:disabled { 118 | opacity: 0.65; 119 | cursor: not-allowed; 120 | } 121 | 122 | -------------------------------------------------------------------------------- /evalPerImage.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Full scene evaluation of DeepMask/SharpMask 8 | ------------------------------------------------------------------------------]] 9 | 10 | require 'torch' 11 | require 'cutorch' 12 | require 'image' 13 | 14 | local cjson = require 'cjson' 15 | local tds = require 'tds' 16 | local coco = require 'coco' 17 | 18 | paths.dofile('DeepMask.lua') 19 | paths.dofile('SharpMask.lua') 20 | 21 | -------------------------------------------------------------------------------- 22 | -- parse arguments 23 | local cmd = torch.CmdLine() 24 | cmd:text() 25 | cmd:text('full scene evaluation of DeepMask/SharpMask') 26 | cmd:text() 27 | cmd:argument('-model', 'model to load') 28 | cmd:text('Options:') 29 | cmd:option('-datadir', 'data/', 'data directory') 30 | cmd:option('-seed', 1, 'manually set RNG seed') 31 | cmd:option('-gpu', 1, 'gpu device') 32 | cmd:option('-split', 'val', 'dataset split to be used (train/val)') 33 | cmd:option('-np', 500,'number of proposals') 34 | cmd:option('-thr', .2, 'mask binary threshold') 35 | cmd:option('-save', false, 'save top proposals') 36 | cmd:option('-startAt', 1, 'start image id') 37 | cmd:option('-endAt', 5000, 'end image id') 38 | cmd:option('-smin', -2.5, 'min scale') 39 | cmd:option('-smax', .5, 'max scale') 40 | cmd:option('-sstep', .5, 'scale step') 41 | cmd:option('-timer', false, 'breakdown timer') 42 | cmd:option('-dm', false, 'use DeepMask version of SharpMask') 43 | 44 | local config = cmd:parse(arg) 45 | 46 | -------------------------------------------------------------------------------- 47 | -- various initializations 48 | torch.setdefaulttensortype('torch.FloatTensor') 49 | cutorch.setDevice(config.gpu) 50 | torch.manualSeed(config.seed) 51 | math.randomseed(config.seed) 52 | local maskApi = coco.MaskApi 53 | local meanstd = {mean={ 0.485, 0.456, 0.406 }, std={ 0.229, 0.224, 0.225 }} 54 | 55 | -------------------------------------------------------------------------------- 56 | -- load model and config 57 | print('| loading model file... ' .. config.model) 58 | local m = torch.load(config.model..'/model.t7') 59 | local c = m.config 60 | for k,v in pairs(c) do if config[k] == nil then config[k] = v end end 61 | local epoch = 0 62 | if paths.filep(config.model..'/log') then 63 | for line in io.lines(config.model..'/log') do 64 | if string.find(line,'train') then epoch = epoch + 1 end 65 | end 66 | print(string.format('| number of examples seen until now: %d (%d epochs)', 67 | epoch*config.maxload*config.batch,epoch)) 68 | end 69 | 70 | local model = m.model 71 | model:inference(config.np) 72 | model:cuda() 73 | 74 | -------------------------------------------------------------------------------- 75 | -- directory to save results 76 | local savedir = string.format('%s/epoch=%d/',config.model,epoch) 77 | print(string.format('| saving results results in %s',savedir)) 78 | os.execute(string.format('mkdir -p %s',savedir)) 79 | os.execute(string.format('mkdir -p %s/t7',savedir)) 80 | os.execute(string.format('mkdir -p %s/jsons',savedir)) 81 | if config.save then os.execute(string.format('mkdir -p %s/res',savedir)) end 82 | 83 | -------------------------------------------------------------------------------- 84 | -- create inference module 85 | local scales = {} 86 | for i = config.smin,config.smax,config.sstep do table.insert(scales,2^i) end 87 | 88 | if torch.type(model)=='nn.DeepMask' then 89 | paths.dofile('InferDeepMask.lua') 90 | elseif torch.type(model)=='nn.SharpMask' then 91 | paths.dofile('InferSharpMask.lua') 92 | end 93 | 94 | local infer = Infer{ 95 | np = config.np, 96 | scales = scales, 97 | meanstd = meanstd, 98 | model = model, 99 | iSz = config.iSz, 100 | dm = config.dm, 101 | timer = config.timer, 102 | } 103 | 104 | -------------------------------------------------------------------------------- 105 | -- get list of eval images 106 | local annFile = string.format('%s/annotations/instances_%s2014.json', 107 | config.datadir,config.split) 108 | local coco = coco.CocoApi(annFile) 109 | local imgIds = coco:getImgIds() 110 | imgIds,_ = imgIds:sort() 111 | 112 | -------------------------------------------------------------------------------- 113 | -- function: encode proposals 114 | local function encodeProps(props,np,img,k,masks,scores) 115 | local t = (k-1)*np 116 | local enc = maskApi.encode(masks) 117 | 118 | for i = 1, np do 119 | local elem = tds.Hash() 120 | elem.segmentation = tds.Hash(enc[i]) 121 | elem.image_id=img.id 122 | elem.category_id=1 123 | elem.score=scores[i][1] 124 | 125 | props[t+i] = elem 126 | end 127 | end 128 | 129 | -------------------------------------------------------------------------------- 130 | -- function: convert props to json and save 131 | local function saveProps(props,savedir,s,e) 132 | --t7 133 | local pathsvt7 = string.format('%s/t7/props-%d-%d.t7', savedir,s,e) 134 | torch.save(pathsvt7,props) 135 | --json 136 | local pathsvjson = string.format('%s/jsons/props-%d-%d.json', savedir,s,e) 137 | local propsjson = {} 138 | for _,prop in pairs(props) do -- hash2table 139 | local elem = {} 140 | elem.category_id = prop.category_id 141 | elem.image_id = prop.image_id 142 | elem.score = prop.score 143 | elem.segmentation={ 144 | size={prop.segmentation.size[1],prop.segmentation.size[2]}, 145 | counts = prop.segmentation.counts or prop.segmentation.count 146 | } 147 | table.insert(propsjson,elem) 148 | end 149 | local jsonText = cjson.encode(propsjson) 150 | local f = io.open(pathsvjson,'w'); f:write(jsonText); f:close() 151 | collectgarbage() 152 | end 153 | 154 | -------------------------------------------------------------------------------- 155 | -- function: read image 156 | local function readImg(datadir,split,fileName) 157 | local pathImg = string.format('%s/%s2014/%s',datadir,split,fileName) 158 | local inp = image.load(pathImg,3) 159 | return inp 160 | end 161 | 162 | -------------------------------------------------------------------------------- 163 | -- run 164 | print('| start eval') 165 | local props, svcount = tds.Hash(), config.startAt 166 | for k = config.startAt,config.endAt do 167 | xlua.progress(k,config.endAt) 168 | 169 | -- load image 170 | local img = coco:loadImgs(imgIds[k])[1] 171 | local input = readImg(config.datadir,config.split,img.file_name) 172 | local h,w = img.height,img.width 173 | 174 | -- forward all scales 175 | infer:forward(input) 176 | 177 | -- get top proposals 178 | local masks,scores = infer:getTopProps(config.thr,h,w) 179 | 180 | -- encode proposals 181 | encodeProps(props,config.np,img,k,masks,scores) 182 | 183 | -- save top masks? 184 | if config.save then 185 | local res = input:clone() 186 | maskApi.drawMasks(res, masks, 10) 187 | image.save(string.format('%s/res/%d.jpg',savedir,k),res) 188 | end 189 | 190 | -- save proposals 191 | if k%500 == 0 then 192 | saveProps(props,savedir,svcount,k); props = tds.Hash(); collectgarbage() 193 | svcount = svcount + 500 194 | end 195 | 196 | collectgarbage() 197 | end 198 | 199 | if config.timer then infer:printTiming() end 200 | collectgarbage() 201 | print('| finish') 202 | -------------------------------------------------------------------------------- /evalPerPatch.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Per patch evaluation of DeepMask/SharpMask 8 | ------------------------------------------------------------------------------]] 9 | 10 | require 'torch' 11 | require 'cutorch' 12 | 13 | paths.dofile('DeepMask.lua') 14 | paths.dofile('SharpMask.lua') 15 | 16 | -------------------------------------------------------------------------------- 17 | -- parse arguments 18 | local cmd = torch.CmdLine() 19 | cmd:text() 20 | cmd:text('per patch evaluation of DeepMask/SharpMask') 21 | cmd:text() 22 | cmd:argument('-model', 'model to load') 23 | cmd:text('Options:') 24 | cmd:option('-seed', 1, 'Manually set RNG seed') 25 | cmd:option('-gpu', 1, 'gpu device') 26 | cmd:option('-maxload', 5000, 'max number of training batches per epoch') 27 | cmd:option('-testmaxload', 5000, 'max number of testing batches') 28 | cmd:option('-save', false, 'save output') 29 | 30 | local config = cmd:parse(arg) 31 | 32 | -------------------------------------------------------------------------------- 33 | -- various initializations 34 | torch.setdefaulttensortype('torch.FloatTensor') 35 | cutorch.setDevice(config.gpu) 36 | torch.manualSeed(config.seed) 37 | math.randomseed(config.seed) 38 | local inputs = torch.CudaTensor() 39 | 40 | -------------------------------------------------------------------------------- 41 | -- loading model and config 42 | print('| loading model file...' .. config.model) 43 | local m = torch.load(config.model..'/model.t7') 44 | local c = m.config 45 | for k,v in pairs(c) do if config[k] == nil then config[k] = v end end 46 | local epoch = 0 47 | if paths.filep(config.model..'/log') then 48 | for line in io.lines(config.model..'/log') do 49 | if string.find(line,'train') then epoch = epoch + 1 end 50 | end 51 | print(string.format('| number of examples seen until now: %d (%d epochs)', 52 | epoch*config.maxload*config.batch,epoch)) 53 | end 54 | config.hfreq = 0 -- only evaluate masks 55 | 56 | local model = m.model 57 | if torch.type(model)=='nn.DeepMask' then 58 | model=nn.Sequential():add(model.trunk):add(model.maskBranch) 59 | end 60 | model:evaluate() 61 | 62 | -------------------------------------------------------------------------------- 63 | -- directory to save results 64 | local savedir 65 | if config.save then 66 | require 'image' 67 | savedir = string.format('%s/epoch=%d/res-patch/',config.model,epoch) 68 | os.execute(string.format('mkdir -p %s',savedir)) 69 | end 70 | 71 | -------------------------------------------------------------------------------- 72 | -- initialize data provider and mask meter 73 | local DataLoader = paths.dofile('DataLoader.lua') 74 | local trainLoader, valLoader = DataLoader.create(config) 75 | local loader 76 | if config.loadfromtrain then 77 | loader = trainLoader 78 | else 79 | loader = valLoader 80 | end 81 | paths.dofile('trainMeters.lua') 82 | local maskmeter = IouMeter(0.5,config.testmaxload*config.batch) 83 | 84 | -------------------------------------------------------------------------------- 85 | -- function display output 86 | local function saveRes(input,target,output,savedir,n) 87 | local batch,h,w = target:size(1),config.gSz,config.gSz 88 | 89 | local input,target,output = input:float(),target:float(),output:float() 90 | input = input:narrow(3,16,config.iSz):narrow(4,16,config.iSz) 91 | output:mul(-1):exp():add(1):pow(-1) -- transform outs in probability 92 | output = output:view(batch,h,w) 93 | 94 | local imgRGB = torch.Tensor(batch,3,h,w):zero() 95 | local outJet = torch.Tensor(batch,3,h,w):zero() 96 | 97 | for b = 1, batch do 98 | imgRGB:narrow(1,b,1):copy(image.scale(input[b],w,h)) 99 | local oj = torch.floor(output[b]*100):add(1):double() 100 | oj = image.scale(oj,w,h); oj = image.y2jet(oj) 101 | outJet:narrow(1,b,1):copy(oj) 102 | local mask = image.scale(target[b],w,h):ge(0):double() 103 | local me = image.erode(mask,torch.DoubleTensor(3,3):fill(1)) 104 | local md = image.dilate(mask,torch.DoubleTensor(3,3):fill(1)) 105 | local maskf = md - me 106 | maskf = maskf:eq(1) 107 | imgRGB:narrow(1,b,1):add(-imgRGB:min()):mul(1/imgRGB:max()) 108 | imgRGB[b][1][maskf]=1; imgRGB[b][2][maskf]=0; imgRGB[b][3][maskf]=0 109 | end 110 | 111 | -- concatenate 112 | local res = torch.Tensor(3,h*batch,w*2):zero() 113 | for b = 1, batch do 114 | res:narrow(2,(b-1)*h+1,h):narrow(3,1,w):copy(imgRGB[b]) 115 | res:narrow(2,(b-1)*h+1,h):narrow(3,w+1,w):copy(outJet[b]) 116 | end 117 | 118 | image.save(string.format('%s/%d.jpg',savedir,n),res) 119 | end 120 | 121 | names = {'train','test'} 122 | limits = {config.maxload,config.testmaxload} 123 | 124 | for i,loader in ipairs{trainLoader,valLoader} do 125 | -------------------------------------------------------------------------------- 126 | -- start evaluation 127 | print(string.format('| start per batch evaluation for %s set',names[i])) 128 | sys.tic() 129 | maskmeter:reset() 130 | for n, sample in loader:run() do 131 | xlua.progress(n,config.testmaxload) 132 | 133 | -- copy input and target to the GPU 134 | inputs:resize(sample.inputs:size()):copy(sample.inputs) 135 | 136 | -- infer mask in batch 137 | local output = model:forward(inputs):float() 138 | cutorch.synchronize() 139 | output = output:view(sample.labels:size()) 140 | 141 | -- compute IoU 142 | maskmeter:add(output,sample.labels) 143 | 144 | -- save? 145 | if config.save then 146 | saveRes(sample.inputs, sample.labels, output, savedir, n) 147 | end 148 | collectgarbage() 149 | end 150 | 151 | -------------------------------------------------------------------------------- 152 | -- log 153 | print('Results:') 154 | 155 | local log = string.format( 156 | '| # samples: %d\n'.. 157 | '| samples/s %7d '.. 158 | '| mean %06.2f median %06.2f '.. 159 | 'iou@.5 %06.2f iou@.7 %06.2f ', 160 | maskmeter.n,config.batch*config.testmaxload/sys.toc(), 161 | maskmeter:value('mean'),maskmeter:value('median'), 162 | maskmeter:value('0.5'), maskmeter:value('0.7') 163 | ) 164 | print(log) 165 | print('----------------------------------------------') 166 | end 167 | print('| finish') 168 | cutorch.synchronize() 169 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Object Crop Bot 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 |
14 |
15 |
16 |

A fun lil' thing that crops out objects from an image using AI.

17 | 18 |
19 |
20 | No Image Loaded 21 |
22 | 25 |
26 |
27 | 28 |
29 | 30 |
31 |
32 | 33 |
34 |

Select an image to play with from below, or upload your own: 35 |

36 | 37 |
38 | 39 | 40 | 41 | 42 | 43 |
44 |
45 | 46 | Move along... 47 |
48 |

Brought to you by the Andrey Kurenkov. Built at TreeHacks, Stanford (mostly over one sleepless night). Cropping AI "DeepMask" entirely by Facebook, great cropping UI by Fengyuan Chen. Code on Github.

49 | 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /modelUtils.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Utility functions for models 8 | ------------------------------------------------------------------------------]] 9 | 10 | local utils = {} 11 | 12 | -------------------------------------------------------------------------------- 13 | -- all BN modules in ResNet to be transformed into SpatialConstDiagonal 14 | -- (with inn routines) 15 | local inn = require 'inn' 16 | local innutils = require 'inn.utils' 17 | if not nn.SpatialConstDiagonal then 18 | torch.class('nn.SpatialConstDiagonal', 'inn.ConstAffine') 19 | end 20 | utils.BNtoFixed = innutils.BNtoFixed 21 | 22 | -------------------------------------------------------------------------------- 23 | -- function: linear2convTrunk 24 | function utils.linear2convTrunk(net,fSz) 25 | return net:replace(function(x) 26 | if torch.typename(x):find('Linear') then 27 | local nInp,nOut = x.weight:size(2)/(fSz*fSz),x.weight:size(1) 28 | local w = torch.reshape(x.weight,nOut,nInp,fSz,fSz) 29 | local y = cudnn.SpatialConvolution(nInp,nOut,fSz,fSz,1,1) 30 | y.weight:copy(w); y.gradWeight:copy(w); y.bias:copy(x.bias) 31 | return y 32 | elseif torch.typename(x):find('Threshold') then 33 | return cudnn.ReLU() 34 | elseif torch.typename(x):find('View') or 35 | torch.typename(x):find('SpatialZeroPadding') then 36 | return nn.Identity() 37 | else 38 | return x 39 | end 40 | end 41 | ) 42 | end 43 | 44 | -------------------------------------------------------------------------------- 45 | -- function: linear2convHeads 46 | function utils.linear2convHead(net) 47 | return net:replace(function(x) 48 | if torch.typename(x):find('Linear') then 49 | local nInp,nOut = x.weight:size(2),x.weight:size(1) 50 | local w = torch.reshape(x.weight,nOut,nInp,1,1) 51 | local y = cudnn.SpatialConvolution(nInp,nOut,1,1,1,1) 52 | y.weight:copy(w); y.gradWeight:copy(w); y.bias:copy(x.bias) 53 | return y 54 | elseif torch.typename(x):find('Threshold') then 55 | return cudnn.ReLU() 56 | elseif not torch.typename(x):find('View') and 57 | not torch.typename(x):find('Copy') then 58 | return x 59 | end 60 | end 61 | ) 62 | end 63 | 64 | -------------------------------------------------------------------------------- 65 | -- function: replace 0-padding of 3x3 conv into mirror-padding 66 | function utils.updatePadding(net, nn_padding) 67 | if torch.typename(net) == "nn.Sequential" or 68 | torch.typename(net) == "nn.ConcatTable" then 69 | for i = #net.modules,1,-1 do 70 | local out = utils.updatePadding(net:get(i), nn_padding) 71 | if out ~= -1 then 72 | local pw, ph = out[1], out[2] 73 | net.modules[i] = nn.Sequential():add(nn_padding(pw,pw,ph,ph)) 74 | :add(net.modules[i]):cuda() 75 | end 76 | end 77 | else 78 | if torch.typename(net) == "nn.SpatialConvolution" or 79 | torch.typename(net) == "cudnn.SpatialConvolution" then 80 | if (net.kW == 3 and net.kH == 3) or (net.kW==7 and net.kH==7) then 81 | local pw, ph = net.padW, net.padH 82 | net.padW, net.padH = 0, 0 83 | return {pw,ph} 84 | end 85 | end 86 | end 87 | return -1 88 | end 89 | 90 | return utils 91 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Train DeepMask or SharpMask 8 | ------------------------------------------------------------------------------]] 9 | 10 | require 'torch' 11 | require 'cutorch' 12 | 13 | -------------------------------------------------------------------------------- 14 | -- parse arguments 15 | local cmd = torch.CmdLine() 16 | cmd:text() 17 | cmd:text('train DeepMask or SharpMask') 18 | cmd:text() 19 | cmd:text('Options:') 20 | cmd:option('-rundir', 'exps/', 'experiments directory') 21 | cmd:option('-datadir', 'data/', 'data directory') 22 | cmd:option('-seed', 1, 'manually set RNG seed') 23 | cmd:option('-gpu', 1, 'gpu device') 24 | cmd:option('-nthreads', 2, 'number of threads for DataSampler') 25 | cmd:option('-reload', '', 'reload a network from given directory') 26 | cmd:text() 27 | cmd:text('Training Options:') 28 | cmd:option('-batch', 32, 'training batch size') 29 | cmd:option('-lr', 0, 'learning rate (0 uses default lr schedule)') 30 | cmd:option('-momentum', 0.9, 'momentum') 31 | cmd:option('-wd', 5e-4, 'weight decay') 32 | cmd:option('-maxload', 2000, 'max number of training batches per epoch') 33 | cmd:option('-testmaxload', 500, 'max number of testing batches') 34 | cmd:option('-maxepoch', 300, 'max number of training epochs') 35 | cmd:option('-iSz', 160, 'input size') 36 | cmd:option('-oSz', 56, 'output size') 37 | cmd:option('-gSz', 112, 'ground truth size') 38 | cmd:option('-shift', 16, 'shift jitter allowed') 39 | cmd:option('-scale', .25, 'scale jitter allowed') 40 | cmd:option('-hfreq', 0.5, 'mask/score head sampling frequency') 41 | cmd:option('-scratch', false, 'train DeepMask with randomly initialize weights') 42 | cmd:text() 43 | cmd:text('SharpMask Options:') 44 | cmd:option('-dm', '', 'path to trained deepmask (if dm, then train SharpMask)') 45 | cmd:option('-km', 32, 'km') 46 | cmd:option('-ks', 32, 'ks') 47 | 48 | local config = cmd:parse(arg) 49 | 50 | -------------------------------------------------------------------------------- 51 | -- various initializations 52 | torch.setdefaulttensortype('torch.FloatTensor') 53 | cutorch.setDevice(config.gpu) 54 | torch.manualSeed(config.seed) 55 | math.randomseed(config.seed) 56 | 57 | local trainSm -- flag to train SharpMask (true) or DeepMask (false) 58 | if #config.dm > 0 then 59 | trainSm = true 60 | config.hfreq = 0 -- train only mask head 61 | config.gSz = config.iSz -- in sharpmask, ground-truth has same dim as input 62 | end 63 | 64 | paths.dofile('DeepMask.lua') 65 | if trainSm then paths.dofile('SharpMask.lua') end 66 | 67 | -------------------------------------------------------------------------------- 68 | -- reload? 69 | local epoch, model 70 | if #config.reload > 0 then 71 | epoch = 0 72 | if paths.filep(config.reload..'/log') then 73 | for line in io.lines(config.reload..'/log') do 74 | if string.find(line,'train') then epoch = epoch + 1 end 75 | end 76 | end 77 | print(string.format('| reloading experiment %s', config.reload)) 78 | local m = torch.load(string.format('%s/model.t7', config.reload)) 79 | model, config = m.model, m.config 80 | end 81 | 82 | -------------------------------------------------------------------------------- 83 | -- directory to save log and model 84 | local pathsv = trainSm and 'sharpmask/exp' or 'deepmask/exp' 85 | config.rundir = cmd:string( 86 | paths.concat(config.reload=='' and config.rundir or config.reload, pathsv), 87 | config,{rundir=true, gpu=true, reload=true, datadir=true, dm=true} --ignore 88 | ) 89 | 90 | print(string.format('| running in directory %s', config.rundir)) 91 | os.execute(string.format('mkdir -p %s',config.rundir)) 92 | 93 | -------------------------------------------------------------------------------- 94 | -- network and criterion 95 | model = model or (trainSm and nn.SharpMask(config) or nn.DeepMask(config)) 96 | local criterion = nn.SoftMarginCriterion():cuda() 97 | 98 | -------------------------------------------------------------------------------- 99 | -- initialize data loader 100 | local DataLoader = paths.dofile('DataLoader.lua') 101 | local trainLoader, valLoader = DataLoader.create(config) 102 | 103 | -------------------------------------------------------------------------------- 104 | -- initialize Trainer (handles training/testing loop) 105 | if trainSm then 106 | paths.dofile('TrainerSharpMask.lua') 107 | else 108 | paths.dofile('TrainerDeepMask.lua') 109 | end 110 | local trainer = Trainer(model, criterion, config) 111 | 112 | -------------------------------------------------------------------------------- 113 | -- do it 114 | epoch = epoch or 1 115 | print('| start training') 116 | for i = 1, config.maxepoch do 117 | trainer:train(epoch,trainLoader) 118 | if i%2 == 0 then trainer:test(epoch,valLoader) end 119 | epoch = epoch + 1 120 | end 121 | -------------------------------------------------------------------------------- /trainMeters.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 3 | This source code is licensed under the BSD-style license found in the 4 | LICENSE file in the root directory of this source tree. An additional grant 5 | of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | Contains the tree metrics used during training/evaluation: 8 | - lossmeter: measure the average loss. 9 | - binarymeter: measure error of predicted objectness score and ground truth 10 | objectness annotation. 11 | - ioumeter: measure iou between infered and ground truth masks. 12 | ------------------------------------------------------------------------------]] 13 | 14 | -------------------------------------------------------------------------------- 15 | -- loss meter 16 | do 17 | local LossMeter = torch.class('LossMeter') 18 | -- init 19 | function LossMeter:__init() 20 | self:reset() 21 | end 22 | 23 | -- function: reset 24 | function LossMeter:reset() 25 | self.sum = 0; self.n = 0 26 | end 27 | 28 | -- function: add 29 | function LossMeter:add(value,n) 30 | n = n or 1 31 | self.sum = self.sum + value 32 | self.n = self.n + n 33 | end 34 | 35 | -- function: value 36 | function LossMeter:value() 37 | return self.sum / self.n 38 | end 39 | end 40 | 41 | -------------------------------------------------------------------------------- 42 | -- binary meter 43 | do 44 | local BinaryMeter = torch.class('BinaryMeter') 45 | -- init 46 | function BinaryMeter:__init() 47 | self:reset() 48 | end 49 | -- function: reset 50 | function BinaryMeter:reset() 51 | self.acc = 0; self.n = 0 52 | end 53 | 54 | -- function: add 55 | function BinaryMeter:add(output, target) 56 | target, output = target:squeeze(), output:squeeze() 57 | assert(output:nElement() == target:nElement(), 58 | 'target and output do not match') 59 | 60 | local acc = torch.cmul(output,target) 61 | self.acc = self.acc + acc:ge(0):sum() 62 | self.n = self.n + output:size(1) 63 | end 64 | 65 | -- function: value 66 | function BinaryMeter:value() 67 | local res = self.acc/self.n 68 | return res*100 69 | end 70 | end 71 | 72 | -------------------------------------------------------------------------------- 73 | -- iou meter 74 | do 75 | local IouMeter = torch.class('IouMeter') 76 | -- init 77 | function IouMeter:__init(thr,sz) 78 | self.sz = sz 79 | self.iou = torch.Tensor(sz) 80 | self.thr = math.log(thr/(1-thr)) 81 | self:reset() 82 | end 83 | 84 | -- function: reset 85 | function IouMeter:reset() 86 | self.iou:zero(); self.n = 0 87 | end 88 | 89 | -- function: add 90 | function IouMeter:add(output, target) 91 | target, output = target:squeeze():float(), output:squeeze():float() 92 | assert(output:nElement() == target:nElement(), 93 | 'target and output do not match') 94 | 95 | local batch,h,w = output:size(1),output:size(2),output:size(3) 96 | local nOuts = h*w 97 | local iouptr = self.iou:data() 98 | 99 | local int,uni 100 | local pred = output:ge(self.thr) 101 | local pPtr,tPtr = pred:data(), target:data() 102 | for b = 0,batch-1 do 103 | int,uni = 0,0 104 | for i = 0,nOuts-1 do 105 | local id = b*nOuts+i 106 | if pPtr[id] == 1 and tPtr[id] == 1 then int = int + 1 end 107 | if pPtr[id] == 1 or tPtr[id] == 1 then uni = uni + 1 end 108 | end 109 | if uni > 0 then iouptr[self.n+b] = int/uni end 110 | end 111 | self.n = self.n + batch 112 | end 113 | 114 | -- function: value 115 | function IouMeter:value(s) 116 | if s then 117 | local res 118 | local nb = math.max(self.iou:ne(0):sum(),1) 119 | local iou = self.iou:narrow(1,1,nb) 120 | if s == 'mean' then 121 | res = iou:mean() 122 | elseif s == 'median' then 123 | res = iou:median():squeeze() 124 | elseif tonumber(s) then 125 | local iouSort, _ = iou:sort() 126 | res = iouSort:ge(tonumber(s)):sum()/nb 127 | elseif s == 'hist' then 128 | res = torch.histc(iou,20)/nb 129 | end 130 | 131 | return res*100 132 | else 133 | local value = {} 134 | for _,s in ipairs(self.stats) do 135 | value[s] = self:value(s) 136 | end 137 | return value 138 | end 139 | end 140 | end 141 | --------------------------------------------------------------------------------