├── .gitignore ├── LICENSE ├── README.md ├── chopchop ├── Data_slicing.py └── chopchop2rec.py ├── demo └── Demo_forward_backward.ipynb ├── doc ├── csse_talk_diakogiannis.pdf └── manuscript.pdf ├── images ├── img_1.png ├── img_101.png ├── img_2.png ├── img_21.png ├── img_3.png ├── img_71.png ├── img_72.png ├── img_77.png ├── img_80.png └── mantis.jpg ├── models ├── changedetection │ └── mantis │ │ ├── mantis_dn.py │ │ └── mantis_dn_features.py ├── heads │ └── head_cmtsk.py └── semanticsegmentation │ └── x_unet │ ├── x_dn.py │ └── x_dn_features.py ├── nn ├── __init__.py ├── activations │ ├── __init__.py │ └── sigmoid_crisp.py ├── layers │ ├── __init__.py │ ├── attention.py │ ├── combine.py │ ├── conv2Dnormed.py │ ├── ftnmt.py │ └── scale.py ├── loss │ ├── __init__.py │ ├── ftnmt_loss.py │ └── mtsk_loss.py ├── pooling │ ├── __init__.py │ └── psp_pooling.py └── units │ ├── __init__.py │ ├── ceecnet.py │ └── fractal_resnet.py ├── requirements.txt ├── src ├── LVRCDDataset.py ├── LVRCDNormal.py └── semseg_aug_cv2.py └── utils ├── __init__.py └── get_norm.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | CSIRO Open Source Software Licence Agreement (variation of the BSD / MIT License) 2 | Copyright (c) ceecnet, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. 3 | All rights reserved. CSIRO is willing to grant you a licence to ceecnet on the following terms, except where otherwise indicated for third party material. 4 | Redistribution and use of this software in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | * Neither the name of CSIRO nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission of CSIRO. 8 | EXCEPT AS EXPRESSLY STATED IN THIS AGREEMENT AND TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, THE SOFTWARE IS PROVIDED "AS-IS". CSIRO MAKES NO REPRESENTATIONS, WARRANTIES OR CONDITIONS OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY REPRESENTATIONS, WARRANTIES OR CONDITIONS REGARDING THE CONTENTS OR ACCURACY OF THE SOFTWARE, OR OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, THE ABSENCE OF LATENT OR OTHER DEFECTS, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. 9 | TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL CSIRO BE LIABLE ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, IN AN ACTION FOR BREACH OF CONTRACT, NEGLIGENCE OR OTHERWISE) FOR ANY CLAIM, LOSS, DAMAGES OR OTHER LIABILITY HOWSOEVER INCURRED. WITHOUT LIMITING THE SCOPE OF THE PREVIOUS SENTENCE THE EXCLUSION OF LIABILITY SHALL INCLUDE: LOSS OF PRODUCTION OR OPERATION TIME, LOSS, DAMAGE OR CORRUPTION OF DATA OR RECORDS; OR LOSS OF ANTICIPATED SAVINGS, OPPORTUNITY, REVENUE, PROFIT OR GOODWILL, OR OTHER ECONOMIC LOSS; OR ANY SPECIAL, INCIDENTAL, INDIRECT, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES, ARISING OUT OF OR IN CONNECTION WITH THIS AGREEMENT, ACCESS OF THE SOFTWARE OR ANY OTHER DEALINGS WITH THE SOFTWARE, EVEN IF CSIRO HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH CLAIM, LOSS, DAMAGES OR OTHER LIABILITY. 10 | APPLICABLE LEGISLATION SUCH AS THE AUSTRALIAN CONSUMER LAW MAY APPLY REPRESENTATIONS, WARRANTIES, OR CONDITIONS, OR IMPOSES OBLIGATIONS OR LIABILITY ON CSIRO THAT CANNOT BE EXCLUDED, RESTRICTED OR MODIFIED TO THE FULL EXTENT SET OUT IN THE EXPRESS TERMS OF THIS CLAUSE ABOVE "CONSUMER GUARANTEES". TO THE EXTENT THAT SUCH CONSUMER GUARANTEES CONTINUE TO APPLY, THEN TO THE FULL EXTENT PERMITTED BY THE APPLICABLE LEGISLATION, THE LIABILITY OF CSIRO UNDER THE RELEVANT CONSUMER GUARANTEE IS LIMITED (WHERE PERMITTED AT CSIRO'S OPTION) TO ONE OF FOLLOWING REMEDIES OR SUBSTANTIALLY EQUIVALENT REMEDIES: 11 | (a) THE REPLACEMENT OF THE SOFTWARE, THE SUPPLY OF EQUIVALENT SOFTWARE, OR SUPPLYING RELEVANT SERVICES AGAIN; 12 | (b) THE REPAIR OF THE SOFTWARE; 13 | (c) THE PAYMENT OF THE COST OF REPLACING THE SOFTWARE, OF ACQUIRING EQUIVALENT SOFTWARE, HAVING THE RELEVANT SERVICES SUPPLIED AGAIN, OR HAVING THE SOFTWARE REPAIRED. 14 | IN THIS CLAUSE, CSIRO INCLUDES ANY THIRD PARTY AUTHOR OR OWNER OF ANY PART OF THE SOFTWARE OR MATERIAL DISTRIBUTED WITH IT. CSIRO MAY ENFORCE ANY RIGHTS ON BEHALF OF THE RELEVANT THIRD PARTY. 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ** UPDATE ** 2 | We added a pure semantic segmentation model that uses CEECNetV1, CEECNetV2 or FracTAL ResNet micro-topologies, based on a single encoder/decoder macro-topology (unet-like). These can be found in location: ```models/semanticsegmentation/x_unet```. 3 | 4 | # Looking for change? Roll the Dice and demand Attention 5 | ![mantis](images/img_3.png) 6 | 7 | Official [mxnet](https://mxnet.incubator.apache.org/) implementation of the paper: ["Looking for change? Roll the Dice and demand Attention" (arxiv version)](https://arxiv.org/abs/2009.02062), [Diakogiannis et al. (2020 - journal version)](https://www.mdpi.com/2072-4292/13/18/3707). This repository contains source code for implementing and training the mantis ceecnet/FracTAL ResNet as described in our manuscript. All models are built with the mxnet DL framework (version < 2.0), under the gluon api. We do not provide pre-trained weights. 8 | 9 | Inference examples for the task of Building change detection for the model mantis ceecnetV1 (fractal Tanimoto loss tdepth, ftdepth=5). From left to right, input image date 1, input image date 2, ground truth, inference, confidence heat map for the segmentation task. 10 | ![mantis](images/img_1.png) 11 | ![mantis](images/img_2.png) 12 | 13 | Inference examples for the task of Building change detection for the model mantis ceecnetV2 (ftdepth=5). Order of figures as above. These are for the difficult tiles 21, 71, 72, 77, 80, 101 from the test set. 14 | ![mantis](images/img_21.png) 15 | ![mantis](images/img_71.png) 16 | ![mantis](images/img_72.png) 17 | ![mantis](images/img_77.png) 18 | ![mantis](images/img_80.png) 19 | ![mantis](images/img_101.png) 20 | 21 | 22 | ### Directory structure: 23 | 24 | ``` 25 | . 26 | ├── chopchop 27 | ├── demo 28 | ├── doc 29 | ├── images 30 | ├── models 31 | │   ├── changedetection 32 | │   │   └── mantis 33 | │   ├── heads 34 | │   └── semanticsegmentation 35 | │   └── x_unet 36 | ├── nn 37 | │   ├── activations 38 | │   ├── layers 39 | │   ├── loss 40 | │   ├── pooling 41 | │   └── units 42 | ├── src 43 | └── utils 44 | ``` 45 | 46 | In directory ```chopchop``` exists code for splitting triplets of raster files (date1, date2, ground truth) into small training patches. It is tailored on the LEVIR CD dataset. In ```demo``` exists a notebooks that shows how to initiate a mantis ceecnet model, and perform forward and multitasking backward operations. In ```models/changedetection/mantis``` exists a generic definition for arbitrary depth and number of filters, that are described in our manuscript. In ```nn``` exist all the necessary building blocks to construct the models we present, as well as loss function definitions. In particular, in ```nn/loss``` we provide the average fractal Tanimoto with dual (file ```nn/loss/ftnmt_loss.py```), as well as a class that can be used for multitasking loss training. Users of this method may want to write their own custom implementation for multitasking training, based on the ```ftnmt_loss.py``` file. See ```demo``` for example usage with a specific ground truth labels configuration. In ```src``` we provide a mxnet Dataset class, as well as a normalization class. Finally, in utils, there exist a function for selecting BatchNormalization, or GroupNorm, as a paremeter. 47 | 48 | 49 | ### Datasets 50 | Users can find the datasets used in this publication in the following locations: 51 | LEVIR CD Dataset: https://justchenhao.github.io/LEVIR/ 52 | WHU Dataset: http://gpcv.whu.edu.cn/data/building_dataset.html 53 | 54 | 55 | ### License 56 | CSIRO BSTD/MIT LICENSE 57 | 58 | As a condition of this licence, you agree that where you make any adaptations, modifications, further developments, 59 | or additional features available to CSIRO or the public in connection with your access to the Software, you do so on the terms of the BSD 3-Clause Licence template, a copy available at: http://opensource.org/licenses/BSD-3-Clause. 60 | 61 | 62 | 63 | ### CITATION 64 | If you find the contents of this repository useful for your research, please cite: 65 | ``` 66 | @Article{rs13183707, 67 | AUTHOR = {Diakogiannis, Foivos I. and Waldner, François and Caccetta, Peter}, 68 | TITLE = {Looking for Change? Roll the Dice and Demand Attention}, 69 | JOURNAL = {Remote Sensing}, 70 | VOLUME = {13}, 71 | YEAR = {2021}, 72 | NUMBER = {18}, 73 | ARTICLE-NUMBER = {3707}, 74 | URL = {https://www.mdpi.com/2072-4292/13/18/3707}, 75 | ISSN = {2072-4292}, 76 | ABSTRACT = {Change detection, i.e., the identification per pixel of changes for some classes of interest from a set of bi-temporal co-registered images, is a fundamental task in the field of remote sensing. It remains challenging due to unrelated forms of change that appear at different times in input images. Here, we propose a deep learning framework for the task of semantic change detection in very high-resolution aerial images. Our framework consists of a new loss function, a new attention module, new feature extraction building blocks, and a new backbone architecture that is tailored for the task of semantic change detection. Specifically, we define a new form of set similarity that is based on an iterative evaluation of a variant of the Dice coefficient. We use this similarity metric to define a new loss function as well as a new, memory efficient, spatial and channel convolution Attention layer: the FracTAL. We introduce two new efficient self-contained feature extraction convolution units: the CEECNet and FracTALResNet units. Further, we propose a new encoder/decoder scheme, a network macro-topology, that is tailored for the task of change detection. The key insight in our approach is to facilitate the use of relative attention between two convolution layers in order to fuse them. We validate our approach by showing excellent performance and achieving state-of-the-art scores (F1 and Intersection over Union-hereafter IoU) on two building change detection datasets, namely, the LEVIRCD (F1: 0.918, IoU: 0.848) and the WHU (F1: 0.938, IoU: 0.882) datasets.}, 77 | DOI = {10.3390/rs13183707} 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /chopchop/Data_slicing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code: slicing of large raster images in image patches of window size F (= 256). In this code, the ~10% of the area of each image 3 | is kept as validation data. To achieve this we keep the lowest (bottom right) 10% of each tile as validation data. This is done by 4 | using all the indices corresponding to the lowest 10% of area (i.e. after the ~70% of the length of each area). 5 | 6 | Area_test = (0.3 * Height) * (0.3 * Width) ~= 0.1 * Height*Width 7 | """ 8 | 9 | 10 | from chopchop2rec import * 11 | 12 | # Reads triplet of names 13 | import glob 14 | flname_prefix_data = r'/Location/Of/LEVIRCD/Data/' 15 | flnames_images_A = sorted(glob.glob(flname_prefix_data + r'train/A/*.png')) 16 | flnames_images_A += sorted(glob.glob(flname_prefix_data + r'val/A/*.png')) 17 | flnames_images_A = sorted(flnames_images_A) 18 | 19 | 20 | flnames_images_B = sorted(glob.glob(flname_prefix_data + r'train/B/*.png')) 21 | flnames_images_B += sorted(glob.glob(flname_prefix_data + r'val/B/*.png')) 22 | flnames_images_B = sorted(flnames_images_B) 23 | 24 | 25 | flnames_images_chng = sorted(glob.glob(flname_prefix_data + r'train/label/*.png')) 26 | flnames_images_chng += sorted(glob.glob(flname_prefix_data + r'val/label/*.png')) 27 | flnames_images_chng = sorted(flnames_images_chng) 28 | 29 | 30 | listOfAll123 = list(zip(flnames_images_A,flnames_images_B,flnames_images_chng)) 31 | 32 | 33 | if __name__ == '__main__': 34 | 35 | print ("Starting Chopping...") 36 | mywriter = WriteDataRecordIO(listOfAll123) 37 | mywriter.chop_all() 38 | print ("Done!") 39 | 40 | 41 | -------------------------------------------------------------------------------- /chopchop/chopchop2rec.py: -------------------------------------------------------------------------------- 1 | # ============================== Helper Functions ================================== 2 | # Helper functions to create boundary and distance transform 3 | # ground trouth label in 1hot format 4 | import cv2 5 | import glob 6 | import numpy as np 7 | 8 | 9 | def get_boundary(labels, _kernel_size = (3,3)): 10 | 11 | label = labels.copy() 12 | for channel in range(label.shape[0]): 13 | temp = cv2.Canny(label[channel],0,1) 14 | label[channel] = cv2.dilate(temp, cv2.getStructuringElement(cv2.MORPH_CROSS,_kernel_size) ,iterations = 1) 15 | 16 | label = label.astype(np.float32) 17 | label /= 255. 18 | label = label.astype(np.uint8) 19 | return label 20 | 21 | def get_distance(labels): 22 | label = labels.copy() 23 | dists = np.empty_like(label,dtype=np.float32) 24 | for channel in range(label.shape[0]): 25 | dist = cv2.distanceTransform(label[channel], cv2.DIST_L2, 0) 26 | dist = cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX) 27 | dists[channel] = dist 28 | 29 | dists = dists * 100. 30 | dists = dists.astype(np.uint8) 31 | return dists 32 | # ==================================================================================== 33 | 34 | 35 | 36 | import mxnet as mx 37 | import pickle # this is necessary for translating from/to string for writing/reading in mxnet recordio 38 | from multiprocessing import Lock 39 | from pathos.pools import ThreadPool as pp 40 | import rasterio 41 | 42 | 43 | class WriteDataRecordIO(object): 44 | def __init__(self, 45 | ListOfFlnames123, 46 | flname_prefix_write= r'/Location/Of/Your/LEVIRCD/Files/', 47 | NClasses=2, # 1hot encoding 48 | Filter=256, 49 | stride_div=2, 50 | length_scale = 0.317): 51 | 52 | 53 | self.listOfFlnames123 = ListOfFlnames123 54 | self.lock = Lock() 55 | 56 | self.Filter=Filter 57 | self.stride = Filter//stride_div 58 | 59 | self.teye_label = np.eye(NClasses,dtype=np.uint8) 60 | self.global_train_idx = 0 61 | self.global_valid_idx = 0 62 | 63 | flname_train_idx = flname_prefix_write + r'training_LVRCD_F{}.idx'.format(Filter) 64 | flname_train_rec = flname_prefix_write + r'training_LVRCD_F{}.rec'.format(Filter) 65 | 66 | flname_valid_idx = flname_prefix_write + r'validation_LVRCD_F{}.idx'.format(Filter) 67 | flname_valid_rec = flname_prefix_write + r'validation_LVRCD_F{}.rec'.format(Filter) 68 | 69 | self.record_train = mx.recordio.MXIndexedRecordIO(idx_path=flname_train_idx, uri=flname_train_rec , flag='w') 70 | self.record_valid = mx.recordio.MXIndexedRecordIO(idx_path=flname_valid_idx, uri=flname_valid_rec , flag='w') 71 | 72 | self.length_scale=length_scale 73 | 74 | self.Filter = Filter 75 | self.stride = Filter // stride_div 76 | 77 | def update_imgs_mask(self, flnames123): 78 | name_img1, name_img2, name_mask = flnames123 79 | 80 | 81 | with rasterio.open(name_img1,mode='r',driver='png') as src1: 82 | img1 = src1.read() 83 | 84 | with rasterio.open(name_img2,mode='r',driver='png') as src2: 85 | img2 = src2.read() 86 | 87 | with rasterio.open(name_mask,mode='r',driver='png') as srcm: 88 | self.label = srcm.read(1) 89 | self.label[ self.label > 0 ] = 1 90 | 91 | self.img = np.concatenate((img1,img2),axis=0) 92 | 93 | # Constants that relate to rows, columns 94 | self.nTimesRows = int((self.img.shape[1] - self.Filter)//self.stride + 1) 95 | self.nTimesCols = int((self.img.shape[2] - self.Filter)//self.stride + 1) 96 | 97 | 98 | self.nTimesRows_val = int((1.0-self.length_scale)*self.nTimesRows) 99 | self.nTimesCols_val = int((1.0-self.length_scale)*self.nTimesCols) 100 | 101 | 102 | def _2D21H(self,tmask_label): 103 | 104 | tmask_label_1h = self.teye_label[tmask_label] 105 | tmask_label_1h = tmask_label_1h.transpose([2,0,1]) 106 | distance_map = get_distance(tmask_label_1h) 107 | bounds_map = get_boundary(tmask_label_1h) 108 | tlabels_all = np.concatenate([tmask_label_1h, bounds_map, distance_map],axis=0) 109 | 110 | return tlabels_all 111 | 112 | def chop_all(self): 113 | # For all triples in list of filenames 114 | for idx,name123 in enumerate(self.listOfFlnames123): 115 | print ("============================") 116 | print ("----------------------------") 117 | print ("Processing:: {}/{} triplets".format(idx, len(self.listOfFlnames123))) 118 | print ("----------------------------") 119 | for name in name123: 120 | print("Processing File:{}".format(name)) 121 | print ("****************************") 122 | 123 | # read image and mask 124 | self.update_imgs_mask(name123) 125 | 126 | # Do the chop on specific images 127 | self.thread_chop() 128 | 129 | 130 | self.record_train.close() 131 | self.record_valid.close() 132 | 133 | # Change here nthread to maximum available threads you have (or less) 134 | def thread_chop(self,nthread=24): 135 | """ 136 | Extracts patches in parallel from a single tuple of (raster, label, group_label) 137 | """ 138 | RowsCols = [(row, col) for row in range(self.nTimesRows-1) for col in range(self.nTimesCols-1)] 139 | Rows = [row for row in range(self.nTimesRows-1)] 140 | Cols = [col for col in range(self.nTimesCols-1)] 141 | 142 | pool = pp(nodes=nthread) 143 | result1 = pool.map(self.extract_patch,RowsCols) 144 | result2 = pool.map(self.extract_last_Col,Rows) 145 | result3 = pool.map(self.extract_last_Row,Cols) 146 | 147 | 148 | def extract_patch(self, RowCol): 149 | """ 150 | Single chip extraction. 151 | """ 152 | row, col = RowCol 153 | # Extract temporary 154 | tmask_label = self.label[row*self.stride:row*self.stride+self.Filter, col*self.stride:col*self.stride+self.Filter].copy().astype(np.uint8) 155 | timg = self.img[ :, row*self.stride:row*self.stride+self.Filter, col*self.stride:col*self.stride+self.Filter].copy() 156 | 157 | tlabels_all = self._2D21H(tmask_label) 158 | 159 | timg = np.concatenate((timg,tlabels_all),axis=0).astype(np.uint8) 160 | timg = pickle.dumps(timg) 161 | 162 | self.lock.acquire() 163 | 164 | if row >= self.nTimesRows_val and col >= self.nTimesCols_val : 165 | self.record_valid.write_idx(self.global_valid_idx,timg) 166 | self.global_valid_idx += 1 167 | else: 168 | self.record_train.write_idx(self.global_train_idx,timg) 169 | self.global_train_idx += 1 170 | 171 | self.lock.release() 172 | 173 | 174 | 175 | 176 | def extract_last_Col(self,row): 177 | # Keep the overlapping non integer final row/column images as validation images as well 178 | rev_col = self.img.shape[2] - self.Filter 179 | timg = self.img[:, row*self.stride:row*self.stride+self.Filter, rev_col:].copy() 180 | 181 | tmask_label = self.label[row*self.stride:row*self.stride+self.Filter, rev_col:].copy().astype(np.uint8) 182 | 183 | tlabels_all = self._2D21H(tmask_label) 184 | 185 | timg = np.concatenate((timg,tlabels_all),axis=0).astype(np.uint8) 186 | timg = pickle.dumps(timg) 187 | 188 | self.lock.acquire() 189 | self.record_valid.write_idx(self.global_valid_idx,timg) 190 | self.global_valid_idx += 1 191 | self.lock.release() 192 | 193 | 194 | def extract_last_Row(self,col): 195 | # Keep the overlapping non integer final row/column images as validation images as well 196 | rev_row = self.img.shape[1] - self.Filter 197 | 198 | timg = self.img[ :, rev_row:, col*self.stride:col*self.stride+self.Filter].copy() 199 | tmask_label = self.label[rev_row:, col*self.stride:col*self.stride+self.Filter].copy().astype(np.uint8) 200 | 201 | tlabels_all = self._2D21H(tmask_label) 202 | 203 | timg = np.concatenate((timg,tlabels_all),axis=0).astype(np.uint8) 204 | timg = pickle.dumps(timg) 205 | 206 | self.lock.acquire() 207 | self.record_valid.write_idx(self.global_valid_idx,timg) 208 | self.global_valid_idx += 1 209 | self.lock.release() 210 | 211 | 212 | 213 | -------------------------------------------------------------------------------- /demo/Demo_forward_backward.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Populating the interactive namespace from numpy and matplotlib\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "%pylab inline" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import sys\n", 27 | "sys.path.append('/Your/Location/To/CEECNetRepo/')" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stderr", 37 | "output_type": "stream", 38 | "text": [ 39 | "/usr/local/lib/python3.7/site-packages/joblib/_multiprocessing_helpers.py:45: UserWarning: [Errno 28] No space left on device. joblib will operate in serial mode\n", 40 | " warnings.warn('%s. joblib will operate in serial mode' % (e,))\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "from mxnet import nd \n", 46 | "from ceecnet.models.changedetection.mantis.mantis_dn import *" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "depth:= 0, nfilters: 32, nheads::4, widths::1\n", 59 | "depth:= 1, nfilters: 64, nheads::8, widths::1\n", 60 | "depth:= 2, nfilters: 128, nheads::16, widths::1\n", 61 | "depth:= 3, nfilters: 256, nheads::32, widths::1\n", 62 | "depth:= 4, nfilters: 512, nheads::64, widths::1\n", 63 | "depth:= 5, nfilters: 1024, nheads::128, widths::1\n", 64 | "depth:= 6, nfilters: 512, nheads::128, widths::1\n", 65 | "depth:= 7, nfilters: 256, nheads::64, widths::1\n", 66 | "depth:= 8, nfilters: 128, nheads::32, widths::1\n", 67 | "depth:= 9, nfilters: 64, nheads::16, widths::1\n", 68 | "depth:= 10, nfilters: 32, nheads::8, widths::1\n" 69 | ] 70 | } 71 | ], 72 | "source": [ 73 | "# D6nf32 example \n", 74 | "depth=6\n", 75 | "norm_type='GroupNorm'\n", 76 | "norm_groups=4\n", 77 | "ftdepth=5\n", 78 | "NClasses=2\n", 79 | "nfilters_init=32\n", 80 | "psp_depth=4\n", 81 | "nheads_start=4\n", 82 | "\n", 83 | "\n", 84 | "net = mantis_dn_cmtsk(nfilters_init=nfilters_init, NClasses=NClasses,depth=depth, ftdepth=ftdepth, model='CEECNetV1',psp_depth=psp_depth,norm_type=norm_type,norm_groups=norm_groups,nheads_start=nheads_start)\n", 85 | "net.initialize()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "BatchSize = 4\n", 95 | "img_size=256\n", 96 | "NChannels = 3\n", 97 | "\n", 98 | "input_img_1 = nd.random.uniform(shape=[BatchSize, NChannels, img_size, img_size])\n", 99 | "input_img_2 = nd.random.uniform(shape=[BatchSize, NChannels, img_size, img_size])" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## Example of forward operation:\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 6, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "outs = net(input_img_1, input_img_2)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 7, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "(4, 2, 256, 256)\n", 128 | "(4, 2, 256, 256)\n", 129 | "(4, 2, 256, 256)\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "# outs is a list of outputs, segmentation, boundary, distance. \n", 135 | "# Each has shape BatchSize, NClasses, img_size, img_size\n", 136 | "for out in outs:\n", 137 | " print (out.shape)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "### Example of performing backward with multitasking operation" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 8, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "labels_segm = nd.random.uniform(shape=[BatchSize, NClasses, img_size, img_size])\n", 154 | "labels_segm = labels_segm > 0.5\n", 155 | "\n", 156 | "labels_bound = nd.random.uniform(shape=[BatchSize, NClasses, img_size, img_size])\n", 157 | "labels_bound = labels_bound > 0.5\n", 158 | "\n", 159 | "labels_dist = nd.random.uniform(shape=[BatchSize, NClasses, img_size, img_size])\n", 160 | "\n", 161 | "\n", 162 | "labels = [labels_segm,labels_bound,labels_dist]\n", 163 | "labels = nd.concat(*labels,dim=1)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 9, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "from mxnet import autograd\n", 173 | "from ceecnet.nn.loss.mtsk_loss import *" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 10, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "myMTSKL = mtsk_loss()" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 12, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "with autograd.record():\n", 192 | " listOfPreds = net(input_img_1, input_img_2)\n", 193 | " loss = myMTSKL.loss(listOfPreds,labels)\n", 194 | " loss.backward()" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 13, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "(4,)" 206 | ] 207 | }, 208 | "execution_count": 13, 209 | "metadata": {}, 210 | "output_type": "execute_result" 211 | } 212 | ], 213 | "source": [ 214 | "loss.shape" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 28, 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "data": { 224 | "text/plain": [ 225 | "\n", 226 | "[0.50219935 0.5020496 0.5023406 0.5021815 ]\n", 227 | "" 228 | ] 229 | }, 230 | "execution_count": 28, 231 | "metadata": {}, 232 | "output_type": "execute_result" 233 | } 234 | ], 235 | "source": [ 236 | "loss" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [] 245 | } 246 | ], 247 | "metadata": { 248 | "kernelspec": { 249 | "display_name": "Python 3", 250 | "language": "python", 251 | "name": "python3" 252 | }, 253 | "language_info": { 254 | "codemirror_mode": { 255 | "name": "ipython", 256 | "version": 3 257 | }, 258 | "file_extension": ".py", 259 | "mimetype": "text/x-python", 260 | "name": "python", 261 | "nbconvert_exporter": "python", 262 | "pygments_lexer": "ipython3", 263 | "version": "3.7.6" 264 | } 265 | }, 266 | "nbformat": 4, 267 | "nbformat_minor": 4 268 | } 269 | -------------------------------------------------------------------------------- /doc/csse_talk_diakogiannis.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/doc/csse_talk_diakogiannis.pdf -------------------------------------------------------------------------------- /doc/manuscript.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/doc/manuscript.pdf -------------------------------------------------------------------------------- /images/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/images/img_1.png -------------------------------------------------------------------------------- /images/img_101.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/images/img_101.png -------------------------------------------------------------------------------- /images/img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/images/img_2.png -------------------------------------------------------------------------------- /images/img_21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/images/img_21.png -------------------------------------------------------------------------------- /images/img_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/images/img_3.png -------------------------------------------------------------------------------- /images/img_71.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/images/img_71.png -------------------------------------------------------------------------------- /images/img_72.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/images/img_72.png -------------------------------------------------------------------------------- /images/img_77.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/images/img_77.png -------------------------------------------------------------------------------- /images/img_80.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/images/img_80.png -------------------------------------------------------------------------------- /images/mantis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feevos/ceecnet/9dc76f8cd16d44b264cae8c5846eefb8fcf6162d/images/mantis.jpg -------------------------------------------------------------------------------- /models/changedetection/mantis/mantis_dn.py: -------------------------------------------------------------------------------- 1 | from ceecnet.models.heads.head_cmtsk import * 2 | from ceecnet.models.changedetection.mantis.mantis_dn_features import * 3 | 4 | 5 | # Mantis conditioned multitasking. 6 | class mantis_dn_cmtsk(HybridBlock): 7 | def __init__(self, nfilters_init, depth, NClasses,widths=[1], psp_depth=4,verbose=True, norm_type='BatchNorm', norm_groups=None,nheads_start=8, model='CEECNetV1', upFuse=False, ftdepth=5,**kwards): 8 | super().__init__(**kwards) 9 | 10 | with self.name_scope(): 11 | 12 | self.features = mantis_dn_features(nfilters_init=nfilters_init, depth=depth, widths=widths, psp_depth=psp_depth, verbose=verbose, norm_type=norm_type, norm_groups=norm_groups, nheads_start=nheads_start, model=model, upFuse=upFuse, ftdepth=ftdepth, **kwards) 13 | self.head = Head_CMTSK_BC(nfilters_init,NClasses, norm_type=norm_type, norm_groups=norm_groups, **kwards) 14 | 15 | def hybrid_forward(self,F,input_t1, input_t2): 16 | out1, out2= self.features(input_t1,input_t2) 17 | 18 | return self.head(out1,out2) 19 | 20 | -------------------------------------------------------------------------------- /models/changedetection/mantis/mantis_dn_features.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | from ceecnet.nn.layers.conv2Dnormed import * 5 | from ceecnet.nn.layers.attention import * 6 | from ceecnet.nn.pooling.psp_pooling import * 7 | 8 | 9 | from ceecnet.nn.layers.scale import * 10 | from ceecnet.nn.layers.combine import * 11 | 12 | # CEEC units 13 | from ceecnet.nn.units.ceecnet import * 14 | 15 | # FracTALResUnit 16 | from ceecnet.nn.units.fractal_resnet import * 17 | 18 | """ 19 | if upFuse == True, then instead of concatenation of the encoder features with the decoder features, the algorithm performs Fusion with 20 | relative attention. 21 | """ 22 | 23 | 24 | class mantis_dn_features(HybridBlock): 25 | def __init__(self, nfilters_init, depth, widths=[1], psp_depth=4, verbose=True, norm_type='BatchNorm', norm_groups=None, nheads_start=8, model='CEECNetV1', upFuse=False, ftdepth=5, **kwards): 26 | super().__init__(**kwards) 27 | 28 | 29 | self.depth = depth 30 | 31 | 32 | if len(widths) == 1 and depth != 1: 33 | widths = widths * depth 34 | else: 35 | assert depth == len(widths), ValueError("depth and length of widths must match, aborting ...") 36 | 37 | with self.name_scope(): 38 | 39 | self.conv_first = Conv2DNormed(nfilters_init,kernel_size=(1,1), _norm_type = norm_type, norm_groups=norm_groups) 40 | self.fuse_first = Fusion(nfilters_init, norm=norm_type, norm_groups=norm_groups) 41 | 42 | # List of convolutions and pooling operators 43 | self.convs_dn = gluon.nn.HybridSequential() 44 | self.pools = gluon.nn.HybridSequential() 45 | self.fuse = gluon.nn.HybridSequential() 46 | 47 | 48 | for idx in range(depth): 49 | nheads = nheads_start * 2**idx # 50 | nfilters = nfilters_init * 2 **idx 51 | if verbose: 52 | print ("depth:= {0}, nfilters: {1}, nheads::{2}, widths::{3}".format(idx,nfilters,nheads,widths[idx])) 53 | tnet = gluon.nn.HybridSequential() 54 | for _ in range(widths[idx]): 55 | if model == 'CEECNetV1': 56 | tnet.add(CEEC_unit_v1(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 57 | elif model == 'CEECNetV2': 58 | tnet.add(CEEC_unit_v2(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 59 | elif model == 'FracTALResNet': 60 | tnet.add(FracTALResNet_unit(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 61 | else: 62 | raise ValueError("I don't know requested model, aborting ... - Given model::{}".format(model)) 63 | self.convs_dn.add(tnet) 64 | 65 | if idx < depth-1: 66 | self.fuse.add( Fusion( nfilters=nfilters, nheads = nheads , norm = norm_type, norm_groups=norm_groups) ) 67 | self.pools.add(DownSample(nfilters, _norm_type=norm_type, norm_groups=norm_groups)) 68 | # Middle pooling operator 69 | self.middle = PSP_Pooling(nfilters,depth=psp_depth, _norm_type=norm_type,norm_groups=norm_groups) 70 | 71 | 72 | self.convs_up = gluon.nn.HybridSequential() # 1 argument 73 | self.UpCombs = gluon.nn.HybridSequential() # 2 arguments 74 | for idx in range(depth-1,0,-1): 75 | nheads = nheads_start * 2**idx 76 | nfilters = nfilters_init * 2 **(idx-1) 77 | if verbose: 78 | print ("depth:= {0}, nfilters: {1}, nheads::{2}, widths::{3}".format(2*depth-idx-1,nfilters,nheads,widths[idx])) 79 | 80 | tnet = gluon.nn.HybridSequential() 81 | for _ in range(widths[idx]): 82 | if model == 'CEECNetV1': 83 | tnet.add(CEEC_unit_v1(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 84 | elif model == 'CEECNetV2': 85 | tnet.add(CEEC_unit_v2(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 86 | elif model == 'FracTALResNet': 87 | tnet.add(FracTALResNet_unit(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 88 | else: 89 | raise ValueError("I don't know requested model, aborting ... - Given model::{}".format(model)) 90 | self.convs_up.add(tnet) 91 | 92 | if upFuse==True: 93 | self.UpCombs.add(combine_layers_wthFusion(nfilters=nfilters, nheads=nheads, _norm_type=norm_type,norm_groups=norm_groups,ftdepth=ftdepth)) 94 | else: 95 | self.UpCombs.add(combine_layers(nfilters, _norm_type=norm_type,norm_groups=norm_groups)) 96 | 97 | def hybrid_forward(self, F, input_t1, input_t2): 98 | 99 | conv1_t1 = self.conv_first(input_t1) 100 | conv1_t2 = self.conv_first(input_t2) 101 | 102 | fuse1 = self.fuse_first(conv1_t1,conv1_t2) 103 | 104 | # ******** Going down *************** 105 | fusions = [] 106 | 107 | # Workaround of a mxnet bug 108 | # https://github.com/apache/incubator-mxnet/issues/16736 109 | pools1 = F.identity(conv1_t1) 110 | pools2 = F.identity(conv1_t2) 111 | 112 | for idx in range(self.depth): 113 | conv1 = self.convs_dn[idx](pools1) 114 | conv2 = self.convs_dn[idx](pools2) 115 | 116 | 117 | if idx < self.depth-1: 118 | # Evaluate fusions 119 | conv1 = F.identity(conv1) 120 | conv2 = F.identity(conv2) 121 | fusions = fusions + [self.fuse[idx](conv1,conv2)] 122 | # Evaluate pools 123 | pools1 = self.pools[idx](conv1) 124 | pools2 = self.pools[idx](conv2) 125 | 126 | # Middle psppooling 127 | middle = self.middle(F.concat(conv1,conv2, dim=1)) 128 | # Activation of middle layer 129 | middle = F.relu(middle) 130 | fusions = fusions + [middle] 131 | 132 | # ******* Coming up **************** 133 | convs_up = middle 134 | for idx in range(self.depth-1): 135 | convs_up = self.UpCombs[idx](convs_up, fusions[-idx-2]) 136 | convs_up = self.convs_up[idx](convs_up) 137 | 138 | return convs_up, fuse1 139 | 140 | 141 | -------------------------------------------------------------------------------- /models/heads/head_cmtsk.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | 5 | from ceecnet.nn.activations.sigmoid_crisp import * 6 | from ceecnet.nn.pooling.psp_pooling import * 7 | from ceecnet.nn.layers.conv2Dnormed import * 8 | 9 | # Helper classification head, for a single layer output 10 | class HeadSingle(HybridBlock): 11 | def __init__(self, nfilters, NClasses, depth=2, norm_type='BatchNorm',norm_groups=None, **kwargs): 12 | super().__init__(**kwargs) 13 | 14 | 15 | with self.name_scope(): 16 | self.logits = gluon.nn.HybridSequential() 17 | for _ in range(depth): 18 | self.logits.add( Conv2DNormed(channels = nfilters,kernel_size = (3,3),padding=(1,1), _norm_type=norm_type, norm_groups=norm_groups)) 19 | self.logits.add( gluon.nn.Activation('relu')) 20 | self.logits.add( gluon.nn.Conv2D(NClasses,kernel_size=1,padding=0)) 21 | 22 | def hybrid_forward(self,F,input): 23 | return self.logits(input) 24 | 25 | 26 | 27 | class Head_CMTSK_BC(HybridBlock): 28 | # BC: Balanced (features) Crisp (boundaries) 29 | def __init__(self, _nfilters_init, _NClasses, norm_type = 'BatchNorm', norm_groups=None, **kwards): 30 | super().__init__() 31 | 32 | self.model_name = "Head_CMTSK_BC" 33 | 34 | self.nfilters = _nfilters_init # Initial number of filters 35 | self.NClasses = _NClasses 36 | 37 | 38 | with self.name_scope(): 39 | 40 | 41 | self.psp_2ndlast = PSP_Pooling(self.nfilters, _norm_type = norm_type, norm_groups=norm_groups) 42 | 43 | # bound logits 44 | self.bound_logits = HeadSingle(self.nfilters, self.NClasses, norm_type = norm_type, norm_groups=norm_groups) 45 | self.bound_Equalizer = Conv2DNormed(channels = self.nfilters,kernel_size =1, _norm_type=norm_type, norm_groups=norm_groups) 46 | 47 | # distance logits -- deeper for better reconstruction 48 | self.distance_logits = HeadSingle(self.nfilters, self.NClasses, norm_type = norm_type, norm_groups=norm_groups) 49 | self.dist_Equalizer = Conv2DNormed(channels = self.nfilters,kernel_size =1, _norm_type=norm_type, norm_groups=norm_groups) 50 | 51 | 52 | self.Comb_bound_dist = Conv2DNormed(channels = self.nfilters,kernel_size =1, _norm_type=norm_type, norm_groups=norm_groups) 53 | 54 | 55 | # Segmenetation logits -- deeper for better reconstruction 56 | self.final_segm_logits = HeadSingle(self.nfilters, self.NClasses, norm_type = norm_type, norm_groups=norm_groups) 57 | 58 | 59 | 60 | self.CrispSigm = SigmoidCrisp() 61 | 62 | # Last activation, customization for binary results 63 | if ( self.NClasses == 1): 64 | self.ChannelAct = gluon.nn.HybridLambda(lambda F,x: F.sigmoid(x)) 65 | else: 66 | self.ChannelAct = gluon.nn.HybridLambda(lambda F,x: F.softmax(x,axis=1)) 67 | 68 | def hybrid_forward(self,F, UpConv4, conv1): 69 | 70 | 71 | # second last layer 72 | convl = F.concat(conv1,UpConv4) 73 | conv = self.psp_2ndlast(convl) 74 | conv = F.relu(conv) 75 | 76 | 77 | # logits 78 | 79 | # 1st find distance map, skeleton like, topology info 80 | dist = self.distance_logits(convl) # do not use max pooling for distance 81 | dist = self.ChannelAct(dist) 82 | distEq = F.relu(self.dist_Equalizer(dist)) # makes nfilters equals to conv and convl 83 | 84 | 85 | # Then find boundaries 86 | bound = F.concat(conv, distEq) 87 | bound = self.bound_logits(bound) 88 | bound = self.CrispSigm(bound) # Boundaries are not mutually exclusive 89 | boundEq = F.relu(self.bound_Equalizer(bound)) 90 | 91 | 92 | # Now combine all predictions in a final segmentation mask 93 | # Balance first boundary and distance transform, with the features 94 | comb_bd = self.Comb_bound_dist(F.concat(boundEq, distEq,dim=1)) 95 | comb_bd = F.relu(comb_bd) 96 | 97 | all_layers = F.concat(comb_bd, conv) 98 | final_segm = self.final_segm_logits(all_layers) 99 | final_segm = self.ChannelAct(final_segm) 100 | 101 | 102 | return final_segm, bound, dist 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /models/semanticsegmentation/x_unet/x_dn.py: -------------------------------------------------------------------------------- 1 | from ceecnet.models.heads.head_cmtsk import * 2 | from ceecnet.models.semanticsegmentation.x_unet.x_dn_features import * 3 | 4 | 5 | # Mantis conditioned multitasking. 6 | class X_dn_cmtsk(HybridBlock): 7 | def __init__(self, nfilters_init, depth, NClasses,widths=[1], psp_depth=4,verbose=True, norm_type='BatchNorm', norm_groups=None,nheads_start=8, model='CEECNetV1', upFuse=False, ftdepth=5,**kwards): 8 | super().__init__(**kwards) 9 | 10 | with self.name_scope(): 11 | 12 | self.features = X_dn_features(nfilters_init=nfilters_init, depth=depth, widths=widths, psp_depth=psp_depth, verbose=verbose, norm_type=norm_type, norm_groups=norm_groups, nheads_start=nheads_start, model=model, upFuse=upFuse, ftdepth=ftdepth, **kwards) 13 | self.head = Head_CMTSK_BC(nfilters_init,NClasses, norm_type=norm_type, norm_groups=norm_groups, **kwards) 14 | 15 | def hybrid_forward(self,F,input): 16 | out1, out2= self.features(input) 17 | 18 | return self.head(out1,out2) 19 | 20 | -------------------------------------------------------------------------------- /models/semanticsegmentation/x_unet/x_dn_features.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | from ceecnet.nn.layers.conv2Dnormed import * 5 | from ceecnet.nn.layers.attention import * 6 | from ceecnet.nn.pooling.psp_pooling import * 7 | 8 | 9 | from ceecnet.nn.layers.scale import * 10 | from ceecnet.nn.layers.combine import * 11 | 12 | # CEEC units 13 | from ceecnet.nn.units.ceecnet import * 14 | 15 | # FracTALResUnit 16 | from ceecnet.nn.units.fractal_resnet import * 17 | 18 | """ 19 | if upFuse == True, then instead of concatenation of the encoder features with the decoder features, the algorithm performs Fusion with 20 | relative attention. 21 | """ 22 | 23 | 24 | class X_dn_features(HybridBlock): 25 | def __init__(self, nfilters_init, depth, widths=[1], psp_depth=4, verbose=True, norm_type='BatchNorm', norm_groups=None, nheads_start=8, model='CEECNetV1', upFuse=False, ftdepth=5, **kwards): 26 | super().__init__(**kwards) 27 | 28 | 29 | self.depth = depth 30 | 31 | 32 | if len(widths) == 1 and depth != 1: 33 | widths = widths * depth 34 | else: 35 | assert depth == len(widths), ValueError("depth and length of widths must match, aborting ...") 36 | 37 | with self.name_scope(): 38 | 39 | self.conv_first = Conv2DNormed(nfilters_init,kernel_size=(1,1), _norm_type = norm_type, norm_groups=norm_groups) 40 | 41 | # List of convolutions and pooling operators 42 | self.convs_dn = gluon.nn.HybridSequential() 43 | self.pools = gluon.nn.HybridSequential() 44 | 45 | 46 | for idx in range(depth): 47 | nheads = nheads_start * 2**idx # 48 | nfilters = nfilters_init * 2 **idx 49 | if verbose: 50 | print ("depth:= {0}, nfilters: {1}, nheads::{2}, widths::{3}".format(idx,nfilters,nheads,widths[idx])) 51 | tnet = gluon.nn.HybridSequential() 52 | for _ in range(widths[idx]): 53 | if model == 'CEECNetV1': 54 | tnet.add(CEEC_unit_v1(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 55 | elif model == 'CEECNetV2': 56 | tnet.add(CEEC_unit_v2(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 57 | elif model == 'FracTALResNet': 58 | tnet.add(FracTALResNet_unit(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 59 | else: 60 | raise ValueError("I don't know requested model, available options: CEECNetV1, CEECNetV2, FracTALResNet - Given model::{}, aborting ...".format(model)) 61 | self.convs_dn.add(tnet) 62 | 63 | if idx < depth-1: 64 | self.pools.add(DownSample(nfilters, _norm_type=norm_type, norm_groups=norm_groups)) 65 | # Middle pooling operator 66 | self.middle = PSP_Pooling(nfilters,depth=psp_depth, _norm_type=norm_type,norm_groups=norm_groups) 67 | 68 | 69 | self.convs_up = gluon.nn.HybridSequential() # 1 argument 70 | self.UpCombs = gluon.nn.HybridSequential() # 2 arguments 71 | for idx in range(depth-1,0,-1): 72 | nheads = nheads_start * 2**idx 73 | nfilters = nfilters_init * 2 **(idx-1) 74 | if verbose: 75 | print ("depth:= {0}, nfilters: {1}, nheads::{2}, widths::{3}".format(2*depth-idx-1,nfilters,nheads,widths[idx])) 76 | 77 | tnet = gluon.nn.HybridSequential() 78 | for _ in range(widths[idx]): 79 | if model == 'CEECNetV1': 80 | tnet.add(CEEC_unit_v1(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 81 | elif model == 'CEECNetV2': 82 | tnet.add(CEEC_unit_v2(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 83 | elif model == 'FracTALResNet': 84 | tnet.add(FracTALResNet_unit(nfilters=nfilters, nheads = nheads, ngroups = nheads , norm_type = norm_type, norm_groups=norm_groups,ftdepth=ftdepth)) 85 | else: 86 | raise ValueError("I don't know requested model, available options: CEECNetV1, CEECNetV2, FracTALResNet - Given model::{}, aborting ...".format(model)) 87 | self.convs_up.add(tnet) 88 | 89 | if upFuse==True: 90 | self.UpCombs.add(combine_layers_wthFusion(nfilters=nfilters, nheads=nheads, _norm_type=norm_type,norm_groups=norm_groups,ftdepth=ftdepth)) 91 | else: 92 | self.UpCombs.add(combine_layers(nfilters, _norm_type=norm_type,norm_groups=norm_groups)) 93 | 94 | def hybrid_forward(self, F, input): 95 | 96 | conv1_first = self.conv_first(input) 97 | 98 | 99 | # ******** Going down *************** 100 | fusions = [] 101 | 102 | # Workaround of a mxnet bug 103 | # https://github.com/apache/incubator-mxnet/issues/16736 104 | pools = F.identity(conv1_first) 105 | 106 | for idx in range(self.depth): 107 | conv1 = self.convs_dn[idx](pools) 108 | if idx < self.depth-1: 109 | # Evaluate fusions 110 | conv1 = F.identity(conv1) 111 | fusions = fusions + [conv1] 112 | # Evaluate pools 113 | pools = self.pools[idx](conv1) 114 | 115 | # Middle psppooling 116 | middle = self.middle(conv1) 117 | # Activation of middle layer 118 | middle = F.relu(middle) 119 | fusions = fusions + [middle] 120 | 121 | # ******* Coming up **************** 122 | convs_up = middle 123 | for idx in range(self.depth-1): 124 | convs_up = self.UpCombs[idx](convs_up, fusions[-idx-2]) 125 | convs_up = self.convs_up[idx](convs_up) 126 | 127 | return convs_up, conv1_first 128 | 129 | 130 | -------------------------------------------------------------------------------- /nn/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nn/activations/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nn/activations/sigmoid_crisp.py: -------------------------------------------------------------------------------- 1 | from mxnet.gluon import HybridBlock 2 | import mxnet as mx 3 | 4 | 5 | class SigmoidCrisp(HybridBlock): 6 | def __init__(self, smooth=1.e-2,**kwards): 7 | super().__init__(**kwards) 8 | 9 | 10 | self.smooth = smooth 11 | with self.name_scope(): 12 | self.gamma = self.params.get('gamma', shape=(1,), init=mx.init.One()) 13 | 14 | 15 | def hybrid_forward(self, F, input, gamma): 16 | out = self.smooth + F.sigmoid(gamma) 17 | out = F.reciprocal(out) 18 | 19 | out = F.broadcast_mul(input,out) 20 | out = F.sigmoid(out) 21 | return out 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /nn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nn/layers/attention.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | from ceecnet.nn.layers.conv2Dnormed import * 4 | from ceecnet.nn.layers.ftnmt import * 5 | 6 | 7 | 8 | class RelFTAttention2D(HybridBlock): 9 | def __init__(self, nkeys, kernel_size=3, padding=1,nheads=1, norm = 'BatchNorm', norm_groups=None,ftdepth=5,**kwards): 10 | super().__init__(**kwards) 11 | 12 | with self.name_scope(): 13 | 14 | self.query = Conv2DNormed(channels=nkeys,kernel_size= kernel_size, padding = padding, _norm_type= norm, norm_groups=norm_groups, groups=nheads) 15 | self.key = Conv2DNormed(channels=nkeys,kernel_size= kernel_size, padding = padding, _norm_type= norm, norm_groups=norm_groups, groups=nheads) 16 | self.value = Conv2DNormed(channels=nkeys,kernel_size= kernel_size, padding = padding, _norm_type= norm, norm_groups=norm_groups, groups=nheads) 17 | 18 | 19 | self.metric_channel = FTanimoto(depth=ftdepth, axis=[2,3]) 20 | self.metric_space = FTanimoto(depth=ftdepth, axis=1) 21 | 22 | self.norm = get_norm(name=norm, axis=1, norm_groups= norm_groups) 23 | 24 | def hybrid_forward(self, F, input1, input2, input3): 25 | 26 | # These should work with ReLU as well 27 | q = F.sigmoid(self.query(input1)) 28 | k = F.sigmoid(self.key(input2))# B,C,H,W 29 | v = F.sigmoid(self.value(input3)) # B,C,H,W 30 | 31 | att_spat = self.metric_space(q,k) # B,1,H,W 32 | v_spat = F.broadcast_mul(att_spat, v) # emphasize spatial features 33 | 34 | att_chan = self.metric_channel(q,k) # B,C,1,1 35 | v_chan = F.broadcast_mul(att_chan, v) # emphasize spatial features 36 | 37 | 38 | v_cspat = 0.5*F.broadcast_add(v_chan, v_spat) # emphasize spatial features 39 | v_cspat = self.norm(v_cspat) 40 | 41 | return v_cspat 42 | 43 | 44 | 45 | class FTAttention2D(HybridBlock): 46 | def __init__(self, nkeys, kernel_size=3, padding=1, nheads=1, norm = 'BatchNorm', norm_groups=None,ftdepth=5,**kwards): 47 | super().__init__(**kwards) 48 | 49 | with self.name_scope(): 50 | self. att = RelFTAttention2D(nkeys=nkeys,kernel_size=kernel_size, padding=padding, nheads=nheads, norm = norm, norm_groups=norm_groups, ftdepth=ftdepth,**kwards) 51 | 52 | 53 | def hybrid_forward(self, F, input): 54 | return self.att(input,input,input) 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /nn/layers/combine.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | from ceecnet.nn.layers.scale import * 5 | from ceecnet.nn.layers.conv2Dnormed import * 6 | 7 | 8 | """ 9 | For combining layers with Fusion (i.e. relative attention), see ../units/ceecnet.py 10 | """ 11 | 12 | 13 | class combine_layers(HybridBlock): 14 | def __init__(self,_nfilters, _norm_type = 'BatchNorm', norm_groups=None, **kwards): 15 | HybridBlock.__init__(self,**kwards) 16 | 17 | with self.name_scope(): 18 | 19 | # This performs convolution, no BatchNormalization. No need for bias. 20 | self.up = UpSample(_nfilters, _norm_type = _norm_type, norm_groups=norm_groups) 21 | 22 | self.conv_normed = Conv2DNormed(channels = _nfilters, 23 | kernel_size=(1,1), 24 | padding=(0,0), 25 | _norm_type=_norm_type, 26 | norm_groups=norm_groups) 27 | 28 | 29 | 30 | 31 | def hybrid_forward(self,F,_layer_lo, _layer_hi): 32 | 33 | up = self.up(_layer_lo) 34 | up = F.relu(up) 35 | x = F.concat(up,_layer_hi, dim=1) 36 | x = self.conv_normed(x) 37 | 38 | return x 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /nn/layers/conv2Dnormed.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from mxnet import gluon 3 | from mxnet.gluon import HybridBlock 4 | from ceecnet.utils.get_norm import * 5 | 6 | 7 | class Conv2DNormed(HybridBlock): 8 | """ 9 | Convenience wrapper layer for 2D convolution followed by a normalization layer 10 | All other keywords are the same as gluon.nn.Conv2D 11 | """ 12 | 13 | def __init__(self, channels, kernel_size, strides=(1, 1), 14 | padding=(0, 0), dilation=(1, 1), activation=None, 15 | weight_initializer=None, in_channels=0, _norm_type = 'BatchNorm', norm_groups=None, axis =1 , groups=1, **kwards): 16 | super().__init__(**kwards) 17 | 18 | with self.name_scope(): 19 | self.conv2d = gluon.nn.Conv2D(channels, kernel_size = kernel_size, 20 | strides= strides, 21 | padding=padding, 22 | dilation= dilation, 23 | activation=activation, 24 | use_bias=False, 25 | weight_initializer = weight_initializer, 26 | groups=groups, 27 | in_channels=0) 28 | 29 | self.norm_layer = get_norm(_norm_type, axis=axis, norm_groups= norm_groups) 30 | 31 | def hybrid_forward(self,F,_x): 32 | 33 | x = self.conv2d(_x) 34 | x = self.norm_layer(x) 35 | 36 | return x 37 | 38 | -------------------------------------------------------------------------------- /nn/layers/ftnmt.py: -------------------------------------------------------------------------------- 1 | from mxnet.gluon import HybridBlock 2 | 3 | 4 | class FTanimoto(HybridBlock): 5 | """ 6 | This is the average fractal Tanimoto set similarity with complement. 7 | """ 8 | def __init__(self, depth=5, smooth=1.0e-5, axis=[2,3],**kwards): 9 | super().__init__(**kwards) 10 | 11 | assert depth >= 0, "Expecting depth >= 0, aborting ..." 12 | 13 | if depth == 0: 14 | self.depth = 1 15 | self.scale = 1. 16 | else: 17 | self.depth = depth 18 | self.scale = 1./depth 19 | 20 | self.smooth = smooth 21 | self.axis=axis 22 | 23 | def inner_prod(self, F, prob, label): 24 | prod = F.broadcast_mul(prob,label) 25 | prod = F.sum(prod,axis=self.axis,keepdims=True) 26 | 27 | return prod 28 | 29 | 30 | 31 | def tnmt_base(self, F, preds, labels): 32 | 33 | tpl = self.inner_prod(F,preds,labels) 34 | tpp = self.inner_prod(F,preds,preds) 35 | tll = self.inner_prod(F,labels,labels) 36 | 37 | 38 | num = tpl + self.smooth 39 | denum = 0.0 40 | 41 | 42 | for d in range(self.depth): 43 | a = 2.**d 44 | b = -(2.*a-1.) 45 | 46 | denum = denum + F.reciprocal(F.broadcast_add(a*(tpp+tll), b *tpl) + self.smooth) 47 | 48 | return F.broadcast_mul(num,denum)*self.scale 49 | 50 | def hybrid_forward(self, F, preds, labels): 51 | l12 = self.tnmt_base(F,preds,labels) 52 | l12 = l12 + self.tnmt_base(F,1.-preds, 1.-labels) 53 | 54 | return 0.5*l12 55 | -------------------------------------------------------------------------------- /nn/layers/scale.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | 4 | from ceecnet.nn.layers.conv2Dnormed import * 5 | from ceecnet.utils.get_norm import * 6 | 7 | class DownSample(HybridBlock): 8 | def __init__(self, nfilters, factor=2, _norm_type='BatchNorm', norm_groups=None, **kwargs): 9 | super().__init__(**kwargs) 10 | 11 | 12 | # Double the size of filters, since you downscale by 2. 13 | self.factor = factor 14 | self.nfilters = nfilters * self.factor 15 | 16 | self.kernel_size = (3,3) 17 | self.strides = (factor,factor) 18 | self.pad = (1,1) 19 | 20 | with self.name_scope(): 21 | self.convdn = Conv2DNormed(self.nfilters, 22 | kernel_size=self.kernel_size, 23 | strides=self.strides, 24 | padding=self.pad, 25 | _norm_type = _norm_type, 26 | norm_groups=norm_groups) 27 | 28 | 29 | def hybrid_forward(self,F,_xl): 30 | 31 | x = self.convdn(_xl) 32 | 33 | return x 34 | 35 | 36 | class UpSample(HybridBlock): 37 | def __init__(self,nfilters, factor = 2, _norm_type='BatchNorm', norm_groups=None, **kwards): 38 | HybridBlock.__init__(self,**kwards) 39 | 40 | 41 | self.factor = factor 42 | self.nfilters = nfilters // self.factor 43 | 44 | with self.name_scope(): 45 | self.convup_normed = Conv2DNormed(self.nfilters, 46 | kernel_size = (1,1), 47 | _norm_type = _norm_type, 48 | norm_groups=norm_groups) 49 | 50 | def hybrid_forward(self,F,_xl): 51 | x = F.UpSampling(_xl, scale=self.factor, sample_type='nearest') 52 | x = self.convup_normed(x) 53 | 54 | return x 55 | 56 | -------------------------------------------------------------------------------- /nn/loss/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nn/loss/ftnmt_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fractal Tanimoto (with dual) loss 3 | """ 4 | 5 | from mxnet.gluon.loss import Loss 6 | class ftnmt_loss(Loss): 7 | """ 8 | This function calculates the average fractal tanimoto similarity for d = 0...depth 9 | """ 10 | def __init__(self, depth=5, axis= [1,2,3], smooth = 1.0e-5, batch_axis=0, weight=None, **kwargs): 11 | super().__init__(batch_axis, weight, **kwargs) 12 | 13 | assert depth>= 0, ValueError("depth must be >= 0, aborting...") 14 | 15 | self.smooth = smooth 16 | self.axis=axis 17 | self.depth = depth 18 | 19 | if depth == 0: 20 | self.depth = 1 21 | self.scale = 1. 22 | else: 23 | self.depth = depth 24 | self.scale = 1./depth 25 | 26 | def inner_prod(self, F, prob, label): 27 | prod = F.broadcast_mul(prob,label) 28 | prod = F.sum(prod,axis=self.axis) 29 | 30 | return prod 31 | 32 | def tnmt_base(self, F, preds, labels): 33 | 34 | tpl = self.inner_prod(F,preds,labels) 35 | tpp = self.inner_prod(F,preds,preds) 36 | tll = self.inner_prod(F,labels,labels) 37 | 38 | 39 | num = tpl + self.smooth 40 | scale = 1./self.depth 41 | denum = 0.0 42 | for d in range(self.depth): 43 | a = 2.**d 44 | b = -(2.*a-1.) 45 | 46 | denum = denum + F.reciprocal(F.broadcast_add(a*(tpp+tll), b *tpl) + self.smooth) 47 | 48 | result = F.broadcast_mul(num,denum)*scale 49 | return F.mean(result, axis=0,exclude=True) 50 | 51 | 52 | def hybrid_forward(self,F, preds, labels): 53 | 54 | l1 = self.tnmt_base(F,preds,labels) 55 | l2 = self.tnmt_base(F,1.-preds, 1.-labels) 56 | 57 | result = 0.5*(l1+l2) 58 | 59 | return 1. - result 60 | 61 | -------------------------------------------------------------------------------- /nn/loss/mtsk_loss.py: -------------------------------------------------------------------------------- 1 | from ceecnet.nn.loss.ftnmt_loss import * 2 | 3 | class mtsk_loss(object): 4 | """ 5 | Here NClasses = 2 by default, for a binary segmentation problem in 1hot representation 6 | """ 7 | 8 | def __init__(self,depth=0, NClasses=2): 9 | 10 | self.ftnmt = ftnmt_loss(depth=depth) 11 | self.ftnmt.hybridize() 12 | 13 | self.skip = NClasses 14 | 15 | def loss(self,_prediction,_label): 16 | 17 | pred_segm = _prediction[0] 18 | pred_bound = _prediction[1] 19 | pred_dists = _prediction[2] 20 | 21 | # In our implementation of the labels, we stack together the [segmentation, boundary, distance] labels, 22 | # along the channel axis. 23 | label_segm = _label[:,:self.skip,:,:] 24 | label_bound = _label[:,self.skip:2*self.skip,:,:] 25 | label_dists = _label[:,2*self.skip:,:,:] 26 | 27 | 28 | loss_segm = self.ftnmt(pred_segm, label_segm) 29 | loss_bound = self.ftnmt(pred_bound, label_bound) 30 | loss_dists = self.ftnmt(pred_dists, label_dists) 31 | 32 | return (loss_segm+loss_bound+loss_dists)/3.0 33 | 34 | -------------------------------------------------------------------------------- /nn/pooling/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nn/pooling/psp_pooling.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | from ceecnet.nn.layers.conv2Dnormed import * 4 | 5 | class PSP_Pooling(gluon.HybridBlock): 6 | def __init__(self, nfilters, depth=4, _norm_type = 'BatchNorm', norm_groups=None, mob=False, **kwards): 7 | gluon.HybridBlock.__init__(self,**kwards) 8 | 9 | 10 | self.nfilters = nfilters 11 | self.depth = depth 12 | 13 | self.convs = gluon.nn.HybridSequential() 14 | with self.name_scope(): 15 | for _ in range(depth): 16 | self.convs.add(Conv2DNormed(self.nfilters,kernel_size=(1,1),padding=(0,0),_norm_type=_norm_type, norm_groups=norm_groups)) 17 | 18 | self.conv_norm_final = Conv2DNormed(channels = self.nfilters, 19 | kernel_size=(1,1), 20 | padding=(0,0), 21 | _norm_type=_norm_type, 22 | norm_groups=norm_groups) 23 | 24 | 25 | # ******** Utilities functions to avoid calling infer_shape **************** 26 | def HalfSplit(self, F,_a): 27 | """ 28 | Returns a list of half split arrays. Usefull for HalfPoolling 29 | """ 30 | b = F.split(_a,axis=2,num_outputs=2) # Split First dimension 31 | c1 = F.split(b[0],axis=3,num_outputs=2) # Split 2nd dimension 32 | c2 = F.split(b[1],axis=3,num_outputs=2) # Split 2nd dimension 33 | 34 | 35 | d11 = c1[0] 36 | d12 = c1[1] 37 | 38 | d21 = c2[0] 39 | d22 = c2[1] 40 | 41 | return [d11,d12,d21,d22] 42 | 43 | 44 | def QuarterStitch(self, F,_Dss): 45 | """ 46 | INPUT: 47 | A list of [d11,d12,d21,d22] block matrices. 48 | OUTPUT: 49 | A single matrix joined of these submatrices 50 | """ 51 | 52 | temp1 = F.concat(_Dss[0],_Dss[1],dim=-1) 53 | temp2 = F.concat(_Dss[2],_Dss[3],dim=-1) 54 | result = F.concat(temp1,temp2,dim=2) 55 | 56 | return result 57 | 58 | 59 | def HalfPooling(self, F,_a): 60 | Ds = self.HalfSplit(F,_a) 61 | 62 | Dss = [] 63 | for x in Ds: 64 | Dss += [F.broadcast_mul(F.ones_like(x) , F.Pooling(x,global_pool=True))] 65 | 66 | return self.QuarterStitch(F,Dss) 67 | 68 | 69 | 70 | def SplitPooling(self, F, _a, depth): 71 | """ 72 | A recursive function that produces the Pooling you want - in particular depth (powers of 2) 73 | """ 74 | if depth==1: 75 | return self.HalfPooling(F,_a) 76 | else : 77 | D = self.HalfSplit(F,_a) 78 | return self.QuarterStitch(F,[self.SplitPooling(F,d,depth-1) for d in D]) 79 | # *********************************************************************************** 80 | 81 | def hybrid_forward(self,F,_input): 82 | 83 | p = [_input] 84 | # 1st:: Global Max Pooling . 85 | p += [self.convs[0](F.broadcast_mul(F.ones_like(_input) , F.Pooling(_input,global_pool=True)))] 86 | p += [self.convs[d](self.SplitPooling(F,_input,d)) for d in range(1,self.depth)] 87 | out = F.concat(*p,dim=1) 88 | out = self.conv_norm_final(out) 89 | 90 | return out 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /nn/units/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nn/units/ceecnet.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | from ceecnet.nn.layers.conv2Dnormed import * 4 | from ceecnet.utils.get_norm import * 5 | from ceecnet.nn.layers.attention import * 6 | 7 | 8 | class ResizeLayer(HybridBlock): 9 | """ 10 | Applies bilinear up/down sampling in spatial dims and changes number of filters as well 11 | """ 12 | def __init__(self, nfilters, height, width, _norm_type = 'BatchNorm', norm_groups=None, **kwards): 13 | super().__init__(**kwards) 14 | 15 | self.height=height 16 | self.width = width 17 | 18 | with self.name_scope(): 19 | 20 | self.conv2d = Conv2DNormed(channels=nfilters,kernel_size=3,padding=1, _norm_type=_norm_type, norm_groups = norm_groups, **kwards) 21 | 22 | 23 | def hybrid_forward(self, F, input): 24 | out = F.contrib.BilinearResize2D(input,height=self.height,width=self.width) 25 | out = self.conv2d(out) 26 | 27 | return out 28 | 29 | class ExpandLayer(HybridBlock): 30 | def __init__(self,nfilters, _norm_type = 'BatchNorm', norm_groups=None, ngroups=1,**kwards): 31 | super().__init__(**kwards) 32 | 33 | 34 | with self.name_scope(): 35 | self.conv1 = Conv2DNormed(channels=nfilters,kernel_size=3,padding=1,groups=ngroups, _norm_type=_norm_type, norm_groups = norm_groups, **kwards) 36 | self.conv2 = Conv2DNormed(channels=nfilters,kernel_size=3,padding=1,groups=ngroups,_norm_type=_norm_type, norm_groups = norm_groups,**kwards) 37 | 38 | def hybrid_forward(self, F, input): 39 | 40 | out = F.contrib.BilinearResize2D(input,scale_height=2.,scale_width=2.) 41 | out = self.conv1(out) 42 | out = F.relu(out) 43 | out = self.conv2(out) 44 | out = F.relu(out) 45 | 46 | return out 47 | 48 | class ExpandNCombine(HybridBlock): 49 | def __init__(self,nfilters, _norm_type = 'BatchNorm', norm_groups=None,ngroups=1,**kwards): 50 | super().__init__(**kwards) 51 | 52 | with self.name_scope(): 53 | self.conv1 = Conv2DNormed(channels=nfilters,kernel_size=3,padding=1,groups=ngroups,_norm_type=_norm_type, norm_groups = norm_groups,**kwards) 54 | self.conv2 = Conv2DNormed(channels=nfilters,kernel_size=3,padding=1,groups=ngroups,_norm_type=_norm_type, norm_groups = norm_groups,**kwards) 55 | 56 | def hybrid_forward(self, F, input1, input2): 57 | 58 | out = F.contrib.BilinearResize2D(input1,scale_height=2.,scale_width=2.) 59 | out = self.conv1(out) 60 | out = F.relu(out) 61 | out2 = self.conv2(F.concat(out,input2,dim=1)) 62 | out2 = F.relu(out2) 63 | 64 | return out2 65 | 66 | 67 | 68 | class CEEC_unit_v1(HybridBlock): 69 | def __init__(self, nfilters, nheads= 1, ngroups=1, norm_type='BatchNorm', norm_groups=None, ftdepth=5, **kwards): 70 | super().__init__(**kwards) 71 | 72 | 73 | with self.name_scope(): 74 | nfilters_init = nfilters//2 75 | self.conv_init_1 = Conv2DNormed(channels=nfilters_init, kernel_size=3,padding=1,strides=1, groups=ngroups, _norm_type=norm_type, norm_groups=norm_groups, **kwards) 76 | self.compr11 = Conv2DNormed(channels=nfilters_init*2, kernel_size=3,padding=1,strides=2, groups=ngroups, _norm_type=norm_type, norm_groups=norm_groups, **kwards) 77 | self.compr12 = Conv2DNormed(channels=nfilters_init*2, kernel_size=3,padding=1,strides=1, groups=ngroups, _norm_type=norm_type, norm_groups=norm_groups,**kwards) 78 | self.expand1 = ExpandNCombine(nfilters_init,_norm_type = norm_type, norm_groups=norm_groups,ngroups=ngroups) 79 | 80 | # --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 81 | 82 | self.conv_init_2 = Conv2DNormed(channels=nfilters_init, kernel_size=3,padding=1,strides=1, groups=ngroups, _norm_type=norm_type, norm_groups=norm_groups, **kwards)#half size 83 | 84 | self.expand2 = ExpandLayer(nfilters_init//2 ,_norm_type = norm_type, norm_groups=norm_groups,ngroups=ngroups ) 85 | self.compr21 = Conv2DNormed(channels=nfilters_init, kernel_size=3,padding=1,strides=2, groups=ngroups, _norm_type=norm_type, norm_groups=norm_groups,**kwards) 86 | self.compr22 = Conv2DNormed(channels=nfilters_init, kernel_size=3,padding=1,strides=1, groups=ngroups, _norm_type=norm_type, norm_groups=norm_groups,**kwards) 87 | 88 | # Will join with master input with concatenation -- IMPORTANT: ngroups = 1 !!!! 89 | self.collect = Conv2DNormed(channels=nfilters, kernel_size=3,padding=1,strides=1, groups=1, _norm_type=norm_type, norm_groups=norm_groups,**kwards) 90 | 91 | 92 | self.att = FTAttention2D(nkeys=nfilters,nheads=nheads,norm=norm_type, norm_groups = norm_groups,ftdepth=ftdepth) 93 | self.ratt122 = RelFTAttention2D(nkeys=nfilters_init, nheads=nheads,norm=norm_type, norm_groups = norm_groups,ftdepth=ftdepth) 94 | self.ratt211 = RelFTAttention2D(nkeys=nfilters_init, nheads=nheads,norm=norm_type, norm_groups = norm_groups,ftdepth=ftdepth) 95 | 96 | 97 | self.gamma1 = self.params.get('gamma1', shape=(1,), init=mx.init.Zero()) 98 | self.gamma2 = self.params.get('gamma2', shape=(1,), init=mx.init.Zero()) 99 | self.gamma3 = self.params.get('gamma3', shape=(1,), init=mx.init.Zero()) 100 | 101 | 102 | def hybrid_forward(self, F, input, gamma1, gamma2, gamma3): 103 | 104 | # =========== UNet branch =========== 105 | out10 = self.conv_init_1(input) 106 | out1 = self.compr11(out10) 107 | out1 = F.relu(out1) 108 | out1 = self.compr12(out1) 109 | out1 = F.relu(out1) 110 | out1 = self.expand1(out1,out10) 111 | out1 = F.relu(out1) 112 | 113 | 114 | # =========== \capNet branch =========== 115 | input = F.identity(input) # Solves a mxnet bug 116 | 117 | out20 = self.conv_init_2(input) 118 | out2 = self.expand2(out20) 119 | out2 = F.relu(out2) 120 | out2 = self.compr21(out2) 121 | out2 = F.relu(out2) 122 | out2 = self.compr22(F.concat(out2,out20,dim=1)) 123 | out2 = F.relu(out2) 124 | 125 | att = F.broadcast_mul(gamma1,self.att(input)) 126 | ratt122 = F.broadcast_mul(gamma2,self.ratt122(out1,out2,out2)) 127 | ratt211 = F.broadcast_mul(gamma3,self.ratt211(out2,out1,out1)) 128 | 129 | ones1 = F.ones_like(out10) 130 | ones2 = F.ones_like(input) 131 | 132 | # Enhanced output of 1, based on memory of 2 133 | out122 = F.broadcast_mul(out1,ones1 + ratt122) 134 | # Enhanced output of 2, based on memory of 1 135 | out211 = F.broadcast_mul(out2,ones1 + ratt211) 136 | 137 | out12 = F.relu(self.collect(F.concat(out122,out211,dim=1))) 138 | 139 | # Emphasize residual output from memory on input 140 | out_res = F.broadcast_mul(input + out12, ones2 + att) 141 | return out_res 142 | 143 | 144 | 145 | 146 | 147 | 148 | # ======= Definitions for CEEC unit v2 (replace concatenations with Fusion ========================= 149 | # -------------------------------------- helper functions ------------------------------------------- 150 | 151 | class Fusion(HybridBlock): 152 | def __init__(self,nfilters, kernel_size=3, padding=1,nheads=1, norm = 'BatchNorm', norm_groups=None, ftdepth=5,**kwards): 153 | super().__init__(**kwards) 154 | 155 | 156 | with self.name_scope(): 157 | self.fuse = Conv2DNormed(nfilters,kernel_size= kernel_size, padding = padding, _norm_type= norm, norm_groups=norm_groups, groups=nheads,**kwards) 158 | # Or shall I use the same? 159 | self.relatt12 = RelFTAttention2D(nkeys=nfilters, kernel_size=kernel_size, padding=padding, nheads=nheads, norm =norm, norm_groups=norm_groups,ftdepth=ftdepth,**kwards) 160 | self.relatt21 = RelFTAttention2D(nkeys=nfilters, kernel_size=kernel_size, padding=padding, nheads=nheads, norm =norm, norm_groups=norm_groups,ftdepth=ftdepth,**kwards) 161 | 162 | 163 | self.gamma1 = self.params.get('gamma1', shape=(1,), init=mx.init.Zero()) 164 | self.gamma2 = self.params.get('gamma2', shape=(1,), init=mx.init.Zero()) 165 | 166 | 167 | def hybrid_forward(self, F, input_t1, input_t2, gamma1, gamma2): 168 | # These inputs must have the same dimensionality , t1, t2 169 | relatt12 = F.broadcast_mul(gamma1,self.relatt12(input_t1,input_t2,input_t2)) 170 | relatt21 = F.broadcast_mul(gamma2,self.relatt21(input_t2,input_t1,input_t1)) 171 | 172 | ones = F.ones_like(input_t1) 173 | 174 | # Enhanced output of 1, based on memory of 2 175 | out12 = F.broadcast_mul(input_t1,ones + relatt12) 176 | # Enhanced output of 2, based on memory of 1 177 | out21 = F.broadcast_mul(input_t2,ones + relatt21) 178 | 179 | 180 | fuse = self.fuse(F.concat(out12, out21,dim=1)) 181 | fuse = F.relu(fuse) 182 | 183 | return fuse 184 | 185 | 186 | 187 | class CATFusion(HybridBlock): 188 | """ 189 | Alternative to concatenation followed by normed convolution: improves performance. 190 | """ 191 | def __init__(self,nfilters_out, nfilters_in, kernel_size=3, padding=1,nheads=1, norm = 'BatchNorm', norm_groups=None, ftdepth=5,**kwards): 192 | super().__init__(**kwards) 193 | 194 | 195 | with self.name_scope(): 196 | self.fuse = Conv2DNormed(nfilters_out,kernel_size= kernel_size, padding = padding, _norm_type= norm, norm_groups=norm_groups, groups=nheads,**kwards) 197 | # Or shall I use the same? 198 | self.relatt12 = RelFTAttention2D(nkeys=nfilters_in, kernel_size=kernel_size, padding=padding, nheads=nheads, norm =norm, norm_groups=norm_groups,ftdepth=ftdepth,**kwards) 199 | self.relatt21 = RelFTAttention2D(nkeys=nfilters_in, kernel_size=kernel_size, padding=padding, nheads=nheads, norm =norm, norm_groups=norm_groups,ftdepth=ftdepth,**kwards) 200 | 201 | 202 | self.gamma1 = self.params.get('gamma1', shape=(1,), init=mx.init.Zero()) 203 | self.gamma2 = self.params.get('gamma2', shape=(1,), init=mx.init.Zero()) 204 | 205 | 206 | 207 | def hybrid_forward(self, F, input_t1, input_t2, gamma1, gamma2): 208 | # These inputs must have the same dimensionality , t1, t2 209 | relatt12 = F.broadcast_mul(gamma1,self.relatt12(input_t1,input_t2,input_t2)) 210 | relatt21 = F.broadcast_mul(gamma2,self.relatt21(input_t2,input_t1,input_t1)) 211 | 212 | ones = F.ones_like(input_t1) 213 | 214 | # Enhanced output of 1, based on memory of 2 215 | out12 = F.broadcast_mul(input_t1,ones + relatt12) 216 | # Enhanced output of 2, based on memory of 1 217 | out21 = F.broadcast_mul(input_t2,ones + relatt21) 218 | 219 | 220 | fuse = self.fuse(F.concat(out12, out21,dim=1)) 221 | fuse = F.relu(fuse) 222 | 223 | return fuse 224 | 225 | 226 | 227 | 228 | class combine_layers_wthFusion(HybridBlock): 229 | def __init__(self,nfilters, nheads=1, _norm_type = 'BatchNorm', norm_groups=None,ftdepth=5, **kwards): 230 | HybridBlock.__init__(self,**kwards) 231 | 232 | with self.name_scope(): 233 | 234 | self.conv1 = Conv2DNormed(channels=nfilters,kernel_size=3,padding=1, groups=nheads, _norm_type=_norm_type, norm_groups = norm_groups, **kwards)# restore help 235 | self.conv3 = Fusion(nfilters=nfilters, kernel_size=3, padding=1, nheads=nheads, norm=_norm_type, norm_groups = norm_groups, ftdepth=ftdepth,**kwards) # process 236 | 237 | def hybrid_forward(self,F,_layer_lo, _layer_hi): 238 | 239 | up = F.contrib.BilinearResize2D(_layer_lo,scale_height=2.,scale_width=2.) 240 | up = self.conv1(up) 241 | up = F.relu(up) 242 | x = self.conv3(up,_layer_hi) 243 | 244 | return x 245 | 246 | 247 | class ExpandNCombine_V3(HybridBlock): 248 | def __init__(self,nfilters, _norm_type = 'BatchNorm', norm_groups=None,ngroups=1,ftdepth=5,**kwards): 249 | super().__init__(**kwards) 250 | 251 | 252 | with self.name_scope(): 253 | self.conv1 = Conv2DNormed(channels=nfilters,kernel_size=3,padding=1,groups=ngroups,_norm_type=_norm_type, norm_groups = norm_groups,**kwards)# restore help 254 | self.conv2 = Conv2DNormed(channels=nfilters,kernel_size=3,padding=1,groups=ngroups,_norm_type=_norm_type, norm_groups = norm_groups,**kwards)# restore help 255 | self.conv3 = Fusion(nfilters=nfilters,kernel_size=3,padding=1,nheads=ngroups,norm=_norm_type, norm_groups = norm_groups,ftdepth=ftdepth,**kwards) # process 256 | 257 | def hybrid_forward(self, F, input1, input2): 258 | 259 | out = F.contrib.BilinearResize2D(input1,scale_height=2.,scale_width=2.) 260 | out = self.conv1(out) 261 | out1 = F.relu(out) 262 | 263 | out2 = self.conv2(input2) 264 | out2 = F.relu(out2) 265 | 266 | outf = self.conv3(out1,out2) 267 | outf = F.relu(outf) 268 | 269 | return outf 270 | 271 | 272 | 273 | 274 | # ------------------------------------------------------------------------------------------------------------------- 275 | 276 | class CEEC_unit_v2(HybridBlock): 277 | def __init__(self, nfilters, nheads= 1, ngroups=1, norm_type='BatchNorm', norm_groups=None, ftdepth=5, **kwards): 278 | super().__init__(**kwards) 279 | 280 | 281 | with self.name_scope(): 282 | nfilters_init = nfilters//2 283 | self.conv_init_1 = Conv2DNormed(channels=nfilters_init, kernel_size=3,padding=1,strides=1, groups=ngroups, _norm_type=norm_type, norm_groups=norm_groups, **kwards)#half size 284 | self.compr11 = Conv2DNormed(channels=nfilters_init*2, kernel_size=3,padding=1,strides=2, groups=ngroups, _norm_type=norm_type, norm_groups=norm_groups, **kwards)#half size 285 | self.compr12 = Conv2DNormed(channels=nfilters_init*2, kernel_size=3,padding=1,strides=1, groups=ngroups, _norm_type=norm_type, norm_groups=norm_groups,**kwards)# process 286 | self.expand1 = ExpandNCombine_V3(nfilters_init,_norm_type = norm_type, norm_groups=norm_groups,ngroups=ngroups,ftdepth=ftdepth) # restore original size + process 287 | 288 | 289 | self.conv_init_2 = Conv2DNormed(channels=nfilters_init, kernel_size=3,padding=1,strides=1, groups=ngroups, _norm_type=norm_type, norm_groups=norm_groups, **kwards)#half size 290 | self.expand2 = ExpandLayer(nfilters_init//2 ,_norm_type = norm_type, norm_groups=norm_groups,ngroups=ngroups ) 291 | self.compr21 = Conv2DNormed(channels=nfilters_init, kernel_size=3,padding=1,strides=2, groups=ngroups, _norm_type=norm_type, norm_groups=norm_groups,**kwards) 292 | self.compr22 = Fusion(nfilters=nfilters_init, kernel_size=3,padding=1, nheads=ngroups, norm=norm_type, norm_groups=norm_groups,ftdepth=ftdepth,**kwards) 293 | 294 | self.collect = CATFusion(nfilters_out=nfilters, nfilters_in=nfilters_init, kernel_size=3,padding=1,nheads=1, norm=norm_type, norm_groups=norm_groups,ftdepth=ftdepth,**kwards) 295 | 296 | self.att = FTAttention2D(nkeys=nfilters,nheads=nheads,norm=norm_type, norm_groups = norm_groups, ftdepth=ftdepth) 297 | self.ratt122 = RelFTAttention2D(nkeys=nfilters_init, nheads=nheads,norm=norm_type, norm_groups = norm_groups, ftdepth=ftdepth) 298 | self.ratt211 = RelFTAttention2D(nkeys=nfilters_init, nheads=nheads,norm=norm_type, norm_groups = norm_groups, ftdepth=ftdepth) 299 | 300 | 301 | self.gamma1 = self.params.get('gamma1', shape=(1,), init=mx.init.Zero()) 302 | self.gamma2 = self.params.get('gamma2', shape=(1,), init=mx.init.Zero()) 303 | self.gamma3 = self.params.get('gamma3', shape=(1,), init=mx.init.Zero()) 304 | 305 | def hybrid_forward(self, F, input, gamma1, gamma2, gamma3): 306 | 307 | # =========== UNet branch =========== 308 | out10 = self.conv_init_1(input) 309 | out1 = self.compr11(out10) 310 | out1 = F.relu(out1) 311 | #print (out1.shape) 312 | out1 = self.compr12(out1) 313 | out1 = F.relu(out1) 314 | #print (out1.shape) 315 | out1 = self.expand1(out1,out10) 316 | out1 = F.relu(out1) 317 | 318 | 319 | # =========== \capNet branch =========== 320 | input = F.identity(input) # Solves a mxnet bug 321 | 322 | out20 = self.conv_init_2(input) 323 | out2 = self.expand2(out20) 324 | out2 = F.relu(out2) 325 | out2 = self.compr21(out2) 326 | out2 = F.relu(out2) 327 | out2 = self.compr22(out2,out20) 328 | 329 | 330 | 331 | input = F.identity(input) # Solves a mxnet bug 332 | 333 | att = F.broadcast_mul(gamma1,self.att(input)) 334 | ratt122 = F.broadcast_mul(gamma2,self.ratt122(out1,out2,out2)) 335 | ratt211 = F.broadcast_mul(gamma3,self.ratt211(out2,out1,out1)) 336 | 337 | ones1 = F.ones_like(out10) 338 | ones2 = F.ones_like(input) 339 | 340 | # Enhanced output of 1, based on memory of 2 341 | out122 = F.broadcast_mul(out1,ones1 + ratt122) 342 | # Enhanced output of 2, based on memory of 1 343 | out211 = F.broadcast_mul(out2,ones1 + ratt211) 344 | 345 | 346 | out12 = self.collect(out122,out211) # includes relu, it's for fusion 347 | 348 | out_res = F.broadcast_mul(input + out12, ones2 + att) 349 | return out_res 350 | 351 | 352 | -------------------------------------------------------------------------------- /nn/units/fractal_resnet.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | from mxnet.gluon import HybridBlock 3 | from ceecnet.nn.layers.conv2Dnormed import * 4 | from ceecnet.utils.get_norm import * 5 | from ceecnet.nn.layers.attention import * 6 | 7 | class ResNet_v2_block(HybridBlock): 8 | """ 9 | ResNet v2 building block. It is built upon the assumption of ODD kernel 10 | """ 11 | def __init__(self, _nfilters,_kernel_size=(3,3),_dilation_rate=(1,1), 12 | _norm_type='BatchNorm', norm_groups=None, ngroups=1, **kwards): 13 | super().__init__(**kwards) 14 | 15 | self.nfilters = _nfilters 16 | self.kernel_size = _kernel_size 17 | self.dilation_rate = _dilation_rate 18 | 19 | 20 | with self.name_scope(): 21 | 22 | # Ensures padding = 'SAME' for ODD kernel selection 23 | p0 = self.dilation_rate[0] * (self.kernel_size[0] - 1)/2 24 | p1 = self.dilation_rate[1] * (self.kernel_size[1] - 1)/2 25 | p = (int(p0),int(p1)) 26 | 27 | 28 | self.BN1 = get_norm(_norm_type, norm_groups=norm_groups ) 29 | self.conv1 = gluon.nn.Conv2D(self.nfilters,kernel_size = self.kernel_size,padding=p,dilation=self.dilation_rate,use_bias=False,groups=ngroups) 30 | self.BN2 = get_norm(_norm_type, norm_groups= norm_groups) 31 | self.conv2 = gluon.nn.Conv2D(self.nfilters,kernel_size = self.kernel_size,padding=p,dilation=self.dilation_rate,use_bias=True, groups=ngroups) 32 | 33 | 34 | def hybrid_forward(self,F,_input_layer): 35 | 36 | x = self.BN1(_input_layer) 37 | x = F.relu(x) 38 | x = self.conv1(x) 39 | 40 | x = self.BN2(x) 41 | x = F.relu(x) 42 | x = self.conv2(x) 43 | 44 | return x 45 | 46 | class FracTALResNet_unit(HybridBlock): 47 | def __init__(self, nfilters, ngroups=1, nheads=1, kernel_size=(3,3), dilation_rate=(1,1), norm_type = 'BatchNorm', norm_groups=None, ftdepth=5,**kwards): 48 | super().__init__(**kwards) 49 | 50 | with self.name_scope(): 51 | self.block1 = ResNet_v2_block(nfilters,kernel_size,dilation_rate,_norm_type = norm_type, norm_groups=norm_groups, ngroups=ngroups) 52 | self.attn = FTAttention2D(nkeys=nfilters, nheads=nheads, kernel_size=kernel_size, norm = norm_type, norm_groups = norm_groups,ftdepth=ftdepth) 53 | 54 | self.gamma = self.params.get('gamma', shape=(1,), init=mx.init.Zero()) 55 | 56 | def hybrid_forward(self, F, input, gamma): 57 | out1 = self.block1(input) 58 | 59 | 60 | att = self.attn(input) 61 | att= F.broadcast_mul(gamma,att) 62 | 63 | out = F.broadcast_mul((input + out1) , F.ones_like(out1) + att) 64 | return out 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | glob 2 | numpy 3 | rasterio 4 | pickle 5 | opencv-python 6 | pathos 7 | mxnet 8 | 9 | -------------------------------------------------------------------------------- /src/LVRCDDataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | DataSet reader for the LEVIRCD dataset. 3 | """ 4 | 5 | import numpy as np 6 | import glob 7 | from mxnet.gluon.data import dataset 8 | import cv2 9 | import mxnet as mx 10 | import pickle 11 | 12 | class LVRCDDataset(dataset.Dataset): 13 | def __init__(self, root=r'/Location/Of/Your/LEVIRCD/Files/', mode='train', mtsk = True, transform=None, norm=None, pMessUp=0.0, Filter=256, prob_swap=0.5, prob_zero_change=0.5): 14 | 15 | self.NClasses=2 16 | self.NChannels=3 # RGB 17 | # Transformation of augmented data 18 | self._mode = mode 19 | self.mtsk = mtsk 20 | 21 | self.prob_swap = prob_swap 22 | self.prob_zero_change = prob_zero_change 23 | 24 | 25 | self._transform = transform 26 | self._norm = norm # Normalization of img 27 | 28 | if (root[-1]!='/'): 29 | root = root + r'/' 30 | 31 | if mode is 'train': 32 | flname_idx = glob.glob(root + r'training_LVRCD_F{}.idx'.format(Filter))[0] 33 | flname_rec = glob.glob(root + r'training_LVRCD_F{}.rec'.format(Filter))[0] 34 | elif mode is 'val': 35 | flname_idx = glob.glob(root + r'validation_LVRCD_F{}.idx'.format(Filter))[0] 36 | flname_rec = glob.glob(root + r'validation_LVRCD_F{}.rec'.format(Filter))[0] 37 | else: 38 | raise Exception ('I was given inconcistent mode, available choices: {train, val}, aborting ...') 39 | 40 | 41 | self.record = mx.recordio.MXIndexedRecordIO(idx_path=flname_idx, uri=flname_rec , flag='r') 42 | 43 | def get_boundary(self, labels, _kernel_size = (3,3)): 44 | 45 | label = labels.copy().astype(np.uint8) 46 | for channel in range(label.shape[0]): 47 | temp = cv2.Canny(label[channel],0,1) 48 | label[channel] = cv2.dilate(temp, cv2.getStructuringElement(cv2.MORPH_CROSS,_kernel_size) ,iterations = 1) 49 | 50 | label = label.astype(np.float32) 51 | label /= 255. 52 | return label 53 | 54 | def get_distance(self,labels): 55 | label = labels.copy().astype(np.uint8) 56 | dists = np.empty_like(label,dtype=np.float32) 57 | for channel in range(label.shape[0]): 58 | dist = cv2.distanceTransform(label[channel], cv2.DIST_L2, 0) 59 | dist = cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX) 60 | dists[channel] = dist 61 | 62 | return dists 63 | 64 | 65 | def __getitem__(self, idx): 66 | 67 | key = self.record.keys[idx] 68 | imgall = pickle.loads(self.record.read_idx(key)) 69 | 70 | base = imgall[:self.NChannels*2].astype(np.float32) 71 | mask = imgall[self.NChannels*2:].astype(np.float32) 72 | mask[self.NClasses*2:] = mask[self.NClasses*2:]/100. # Bring the distance transform to 0,1 scale 73 | 74 | 75 | if self.mtsk == False: 76 | mask = mask[:self.NClasses,:,:] 77 | 78 | if self._transform is not None: 79 | base, mask = self._transform(base, mask) 80 | # RGB images, 3 bands 81 | t1 = base[:self.NChannels] 82 | t2 = base[self.NChannels:] 83 | 84 | 85 | # @@@@@@@@@@@@@@@@@@@@@ TWO ESSENTIAL transformations @@@@@@@@@@@@@@@ 86 | # Select randomly NOCHANGE or Great Scott 87 | if np.random.rand() >= 0.5: 88 | # Great Scott: random time ordering 89 | if np.random.rand() >= self.prob_swap: 90 | temp = t2.copy() 91 | t2 = t1 92 | t1 = temp 93 | else: 94 | # NOCHANGE to help avoid learning buildings as a mask 95 | if np.random.rand() >= self.prob_zero_change: 96 | if np.random.rand() >= 0.5: 97 | t2 = t1 98 | else: 99 | t1 = t2 100 | # Segmentation is all NOCHANGE now fix mask 101 | mask = mask[:self.NClasses,:,:] 102 | # No CHANGE 103 | mask[0] = 1 104 | mask[1] = 0 105 | boundaries = self.get_boundary(mask) 106 | dists = self.get_distance(mask) 107 | mask = np.concatenate([mask,boundaries,dists],axis=0) 108 | mask = mask.astype(np.float32) 109 | # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 110 | 111 | 112 | 113 | if self._norm is not None: 114 | t1 = self._norm(t1.astype(np.float32)) 115 | t2 = self._norm(t2.astype(np.float32)) 116 | 117 | return t1.astype(np.float32), t2.astype(np.float32), mask.astype(np.float32) 118 | 119 | else: 120 | # RGB images, 3 bands 121 | t1 = base[:self.NChannels] 122 | t2 = base[self.NChannels:] 123 | if self._norm is not None: 124 | t1 = self._norm(t1.astype(np.float32)) 125 | t2 = self._norm(t2.astype(np.float32)) 126 | 127 | 128 | return t1.astype(np.float32), t2.astype(np.float32), mask.astype(np.float32) 129 | 130 | def __len__(self): 131 | return len(self.record.keys) 132 | 133 | 134 | -------------------------------------------------------------------------------- /src/LVRCDNormal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class for normalizing the sliced images for the LEVIRCD dataset 3 | """ 4 | 5 | 6 | import numpy as np 7 | 8 | 9 | # Class to normalize images 10 | class LVRCDNormal(object): 11 | """ 12 | class for Normalization of images, per channel, in format CHW 13 | """ 14 | def __init__(self): 15 | 16 | # Normalization constants for image -- calculated from training images 17 | self._mean = np.array([100.90723866, 99.52347812, 84.97354742]) 18 | self._std = np.array ([ 42.8782652 , 40.90759297, 38.31541013 ]) 19 | 20 | 21 | 22 | def __call__(self,img): 23 | 24 | temp = img.astype(np.float32) 25 | temp2 = temp.T 26 | temp2 -= self._mean 27 | temp2 /= self._std 28 | 29 | temp = temp2.T 30 | 31 | return temp 32 | 33 | 34 | 35 | def restore(self,normed_img): 36 | 37 | d2 = normed_img.T * self._std 38 | d2 = d2 + self._mean 39 | d2 = d2.T 40 | d2 = np.round(d2) 41 | d2 = d2.astype('uint8') 42 | 43 | return d2 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /src/semseg_aug_cv2.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import itertools 3 | import numpy as np 4 | 5 | 6 | class ParamsRange(dict): 7 | def __init__(self): 8 | 9 | 10 | # Good default values for 256x256 images 11 | self['center_range'] =[0,256] 12 | self['rot_range'] =[-85.0,85.0] 13 | self['zoom_range'] = [0.75,1.25] 14 | self['noise_mean'] = [0]*5 15 | self['noise_var'] = [10]*5 16 | 17 | 18 | class SemSegAugmentor_CV(object): 19 | """ 20 | INPUTS: 21 | parameters range for all transformations 22 | probability of transformation to take place - default to 1. 23 | Nrot: number of rotations in comparison with reflections x,y,xy. Default to equal the number of reflections. 24 | """ 25 | def __init__(self, params_range, prob = 1.0, Nrot=5, norm = None, one_hot = True): 26 | 27 | self.norm = norm # This is a necessary hack to apply brightness normalization 28 | self.one_hot = one_hot 29 | self.range = params_range 30 | self.prob = prob 31 | assert self.prob <= 1 , "prob must be in range [0,1], you gave prob::{}".format(prob) 32 | 33 | 34 | # define a proportion of operations? 35 | self.operations = [self.reflect_x, self.reflect_y, self.reflect_xy,self.random_brightness, self.random_shadow] 36 | self.operations += [self.rand_shit_rot_zoom]*Nrot 37 | self.iterator = itertools.cycle(self.operations) 38 | 39 | 40 | def _shift_rot_zoom(self,_img, _mask, _center, _angle, _scale): 41 | """ 42 | OpenCV random scale+rotation 43 | """ 44 | imgT = _img.transpose([1,2,0]) 45 | if (self.one_hot): 46 | maskT = _mask.transpose([1,2,0]) 47 | else: 48 | maskT = _mask 49 | 50 | cols, rows = imgT.shape[:-1] 51 | 52 | # Produces affine rotation matrix, with center, for angle, and optional zoom in/out scale 53 | tRotMat = cv2.getRotationMatrix2D(_center, _angle, _scale) 54 | 55 | img_trans = cv2.warpAffine(imgT,tRotMat,(cols,rows),flags=cv2.INTER_AREA, borderMode=cv2.BORDER_REFLECT_101) # """,flags=cv2.INTER_CUBIC,""" 56 | mask_trans= cv2.warpAffine(maskT,tRotMat,(cols,rows),flags=cv2.INTER_AREA, borderMode=cv2.BORDER_REFLECT_101) 57 | 58 | img_trans = img_trans.transpose([2,0,1]) 59 | if (self.one_hot): 60 | mask_trans = mask_trans.transpose([2,0,1]) 61 | 62 | return img_trans, mask_trans 63 | 64 | 65 | def reflect_x(self,_img,_mask): 66 | 67 | img_z = _img[:,::-1,:] 68 | if self.one_hot: 69 | mask_z = _mask[:,::-1,:] # 1hot representation 70 | else: 71 | mask_z = _mask[::-1,:] # standard (int's representation) 72 | 73 | return img_z, mask_z 74 | 75 | def reflect_y(self,_img,_mask): 76 | img_z = _img[:,:,::-1] 77 | if self.one_hot: 78 | mask_z = _mask[:,:,::-1] # 1hot representation 79 | else: 80 | mask_z = _mask[:,::-1] # standard (int's representation) 81 | 82 | return img_z, mask_z 83 | 84 | def reflect_xy(self,_img,_mask): 85 | img_z = _img[:,::-1,::-1] 86 | if self.one_hot: 87 | mask_z = _mask[:,::-1,::-1] # 1hot representation 88 | else: 89 | mask_z = _mask[::-1,::-1] # standard (int's representation) 90 | 91 | return img_z, mask_z 92 | 93 | 94 | 95 | def rand_shit_rot_zoom(self,_img,_mask): 96 | 97 | center = np.random.randint(low=self.range['center_range'][0], 98 | high=self.range['center_range'][1], 99 | size=2) 100 | # This is in radians 101 | angle = np.random.uniform(low=self.range['rot_range'][0], 102 | high=self.range['rot_range'][1]) 103 | 104 | scale = np.random.uniform(low=self.range['zoom_range'][0], 105 | high=self.range['zoom_range'][1]) 106 | 107 | 108 | return self._shift_rot_zoom(_img,_mask,tuple(center),angle,scale) #, tuple(center),angle,scale 109 | 110 | 111 | 112 | # ============================================ New additions below ======================================================= 113 | # **************** Random brightness (light/dark) and random shadow polygons ************* 114 | # ******** Taken from: https://medium.freecodecamp.org/image-augmentation-make-it-rain-make-it-snow-how-to-modify-a-photo-with-machine-learning-163c0cb3843f 115 | # ******** See https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library for library 116 | 117 | def random_brightness(self,_img, _mask): 118 | """ 119 | This function only applies only on the first 3 channels (RGB) of an image. 120 | Input: RGB image, transforms to np.uint8 121 | Output: RGB image + extra channels. 122 | """ 123 | 124 | if self.norm is not None: 125 | image = self.norm.restore(_img).transpose([1,2,0])[:,:,:3].copy() # use only three bands 126 | imgcp = self.norm.restore(_img.copy()) # use only three bands 127 | 128 | else : 129 | 130 | image = _img.transpose([1,2,0])[:,:,:3].copy().astype(np.uint8) # use only three bands 131 | imgcp = _img.copy() .astype(np.uint8)# use only three bands 132 | 133 | image_HLS = cv2.cvtColor(image,cv2.COLOR_RGB2HLS) ## Conversion to HLS 134 | image_HLS = np.array(image_HLS, dtype = np.float64) 135 | random_brightness_coefficient = np.random.uniform()+0.5 ## generates value between 0.5 and 1.5 136 | image_HLS[:,:,1] = image_HLS[:,:,1]*random_brightness_coefficient ## scale pixel values up or down for channel 1(Lightness) 137 | image_HLS[:,:,1][image_HLS[:,:,1]>255] = 255 ##Sets all values above 255 to 255 138 | image_HLS = np.array(image_HLS, dtype = np.uint8) 139 | image_RGB = cv2.cvtColor(image_HLS,cv2.COLOR_HLS2RGB) ## Conversion back to RGB 140 | 141 | 142 | imgcp[:3,:,:] = image_RGB.transpose([2,0,1]) 143 | 144 | if self.norm is not None: 145 | imgcp = self.norm(imgcp) 146 | 147 | return imgcp.astype(_img.dtype), _mask 148 | 149 | 150 | 151 | def _generate_shadow_coordinates(self,imshape, no_of_shadows=1): 152 | vertices_list=[] 153 | for index in range(no_of_shadows): 154 | vertex=[] 155 | for dimensions in range(np.random.randint(3,15)): ## Dimensionality of the shadow polygon 156 | vertex.append(( imshape[1]*np.random.uniform(), imshape[0]*np.random.uniform())) 157 | vertices = np.array([vertex], dtype=np.int32) ## single shadow vertices 158 | vertices = cv2.convexHull(vertices[0]) 159 | vertices = vertices.transpose([1,0,2]) 160 | vertices_list.append(vertices) 161 | return vertices_list ## List of shadow vertices 162 | 163 | def _add_shadow(self, image, no_of_shadows=1): 164 | image_HLS = cv2.cvtColor(image,cv2.COLOR_RGB2HLS) ## Conversion to HLS 165 | tmask = np.zeros_like(image[:,:,0]) 166 | imshape = image.shape 167 | vertices_list= self._generate_shadow_coordinates(imshape, no_of_shadows) #3 getting list of shadow vertices 168 | for vertices in vertices_list: 169 | cv2.fillPoly(tmask, vertices, 255) 170 | image_HLS[:,:,1][tmask[:,:]==255] = image_HLS[:,:,1][tmask[:,:]==255]*0.5 171 | image_RGB = cv2.cvtColor(image_HLS,cv2.COLOR_HLS2RGB) ## Conversion to RGB 172 | return image_RGB 173 | 174 | 175 | def random_shadow(self,_img, _mask): 176 | 177 | 178 | if self.norm is not None: 179 | image = self.norm.restore(_img).transpose([1,2,0])[:,:,:3].copy() # use only three bands 180 | imgcp = self.norm.restore(_img.copy()) # use only three bands 181 | 182 | else : 183 | 184 | image = _img.transpose([1,2,0])[:,:,:3].copy().astype(np.uint8)# use only three bands 185 | imgcp = _img.copy() .astype(np.uint8)# use only three bands 186 | 187 | shadow_image = self._add_shadow(image) 188 | 189 | imgcp.transpose([1,2,0])[:,:,:3] = shadow_image 190 | 191 | if self.norm is not None: 192 | imgcp = self.norm(imgcp) 193 | 194 | return imgcp.astype(_img.dtype), _mask 195 | 196 | # ===================================================================================== 197 | 198 | 199 | def __call__(self,_img, _mask): 200 | 201 | rand = np.random.rand() 202 | if (rand <= self.prob): 203 | return next(self.iterator)(_img,_mask) 204 | else : 205 | return _img, _mask 206 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/get_norm.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from mxnet import gluon 3 | 4 | def get_norm(name, axis=1, norm_groups=None): 5 | if (name == 'BatchNorm'): 6 | return gluon.nn.BatchNorm(axis=axis) 7 | elif (name == 'InstanceNorm'): 8 | return gluon.nn.InstanceNorm(axis=axis) 9 | elif (name == 'LayerNorm'): 10 | return gluon.nn.LayerNorm(axis=axis) 11 | elif (name == 'GroupNorm' and norm_groups is not None): 12 | return gluon.nn.GroupNorm(num_groups = norm_groups) # applied to channel axis 13 | else: 14 | raise NotImplementedError 15 | --------------------------------------------------------------------------------