├── lib.pyc ├── test.pyc ├── val.pyc ├── config.pyc ├── netdef.pyc ├── train.pyc ├── val ├── 7_7.png ├── 7_8.png ├── 7_9.png ├── 7_11.png └── 7_12.png ├── epoch-4.13.png ├── epoch-4.22.png ├── 0.90xx-80epoch.png ├── confusionTree.png ├── trainInterface.pyc ├── deep-unet-short.pdf ├── predictInterface.pyc ├── confusion Tree ├── confusionTree-balanced ├── confusionTree-2nd ├── confusionMatrix └── confusionTree ├── confusionTree ├── README.md ├── lib.py ├── train.py ├── config.py ├── findBestEpoch.py ├── predictInterface.py ├── test.py ├── val.py ├── netdef.py └── trainInterface.py /lib.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/lib.pyc -------------------------------------------------------------------------------- /test.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/test.pyc -------------------------------------------------------------------------------- /val.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/val.pyc -------------------------------------------------------------------------------- /config.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/config.pyc -------------------------------------------------------------------------------- /netdef.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/netdef.pyc -------------------------------------------------------------------------------- /train.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/train.pyc -------------------------------------------------------------------------------- /val/7_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/val/7_7.png -------------------------------------------------------------------------------- /val/7_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/val/7_8.png -------------------------------------------------------------------------------- /val/7_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/val/7_9.png -------------------------------------------------------------------------------- /val/7_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/val/7_11.png -------------------------------------------------------------------------------- /val/7_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/val/7_12.png -------------------------------------------------------------------------------- /epoch-4.13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/epoch-4.13.png -------------------------------------------------------------------------------- /epoch-4.22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/epoch-4.22.png -------------------------------------------------------------------------------- /0.90xx-80epoch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/0.90xx-80epoch.png -------------------------------------------------------------------------------- /confusionTree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/confusionTree.png -------------------------------------------------------------------------------- /trainInterface.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/trainInterface.pyc -------------------------------------------------------------------------------- /deep-unet-short.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/deep-unet-short.pdf -------------------------------------------------------------------------------- /predictInterface.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BigbyNick/TreeSegNet/HEAD/predictInterface.pyc -------------------------------------------------------------------------------- /confusion Tree/confusionTree-balanced: -------------------------------------------------------------------------------- 1 | (dp0 2 | (I0 3 | I1 4 | I2 5 | tp1 6 | (dp2 7 | (I1 8 | I2 9 | tp3 10 | (dp4 11 | (I2 12 | tp5 13 | (dp6 14 | s(I1 15 | tp7 16 | (dp8 17 | ss(I0 18 | tp9 19 | (dp10 20 | ss(I3 21 | I4 22 | I5 23 | tp11 24 | (dp12 25 | (I4 26 | I5 27 | tp13 28 | (dp14 29 | (I5 30 | tp15 31 | (dp16 32 | s(I4 33 | tp17 34 | (dp18 35 | ss(I3 36 | tp19 37 | (dp20 38 | ss. -------------------------------------------------------------------------------- /confusionTree: -------------------------------------------------------------------------------- 1 | (dp0 2 | (I1 3 | I2 4 | I3 5 | I4 6 | I5 7 | tp1 8 | (dp2 9 | (I2 10 | I3 11 | I4 12 | I5 13 | tp3 14 | (dp4 15 | (I2 16 | tp5 17 | (dp6 18 | s(I3 19 | I4 20 | I5 21 | tp7 22 | (dp8 23 | (I4 24 | I5 25 | tp9 26 | (dp10 27 | (I5 28 | tp11 29 | (dp12 30 | s(I4 31 | tp13 32 | (dp14 33 | ss(I3 34 | tp15 35 | (dp16 36 | sss(I1 37 | tp17 38 | (dp18 39 | ss(I0 40 | tp19 41 | (dp20 42 | s. -------------------------------------------------------------------------------- /confusion Tree/confusionTree-2nd: -------------------------------------------------------------------------------- 1 | (dp0 2 | (I0 3 | I1 4 | I2 5 | I3 6 | I5 7 | tp1 8 | (dp2 9 | (I0 10 | I2 11 | I3 12 | I5 13 | tp3 14 | (dp4 15 | (I0 16 | I5 17 | tp5 18 | (dp6 19 | (I5 20 | tp7 21 | (dp8 22 | s(I0 23 | tp9 24 | (dp10 25 | ss(I2 26 | I3 27 | tp11 28 | (dp12 29 | (I2 30 | tp13 31 | (dp14 32 | s(I3 33 | tp15 34 | (dp16 35 | sss(I1 36 | tp17 37 | (dp18 38 | ss(I4 39 | tp19 40 | (dp20 41 | s. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TreeSegNet 2 | The project of the TreeSegNet on the ISPRS 2D labeling Potsdam dataset. 3 | 4 | **The paper can be reached [here](https://arxiv.org/abs/1804.10879).** 5 | 6 | Paper coming soon... 7 | 8 | [ISPRS Semantic Labeling Contest (2D): Results on Potsdam](http://www2.isprs.org/commissions/comm2/wg4/potsdam-2d-semantic-labeling.html). **The results of my method name begins with BUCTY(BUCTY1, BUCTY2, BUCTY3, BUCTY4, and BUCTY5)** 9 | -------------------------------------------------------------------------------- /lib.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | from os.path import abspath,join 4 | # 填写 deep517/lib/ 的绝对路径 5 | absLibpPath = None 6 | 7 | if absLibpPath is None: 8 | _path = (abspath(__file__)) 9 | absLibpPath = absLibpPath or join(_path[:_path.index('deep517')+7],'lib') 10 | 11 | if absLibpPath not in sys.path: 12 | sys.path = [absLibpPath]+sys.path 13 | 14 | from yllibInterface import * 15 | import configManager 16 | 17 | if __name__ == '__main__': 18 | print(absLibpPath) 19 | pass 20 | -------------------------------------------------------------------------------- /confusion Tree/confusionMatrix: -------------------------------------------------------------------------------- 1 | cnumpy.core.multiarray 2 | _reconstruct 3 | p0 4 | (cnumpy 5 | ndarray 6 | p1 7 | (I0 8 | tp2 9 | S'b' 10 | p3 11 | tp4 12 | Rp5 13 | (I1 14 | (I6 15 | I6 16 | tp6 17 | cnumpy 18 | dtype 19 | p7 20 | (S'i8' 21 | p8 22 | I0 23 | I1 24 | tp9 25 | Rp10 26 | (I3 27 | S'<' 28 | p11 29 | NNNI-1 30 | I-1 31 | I0 32 | tp12 33 | bI00 34 | S'\x05\x1a@\x03\x00\x00\x00\x00o\x80\x12\x00\x00\x00\x00\x00\xf6\xae,\x00\x00\x00\x00\x00\xa1_\x19\x00\x00\x00\x00\x00\x8f\xe4\x03\x00\x00\x00\x00\x00\xa9\xe6\x17\x00\x00\x00\x00\x00}\xb7\x14\x00\x00\x00\x00\x00KP\x1a\x03\x00\x00\x00\x00v6\x0b\x00\x00\x00\x00\x00ni\x03\x00\x00\x00\x00\x00\xb4\x04\x00\x00\x00\x00\x00\x00\x06\x9a\x08\x00\x00\x00\x00\x00DD\x19\x00\x00\x00\x00\x00\x1f\xcf\x07\x00\x00\x00\x00\x00\x12mY\x01\x00\x00\x00\x00E\r7\x00\x00\x00\x00\x00\xbd\x0e\x00\x00\x00\x00\x00\x007\x83\x0b\x00\x00\x00\x00\x00b\x98\n\x00\x00\x00\x00\x00\xee3\x02\x00\x00\x00\x00\x00N\xb0\x1a\x00\x00\x00\x00\x00\xb5\xfa#\x01\x00\x00\x00\x00X\xbc\x00\x00\x00\x00\x00\x00\x11\x9e\x01\x00\x00\x00\x00\x00\x9e\xbb\x03\x00\x00\x00\x00\x00\xa9&\x00\x00\x00\x00\x00\x00\xd3\x0e\x00\x00\x00\x00\x00\x00\x13R\x02\x00\x00\x00\x00\x00`11\x00\x00\x00\x00\x00\x16g\x01\x00\x00\x00\x00\x00U%\x13\x00\x00\x00\x00\x00\x88)\n\x00\x00\x00\x00\x00\x9a\x07\t\x00\x00\x00\x00\x00\xc7\xf7\x05\x00\x00\x00\x00\x00\xb0N\x00\x00\x00\x00\x00\x00\\pO\x00\x00\x00\x00\x00' 35 | p13 36 | tp14 37 | b. -------------------------------------------------------------------------------- /confusion Tree/confusionTree: -------------------------------------------------------------------------------- 1 | (dp0 2 | (I0 3 | cnumpy.core.multiarray 4 | scalar 5 | p1 6 | (cnumpy 7 | dtype 8 | p2 9 | (S'i8' 10 | p3 11 | I0 12 | I1 13 | tp4 14 | Rp5 15 | (I3 16 | S'<' 17 | p6 18 | NNNI-1 19 | I-1 20 | I0 21 | tp7 22 | bS'\x01\x00\x00\x00\x00\x00\x00\x00' 23 | p8 24 | tp9 25 | Rp10 26 | g1 27 | (g5 28 | S'\x02\x00\x00\x00\x00\x00\x00\x00' 29 | p11 30 | tp12 31 | Rp13 32 | g1 33 | (g5 34 | S'\x03\x00\x00\x00\x00\x00\x00\x00' 35 | p14 36 | tp15 37 | Rp16 38 | g1 39 | (g5 40 | S'\x05\x00\x00\x00\x00\x00\x00\x00' 41 | p17 42 | tp18 43 | Rp19 44 | tp20 45 | (dp21 46 | (I0 47 | g1 48 | (g5 49 | S'\x02\x00\x00\x00\x00\x00\x00\x00' 50 | p22 51 | tp23 52 | Rp24 53 | g1 54 | (g5 55 | S'\x03\x00\x00\x00\x00\x00\x00\x00' 56 | p25 57 | tp26 58 | Rp27 59 | g1 60 | (g5 61 | S'\x05\x00\x00\x00\x00\x00\x00\x00' 62 | p28 63 | tp29 64 | Rp30 65 | tp31 66 | (dp32 67 | (I0 68 | g1 69 | (g5 70 | S'\x02\x00\x00\x00\x00\x00\x00\x00' 71 | p33 72 | tp34 73 | Rp35 74 | g1 75 | (g5 76 | S'\x03\x00\x00\x00\x00\x00\x00\x00' 77 | p36 78 | tp37 79 | Rp38 80 | tp39 81 | (dp40 82 | (I0 83 | tp41 84 | (dp42 85 | s(I2 86 | g1 87 | (g5 88 | S'\x03\x00\x00\x00\x00\x00\x00\x00' 89 | p43 90 | tp44 91 | Rp45 92 | tp46 93 | (dp47 94 | (I2 95 | tp48 96 | (dp49 97 | s(I3 98 | tp50 99 | (dp51 100 | sss(I5 101 | tp52 102 | (dp53 103 | ss(I1 104 | tp54 105 | (dp55 106 | ss(I4 107 | tp56 108 | (dp57 109 | s. -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | from lib import * 4 | import sys,os 5 | import lib 6 | from lib import dicto, glob, getArgvDic,filename 7 | from lib import show, loga, logl, imread, imsave 8 | from configManager import (getImgGtNames, indexOf, readgt, readimg, 9 | setMod, togt, toimg, makeTrainEnv) 10 | 11 | from config import c, cf 12 | 13 | setMod('train') 14 | 15 | from configManager import args 16 | args.names = getImgGtNames(c.names)[:] 17 | args.prefix = c.weightsPrefix 18 | args.classn = 6 19 | #args.window = (64*10,64*10) 20 | args.window = (64*8,64*8) 21 | #args.window = (64*1,64*1) 22 | [ 20. , 29.96875] 23 | # ============================================================================= 24 | # config BEGIN 25 | # ============================================================================= 26 | args.update( 27 | # batch=8, 28 | # batch=1, 29 | batch=2, #4G*2 30 | # batch=4, #8G*2 31 | # epoch=50, 32 | epoch=80, 33 | resume=0, 34 | epochSize = 10000, 35 | ) 36 | # ============================================================================= 37 | # config END 38 | # ============================================================================= 39 | 40 | 41 | 42 | 43 | argListt, argsFromSys = getArgvDic() 44 | args.update(argsFromSys) 45 | 46 | makeTrainEnv(args) 47 | c.args=(args) 48 | if __name__ == '__main__': 49 | import trainInterface as train 50 | train.train() 51 | pass 52 | 53 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | import lib 4 | from lib import dicto,dirname, basename,os,log,fileJoinPath, pathjoin 5 | from lib import show, loga, logl, imread, imsave 6 | 7 | from configManager import (getImgGtNames, indexOf, readgt, readimg, 8 | setMod, togt, toimg) 9 | from configManager import cf,c 10 | # ============================================================================= 11 | # config BEGIN 12 | # ============================================================================= 13 | cf.netdir = 'isprs' 14 | cf.project = None 15 | cf.experment = None 16 | 17 | cf.trainGlob = u'/home/victoria/0-images/isprs/after/train/*_RGB.tif' 18 | cf.toGtPath = lambda path:path.replace('_RGB.tif','_label.png') 19 | cf.val = u'/home/victoria/0-images/isprs/after/val/*_RGB.tif' 20 | 21 | cf.toValGtPath = None 22 | 23 | cf.testGlob = u'/home/victoria/0-images/isprs/test/*_RGB.tif' 24 | # ============================================================================= 25 | # config END 26 | # ============================================================================= 27 | 28 | 29 | filePath = fileJoinPath(__file__) 30 | jobDir = (os.path.split(dirname(filePath))[-1]) 31 | expDir = (os.path.split((filePath))[-1]) 32 | 33 | cf.project = cf.project or jobDir 34 | cf.experment = cf.experment or expDir 35 | 36 | cf.savename = '%s-%s-%s'%(cf.netdir,cf.experment,cf.project) 37 | 38 | cf.toValGtPath = cf.toValGtPath or cf.toGtPath 39 | #cf.valArgs = cf.valArgs or cf.trainArgs 40 | 41 | 42 | 43 | c.update(cf) 44 | c.cf = cf 45 | 46 | 47 | c.weightsPrefix = fileJoinPath(__file__,pathjoin(c.tmpdir,'weights/%s-%s'%(c.netdir,c.experment))) 48 | #show- map(readimg,c.names[:10]) 49 | if __name__ == '__main__': 50 | setMod('train') 51 | img = readimg(c.names[0]) 52 | gt = readgt(c.names[0]) 53 | show(img,gt) 54 | loga(gt) 55 | pass 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /findBestEpoch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Nov 25 13:33:24 2017 5 | 6 | @author: victoria 7 | """ 8 | from yllab import * 9 | training = False 10 | training = True 11 | findBest = False 12 | findBest = True 13 | #sortKey = None 14 | if training: 15 | pred('training.......') 16 | from train import * 17 | import trainInterface 18 | trainInterface.train() 19 | 20 | from val import * 21 | import val 22 | 23 | #import inferenceInterface 24 | #c.reload = inferenceInterface 25 | import predictInterface 26 | c.predictInterface = predictInterface 27 | 28 | #setMod('train') 29 | evaluFun,sortKey = accEvalu,'acc' 30 | if findBest: 31 | pred('auto Find Best Epoch .......') 32 | c.args.step = None 33 | 34 | epochs=sorted([findints(filename(p))[-1] for p in glob(c.weightsPrefix+'*')]) 35 | df = autoFindBestEpoch(c,evaluFun,sortkey=sortKey, savefig='epoch.png',epochs=epochs) 36 | # row = df.loc[df[sortKey].argmax()] 37 | row = df.loc[df[sortKey].argmax()] 38 | restore =int(row.restore) 39 | df.to_csv('%s:%s_epoch:%s.csv'%(sortKey,row[sortKey],restore)) 40 | else: 41 | restore = -1 42 | #%% 43 | if __name__ == '__main__': 44 | pred('refine .......') 45 | c.args.restore = restore 46 | c.args.step = .2 47 | # reload(c.reload) 48 | # inference = inferenceInterface.inference 49 | reload(predictInterface) 50 | inference = predictInterface.predict 51 | # c.inference = inference 52 | e = Evalu(evaluFun, 53 | # evaluName='restore-%s'%restore, 54 | valNames=c.names, 55 | # loadcsv=1, 56 | # logFormat='acc:{acc:.3f}, loss:{loss:.3f}', 57 | sortkey=sortKey, 58 | # loged=False, 59 | saveResoult=False, 60 | ) 61 | c.names.sort(key=lambda x:readgt(x).shape[0]) 62 | for name in c.names[:]: 63 | img,gt = readimg(name),readgt(name) 64 | prob = inference((name)) 65 | re = prob.argmax(2) 66 | e.evalu(re,gt,name) 67 | 68 | gtc = labelToColor(gt,colors) 69 | rec = labelToColor(re,colors) 70 | smallGap = 10 71 | # show(img[::smallGap,::smallGap],gtc[::smallGap,::smallGap],(gt!=re)[::smallGap,::smallGap],rec[::smallGap,::smallGap]) 72 | # diff = binaryDiff(re,gt) 73 | # show(img,diff,re) 74 | # show(img,diff) 75 | # show(diff) 76 | # yellowImg=gt[...,None]*img+(npa-[255,255,0]).astype(np.uint8)*~gt[...,None] 77 | # show(yellowImg,diff) 78 | # imsave(pathjoin(args.out,name+'.png'),uint8(re)) 79 | imsave('val/%s.tif'%name,uint8(rec)) 80 | print args.restore,e[sortKey].mean() 81 | 82 | 83 | 84 | ''' 85 | 86 | ''' -------------------------------------------------------------------------------- /predictInterface.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | ''' 3 | res-unet 4 | 全图预测 自动填充黑边 以适应上下采样 5 | 6 | Parameters 7 | ---------- 8 | ''' 9 | from lib import * 10 | import logging 11 | logging.basicConfig(level=logging.INFO) 12 | 13 | import numpy as np 14 | import mxnet as mx 15 | 16 | from collections import namedtuple 17 | from configManager import args,c 18 | 19 | if __name__ == '__main__': 20 | import val 21 | 22 | sym, arg_params, aux_params = mx.model.load_checkpoint( 23 | args.prefix, args.restore) 24 | # print(sym.list_outputs()) 25 | Batch = namedtuple('Batch', ['data']) 26 | #mod = mx.mod.Module(symbol=sym, label_names=None, context=mx.gpu()) 27 | 28 | args.simgShape = args.window 29 | if not isinstance(args.window,(tuple,list,np.ndarray)): 30 | args.simgShape = (args.window,args.window) 31 | mod = mx.mod.Module(symbol=sym, label_names=None, context=mx.gpu(1)) 32 | hh,ww = args.simgShape 33 | mod.bind(for_training=False, data_shapes=[ 34 | ('data', (1, 5, hh, ww))], label_shapes=None and [ 35 | ('softmax%d_label'%(i+1),(1,hh//2**i,ww//2**i)) for i in (0,)], 36 | force_rebind=False, ) 37 | mod.set_params(arg_params, aux_params, allow_missing=True) 38 | 39 | 40 | def readChannel(name, basenames=None): 41 | kinds = ['_RGB.tif','_IRRG.tif','_dsm.tif'] 42 | dirr = dirname(c['val']) 43 | # dirr = dirname(c['testGlob']) 44 | dirr = dirname(c.trainGlob) 45 | # dirr = dirname(c.testGlob) 46 | if not basenames: 47 | basenames = kinds 48 | imgs = [] 49 | if kinds[0] in basenames: 50 | path = pathjoin(dirr,name+kinds[0]) 51 | img = imread(path) 52 | shape = img.shape[:2] 53 | imgs.append(img) 54 | if kinds[1] in basenames: 55 | path = pathjoin(dirr,name+kinds[1]) 56 | img = imread(path) 57 | imgs.append(img[...,:1]) 58 | if kinds[2] in basenames: 59 | path = pathjoin(dirr,name+kinds[2]) 60 | img = imread(path) 61 | if img.shape != shape: 62 | img = cv2.resize(img,shape) 63 | imgs.append(img[...,None]) 64 | if len(imgs) == 1: 65 | return imgs[0] 66 | mimg = reduce(lambda x,y:np.append(x,y,2),imgs) 67 | return mimg 68 | 69 | def predict(name): 70 | img = readChannel(name)/255. 71 | # img = img[::3,::3] 72 | def handleSimg(simg): 73 | simg = simg.transpose(2,0,1) 74 | mod.forward(Batch(data=[mx.nd.array(np.expand_dims( 75 | simg, 0))]), is_train=False) 76 | prob = mod.get_outputs()[0].asnumpy()[0] 77 | re= prob.transpose(1,2,0) 78 | return re 79 | 80 | re = autoSegmentWholeImg(img, args.simgShape, handleSimg, step=args.step, weightCore='gauss') 81 | return re 82 | 83 | if __name__ == '__main__': 84 | pass 85 | from test import * -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | 4 | import sys,os 5 | import numpy as np 6 | import lib 7 | from lib import dicto, glob, getArgvDic, findints,pathjoin 8 | from lib import show, loga, logl, imread, imsave 9 | from lib import Evalu,diceEvalu 10 | from lib import * 11 | from configManager import (getImgGtNames, indexOf, readgt, readimg, 12 | setMod, togt, toimg, makeValEnv, doc) 13 | from train import c, cf, args 14 | 15 | setMod('train') 16 | setMod('test') 17 | 18 | args.out = pathjoin(c.tmpdir,'val/tif-4-22') 19 | 20 | # ============================================================================= 21 | # config BEGIN 22 | # ============================================================================= 23 | args.update( 24 | restore=-1, 25 | # restore=34, 26 | # step=None, 27 | step=.2, 28 | ) 29 | # ============================================================================= 30 | # config END 31 | # ============================================================================= 32 | 33 | 34 | 35 | if args.restore == -1: 36 | pas = [p[len(args.prefix):] for p in glob(args.prefix+'*')] 37 | args.restore = len(pas) and max(map(lambda s:len(findints(s)) and findints(s)[-1],pas)) 38 | 39 | makeValEnv(args) 40 | accEvalu = lambda re,gt:{'acc':(re==gt).sum()*1./re.size,'loss':(~(re==gt)).sum()*1./re.size} 41 | 42 | colors = np.array([(255, 255, 255),(0, 0, 255),(0, 255, 255),(0, 255, 0), 43 | (255, 255, 0), (255, 0, 0),])/255. 44 | 45 | import predictInterface 46 | c.predictInterface = predictInterface 47 | if __name__ == '__main__': 48 | import predictInterface 49 | c.predictInterface = predictInterface 50 | predict = predictInterface.predict 51 | # c.predict = predict 52 | e = Evalu(accEvalu, 53 | # evaluName='restore-%s'%restore, 54 | valNames=c.names, 55 | # loadcsv=1, 56 | logFormat='acc:{acc:.3f}, loss:{loss:.3f}', 57 | sortkey='loss', 58 | # loged=False, 59 | saveResoult=False, 60 | ) 61 | # c.names.sort(key=lambda x:readgt(x).shape[0]) 62 | for name in c.names[::1]: 63 | img,gt = readimg(name),readgt 64 | prob = predict((name)) 65 | re = prob.argmax(2) 66 | # e.evalu(re,gt,name) 67 | gt = re 68 | gtc = labelToColor(gt,colors) 69 | rec = labelToColor(re,colors) 70 | show(img[::10,::10],gtc[::10,::10],(gt!=re)[::10,::10],rec[::10,::10]) 71 | # diff = binaryDiff(re,gt) 72 | # show(img,diff,re) 73 | # show(img,diff) 74 | # show(diff) 75 | # yellowImg=gt[...,None]*img+(npa-[255,255,0]).astype(np.uint8)*~gt[...,None] 76 | # show(yellowImg,diff) 77 | # imsave(pathjoin(args.out,name+'.tif'),uint8(rec)) 78 | imsave(pathjoin(args.out,name+'.tif'),uint8(rec)) 79 | 80 | # print args.restore,e.loss.mean() 81 | 82 | 83 | #map(lambda n:show(readimg(n),e[n],readgt(n)),e.low(80).index[:]) 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | 4 | import sys,os 5 | import numpy as np 6 | import lib 7 | from lib import dicto, glob, getArgvDic, findints,pathjoin 8 | from lib import show, loga, logl, imread, imsave 9 | from lib import Evalu,diceEvalu 10 | from lib import * 11 | from configManager import (getImgGtNames, indexOf, readgt, readimg, 12 | setMod, togt, toimg, makeValEnv, doc) 13 | from train import c, cf, args 14 | 15 | setMod('val') 16 | #setMod('test') 17 | args.out = pathjoin(c.tmpdir,'val/png') 18 | 19 | # ============================================================================= 20 | # config BEGIN 21 | # ============================================================================= 22 | args.update( 23 | # restore=43, 24 | restore=28, 25 | step=None, 26 | # step=.2, 27 | ) 28 | # ============================================================================= 29 | # config END 30 | # ============================================================================= 31 | 32 | 33 | 34 | if args.restore == -1: 35 | pas = [p[len(args.prefix):] for p in glob(args.prefix+'*')] 36 | args.restore = len(pas) and max(map(lambda s:len(findints(s)) and findints(s)[-1],pas)) 37 | 38 | makeValEnv(args) 39 | accEvalu = lambda re,gt:{'acc':(re==gt).sum()*1./re.size,'loss':(~(re==gt)).sum()*1./re.size} 40 | 41 | colors = np.array([(255, 255, 255),(0, 0, 255),(0, 255, 255),(0, 255, 0), 42 | (255, 255, 0), (255, 0, 0),])/255. 43 | 44 | import predictInterface 45 | c.predictInterface = predictInterface 46 | if __name__ == '__main__': 47 | import predictInterface 48 | c.predictInterface = predictInterface 49 | predict = predictInterface.predict 50 | # c.predict = predict 51 | e = Evalu(accEvalu, 52 | # evaluName='restore-%s'%restore, 53 | valNames=c.names, 54 | # loadcsv=1, 55 | logFormat='acc:{acc:.3f}, loss:{loss:.3f}', 56 | sortkey='loss', 57 | # loged=False, 58 | saveResoult=False, 59 | ) 60 | c.names.sort(key=lambda x:readgt(x).shape[0]) 61 | for name in c.names[:]: 62 | img,gt = readimg(name),readgt(name) 63 | prob = predict((name)) 64 | re = prob.argmax(2) 65 | e.evalu(re,gt,name) 66 | 67 | gtc = labelToColor(gt,colors) 68 | rec = labelToColor(re,colors) 69 | show(img[::10,::10],gtc[::10,::10],(gt!=re)[::10,::10],rec[::10,::10]) 70 | # diff = binaryDiff(re,gt) 71 | # show(img,diff,re) 72 | # show(img,diff) 73 | # show(diff) 74 | # yellowImg=gt[...,None]*img+(npa-[255,255,0]).astype(np.uint8)*~gt[...,None] 75 | # show(yellowImg,diff) 76 | # imsave(pathjoin(args.out,name+'.png'),uint8(re)) 77 | 78 | print args.restore,e.loss.mean() 79 | 80 | 81 | #map(lambda n:show(readimg(n),e[n],readgt(n)),e.low(80).index[:]) 82 | 83 | 84 | 85 | 86 | 87 | 88 | class ArgList(list): 89 | ''' 90 | 标记类 用于标记需要被autoFindBestParams函数迭代的参数列表 91 | ''' 92 | pass 93 | 94 | 95 | def autoFindBestParams(c, args,evaluFun,sortkey=None,savefig=False): 96 | '''遍历args里面 ArgList的所有参数组合 并通过sortkey 找出最佳参数组合 97 | 98 | Parameters 99 | ---------- 100 | c : dicto 101 | 即configManager 生成的测试集的所有环境配置 c 102 | 包含args,数据配置,各类函数等 103 | args : dicto 104 | predict的参数,但需要包含 ArgList 类 将遍历ArgList的所有参数组合 并找出最佳参数组合 105 | evaluFun : Funcation 106 | 用于评测的函数,用于Evalu类 需要返回dict对象 107 | sortkey : str, default None 108 | 用于筛选时候的key 默认为df.columns[-1] 109 | 110 | Return: DataFrame 111 | 每个参数组合及其评价的平均值 112 | ''' 113 | iters = filter(lambda it:isinstance(it[1],ArgList),args.items()) 114 | iters = sorted(iters,key=lambda x:len(x[1]),reverse=True) 115 | argsraw = args.copy() 116 | argsl = [] 117 | args = dicto() 118 | 119 | k,vs = iters[0] 120 | lenn = len(iters) 121 | deep = 0 122 | tags = [0,]*lenn 123 | while deep>=0: 124 | vs = iters[deep][1] 125 | ind = tags[deep] 126 | if ind != len(vs): 127 | v = vs[ind] 128 | tags[deep]+=1 129 | key = iters[deep][0] 130 | args[key] = v 131 | if deep == lenn-1: 132 | argsl.append(args.copy()) 133 | else: 134 | deep+=1 135 | else: 136 | tags[deep:]=[0]*(lenn-deep) 137 | deep -= 1 138 | assert len(argsl),"args don't have ArgList Values!!" 139 | pds,pddf = pd.Series, pd.DataFrame 140 | edic={} 141 | for arg in argsl: 142 | argsraw.update(arg) 143 | c.args.update(argsraw) 144 | e = Evalu(evaluFun, 145 | evaluName='tmp', 146 | sortkey=sortkey, 147 | loged=False, 148 | saveResoult=False, 149 | ) 150 | reload(c.predictInterface) 151 | predict = c.predictInterface.predict 152 | for name in c.names[::]: 153 | gt = c.readgt(name) 154 | prob = predict((name)) 155 | re = prob.argmax(2) 156 | # from yllab import g 157 | # g.re,g.gt = re,gt 158 | e.evalu(re,gt,name) 159 | # img = readimg(name) 160 | # show(re,gt) 161 | # show(img) 162 | if sortkey is None: 163 | sortkey = e.columns[-1] 164 | keys = tuple(arg.values()) 165 | for k,v in arg.items(): 166 | e[k] = v 167 | edic[keys] = e 168 | print 'arg: %s\n'%str(arg), e.mean() 169 | es = pddf(map(lambda x:pds(x.mean()), edic.values())) 170 | print '-'*20+'\nmax %s:\n'%sortkey,es.loc[es[sortkey].argmax()] 171 | print '\nmin %s:\n'%sortkey,es.loc[es[sortkey].argmin()] 172 | if len(iters) == 1: 173 | k = iters[0][0] 174 | import matplotlib.pyplot as plt 175 | df = es.copy() 176 | df = df.sort_values(k) 177 | plt.plot(df[k],df[sortkey],'--');plt.plot(df[k],df[sortkey],'rx') 178 | plt.xlabel(k);plt.ylabel(sortkey);plt.grid() 179 | if savefig: 180 | plt.savefig(savefig); 181 | plt.close() 182 | else: 183 | plt.show() 184 | return es 185 | 186 | def autoFindBestEpoch(c, evaluFun,sortkey=None,epochs=None,savefig=False): 187 | '''遍历所有epoch的weight 并通过测试集评估项sortkey 找出最佳epoch 188 | 189 | Parameters 190 | ---------- 191 | c : dicto 192 | 即configManager 生成的测试集的所有环境配置 c 193 | 包含args,数据配置,各类函数等 194 | evaluFun : Funcation 195 | 用于评测的函数,用于Evalu类 需要返回dict对象 196 | sortkey : str, default None 197 | 用于筛选时候的key 默认为df.columns[-1] 198 | 199 | Return: DataFrame 200 | 每个参数组合及其评价的平均值 201 | ''' 202 | args = c.args 203 | if not isinstance(epochs,(tuple,list)) : 204 | pas = [p[len(args.prefix):] for p in glob(args.prefix+'*') if p[-4:]!='json'] 205 | eps = map(lambda s:len(findints(s)) and findints(s)[-1],pas) 206 | maxx = len(eps) and max(eps) 207 | minn = len(eps) and min(eps) 208 | if isinstance(epochs,int): 209 | epochs = range(minn,maxx)[::epochs]+[maxx] 210 | else: 211 | epochs = range(minn,maxx+1) 212 | args['restore'] = ArgList(epochs) 213 | # print epochs 214 | df = autoFindBestParams(c, args, evaluFun,sortkey=sortkey,savefig=savefig) 215 | return df 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /netdef.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import logging 3 | 4 | logging.basicConfig(level=logging.INFO) 5 | 6 | import mxnet as mx 7 | 8 | 9 | 10 | 11 | def bottleneck(inputs, k): 12 | x = mx.sym.BatchNorm(data=inputs, momentum=0.99) 13 | x = mx.sym.Activation(data=x, act_type='relu') 14 | x = mx.sym.Convolution(data=x, kernel=(1,1), stride=(1,1), num_filter=k*4) 15 | x = mx.sym.Dropout(x, p=0.2) 16 | return x 17 | 18 | def composite_function(inputs, dilate): 19 | x = mx.sym.BatchNorm(data=inputs, momentum=0.99) 20 | x = mx.sym.Activation(data=x, act_type='relu') 21 | x = mx.sym.Convolution(data=x, kernel=(3,3), stride=(1,1), pad=dilate, num_filter=k, dilate=dilate) 22 | x = mx.sym.Dropout(x, p=0.2) 23 | return x 24 | 25 | def composite_function_bottleneck(inputs, dilate): 26 | x = bottleneck(inputs, k) 27 | x = composite_function(x, dilate) 28 | return x 29 | 30 | def transition(inputs): 31 | x = mx.sym.BatchNorm(data=inputs, momentum=0.99) 32 | x = mx.sym.Convolution(data=x, kernel=(1,1), stride=(1,1), pad=(0,0), num_filter=k) 33 | # x = mx.sym.Dropout(x, p=0.2) 34 | return x 35 | 36 | def dense_block(inputs, dilate): 37 | x1 = composite_function(inputs, dilate) 38 | x2 = composite_function(mx.sym.concat(inputs, x1, dim=1), dilate) 39 | x3 = composite_function(mx.sym.concat(inputs, x1, x2, dim=1), dilate) 40 | x4 = composite_function(mx.sym.concat(inputs, x1, x2, x3, dim=1), dilate) 41 | return mx.sym.concat(x1, x2, x3, x4, dim=1) 42 | 43 | def conv(data, kernel=(3, 3), stride=(1, 1), pad=(0, 0), num_filter=None, name=None): 44 | return mx.sym.Convolution(data=data, kernel=kernel, stride=stride, pad=pad, num_filter=num_filter, name='conv_{}'.format(name)) 45 | 46 | 47 | def bn_relu(data, name): 48 | return mx.sym.Activation(data=mx.sym.BatchNorm(data=data, momentum=0.99, name='bn_{}'.format(name)), act_type='relu', name='relu_{}'.format(name)) 49 | 50 | 51 | def conv_bn_relu(data, kernel=(3, 3), stride=(1, 1), pad=(0, 0), num_filter=None, name=None): 52 | return bn_relu(conv(data, kernel, stride, pad, num_filter, 'conv_{}'.format(name)), 'relu_{}'.format(name)) 53 | 54 | 55 | def down_block(data, f, name): 56 | x = mx.sym.Pooling(data=data, kernel=(2,2), stride=(2,2), pool_type='max') 57 | # temp = conv_bn_relu(data, (3, 3), (2, 2), (1, 1), 58 | # f, 'layer1_{}'.format(name)) 59 | x = conv_bn_relu(x, (3, 3), (1, 1), (1, 1), 60 | f, 'layer2_{}'.format(name)) 61 | bn = mx.sym.BatchNorm(data=conv(x, (3, 3), (1, 1), (1, 1), f, 'layer3_{}'.format( 62 | name)), momentum=0.99, name='layer3_bn_{}'.format(name)) 63 | bn = bn + x 64 | act = mx.sym.Activation(data=bn, act_type='relu', 65 | name='layer3_relu_{}'.format(name)) 66 | return bn, act 67 | 68 | 69 | def up_block(act, bn, f, p, name): 70 | x = mx.sym.UpSampling( 71 | act, num_filter=p, scale=2, sample_type='bilinear', name='upsample_{}'.format(name)) 72 | # temp = mx.sym.Deconvolution(data=act, kernel=(3, 3), stride=(2, 2), pad=( 73 | # 1, 1), adj=(1, 1), num_filter=32, name='layer1_dconv_{}'.format(name)) 74 | x = mx.sym.concat(bn, x, dim=1) 75 | x = conv_bn_relu(x, (1,1), (1,1), (0,0), f, 'layer_1x1_{}'.format(name)) 76 | temp = conv_bn_relu(x, (3, 3), (1, 1), (1, 1), 77 | f, 'layer2_{}'.format(name)) 78 | bn = mx.sym.BatchNorm(data=conv(temp, (3, 3), (1, 1), (1, 1), f, 'layer3_{}'.format( 79 | name)), momentum=0.99, name='layer3_bn_{}'.format(name)) 80 | bn = bn + x 81 | return mx.sym.Activation(data=bn, act_type='relu', name='layer3_relu_{}'.format(name)) 82 | 83 | k = 2 84 | def getNet(n): 85 | global k 86 | k = n 87 | data = mx.sym.Variable('data') 88 | global rawData 89 | rawData = data 90 | x = conv_bn_relu(data, (3, 3), (1, 1), (1, 1), 64, 'conv0_1') 91 | net = conv_bn_relu(x, (3, 3), (1, 1), (1, 1), 64, 'conv0_2') 92 | bn1 = mx.sym.BatchNorm(data=conv( 93 | net, (3, 3), (1, 1), (1, 1), 64, 'conv0_3'), momentum=0.99, name='conv0_3_bn') 94 | bn1 = bn1 + x 95 | act1 = mx.sym.Activation(data=bn1, act_type='relu', name='conv0_3_relu') 96 | global ACT1 97 | # ACT1 = resnextBlock(act1,16,(1,1),False,getLayerName('short'),4) 98 | ACT1 = act1 99 | 100 | bn2, act2 = down_block(act1, 128, 'down1') 101 | bn3, act3 = down_block(act2, 256, 'down2') 102 | bn4, act4 = down_block(act3, 512, 'down3') 103 | bn5, act5 = down_block(act4, 512, 'down4') 104 | bn6, act6 = down_block(act5, 512, 'down5') 105 | 106 | bn7, act7 = down_block(act6, 512, 'down6') 107 | 108 | temp = up_block(act7, bn6, 512, 512, 'up6') 109 | temp = up_block(temp, bn5, 512, 512, 'up5') 110 | temp = up_block(temp, bn4, 256, 512, 'up4') 111 | temp = up_block(temp, bn3, 128, 256, 'up3') 112 | temp = up_block(temp, bn2, 64, 128, 'up2') 113 | temp = up_block(temp, bn1, 32, 64, 'up1') 114 | score1 = conv(temp, (1, 1), (1, 1), (0, 0), 6, 'score1') 115 | net1 = mx.sym.SoftmaxOutput(score1, multi_output=True, name='softmax1') 116 | 117 | from yllab import load_data 118 | net1 = confusionTree(inputt=temp,tree=load_data('confusionTree')) 119 | return net1 120 | 121 | __NAME_COUNT__ = {} 122 | def getLayerName(name="None"): 123 | n = __NAME_COUNT__.get(name,0) 124 | n = n + 1 125 | __NAME_COUNT__[name] = n 126 | return name+'_%s'%n 127 | 128 | 129 | def resnextBlock(data,num_filter, stride, dim_match, name, 130 | num_group=32, bn_mom=0.9, workspace=256,): 131 | conv1 = mx.sym.Convolution(data=data, num_filter=int(num_filter*0.5), kernel=(1,1), stride=(1,1), pad=(0,0), 132 | no_bias=True, workspace=workspace, name=name + '_conv1') 133 | bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1') 134 | act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') 135 | 136 | 137 | conv2 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.5), num_group=num_group, kernel=(3,3), stride=stride, pad=(1,1), 138 | no_bias=True, workspace=workspace, name=name + '_conv2') 139 | bn2 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2') 140 | act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') 141 | 142 | 143 | conv3 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True, 144 | workspace=workspace, name=name + '_conv3') 145 | bn3 = mx.sym.BatchNorm(data=conv3, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3') 146 | 147 | if dim_match: 148 | shortcut = data 149 | else: 150 | shortcut_conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, 151 | workspace=workspace, name=name+'_sc') 152 | shortcut = mx.sym.BatchNorm(data=shortcut_conv, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_sc_bn') 153 | eltwise = bn3 + shortcut 154 | return mx.sym.Activation(data=eltwise, act_type='relu', name=name + '_relu') 155 | 156 | def pipe(data,filters): 157 | if filters == 1: 158 | out = conv_bn_relu(data, (3, 3), (1, 1), (1, 1), 159 | 1, getLayerName('conv_bn_relu')) 160 | return out 161 | fs = filters*4 #8--->16 162 | global ACT1 163 | data = mx.sym.concat(data,ACT1, dim=1) 164 | out = resnextBlock(data,fs,(1,1),False,getLayerName('resNext'),min(fs//4,32)) 165 | return out 166 | 167 | def pipe2(inp,filters): 168 | if filters == 1: 169 | out = conv_bn_relu(inp, (3, 3), (1, 1), (1, 1), 170 | 1, getLayerName('conv_bn_relu')) 171 | return out 172 | layer1,layer2 = 5,3 173 | out1 = conv_bn_relu(inp, (3, 3), (1, 1), (1, 1), 174 | filters*layer1, getLayerName('conv_bn_relu')) 175 | out = conv_bn_relu(out1, (3, 3), (1, 1), (1, 1), 176 | filters*layer2, getLayerName('conv_bn_relu')) 177 | # out = conv(inp, (3, 3), (1, 1), (1, 1), 178 | # filters, getLayerName('conv_bn_relu')) 179 | out = mx.sym.concat(inp,out, dim=1) 180 | return out 181 | 182 | def confusionTree(inputt=None,tree=None): 183 | if inputt is None: 184 | inputt = mx.sym.Variable('data') 185 | classn = sum(map(len,tree.keys())) 186 | probs = [0]*classn 187 | 188 | def walkTree(inp, tree, key): 189 | out = pipe(inp,len(key)) 190 | if len(key) == 1: 191 | probs[key[0]]=out 192 | else: 193 | for k,v in tree.items(): 194 | walkTree(out,v,k) 195 | walkTree(inputt,tree,tuple(range(classn))) 196 | out = mx.sym.concat(*probs, dim=1) 197 | net = mx.sym.SoftmaxOutput(out, multi_output=True, name='softmax1') 198 | return net 199 | 200 | if __name__ == '__main__': 201 | pass 202 | # from yllab import * 203 | net = getNet(6) 204 | # net = confusionTree(tree=tre) 205 | mx.viz.plot_network(net, save_format='pdf', shape={ 206 | 'data': (1, 3, 640, 640), 207 | 'softmax1_label': (1, 640, 640), }).render('TresegNet-short') -------------------------------------------------------------------------------- /trainInterface.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | ''' 3 | res-unet1-simg 4 | 取小图训练 5 | 6 | Parameters 7 | ---------- 8 | step : int 9 | 填充黑边 将图片shape 调整为step的整数倍 10 | ''' 11 | from yllab import * 12 | from lib import * 13 | import logging 14 | logging.basicConfig(level=logging.INFO) 15 | npm = lambda m:m.asnumpy() 16 | npm = FunAddMagicMethod(npm) 17 | 18 | import mxnet as mx 19 | import random 20 | from netdef import getNet 21 | 22 | if __name__ == '__main__': 23 | from configManager import c 24 | from train import args 25 | 26 | else: 27 | from configManager import args,c 28 | 29 | class SimpleBatch(object): 30 | def __init__(self, data, label, pad=0): 31 | self.data = data 32 | self.label = label 33 | self.pad = pad 34 | 35 | 36 | labrgb = lambda lab:cv2.cvtColor(lab,cv2.COLOR_LAB2RGB) 37 | randint = lambda x:np.random.randint(-x,x) 38 | def imgAug(image,gt,prob=.5): 39 | if random.random() > prob: 40 | image = np.fliplr(image) 41 | gt = np.fliplr(gt) 42 | if random.random() > prob: 43 | image = np.flipud(image) 44 | gt = np.flipud(gt) 45 | return image,gt 46 | 47 | def handleImgGt(imgs, gts,): 48 | for i in range(len(imgs)): 49 | # if np.random.randint(2): 50 | # imgs[i] = np.fliplr(imgs[i]) 51 | # gts[i] = np.fliplr(gts[i]) 52 | # if np.random.randint(2): 53 | # imgs[i] = np.flipud(imgs[i]) 54 | # gts[i] = np.flipud(gts[i]) 55 | imgs[i],gts[i] = imgAug(imgs[i],gts[i]) 56 | if args.classn ==2: 57 | gts = gts >.5 58 | g.im=imgs;g.gt =gts 59 | imgs = imgs.transpose(0,3,1,2)/255. 60 | mximgs = map(mx.nd.array,[imgs]) 61 | mxgtss = map(mx.nd.array,[gts]) 62 | mxdata = SimpleBatch(mximgs,mxgtss) 63 | return mxdata 64 | 65 | def readChannel(name, basenames=None): 66 | # kinds = ['_RGB.tif','_IRRG.tif','_lastools.jpg'] 67 | kinds = ['_RGB.tif','_IRRG.tif','_dsm.tif'] 68 | dirr = dirname(c['trainGlob']) 69 | if not basenames: 70 | basenames = kinds 71 | imgs = [] 72 | if kinds[0] in basenames: 73 | path = pathjoin(dirr,name+kinds[0]) 74 | img = imread(path) 75 | 76 | hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) 77 | hsv = hsv.astype(np.int32) 78 | # adjust brightness 79 | hsv[:, :, 2] += random.randint(-15, 15) 80 | # adjust saturation 81 | hsv[:, :, 1] += random.randint(-10, 10) 82 | # adjust hue 83 | hsv[:, :, 0] += random.randint(-5, 5) 84 | hsv = np.clip(hsv, 0, 255) 85 | hsv = hsv.astype(np.uint8) 86 | img = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) 87 | 88 | imgs.append(img) 89 | if kinds[1] in basenames: 90 | path = pathjoin(dirr,name+kinds[1]) 91 | img = imread(path) 92 | imgs.append(img[...,:1]) 93 | if kinds[2] in basenames: 94 | path = pathjoin(dirr,name+kinds[2]) 95 | img = imread(path) 96 | imgs.append(img[...,None]) 97 | if len(imgs) == 1: 98 | return imgs[0] 99 | mimg = reduce(lambda x,y:np.append(x,y,2),imgs) 100 | return mimg 101 | from collections import Iterator 102 | class GenSimg(Iterator): 103 | ''' 104 | 随机生成小图片simg及gt 的迭代器,默认使用1Gb内存作为图片缓存 105 | 默认生成simg总面积≈所有图像总面积时 即结束 106 | ''' 107 | def __init__(self, imggts, simgShape, handleImgGt=None, 108 | batch=1, cache=None,iters=None, 109 | timesPerRead=1,infinity=False): 110 | ''' 111 | imggts: zip(jpgs,pngs) 112 | simgShape: simg的shape 113 | handleImgGt: 对输出结果运行handleImgGt(img,gt)处理后再返回 114 | batch: 每次返回的batch个数 115 | cache: 缓存图片数目, 默认缓存1Gb的数目 116 | timesPerRead: 平均每次读的图片使用多少次(不会影响总迭代次数),默认1次 117 | iters: 固定输出小图片的总数目,与batch无关 118 | infinity: 无限迭代 119 | ''' 120 | if isinstance(simgShape,int): 121 | simgShape = (simgShape,simgShape) 122 | self.handleImgGt = handleImgGt 123 | self.imggts = imggts 124 | self.simgShape = simgShape 125 | self.batch = batch 126 | self._iters = iters 127 | self.iters = self._iters 128 | self.infinity = infinity 129 | 130 | hh,ww = simgShape 131 | jpg,png = imggts[0] 132 | img = readChannel(jpg) 133 | h,w = img.shape[:2] 134 | if cache is None: 135 | cache = max(1,int(5e9/img.nbytes)) 136 | cache = min(cache,len(imggts)) 137 | self.maxPerCache = int(cache*(h*w)*1./(hh*ww))* timesPerRead/batch 138 | self.cache = cache 139 | self.n = len(imggts) 140 | self._times = max(1,int(round(self.n*1./cache/timesPerRead))) 141 | self.times = self._times 142 | self.totaln = self.sn = iters or int((h*w)*self.n*1./(hh*ww)) 143 | self.willn = iters or self.maxPerCache*self.times*batch 144 | self.count = 0 145 | self.reset() 146 | 147 | self.bytes = img.nbytes 148 | argsStr = '''imggts=%s pics in dir: %s, 149 | simgShape=%s, 150 | handleImgGt=%s, 151 | batch=%s, cache=%s,iters=%s, 152 | timesPerRead=%s, infinity=%s'''%(self.n , os.path.dirname(jpg) or './', simgShape, handleImgGt, 153 | batch, cache,iters, 154 | timesPerRead,infinity) 155 | generatorStr = '''maxPerCache=%s, readTimes=%s 156 | Will generator maxPerCache*readTimes*batch=%s'''%(self.maxPerCache, self.times, 157 | self.willn) 158 | if iters: 159 | generatorStr = 'Will generator iters=%s'%iters 160 | self.__describe = '''GenSimg(%s) 161 | 162 | Total imgs Could generator %s simgs, 163 | %s simgs. 164 | '''%(argsStr,self.totaln, 165 | generatorStr,) 166 | def reset(self): 167 | if (self.times<=0 and self.iters is None) and not self.infinity: 168 | self.times = self._times 169 | raise StopIteration 170 | self.now = self.maxPerCache 171 | inds = np.random.choice(range(len(self.imggts)),self.cache,replace=False) 172 | datas = {} 173 | for ind in inds: 174 | jpg,png = self.imggts[ind] 175 | img,gt = readChannel(jpg),imread(png) 176 | datas[jpg] = img,gt 177 | self.data = self.datas = datas 178 | self.times -= 1 179 | def next(self): 180 | self.count += 1 181 | if (self.iters is not None) and not self.infinity: 182 | if self.iters <= 0: 183 | self.iters = self._iters 184 | raise StopIteration 185 | self.iters -= self.batch 186 | if self.now <= 0: 187 | self.reset() 188 | self.now -= 1 189 | hh,ww = self.simgShape 190 | datas = self.datas 191 | imgs, gts = [], [] 192 | for t in range(self.batch): 193 | img,gt = datas[np.random.choice(datas.keys(),1,replace=False)[0]] 194 | h,w = img.shape[:2] 195 | i= np.random.randint(h-hh+1) 196 | j= np.random.randint(w-ww+1) 197 | (img,gt) = img[i:i+hh,j:j+ww],gt[i:i+hh,j:j+ww] 198 | imgs.append(img), gts.append(gt) 199 | (imgs,gts) = map(np.array,(imgs,gts)) 200 | if self.handleImgGt: 201 | return self.handleImgGt(imgs,gts) 202 | return (imgs,gts) 203 | @property 204 | def imgs(self): 205 | return [img for img,gt in self.datas.values()] 206 | @property 207 | def gts(self): 208 | return [gt for img,gt in self.datas.values()] 209 | def __str__(self): 210 | batch = self.batch 211 | n = len(self.datas) 212 | return self.__describe + \ 213 | ''' 214 | status: 215 | iter in %s/%s(%.2f) 216 | batch in %s/%s(%.2f) 217 | cache imgs: %s 218 | cache size: %.2f MB 219 | '''%(self.count*batch,self.willn,self.count*1.*batch/self.willn, 220 | self.count,self._times*self.maxPerCache, 221 | self.count*1./(self._times*self.maxPerCache), 222 | n, (n*self.bytes/2**20)) 223 | 224 | __repr__ = __str__ 225 | class GenSimgInMxnet(GenSimg): 226 | @property 227 | def provide_data(self): 228 | return [('data', (args.batch, 5, args.simgShape[0], args.simgShape[1]))] 229 | @property 230 | def provide_label(self): 231 | return [('softmax1_label', (args.batch, args.simgShape[0], args.simgShape[1])),] 232 | 233 | 234 | def saveNow(name = None): 235 | f=mx.callback.do_checkpoint(name or args.prefix) 236 | f(-1,mod.symbol,*mod.get_params()) 237 | 238 | 239 | 240 | 241 | default = dicto( 242 | gpu = 2, 243 | lr = 0.01, 244 | epochSize = 10000, 245 | step=64, 246 | window=64*2, 247 | classn=3 248 | ) 249 | 250 | for k in default.keys(): 251 | if k not in args: 252 | args[k] = default[k] 253 | 254 | args.names = zip(c.names,map(c.togt,c.names)) 255 | 256 | args.simgShape = args.window 257 | if not isinstance(args.window,(tuple,list,np.ndarray)): 258 | args.simgShape = (args.window,args.window) 259 | 260 | net = getNet(args.classn) 261 | 262 | if args.resume: 263 | print('resume training from epoch {}'.format(args.resume)) 264 | _, arg_params, aux_params = mx.model.load_checkpoint( 265 | args.prefix, args.resume) 266 | else: 267 | arg_params = None 268 | aux_params = None 269 | 270 | if 'plot' in args: 271 | mx.viz.plot_network(net, save_format='pdf', shape={ 272 | 'data': (1, 5, 640, 640), 273 | 'softmax1_label': (1, 640, 640), }).render(args.prefix) 274 | exit(0) 275 | mod = mx.mod.Module( 276 | symbol=net, 277 | context=[mx.gpu(k) for k in range(args.gpu)] if args.gpu!=1 else [mx.gpu(1)], 278 | data_names=('data',), 279 | label_names=('softmax1_label',) 280 | ) 281 | c.mod = mod 282 | 283 | #if 0: 284 | gen = GenSimgInMxnet(args.names, args.simgShape, 285 | handleImgGt=handleImgGt, 286 | batch=args.batch, 287 | # cache=len(args.names), 288 | iters=args.epochSize 289 | ) 290 | #gen = GenSimgInMxnet(args.names,c.batch,handleImgGt=imgGtAdd0Fill(c.step)) 291 | g.gen = gen 292 | total_steps = gen.totaln * args.epoch / gen.batch 293 | lr_sch = mx.lr_scheduler.MultiFactorScheduler( 294 | step=[total_steps // 5 *1 ,total_steps // 5 *2 ,total_steps // 5 *3 ,total_steps // 5 * 4,int(total_steps / 5. * 4.5),], factor=0.1) 295 | class Lrs(mx.lr_scheduler.MultiFactorScheduler): 296 | def __init__(self,*l,**kv): 297 | mx.lr_scheduler.MultiFactorScheduler.__init__(self,*l,**kv) 298 | self.num_update=None 299 | def __call__(self,num_update): 300 | lr = mx.lr_scheduler.MultiFactorScheduler.__call__(self,num_update) 301 | if self.num_update != num_update: 302 | stdout('\rstep:%s, lr:%s, '%(num_update, lr)) 303 | self.num_update = num_update 304 | return lr 305 | 306 | #lr_sch = lambda x:(log('\r %s, '%x) and 0.01) 307 | #lr_sch = Lrs( 308 | # step=[total_steps // 5 *2 ,total_steps // 5 *3 ,total_steps // 5 * 4,int(total_steps / 5. * 4.5),], factor=0.1) 309 | lr_sch = Lrs( 310 | step=[total_steps // 2, total_steps * 3// 4 , total_steps*15//16], factor=0.1) 311 | 312 | def train(): 313 | mod.fit( 314 | gen, 315 | begin_epoch=args.resume, 316 | arg_params=arg_params, 317 | aux_params=aux_params, 318 | batch_end_callback=mx.callback.Speedometer(args.batch), 319 | epoch_end_callback=mx.callback.do_checkpoint(args.prefix), 320 | optimizer='sgd', 321 | optimizer_params=(('learning_rate', args.lr), ('momentum', 0.9), 322 | ('lr_scheduler', lr_sch), ('wd', 0.0005)), 323 | num_epoch=args.epoch) 324 | if __name__ == '__main__': 325 | pass 326 | 327 | 328 | if 0: 329 | #%% 330 | ne = g.gen.next() 331 | #for ne in dd: 332 | ds,las = ne.data, ne.label 333 | d,la = npm-ds[0],npm-las[0] 334 | im = d.transpose(0,2,3,1) 335 | show(labrgb(uint8(im[0])));show(la) 336 | --------------------------------------------------------------------------------