├── __init__.py ├── models ├── __init__.py ├── resunet_d6_encoder.py ├── resunet_d7_encoder.py ├── resunet_d6_causal_mtskcolor_ddist.py └── resunet_d7_causal_mtskcolor_ddist.py ├── nn ├── BBlocks │ ├── __init__.py │ └── resnet_blocks.py ├── __init__.py ├── pooling │ ├── __init__.py │ ├── psp_pooling_understanding_nonHybrid.py │ └── psp_pooling.py ├── Units │ ├── __init__.py │ ├── resnet_units.py │ └── resnet_atrous_units.py ├── layers │ ├── __init__.py │ ├── combine.py │ ├── conv2Dnormed.py │ └── scale.py └── loss │ ├── __init__.py │ └── loss.py ├── src ├── __init__.py ├── bound_dist.py ├── ISPRSNormal.py ├── ISPRSDataset.py ├── semseg_aug_cv2.py └── chopchop_run.py ├── images └── inference_all_tasks_1.png ├── LICENSE.txt └── readme.md /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nn/BBlocks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nn/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nn/pooling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nn/Units/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nn/loss/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/inference_all_tasks_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/resuneta/HEAD/images/inference_all_tasks_1.png -------------------------------------------------------------------------------- /src/bound_dist.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | def get_boundary(label, kernel_size = (3,3)): 4 | tlabel = label.astype(np.uint8) 5 | temp = cv2.Canny(tlabel,0,1) 6 | tlabel = cv2.dilate( 7 | temp, 8 | cv2.getStructuringElement( 9 | cv2.MORPH_CROSS, 10 | kernel_size), 11 | iterations = 1) 12 | tlabel = tlabel.astype(np.float32) 13 | tlabel /= 255. 14 | return tlabel 15 | 16 | 17 | def get_distance(label): 18 | tlabel = label.astype(np.uint8) 19 | dist = cv2.distanceTransform(tlabel, 20 | cv2.DIST_L2, 21 | 0) 22 | dist = cv2.normalize(dist, 23 | dist, 24 | 0, 1.0, 25 | cv2.NORM_MINMAX) 26 | return dist 27 | -------------------------------------------------------------------------------- /nn/Units/resnet_units.py: -------------------------------------------------------------------------------- 1 | from resuneta.nn.BBlocks import resnet_blocks 2 | from mxnet.gluon import HybridBlock 3 | 4 | 5 | 6 | class ResNet_v2_unit(HybridBlock): 7 | """ 8 | Following He et al. 2016 -- there is the option to replace BatchNormalization with Instance normalization 9 | """ 10 | def __init__(self, _nfilters,_kernel_size=(3,3),_dilation_rate=(1,1), _norm_type = 'BatchNorm', **kwards): 11 | super(ResNet_v2_unit,self).__init__(**kwards) 12 | 13 | with self.name_scope(): 14 | self.ResBlock1 = resnet_blocks.ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate,_norm_type = _norm_type) 15 | 16 | 17 | def hybrid_forward(self,F,_xl): 18 | 19 | 20 | # x = self.ResBlock1 (_xl) + _xl # Imperative programming only 21 | x = F.broadcast_add(self.ResBlock1 (_xl) ,_xl) # Uniform description for both Symbol and NDArray 22 | 23 | return x 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/ISPRSNormal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class for normalizing the sliced images for the ISPRS competition Potsdam 3 | """ 4 | 5 | 6 | import numpy as np 7 | 8 | 9 | class ISPRSNormal(object): 10 | def __init__(self, mean=None, std=None): 11 | 12 | if (mean == None or std == None): 13 | self._mean = np.array([ 85.48596573, 91.41396302, 84.60300113, 96.89973231, 46.04194328]) 14 | self._std = np.array ([35.624903855445062, 34.882833894659328, 36.222623905578963, 15 | 36.663837159102393, 54.91177108287215]) 16 | 17 | 18 | else : 19 | self._mean = mean 20 | self._std = std 21 | 22 | 23 | def __call__(self,img): 24 | 25 | temp = img.astype(np.float32) 26 | temp2 = temp.T 27 | temp2 -= self._mean 28 | temp2 /= self._std 29 | 30 | temp = temp2.T 31 | 32 | return temp 33 | 34 | 35 | 36 | def restore(self,normed_img): 37 | 38 | d2 = normed_img.T * self._std 39 | d2 = d2 + self._mean 40 | d2 = d2.T 41 | d2 = np.round(d2) 42 | d2 = d2.astype('uint8') 43 | 44 | return d2 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /nn/layers/combine.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | from resuneta.nn.layers.scale import * 5 | from resuneta.nn.layers.conv2Dnormed import * 6 | 7 | class combine_layers(HybridBlock): 8 | """ 9 | This is a function that combines two layers, a low (that is upsampled) and a higher one. 10 | The philosophy is similar to the combination one finds in the UNet architecture. 11 | It is used both in UNet and ResUNet models. 12 | """ 13 | 14 | 15 | def __init__(self,_nfilters, _norm_type = 'BatchNorm', **kwards): 16 | HybridBlock.__init__(self,**kwards) 17 | 18 | with self.name_scope(): 19 | 20 | # This performs convolution, no BatchNormalization. No need for bias. 21 | self.up = UpSample(_nfilters, _norm_type = _norm_type ) 22 | 23 | self.conv_normed = Conv2DNormed(channels = _nfilters, 24 | kernel_size=(1,1), 25 | padding=(0,0), _norm_type=_norm_type) 26 | 27 | 28 | 29 | 30 | def hybrid_forward(self,F,_layer_lo, _layer_hi): 31 | 32 | up = self.up(_layer_lo) 33 | up = F.relu(up) 34 | x = F.concat(up,_layer_hi, dim=1) # Concat along CHANNEL axis 35 | x = self.conv_normed(x) 36 | 37 | return x 38 | 39 | 40 | -------------------------------------------------------------------------------- /nn/BBlocks/resnet_blocks.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | 5 | class ResNet_v2_block(HybridBlock): 6 | """ 7 | ResNet v2 building block. It is built upon the assumption of ODD kernel 8 | """ 9 | def __init__(self, _nfilters,_kernel_size=(3,3),_dilation_rate=(1,1), 10 | _norm_type='BatchNorm', **kwards): 11 | HybridBlock.__init__(self,**kwards) 12 | 13 | self.nfilters = _nfilters 14 | self.kernel_size = _kernel_size 15 | self.dilation_rate = _dilation_rate 16 | 17 | 18 | if (_norm_type == 'BatchNorm'): 19 | self.norm = gluon.nn.BatchNorm 20 | _prefix = "_BN" 21 | elif (_norm_type == 'InstanceNorm'): 22 | self.norm = gluon.nn.InstanceNorm 23 | _prefix = "_IN" 24 | elif (norm_type == 'LayerNorm'): 25 | self.norm = gluon.nn.LayerNorm 26 | _prefix = "_LN" 27 | else: 28 | raise NotImplementedError 29 | 30 | 31 | with self.name_scope(): 32 | 33 | # Ensures padding = 'SAME' for ODD kernel selection 34 | p0 = self.dilation_rate[0] * (self.kernel_size[0] - 1)/2 35 | p1 = self.dilation_rate[1] * (self.kernel_size[1] - 1)/2 36 | p = (int(p0),int(p1)) 37 | 38 | 39 | self.BN1 = self.norm(axis=1, prefix = _prefix+"1_") 40 | self.conv1 = gluon.nn.Conv2D(self.nfilters,kernel_size = self.kernel_size,padding=p,dilation=self.dilation_rate,use_bias=False,prefix="_conv1_") 41 | self.BN2 = self.norm(axis=1,prefix= _prefix + "2_") 42 | self.conv2 = gluon.nn.Conv2D(self.nfilters,kernel_size = self.kernel_size,padding=p,dilation=self.dilation_rate,use_bias=True,prefix="_conv2_") 43 | 44 | 45 | def hybrid_forward(self,F,_input_layer): 46 | 47 | 48 | x = self.BN1(_input_layer) 49 | x = F.relu(x) 50 | x = self.conv1(x) 51 | 52 | x = self.BN2(x) 53 | x = F.relu(x) 54 | x = self.conv2(x) 55 | 56 | return x 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /nn/layers/conv2Dnormed.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | class Conv2DNormed(HybridBlock): 5 | """ 6 | Convenience wrapper layer for 2D convolution followed by a normalization layer 7 | (either BatchNorm or InstanceNorm). 8 | norm_type: Either BatchNorm (default) or InstanceNorm strings. 9 | axis : axis in normalization (exists only in BatchNorm). 10 | All other keywords are the same as gluon.nn.Conv2D 11 | """ 12 | 13 | def __init__(self, channels, kernel_size, strides=(1, 1), 14 | padding=(0, 0), dilation=(1, 1), activation=None, 15 | weight_initializer=None, in_channels=0, _norm_type = 'BatchNorm', axis =1 ,**kwards): 16 | HybridBlock.__init__(self,**kwards) 17 | 18 | if (_norm_type == 'BatchNorm'): 19 | self.norm = gluon.nn.BatchNorm 20 | elif (_norm_type == 'SyncBatchNorm'): 21 | self.norm = gluon.contrib.nn.SyncBatchNorm 22 | _prefix = "_SyncBN" 23 | elif (_norm_type == 'InstanceNorm'): 24 | self.norm = gluon.nn.InstanceNorm 25 | 26 | elif (_norm_type == 'LayerNorm'): 27 | self.norm = gluon.nn.LayerNorm 28 | else: 29 | raise NotImplementedError 30 | 31 | 32 | with self.name_scope(): 33 | self.conv2d = gluon.nn.Conv2D(channels, kernel_size = kernel_size, 34 | strides= strides, 35 | padding=padding, 36 | dilation= dilation, 37 | activation=activation, 38 | use_bias=False, 39 | weight_initializer = weight_initializer, 40 | in_channels=0) 41 | 42 | self.norm_layer = self.norm(axis=axis) 43 | 44 | 45 | def hybrid_forward(self,F,_x): 46 | 47 | x = self.conv2d(_x) 48 | x = self.norm_layer(x) 49 | 50 | return x 51 | 52 | 53 | -------------------------------------------------------------------------------- /nn/pooling/psp_pooling_understanding_nonHybrid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use this only to understand the psp pooling. This code is not hybridizable. 3 | 4 | 5 | TODO: Currently there is a problem: I need the layer size at runtime, but I cannot get it for Symbol, 6 | only for ndarray. This needs to be fixed!!! 7 | """ 8 | 9 | 10 | from mxnet import gluon 11 | from mxnet.gluon import HybridBlock 12 | from mxnet.ndarray import NDArray 13 | from resuneta.nn.layers.conv2Dnormed import * 14 | 15 | class PSP_Pooling(HybridBlock): 16 | 17 | """ 18 | Pyramid Scene Parsing pooling layer, as defined in Zhao et al. 2017 (https://arxiv.org/abs/1612.01105) 19 | This is only the pyramid pooling module. 20 | INPUT: 21 | layer of size Nbatch, Nchannel, H, W 22 | OUTPUT: 23 | layer of size Nbatch, Nchannel, H, W. 24 | 25 | """ 26 | 27 | def __init__(self, _nfilters, _norm_type = 'BatchNorm', **kwards): 28 | HybridBlock.__init__(self,**kwards) 29 | 30 | self.nfilters = _nfilters 31 | 32 | # This is used as a container (list) of layers 33 | self.convs = gluon.nn.HybridSequential() 34 | with self.name_scope(): 35 | 36 | self.convs.add(Conv2DNormed(self.nfilters//4,kernel_size=(1,1),padding=(0,0), prefix="_conv1_")) 37 | self.convs.add(Conv2DNormed(self.nfilters//4,kernel_size=(1,1),padding=(0,0), prefix="_conv2_")) 38 | self.convs.add(Conv2DNormed(self.nfilters//4,kernel_size=(1,1),padding=(0,0), prefix="_conv3_")) 39 | self.convs.add(Conv2DNormed(self.nfilters//4,kernel_size=(1,1),padding=(0,0), prefix="_conv4_")) 40 | 41 | 42 | self.conv_norm_final = Conv2DNormed(channels = self.nfilters, 43 | kernel_size=(1,1), 44 | padding=(0,0), 45 | _norm_type=_norm_type) 46 | 47 | 48 | 49 | def hybrid_forward(self,F,_input): 50 | 51 | # This if statement could be slowing down the performance. 52 | if isinstance(_input,NDArray): 53 | layer_size = _input.shape[2] 54 | else : 55 | raise NotImplementedError 56 | #layer_size = _input.infer_shape() 57 | 58 | p = [_input] 59 | for i in range(4): 60 | 61 | pool_size = layer_size // (2**i) # Need this to be integer 62 | x = F.Pooling(_input,kernel=[pool_size,pool_size],stride=[pool_size,pool_size],pool_type='max') 63 | x = F.UpSampling(x,sample_type='nearest',scale=pool_size) 64 | x = self.convs[i](x) 65 | p += [x] 66 | 67 | out = F.concat(p[0],p[1],p[2],p[3],p[4],dim=1) 68 | 69 | out = self.conv_norm_final(out) 70 | 71 | return out 72 | 73 | 74 | -------------------------------------------------------------------------------- /nn/layers/scale.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | from resuneta.nn.layers.conv2Dnormed import * 5 | 6 | 7 | class DownSample(HybridBlock): 8 | """ 9 | DownSample a convolutional layer by half, and at the same time double the number of filters. 10 | """ 11 | def __init__(self,_nfilters, _factor=2, _norm_type='BatchNorm', **kwards): 12 | HybridBlock.__init__(self, **kwards) 13 | 14 | 15 | # Double the size of filters, since you will downscale by 2. 16 | self.factor = _factor 17 | self.nfilters = _nfilters * self.factor 18 | # I was using a kernel size of 1x1, this is notthing to do with max pooling, or selecting the most dominant number. Now changing that. 19 | # There is bug somewhere, if I use kernel_size = 2, code crashes with memory-illegal access. 20 | # Am not sure it is my bug, or something mxnet related 21 | 22 | # Kernel = 3, padding = 1 works fine, no bug here in latest version of mxnet. 23 | self.kernel_size = (3,3) 24 | self.strides = (2,2) 25 | self.pad = (1,1) 26 | 27 | 28 | with self.name_scope(): 29 | self.convdn = gluon.nn.Conv2D(self.nfilters, 30 | kernel_size=self.kernel_size, 31 | strides=self.strides, 32 | padding = self.pad, 33 | use_bias=False, 34 | prefix="_convdn_") 35 | 36 | 37 | def hybrid_forward(self,F,_xl): 38 | 39 | x = self.convdn(_xl) 40 | 41 | return x 42 | 43 | 44 | 45 | # This will go to the decoder architecture 46 | class UpSample(HybridBlock): 47 | """ 48 | UpSample by resizing and a k=1 convolution to half the size of filters. The point here is to get away 49 | from the transposed convolution 50 | """ 51 | 52 | def __init__(self,_nfilters, factor = 2, _norm_type='BatchNorm', **kwards): 53 | HybridBlock.__init__(self,**kwards) 54 | 55 | 56 | self.factor = factor 57 | self.nfilters = _nfilters // self.factor 58 | 59 | with self.name_scope(): 60 | self.convup_normed = Conv2DNormed(self.nfilters, 61 | kernel_size = (1,1), 62 | _norm_type = _norm_type, 63 | prefix="_convdn_") 64 | 65 | def hybrid_forward(self,F,_xl): 66 | # I need to add bilinear upsampling, but I get an error, for now will be 'nearest' till 67 | # issue is resolved (opened ticket on github). 68 | # See https://stackoverflow.com/questions/47897924/implementing-bilinear-interpolation-with-mxnet-ndarray-upsampling/48013886#48013886 69 | x = F.UpSampling(_xl, scale=self.factor, sample_type='nearest') 70 | x = self.convup_normed(x) 71 | 72 | return x 73 | 74 | -------------------------------------------------------------------------------- /nn/loss/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mxnet.gluon.loss import Loss 3 | 4 | 5 | 6 | 7 | class Tanimoto(Loss): 8 | def __init__(self, _smooth=1.0e-5, _axis=[2,3], _weight = None, _batch_axis= 0, **kwards): 9 | Loss.__init__(self,weight=_weight, batch_axis = _batch_axis, **kwards) 10 | 11 | self.axis = _axis 12 | self.smooth = _smooth 13 | 14 | def hybrid_forward(self,F,_preds, _label): 15 | 16 | # Evaluate the mean volume of class per batch 17 | Vli = F.mean(F.sum(_label,axis=self.axis),axis=0) 18 | #wli = 1.0/Vli**2 # weighting scheme 19 | wli = F.reciprocal(Vli**2) # weighting scheme 20 | 21 | # ---------------------This line is taken from niftyNet package -------------- 22 | # ref: https://github.com/NifTK/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py, lines:170 -- 172 23 | # new_weights = tf.where(tf.is_inf(weights), tf.zeros_like(weights), weights) 24 | # weights = tf.where(tf.is_inf(weights), tf.ones_like(weights) * tf.reduce_max(new_weights), weights) 25 | # -------------------------------------------------------------------- 26 | 27 | # *********************************************************************************************** 28 | # First turn inf elements to zero, then replace that with the maximum weight value 29 | new_weights = F.where(wli == np.float('inf'), F.zeros_like(wli), wli ) 30 | wli = F.where( wli == np.float('inf'), F.broadcast_mul(F.ones_like(wli),F.max(new_weights)) , wli) 31 | # ************************************************************************************************ 32 | 33 | 34 | rl_x_pl = F.sum( F.broadcast_mul(_label , _preds), axis=self.axis) 35 | # This is sum of squares 36 | l = F.sum( F.broadcast_mul(_label , _label), axis=self.axis) 37 | r = F.sum( F.broadcast_mul( _preds , _preds ) , axis=self.axis) 38 | 39 | rl_p_pl = l + r - rl_x_pl 40 | 41 | tnmt = (F.sum( F.broadcast_mul(wli , rl_x_pl),axis=1) + self.smooth)/ ( F.sum( F.broadcast_mul(wli,(rl_p_pl)),axis=1) + self.smooth) 42 | 43 | return tnmt # This returns the tnmt for EACH data point, i.e. a vector of values equal to the batch size 44 | 45 | 46 | 47 | # This is the loss used in the manuscript of resuneta 48 | class Tanimoto_wth_dual(Loss): 49 | """ 50 | Tanimoto coefficient with dual from: Diakogiannis et al 2019 (https://arxiv.org/abs/1904.00592) 51 | Note: to use it in deep learning training use: return 1. - 0.5*(loss1+loss2) 52 | """ 53 | def __init__(self, _smooth=1.0e-5, _axis=[2,3], _weight = None, _batch_axis= 0, **kwards): 54 | Loss.__init__(self,weight=_weight, batch_axis = _batch_axis, **kwards) 55 | 56 | with self.name_scope(): 57 | self.Loss = Tanimoto(_smooth = _smooth, _axis = _axis) 58 | 59 | 60 | def hybrid_forward(self,F,_preds,_label): 61 | 62 | # measure of overlap 63 | loss1 = self.Loss(_preds,_label) 64 | 65 | # measure of non-overlap as inner product 66 | preds_dual = 1.0-_preds 67 | labels_dual = 1.0-_label 68 | loss2 = self.Loss(preds_dual,labels_dual) 69 | 70 | 71 | return 0.5*(loss1+loss2) 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | CSIRO Open Source Software Licence Agreement (variation of the BSD / MIT License) 2 | Copyright (c) resuneta, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. 3 | All rights reserved. CSIRO is willing to grant you a licence to resuneta on the following terms, except where otherwise indicated for third party material. 4 | Redistribution and use of this software in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | * Neither the name of CSIRO nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission of CSIRO. 8 | EXCEPT AS EXPRESSLY STATED IN THIS AGREEMENT AND TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, THE SOFTWARE IS PROVIDED "AS-IS". CSIRO MAKES NO REPRESENTATIONS, WARRANTIES OR CONDITIONS OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY REPRESENTATIONS, WARRANTIES OR CONDITIONS REGARDING THE CONTENTS OR ACCURACY OF THE SOFTWARE, OR OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, THE ABSENCE OF LATENT OR OTHER DEFECTS, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. 9 | TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL CSIRO BE LIABLE ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, IN AN ACTION FOR BREACH OF CONTRACT, NEGLIGENCE OR OTHERWISE) FOR ANY CLAIM, LOSS, DAMAGES OR OTHER LIABILITY HOWSOEVER INCURRED. WITHOUT LIMITING THE SCOPE OF THE PREVIOUS SENTENCE THE EXCLUSION OF LIABILITY SHALL INCLUDE: LOSS OF PRODUCTION OR OPERATION TIME, LOSS, DAMAGE OR CORRUPTION OF DATA OR RECORDS; OR LOSS OF ANTICIPATED SAVINGS, OPPORTUNITY, REVENUE, PROFIT OR GOODWILL, OR OTHER ECONOMIC LOSS; OR ANY SPECIAL, INCIDENTAL, INDIRECT, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES, ARISING OUT OF OR IN CONNECTION WITH THIS AGREEMENT, ACCESS OF THE SOFTWARE OR ANY OTHER DEALINGS WITH THE SOFTWARE, EVEN IF CSIRO HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH CLAIM, LOSS, DAMAGES OR OTHER LIABILITY. 10 | APPLICABLE LEGISLATION SUCH AS THE AUSTRALIAN CONSUMER LAW MAY APPLY REPRESENTATIONS, WARRANTIES, OR CONDITIONS, OR IMPOSES OBLIGATIONS OR LIABILITY ON CSIRO THAT CANNOT BE EXCLUDED, RESTRICTED OR MODIFIED TO THE FULL EXTENT SET OUT IN THE EXPRESS TERMS OF THIS CLAUSE ABOVE "CONSUMER GUARANTEES". TO THE EXTENT THAT SUCH CONSUMER GUARANTEES CONTINUE TO APPLY, THEN TO THE FULL EXTENT PERMITTED BY THE APPLICABLE LEGISLATION, THE LIABILITY OF CSIRO UNDER THE RELEVANT CONSUMER GUARANTEE IS LIMITED (WHERE PERMITTED AT CSIRO'S OPTION) TO ONE OF FOLLOWING REMEDIES OR SUBSTANTIALLY EQUIVALENT REMEDIES: 11 | (a) THE REPLACEMENT OF THE SOFTWARE, THE SUPPLY OF EQUIVALENT SOFTWARE, OR SUPPLYING RELEVANT SERVICES AGAIN; 12 | (b) THE REPAIR OF THE SOFTWARE; 13 | (c) THE PAYMENT OF THE COST OF REPLACING THE SOFTWARE, OF ACQUIRING EQUIVALENT SOFTWARE, HAVING THE RELEVANT SERVICES SUPPLIED AGAIN, OR HAVING THE SOFTWARE REPAIRED. 14 | IN THIS CLAUSE, CSIRO INCLUDES ANY THIRD PARTY AUTHOR OR OWNER OF ANY PART OF THE SOFTWARE OR MATERIAL DISTRIBUTED WITH IT. CSIRO MAY ENFORCE ANY RIGHTS ON BEHALF OF THE RELEVANT THIRD PARTY. 15 | -------------------------------------------------------------------------------- /src/ISPRSDataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | DataSet reader for the ISPRS data competition. It assumes the structure under the root directory 3 | where the data are saved 4 | /root/ 5 | /training/ 6 | /imgs/ 7 | /masks/ 8 | /validation/ 9 | /imgs/ 10 | /masks/ 11 | 12 | """ 13 | 14 | import os 15 | import numpy as np 16 | 17 | from mxnet.gluon.data import dataset 18 | import cv2 19 | 20 | class ISPRSDataset(dataset.Dataset): 21 | def __init__(self, root, mode='train', mtsk = True, color=True, transform=None, norm=None): 22 | 23 | self._mode = mode 24 | self.mtsk = mtsk 25 | self.color = color 26 | if (color): 27 | self.colornorm = np.array([1./179, 1./255, 1./255]) 28 | 29 | self._transform = transform 30 | self._norm = norm # Normalization of img 31 | 32 | if (root[-1]=='/'): 33 | self._root_train = root+'training/' 34 | self._root_valid = root + 'validation/' 35 | else : 36 | self._root_train = root+'/training/' 37 | self._root_valid = root + '/validation/' 38 | 39 | 40 | if mode is 'train': 41 | self._root_img = self._root_train + r'imgs/' 42 | self._root_mask = self._root_train + r'masks/' 43 | elif mode is 'val': 44 | self._root_img = self._root_valid + r'imgs/' 45 | self._root_mask = self._root_valid + r'masks/' 46 | else: 47 | raise Exception ('I was given inconcistent mode, available choices: {train, val}, aborting ...') 48 | 49 | 50 | 51 | self._img_list = sorted(os.listdir(self._root_img)) 52 | self._mask_list = sorted(os.listdir(self._root_mask)) 53 | 54 | assert len(self._img_list) == len(self._mask_list), "Masks and labels do not have same numbers, error" 55 | 56 | self.img_names = list(zip(self._img_list, self._mask_list)) 57 | 58 | 59 | def __getitem__(self, idx): 60 | 61 | base_filepath = os.path.join(self._root_img, self.img_names[idx][0]) 62 | mask_filepath = os.path.join(self._root_mask, self.img_names[idx][1]) 63 | 64 | # load in float32 65 | base = np.load(base_filepath) 66 | if self.color: 67 | timg = base.transpose([1,2,0])[:,:,:3].astype(np.uint8) 68 | base_hsv = cv2.cvtColor(timg,cv2.COLOR_RGB2HSV) 69 | base_hsv = base_hsv *self.colornorm 70 | base_hsv = base_hsv.transpose([2,0,1]).astype(np.float32) 71 | 72 | 73 | base = base.astype(np.float32) 74 | 75 | mask = np.load(mask_filepath) 76 | mask = mask.astype(np.float32) 77 | 78 | 79 | if self.color: 80 | mask = np.concatenate([mask,base_hsv],axis=0) 81 | 82 | if self.mtsk == False: 83 | mask = mask[:6,:,:] 84 | 85 | if self._transform is not None: 86 | base, mask = self._transform(base, mask) 87 | if self._norm is not None: 88 | base = self._norm(base.astype(np.float32)) 89 | 90 | return base.astype(np.float32), mask.astype(np.float32) 91 | 92 | else: 93 | if self._norm is not None: 94 | base = self._norm(base.astype(np.float32)) 95 | 96 | return base.astype(np.float32), mask.astype(np.float32) 97 | 98 | def __len__(self): 99 | return len(self.img_names) 100 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # ResUNet-a: a deep learning framework for semantic segmentation of remotely sensed data 2 | 3 | This repository contains source code for some of the models used in the manuscript of the ([ResUNet-a](https://arxiv.org/abs/1904.00592)) paper. ResUNet-a is built with the [mxnet](https://mxnet.incubator.apache.org/) DL framework, under the gluon api. 4 | 5 | 6 | ![inference example](images/inference_all_tasks_1.png) 7 | 8 | ### Requirements 9 | 1. mxnet (latest version, tests run with mxnet_cu92-1.5.0b20190613) 10 | 2. opencv 11 | 3. rasterio 12 | 4. glob 13 | 5. pathos 14 | 6. [ISPRS Potsdam data](http://www2.isprs.org/commissions/comm3/wg4/2d-sem-label-potsdam.html) that should be appropriately preprocessed (slices of 256x256 patches). 15 | 16 | 17 | 18 | ### Directory Structure 19 | 20 | Complete models live in the ```models``` directory, specifically models d6 and d7 (conditioned multitasking). 21 | These are built from modules that are alive in ```resuneta/nn``` directory. The Tanimoto loss function (with complement) is defined in file ```resuneta/nn/loss/loss.py``` Inference demo (```.ipynb```) can be found in directory demo. Directory ```nn``` contains all necessary modules for building resuneta models. Directort ```src``` is related to dataset definitions. In addition, file src/chopchop_run.py is an executable that produces slices of patches in size 256x256 from the original data. Please see the source code for modifications based on your directory structures. 22 | 23 | ``` 24 | ├── demo 25 | ├── images 26 | ├── models 27 | ├── nn 28 | │   ├── BBlocks 29 | │   ├── layers 30 | │   ├── loss 31 | │   ├── pooling 32 | │   └── Units 33 | └── src 34 | ``` 35 | 36 | ### Example of model usage 37 | See also ```demo/*.ipynb``` 38 | 39 | ```python 40 | In [1]: from resuneta.models.resunet_d7_causal_mtskcolor_ddist import * 41 | ...: from mxnet import nd 42 | ...: 43 | 44 | In [2]: Nfilters_init = 32 45 | ...: NClasses = 6 46 | ...: net = ResUNet_d7(Nfilters_init,NClasses) 47 | ...: net.initialize() 48 | ...: 49 | depth:= 0, nfilters: 32 50 | depth:= 1, nfilters: 64 51 | depth:= 2, nfilters: 128 52 | depth:= 3, nfilters: 256 53 | depth:= 4, nfilters: 512 54 | depth:= 5, nfilters: 1024 55 | depth:= 6, nfilters: 2048 56 | depth:= 7, nfilters: 1024 57 | depth:= 8, nfilters: 512 58 | depth:= 9, nfilters: 256 59 | depth:= 10, nfilters: 128 60 | depth:= 11, nfilters: 64 61 | depth:= 12, nfilters: 32 62 | 63 | In [3]: xx = nd.random.uniform(shape=[1,5,256,256]) 64 | 65 | In [4]: out = net(xx) 66 | ``` 67 | 68 | 69 | ### License 70 | CSIRO BSTD/MIT LICENSE 71 | 72 | As a condition of this licence, you agree that where you make any adaptations, modifications, further developments, or additional features available to CSIRO or the public in connection with your access to the Software, you do so on the terms of the BSD 3-Clause Licence template, a copy available at: http://opensource.org/licenses/BSD-3-Clause. 73 | 74 | ### Citation 75 | If you find the contents of this repository useful for your research, please cite: 76 | 77 | ``` 78 | @article{DIAKOGIANNIS202094, 79 | title = "ResUNet-a: A deep learning framework for semantic segmentation of remotely sensed data", 80 | journal = "ISPRS Journal of Photogrammetry and Remote Sensing", 81 | volume = "162", 82 | pages = "94 - 114", 83 | year = "2020", 84 | issn = "0924-2716", 85 | doi = "https://doi.org/10.1016/j.isprsjprs.2020.01.013", 86 | url = "http://www.sciencedirect.com/science/article/pii/S0924271620300149", 87 | author = "Foivos I. Diakogiannis and François Waldner and Peter Caccetta and Chen Wu", 88 | keywords = "Convolutional neural network, Loss function, Architecture, Data augmentation, Very high spatial resolution" 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /nn/pooling/psp_pooling.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | from resuneta.nn.layers.conv2Dnormed import * 4 | 5 | class PSP_Pooling(gluon.HybridBlock): 6 | """ 7 | This is the PSPPooling layer, defined recursively so as to avoid calling ndarray.shape. This form is hybridizable. 8 | """ 9 | 10 | def __init__(self, nfilters, depth=4, _norm_type = 'BatchNorm',**kwards): 11 | gluon.HybridBlock.__init__(self,**kwards) 12 | 13 | 14 | 15 | self.nfilters = nfilters 16 | self.depth = depth 17 | 18 | # This is used as a container (list) of layers 19 | self.convs = gluon.nn.HybridSequential() 20 | with self.name_scope(): 21 | for _ in range(depth): 22 | self.convs.add(Conv2DNormed(self.nfilters//self.depth,kernel_size=(1,1),padding=(0,0),_norm_type=_norm_type)) 23 | 24 | self.conv_norm_final = Conv2DNormed(channels = self.nfilters, 25 | kernel_size=(1,1), 26 | padding=(0,0), 27 | _norm_type=_norm_type) 28 | 29 | 30 | # ******** Utilities functions to avoid calling infer_shape **************** 31 | def HalfSplit(self, F,_a): 32 | """ 33 | Returns a list of half split arrays. Usefull for HalfPoolling 34 | """ 35 | b = F.split(_a,axis=2,num_outputs=2) # Split First dimension 36 | c1 = F.split(b[0],axis=3,num_outputs=2) # Split 2nd dimension 37 | c2 = F.split(b[1],axis=3,num_outputs=2) # Split 2nd dimension 38 | 39 | 40 | d11 = c1[0] 41 | d12 = c1[1] 42 | 43 | d21 = c2[0] 44 | d22 = c2[1] 45 | 46 | return [d11,d12,d21,d22] 47 | 48 | 49 | def QuarterStitch(self, F,_Dss): 50 | """ 51 | INPUT: 52 | A list of [d11,d12,d21,d22] block matrices. 53 | OUTPUT: 54 | A single matrix joined of these submatrices 55 | """ 56 | 57 | temp1 = F.concat(_Dss[0],_Dss[1],dim=-1) 58 | temp2 = F.concat(_Dss[2],_Dss[3],dim=-1) 59 | result = F.concat(temp1,temp2,dim=2) 60 | 61 | return result 62 | 63 | 64 | def HalfPooling(self, F,_a): 65 | """ 66 | Tested, produces consinstent results. 67 | """ 68 | Ds = self.HalfSplit(F,_a) 69 | 70 | Dss = [] 71 | for x in Ds: 72 | Dss += [F.broadcast_mul(F.ones_like(x) , F.Pooling(x,global_pool=True))] 73 | 74 | return self.QuarterStitch(F,Dss) 75 | 76 | 77 | 78 | #from functools import lru_cache 79 | #@lru_cache(maxsize=None) # This increases by a LOT the performance 80 | # Can't make it to work with symbol though (yet) 81 | def SplitPooling(self, F, _a, depth): 82 | #print("Calculating F", "(", depth, ")\n") 83 | """ 84 | A recursive function that produces the Pooling you want - in particular depth (powers of 2) 85 | """ 86 | if depth==1: 87 | return self.HalfPooling(F,_a) 88 | else : 89 | D = self.HalfSplit(F,_a) 90 | return self.QuarterStitch(F,[self.SplitPooling(F,d,depth-1) for d in D]) 91 | 92 | 93 | # *********************************************************************************** 94 | 95 | def hybrid_forward(self,F,_input): 96 | 97 | p = [_input] 98 | # 1st:: Global Max Pooling . 99 | p += [self.convs[0](F.broadcast_mul(F.ones_like(_input) , F.Pooling(_input,global_pool=True)))] 100 | p += [self.convs[d](self.SplitPooling(F,_input,d)) for d in range(1,self.depth)] 101 | out = F.concat(*p,dim=1) 102 | out = self.conv_norm_final(out) 103 | 104 | return out 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /src/semseg_aug_cv2.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import itertools 3 | import numpy as np 4 | 5 | 6 | class ParamsRange(dict): 7 | def __init__(self): 8 | 9 | # Good default values for 256x256 images 10 | self['center_range'] =[0,256] 11 | self['rot_range'] =[-85.0,85.0] 12 | self['zoom_range'] = [0.25,1.25] 13 | 14 | 15 | class SemSegAugmentor_CV(object): 16 | """ 17 | INPUTS: 18 | parameters range for all transformations 19 | probability of transformation to take place - default to 1. 20 | Nrot: number of rotations in comparison with reflections x,y,xy. Default to equal the number of reflections. 21 | """ 22 | def __init__(self, params_range, prob = 1.0, Nrot=3, one_hot = True): 23 | 24 | self.one_hot = one_hot 25 | self.range = params_range 26 | self.prob = prob 27 | assert self.prob <= 1 , "prob must be in range [0,1], you gave prob::{}".format(prob) 28 | 29 | 30 | # define a proportion of operations? 31 | self.operations = [self.reflect_x, self.reflect_y, self.reflect_xy] 32 | self.operations += [self.rand_shit_rot_zoom]*Nrot 33 | self.iterator = itertools.cycle(self.operations) 34 | 35 | 36 | def _shift_rot_zoom(self,_img, _mask, _center, _angle, _scale): 37 | """ 38 | OpenCV random scale+rotation 39 | """ 40 | imgT = _img.transpose([1,2,0]) 41 | if (self.one_hot): 42 | maskT = _mask.transpose([1,2,0]) 43 | else: 44 | maskT = _mask 45 | 46 | cols, rows = imgT.shape[:-1] 47 | 48 | # Produces affine rotation matrix, with center, for angle, and optional zoom in/out scale 49 | tRotMat = cv2.getRotationMatrix2D(_center, _angle, _scale) 50 | 51 | img_trans = cv2.warpAffine(imgT,tRotMat,(cols,rows),flags=cv2.INTER_AREA, borderMode=cv2.BORDER_REFLECT_101) # """,flags=cv2.INTER_CUBIC,""" 52 | mask_trans= cv2.warpAffine(maskT,tRotMat,(cols,rows),flags=cv2.INTER_AREA, borderMode=cv2.BORDER_REFLECT_101) 53 | 54 | img_trans = img_trans.transpose([2,0,1]) 55 | if (self.one_hot): 56 | mask_trans = mask_trans.transpose([2,0,1]) 57 | 58 | return img_trans, mask_trans 59 | 60 | 61 | def reflect_x(self,_img,_mask): 62 | 63 | img_z = _img[:,::-1,:] 64 | if self.one_hot: 65 | mask_z = _mask[:,::-1,:] # 1hot representation 66 | else: 67 | mask_z = _mask[::-1,:] # standard (int's representation) 68 | 69 | return img_z, mask_z 70 | 71 | def reflect_y(self,_img,_mask): 72 | img_z = _img[:,:,::-1] 73 | if self.one_hot: 74 | mask_z = _mask[:,:,::-1] # 1hot representation 75 | else: 76 | mask_z = _mask[:,::-1] # standard (int's representation) 77 | 78 | return img_z, mask_z 79 | 80 | def reflect_xy(self,_img,_mask): 81 | img_z = _img[:,::-1,::-1] 82 | if self.one_hot: 83 | mask_z = _mask[:,::-1,::-1] # 1hot representation 84 | else: 85 | mask_z = _mask[::-1,::-1] # standard (int's representation) 86 | 87 | return img_z, mask_z 88 | 89 | 90 | 91 | def rand_shit_rot_zoom(self,_img,_mask): 92 | 93 | center = np.random.randint(low=self.range['center_range'][0], 94 | high=self.range['center_range'][1], 95 | size=2) 96 | # This is in radians 97 | angle = np.random.uniform(low=self.range['rot_range'][0], 98 | high=self.range['rot_range'][1]) 99 | 100 | scale = np.random.uniform(low=self.range['zoom_range'][0], 101 | high=self.range['zoom_range'][1]) 102 | 103 | 104 | return self._shift_rot_zoom(_img,_mask,tuple(center),angle,scale) #, tuple(center),angle,scale 105 | 106 | 107 | 108 | def __call__(self,_img, _mask): 109 | 110 | rand = np.random.rand() 111 | if (rand <= self.prob): 112 | return next(self.iterator)(_img,_mask) 113 | else : 114 | return _img, _mask 115 | -------------------------------------------------------------------------------- /models/resunet_d6_encoder.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from mxnet import gluon 3 | from mxnet.gluon import HybridBlock 4 | 5 | 6 | from resuneta.nn.Units.resnet_units import * 7 | from resuneta.nn.Units.resnet_atrous_units import * 8 | from resuneta.nn.pooling.psp_pooling import * 9 | from resuneta.nn.layers.scale import * 10 | from resuneta.nn.layers.combine import * 11 | from resuneta.nn.layers.conv2Dnormed import * 12 | 13 | 14 | 15 | class ResUNet_d6_encoder(HybridBlock): 16 | """ 17 | This will be used for 256x256 image input, so the atrous convolutions should be determined by the depth 18 | """ 19 | 20 | def __init__(self,_nfilters_init, _NClasses, verbose=True, _norm_type = 'BatchNorm', **kwards): 21 | HybridBlock.__init__(self,**kwards) 22 | 23 | self.model_name = "ResUNet_d6_encoder" 24 | 25 | self.depth = 6 26 | 27 | self.nfilters = _nfilters_init # Initial number of filters 28 | self.NClasses = _NClasses 29 | 30 | 31 | 32 | with self.name_scope(): 33 | 34 | # First convolution Layer 35 | # Starting with first convolutions to make the input "channel" dim equal to the number of initial filters 36 | self.conv_first_normed = Conv2DNormed(channels=self.nfilters, 37 | kernel_size=(1,1), 38 | _norm_type = _norm_type, 39 | prefix="_conv_first_") 40 | 41 | 42 | # Progressively reducing the dilation_rate of Atrous convolutions (the deeper the smaller). 43 | 44 | # Usually 32 45 | nfilters = self.nfilters * 2**(0) 46 | if verbose: 47 | print ("depth:= {0}, nfilters: {1}".format(0,nfilters)) 48 | self.Dn1 = ResNet_atrous_unit(nfilters, _norm_type = _norm_type) 49 | self.pool1 = DownSample(nfilters, _norm_type = _norm_type) 50 | 51 | # Usually 64 52 | nfilters = self.nfilters * 2**(1) 53 | if verbose: 54 | print ("depth:= {0}, nfilters: {1}".format(1,nfilters)) 55 | self.Dn2 = ResNet_atrous_unit(nfilters, _norm_type = _norm_type) 56 | self.pool2 = DownSample(nfilters, _norm_type = _norm_type) 57 | 58 | # Usually 128 59 | nfilters = self.nfilters * 2**(2) 60 | if verbose: 61 | print ("depth:= {0}, nfilters: {1}".format(2,nfilters)) 62 | self.Dn3 = ResNet_atrous_2_unit(nfilters, _norm_type = _norm_type) 63 | self.pool3 = DownSample(nfilters, _norm_type = _norm_type) 64 | 65 | # Usually 256 66 | nfilters = self.nfilters * 2**(3) 67 | if verbose: 68 | print ("depth:= {0}, nfilters: {1}".format(3,nfilters)) 69 | self.Dn4 = ResNet_atrous_2_unit(nfilters,_dilation_rates=[3,5], _norm_type = _norm_type) 70 | self.pool4 = DownSample(nfilters, _norm_type = _norm_type) 71 | 72 | # Usually 512 73 | nfilters = self.nfilters * 2**(4) 74 | if verbose: 75 | print ("depth:= {0}, nfilters: {1}".format(4,nfilters)) 76 | self.Dn5 = ResNet_v2_unit(nfilters, _norm_type = _norm_type) 77 | self.pool5 = DownSample(nfilters) 78 | 79 | # Usually 1024 80 | nfilters = self.nfilters * 2**(5) 81 | if verbose: 82 | print ("depth:= {0}, nfilters: {1}".format(5,nfilters)) 83 | self.Dn6 = ResNet_v2_unit(nfilters) 84 | 85 | 86 | # Same number of filters, with new definition 87 | self.middle = PSP_Pooling(nfilters, _norm_type = _norm_type) 88 | 89 | 90 | 91 | def hybrid_forward(self,F,_input): 92 | 93 | # First convolution 94 | conv1 = self.conv_first_normed(_input) 95 | conv1 = F.relu(conv1) 96 | 97 | 98 | Dn1 = self.Dn1(conv1) 99 | pool1 = self.pool1(Dn1) 100 | 101 | Dn2 = self.Dn2(pool1) 102 | pool2 = self.pool2(Dn2) 103 | 104 | Dn3 = self.Dn3(pool2) 105 | pool3 = self.pool3(Dn3) 106 | 107 | Dn4 = self.Dn4(pool3) 108 | pool4 = self.pool4(Dn4) 109 | 110 | Dn5 = self.Dn5(pool4) 111 | pool5 = self.pool5(Dn5) 112 | 113 | 114 | Dn6 = self.Dn6(pool5) 115 | 116 | middle = self.middle(Dn6) 117 | middle = F.relu(middle) # Activation of middle layers 118 | 119 | 120 | return middle 121 | 122 | -------------------------------------------------------------------------------- /nn/Units/resnet_atrous_units.py: -------------------------------------------------------------------------------- 1 | from resuneta.nn.BBlocks import resnet_blocks 2 | from mxnet.gluon import HybridBlock 3 | 4 | 5 | 6 | # TODO: write a more sofisticated version, using HybridBlock as a container 7 | class ResNet_atrous_unit(HybridBlock): 8 | def __init__(self, _nfilters, _kernel_size=(3,3), _dilation_rates=[3,15,31], _norm_type = 'BatchNorm', **kwards): 9 | super(ResNet_atrous_unit,self).__init__(**kwards) 10 | 11 | 12 | # mxnet doesn't like wrapping things inside a list: it shadows the HybridBlock, remove list 13 | with self.name_scope(): 14 | 15 | self.ResBlock1 = resnet_blocks.ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(1,1), _norm_type = _norm_type, prefix="_ResNetv2block_1_") 16 | 17 | d = _dilation_rates[0] 18 | self.ResBlock2 = resnet_blocks.ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(d,d), _norm_type = _norm_type, prefix="_ResNetv2block_2_") 19 | 20 | d = _dilation_rates[1] 21 | self.ResBlock3 = resnet_blocks.ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(d,d), _norm_type = _norm_type, prefix="_ResNetv2block_3_") 22 | 23 | d = _dilation_rates[2] 24 | self.ResBlock4 = resnet_blocks.ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(d,d), _norm_type = _norm_type, prefix="_ResNetv2block_4_") 25 | 26 | 27 | 28 | def hybrid_forward(self,F,_xl): 29 | 30 | # First perform a standard ResNet block with dilation_rate = 1 31 | 32 | x = _xl 33 | 34 | """ 35 | # These are great for Imperative programming only, 36 | x = x + self.ResBlock1(_xl) 37 | x = x + self.ResBlock2(_xl) 38 | x = x + self.ResBlock3(_xl) 39 | x = x + self.ResBlock4(_xl) 40 | # """ 41 | 42 | # Uniform description for both Symbol and NDArray 43 | x = F.broadcast_add( x , self.ResBlock1(_xl) ) 44 | x = F.broadcast_add( x , self.ResBlock2(_xl) ) 45 | x = F.broadcast_add( x , self.ResBlock3(_xl) ) 46 | x = F.broadcast_add( x , self.ResBlock4(_xl) ) 47 | 48 | return x 49 | 50 | 51 | 52 | 53 | 54 | # Two atrous in parallel 55 | class ResNet_atrous_2_unit(HybridBlock): 56 | def __init__(self, _nfilters, _kernel_size=(3,3), _dilation_rates=[3,15], _norm_type = 'BatchNorm', **kwards): 57 | super(ResNet_atrous_2_unit,self).__init__(**kwards) 58 | 59 | 60 | # mxnet doesn't like wrapping things inside a list: it shadows the HybridBlock, remove list 61 | with self.name_scope(): 62 | 63 | self.ResBlock1 = resnet_blocks.ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(1,1), _norm_type = _norm_type, prefix="_ResNetv2block_1_") 64 | 65 | d = _dilation_rates[0] 66 | self.ResBlock2 = resnet_blocks.ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(d,d), _norm_type = _norm_type, prefix="_ResNetv2block_2_") 67 | 68 | d = _dilation_rates[1] 69 | self.ResBlock3 = resnet_blocks.ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(d,d), _norm_type = _norm_type, prefix="_ResNetv2block_3_") 70 | 71 | 72 | 73 | def hybrid_forward(self,F,_xl): 74 | 75 | # First perform a standard ResNet block with dilation_rate = 1 76 | x = _xl 77 | 78 | """ 79 | # Imperative program only 80 | x = x + self.ResBlock1(_xl) 81 | x = x + self.ResBlock2(_xl) 82 | x = x + self.ResBlock3(_xl) 83 | # """ 84 | 85 | # Uniform description for both Symbol and NDArray 86 | x = F.broadcast_add( x , self.ResBlock1(_xl) ) 87 | x = F.broadcast_add( x , self.ResBlock2(_xl) ) 88 | x = F.broadcast_add( x , self.ResBlock3(_xl) ) 89 | 90 | return x 91 | 92 | 93 | 94 | 95 | # One atrous in parallel 96 | class ResNet_atrous_1_unit(HybridBlock): 97 | def __init__(self, _nfilters, _kernel_size=(3,3), _dilation_rates=[3], _norm_type = 'BatchNorm', **kwards): 98 | super(ResNet_atrous_1_unit,self).__init__(**kwards) 99 | 100 | 101 | # mxnet doesn't like wrapping things inside a list: it shadows the HybridBlock, remove list 102 | with self.name_scope(): 103 | 104 | self.ResBlock1 = resnet_blocks.ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(1,1), _norm_type = _norm_type, prefix="_ResNetv2block_1_") 105 | 106 | d = _dilation_rates[0] 107 | self.ResBlock2 = resnet_blocks.ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(d,d), _norm_type = _norm_type, prefix="_ResNetv2block_2_") 108 | 109 | 110 | 111 | 112 | def hybrid_forward(self,F,_xl): 113 | 114 | # First perform a standard ResNet block with dilation_rate = 1 115 | x = _xl 116 | 117 | """ 118 | # Imperative program only 119 | x = x + self.ResBlock1(_xl) 120 | x = x + self.ResBlock2(_xl) 121 | # """ 122 | 123 | 124 | x = F.broadcast_add( x , self.ResBlock1(_xl) ) 125 | x = F.broadcast_add( x , self.ResBlock2(_xl) ) 126 | 127 | 128 | return x 129 | 130 | -------------------------------------------------------------------------------- /models/resunet_d7_encoder.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from mxnet import gluon 3 | from mxnet.gluon import HybridBlock 4 | 5 | 6 | from resuneta.nn.Units.resnet_units import * 7 | from resuneta.nn.Units.resnet_atrous_units import * 8 | from resuneta.nn.pooling.psp_pooling import * 9 | from resuneta.nn.layers.scale import * 10 | from resuneta.nn.layers.combine import * 11 | from resuneta.nn.layers.conv2Dnormed import * 12 | 13 | 14 | 15 | class ResUNet_d7_encoder(HybridBlock): 16 | """ 17 | This will be used for 256x256 image input, so the atrous convolutions should be determined by the depth 18 | """ 19 | 20 | def __init__(self,_nfilters_init, _NClasses, verbose=True, psp_depth=3, _norm_type = 'BatchNorm', **kwards): 21 | HybridBlock.__init__(self,**kwards) 22 | 23 | self.model_name = "ResUNet_d7_encoder" 24 | 25 | self.depth = 7 26 | 27 | self.nfilters = _nfilters_init # Initial number of filters 28 | self.NClasses = _NClasses 29 | 30 | 31 | 32 | with self.name_scope(): 33 | 34 | # First convolution Layer 35 | # Starting with first convolutions to make the input "channel" dim equal to the number of initial filters 36 | self.conv_first_normed = Conv2DNormed(channels=self.nfilters, 37 | kernel_size=(1,1), 38 | _norm_type = _norm_type, 39 | prefix="_conv_first_") 40 | 41 | 42 | # Progressively reducing the dilation_rate of Atrous convolutions (the deeper the smaller). 43 | 44 | # Usually 32 45 | nfilters = self.nfilters * 2**(0) 46 | if verbose: 47 | print ("depth:= {0}, nfilters: {1}".format(0,nfilters)) 48 | self.Dn1 = ResNet_atrous_unit(nfilters, _norm_type = _norm_type) 49 | self.pool1 = DownSample(nfilters, _norm_type = _norm_type) 50 | 51 | # Usually 64 52 | nfilters = self.nfilters * 2**(1) 53 | if verbose: 54 | print ("depth:= {0}, nfilters: {1}".format(1,nfilters)) 55 | self.Dn2 = ResNet_atrous_unit(nfilters, _norm_type = _norm_type) 56 | self.pool2 = DownSample(nfilters, _norm_type = _norm_type) 57 | 58 | # Usually 128 59 | nfilters = self.nfilters * 2**(2) 60 | if verbose: 61 | print ("depth:= {0}, nfilters: {1}".format(2,nfilters)) 62 | self.Dn3 = ResNet_atrous_2_unit(nfilters, _norm_type = _norm_type) 63 | self.pool3 = DownSample(nfilters, _norm_type = _norm_type) 64 | 65 | # Usually 256 66 | nfilters = self.nfilters * 2**(3) 67 | if verbose: 68 | print ("depth:= {0}, nfilters: {1}".format(3,nfilters)) 69 | self.Dn4 = ResNet_atrous_2_unit(nfilters,_dilation_rates=[3,5], _norm_type = _norm_type) 70 | self.pool4 = DownSample(nfilters, _norm_type = _norm_type) 71 | 72 | # Usually 512 73 | nfilters = self.nfilters * 2**(4) 74 | if verbose: 75 | print ("depth:= {0}, nfilters: {1}".format(4,nfilters)) 76 | self.Dn5 = ResNet_v2_unit(nfilters, _norm_type = _norm_type) 77 | self.pool5 = DownSample(nfilters, _norm_type = _norm_type) 78 | 79 | # Usually 1024 80 | nfilters = self.nfilters * 2**(5) 81 | if verbose: 82 | print ("depth:= {0}, nfilters: {1}".format(5,nfilters)) 83 | self.Dn6 = ResNet_v2_unit(nfilters, _norm_type = _norm_type) 84 | self.pool6 = DownSample(nfilters, _norm_type = _norm_type) 85 | 86 | 87 | # Usually 2048 88 | nfilters = self.nfilters * 2**(6) 89 | if verbose: 90 | print ("depth:= {0}, nfilters: {1}".format(6,nfilters)) 91 | self.Dn7 = ResNet_v2_unit(nfilters, _norm_type = _norm_type) 92 | 93 | 94 | # Same number of filters, with new definition 95 | self.middle = PSP_Pooling(nfilters, _norm_type = _norm_type, depth = psp_depth) 96 | 97 | def hybrid_forward(self,F,_input): 98 | 99 | # First convolution 100 | conv1 = self.conv_first_normed(_input) 101 | conv1 = F.relu(conv1) 102 | 103 | 104 | Dn1 = self.Dn1(conv1) 105 | pool1 = self.pool1(Dn1) 106 | 107 | Dn2 = self.Dn2(pool1) 108 | pool2 = self.pool2(Dn2) 109 | 110 | Dn3 = self.Dn3(pool2) 111 | pool3 = self.pool3(Dn3) 112 | 113 | Dn4 = self.Dn4(pool3) 114 | pool4 = self.pool4(Dn4) 115 | 116 | Dn5 = self.Dn5(pool4) 117 | pool5 = self.pool5(Dn5) 118 | 119 | 120 | Dn6 = self.Dn6(pool5) 121 | pool6 = self.pool6(Dn6) 122 | 123 | Dn7 = self.Dn7(pool6) 124 | 125 | #pool7 = self.pool7(Dn7) 126 | #pool7 = F.UpSampling(pool7, scale=2, sample_type='nearest') 127 | #out = F.concat(Dn7,pool7,dim=1) 128 | #out = F.relu(out) 129 | 130 | middle = self.middle(Dn7) 131 | middle = F.relu(middle) # Restore channels 132 | 133 | 134 | return middle 135 | 136 | -------------------------------------------------------------------------------- /models/resunet_d6_causal_mtskcolor_ddist.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from mxnet import gluon 3 | from mxnet.gluon import HybridBlock 4 | 5 | 6 | from resuneta.nn.Units.resnet_units import * 7 | from resuneta.nn.Units.resnet_atrous_units import * 8 | from resuneta.nn.pooling.psp_pooling import * 9 | from resuneta.nn.layers.scale import * 10 | from resuneta.nn.layers.combine import * 11 | 12 | 13 | from resuneta.models.resunet_d6_encoder import * 14 | 15 | 16 | class ResUNet_d6(HybridBlock): 17 | """ 18 | This will be used for 256x256 image input, so the atrous convolutions should be determined by the depth 19 | """ 20 | 21 | def __init__(self,_nfilters_init, _NClasses, verbose=True, _norm_type = 'BatchNorm', **kwards): 22 | HybridBlock.__init__(self,**kwards) 23 | 24 | self.model_name = "ResUNet_d6" 25 | 26 | self.depth = 6 27 | 28 | self.nfilters = _nfilters_init # Initial number of filters 29 | self.NClasses = _NClasses 30 | 31 | # Provide a flexibility in Normalization layers, test both 32 | #self.NormLayer = InstanceNorm 33 | #self.NormLayer = gluon.nn.BatchNorm 34 | 35 | 36 | with self.name_scope(): 37 | 38 | 39 | self.encoder = ResUNet_d6_encoder(self.nfilters, self.NClasses,_norm_type=_norm_type, verbose=verbose) 40 | 41 | 42 | nfilters = self.nfilters * 2 ** (self.depth - 1 -1) 43 | if verbose: 44 | print ("depth:= {0}, nfilters: {1}".format(6,nfilters)) 45 | self.UpComb1 = combine_layers(nfilters) 46 | self.UpConv1 = ResNet_atrous_2_unit(nfilters,_dilation_rates=[3,5]) 47 | 48 | nfilters = self.nfilters * 2 ** (self.depth - 1 -2) 49 | if verbose: 50 | print ("depth:= {0}, nfilters: {1}".format(7,nfilters)) 51 | self.UpComb2 = combine_layers(nfilters) 52 | self.UpConv2 = ResNet_atrous_2_unit(nfilters) 53 | 54 | nfilters = self.nfilters * 2 ** (self.depth - 1 -3) 55 | if verbose: 56 | print ("depth:= {0}, nfilters: {1}".format(8,nfilters)) 57 | self.UpComb3 = combine_layers(nfilters) 58 | self.UpConv3 = ResNet_atrous_unit(nfilters) 59 | 60 | nfilters = self.nfilters * 2 ** (self.depth - 1 -4) 61 | if verbose: 62 | print ("depth:= {0}, nfilters: {1}".format(9,nfilters)) 63 | self.UpComb4 = combine_layers(nfilters) 64 | self.UpConv4 = ResNet_atrous_unit(nfilters) 65 | 66 | 67 | nfilters = self.nfilters * 2 ** (self.depth - 1 -5) 68 | if verbose: 69 | print ("depth:= {0}, nfilters: {1}".format(10,nfilters)) 70 | self.UpComb5 = combine_layers(nfilters) 71 | self.UpConv5 = ResNet_atrous_unit(nfilters) 72 | 73 | 74 | self.psp_2ndlast = PSP_Pooling(self.nfilters, _norm_type = _norm_type) 75 | 76 | # Segmenetation logits -- deeper for better reconstruction 77 | self.logits = gluon.nn.HybridSequential() 78 | self.logits.add( Conv2DNormed(channels = self.nfilters,kernel_size = (3,3),padding=(1,1))) 79 | self.logits.add( gluon.nn.Activation('relu')) 80 | self.logits.add( Conv2DNormed(channels = self.nfilters,kernel_size = (3,3),padding=(1,1))) 81 | self.logits.add( gluon.nn.Activation('relu')) 82 | self.logits.add( gluon.nn.Conv2D(self.NClasses,kernel_size=1,padding=0)) 83 | 84 | # bound logits 85 | self.bound_logits = gluon.nn.HybridSequential() 86 | self.bound_logits.add( Conv2DNormed(channels = self.nfilters,kernel_size = (3,3),padding=(1,1))) 87 | self.bound_logits.add( gluon.nn.Activation('relu')) 88 | self.bound_logits.add( gluon.nn.Conv2D(self.NClasses,kernel_size=1,padding=0)) 89 | 90 | 91 | # distance logits -- deeper for better reconstruction 92 | self.distance_logits = gluon.nn.HybridSequential() 93 | self.distance_logits.add( Conv2DNormed(channels = self.nfilters,kernel_size = (3,3),padding=(1,1))) 94 | self.distance_logits.add( gluon.nn.Activation('relu')) 95 | self.distance_logits.add( Conv2DNormed(channels = self.nfilters,kernel_size = (3,3),padding=(1,1))) 96 | self.distance_logits.add( gluon.nn.Activation('relu')) 97 | self.distance_logits.add( gluon.nn.Conv2D(self.NClasses,kernel_size=1,padding=0)) 98 | 99 | 100 | # This layer is trying to identify the exact coloration on HSV scale (cv2 devined) 101 | self.color_logits = gluon.nn.Conv2D(3,kernel_size=1,padding=0) 102 | 103 | 104 | 105 | # Last activation, customization for binary results 106 | if ( self.NClasses == 1): 107 | self.ChannelAct = gluon.nn.HybridLambda(lambda F,x: F.sigmoid(x)) 108 | else: 109 | self.ChannelAct = gluon.nn.HybridLambda(lambda F,x: F.softmax(x,axis=1)) 110 | 111 | def hybrid_forward(self,F,_input): 112 | 113 | # First convolution 114 | conv1 = self.encoder.conv_first_normed(_input) 115 | conv1 = F.relu(conv1) 116 | 117 | 118 | Dn1 = self.encoder.Dn1(conv1) 119 | pool1 = self.encoder.pool1(Dn1) 120 | 121 | 122 | Dn2 = self.encoder.Dn2(pool1) 123 | pool2 = self.encoder.pool2(Dn2) 124 | 125 | Dn3 = self.encoder.Dn3(pool2) 126 | pool3 = self.encoder.pool3(Dn3) 127 | 128 | Dn4 = self.encoder.Dn4(pool3) 129 | pool4 = self.encoder.pool4(Dn4) 130 | 131 | Dn5 = self.encoder.Dn5(pool4) 132 | pool5 = self.encoder.pool5(Dn5) 133 | 134 | 135 | Dn6 = self.encoder.Dn6(pool5) 136 | 137 | middle = self.encoder.middle(Dn6) 138 | middle = F.relu(middle) # Activation of middle layers 139 | 140 | 141 | UpComb1 = self.UpComb1(middle,Dn5) 142 | UpConv1 = self.UpConv1(UpComb1) 143 | 144 | UpComb2 = self.UpComb2(UpConv1,Dn4) 145 | UpConv2 = self.UpConv2(UpComb2) 146 | 147 | UpComb3 = self.UpComb3(UpConv2,Dn3) 148 | UpConv3 = self.UpConv3(UpComb3) 149 | 150 | UpComb4 = self.UpComb4(UpConv3,Dn2) 151 | UpConv4 = self.UpConv4(UpComb4) 152 | 153 | UpComb5 = self.UpComb5(UpConv4,Dn1) 154 | UpConv5 = self.UpConv5(UpComb5) 155 | 156 | # second last layer 157 | convl = F.concat(conv1,UpConv5) 158 | conv = self.psp_2ndlast(convl) 159 | conv = F.relu(conv) 160 | 161 | # logits 162 | # 1st find distance map, skeleton like, topology info 163 | dist = self.distance_logits(convl) # Modification here, do not use max pooling for distance 164 | #dist = F.softmax(dist,axis=1) 165 | dist = self.ChannelAct(dist) 166 | 167 | # Then find boundaries 168 | bound = F.concat(conv,dist) 169 | bound = self.bound_logits(bound) 170 | bound = F.sigmoid(bound) # Boundaries are not mutually exclusive the way I am creating them. 171 | # Color prediction (HSV - cv2) 172 | convc = self.color_logits(convl) 173 | # HSV (cv2) color prediction 174 | convc = F.sigmoid(convc) # This will be for self-supervised as well 175 | 176 | # Finally, find segmentation mask 177 | logits = F.concat(conv,bound,dist) 178 | logits = self.logits(logits) 179 | #logits = F.softmax(logits,axis=1) 180 | logits = self.ChannelAct(logits) 181 | 182 | return logits, bound, dist, convc 183 | 184 | 185 | -------------------------------------------------------------------------------- /models/resunet_d7_causal_mtskcolor_ddist.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from mxnet import gluon 3 | from mxnet.gluon import HybridBlock 4 | 5 | 6 | from resuneta.nn.Units.resnet_units import * 7 | from resuneta.nn.Units.resnet_atrous_units import * 8 | from resuneta.nn.pooling.psp_pooling import * 9 | from resuneta.nn.layers.scale import * 10 | from resuneta.nn.layers.combine import * 11 | 12 | from resuneta.models.resunet_d7_encoder import * 13 | 14 | 15 | class ResUNet_d7(HybridBlock): 16 | """ 17 | This will be used for 256x256 image input, so the atrous convolutions should be determined by the depth 18 | """ 19 | 20 | def __init__(self,_nfilters_init, _NClasses, verbose=True, _norm_type = 'BatchNorm', **kwards): 21 | HybridBlock.__init__(self,**kwards) 22 | 23 | self.model_name = "ResUNet_d7_cmtskc" 24 | 25 | self.depth = 7 26 | 27 | self.nfilters = _nfilters_init # Initial number of filters 28 | self.NClasses = _NClasses 29 | 30 | 31 | 32 | with self.name_scope(): 33 | 34 | 35 | self.encoder = ResUNet_d7_encoder(self.nfilters, self.NClasses,_norm_type=_norm_type) 36 | 37 | 38 | nfilters = self.nfilters * 2 ** (self.depth - 1 -1) 39 | if verbose: 40 | print ("depth:= {0}, nfilters: {1}".format(7,nfilters)) 41 | self.UpComb1 = combine_layers(nfilters,_norm_type=_norm_type) 42 | self.UpConv1 = ResNet_v2_unit(nfilters,_norm_type=_norm_type) 43 | 44 | nfilters = self.nfilters * 2 ** (self.depth - 1 -2) 45 | if verbose: 46 | print ("depth:= {0}, nfilters: {1}".format(8,nfilters)) 47 | self.UpComb2 = combine_layers(nfilters,_norm_type=_norm_type) 48 | self.UpConv2 = ResNet_atrous_2_unit(nfilters,_norm_type=_norm_type) 49 | 50 | nfilters = self.nfilters * 2 ** (self.depth - 1 -3) 51 | if verbose: 52 | print ("depth:= {0}, nfilters: {1}".format(9,nfilters)) 53 | self.UpComb3 = combine_layers(nfilters,_norm_type=_norm_type) 54 | self.UpConv3 = ResNet_atrous_unit(nfilters,_norm_type=_norm_type) 55 | 56 | nfilters = self.nfilters * 2 ** (self.depth - 1 -4) 57 | if verbose: 58 | print ("depth:= {0}, nfilters: {1}".format(10,nfilters)) 59 | self.UpComb4 = combine_layers(nfilters,_norm_type=_norm_type) 60 | self.UpConv4 = ResNet_atrous_unit(nfilters,_norm_type=_norm_type) 61 | 62 | 63 | nfilters = self.nfilters * 2 ** (self.depth - 1 -5) 64 | if verbose: 65 | print ("depth:= {0}, nfilters: {1}".format(11,nfilters)) 66 | self.UpComb5 = combine_layers(nfilters,_norm_type=_norm_type) 67 | self.UpConv5 = ResNet_atrous_unit(nfilters,_norm_type=_norm_type) 68 | 69 | nfilters = self.nfilters * 2 ** (self.depth - 1 -6) 70 | if verbose: 71 | print ("depth:= {0}, nfilters: {1}".format(12,nfilters)) 72 | self.UpComb6 = combine_layers(nfilters,_norm_type=_norm_type) 73 | self.UpConv6 = ResNet_atrous_unit(nfilters,_norm_type=_norm_type) 74 | 75 | 76 | self.psp_2ndlast = PSP_Pooling(self.nfilters, _norm_type = _norm_type) 77 | 78 | # Segmenetation logits -- deeper for better reconstruction 79 | self.logits = gluon.nn.HybridSequential() 80 | self.logits.add( Conv2DNormed(channels = self.nfilters,kernel_size = (3,3),padding=(1,1))) 81 | self.logits.add( gluon.nn.Activation('relu')) 82 | self.logits.add( Conv2DNormed(channels = self.nfilters,kernel_size = (3,3),padding=(1,1))) 83 | self.logits.add( gluon.nn.Activation('relu')) 84 | self.logits.add( gluon.nn.Conv2D(self.NClasses,kernel_size=1,padding=0)) 85 | 86 | # bound logits 87 | self.bound_logits = gluon.nn.HybridSequential() 88 | self.bound_logits.add( Conv2DNormed(channels = self.nfilters,kernel_size = (3,3),padding=(1,1))) 89 | self.bound_logits.add( gluon.nn.Activation('relu')) 90 | self.bound_logits.add( gluon.nn.Conv2D(self.NClasses,kernel_size=1,padding=0)) 91 | 92 | 93 | # distance logits -- deeper for better reconstruction 94 | self.distance_logits = gluon.nn.HybridSequential() 95 | self.distance_logits.add( Conv2DNormed(channels = self.nfilters,kernel_size = (3,3),padding=(1,1))) 96 | self.distance_logits.add( gluon.nn.Activation('relu')) 97 | self.distance_logits.add( Conv2DNormed(channels = self.nfilters,kernel_size = (3,3),padding=(1,1))) 98 | self.distance_logits.add( gluon.nn.Activation('relu')) 99 | self.distance_logits.add( gluon.nn.Conv2D(self.NClasses,kernel_size=1,padding=0)) 100 | 101 | 102 | 103 | 104 | # This layer is trying to identify the exact coloration on HSV scale (cv2 devined) 105 | self.color_logits = gluon.nn.Conv2D(3,kernel_size=1,padding=0) 106 | 107 | 108 | 109 | # Last activation, customization for binary results 110 | if ( self.NClasses == 1): 111 | self.ChannelAct = gluon.nn.HybridLambda(lambda F,x: F.sigmoid(x)) 112 | else: 113 | self.ChannelAct = gluon.nn.HybridLambda(lambda F,x: F.softmax(x,axis=1)) 114 | 115 | 116 | 117 | def hybrid_forward(self,F,_input): 118 | 119 | # First convolution 120 | conv1 = self.encoder.conv_first_normed(_input) 121 | conv1 = F.relu(conv1) 122 | 123 | 124 | Dn1 = self.encoder.Dn1(conv1) 125 | pool1 = self.encoder.pool1(Dn1) 126 | 127 | 128 | Dn2 = self.encoder.Dn2(pool1) 129 | pool2 = self.encoder.pool2(Dn2) 130 | 131 | Dn3 = self.encoder.Dn3(pool2) 132 | pool3 = self.encoder.pool3(Dn3) 133 | 134 | Dn4 = self.encoder.Dn4(pool3) 135 | pool4 = self.encoder.pool4(Dn4) 136 | 137 | Dn5 = self.encoder.Dn5(pool4) 138 | pool5 = self.encoder.pool5(Dn5) 139 | 140 | 141 | Dn6 = self.encoder.Dn6(pool5) 142 | pool6 = self.encoder.pool6(Dn6) 143 | 144 | 145 | Dn7 = self.encoder.Dn7(pool6) 146 | middle = self.encoder.middle(Dn7) 147 | middle = F.relu(middle) # Activation of middle layers 148 | 149 | 150 | UpComb1 = self.UpComb1(middle,Dn6) 151 | UpConv1 = self.UpConv1(UpComb1) 152 | 153 | UpComb2 = self.UpComb2(UpConv1,Dn5) 154 | UpConv2 = self.UpConv2(UpComb2) 155 | 156 | UpComb3 = self.UpComb3(UpConv2,Dn4) 157 | UpConv3 = self.UpConv3(UpComb3) 158 | 159 | UpComb4 = self.UpComb4(UpConv3,Dn3) 160 | UpConv4 = self.UpConv4(UpComb4) 161 | 162 | UpComb5 = self.UpComb5(UpConv4,Dn2) 163 | UpConv5 = self.UpConv5(UpComb5) 164 | 165 | UpComb6 = self.UpComb6(UpConv5,Dn1) 166 | UpConv6 = self.UpConv6(UpComb6) 167 | 168 | # second last layer 169 | convl = F.concat(conv1,UpConv6) 170 | conv = self.psp_2ndlast(convl) 171 | conv = F.relu(conv) 172 | 173 | # logits 174 | 175 | # 1st find distance map, skeleton like, topology info 176 | dist = self.distance_logits(convl) # Modification here, do not use max pooling for distance 177 | #dist = F.softmax(dist,axis=1) 178 | dist = self.ChannelAct(dist) 179 | 180 | # Then find boundaries 181 | bound = F.concat(conv,dist) 182 | bound = self.bound_logits(bound) 183 | bound = F.sigmoid(bound) # Boundaries are not mutually exclusive the way I am creating them. 184 | # Color prediction (HSV - cv2) 185 | convc = self.color_logits(convl) 186 | # HSV (cv2) color prediction 187 | convc = F.sigmoid(convc) # This will be for self-supervised as well 188 | 189 | # Finally, find segmentation mask 190 | logits = F.concat(conv,bound,dist) 191 | logits = self.logits(logits) 192 | #logits = F.softmax(logits,axis=1) 193 | logits = self.ChannelAct(logits) 194 | 195 | return logits, bound, dist, convc 196 | 197 | -------------------------------------------------------------------------------- /src/chopchop_run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code: slicing of large raster images in image patches of window size F (= 256). In this code, the ~10% of the area of each image 3 | is kept as validation data. To achieve this we keep the lowest (bottom right) 10% of each tile as validation data. This is done by 4 | using all the indices corresponding to the lowest 10% of area (i.e. after the ~70% of the length of each area). 5 | 6 | Area_test = (0.3 * Height) * (0.3 * Width) ~ 0.1 * Height*Width 7 | """ 8 | 9 | 10 | import rasterio 11 | import numpy as np 12 | import glob 13 | import cv2 14 | import uuid 15 | from pathos.pools import ThreadPool as pp 16 | 17 | 18 | # ********************************** CONSTANTS ************************************* 19 | # Class definitions 20 | # New fast access, from stackoverflow: https://stackoverflow.com/questions/53059201/how-to-convert-3d-rgb-label-imagein-semantic-segmentation-to-2d-gray-image-an 21 | # ****************************************************************** 22 | NClasses = 6 # Looking at the data I am treating "background" as a separate class. 23 | Background = np.array([255,0,0]) #:{'name':'Background','cType':0}, 24 | ImSurf = np.array ([255,255,255])# :{'name':'ImSurf','cType':1}, 25 | Car = np.array([255,255,0]) # :{'name':'Car','cType':2}, 26 | Building = np.array([0,0,255]) #:{'name':'Building','cType':3}, 27 | LowVeg = np.array([0,255,255]) # :{'name':'LowVeg','cType':4}, 28 | Tree = np.array([0,255,0]) # :{'name':'Tree','cType':5} 29 | # ****************************************************************** 30 | 31 | 32 | # READING DATA 33 | # @@@@@@@@@@@@@@@@@ REPLACE THIS WITH YOUR DATA DIRECTORY 34 | read_prefix= r'/flush1/dia021/isprs_potsdam/raw_data/' 35 | prefix_imgs = r'4_Ortho_RGBIR/' 36 | prefix_dems = r'1_DSM_normalisation/' 37 | prefix_labels = r'5_Labels_for_participants/' 38 | 39 | 40 | 41 | flnames_imgs = sorted(glob.glob(read_prefix+prefix_imgs+'*.tif')) 42 | flnames_dems = sorted(glob.glob(read_prefix+prefix_dems+'*.tif')) 43 | flnames_labels = sorted(glob.glob(read_prefix+prefix_labels+'*.tif')) 44 | 45 | 46 | IDs = [] 47 | for name in flnames_labels: 48 | IDs +=[name.replace(read_prefix + '5_Labels_for_participants/top_potsdam_','').replace('_label.tif','')] 49 | 50 | 51 | 52 | 53 | # Helper functions to create boundary and distance transform 54 | # Expect ground trouth label in 1hot format 55 | # Necessary for fixing an error in the data: 56 | def img_transform(_img): 57 | new_size = 6000 58 | _nchannels=_img.shape[0] 59 | img = np.transpose(_img,[1,2,0]) 60 | img = cv2.resize(img,(new_size,new_size),interpolation= cv2.INTER_NEAREST) 61 | #img = transform.resize(img,(new_size,new_size,_nchannels),preserve_range=True) 62 | img = np.transpose(img,[2,0,1]) 63 | #img = img.astype('uint8') 64 | 65 | return img 66 | 67 | 68 | def get_boundary(labels, _kernel_size = (3,3)): 69 | 70 | label = labels.copy() 71 | for channel in range(label.shape[0]): 72 | temp = cv2.Canny(label[channel],0,1) 73 | label[channel] = cv2.dilate(temp, cv2.getStructuringElement(cv2.MORPH_CROSS,_kernel_size) ,iterations = 1) 74 | 75 | label = label.astype(np.float32) 76 | label /= 255. 77 | #label = label.astype(np.uint8) 78 | return label 79 | 80 | def get_distance(labels): 81 | label = labels.copy() 82 | #print (label.shape) 83 | dists = np.empty_like(label,dtype=np.float32) 84 | for channel in range(label.shape[0]): 85 | dist = cv2.distanceTransform(label[channel], cv2.DIST_L2, 0) 86 | dist = cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX) 87 | dists[channel] = dist 88 | 89 | return dists 90 | 91 | 92 | 93 | 94 | 95 | def ID_2_filenames(_ID): 96 | 97 | if (len(_ID[2:]) == 1 ): 98 | ID_dsm = '0'+_ID[0]+'_'+'0'+_ID[2] 99 | else: 100 | ID_dsm = '0'+_ID 101 | 102 | label_name = r'top_potsdam_' + _ID + '_label.tif' 103 | dsm_name = r'dsm_potsdam_' + ID_dsm + '_normalized_lastools.jpg' 104 | img_name = r'top_potsdam_'+_ID + '_RGBIR.tif' 105 | 106 | return label_name, img_name, dsm_name 107 | 108 | 109 | 110 | def read_n_stack(_ID): 111 | """ 112 | Given and ID string, returns stacked img (RGBIR+DEMS) and label (RGB) 113 | It fixes a bug in the data, one having dimension one pixel less 114 | """ 115 | tflname_label, tflname_img, tflname_dems = ID_2_filenames(_ID) 116 | 117 | tflname_label = read_prefix + prefix_labels + tflname_label 118 | tflname_img = read_prefix + prefix_imgs + tflname_img 119 | tflname_dems = read_prefix + prefix_dems + tflname_dems 120 | 121 | 122 | # read label: 123 | with rasterio.open(tflname_label) as src: 124 | label = src.read() 125 | if label.shape[1:] != (6000,6000): 126 | label = img_transform(label) 127 | 128 | # read image 129 | with rasterio.open(tflname_img) as src: 130 | img = src.read() 131 | if img.shape[1:] != (6000,6000): 132 | img = img_transform(img) 133 | 134 | # read DEMs 135 | with rasterio.open(tflname_dems) as src: 136 | dems = src.read() 137 | if dems.shape[1:] != (6000,6000): 138 | dems = img_transform(dems) 139 | 140 | img = np.concatenate([img,dems],axis=0) 141 | 142 | return img, label 143 | 144 | 145 | # Fast version to translate class RGB tuples to integer indices 146 | def rgb_to_2D_label(_label): 147 | label_seg = np.zeros(_label.shape[1:],dtype=np.uint8) 148 | label_seg [np.all(_label.transpose([1,2,0])==Background,axis=-1)] = 0 149 | label_seg [np.all(_label.transpose([1,2,0])==ImSurf,axis=-1)] = 1 150 | label_seg [np.all(_label.transpose([1,2,0])==Car,axis=-1)] = 2 151 | label_seg [np.all(_label.transpose([1,2,0])==Building,axis=-1)] = 3 152 | label_seg [np.all(_label.transpose([1,2,0])==LowVeg,axis=-1)] = 4 153 | label_seg [np.all(_label.transpose([1,2,0])==Tree,axis=-1)] = 5 154 | 155 | return label_seg 156 | 157 | 158 | # translates image to 1H encoding 159 | def rgb_to_1Hlabel(_label): 160 | teye = np.eye(NClasses,dtype=np.uint8) 161 | 162 | label_seg = np.zeros([*_label.shape[1:],NClasses],dtype=np.uint8) 163 | label_seg [np.all(_label.transpose([1,2,0])==Background,axis=-1)] = teye[0] 164 | label_seg [np.all(_label.transpose([1,2,0])==ImSurf,axis=-1)] = teye[1] 165 | label_seg [np.all(_label.transpose([1,2,0])==Car,axis=-1)] = teye[2] 166 | label_seg [np.all(_label.transpose([1,2,0])==Building,axis=-1)] = teye[3] 167 | label_seg [np.all(_label.transpose([1,2,0])==LowVeg,axis=-1)] = teye[4] 168 | label_seg [np.all(_label.transpose([1,2,0])==Tree,axis=-1)] = teye[5] 169 | 170 | return label_seg.transpose([2,0,1]) 171 | 172 | def ID_preprocessing(_ID): 173 | 174 | img, label = read_n_stack(_ID) 175 | label = rgb_to_1Hlabel(label) # This makes label in 1H 176 | 177 | return img,label 178 | 179 | 180 | # Read img names and corresponding masks. 181 | Filter = 256 # Window size of patches 182 | stride = Filter // 2 # This is the stride so as to capture edge effects. 183 | length_scale = 0.317 # when squared gives ~ 0.1, i.e. 10% of all area 184 | 185 | 186 | ###prefix_global_write = r'/flush1/dia021/isprs_potsdam/Data_6k/' 187 | prefix_global_write = r'/scratch1/dia021/isprs_potsdam/Data_6k/' 188 | 189 | 190 | # ********************** END OF CONSTANTS ****************************************** 191 | 192 | 193 | 194 | # This function is used for parallelization 195 | def img_n_tens_slice(_img_ID): 196 | 197 | # Training images location 198 | write_dir_img_train = prefix_global_write + 'training/imgs/' 199 | write_dir_label_train = prefix_global_write + 'training/masks/' 200 | 201 | # Validation images location 202 | write_dir_img_val = prefix_global_write + 'validation/imgs/' 203 | write_dir_label_val = prefix_global_write + 'validation/masks/' 204 | 205 | 206 | print ("reading img,masks ID::{}".format(_img_ID)) 207 | 208 | _img, _masks = ID_preprocessing(_img_ID) 209 | 210 | 211 | nTimesRows = int((_img.shape[1] - Filter)//stride + 1) 212 | nTimesCols = int((_img.shape[2] - Filter)//stride + 1) 213 | 214 | 215 | # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 216 | # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 217 | # This is for keeping a validation set 218 | nTimesRows_val = int((1.0-length_scale)*nTimesRows) 219 | nTimesCols_val = int((1.0-length_scale)*nTimesCols) 220 | # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 221 | # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 222 | 223 | for row in range(nTimesRows-1): 224 | for col in range(nTimesCols-1): 225 | 226 | # Extract temporary 227 | timg = _img[:, row*stride:row*stride+Filter, col*stride:col*stride+Filter] 228 | tmask_1hot = _masks[:,row*stride:row*stride+Filter, col*stride:col*stride+Filter] 229 | 230 | # TODO: create boundary/distance on the fly? 231 | tbound = get_boundary(tmask_1hot) 232 | tdist = get_distance(tmask_1hot) 233 | # Aggregate all masks together in a single entity 234 | tmask_all = np.concatenate([tmask_1hot,tbound,tdist],axis=0) 235 | 236 | run_ID = str(uuid.uuid4()) 237 | if row >= nTimesRows_val and col >= nTimesCols_val : 238 | timg_name = write_dir_img_val + 'img-' + run_ID +'.npy' 239 | tmask_name = write_dir_label_val + 'img-'+ run_ID +'-mask.npy' 240 | else: 241 | timg_name = write_dir_img_train + 'img-' + run_ID +'.npy' 242 | tmask_name = write_dir_label_train + 'img-'+ run_ID +'-mask.npy' 243 | 244 | 245 | np.save(timg_name, timg) 246 | np.save(tmask_name, tmask_all) 247 | 248 | # Keep the overlapping non integer final row/column images as validation images as well 249 | rev_row = _img.shape[1] - Filter 250 | rev_col = _img.shape[2] - Filter 251 | for row in range(nTimesRows-1): 252 | timg = _img [:, row*stride:row*stride+Filter, rev_col:] 253 | tmask_1hot = _masks[:, row*stride:row*stride+Filter, rev_col:] 254 | 255 | tbound = get_boundary(tmask_1hot) 256 | tdist = get_distance(tmask_1hot) 257 | # Aggregate all masks together in a single entity 258 | tmask_all = np.concatenate([tmask_1hot,tbound,tdist],axis=0) 259 | run_ID = str(uuid.uuid4()) 260 | timg_name = write_dir_img_val + 'img-' + run_ID +'.npy' 261 | tmask_name = write_dir_label_val + 'img-'+ run_ID +'-mask.npy' 262 | 263 | np.save(timg_name, timg) 264 | np.save(tmask_name, tmask_all) 265 | 266 | 267 | for col in range(nTimesCols-1): 268 | timg = _img [:, rev_row:, col*stride:col*stride + Filter] 269 | tmask_1hot = _masks[:, rev_row:, col*stride:col*stride + Filter] 270 | 271 | tbound = get_boundary(tmask_1hot) 272 | tdist = get_distance(tmask_1hot) 273 | # Aggregate all masks together in a single entity 274 | tmask_all = np.concatenate([tmask_1hot,tbound,tdist],axis=0) 275 | run_ID = str(uuid.uuid4()) 276 | timg_name = write_dir_img_val + 'img-' + run_ID +'.npy' 277 | tmask_name = write_dir_label_val + 'img-'+ run_ID +'-mask.npy' 278 | 279 | np.save(timg_name, timg) 280 | np.save(tmask_name, tmask_all) 281 | 282 | 283 | 284 | 285 | 286 | if __name__ == '__main__': 287 | 288 | # Process each node in parallel 289 | nnodes = int(16) # Change with the number of your CPUs 290 | pool = pp(nodes=nnodes) 291 | 292 | 293 | # These are the training images -- processing in parallel 294 | pool.map(img_n_tens_slice,IDs) 295 | 296 | 297 | 298 | 299 | 300 | --------------------------------------------------------------------------------