├── FracTAL_ResUNet ├── __init__.py ├── models │ ├── __init__.py │ ├── heads │ │ ├── __init__.py │ │ └── head_cmtsk.py │ └── semanticsegmentation │ │ ├── FracTAL_ResUNet.py │ │ ├── FracTAL_ResUNet_features.py │ │ └── __init__.py ├── nn │ ├── __init__.py │ ├── activations │ │ ├── __init__.py │ │ └── sigmoid_crisp.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── combine.py │ │ ├── conv2Dnormed.py │ │ ├── ftnmt.py │ │ └── scale.py │ ├── loss │ │ ├── __init__.py │ │ ├── ftnmt_loss.py │ │ └── mtsk_loss.py │ ├── pooling │ │ ├── __init__.py │ │ └── psp_pooling.py │ └── units │ │ ├── __init__.py │ │ └── fractal_resnet.py ├── src │ ├── __init__.py │ └── semseg_aug_cv2.py └── utils │ ├── __init__.py │ └── get_norm.py ├── README.md ├── examples ├── Demo_forward_backward.ipynb └── demo_instance_segmentation.ipynb ├── images └── decode.png ├── postprocessing └── instance_segmentation.py └── requirements.txt /FracTAL_ResUNet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/models/heads/head_cmtsk.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | 5 | from decode.FracTAL_ResUNet.nn.activations.sigmoid_crisp import * 6 | from decode.FracTAL_ResUNet.nn.pooling.psp_pooling import * 7 | from decode.FracTAL_ResUNet.nn.layers.conv2Dnormed import * 8 | 9 | # Helper classification head, for a single layer output 10 | class HeadSingle(HybridBlock): 11 | def __init__(self, nfilters, NClasses, depth=2, norm_type='BatchNorm',norm_groups=None, **kwargs): 12 | super().__init__(**kwargs) 13 | 14 | 15 | with self.name_scope(): 16 | self.logits = gluon.nn.HybridSequential() 17 | for _ in range(depth): 18 | self.logits.add( Conv2DNormed(channels = nfilters,kernel_size = (3,3),padding=(1,1), _norm_type=norm_type, norm_groups=norm_groups)) 19 | self.logits.add( gluon.nn.Activation('relu')) 20 | self.logits.add( gluon.nn.Conv2D(NClasses,kernel_size=1,padding=0)) 21 | 22 | def hybrid_forward(self,F,input): 23 | return self.logits(input) 24 | 25 | 26 | 27 | class Head_CMTSK_BC(HybridBlock): 28 | # BC: Balanced (features) Crisp (boundaries) 29 | def __init__(self, _nfilters_init, _NClasses, norm_type = 'BatchNorm', norm_groups=None, **kwards): 30 | super().__init__() 31 | 32 | self.model_name = "Head_CMTSK_BC" 33 | 34 | self.nfilters = _nfilters_init # Initial number of filters 35 | self.NClasses = _NClasses 36 | 37 | 38 | with self.name_scope(): 39 | 40 | 41 | self.psp_2ndlast = PSP_Pooling(self.nfilters, _norm_type = norm_type, norm_groups=norm_groups) 42 | 43 | # bound logits 44 | self.bound_logits = HeadSingle(self.nfilters, self.NClasses, norm_type = norm_type, norm_groups=norm_groups) 45 | self.bound_Equalizer = Conv2DNormed(channels = self.nfilters,kernel_size =1, _norm_type=norm_type, norm_groups=norm_groups) 46 | 47 | # distance logits -- deeper for better reconstruction 48 | self.distance_logits = HeadSingle(self.nfilters, self.NClasses, norm_type = norm_type, norm_groups=norm_groups) 49 | self.dist_Equalizer = Conv2DNormed(channels = self.nfilters,kernel_size =1, _norm_type=norm_type, norm_groups=norm_groups) 50 | 51 | 52 | self.Comb_bound_dist = Conv2DNormed(channels = self.nfilters,kernel_size =1, _norm_type=norm_type, norm_groups=norm_groups) 53 | 54 | 55 | # Segmenetation logits -- deeper for better reconstruction 56 | self.final_segm_logits = HeadSingle(self.nfilters, self.NClasses, norm_type = norm_type, norm_groups=norm_groups) 57 | 58 | 59 | 60 | self.CrispSigm = SigmoidCrisp() 61 | 62 | # Last activation, customization for binary results 63 | if ( self.NClasses == 1): 64 | self.ChannelAct = gluon.nn.HybridLambda(lambda F,x: F.sigmoid(x)) 65 | else: 66 | self.ChannelAct = gluon.nn.HybridLambda(lambda F,x: F.softmax(x,axis=1)) 67 | 68 | def hybrid_forward(self,F, UpConv4, conv1): 69 | 70 | 71 | # second last layer 72 | convl = F.concat(conv1,UpConv4) 73 | conv = self.psp_2ndlast(convl) 74 | conv = F.relu(conv) 75 | 76 | 77 | # logits 78 | 79 | # 1st find distance map, skeleton like, topology info 80 | dist = self.distance_logits(convl) # do not use max pooling for distance 81 | dist = self.ChannelAct(dist) 82 | distEq = F.relu(self.dist_Equalizer(dist)) # makes nfilters equals to conv and convl 83 | 84 | 85 | # Then find boundaries 86 | bound = F.concat(conv, distEq) 87 | bound = self.bound_logits(bound) 88 | bound = self.CrispSigm(bound) # Boundaries are not mutually exclusive 89 | boundEq = F.relu(self.bound_Equalizer(bound)) 90 | 91 | 92 | # Now combine all predictions in a final segmentation mask 93 | # Balance first boundary and distance transform, with the features 94 | comb_bd = self.Comb_bound_dist(F.concat(boundEq, distEq,dim=1)) 95 | comb_bd = F.relu(comb_bd) 96 | 97 | all_layers = F.concat(comb_bd, conv) 98 | final_segm = self.final_segm_logits(all_layers) 99 | final_segm = self.ChannelAct(final_segm) 100 | 101 | 102 | return final_segm, bound, dist 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/models/semanticsegmentation/FracTAL_ResUNet.py: -------------------------------------------------------------------------------- 1 | from decode.FracTAL_ResUNet.models.heads.head_cmtsk import * 2 | from decode.FracTAL_ResUNet.models.semanticsegmentation.FracTAL_ResUNet_features import * 3 | 4 | 5 | # FracTAL_ResUNet + conditioned multitasking. 6 | class FracTAL_ResUNet_cmtsk(HybridBlock): 7 | def __init__(self, nfilters_init, depth, NClasses,widths=[1], psp_depth=4,verbose=True, norm_type='BatchNorm', norm_groups=None,nheads_start=8, upFuse=False, ftdepth=5,**kwards): 8 | super().__init__(**kwards) 9 | 10 | with self.name_scope(): 11 | 12 | self.features = FracTAL_ResUNet_features(nfilters_init=nfilters_init, depth=depth, widths=widths, psp_depth=psp_depth, verbose=verbose, norm_type=norm_type, norm_groups=norm_groups, nheads_start=nheads_start, upFuse=upFuse, ftdepth=ftdepth, **kwards) 13 | self.head = Head_CMTSK_BC(nfilters_init, NClasses, norm_type=norm_type, norm_groups=norm_groups, **kwards) 14 | 15 | def hybrid_forward(self,F,input): 16 | out1, out2= self.features(input) 17 | 18 | return self.head(out1,out2) 19 | 20 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/models/semanticsegmentation/FracTAL_ResUNet_features.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | from decode.FracTAL_ResUNet.nn.layers.conv2Dnormed import * 5 | from decode.FracTAL_ResUNet.nn.layers.attention import * 6 | from decode.FracTAL_ResUNet.nn.pooling.psp_pooling import * 7 | 8 | 9 | from decode.FracTAL_ResUNet.nn.layers.scale import * 10 | from decode.FracTAL_ResUNet.nn.layers.combine import * 11 | 12 | # FracTALResUnit 13 | from decode.FracTAL_ResUNet.nn.units.fractal_resnet import * 14 | 15 | """ 16 | if upFuse == True, then instead of concatenation of the encoder features with the decoder features, the algorithm performs Fusion with 17 | relative attention. 18 | """ 19 | 20 | 21 | class FracTAL_ResUNet_features(HybridBlock): 22 | def __init__(self, nfilters_init, depth, widths=[1], psp_depth=4, verbose=True, norm_type='BatchNorm', norm_groups=None, nheads_start=8, upFuse=False, ftdepth=5, **kwards): 23 | super().__init__(**kwards) 24 | 25 | 26 | self.depth = depth 27 | 28 | 29 | if len(widths) == 1 and depth != 1: 30 | widths = widths * depth 31 | else: 32 | assert depth == len(widths), ValueError("depth and length of widths must match, aborting ...") 33 | 34 | with self.name_scope(): 35 | 36 | self.conv_first = Conv2DNormed(nfilters_init,kernel_size=(1,1), _norm_type = norm_type, norm_groups=norm_groups) 37 | 38 | # List of convolutions and pooling operators 39 | self.convs_dn = gluon.nn.HybridSequential() 40 | self.pools = gluon.nn.HybridSequential() 41 | 42 | 43 | for idx in range(depth): 44 | nheads = nheads_start * 2**idx # 45 | nfilters = nfilters_init * 2 **idx 46 | if verbose: 47 | print ("depth:= {0}, nfilters: {1}, nheads::{2}, widths::{3}".format(idx,nfilters,nheads,widths[idx])) 48 | tnet = gluon.nn.HybridSequential() 49 | for _ in range(widths[idx]): 50 | tnet.add(FracTALResNet_unit(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 51 | self.convs_dn.add(tnet) 52 | 53 | if idx < depth-1: 54 | self.pools.add(DownSample(nfilters, _norm_type=norm_type, norm_groups=norm_groups)) 55 | # Middle pooling operator 56 | self.middle = PSP_Pooling(nfilters,depth=psp_depth, _norm_type=norm_type,norm_groups=norm_groups) 57 | 58 | 59 | self.convs_up = gluon.nn.HybridSequential() # 1 argument 60 | self.UpCombs = gluon.nn.HybridSequential() # 2 arguments 61 | for idx in range(depth-1,0,-1): 62 | nheads = nheads_start * 2**idx 63 | nfilters = nfilters_init * 2 **(idx-1) 64 | if verbose: 65 | print ("depth:= {0}, nfilters: {1}, nheads::{2}, widths::{3}".format(2*depth-idx-1,nfilters,nheads,widths[idx])) 66 | 67 | tnet = gluon.nn.HybridSequential() 68 | for _ in range(widths[idx]): 69 | tnet.add(FracTALResNet_unit(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 70 | self.convs_up.add(tnet) 71 | 72 | if upFuse==True: 73 | self.UpCombs.add(combine_layers_wthFusion(nfilters=nfilters, nheads=nheads, _norm_type=norm_type,norm_groups=norm_groups,ftdepth=ftdepth)) 74 | else: 75 | self.UpCombs.add(combine_layers(nfilters, _norm_type=norm_type,norm_groups=norm_groups)) 76 | 77 | def hybrid_forward(self, F, input): 78 | 79 | conv1_first = self.conv_first(input) 80 | 81 | 82 | # ******** Going down *************** 83 | fusions = [] 84 | 85 | # Workaround of a mxnet bug 86 | # https://github.com/apache/incubator-mxnet/issues/16736 87 | pools = F.identity(conv1_first) 88 | 89 | for idx in range(self.depth): 90 | conv1 = self.convs_dn[idx](pools) 91 | if idx < self.depth-1: 92 | # Evaluate fusions 93 | conv1 = F.identity(conv1) 94 | fusions = fusions + [conv1] 95 | # Evaluate pools 96 | pools = self.pools[idx](conv1) 97 | 98 | # Middle psppooling 99 | middle = self.middle(conv1) 100 | # Activation of middle layer 101 | middle = F.relu(middle) 102 | fusions = fusions + [middle] 103 | 104 | # ******* Coming up **************** 105 | convs_up = middle 106 | for idx in range(self.depth-1): 107 | convs_up = self.UpCombs[idx](convs_up, fusions[-idx-2]) 108 | convs_up = self.convs_up[idx](convs_up) 109 | 110 | return convs_up, conv1_first 111 | 112 | 113 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/models/semanticsegmentation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/activations/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/activations/sigmoid_crisp.py: -------------------------------------------------------------------------------- 1 | from mxnet.gluon import HybridBlock 2 | import mxnet as mx 3 | 4 | 5 | class SigmoidCrisp(HybridBlock): 6 | def __init__(self, smooth=1.e-2,**kwards): 7 | super().__init__(**kwards) 8 | 9 | 10 | self.smooth = smooth 11 | with self.name_scope(): 12 | self.gamma = self.params.get('gamma', shape=(1,), init=mx.init.One()) 13 | 14 | 15 | def hybrid_forward(self, F, input, gamma): 16 | out = self.smooth + F.sigmoid(gamma) 17 | out = F.reciprocal(out) 18 | 19 | out = F.broadcast_mul(input,out) 20 | out = F.sigmoid(out) 21 | return out 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/layers/attention.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | from decode.FracTAL_ResUNet.nn.layers.conv2Dnormed import * 4 | from decode.FracTAL_ResUNet.nn.layers.ftnmt import * 5 | 6 | 7 | 8 | class RelFTAttention2D(HybridBlock): 9 | def __init__(self, nkeys, kernel_size=3, padding=1,nheads=1, norm = 'BatchNorm', norm_groups=None,ftdepth=5,**kwards): 10 | super().__init__(**kwards) 11 | 12 | with self.name_scope(): 13 | 14 | self.query = Conv2DNormed(channels=nkeys,kernel_size= kernel_size, padding = padding, _norm_type= norm, norm_groups=norm_groups) 15 | self.key = Conv2DNormed(channels=nkeys,kernel_size= kernel_size, padding = padding, _norm_type= norm, norm_groups=norm_groups) 16 | self.value = Conv2DNormed(channels=nkeys,kernel_size= kernel_size, padding = padding, _norm_type= norm, norm_groups=norm_groups) 17 | 18 | 19 | self.metric_channel = FTanimoto(depth=ftdepth, axis=[2,3]) 20 | self.metric_space = FTanimoto(depth=ftdepth, axis=1) 21 | 22 | self.norm = get_norm(name=norm, axis=1, norm_groups= norm_groups) 23 | 24 | def hybrid_forward(self, F, input1, input2, input3): 25 | 26 | # These should work with ReLU as well 27 | q = F.sigmoid(self.query(input1)) 28 | k = F.sigmoid(self.key(input2))# B,C,H,W 29 | v = F.sigmoid(self.value(input3)) # B,C,H,W 30 | 31 | att_spat = self.metric_space(q,k) # B,1,H,W 32 | v_spat = F.broadcast_mul(att_spat, v) # emphasize spatial features 33 | 34 | att_chan = self.metric_channel(q,k) # B,C,1,1 35 | v_chan = F.broadcast_mul(att_chan, v) # emphasize spatial features 36 | 37 | 38 | v_cspat = 0.5*F.broadcast_add(v_chan, v_spat) # emphasize spatial features 39 | v_cspat = self.norm(v_cspat) 40 | 41 | return v_cspat 42 | 43 | 44 | 45 | class FTAttention2D(HybridBlock): 46 | def __init__(self, nkeys, kernel_size=3, padding=1, nheads=1, norm = 'BatchNorm', norm_groups=None,ftdepth=5,**kwards): 47 | super().__init__(**kwards) 48 | 49 | with self.name_scope(): 50 | self. att = RelFTAttention2D(nkeys=nkeys,kernel_size=kernel_size, padding=padding, nheads=nheads, norm = norm, norm_groups=norm_groups,ftdepth=ftdepth,**kwards) 51 | 52 | 53 | def hybrid_forward(self, F, input): 54 | return self.att(input,input,input) 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/layers/combine.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | from decode.FracTAL_ResUNet.nn.layers.scale import * 5 | from decode.FracTAL_ResUNet.nn.layers.conv2Dnormed import * 6 | 7 | 8 | """ 9 | For combining layers with Fusion (i.e. relative attention), see ../units/ceecnet.py 10 | """ 11 | 12 | 13 | class combine_layers(HybridBlock): 14 | def __init__(self,_nfilters, _norm_type = 'BatchNorm', norm_groups=None, **kwards): 15 | HybridBlock.__init__(self,**kwards) 16 | 17 | with self.name_scope(): 18 | 19 | # This performs convolution, no BatchNormalization. No need for bias. 20 | self.up = UpSample(_nfilters, _norm_type = _norm_type, norm_groups=norm_groups) 21 | 22 | self.conv_normed = Conv2DNormed(channels = _nfilters, 23 | kernel_size=(1,1), 24 | padding=(0,0), 25 | _norm_type=_norm_type, 26 | norm_groups=norm_groups) 27 | 28 | 29 | 30 | 31 | def hybrid_forward(self,F,_layer_lo, _layer_hi): 32 | 33 | up = self.up(_layer_lo) 34 | up = F.relu(up) 35 | x = F.concat(up,_layer_hi, dim=1) 36 | x = self.conv_normed(x) 37 | 38 | return x 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/layers/conv2Dnormed.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from mxnet import gluon 3 | from mxnet.gluon import HybridBlock 4 | from decode.FracTAL_ResUNet.utils.get_norm import * 5 | 6 | 7 | class Conv2DNormed(HybridBlock): 8 | """ 9 | Convenience wrapper layer for 2D convolution followed by a normalization layer 10 | All other keywords are the same as gluon.nn.Conv2D 11 | """ 12 | 13 | def __init__(self, channels, kernel_size, strides=(1, 1), 14 | padding=(0, 0), dilation=(1, 1), activation=None, 15 | weight_initializer=None, in_channels=0, _norm_type = 'BatchNorm', norm_groups=None, axis =1 , groups=1, **kwards): 16 | super().__init__(**kwards) 17 | 18 | with self.name_scope(): 19 | self.conv2d = gluon.nn.Conv2D(channels, kernel_size = kernel_size, 20 | strides= strides, 21 | padding=padding, 22 | dilation= dilation, 23 | activation=activation, 24 | use_bias=False, 25 | weight_initializer = weight_initializer, 26 | groups=groups, 27 | in_channels=0) 28 | 29 | self.norm_layer = get_norm(_norm_type, axis=axis, norm_groups= norm_groups) 30 | 31 | def hybrid_forward(self,F,_x): 32 | 33 | x = self.conv2d(_x) 34 | x = self.norm_layer(x) 35 | 36 | return x 37 | 38 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/layers/ftnmt.py: -------------------------------------------------------------------------------- 1 | from mxnet.gluon import HybridBlock 2 | 3 | 4 | class FTanimoto(HybridBlock): 5 | """ 6 | This is the average fractal Tanimoto set similarity with complement. 7 | """ 8 | def __init__(self, depth=5, smooth=1.0e-5, axis=[2,3],**kwards): 9 | super().__init__(**kwards) 10 | 11 | assert depth >= 0, "Expecting depth >= 0, aborting ..." 12 | 13 | if depth == 0: 14 | self.depth = 1 15 | self.scale = 1. 16 | else: 17 | self.depth = depth 18 | self.scale = 1./depth 19 | 20 | self.smooth = smooth 21 | self.axis=axis 22 | 23 | def inner_prod(self, F, prob, label): 24 | prod = F.broadcast_mul(prob,label) 25 | prod = F.sum(prod,axis=self.axis,keepdims=True) 26 | 27 | return prod 28 | 29 | 30 | 31 | def tnmt_base(self, F, preds, labels): 32 | 33 | tpl = self.inner_prod(F,preds,labels) 34 | tpp = self.inner_prod(F,preds,preds) 35 | tll = self.inner_prod(F,labels,labels) 36 | 37 | 38 | num = tpl + self.smooth 39 | denum = 0.0 40 | 41 | 42 | for d in range(self.depth): 43 | a = 2.**d 44 | b = -(2.*a-1.) 45 | 46 | denum = denum + F.reciprocal(F.broadcast_add(a*(tpp+tll), b *tpl) + self.smooth) 47 | 48 | return F.broadcast_mul(num,denum)*self.scale 49 | 50 | def hybrid_forward(self, F, preds, labels): 51 | l12 = self.tnmt_base(F,preds,labels) 52 | l12 = l12 + self.tnmt_base(F,1.-preds, 1.-labels) 53 | 54 | return 0.5*l12 55 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/layers/scale.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | from decode.FracTAL_ResUNet.nn.layers.conv2Dnormed import * 5 | from decode.FracTAL_ResUNet.utils.get_norm import * 6 | 7 | class DownSample(HybridBlock): 8 | def __init__(self, nfilters, factor=2, _norm_type='BatchNorm', norm_groups=None, **kwargs): 9 | super().__init__(**kwargs) 10 | 11 | 12 | # Double the size of filters, since you downscale by 2. 13 | self.factor = factor 14 | self.nfilters = nfilters * self.factor 15 | 16 | self.kernel_size = (3,3) 17 | self.strides = (factor,factor) 18 | self.pad = (1,1) 19 | 20 | with self.name_scope(): 21 | self.convdn = Conv2DNormed(self.nfilters, 22 | kernel_size=self.kernel_size, 23 | strides=self.strides, 24 | padding=self.pad, 25 | _norm_type = _norm_type, 26 | norm_groups=norm_groups) 27 | 28 | 29 | def hybrid_forward(self,F,_xl): 30 | 31 | x = self.convdn(_xl) 32 | 33 | return x 34 | 35 | 36 | class UpSample(HybridBlock): 37 | def __init__(self,nfilters, factor = 2, _norm_type='BatchNorm', norm_groups=None, **kwards): 38 | HybridBlock.__init__(self,**kwards) 39 | 40 | 41 | self.factor = factor 42 | self.nfilters = nfilters // self.factor 43 | 44 | with self.name_scope(): 45 | self.convup_normed = Conv2DNormed(self.nfilters, 46 | kernel_size = (1,1), 47 | _norm_type = _norm_type, 48 | norm_groups=norm_groups) 49 | 50 | def hybrid_forward(self,F,_xl): 51 | x = F.UpSampling(_xl, scale=self.factor, sample_type='nearest') 52 | x = self.convup_normed(x) 53 | 54 | return x 55 | 56 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/loss/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/loss/ftnmt_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fractal Tanimoto (with dual) loss 3 | """ 4 | 5 | from mxnet.gluon.loss import Loss 6 | class ftnmt_loss(Loss): 7 | """ 8 | This function calculates the average fractal tanimoto similarity for d = 0...depth 9 | """ 10 | def __init__(self, depth=5, axis= [1,2,3], smooth = 1.0e-5, batch_axis=0, weight=None, **kwargs): 11 | super().__init__(batch_axis, weight, **kwargs) 12 | 13 | assert depth>= 0, ValueError("depth must be >= 0, aborting...") 14 | 15 | self.smooth = smooth 16 | self.axis=axis 17 | self.depth = depth 18 | 19 | if depth == 0: 20 | self.depth = 1 21 | self.scale = 1. 22 | else: 23 | self.depth = depth 24 | self.scale = 1./depth 25 | 26 | def inner_prod(self, F, prob, label): 27 | prod = F.broadcast_mul(prob,label) 28 | prod = F.sum(prod,axis=self.axis) 29 | 30 | return prod 31 | 32 | def tnmt_base(self, F, preds, labels): 33 | 34 | tpl = self.inner_prod(F,preds,labels) 35 | tpp = self.inner_prod(F,preds,preds) 36 | tll = self.inner_prod(F,labels,labels) 37 | 38 | 39 | num = tpl + self.smooth 40 | scale = 1./self.depth 41 | denum = 0.0 42 | for d in range(self.depth): 43 | a = 2.**d 44 | b = -(2.*a-1.) 45 | 46 | denum = denum + F.reciprocal(F.broadcast_add(a*(tpp+tll), b *tpl) + self.smooth) 47 | 48 | result = F.broadcast_mul(num,denum)*scale 49 | return F.mean(result, axis=0,exclude=True) 50 | 51 | 52 | def hybrid_forward(self,F, preds, labels): 53 | 54 | l1 = self.tnmt_base(F,preds,labels) 55 | l2 = self.tnmt_base(F,1.-preds, 1.-labels) 56 | 57 | result = 0.5*(l1+l2) 58 | 59 | return 1. - result 60 | 61 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/loss/mtsk_loss.py: -------------------------------------------------------------------------------- 1 | from FracTAL_ResUNet.nn.loss.ftnmt_loss import * 2 | 3 | class mtsk_loss(object): 4 | """ 5 | Here NClasses = 2 by default, for a binary segmentation problem in 1hot representation 6 | """ 7 | 8 | def __init__(self,depth=0, NClasses=2): 9 | 10 | self.ftnmt = ftnmt_loss(depth=depth) 11 | self.ftnmt.hybridize() 12 | 13 | self.skip = NClasses 14 | 15 | def loss(self,_prediction,_label): 16 | 17 | pred_segm = _prediction[0] 18 | pred_bound = _prediction[1] 19 | pred_dists = _prediction[2] 20 | 21 | # In our implementation of the labels, we stack together the [segmentation, boundary, distance] labels, 22 | # along the channel axis. 23 | label_segm = _label[:,:self.skip,:,:] 24 | label_bound = _label[:,self.skip:2*self.skip,:,:] 25 | label_dists = _label[:,2*self.skip:,:,:] 26 | 27 | 28 | loss_segm = self.ftnmt(pred_segm, label_segm) 29 | loss_bound = self.ftnmt(pred_bound, label_bound) 30 | loss_dists = self.ftnmt(pred_dists, label_dists) 31 | 32 | return (loss_segm+loss_bound+loss_dists)/3.0 33 | 34 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/pooling/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/pooling/psp_pooling.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | from decode.FracTAL_ResUNet.nn.layers.conv2Dnormed import * 4 | 5 | class PSP_Pooling(gluon.HybridBlock): 6 | def __init__(self, nfilters, depth=4, _norm_type = 'BatchNorm', norm_groups=None, mob=False, **kwards): 7 | gluon.HybridBlock.__init__(self,**kwards) 8 | 9 | 10 | self.nfilters = nfilters 11 | self.depth = depth 12 | 13 | self.convs = gluon.nn.HybridSequential() 14 | with self.name_scope(): 15 | for _ in range(depth): 16 | self.convs.add(Conv2DNormed(self.nfilters,kernel_size=(1,1),padding=(0,0),_norm_type=_norm_type, norm_groups=norm_groups)) 17 | 18 | self.conv_norm_final = Conv2DNormed(channels = self.nfilters, 19 | kernel_size=(1,1), 20 | padding=(0,0), 21 | _norm_type=_norm_type, 22 | norm_groups=norm_groups) 23 | 24 | 25 | # ******** Utilities functions to avoid calling infer_shape **************** 26 | def HalfSplit(self, F,_a): 27 | """ 28 | Returns a list of half split arrays. Usefull for HalfPoolling 29 | """ 30 | b = F.split(_a,axis=2,num_outputs=2) # Split First dimension 31 | c1 = F.split(b[0],axis=3,num_outputs=2) # Split 2nd dimension 32 | c2 = F.split(b[1],axis=3,num_outputs=2) # Split 2nd dimension 33 | 34 | 35 | d11 = c1[0] 36 | d12 = c1[1] 37 | 38 | d21 = c2[0] 39 | d22 = c2[1] 40 | 41 | return [d11,d12,d21,d22] 42 | 43 | 44 | def QuarterStitch(self, F,_Dss): 45 | """ 46 | INPUT: 47 | A list of [d11,d12,d21,d22] block matrices. 48 | OUTPUT: 49 | A single matrix joined of these submatrices 50 | """ 51 | 52 | temp1 = F.concat(_Dss[0],_Dss[1],dim=-1) 53 | temp2 = F.concat(_Dss[2],_Dss[3],dim=-1) 54 | result = F.concat(temp1,temp2,dim=2) 55 | 56 | return result 57 | 58 | 59 | def HalfPooling(self, F,_a): 60 | Ds = self.HalfSplit(F,_a) 61 | 62 | Dss = [] 63 | for x in Ds: 64 | Dss += [F.broadcast_mul(F.ones_like(x) , F.Pooling(x,global_pool=True))] 65 | 66 | return self.QuarterStitch(F,Dss) 67 | 68 | 69 | 70 | def SplitPooling(self, F, _a, depth): 71 | """ 72 | A recursive function that produces the Pooling you want - in particular depth (powers of 2) 73 | """ 74 | if depth==1: 75 | return self.HalfPooling(F,_a) 76 | else : 77 | D = self.HalfSplit(F,_a) 78 | return self.QuarterStitch(F,[self.SplitPooling(F,d,depth-1) for d in D]) 79 | # *********************************************************************************** 80 | 81 | def hybrid_forward(self,F,_input): 82 | 83 | p = [_input] 84 | # 1st:: Global Max Pooling . 85 | p += [self.convs[0](F.broadcast_mul(F.ones_like(_input) , F.Pooling(_input,global_pool=True)))] 86 | p += [self.convs[d](self.SplitPooling(F,_input,d)) for d in range(1,self.depth)] 87 | out = F.concat(*p,dim=1) 88 | out = self.conv_norm_final(out) 89 | 90 | return out 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/units/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/nn/units/fractal_resnet.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | from decode.FracTAL_ResUNet.nn.layers.conv2Dnormed import * 4 | from decode.FracTAL_ResUNet.utils.get_norm import * 5 | from decode.FracTAL_ResUNet.nn.layers.attention import * 6 | 7 | class ResNet_v2_block(HybridBlock): 8 | """ 9 | ResNet v2 building block. It is built upon the assumption of ODD kernel 10 | """ 11 | def __init__(self, _nfilters,_kernel_size=(3,3),_dilation_rate=(1,1), 12 | _norm_type='BatchNorm', norm_groups=None, ngroups=1, **kwards): 13 | super().__init__(**kwards) 14 | 15 | self.nfilters = _nfilters 16 | self.kernel_size = _kernel_size 17 | self.dilation_rate = _dilation_rate 18 | 19 | 20 | with self.name_scope(): 21 | 22 | # Ensures padding = 'SAME' for ODD kernel selection 23 | p0 = self.dilation_rate[0] * (self.kernel_size[0] - 1)/2 24 | p1 = self.dilation_rate[1] * (self.kernel_size[1] - 1)/2 25 | p = (int(p0),int(p1)) 26 | 27 | 28 | self.BN1 = get_norm(_norm_type, norm_groups=norm_groups ) 29 | self.conv1 = gluon.nn.Conv2D(self.nfilters,kernel_size = self.kernel_size,padding=p,dilation=self.dilation_rate,use_bias=False,groups=ngroups) 30 | self.BN2 = get_norm(_norm_type, norm_groups= norm_groups) 31 | self.conv2 = gluon.nn.Conv2D(self.nfilters,kernel_size = self.kernel_size,padding=p,dilation=self.dilation_rate,use_bias=True, groups=ngroups) 32 | 33 | 34 | def hybrid_forward(self,F,_input_layer): 35 | 36 | x = self.BN1(_input_layer) 37 | x = F.relu(x) 38 | x = self.conv1(x) 39 | 40 | x = self.BN2(x) 41 | x = F.relu(x) 42 | x = self.conv2(x) 43 | 44 | return x 45 | 46 | class FracTALResNet_unit(HybridBlock): 47 | def __init__(self, nfilters, ngroups=1, nheads=1, kernel_size=(3,3), dilation_rate=(1,1), norm_type = 'BatchNorm', norm_groups=None, ftdepth=5,**kwards): 48 | super().__init__(**kwards) 49 | 50 | with self.name_scope(): 51 | self.block1 = ResNet_v2_block(nfilters,kernel_size,dilation_rate,_norm_type = norm_type, norm_groups=norm_groups, ngroups=ngroups) 52 | self.attn = FTAttention2D(nkeys=nfilters, nheads=nheads, kernel_size=kernel_size, norm = norm_type, norm_groups = norm_groups,ftdepth=ftdepth) 53 | 54 | self.gamma = self.params.get('gamma', shape=(1,), init=mx.init.Zero()) 55 | 56 | def hybrid_forward(self, F, input, gamma): 57 | out1 = self.block1(input) 58 | 59 | 60 | att = self.attn(input) 61 | att= F.broadcast_mul(gamma,att) 62 | 63 | out = F.broadcast_mul((input + out1) , F.ones_like(out1) + att) 64 | return out 65 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/src/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/src/semseg_aug_cv2.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import itertools 3 | import numpy as np 4 | 5 | 6 | class ParamsRange(dict): 7 | def __init__(self): 8 | 9 | 10 | # Good default values for 256x256 images 11 | self['center_range'] =[0,256] 12 | self['rot_range'] =[-85.0,85.0] 13 | self['zoom_range'] = [0.75,1.25] 14 | self['noise_mean'] = [0]*5 15 | self['noise_var'] = [10]*5 16 | 17 | 18 | class SemSegAugmentor_CV(object): 19 | """ 20 | INPUTS: 21 | parameters range for all transformations 22 | probability of transformation to take place - default to 1. 23 | Nrot: number of rotations in comparison with reflections x,y,xy. Default to equal the number of reflections. 24 | """ 25 | def __init__(self, params_range, prob = 1.0, Nrot=5, norm = None, one_hot = True): 26 | 27 | self.norm = norm # This is a necessary hack to apply brightness normalization 28 | self.one_hot = one_hot 29 | self.range = params_range 30 | self.prob = prob 31 | assert self.prob <= 1 , "prob must be in range [0,1], you gave prob::{}".format(prob) 32 | 33 | 34 | # define a proportion of operations? 35 | self.operations = [self.reflect_x, self.reflect_y, self.reflect_xy,self.random_brightness, self.random_shadow] 36 | self.operations += [self.rand_shit_rot_zoom]*Nrot 37 | self.iterator = itertools.cycle(self.operations) 38 | 39 | 40 | def _shift_rot_zoom(self,_img, _mask, _center, _angle, _scale): 41 | """ 42 | OpenCV random scale+rotation 43 | """ 44 | imgT = _img.transpose([1,2,0]) 45 | if (self.one_hot): 46 | maskT = _mask.transpose([1,2,0]) 47 | else: 48 | maskT = _mask 49 | 50 | cols, rows = imgT.shape[:-1] 51 | 52 | # Produces affine rotation matrix, with center, for angle, and optional zoom in/out scale 53 | tRotMat = cv2.getRotationMatrix2D(_center, _angle, _scale) 54 | 55 | img_trans = cv2.warpAffine(imgT,tRotMat,(cols,rows),flags=cv2.INTER_AREA, borderMode=cv2.BORDER_REFLECT_101) # """,flags=cv2.INTER_CUBIC,""" 56 | mask_trans= cv2.warpAffine(maskT,tRotMat,(cols,rows),flags=cv2.INTER_AREA, borderMode=cv2.BORDER_REFLECT_101) 57 | 58 | img_trans = img_trans.transpose([2,0,1]) 59 | if (self.one_hot): 60 | mask_trans = mask_trans.transpose([2,0,1]) 61 | 62 | return img_trans, mask_trans 63 | 64 | 65 | def reflect_x(self,_img,_mask): 66 | 67 | img_z = _img[:,::-1,:] 68 | if self.one_hot: 69 | mask_z = _mask[:,::-1,:] # 1hot representation 70 | else: 71 | mask_z = _mask[::-1,:] # standard (int's representation) 72 | 73 | return img_z, mask_z 74 | 75 | def reflect_y(self,_img,_mask): 76 | img_z = _img[:,:,::-1] 77 | if self.one_hot: 78 | mask_z = _mask[:,:,::-1] # 1hot representation 79 | else: 80 | mask_z = _mask[:,::-1] # standard (int's representation) 81 | 82 | return img_z, mask_z 83 | 84 | def reflect_xy(self,_img,_mask): 85 | img_z = _img[:,::-1,::-1] 86 | if self.one_hot: 87 | mask_z = _mask[:,::-1,::-1] # 1hot representation 88 | else: 89 | mask_z = _mask[::-1,::-1] # standard (int's representation) 90 | 91 | return img_z, mask_z 92 | 93 | 94 | 95 | def rand_shit_rot_zoom(self,_img,_mask): 96 | 97 | center = np.random.randint(low=self.range['center_range'][0], 98 | high=self.range['center_range'][1], 99 | size=2) 100 | # This is in radians 101 | angle = np.random.uniform(low=self.range['rot_range'][0], 102 | high=self.range['rot_range'][1]) 103 | 104 | scale = np.random.uniform(low=self.range['zoom_range'][0], 105 | high=self.range['zoom_range'][1]) 106 | 107 | 108 | return self._shift_rot_zoom(_img,_mask,tuple(center),angle,scale) #, tuple(center),angle,scale 109 | 110 | 111 | 112 | # ============================================ New additions below ======================================================= 113 | # **************** Random brightness (light/dark) and random shadow polygons ************* 114 | # ******** Taken from: https://medium.freecodecamp.org/image-augmentation-make-it-rain-make-it-snow-how-to-modify-a-photo-with-machine-learning-163c0cb3843f 115 | # ******** See https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library for library 116 | 117 | def random_brightness(self,_img, _mask): 118 | """ 119 | This function only applies only on the first 3 channels (RGB) of an image. 120 | Input: RGB image, transforms to np.uint8 121 | Output: RGB image + extra channels. 122 | """ 123 | 124 | if self.norm is not None: 125 | image = self.norm.restore(_img).transpose([1,2,0])[:,:,:3].copy() # use only three bands 126 | imgcp = self.norm.restore(_img.copy()) # use only three bands 127 | 128 | else : 129 | 130 | image = _img.transpose([1,2,0])[:,:,:3].copy().astype(np.uint8) # use only three bands 131 | imgcp = _img.copy() .astype(np.uint8)# use only three bands 132 | 133 | image_HLS = cv2.cvtColor(image,cv2.COLOR_RGB2HLS) ## Conversion to HLS 134 | image_HLS = np.array(image_HLS, dtype = np.float64) 135 | random_brightness_coefficient = np.random.uniform()+0.5 ## generates value between 0.5 and 1.5 136 | image_HLS[:,:,1] = image_HLS[:,:,1]*random_brightness_coefficient ## scale pixel values up or down for channel 1(Lightness) 137 | image_HLS[:,:,1][image_HLS[:,:,1]>255] = 255 ##Sets all values above 255 to 255 138 | image_HLS = np.array(image_HLS, dtype = np.uint8) 139 | image_RGB = cv2.cvtColor(image_HLS,cv2.COLOR_HLS2RGB) ## Conversion back to RGB 140 | 141 | 142 | imgcp[:3,:,:] = image_RGB.transpose([2,0,1]) 143 | 144 | if self.norm is not None: 145 | imgcp = self.norm(imgcp) 146 | 147 | return imgcp.astype(_img.dtype), _mask 148 | 149 | 150 | 151 | def _generate_shadow_coordinates(self,imshape, no_of_shadows=1): 152 | vertices_list=[] 153 | for index in range(no_of_shadows): 154 | vertex=[] 155 | for dimensions in range(np.random.randint(3,15)): ## Dimensionality of the shadow polygon 156 | vertex.append(( imshape[1]*np.random.uniform(), imshape[0]*np.random.uniform())) 157 | vertices = np.array([vertex], dtype=np.int32) ## single shadow vertices 158 | vertices = cv2.convexHull(vertices[0]) 159 | vertices = vertices.transpose([1,0,2]) 160 | vertices_list.append(vertices) 161 | return vertices_list ## List of shadow vertices 162 | 163 | def _add_shadow(self, image, no_of_shadows=1): 164 | image_HLS = cv2.cvtColor(image,cv2.COLOR_RGB2HLS) ## Conversion to HLS 165 | tmask = np.zeros_like(image[:,:,0]) 166 | imshape = image.shape 167 | vertices_list= self._generate_shadow_coordinates(imshape, no_of_shadows) #3 getting list of shadow vertices 168 | for vertices in vertices_list: 169 | cv2.fillPoly(tmask, vertices, 255) 170 | image_HLS[:,:,1][tmask[:,:]==255] = image_HLS[:,:,1][tmask[:,:]==255]*0.5 171 | image_RGB = cv2.cvtColor(image_HLS,cv2.COLOR_HLS2RGB) ## Conversion to RGB 172 | return image_RGB 173 | 174 | 175 | def random_shadow(self,_img, _mask): 176 | 177 | 178 | if self.norm is not None: 179 | image = self.norm.restore(_img).transpose([1,2,0])[:,:,:3].copy() # use only three bands 180 | imgcp = self.norm.restore(_img.copy()) # use only three bands 181 | 182 | else : 183 | 184 | image = _img.transpose([1,2,0])[:,:,:3].copy().astype(np.uint8)# use only three bands 185 | imgcp = _img.copy() .astype(np.uint8)# use only three bands 186 | 187 | shadow_image = self._add_shadow(image) 188 | 189 | imgcp.transpose([1,2,0])[:,:,:3] = shadow_image 190 | 191 | if self.norm is not None: 192 | imgcp = self.norm(imgcp) 193 | 194 | return imgcp.astype(_img.dtype), _mask 195 | 196 | # ===================================================================================== 197 | 198 | 199 | def __call__(self,_img, _mask): 200 | 201 | rand = np.random.rand() 202 | if (rand <= self.prob): 203 | return next(self.iterator)(_img,_mask) 204 | else : 205 | return _img, _mask 206 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FracTAL_ResUNet/utils/get_norm.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from mxnet import gluon 3 | 4 | def get_norm(name, axis=1, norm_groups=None): 5 | if (name == 'BatchNorm'): 6 | return gluon.nn.BatchNorm(axis=axis) 7 | elif (name == 'InstanceNorm'): 8 | return gluon.nn.InstanceNorm(axis=axis) 9 | elif (name == 'LayerNorm'): 10 | return gluon.nn.LayerNorm(axis=axis) 11 | elif (name == 'GroupNorm' and norm_groups is not None): 12 | return gluon.nn.GroupNorm(num_groups = norm_groups) # applied to channel axis 13 | else: 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scalable satellite-based delineation of field boundaries with DECODE 2 | 3 | Official [mxnet](https://mxnet.incubator.apache.org/) implementation of the paper: ["Detect, consolidate, delineate: scalable mapping of field boundaries using satellite images"](https://www.mdpi.com/2072-4292/13/11/2197), Waldner et al. (2021). This repository contains source code for implementing and training the FracTAL ResUNet as described in the paper. All models are built with the mxnet DL framework (version < 2.0), under the gluon api. We do not provide pre-trained weights. 4 | 5 | Inference examples for six areas in Australia. 6 | ![mantis](images/decode.png) 7 | 8 | 9 | ### Directory structure: 10 | 11 | ``` 12 | . 13 | ├── examples 14 | ├── images 15 | ├── FracTAL_ResUNet 16 | │ ├── heads 17 | │ └── semanticsegmentation 18 | │ └── x_unet 19 | ├── nn 20 | │ ├── activations 21 | │ ├── layers 22 | │ ├── loss 23 | │ ├── pooling 24 | │ └── units 25 | └── postprocessing 26 | ``` 27 | 28 | In ```examples```, there are notebooks that 1) show how to initiate a Fractal ResUNet model, and perform forward and multitasking backward operations and 2) that illustrate how to perform instance segmantion using hierarchical watershed segmentations. 29 | 30 | 31 | ### License 32 | CSIRO BSTD/MIT LICENSE 33 | 34 | As a condition of this licence, you agree that where you make any adaptations, modifications, further developments, or additional features available to CSIRO or the public in connection with your access to the Software, you do so on the terms of the BSD 3-Clause Licence template, a copy available at: http://opensource.org/licenses/BSD-3-Clause. 35 | 36 | 37 | 38 | ### CITATION 39 | If you find the contents of this repository useful for your research, please cite: 40 | ``` 41 | 42 | @Article{rs13112197, 43 | AUTHOR = {Waldner, Franc\{c}ois and Diakogiannis, Foivos I. and Batchelor, Kathryn and Ciccotosto-Camp, Michael and Cooper-Williams, Elizabeth and Herrmann, Chris and Mata, Gonzalo and Toovey, Andrew}, 44 | TITLE = {Detect, Consolidate, Delineate: Scalable Mapping of Field Boundaries Using Satellite Images}, 45 | JOURNAL = {Remote Sensing}, 46 | VOLUME = {13}, 47 | YEAR = {2021}, 48 | NUMBER = {11}, 49 | ARTICLE-NUMBER = {2197}, 50 | URL = {https://www.mdpi.com/2072-4292/13/11/2197}, 51 | ISSN = {2072-4292}, 52 | ABSTRACT = {Digital agriculture services can greatly assist growers to monitor their fields and optimize their use throughout the growing season. Thus, knowing the exact location of fields and their boundaries is a prerequisite. Unlike property boundaries, which are recorded in local council or title records, field boundaries are not historically recorded. As a result, digital services currently ask their users to manually draw their field, which is time-consuming and creates disincentives. Here, we present a generalized method, hereafter referred to as DECODE (DEtect, COnsolidate, and DElinetate), that automatically extracts accurate field boundary data from satellite imagery using deep learning based on spatial, spectral, and temporal cues. We introduce a new convolutional neural network (FracTAL ResUNet) as well as two uncertainty metrics to characterize the confidence of the field detection and field delineation processes. We finally propose a new methodology to compare and summarize field-based accuracy metrics. To demonstrate the performance and scalability of our method, we extracted fields across the Australian grains zone with a pixel-based accuracy of 0.87 and a field-based accuracy of up to 0.88 depending on the metric. We also trained a model on data from South Africa instead of Australia and found it transferred well to unseen Australian landscapes. We conclude that the accuracy, scalability and transferability of DECODE shows that large-scale field boundary extraction based on deep learning has reached operational maturity. This opens the door to new agricultural services that provide routine, near-real time field-based analytics.}, 53 | DOI = {10.3390/rs13112197} 54 | } 55 | 56 | 57 | 58 | 59 | ` 60 | -------------------------------------------------------------------------------- /examples/Demo_forward_backward.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Populating the interactive namespace from numpy and matplotlib\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "%pylab inline" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import sys\n", 27 | "sys.path.append('/Your/Location/To/DECODERepo/')" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stderr", 37 | "output_type": "stream", 38 | "text": [ 39 | "/usr/local/lib/python3.7/site-packages/joblib/_multiprocessing_helpers.py:45: UserWarning: [Errno 28] No space left on device. joblib will operate in serial mode\n", 40 | " warnings.warn('%s. joblib will operate in serial mode' % (e,))\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "from mxnet import nd \n", 46 | "from FracTALResUNet.models.changedetection.mantis.mantis_dn import *" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "depth:= 0, nfilters: 32, nheads::4, widths::1\n", 59 | "depth:= 1, nfilters: 64, nheads::8, widths::1\n", 60 | "depth:= 2, nfilters: 128, nheads::16, widths::1\n", 61 | "depth:= 3, nfilters: 256, nheads::32, widths::1\n", 62 | "depth:= 4, nfilters: 512, nheads::64, widths::1\n", 63 | "depth:= 5, nfilters: 1024, nheads::128, widths::1\n", 64 | "depth:= 6, nfilters: 512, nheads::128, widths::1\n", 65 | "depth:= 7, nfilters: 256, nheads::64, widths::1\n", 66 | "depth:= 8, nfilters: 128, nheads::32, widths::1\n", 67 | "depth:= 9, nfilters: 64, nheads::16, widths::1\n", 68 | "depth:= 10, nfilters: 32, nheads::8, widths::1\n" 69 | ] 70 | } 71 | ], 72 | "source": [ 73 | "# D6nf32 example \n", 74 | "depth=6\n", 75 | "norm_type='GroupNorm'\n", 76 | "norm_groups=4\n", 77 | "ftdepth=5\n", 78 | "NClasses=2\n", 79 | "nfilters_init=32\n", 80 | "psp_depth=4\n", 81 | "nheads_start=4\n", 82 | "\n", 83 | "\n", 84 | "net = mantis_dn_cmtsk(nfilters_init=nfilters_init, NClasses=NClasses,depth=depth, ftdepth=ftdepth, psp_depth=psp_depth,norm_type=norm_type,norm_groups=norm_groups,nheads_start=nheads_start)\n", 85 | "net.initialize()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "BatchSize = 4\n", 95 | "img_size=256\n", 96 | "NChannels = 3\n", 97 | "\n", 98 | "input_img_1 = nd.random.uniform(shape=[BatchSize, NChannels, img_size, img_size])\n", 99 | "input_img_2 = nd.random.uniform(shape=[BatchSize, NChannels, img_size, img_size])" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## Example of forward operation:\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 6, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "outs = net(input_img_1, input_img_2)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 7, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "(4, 2, 256, 256)\n", 128 | "(4, 2, 256, 256)\n", 129 | "(4, 2, 256, 256)\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "# outs is a list of outputs, segmentation, boundary, distance. \n", 135 | "# Each has shape BatchSize, NClasses, img_size, img_size\n", 136 | "for out in outs:\n", 137 | " print (out.shape)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "### Example of performing backward with multitasking operation" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 8, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "labels_segm = nd.random.uniform(shape=[BatchSize, NClasses, img_size, img_size])\n", 154 | "labels_segm = labels_segm > 0.5\n", 155 | "\n", 156 | "labels_bound = nd.random.uniform(shape=[BatchSize, NClasses, img_size, img_size])\n", 157 | "labels_bound = labels_bound > 0.5\n", 158 | "\n", 159 | "labels_dist = nd.random.uniform(shape=[BatchSize, NClasses, img_size, img_size])\n", 160 | "\n", 161 | "\n", 162 | "labels = [labels_segm,labels_bound,labels_dist]\n", 163 | "labels = nd.concat(*labels,dim=1)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 9, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "from mxnet import autograd\n", 173 | "from ceecnet.nn.loss.mtsk_loss import *" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 10, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "myMTSKL = mtsk_loss()" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 12, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "with autograd.record():\n", 192 | " listOfPreds = net(input_img_1, input_img_2)\n", 193 | " loss = myMTSKL.loss(listOfPreds,labels)\n", 194 | " loss.backward()" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 13, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "(4,)" 206 | ] 207 | }, 208 | "execution_count": 13, 209 | "metadata": {}, 210 | "output_type": "execute_result" 211 | } 212 | ], 213 | "source": [ 214 | "loss.shape" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 28, 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "data": { 224 | "text/plain": [ 225 | "\n", 226 | "[0.50219935 0.5020496 0.5023406 0.5021815 ]\n", 227 | "" 228 | ] 229 | }, 230 | "execution_count": 28, 231 | "metadata": {}, 232 | "output_type": "execute_result" 233 | } 234 | ], 235 | "source": [ 236 | "loss" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [] 245 | } 246 | ], 247 | "metadata": { 248 | "kernelspec": { 249 | "display_name": "Python 3", 250 | "language": "python", 251 | "name": "python3" 252 | }, 253 | "language_info": { 254 | "codemirror_mode": { 255 | "name": "ipython", 256 | "version": 3 257 | }, 258 | "file_extension": ".py", 259 | "mimetype": "text/x-python", 260 | "name": "python", 261 | "nbconvert_exporter": "python", 262 | "pygments_lexer": "ipython3", 263 | "version": "3.7.9" 264 | } 265 | }, 266 | "nbformat": 4, 267 | "nbformat_minor": 4 268 | } 269 | -------------------------------------------------------------------------------- /images/decode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waldnerf/decode/bed6f6b4173f49362a5113207155cc103e8fd139/images/decode.png -------------------------------------------------------------------------------- /postprocessing/instance_segmentation.py: -------------------------------------------------------------------------------- 1 | import higra as hg 2 | import numpy as np 3 | 4 | 5 | def InstSegm(extent, boundary, t_ext=0.4, t_bound=0.2): 6 | """ 7 | INPUTS: 8 | extent : extent prediction 9 | boundary : boundary prediction 10 | t_ext : threshold for extent 11 | t_bound : threshold for boundary 12 | OUTPUT: 13 | instances 14 | """ 15 | 16 | # Threshold extent mask 17 | ext_binary = np.uint8(extent >= t_ext) 18 | 19 | # Artificially create strong boundaries for 20 | # pixels with non-field labels 21 | input_hws = np.copy(boundary) 22 | input_hws[ext_binary == 0] = 1 23 | 24 | # Create the directed graph 25 | size = input_hws.shape[:2] 26 | graph = hg.get_8_adjacency_graph(size) 27 | edge_weights = hg.weight_graph( 28 | graph, 29 | input_hws, 30 | hg.WeightFunction.mean) 31 | 32 | tree, altitudes = hg.watershed_hierarchy_by_dynamics( 33 | graph, 34 | edge_weights) 35 | 36 | # Get individual fields 37 | # by cutting the graph using altitude 38 | instances = hg.labelisation_horizontal_cut_from_threshold( 39 | tree, 40 | altitudes, 41 | threshold=t_bound).astype(np.float) 42 | 43 | instances[ext_binary == 0] = np.nan 44 | 45 | return instances 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | glob 2 | numpy 3 | rasterio 4 | pickle 5 | opencv-python 6 | pathos 7 | mxnet < 2.0 8 | higra 9 | 10 | --------------------------------------------------------------------------------